Skip to content

Commit

Permalink
Add lazyload from torchhacks (Lightning-AI#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Apr 19, 2023
1 parent 07c022b commit 88ba008
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 26 deletions.
13 changes: 7 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice
from lit_llama.utils import EmptyInitOnDevice, lazy_load


@torch.no_grad()
Expand Down Expand Up @@ -108,15 +108,16 @@ def main(
fabric = L.Fabric(accelerator="cuda", devices=1)
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32

print("Loading model ...", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
print("Loading model ...", file=sys.stderr)
t0 = time.time()
model = LLaMA.from_name(model_size)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

checkpoint = lazy_load(checkpoint_path)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
model = fabric.setup_module(model)
Expand Down
20 changes: 10 additions & 10 deletions generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from generate import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA, LLaMAConfig
from lit_llama.utils import EmptyInitOnDevice
from lit_llama.utils import EmptyInitOnDevice, lazy_load
from scripts.prepare_alpaca import generate_prompt


Expand Down Expand Up @@ -60,22 +60,22 @@ def main(

dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32

print("Loading model ...", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
print("Loading model ...", file=sys.stderr)
t0 = time.time()
model = LLaMA(LLaMAConfig()) # TODO: Support different model sizes

# 1. Load the pretrained weights
pretrained_checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu"))
model.load_state_dict(pretrained_checkpoint, strict=False)
# 1. Load the pretrained weights
pretrained_checkpoint = lazy_load(pretrained_path, map_location=torch.device("cpu"))
model.load_state_dict(pretrained_checkpoint, strict=False)

# 2. Load the fine-tuned adapter weights
adapter_checkpoint = torch.load(adapter_path)
model.load_state_dict(adapter_checkpoint, strict=False)
# 2. Load the fine-tuned adapter weights
adapter_checkpoint = lazy_load(adapter_path)
model.load_state_dict(adapter_checkpoint, strict=False)

print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
model = fabric.setup_module(model)
Expand Down
20 changes: 10 additions & 10 deletions generate_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from generate import generate
from lit_llama import Tokenizer, LLaMA, LLaMAConfig
from lit_llama.lora import lora
from lit_llama.utils import EmptyInitOnDevice
from lit_llama.utils import EmptyInitOnDevice, lazy_load
from scripts.prepare_alpaca import generate_prompt

lora_r = 8
Expand Down Expand Up @@ -74,22 +74,22 @@ def main(
raise ValueError(f"{dtype} is not a valid dtype.")
dtype = dt

print("Loading model ...", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
print("Loading model ...", file=sys.stderr)
t0 = time.time()
model = LLaMA(LLaMAConfig()) # TODO: Support different model sizes

# 1. Load the pretrained weights
pretrained_checkpoint = torch.load(pretrained_path)
model.load_state_dict(pretrained_checkpoint, strict=False)
# 1. Load the pretrained weights
pretrained_checkpoint = lazy_load(pretrained_path)
model.load_state_dict(pretrained_checkpoint, strict=False)

# 2. Load the fine-tuned LoRA weights
lora_checkpoint = torch.load(lora_path)
model.load_state_dict(lora_checkpoint, strict=False)
# 2. Load the fine-tuned LoRA weights
lora_checkpoint = lazy_load(lora_path)
model.load_state_dict(lora_checkpoint, strict=False)

print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
model = fabric.setup_module(model)
Expand Down
137 changes: 137 additions & 0 deletions lit_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import functools
from pathlib import Path
import pickle
import warnings
from io import BytesIO

import torch
import torch.utils._device
Expand Down Expand Up @@ -50,6 +53,7 @@ def __init__(self, device=None, dtype=None, quantization_mode=None):
dtype: `torch.dtype` to work with
quantization_mode: optional string, quantization mode to work with, default `None`.
Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU)
`qptq.int4`, `gptq.int8`: GPTQ pre-quantized models
Example::
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
Expand Down Expand Up @@ -105,3 +109,136 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
):
kwargs["dtype"] = self.dtype
return func(*args, **kwargs)


# this is taken from torchhacks https://github.com/lernapparat/torchhacks


class NotYetLoadedTensor:
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
self.metatensor = metatensor
self.archiveinfo = archiveinfo
self.storageinfo = storageinfo
self.rebuild_args = rebuild_args

@classmethod
def rebuild(
cls,
storage,
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata=None,
archiveinfo=None,
):
rebuild_args = (
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata,
)
metatensor = torch._utils._rebuild_tensor_v2(
storage,
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata,
)
storageinfo = storage.archiveinfo
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)

def _load_tensor(self):
name, storage_cls, fn, device, size = self.storageinfo
dtype = self.metatensor.dtype

uts = (
self.archiveinfo.zipfile.get_storage_from_record(
f"data/{fn}",
size * torch._utils._element_size(dtype),
torch.UntypedStorage,
)
._typed_storage()
._untyped_storage
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
storage = torch.storage.TypedStorage(
wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True
)
tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
return tensor

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
loaded_args = [
(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args
]
res = func(*loaded_args, **kwargs)
# gc.collect would be costly here, maybe do it optionally
return res

def __getattr__(self, name):
# properties
## TODO: device, is_...??
## TODO: mH, mT, H, T, data, imag, real
## name ???
if name in {
"dtype",
"grad",
"grad_fn",
"layout",
"names",
"ndim",
"output_nr",
"requires_grad",
"retains_grad",
"shape",
"volatile",
}:
return getattr(self.metatensor, name)
if name in {"size"}:
return getattr(self.metatensor, name)
# materializing with contiguous is needed for quantization
if name in {"contiguous"}:
return getattr(self._load_tensor(), name)

raise AttributeError(f"{type(self)} does not have {name}")

def __repr__(self):
return f"NotYetLoadedTensor({repr(self.metatensor)})"


class LazyLoadingUnpickler(pickle.Unpickler):
def __init__(self, file, zipfile):
super().__init__(file)
self.zipfile = zipfile

def find_class(self, module, name):
if module == "torch._utils" and name == "_rebuild_tensor_v2":
res = super().find_class(module, name)
return functools.partial(NotYetLoadedTensor.rebuild, archiveinfo=self)
return super().find_class(module, name)

def persistent_load(self, pid):
name, cls, fn, device, size = pid
with warnings.catch_warnings():
warnings.simplefilter("ignore")
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
s.archiveinfo = pid
return s


def lazy_load(fn):
zf = torch._C.PyTorchFileReader(str(fn))
with BytesIO(zf.get_record("data.pkl")) as pkl:
mup = LazyLoadingUnpickler(pkl, zf)
sd = mup.load()
return sd
1 change: 1 addition & 0 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def device(self):
monkeypatch.setattr(generate.LLaMA, "from_name", model_mock)
load_mock = Mock()
monkeypatch.setattr(generate.torch, "load", load_mock)
monkeypatch.setattr(generate, "lazy_load", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
tokenizer_mock.return_value.decode.return_value = "foo bar baz"
Expand Down

0 comments on commit 88ba008

Please sign in to comment.