Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MicroLlama training support #1457

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,32 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials

#### All models

| Model | Model size | Author | Reference |
|----|----|----|----|
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/)
| Model | Model size | Reference |
|----------------------------------------------|-----------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
| CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 by H2O.ai | 1.8B | [H2O.ai](https://h2o.ai/platform/danube-1-8b/)
| Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) |
| FreeWilly2 (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 by Trelis | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma by Google | 2B, 7B | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| Llama 2 by Meta AI | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 by Meta AI | 8B, 70B | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| LongChat by LMSYS | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| MicroLlama by Ken Wang | 300M | [MicroLlama repository](https://github.com/keeeeenw/MicroLlama) |
| Mistral and Mixtral by Mistral AI | 7B | [Mistral website](https://mistral.ai/) |
| Nous-Hermes by NousResearch | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) |
| OpenLLaMA by OpenLM Research | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi by Microsoft Research | 1.3B, 2.7B | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Platypus by Lee at el. | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia by EleutherAI | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| RedPajama-INCITE by Together | 3B, 7B | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| StableCode by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM by Stability AI | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TinyLlama by Zhang et al. | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Vicuna by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |

**Tip**: You can list all available models by running the `litgpt download list` command.

Expand Down
115 changes: 115 additions & 0 deletions config_hub/pretrain/microllama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@

# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: micro-llama-300M

# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:

# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/micro-llama

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:

# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False)
resume: false

# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: MicroLlama

# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:

# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000

# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1

# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48)
# Scale this number according to the number of GPU and memory size per GPU
# For example, we used 48 for 4 x 24G 4090
global_batch_size: 48

# Number of samples per data-parallel rank (type: int, default: 12)
# Scale this number according to the memory size per GPU
# For example, we used 12 for 24G 4090
micro_batch_size: 12

# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 2000

# Number of epochs to train on (type: Optional[int], default: null)
epochs:

# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 3000000000000

# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:

# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 2048

# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:

# (type: Optional[float], default: 1.0)
max_norm: 1.0

# (type: float, default: 4e-05)
min_lr: 4.0e-05

# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:

# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000

# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:

# Number of iterations (type: int, default: 100)
max_iters: 100

# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false

# Optimizer-related arguments
optimizer:

class_path: torch.optim.AdamW

init_args:

# (type: float, default: 0.001)
lr: 4e-4

# (type: float, default: 0.01)
weight_decay: 0.1

# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95

# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto

# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf

# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard)
logger_name: tensorboard

# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
31 changes: 31 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,37 @@ def norm_class(self) -> Type:
configs.append(copy)


############
# MicroLlama
############
micro_llama = [
dict(
name="micro-llama-300M{}",
hf_config=dict(org="keeeeenw", name="MicroLlama{}"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=12,
n_head=16,
n_embd=1024,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm", # original TinyLlama and MicroLlama uses FusedRMSNorm
norm_eps=1e-5,
mlp_class_name="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
)
]
for c in micro_llama:
for kind, hf_postfix in [("", "")]:
keeeeenw marked this conversation as resolved.
Show resolved Hide resolved
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(hf_postfix)
configs.append(copy)


##########################
# Trelis Function Calling
##########################
Expand Down
4 changes: 4 additions & 0 deletions litgpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from litgpt.data.lima import LIMA
from litgpt.data.lit_data import LitData
from litgpt.data.longform import LongForm
from litgpt.data.llama_data import LlamaDataModule
from litgpt.data.text_files import TextFiles
from litgpt.data.tinyllama import TinyLlama
from litgpt.data.tinystories import TinyStories
from litgpt.data.openwebtext import OpenWebText
from litgpt.data.microllama import MicroLlama


__all__ = [
Expand All @@ -27,12 +29,14 @@
"JSON",
"LIMA",
"LitData",
"LlamaDataModule",
"DataModule",
"LongForm",
"OpenWebText",
"SFTDataset",
"TextFiles",
"TinyLlama",
"TinyStories",
"MicroLlama"
"get_sft_collate_fn",
]
104 changes: 104 additions & 0 deletions litgpt/data/llama_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union

from torch.utils.data import DataLoader

from litgpt import Tokenizer
from litgpt.data import DataModule


@dataclass
class LlamaDataModule(DataModule):
keeeeenw marked this conversation as resolved.
Show resolved Hide resolved
"""The Llama data module is composed of a mix of SlimPajama and optional Starcoder data.

Provides training and validation streaming dataloaders that return batches of tokens.
"""

data_path: Union[str, Path] = Path("data/")
"""The path to the data directory, containing two folders 'slimpajama' and 'starcoder'
which are the output of the preprocessing step done in advance. See the `tutorial/pretrain_tinyllama.md`
for instructions. The path can also be a remote path (e.g., s3://)."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 8
"""How many DataLoader processes to use for loading."""
use_starcoder: bool = False
"""Toggle for using Starcoder data."""

batch_size: int = field(init=False, repr=False, default=1)
seq_length: int = field(init=False, repr=False, default=2048)

def __post_init__(self):
# Could be a remote path (s3://) or a local path
self.slimpajama_train = str(self.data_path).rstrip("/") + "/slimpajama/train"
self.slimpajama_val = str(self.data_path).rstrip("/") + "/slimpajama/val"
self.required_paths = [self.slimpajama_train, self.slimpajama_val]

if self.use_starcoder:
self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder"
self.required_paths += [self.starcoder_train]

def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
) -> None:
self.batch_size = batch_size
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well

def prepare_data(self) -> None:
for path in self.required_paths:
if not path.startswith("s3://") and not Path(path).is_dir():
raise FileNotFoundError(
"The data path for Llama is expected to be the directory containing these subdirectories:"
f" `slimpajama/train`, `slimpajama/val`, `starcoder` (only for TinyLlama). The directory {path} does not exist."
" Set it via `--data.data_path=...`"
)

def train_dataloader(self) -> DataLoader:
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader

slim_train_data = StreamingDataset(
input_dir=self.slimpajama_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
)
train_data = slim_train_data

if self.use_starcoder:
train_datasets = [
slim_train_data,
StreamingDataset(
input_dir=self.starcoder_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
train_data = CombinedStreamingDataset(
datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False
)

train_dataloader = StreamingDataLoader(
train_data, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader

def val_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataset, TokensLoader

val_dataset = StreamingDataset(
input_dir=self.slimpajama_val,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
# Consider setting to False, but we would lose some samples due to truncation when world size > 1
drop_last=True,
)
val_dataloader = DataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader
Loading