Fast LSTMs in PyTorch
Implementing fast recurrent neural networks is a challenging task. This is not only a hassle for training existing architectures  sometimes optimized implementations such as CuDNN's LSTM help there. More gravely, it also limits experimentation with new architectures.
For example, fast.ai's Jeremy Howard writes
For instance, we’ve been doing lots of research in to different types of recurrent neural network architectures and normalization layers. In both cases, we haven’t been able to get the same level of performance that we see in pure CUDA C implementations, even when using PyTorch’s fantastic new JIT compiler.
 Jeremy Howard
Jeremy does awesome big learning research and teaching and I cannot recommend his courses highly enough (best complemented by my own workshops to get really jump start your AI projects, of course), so I use him as my reference user here.
His comment hints at the traditional ways to solve this: write (better) CUDA kernels. I've done this for Batch Norm and CTC Loss (which in PyTorch now are of comparable speed as CuDNN on my GPU). But that is a lot of effort. And  in particular for RNNs  it doesn't really lend itself to rapid experimentation. You also end up writing two kernels because you loose the ability to use PyTorch's automatic differentiation.
Jeremy's conclusion is to write a new framework in a new language. He isn't too fond of Python and I'll liberally admit that some of his reasons (e.g. typing) are very high on my list of Python shortcomings, too, even if I still think that Python is a great language. So the promise is that the new Swift for TensorFlow will eventually come with a new intermediate representation MLIR, that makes all sorts of optimizations easy.
But PyTorch does have a great JIT IR already and it's what enables the optimizations we already see. So instead inventing a new language, we look at making the JIT compiler^{1} even more fantastic.
This post is going to be about all the technical details I looked in an attempt to make PyTorchJITed LSTMs faster.
A brief overview of how the JIT speeds up PyTorch models
Before we begin, let's briefly discuss how the PyTorch JIT (JustInTime compiler) brings speed to models.
The most obvious thing is that it takes Python out of the equation. In my experience this typically gives a speedup of 10%. That is nice, but based on a poll during my talk in December and discussions on the PyTorch forum, it seems to be much less than people generally expect.
But the JIT's internal representation, the TorchScript IR, also lends itself to optimizations just like any selfrespecting compiler does them. One of the most interesting optimizations here is the fuser.
The fuser takes several operations form the IR of your model and produces a single CPU or GPU kernel that performs those in one go. To see why this is desperately needed, consider what usually happens when you have a tensor $x$ and compute $y = 2 \cdot x + 1$. The computer sees $2 \cdot x$ and computes that, meaning, it reads $x$ from memory into its registers, applies the multiplication, and stores the result in a new tensor. It then goes to add $1$, reading $2 \cdot x$ from memory and writing back $y$. The tedious bits  not only in the description  are the memory reading and writing. Just like you would expect, it gets much faster  just about twice as fast, as the arithmetic is of negligible complexity here  if you load $x$ from memory, multiply by $2$ and add $1$ and then write back $y$.
Now the modular nature of typical Neural Network architectures means there are quite a few of these pointwise operations in a row in certain places. Normalization layers typically end with an affine transformation $y = a \cdot (x  mean(x)) / std(x) + b$ where $a$ and $b$ are also (broadcasted) tensors and are followed by an activation function such as ReLU. While a good normalization layer implementation will cover the affine operation, the activation remains separate.
In modern RNNs, the situation is even more pronounced due to the gates, internal activations used to modulate inputs, updates, and outputs. In the LSTM cell in PyTorch we have $$ \begin{aligned} i_t &= \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{(t1)} + b_{hi}) \\ f_t &= \sigma(W_{if} x_t + b_{if} + W_{hf} h_{(t1)} + b_{hf}) \\ g_t &= \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{(t1)} + b_{hg}) \\ o_t &= \sigma(W_{io} x_t + b_{io} + W_{ho} h_{(t1)} + b_{ho}) \\ c_t &= f_t c_{(t1)} + i_t g_t \\ h_t &= o_t \tanh(c_t), \\ \end{aligned} $$ see the linked PyTorch documentation for details (and I don't know why we have two biases, except because we can and CuDNN has it). One big part is the matrix multiplications. It's efficient to combine those so we would get $$ \begin{aligned} \hat i_t, \hat f_t, \hat g_t, \hat o_t &= chunk (W_{i} x_t + b_i + W_{h} h_{(t1)} + b_{h}) \\ i_t &= \sigma (\hat i_t) \\ f_t &= \sigma (\hat f_t) \\ g_t &= \tanh (\hat g_t) \\ o_t &= \sigma (\hat o_t) \\ c_t &= f_t c_{(t1)} + i_t g_t \\ h_t &= o_t \tanh(c_t), \\ \end{aligned} $$
Now we clearly see that after the first line, we only have pointwise functions. Ideal for fusing!
A bit on how the fuser works
When optimizing a graph, one of the things the JIT will do is the fusion pass. There it looks for a series of pointwise operations to combine into one. This is done inspecting the IR and also trying out reorderings of independent computation, and checking if each operation is fusible^{2} to find as large aggregations as possible to fuse, the fusion groups. If you take a script function, say,
@torch.jit.script
def cell_end(ingate, forgetgate, cellgate, outgate, cx):
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
you can actually print cell_end.graph
to see what the JIT thinks it does before optimization.
After running it with random inputs (say inp = torch.randn(5, 10, 4); cell_end(*inp)
), you can use
cell_end.graph_for(*inp)
to see the optimized graph with the fusion groups. (I have a detailed example for the IoU function in a notebook prepared for the talk linked above.)
What then happens is that when running inputs through the model, those fusion groups are presented to the fuser's execution part, which decides whether the inputs are compatible enough for the fusion, looks for a kernel or compiles one, prepares inputs and allocates outputs and runs the custom kernel. Compilation, here, means that C/CUDAcode is generated and compiled with NVidia's runtime compiler NVRTC or (if CPU fusion is enabled) GCC. These automatically generated kernels thus save you the work of writing your own.
In addition to pointwise operations, it can also fuse chunking (torch.chunk
) and concatenations (torch.cat
), the latter at the end of a fusion group. These are dealt with by viewing the inputs and allocating the concatenated outputs, to that isn't in the kernels themselves.
So as a summary, the fuser is great at dealing with sequences of pointwise operations, but it does not currently support reductions.^{3} With what it currently does, you can get the LSTM forward almost as fast as the CuDNN implementation.
New tricks
In the beginning, the fuser only dealt with cases where all inputs were of the same scalar type. For many calculations, such as those above, that is sufficient. But sometimes, it can be a limitation. I had discussed useful JIT enhancements with Francisco Massa, who created maskrcnnbenchmark. Specifically I wanted to fuse the intersectionoverunion function and looked at what was needed to fuse that. So one of my first larger contributions to the fuser was to add support for mixed scalar types and to support fusing torch.where
, which, of course, needs boolean support.
One of the immensely useful optimizations the JIT has learned post1.0 is to deal with the normalization layers  I believe this was implemented mainly by Adam Paszke. Here, the fuser could already fuse the parts that should be fused, but it didn't get to see them. Batch Norm (and similar layers) is presented to the user as one function, but it really has two parts (in training), the gathering of statistics also known as the reduction part, and the affine transformation shown above. So the key to getting the latter fused is to split the batch norm in the IR. A seemingly simple trick (when you leave out all the details) but enough make the JIT even better for Jeremy  at least I think this was part of his excitement.
But what about derivatives
Now, one of the defining features of PyTorch is the automatic differentiation support. It works by keeping track of the computation as it happens and then going backwards in the recorded calculation graph to compute derivatives. When things now happen en bloc, this doesn't work anymore.
So what we do about that? Well, we have the IR representation, so if we can differentiate that... And sure enough, PyTorch has  perhaps not as widely known  a second automatic differentiation mechanism, one that works at the IR level, sometimes called sourcetosource differentiation, because it doesn't require execution. This is called the autodiff JIT component in PyTorch.^{4}. This is what, to reuse Jeremy's words incorporates differentiable programming deep in to the heart of  well  TorchScript ^{5}.
Initially (and still large parts) of autodiff's derivatives are implemented in C++, but actually PyTorch is getting a way to express functions and derivatives in Python / TorchScript.
So with Autodiff, we can symbolically differentiate the IR representation, which means that we can then apply fusion for the forward and get the backward by running autodiff and excuting (with fusion) the graph for the derivative. (Again, see the notebook linked above for an example of graphs that can be autodifferentiated and the corresponding backward graphs.)
Why optimizing backwards this way is a hard problem
One of the convenient features of PyTorch (and NumPy before that) that makes our life hard in automatic differentiation is broadcasting. We are used to adding a [1, channel, 1, 1]
shaped bias to a [bs, channel, w, h]
shaped image and similar operations and have it implicitly treated as being constant along the singleton axes. But for automatic differentiation, we need to pull the broadcasting into light, because it does have a derivative: summation.^{6} But this means that between the derivatives of every op, we might have summation. And as written above, the fuser does not deal with summation at this point.
This is the situation for PyTorch 1.0. PyTorch's LSTM benchmarking showed JITed LSTM backward as taking about 3 times as much time as CuDNNs. ^{7}
So what to do?
So in early December I had prepared my talk on speeding up PyTorch and fusing IoU. But then after updating PyTorch the backward optimization wouldn't work anymore. Would I have anything to talk about? I briefly and unsuccessfully tried to sell Adam^{8} a flag to disable broadcasting. But then I thought a bit more about the problem.
So if it isn't easy to teach the fuser to execute these SumToSize^{9} operations, as the summations are called because you hand in the size of the (potentially broadcasted) tensor to get the gradient summed to the same shape, can we avoid having them in those parts of the graph we want to fuse?
Just as the chain rule enables backpropagation, it comes to the rescue here, too. The SumToSizes aren't in arbitrary places in the graph, but they are only applied to gradients of intermediate values of our function. Recall the basic structure of the chain rule: for, $f(x) = a \cdot x$ and $y = f(x); z = g(y)$ it gives $$ \frac{\partial z}{\partial x} = \nabla_x f(x) * \nabla_y g(y) = a \cdot \nabla_y g(y). $$ where the $*$ is matrix multiplication, while the $\cdot$ is pointwise multiplication.
Crucially, no matter how complicated the derivative formula is, it is conceptually linear, i.e. something matrix — well — tensormultiplied with the outgradient. But this implies that we will never see products of two SumToSized variables  that would not be linear in the gradients. This is important as otherwise we would have to check for the distribution law: Writing the SumToSize as some $\sum_i a_i$ for some unspecified $i$, you cannot express $(\sum_i a_i) \cdot (\sum_i b_i)$ in terms of just the elements of $a_i \cdot b_i$, so you could not commute the SumToSize it with the multiplication. But we know now that SumToSize only comes in the form of $(\sum_i a_i) \cdot b$, which we can write as $\sum_i (a_i \cdot b)$ if we want, commuting the multiplication by $c$ with the SumToSize.
Also, we will never apply a nonlinear function (say $tanh$ or $ReLU$ or anything else) to a gradient, because the chain formula never asks for that.
There is a small caveat: When a variable is used in more than one place, e.g. $y = f(x, x)$, where $f : (a,b) \mapsto f(a,b)$, we will have an addition $\frac{\partial y}{\partial x} = \nabla_a f + \nabla_b f$.. So we might have two SumToSizes entering the operation. Is this a problem? No, because we know that this particular addition (at some point called AutogradAdd in PyTorch's IR) is nonbroadcasting, so we know that if SumToSize hits both arguments, it will automatically be the same.
So this means we can commute the SumToSize with the gradientrelevant pointwise operations and we only need to cover those. The interesting gradients, aren't written as matrix multiplications, but they use mul
, div
(as numerator), neg
(the unary sign, when something is multiplied by $1$), the special add
we met above, and where
(when we only use part of the gradient and mimic a multiplication by 0 this way). We might also find type_as
casting gradients.
Finally if we have two consecutive SumToSizes, we only need to keep the latter, as it will be necessarily to a smaller shape, because broadcasting only makes tensors larger.
This allows us to move SumToSizes outside our prospective fusion groups. Yay!
Now we have successfully tackled the mathematical side. But there is an organisational aspect, too: We don't know if the fusion execution might decide to fall back to regular execution. Then moving SumToSize further down in the computation will make things inefficient. To avoid that, we do the rearranging in the fusion compiler. But this in turn means that we cannot move the fusions outside the fusion group  because the fusion group is now in the hands of the fuser and we need to process it completely. We solve this simply by doing the SumToSize within the fuser after running the fused kernel.
Another bookkeeping complication had been pointed out by Natalia Gimelshein, NVidia engineer and PyTorch core team member from whom I have learned a tremendous amount about CUDA programming (thanks!): It can happen that one input gradient (the output of the backward) is used several times with various broadcasting. This means that after moving the SumToSize out of the kernel compilation, we need to rededuplicate the outputs to avoid allocating memory for each of the identical instances.
So with that, IoU forward and backward are completely fused again. Also JITed LSTM backward performance is on par with the native ATen implementation at about 2.25x the time that CuDNN takes.
This is the state of the fuser PyTorch master today.
Concatenation and SumToSize
If you look at the discussion above, one interaction has not been mentioned: That of FusedConcat and SumToSize.
One typical way for FusedConcat to appear in a backward (and, as you know by now, only those have SumToSize), is as the backward of torch.chunk
, and indeed, the LSTM above has a chunk just before the pointwise operations (which will have pointwise backwards to fuse).
The trouble is that the torch.chunk
has overly broad semantics: When you chunk into four pieces and your input has size $7$ in the dimension in question, it will happily produce four outputs of shape $2$, $2$, $2$, $1$. The last output can then be broadcast back to size $2$ later. So indeed, there are cases when we have to sum to different sizes for the different kernel outputs in one concat input. A simpler case is when we have broadcasting in the nonconcat dimension. There we can do the concat first and a single SumToSize for the entire thing. Keeping track of all that and trying to be somewhat efficient  as it's run on every execution rather than just on compilation  is a bit tedious and comes in at some 200 lines of additional code.
Finally we needed to deal with SumToSizes that undo addition of dimensions and FusedConcat. These can happen for example when you add bias and don't spell out the batch size (or sequence length) as an explicit singleton dimension. Thus we needed to teach the fuser to change the dimension on which the concatenation is carried out. That was more bookkeeping than I could do on a single day, but as it might be a quite common case I fixed that af few days later.
Current speed
With these patches (note that they haven't been merged into master yet^{10}), the PyTorch LSTM benchmark has the jitpremul LSTM backward at about 1.33x the wallclock time that CuDNN takes. When taking forward and backward, we're about $25\%$ slower than CuDNN. And that's with an LSTM cell implemented in Python / PyTorch.
We sped up the backward by about 2.25x. One of my favourite professors when I studied computer science was Arnold Schönhage, most famous for fast multiplication of large integers. In his Fast Algorithms book, he postulates golden rules, one of them is rule 2: Never waste a factor of 2!. So it would seem that we did useful work here. There also is the iron rule, number 8: The development of fast algorithms is slow! The same applies to optimizations like the above.
I'm looking forward to seeing what great things Jeremy and other people can do with something like this. I appreciate your feedback on where is is useful to you and where you see room for further optimization.
Learning about PyTorch
I hope you have enjoyed this alldetail technical writeup of speeding up AI by extending PyTorch JIT fusion.
If you are close to Munich (say, daytravelclose), and want to learn PyTorch / AI for your own projects from one of the leading PyTorch experts, check out my workshop offering. I also do consultancy and bespoke development through my company, MathInf GmbH.
I love to hear form you at tv@lernapparat.de.

While I'm proud and happy that the bits I added to the JIT are useful, I only have a very modest contribution here, namely extending the fuser a bit and adding new optimizations, but these would not be possible if there had not been dozens of great and skilled people building all this in the first place. When I try to add things to the JIT, I'm always amazed by how little code I have to add because apparently everything had been designed to make my additions possible. ↩

Actually it's (mis)spelled fusable in the source code. ↩

One interesting thing is that PyTorch's ATen tensor operations library has a great tool TensorIterator to deal with the boilerplate of pointwise functions  including optimization of the distribution to threads, using vectorized CPU instructions etc., that would be great to also use in the fuser  if we could adapt it to that the compilation done there. Interestingly, it also supports reductions, such as
sum
but alsonorm
. ↩ 
Actually, I don't know if it's called that, but it lives in a file called
autodiff.cpp
↩ 
But I won't argue that TorchScript is a widely used language and obviously it's not all of Python, so Swift for Tensorflow is certainly set to push the envelope here, similar to some of the AD implemented in Julia. And it terms of language design, the Swift and Julia certainly seem to have much in common in some of the aspects relevant here. It'll be interesting to see if one of the languages more modernly designed will displace Python before it can itself make the (breaking) changes needed to alleviate some of the direst pain points. So far, I'm not sure that any of the proposed replacements does a significantly better job with how the language looks to new users. If I were to create a new language, I'd try hard to make it look like Python as much as I can, because the surface/user experience/use philosophy/whatever you want to call it seems to be something that Python has gotten right so much more than others that it has been hugely successful despite its "compilerarchitectural" shortcomings. ↩

And here it would be great to have better typing support from Python, so that one could reason about broadcasting by treating the shape information as part of the type. This would need to allow the interplay between variables for sizes and constants. The JIT does this to some extend, keeping track of tensor dimension, but one might always wish for more. ↩

Previously, the JIT autodiff didn't support implicit broadcasting. Then, JITed LSTM was very close to CuDNN in speed. But not having broadcasting (and silently ignoring it) in autodiff was very bugprone for users, so the PyTorch team did the right thing and amended the autodiff. ↩

Adam had a lot of great advice throughout my PyTorch and JIT hacking. He also provided great suggestions for improving this article. Thank you! ↩

In order to do the reasoning that follows, we must know that the SumToSize does come from gradients. Thus TorchScript IR has an internal
aten::_grad_sum_to_size
IR node that is only used when insertingsum_to_size
into backwards. ↩ 
If you are curious, it is at in my git repository. ↩