Skip to content

Commit

Permalink
Rework repo structure (Lightning-AI#41)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
lantiga and carmocca authored Mar 28, 2023
1 parent 6d13704 commit 1d726f9
Show file tree
Hide file tree
Showing 17 changed files with 35 additions and 204 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
fail-fast: false
matrix:
include:
- {os: "macOS-11", python-version: "3.10", pytorch-version: "2.0"}
- {os: "ubuntu-20.04", python-version: "3.10", pytorch-version: "2.0"}
- {os: "windows-2022", python-version: "3.10", pytorch-version: "2.0"}
- {os: "macOS-11", python-version: "3.10"}
- {os: "ubuntu-20.04", python-version: "3.10"}
- {os: "windows-2022", python-version: "3.10"}
timeout-minutes: 10

steps:
Expand All @@ -36,7 +36,7 @@ jobs:

- name: Install dependencies
run: |
pip install pytest -r requirements.txt
pip install pytest .
pip list
- name: Run tests
Expand Down
39 changes: 0 additions & 39 deletions .github/workflows/mypy.yml

This file was deleted.

1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
__pycache__
.idea
.DS_Store
.mypy_cache
*.egg-info

# data
Expand Down
63 changes: 0 additions & 63 deletions .pre-commit-config.yaml

This file was deleted.

10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<div align="center">
<img src="assets/Lit_LLaMA_Badge3x.png" alt="Lit-LLaMA" width="128"/>
<img src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Lit_LLaMA_Badge3x.png" alt="Lit-LLaMA" width="128"/>

# ⚡ Lit-LLaMA ️

Expand Down Expand Up @@ -88,7 +88,7 @@ python scripts/convert_checkpoint.py \
Run inference:

```bash
python scripts/generate.py --prompt "Hello, my name is"
python generate.py --prompt "Hello, my name is"
```

This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
Expand All @@ -99,18 +99,18 @@ For GPUs with less memory, enable quantization (`--quantize true`). This will ta
This can run on any consumer GPU.

```bash
python scripts/generate.py --quantize true --prompt "Hello, my name is"
python generate.py --quantize true --prompt "Hello, my name is"
```

See `python scripts/generate.py --help` for more options.
See `python generate.py --help` for more options.

&nbsp;

## Get involved!

We're in a quest towards fully open source AI, especially focusing on models in the 5-20B range, trained using the LLaMA approach (smaller models trained for longer).

<img align="right" src="assets/Lit_LLaMA_Illustration3x.png" alt="Lit-LLaMA" width="128"/>
<img align="right" src="https://pl-public-data.s3.amazonaws.com/assets_lightning/Lit_LLaMA_Illustration3x.png" alt="Lit-LLaMA" width="128"/>

Join us and start contributing, especially on the following areas:

Expand Down
Binary file removed assets/Lit_LLaMA_Badge3x.png
Binary file not shown.
Binary file removed assets/Lit_LLaMA_Illustration3x.png
Binary file not shown.
9 changes: 5 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import lightning as L
import torch

from model import LLaMA
from quantization.bnb import quantize as quantize_model
from tokenizer import Tokenizer
from lit_llama.model import LLaMA
from lit_llama.tokenizer import Tokenizer


@torch.no_grad()
Expand Down Expand Up @@ -106,11 +105,13 @@ def main(
fabric = L.Fabric(accelerator=accelerator, devices=1)

if quantize:
from lit_llama.quantization import quantize

print("Running quantization. This may take a minute ...")
# TODO: Initializing the model directly on the device does not work with quantization
model = LLaMA.from_name(model_size)
# The output layer can be sensitive to quantization, we keep it in default precision
model = quantize_model(model, skip=("lm_head", "output"))
model = quantize(model, skip=("lm_head", "output"))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
else:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
85 changes: 0 additions & 85 deletions pyproject.toml

This file was deleted.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ sentencepiece
tqdm # convert_checkpoint.py
numpy # train.py dataset memmap
jsonargparse # generate.py, convert_checkpoint.py CLI
bitsandbytes # quantization/bnb.py
bitsandbytes # quantization.py
Empty file removed scripts/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from setuptools import setup, find_packages

setup(
name='lit-llama',
version='0.1.0',
description='Implementation of the LLaMA language model',
author='Lightning AI',
url='https://github.com/lightning-AI/lit-llama',
install_requires=[
"torch>=2.0.0",
"lightning>=2.0.0",
"sentencepiece",
"bitsandbytes",
],
packages=find_packages(),
)
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from typing import Tuple

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy

import torch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

import numpy as np
from model import Block, LLaMA, LLaMAConfig

from lit_llama.model import Block, LLaMA, LLaMAConfig

out_dir = "out"
eval_interval = 2000
Expand Down

0 comments on commit 1d726f9

Please sign in to comment.