Today we explore another option for optimization in the JIT. We will see how it presents an alternative that also allows to make LSTM backwards fast, but tries to be smarter about the bookkeeping aspects.
When you use the PyTorch JIT with TorchScript, you write Python, but the JIT will see how to get things to run fast. This happens in two stages: First, the JIT converts the TorchScript (i.e. Python) code to a graph that more or less literally represents your Python code. For your
fn, you can get this graph with
fn.graph.pretty_print() will give you a TorchScript-like representation). If this is the parsing stage, there is a second compilation phase, in which the graph is transformed. This happens when you run the function, say with
a, b, c as arguments. You can see the graph that was created and run with
fn.graph_for(a, b, c).1 As you might expect from the interface, different arguments might lead to different graphs being executed. And this is exactly what we want to look at.
Given that the JIT is a Just-in-Time Compiler, one of the key aspects is when to compile anew and when to use a cached compilation result. Whether or not two tensors look the same for this purpose is captured in the
ArgumentSpec structure. At the time of writing, the criteria are the scalar type, whether a tensor is defined, the device, and the number of dimensions. These are then captured in the
DimensionedTensorType JIT type. A special case are undefined tensors - which are a "special zero" in the context of backward passes where a differentiable output tensor does not get a gradient because it is not connected to the loss that is differentiated.
Now, one difficult problem is what to do with structured data types. The PyTorch JIT will look into tuples for that purpose, but not into lists much. As making this decision is in the critical path - it happens on every invocation of the function - it has to be fast. So the PyTorch JIT encodes the analysis based on the graph inputs (once) and then executes the encoded analysis to get a graph specialization (along with a hash code for fast lookups). This is done in the
ArgumentSpecCreator defined just below
ArgumentSpec and implemented nearby.
So how can we use this specialization step for optimization? One thing we can do - and I proposed this in a PR currently under review - is to specialize inputs of formal type
Optional[T] to either
T (or a
None, depending on whether they are given. This allows us to statically prune
if t is None:-branches.
This optimization is interesting, not as much because it allows to skip the
if evaluation in execution, but because it allows to reduce the number of blocks (e.g. the "if" and "else" branches in the
if statement). There are optimizations, notably the fuser compiling custom cuda kernels, that work only within a single block, but not across them. So if we can eliminate the
if, we can put the block we need into the main graph and fuse things in formerly the branch and things in the main program.
We do not know the exact shapes at compilation time, so we do not know beforehand whether broadcasting will happen. But - and this is the crucial observation - it is the broadcasting in the forward pass of a network that causes
grad_sum_to_size to exist in the backward. But this means that if we have a way to signal "no broadcasting happened in the forward" to the backward in a way that changes the argument spec of the backward (so that the backward gets its own set of optimizations compared to when broadcasting happened), we can just eliminate those
grad_sum_to_size operations we know have no effect. What we thus do is to just make the target shape an
Optional[List[int]] and return
None if no broadcasting happened. If we then specialize that
None in the argument specification, we can just remove the
grad_sum_to_size it feeds into. This is done in a second PR - which is remarkably small if you subtract the preparatory argument specialization PR above.
A very small unrelated optimized with quite a bit of impact
Another very simple optimization affected another issue of broadcasting: As fused kernels typically combine relatively simple arithmetic operations, the execution time is dominated by the memory accesses rather than the computation itself. When analyzing the speed of a point-wise operation kernel from a fusion-enabled batch-norm, I noticed that it was much slower than the batch-norm kernel. After spending quite a bit of time looking at the cause (and experimentally redoing the indexing to minimize the number of integer division/modulo operations), I found that explicitly telling NVRTC to load the values into the read-only cache helps a great deal. In hindsight, this is understandably so - the batch-norm kernel minimizes the reads of the broadcasted mean and variance, weight and bias tensors, so at least caching helps. It turned out to be a two-line patch that has been added to PyTorch.
In the future, it might be useful to use the cache more efficiently by only caching broadcasted tensors. This would require to emit code that varies depending on which tensors are broadcasted. One would have to add a corresponding flag to the fuser's argument specification (that works similarly to that for graphs, but is separate). On the other hand, one might use the above indexing (that is uniform over all tensors and also used in PyTorch's
TensorIterator framework underlying many pointwise and reducing operations), instead of distinguishing contractible dimensions for all tensors separately.
So we found a couple of not terribly complicated optimizations. While the removal of
grad_sum_to_size is a less generic than the rearrangement from the last blog post, it covers many important cases like the LSTM. The optimizations get very good performance for some of the most common neural network operations, bringing JITed Python code's performance very close to that of hand-optimized CUDA kernels.
If you are close to Munich (say, day-travel-close), and want to learn PyTorch / AI for your own projects from one of the leading PyTorch experts, check out my workshop offering. I also do consultancy and bespoke development through my company, MathInf GmbH.
I love to hear form you at email@example.com.
graph_forbits might be missing with
pretty_print, so you need to get by with the regular text representation if you see differentiable graphs or fusion groups being used. ↩
And the complexity of the bookkeeping necessitated in the implementation of sum-to-size in the fuser understandably did not exactly spark joy for the reviewer. ↩