More memory-efficiency: conversion of weights for large models
Working with large language models has us wanting to make the most on invariably limited hardware. In the last two posts, we looked at initialization and loading. Today we consider converting weights and improve on memory efficiency for that by having a better grip on writing models. So here is how I reduced the memory needed to convert a 30 billion parameter LLM by a factor of 45.
We need a better way to save things
In the olden days - and for this article that is 7 years ago, when I started using PyTorch - the relation between GPU and CPU hardware capabilities were quite different to today: Back then, it seems, CPU memory was much more ample relative to GPU memory and model size. This probably is why PyTorch chose to keep the API for saving and loading tensors simple - you load or save the state dictionary with all model weights at once.
As we previously saw, with LLMs it is advantageous to load weights incrementally. But what about saving? One instance when this becomes relevant is when converting weights from one format to another, and my good friend and co-author Luca Antiga mentioned the need to reduce memory usage for that the other day.An improved version of the lazy loading can be found in the lit-lama repository under lit_llama/utils.py
, and the incremental saving was initially done for Lightning, thank you!
So we want approximately the following:
- Save tensors when they are ready and not keep them in memory,
- Eventually save a state dictionary that looks just like the one saved by
torch.save
, - Keep it all in one file.I remember the days when things came split across dozens of floppy disks, but I don't really need to go back.
How to go about saving
When we looked at lazy loading, we saw that PyTorch uses zip-files to store tensors along with pickled data. Just like for the loading, we can implement writing tensor contents (or rather the storage) separately with ease, but we must think about the API.
Given that context managers are a great way to handle files that are open beyond a single API callIn fact, for lit_llama
, we have improved the lazy_load
implementation to use a context manager. Indeed, this had been pointed out by Vivek Kalyan on twitter too., so we will use a context manager to handle the file with incremental_save(filename) as output_file:
.
Next we need to provide a way to store tensors early, free the memory, but let them keep their place in the data structure. The latter is rather situation-dependent, so we let the users do that themselves. output_file.store_early(t)
will store a tensor's data in the zip file and replace the tensor (and the storage) with a proxy that does the right thing when being pickled.
Finally, we need to have the equivalent of torch.save
for just the pickled structure. We call this output_file.save(obj)
. It let's the pickler do its thing and the proxy objects will pretend they were regular Tensor and Storage objects being pickled.
That is all we need.
Details of pickling
For the technical implementation, there are a few detailsSkip this section if you only want to use incremental saving in your code.:
In PyTorch, Tensor
s are stored in the "usual pickle way". The key is the __reduce_ex__
method, which will return a callable (which will be used to re-create the object) and the arguments passed to the callable (which will themselves be pickled). In this tuple, the first element will be a Storage
objectThere is some deprecation, namely of TypedStorage
currently in the works, but it'll likely be an UntypedStorage
then., which we will replace with a StorageProxy
.
The StorageProxy
will write the storage content to a separate member of the ZIP-archive that is the .pt
-file upon instantiation and capture the information used to pickle the Storage
.
As we learned when implementing lazy loading, Pickling the Storage
uses the out-of-band storage (for PyTorch in separate ZIP-file members) facility of the pickling protocol. Thus, we subclass pickle.Pickler
and define a persistent_id method that causes the StorageProxy
to be written as referring to the pre-saved data.
Of course, the details of the handling needs to be compatible with torch.save
, so we'll take the exact pickling logic from their persistent_id
override in torch/serialization.py method.
Just give me the source
The incremental storage need came up while working with Luca and my friends at Lightning AI and they actually let me work on this as part of a project we are doing - thank you!
You can find the source code in Lit-LLaMA's utils.py. When used on OpenAssistant's 30B LLaMA model, the script to convert weights the memory (RSS for the experts) used by the conversion went from 71.9GB to 1.6GB when targeting bfloat16
. For 32 bit weights, it makes the difference between triggering the OOM-Killer and working on my machine.
I hope to put out a new version of TorchHacks soon, which will include incremental loading and saving.
Conclusion
Large models have us need to care more about how we deal with the parameters, from initialization to loading to saving models. Happily, PyTorch's use of standard APIs and formats (Python's pickle and ZIP) lets us hack up ways to work with the saved checkpoints more efficiently.
If you need PyTorch training or consulting, don't hesitate to reach out tv@lernapparat.de. Also, all feedback is deeply appreciated.