The JIT runtime - Calling a ScriptFunction

Dec. 28, 2020

In the first post of our series of PyTorch JIT blog posts, we had a good overview of how the PyTorch JIT works when we wanted to look at how it optimizes models. Today we take a close look at what happens behind the scenes when we call a TorchScript function from Python.

What happens when you call a TorchScript function?

This is structured a bit differently from the last post in that I will take you by the hand and we hop into the rabbit hole of the PyTorch source code. To do this, I have added links to a copy of the source code below and they will open the source code at the corresponding line on the right when you move the mose over them or click them (and sorry, if you read this on mobile, you're probably out). As I want you to click on all the internal links I have marked the very few external links with (external link).

Let us try:

def fn(x):
    return x * 2 + x

this gives


so we have to look what happens when a ScriptFunction is called.

This will be quite a journey, so here is a map: overview

It will be quite a journey, so off we go!

Functions from Python to C++

ScriptFunction is a PyBind-defined class defined in torch/csrc/jit/python/script_init.cpp. It wraps a StrongFunctionPtr. It defines a __call__ method.

The StrongFunctionPtr has a shared pointer to a CompilationUnit and a torch::jit::Function*. The CompilationUnit owns the function (put there e.g. by scripting our function) and potentially multiple of them but to adapt to Python, we want something that is specific to this function, the StrongFunctionPtr.

So for calling, we need the Function. This is passed with the (Python) args to invokeScriptFunctionFromPython, defined in pybind_utils.h, which passes to runAndInsertCall in the same file. There we use the function createStackForSchema to create a stack of IValues from the Python arguments using the function's Schema - TorchScript's way of keeping track of the signature of a function. We skip the details of that. Then the Function's run method is called to execute the function. Afterwards, the stack contains the results in IValues, these are converted to Python values (toPyObject, same file) and returned. If we were in tracing mode (i.e. runningtorch.jit.trace calling our script function), we would record the call in the tracing graph.

So we are looking for Function::run. A Function is an abstract class defined in aten/src/ATen/core/function.h. The interesting specialization is the GraphFunction from torch/csrc/jit/api/function_impl.h and .cpp. It has a GraphExecutor attribute executor_ that can be obtained though (and is instantiated in) its get_executor() (from the .h). With that, all GraphFunction::run does is call the run method of the GraphExecutor with the IValue-stack as the argument. The GraphExecutor is instantiated with a TorchScript Graph from the GraphFunction's optimized_graph() method and the function's name. The optimized_graph is basically the graph with some initial optimizations (applied by GraphFunction::preoptimizeGraph in the .cpp): PeepholeOptimize, ConstantPropagationImmutableTypes and ConstantPooling. I will have to do a separate blog post on all the optimizations and passes let's skip details for now.

Graph Executors

Now we are in the depths of the JIT runtime and in torch/csrc/jit/runtime! The GraphExecutor is a wrapper for GraphExecutorImpl (from graph_executor_impl.h and graph_executor.cpp) or the newer ProfilingGraphExecutorImpl (from profiling_graph_executor_impl.h and .cpp) that both extend the GraphExecutorImplBase (also graph_executor_impl.h and graph_executor.cpp). The wrapper mostly handles instantiation and forwards a few method calls, including, of course, run. Which executor implementation gets used is decided in the GraphExecutor constructor based on three things: the environment variable TORCH_JIT_DISABLE_NEW_EXECUTOR, the C++ command line flag torch_jit_enable_new_executor, and - most importantly for us - a flag obtained through getExecutorMode(), which is exposed to Python via torch._C._jit_set_profiling_executor and set via the torch.jit.fuser context. The default, nowadays is the profiling one.

The graph executor is the thing that handles the optimization and running of things in TorchScript. It works at the level of TorchScript IR graphs and will itself call into the bytecode interpreter for the actual execution, as we will see.

Functions (torch::jit::Function) are not the only thing that get graph executors, sometimes we also want to have executors for other graphs. This will be important to us below, when we use a new executor for some things in order to have specialized optimizations.

(We'll drop the Impl from the graph executors below when referring to them below).

Execution Plans

The run (GraphExecutorImpleBase::run in graph_executor.cpp) method is relatively simple: It gets an ExecutionPlan (defined in graph_executor.h) using the getPlanFor method (which takes our stack - for the types - and the remaining bailout depth as arguments. More on the latter below.). This execution plan has a member code holding the bytecode in a Code object. The GraphExecutor's run method then instantiates an InterpreterState with this code and calls its run method with the stack.

But this means the optimization magic is in getPlanFor, which is specific to the two executors.

The Profiling Graph Executor

The Profiling Graph Executor's getPlanFor is again very easy: If the optimized_plan_ is initialized, that is what it returns (so we only have one optimized execution plan which is used regardless of the input types). If not, it calls getOptimizedPlanFor to make one.

When not disabled through a flag that can be queried and set in Python through torch._C._get/set_graph_executor_optimize, our getOptimizedPlanFor goes through the following:

  • If the bailout depth is $0$, it uses the graph after only the profiling insensitive optimizations (quite a list called from the function runProfilingInsensitiveoptimizations). In particular, no profiling or specialization takes place.

  • If not it creates a profiling plan (if it has not done so yet) that is used to record the shape information in a ProfilingRecord (assigned to class member pr_, ProfilingRecord is defined in profiling_record.h/cpp). The profiling graph is created from the graph we got by running the profiling insinsitive optimizations and then instrumenting the graph (through the ProfilingRecord::instrumentGraph factory function). This graph then contains lots of prim::profile nodes. The ProfilingRecord has a counter for the number of remaining profiling runs until it is ready (ready()), it starts from getNumProfiledRuns(), which can controlled from Python by torch.jit._jit_get/set_num_profiled_runs, the default value is $1$. If the profiling graph is there, but the profiling record is not ready yet, this profiling graph is run (through an ExecutionPlan created from it).

  • Once ready, it creates the optimized graph by running the profiling optimizations (runProfilingOptimizations) on a copy of the profiling graph. This is what we instantiate the ExecutionPlan that is our optimized_plan_ with. When we call fn.get_debug_state() from our Python script function, the debug state's execution_plans member is a dictionary will have this execution plan as the only member. Currently, we get an internal assert error when it doesn't exist (e.g. because we're still in the profiling phase).

And this is really all there is to the Profiling Graph Executor - because we left all the details about the two important bits that are at the core of the executor: the profiling mechanism and the optimizations, in particular after profiling.

We can look at the profiling executor in action:

def fn(x):
    return torch.sin(x) * 2 + x

x = torch.randn(5, device="cuda")

fn(x) # call the function once

This gives us the profiled graph (see the prim::profile nodes):

graph(%x.1 : Tensor):
  %1 : int = prim::Constant[value=1]()
  %2 : int = prim::Constant[value=2]() # <ipython-input-11-43e780f89194>:3:26
  %6 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%x.1)
  %3 : Tensor = aten::sin(%6) # <ipython-input-11-43e780f89194>:3:11
  %7 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%3)
  %4 : Tensor = aten::mul(%7, %2) # <ipython-input-11-43e780f89194>:3:11
  %8 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%4)
  %9 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%x.1)
  %5 : Tensor = aten::add(%8, %9, %1) # <ipython-input-11-43e780f89194>:3:11
  %10 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%5)
   = prim::profile()
  return (%10)

As there isn't an optimized graph yet, getting the debug state throws an exception:



RuntimeError                              Traceback (most recent call last)
<ipython-input-12-22ea0456b2eb> in <module>
      1 # as there isn't an optimized graph yet, this throws an exception
----> 2 print(fn.get_debug_state().execution_plans)

RuntimeError: optimized_plan_ INTERNAL ASSERT FAILED at "../torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp":556, please report a bug to PyTorch. 

because we hit an internal assert because there is no optimized plan yet.

But when we run the function another time (and because the number of profiling runs is 1 by default) we get the optimized graph with TensorExpr fusion group:

fn(x)  # run a second time

Note the TypeCheck and the bailout function we met in the introduction to how the jit optimizes functions(external link). There we also saw a trick to get at the fallback function's graph.

With the optimized plan defined, we can now also get the debug state.


Regardless of how often we called the function and with what arguments, we only ever have one execution plan in the dictionary, as this is just a dummy mapping to the debug state designed for the traditional Graph Executor:

{<torch._C.ArgumentSpec object at 0x7f92021cb470>: <torch._C.ExecutionPlan object at 0x7f92021cbd30>}

Having seen the profiling graph executor in code and in action, let us now look briefly at the traditional one.

The traditional Graph Executor

The traditional GraphExecutor isn't as interesting to us because it's on its way to retirement (well, maybe slowly) but so here is an equally brief overview, and maybe some details that are interesting ideas to contrast with the profiling executor:

The difference starts when the getPlanFor method is called. Depending on whether optimizations are enabled, it calls getOrCompile or getOrCompileFallback, the interesting bit being the getOrCompile. (The non-optimized parts still have some passes they need to apply, but they basically leave out most things that are optional.)

The traditional executor distinguishes between different input configurations (shapes, requires grad, whether optionals are None etc.) and creates distinct optimized graphs for them. This was a key ingredient to optimizing LSTM backwards(external link) because it allowed to make the information whether broadcasting has happened int the forward "static" for the backward. It does so by having a minified version of the information listed above and a hash table (the execution_plans directory in the debug state). So in getOrCompile the JIT creates an ArgumentSpec that is the key to the plan cache. If it finds a plan, it returns that, else it compiles a new one for this spec in compileSpec.

As the traditional executor relies on shape propagation to apply optimizations, it seeds the input's shape information. Then it applies the optimizations (inlining functions etc.) that always work (similar to the profiling insensitive ones in the profiling executor), followed by the differentiation mechanism and optimizations that can only be executed when things do not require gradients (namely fusion with the traditional fuser) either inside the differentiable graph's forward or for graphs that don't need gradients.

And this is really all we need to know about the traditional executor. To see it in action, we can switch to it and run script functions as in the following (note that you want to re-define and re-script the function to not get cached results):

def fn(x):
    return torch.sin(x) * 2 + x

with torch.jit.fuser("fuser0"):
    old_pe = torch._C._jit_set_profiling_executor(False)
    gr1 = fn.graph_for(torch.randn(5, device="cuda", requires_grad=False))
    gr2 = fn.graph_for(torch.randn(5, device="cuda", requires_grad=True))

# we find two execution plans, but there isn't a way to see the argspec in Python

print(gr1)  # it seems the fuser0 is already gone here...

We see that there are two execution plans:

    {<torch._C.ArgumentSpec object at 0x7f402564a370>: <torch._C.ExecutionPlan object at 0x7f4023119eb0>, <torch._C.ArgumentSpec object at 0x7f40257bc530>: <torch._C.ExecutionPlan object at 0x7f40257bc2b0>}

But no sign of fusion:

    graph(%x.1 : Float(*, requires_grad=0, device=cuda:0)):
      %1 : int = prim::Constant[value=2]() # <ipython-input-67-7eb141224258>:3:26
      %2 : int = prim::Constant[value=1]()
      %3 : Float(*, requires_grad=0, device=cuda:0) = aten::sin(%x.1) # <ipython-input-67-7eb141224258>:3:11
      %4 : Float(*, requires_grad=0, device=cuda:0) = aten::mul(%3, %1) # <ipython-input-67-7eb141224258>:3:11
      %5 : Float(*, requires_grad=0, device=cuda:0) = aten::add(%4, %x.1, %2) # <ipython-input-67-7eb141224258>:3:11
      return (%5)

The interpreter

The main part of the interpreter is in the Code class, or rather the CodeImpl one in torch/csrc/jit/runtime/interpreter.cpp.

This has two main parts:

  • Translating graphs to "bytecode" sequences, this is done on instantiation.
  • Running the bytecode.

For the translation, the constructor of CodeImpl calls emitCodeForBlock on the input graph's main block. emitCodeForBlock then has a typical recursive visitor pattern that produces bytecode for the various things. The regular bits like calls to PyTorch functions are done by emitOperator that calls the JIT ops (aten::... and custom ops) but anything control flow is handled on its own code generator function, dispatched from emitNode in a large switch statement on the various special node types, with some amendments valid inside blocks in emitNodeAtBlockLevel. Note that some interesting things, like differentiable graphs from Autodiff are also implemented via "regular" operators, so they don't show up here.

As is a familiar pattern now, the InterpreterState is a forwarding wrapper for InterpreterStateImpl which holds the TorchScript VM execution state (like call stack of frames, registers etc.) and does the actual execution. The execution goes from InterpreterStateImpl::run (or runAsync, as the graph executors have, too) to runImpl which is a state machine with a large switch statement for the various instructions. Of special note are CALL and INTERFACE_CALL as well as (context-manager) EXIT which run functions (builtin or graph ones). Graph functions are run in runGraphFunction just like they are when called from Python above: We call the functions get_executor() method and the executor's getPlanFor and take its code member. In a new interpreter frame, we run that. Note how this way, the function's graph will go through the optimization as needed.


So that is what happens when you run a TorchScript function. I hope you enjoyed the technical dive into the parts of the JIT runtime that execute bits and the new (to me) form of walking through the code. So far we left out the compilation parts, these are for the another time, as a detailed look at the frontend, i.e. how we get script functions in the first place. This will be very important to us at a later point, too.

I appreciate your feedback and comments - mail me at tv@lernapparat.de.