Lernapparat

Traceable and Differentiable Extensions with PyTorch

June 26, 2019

Editorial note: This post - from the end of June 2019 - described how to do manually what is now, three months later, offered painlessly by the torch::autograd::Function mechanism. In fact, you can see how we used autograd::Function for the function discussed below in the TorchVision source code. Of course, what autograd::Function does, is providing a nice wrapper for the things I described below.

Three of the most liked features of PyTorch are the extensible autograd mechanism, the ability to extend PyTorch with C++ efficiently, and the tracing/scripting mechanism, the PyTorch JIT. Which leads to the natural question - can we have all at the same time? In this post, we dive into the autograd internals and come out with a solutionThis is not the @torch.jit.script decorated autograd.Function which I'm hacking on, but something that is all available today and perhaps not even breaking news - but by my estimate, there are only 3-5 people aware of it working now. We will have an exciting followup..

Let us see. While C++ extensions are arguably the most popular way of extending PyTorch, there also are C++ custom ops that extend TorchScript, i.e. are traceable / scriptable. The price you pay is that you are limited to the data types that TorchScript provides, but with Tensor, float, int, string and even lists and dicts of them, you are pretty much set. All you need to do is to use RegisterOperators in your module, i.e.

static auto registry =
  torch::RegisterOperators().op("mylib::something", &do_something);

You can do this in addition to the C++-Extension (PyBind11-) bindings for your functions or instead of them. If you remove the extension bindings, you become independent of Python and can load your library in C++-Programs, too, but you might have to think about how to load it from Python.

We have C++ and scripting. But now comes the difficulty, how do we get differentiability? The standard way of doing this - e.g. suggested in the C++ extension tutorial - is to implement forward and backward and then wrap them using a torch.autograd.Function in Python. But here is the problem: those are not scriptable.

So if a C++-implemented operator wrapped in an autograd.Function is not scripteable, maybe we need an "autograd.Function" wrapped in a C++ operator. The only problem is, we do not have (or want) Python inside C++.

So how do Functions work at the C++ level? If you've read my selective excursion into PyTorch externals, I have not gone into any detail there. Edward Yang's great blog post and slides do have it, but alas he skipped the following seven slides [in his NY meetup talk and is] also going to delay writeup for them; you'll have to wait for the sequel for some text. But we had a glimpse at the Function slide and that's the thing we need. I cannot recommend Edward's blog post highly enough, it is a very gentle introduction and gets you up to speed to the point where I learn a lot from reading it. Thank you, Edward!

If you have not worked with Python's torch.autograd.Function or want to refresh your memory, I recommend checking out the documentation as well as the Autograd mechanics chapter of the PyTorch documentation. As this Python side of autograd is well known, we will use that as a reference when diving into the internals.

So I should add a rather large caveat here: Just like in PyTorch 0.4, we had the great Variable/Tensor merge on the Python side, PyTorch is currently (June 2019 / master showing 1.2 as the next version), Will Feng is currently working on merging them in C++, too. This has lead to some breaking changes (for me it was .data() going away in the code we are going to develop - just as you should not use .data in Python these days.

Dissecting how PyTorch builds the computational graph

So let us look at how PyTorch's own functions manage to work with autograd. We pick the simplest possible function - a (pointwise, unary) function that takes one tensor and produces one of the same shape - say atan. After you compile PyTorch, there is a file torch/csrc/autograd/generated/VariableTypeEverything.cppI use rgrep a lot to find my way through the PyTorch code. After forming some basic opinion about where what I am looking for might be - say torch/csrc/autograd - I use rgrep '\batan\b' torch/csrc/autograd. So VariableTypeEverything.cpp has been split up into pieces and we also get hits in the pieces.. Note that we want atan, not the inplace atan_ nor the out of place atan_out function.

Here is how it looks like for me:

Tensor VariableType::atan(const Tensor & self) const {
  RECORD_FUNCTION("atan", std::vector<c10::IValue>({self}), Function::peek_at_next_sequence_nr());
  auto& self_ = unpack(self, "self", 0);
  std::shared_ptr<AtanBackward> grad_fn;
  if (compute_requires_grad( self )) {
    grad_fn = std::shared_ptr<AtanBackward>(new AtanBackward(), deleteFunction);
    grad_fn->set_next_edges(collect_next_edges( self ));
    grad_fn->self_ = SavedVariable(self, false);
  }
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::atan");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #ifndef NDEBUG
  c10::optional<Storage> self__storage_saved =
    self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
  c10::intrusive_ptr<TensorImpl> self__impl_saved;
  if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
  #endif
  auto tmp = ([&]() {
    at::AutoNonVariableTypeMode non_var_type_mode(true);
    return baseType->atan(self_);
  })();
  auto result = as_variable(tmp);
  #ifndef NDEBUG
  if (self__storage_saved.has_value())
    AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
  if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
  #endif
  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}

Uh. That is a lot to digest! We need to focus, so we will first ignore everything that is in those #ifndef NDEBUG or looks jit::tracer-related (we want to use this in a C++ custom op, so we already have tracing from that). Let us delete those lines. We are left with

Tensor VariableType::atan(const Tensor & self) const {
  RECORD_FUNCTION("atan", std::vector<c10::IValue>({self}), Function::peek_at_next_sequence_nr());
  auto& self_ = unpack(self, "self", 0);
  std::shared_ptr<AtanBackward> grad_fn;
  if (compute_requires_grad( self )) {
    grad_fn = std::shared_ptr<AtanBackward>(new AtanBackward(), deleteFunction);
    grad_fn->set_next_edges(collect_next_edges( self ));
    grad_fn->self_ = SavedVariable(self, false);
  }

  auto tmp = ([&]() {
    at::AutoNonVariableTypeMode non_var_type_mode(true);
    return baseType->atan(self_);
  })();
  auto result = as_variable(tmp);

  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }

  return result;
}

Well, still large, but much more manageable! Let us go through it in pieces and see what we find. There is some very important great news. Looking at the signature

Tensor VariableType::atan(const Tensor & self) const {

we see that while we won't be part of a class like VariableType, we can perfectly relate to Tensor atan(const Tensor & self) - we would use something like that in our extension function or custom op implementation, too, with Tensor being torch::Tensor. Awesome!

The next line

  RECORD_FUNCTION("atan", std::vector<c10::IValue>({self}), Function::peek_at_next_sequence_nr());

is for the purposes of the PyTorch profiler (did you know that exists? It is great!)You might have guessed, I used rgrep RECORD_FUNCTION torch/csrc/ to find the definition..

The next line is unpack, and we look at it together with as_variable near the end:

  auto& self_ = unpack(self, "self", 0);
  ...
  auto result = as_variable(tmp);

So these bits are the effect of still having Tensor and Variable - the incoming Tensor self is a variable in disguise and we unpack that to a pure (non-Variable) Tensor and then we wrap the result tmp into a variable-in-disguise Tensor result. Unsurprisingly, unpack is not a very unique name in the PyTorch codebase, here it is VariableType::unpack. This is a bit that is not too unlikely to change in the near future with Will's work mentioned above. For now we need them and will need to come back to them later.

In the middle we have the actual calculation:

 auto tmp = ([&]() {
    at::AutoNonVariableTypeMode non_var_type_mode(true);
    return baseType->atan(self_);
  })();

This is using lambdas in a creative way in order to not spell out the type of tmp and still get scoping, but what happens is that at::AutoNonVariableTypeMode is a guard variable that is similar to making the remainder of the scope behave like with torch.no_grad():. Indeed, it is a functional equivalent of the more familiar torch::NoGradGuard and will be merged into that in due courseJust like with statements are often handy in PyTorch for backend flags, default types, autograd mode, PyTorch C++ uses those guards for both user-facing things and internally.. The line return baseType->aten(self_) is the actual calculation. Of course, we would do our own instead.

The remaining bits of code now are the actual graph recording bits:

  std::shared_ptr<AtanBackward> grad_fn;
  if (compute_requires_grad( self )) {
    grad_fn = std::shared_ptr<AtanBackward>(new AtanBackward(), deleteFunction);
    grad_fn->set_next_edges(collect_next_edges( self ));
    grad_fn->self_ = SavedVariable(self, false);
  }
  ....
  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }

Note carefully that while the actual calculation worked with the unpacked object self_ and tmp, these bits only work on the wrapped objects self and result.Some of you asked for a talk or video on this topic, we will see how well I am at keeping self_ and self apart when speaking.

By checking compute_requires_grad( self ) we only record a graph when something needs a gradient (i.e. requires_grad is set and we're not in a with no_grad() block). The function takes any number of tensors as arguments. The second if will be true precisely when the first was, because we initialize the shared pointer there.

The AtanBackward object is declared in the first line (using a shared pointer - thanks for not needing raw pointers!) is what you also see in Python when you check out x.grad_fn for some calculation result. It is a node in the graph recording the calculation. ATanBackward is a subclass of torch::autograd::Function which does the required bookkeeping for calling backward later.It may be surprising that the superclass of all Backward classes is called Function. In a way, this is the PyTorch 0.1.2 way before Python-level new-style torch.autograd.Functions were introduced and the information that we now save in the ctx Contexts has was stored in instances of the torch.autograd.Function object. After instantiating itThe deleteFunction deleter was not declared as part of the public API, which we fixed while writing this, so you need a very recent nightly/master for things to work., we hook it up to the existing graph by calling set_next_edges with the edges from collect_next_edges. collect_next_edges again takes a variable number of arguments so you would hook up all your inputs' sub graphs.Alas, I think collect_next_edges might not be public API, but I missed this in the first version of the draft because I was working off a feature branch for other work. Hopefully, we'll get an official API soon.

Then graph_fn->self_graph_fn->self_ has nothing to do with self_ other than that they are both related to self, the former is more wrapped, the other is unwrapped... is set to the SavedVariable-wrapped self. This is similar to Python torch.autograd.Functions calling ctx.save_for_backward, which will wrap all of the arguments into SavedVariables. The purpose of SavedVariable is to sanity check that nothing bad happened to our tensor in between (the most (in)famous non-sane thing being caught and raised by this is the one of the variables needed for gradient computation has been modified by an inplace operation exception)This is also why you should save inputs and outputs using ctx.save_for_backward but do not necessarily need to do so for intermediate results that have no references outside your torch.autograd.Function - if noone knows your intermediates, noone can drive you crazy by modifying them without you noticing.. self_ is just an arbitrary member of AtanBackward, we will define our own subclass of torch::autograd::Function later.

Finally, we need to attach our graph to our results, which is done in the second if block by

      set_history(flatten_tensor_args( result ), grad_fn);

As autograd needs to know how many and which outputs we have, we pass them all in one go, flatten_tensor_args takes the results (varargs again) and hands a list to set_history, which then connects grad_fn to each of them.

Phew. So this is what needs to happen in our forward, but how does AtanBackward work?

The backward

So the AtanBackward is defined in torch/csrc/autograd/generated/Functions.h, with details in the Functions.cpp.

Let us first look at the declaration in the Functions.h

struct TORCH_API AtanBackward : public TraceableFunction {
  using TraceableFunction::TraceableFunction;
  variable_list apply(variable_list&& grads) override;
  std::string name() const override { return "AtanBackward"; }
  void release_variables() override {
    self_.reset_data();
    self_.reset_grad_function();
  }

  SavedVariable self_;
};

We will conveniently ignore that it is a TraceableFunction subclass and pretend that it is only a Function subclass. It defined overrides for three methods: apply which is the actual backward computation, name returning the name, and release_variables which cleans up each after the SavedVariable members by calling their reset_data and reset_grad functions. The latter two are straightforward to adapt when we define our own: Return our own name from name and just do the right thing for all our SavedVariable members (we see that we only have self_ here).

This leaves the apply method defined in the Function.cpp, which needs a closer look:

variable_list AtanBackward::apply(variable_list&& grads) {
  IndexRangeGenerator gen;
  auto self_ix = gen.range(1);
  variable_list grad_inputs(gen.size());
  auto& grad = grads[0];
  auto self = self_.unpack();
  if (should_compute_output({ self_ix })) {
    auto grad_result = grad / (self * self + 1);
    copy_range(grad_inputs, self_ix, grad_result);
  }
  return grad_inputs;
}

Looking at the signature

variable_list AtanBackward::apply(variable_list&& grads) {

we get a variable list of gradients (grad_out if you wish) and need to produce a variable list of input gradients (grad_ins). So far so good. A variable_list is just a std::vector<Variable> and we can ignore the Variable vs. Tensor bit for a moment.

The IndexRangeGenerator business looks taunting. Remember that we called collect_next_edges with our inputs? This now helps us map back the list of inbound edges (and grad_ins in the backward results) to the various arguments. For Tensor arguments this is fairly boring, it would get more interesting for lists of tensors and such. After we got all those index ranges, we know how many inputs we had in total and can allocate grad_inputs to the right size (note that by default, the gradients are set to undefined Tensors here, these map to None in Python). Similar to the inputs, we used flatten_tensors on the outputs, and we get to collect the grads. In

  auto& grad = grads[0];

grad is the gradient of our only output. We only need a reference, not a copy (for efficiency reasons). Next we unpack our SavedVariables:

  auto self = self_.unpack();

Note that this SavedVariable::unpack is strictly unrelated to the unpack function we used in the forward. It does the aforementioned sanity checks.

Next we go through our inputs (but we only have one) and if should_compute_output says we should, we compute the gradient of the input and copy it into the grads vector. If you look above the function, you see that there is another variant of should_compute_output that foregoes the IndexRange business and takes just a simple integer index.

The actual gradient computation is

    auto grad_result = grad / (self * self + 1);

Now this is generated from a template by clever tools from the PyTorch tools/autograd/ directory during compilation. If we did it manually, we could rip out the IndexRangeGenerator and write our AtanBackard::apply as

variable_list AtanBackward::apply(variable_list&& grads) {
  variable_list grad_inputs(1);
  auto& grad = grads[0];
  auto self = self_.unpack();
  if (should_compute_output(0)) {
    grad_inputs[0] = grad / (self * self + 1);
  }
  return grad_inputs;
}

which has a bit less baggage.

Rolling our own

So is our new knowledge sufficient to conquer the quest of differentiable, traceable C++ operators? Let us try our hand at a practical example. My fellow PyTorch developer and online friend Francisco Massa introduced a C++ extension for MaskRCNN functions in TorchVision 0.3 and my extensive discussions around the topic with him have been part of the inspiration for this blog post. We pick one, say, roi_align and try.

As is recommended by the official extension tutorial, TorchVision wraps the forward and backward C++ functions into a torch.autograd.Function function, in torchvision/ops/roi_align.py. I will not copy the full thing here, but it takes twoTensor inputs inputs and rois, and parameters output_size (two ints pool_h, pool_w), a float spatial_scale and another int sampling_ratio. It computes a gradient only for input (rois might be integral, but I have not checked). For this it needs the parameters, the rois input Tensor, but only the shape of the input tensor, not the entire input tensor. It saves those in the context. The backward does not need the output. Other than those administrative things the function merely calls roi_align_forward and roi_align_backward from the C++ extension. So let us create a C++ custom operator equivalent of the functionRecall the we prefer the custom op over an extension because it is traceable. We can also just wrap the function for the op in the extension interface and get a differentiable but not traceable extension function..

For simplicity, we will put our changes directly into torchvision, in the main module file vision.cppI would not expect that dumping this is in there is up to TorchVision standards, but we want to get a prototype fast....

So first we need all the utility functions mentioned above, so we include the worldMaybe it would be beneficial to streamline this and bless some more official way of defining this. I hope this blog post can facilitate a discussion here..

#include <torch/csrc/autograd/VariableTypeUtils.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/variable.h>
#include "torch/script.h"

we also want these fancy shortcuts Tensor and variable_list, so

using torch::Tensor;
using torch::autograd::variable_list;

So we declare our Backward object:

struct ROIAlignBackward : public torch::autograd::TraceableFunction {
  variable_list apply(variable_list&& grads) override;
  std::string name() const override {
    return "ROIAlignBackward";
  }
  void release_variables() override {
    rois_.reset_data();
    rois_.reset_grad_function();
  }
  torch::autograd::SavedVariable rois_;
  double spatial_scale;
  int64_t pooled_height;
  int64_t pooled_width;
  int64_t sampling_ratio;
  int64_t batch_size, channels, height, width;
};

Nothing surprising here, what the Python function saved into the ctx Context is now declared as fields. The rois_ are a SavedVariable. Note that Python floats and ints are mapped to double and int64_t. While the extension mechanism (via PyBind11) is lenient here (maybe it should not), this is a necessity for successfully using custom ops and dealing with TorchScript in general.

Our ROIAlignBackward::apply method also mimics the simplified version of the ATanBackward:

variable_list ROIAlignBackward::apply(variable_list&& grads) {
  variable_list grad_inputs(1);
  auto& grad = grads[0];
  auto rois = rois_.unpack();
  if (should_compute_output(0)) {
    grad_inputs[0] = ROIAlign_backward(
        grad,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        batch_size,
        channels,
        height,
        width,
        sampling_ratio);
  }
  return grad_inputs;
}

Even though the forward has two tensor inputs (and a bunch of others), we only return a one-element gradient_inptus (for inputs). This is different to how things work in torch.autograd.Functions but this means that we will only pass inputs to collect_edges in the forward. Just like the Python function, we simply hand off the calculation to the ROIAlign_backward method.

That was not too bad! Let us do the forward. Here, we hit a small stumbling block: VariableType::unpack is a private static method! But looking at it, it only calls VariableType::checked_cast_variable, we might just substitute that. But alas, that, too, is a private static method. So what does it do? It checks t.defined() and t.is_variable(), raising an exception if they do not return true, and then returns as_variable_ref(t);. Happily that is a function and available to us. As we are not in the torch::autograd namespace, we need to prefix the functions with the namespace.

With this in mind, we write our function roi_align. It turns out a bit long because of the many inputs we need to add to grad_fn, but is very straightforward otherwise.

Tensor roi_align(
    const Tensor& input,
    const Tensor& rois,
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t sampling_ratio) {
  // checks from VariableType::unpack
  TORCH_CHECK(input.defined() && input.is_variable(), "invalid argument input");
  TORCH_CHECK(rois.defined() && rois.is_variable(), "invalid argument rois");
  // we might error if rois requires grad...
  auto& input_ = torch::autograd::as_variable_ref(input);
  auto& rois_ = torch::autograd::as_variable_ref(rois);
  std::shared_ptr<ROIAlignBackward> grad_fn;
  if (torch::autograd::compute_requires_grad(input, rois)) {
    grad_fn = std::shared_ptr<ROIAlignBackward>(
        new ROIAlignBackward(), torch::autograd::deleteFunction);
    grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); // note, only input!
    grad_fn->rois_ = torch::autograd::SavedVariable(rois, false);
    // extra bookkeeping
    grad_fn->spatial_scale = spatial_scale;
    grad_fn->pooled_height = pooled_height;
    grad_fn->pooled_width = pooled_width;
    grad_fn->sampling_ratio = sampling_ratio;
    grad_fn->batch_size = input.size(0);
    grad_fn->channels = input.size(1);
    grad_fn->height = input.size(2);
    grad_fn->width = input.size(3);
  }
  auto tmp = ([&]() {
    at::AutoNonVariableTypeMode non_var_type_mode(true);
    return ROIAlign_forward(
        input_,
        rois_,
        spatial_scale,
        pooled_height,
        pooled_width,
        sampling_ratio);
          })();
  auto result = torch::autograd::as_variable(tmp);
  if (grad_fn) {
    set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
  }
  return result;
}

Note how we pass only input to collect edges and how which inputs we calculate gradient for is separate from which ones we store for the backward.

With that we are all set. All that is left to do is export our operatorWe could just do static auto registry = torch::RegisterOperators().op("torchvision::roi_align", &roi_align); similar to the tutorial examples, but for a widely used library like torchvision, it might be good to have argument names show up in error messages etc..

static auto registry = torch::RegisterOperators()
  .op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor", &roi_align);

Done, here we have our own differentiable, traceable/scriptable custom operator torch.ops.torchvision.roi_align!

Even for this little excercise, we include one test case - proper testing would cover much more. In test/test_ops.py we find a test test_roi_align_gradient_cpu testing the gradient. We add a test for our shiny new op

        x2 = x.detach().requires_grad_()
        y2 = torch.ops.torchvision.roi_align(x2, rois,
                                             roi_align.spatial_scale,
                                             roi_align.output_size[0],
                                             roi_align.output_size[1],
                                             roi_align.sampling_ratio)
        y2.sum().backward()

        assert torch.allclose(x2.grad, gt_grad), 'gradient incorrect for RoIAlign CPU in custom op'

This can be run using python3 test/test_ops.py RoIAlignTester.test_roi_align_gradient_cpu after installing torchvision. It works!

Conclusion and Outlook

We inspected PyTorch's autograd mechanism in great detail to uncover how it works in C++. We then used this pattern to add ready differentiability into a C++ custom op. (Note that the backward will not be differentiable unless you use similar methods as we did here.)

As we saw, getting there had some rough edges and the Tensor/Variable merge might bring changes that break this (but probably also making it simpler, so it is good). In order to make this ready for mainstream consumption, we should bless some variant as an official way to implement Functions in C++, so this blog post could be a tutorial on the PyTorch web site.

My code is here, but it is more an example than PR material.

Is this the best way to do differentiability in a JIT-compatible way? It is the one working. But would it not be neat to define forward and backward custom ops and just @torch.jit.script our autograd.Function? That would enable to tie into the JITs source-to-source differentiation capabilities and be much easier for implementors. But it needs quite a bit of hacking in PyTorch. I am very proud that I have made a prototype of that, and we will discuss that next time.

PyTorch Training

As you can tell, I like PyTorch internals. I also like to talk and write about them and about how to use PyTorch efficiently. I offer inhouse and public workshops for beginner, intermediate and PyTorch expert levels. If you are in near Munich (say, in Europe) and need PyTorch training, I love to hear from you! I also do bespoke development.

I hope this blog post is useful to you, I appreciate and read every mail you send to tv@lernapparat.de.