Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relative paths and model configuration #28

Merged
merged 26 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__

# data
data
checkpoints
!data/shakespeare/prepare.py

# downloaded by scripts/compare.py
Expand Down
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ python scripts/convert_checkpoint.py \
You can now run inference:

```bash
python scripts/generate.py \
--prompt "Hello, my name is" \
--checkpoint_path checkpoints/lit-llama/7B/state_dict.pt \
--tokenizer_path checkpoints/lit-llama/tokenizer.model
python scripts/generate.py --prompt "Hello, my name is"
```

This will run using the 7B model and will require roughly 26 GB of GPU memory (A100 GPU).
Expand Down
18 changes: 13 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import sys
import time
import torch
from typing import Optional

import lightning as L
import torch

from model import LLaMA, LLaMAConfig
from model import LLaMA
from quantization.bnb import quantize as quantize_model
from tokenizer import Tokenizer

Expand Down Expand Up @@ -66,8 +68,9 @@ def main(
# compilation fails as it does not support torch.complex64 for RoPE
# compile: bool = False,
accelerator: str = "auto",
checkpoint_path: str = "/srv/data/checkpoints/llama/converted_nano/7B/state_dict.pth",
tokenizer_path: str = "/srv/data/checkpoints/llama/converted_nano/tokenizer.model",
checkpoint_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
model_size: str = "7B",
quantize: bool = False,
):
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
Expand All @@ -86,6 +89,11 @@ def main(
tokenizer_path: The tokenizer path to load.
quantize: Whether to quantize the model using the `LLM.int8()` method
"""
if not checkpoint_path:
checkpoint_path = f"./checkpoints/lit-llama/{model_size}/state_dict.pth"
if not tokenizer_path:
tokenizer_path = "./checkpoints/lit-llama/tokenizer.model"

assert os.path.isfile(checkpoint_path)
assert os.path.isfile(tokenizer_path)

Expand All @@ -94,14 +102,14 @@ def main(
if quantize:
print("Running quantization. This may take a minute ...")
# TODO: Initializing the model directly on the device does not work with quantization
model = LLaMA(LLaMAConfig())
model = LLaMA.from_name(model_size)
# The output layer can be sensitive to quantization, we keep it in default precision
model = quantize_model(model, skip=("lm_head", "output"))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
else:
with fabric.device:
model = LLaMA(LLaMAConfig())
model = LLaMA.from_name(model_size)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)

Expand Down
18 changes: 17 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,26 @@ def forward(self, x):
return x


llama_configs = {
"7B": dict(n_layer=32, n_head=32, n_embd=4096),
"13B": dict(n_layer=40, n_head=40, n_embd=5120),
"30B": dict(n_layer=60, n_head=52, n_embd=6656),
"65B": dict(n_layer=80, n_head=64, n_embd=8192),
}


@dataclass
class LLaMAConfig:
block_size: int = 4096 # 7B
block_size: int = 4096
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
n_embd: int = 4096

@classmethod
def from_name(cls, name: str):
return cls(**llama_configs[name])


class LLaMA(nn.Module):
def __init__(self, config):
Expand Down Expand Up @@ -200,3 +212,7 @@ def step(self, idx, targets):
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return loss

@classmethod
def from_name(cls, name: str):
return cls(LLaMAConfig.from_name(name))
6 changes: 3 additions & 3 deletions scripts/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def convert_state_dict(state_dict):

def meta_weights_for_nano_model(
*,
output_dir: Path,
ckpt_dir: Path = Path("/srv/data/checkpoints/llama/raw"),
tokenizer_path: Path = Path("/srv/data/checkpoints/llama/raw/tokenizer.model"),
output_dir: Path = Path("checkpoints/lit-llama"),
ckpt_dir: Path = Path("checkpoints/llama/"),
tokenizer_path: Path = Path("checkpoints/llama/tokenizer.model"),
model_size: str = "7B",
):
output_dir = output_dir / model_size
Expand Down
2 changes: 1 addition & 1 deletion scripts/prepare_shakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


def prepare(
tokenizer_path: str = "/srv/data/checkpoints/llama/converted_meta/tokenizer.model",
tokenizer_path: str = "checkpoints/llama/tokenizer.model",
destination_path: str = "data/shakespeare",
):
os.makedirs(destination_path, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main():

train_data, val_data = load_datasets()

config = LLaMAConfig
config = LLaMAConfig.from_name("7B")
config.block_size = block_size

with fabric.device:
Expand Down