A basic Flax library which contains reference implementations of various LLMs. Currently, just Llama3.
The code is very much a work in-progress; use at your own risk.
- Verify that Flax code matches reference implementation at 6 significant digits in float64 on CPU.
- Support passing in
kv_mask
andpairwise_mask
to support padded sequences. Test that this works properly. - Implement sharding, at least model-parallel + data-parallel.
- Change the main transformer loop to use a scan instead.
- Add throughput benchmarks.
- Test against the actual llama3 weights.
- Add code that can run tests using Modal Labs.
- (Stretch goal) Implement some of the fancier sharding methods listed in Meta's Movie Gen paper.
- Add tests on GPU.
- Pin dependency versions.
Run tests with docker build -t llamax . && docker run -e JAX_ENABLE_X64=True -e HF_TOKEN=$HF_TOKEN llamax pytest
.
Get your Huggingface token by following their instructions. If you want to run the integration test, download weights as per instructions below, and then pass them in via -v $WEIGHTS_DIR:/data
in the Docker command above.
I don't include any of the weights needed to test the implementations. You'll need to get them yourself. Here's where I got them:
- Llama 3.2 1B: Copy the "original" folder from Huggingface.