Lernapparat

Exploring Python fallback for the JIT

Jan. 26, 2021

One of the key difficulty of the almost everything can be scripted promises is what to do with functions the JIT doesn't understand. In lieu of re-implementing all of Python we need to fall back to the Python we have selectively. Join me today in looking how that can be done.

The JIT has awesome optimizations, but they only work when our models, or at least the parts that are to be optimized, are JITed. This is easy enough if our models trace well, but scripting often gets messy because we need to make our entire model JIT-compatible before we have something we can run. This is what a fallback is designed to make easier.

Acknowledgement: This work was made possible by grid.ai, the people known for PyTorch Lightning. Thank you!

Fallbacks

Obviously, the JIT cannot implement everything - all of Python's standard library and all of all other packages. But then what if my model needs to use something the JIT cannot?

Because I lack imagination and I like to work from examples, I took a very simple example with a built-in function. It wasn't all that easy to come up with one, because many useful functions are already implemented.So I don't necessarily want to imply that every model should need to read random files from disk.

import torch
@torch.jit.script
def fn(x : str):
    return open(x).read()

print(fn.graph)
print(fn(__file__))

But this program doesn't work on today's PyTorch:

Traceback (most recent call last):
  File "/home/tv/pytorch/pytorch/../scripts/fallback.py", line 4, in <module>
    def fn(x : str) -> str:
  File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 939, in script
    fn = torch._C._jit_script_compile(
RuntimeError: 
Python builtin <built-in function open> is currently not supported in Torchscript:
  File "/home/tv/pytorch/pytorch/../scripts/fallback.py", line 5
@torch.jit.script
def fn(x : str) -> str:
    return open(x).read()
           ~~~~ <--- HERE

Depending on which slide you take, you get the TorchScript interpreter or the Python fallback. Kids can take any slide they want! It's their playground I borrowed here.

Golem

Before we can fix it, we need to know a bit how functions are getting into TorchScript. I touched briefly on this in my post on fusers, but now we look at it in (selectively) more detail.

How the JIT makes a graph from source code

When scripting a function from Python, the JIT grabs the Python source code (via the inspect module of the standard Python library) and then runs the Python parser from ast (for Abstract Syntax Tree) module. It then transforms the Python AST into TorchScript AST (implemented in C++). The TorchScript AST is very similar to what you would expect, it is defined in torch/jit/frontend/treeview.h. Every node type has a class. Thankfully it also has a dump() function that shows a lisp-like representation of the tree - every opening paranthesis followed by a name is a node of that class and the further elements in the list until the closing paramthesis are the children.

For our little function, the tree looks like this:

(def
  (ident fn)
  (decl
    (list
      (param
        (ident x)
        (option (variable (ident str)))
        (option)
        (False)))
    (option (variable (ident str))))
  (list
    (return
      (apply
        (.
          (apply
            (variable (ident open))
            (list (variable (ident x)))
            (list))
          (ident read))
        (list)
        (list)))))

This is what gets handed to CompilationUnit::defineAvid readers of this blog may recall that we met compilation units in two previous posts. In the JIT runtime overview we saw that they were holding script functions. In the exploration of graph manipulation in Python, we created functions from graphs with their create_function instance, though I didn't talk about it much. in torch/csrc/jit/frontend/ir_emitter.cpp

So define calls (well, instantiates, but the gist is a call) to_ir to get a graph from this tree. This conversion is done in three steps:

  • With a fresh graph instance, a set of node visitors is called along the tree structure, starting with emitDef. It produces an initial graph that "looks like Python" and is in non-SSAThe defining feature of static single assignment (SSA) form is that eliminates conventional variables, but instead values are only set once. This makes it easier for optimization passes to reason about them and to generate code. SSA has other aspects, like how loops are handled, but the variable bit will be the most important part for us. form.

  • This is passed to a ConvertToSSA pass, which does what the name insinuates.

  • Then some normalization is carried out in CanonicalizeModifiedLoops and NormalizeOps followed by some initial cleanup passes (e.g. in simple cases split tuples into separate values).

One thing to know about the graph visitation in the first step is that it has to deal with "external" (to our function) references. It uses a resolver passed to define to find the matching Python objects or - for particular functions - the TorchScript overrides. Now these things can be vary strongly in their nature - from Tensor variables (or, to us, constants) to other functions to classes - and we do not want to have all of that show up in our TorchScript graphs.

To deal with the discrepancy between all things that might be and the things that can actually be values in a graph, PyTorch defines a data structure Environment (in ir_emitter.cpp) which captures the lexical scopes and local variables and all. Things are stored in the environment in SugaredValues, which we look at in detail below. Then to process any "top-level" identifier (as opposed to an attribute lookup), the graph visitors call Environment::getSugaredVar. This method

  • first checks all local scopes using Environment::findInAnyFrame,
  • if that didn't find anything, it checks a table of magic global sugared values (defined in static table global in the getSugaredVar method),
  • if these didn't return anything, it calls a resolver that looks up tings in the Python environment outside TorchScript and tries to convert this into a SugaredValue. More precisely it calls the resolver three times, first trying to find NamedTuple types, then arbitrary values, and then classes. The resolveValue call will return a SugaredValue, while in the other two cases, we get a type and instantiate specialized SugaredValue subclasses.

The PythonResolver is defined in python/script_init.cpp. It's resolveValue method calls into a python-defined resolution callback to get the desired Python object and calls toSugaredValue (python/python_sugared_value.cpp) to convert the Python object it found to a SugaredValue.

But so what is a SugaredValue?

Sugar, Value!

As mentioned above SugaredValues bridge the gap between anything that can be referenced by name in our programs and what is a Value in the narrow sense of TorchScript graphs.

At the sugared value level, things you can do with the variable call into the SugaredValue object to accomplish things:

  • Sometimes you just want to get a (JIT) Value for the sugared value, this can be obtained by the asValue method. Then you can use the value as arguments in functions calls etc.

  • For things you can do with references, but not generally with values, the compiler can call into methods on the SugaredValue. Notable examples include call (for calling functions, constructores, etc.) and attr for attribute lookup (à la getattr).

  • SugaredValues are subclassed to define the different effects of various methods and help the compiler distinguish valid uses of values (e.g. call a function, add something to a simple value) from invalid ones (like calling a string literal, adding something to a function,...). These things potentially insert things into the graph we're building (e.g. for calling functions) and return a SugaredValue representing the result.

To give you a taste of the richness of this "type system", these are the subclasses of SugaredValue in the JIT (from frontend/sugared_value.h, python/python_sugared_value.h, and - a bit special - three from serialization/import_source.cpp): BooleanDispatchValue, BuiltinFunction, BuiltinModule, ClassNamespaceValue, ClassNamespaceValue, ClassValue, ClosureValue, ConstantParameterList, ConstantTableValue, ExceptionMessageValue, ExceptionValue, FunctionValue, IterableTree, MagicMethod, MethodValue, ModuleDictMethod, ModuleValue, NamedTupleConstructor, NoneValue, OpsValue, PrintValue, PythonSliceClass, PythonValue, RangeValue, SimpleValue, SliceValue, SpecialFormValue, SugaredDict, SugaredEnumClass, SugaredTupleValue, TensorCastValue.

As you imagine, the toSugaredValue function mentioned above uses many of these to represent the various things, notable constants (represented as SimpleValue, functions from Python (e.g. marked torch.jit.ignore represented as PythonValue), and all sorts of special things things the JIT knows more about.

In the end, all things we do on the sugared values will be translated into graph operations and values, so we end with a graph that only contains the runtime universe of JIT types. It still contains loads and stores to variables (prim::Load and prim::Store) and similar things that the ConvertToSSA pass then eliminates (by again building up an environment, but this time either of only types or of only values, as those are separated in TorchScript). This conversion is done in frontend/convert_to_ssa.cpp, for loads and stores in a block (so after loops and so have been dealt with), it is done in EraseLoadStores::eraseBlockLoadStores. This will be important to us later.

Our fallback in the compiler frontend

So what does all that mean for out function above?

It shouldn't be much of a surprise - but we can also fire up the debugger so see it - that the error message above stems from toSugaredValue noticing that it does not want to deal with our case. We can fix that.

--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -1018,8 +1072,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(

   if (py::isinstance<py::function>(obj)) {
     if (typeString(obj) == "builtin_function_or_method") {
-      throw ErrorReport(loc) << "Python builtin " << py::str(obj)
-                             << " is currently not supported in Torchscript";
+      return bindPythonObjectValue(obj, m, loc);
     }
   }

We now have two options: We could try to make an existing SugaredValue class do what we want (and likely, this would be PythonValue. Or, and this is the route we take today (a decision to be revisited later, but to me it looks like the semantics of the result are different enough to want a separate type), we make a new PythonObjectValue, which, as we want it to look like what becomes a JIT Value we make a subclass of SimpleValue.

We define this new PythonObjectValue sugared value class in python_sugared_value.h (the signatures are given, as they are overrides).

--- a/torch/csrc/jit/python/python_sugared_value.h
+++ b/torch/csrc/jit/python/python_sugared_value.h
@@ -345,5 +345,25 @@ struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue {
       size_t n_binders) override;
 };

+struct VISIBILITY_HIDDEN PythonObjectValue : public SimpleValue {
+  PythonObjectValue(Value* v) : SimpleValue(v) {}
+
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Function& m,
+      const std::string& field) override;
+
+  std::string kind() const override {
+    return "computed Python value";
+  }
+
+  std::shared_ptr<SugaredValue> call(
+      const SourceRange& loc,
+      Function& caller,
+      at::ArrayRef<NamedValue> args,
+      at::ArrayRef<NamedValue> kwargs,
+      size_t n_binders) override;
+};
+
 } // namespace jit
 } // namespace torch

We introduce a helper function bindPythonObjectValue in python_sugared_value.cpp that inserts a node of the new kind prim::PyConstant that binds a Python value.Actually it isn't quite as constant, as changes in mutuable Python objects will be reflected. and returns the new PythonObjectValue sugared value. prim::PyConstant mimics prim::Constant, but it as the separation of Python and Python-less parts is crucial to PyTorch (this also gives us the sugared_values and python_sugared_values distinction). Happily, there already is a PyObject JIT type and IValues can be made from Python objects through toIValue (as we met in our JIT runtime overview).

--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -905,6 +905,60 @@
   return std::make_shared<SliceValue>(start, stop, step);
 }

+std::shared_ptr<PythonObjectValue> bindPythonObjectValue(
+    py::object obj,
+    Function& m,
+    SourceRange loc) {
+  Node* n = m.graph()->insertNode(m.graph()->create(prim::PyConstant));
+  n->ival_(attr::value, toIValue(obj, PyObjectType::get()));
+  n->setSourceRange(loc);
+  n->output()->setType(PyObjectType::get());
+  return std::make_shared<PythonObjectValue>(n->output());
+}
+
....

In contrast to SimpleValue we want to enable calls and attribute lookups (and more might be added in the future), so we implement call and attr: call in the python_sugared_value.cpp. As there already is a prim::PythonOp to launch any Python function call, we hook into this, but in contrast to the conventiontional use, our call will use the first input as the function to call, so we remember that we want to change the implementation of prim::PythonOp later.To see prim::PythonOp in action, create a simple method (e.g. multiplying a Tensor by 2), fully type annotate it, and decorate the definition with @torch.jit.ignore. We said that attris similar to Python's getattr and indeed, we'll just use a prim::PythonOp calling getattras the graph created by attr. Both these methods again return PythonObjectValues as sugared values and the graph Values they create are of type PyObject.

...
+std::shared_ptr<SugaredValue> PythonObjectValue::attr(
+    const SourceRange& loc,
+    Function& m,
+    const std::string& field) {
+  // using prim::GetAttr would look nicer in the graph, but we would need
+  // to implement it in the interpreter or move to replacing prim::GetAttr
+  // on PythonObjects
+  // later as a pass
+  std::string cconv(2, 'd');
+  Value* v_field = insertConstant(*m.graph(), field, loc);
+  py::object getattr = py::module::import("builtins").attr("getattr");
+  Node* n = m.graph()->insertNode(m.graph()->createPythonOp(
+      THPObjectPtr(getattr.release().ptr()), cconv, {}));
+  n->setSourceRange(loc);
+  n->addInput(getValue());
+  n->addInput(v_field);
+  n->addOutput()->setType(PyObjectType::get());
+  return std::make_shared<PythonObjectValue>(n->output());
+}
+
+std::shared_ptr<SugaredValue> PythonObjectValue::call(
+    const SourceRange& loc,
+    Function& m,
+    at::ArrayRef<NamedValue> args,
+    at::ArrayRef<NamedValue> kwargs,
+    size_t /*n_binders*/) {
+  auto inputs = toValues(*m.graph(), args);
+  std::string cconv(inputs.size(), 'd');
+  if (!kwargs.empty()) {
+    throw ErrorReport(loc) << "KWARGS currently not supported";
+  }
+  Node* new_node =
+      m.graph()->insertNode(m.graph()->createPythonOp({}, cconv, {}));
+
+  new_node->setSourceRange(loc);
+  new_node->addInput(getValue());
+  for (auto& i : inputs)
+    new_node->addInput(i);
+
+  Value* output = new_node->addOutput()->setType(PyObjectType::get());
+  return std::make_shared<PythonObjectValue>(output);
+}
+
 std::shared_ptr<SugaredValue> toSugaredValue(
     py::object obj,
     Function& m,

In order to print graphs with prim::PyConstant nodes, which have a PyObject IValue as the value attribute (compare to prim::Constant in graphs), we need to add a case to IValue::repr in aten/src/ATen/core/ivalue.cpp. The prim::PyConstant itself is defined - like a prim::CastFromPython that we need in a bit - in aten/src/ATen/core/interned_strings.h.

--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -66,6 +66,8 @@ namespace c10 {
   _(prim, AutogradAllNonZero)        \
   _(prim, AutogradAllZero)           \
   _(prim, Starred)                   \
+  _(prim, PyConstant)                \
+  _(prim, CastFromPython)            \
   _(prim, TupleConstruct)            \
   _(prim, TupleUnpack)               \
   _(prim, TupleIndex)                \

--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -532,6 +532,9 @@ std::ostream& IValue::repr(
       return out << enum_holder->qualifiedClassName() << "." <<
           enum_holder->name();
     }
+    case IValue::Tag::PyObject: {
+      return out << "<python object>";
+    }
     case IValue::Tag::Object: {
       TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind(), ". Perhaps you've frozen a module with custom classes?");
     }

Rising forward?

It seems dubious that rising forward might be an antonym of falling back, but we need a reversal. Above we said that operations on PyObject inputs yield PyObject results. If we do not just want basically use the JIT as a particularly indirect line-by-line Python interpreter, we also need a plan to get our beloved JIT typed objects - such as Tensors, ints, float, and strs from these PyObjects.

At the technical level, we introduce a second new operator, the prim::CastFromPython mentioned above. This takes a (slightly misnamed because it only is one) types argument with a JIT Type attribute. When executed, it will cast the Python object to the desired type just as passing it to a JIT function called from Python would. On error it raises a runtime ValueError.

But how would we do this in our program? We use type annotations on the variable assignment to express our desire to have a certain type. In our example we might say that the result is a str:

import torch
@torch.jit.script
def fn(x : str):
    res: str = open(x).read()
    return res

If we do this, we get an error (the program before would run, and PyObject would show in the schema (the PyTorch JIT signature for a function) as the return value) because the tree visiting in ir_emitter.cpp dispatches the assignment to Environment::setSugaredVar which checks if the annotated type matches the assigned value. We tell it to not throw an exception when the assigned type is PyObject and add a check_type argument to Environment::insertStore which, if true, causes insertStore to add a types attribute with the desired type to the prim::Store node it creates and sets the input Value's type to PyObject.

--- a/torch/csrc/jit/frontend/ir_emitter.cpp
+++ b/torch/csrc/jit/frontend/ir_emitter.cpp
@@ -280,9 +280,14 @@ struct Environment {
       const std::string& name,
       const SourceRange& loc,
       Value* v,
-      TypePtr type) {
+      TypePtr type,
+      bool check_type) {
     auto g = b->owningGraph();
-    g->insertNode(g->createStore(name, v))->setSourceRange(loc);
+    auto n = g->insertNode(g->createStore(name, v))->setSourceRange(loc);
+    if (check_type) {
+      v->setType(PyObjectType::get());
+      n->ty_(attr::types, type);
+    }
     type_table[name] = std::move(type);
   }

@@ -399,14 +404,20 @@ struct Environment {
       if (!annotated_type) {
         annotated_type = as_simple_value->type();
       }
-      if (!as_simple_value->type()->isSubtypeOf(annotated_type)) {
+      if (as_simple_value->type() == PyObjectType::get()) {
+      } else if (!as_simple_value->type()->isSubtypeOf(annotated_type)) {
         throw ErrorReport(loc)
             << "Variable '" << name << "' is annotated with type "
             << annotated_type->repr_str()
             << " but is being assigned to a value of type "
             << as_simple_value->type()->repr_str();
       }
-      insertStore(name, loc, as_simple_value, annotated_type);
+      insertStore(
+          name,
+          loc,
+          as_simple_value,
+          annotated_type,
+          as_simple_value->type() == PyObjectType::get());
     } else {
       value_table[name] = std::move(value);
     }

Then in EraseLoadStore::eraseBlockLoadStores in convert_to_ssa.cpp we insert the prim::CastFromPython node if the prim::Store we process has the types attribute set.

--- a/torch/csrc/jit/frontend/convert_to_ssa.cpp
+++ b/torch/csrc/jit/frontend/convert_to_ssa.cpp
@@ -194,7 +194,18 @@ struct EraseLoadStores {

       switch (n->kind()) {
         case prim::Store: {
-          environment_stack->setVar(n->s(attr::name), n->input());
+          auto v = n->input();
+          if (n->hasAttribute(attr::types)) {
+            auto ty = n->ty(attr::types);
+            auto ta = n->owningGraph()
+                          ->create(prim::CastFromPython)
+                          ->setSourceRange(n->sourceRange())
+                          ->ty_(attr::types, ty)
+                          ->insertAfter(n);
+            ta->addInput(v);
+            v = ta->output()->setType(ty);
+          }
+          environment_stack->setVar(n->s(attr::name), v);
           n->destroy();
         } break;
         case prim::Load: {

We are not quite done here, because we also want to enable casting by return type annotations, so

import torch
@torch.jit.script
def fn(x : str) -> str:
    return open(x).read()

does the right thing. For this, we go back to ir_emitter.cpp find emitReturn and insert a prim::CastFromPython node if we find a type annotation and a return of type PyObject.

--- a/torch/csrc/jit/frontend/ir_emitter.cpp
+++ b/torch/csrc/jit/frontend/ir_emitter.cpp
@@ -1007,7 +1018,16 @@ struct to_ir::emitReturn
             /*allow_conversions=*/true);
       }

-      if (!result->type()->isSubtypeOf(result_type)) {
+      if ((result->type() == PyObjectType::get()) &&
+          (result_type != AnyType::get())) {
+        auto n = graph->insertNode(graph->create(prim::CastFromPython)
+                                       ->setSourceRange(stmt.range())
+                                       ->ty_(attr::types, result_type));
+        n->addInput(result);
+        result = n->output()->setType(result_type);
+      }
+
+      if (!(result->type()->isSubtypeOf(result_type))) {
         throw ErrorReport(stmt.range())
             << "Return value was annotated as having type "
             << result_type->repr_str() << " but is actually of type "

There might be cases missing yet when you want to work with tuples, but that is how it is.

So yay, the JIT frontend can deal with out fallback and we can get back.

The middle ages, er, layers

As will be no surprise to readers of this blog, there are a number of analysis and optimization passes the JIT runs before actually executing code. Most of them are happy to ignore our new two node kinds prim::PyConstant and prim::CastFromPython, but there are two exceptions: The alias analysis and the printing mechanism want to know more. I rand gdb again to find the stack trace for the exceptions being raised, and so we need to modify two and a half places:

  • In runtime/operator.cpp, we need add our new operators to the lists in printerHasSpecialCaseFor and aliasAnalysisHasSpecialCaseFor.
--- a/torch/csrc/jit/runtime/operator.cpp
+++ b/torch/csrc/jit/runtime/operator.cpp
@@ -214,7 +214,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
       prim::CreateObject,  prim::GetAttr,       prim::SetAttr,
       prim::CallFunction,  prim::isinstance,    prim::unchecked_cast,
       prim::tolist,        prim::rpc_async,     prim::rpc_sync,
-      prim::rpc_remote};
+      prim::rpc_remote,    prim::PyConstant,    prim::CastFromPython};

   // WARNING: by adding a value to this set, you are asserting that your
   // primitive is only ever added during optimization and does not need
@@ -324,6 +324,8 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
       prim::Enter,
       prim::Exit,
       prim::FallbackGraph,
+      prim::CastFromPython,
+      prim::PyConstant,
   };

   // Operators that should not be used by alias analysis
  • In for the actual alias analysis we need to specify the aliasing relations in AliasDb::analyzeImpl. I chose to add prim::PyConstant to the creator case and prim::CastFromPython to link input and output (though I don't know how well that works with wherever the PyObject might have come from, wildcard might be an alternative).
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -498,6 +498,7 @@ void AliasDb::analyzeImpl(Node* node) {
     case prim::Closure:
     case prim::CreateObject:
     case prim::tolist:
+    case prim::PyConstant:
       return analyzeCreator(node);
     case prim::TupleConstruct:
     case prim::DictConstruct:
@@ -516,6 +517,7 @@ void AliasDb::analyzeImpl(Node* node) {
         }
       }
       return analyzeExtractor(node);
+    case prim::CastFromPython:
     case prim::unchecked_cast:
       return makePointerTo(node->output(), node->input());
     case prim::ConstantChunk:

And this is all the purely administrative things we needed. I should say that only when I started investigating fallbacks, I started to notice that most infrastructure (Types. Values, IValues,...) is already in place.

Extending and implementing the operators

Actually, there is a small bit of administrative stuff left. The existingprim::PythonOp that we wanted to extend to use the first parameter as the function instead of a fixed one bound directly from Python is implemented via a ConcretePythonOp class in python/python_ir.cpp.

This has two places where we need to deal with the possibility that the Python object that would be a function is "empty" (i.e. a null object), so we add two quick if cases in name (for printing, just emitting a default <PyObjectCall> for now) and for copying (setting the target Python object to null when the source is).

--- a/torch/csrc/jit/python/python_ir.cpp
+++ b/torch/csrc/jit/python/python_ir.cpp
@@ -129,7 +129,9 @@ Node* findNode(Block* block, Symbol kind, bool recurse = true) {

 std::string ConcretePythonOp::name() const {
   pybind11::gil_scoped_acquire gil;
-  if (auto autograd = autogradFunction()) {
+  if (!pyobj) {
+    return "<PyObjectCall>";
+  } else if (auto autograd = autogradFunction()) {
     return getPythonName(autograd->get());
   } else {
     return getPythonName(pyobj.get());
@@ -140,8 +142,12 @@ void ConcretePythonOp::cloneFrom(Node* other_) {
   Node::cloneFrom(other_);
   auto other = other_->cast<ConcretePythonOp>();
   this->cconv = other->cconv;
-  Py_INCREF(other->pyobj.get());
-  this->pyobj = THPObjectPtr(other->pyobj.get());
+  if (other->pyobj) {
+    Py_INCREF(other->pyobj.get());
+    this->pyobj = THPObjectPtr(other->pyobj.get());
+  } else {
+    this->pyobj = {};
+  }
   for (auto& sa : other->scalar_args) {
     Py_INCREF(sa.get());
     this->scalar_args.emplace_back(sa.get());

But now, all that is left is to extend the implementation of prim::PythonOp and write the new prim::PyConstant and prim::CastFromPython.

The functions implementing them are defined and registered as JIT operators in python/python_interpreter.cpp. They take a Node as an argument and return an Operation, a lambda taking a Stack of IValues as input (actually there might be additional values below the ones for the operator) and return the stack with the inputs replaced by the outputs.

Adapting prim::PythonOp is straightforward, thanks to the excellent PyBind11 tooling. On creation we just check the truth value of the pyobj member of the ConcretePythonOperator instance behind the node to decide if the local func object should be initialized from it or left empty. On execution we check if func is empty and if so pop the top input from the stack and cast it to a Python object to get a function to call.

--- a/torch/csrc/jit/python/python_interpreter.cpp
+++ b/torch/csrc/jit/python/python_interpreter.cpp
@@ -12,6 +12,7 @@
 #include <torch/csrc/jit/runtime/graph_executor.h>
 #include <torch/csrc/jit/runtime/operator.h>

+#include <sstream>
 #include <typeinfo>

 #include <pybind11/pybind11.h>
@@ -31,11 +32,14 @@ namespace {
 Operation createPythonOperation(const Node* op_) {
   pybind11::gil_scoped_acquire gil;
   const ConcretePythonOp* op = static_cast<const ConcretePythonOp*>(op_);
-  const py::function func = py::reinterpret_borrow<const py::function>(
-      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
-      py::handle(const_cast<ConcretePythonOp*>(op)->pyobj.get()));
+  const py::function func =
+      (op->pyobj
+           ? py::reinterpret_borrow<const py::function>(
+                 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+                 py::handle(const_cast<ConcretePythonOp*>(op)->pyobj.get()))
+           : py::function{});

-  size_t num_inputs = 0;
+  size_t num_inputs = op->pyobj ? 0 : 1;
   for (auto arg_type : op->cconv) {
     if (arg_type == 'd')
       num_inputs++;
@@ -49,6 +53,12 @@ Operation createPythonOperation(const Node* op_) {
     size_t i = 0;
     size_t next_scalar = 0;
     size_t next_tensor = 0;
+    py::function func_from_val;
+    if (!func) {
+      func_from_val =
+          toPyObject(std::move(peek(stack, next_tensor, num_inputs)));
+      next_tensor++;
+    }
     for (auto arg_type : op->cconv) {
       if (arg_type == 'c') {
         py_inputs[i] = py::reinterpret_borrow<const py::object>(
@@ -65,7 +75,7 @@ Operation createPythonOperation(const Node* op_) {
     }
     drop(stack, num_inputs);
     try {
-      py::object py_output(func(*py_inputs));
+      py::object py_output((func ? func : func_from_val)(*py_inputs));
       stack->push_back(returnToIValue(op->output()->type(), py_output));
     } catch (py::error_already_set& e) {
       throw std::runtime_error(e.what());

The PyConstant Operation is even simpler, jut put the saved IValue in the value member on the stack. We need to take the GIL because copying the IValue here likely needs to increase the Python objects reference counter. I think. I might want to double check.

Similarly, the cast operation implementing CastFromPython takes a Python IValue from the top of the stack, casts with the utility functions also used when invoking ScriptFunctions from Python and returns the result. On error, the cast will throw an exception, but we catch it to translate it to a more Pythonic ValueError.

--- a/torch/csrc/jit/python/python_interpreter.cpp
+++ b/torch/csrc/jit/python/python_interpreter.cpp
@@ -77,10 +87,47 @@
   return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
 }

-RegisterOperators reg({Operator(
-    prim::PythonOp,
-    createPythonOperation,
-    aliasAnalysisIsSpecialCase())});
+Operation createPyConstantOperation(const Node* node) {
+  pybind11::gil_scoped_acquire gil;
+  auto val = node->ival(attr::value);
+  return [=](Stack* stack) {
+    pybind11::gil_scoped_acquire gil;
+    stack->push_back(val);
+  };
+}
+
+Operation createCastFromPythonOperation(const Node* node) {
+  TypePtr typ = node->ty(attr::types);
+  return [=](Stack* stack) {
+    pybind11::gil_scoped_acquire gil;
+
+    py::object pyobj = toPyObject(std::move(pop(stack)));
+    try {
+      stack->push_back(toIValue(pyobj, typ));
+    } catch (py::cast_error& e) {
+      std::stringstream msg;
+      py::object pytype =
+          py::module::import("builtins").attr("type")(pyobj).attr("__name__");
+      msg << "ValueError: cannot cast Python object of type " << pytype
+          << " to TorchScript type " << *typ;
+      throw std::runtime_error(msg.str());
+    }
+  };
+}
+
+RegisterOperators reg(
+    {Operator(
+         prim::PythonOp,
+         createPythonOperation,
+         aliasAnalysisIsSpecialCase()),
+     Operator(
+         prim::PyConstant,
+         createPyConstantOperation,
+         aliasAnalysisIsSpecialCase()),
+     Operator(
+         prim::CastFromPython,
+         createCastFromPythonOperation,
+         aliasAnalysisIsSpecialCase())});

 } // namespace
 } // namespace jit

And that is it!

Now our little TorchScript function runs:

$ PYTHONPATH=build/lib.linux-x86_64-3.8/ python3.8  ../scripts/fallback.py 
graph(%x.1 : str):
  %4 : str = prim::Constant[value="read"]() # ../scripts/fallback.py:5:11
  %1 : PyObject = prim::PyConstant[value=<python object>]() # ../scripts/fallback.py:5:11
  %3 : PyObject = ^<PyObjectCall>()(%1, %x.1) # ../scripts/fallback.py:5:11
  %5 : PyObject = ^getattr()(%3, %4) # ../scripts/fallback.py:5:11
  %6 : PyObject = ^<PyObjectCall>()(%5) # ../scripts/fallback.py:5:11
  %7 : str = prim::CastFromPython[types=str](%6) # ../scripts/fallback.py:5:4
  return (%7)

#!/usr/bin/python
import torch
@torch.jit.script
def fn(x : str) -> str:
    return open(x).read()

print(fn.graph)
print(fn(__file__))

Our patches clock in at a little under 200 lines. It appears that at least in this case, a relatively simple fallback is feasible. We should not kid ourselves: There will be many more cases to handle in the frontend around the SugaredValue and we need tests, too!

Conclusion

Our exploration into a SugaredValue based fallback mechanism worked surprisingly (or emberassingly, I might have forgotten something) well. It will have to be seen how much of the problem it solves. If you want to play with it without copypasting the diff snippets, you can also use my ScriptTorch git branch.

  • We might have cases where we discover during partial de-sugaring (i.e. after attribute lookups) that we cannot complete it and might have needed to invoke the fallback mechanism earlier (i.e. before we knew). This will likely be a tough one to analyse.

  • We did not look at syntax constructs currently not handled by the JIT.

  • We need to work on "standard functions" (the operators and magic functions) working on PyObjects. These would need to be re-routed to PythonOps.

I hope you enjoyed this little expedition into the internals of the JIT, with a view towards implementing a fallback. The JIT already provided us with most of the infrastructure, so this was easy.

But here is your part: In addition to helping out with the code, what is your favourite model bit that you cannot yet script? I look forward to hear from you at tv@lernapparat.de.