Memory-efficient loading of model weights
I have been intrigued by what computers can do for most of my life, and these days that includes large language models (LLMs) running on my own computer. It seems that I am not the only one. But one thing about these LLMs is that they are not kidding when they say large. The amounts of GPU and CPU memory to run them is huge. And, to be perfectly honest, I have a hunch that PyTorch, my favourite modelling library, is not always as economical with memory as it could be.
But then, we are not here to complain, but to investigate and improve things.Well, I am, and you don't have a comment function beyond sending me an email. Speaking of which, I love hearing from you at tv@lernapparat.de.
The problem
So what is going on when we load a model to run it?
The typical thing is
- we instantiate the model (I recently looked at initialization in this context),
- we load a state dictionary with the model weights from disk,
- we copy the state dictionary into the model weights.
And there is the problem in plain sight: We first load create the model, allocating memory for its weights, and then load the state dictionary with all the weights, so we allocate memory for each weight twice. That seems not so frugal!
Granted, if you are using my instantiation improvement to initialize the weights on the GPU directly, one will be GPU memory and the other CPU memory, but really, we would like our weights to not take up memory while going from disk to GPU, right?
So what could we do about it? Well, so there could be two ways:There could be more, but so these are the ones I came up with.
- don't allocate memory for the weights in the model and instead move the weights in the state dict, or
- don't load all the weights of a model into a giant state dict, just to copy them from there.
While the first is attractive (in particular if you might run the model on different GPUs, and I may have more to say on this in a blog post, but not this one), let us look at what is up with the state dict and if we can do things better there.
What does loading state dicts do, anyways?
So what happens when we load the state dict with torch.load
?We could just look at the code, but how much fun would that be?
Well, obviously the things are read from the file and then put into a Python structure, most typically a dictionary, with, and this is of particular interest to us, the tensors being reconstructed.
One thing that you might remember if you have been hanging around PyTorch for long enough is that PyTorch once switched from a "legacy" one to a more modern file format. And in fact, it is a zipfile:
unzip -l 7B/state_dict.pth
on a lit-llama 7B checkpoint gives us
Archive: 7B/state_dict.pth
Length Date Time Name
--------- ---------- ----- ----
26829 1980-00-00 00:00 state_dict/data.pkl
262144000 1980-00-00 00:00 state_dict/data/0
262144000 1980-00-00 00:00 state_dict/data/1
100663296 1980-00-00 00:00 state_dict/data/10
8192 1980-00-00 00:00 state_dict/data/100
100663296 1980-00-00 00:00 state_dict/data/101
33554432 1980-00-00 00:00 state_dict/data/102
90177536 1980-00-00 00:00 state_dict/data/103
90177536 1980-00-00 00:00 state_dict/data/104
90177536 1980-00-00 00:00 state_dict/data/105
...
8192 1980-00-00 00:00 state_dict/data/99
2 1980-00-00 00:00 state_dict/version
--------- -------
13476858063 229 files
Ha! So what do we see?
- There is a directory name in there (which is the extensionless name of what whoever saved the file chose as a filename).
- A
data.pkl
file. This points to the fact that PyTorch uses thepickle
module for serializing Python data structures. Ha! - Many
data/<number>
files of considerable size. These could be our tensors, right?
So these observations can inform our plan:
- Grab the
data.pkl
thing and get back the state dictionary, but - replace getting the
Tensor
s with something that looks similar enough to PyTorch but does not take up memory, - load the tensor from
data/<number>
only as needed.
Cool, let's get going!
Tensor-like
Let us start with ... the first? No, let us go for the middle one first. To get an idea, let's do a small experiment:
We create a Linear
layer without bias.
m = torch.nn.Linear(3, 5, bias=False)
As you might expect, the state_dict
of this contains only a weight
item.
So we can try to trick PyTorch by just passing some other object:
m.load_state_dict({'weight': object()})
this gives us an error
While copying the parameter named "weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'object'>
So we have to use torch.Tensor
or a Tensor-like object. What is a Tensor-like object?
We can look for where this error message is generated, it uses a function torch.overrides.is_tensor_like
to check. It turns out that an object is Tensor-like if it implements __torch_function__
.
Avid readers will notice that last time we used a __torch_function__
to globally hook into all calls to PyTorch. This is related to, but not entirely the same as we do here, where we only use it on our new type.
In fact, the PyTorch documentation mainly covers the use we aim for today. Looking at the documentation, we learn that __torch_function__
should be a class method with the signature
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
So here is the plan for something to load into state dicts:
- use objects of a new class, say
NotYetLoadedTensor
, - implement
__tensor_function__
that loads our tensors from the archive when they are passed to a PyTorch function. - maybe add a few other attributes (like
shape
/size
), we can make these up as we try loading our state dict.
OK, so if we ignore details hard enough, that is easy. But how can we build our fake state dict?
It would be neat if we could just look at the zip file contents and then infer what is which, but, obviously, that's not how it works, because we don't know which of the 227 data/<number>
files belongs to which state dict member. Also, people (in building particular libraries on top of PyTorch) invent all sorts structure and we don't know where exactly those tensors live. So it is best to do some Python unpickling.
Unpickler
classes
It turns out that PyTorch uses a Python-provided extension mechanism for loading tensors while letting Python do the "usual" unpickling. This works with the pickle.Unpickler
class. We can subclass that and then two methods are of particular interest to us:
- A method
persistent_load(pid)
is called when "external" objects are loaded. Writing a dummy class just printing the parameters gets us things like - Another method
find_class(module, name)
that looks for a class (or factory method) forname
inmodule
(passed as two strings).
Let's build a dummy to look at what kind of parameters they get:
import zipfile
import pickle
class MyUnpickler(pickle.Unpickler):
def find_class(self, module, name):
print("find_class", module, name)
return super().find_class(module, name)
def persistent_load(self, pid):
print("persistent_load", pid)
return super().persistent_load(pid)
zf = zipfile.ZipFile('/tmp/x.pt')
pklname, = [fn for fn in zf.namelist() if fn.endswith('data.pkl')]
with zf.open(pklname, 'r') as pkl:
mup = MyUnpickler(pkl)
sd = mup.load()
print(zf.namelist())
This throws an exception soon because the persistent_load
method defaults to (and PyTorch itself overrides it to load the tensors), but we get the following output:
find_class torch._utils _rebuild_tensor_v2
find_class torch FloatStorage
persistent_load ('storage', <class 'torch.FloatStorage'>, '0', 'cpu', 25)
So this is what seems to be going on:
Tensor
s get loaded by thetorch._utils._rebuild_tensor_v2
function we can inspect the source to see that it takes parametersstorage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
.- So the
Tensor
s don't hold their memory directly, but use astorage
object for thatin fact, PyTorch wants to get rid of some of this and we get deprecation warnings for looking too closely atFloatStorage
and friends indicating that only anUntyptedStorage
will remain. This might remind you ofFloatTensor
if you are really old. Oh well.. - The
persistent_load
sees a class (as item number 1, a "filename" in the zip archive'0'
, a device and a size (the example uses a 5x5 tensor I saved).
So we likely want to return a storage in load_persist
, but without memory. Easy: We'll instantiate a storage of the class given to us with the right size and but on the meta device
, a special dummy device that has exactly no backing data but allows to store and propagate metadata (hence the name).
To make things concrete, what would we need to get this snippet to run using MyUnpickler
?
m = torch.nn.Linear(5, 3)
torch.save(m.state_dict(), '/tmp/x.pt')
zf = zipfile.ZipFile('/tmp/x.pt')
pklname, = [fn for fn in zf.namelist() if fn.endswith('data.pkl')]
with zf.open(pklname, 'r') as pkl:
mup = MyUnpickler(pkl)
sd = mup.load()
m.load_state_dict(sd)
Replacing the persistent_load
is easy, if a bit tedious:
def persistent_load(self, pid):
_, cls, fn, dev, size = pid
return torch.storage.TypedStorage(size, dtype=cls().dtype, device='meta')
but now we get the following error
While copying the parameter named "weight", whose dimensions in the model are torch.Size([3, 5]) and whose
dimensions in the checkpoint are torch.Size([3, 5]), an exception occurred : ('Cannot copy out of meta tensor; no
data!',).
This is because we now have meta tensors in the state dict (do look at it instead of loading it into the module).
We can divert the tensor-building to return a NotYetLoadedClass
:
def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
meta_tensor = torch._utils._rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
return NotYetLoadedTensor(meta_tensor)
class MyUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch._utils' and name == '_rebuild_tensor_v2':
return my_rebuild_tensor
print("find_class", module, name)
return super().find_class(module, name)
...
Conveniently, we keep the meta tensor and pass it to our new Tensor-like class.
There we just implement __torch_function__
and - because loading will complain - also __getattr__
for getting the shape (which load_state_dict
uses).
In our __torch_function__
we'll replace NotYetLoaded arguments with "materialized" versions. But we just fill things with random numbers for now.We need some serious bookkeeping for getting the tensors which we put off for a paragraph or two...
So here we go:
class NotYetLoadedTensor:
def __init__(self, meta_tensor):
self.meta_tensor = meta_tensor
def _load_tensor(self):
# here I am cheating, we should read the right thing from the zip file
return torch.randn_like(self.meta_tensor, device="cpu")
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
margs = [a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a for a in args]
return func(*margs, *kwargs)
def __getattr__(self, name):
if name == "shape":
return self.meta_tensor.shape
So with this, we can load the state dict of our module! Supercool!
But wait, we don't get the proper weights yet, just random stuff.
Let us fix this. For this we need to keep track of the zip file (and also keep it open...) and the file name. Originally I thought this was a problem, but we as my_rebuild_tensor
gets the very Python storage object we create, we can just tack on these bits to there.This seems like no big deal, but in my initial attempts, I tried to get the storage from the tensor, and that is more difficult, as you will get a new Python storage object pointing to the same C++ storage object if you use t.storage()
on a tensor object. Alban mentioned that PyTorch devs are working on round-tripping Python objects in those cases (thank you!).
While we are at it, we make things into a lazy_load
function. So here we go:
import zipfile
import pickle
class NotYetLoadedTensor:
def __init__(self, meta_tensor, zip_file, zip_prefix, pid, rebuild_args):
self.meta_tensor = meta_tensor
self.zip_file = zip_file
self.zip_prefix = zip_prefix
self.pid = pid
self.rebuild_args = rebuild_args
def _load_tensor(self):
_, cls, fn, dev, size = self.pid
buffer = self.zip_file.read(f"{self.zip_prefix}/data/{fn}")
storage = cls.from_buffer(buffer, "native")
tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
return tensor
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
margs = [a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a for a in args]
return func(*margs, *kwargs)
def __getattr__(self, name):
if name == "shape":
return self.meta_tensor.shape
def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
meta_tensor = torch._utils._rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
rebuild_args = storage_offset, size, stride, requires_grad, backward_hooks, metadata
return NotYetLoadedTensor(meta_tensor, storage.zip_file, storage.zip_prefix, storage.original_pid, rebuild_args)
class MyUnpickler(pickle.Unpickler):
def __init__(self, pkl_file, zip_file, prefix):
super().__init__(pkl_file)
self.zip_file = zip_file
self.prefix = prefix
def find_class(self, module, name):
if module == 'torch._utils' and name == '_rebuild_tensor_v2':
return my_rebuild_tensor
print("find_class", module, name)
return super().find_class(module, name)
def persistent_load(self, pid):
_, cls, fn, dev, size = pid
s = torch.storage.TypedStorage(size, dtype=cls().dtype, device='meta')
s.original_pid = pid
s.zip_file = self.zip_file
s.zip_prefix = self.prefix
return s
def lazy_load(fn):
zf = zipfile.ZipFile(fn)
pklname, = [fn for fn in zf.namelist() if fn.endswith('data.pkl')]
prefix, _ = pklname.rsplit('/', 1)
with zf.open(pklname, 'r') as pkl:
mup = MyUnpickler(pkl, zf, prefix)
sd = mup.load()
return sd
Let's see if it works:
m = torch.nn.Linear(5, 3)
torch.save(m.state_dict(), '/tmp/x.pt')
sd = lazy_load('/tmp/x.pt')
m2 = torch.nn.Linear(5, 3)
m2.load_state_dict(sd)
x = torch.randn(2, 5)
torch.testing.assert_close(m(x), m2(x))
Awesome!
Conclusion
So we wondered about the memory waste of loading the state dict into memory and then copying into the model and ended up with a way to get the state dict but only load the members into memory as they are needed. Neat!
For those who are in a hurry (and thought to skip to the conclusion), I packaged this up in a small package torchhacks (on github and available via pip, where I want to collect a few of these bits worth having and which I'll use and so hopefully fix and improve stuff (for example, there are things I noticed while writing this).
If you need PyTorch superpowers, you can hire my consulting firm, MathInf, for either helping you out with mine as a service or book a training to build your own. Do reach out for this - or any feedback - at tv@lernapparat.de. Thank you!