Skip to content

Commit

Permalink
Merge pull request #7 from /issues/3/litcli
Browse files Browse the repository at this point in the history
Issues/3/litcli
  • Loading branch information
tomogwen authored Mar 7, 2024
2 parents 04075a6 + 054c509 commit 6b9d637
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 168 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ lightning_logs/
.ruff_cache/
.DS_Store
tokenised.pt
slurm*
46 changes: 21 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,37 @@ This repo contains my efforts to learn how to create a (better than research cod

**MWE:** The code here grew out of a minimal example of multi-node, multi-GPU training with PyTorch Lightning on a slurm cluster - if you're interested in that, please see the [slurmformer branch](https://github.com/tomogwen/LitGPT/tree/slurmformer).

## Goal
## 🔧 Installation

A non-exhaustive list of skills I'd like to learn or practice with this repo are below.

Machine learning engineering:
- [ ] Dealing with hyperparams nicely
- Config files + CLI
- Use an args objects or pass around many hparams?
- [ ] Dealing with different accelerators nicely
- should run easily on CPU, MPS, or (multi-)GPU.
To install dependencies and activate the conda environment:
```
conda env create -f env.yml
conda activate litgpt
```

Software development:
- [ ] Doc strings and type hints
- [X] Setting up github actions.
- [X] Writing tests.
- [X] Setting up pre-commit checks.
- [X] 'Packagify'-ing code.
- [X] Having good repo structure.
If developing, install pre-commit checks:
```
pre-commit install
```

## Installation
## 📈 Training

To install dependencies and activate the conda environment:
To train the model (whilst in the conda environment):
```
> conda env create -f env.yml
> conda activate litgpt
litgpt fit --config configs/default.yaml
```

If developing, install pre-commit checks:
You can override and extend the config file using the CLI. Arguments like `--optimizer` and `--lr_scheduler` accept Torch classes. For example:
```
> pre-commit install
litgpt fit --config configs/default.yaml --optimizer Adam
```

## Usage
This uses the [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_intermediate.html#). All options can be seen by running `litgpt fit --help`.

### 🚀 HPC

To train the model locally (whilst in the conda environment):
A script for [DDP training](https://pytorch.org/tutorials/beginner/ddp_series_theory.html) on Slurm-managed HPC is provided. Update the [shell script](scripts/slurm.sh) where required, make it executable (with `chmod +x scripts/slurm.sh`), and run it:
```
> train
scripts/slurm.sh
```
This script will generate and submit a slurm job using `sbatch`. Generating the script dynamically allows resource requests to be set once at the top of the file, then passed to both slurm (to allocate resources) and Lightning (to utilise them).
19 changes: 19 additions & 0 deletions configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# lightning.pytorch==2.2.0.post0
trainer:
max_epochs: 10
model:
vocab_size: 65
n_embd: 384
n_heads: 6
num_blocks: 3
batch_size: 64
block_size: 256
dropout: 0.2
lr: 0.0003
data:
dataset_path: data/tinyshakespeare.txt
batch_size: 64
train_test_split: 0.95
train_dataloader_workers: 10
val_dataloader_workers: 10
block_size: 256
28 changes: 28 additions & 0 deletions configs/slurm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# lightning.pytorch==2.2.0.post0
trainer:
max_epochs: 10
accelerator: gpu
num_nodes: 2
devices: 2 # devices per node
strategy: ddp
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
log_model: all
project: litgpt
model:
vocab_size: 65
n_embd: 384
n_heads: 6
num_blocks: 3
batch_size: 64
block_size: 256
dropout: 0.2
lr: 0.0003
data:
dataset_path: data/tinyshakespeare.txt
batch_size: 64
train_test_split: 0.95
train_dataloader_workers: 10
val_dataloader_workers: 10
block_size: 256
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Minimal GPT Implementation in PyTorch Lightning"
readme = "README.md"

[project.scripts]
train = "litgpt.train:main"
litgpt = "litgpt.main:main"

[project.urls]
Source = "https://github.com/tomogwen/litgpt"
Expand Down
67 changes: 52 additions & 15 deletions scripts/slurm.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,31 +1,68 @@
#!/bin/bash
############################
# -- Set the following! --
############################
# ---- Account Details ----
QOS=your_qos
ACCOUNT=your_account
# ---- Time Requested ----
hours=1
mins=0
seconds=0
# -- Resources Requested --
NODES=2
GPUS_PER_NODE=2
CPUS_PER_NODE=4
# ------ Conda Setup ------
CONDA_ENVS_DIR=/path/to/dir/for/envs/
CONDA_ENV_PATH="$CONDA_ENVS_DIR""litgpt/"
CONDA_PACKAGES_DIR=/path/to/store/conda-packages/
############################
# If you want to use wanbd run
# > wanbd login
# to add creds to your .netrc
############################

sed \
-e "s|\$QOS|$QOS|" \
-e "s|\$ACCOUNT|$ACCOUNT|" \
-e "s|\$hours|$hours|" \
-e "s|\$mins|$mins|" \
-e "s|\$seconds|$seconds|" \
-e "s|\$NODES|$NODES|" \
-e "s|\$GPUS_PER_NODE|$GPUS_PER_NODE|" \
-e "s|\$CPUS_PER_NODE|$CPUS_PER_NODE|g" \
-e "s|\$CONDA_ENVS_DIR|$CONDA_ENVS_DIR|" \
-e "s|\$CONDA_ENV_PATH|$CONDA_ENV_PATH|" \
-e "s|\$CONDA_PACKAGES_DIR|$CONDA_PACKAGES_DIR|" << 'EOF' | sbatch
#!/bin/bash
#SBATCH --qos $QOS
#SBATCH --account $ACCOUNT
#SBATCH --time $H:$M:$S
#SBATCH --time $hours:$mins:$seconds
#SBATCH --nodes $NODES
#SBATCH --gpus-per-node $GPUS_PER_NODE
#SBATCH --cpus-per-gpu 36
#SBATCH --cpus-per-gpu $CPUS_PER_NODE
#SBATCH --ntasks-per-node $GPUS_PER_NODE
# Enable shell debugging
set -x
# Load modules if present on cluster, e.g.:
# module purge
# module load torchvision
# Load conda
module purge
module load Miniconda3/4.10.3
# Set up venv
python -m venv --system-site-packages min-gpt-train
source min-gpt-train/bin/activate
# Setup conda
export CONDA_PKGS_DIRS=$CONDA_PACKAGES_DIR
eval "$(${EBROOTMINICONDA3}/bin/conda shell.bash hook)"
# do pip installs
pip install torchvision
pip install lightning
pip install wandb
# Install env if required
if [ ! -d "$CONDA_ENV_PATH" ]; then
conda env create -f env.yml --prefix=$CONDA_ENV_PATH
fi
# init wandb
wandb login $WANDB_API_KEY
# Activate env
conda activate ${CONDA_ENVS_DIR}
# run train script
srun train
srun litgpt fit --config configs/slurm.yaml --trainer.devices $NODES --trainer.devices $GPUS_PER_NODE --data.train_dataloader_workers $CPUS_PER_NODE --data.val_dataloader_workers $CPUS_PER_NODE
EOF
47 changes: 23 additions & 24 deletions src/litgpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,23 @@ def __getitem__(self, index):
class TinyShakespeareDataModule(L.LightningDataModule):
def __init__(
self,
dataset_path,
train_dataloader_workers=10,
val_dataloader_workers=10,
batch_size=32,
block_size=256,
train_test_split=0.95,
dataset_path: str = "data/tinyshakespeare.txt",
batch_size: int = 64,
train_test_split: float = 0.95,
train_dataloader_workers: int = 10,
val_dataloader_workers: int = 10,
block_size: int = 256,
):
super().__init__()
self.dataset_path = os.path.abspath(dataset_path)
self.data_dir = os.path.dirname(self.dataset_path)
self.tokenised_path = os.path.join(self.data_dir, "tokenised.pt")
self.batch_size = batch_size
self.block_size = block_size
self.train_test_split = train_test_split
self.train_dataloader_workers = train_dataloader_workers
self.val_dataloader_workers = val_dataloader_workers
self.save_hyperparameters()
data_dir = os.path.dirname(self.hparams.dataset_path)
self.hparams.tokenised_path = os.path.join(data_dir, "tokenised.pt")

def prepare_data(self, tokenised_path=None):
# runs once, called from main process
# tokenise data here

with open(self.dataset_path, "r", encoding="utf-8") as f:
with open(self.hparams.dataset_path, "r", encoding="utf-8") as f:
text = f.read()
chars = sorted(list(set(text)))

Expand All @@ -62,30 +57,34 @@ def encode(s):
return [stoi[c] for c in s] # encoder: maps strings to list of ints

data = torch.tensor(encode(text), dtype=torch.long)
torch.save(data, self.tokenised_path)
torch.save(data, self.hparams.tokenised_path)

def setup(self, stage):
# runs on every GPU
# stage is e.g., "fit", "test"
data = torch.load(self.tokenised_path)
data = torch.load(self.hparams.tokenised_path)

n = int(self.train_test_split * len(data))
self.train_data = TinyShakespeareDataSet(data[:n], block_size=self.block_size)
self.val_data = TinyShakespeareDataSet(data[n:], block_size=self.block_size)
n = int(self.hparams.train_test_split * len(data))
self.train_data = TinyShakespeareDataSet(
data[:n], block_size=self.hparams.block_size
)
self.val_data = TinyShakespeareDataSet(
data[n:], block_size=self.hparams.block_size
)

def train_dataloader(self):
# lightning should auto-add DistributedSampler for these dataloaders when required
return DataLoader(
self.train_data,
batch_size=self.batch_size,
num_workers=self.train_dataloader_workers,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.train_dataloader_workers,
persistent_workers=True,
)

def val_dataloader(self):
return DataLoader(
self.val_data,
batch_size=self.batch_size,
num_workers=self.val_dataloader_workers,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.val_dataloader_workers,
persistent_workers=True,
)
18 changes: 18 additions & 0 deletions src/litgpt/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# LitGPT
# Minimal GPT Implementation in PyTorch Lightning
# https://github.com/tomogwen/litgpt

import torch
from lightning.pytorch.cli import LightningCLI

from litgpt.data import TinyShakespeareDataModule
from litgpt.model import LitMinGPT


def main():
torch.set_float32_matmul_precision("high")
LightningCLI(LitMinGPT, TinyShakespeareDataModule)


if __name__ == "__main__":
main()
Loading

0 comments on commit 6b9d637

Please sign in to comment.