Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: search results - evaluation and results processing #249

Merged
merged 33 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f62a1d3
Implement binning of nets by param count during search
gabikadlecova Feb 5, 2025
c09d528
Add test for binning, fix an off-by-1 error. Formatting.
gabikadlecova Feb 7, 2025
99efc41
Workflow for downstream eval. Save pareto front in search. Plotting o…
gabikadlecova Feb 9, 2025
da068f7
Modify search to use other objectives than val_loss and parameters. E…
gabikadlecova Feb 10, 2025
168d35f
Formatting
gabikadlecova Feb 11, 2025
698a5ac
Add typing
gabikadlecova Feb 11, 2025
4eaef06
Add model_config.yaml to the search checkpoint test. Add some typing
gabikadlecova Feb 11, 2025
f046d7e
Test format
gabikadlecova Feb 11, 2025
80660ae
Fix perplexity calculation in search
gabikadlecova Feb 11, 2025
a9fe3eb
Fix perplexity computation
gabikadlecova Feb 11, 2025
cad9b2a
format
gabikadlecova Feb 11, 2025
cf53449
Remove plotting and results from the library. Convert param bins reje…
gabikadlecova Feb 12, 2025
9a85591
Merge changes from main
gabikadlecova Feb 12, 2025
077fd7f
Format, add annotations
gabikadlecova Feb 12, 2025
edb2071
Imports sortign
gabikadlecova Feb 12, 2025
47b986c
Import sorting tests
gabikadlecova Feb 12, 2025
17fdb68
Format
gabikadlecova Feb 12, 2025
8f5fc56
Fix stratified strategy test
gabikadlecova Feb 12, 2025
d92ee72
Reformat test
gabikadlecova Feb 12, 2025
4858167
Adapt the eval and search code to finetuned models. Allow head size i…
gabikadlecova Feb 12, 2025
f4dbd67
Format
gabikadlecova Feb 12, 2025
6cfed4b
Fix flops device. Add typing
gabikadlecova Feb 12, 2025
805761a
Fix names in docstrings
gabikadlecova Feb 12, 2025
c843b3c
Merge branch 'main' into search-results-comparison
gabikadlecova Feb 12, 2025
52b5555
Modify search and eval so that the checkpoint loading is seamless for…
gabikadlecova Feb 13, 2025
ffd56ae
Make ruff happy
gabikadlecova Feb 13, 2025
37d03c5
Fix search test and param bins so that it's not in an inf loop
gabikadlecova Feb 14, 2025
06dfa1e
Check for intersection instead of whole dir contents in checkpoint
gabikadlecova Feb 14, 2025
b875c27
Swap latency and flops to be correct
gabikadlecova Feb 14, 2025
3e3293a
Modify how models are checkpointed in search. Add patience to param b…
gabikadlecova Feb 14, 2025
77b6a00
Control search verbosity via tqdm. Write tests for eval, modify searc…
gabikadlecova Feb 17, 2025
70e328f
Swap args in fabric.save
gabikadlecova Feb 17, 2025
5efa912
Small docs changes, rename objectives, move check of objectives to se…
gabikadlecova Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/test_extract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import pathlib
gabikadlecova marked this conversation as resolved.
Show resolved Hide resolved

import torch
from litgpt.config import Config
from litgpt.utils import save_config

from whittle.models.gpt import GPT
from whittle.models.gpt.extract import extract_current_sub_network, extract_sub_network
Expand Down Expand Up @@ -87,6 +90,47 @@ def test_extract_sub_network_llamamlp() -> None:
)


def test_save_config() -> None:
config = Config.from_name("micro-llama-300M")
config.fix_head_size = False

super_network = GPT(config)

# set norm weights to random
# the default is unit/zero vector which does not test the extract
make_norm_weights_random(super_network.transformer.ln_f)
for i in range(config.n_layer):
block = super_network.transformer.h[i]
make_norm_weights_random(block.norm_1)
make_norm_weights_random(block.post_attention_norm)
make_norm_weights_random(block.norm_2)
make_norm_weights_random(block.post_mlp_norm)

# simulate a smaller network
n_embd = 128
intermediate_size = 1024
n_layer = 6
n_head = 12

super_network.eval()
super_network.set_sub_network(
sub_network_n_embd=n_embd,
sub_network_intermediate_size=intermediate_size,
sub_network_num_heads=n_head,
sub_network_n_layers=n_layer,
)

# instantiate a new model
sub_network = extract_current_sub_network(super_network)
save_dir = pathlib.Path("microllama")
save_dir.mkdir(parents=True, exist_ok=True)
save_config(sub_network.config, save_dir)

cfg = Config.from_file(save_dir / "model_config.yaml")
for key in cfg.__dict__:
assert cfg.__dict__[key] == sub_network.config.__dict__[key]


def make_norm_weights_random(norm):
if norm is None or isinstance(norm, torch.nn.Identity):
return
Expand Down
60 changes: 60 additions & 0 deletions test/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from syne_tune.config_space import randint

from whittle.sampling.param_bins import ParamBins
from whittle.search import multi_objective_search
from whittle.search.baselines import methods

Expand All @@ -15,14 +16,41 @@ def objective(config, **kwargs):
search_space = {"a": randint(0, 10), "b": randint(0, 100)}


def param_bins(bin_n, bin_s, bin_t):
def params_estimator(config):
return config["a"]

bin_width = 10
min_config = {"a": 0}
max_config = {"a": bin_n * bin_width}

bins = ParamBins(
min_config,
max_config,
params_estimator,
num_bins=bin_n,
log_bins=False,
start_bin_size=bin_s,
empty_bin_tolerance=bin_t,
)
return bins, bin_width


@pytest.mark.parametrize("search_strategy", methods)
def test_multi_objective_search(search_strategy, num_samples=5):
bins, _ = (
param_bins(10, 2, 1)
if search_strategy == "stratified_random_search"
else (None, None)
)

results = multi_objective_search(
objective=objective,
search_strategy=search_strategy,
search_space=search_space,
objective_kwargs={},
num_samples=num_samples,
param_bins=bins,
)

assert all(
Expand All @@ -46,3 +74,35 @@ def test_multi_objective_search(search_strategy, num_samples=5):
hp_name: (hp.upper - hp.lower) // 2 for hp_name, hp in search_space.items()
}
assert results["configs"][2] == mid_point


bin_tolerance = [0, 1]
bin_size = [1, 2]
num_bins = [3, 10]


@pytest.mark.parametrize("bin_t", bin_tolerance)
@pytest.mark.parametrize("bin_s", bin_size)
@pytest.mark.parametrize("bin_n", num_bins)
def test_param_bins(bin_t, bin_s, bin_n):
bins, bin_width = param_bins(bin_n, bin_s, bin_t)

# fill up to bin_n - 1 bins
for j in range(bin_s):
for i in range(bin_n - bin_t - 1):
assert bins.put_in_bin({"a": 1 + i * bin_width})

# fill the last one unless it'd be filled fully (leave 1 not full)
if j < bin_s - 1:
assert bins.put_in_bin({"a": 1 + (bin_n - bin_t - 1) * bin_width})

assert bins.current_bin_length == bin_s

# last bin is not filled fully -> this should be false
assert not bins.put_in_bin({"a": 3})
assert bins.current_bin_length == bin_s

# last bin is filled fully -> this should be true
assert bins.put_in_bin({"a": 2 + bin_width * (bin_n - bin_t - 1)})
assert bins.put_in_bin({"a": 3})
assert bins.current_bin_length == bin_s + 1
1 change: 1 addition & 0 deletions test/test_search_sub_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,5 @@ def test_checkpoints(tmp_path, checkpoint_dir):
for checkpoint_dir in checkpoint_dirs:
assert set(os.listdir(out_dir / checkpoint_dir)) == {
"lit_model.pth",
"model_config.yaml",
}
14 changes: 14 additions & 0 deletions whittle/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,17 @@ class SearchArgs:
"""Multi-objective search strategy"""
iterations: int = 100
"""Number of iterations for the multi-objective search"""


@dataclass
class ParamBinArgs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that be part of SearchArgs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can imagine that you'll use binning outside of search (e.g. for pre-selection of networks to evaluate). But I can add it to SearchArgs

"""parameter bin-related arguments - to limit what networks are sampled"""

"""Number of parameter bins to use"""
num_bins: int = 20
"""Whether to use log spaced bins"""
log_bins: bool = False
"""Starting size of the bins (how many configs must be in each bin until the total limit is increased)"""
start_bin_size: int = 1
"""The total limit will be increased even if K bins are not full yet (some param counts may have only few nets)"""
empty_bin_tolerance: int = 4
106 changes: 106 additions & 0 deletions whittle/evaluate_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from __future__ import annotations

import json
from pathlib import Path
from typing import Any

from litgpt import Config
from litgpt.utils import lazy_load

from whittle.eval.utils import convert_and_evaluate
from whittle.metrics import compute_flops, compute_latency, compute_parameters
from whittle.models.gpt import GPT


def setup(
checkpoint_dir: Path,
timurcarstensen marked this conversation as resolved.
Show resolved Hide resolved
out_dir: Path | None = None,
tasks: str | None = None,
seed: int = 1337,
num_fewshot: int | None = None,
batch_size: int | str = 1,
latency_batch_size: int = 8,
device: str | None = None,
limit: float | None = None,
measure_flops: bool = False,
measure_latency: bool = False,
) -> None:
"""
Evaluate a model with the LM Evaluation Harness. Compute the latency of a PyTorch model for inference, and FLOPs.

Arguments:
checkpoint_dir: The path to the model's checkpoint directory to load for evaluation.
out_dir: Directory in which to save evaluation results. If not provided, saving to `checkpoint_dir/eval` by default.
tasks: Task names to evaluate. Example: "hellaswag,mmlu"
seed: The random seed to use for reproducibility.
num_fewshot: Number of examples in few-shot context.
batch_size: Batch size configuration as positive integer value (default: 1),
"auto", in the format 'auto:N', where 'auto:4' recomputes the batch size 4 times.
latency_batch_size: Batch size for latency computation.
device: Device to use for evaluation, for example, "cuda" or "cuda:0".
limit: Limit on number of examples per task.
measure_flops: Whether to compute FLOPs. Default is False.
measure_latency: Whether to compute latency. Default is False.
"""
if out_dir is None:
out_dir = checkpoint_dir / "eval"

metrics_path = out_dir / "metrics.json"

metrics_path.parent.mkdir(parents=True, exist_ok=True)

# sub-network saved as a config instead of the extracted lit_model.pth (to save memory)
sub_network_config: dict[str, Any] | None = None
ckp = lazy_load(checkpoint_dir / "lit_model.pth")

# sub-network config loading (contains the config and checkpoint path of the parent)
sub_network_config = ckp.get("sub_network_config", None)
parent_dir = ckp.get("parent_dir", None)
if parent_dir is not None:
checkpoint_dir = Path(parent_dir)
ckp = lazy_load(checkpoint_dir / "lit_model.pth")

config = Config.from_file(checkpoint_dir / "model_config.yaml")
config.fix_head_size = True
config.model_type = "gpt"
config.tie_embeddings = False

model = GPT(config)
model.name_or_path = checkpoint_dir # WhittleLM loads AutoTokenizer inside

model.load_state_dict(ckp["model"] if "model" in ckp else ckp)
del ckp

# if the checkpoint was a sub-network, set it at this point
if sub_network_config is not None:
model.select_sub_network(sub_network_config)

# compute metrics
metrics = {}
metrics["parameters"] = compute_parameters(model)
if measure_flops:
metrics["flops"] = compute_latency(model)
if measure_latency:
metrics["latency"] = compute_flops(
model, batch_size=latency_batch_size, previous_device=device
)
gabikadlecova marked this conversation as resolved.
Show resolved Hide resolved
metrics_path.write_text(json.dumps(metrics, indent=2))

# downstream task evaluation
model.to(device)
convert_and_evaluate(
model,
out_dir=out_dir,
tasks=tasks,
seed=seed,
num_fewshot=num_fewshot,
batch_size=batch_size,
device=device,
limit=limit,
)


if __name__ == "__main__":
from jsonargparse import CLI

CLI(setup)
6 changes: 6 additions & 0 deletions whittle/metrics/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def compute_flops(
batch_size: int = 1,
gabikadlecova marked this conversation as resolved.
Show resolved Hide resolved
sequence_length: int = 512,
metric: Literal["flops", "macs"] = "flops",
previous_device: str | None = None,
) -> float:
"""
Estimates the number of floating-point operations (FLOPs) or multiply-accumulate operations (MACs) for a GPT model.
Expand All @@ -33,6 +34,7 @@ def compute_flops(
batch_size: The batch size for the input tensor. Defaults to 1.
sequence_length: The sequence length for the input tensor. Defaults to 512.
metric: The metric to return. Either "flops" for floating-point operations or "macs" for multiply-accumulate operations. Defaults to "flops".
previous_device: The device to cast to after profiling. If None, the device is not changed. Defaults to None.

Returns:
The estimated number of floating-point operations (FLOPs) or multiply-accumulate operations (MACs) for the model's forward pass, depending on the specified metric.
Expand All @@ -43,6 +45,7 @@ def compute_flops(
)

model.eval()
model.to("cpu")

os.environ["DS_ACCELERATOR"] = "CPU"
deepspeed.accelerator.set_accelerator(CPU_Accelerator())
Expand All @@ -56,6 +59,9 @@ def compute_flops(
as_string=False,
)

if previous_device is not None:
model.to(previous_device)

if metric == "flops":
return flops
else:
Expand Down
13 changes: 9 additions & 4 deletions whittle/metrics/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def compute_latency(
use_cuda: bool = False,
batch_size: int = 8,
n_samples: int = 10,
device: str | None = None,
) -> float:
"""
Profiles the latency of a PyTorch model for inference.
Expand All @@ -53,16 +54,21 @@ def compute_latency(
use_cuda (bool, optional): If True and CUDA is available, the model will be moved to the GPU for profiling. Defaults to False.
batch_size (int, optional): The batch size for the input tensor. Defaults to 8.
n_samples (int, optional): The number of samples to profile after the warm-up phase. Defaults to 10.
device (Optional[str], optional): The device to use for profiling. If None, the device is inferred based on use_cuda. Defaults to None.

Returns:
float: The average inference time per sample in milliseconds.
"""
input_tensor = torch.randint(
0, model.config.padded_vocab_size, (batch_size, model.max_seq_length)
)
if use_cuda and torch.cuda.is_available():
model = model.cuda()
input_tensor = input_tensor.cuda()
if device is None:
if use_cuda and torch.cuda.is_available():
model = model.cuda()
input_tensor = input_tensor.cuda()
else:
model = model.to(device)
input_tensor = input_tensor.to(device)

# Use PyTorch profiler to record compute_latency
model.eval()
Expand All @@ -84,5 +90,4 @@ def compute_latency(

# Convert time to milliseconds
total_time_ms = (cpu_time_us + cuda_time_us) / 1000
model = model.cpu()
return total_time_ms / n_samples
3 changes: 2 additions & 1 deletion whittle/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,10 @@ def select_sub_network(self, config: dict[str, Any]) -> None:
"""
self.set_sub_network(
config["embed_dim"],
config["mlp_ratio"] * config["embed_dim"],
int(config["mlp_ratio"] * config["embed_dim"]),
config["num_heads"],
config["depth"],
sub_network_head_size=config.get("head_size", None),
)

def reset_super_network(self):
Expand Down
Loading
Loading