Skip to content

Commit

Permalink
Fix flops device. Add typing
Browse files Browse the repository at this point in the history
  • Loading branch information
gabikadlecova committed Feb 12, 2025
1 parent f4dbd67 commit 6cfed4b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions whittle/evaluate_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
from pathlib import Path
from typing import Any

import torch
from litgpt import Config
Expand Down Expand Up @@ -53,7 +54,7 @@ def setup(
metrics_path.parent.mkdir(parents=True, exist_ok=True)

# sub-network saved as a config instead of the extracted lit_model.pth (to save memory)
subnet_config = None
subnet_config: dict[str, Any] | None = None
if is_sub_network:
ckp = torch.load(checkpoint_dir / "sub_network.pkl")
subnet_config = ckp["config"]
Expand All @@ -74,7 +75,7 @@ def setup(
metrics["flops"] = compute_latency(model)
if measure_latency:
metrics["latency"] = compute_flops(
model, batch_size=latency_batch_size, device=device
model, batch_size=latency_batch_size, previous_device=device
)

metrics_path.write_text(json.dumps(metrics, indent=2))
Expand All @@ -86,6 +87,7 @@ def setup(
del ckp

if is_sub_network:
assert subnet_config is not None
model.select_sub_network(subnet_config)

convert_and_evaluate(
Expand Down
4 changes: 2 additions & 2 deletions whittle/search_sub_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from lightning.fabric.strategies import FSDPStrategy
from litgpt import Tokenizer
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import DataModule, TinyStories, Alpaca
from litgpt.data import Alpaca, DataModule, TinyStories
from litgpt.finetune.lora import validate as finetune_validate
from litgpt.model import Config
from litgpt.pretrain import get_dataloaders, validate
from litgpt.finetune.lora import validate as finetune_validate
from litgpt.utils import (
auto_download_checkpoint,
check_nvlink_connectivity,
Expand Down

0 comments on commit 6cfed4b

Please sign in to comment.