Skip to content

Commit

Permalink
Introduce basic configuration setup (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 27, 2023
1 parent dccfc89 commit 83d1372
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 126 deletions.
63 changes: 63 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
default_language_version:
python: python3

ci:
autofix_prs: true
autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
autoupdate_schedule: quarterly

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-yaml
- id: check-executables-have-shebangs
- id: check-toml
- id: check-case-conflict
- id: check-added-large-files
args: ['--maxkb=350', '--enforce-all']

- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py38-plus]
name: Upgrade code

- repo: https://github.com/PyCQA/docformatter
rev: v1.5.1
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]

- repo: https://github.com/asottile/yesqa
rev: v1.4.0
hooks:
- id: yesqa
name: Unused noqa

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
name: Format imports

- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
name: Format code

- repo: https://github.com/asottile/blacken-docs
rev: 1.13.0
hooks:
- id: blacken-docs
args: [--line-length=120]

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.259'
hooks:
- id: ruff
args: ["--fix"]
1 change: 0 additions & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,3 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

28 changes: 13 additions & 15 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
import sys
import time
import torch
from tokenizer import Tokenizer

import lightning as L
import torch

from quantization.bnb import quantize as quantize_model
import sys
from tokenizer import Tokenizer


@torch.no_grad()
def generate(model, idx, max_new_tokens, max_seq_length, temperature=1.0, top_k=None):
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Expand Down Expand Up @@ -57,13 +58,13 @@ def generate(model, idx, max_new_tokens, max_seq_length, temperature=1.0, top_k=
def get_model(original: bool = False):
if original:
try:
from original_model import Transformer, ModelArgs
from original_model import ModelArgs, Transformer
except ModuleNotFoundError:
from scripts.download import download_original

download_original(os.path.dirname(__file__))

from original_model import Transformer, ModelArgs
from original_model import ModelArgs, Transformer

config = ModelArgs(dim=4096, n_layers=32, n_heads=32, vocab_size=32000, max_batch_size=1) # 7B config
return Transformer(config), config.max_seq_len
Expand All @@ -89,8 +90,7 @@ def main(
original_model: bool = False,
quantize: bool = False,
):
"""
Generates text samples based on a pre-trained LLaMA model and tokenizer.
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
Args:
prompt: The prompt string to use for generating the samples.
Expand Down Expand Up @@ -118,7 +118,7 @@ def main(
model, max_seq_length = get_model(original_model)

# The output layer can be sensitive to quantization, we keep it in default precision
model = quantize_model(model, skip=("lm_head", "output", ))
model = quantize_model(model, skip=("lm_head", "output"))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint, strict=(not original_model))
else:
Expand All @@ -128,7 +128,7 @@ def main(
model.load_state_dict(checkpoint, strict=(not original_model))

model.eval()

# if compile:
# model = torch.compile(model)

Expand All @@ -141,9 +141,7 @@ def main(
L.seed_everything(1234)
t0 = time.time()
for _ in range(num_samples):
y = generate(
model, encoded_prompt, max_new_tokens, max_seq_length, temperature=temperature, top_k=top_k
)
y = generate(model, encoded_prompt, max_new_tokens, max_seq_length, temperature=temperature, top_k=top_k)
print(tokenizer.decode(y[0]))

print(f"Time for inference: {time.time() - t0:.02f} seconds", file=sys.stderr)
Expand All @@ -153,5 +151,5 @@ def main(
if __name__ == "__main__":
from jsonargparse import CLI

torch.set_float32_matmul_precision('high')
torch.set_float32_matmul_precision("high")
CLI(main)
73 changes: 37 additions & 36 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Full definition of a LLaMA Language Model, all of it in this single file.
"""Full definition of a LLaMA Language Model, all of it in this single file.
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""

Expand All @@ -12,12 +12,14 @@


def build_rope_cache(seq_len, n_elem, dtype, base=10000):
"""
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py
MIT License: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1. / (base ** (torch.arange(0, n_elem, 2, dtype=dtype) / n_elem))
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype) / n_elem))

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype)
Expand All @@ -32,21 +34,22 @@ def build_rope_cache(seq_len, n_elem, dtype, base=10000):

def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor):
x = x.transpose(1, 2)

# truncate to support variable sizes
T = x.size(1)
rope_cache = rope_cache[:T]

xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
x_out = torch.view_as_real(xc * rope_cache).flatten(3)
return x_out.transpose(1, 2).type_as(x)


class RMSNorm(nn.Module):
"""
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py
BSD 3-Clause License: https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE
"""Root Mean Square Layer Normalization.
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""

def __init__(self, size, dim=-1, eps=1e-5):
Expand All @@ -60,13 +63,12 @@ def forward(self, x):
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
# rms_x = norm_x * d_x ** (-1. / 2)
# x_normed = x / (rms_x + self.eps)
norm_x = torch.mean(x*x, dim=self.dim, keepdim=True)
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.scale * x_normed


class CausalSelfAttention(nn.Module):

def __init__(self, config, rope_cache):
super().__init__()
assert config.n_embd % config.n_head == 0
Expand All @@ -81,23 +83,23 @@ def __init__(self, config, rope_cache):
self.register_buffer("rope_cache", rope_cache, persistent=False)

def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

head_size = C // self.n_head
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)

q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

# output projection
y = self.c_proj(y)
Expand All @@ -106,7 +108,6 @@ def forward(self, x):


class MLP(nn.Module):

def __init__(self, config):
super().__init__()
hidden_dim = 4 * config.n_embd
Expand All @@ -115,9 +116,9 @@ def __init__(self, config):
# ensure n_hidden is multiple of N
n_hidden = ((n_hidden - 1) // N) * N + N

self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)

def forward(self, x):
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
Expand All @@ -126,7 +127,6 @@ def forward(self, x):


class Block(nn.Module):

def __init__(self, config, rope_cache):
super().__init__()
self.rms_1 = RMSNorm(config.n_embd)
Expand All @@ -150,7 +150,6 @@ class LLaMAConfig:


class LLaMA(nn.Module):

def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
Expand All @@ -160,32 +159,34 @@ def __init__(self, config):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

rope_cache = build_rope_cache(
seq_len=config.block_size,
n_elem=config.n_embd // config.n_head,
dtype=self.lm_head.weight.dtype,
seq_len=config.block_size, n_elem=config.n_embd // config.n_head, dtype=self.lm_head.weight.dtype
)

self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
h = nn.ModuleList([Block(config, rope_cache) for _ in range(config.n_layer)]),
ln_f = RMSNorm(config.n_embd),
))
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
h=nn.ModuleList([Block(config, rope_cache) for _ in range(config.n_layer)]),
ln_f=RMSNorm(config.n_embd),
)
)

# init all weights
self.apply(self._init_weights)

def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * self.config.n_layer))
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * self.config.n_layer))
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))

def forward(self, idx):
_, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

# forward the LLaMA model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)

for block in self.transformer.h:
x = block(x)
Expand Down
Loading

0 comments on commit 83d1372

Please sign in to comment.