Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MicroLlama training support #1457

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,31 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials

| Model | Model size | Author | Reference |
|----|----|----|----|
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/)
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |

**Tip**: You can list all available models by running the `litgpt download list` command.

Expand Down
115 changes: 115 additions & 0 deletions config_hub/pretrain/microllama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@

# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: micro-llama-300M

# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:

# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/micro-llama

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:

# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False)
resume: false

# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: MicroLlama

# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:

# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000

# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1

# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48)
# Scale this number according to the number of GPU and memory size per GPU
# For example, we used 48 for 4 x 24G 4090
global_batch_size: 48

# Number of samples per data-parallel rank (type: int, default: 12)
# Scale this number according to the memory size per GPU
# For example, we used 12 for 24G 4090
micro_batch_size: 12

# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 2000

# Number of epochs to train on (type: Optional[int], default: null)
epochs:

# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 3000000000000

# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:

# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 2048

# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:

# (type: Optional[float], default: 1.0)
max_norm: 1.0

# (type: float, default: 4e-05)
min_lr: 4.0e-05

# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:

# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000

# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:

# Number of iterations (type: int, default: 100)
max_iters: 100

# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false

# Optimizer-related arguments
optimizer:

class_path: torch.optim.AdamW

init_args:

# (type: float, default: 0.001)
lr: 4e-4

# (type: float, default: 0.01)
weight_decay: 0.1

# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95

# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto

# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf

# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard)
logger_name: tensorboard

# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
28 changes: 27 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm", # original TinyLlama uses FusedRMSNorm
norm_class_name="RMSNorm", # original TinyLlama use FusedRMSNorm
norm_eps=1e-5,
mlp_class_name="LLaMAMLP",
intermediate_size=5632,
Expand All @@ -1563,6 +1563,32 @@ def norm_class(self) -> Type:
configs.append(copy)


############
# MicroLlama
############
micro_llama = [
dict(
name="micro-llama-300M",
hf_config=dict(org="keeeeenw", name="MicroLlama"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=12,
n_head=16,
n_embd=1024,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm", # original TinyLlama and MicroLlama use FusedRMSNorm
norm_eps=1e-5,
mlp_class_name="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
)
]
configs.extend(micro_llama)


##########################
# Trelis Function Calling
##########################
Expand Down
2 changes: 2 additions & 0 deletions litgpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from litgpt.data.tinyllama import TinyLlama
from litgpt.data.tinystories import TinyStories
from litgpt.data.openwebtext import OpenWebText
from litgpt.data.microllama import MicroLlama


__all__ = [
Expand All @@ -34,5 +35,6 @@
"TextFiles",
"TinyLlama",
"TinyStories",
"MicroLlama"
"get_sft_collate_fn",
]
13 changes: 13 additions & 0 deletions litgpt/data/microllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass
from pathlib import Path
from typing import Union

from litgpt.data import TinyLlama

@dataclass
class MicroLlama(TinyLlama):
"""The MicroLlama data module is composed of only SlimPajama data."""

def __init__(self, data_path: Union[str, Path] = Path("data/"), seed: int = 42, num_workers: int = 8):
super().__init__(data_path=data_path, seed=seed, num_workers=num_workers, use_starcoder=False)
2 changes: 2 additions & 0 deletions litgpt/data/prepare_slimpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


class SlimPajamaDataRecipe(DataChunkRecipe):
is_generator = True

def __init__(self, tokenizer: Tokenizer, chunk_size: int):
super().__init__(chunk_size)
self.tokenizer = tokenizer
Expand Down
2 changes: 2 additions & 0 deletions litgpt/data/prepare_starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@


class StarcoderDataRecipe(DataChunkRecipe):
is_generator = True

def __init__(self, tokenizer: Tokenizer, chunk_size: int):
super().__init__(chunk_size)
self.tokenizer = tokenizer
Expand Down
55 changes: 33 additions & 22 deletions litgpt/data/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class TinyLlama(DataModule):
"""The random seed for shuffling the dataset."""
num_workers: int = 8
"""How many DataLoader processes to use for loading."""
use_starcoder: bool = True
"""Toggle for using Starcoder data."""

batch_size: int = field(init=False, repr=False, default=1)
seq_length: int = field(init=False, repr=False, default=2048)
Expand All @@ -32,7 +34,11 @@ def __post_init__(self):
# Could be a remote path (s3://) or a local path
self.slimpajama_train = str(self.data_path).rstrip("/") + "/slimpajama/train"
self.slimpajama_val = str(self.data_path).rstrip("/") + "/slimpajama/val"
self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder"
self.required_paths = [self.slimpajama_train, self.slimpajama_val]

if self.use_starcoder:
self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder"
self.required_paths += [self.starcoder_train]

def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
Expand All @@ -41,7 +47,7 @@ def connect(
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well

def prepare_data(self) -> None:
for path in (self.slimpajama_train, self.slimpajama_val, self.starcoder_train):
for path in self.required_paths:
if not path.startswith("s3://") and not Path(path).is_dir():
raise FileNotFoundError(
"The data path for TinyLlama is expected to be the directory containing these subdirectories:"
Expand All @@ -52,28 +58,33 @@ def prepare_data(self) -> None:
def train_dataloader(self) -> DataLoader:
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader

train_datasets = [
StreamingDataset(
input_dir=self.slimpajama_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
StreamingDataset(
input_dir=self.starcoder_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
combined_dataset = CombinedStreamingDataset(
datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False
slim_train_data = StreamingDataset(
input_dir=self.slimpajama_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
)
train_data = slim_train_data

if self.use_starcoder:
train_datasets = [
slim_train_data,
StreamingDataset(
input_dir=self.starcoder_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
train_data = CombinedStreamingDataset(
datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False
)

train_dataloader = StreamingDataLoader(
combined_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
train_data, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader

Expand Down
2 changes: 1 addition & 1 deletion litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from litgpt import Tokenizer
from litgpt.args import EvalArgs, TrainArgs
from litgpt.config import name_to_config
from litgpt.data import DataModule, TinyLlama
from litgpt.data import DataModule, TinyLlama, MicroLlama
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from litgpt.utils import (
CycleIterator,
Expand Down
Loading