Skip to content

Commit

Permalink
Alpaca finetuning with LoRA 1/n (Lightning-AI#52)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Antiga <luca@lightning.ai>
  • Loading branch information
awaelchli and lantiga authored Mar 29, 2023
1 parent 330de5a commit f808df1
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ build
# data
data
checkpoints
out
!data/shakespeare/prepare.py

# downloaded by our tests
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ See `python generate.py --help` for more options.

&nbsp;

## Finetune the model

We provide a simple training script in `finetune.py` that instruction-tunes a pretrained model on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and also integrates [LoRA](https://arxiv.org/abs/2106.09685).

1. Download the data and generate a instruction tuning dataset:

```bash
python scripts/prepare_alpaca.py
```
2. Run the finetuning script

```bash
python finetune.py
```

It is expected that you have downloaded the pretrained weights as described above.
The finetuning requires a machine with at least 4 GPUs with 24 GB memory.
Coming soon: LoRA + quantization for training on a single GPU!


## Get involved!

We're in a quest towards fully open source AI.
Expand Down
199 changes: 199 additions & 0 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""
Instruction-tuning with LoRA on the Alpaca dataset.
"""
import os
import time

import lightning as L
import numpy as np
import torch
from lightning.fabric.strategies import DeepSpeedStrategy

from generate import generate
from lit_llama.lora import mark_only_lora_as_trainable, with_lora
from lit_llama.model import LLaMA, LLaMAConfig
from lit_llama.tokenizer import Tokenizer
from scripts.prepare_alpaca import generate_prompt

out_dir = "out"
eval_interval = 100
eval_iters = 100
log_interval = 1

# Hyperparameters
learning_rate = 2e-5
batch_size = 32
micro_batch_size = 4
gradient_accumulation_steps = batch_size // micro_batch_size

# TODO: Limit to 3 epochs
max_iters = 100000000 # 50000 * 3 // 4 // batch_size
weight_decay = 0.0
block_size = 256

# TODO: LR scheduling
warmup_steps = 100

ds_config = {
"gradient_accumulation_steps": gradient_accumulation_steps,
"train_micro_batch_size_per_gpu": micro_batch_size,
"bf16": {
"enabled": True
},
"gradient_clipping": 1,
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "none"
},
"offload_optimizer": {
"device": "none"
},
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"contiguous_gradients": True
}
}


def main() -> None:
fabric = L.Fabric(
accelerator="cuda",
devices=4,
strategy=DeepSpeedStrategy(config=ds_config)
)
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)

if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)

train_data, val_data = load_datasets()

config = LLaMAConfig.from_name("7B")
config.block_size = block_size

with fabric.device, with_lora(r=8, alpha=32, dropout=0.1, enabled=True):
model = LLaMA(config)

checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")

# strict=False because missing keys due to lora weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)
mark_only_lora_as_trainable(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
model, optimizer = fabric.setup(model, optimizer)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iters, last_epoch=-1)

train(fabric, model, optimizer, train_data, val_data)


def train(
fabric: L.Fabric,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_data: np.ndarray,
val_data: np.ndarray,
) -> None:
"""The training loop.
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
for iter_num in range(max_iters):

# evaluate the loss on train/val sets and write checkpoints
if iter_num % eval_interval == 0:
val_loss = validate(fabric, model, val_data)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
# TODO: Save with Fabric
# print(f"Saving checkpoint to {out_dir}")
# checkpoint = {"model": model, "optimizer": optimizer, "iter": iter, "val_loss": val_loss}
#fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pt"), checkpoint)
fabric.barrier()


t0 = time.time()

input_ids, targets = get_batch(fabric, train_data)

logits = model(input_ids)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

fabric.backward(loss)

if (iter_num + 1) % gradient_accumulation_steps == 0:
optimizer.step()
# scheduler.step()
optimizer.zero_grad()

dt = time.time() - t0
if iter_num % log_interval == 0:
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")


def generate_response(model, instruction):
tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
sample = {"instruction": instruction, "input": ""}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=True)
encoded = encoded[None, :] # add batch dimension
encoded = encoded.to(model.device)

output = generate(
model,
idx=encoded,
max_seq_length=block_size,
max_new_tokens=100,
)
output = tokenizer.decode(output[0].cpu())
return output.split("### Response:")[1].strip()


@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
input_ids, targets = get_batch(fabric, val_data)
logits = model(input_ids)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
losses[k] = loss.item()
out = losses.mean()

# produce an example:
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
fabric.print(instruction)
fabric.print(generate_response(model, instruction))

model.train()
return out


def get_batch(fabric: L.Fabric, data: list, pad_id: int = 0):
ix = torch.randint(len(data), (micro_batch_size,))

def pad(x):
# TODO: optimize this to pad to the next multiple of 8 or so?
n = block_size - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))

def shift_right(x):
return x[1:]

x = torch.stack([pad(data[i]["input_ids"]) for i in ix]).type(torch.int64)
y = torch.stack([pad(shift_right(data[i]["labels"])) for i in ix]).type(torch.int64)
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
return x, y


def load_datasets(data_dir: str = "data/alpaca"):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
return train_data, val_data


if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
main()
5 changes: 4 additions & 1 deletion lit_llama/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,12 @@ def __init__(self, config: llama.LLaMAConfig, rope_cache: torch.Tensor) -> None:


@contextmanager
def with_lora(r, alpha: float, dropout: float):
def with_lora(r, alpha, dropout, enabled: bool = True):
"""A context manager under which you can instantiate the model with LLoRA.
"""
if not enabled:
yield
return

CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)

Expand Down
13 changes: 12 additions & 1 deletion lit_llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,24 @@ def vocab_size(self) -> int:
return self.processor.vocab_size()

def encode(
self, string: str, bos: bool = True, eos: bool = False, device: Optional[torch.device] = None
self,
string: str,
bos: bool = True,
eos: bool = False,
max_length: int = -1,
pad: bool = False,
device: Optional[torch.device] = None
) -> torch.Tensor:
tokens = self.processor.encode(string)
if bos:
tokens = [self.bos_id] + tokens
if eos:
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
if pad and len(tokens) < max_length:
tokens += [self.pad_id] * (max_length - len(tokens))

return torch.tensor(tokens, dtype=torch.int, device=device)

def decode(self, tokens: torch.Tensor) -> str:
Expand Down
Loading

0 comments on commit f808df1

Please sign in to comment.