PyTorch style advice: Avoid inplace
Every now and then, people wonder whether using inplace operations is a good way to reduce the memory consumption of PyTorch models. Let us look at this in some detail.
Recently, I discussed reusing of stateless modules on twitter, and Andrej Karpathy, whose minGPT code was one of the examples, chimed in and noted
I've always been very sketched out by module re-use and in-place operations. These can both lead to very subtle and hard-to-find issues, and would always advise against their use, esp early in development.
So what is it about inplace?
Inplace operations
Many Pytorch operations come with inplace variants. The convention is that those are suffixed with a single _
and return the (same as the corresponding input) object they have manipulated.
For example, given a Tensor t
, the function t.mul(2)
will return the same new tensor that we would get from 2 * t
. The inplace operation t.mul_(2)
will multiply the elements of t
by 2
without making a copy and return t
. While the operation is the same as t *= 2
, returning the tensor has the advantage that we can chain these operations conveniently t.mul_(2).add_(1)
.
At the more high-level interface, some nn.Module
s for activations provide the option to work inplace, e.g. ReLU(inplace=True)
will use the inplace torch.nn.functional.relu_
instead of torch.nn.functional.relu
.
Inplace and autograd
But there are caveats to inplace operations. For PyTorch, the two main ones are
- inplace operations sometimes don't mix well with autograd, and
- inplace operations make it very hard to (automatically) reason about code.
Let us look at the first. The following code snippet will produce what is perhaps one of the most infamous PyTorch error messages RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
:
a = torch.randn(5, requires_grad=True)
b = 2 * a
c = b ** 2 # replace this with c = b + 2 and the autograd error will go away
b += 1 # inplace operation!
c.sum().backward()
The reason is that the PowBackward
backward operation attached to c
wants to compute grad_b = grad_c * 2 * (b ** 1)
because $f(b) = b^\alpha$ has the derivative $f'(b) = \alpha b^{\alpha - 1}$ and autograd detects that b
has been modified.
How does autograd detect that? Tensors come with a version counter, and autograd keeps a copy of the counter when it saves tensors in the autograd graph and compares the current version to the saved one during the backwardOf course, one might wonder whether autograd could just mark those tensors as read-only to give us an error message during the forward when it is modified.. We can print b._version
before and after the b += 1
to see that it indeed bumps the version.
We will return to the autograd interaction in a moment but for completeness, let us look briefly at the second issue with inplace operations.
When PyTorch or other libraries analyze PyTorch programs, for example in the JIT, they will consider reordering and fusion opportunities. But there is a difficult question when it comes to inplace operations:
Which tensors does a given inplace operation change? Obviously, the one we operate on, but so other tensors might be views of the same underlying memory area (for example t.view(-1)
will give a flattened view of t
(and fail if the strides make this impossible) and t.permute
or t.transpose
will also give new tensors sharing the memory with their inputs). The PyTorch JIT keeps an alias database to track these things, but it is one of the most complex aspects and sometimes cannot perfectly work: For example, PyTorch cannot know if input tensors a given (scripted) function gets are views of each other. Also, some functions like .reshape
return views sometimes (if it is possible) but not always.
Often, the PyTorch JIT or other libraries will pass on optimizing code that uses inplace operations to not risk silently making assumptions that can fail with creative enough use. This is why inplace can sometimes hinder optimization.
Saving memory with inplace
So far so good, and this is what I always tell with the advice to be careful with inplace and also that it should not be the first optimization.
But recently, it occurred to me, that the optimization is even more dubious:
Let us look at memory use in neural network training in some detail. Roughly speaking, there are three kinds of tensors consuming (GPU, mostly) memory:
- Global state, i.e. our model's parameters, momentum or other statistics for the optimizer. I would morally count in gradients, here, too.Stas Bekman points out that I did not initially mention gradients, and these use quite a bit of memory, especially in transformer-like architectures. To me, gradients are somewhat in between global state and intermediate computation (though they are not collected in the forwards pass): In many operations, they are not quite global state as they are cleared after the gradient step and they are not needed during the forward but instead allocated bit by bit during the backward. If you accumulate over several forward computations (for large, memory-hungry models where all this is important), they are even more like global state. And they are what is left over "after autograd" finishes. So at this level of abstraction, I'd say they are very similar to parameters. Thank you Stas, for this important comment, and pointing out some other typos in the article!
- Intermediate computation results saved for the backward. Things collected in the forward pass that are saved for the backward. This is (in addition to the global state) the reason why batch sizes typically need to be smaller in training than in validation to fit into GPU memory.
- Temporary computation results that are not saved for the backward. These are local to the functions they are computed in and can be deallocated once the variable goes out of scope.Python sometimes seems to not deallocate memory when it is not needed, but only after a while or when the garbage collector is run (
gc.collect()
from the standard librarygc
module). This can happen in particular when there are reference cycles (a
points tob
andb
, directly or indirectly, back toa
) keeping the reference counts from dropping to 0.
So what does that mean for inplace? As discussed above, we cannot use inplace operations when autograd needs the tensor we would modify for the backward. So we can only use inplace for the third kind of tensors, the temporary ones. But this means that inplace operations don't help us save on the more costly long- and medium-term use of memory but only optimizes memory use in short-lived tensors.Note that this is a bit different for inference when there are only global and temporary tensors. But there, PyTorch is developing things like a static graph engine that will automatically identify fusion and memory reuse opportunities, and engines like ONNXRuntime and TensorRT will likely do similar things. So manually prescribing inplace seems not that great a win in any situation.
Conclusion
Avoid inplace, unless you have a good reason to want it. The gains are very limited and sometimes it blows up or silently prevents optimizations, and if it does probably is not the best use of your time.
PyTorch courses
You can have much more geeking out about autograd details in my all about autograd online course. I have also published a set of slides covering some autograd for a PyTorch Tutorial I gave this summer.
I'm currently working on courses Introduction to deep learning with PyTorch and Transformers with PyTorch. Drop me a line if you want to be notified when they are ready or the next batch starts.
As always I appreciate your comments and feedback at tv@lernapparat.de.