Skip to content

Commit

Permalink
Save checkpoint in train.py (Lightning-AI#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Apr 4, 2023
1 parent ac83160 commit f422a48
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
24 changes: 24 additions & 0 deletions lit_llama/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Utility functions for training and inference."""

import torch
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig


def save_model_checkpoint(fabric, model, file_path):
"""Handles boilerplate logic for retrieving and saving the state_dict.
This will be upstreamed to Fabric soon.
"""

if isinstance(fabric.strategy, FSDPStrategy):
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model._forward_module.state_dict()
else:
state_dict = model.state_dict()

if fabric.global_rank == 0:
torch.save(state_dict, file_path)
fabric.barrier()
11 changes: 6 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import numpy as np

from lit_llama.model import Block, LLaMA, LLaMAConfig
from lit_llama.utils import save_model_checkpoint

out_dir = "out"

out_dir = "out/training"
eval_interval = 2000
eval_iters = 200
log_interval = 1
Expand Down Expand Up @@ -87,12 +89,11 @@ def train(
# TODO: add learning rate scheduling

# evaluate the loss on train/val sets and write checkpoints
if iter_num > 0 and iter_num % eval_interval == 0 and fabric.global_rank == 0:
if iter_num > 0 and iter_num % eval_interval == 0:
val_loss = validate(fabric, model, val_data)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
# TODO: Save with Fabric
# print(f"saving checkpoint to {out_dir}")
# torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
fabric.print(f"Saving checkpoint to {out_dir}")
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pt"))

t0 = time.time()

Expand Down

0 comments on commit f422a48

Please sign in to comment.