Skip to content

Commit

Permalink
Merge pull request #11 from /issues/10/typehints
Browse files Browse the repository at this point in the history
issues/10/typehints
  • Loading branch information
tomogwen authored Jul 11, 2024
2 parents 7a94ffa + d3b990d commit 59fc365
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 113 deletions.
10 changes: 6 additions & 4 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ dependencies:
- python=3.12.2
- pytorch
- pip: # pip install <x>
- "lightning[pytorch-extra]"
- tensorboard
- wandb
- -e .
- beartype
- jaxtyping
- "lightning[pytorch-extra]"
- tensorboard
- wandb
- -e .
74 changes: 47 additions & 27 deletions src/litgpt/data.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
# Torch Dataset and LightningDataModule wrapper for Tiny Shakespeare dataset

import os

import lightning as L
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset


class TinyShakespeareDataSet(Dataset):
def __init__(self, raw_text, block_size=256):
"""
Torch Dataset for the Tiny Shakespeare dataset.
"""

def __init__(self, raw_text: Tensor, block_size: int = 256):
"""
Args:
raw_text: Tensor of tokens, processed in TinyShakespeareDataModule.
block_size: number of tokens in each training sample.
"""
super().__init__()
self.raw_text = raw_text
self.xs = torch.stack(
self.raw_text: Tensor = raw_text
self.xs: Tensor = torch.stack(
[raw_text[i : i + block_size] for i in range(len(raw_text) - block_size)]
)
self.ys = torch.stack(
self.ys: Tensor = torch.stack(
[
raw_text[i + 1 : i + block_size + 1]
for i in range(len(raw_text) - block_size)
]
)

def __len__(self):
def __len__(self) -> int:
return len(self.xs)

def __getitem__(self, index):
def __getitem__(self, index) -> int:
return self.xs[index], self.ys[index]


class TinyShakespeareDataModule(L.LightningDataModule):
"""
Lightning DataModule for the Tiny Shakespeare dataset.
"""

def __init__(
self,
dataset_path: str = "data/tinyshakespeare.txt",
Expand All @@ -38,52 +50,60 @@ def __init__(
val_dataloader_workers: int = 10,
block_size: int = 256,
):
"""
Args:
dataset_path: path to a text file (typically) containing the complete works of Shakespeare.
batch_size: number of datapoints given on each sample from the dataset.
train_test_split: number in [0, 1] representing what proportion of the data is used as training data.
train_dataloader_workers: passed to the train dataloader num_workers arg.
val_dataloader_workers: passed to the val dataloader num_workers arg.
block_size: number of tokens in each training sample.
"""
super().__init__()
self.save_hyperparameters()
data_dir = os.path.dirname(self.hparams.dataset_path)
self.hparams.tokenised_path = os.path.join(data_dir, "tokenised.pt")
data_dir: str = os.path.dirname(self.hparams.dataset_path)
self.hparams.tokenised_path: str = os.path.join(data_dir, "tokenised.pt")

def prepare_data(self, tokenised_path=None):
# runs once, called from main process
# tokenise data here

"""Loads data from txt file, tokenises it, then saves as a Tensor. Runs once from parent process."""
with open(self.hparams.dataset_path, "r", encoding="utf-8") as f:
text = f.read()
chars = sorted(list(set(text)))
text: str = f.read()
chars: list[str] = sorted(list(set(text)))

# tokeniser
stoi = {ch: i for i, ch in enumerate(chars)} # string to int
stoi: dict = {ch: i for i, ch in enumerate(chars)} # string to int

def encode(s):
def encode(s: str) -> list[int]:
return [stoi[c] for c in s] # encoder: maps strings to list of ints

# tokenise data
data = torch.tensor(encode(text), dtype=torch.long)
data: torch.tensor = torch.tensor(encode(text), dtype=torch.long)
torch.save(data, self.hparams.tokenised_path)

def setup(self, stage):
# runs on every GPU
# stage is e.g., "fit", "test"
data = torch.load(self.hparams.tokenised_path)
"""Loads Tensor of Tokens and splits into train/val datasets. Runs on each GPU if using DDP."""
data: torch.tensor = torch.load(self.hparams.tokenised_path)

n = int(self.hparams.train_test_split * len(data))
self.train_data = TinyShakespeareDataSet(
n: int = int(self.hparams.train_test_split * len(data))
self.train_data: TinyShakespeareDataSet = TinyShakespeareDataSet(
data[:n], block_size=self.hparams.block_size
)
self.val_data = TinyShakespeareDataSet(
self.val_data: TinyShakespeareDataSet = TinyShakespeareDataSet(
data[n:], block_size=self.hparams.block_size
)

def train_dataloader(self):
# lightning should auto-add DistributedSampler for these dataloaders when required
def train_dataloader(self) -> DataLoader:
"""Returns a dataloader for the training dataset."""
# lightning auto-adds DistributedSampler for these dataloaders when required
return DataLoader(
self.train_data,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.train_dataloader_workers,
persistent_workers=True,
)

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
"""Returns a dataloader for the validation dataset."""
return DataLoader(
self.val_data,
batch_size=self.hparams.batch_size,
Expand Down
2 changes: 1 addition & 1 deletion src/litgpt/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# LitGPT
# Minimal GPT Implementation in PyTorch Lightning
# https://github.com/tomogwen/litgpt

Expand All @@ -10,6 +9,7 @@


def main():
"""LightningCLI entry point."""
torch.set_float32_matmul_precision("high")
LightningCLI(LitMinGPT, TinyShakespeareDataModule, save_config_callback=None)

Expand Down
Loading

0 comments on commit 59fc365

Please sign in to comment.