Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize checkpoint filename and extension #142

Merged
merged 1 commit into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ See `python generate.py --help` for more options.
You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:

```bash
python quantize.py --checkpoint_path state_dict.pth --tokenizer_path tokenizer.model --output_path llama-7b-gptq.4bit.pt --dtype bfloat16 --quantize gptq.int4
python quantize.py --checkpoint_path lit-llama.pth --tokenizer_path tokenizer.model --output_path llama-7b-gptq.4bit.pth --dtype bfloat16 --quantize gptq.int4
```

With the generated quantized checkpoint generation works as usual with `--quantize gptq.int4`, bringing GPU usage to about ~5GB. As only the weights of the Linear layers are quantized, it is useful to use `--dtype bfloat16` even with the quantization enabled.
Expand Down
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(
``"gptq.int4"``: GPTQ 4-bit mode.
"""
if not checkpoint_path:
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/state_dict.pth")
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file()
Expand Down
6 changes: 3 additions & 3 deletions finetune_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning.fabric.strategies import DeepSpeedStrategy


pretrained_path = "checkpoints/lit-llama/7B/state_dict.pth"
pretrained_path = "checkpoints/lit-llama/7B/lit-llama.pth"
out_dir = "out/adapter/alpaca"
eval_interval = 600
save_interval = 1000
Expand Down Expand Up @@ -95,7 +95,7 @@ def main():
train(fabric, model, optimizer, train_data, val_data)

# Save the final checkpoint at the end of training
save_model_checkpoint(fabric, model, os.path.join(out_dir, "alpaca-adapter-finetuned.ckpt"))
save_model_checkpoint(fabric, model, os.path.join(out_dir, "alpaca-adapter-finetuned.pth"))


def train(
Expand Down Expand Up @@ -140,7 +140,7 @@ def train(
if step_count % save_interval == 0:
print(f"Saving adapter weights to {out_dir}")
# TODO: Provide a function/script to merge the adapter weights with pretrained weights
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.ckpt"))
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}.pth"))

dt = time.time() - t0
if iter_num % log_interval == 0:
Expand Down
4 changes: 2 additions & 2 deletions finetune_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
config = LLaMAConfig.from_name("7B")
config.block_size = block_size

checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")
checkpoint = torch.load("checkpoints/lit-llama/7B/lit-llama.pth")

with fabric.device, lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
torch.set_default_tensor_type(torch.HalfTensor)
Expand Down Expand Up @@ -110,7 +110,7 @@ def train(
# We are only saving the LoRA weights
# TODO: Provide a function/script to merge the LoRA weights with pretrained weights
checkpoint = lora_state_dict(model)
fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pt"), checkpoint)
fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)

dt = time.time() - t0
if iter_num % log_interval == 0:
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main(
``"gptq.int4"``: GPTQ 4-bit mode.
"""
if not checkpoint_path:
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/state_dict.pth")
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file()
Expand Down
4 changes: 2 additions & 2 deletions generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def main(
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
"""
if not adapter_path:
adapter_path = Path("out/adapter/alpaca/alpaca-adapter-finetuned.pt")
adapter_path = Path("out/adapter/alpaca/alpaca-adapter-finetuned.pth")
if not pretrained_path:
pretrained_path = Path(f"./checkpoints/lit-llama/7B/state_dict.pth")
pretrained_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")

Expand Down
4 changes: 2 additions & 2 deletions lit_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def save_model_checkpoint(fabric, model, file_path):
fabric.barrier()
if fabric.global_rank == 0:
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pt"))
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
return

if isinstance(fabric.strategy, FSDPStrategy):
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(self, device=None, dtype=None, quantization_mode=None):
Example::
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
model = LLaMA.from_name('7B')
model.load_state_dict(torch.load('llama-lit/7B/state_dict.pth'))"""
model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""

self.quantization_mode = quantization_mode
self.quantized_linear_cls = None
Expand Down
2 changes: 1 addition & 1 deletion quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def main(
Note that ``"llm.int8"```does not need a quantization step.
"""
if not checkpoint_path:
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/state_dict.pth")
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file()
Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def meta_weights_for_nano_model(
del attn
gc.collect()

torch.save(combined, Path(output_dir, "state_dict.pth"))
torch.save(combined, Path(output_dir, "lit-llama.pth"))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def convert_hf_checkpoint(
model_size: str = "7B",
hf_checkpoint_path: Path = Path("checkpoints/llama-7b-hf"),
lit_checkpoint: Path = Path("checkpoints/lit-llama.ckpt"),
lit_checkpoint: Path = Path("checkpoints/lit-llama.pth"),
verify: bool = False,
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def train(
val_loss = validate(fabric, model, val_data)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
fabric.print(f"Saving checkpoint to {out_dir}")
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pt"))
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))

t0 = time.time()

Expand Down