Optimiziation using Specialization of Arguments in the PyTorch JIT

April 8, 2019

Today we explore another option for optimization in the JIT. We will see how it presents an alternative that also allows to make LSTM backwards fast, but tries to be smarter about the bookkeeping aspects.

When you use the PyTorch JIT with TorchScript, you write Python, but the JIT will see how to get things to run fast. This happens in two stages: First, the JIT converts the TorchScript (i.e. Python) code to a graph that more or less literally represents your Python code. For your torch.jit.scripted function fn, you can get this graph with fn.graph (and fn.graph.pretty_print() will give you a TorchScript-like representation). If this is the parsing stage, there is a second compilation phase, in which the graph is transformed. This happens when you run the function, say with a, b, c as arguments. You can see the graph that was created and run withfn.graph_for(a, b, c).1 As you might expect from the interface, different arguments might lead to different graphs being executed. And this is exactly what we want to look at.

Given that the JIT is a Just-in-Time Compiler, one of the key aspects is when to compile anew and when to use a cached compilation result. Whether or not two tensors look the same for this purpose is captured in the ArgumentSpec structure. At the time of writing, the criteria are the scalar type, whether a tensor is defined, the device, and the number of dimensions. These are then captured in the DimensionedTensorType JIT type. A special case are undefined tensors - which are a "special zero" in the context of backward passes where a differentiable output tensor does not get a gradient because it is not connected to the loss that is differentiated. Now, one difficult problem is what to do with structured data types. The PyTorch JIT will look into tuples for that purpose, but not into lists much. As making this decision is in the critical path - it happens on every invocation of the function - it has to be fast. So the PyTorch JIT encodes the analysis based on the graph inputs (once) and then executes the encoded analysis to get a graph specialization (along with a hash code for fast lookups). This is done in the ArgumentSpecCreator defined just below ArgumentSpec and implemented nearby.

So how can we use this specialization step for optimization? One thing we can do - and I proposed this in a PR currently under review - is to specialize inputs of formal type Optional[T] to either T (or a DimensionedTensorType for Optional[Tensor]) or None, depending on whether they are given. This allows us to statically prune if t is None:-branches. This optimization is interesting, not as much because it allows to skip the if evaluation in execution, but because it allows to reduce the number of blocks (e.g. the "if" and "else" branches in the if statement). There are optimizations, notably the fuser compiling custom cuda kernels, that work only within a single block, but not across them. So if we can eliminate the if, we can put the block we need into the main graph and fuse things in formerly the branch and things in the main program.

But can we also use this in a clever way to help with the issue of broadcasting and the insertion of grad_sum_to_size from it that we discussed in the last post?2

We do not know the exact shapes at compilation time, so we do not know beforehand whether broadcasting will happen. But - and this is the crucial observation - it is the broadcasting in the forward pass of a network that causes grad_sum_to_size to exist in the backward. But this means that if we have a way to signal "no broadcasting happened in the forward" to the backward in a way that changes the argument spec of the backward (so that the backward gets its own set of optimizations compared to when broadcasting happened), we can just eliminate those grad_sum_to_size operations we know have no effect. What we thus do is to just make the target shape an Optional[List[int]] and return None if no broadcasting happened. If we then specialize that Optional to None in the argument specification, we can just remove the grad_sum_to_size it feeds into. This is done in a second PR - which is remarkably small if you subtract the preparatory argument specialization PR above.

A very small unrelated optimization with quite a bit of impact

Another very simple optimization affected another issue of broadcasting: As fused kernels typically combine relatively simple arithmetic operations, the execution time is dominated by the memory accesses rather than the computation itself. When analyzing the speed of a point-wise operation kernel from a fusion-enabled batch-norm, I noticed that it was much slower than the batch-norm kernel. After spending quite a bit of time looking at the cause (and experimentally redoing the indexing to minimize the number of integer division/modulo operations), I found that explicitly telling NVRTC to load the values into the read-only cache helps a great deal. In hindsight, this is understandably so - the batch-norm kernel minimizes the reads of the broadcasted mean and variance, weight and bias tensors, so at least caching helps. It turned out to be a two-line patch that has been added to PyTorch.

In the future, it might be useful to use the cache more efficiently by only caching broadcasted tensors. This would require to emit code that varies depending on which tensors are broadcasted. One would have to add a corresponding flag to the fuser's argument specification (that works similarly to that for graphs, but is separate). On the other hand, one might use the above indexing (that is uniform over all tensors and also used in PyTorch's TensorIterator framework underlying many pointwise and reducing operations), instead of distinguishing contractible dimensions for all tensors separately.

Conclusion

So we found a couple of not terribly complicated optimizations. While the removal of grad_sum_to_size is a less generic than the rearrangement from the last blog post, it covers many important cases like the LSTM. The optimizations get very good performance for some of the most common neural network operations, bringing JITed Python code's performance very close to that of hand-optimized CUDA kernels.

If you are close to Munich (say, day-travel-close), and want to learn PyTorch / AI for your own projects from one of the leading PyTorch experts, check out my workshop offering. I also do consultancy and bespoke development through my company, MathInf GmbH.

I love to hear form you at tv@lernapparat.de.

1. With graph_for bits might be missing with pretty_print, so you need to get by with the regular text representation if you see differentiable graphs or fusion groups being used.

2. And the complexity of the bookkeeping necessitated in the implementation of sum-to-size in the fuser understandably did not exactly spark joy for the reviewer.