Skip to content

A llama3 implementation in Flax, which aims to be compatible with the official torch version up to floating point error.

License

Notifications You must be signed in to change notification settings

finbarrtimbers/llamax

Repository files navigation

llamax

Lint Tests Coverage

A basic Flax library which contains reference implementations of various LLMs. Currently, just Llama3.

The code is very much a work in-progress; use at your own risk.

TODO

  • Verify that Flax code matches reference implementation at 6 significant digits in float64 on CPU.
  • Support passing in kv_mask and pairwise_mask to support padded sequences. Test that this works properly.
  • Implement sharding, at least model-parallel + data-parallel.
  • Change the main transformer loop to use a scan instead.
  • Add throughput benchmarks.
  • Test against the actual llama3 weights.
  • Add code that can run tests using Modal Labs.
  • (Stretch goal) Implement some of the fancier sharding methods listed in Meta's Movie Gen paper.
  • Add tests on GPU.
  • Pin dependency versions.

Usage

Run tests with docker build -t llamax . && docker run -e JAX_ENABLE_X64=True -e HF_TOKEN=$HF_TOKEN llamax pytest.

Get your Huggingface token by following their instructions. If you want to run the integration test, download weights as per instructions below, and then pass them in via -v $WEIGHTS_DIR:/data in the Docker command above.

Weights

I don't include any of the weights needed to test the implementations. You'll need to get them yourself. Here's where I got them:

  1. Llama 3.2 1B: Copy the "original" folder from Huggingface.

About

A llama3 implementation in Flax, which aims to be compatible with the official torch version up to floating point error.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published