Lernapparat

A selective excursion into the internals of PyTorch

July 28, 2018

The beauty of PyTorch is that it makes its magic so conveniently accessible from Python. But how does it do so? We take a peek inside the gears that make PyTroch tick. (Note that this is a work in progress. I'd be happy to hear your suggestions for additions or corrections.)

Planning and Packing

We wish to take a selective tour of PyTorch. I will ignore some things that are not as useful for knowing how to e.g. implement a new functions - unless I find it interesting in itself. For the Python bits, using IPython and ? and ?? is a great way to see what's going on.

We assume that you have a git tree checked out and built. I can generally recommend having one of those in order to search for things as sometimes what you look for can be in automatically generated files. If you want GDB debugging build with an environment variable DEBUG=1.

Parts of PyTorch

As you may know, the PyTorch repository also hosts Caffe2. We will largely ignore that. To us, the most important directories are:

  • torch, the Python modules,
  • torch/csrc, the C++ bits of PyTorch,
  • aten, the Tensor library,
  • tools, the magic for autogenerating important bits of PyTorch's functions without programmers having to spell out all the boilerplate. In addition to the scripts, it also hosts templates and definition lists.
  • test, the unit tests,
  • docs, the documentation.

Tensors

In PyTorch (0.4+), we have Tensors as the central datatype in Python. (That appears to differentiate into FloatTensor (=Tensor), DoubleTensors, cuda.FloatTensors etc, but that's a trick: while Tensor is a type just like any class in Python, the others are of type tensortype.) Simple enough, they are defined in torch/tensor.py. But from there we quickly get into the C land: They are derived from torch._C._TensorBase. By the way nn.Parameter is itself a subclass of Tensor, mainly designed to signal nn.Modules that they are learnable.

Now _TensorBase is defined in torch/csrc/autograd/python_variable.cpp - it is the Python side of a datastructure THPVariable (see python_variable.h) that wraps a torch::autograd::Variable and keeps track of the Python _backward_hooks. Variable here indicates that it provides autograd-tracing (if you remember the 2017 PyTorch, the distinction between Variable and Tensor in C++ will be familiar to you).

torch::autograd::Variable is in turn defined in torch/csrc/autograd/variable.cpp and provides much of what we are used to from Python Tensor. It is a subclass of at:Tensor from ATen.

ATen mainly lives in aten/src/ATen, with some (many) legacy functions from Torch in the aten/src/TH* directories. It knows nothing about derivatives and autograd, but defines the basic at::Tensor type (an "Array" on the CPU or GPU) and many of the functions the operate on them and are eventually exported to work on Python Tensors for us.

Functions

Arguably the central interfaces of PyTorch are the various Modules in torch.nn and their functional counterparts in torch.nn.functional often imported as F as well as the functions in torch that operate on Tensors. Tensor member functions are also imported, but usually they are glue for torch functions and it is OK to think of them as such for the time being. (There are many more important corners but we'll leave it at that.)

Now, torch.nn and torch.nn.functional are pure Python - torch/nn/functional.py has the functional bits and the files in torch/nn/modules define the main module class and the various network modules and losses that we know.

Now functions showing up in torch are more interesting - let's take torch.bilinear - the function behind torch.nn.functional.bilinear - as an example. If you evaluate it (not call) on the IPython prompt, you'll see <function _VariableFunctions.bilinear> and print (or ?) says it is built-in, i.e. defined in C. Sure enough, there is a torch._C._VariableFunctions.bilinear. (Note: That bilinear is exported as torch.bilinear is somewhat accidental. Do use the documented interfaces, here torch.nn.functional.bilinear whenever you can!)

Functions in ATen

This time, we will go from the botom up and start at the ATen level. In aten/src/ATen/native - that hosts the source code to most "new" functions not inherited from Torch, there is a file native_functions.yaml. For bilinear it contains the declaration

- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor
  variants: function

The README.md in the same directory has the details, but it tells you what the parameter and output types and that is available as a function only rather than also being a method of the Tensor type.

(Note: in aten/src/ATen there are Declarations.cwrap and nn.yaml that have similar information and name translations for the functions from aten/src/TH(C) and aten/src/TH(CU)NN.)

Note that it doesn't tell us anything about the backward yet.

If we grep for bilinear in aten/src/ATen/native (I use rgrep to search in directories), we find it's definition in the file Linear.cpp. We see that it mainly is an interface to _trilinear in the same file. Checking native_functions.yaml again, we see that _trilinear is there, too.Note that there isn't much documentation for ATen's API at the moment (some is in aten/README.md and also there are header files in aten/doc, but mostly it is "assume it is similar to PyTorch and figure it out"). (Sure enough, it also shows up as torch._trilinear, but no, don't use it.)

But I said before that ATen is unaware of gradients and backwards. Where do those come from?

Derivatives definitions

So we are looking for the backward of bilinear.

Autograd involves a lot of magic code generation. Our first stop for the search is the derivatives definition file, tools/autograd/derivatives.yaml. This one has explanations at the top. The gist is that it contains maps from the forward to the backward.

But searching for bilinear turns up upsample_bilinear2d_forward and friends, but not bilinear. A dead end? No, we recall that bilinear was mostly _trilinear. And sure enough, that has an entry:

- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim)
  i1, i2, i3: _trilinear_backward(grad, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, grad_input_mask)

So simple enough, there is a function _trilinear_backward providing the derivative. Let's search for it.

As many backwards not defined elsewhere, it is coded in tools/autograd/templates/Functions.cpp (possibly the oddest location of all PyTorch functionality we look at today). And there we see that _trilinear_backward mostly consists of three calls to _trilinear (as you would expect from a product of three matrices) for gradients to each input tensor.

But how does autograd know the backward of bilinear? How can it do higher order derivatives?

Automatic differentiability of ATen functions

Here is the interesting bit: When interfacing ATen functions through torch::autograd::Variables (or Python Tensors, see above), autograd traces the function and constructs the computation tree it needs for backward. Sometimes it is advisable to do an explicit backward for efficiency. In other cases (as with trilinear), the forward involves inplace operations, custom kernels or other stuff that autograd cannot follow. Then you need an explicit backward. But often, derivatives work magically and also, having some technical bits (as _trilinear) with explicit derivatives and a user-facing functions (as bilinear) that adds some bits but is differentiable through tracing.

(I should add something about how this works.)

Custom kernels and other bits

If you want to have custom kernels - i.e. computing directly on the data instead of using ATen functions, there is a bit more to consider. (But you can skip this if you are happy with just the native functions.)

The first is that you need to define a backward (see above) to make your function differentiable. If the backward typically also uses a custom kernel, you would define it like you would the forward, too, i.e. in aten/src/ATen/native.

Let's pick an example again, we randomly choose hardshrink, which we can follow to a a ATen C++ function of the same name. In native_functions.yaml we see

- func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
  dispatch:
    CPU: hardshrink_cpu
    CUDA: hardshrink_cuda

- func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
  dispatch:
    CPU: hardshrink_backward_cpu
    CUDA: hardshrink_backward_cuda

So we have forward and backward (tied together in tools/autograd/derivatives.yaml as described above), and it dispatches to _cpu and _cuda respectively.Sometimes - for example if you also want to consider a CuDNN-based implementation, you would write a manual dispatch, usually in a C++ located in the aten/src/ATen/native_function directory (typically below the CPU functions).

This dispatching is one of the convenient bits of magic in ATen. It is dispatched based on the type of an argument named self if it is a Tensor or, if that doesn't exist, based on the first Tensor or TensorList. (There are a bunch of *.py in aten/src/ATen doing the generation with templates in the template subdirectory.)

But back to our main objective, the kernel functions, we focus on the forward. We find hardshrink_cpu and hardshrink_cuda under /aten/src/ATen/native/ in Activation.cpp and cuda/Activation.cu, respectivelyAs usual .cu means that it is compiled by nvcc..

We look at the cuda version first, because it looks a bit more prototypical. The function looks hardshrink_cuda is very short. It reserves the output tensor out_tensorAlternatively, you could just return a Tensor from the templated function.. Then it uses the AT_DISPATCH_FLOATING_TYPES_AND_HALF macro from ATen/Dispatch.h using a short C++-lambda that calls the templated hardshrink_cuda_kernel function. The template argument is scalar_t (that name being hardcoded in the macro), which is the basic data type (e.g. double or float) given by the argument self.type() to the macro. There are various dispatching macros (like AT_DISPATCH_ALL_TYPES and AT_DISPATCH_INTEGRAL_TYPES), too. A little further up is the templated function, hardshrink_cuda_kernel (still a host function in CUDA terminology - I sometimes prefer to name these _template instead of _kernel and keep _kernel for device functions).

It generally is a good idea to apply the various input checks from ATen/TensorUtils.h (do grep for examples - or I might add a section below).

As hardshrink is a pointwise operation of two tensors (self and out_tensor), it uses at::cuda::CUDA_tensor_apply2 from ATen/cuda/CUDAApplyUtils.cuh. The function could be any __device__ function executed on the GPU, here it is a C++-lambda.

For more complex, non-pointwise functions, you could call cuda kernels on the various Tensor-arguments .data<scalar_t>(), keeping track of .size() and .stride() using the usual cuda <<< >>>.

The CPU-side merges skips the short templated function and does the calculation in AT_DISPATCH_FLOATING_TYPES. Here, at::CPU_tensor_apply2 from ATen/CPUApplyUtils.h is used.

Testing and documenting your function

Note: both tests and documentation work with the torch module they find with the usual Python mechanisms, i.e. what is installed rather than what is in your checkout.

Now you need to add tests to test/*.py. Don't forget to call flake8 on them before you submit your PR. You only need the errors, I call python3 -mflake8 test/| grep ': E' or so.

The documentation can be built by changing into the docs directory and calling make html. It will build the documentation and put it into docs/build/html.

Some odd and ends

Some functions don't fit the scheme above and are defined more "manually". For example, the torch.tensor function takes more or less arbitrary data to make a tensor of, but that needs to be much more flexible than the "usual" functions should be. We find a function THPVariable_tensor in torch/csrc/autograd/generated/python_torch_functions.cpp, which calls tensor_ctor in torch/csrc/utils/tensor_new.cpp. Finally, in internal_new_from_data in the same file, we can see that torch goes through the various types you can pass to tensor. For example, you'll be able to spot the To copy construct from a tensor warning that you get when you pass a Tensor to torch.tensor.

Conclusion

We took a very brief tour through some bits of PyTorch that are good to know when implementing new functions or just wanting to know how things work.

I hope you enjoyed the read, please do send me your feedback. Have fun hacking PyTorch!