Optimizing models using the PyTorch JIT
Today we look at TorchScript, the language implemented by the PyTorch JIT ("Just in Time compiler"), PyTorch's solution for deployment and model optimization. We can use it to export models to work beyond Python, e.g. on mobile or embedded platforms, or just to escape the infamous Python Global Interpreter Lock during computation. This is possibly the more wellknown application.Luca Antiga and Christian Sarofeen shared many insightful comments and suggestions after the article first went live. Stas Bekman pointed out improvements. Thank you! All errors are still mine.
But the JIT also lends itself to the implementation holistic optimizations that consider several operations at once. This is as opposed to just writing a better implementation of any given PyTorch operation, although the JIT works for these, too, as we will see.
We will start with a highlevel overview of how PyTorch and the JIT work to then dive into the how it enables compiling fused kernels to optimize models at run time.If you want to take a look at exporting models, do check out Chapter 15 of our book, from which I also took some diagrams below. There we introduce the JIT with a view towards running the model in C++ and on mobile. The book also as a comprehensive introduction from everything PyTorch to how to represent data and a detailed account of project to build an AI detecting cancerous lung nodules.
This tutorial has been prepared in the context of work I did for AMD. Thank you!
The overall structure of PyTorch
The first thing we want to do when considering how the JIT works is consider the structure of PyTorch.
(image from Deep Learning with PyTorch)
PyTorch most prominently is a Python library, I call this part classic PyTorch. Some parts are implemented in Python (e.g. the torch.nn
modules and the optimizers), but the compute functions (like torch.matmul
) are provided as a Python C++ extension.
Looking a bit closer, this Python C++ extension is a thin wrapper around PyTorch's C++ library LibTorch. That in turn uses the ATen tensor library which itself dispatches into various backends.
The PyTorch JIT now implements a virtual machine that takes in TorchScript programs (typically created through the torch.jit
) and runs them by calling into LibTorch itself, circumventing the Python parts.
The JIT also is extendable by defining Custom Ops, we'll get back to this. To run PyTorchexported programs in Torch Mobile or Torch Serving, the typical thing is to implement a wrapper around the JIT api to load and run modules.
TorchScript
Now that we know that we want to run our model in the JIT execution, we should see how to get our model into TorchScriptTorchScript is used simultaneously for the language  mostly a typed subset of Python  and the representation (the intermediate representation  IR)., the form the JIT can process.
There are two main ways of achieving this (but they can be mixed), scripting and tracing. Let's look at them.
Scripting
Scripting compiles (mostly) a subset of Python. It takes the Python source code and transforms it. "Here is what the function should do", just like normal programming.To run the code, check out the Notebook of this tutorial.
@torch.jit.script
def fn(x):
return x * 2
fn, fn.graph
(<torch.jit.ScriptFunction at 0x7fcd5709a310>,
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=2]() # <ipythoninput36dcc6c3c1c8e>:3:15
%3 : Tensor = aten::mul(%x.1, %2) # <ipythoninput36dcc6c3c1c8e>:3:11
return (%3))
Tracing
Tracing runs the code and observers the calls into PyTorch with some sample input.
"Watch me, now you know how to do the same."The avid reader will note that I sometimes leave the evaluation last (and have the result displayed by Jupyter if we are in a notebook) and sometimes use print
. The main difference is how strings like .code
below are handled: Jupyter will use "the representation" as output, showing multiline strings with newline characters as `\n, while printing actually starts a new line when seeing newline characters.
def fn(x):
return x * 2
fn = torch.jit.trace(fn, [torch.randn(5)])
print(fn.graph, fn.code)
graph(%x : Float(5, strides=[1], requires_grad=0, device=cpu)):
%1 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # <ipythoninput57d43d0af2e80>:2:0
%2 : Float(5, strides=[1], requires_grad=0, device=cpu) = aten::mul(%x, %1) # <ipythoninput57d43d0af2e80>:2:0
return (%2)
def fn(x: Tensor) > Tensor:
return torch.mul(x, CONSTANTS.c0)
By the way: The specialization for the Tensor shape isn't relevant here and will be erased e.g. during saving of the model.
What is TorchScript?
Now that had a glimpse of TorchScript, what is it?
One important difference between TorchScript and Python is that in TorchScript everything is typed. Important types are
bool
,int
,long
,double
for numbers (int = 32 bit integer, long = 64 bit integer)Tensor
for tensors (of arbitrary shape, dtype, ...)List[T]
a list with elements of type T (one of the above) Tuples are of fixed size with arbitrary but fixed element type, so e.g.
Tuple(Tensor, int)
. Optional[T]
for things that can beNone
None
always is of type Optional[T]
for some specific T
(except in the rarest circumstances).
PyTorch will mostly infer the intermediate and return types, but you need to annotate any nonTensor inputs.
Another important difference is the binding behaviour  when a given variable name is looked up to find the associated variable. Python uses late binding. If we write a function that calls torch.matmul
the Python interpreter will look up what torch.matmul
is when it executes the statement in which it is used.
This is in contrast to many other languages, which use early binding, as  you guessed it  TorchScript does: When we compile a function to TorchScript, the JIT looks it up thenWhile functions are looked up at compile time, the PyTorch JIT virtual machine executing the byte code looks up the operators during runtime. and there and puts it into our function (it even inlines the commands, but that is not the point here).
Tracing vs. Scripting
Scripting will process all of the code, but it may not understand all of it. This means it captures all constructs it supports (e.g. some control flow), but it will fail if there is something it doesn't understand.
Tracing cannot see anything that is not a direct call into PyTorch and will happily ignore it (e.g. control flow) This is also the reason why it will loudly complain if you have nontensor inputs.
def fn(x):
for i in range(x.dim()):
x = x * x
return x
script_fn = torch.jit.script(fn)
trace_fn = torch.jit.trace(fn, [torch.randn(5, 5)])
print(script_fn.code)
def fn(x: Tensor) > Tensor:
x0 = x
for i in range(torch.dim(x)):
x0 = torch.mul(x0, x0)
return x0
print(trace_fn.code)
def fn(x: Tensor) > Tensor:
x0 = torch.mul(x, x)
return torch.mul(x0, x0)
Tracing and Scripting Modules
But our models often are not functions. What now?
With tracing, we can work with Modules just like we work with functions. We get a ScriptModule
subclass that behaves much like a Module
with parameters, state dict etc.
model = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1))
traced_model = torch.jit.trace(model, [torch.randn(8, 1)])
type(traced_model), traced_model
(torch.jit._trace.TopLevelTracedModule,
Sequential(
original_name=Sequential
(0): Linear(original_name=Linear)
(1): ReLU(original_name=ReLU)
(2): Linear(original_name=Linear)
))
Saving is a bit different, here we include the model on purpose:
traced_model.save('./traced_model.pt')
loaded_model = torch.jit.load('./traced_model.pt')
loaded_model(torch.randn(8,1))
tensor([[ 0.2657],
[ 0.1038],
[ 0.1530],
[ 0.3578],
[ 0.0223],
[ 0.0336],
[0.0023],
[ 0.1440]], grad_fn=<AddBackward0>)
Scripting Modules
Scripting modules is ... a bit more tricky. The class in its entirety is not scripted, instead we need to script an instance of the moduleIn particular we do not want to script the __init__
method, but instead inspect the instance after it has been initialized. . Then all data members will be collected and methods will be processed in a way similar to how script functions work.
scripted_model = torch.jit.script(model)
print(scripted_model.code)
def forward(self,
input: Tensor) > Tensor:
_0 = getattr(self, "0")
_1 = getattr(self, "1")
_2 = getattr(self, "2")
input0 = (_0).forward(input, )
input1 = (_1).forward(input0, )
return (_2).forward(input1, )
We can also look at the graph including submodules, but it gets unwieldy rather fast:
scripted_model.forward.inlined_graph
graph(%self : __torch__.torch.nn.modules.container.___torch_mangle_13.Sequential,
%input.1 : Tensor):
%2 : __torch__.torch.nn.modules.linear.___torch_mangle_10.Linear = prim::GetAttr[name="0"](%self)
%3 : __torch__.torch.nn.modules.activation.___torch_mangle_11.ReLU = prim::GetAttr[name="1"](%self)
%4 : __torch__.torch.nn.modules.linear.___torch_mangle_12.Linear = prim::GetAttr[name="2"](%self)
%8 : int = prim::Constant[value=1]()
%9 : int = prim::Constant[value=2]() # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1663:22
%10 : Tensor = prim::GetAttr[name="weight"](%2)
%11 : Tensor = prim::GetAttr[name="bias"](%2)
%12 : int = aten::dim(%input.1) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1663:7
%13 : bool = aten::eq(%12, %9) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1663:7
%input.3 : Tensor = prim::If(%13) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1663:4
block0():
%15 : Tensor = aten::t(%10) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1665:39
%ret.2 : Tensor = aten::addmm(%11, %input.1, %15, %8, %8) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1665:14
> (%ret.2)
block1():
%17 : Tensor = aten::t(%10) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1667:30
%output.2 : Tensor = aten::matmul(%input.1, %17) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1667:17
%output.4 : Tensor = aten::add_(%output.2, %11, %8) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1669:12
> (%output.4)
%input.5 : Tensor = aten::relu(%input.3) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1111:17
%21 : int = prim::Constant[value=1]()
%22 : int = prim::Constant[value=2]() # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1663:22
%23 : Tensor = prim::GetAttr[name="weight"](%4)
%24 : Tensor = prim::GetAttr[name="bias"](%4)
%25 : int = aten::dim(%input.5) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1663:7
%26 : bool = aten::eq(%25, %22) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1663:7
%input.7 : Tensor = prim::If(%26) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1663:4
block0():
%28 : Tensor = aten::t(%23) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1665:39
%ret.1 : Tensor = aten::addmm(%24, %input.5, %28, %21, %21) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1665:14
> (%ret.1)
block1():
%30 : Tensor = aten::t(%23) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1667:30
%output.1 : Tensor = aten::matmul(%input.5, %30) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1667:17
%output.3 : Tensor = aten::add_(%output.1, %24, %21) # /usr/local/lib/python3.9/distpackages/torch/nn/functional.py:1669:12
> (%output.3)
return (%input.7)
What can you do with scripted modules?
 Run them as is, bypassing Python.
 not as much speedup as often is expected (maybe 5%10% for some models I tested),
 but  sometimes crucially  it avoids the dreaded Python Global Interpreter Lock (GIL), so it is useful e.g. for multithreaded things like serving PyTorch models.
 Export and run in C++ / Mobile / ..., export to other frameworks like TVM.
 Apply holistic optimizations (this is what a submodule, the JIT fuser does).
How the JIT works at a very high level
For the indepth discussion of fusers it will be useful to look closer at how the JIT works under the hood. The JIT has several phases to get us from a function to running our programs. For our purposes, we think of the following three stages:
 The first thing is to go from tracing or source to a graph.
 Then there are a number of compiler passes through the graph to go from
.graph
to an optimized graph (that can be retrieved with.graph_for(*inputs)
. We will meet some of them in detail below.  Finally, the
.graph
is compiled to a from of bytecode that is then executed by a virtual machine. We might hope to not meet the bytecode too often, but clearly we want this part to be fast, too. This maintains the operands on a stack and then dispatches to the various operators registered by LibTorch or the custom operators that extend the JIT.
The unoptimized .graph
is the "household" format here, in particular, this is what is serialized and when loaded the optimizations will have to be redone.
Tracing or scripting to a .graph
When tracing a function, the LibTorch dispatcher will call a special function (found in torch/csrc/autograd/generated/TraceTypeEverything.cpp
after you have built PyTorch) for every call of a LibTorch function.
Before redispatching to LibTorch operationFor more on the dispatcher, see Ed Yang's excellent blog post Let's talk about the PyTorch Dispatcher., this special function will record a graph node (the ones that show up in .graph
) including function calls, source location and type information.

When tracing modules, the tracer will also hook into the module
__call__
method to record the current module as the scope to capture the module structure. This is done at the Python level in thetorch.nn.Module
class, see the_slow_forward
method there. 
When scripting a function from Python, the JIT grabs the Python source code (via the
inspect
module of the standard Python library) and then runs the Python parser fromast
(for Abstract Syntax Tree) module. It then transforms the Python AST into TorchScript AST (implemented in C++), which is an initial graph form that looks a lot like Python. Any name lookup is also done at this stage, so TorchScript is (mostly) statically binding rather than dynamically like Python. After the lookup, it represents objects as Sugared ValuesSugared values represent objects that do not map directly toValues
in that TorchScript operates on, but can appear in the Python code and hence the preSSA form. These are things likeself
 also supporting attribute lookups ikeself.weight
, references to Python functions, etc. They are desugared while converting to SSA.. Finally, the JIT transforms the graph into the static single assignment (SSA) form that you can see with.graph
.  There is a variant of scripting that can be called directly from C++ and does not use the Python
ast
but parses Python on its own. This is used internally byAutoDiff
but is also a neat trick to use from C++.
Optimization passes
The JIT compiler gets us from .graph
to what we see with .graph_for
above by running a series of optimization (and some other) passes. This is done by the JIT's GraphExecutor (actually there are two, the "regular" one and the profiling one) on the first run or first few runs in the case of the profiling executor. The optimized graphs are cached along with the bytecode.
There are a number of passes that work and do not affect the automatic differentiation likeThis list is by no means complete.
 Eliminating dead code and common subexpressions, precomputing things that only involve constants,
 Pooling redundant constants into single values, and some simple "pattern matching" optimizations (like eliminating
.t().t()
),  Unrolling small loops and batching matrix multiplications that result from unrolling loops.
If the last one looks highly specialized, it is, but it is quite commonly used in recurrent networks such as LSTMs with the input weights.
As you might have guessed with the introduction, there are also some passes that would mess up autodifferentiation and can only be done if gradients are not required, or the differentiation has already been performed.
Bytecode and execution
Finally, the optimized graph is lowered to bytecode and run by the virtual machine. The virtual machine can also do function calls, this is used e.g. by the fallback mechanisms of the fusers. We will not deal much with this part.
So this gives you a very highlevel overview of what goes on in the JIT. As usual, things get complicated quickly, and the JIT is actively being developed, making this a bit of a moving target.
Excursion: GPU, efficiency, measurement
Before we discuss optimization through the JIT we have to discuss measurement. In fact, one of my many informal mottos is It's not optimization until you measure. Although PyTorch offers a capable profiling facility, I'll only discuss the most basic measurement here.
When code seems slow, it's important to figure out how slow it really is, and why. To my mind, a lot of measurement can be done with very basic tools, e.g. IPython's %timeit
magic.
GPU computations are asynchronous and should be so. It is important to avoid unneeded synchronization points, requiring the CPU to wait for the GPU work to finish to inspect results.
Synchronizations happen because the program needs to know something about the computation (e.g. sizes of tensors depending on the input). These are often unavoidable, however, typical sources of spurious synchronizations can result from simple operators like .to(device="cpu")
, .item()
, .to_list()
, and print
.
If we want to time GPU kernels, we want to be sure to synchronize before taking the start and end times. Typically, we also want to run some "warmup" iterations, i.e. run the measured function a few times before timing it.
As an example, let us take the uniformity loss from Wang and Isola: Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere (a great paper!).
The Uniformity loss is defined as a function of the pairwise distances over a largish set of vectors.
def lunif(x, t=2): # copied from the paper
sq_pdist = torch.pdist(x, p=2).pow(2)
return sq_pdist.mul(t).exp().mean().log()
x = torch.randn(1024, 128, device="cuda")
x /= x.norm(p=2, dim=1, keepdim=True).requires_grad_()
lunif(x)
tensor(3.9375, device='cuda:0', grad_fn=<LogBackward>)
One would think that the specialised pdist
function is the right tool for the job.
But is it? Let's time it.Note what happens with the synchronization here. In order to measure the time the GPU takes to compute something we need to synchronize before starting the clock (so all previous work does not run into our timing) and before stopping the clock (so all our work is included in the timing). We do this here at the end of the totime
function. We only need to synchronize at the end of the function because the synchronization in the last warmup call is the synchronization before starting the timing. (And synchronization at the beginng of the functon would not be before starting the clock.)
def totime(fn):
l = fn(x)
g, = torch.autograd.grad(l, x)
torch.cuda.synchronize()
totime(lunif) # warmup
%timeit totime(lunif)
18.6 ms ± 14.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Let's use $xy^2 = x^2 + y^2  2 \langle x, y\rangle$ and compare.We reuse the totime
wrapper function that takes care of the synchronizations.
def lunif2(x, t=2):
t=2
xnorm = torch.norm(x, p=2, dim=1).pow(2)
sq_pdist = xnorm[None] + xnorm[:, None]  2 * torch.mm(x, x.t())
exp = sq_pdist.mul(t).exp().tril(diagonal=1)
N = x.size(0)
res = exp.sum().mul(2/(N*NN)).log()
return res
print((lunif2(x.to(torch.double))  lunif(x.to(torch.double))).item())
totime(lunif2)
%timeit totime(lunif2)
4.440892098500626e16
2.23 ms ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Even though we have stark inefficiencies (like taking tril and taking a copy to do so), this is almost an order of magnitude faster!
Largely due to backward of pdist
implementation.
Optimization
But Python is slow...
Uniformity loss: "what formulas you use" is the real bottleneck (unless you optimize pdist).
The "what do we compute" typically should be the first optimization target.
But when we fix the task ("what"), how can we optimize?
Conventional wisdom: Python is slow
 certainly, Python isn't fast (
for
loop vs C++for
loop)  but, if the GPU is saturated $\Rightarrow$ Python isn't the bottleneck
How PyTorch programs spend their time
At a very high level, you can divide time spent into these parts:
 Python program flow,
 Data "administrative overhead" (creating
Tensor
data structures, autogradNode
s etc.),  Data aquisition (I/O),
 Computation roughly as
 fixed overhead (kernel launches etc.),
 reading / writing memory,
 "real computation".
Thomas' rule of thumb: As long as your operands are reasonably large (say 100s of elements, not single elements), Python and data "administrative overhead" probably isn't your main problem.
So while the JIT takes away some Python overhead, this is not spectacular optimization. With this out of the way, let us get back to how the JIT helps us optimize things.
An adhoc graph plotter (skip this on first reading)
It will be handy to draw some graphs, so here is a function that plots our graphs. It's not complete by any means, but it helps us here.
def make_graph(gr):
import graphviz
dot = graphviz.Digraph(format='svg', graph_attr={'labelloc': 't'})
nodes = {}
for i in gr.inputs():
nname = i.debugName()
label = nname.split('.')[0]
nodes[nname] = (nname, dot)
dot.node(nname, label, color='blue')
unseen_ops = {'prim::ListConstruct', 'aten::index',
'aten::size', 'aten::slice', 'aten::unsqueeze', 'aten::squeeze',
'aten::to', 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous',
'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', 'aten::unbind',
'aten::select', 'aten::detach', 'aten::stack', 'aten::reshape', 'aten::split_with_sizes',
'aten::cat', 'aten::expand', 'aten::expand_as', 'aten::_shape_as_tensor',
'aten::_size_if_not_equal', 'prim::BroadcastSizes',
'prim::Constant',
}
def process_block(nodeit, dot):
firstnode = None
lastnode = None
for n in nodeit:
k = n.kind()
outs = list(n.outputs())
inps = list(n.inputs())
type_outs = [o.type().kind() for o in outs]
type_inps = [o.type().kind() for o in inps]
if k == 'prim::If':
label = 'If'
nname = outs[0].debugName()
for i in inps:
src, srcdot = nodes.get(i.debugName(), (None, None))
if src is not None:
srcdot.edge(src, nname + '_in')
dot.node(nname + '_in', 'If', shape='diamond')
dot.node(nname, '', width='0.1', height='0.1')
dot.edge(nname + '_in', nname, style='invis')
nodes[nname] = (nname, dot)
bl = list(n.blocks())
for i, b in enumerate(bl):
with dot.subgraph(name=f"cluster_{nname}_{i}", graph_attr={'label':''}) as sub_dot:
firstnode, lastnode = process_block(b.nodes(), sub_dot)
dot.edge(nname + '_in', firstnode, label="yn"[i])
dot.edge(lastnode, nname)
if firstnode is None:
firstnode = nname + '_in'
lastnode = nname
elif k == 'prim::DifferentiableGraph':
label = 'DifferentiableGraph'
nname = outs[0].debugName()
nodes[nname] = (nname, dot)
sg = n.g('Subgraph')
nis = list(n.inputs())
sgis = list(sg.inputs())
assert len(nis) == len(sgis)
for ni, sgi in zip(nis, sgis):
if ni.debugName() in nodes:
nodes[sgi.debugName()] = nodes[ni.debugName()]
with dot.subgraph(name=f"cluster_{nname}", graph_attr={
'label': 'DifferentiableGraph', 'labelloc':'b', 'labeljust':'r'}) as sub_dot:
firstnode, lastnode = process_block(sg.nodes(), sub_dot)
nos = list(n.outputs())
sgos = list(sg.outputs())
assert len(nos) <= len(sgos)
for no, sgo in zip(nos, sgos):
if sgo.debugName() in nodes:
nodes[no.debugName()] = (nodes[sgo.debugName()][0], dot)
elif k not in unseen_ops:
if k == 'prim::CallFunction':
label = 'call ' + next(n.inputs()).node().s("name")
else:
label = k.replace('aten::', '').replace('prim::', '')
nname = outs[0].debugName()
dot.node(nname, label, shape='box', style='rounded')
for o in outs:
nodes[o.debugName()] = (nname, dot)
for i in inps:
src, srcdot = nodes.get(i.debugName(), (None, None))
if src is not None:
srcdot.edge(src, nname)
if firstnode is None:
firstnode = nname
lastnode = nname
return firstnode, lastnode
process_block(gr.nodes(), dot)
dot.node('.outputs', 'outputs', color='blue')
for i, o in enumerate(gr.outputs()):
src, srcdot = nodes.get(o.debugName(), (None, None))
if src is not None:
dot.edge(src, '.outputs')
return dot
Holistic Optimizations  JIT fusers
So currently the fuser is a hotspot of development, and PyTorch has no fewer than three fusers:
help(torch.jit.fuser)
Help on function fuser in module torch.jit._fuser:
fuser(name)
A context manager that facilitates switching between
backend fusers.
Valid names:
* ``fuser0``  enables only legacy fuser
* ``fuser1``  enables only NNC
* ``fuser2``  enables only nvFuser
How the JIT optimizes pointwise operations
To get a taste of how the JIT fuser works, let us look at the intersection over union ratio for detection modelsAnother prominent example of pointwise operations is in LSTMs: They can be though of as two matrix multiplications followed by a series of pointwise operations for the gates. The case of LSTMs has been a show case for the JIT show case for JIT optimizations.. We have a two lists of rectangles given by the top left (as x and y coordinates) and width and height. To measure the pairwise agreement of the $i$th rectangle in the first and in the second list. We do this by the intersection over union ratio which computes the areas of the intersection and the union of the two rectangles. The quotient of the two is between 0 (no agreement at all) and 1 (perfect agreement).
def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):
xi = torch.max(x1, x2) # Intersection left
yi = torch.max(y1, y2) # Intersection top
wi = torch.clamp(torch.min(x1+w1, x2+w2)  xi, min=0.) # Intersection width
hi = torch.clamp(torch.min(y1+h1, y2+h2)  yi, min=0.) # Intersection height
area_i = wi * hi # Area Intersection
area_u = w1 * h1 + w2 * h2  wi * hi # Area Union
return area_i / torch.clamp(area_u, min=1e5) # Intersection over Union
# we make a scripted function
ratio_iou_scripted = torch.jit.script(ratio_iou)
This is a simple enough function with elementwise computation. Let us look at the function graph.
make_graph(ratio_iou_scripted.graph)
It is not complex as code, but it has quite a few operations. Now, in terms of execution, every of these ops launches a kernel (a function run on the GPU) that does three things:
 Load the inputs (from the incoming edges) from memory,
 compute the output,
 store the result.
So this function loads 37 tensors and stores 20 outputs with only trivial computation. Clearly this is heavily limited by the memory transfers, even if the cache can help.
What if we could make it all into one large kernel and have 8 loads and 1 store?
This is exactly what a fuser does and it does give us a good speedup:
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda').exp()
def take_time(fn):
_ = fn(x1, y1, w1, h1, x2, y2, w2, h2)
torch.cuda.synchronize()
take_time(ratio_iou) # warmup
%timeit take_time(ratio_iou)
for i in range(2):
take_time(ratio_iou_scripted)
%timeit take_time(ratio_iou_scripted)
155 µs ± 957 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
37.4 µs ± 117 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
We can see in the graph specialised for the inputs which operations are fused:
make_graph(ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2))
Aha, so we do some type check and if that returns OK, we run a TensorExprGroup
, which will be executed as one kernel. We keep a fallback just in case.
In the text representation, we can actually see the TensorExprGroup
and we can see which operations are fused:
ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)
graph(%x1.1 : Tensor,
%y1.1 : Tensor,
%w1.1 : Tensor,
%h1.1 : Tensor,
%x2.1 : Tensor,
%y2.1 : Tensor,
%w2.1 : Tensor,
%h2.1 : Tensor):
%112 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %113 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %114 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %115 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %116 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %117 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %118 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %119 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %120 : bool = prim::TypeCheck(%w2.1, %h2.1, %w1.1, %h1.1, %y2.1, %y1.1, %x2.1, %x1.1)
%121 : Tensor = prim::If(%120)
block0():
%68 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%112, %113, %114, %115, %116, %117, %118, %119)
> (%68)
block1():
%146 : Function = prim::Constant[name="fallback_function", fallback=1]()
%147 : (Tensor) = prim::CallFunction(%146, %w2.1, %h2.1, %w1.1, %h1.1, %y2.1, %y1.1, %x2.1, %x1.1)
%148 : Tensor = prim::TupleUnpack(%147)
> (%148)
return (%121)
with prim::TensorExprGroup_0 = graph(%14 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%15 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%17 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%18 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%34 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%37 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%51 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%54 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0)):
%4 : float = prim::Constant[value=1.0000000000000001e05]()
%42 : None = prim::Constant()
%41 : float = prim::Constant[value=0.]()
%55 : int = prim::Constant[value=1]()
%xi.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::max(%54, %51) # <ipythoninput185cfef179cef6>:2:9
%yi.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::max(%37, %34) # <ipythoninput185cfef179cef6>:3:9
%56 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%54, %17, %55) # <ipythoninput185cfef179cef6>:4:31
%53 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%51, %14, %55) # <ipythoninput185cfef179cef6>:4:38
%50 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::min(%56, %53) # <ipythoninput185cfef179cef6>:4:21
%47 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%50, %xi.2, %55) # <ipythoninput185cfef179cef6>:4:21
%wi.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%47, %41, %42) # <ipythoninput185cfef179cef6>:4:9
%39 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%37, %18, %55) # <ipythoninput185cfef179cef6>:5:31
%36 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%34, %15, %55) # <ipythoninput185cfef179cef6>:5:38
%33 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::min(%39, %36) # <ipythoninput185cfef179cef6>:5:21
%30 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%33, %yi.2, %55) # <ipythoninput185cfef179cef6>:5:21
%hi.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%30, %41, %42) # <ipythoninput185cfef179cef6>:5:9
%area_i.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%wi.2, %hi.2) # <ipythoninput185cfef179cef6>:6:13
%19 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%17, %18) # <ipythoninput185cfef179cef6>:7:13
%16 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%14, %15) # <ipythoninput185cfef179cef6>:7:23
%13 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%19, %16, %55) # <ipythoninput185cfef179cef6>:7:13
%area_u.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%13, %area_i.2, %55) # <ipythoninput185cfef179cef6>:7:13
%6 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%area_u.2, %4, %42) # <ipythoninput185cfef179cef6>:8:20
%2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::div(%area_i.2, %6) # <ipythoninput185cfef179cef6>:8:11
return (%2)
We will look in some detail how these things work, but the core idea is that operations of the TensorExprGroup
here will be compiled into a single kernel that then computes the result from the inputs in one go.
How the Fusers Work at a High Level
At a high level, PyTorch's fusers work in three parts:

In a fusion JIT compiler pass, the operations that can be fused are arranged in a fusion group. By looking at which operations can be fused, we get a good glimpse of what the fusers (think) they can achieve. The classic (or legacy) PyTorch fuser only considers pointwise operations (like the IOU above, see
isSimpleMap
intorch/csrc/jit/passes/graph_fuser.cpp
). The cuda fuser (or fuser2/nvFuser above), which is conceptually somewhat close but much more elaborate than the classic fuser also handlessum
(seeIRParser
'sregisterJitOperator
intorch/csrc/jit/codegen/cuda/parser.cpp
). The TensorExpr fuser (fuser1, the default) fuses pointwise andsoftmax
andlog_softmax
in addition tosum
if reduction support is enabled (seeisSupported
intorch/csrc/jit/passes/tensorexpr_graph_fuser.cpp
). It generates a fusion group node of some sort, but, in the case of the newer two fusers also inserts a check (TypeCheck
or ...) and an explicit fallback. Interestingly, the fusers also supportrand_like
, which is very interesting and useful functionality for things like dropout. 
At some point (typically the first invocation of the fusion group), it compiles a kernel for the computation. Typically this is specific to (some aspects of) the type and shape of the inputs. For the GPU, the fusers emit HIP/CUDA C code and compile using the GPU RTC (run time compile) library. For the CPU the classic fuser would also use C but the TensorExpr fuser uses an LLVM backend (but note that the CPU is much less of a target and the main use case is the GPU). These kernels are cached.

When running a fusion group (the fuser registers an operator with the JIT that is then called), the fuser needs to launch the kernel. For the newer fusers, checking whether the inputs matches expectations is done outside this node, but the classic fuser would do the fallback itself if needed.
One thing to know about the fallback is that it itself will be optimized by the PyTorch JIT. So when we run a function that has been optimized with fusions with incompatible parameters (e.g. change whether we want gradients), the faling type check would cause the JIT to call the fallback and that would then get the optimizations for these parameters (and another level of check and fallback).
Code generation from TorchScript IR to GPU kernel
In addition to the operator support, the code generation is where each fuser has a different approach.
The CUDA fuser first transforms the TorchScript IR in the CudaFusionGroup to a Fusion IR. This is then further lowered to the Kernel IR and finally translated to C++code from which the runtime compiler generates the kernel. The approach is conceptually relatively straightforward: there are optimizations related to how the data access, and then pointwise operators are just loading, computing and storing. For reductions, there is a heuristic for how to deal with the reduction axes (this is somewhat similar to TensorIterators in ATen, and, indeed the usecase is quite similar but with the compiletime vs. runtime distinction). But, as these things go, to get good results, there are quite a few things to take care of.
The TensorExpr fuser (which is inspired by the lower levels of the Apache TVM) translates the TorchScript IR into a sequence of LoopNest statements (this is done in torch/csrc/jit/tensorexpr/kernel.cpp
, which implements the operator processing the TensorExprGroup
Torchscript IR node). This is the TensorExpr IR (the quickest overview over the IR node types can maybe be had by looking at torch/csrc/jit/tensorexpr/ir_visitor.h
). They are then optimized and lowered before they are passed to the code generators (CUDA source code for the GPU or LLVM for the CPU) that write kernel functions and then compile and run them (again, with caching).
Automatic Differentiation in TorchScript
Things are a bit more complicated if we need gradients. The default mode of the JIT is to execute the LibTorch operations and they will build an autograd graph just like in classic PyTorch. But when we want to fuse operators, things get a bit more complicated. The problem here is AutoGrad needs intermediate results to compute the backward. This is OK, but our express purpose here is to skip storing and loading the intermediate results. This is mitigated by the PyTorch JIT's own automatic differentiation mechanism, AutoDiff (as opposed to AutoGrad in nonJIT PyTorch execution).
We can see it in action when we redefine our function and run it with gradientrequiring inputs: we get a DifferentiableGraph
in there and the TensorExprGroup
is inside that (usually this would be created as part of the fallback function but to start fresh and see this better we have to redefine the function here, just rescripting isn't enough to clear the script):
def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):
xi = torch.max(x1, x2) # Intersection left
yi = torch.max(y1, y2) # Intersection top
wi = torch.clamp(torch.min(x1+w1, x2+w2)  xi, min=0.) # Intersection width
hi = torch.clamp(torch.min(y1+h1, y2+h2)  yi, min=0.) # Intersection height
area_i = wi * hi # Area Intersection
area_u = w1 * h1 + w2 * h2  wi * hi # Area Union
return area_i / torch.clamp(area_u, min=1e5) # Intersection over Union
ratio_iou_scripted = torch.jit.script(ratio_iou)
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=True).exp()
for i in range(10):
ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)
print(ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2))
make_graph(ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2))
graph(%x1.1 : Tensor,
%y1.1 : Tensor,
%w1.1 : Tensor,
%h1.1 : Tensor,
%x2.1 : Tensor,
%y2.1 : Tensor,
%w2.1 : Tensor,
%h2.1 : Tensor):
%68 : Tensor = prim::DifferentiableGraph_0(%h2.1, %h1.1, %w2.1, %w1.1, %y2.1, %y1.1, %x2.1, %x1.1)
return (%68)
with prim::DifferentiableGraph_0 = graph(%65 : Tensor,
%70 : Tensor,
%96 : Tensor,
%101 : Tensor,
%104 : Tensor,
%106 : Tensor,
%109 : Tensor,
%111 : Tensor):
%617 : int[] = aten::size(%111) # <string>:3:44
%620 : int[] = aten::size(%109) # <string>:3:93
%624 : int[] = aten::size(%106) # <string>:3:44
%627 : int[] = aten::size(%104) # <string>:3:93
%634 : int[] = aten::size(%101) # <string>:3:93
%641 : int[] = aten::size(%96) # <string>:3:93
%655 : int[] = aten::size(%70) # <string>:3:93
%662 : int[] = aten::size(%65) # <string>:3:93
%903 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %904 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %905 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %906 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %907 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %908 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %909 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %910 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %911 : bool = prim::TypeCheck(%96, %65, %101, %70, %104, %106, %109, %111)
%912 : Tensor, %913 : Tensor, %914 : Tensor, %915 : Tensor, %916 : Tensor, %917 : Tensor, %918 : Tensor, %919 : Tensor, %920 : Tensor, %921 : Tensor, %922 : Tensor, %923 : Tensor = prim::If(%911)
block0():
%830 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %832 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %area_u.4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %area_i.4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %hi.4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %846 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %850 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %852 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %wi.4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %856 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %860 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %862 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%903, %904, %905, %906, %907, %908, %909, %910)
> (%830, %832, %area_u.4, %area_i.4, %hi.4, %846, %850, %852, %wi.4, %856, %860, %862)
block1():
%959 : Function = prim::Constant[name="fallback_function", fallback=1]()
%960 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::CallFunction(%959, %96, %65, %101, %70, %104, %106, %109, %111)
%961 : Tensor, %962 : Tensor, %963 : Tensor, %964 : Tensor, %965 : Tensor, %966 : Tensor, %967 : Tensor, %968 : Tensor, %969 : Tensor, %970 : Tensor, %971 : Tensor, %972 : Tensor = prim::TupleUnpack(%960)
> (%961, %962, %963, %964, %965, %966, %967, %968, %969, %970, %971, %972)
%875 : int[] = aten::size(%912)
%876 : int[] = aten::size(%913)
%877 : int[] = aten::size(%914)
%878 : int[] = aten::size(%915)
%879 : int[] = aten::size(%916)
%880 : int[] = aten::size(%917)
%881 : int[] = aten::size(%918)
%882 : int[] = aten::size(%919)
%883 : int[] = aten::size(%920)
%884 : int[] = aten::size(%921)
%885 : int[] = aten::size(%922)
%886 : int[] = aten::size(%923)
%887 : int[] = prim::BroadcastSizes(%617, %620)
%888 : int[] = prim::BroadcastSizes(%624, %627)
%891 : int[] = prim::BroadcastSizes(%886, %885)
%895 : int[] = prim::BroadcastSizes(%882, %881)
%898 : int[] = prim::BroadcastSizes(%634, %655)
%899 : int[] = prim::BroadcastSizes(%641, %662)
%900 : int[] = prim::BroadcastSizes(%898, %899)
%619 : int[]? = aten::_size_if_not_equal(%617, %887) # <string>:3:19
%622 : int[]? = aten::_size_if_not_equal(%620, %887) # <string>:3:68
%626 : int[]? = aten::_size_if_not_equal(%624, %888) # <string>:3:19
%629 : int[]? = aten::_size_if_not_equal(%627, %888) # <string>:3:68
%633 : int[]? = aten::_size_if_not_equal(%617, %886) # <string>:3:19
%636 : int[]? = aten::_size_if_not_equal(%634, %886) # <string>:3:68
%640 : int[]? = aten::_size_if_not_equal(%620, %885) # <string>:3:19
%643 : int[]? = aten::_size_if_not_equal(%641, %885) # <string>:3:68
%647 : int[]? = aten::_size_if_not_equal(%891, %884) # <string>:3:19
%650 : int[]? = aten::_size_if_not_equal(%887, %884) # <string>:3:68
%654 : int[]? = aten::_size_if_not_equal(%624, %882) # <string>:3:19
%657 : int[]? = aten::_size_if_not_equal(%655, %882) # <string>:3:68
%661 : int[]? = aten::_size_if_not_equal(%627, %881) # <string>:3:19
%664 : int[]? = aten::_size_if_not_equal(%662, %881) # <string>:3:68
%668 : int[]? = aten::_size_if_not_equal(%895, %880) # <string>:3:19
%671 : int[]? = aten::_size_if_not_equal(%888, %880) # <string>:3:68
%675 : int[]? = aten::_size_if_not_equal(%883, %878) # <string>:3:19
%678 : int[]? = aten::_size_if_not_equal(%879, %878) # <string>:3:68
%682 : int[]? = aten::_size_if_not_equal(%634, %898) # <string>:3:19
%685 : int[]? = aten::_size_if_not_equal(%655, %898) # <string>:3:68
%689 : int[]? = aten::_size_if_not_equal(%641, %899) # <string>:3:19
%692 : int[]? = aten::_size_if_not_equal(%662, %899) # <string>:3:68
%696 : int[]? = aten::_size_if_not_equal(%898, %900) # <string>:3:19
%699 : int[]? = aten::_size_if_not_equal(%899, %900) # <string>:3:68
%703 : int[]? = aten::_size_if_not_equal(%900, %877) # <string>:3:19
%706 : int[]? = aten::_size_if_not_equal(%878, %877) # <string>:3:68
%710 : int[]? = aten::_size_if_not_equal(%878, %875) # <string>:3:19
%713 : int[]? = aten::_size_if_not_equal(%876, %875) # <string>:3:68
return (%912, %111, %109, %619, %622, %106, %104, %626, %629, %101, %633, %636, %96, %640, %643, %923, %922, %647, %650, %921, %70, %654, %657, %65, %661, %664, %919, %918, %668, %671, %917, %920, %916, %675, %678, %682, %685, %689, %692, %696, %699, %915, %703, %706, %914, %913, %710, %713)
with prim::TensorExprGroup_0 = graph(%14 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%15 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%17 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%18 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%34 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%37 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%51 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%54 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0)):
%4 : float = prim::Constant[value=1.0000000000000001e05]()
%42 : None = prim::Constant()
%41 : float = prim::Constant[value=0.]()
%55 : int = prim::Constant[value=1]()
%xi.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::max(%54, %51) # <ipythoninput23f16be0da5a84>:2:9
%yi.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::max(%37, %34) # <ipythoninput23f16be0da5a84>:3:9
%56 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%54, %17, %55) # <ipythoninput23f16be0da5a84>:4:31
%53 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%51, %14, %55) # <ipythoninput23f16be0da5a84>:4:38
%50 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::min(%56, %53) # <ipythoninput23f16be0da5a84>:4:21
%47 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%50, %xi.3, %55) # <ipythoninput23f16be0da5a84>:4:21
%wi.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%47, %41, %42) # <ipythoninput23f16be0da5a84>:4:9
%39 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%37, %18, %55) # <ipythoninput23f16be0da5a84>:5:31
%36 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%34, %15, %55) # <ipythoninput23f16be0da5a84>:5:38
%33 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::min(%39, %36) # <ipythoninput23f16be0da5a84>:5:21
%30 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%33, %yi.3, %55) # <ipythoninput23f16be0da5a84>:5:21
%hi.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%30, %41, %42) # <ipythoninput23f16be0da5a84>:5:9
%area_i.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%wi.3, %hi.3) # <ipythoninput23f16be0da5a84>:6:13
%19 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%17, %18) # <ipythoninput23f16be0da5a84>:7:13
%16 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%14, %15) # <ipythoninput23f16be0da5a84>:7:23
%13 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%19, %16, %55) # <ipythoninput23f16be0da5a84>:7:13
%area_u.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%13, %area_i.3, %55) # <ipythoninput23f16be0da5a84>:7:13
%6 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%area_u.3, %4, %42) # <ipythoninput23f16be0da5a84>:8:20
%2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::div(%area_i.3, %6) # <ipythoninput23f16be0da5a84>:8:11
return (%2, %6, %area_u.3, %area_i.3, %hi.3, %30, %36, %39, %wi.3, %47, %53, %56)
To understand why this is, we need to look at how AutoDiff works. It has roughly three stages:

The first part of AutoDiff is a pass that creates these differentiable graphs (in the optimizations, notably before the fusing). AutoDiff has a catalogue of operations for which it can compute backwardsAutodiff has its own derivative definitions which could potentially differ from the AutoGrad ones. it will move those into the
DifferentiableGraph
. 
Then, when we run a graph containing
DifferentiableGraph
nodes (i.e. during the forward pass), the second part of AutoDiff will compute the gradient by going through the nodes of the forward graph. This is a form of sourcetosource differentiation (but in contrast to classic symbolic differentiation, it is specialized to autogradstyle jacobianvectorproducts). This can amend the forward to output intermediates that are then captured for the backward, similar to thesave_for_backward
mechanism in anautograd.Function
 you can see that theTensorExprGroup
now returns a lot more values and theDifferentiableGraph
itself adds all these sizes. 
Finally, the PyTorch AutoGrad(!) mechanism is used by making a
DifferentiableGraphBackward
node that holds on to the intermediate values and, when backward is called, runs the backward graph constructed in the previous step (including letting the JIT optimize it, potentially fusing operations etc.).
What is it with these sizes then? The convenient broadcasting semantics cause PyTorch to implicitly expand operands to (mostly) binary operations. But these expansions have a gradient operation associated with them  a summation of any broadcast dimensions. These size operations check whether broadcasting has happened (i.e. the output shape is large than the input for a binary operation) and if so record the target size for the summation (and None
if no summation is needed thanks to the aten::_size_if_not_equal
operation).
There is another thing to note here: The JIT currently does not have a terribly smart logic to decide which things to capture and which things might be as well recomputed (e.g. done manually, one might well choose to recompute all the intermediates of our little function instead of capturing the values), but will mimic what AutoGrad does (defined by the AutoDiff backward specifications).
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=True).exp()
def take_time(fn):
_ = fn(x1, y1, w1, h1, x2, y2, w2, h2)
torch.cuda.synchronize()
take_time(ratio_iou) # warmup
%timeit take_time(ratio_iou)
for i in range(2):
take_time(ratio_iou_scripted)
%timeit take_time(ratio_iou_scripted)
222 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
94 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Displaying backwards graphs
Sometimes we want to see what actually happens in the backwards. This is not completely trivial. It is relatively easy to get the nonoptimized graph, as it is an attribute to the prim::DifferentiableGraph
node:
gr = ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)
n = gr.findAllNodes('prim::DifferentiableGraph')[0]
make_graph(n.g('ReverseSubgraph'))
We see lots of GradOf
(the backward of something in the forward) and AutogradAdd
(the "backward" of using a value several times).We also see that our make_graph
function is rather incomplete. If you want to play with this more, add GradOf
handling to it. We would need to process the block of GradOf
nodes. The information what has been differentiated to get this grad is available in the string attribute name
(node.s("name")
).
But getting the optimized graph requires us to dig through things like executor states and execution plans and their code object. These then have the gradient executor states...or know how to find them, because technically, the gradient executors are attached to the concrete DifferentiableGraph
TorchScript operator., which then have an execution plan with the optimized graph.While we will not be looking at differentiable graphs, we will se a bit of how these things work in the next blog post, where we will look at what happens behind the curtain when we call a TorchScript function from Python.
for i in range(10):
res = ratio_iou_scripted(x1, y1, w1, h1, x2, y2, w2, h2)
torch.autograd.grad(res.sum(), [x1, y1, w1, h1, x2, y2, w2, h2])
fw_execution_plan = list(ratio_iou_scripted.get_debug_state().execution_plans.values())[0]
bw_execution_plan = list(ep.code.grad_executor_states()[0].execution_plans.values())[0]
make_graph(bw_execution_plan.graph)
These graphs look quite unwieldy. So now that we know how to get them if we have to, we better move on.
The Profiling Executor
We mentioned that the JIT fusers will specialze on detailed tensor type informaition. How does it get this information? It is through the Profiling Executor that is in charge of running the JITed graphs.
The profiling executor will record tensor type information (dtype, strides, sizes, requires gradient)Lest you should be thinking of taking time measurements when hearing profiling  I sure did  this does not seem to be done here, currently. in its profiling phase (the first few invocations). This is done by inserting special prim::profile
nodes into the graph which then run an operator collecting and aggregating this information. Currently, it runs one profiling run, but this is configurable. Then it uses this information to implement optimizations.
Traditionally, the same thing (get tensor type information attached to every value) has been done by propagating the types from the inputs through the graph. While this works great in general, it soon hits limitations, e.g. for convolutions the output shape (and thus precise type information) depends on the value (not even just the type) of e.g. the padding input. This means that unless we detect that the outputshaping inputs are constants and have some way of accessing type propagation, we do not know the output shapeThe same topic is also addressed by people interested in typechecking tensor programs who coordinate on the Python typing sig mailing list.). This, and the fact that the rules of which things can be handled by the same optimized graph are quite complex and differ from one optimization to the other, is the reason we instead observe shapes during runtime (my impression is that PyTorch operations would ideally provide type propagation information, but that may just be me).
So when the JIT fuser passes mentioned above go to work, they find these typing annotations on all tensor values and can adjust.
One interesting aspects about the type expectations encoded by TypeCheck
for the TensorExpr fuser and CudaFuserGuard
for the CUDA fuser.Iterestingly, TypeCheck
is wired into the JIT interpreter and JIT type system, while the CudaFuserGuard
is implemented as a regular operator and implemented "manually" in a function complyWith
in torch/csrc/jit/codegen/cuda/interface.cpp
. While they both nail the tensor shape and layout, the CUDA fuser will use the same kernel on tensors of different sizes as long as the contiguity pattern (i.e. that there are no gaps in the storage between the values of the tensor, e.g. from slicing) is the same.
Looking at fallback graphs
We mentioned the importance of fallbacks and how fallbacks are again optimized. But we have yet to see it. Sadly, the JIT's Python interface is lacking or, hopefully, lagging a bit.
But we can hack around this by building our own little PyTorch extension that provides the missing functionality. Again, I recommend to skip this bit on first reading and revisit if you really want to know about types in the JIT (that would be another tutorial).
csrc = """
#include <torch/extension.h>
using ::c10::Type;
using ::torch::jit::FunctionType;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<FunctionType, Type, std::shared_ptr<FunctionType>>(m, "FunctionType")
.def("name", [](const std::shared_ptr<FunctionType>& self) {
return self>function()>name();
})
.def(
"get_debug_state",
[](const std::shared_ptr<FunctionType>& self) {
return self>function()>get_executor().getDebugState();
})
.def("optimized_graph", [](const std::shared_ptr<FunctionType>& self) {
return self>function()>optimized_graph();
});
}
"""
import torch.utils.cpp_extension
ext = torch.utils.cpp_extension.load_inline("functiontype_ext",[csrc], verbose=True)
def find_function_types(graph_or_block, function_types=None):
if function_types is None:
function_types = []
for n in graph_or_block.nodes():
if n.kind() == 'prim::Constant':
t = n.output().type()
if t.kind() == 'FunctionType':
function_types.append(t)
else:
for b in n.blocks():
find_function_types(b, function_types=function_types)
if n.hasAttribute('Subgraph'):
find_function_types(n.g('Subgraph'), function_types=function_types)
return function_types
def get_function_graphs(gr):
return {t.name(): list(t.get_debug_state().execution_plans.values())[0].graph for t in find_function_types(gr)}
Using /home/tv/.cache/torch_extensions as PyTorch extensions root...
Emitting ninja build file /home/tv/.cache/torch_extensions/functiontype_ext/build.ninja...
Building extension module functiontype_ext...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module functiontype_ext...
With this, we can now extract the fallback. Let us run our function a few times, first without needing gradients and then with needing gradients.
The original graph is the part that doesn't need gradients, as could be expected.
def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):
xi = torch.max(x1, x2) # Intersection left
yi = torch.max(y1, y2) # Intersection top
wi = torch.clamp(torch.min(x1+w1, x2+w2)  xi, min=0.) # Intersection width
hi = torch.clamp(torch.min(y1+h1, y2+h2)  yi, min=0.) # Intersection height
area_i = wi * hi # Area Intersection
area_u = w1 * h1 + w2 * h2  wi * hi # Area Union
return area_i / torch.clamp(area_u, min=1e5) # Intersection over Union
ratio_iou_scripted = torch.jit.script(ratio_iou)
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda').exp()
for i in range(10):
ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=True).exp()
for i in range(10):
ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)
gr = torch.jit.last_executed_optimized_graph()
make_graph(gr)
gr_fb1 = get_function_graphs(gr)['fallback_function']
make_graph(gr_fb1)
We can take this to several levels, but when we get an "internal assert failed" error regarding a missing optimized plan, it means that we have reached the end of the optimized fallback passes.
gr_fb2 = get_function_graphs(gr_fb1)['fallback_function']

RuntimeError Traceback (most recent call last)
<ipythoninput28229c5d87e1d1> in <module>
> 1 gr_fb2 = get_function_graphs(gr_fb1)['fallback_function']
<ipythoninput25efbba7d523b5> in get_function_graphs(gr)
40
41 def get_function_graphs(gr):
> 42 return {t.name(): list(t.get_debug_state().execution_plans.values())[0].graph for t in find_function_types(gr)}
<ipythoninput25efbba7d523b5> in <dictcomp>(.0)
40
41 def get_function_graphs(gr):
> 42 return {t.name(): list(t.get_debug_state().execution_plans.values())[0].graph for t in find_function_types(gr)}
RuntimeError: optimized_plan_ INTERNAL ASSERT FAILED at "../torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp":551, please report a bug to PyTorch.
We can inspect the unoptimized fallback, even if it may seem counterintuitive to the uninitiated like us that the unoptimized graph should be accessed via optimized_graph
In a later blog post we will take a deep dive into how script functions are executed and will find that there are some very elementary optmizations applied at the function level. (Also note that the type annotations in the fallback branch are bogus. Oh well.):
find_function_types(gr_fb1)[0].optimized_graph()
graph(%0 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%1 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%5 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%6 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),
%7 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0)):
%11 : int = prim::Constant[value=1]()
%10 : float = prim::Constant[value=0.]()
%9 : None = prim::Constant()
%8 : float = prim::Constant[value=1.0000000000000001e05]()
%xi.4 : Tensor = aten::max(%7, %6) # <ipythoninput26071038d7fdab>:2:9
%yi.4 : Tensor = aten::max(%5, %4) # <ipythoninput26071038d7fdab>:3:9
%14 : Tensor = aten::add(%7, %2, %11) # <ipythoninput26071038d7fdab>:4:31
%15 : Tensor = aten::add(%6, %0, %11) # <ipythoninput26071038d7fdab>:4:38
%16 : Tensor = aten::min(%14, %15) # <ipythoninput26071038d7fdab>:4:21
%17 : Tensor = aten::sub(%16, %xi.4, %11) # <ipythoninput26071038d7fdab>:4:21
%wi.4 : Tensor = aten::clamp(%17, %10, %9) # <ipythoninput26071038d7fdab>:4:9
%19 : Tensor = aten::add(%5, %3, %11) # <ipythoninput26071038d7fdab>:5:31
%20 : Tensor = aten::add(%4, %1, %11) # <ipythoninput26071038d7fdab>:5:38
%21 : Tensor = aten::min(%19, %20) # <ipythoninput26071038d7fdab>:5:21
%22 : Tensor = aten::sub(%21, %yi.4, %11) # <ipythoninput26071038d7fdab>:5:21
%hi.4 : Tensor = aten::clamp(%22, %10, %9) # <ipythoninput26071038d7fdab>:5:9
%area_i.4 : Tensor = aten::mul(%wi.4, %hi.4) # <ipythoninput26071038d7fdab>:6:13
%25 : Tensor = aten::mul(%2, %3) # <ipythoninput26071038d7fdab>:7:13
%26 : Tensor = aten::mul(%0, %1) # <ipythoninput26071038d7fdab>:7:23
%27 : Tensor = aten::add(%25, %26, %11) # <ipythoninput26071038d7fdab>:7:13
%area_u.4 : Tensor = aten::sub(%27, %area_i.4, %11) # <ipythoninput26071038d7fdab>:7:13
%29 : Tensor = aten::clamp(%area_u.4, %8, %9) # <ipythoninput26071038d7fdab>:8:20
%30 : Tensor = aten::div(%area_i.4, %29) # <ipythoninput26071038d7fdab>:8:11
%31 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::TupleConstruct(%30, %29, %area_u.4, %area_i.4, %hi.4, %22, %20, %19, %wi.4, %17, %15, %14)
return (%31)
How we could go at benchmarking
We can now pit the various fusers against each other if we want. We abuse the context manager in a noncontextmanagery way. Note that we do not time the backwards here, but it would be straightforward to do, too.
for rq in [False, True]:
for fuser in [None, "fuser1", "fuser2"]:
if fuser is not None:
c = torch.jit.fuser(fuser)
c.__enter__()
def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):
xi = torch.max(x1, x2) # Intersection left
yi = torch.max(y1, y2) # Intersection top
wi = torch.clamp(torch.min(x1+w1, x2+w2)  xi, min=0.) # Intersection width
hi = torch.clamp(torch.min(y1+h1, y2+h2)  yi, min=0.) # Intersection height
area_i = wi * hi # Area Intersection
area_u = w1 * h1 + w2 * h2  wi * hi # Area Union
return area_i / torch.clamp(area_u, min=1e5) # Intersection over Union
ratio_iou_scripted = torch.jit.script(ratio_iou) if fuser is not None else ratio_iou
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=rq).exp()
print(f"fuser: {fuser}, requires gradient: {rq}")
for i in range(10):
take_time(ratio_iou_scripted)
%timeit take_time(ratio_iou_scripted)
fuser: None, requires gradient: False
152 µs ± 330 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fuser: fuser1, requires gradient: False
37.6 µs ± 52 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fuser: fuser2, requires gradient: False
47.5 µs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fuser: None, requires gradient: True
218 µs ± 2.73 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fuser: fuser1, requires gradient: True
92.8 µs ± 38.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fuser: fuser2, requires gradient: True
106 µs ± 43 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Doing funny things to kick the tires a bit
If you followed along, you will have noticed that the order of kernels to try depends on how we have called our scripted function before. This can lead to somewhat funny effects.
One thing is that whether we end up running a DifferentiableGraph
(and computing the intermediates) depends on what we did during the profiling and the fallback mechanisms for the fusion groups.
In fact, there are bugs to be found (reported as #49299) where whether we get gradient requiring outputs does not match what we feed into the scripted function:
for fuser in ["fuser1", "fuser2"]:
for rq in [True, False]:
c = torch.jit.fuser(fuser)
c.__enter__()
def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):
xi = torch.max(x1, x2) # Intersection left
yi = torch.max(y1, y2) # Intersection top
wi = torch.clamp(torch.min(x1+w1, x2+w2)  xi, min=0.) # Intersection width
hi = torch.clamp(torch.min(y1+h1, y2+h2)  yi, min=0.) # Intersection height
area_i = wi * hi # Area Intersection
area_u = w1 * h1 + w2 * h2  wi * hi # Area Union
return area_i / torch.clamp(area_u, min=1e5) # Intersection over Union
ratio_iou_scripted = torch.jit.script(ratio_iou)
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=not rq).exp()
for i in range(10):
ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=rq).exp()
print("fuser:", fuser, "input requires grad:", x1.requires_grad, "output requires grad:", ratio_iou_scripted(x1, y1, w1, h1, x2, y2, w2, h2).requires_grad)
fuser: fuser1 input requires grad: True output requires grad: True
fuser: fuser1 input requires grad: False output requires grad: True
fuser: fuser2 input requires grad: True output requires grad: True
fuser: fuser2 input requires grad: False output requires grad: True
Another fun thing to try is what happens when the profiling runs see different tensor sizes (this is a real thing, e.g. for Neural Machine Translation or other NLP applications).
Do change the fuser between fuser1
and fuser2
here. We see that the CUDA fuser can handle both sizes with the same kernel while the TensorExpr fuser decides to not optimize this path.
c = torch.jit.fuser("fuser1")
c.__enter__()
torch._C._jit_set_num_profiled_runs(2)
def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):
xi = torch.max(x1, x2) # Intersection left
yi = torch.max(y1, y2) # Intersection top
wi = torch.clamp(torch.min(x1+w1, x2+w2)  xi, min=0.) # Intersection width
hi = torch.clamp(torch.min(y1+h1, y2+h2)  yi, min=0.) # Intersection height
area_i = wi * hi # Area Intersection
area_u = w1 * h1 + w2 * h2  wi * hi # Area Union
return area_i / torch.clamp(area_u, min=1e5) # Intersection over Union
ratio_iou_scripted = torch.jit.script(ratio_iou)
inputs1 = torch.randn(8, 100, 1000, device='cuda').exp()
inputs2 = torch.randn(8, 101, 1000, device='cuda').exp()
for i in range(10):
ratio_iou_scripted.graph_for(*inputs1)
ratio_iou_scripted.graph_for(*inputs2)
make_graph(ratio_iou_scripted.graph_for(*inputs1))
Getting more debug output
When we run the JIT on the command line, we can make use of its debug logging facility to watch its parts in action more closely.
The fusers also have various debugging facilities. The TensorExpr one uses the debug logging facility (grep for GRAPH_
in torch/csrc/jit/tensorexpr/
) and the CUDA one uses environment variables starting with PYTORCH_CUDA_FUSER
(grep for that in torch/csrc/jit/codegen/cuda/
).
Conclusion
In this piece, we saw a bit how the JIT worksThere also is a more general technical overview in the file torch/csrc/jit/OVERVIEW.md
in the JIT directory of the PyTorch source code) and various bits of documentation in .md
files throughout the source as well as in comments in the source., with a focus on the parts that make fusion optimizations possible and took a dive from a very high level to experimentation that try to show how some internals work. In the next blog post, we will dive a bit deeper into how JIT functions are executed and follow the source. Stay tuned!
I hope you enjoyed this tour. As always your feedback is appreciated: tv@lernapparat.de.