Happy New Year!
The PyTorch JIT and its TorchScript language cover important parts of PyTorch's core goals. But are we integrating them in the right way into PyTorch? As my wish and plan for 2021, here are my thoughts why and how they should take a more central role.
I you need a break from programming, the Library of Congress has a public domain e-book of Alice in Wonderland.
Why the JIT is important to PyTorch
The PyTorch JIT and its TorchScript language serve three main purposes:
- Deployment: When we want to deploy our model on mobile, in a C++ application or generally where we don't have a full Python installation, it gives us a way to run models without Python. Even from within Python, the larger separation can be advantageous - JITed computation does not need the Global Interpreter Lock (GIL) except for returning the result.
- Optimization: While we can usually get PyTorch to saturate our GPU compute facilities, executing operations one by one often is inefficient. By symbolically representing our models, the JIT enables holistic optimizations, in particular through the fusion of operators. It does so seamlessly in the sense that once our model is in TorchScript, the optimizations will be deployed automatically where applicable. This avoids the need to write specialized kernels.
- Extensibility: For advanced uses, including bespoke acceleration of computation, the JIT offers an interface that allows third parties to offer the PyTorch user experience while adding their capabilities. Examples range from TorchVision's ops for detection models to TRTorch to Apache TVM.
The JIT and TorchScript serve these purposes well. The JIT is an amazing piece of software and I believe it is mostly ready for general consumption, even if it has more sharp edges (so it is comparable in maturity to PyTorch, say, ~0.3).
Why the JIT user experience isn't ideal
But there are important shortcomings of the overall user experience.While I will talk very frankly about how I think the PyTorch JIT currently leaves something to be desired, I should emphasize that I do this in order to paint how the PyTorch world should be and what is needed to be done before reaching it. The JIT and other parts of PyTorch are cool software and I admire the terrific job that the people developing it have done already. The direction I am proposing to take here just hasn't been anyone's focus here, so this isn't «someone should have done ...» but «let's do this!». I gave a brief talk at the PyTorch dev day on some aspects, but through discussions with lots of great people the theme has been distilled into this blog post. These greatly limit adoption, and, as a consequence, also keeps us from eliminating the remaining sharp edges.
JITing as a step in the workflow
Working with the JIT currently has three steps: - Develop the model (and train if you want to use the JIT for deployment) - Trace or Script the model - Deploy!
The part that isn't ideal is the conversion step: - Users like and understand tracing is but can be brittle because the JIT can literally only see calls into PyTorch. This limitation offers a substantial pitfall for users and fundamentally cannot be overcome. - Scripting on the other hand is limited by only being able to process the TorchScript subset of Python. This isn't too bad, but it is hard for the user as they would have to convert the entire model at once.
We believe that the best approach to making things easier for the users is to eliminate the conversion step entirely by enabling them to work in TorchScript from the start. Just as a key innovation PyTorch brought to a broad audience (and that was pioneered by frameworks such as HIPS autograd and chainer) is doing away with the separation of model definition and execution, we would take away JITing as a separate step. I believe that this can be as important a step for enabling production as eager execution was for deep learning frameworks in general.
Given the limitations discussed above, we cannot rely on tracing to achieve this. What we need instead is a seamless fallback mechanism to Python for things the PyTorch JIT cannot understand.
Divergence between TorchScript and PyTorch
TorchScript's existence as just another auxiliary part of PyTorch while essentially needing to cover all of PyTorch's features leads to a constant game of eliminating inconsistencies and catching up with new features.
This is frustrating to users because divergences are usually not plannable for them. Some recent examples include
- AMP seems to not fully work in the JIT. It partially does - with tracing, the AMP manipulations will be traced and hardwired into the network - but when using scripting people run into bad interactions with Autodiff.
- The lack of an easy fallback-to-python mechanism for script makes it hard to write code. As Simon Wang put it, the JIT should not force users to write ugly code.
- In general, Autodiff and Autograd seem to do mostly the same thing, until they surprisingly do not.
Why the JIT developer experience isn't ideal
The PyTorch JIT lacks in another way that I think is detrimental to widespread adoption. While the JIT is well-extensible through custom operators, C++-defined classes, and custom (optimization) passes - and there are very successful examples for this like TorchVision's specialized detection operators and the TRTorch extension (as well as the fusers in PyTorch itself).
It would be much simpler, though, if this extensibility could be had at a Python level.
- In developing FX, it would seem PyTorch is getting a second JIT that appears to be roughly equivalent to (a subset of) the TorchScript's nearly invisible pre-SSA (Static Single Assignment) form. Quite a bit of the appeal of FX seems to be that it works in Python.
Ideally, one JIT would be plenty for PyTorch. At the very least we should recognize that FX demonstrates that there is a relatively important usecase that the JIT is not serving well.
- Similarly it would be neat if we could enable using Python for those extension use cases where one currently has to drop to C++. This would allow people to experiment with creative transformations and optimizations.
What to do about this
To me, a solid TorchScript story is crucial to PyTorch, as it is at the core of its answer to deployment and the need for holistic optimization. To a large extent, deficits in TorchScript user experience are deficits in PyTorch user experience.
Close Integration of TorchScript
I believe that in addition to closing remaining technical gaps we need to put TorchScript more at the center of PyTorch.
The goal should be that models are in TorchScript by default, with seamless fallback to Python for Types that TorchScript cannot handle.
Note that this is not intended to preclude models in Python or other pure Python use of PyTorch. It's just about changing the default. One of the larger things we are giving up on with this is the lazy binding.Lazy binding here is the characteristic behaviour of Python that all lookups of attributes etc. are done at runtime. This contrasts with the JIT which will look up functions and objects at compile time and inlines the functions etc.. But so have other frameworks (like Cython) and it would seem to be a good trade-of.
One aspect of the immediate appeal of the Julia Language is that it allows users to integrate both "standard library operations" and custom computation. The PyTorch JIT may well allow us to do so, too, and we should if we want to maintain and expand PyTorch's lead in being both intuitive and flexible.
Even more lofty ideas
- Autograd should get its derivatives from Autodiff. This would eliminate the divergence between Autograd and Autodiff.
Is there a plan?
Yes. But it is my personal plan rather than one derived/approved by Facebook as the PyTorch owner. It is going to take some work:
- The JIT (legitimately) started out cutting some corners. But to make it the general-purpose execution engine we want, we need some correctness fixes.
- Low-level features:
- If in Autodiff
- Views in Autodiff
- Incremental profiling
- A bit more formal spec? There is quite a bit of documentations, but it seems that some times in the cascade what one pass expects and what the other produces, there are some gaps (e.g. shape profiling vs. profile consumptions).
- User experience
- Python falback (there already is
PyObjectType, so we're close)
- Scripted Autograd functions
- Python falback (there already is
- Developer experience to extend JIT
- robust Python interface for manipulating graphs (e.g. PyTorch Issue #49969)
- Python pass infrastructure
- Non-SSA form exposure to Python (for fallback, too?)
If you want to help out, to give me a shout.
Risks: We might not reach our goals, either implementing the features or rejection by the PyTorch owner.
We have seen how the integration of TorchScript into PyTorch and the JIT usability leaves something to desire. While PyTorch is still rapidly evolving and gaining features, it would seem that we are currently evolving more than disrupting, even in areas like the fusers where we are heavily investing. Moving to TorchScript by default is a much bigger effort, but it also offers a huge leap in terms of making the JIT's advanced features accessible.
So here is to more TorchScript in 2021 - the year of ScriptTorch!