-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Issues/3/litcli
- Loading branch information
Showing
11 changed files
with
243 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ lightning_logs/ | |
.ruff_cache/ | ||
.DS_Store | ||
tokenised.pt | ||
slurm* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# lightning.pytorch==2.2.0.post0 | ||
trainer: | ||
max_epochs: 10 | ||
model: | ||
vocab_size: 65 | ||
n_embd: 384 | ||
n_heads: 6 | ||
num_blocks: 3 | ||
batch_size: 64 | ||
block_size: 256 | ||
dropout: 0.2 | ||
lr: 0.0003 | ||
data: | ||
dataset_path: data/tinyshakespeare.txt | ||
batch_size: 64 | ||
train_test_split: 0.95 | ||
train_dataloader_workers: 10 | ||
val_dataloader_workers: 10 | ||
block_size: 256 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# lightning.pytorch==2.2.0.post0 | ||
trainer: | ||
max_epochs: 10 | ||
accelerator: gpu | ||
num_nodes: 2 | ||
devices: 2 # devices per node | ||
strategy: ddp | ||
logger: | ||
class_path: pytorch_lightning.loggers.WandbLogger | ||
init_args: | ||
log_model: all | ||
project: litgpt | ||
model: | ||
vocab_size: 65 | ||
n_embd: 384 | ||
n_heads: 6 | ||
num_blocks: 3 | ||
batch_size: 64 | ||
block_size: 256 | ||
dropout: 0.2 | ||
lr: 0.0003 | ||
data: | ||
dataset_path: data/tinyshakespeare.txt | ||
batch_size: 64 | ||
train_test_split: 0.95 | ||
train_dataloader_workers: 10 | ||
val_dataloader_workers: 10 | ||
block_size: 256 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,68 @@ | ||
#!/bin/bash | ||
############################ | ||
# -- Set the following! -- | ||
############################ | ||
# ---- Account Details ---- | ||
QOS=your_qos | ||
ACCOUNT=your_account | ||
# ---- Time Requested ---- | ||
hours=1 | ||
mins=0 | ||
seconds=0 | ||
# -- Resources Requested -- | ||
NODES=2 | ||
GPUS_PER_NODE=2 | ||
CPUS_PER_NODE=4 | ||
# ------ Conda Setup ------ | ||
CONDA_ENVS_DIR=/path/to/dir/for/envs/ | ||
CONDA_ENV_PATH="$CONDA_ENVS_DIR""litgpt/" | ||
CONDA_PACKAGES_DIR=/path/to/store/conda-packages/ | ||
############################ | ||
# If you want to use wanbd run | ||
# > wanbd login | ||
# to add creds to your .netrc | ||
############################ | ||
|
||
sed \ | ||
-e "s|\$QOS|$QOS|" \ | ||
-e "s|\$ACCOUNT|$ACCOUNT|" \ | ||
-e "s|\$hours|$hours|" \ | ||
-e "s|\$mins|$mins|" \ | ||
-e "s|\$seconds|$seconds|" \ | ||
-e "s|\$NODES|$NODES|" \ | ||
-e "s|\$GPUS_PER_NODE|$GPUS_PER_NODE|" \ | ||
-e "s|\$CPUS_PER_NODE|$CPUS_PER_NODE|g" \ | ||
-e "s|\$CONDA_ENVS_DIR|$CONDA_ENVS_DIR|" \ | ||
-e "s|\$CONDA_ENV_PATH|$CONDA_ENV_PATH|" \ | ||
-e "s|\$CONDA_PACKAGES_DIR|$CONDA_PACKAGES_DIR|" << 'EOF' | sbatch | ||
#!/bin/bash | ||
#SBATCH --qos $QOS | ||
#SBATCH --account $ACCOUNT | ||
#SBATCH --time $H:$M:$S | ||
#SBATCH --time $hours:$mins:$seconds | ||
#SBATCH --nodes $NODES | ||
#SBATCH --gpus-per-node $GPUS_PER_NODE | ||
#SBATCH --cpus-per-gpu 36 | ||
#SBATCH --cpus-per-gpu $CPUS_PER_NODE | ||
#SBATCH --ntasks-per-node $GPUS_PER_NODE | ||
# Enable shell debugging | ||
set -x | ||
# Load modules if present on cluster, e.g.: | ||
# module purge | ||
# module load torchvision | ||
# Load conda | ||
module purge | ||
module load Miniconda3/4.10.3 | ||
# Set up venv | ||
python -m venv --system-site-packages min-gpt-train | ||
source min-gpt-train/bin/activate | ||
# Setup conda | ||
export CONDA_PKGS_DIRS=$CONDA_PACKAGES_DIR | ||
eval "$(${EBROOTMINICONDA3}/bin/conda shell.bash hook)" | ||
# do pip installs | ||
pip install torchvision | ||
pip install lightning | ||
pip install wandb | ||
# Install env if required | ||
if [ ! -d "$CONDA_ENV_PATH" ]; then | ||
conda env create -f env.yml --prefix=$CONDA_ENV_PATH | ||
fi | ||
# init wandb | ||
wandb login $WANDB_API_KEY | ||
# Activate env | ||
conda activate ${CONDA_ENVS_DIR} | ||
# run train script | ||
srun train | ||
srun litgpt fit --config configs/slurm.yaml --trainer.devices $NODES --trainer.devices $GPUS_PER_NODE --data.train_dataloader_workers $CPUS_PER_NODE --data.val_dataloader_workers $CPUS_PER_NODE | ||
EOF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# LitGPT | ||
# Minimal GPT Implementation in PyTorch Lightning | ||
# https://github.com/tomogwen/litgpt | ||
|
||
import torch | ||
from lightning.pytorch.cli import LightningCLI | ||
|
||
from litgpt.data import TinyShakespeareDataModule | ||
from litgpt.model import LitMinGPT | ||
|
||
|
||
def main(): | ||
torch.set_float32_matmul_precision("high") | ||
LightningCLI(LitMinGPT, TinyShakespeareDataModule) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.