Skip to content

Commit

Permalink
Step by step guides (#144)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Antiga <luca@lightning.ai>
  • Loading branch information
awaelchli and lantiga authored Apr 18, 2023
1 parent 016f7d1 commit 33ef184
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 22 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ python generate.py --prompt "Hello, my name is"

This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).

[Full guide for generating samples from the model](howto/inference.md).

### Run Lit-LLaMA on consumer devices

For GPUs with less memory, enable quantization (`--quantize llm.int8`) or use bfloat16 (`--dtype bfloat16`). Quantization will take longer to load but require ~8GB of memory. bfloat16 is closer to the "full deal" and runs on ~10GB of GPU memory.
This can run on any consumer GPU.
On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):

```bash
python generate.py --quantize llm.int8 --prompt "Hello, my name is"
Expand All @@ -121,7 +123,7 @@ python quantize.py --checkpoint_path lit-llama.pth --tokenizer_path tokenizer.mo

With the generated quantized checkpoint generation works as usual with `--quantize gptq.int4`, bringing GPU usage to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to use `--dtype bfloat16` even with the quantization enabled.

&nbsp;
[Full guide for generating samples from the model](howto/inference.md).

## Finetune the model

Expand All @@ -147,6 +149,11 @@ It is expected that you have downloaded the pretrained weights as described abov
The finetuning requires at least one GPU with ~24 GB memory (GTX 3090). Follow the instructions in the script to efficiently fit your GPU memory.
Note: For some GPU models you might need to set `torch.backends.cuda.enable_flash_sdp(False)` (see comments at the top of the script).

More details about each finetuning method and how you can apply it to your own data can be found in our how-to guides:

- [Finetune with LoRA](howto/finetune_lora.md)
- [Finetune with Adapters](howto/finetune_adapter.md)

## Get involved!

We're in a quest towards fully open source AI.
Expand Down
31 changes: 19 additions & 12 deletions finetune_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
https://arxiv.org/abs/2303.16199
This script uses DeepSpeed Zero-2 to train efficiently on 8 A100 GPUs within 1 hour as done in the original paper.
If you have fewer GPUs, you can adjust the devices variable to e.g. `devices = 1` and tune the
`micro_batch_size` to fit your GPU memory.
This script runs on a single GPU by default. You can adjust the `micro_batch_size` to fit your GPU memory.
You can finetune within 1 hour as done in the original paper using DeepSpeed Zero-2 on 8 A100 GPUs by setting the
devices variable to `devices = 8` and `micro_batch_size = 8` (or higher).
Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
Expand All @@ -27,18 +27,16 @@
from lightning.fabric.strategies import DeepSpeedStrategy


pretrained_path = "checkpoints/lit-llama/7B/lit-llama.pth"
out_dir = "out/adapter/alpaca"
eval_interval = 600
save_interval = 1000
eval_iters = 100
log_interval = 1
devices = 8
devices = 1

# Hyperparameters
learning_rate = 9e-3
batch_size = 64 / devices
micro_batch_size = 8
micro_batch_size = 4
gradient_accumulation_steps = batch_size // micro_batch_size
epoch_size = 50000 # train dataset size
num_epochs = 5
Expand All @@ -54,7 +52,12 @@
}


def main():
def main(
data_dir: str = "data/alpaca",
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
out_dir: str = "out/adapter/alpaca",
):

fabric = L.Fabric(
accelerator="cuda",
devices=devices,
Expand All @@ -67,7 +70,7 @@ def main():
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)

train_data, val_data = load_datasets()
train_data, val_data = load_datasets(data_dir=data_dir)

config = LLaMAConfig()
config.block_size = block_size
Expand All @@ -93,7 +96,7 @@ def main():

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
model, optimizer = fabric.setup(model, optimizer)
train(fabric, model, optimizer, train_data, val_data)
train(fabric, model, optimizer, train_data, val_data, out_dir)

# Save the final checkpoint at the end of training
save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-adapter-finetuned.pth"))
Expand All @@ -105,6 +108,7 @@ def train(
optimizer: torch.optim.Optimizer,
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
) -> None:
"""The training loop.
Expand Down Expand Up @@ -215,7 +219,7 @@ def pad_right(x, pad_id):
return x, y


def load_datasets(data_dir: str = "data/alpaca"):
def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
return train_data, val_data
Expand Down Expand Up @@ -248,4 +252,7 @@ def save_model_checkpoint(fabric, model, file_path):
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
main()

from jsonargparse.cli import CLI

CLI(main)
22 changes: 15 additions & 7 deletions finetune_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from scripts.prepare_alpaca import generate_prompt


out_dir = "out/lora/alpaca"
eval_interval = 100
save_interval = 100
eval_iters = 100
Expand All @@ -39,20 +38,25 @@
warmup_steps = 100


def main():
def main(
data_dir: str = "data/alpaca",
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
out_dir: str = "out/lora/alpaca",
):

fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-mixed")
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)

if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)

train_data, val_data = load_datasets()
train_data, val_data = load_datasets(data_dir=data_dir)

config = LLaMAConfig.from_name("7B")
config.block_size = block_size

checkpoint = torch.load("checkpoints/lit-llama/7B/lit-llama.pth")
checkpoint = torch.load(pretrained_path)

with fabric.device, lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
torch.set_default_tensor_type(torch.HalfTensor)
Expand All @@ -65,7 +69,7 @@ def main():

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model, optimizer = fabric.setup(model, optimizer)
train(fabric, model, optimizer, train_data, val_data)
train(fabric, model, optimizer, train_data, val_data, out_dir)

# Save the final LoRA checkpoint at the end of training
checkpoint = lora_state_dict(model)
Expand All @@ -78,6 +82,7 @@ def train(
optimizer: torch.optim.Optimizer,
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
) -> None:
"""The training loop.
Expand Down Expand Up @@ -189,7 +194,7 @@ def pad_right(x, pad_id):
return x, y


def load_datasets(data_dir: str = "data/alpaca"):
def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
return train_data, val_data
Expand All @@ -199,4 +204,7 @@ def load_datasets(data_dir: str = "data/alpaca"):
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
main()

from jsonargparse.cli import CLI

CLI(main)
44 changes: 44 additions & 0 deletions howto/download_weights.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
## Downloading pretrained weights

Except for when you are training from scratch, you will need the pretrained weights from Meta.
Download the model weights following the instructions on the official [LLaMA repository](https://github.com/facebookresearch/llama).

Once downloaded, you should have a folder like this:

```text
checkpoints/llama
├── 7B
│ ├── checklist.chk
│ ├── consolidated.00.pth
│ └── params.json
├── 13B
│ ...
├── tokenizer_checklist.chk
└── tokenizer.model
```

Convert the weights to the Lit-LLaMA format:

```bash
python scripts/convert_checkpoint.py \
--output_dir checkpoints/lit-llama \
--ckpt_dir checkpoints/llama \
--tokenizer_path checkpoints/llama/tokenizer.model \
--model_size 7B
```

You are all set. Now you can continue with inference or finetuning.

## Convert from HuggingFace

It is also possible to import weights in the format of the HuggingFace [LLaMA](https://huggingface.co/docs/transformers/main/en/model_doc/llama#transformers.LlamaForCausalLM) model.
Run this script to convert the weights for loading into Lit-LLaMA:

```bash
python scripts/convert_hf_checkpoint.py \
--hf_checkpoint_path path/to/hf/checkpoint/folder \
--lit_checkpoint checkpoints/lit-llama.pth
--model_size 7B
```

You can now run [`generate.py` to test the imported weights](inference.md).
97 changes: 97 additions & 0 deletions howto/finetune_adapter.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Finetuning with Adapter

[LLaMA-Adapter](https://arxiv.org/abs/2303.16199) is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in LLaMA. In total, there are only 1.2M parameters to update during finetuning, which significantly reduces the memory footprint and speeds up training.

We are able to demonstrate instruction-finetuning Lit-LLaMA 7B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single GTX 3090 (24GB) GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.

## Preparation

The steps here only need to be done once:

1. Follow the instructions in the [README](README.md) to install the dependencies.
2. Download and convert the weights and save them in the `./checkpoints` folder as described [here](download_weights.md).
3. If you want to utilize more than one GPU, you should `pip install deepspeed`.
4. Download the data and generate the Alpaca instruction tuning dataset:

```bash
python scripts/prepare_alpaca.py
```

or [prepare your own dataset](#tune-on-your-own-dataset).

## Running the finetuning

```bash
python finetune_adapter.py
```

The finetuning requires at least one GPU with ~24 GB memory (GTX 3090).
You can speed up training by setting the `devices` variable in the script to utilize more GPUs if available.
Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.

For example, the follwing settings will let you finetune the model in under 1 hour using DeepSpeed Zero-2:
```python
devices = 8
micro_batch_size = 8
```

This script will save checkpoints periodically to the folder `out/`.

## Test the model

You can test the finetuned model with your own instructions by running:

```bash
python generate_adapter.py \
--prompt "Recommend a movie to watch on the weekend." \
--quantize llm.int8
```
Output:
```
A good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...
```
If your GPU supports `bfloat16`, the script will automatically use it. Together with `--quantize llm.int8`, this brings the memory consumption down to ~8 GB.

## Tune on your dataset

With only a few modifications, you can prepare and train on your own instruction dataset.

1. Create a json file in which each row holds one instruction-response pair.
A row has an entry for 'instruction', 'input', and 'output', where 'input' is optional an can be
the empty string if the instruction doesn't require a context. Below is an example json file:

```
[
{
"instruction": "Arrange the given numbers in ascending order.",
"input": "2, 4, 0, 8, 3",
"output": "0, 2, 3, 4, 8"
},
...
]
```
2. Make a copy of `scripts/prepare_alpaca.py` and name it what you want:
```bash
cp scripts/prepare_alpaca.py scripts/prepare_mydata.py
```
3. Modify `scripts/prepare_mydata.py` to read the json data file.
4. Run the script to generate the preprocessed, tokenized train-val split:
```bash
python scripts/prepare_mydata.py --destination_path data/mydata/
```
5. Run `finetune_adapter.py` by passing in the location of your data (and optionally other parameters):
```bash
python finetune_adapter.py --data_dir data/mydata/ --out_dir out/myexperiment
```
## Troubleshooting
If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
Loading

0 comments on commit 33ef184

Please sign in to comment.