Using numba to speed up mean-squared displacement calculations
You can download this whole thing as a Jupyter notebook here
Writing a faster mean-squared-displacement function¶
I'm trying to calculate the Mean Squared Displacement (MSD) of a particle trajectory. In reality, I have an array of positions for numtrj
particles and numpts
timepoints and dim
dimensions: pos[numtrj, numpts, dim]
. I think my question has the same answer if I just have the x
trajectory of a single particle, though.
In case you haven't done MSD calculations before, there's one cute way in which you get extra information out of a trajectory. Say I have just five time points, and my positions are
In [83]: x
Out[83]: array([ 0. , 1.74528704, 1.59639865, 2.59976219, 3.70852457])
You could just get squared displacement by looking at x**2. However, you could also say that you have 4 different values for the displacement at one timestep:
x[1] - x[0], x[2] - x[1], x[3] - x[2], x[4] - x[3]
Similarly, three values for displacement at two timesteps: x[2:] - x[:-2]. Etc. So, the way I'm calculating MSD at the moment is:
def msd_1d(x):
result = np.zeros_like(x)
for i in range(1,len(x)):
result[i] = np.average((x[i:] - x[:-i])**2)
return result
or
def get_msd_traj(pos):
result = np.zeros_like(pos)
for i in range(1,pos.shape[1]):
result[:,i,:] = np.average((pos[:,i:,:] - pos[:,:-i,:])**2,axis=1)
return result
(side note: often the data comes indexed like pos[numpts, numtrj, dim]
for molecular dynamics trajectories, but that doesn't change anything here)
So, I asked Joshua Adelman if he had any quick thoughts.
Josh cleared up some basic misconceptions for me and gave a couple of suggestions:
From Josh¶
Hi Michael,
In general, numba is not going to speed up calculations that involve numpy built-in methods (unless they are things like np.exp, np.sin, np.cos) or array slicing. It works best if you have unrolled any vectorization into explicited nested loops and operate on arrays element-by-element. If you go that route, check to see if you can get the code to compile using @jit(nopython=True), which would be indicitive of numba being able to translate the code to llvm IR without using python objects (which tend to slow things down). I've also had a lot of success lately parallelizing numba code using threading if you aren't dealing with any python objects. If you can compile in nopython mode, then you can also specify:
@jit(nopython=True, nogil=True)and use a ThreadPool to chop up the work, using something like:
http://numba.pydata.org/numba-doc/0.19.1/user/examples.html?highlight=nogil#multi-threading
it's undocumented, but for simple stuff you can use:
from multiprocessing.pool import ThreadPool pool = ThreadPool(processes=nthreads) results = pool.map(func, arg)It has the same API as the multiprocessing pool, and for embarassingly parallel tasks that don't have possible race conditions, it works quite nicely depending on the overhead involved in spawning off the task relative to the cost of the task.
Also, make sure you're using the latest version of numba since it's getting better rapidly. I'm not sure if you'll be able to call np.zeros_like in a numba-ized method in nopython mode, although you might. If not, you can pass in the results array as an argument.
Hope that helps. I can't think of any sneaky vectorization tricks in pure-numpy off the top of my head to make your calculation faster. Let me know if you have any other questions.
Josh
Back to me¶
So, I tried several things. You have to do some fiddling around to get nopython=True
to work. In addition to not knowing about np.average
, things like x[1:] - x[:-1]
are doomed to failure. However, that last part isn't really trouble, because one of Josh's points was that I should write out all of my loops explicitly. Here are two versions of the 1D version:
import numpy as np
from numba import jit
def msd_1d(x):
result = np.zeros_like(x)
for i in range(1,len(x)):
result[i] = np.average((x[i:] - x[:-i])**2)
return result
@jit(nopython = True)
def msd_1d_nb1(x):
result = np.zeros_like(x)
for delta in range(1,len(x)):
thisresult = 0
for i in range(delta,len(x)):
thisresult += (x[i] - x[i-delta])**2
result[delta] = thisresult / (len(x) - delta)
return result
@jit(nopython = True)
def msd_1d_nb2(x):
result = np.zeros_like(x)
for delta in range(1,len(x)):
for i in range(delta,len(x)):
result[delta] += (x[i] - x[i-delta])**2
result[delta] = result[delta] / (len(x) - delta)
return result
Note that the above definitions will work almost no matter what you do. You don't get the numba
issues until you try to actually run the functions. Here are some timings:
%timeit msd_1d(np.random.randn(10))
%timeit msd_1d_nb1(np.random.randn(10))
%timeit msd_1d_nb2(np.random.randn(10))
%timeit msd_1d(np.random.randn(100))
%timeit msd_1d_nb1(np.random.randn(100))
%timeit msd_1d_nb2(np.random.randn(100))
%timeit msd_1d(np.random.randn(1000))
%timeit msd_1d_nb1(np.random.randn(1000))
%timeit msd_1d_nb2(np.random.randn(1000))
So, pretty big win there! Approximately two orders of magnitude. It looks like it scales well too.
The larger version¶
So, it looks like there's not much of a difference between version 1 and 2. But it does look like version 1 is a bit faster.
def get_msd_traj(pos):
result = np.zeros_like(pos)
for i in range(1,pos.shape[1]):
result[:,i,:] = np.average((pos[:,i:,:] - pos[:,:-i,:])**2,axis=1)
return result
@jit(nopython = True)
def get_msd_traj_nb1(pos):
result = np.zeros_like(pos)
deltastop = pos.shape[1]
for traj in range(pos.shape[0]):
for dim in range(pos.shape[2]):
for delta in range(1,deltastop):
thisresult = 0
for i in range(delta,deltastop):
thisresult += (pos[traj,i,dim] - pos[traj,i-delta,dim])**2
result[traj,delta,dim] = thisresult / (deltastop - delta)
return result
First, let's do some due diligence and make sure the results are equivalent. Then, timings.
a = np.random.randn(5,3,2)
np.all(get_msd_traj(a) == get_msd_traj_nb1(a))
%timeit get_msd_traj(np.random.randn(10,10,2))
%timeit get_msd_traj_nb1(np.random.randn(10,10,2))
%timeit get_msd_traj(np.random.randn(100,100,2))
%timeit get_msd_traj_nb1(np.random.randn(100,100,2))
%timeit get_msd_traj(np.random.randn(512,2001,2))
%timeit get_msd_traj_nb1(np.random.randn(512,2001,2))
So, for the actual data sizes I care about, I'm probably down to "only" an order of magnitude speedup. That's still pretty awesome.
from IPython.core.display import HTML
HTML('''
<a href="http://example.com">link</a>
<h2>Comments from Old blog</h2>
<div id="comments">
<h3 id="comments-title">One Response to <em>Using numba to speed up mean-squared displacement calculations</em></h3>
<ol class="commentlist">
<li class="comment byuser comment-author-mglerner bypostauthor even thread-even depth-1" id="li-comment-27141">
<div id="comment-27141">
<div class="comment-author vcard">
<img alt='' src='http://1.gravatar.com/avatar/d49bf8fdd300871a66f21a8a97674483?s=40&d=mm&r=g' srcset='http://1.gravatar.com/avatar/d49bf8fdd300871a66f21a8a97674483?s=80&d=mm&r=g 2x' class='avatar avatar-40 photo' height='40' width='40' /> <cite class="fn">mglerner</cite> <span class="says">says:</span> </div><!-- .comment-author .vcard -->
<div class="comment-meta commentmetadata"><a href="http://www.mglerner.com/blog/?p=52#comment-27141">
July 3, 2015 at 3:21 pm</a> <a class="comment-edit-link" href="http://www.mglerner.com/blog/wp-admin/comment.php?action=editcomment&c=27141">(Edit)</a> </div><!-- .comment-meta .commentmetadata -->
<div class="comment-body"><p>On twitter, <a href="https://twitter.com/khinsen" rel="nofollow">Konrad Hinsen</a> <a href="https://twitter.com/khinsen/status/616853234041384961" rel="nofollow">pointed out</a> that I could get a bigger speedup (likely a factor of 1000, but it's NlogN instead of N^2, so it only gets better) by using a better algorithm. Several years ago, I did try to use an FFT-based method. I was sketched out by the fact that the results I obtained were similar to the exact results, but not identical. It sounds likely that his method is just better. The paper claims it's exact, and a followup conversation indicates that they get better accuracy by doing fewer calculations (thus less accumulated error). Filed away for next time!</p>
</div>
<div class="reply">
<a rel='nofollow' class='comment-reply-link' href='http://www.mglerner.com/blog/?p=52&replytocom=27141#respond' onclick='return addComment.moveForm( "comment-27141", "27141", "respond", "52" )' aria-label='Reply to mglerner'>Reply</a> </div><!-- .reply -->
</div><!-- #comment-## -->
</li><!-- #comment-## -->
</ol>
</div><!-- #comments -->''')
On twitter, Konrad Hinsen pointed out that I could get a bigger speedup (likely a factor of 1000, but it's NlogN instead of N^2, so it only gets better) by using a better algorithm. Several years ago, I did try to use an FFT-based method. I was sketched out by the fact that the results I obtained were similar to the exact results, but not identical. It sounds likely that his method is just better. The paper claims it's exact, and a followup conversation indicates that they get better accuracy by doing fewer calculations (thus less accumulated error). Filed away for next time!