Lernapparat

Memory-efficient loading of model weights

April 18, 2023

I have been intrigued by what computers can do for most of my life, and these days that includes large language models (LLMs) running on my own computer. It seems that I am not the only one. But one thing about these LLMs is that they are not kidding when they say large. The amounts of GPU and CPU memory to run them is huge. And, to be perfectly honest, I have a hunch that PyTorch, my favourite modelling library, is not always as economical with memory as it could be.

But then, we are not here to complain, but to investigate and improve things.Well, I am, and you don't have a comment function beyond sending me an email. Speaking of which, I love hearing from you at tv@lernapparat.de.

The problem

So what is going on when we load a model to run it?

The typical thing is

And there is the problem in plain sight: We first load create the model, allocating memory for its weights, and then load the state dictionary with all the weights, so we allocate memory for each weight twice. That seems not so frugal!

Granted, if you are using my instantiation improvement to initialize the weights on the GPU directly, one will be GPU memory and the other CPU memory, but really, we would like our weights to not take up memory while going from disk to GPU, right?

So what could we do about it? Well, so there could be two ways:There could be more, but so these are the ones I came up with.

  • don't allocate memory for the weights in the model and instead move the weights in the state dict, or
  • don't load all the weights of a model into a giant state dict, just to copy them from there.

While the first is attractive (in particular if you might run the model on different GPUs, and I may have more to say on this in a blog post, but not this one), let us look at what is up with the state dict and if we can do things better there.

What does loading state dicts do, anyways?

So what happens when we load the state dict with torch.load?We could just look at the code, but how much fun would that be? Well, obviously the things are read from the file and then put into a Python structure, most typically a dictionary, with, and this is of particular interest to us, the tensors being reconstructed.

One thing that you might remember if you have been hanging around PyTorch for long enough is that PyTorch once switched from a "legacy" one to a more modern file format. And in fact, it is a zipfile:

unzip -l 7B/state_dict.pth

on a lit-llama 7B checkpoint gives us

Archive:  7B/state_dict.pth
  Length      Date    Time    Name
---------  ---------- -----   ----
    26829  1980-00-00 00:00   state_dict/data.pkl
262144000  1980-00-00 00:00   state_dict/data/0
262144000  1980-00-00 00:00   state_dict/data/1
100663296  1980-00-00 00:00   state_dict/data/10
     8192  1980-00-00 00:00   state_dict/data/100
100663296  1980-00-00 00:00   state_dict/data/101
 33554432  1980-00-00 00:00   state_dict/data/102
 90177536  1980-00-00 00:00   state_dict/data/103
 90177536  1980-00-00 00:00   state_dict/data/104
 90177536  1980-00-00 00:00   state_dict/data/105
...
     8192  1980-00-00 00:00   state_dict/data/99
        2  1980-00-00 00:00   state_dict/version
---------                     -------
13476858063                     229 files

Ha! So what do we see?

  • There is a directory name in there (which is the extensionless name of what whoever saved the file chose as a filename).
  • A data.pkl file. This points to the fact that PyTorch uses the pickle module for serializing Python data structures. Ha!
  • Many data/<number> files of considerable size. These could be our tensors, right?

So these observations can inform our plan:

  • Grab the data.pkl thing and get back the state dictionary, but
  • replace getting the Tensors with something that looks similar enough to PyTorch but does not take up memory,
  • load the tensor from data/<number> only as needed.

Cool, let's get going!

Tensor-like

Let us start with ... the first? No, let us go for the middle one first. To get an idea, let's do a small experiment:

We create a Linear layer without bias.

m = torch.nn.Linear(3, 5, bias=False)

As you might expect, the state_dict of this contains only a weight item. So we can try to trick PyTorch by just passing some other object:

m.load_state_dict({'weight': object()})

this gives us an error

While copying the parameter named "weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'object'>

So we have to use torch.Tensor or a Tensor-like object. What is a Tensor-like object?

We can look for where this error message is generated, it uses a function torch.overrides.is_tensor_liketo check. It turns out that an object is Tensor-like if it implements __torch_function__.

Avid readers will notice that last time we used a __torch_function__ to globally hook into all calls to PyTorch. This is related to, but not entirely the same as we do here, where we only use it on our new type. In fact, the PyTorch documentation mainly covers the use we aim for today. Looking at the documentation, we learn that __torch_function__ should be a class method with the signature

@classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):

So here is the plan for something to load into state dicts:

  • use objects of a new class, say NotYetLoadedTensor,
  • implement __tensor_function__ that loads our tensors from the archive when they are passed to a PyTorch function.
  • maybe add a few other attributes (like shape / size), we can make these up as we try loading our state dict.

OK, so if we ignore details hard enough, that is easy. But how can we build our fake state dict? It would be neat if we could just look at the zip file contents and then infer what is which, but, obviously, that's not how it works, because we don't know which of the 227 data/<number> files belongs to which state dict member. Also, people (in building particular libraries on top of PyTorch) invent all sorts structure and we don't know where exactly those tensors live. So it is best to do some Python unpickling.

Unpickler classes

It turns out that PyTorch uses a Python-provided extension mechanism for loading tensors while letting Python do the "usual" unpickling. This works with the pickle.Unpickler class. We can subclass that and then two methods are of particular interest to us:

  • A method persistent_load(pid)is called when "external" objects are loaded. Writing a dummy class just printing the parameters gets us things like
  • Another method find_class(module, name) that looks for a class (or factory method) for name in module (passed as two strings).

Let's build a dummy to look at what kind of parameters they get:

import zipfile
import pickle

class MyUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        print("find_class", module, name)
        return super().find_class(module, name)
    def persistent_load(self, pid):
        print("persistent_load", pid)
        return super().persistent_load(pid)

zf = zipfile.ZipFile('/tmp/x.pt')
pklname, = [fn for fn in zf.namelist() if fn.endswith('data.pkl')]
with zf.open(pklname, 'r') as pkl:
    mup = MyUnpickler(pkl)
    sd = mup.load()

print(zf.namelist())

This throws an exception soon because the persistent_load method defaults to (and PyTorch itself overrides it to load the tensors), but we get the following output:

find_class torch._utils _rebuild_tensor_v2
find_class torch FloatStorage
persistent_load ('storage', <class 'torch.FloatStorage'>, '0', 'cpu', 25)

So this is what seems to be going on:

  • Tensors get loaded by the torch._utils._rebuild_tensor_v2 function we can inspect the source to see that it takes parameters storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None.
  • So the Tensors don't hold their memory directly, but use a storage object for thatin fact, PyTorch wants to get rid of some of this and we get deprecation warnings for looking too closely at FloatStorage and friends indicating that only an UntyptedStorage will remain. This might remind you of FloatTensor if you are really old. Oh well..
  • The persistent_load sees a class (as item number 1, a "filename" in the zip archive '0', a device and a size (the example uses a 5x5 tensor I saved).

So we likely want to return a storage in load_persist, but without memory. Easy: We'll instantiate a storage of the class given to us with the right size and but on the meta device, a special dummy device that has exactly no backing data but allows to store and propagate metadata (hence the name).

To make things concrete, what would we need to get this snippet to run using MyUnpickler?

m = torch.nn.Linear(5, 3)
torch.save(m.state_dict(), '/tmp/x.pt')

zf = zipfile.ZipFile('/tmp/x.pt')
pklname, = [fn for fn in zf.namelist() if fn.endswith('data.pkl')]
with zf.open(pklname, 'r') as pkl:
    mup = MyUnpickler(pkl)
    sd = mup.load()

m.load_state_dict(sd)

Replacing the persistent_load is easy, if a bit tedious:

    def persistent_load(self, pid):
        _, cls, fn, dev, size = pid
        return torch.storage.TypedStorage(size, dtype=cls().dtype, device='meta')

but now we get the following error

While copying the parameter named "weight", whose dimensions in the model are torch.Size([3, 5]) and whose 
dimensions in the checkpoint are torch.Size([3, 5]), an exception occurred : ('Cannot copy out of meta tensor; no 
data!',).

This is because we now have meta tensors in the state dict (do look at it instead of loading it into the module).

We can divert the tensor-building to return a NotYetLoadedClass:

def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
    meta_tensor = torch._utils._rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
    return NotYetLoadedTensor(meta_tensor)

class MyUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch._utils' and name == '_rebuild_tensor_v2':
            return my_rebuild_tensor
        print("find_class", module, name)
        return super().find_class(module, name)
    ...

Conveniently, we keep the meta tensor and pass it to our new Tensor-like class. There we just implement __torch_function__ and - because loading will complain - also __getattr__ for getting the shape (which load_state_dict uses). In our __torch_function__ we'll replace NotYetLoaded arguments with "materialized" versions. But we just fill things with random numbers for now.We need some serious bookkeeping for getting the tensors which we put off for a paragraph or two...

So here we go:

class NotYetLoadedTensor:
    def __init__(self, meta_tensor):
        self.meta_tensor = meta_tensor

    def _load_tensor(self):
        # here I am cheating, we should read the right thing from the zip file
        return torch.randn_like(self.meta_tensor, device="cpu")

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        margs = [a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a for a in args]
        return func(*margs, *kwargs)

    def __getattr__(self, name):
        if name == "shape":
            return self.meta_tensor.shape

So with this, we can load the state dict of our module! Supercool!

But wait, we don't get the proper weights yet, just random stuff.

Let us fix this. For this we need to keep track of the zip file (and also keep it open...) and the file name. Originally I thought this was a problem, but we as my_rebuild_tensor gets the very Python storage object we create, we can just tack on these bits to there.This seems like no big deal, but in my initial attempts, I tried to get the storage from the tensor, and that is more difficult, as you will get a new Python storage object pointing to the same C++ storage object if you use t.storage() on a tensor object. Alban mentioned that PyTorch devs are working on round-tripping Python objects in those cases (thank you!).

While we are at it, we make things into a lazy_load function. So here we go:

import zipfile
import pickle

class NotYetLoadedTensor:
    def __init__(self, meta_tensor, zip_file, zip_prefix, pid, rebuild_args):
        self.meta_tensor = meta_tensor
        self.zip_file = zip_file
        self.zip_prefix = zip_prefix
        self.pid = pid
        self.rebuild_args = rebuild_args

    def _load_tensor(self):
        _, cls, fn, dev, size = self.pid
        buffer = self.zip_file.read(f"{self.zip_prefix}/data/{fn}")
        storage = cls.from_buffer(buffer, "native")
        tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
        return tensor

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        margs = [a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a for a in args]
        return func(*margs, *kwargs)

    def __getattr__(self, name):
        if name == "shape":
            return self.meta_tensor.shape

def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
    meta_tensor = torch._utils._rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
    rebuild_args = storage_offset, size, stride, requires_grad, backward_hooks, metadata
    return NotYetLoadedTensor(meta_tensor, storage.zip_file, storage.zip_prefix, storage.original_pid, rebuild_args)

class MyUnpickler(pickle.Unpickler):
    def __init__(self, pkl_file, zip_file, prefix):
        super().__init__(pkl_file)
        self.zip_file = zip_file
        self.prefix = prefix

    def find_class(self, module, name):
        if module == 'torch._utils' and name == '_rebuild_tensor_v2':
            return my_rebuild_tensor
        print("find_class", module, name)
        return super().find_class(module, name)

    def persistent_load(self, pid):
        _, cls, fn, dev, size = pid
        s = torch.storage.TypedStorage(size, dtype=cls().dtype, device='meta')
        s.original_pid = pid
        s.zip_file = self.zip_file
        s.zip_prefix = self.prefix
        return s

def lazy_load(fn):
    zf = zipfile.ZipFile(fn)
    pklname, = [fn for fn in zf.namelist() if fn.endswith('data.pkl')]
    prefix, _ = pklname.rsplit('/', 1)
    with zf.open(pklname, 'r') as pkl:
        mup = MyUnpickler(pkl, zf, prefix)
        sd = mup.load()
    return sd

Let's see if it works:

m = torch.nn.Linear(5, 3)
torch.save(m.state_dict(), '/tmp/x.pt')
sd = lazy_load('/tmp/x.pt')
m2 = torch.nn.Linear(5, 3)
m2.load_state_dict(sd)

x = torch.randn(2, 5)
torch.testing.assert_close(m(x), m2(x))

Awesome!

Conclusion

So we wondered about the memory waste of loading the state dict into memory and then copying into the model and ended up with a way to get the state dict but only load the members into memory as they are needed. Neat!

For those who are in a hurry (and thought to skip to the conclusion), I packaged this up in a small package torchhacks (on github and available via pip, where I want to collect a few of these bits worth having and which I'll use and so hopefully fix and improve stuff (for example, there are things I noticed while writing this).

If you need PyTorch superpowers, you can hire my consulting firm, MathInf, for either helping you out with mine as a service or book a training to build your own. Do reach out for this - or any feedback - at tv@lernapparat.de. Thank you!