Optimiziation using Specialization of Arguments in the PyTorch JIT
Today we explore another option for optimization in the JIT. We will see how it presents an alternative that also allows to make LSTM backwards fast, but tries to be smarter about the bookkeeping aspects.
When you use the PyTorch JIT with TorchScript, you write Python, but the JIT will see how to get things to run fast. This happens in two stages: First, the JIT converts the TorchScript (i.e. Python) code to a graph that more or less literally represents your Python code. For your torch.jit.script
ed function fn
, you can get this graph with fn.graph
(and fn.graph.pretty_print()
will give you a TorchScriptlike representation). If this is the parsing stage, there is a second compilation phase, in which the graph is transformed. This happens when you run the function, say with a, b, c
as arguments. You can see the graph that was created and run withfn.graph_for(a, b, c)
.^{1} As you might expect from the interface, different arguments might lead to different graphs being executed. And this is exactly what we want to look at.
Given that the JIT is a JustinTime Compiler, one of the key aspects is when to compile anew and when to use a cached compilation result. Whether or not two tensors look the same for this purpose is captured in the ArgumentSpec
structure. At the time of writing, the criteria are the scalar type, whether a tensor is defined, the device, and the number of dimensions. These are then captured in the DimensionedTensorType
JIT type. A special case are undefined tensors  which are a "special zero" in the context of backward passes where a differentiable output tensor does not get a gradient because it is not connected to the loss that is differentiated.
Now, one difficult problem is what to do with structured data types. The PyTorch JIT will look into tuples for that purpose, but not into lists much. As making this decision is in the critical path  it happens on every invocation of the function  it has to be fast. So the PyTorch JIT encodes the analysis based on the graph inputs (once) and then executes the encoded analysis to get a graph specialization (along with a hash code for fast lookups). This is done in the ArgumentSpecCreator
defined just below ArgumentSpec
and implemented nearby.
So how can we use this specialization step for optimization? One thing we can do  and I proposed this in a PR currently under review  is to specialize inputs of formal type Optional[T]
to either T
(or a DimensionedTensorType
for Optional[Tensor]
) or None
, depending on whether they are given. This allows us to statically prune if t is None:
branches.
This optimization is interesting, not as much because it allows to skip the if
evaluation in execution, but because it allows to reduce the number of blocks (e.g. the "if" and "else" branches in the if
statement). There are optimizations, notably the fuser compiling custom cuda kernels, that work only within a single block, but not across them. So if we can eliminate the if
, we can put the block we need into the main graph and fuse things in formerly the branch and things in the main program.
But can we also use this in a clever way to help with the issue of broadcasting and the insertion of grad_sum_to_size
from it that we discussed in the last post?^{2}
We do not know the exact shapes at compilation time, so we do not know beforehand whether broadcasting will happen. But  and this is the crucial observation  it is the broadcasting in the forward pass of a network that causes grad_sum_to_size
to exist in the backward. But this means that if we have a way to signal "no broadcasting happened in the forward" to the backward in a way that changes the argument spec of the backward (so that the backward gets its own set of optimizations compared to when broadcasting happened), we can just eliminate those grad_sum_to_size
operations we know have no effect. What we thus do is to just make the target shape an Optional[List[int]]
and return None
if no broadcasting happened. If we then specialize that Optional
to None
in the argument specification, we can just remove the grad_sum_to_size
it feeds into. This is done in a second PR  which is remarkably small if you subtract the preparatory argument specialization PR above.
A very small unrelated optimized with quite a bit of impact
Another very simple optimization affected another issue of broadcasting: As fused kernels typically combine relatively simple arithmetic operations, the execution time is dominated by the memory accesses rather than the computation itself. When analyzing the speed of a pointwise operation kernel from a fusionenabled batchnorm, I noticed that it was much slower than the batchnorm kernel. After spending quite a bit of time looking at the cause (and experimentally redoing the indexing to minimize the number of integer division/modulo operations), I found that explicitly telling NVRTC to load the values into the readonly cache helps a great deal. In hindsight, this is understandably so  the batchnorm kernel minimizes the reads of the broadcasted mean and variance, weight and bias tensors, so at least caching helps. It turned out to be a twoline patch that has been added to PyTorch.
In the future, it might be useful to use the cache more efficiently by only caching broadcasted tensors. This would require to emit code that varies depending on which tensors are broadcasted. One would have to add a corresponding flag to the fuser's argument specification (that works similarly to that for graphs, but is separate). On the other hand, one might use the above indexing (that is uniform over all tensors and also used in PyTorch's TensorIterator
framework underlying many pointwise and reducing operations), instead of distinguishing contractible dimensions for all tensors separately.
Conclusion
So we found a couple of not terribly complicated optimizations. While the removal of grad_sum_to_size
is a less generic than the rearrangement from the last blog post, it covers many important cases like the LSTM. The optimizations get very good performance for some of the most common neural network operations, bringing JITed Python code's performance very close to that of handoptimized CUDA kernels.
If you are close to Munich (say, daytravelclose), and want to learn PyTorch / AI for your own projects from one of the leading PyTorch experts, check out my workshop offering. I also do consultancy and bespoke development through my company, MathInf GmbH.
I love to hear form you at tv@lernapparat.de.

With
graph_for
bits might be missing withpretty_print
, so you need to get by with the regular text representation if you see differentiable graphs or fusion groups being used. ↩ 
And the complexity of the bookkeeping necessitated in the implementation of sumtosize in the fuser understandably did not exactly spark joy for the reviewer. ↩