Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointing #8

Merged
merged 12 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ __pycache__/
wandb/
lightning_logs/
.ruff_cache/
checkpoints/*
!checkpoints/.gitkeep
.DS_Store
tokenised.pt
slurm*
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
[![Pre-commit](https://github.com/tomogwen/LitGPT/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/tomogwen/LitGPT/actions/workflows/pre-commit.yml)  [![Tests](https://github.com/tomogwen/LitGPT/actions/workflows/tests.yml/badge.svg)](https://github.com/tomogwen/LitGPT/actions/workflows/tests.yml)
# ⚡️ Lightning Minimal GPT

This repo contains my efforts to learn how to create a (better than research code, aspiring to production quality) deep learning repository. It trains an implementation of Kaparthy's [minGPT](https://github.com/karpathy/minGPT) in PyTorch Lightning.

**MWE:** The code here grew from a minimal example of distributed training on a Slurm cluster. If you're interested in that, please see the [slurmformer branch](https://github.com/tomogwen/LitGPT/tree/slurmformer).
This repo trains a PyTorch implementation of [minGPT](https://github.com/karpathy/minGPT) using PyTorch Lightning. MinGPT is a minimal version of a [GPT language model](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) as taught in Kaparthy's [zero-to-hero course](https://www.youtube.com/watch?v=kCc8FmEb1nY&ab_channel=AndrejKarpathy). This codebase is a 'playground' repository where I can practice writing (hopefully!) better deep learning code.

## 🔧 Installation

Expand All @@ -25,12 +23,18 @@ To train the model (whilst in the conda environment):
litgpt fit --config configs/default.yaml
```

You can override and extend the config file using the CLI. Arguments like `--optimizer` and `--lr_scheduler` accept Torch classes. For example:
You can override and extend the config file using the CLI. Arguments like `--optimizer` and `--lr_scheduler` accept Torch classes. Run `litgpt fit --help` or read the [LightningCLI docs](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) for all options.


### 👀 Logging

We provide config files for [Tensorboard](https://www.tensorflow.org/tensorboard) and [Weights & Biases](https://wandb.ai/) monitoring. Training with the default config (as above) uses Tensorboard. You can monitor training by running:

```
litgpt fit --config configs/default.yaml --optimizer Adam
tensorboard --log-dir=checkpoints/
```

This uses the [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_intermediate.html#). All options can be seen by running `litgpt fit --help`.
To log with Weights & Biases use the `default_wandb.yaml` or `ddp.yaml` config files. You will need to authenticate for the first time using `wandb login`.

### 🚀 HPC

Expand Down
Empty file added checkpoints/.gitkeep
Empty file.
6 changes: 6 additions & 0 deletions configs/ddp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ trainer:
num_nodes: 2
devices: 2 # devices per node
strategy: ddp
logger:
class_path: WandbLogger
init_args:
log_model: all
project: LitGPT
save_dir: checkpoints/
model:
vocab_size: 65
n_embd: 384
Expand Down
5 changes: 5 additions & 0 deletions configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# lightning.pytorch==2.2.0.post0
trainer:
max_epochs: 10
logger:
class_path: TensorBoardLogger
init_args:
save_dir: checkpoints/
name: tensorboard
model:
vocab_size: 65
n_embd: 384
Expand Down
25 changes: 25 additions & 0 deletions configs/default_wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# lightning.pytorch==2.2.0.post0
trainer:
max_epochs: 10
logger:
class_path: WandbLogger
init_args:
log_model: all
project: LitGPT
save_dir: checkpoints/
model:
vocab_size: 65
n_embd: 384
n_heads: 6
num_blocks: 3
batch_size: 64
block_size: 256
dropout: 0.2
lr: 0.0003
data:
dataset_path: data/tinyshakespeare.txt
batch_size: 64
train_test_split: 0.95
train_dataloader_workers: 10
val_dataloader_workers: 10
block_size: 256
3 changes: 3 additions & 0 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ dependencies:
- hatch
- pip
- pre-commit
- protobuf
- pytest
- python=3.12.2
- pytorch
- pip: # pip install <x>
- "lightning[pytorch-extra]"
- tensorboard
- wandb
- -e .
2 changes: 2 additions & 0 deletions src/litgpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ def prepare_data(self, tokenised_path=None):
text = f.read()
chars = sorted(list(set(text)))

# tokeniser
stoi = {ch: i for i, ch in enumerate(chars)} # string to int

def encode(s):
return [stoi[c] for c in s] # encoder: maps strings to list of ints

# tokenise data
data = torch.tensor(encode(text), dtype=torch.long)
torch.save(data, self.hparams.tokenised_path)

Expand Down
2 changes: 1 addition & 1 deletion src/litgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def main():
torch.set_float32_matmul_precision("high")
LightningCLI(LitMinGPT, TinyShakespeareDataModule)
LightningCLI(LitMinGPT, TinyShakespeareDataModule, save_config_callback=None)


if __name__ == "__main__":
Expand Down
7 changes: 5 additions & 2 deletions src/litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,18 @@ def __init__(
self.save_hyperparameters()
self.decoder = TransformerDecoder(self.hparams)

def forward(self, inputs, target):
return self.decoder(inputs, target)

def training_step(self, batch, batch_idx):
x, y = batch
logits, loss = self.decoder(x, y)
logits, loss = self(x, y)
self.log("train_loss", loss, sync_dist=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
logits, loss = self.decoder(x, y)
logits, loss = self(x, y)
self.log("val_loss", loss, sync_dist=True)
return loss

Expand Down
Loading