diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..0a0c6ae --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,63 @@ +default_language_version: + python: python3 + +ci: + autofix_prs: true + autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' + autoupdate_schedule: quarterly + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-executables-have-shebangs + - id: check-toml + - id: check-case-conflict + - id: check-added-large-files + args: ['--maxkb=350', '--enforce-all'] + + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: [--py38-plus] + name: Upgrade code + + - repo: https://github.com/PyCQA/docformatter + rev: v1.5.1 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120] + + - repo: https://github.com/asottile/yesqa + rev: v1.4.0 + hooks: + - id: yesqa + name: Unused noqa + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + name: Format imports + + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + name: Format code + + - repo: https://github.com/asottile/blacken-docs + rev: 1.13.0 + hooks: + - id: blacken-docs + args: [--line-length=120] + + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: 'v0.0.259' + hooks: + - id: ruff + args: ["--fix"] diff --git a/LICENSE b/LICENSE index fe4eb47..fe60df9 100644 --- a/LICENSE +++ b/LICENSE @@ -199,4 +199,3 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - diff --git a/README.md b/README.md index 8b13789..e69de29 100644 --- a/README.md +++ b/README.md @@ -1 +0,0 @@ - diff --git a/generate.py b/generate.py index 0fef5df..020f5e8 100644 --- a/generate.py +++ b/generate.py @@ -1,16 +1,17 @@ import os +import sys import time -import torch -from tokenizer import Tokenizer + import lightning as L +import torch + from quantization.bnb import quantize as quantize_model -import sys +from tokenizer import Tokenizer @torch.no_grad() def generate(model, idx, max_new_tokens, max_seq_length, temperature=1.0, top_k=None): - """ - Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. The implementation of this function is modified from A. Karpathy's nanoGPT. @@ -57,13 +58,13 @@ def generate(model, idx, max_new_tokens, max_seq_length, temperature=1.0, top_k= def get_model(original: bool = False): if original: try: - from original_model import Transformer, ModelArgs + from original_model import ModelArgs, Transformer except ModuleNotFoundError: from scripts.download import download_original download_original(os.path.dirname(__file__)) - from original_model import Transformer, ModelArgs + from original_model import ModelArgs, Transformer config = ModelArgs(dim=4096, n_layers=32, n_heads=32, vocab_size=32000, max_batch_size=1) # 7B config return Transformer(config), config.max_seq_len @@ -89,8 +90,7 @@ def main( original_model: bool = False, quantize: bool = False, ): - """ - Generates text samples based on a pre-trained LLaMA model and tokenizer. + """Generates text samples based on a pre-trained LLaMA model and tokenizer. Args: prompt: The prompt string to use for generating the samples. @@ -118,7 +118,7 @@ def main( model, max_seq_length = get_model(original_model) # The output layer can be sensitive to quantization, we keep it in default precision - model = quantize_model(model, skip=("lm_head", "output", )) + model = quantize_model(model, skip=("lm_head", "output")) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint, strict=(not original_model)) else: @@ -128,7 +128,7 @@ def main( model.load_state_dict(checkpoint, strict=(not original_model)) model.eval() - + # if compile: # model = torch.compile(model) @@ -141,9 +141,7 @@ def main( L.seed_everything(1234) t0 = time.time() for _ in range(num_samples): - y = generate( - model, encoded_prompt, max_new_tokens, max_seq_length, temperature=temperature, top_k=top_k - ) + y = generate(model, encoded_prompt, max_new_tokens, max_seq_length, temperature=temperature, top_k=top_k) print(tokenizer.decode(y[0])) print(f"Time for inference: {time.time() - t0:.02f} seconds", file=sys.stderr) @@ -153,5 +151,5 @@ def main( if __name__ == "__main__": from jsonargparse import CLI - torch.set_float32_matmul_precision('high') + torch.set_float32_matmul_precision("high") CLI(main) diff --git a/model.py b/model.py index 64f0aa9..6268d0e 100644 --- a/model.py +++ b/model.py @@ -1,5 +1,5 @@ -""" -Full definition of a LLaMA Language Model, all of it in this single file. +"""Full definition of a LLaMA Language Model, all of it in this single file. + Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. """ @@ -12,12 +12,14 @@ def build_rope_cache(seq_len, n_elem, dtype, base=10000): - """ - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py - MIT License: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. """ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1. / (base ** (torch.arange(0, n_elem, 2, dtype=dtype) / n_elem)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype) / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, dtype=dtype) @@ -32,11 +34,11 @@ def build_rope_cache(seq_len, n_elem, dtype, base=10000): def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor): x = x.transpose(1, 2) - + # truncate to support variable sizes T = x.size(1) rope_cache = rope_cache[:T] - + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3)) x_out = torch.view_as_real(xc * rope_cache).flatten(3) @@ -44,9 +46,10 @@ def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor): class RMSNorm(nn.Module): - """ - Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py - BSD 3-Clause License: https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. """ def __init__(self, size, dim=-1, eps=1e-5): @@ -60,13 +63,12 @@ def forward(self, x): # norm_x = x.norm(2, dim=self.dim, keepdim=True) # rms_x = norm_x * d_x ** (-1. / 2) # x_normed = x / (rms_x + self.eps) - norm_x = torch.mean(x*x, dim=self.dim, keepdim=True) + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) return self.scale * x_normed class CausalSelfAttention(nn.Module): - def __init__(self, config, rope_cache): super().__init__() assert config.n_embd % config.n_head == 0 @@ -81,15 +83,15 @@ def __init__(self, config, rope_cache): self.register_buffer("rope_cache", rope_cache, persistent=False) def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) head_size = C // self.n_head - k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) q = apply_rope(q, self.rope_cache) k = apply_rope(k, self.rope_cache) @@ -97,7 +99,7 @@ def forward(self, x): # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # efficient attention using Flash Attention CUDA kernels y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.c_proj(y) @@ -106,7 +108,6 @@ def forward(self, x): class MLP(nn.Module): - def __init__(self, config): super().__init__() hidden_dim = 4 * config.n_embd @@ -115,9 +116,9 @@ def __init__(self, config): # ensure n_hidden is multiple of N n_hidden = ((n_hidden - 1) // N) * N + N - self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) - self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) - self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) + self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) + self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) + self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) def forward(self, x): x = F.silu(self.c_fc1(x)) * self.c_fc2(x) @@ -126,7 +127,6 @@ def forward(self, x): class Block(nn.Module): - def __init__(self, config, rope_cache): super().__init__() self.rms_1 = RMSNorm(config.n_embd) @@ -150,7 +150,6 @@ class LLaMAConfig: class LLaMA(nn.Module): - def __init__(self, config): super().__init__() assert config.vocab_size is not None @@ -160,32 +159,34 @@ def __init__(self, config): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) rope_cache = build_rope_cache( - seq_len=config.block_size, - n_elem=config.n_embd // config.n_head, - dtype=self.lm_head.weight.dtype, + seq_len=config.block_size, n_elem=config.n_embd // config.n_head, dtype=self.lm_head.weight.dtype ) - self.transformer = nn.ModuleDict(dict( - wte = nn.Embedding(config.vocab_size, config.n_embd), - h = nn.ModuleList([Block(config, rope_cache) for _ in range(config.n_layer)]), - ln_f = RMSNorm(config.n_embd), - )) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + h=nn.ModuleList([Block(config, rope_cache) for _ in range(config.n_layer)]), + ln_f=RMSNorm(config.n_embd), + ) + ) # init all weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * self.config.n_layer)) + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * self.config.n_layer)) + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) def forward(self, idx): _, t = idx.size() - assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" # forward the LLaMA model itself - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) for block in self.transformer.h: x = block(x) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2a63048 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,79 @@ +[metadata] +name = "lit-llama" +author = "Lightning-AI et al." +url = "https://github.com/Lightning-AI/lit-llama" + +[build-system] +requires = [ + "setuptools", + "wheel", +] + + +[tool.isort] +known_first_party = [ + "quantization", + "scripts", + "*.py", +] +profile = "black" +line_length = 120 +force_sort_within_sections = "False" +order_by_type = "False" + + +[tool.black] +line-length = 120 + + +[tool.ruff] +line-length = 120 +# Enable Pyflakes `E` and `F` codes by default. +select = [ + "E", "W", # see: https://pypi.org/project/pycodestyle + "F", # see: https://pypi.org/project/pyflakes +] +ignore = [ + "E731", # Do not assign a lambda expression, use a def +] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".git", +] +ignore-init-module-imports = true + + +#[tool.mypy] +#files = [ +# "quantization", +# "scripts", +# "*.py", +#] +#install_types = "True" +#non_interactive = "True" +#disallow_untyped_defs = "True" +#ignore_missing_imports = "True" +#show_error_codes = "True" +#warn_redundant_casts = "True" +#warn_unused_configs = "True" +#warn_unused_ignores = "True" +#allow_redefinition = "True" +## style choices +#warn_no_return = "False" + + +#[tool.pytest.ini_options] +#norecursedirs = [ +# ".git", +# ".github", +#] +#addopts = [ +# "--strict-markers", +# "--doctest-modules", +# "--color=yes", +# "--disable-pytest-warnings", +#] +#filterwarnings = [ +# "error::FutureWarning", +#] +#junit_duration_report = "call" diff --git a/quantization/bnb.py b/quantization/bnb.py index 0551756..a8eb86a 100644 --- a/quantization/bnb.py +++ b/quantization/bnb.py @@ -1,19 +1,17 @@ -import torch.nn as nn -from typing import Tuple import os +from typing import Tuple + +import torch.nn as nn + os.environ["BITSANDBYTES_NOWELCOME"] = "1" -import bitsandbytes as bnb +import bitsandbytes as bnb # noqa: E402 -def quantize(model: nn.Module, threshold: float = 6.0, skip: Tuple[str, ...] = ( )) -> nn.Module: +def quantize(model: nn.Module, threshold: float = 6.0, skip: Tuple[str, ...] = ()) -> nn.Module: for name, module in model.named_children(): if isinstance(module, nn.Linear) and name not in skip: model._modules[name] = bnb.nn.Linear8bitLt( - module.in_features, - module.out_features, - bias=module.bias, - has_fp16_weights=False, - threshold=threshold, + module.in_features, module.out_features, bias=module.bias, has_fp16_weights=False, threshold=threshold ) if module.children(): diff --git a/requirements.txt b/requirements.txt index 057c164..f1aef71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ lightning>=2.0.0 sentencepiece tqdm # convert_checkpoint.py numpy # train.py dataset memmap -jsonargparse # generate.py, convert_checkpoint.py CLI \ No newline at end of file +jsonargparse # generate.py, convert_checkpoint.py CLI diff --git a/scripts/compare.py b/scripts/compare.py index bba8784..af87815 100644 --- a/scripts/compare.py +++ b/scripts/compare.py @@ -5,14 +5,14 @@ def build_rope_cache_old(seq_len, n_elem, dtype, base=10000): - """This is the `build_rope_cache` implementation we initially intended to use, but it is - numerically not exactly equivalent to the one in the Meta model. We keep it here for posterity. + """This is the `build_rope_cache` implementation we initially intended to use, but it is numerically not + exactly equivalent to the one in the Meta model. We keep it here for posterity. - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py - MIT License: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license - """ + Derived from:mers/rope/__init__.py + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license MIT License: + """ # noqa: E501 # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1. / (base ** (torch.arange(0, n_elem, 2, dtype=dtype) / n_elem)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype) / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, dtype=dtype) @@ -34,13 +34,15 @@ def build_rope_cache_old(seq_len, n_elem, dtype, base=10000): def rotate_neg_half(x: torch.Tensor): # $\frac{d}{2}$ d_2 = x.shape[-1] // 2 - # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ # noqa: E501 return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) def apply_rope_old(x: torch.Tensor, rope_cache): - """This is the `apply_rope` implementation we initially intended to use, but it is - numerically not exactly equivalent to the one in the Meta model. We keep it here for posterity. + """This is the `apply_rope` implementation we initially intended to use, but it is numerically not exactly + equivalent to the one in the Meta model. + + We keep it here for posterity. """ neg_half_x = rotate_neg_half(x) cos, sin = rope_cache @@ -127,24 +129,17 @@ def compare_to_orig_llama(): n_embd = 32 llama_config = llama.LLaMAConfig( - block_size=block_size, - vocab_size=vocab_size, - n_layer=n_layer, - n_head=n_head, - n_embd=n_embd + block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd ) orig_llama_config = orig_llama.ModelArgs( - dim=n_embd, - n_layers=n_layer, - n_heads=n_head, - vocab_size=vocab_size, - norm_eps=1e-5, - max_seq_len=block_size + dim=n_embd, n_layers=n_layer, n_heads=n_head, vocab_size=vocab_size, norm_eps=1e-5, max_seq_len=block_size ) batch_size = 3 - token_sample = torch.randint(0, orig_llama_config.vocab_size, size=(batch_size, orig_llama_config.max_seq_len), dtype=torch.int64) + token_sample = torch.randint( + 0, orig_llama_config.vocab_size, size=(batch_size, orig_llama_config.max_seq_len), dtype=torch.int64 + ) llama_model = llama.LLaMA(llama_config) orig_llama_model = orig_llama.Transformer(orig_llama_config) @@ -160,7 +155,7 @@ def compare_to_orig_llama(): seq_len = token_sample.shape[1] mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) - orig_llama_block_out = orig_llama_model.layers[0](orig_llama_embed, 0, orig_llama_model.freqs_cis[: seq_len], mask) + orig_llama_block_out = orig_llama_model.layers[0](orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], mask) llama_block_out = llama_model.transformer.h[0](llama_embed) block_matches = torch.allclose(orig_llama_block_out, llama_block_out) diff --git a/scripts/convert_checkpoint.py b/scripts/convert_checkpoint.py index 37f08f4..9dc1974 100644 --- a/scripts/convert_checkpoint.py +++ b/scripts/convert_checkpoint.py @@ -1,8 +1,9 @@ +import os +import shutil from pathlib import Path + import torch from tqdm import tqdm -import os -import shutil """ Sample usage: @@ -14,6 +15,7 @@ ``` """ + def convert_state_dict(state_dict): converted = {} converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"] @@ -25,16 +27,26 @@ def convert_state_dict(state_dict): # attention # the wq, wk, wv from the FB model are stacked in our model as c_attn - converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat(( - state_dict[f"layers.{layer_idx}.attention.wq.weight"], - state_dict[f"layers.{layer_idx}.attention.wk.weight"], - state_dict[f"layers.{layer_idx}.attention.wv.weight"], - )) - converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[f"layers.{layer_idx}.attention.wo.weight"] + converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat( + ( + state_dict[f"layers.{layer_idx}.attention.wq.weight"], + state_dict[f"layers.{layer_idx}.attention.wk.weight"], + state_dict[f"layers.{layer_idx}.attention.wv.weight"], + ) + ) + converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[ + f"layers.{layer_idx}.attention.wo.weight" + ] # mlp - converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[f"layers.{layer_idx}.feed_forward.w1.weight"] - converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[f"layers.{layer_idx}.feed_forward.w2.weight"] - converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[f"layers.{layer_idx}.feed_forward.w3.weight"] + converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[ + f"layers.{layer_idx}.feed_forward.w1.weight" + ] + converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[ + f"layers.{layer_idx}.feed_forward.w2.weight" + ] + converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[ + f"layers.{layer_idx}.feed_forward.w3.weight" + ] # rms norm converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"] converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"] @@ -57,8 +69,7 @@ def meta_weights_for_nano_model( shutil.copy(tokenizer_path, output_dir.parent) checkpoint_files = sorted(ckpt_dir.glob("*.pth")) - - + # for the bigger models, there are multiple model-parallel checkpoints # and we combine them into one single file combined = {} @@ -66,7 +77,7 @@ def meta_weights_for_nano_model( checkpoint = torch.load(file, map_location="cpu") converted = convert_state_dict(checkpoint) combined.update(converted) - + torch.save(combined, Path(output_dir, "state_dict.pth")) diff --git a/scripts/prepare_shakespeare.py b/scripts/prepare_shakespeare.py index 63d6818..a8342e5 100644 --- a/scripts/prepare_shakespeare.py +++ b/scripts/prepare_shakespeare.py @@ -23,6 +23,7 @@ import sys import requests + import numpy as np @@ -38,7 +39,7 @@ def prepare( with open(input_file_path, "w") as f: f.write(requests.get(data_url).text) - with open(input_file_path, "r") as f: + with open(input_file_path) as f: data = f.read() n = len(data) train_data = data[: int(n * 0.9)] diff --git a/train.py b/train.py index f360cc9..58b6c55 100644 --- a/train.py +++ b/train.py @@ -3,14 +3,13 @@ from functools import partial import lightning as L -import numpy as np import torch import torch.nn.functional as F from lightning.fabric.strategies import FSDPStrategy from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from model import LLaMA, LLaMAConfig, Block - +import numpy as np +from model import Block, LLaMA, LLaMAConfig out_dir = "out" eval_interval = 2000 @@ -33,20 +32,10 @@ def main(): - auto_wrap_policy = partial( - transformer_auto_wrap_policy, transformer_layer_cls={Block} - ) - strategy = FSDPStrategy( - auto_wrap_policy=auto_wrap_policy, - activation_checkpointing=Block, - ) - - fabric = L.Fabric( - accelerator="cuda", - devices=4, - precision="bf16-mixed", - strategy=strategy, - ) + auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) + strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block) + + fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy) fabric.launch() fabric.seed_everything(1337 + fabric.global_rank) @@ -66,12 +55,7 @@ def main(): model = fabric.setup_module(model) - optimizer = torch.optim.AdamW( - model.parameters(), - lr=learning_rate, - weight_decay=weight_decay, - betas=(beta1, beta2), - ) + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) optimizer = fabric.setup_optimizers(optimizer) train(fabric, model, optimizer, train_data, val_data) @@ -79,10 +63,10 @@ def main(): def train(fabric, model, optimizer, train_data, val_data): """The training loop. - + Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. """ - + iter_num = 0 while True: @@ -97,7 +81,7 @@ def train(fabric, model, optimizer, train_data, val_data): # torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) t0 = time.time() - + input_ids, targets = get_batch(fabric, train_data, block_size=model.config.block_size) logits = model(input_ids) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)