diff --git a/whittle/evaluate_network.py b/whittle/evaluate_network.py index 17e72c7..e2d995b 100644 --- a/whittle/evaluate_network.py +++ b/whittle/evaluate_network.py @@ -2,6 +2,7 @@ import json from pathlib import Path +from typing import Any import torch from litgpt import Config @@ -53,7 +54,7 @@ def setup( metrics_path.parent.mkdir(parents=True, exist_ok=True) # sub-network saved as a config instead of the extracted lit_model.pth (to save memory) - subnet_config = None + subnet_config: dict[str, Any] | None = None if is_sub_network: ckp = torch.load(checkpoint_dir / "sub_network.pkl") subnet_config = ckp["config"] @@ -74,7 +75,7 @@ def setup( metrics["flops"] = compute_latency(model) if measure_latency: metrics["latency"] = compute_flops( - model, batch_size=latency_batch_size, device=device + model, batch_size=latency_batch_size, previous_device=device ) metrics_path.write_text(json.dumps(metrics, indent=2)) @@ -86,6 +87,7 @@ def setup( del ckp if is_sub_network: + assert subnet_config is not None model.select_sub_network(subnet_config) convert_and_evaluate( diff --git a/whittle/search_sub_networks.py b/whittle/search_sub_networks.py index 06e362d..3f22a5b 100644 --- a/whittle/search_sub_networks.py +++ b/whittle/search_sub_networks.py @@ -11,10 +11,10 @@ from lightning.fabric.strategies import FSDPStrategy from litgpt import Tokenizer from litgpt.args import EvalArgs, TrainArgs -from litgpt.data import DataModule, TinyStories, Alpaca +from litgpt.data import Alpaca, DataModule, TinyStories +from litgpt.finetune.lora import validate as finetune_validate from litgpt.model import Config from litgpt.pretrain import get_dataloaders, validate -from litgpt.finetune.lora import validate as finetune_validate from litgpt.utils import ( auto_download_checkpoint, check_nvlink_connectivity,