Making model initialization faster

March 4, 2023

Recently, largeish pretrained models are the rage. But these take long to init. What can we do?

One great thing about hanging out online with people in the PyTorch community is that they keep asking the right questions. Recently, the inimitable Stas Bekman asked about skipping module initialization because for the models they (at Hugging Face) work with, just instantiating the module can take a long time. So let's look into this.Stas mentioned this bug report for HuggingFace transformers on the issue. Likely, people will find a much nicer solution than we do here. But nothing like a quick hack!

One of my favourite models to play with is Andrej Karpathy's geat NanoGPT.

A typical first step isOf course, I'm taking the XL model, because the larger your model, the more important you are. And larger models get nicer numbers here, too.:

import nanogpt.model as nanogpt
model = nanogpt.GPT.from_pretrained("gpt2-xl", {}).to("cuda")

but this is kind of slow. How slow?

%timeit model = nanogpt.GPT.from_pretrained("gpt2-xl", {}).to("cuda")


13.8 s ± 130 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

on my relatively nice computer (albeit loading the weights from a hard drive, not some fancy SSD).

More than ten seconds? That is too slow! Anyone reading my blog regularly (or just stopping by) knows I am a great fan of fast when it comes to PyTorch.

So what is the problem? When we create a model, PyTorch allocates the parameters on the CPU by default. Then it initializes the parameters on the CPU. And then we move them to CUDA. This is wired relatively deeply into PyTorch, so it is hard to change directly.

However, over time, PyTorch has created all sorts of ways to hook into how things get executed. In our case, using the __torch_function__ mechanism which allows us to see (and divert) all calls into PyTorch. Happily, PyTorch makes this easy by providing a TorchFunctionMode context manager class that we can subclass to get to the __torch_function__ immediately. In our __torch_function__ method, we get the function that is called along with positional and keyword arguments.

What do we do? PyTorch modules (from the standard torch.nn or other well-written modules) initialize their weights by calling the functions in torch.nn.init. So we check whether the function being called has a __module__ attribute of 'torch.nn.init' and if so, we skip calling the function and (because they're all in-place function) return the tensor keyword-argument or first argument. For anything else, we check whether the function creates a tensor on a given device (happily PyTorch has a context manager for that, which we can adapt to do the skipping and thus it has a list of operationsfor it)This is an adaptation of PyTorch's torch.utils._device.DeviceContext please see the PyTorch license for details of the BSD-style license.

We wrap this into a nice context manager:

import torch.utils._device

class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
    def __init__(self, device=None):
        self.device = device

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        if getattr(func, '__module__', None) == 'torch.nn.init':
            if 'tensor' in kwargs:
                return kwargs['tensor']
                return args[0]
        if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
            kwargs['device'] = self.device
        return func(*args, **kwargs)

And then we can call

with EmptyInitOnDevice("cuda"):
   model = nanogpt.GPT.from_pretrained("gpt2-xl", {})

feels much faster. But my motto is it's not optimization until we measure, so:

with EmptyInitOnDevice("cuda"):
    %timeit model = nanogpt.GPT.from_pretrained("gpt2-xl", {})


2.26 s ± 10.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So we are more than 6 times faster. Yay!

Note that this only works when your models use torch.nn.init to initialize weights. If you want this to work with models using e.g. torch.randn directly, you need to divert those calls to torch.empty.

N.B. Maybe this will be obsolete in a short while if PyTorch incorporates something like this in the standard tooling. torch.utils._device._device_constructors() gives you the list of functions that construct tensors in PyTorch (including some obscure ones), so if you can handle the ones listed there, you are all set.

Also note that much of the time difference is from the init running on the CPU. Using the torch.utils._device.DeviceContext, you can get the full init but construct your tensors directly on the GPU for very similar performance gains.

One thing this does not solve is initialization spread on multiple GPUs for models that do not fit a single GPU. But that is for another day.

If you like your models fast, I can help you with training and consulting via my company, MathInf GmbH. As always, I hope this is useful to you, please send questions and comments to tv@lernapparat.de.