PyTorch best practices

Sept. 11, 2020

We look at the some best practices but also try to shed some light at the rationale behind it. Whether this becomes a series or an updated blog post, we will see.

Are you an intermediate PyTorch programmer? Are you following documented best practices? Do you form opinions on which ones to adhere to and which one can be foregone without trouble?

I will liberally admit that I sometimes have a hard time following the best practices when the thing they are advising against seems to work and I don't fully understand their rationale. So here is a little thing that happened to me.

A story of me doing stuff (Quantization)

After building PyTorch for the Raspberry Pi, I've been looking to do some fun projects with it. Sure enough, I found a model that I wanted to adapt to running on the Pi. I had the thing running soon enough, but it was not as fast as I would like it to be. So I started looking at quantizing it.

Quantization makes any operation stateful (temporarily)

Now if you think of a PyTorch computation as a set of (Tensor) values linked by operations, quantization consists of taking each operation and forming an opinion what range of values output Tensors would take in order to approximate numbers in that range by integers from the quantized element type via an affine transformation. Don't worry if that sounds complicated, the main thing is that now each operation needs to be associated "with an opinion", or more precisely an observer that records the minimum and maximum value that has been seen around over some exemplary use of the model. But this now means that during quantization, all operations become stateful.More precisely, they become stateful when preparing for quantization and until doing the quantization.

I often mention this when I advocate to not declare the activation function once and re-use it several times. This is because at the various points in the computation where the functions are used, the observer would, in general, see different values, so now they work differently.

This new stateful nature also applies to simple things like adding tensors, usually just expressed as a + b. For this, PyTorch provides torch.nn.quantized.FloatFunctionalmodule. It is a usual Module with the twist that instead of using forward in the computation, there are several methods corresponding to basic operations, in our case .add.

So I took the residual module, which looked roughly like thisNote how it declares activations separately, which is a good thing!:

class ResBlock(torch.nn.Module):
  def __init__(self, ...):
     self.conv1 = ...
     self.act1 = ...
     self.conv2 = ...
     self.act2 = ...
  def forward(self, x):
     return self.act2(x + self.conv2(self.act1(self.conv1(x))))

And I added self.add = torch.nn.quantized.FloatFunctional() to __init__ and replaced the x + ... with self.add.add(x, ...). Done!

With the model thus prepared, I could add the quantization itself, which is simple enough following the PyTorch tutorial. At the bottom of the evaluation script, with the model all loaded, set to eval etc., I added the following and restarted the notebook kernel I was working with and ran all this.

model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.backends.quantized.engine = 'qnnpack'
# wrap in quantization of inputs / de-quantization of output)
model = torch.quantization.QuantWrapper(model)
# insert observers
torch.quantization.prepare(model, inplace=True)

and so later (after running the model a bit to get observations), I would call

torch.quantization.convert(model, inplace=True)

to get a model. Easy!

An unexpected error

And now I just had to run through a few batches of input.

preds = model(inp)

But what happened, was

ModuleAttributeError: 'ResBlock' object has no attribute 'add'


What went wrong? Maybe I had a typo in ResBlock?

In Jupyter, you can check very easily using ?? model.resblock1. But this was all right, no typos.

So this is where the PyTorch best practices comes in.

Serialization best practices

The PyTorch documentation has a note on serialization that contains - or consists of - a best practices section. It starts with

There are two main approaches for serializing and restoring a model. The first (recommended) saves and loads only the model parameters:

and then shows how this works using the state_dict() and load_state_dict() methods. The second method is to save and load the model.

The note provides the following rationale for preferring serializing parameters only:

However in [the case of saving the model], the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.

It turns out that this is quite an understatement and even in our very modest modification - hardly a serious refactor - we ran into the problem it alludes to.

What went wrong?

To get to the core of what went wrong, we have to think about what object is in Python. In a gross oversimplification, it is completely defined by its __dict__ attribute holding all the ("data") members and its __class__ attribute pointing to its type (so e.g. for Module instances, this will be Module, and for Module itself (a class) it will be type) . When we call a method, it typically is not in the __dict__ (it could be if we tried hard) but Python will automatically consult the __class__ to find the methods (or other things it could not find in the __dict__).

When deserializing the model (and the author of the model I used didn't follow the best practice advice) Python will construct an object by looking up the type for the __class__ and combining it with the deserialized __dict__. But the thing it (rightfully) does not do, is to call __init__ to set up the class (it should not, not lest because things might have been modified between __init__ and serialization or it might have side-effects we do not want). This means, that we get when we call the module, we are using the new forward but get the __dict__ as prepared by the original author's __init__ and subsequent training, without the new attribute add our modified __init__ added.

So this in a nutshell, this is why serializing PyTorch modules or generally objects in Python is dangerous: You very easily end up with something where the data attributes and the code are out of sync.

Maintaining compatibility

An obvious thing here - a drawback if you wish - is that we need to keep track of the configuration for setup in addition to the state dictionary. But you can easily serialize all the parameters along with the state dict if you want - just stick them into a joint dictionary.

But there are other advantages to not serializing the modules itself:

The obvious thing is that we can work with the state dictionary. We can load the state dictionary without having the Modules and we can inspect and modify the state dictionary if we changed something important.

The not quite as obvious thing is that the implementor or user can customize how modules process the state dict. This is in two ways:

  • For the users, there are hooks. Well, they're not very official, but so there is _register_load_state_dict_pre_hook which you can use to register hooks that process the state dict before it is used to update the model, and there is _register_state_dict_hook to register hooks that are called after the state dict has been collected and before it is returned from state_dict().

  • More importantly, though, implementors can override _load_from_state_dict . When the class has an attribute _version, this is saved as the version metadata in the state dict. With this, you can add conversions from older state dictionaries BatchNorm provides an example of how to do this, it roughly looks like this:

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)

        if (version is None or version < 2) and self.have_new_thing:
            new_key = prefix + 'new_thing_param'
            if new_key not in state_dict:
                state_dict[new_key] = ... # some default here

            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

So here, we check if the version is old and we need a new key, we add it before handing to the superclass (typically torch.nn.Module) usual processing.


So here we saw in great detail what went wrong when we saved the model rather than following the best practice to save just the parameters. My personal takeaway is that the pitfall that saving models offers is rather large and easy to fall into, and so we should really care to save models only as parameters and not Module classes.

I hope you enjoyed this little deep dive into a PyTorch best practice. More of this can be found in Piotr's and my imaginary PyTorch book and until it materializes in my no-nonsense PyTorch workshops. A special shout of thank you to Piotr, I couldn't do half the PyTorch things I do without him! Do send me an E-Mail tv@lernapparat.de if you want to become a better PyTorch programmer or I can help you with PyTorch and ML consulting.

I appreciate your comments and feedback at tv@lernapparat.de.