Skip to content

Commit

Permalink
Standardize checkpoint filename and extension (Lightning-AI#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Apr 16, 2023
1 parent f135ed8 commit 528fa0d
Show file tree
Hide file tree
Showing 11 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ See `python generate.py --help` for more options.
You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:

```bash
python quantize.py --checkpoint_path state_dict.pth --tokenizer_path tokenizer.model --output_path llama-7b-gptq.4bit.pt --dtype bfloat16 --quantize gptq.int4
python quantize.py --checkpoint_path lit-llama.pth --tokenizer_path tokenizer.model --output_path llama-7b-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4
```

With the generated quantized checkpoint generation works as usual with `--quantize gptq.int4`, bringing GPU usage to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to use `--dtype bfloat16` even with the quantization enabled.
Expand Down
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(
``"gptq.int4"``: GPTQ 4-bit mode.
"""
if not checkpoint_path:
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/state_dict.pth")
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file()
Expand Down
6 changes: 3 additions & 3 deletions finetune_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning.fabric.strategies import DeepSpeedStrategy


pretrained_path = "checkpoints/lit-llama/7B/state_dict.pth"
pretrained_path = "checkpoints/lit-llama/7B/lit-llama.pth"
out_dir = "out/adapter/alpaca"
eval_interval = 600
save_interval = 1000
Expand Down Expand Up @@ -95,7 +95,7 @@ def main():
train(fabric, model, optimizer, train_data, val_data)

# Save the final checkpoint at the end of training
save_model_checkpoint(fabric, model, os.path.join(out_dir, "alpaca-adapter-finetuned.ckpt"))
save_model_checkpoint(fabric, model, os.path.join(out_dir, "alpaca-adapter-finetuned.pth"))


def train(
Expand Down Expand Up @@ -140,7 +140,7 @@ def train(
if step_count % save_interval == 0:
print(f"Saving adapter weights to {out_dir}")
# TODO: Provide a function/script to merge the adapter weights with pretrained weights
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.ckpt"))
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.pth"))

dt = time.time() - t0
if iter_num % log_interval == 0:
Expand Down
4 changes: 2 additions & 2 deletions finetune_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
config = LLaMAConfig.from_name("7B")
config.block_size = block_size

checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")
checkpoint = torch.load("checkpoints/lit-llama/7B/lit-llama.pth")

with fabric.device, lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
torch.set_default_tensor_type(torch.HalfTensor)
Expand Down Expand Up @@ -110,7 +110,7 @@ def train(
# We are only saving the LoRA weights
# TODO: Provide a function/script to merge the LoRA weights with pretrained weights
checkpoint = lora_state_dict(model)
fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pt"), checkpoint)
fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)

dt = time.time() - t0
if iter_num % log_interval == 0:
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main(
``"gptq.int4"``: GPTQ 4-bit mode.
"""
if not checkpoint_path:
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/state_dict.pth")
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file()
Expand Down
4 changes: 2 additions & 2 deletions generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def main(
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
"""
if not adapter_path:
adapter_path = Path("out/adapter/alpaca/alpaca-adapter-finetuned.pt")
adapter_path = Path("out/adapter/alpaca/alpaca-adapter-finetuned.pth")
if not pretrained_path:
pretrained_path = Path(f"./checkpoints/lit-llama/7B/state_dict.pth")
pretrained_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")

Expand Down
4 changes: 2 additions & 2 deletions lit_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def save_model_checkpoint(fabric, model, file_path):
fabric.barrier()
if fabric.global_rank == 0:
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pt"))
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
return

if isinstance(fabric.strategy, FSDPStrategy):
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(self, device=None, dtype=None, quantization_mode=None):
Example::
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
model = LLaMA.from_name('7B')
model.load_state_dict(torch.load('llama-lit/7B/state_dict.pth'))"""
model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""

self.quantization_mode = quantization_mode
self.quantized_linear_cls = None
Expand Down
2 changes: 1 addition & 1 deletion quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def main(
Note that ``"llm.int8"```does not need a quantization step.
"""
if not checkpoint_path:
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/state_dict.pth")
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file()
Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def meta_weights_for_nano_model(
del attn
gc.collect()

torch.save(combined, Path(output_dir, "state_dict.pth"))
torch.save(combined, Path(output_dir, "lit-llama.pth"))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def convert_hf_checkpoint(
model_size: str = "7B",
hf_checkpoint_path: Path = Path("checkpoints/llama-7b-hf"),
lit_checkpoint: Path = Path("checkpoints/lit-llama.ckpt"),
lit_checkpoint: Path = Path("checkpoints/lit-llama.pth"),
verify: bool = False,
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def train(
val_loss = validate(fabric, model, val_data)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
fabric.print(f"Saving checkpoint to {out_dir}")
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pt"))
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))

t0 = time.time()

Expand Down

0 comments on commit 528fa0d

Please sign in to comment.