Skip to content

Commit

Permalink
Modify search and eval so that the checkpoint loading is seamless for…
Browse files Browse the repository at this point in the history
… both litgpt-like checkpoints and whittle subnets. Fix bins in test.
  • Loading branch information
gabikadlecova committed Feb 13, 2025
1 parent c843b3c commit 52b5555
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 24 deletions.
4 changes: 3 additions & 1 deletion test/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def params_estimator(config):
@pytest.mark.parametrize("search_strategy", methods)
def test_multi_objective_search(search_strategy, num_samples=5):
bins, _ = (
param_bins(10, 2, 1) if search_strategy == "stratified_random_search" else None
param_bins(10, 2, 1)
if search_strategy == "stratified_random_search"
else (None, None)
)

results = multi_objective_search(
Expand Down
41 changes: 20 additions & 21 deletions whittle/evaluate_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from litgpt import Config
from litgpt.utils import lazy_load

from whittle.eval.utils import convert_and_evaluate
from whittle.metrics import compute_flops, compute_latency, compute_parameters
Expand All @@ -22,8 +23,6 @@ def setup(
latency_batch_size: int = 8,
device: str | None = None,
limit: float | None = None,
tokenizer_name_or_path: str | None = None,
is_sub_network: bool = False,
measure_flops: bool = False,
measure_latency: bool = False,
) -> None:
Expand All @@ -41,8 +40,6 @@ def setup(
latency_batch_size: Batch size for latency computation.
device: Device to use for evaluation, for example, "cuda" or "cuda:0".
limit: Limit on number of examples per task.
tokenizer_name_or_path: Name or path to the tokenizer file to use for the model. Default is checkpoint_dir.
is_sub_network: Whether the model is a sub-network config or a whittle model. Default is False.
measure_flops: Whether to compute FLOPs. Default is False.
measure_latency: Whether to compute latency. Default is False.
"""
Expand All @@ -54,42 +51,44 @@ def setup(
metrics_path.parent.mkdir(parents=True, exist_ok=True)

# sub-network saved as a config instead of the extracted lit_model.pth (to save memory)
subnet_config: dict[str, Any] | None = None
if is_sub_network:
ckp = torch.load(checkpoint_dir / "sub_network.pkl")
subnet_config = ckp["config"]
checkpoint_dir = ckp["parent_dir"]
sub_network_config: dict[str, Any] | None = None
ckp = lazy_load(checkpoint_dir / "lit_model.pth")

# sub-network config loading (contains the config and checkpoint path of the parent)
sub_network_config = ckp.get("sub_network_config", None)
parent_dir = ckp.get("parent_dir", None)
if parent_dir is not None:
checkpoint_dir = Path(parent_dir)
ckp = lazy_load(checkpoint_dir / "lit_model.pth")

config = Config.from_file(checkpoint_dir / "model_config.yaml")
config.fix_head_size = True
config.model_type = "gpt"
config.tie_embeddings = False

model = GPT(config)
model.name_or_path = tokenizer_name_or_path # WhittleLM loads AutoTokenizer inside
model.name_or_path = checkpoint_dir # WhittleLM loads AutoTokenizer inside

model.load_state_dict(ckp["model"] if "model" in ckp else ckp)
del ckp

# if the checkpoint was a sub-network, set it at this point
if sub_network_config is not None:
model.select_sub_network(sub_network_config)

# compute metrics
metrics = {}
metrics["parameters"] = compute_parameters(model)

if measure_flops:
metrics["flops"] = compute_latency(model)
if measure_latency:
metrics["latency"] = compute_flops(
model, batch_size=latency_batch_size, previous_device=device
)

metrics_path.write_text(json.dumps(metrics, indent=2))

# downstream task evaluation
model.to(device)

ckp = torch.load(checkpoint_dir / "lit_model.pth")
model.load_state_dict(ckp["model"] if "model" in ckp else ckp)
del ckp

if is_sub_network:
assert subnet_config is not None
model.select_sub_network(subnet_config)

convert_and_evaluate(
model,
out_dir=out_dir,
Expand Down
8 changes: 6 additions & 2 deletions whittle/search_sub_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
load_checkpoint,
parse_devices,
save_config,
copy_config_files,
)

from torch.utils.data import DataLoader

from whittle.args import ParamBinArgs, SearchArgs
Expand Down Expand Up @@ -324,12 +326,14 @@ def main(
sub_network = extract_current_sub_network(model)
model.reset_super_network()

fabric.save(save_path, {"model": sub_network, "parent_dir": checkpoint_dir})
fabric.save(save_path, {"model": sub_network})
save_config(sub_network.config, out_dir / f"sub_network_{i}")
copy_config_files(checkpoint_dir, save_path.parent)
else:
save_path = save_path.parent / "sub_network.pkl"
torch.save(
{"config": sub_network_dict, "parent_dir": checkpoint_dir}, save_path
{"sub_network_config": sub_network_dict, "parent_dir": checkpoint_dir},
save_path,
)

# save all paths to pareto optimal sub-networks
Expand Down

0 comments on commit 52b5555

Please sign in to comment.