This repo contains the unofficial PyTorch implementation of the Mamba
model as introduced in Gu & Dao (2023). One can find the official (CUDA) implementation here or the nice alxndrTL alternative Python implementation. This repo is developed mainly for didactic purposes to spell out the details of a State Space Models
.
The main contribution of this repo is the implementation of the Mamba
architecture (upon which a basic LLM is built, see 📂 mamba.llm
) using the Lightning Framework which abstract away the difficulty of training on distributed settings.
The basic usage is to instantiate a Mamba
model and run it on the input of choice. The model expects the input tensor to have shape (batch_size, seq_len, d_input)
and outputs a similarly shaped tensor.
from mamba import Mamba
d_input = 16
seq_len = 32
batch_size = 4
model = Mamba(
num_layers = 6, # Number of layers of the full model
d_input = d_input, # Dimension of each vector in the input sequence (i.e. token size)
d_model = 64, # Dimension of the visible state space
d_state = 64, # Dimension of the latent hidden states
d_discr = 16, # Rank of the discretization matrix Δ
ker_size = 4, # Kernel size of the convolution in the MambaBlock
parallel = False, # Whether to use the sequential or the parallel implementation
)
# Mockup input for example purposes
inp = torch.randn(batch_size, seq_len, d_input)
# Compute the output using the Mamba architecture
out = model(inp) # (batch_size, seq_len, d_input)
This repo also offers a didactic implementation of an LLM built using Lightning which unlocks easy training on multi-gpus. To use it one can simply run the following example:
from lightning import Trainer
from transformers import AutoTokenizer
from mamba import MambaLLM
from mamba.tiny import TinyStoriesLightning
config = ... # path to YAML configuration file
# Load an off-the-shelf tokenizer from HF
tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')
# Load the Mamba model from a config file
model = MambaLLM.from_config(config, key='mamba')
# Load the dataset
dataset = TinyStoriesLightning.from_config(
config,
tokenizer,
key='dataset'
)
trainer = Trainer(
max_epochs = 500,
accelerator = 'gpu',
devices = 4, # Piece of cake multi-gpu support!
strategy = 'ddp_find_unused_parameters_false',
)
# Train the model
trainer.fit(model, dataset)
Alternatively, one can also run the training script train.py
directly which expects the configuration file path and accepts all the Trainer arguments.
python train.py --config <path_to_config_file>\
--max_epochs 500\
--accelerator gpu\
--devices 4
A cool feature of MambaLLM
implementation is the lazy (batched-) inference implemented via a generator. One can thus print tokens on screen as they are streamed by the model, no need to wait for the whole inference to finish! A mock-up script would look like the following.
from mamba import MambaLLM
from transformers import AutoTokenizer
# Get an off-the-shelf tokenizer
tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
# Parameters for the LLM
vocab_size = tokenizer.vocab_size + 1
num_layers = 6
d_input = 16
d_model = 64
d_state = 64
d_discr = 16
ker_size = 4
parallel = True
model = MambaLLM(
vocab_size = vocab_size,
num_layers = num_layers,
d_input = d_input,
d_model = d_model,
d_state = d_state,
d_discr = d_discr,
ker_size = ker_size,
parallel = parallel
)
# Parameters for the inference
token_lim = 16
use_top_k = 50
temperature = 0.7
# Generate text
stream = model.generate(
# We can provide more than one prompt!
prompt=[
'Once upon a time',
'In a galaxy far far away',
],
tokenizer=tokenizer,
token_lim=token_lim,
use_top_k=use_top_k,
temperature=temperature,
)
for token in stream:
# Each token is a dictionary indexed by the
# batch-id and contains the produced string
# as value, so we can print the first batch as:
print(token[0], end='')
- Put all the essential pieces together
- Add functioning parallel implementation (p-scan) (🙏🏻 @ Zeta Project & alxndrTL)
- Add functioning training script (Lightning)
- Show some results
This repo build on top of some cool libraries, one can easily install the dependencies via pip install -r requirements.txt
einops==0.7.0
lightning==2.2.2
PyYAML==6.0
PyYAML==6.0.1
torch==2.0.1
transformers==4.40.0
Some of the code is inspired by the great repo LLM From Scratch that was used as a guide to elucidate several doubts along the way.
@article{gu2023mamba,
title={Mamba: Linear-time sequence modeling with selective state spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}