Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MicroLlama training support #1457

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions config_hub/pretrain/microllama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@

# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: micro-llama-300M

# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:

# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/micro-llama

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:

# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False)
resume: false

# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: MicroLlama

# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:

# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000

# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1

# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48)
# Scale this number according to the number of GPU and memory size per GPU
# For example, we used 48 for 4 x 24G 4090
global_batch_size: 48

# Number of samples per data-parallel rank (type: int, default: 12)
# Scale this number according to the memory size per GPU
# For example, we used 12 for 24G 4090
micro_batch_size: 12

# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 2000

# Number of epochs to train on (type: Optional[int], default: null)
epochs:

# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 3000000000000

# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:

# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 2048

# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:

# (type: Optional[float], default: 1.0)
max_norm: 1.0

# (type: float, default: 4e-05)
min_lr: 4.0e-05

# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:

# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000

# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:

# Number of iterations (type: int, default: 100)
max_iters: 100

# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false

# Optimizer-related arguments
optimizer:

class_path: torch.optim.AdamW

init_args:

# (type: float, default: 0.001)
lr: 4e-4

# (type: float, default: 0.01)
weight_decay: 0.1

# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95

# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto

# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf

# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard)
logger_name: tensorboard

# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
31 changes: 31 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,37 @@ def norm_class(self) -> Type:
configs.append(copy)


############
# MicroLlama
############
micro_llama = [
dict(
name="micro-llama-300M{}",
hf_config=dict(org="keeeeenw", name="MicroLlama{}"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=12,
n_head=16,
n_embd=1024,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm", # original TinyLlama and MicroLlama uses FusedRMSNorm
norm_eps=1e-5,
mlp_class_name="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
)
]
for c in micro_llama:
for kind, hf_postfix in [("", "")]:
keeeeenw marked this conversation as resolved.
Show resolved Hide resolved
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(hf_postfix)
configs.append(copy)


##########################
# Trelis Function Calling
##########################
Expand Down
2 changes: 2 additions & 0 deletions litgpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from litgpt.data.tinyllama import TinyLlama
from litgpt.data.tinystories import TinyStories
from litgpt.data.openwebtext import OpenWebText
from litgpt.data.microllama import MicroLlama


__all__ = [
Expand All @@ -34,5 +35,6 @@
"TextFiles",
"TinyLlama",
"TinyStories",
"MicroLlama"
"get_sft_collate_fn",
]
80 changes: 80 additions & 0 deletions litgpt/data/microllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union

from torch.utils.data import DataLoader

from litgpt import Tokenizer
from litgpt.data import DataModule


@dataclass
class MicroLlama(DataModule):
"""The MicroLlama data module is based on TinyLlama but composed of only SlimPajama data.

Provides training and validation streaming dataloaders that return batches of tokens.
"""

data_path: Union[str, Path] = Path("data/")
"""The path to the data directory, containing folder 'slimpajama'
which is the output of the preprocessing step done in advance. See the `tutorial/pretrain_tinyllama.md`
for TinyLlama for instructions. There is no need to process starcoder because MicroLlama only uses
SlimPajama data. The path can also be a remote path (e.g., s3://)."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 8
"""How many DataLoader processes to use for loading."""

batch_size: int = field(init=False, repr=False, default=1)
seq_length: int = field(init=False, repr=False, default=2048)

def __post_init__(self):
# Could be a remote path (s3://) or a local path
self.slimpajama_train = str(self.data_path).rstrip("/") + "/slimpajama/train"
self.slimpajama_val = str(self.data_path).rstrip("/") + "/slimpajama/val"

def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.batch_size = batch_size
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well

def prepare_data(self) -> None:
for path in (self.slimpajama_train, self.slimpajama_val):
if not path.startswith("s3://") and not Path(path).is_dir():
raise FileNotFoundError(
"The data path for TinyLlama is expected to be the directory containing these subdirectories:"
f" `slimpajama/train`, `slimpajama/val`. The directory {path} does not exist."
" Set it via `--data.data_path=...`"
)

def train_dataloader(self) -> DataLoader:
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader

train_dataset = StreamingDataset(
input_dir=self.slimpajama_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
)

train_dataloader = StreamingDataLoader(
train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader

def val_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataset, TokensLoader

val_dataset = StreamingDataset(
input_dir=self.slimpajama_val,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
# Consider setting to False, but we would lose some samples due to truncation when world size > 1
drop_last=True,
)
val_dataloader = DataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader
2 changes: 2 additions & 0 deletions litgpt/data/prepare_slimpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


class SlimPajamaDataRecipe(DataChunkRecipe):
is_generator = True # fix any exception related to DataChunkRecipe

def __init__(self, tokenizer: Tokenizer, chunk_size: int):
super().__init__(chunk_size)
self.tokenizer = tokenizer
Expand Down
8 changes: 6 additions & 2 deletions litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from litgpt import Tokenizer
from litgpt.args import EvalArgs, TrainArgs
from litgpt.config import name_to_config
from litgpt.data import DataModule, TinyLlama
from litgpt.data import DataModule, TinyLlama, MicroLlama
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from litgpt.utils import (
CycleIterator,
Expand Down Expand Up @@ -105,7 +105,11 @@ def setup(
quit()

hparams = capture_hparams()
data = TinyLlama() if data is None else data
data = data
keeeeenw marked this conversation as resolved.
Show resolved Hide resolved
if data is None or data == 'TinyLlama':
data = TinyLlama()
elif data == 'MicroLlama':
data = MicroLlama()
keeeeenw marked this conversation as resolved.
Show resolved Hide resolved

config = Config.from_name(model_name) if model_config is None else model_config
precision = precision or get_default_supported_precision(training=True)
Expand Down