Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework repo structure #41

Merged
merged 16 commits into from
Mar 28, 2023
8 changes: 4 additions & 4 deletions .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
fail-fast: false
matrix:
include:
- {os: "macOS-11", python-version: "3.10", pytorch-version: "2.0"}
- {os: "ubuntu-20.04", python-version: "3.10", pytorch-version: "2.0"}
- {os: "windows-2022", python-version: "3.10", pytorch-version: "2.0"}
- {os: "macOS-11", python-version: "3.10"}
- {os: "ubuntu-20.04", python-version: "3.10"}
- {os: "windows-2022", python-version: "3.10"}
timeout-minutes: 10

steps:
Expand All @@ -36,7 +36,7 @@ jobs:

- name: Install dependencies
run: |
pip install pytest -r requirements.txt
pip install pytest .
pip list

- name: Run tests
Expand Down
39 changes: 0 additions & 39 deletions .github/workflows/mypy.yml

This file was deleted.

1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
__pycache__
.idea
.DS_Store
.mypy_cache
*.egg-info

# data
Expand Down
63 changes: 0 additions & 63 deletions .pre-commit-config.yaml

This file was deleted.

10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<div align="center">
<img src="assets/Lit_LLaMA_Badge3x.png" alt="Lit-LLaMA" width="128"/>
<img src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Lit_LLaMA_Badge3x.png" alt="Lit-LLaMA" width="128"/>

# ⚡ Lit-LLaMA ️

Expand Down Expand Up @@ -88,7 +88,7 @@ python scripts/convert_checkpoint.py \
Run inference:

```bash
python scripts/generate.py --prompt "Hello, my name is"
python generate.py --prompt "Hello, my name is"
```

This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
Expand All @@ -99,18 +99,18 @@ For GPUs with less memory, enable quantization (`--quantize true`). This will ta
This can run on any consumer GPU.

```bash
python scripts/generate.py --quantize true --prompt "Hello, my name is"
python generate.py --quantize true --prompt "Hello, my name is"
```

See `python scripts/generate.py --help` for more options.
See `python generate.py --help` for more options.

&nbsp;

## Get involved!

We're in a quest towards fully open source AI, especially focusing on models in the 5-20B range, trained using the LLaMA approach (smaller models trained for longer).

<img align="right" src="assets/Lit_LLaMA_Illustration3x.png" alt="Lit-LLaMA" width="128"/>
<img align="right" src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Lit_LLaMA_Illustration3x.png" alt="Lit-LLaMA" width="128"/>

Join us and start contributing, especially on the following areas:

Expand Down
Binary file removed assets/Lit_LLaMA_Badge3x.png
Binary file not shown.
Binary file removed assets/Lit_LLaMA_Illustration3x.png
Binary file not shown.
9 changes: 5 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import lightning as L
import torch

from model import LLaMA
from quantization.bnb import quantize as quantize_model
from tokenizer import Tokenizer
from lit_llama.model import LLaMA
from lit_llama.tokenizer import Tokenizer


@torch.no_grad()
Expand Down Expand Up @@ -106,11 +105,13 @@ def main(
fabric = L.Fabric(accelerator=accelerator, devices=1)

if quantize:
from lit_llama.quantization import quantize

print("Running quantization. This may take a minute ...")
# TODO: Initializing the model directly on the device does not work with quantization
model = LLaMA.from_name(model_size)
# The output layer can be sensitive to quantization, we keep it in default precision
model = quantize_model(model, skip=("lm_head", "output"))
model = quantize(model, skip=("lm_head", "output"))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
else:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
85 changes: 0 additions & 85 deletions pyproject.toml

This file was deleted.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ sentencepiece
tqdm # convert_checkpoint.py
numpy # train.py dataset memmap
jsonargparse # generate.py, convert_checkpoint.py CLI
bitsandbytes # quantization/bnb.py
bitsandbytes # quantization.py
Empty file removed scripts/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from setuptools import setup, find_packages

setup(
name='lit-llama',
version='0.1.0',
description='Implementation of the LLaMA language model',
author='Lightning AI',
url='https://github.com/lightning-AI/lit-llama',
install_requires=[
"torch>=2.0.0",
"lightning>=2.0.0",
"sentencepiece",
"bitsandbytes",
],
packages=find_packages(),
)
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from typing import Tuple

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy

import torch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

import numpy as np
from model import Block, LLaMA, LLaMAConfig

from lit_llama.model import Block, LLaMA, LLaMAConfig

out_dir = "out"
eval_interval = 2000
Expand Down