diff --git a/test/test_evaluate.py b/test/test_evaluate.py new file mode 100644 index 0000000..ac58413 --- /dev/null +++ b/test/test_evaluate.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import json +import pathlib +import shutil + +import pytest +import torch +from litgpt.scripts.download import download_from_hub +from litgpt.utils import ( + copy_config_files as copy_config_files_func, +) + +from whittle import evaluate_network +from whittle.models.gpt import GPT + + +@pytest.fixture(scope="session") +def checkpoint_dir(tmp_path_factory): + checkpoint_dir = tmp_path_factory.getbasetemp() + download_from_hub(repo_id="EleutherAI/pythia-14m", checkpoint_dir=checkpoint_dir) + return pathlib.Path(checkpoint_dir) / "EleutherAI" / "pythia-14m" + + +def convert_and_evaluate_mock( + model, out_dir, tasks, seed, num_fewshot, batch_size, device, limit, **kwargs +): + assert tasks == "arc_easy" + assert seed == 42 + assert num_fewshot == 1 + assert batch_size == 1 + assert device == "cpu" + assert limit == 1 + + assert isinstance(model, GPT) + + +def setup_checkpoint_dir(checkpoint_dir, sub_network_dir, checkpoint_mode): + # test all supported checkpoint formats + if checkpoint_mode == "litgpt": + copy_config_files_func(checkpoint_dir, sub_network_dir) + shutil.copy(checkpoint_dir / "lit_model.pth", sub_network_dir / "lit_model.pth") + + elif checkpoint_mode == "whittle": + copy_config_files_func(checkpoint_dir, sub_network_dir) + ckp = torch.load(checkpoint_dir / "lit_model.pth") + ckp = {"model": ckp, "parent_dir": checkpoint_dir} + torch.save(ckp, sub_network_dir / "lit_model.pth") + + elif checkpoint_mode == "whittle-minimalistic": + ckp = torch.load(checkpoint_dir / "lit_model.pth") + ckp = {"model": ckp, "parent_dir": checkpoint_dir} + torch.save(ckp, sub_network_dir / "lit_model.pth") + shutil.copy( + checkpoint_dir / "model_config.yaml", sub_network_dir / "model_config.yaml" + ) + + elif checkpoint_mode == "whittle-sub-network": + sub_network_config = { + "embed_dim": 2, + "mlp_ratio": 1.5, + "num_heads": 1, + "depth": 1, + } + torch.save( + {"sub_network_config": sub_network_config, "parent_dir": checkpoint_dir}, + sub_network_dir / "lit_model.pth", + ) + + +@pytest.mark.parametrize("measure_latency", [True, False]) +@pytest.mark.parametrize("measure_flops", [True, False]) +@pytest.mark.parametrize( + "checkpoint_mode", + ["litgpt", "whittle", "whittle-minimalistic", "whittle-sub-network"], +) +def test_evaluate(checkpoint_dir, checkpoint_mode, measure_flops, measure_latency): + sub_network_dir = pathlib.Path(checkpoint_dir) / "sub_network" + sub_network_dir.mkdir(parents=True, exist_ok=True) + + setup_checkpoint_dir(checkpoint_dir, sub_network_dir, checkpoint_mode) + + evaluate_network.convert_and_evaluate = convert_and_evaluate_mock + + evaluate_network.setup( + sub_network_dir, + measure_latency=measure_latency, + measure_flops=measure_flops, + tasks="arc_easy", + seed=42, + num_fewshot=1, + batch_size=1, + device="cpu", + limit=1, + ) + + metrics_path = sub_network_dir / "eval" / "metrics.json" + with open(metrics_path) as f: + metrics = json.load(f) + + assert "parameters" in metrics + assert "latency" in metrics if measure_latency else "latency" not in metrics + assert "flops" in metrics if measure_flops else "flops" not in metrics + + for v in metrics.values(): + assert isinstance(v, (int, float)) diff --git a/test/test_extract.py b/test/test_extract.py index cb3badb..7003f2d 100644 --- a/test/test_extract.py +++ b/test/test_extract.py @@ -1,7 +1,10 @@ from __future__ import annotations +import pathlib + import torch from litgpt.config import Config +from litgpt.utils import save_config from whittle.models.gpt import GPT from whittle.models.gpt.extract import extract_current_sub_network, extract_sub_network @@ -87,6 +90,47 @@ def test_extract_sub_network_llamamlp() -> None: ) +def test_save_config() -> None: + config = Config.from_name("micro-llama-300M") + config.fix_head_size = False + + super_network = GPT(config) + + # set norm weights to random + # the default is unit/zero vector which does not test the extract + make_norm_weights_random(super_network.transformer.ln_f) + for i in range(config.n_layer): + block = super_network.transformer.h[i] + make_norm_weights_random(block.norm_1) + make_norm_weights_random(block.post_attention_norm) + make_norm_weights_random(block.norm_2) + make_norm_weights_random(block.post_mlp_norm) + + # simulate a smaller network + n_embd = 128 + intermediate_size = 1024 + n_layer = 6 + n_head = 12 + + super_network.eval() + super_network.set_sub_network( + sub_network_n_embd=n_embd, + sub_network_intermediate_size=intermediate_size, + sub_network_num_heads=n_head, + sub_network_n_layers=n_layer, + ) + + # instantiate a new model + sub_network = extract_current_sub_network(super_network) + save_dir = pathlib.Path("microllama") + save_dir.mkdir(parents=True, exist_ok=True) + save_config(sub_network.config, save_dir) + + cfg = Config.from_file(save_dir / "model_config.yaml") + for key in cfg.__dict__: + assert cfg.__dict__[key] == sub_network.config.__dict__[key] + + def make_norm_weights_random(norm): if norm is None or isinstance(norm, torch.nn.Identity): return diff --git a/test/test_search.py b/test/test_search.py index a80662c..2727d7b 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -4,6 +4,7 @@ import pytest from syne_tune.config_space import randint +from whittle.sampling.param_bins import ParamBins from whittle.search import multi_objective_search from whittle.search.baselines import methods @@ -15,14 +16,40 @@ def objective(config, **kwargs): search_space = {"a": randint(0, 10), "b": randint(0, 100)} +def param_bins(bin_n, bin_s, bin_t, max_config=100): + def params_estimator(config): + return config["a"] + + min_config = {"a": 0} + max_config = {"a": max_config} + + bins = ParamBins( + min_config, + max_config, + params_estimator, + num_bins=bin_n, + log_bins=False, + start_bin_size=bin_s, + empty_bin_tolerance=bin_t, + ) + return bins + + @pytest.mark.parametrize("search_strategy", methods) def test_multi_objective_search(search_strategy, num_samples=5): + bins = ( + param_bins(10, 2, 1, max_config=10) + if search_strategy == "stratified_random_search" + else (None, None) + ) + results = multi_objective_search( objective=objective, search_strategy=search_strategy, search_space=search_space, objective_kwargs={}, num_samples=num_samples, + param_bins=bins, ) assert all( @@ -46,3 +73,36 @@ def test_multi_objective_search(search_strategy, num_samples=5): hp_name: (hp.upper - hp.lower) // 2 for hp_name, hp in search_space.items() } assert results["configs"][2] == mid_point + + +bin_tolerance = [0, 1] +bin_size = [1, 2] +num_bins = [3, 10] + + +@pytest.mark.parametrize("bin_t", bin_tolerance) +@pytest.mark.parametrize("bin_s", bin_size) +@pytest.mark.parametrize("bin_n", num_bins) +def test_param_bins(bin_t, bin_s, bin_n): + bin_width = 10 + bins = param_bins(bin_n, bin_s, bin_t, max_config=bin_n * bin_width) + + # fill up to bin_n - 1 bins + for j in range(bin_s): + for i in range(bin_n - bin_t - 1): + assert bins.put_in_bin({"a": 1 + i * bin_width}) + + # fill the last one unless it'd be filled fully (leave 1 not full) + if j < bin_s - 1: + assert bins.put_in_bin({"a": 1 + (bin_n - bin_t - 1) * bin_width}) + + assert bins.current_bin_length == bin_s + + # last bin is not filled fully -> this should be false + assert not bins.put_in_bin({"a": 3}) + assert bins.current_bin_length == bin_s + + # last bin is filled fully -> this should be true + assert bins.put_in_bin({"a": 2 + bin_width * (bin_n - bin_t - 1)}) + assert bins.put_in_bin({"a": 3}) + assert bins.current_bin_length == bin_s + 1 diff --git a/test/test_search_sub_network.py b/test/test_search_sub_network.py index f06c711..38d1a49 100644 --- a/test/test_search_sub_network.py +++ b/test/test_search_sub_network.py @@ -12,6 +12,7 @@ from litgpt.args import EvalArgs from litgpt.config import Config from litgpt.scripts.download import download_from_hub +from litgpt.utils import lazy_load from torch.utils.data import DataLoader from whittle import search_sub_networks @@ -32,7 +33,7 @@ def test_objective(): sub_network_config = {"embed_dim": 2, "mlp_ratio": 1, "num_heads": 1, "depth": 1} x, y = search_sub_networks._objective( config=sub_network_config, - fabric=Fabric(), + fabric=Fabric(devices=1), model=model, val_dataloader=dataloader, eval=EvalArgs(interval=1, max_iters=1, final_validation=False), @@ -50,7 +51,25 @@ def checkpoint_dir(tmp_path_factory): return pathlib.Path(checkpoint_dir) / "EleutherAI" / "pythia-14m" -def test_checkpoints(tmp_path, checkpoint_dir): +def get_checkpoint_contents(copy_config_files, save_checkpoints): + # model_config.yaml if in the parent super-net directory + if not save_checkpoints: + return {"lit_model.pth"} + # we copied config files, model_config.yaml and tokenizer configs are the most important + if copy_config_files: + return { + "lit_model.pth", + "model_config.yaml", + "tokenizer.json", + "tokenizer_config.json", + } + # other config files are in the parent super-net directory + return {"lit_model.pth", "model_config.yaml"} + + +@pytest.mark.parametrize("copy_config_files", [True, False]) +@pytest.mark.parametrize("save_checkpoints", [True, False]) +def test_checkpoints(tmp_path, checkpoint_dir, copy_config_files, save_checkpoints): dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]]) dataloader = DataLoader(dataset, batch_size=3) search_sub_networks.get_dataloaders = Mock(return_value=(dataloader, dataloader)) @@ -64,6 +83,9 @@ def test_checkpoints(tmp_path, checkpoint_dir): out_dir=out_dir, search=SearchArgs(iterations=3), eval=EvalArgs(interval=1, max_iters=1, final_validation=False), + save_checkpoints=save_checkpoints, + copy_config_files=copy_config_files, + verbose=False, ) out_dir_contents = set(os.listdir(out_dir)) @@ -76,6 +98,18 @@ def test_checkpoints(tmp_path, checkpoint_dir): assert checkpoint_dirs.issubset(out_dir_contents) assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: - assert set(os.listdir(out_dir / checkpoint_dir)) == { - "lit_model.pth", - } + # Check that the checkpoint directory contains the expected files + contents = get_checkpoint_contents(copy_config_files, save_checkpoints) + assert contents.issubset(set(os.listdir(out_dir / checkpoint_dir))) + + # Check the contents of lit_model.pth + checkpoint = lazy_load(out_dir / checkpoint_dir / "lit_model.pth") + if save_checkpoints: + assert "model" in checkpoint + if not copy_config_files: + assert "parent_dir" in checkpoint + else: + assert "sub_network_config" in checkpoint + assert "parent_dir" in checkpoint + + assert os.path.exists(out_dir / "pareto_optimal_paths.json") diff --git a/whittle/args.py b/whittle/args.py index 3419de2..796310d 100644 --- a/whittle/args.py +++ b/whittle/args.py @@ -15,3 +15,17 @@ class SearchArgs: """Multi-objective search strategy""" iterations: int = 100 """Number of iterations for the multi-objective search""" + + +@dataclass +class ParamBinArgs: + """parameter bin-related arguments - to limit what networks are sampled""" + + """Number of parameter bins to use""" + num_bins: int = 20 + """Whether to use log spaced bins""" + log_bins: bool = False + """Starting size of the bins (how many configs must be in each bin until the total limit is increased)""" + start_bin_size: int = 1 + """The total limit will be increased even if K bins are not full yet (some param counts may have only few nets)""" + empty_bin_tolerance: int = 4 diff --git a/whittle/evaluate_network.py b/whittle/evaluate_network.py new file mode 100644 index 0000000..e3f528f --- /dev/null +++ b/whittle/evaluate_network.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from litgpt import Config +from litgpt.utils import lazy_load + +from whittle.eval.utils import convert_and_evaluate +from whittle.metrics import compute_flops, compute_latency, compute_parameters +from whittle.models.gpt import GPT + + +def setup( + checkpoint_dir: Path, + out_dir: Path | None = None, + tasks: str | None = None, + seed: int = 1337, + num_fewshot: int | None = None, + batch_size: int | str = 1, + latency_batch_size: int = 8, + device: str | None = None, + limit: float | None = None, + measure_flops: bool = False, + measure_latency: bool = False, +) -> None: + """ + Evaluate a model with the LM Evaluation Harness. Compute the latency of a PyTorch model for inference, and FLOPs. + + Arguments: + checkpoint_dir: The path to the model's checkpoint directory to load for evaluation. + out_dir: Directory in which to save evaluation results. If not provided, saving to `checkpoint_dir/eval` by default. + tasks: Task names to evaluate. Example: "hellaswag,mmlu" + seed: The random seed to use for reproducibility. + num_fewshot: Number of examples in few-shot context. + batch_size: Batch size configuration as positive integer value (default: 1), + "auto", in the format 'auto:N', where 'auto:4' recomputes the batch size 4 times. + latency_batch_size: Batch size for latency computation. + device: Device to use for evaluation, for example, "cuda" or "cuda:0". + limit: Limit on number of examples per task. + measure_flops: Whether to compute FLOPs. Default is False. + measure_latency: Whether to compute latency. Default is False. + """ + if out_dir is None: + out_dir = checkpoint_dir / "eval" + + metrics_path = out_dir / "metrics.json" + + metrics_path.parent.mkdir(parents=True, exist_ok=True) + + # sub-network saved as a config instead of the extracted lit_model.pth (to save memory) + sub_network_config: dict[str, Any] | None = None + ckp = lazy_load(checkpoint_dir / "lit_model.pth") + + # sub-network config loading (contains the config and checkpoint path of the parent) + sub_network_config = ckp.get("sub_network_config", None) + parent_dir = ckp.get("parent_dir", None) + # it's either a standalone litgpt model or a sub-network (depending on if there is also a parent_dir) + if "model" not in ckp: + # not None: sub-network, None: raw state dict + if parent_dir is not None: + checkpoint_dir = Path(parent_dir) + ckp = lazy_load(checkpoint_dir / "lit_model.pth") + + config = Config.from_file(checkpoint_dir / "model_config.yaml") + config.fix_head_size = True + config.model_type = "gpt" + config.tie_embeddings = False + + model = GPT(config) + # WhittleLM loads AutoTokenizer inside - either we copied it to checkpoint_dir, or it is referenced in parent_dir + model.name_or_path = checkpoint_dir if parent_dir is None else parent_dir + + model.load_state_dict(ckp["model"] if "model" in ckp else ckp) + del ckp + + # if the checkpoint was a sub-network, set it at this point + if sub_network_config is not None: + model.select_sub_network(sub_network_config) + + # compute metrics + metrics = {} + metrics["parameters"] = compute_parameters(model) + if measure_latency: + metrics["latency"] = compute_latency(model) + if measure_flops: + metrics["flops"] = compute_flops( + model, batch_size=latency_batch_size, previous_device=device + ) + metrics_path.write_text(json.dumps(metrics, indent=2)) + + # downstream task evaluation + model.to(device) + convert_and_evaluate( + model, + out_dir=out_dir, + tasks=tasks, + seed=seed, + num_fewshot=num_fewshot, + batch_size=batch_size, + device=device, + limit=limit, + ) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(setup) diff --git a/whittle/metrics/flops.py b/whittle/metrics/flops.py index 2b10ab0..4336fe9 100644 --- a/whittle/metrics/flops.py +++ b/whittle/metrics/flops.py @@ -22,6 +22,7 @@ def compute_flops( batch_size: int = 1, sequence_length: int = 512, metric: Literal["flops", "macs"] = "flops", + previous_device: str | None = None, ) -> float: """ Estimates the number of floating-point operations (FLOPs) or multiply-accumulate operations (MACs) for a GPT model. @@ -33,6 +34,7 @@ def compute_flops( batch_size: The batch size for the input tensor. Defaults to 1. sequence_length: The sequence length for the input tensor. Defaults to 512. metric: The metric to return. Either "flops" for floating-point operations or "macs" for multiply-accumulate operations. Defaults to "flops". + previous_device: The device to cast to after profiling. If None, the device is not changed. Defaults to None. Returns: The estimated number of floating-point operations (FLOPs) or multiply-accumulate operations (MACs) for the model's forward pass, depending on the specified metric. @@ -43,6 +45,7 @@ def compute_flops( ) model.eval() + model.to("cpu") os.environ["DS_ACCELERATOR"] = "CPU" deepspeed.accelerator.set_accelerator(CPU_Accelerator()) @@ -56,6 +59,9 @@ def compute_flops( as_string=False, ) + if previous_device is not None: + model.to(previous_device) + if metric == "flops": return flops else: diff --git a/whittle/metrics/latency.py b/whittle/metrics/latency.py index 123f16e..e22cb08 100644 --- a/whittle/metrics/latency.py +++ b/whittle/metrics/latency.py @@ -41,6 +41,7 @@ def compute_latency( use_cuda: bool = False, batch_size: int = 8, n_samples: int = 10, + device: str | None = None, ) -> float: """ Profiles the latency of a PyTorch model for inference. @@ -53,6 +54,7 @@ def compute_latency( use_cuda (bool, optional): If True and CUDA is available, the model will be moved to the GPU for profiling. Defaults to False. batch_size (int, optional): The batch size for the input tensor. Defaults to 8. n_samples (int, optional): The number of samples to profile after the warm-up phase. Defaults to 10. + device (Optional[str], optional): The device to use for profiling. If None, the device is inferred based on use_cuda. Defaults to None. Returns: float: The average inference time per sample in milliseconds. @@ -60,9 +62,13 @@ def compute_latency( input_tensor = torch.randint( 0, model.config.padded_vocab_size, (batch_size, model.max_seq_length) ) - if use_cuda and torch.cuda.is_available(): - model = model.cuda() - input_tensor = input_tensor.cuda() + if device is None: + if use_cuda and torch.cuda.is_available(): + model = model.cuda() + input_tensor = input_tensor.cuda() + else: + model = model.to(device) + input_tensor = input_tensor.to(device) # Use PyTorch profiler to record compute_latency model.eval() @@ -84,5 +90,4 @@ def compute_latency( # Convert time to milliseconds total_time_ms = (cpu_time_us + cuda_time_us) / 1000 - model = model.cpu() return total_time_ms / n_samples diff --git a/whittle/models/gpt/model.py b/whittle/models/gpt/model.py index be8e654..21ea0b5 100644 --- a/whittle/models/gpt/model.py +++ b/whittle/models/gpt/model.py @@ -266,9 +266,10 @@ def select_sub_network(self, config: dict[str, Any]) -> None: """ self.set_sub_network( config["embed_dim"], - config["mlp_ratio"] * config["embed_dim"], + int(config["mlp_ratio"] * config["embed_dim"]), config["num_heads"], config["depth"], + sub_network_head_size=config.get("head_size", None), ) def reset_super_network(self): diff --git a/whittle/sampling/param_bins.py b/whittle/sampling/param_bins.py new file mode 100644 index 0000000..4166844 --- /dev/null +++ b/whittle/sampling/param_bins.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import numpy as np + +from whittle.metrics.parameters import ( + compute_parameters, +) +from whittle.models.gpt import GPT + + +class ParamsEstimator: + def __init__(self, model: GPT): + self.model = model + + def get_params(self, config: dict[str, Any]) -> float: + self.model.select_sub_network(config) + params = compute_parameters(self.model) + self.model.reset_super_network() + return params + + def __call__(self, config: dict[str, Any]) -> float: + return self.get_params(config) + + +class ParamBins: + def __init__( + self, + min_config: dict[str, Any], + max_config: dict[str, Any], + params_func: Callable, + num_bins: int = 20, + log_bins: bool = False, + start_bin_size: int = 1, + empty_bin_tolerance: int = 4, + ): + self.params_func = params_func + self.min_params = self.get_params(min_config) + self.max_params = self.get_params(max_config) + + # get evenly spaced / log spaced bins between min_params and max_params + if log_bins: + self.values = np.logspace( + np.log10(self.min_params), np.log10(self.max_params), num=num_bins + 1 + ) + else: + self.values = np.linspace(self.min_params, self.max_params, num=num_bins + 1) + + self.bins = [0 for _ in self.values[1:]] # one bin for every lower bound + self.current_bin_length = start_bin_size + self.empty_bin_tolerance = empty_bin_tolerance + + def get_params(self, config: dict[str, Any]) -> float: + return self.params_func(config) + + def put_in_bin(self, config: dict[str, Any]) -> bool: + params = self.get_params(config) + + found = False + placed = False + at_max_length = 0 + for i, value in enumerate(self.values): + # get the first bin + if not found and params <= value: + found = True + # place into a bin with space left + if self.bins[i - 1] < self.current_bin_length: + self.bins[i - 1] += 1 + placed = True + + # found a bin with space left, don't increase bin length + if self.bins[i - 1] == self.current_bin_length: + at_max_length += 1 + + # increase bin length if almost all bins are full + if (at_max_length + self.empty_bin_tolerance) >= len(self.bins): + self.current_bin_length += 1 + + return placed diff --git a/whittle/sampling/random_sampler.py b/whittle/sampling/random_sampler.py index b814cf0..c449ce5 100644 --- a/whittle/sampling/random_sampler.py +++ b/whittle/sampling/random_sampler.py @@ -1,11 +1,15 @@ from __future__ import annotations import warnings +from collections.abc import Callable from typing import Any import numpy as np from syne_tune.config_space import Categorical, Domain +from whittle.args import ParamBinArgs +from whittle.sampling.param_bins import ParamBins + class RandomSampler: """ @@ -22,10 +26,10 @@ def __init__(self, config_space: dict, seed: int | None = None): def sample(self) -> dict[str, Any]: """ - Gets the smallest sub-network configuration from the search space. + Gets a random sub-network configuration from the search space. Returns: - The smallest sub-network configuration. + A random sub-network configuration. """ config = {} for hp_name, hparam in self.config_space.items(): @@ -77,3 +81,50 @@ def get_largest_sub_network(self) -> dict[str, Any]: else: config[k] = v.upper return config + + +class StratifiedRandomSampler(RandomSampler): + """ + StratifiedRandomSampler samples configurations from a given search space using a random state. + It maintains a set of bins to ensure that the configurations are sampled uniformly based on their parameter count. + + Args: + config_space: The search space from which to sample. + seed: Seed for the random number generator. Defaults to None. + param_bins: The parameter bins that limit the sub-network params in the search. + """ + + def __init__( + self, + config_space: dict, + params_estimator: Callable, + seed: int | None = None, + param_bins: ParamBinArgs | None = None, + ): + param_bins = param_bins if param_bins is not None else ParamBinArgs() + super().__init__(config_space, seed=seed) + self.param_bins = ParamBins( + self.get_smallest_sub_network(), + self.get_largest_sub_network(), + params_estimator, + num_bins=param_bins.num_bins, + log_bins=param_bins.log_bins, + start_bin_size=param_bins.start_bin_size, + empty_bin_tolerance=param_bins.empty_bin_tolerance, + ) + + def sample(self) -> dict[str, Any]: + """ + Gets the smallest sub-network configuration from the search space. + + Returns: + The smallest sub-network configuration. + """ + while True: + config = super().sample() + + # find a bin for the config, if not found, continue sampling + if self.param_bins.put_in_bin(config): + break + + return config diff --git a/whittle/search/baselines.py b/whittle/search/baselines.py index d1d51f6..0bb6cdd 100644 --- a/whittle/search/baselines.py +++ b/whittle/search/baselines.py @@ -15,7 +15,9 @@ LinearScalarizedScheduler, ) +from whittle.sampling.param_bins import ParamBins from whittle.search.local_search import LS +from whittle.search.stratified_search import StratifiedRandomSearch def get_random(config_space): @@ -68,6 +70,7 @@ class MethodArguments: metrics: list mode: list random_seed: int + param_bins: ParamBins | None = None def initial_design(config_space): @@ -90,6 +93,7 @@ class Methods: RSBO = "rsbo" EHVI = "ehvi" MOASHA = "moasha" + SRS = "stratified_random_search" methods = { @@ -100,6 +104,14 @@ class Methods: random_seed=method_arguments.random_seed, points_to_evaluate=initial_design(method_arguments.config_space), ), + Methods.SRS: lambda method_arguments: StratifiedRandomSearch( + config_space=method_arguments.config_space, + metric=method_arguments.metrics[0], + mode=method_arguments.mode[0], + random_seed=method_arguments.random_seed, + points_to_evaluate=initial_design(method_arguments.config_space), + param_bins=method_arguments.param_bins, + ), Methods.MOREA: lambda method_arguments: MOREA( config_space=method_arguments.config_space, metric=method_arguments.metrics, diff --git a/whittle/search/search.py b/whittle/search/search.py index 9514b0a..f91fed9 100644 --- a/whittle/search/search.py +++ b/whittle/search/search.py @@ -6,7 +6,9 @@ import numpy as np from lightning.fabric.loggers import Logger +from tqdm import tqdm +from whittle.sampling.param_bins import ParamBins from whittle.search.ask_tell_scheduler import AskTellScheduler from whittle.search.baselines import MethodArguments, methods from whittle.search.multi_objective import get_pareto_optimal @@ -20,6 +22,10 @@ def multi_objective_search( objective_kwargs: dict[str, Any] | None = None, logger: Logger | None = None, seed: int | None = None, + param_bins: ParamBins | None = None, + objective_1_name: str = "objective_1", + objective_2_name: str = "objective_2", + verbose: bool = True, ) -> dict[str, Any]: """ Search for the Pareto-optimal sub-networks using the specified strategy. @@ -37,7 +43,15 @@ def multi_objective_search( Defaults to None. seed: The random seed for reproducibility. Defaults to None. - + param_bins: The parameter bins that limit the sub-network params in the search. + The configs from ask() are rejected if they fit into a bin that is full. + The bin size is increased if all bins are full. + Defaults to None. + objective_1_name: The name of the first objective. + Defaults to "objective_1". + objective_2_name: The name of the second objective. + Defaults to "objective_2". + verbose: Whether to have a verbose tqdm output. Returns: The results of the search, including Pareto-optimal solutions. @@ -54,17 +68,20 @@ def multi_objective_search( metrics=metrics, mode=["min", "min"], random_seed=seed, + param_bins=param_bins, ) ) scheduler = AskTellScheduler(base_scheduler=base_scheduler) costs = np.empty((num_samples, 2)) - runtime = [] - configs = [] + runtime: list[float] = [] + configs: list[dict[str, Any]] = [] start_time = time.time() - for i in range(num_samples): + + for i in tqdm(range(num_samples), disable=not verbose): trial_suggestion = scheduler.ask() + objective_1, objective_2 = objective( trial_suggestion.config, **(objective_kwargs or {}) ) @@ -80,12 +97,12 @@ def multi_objective_search( runtime.append(time.time() - start_time) - observation = dict( - iteration=i, - objective_1=float(objective_1), - objective_2=float(objective_2), - runtime=runtime[-1], - ) + observation = { + "iteration": i, + objective_1_name: float(objective_1), + objective_2_name: float(objective_2), + "runtime": runtime[-1], + } if logger is not None: logger.log_metrics(observation) diff --git a/whittle/search/stratified_search.py b/whittle/search/stratified_search.py new file mode 100644 index 0000000..c204a7f --- /dev/null +++ b/whittle/search/stratified_search.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import Any + +from syne_tune.optimizer.schedulers import FIFOScheduler +from syne_tune.optimizer.schedulers.searchers.random_grid_searcher import RandomSearcher + +from whittle.sampling.param_bins import ParamBins + + +class StratifiedRandomSearch(FIFOScheduler): + """ + Stratified Random Search (SRS) is a search strategy that samples configurations + uniformly at random from the search space. It is a simple baseline for + hyperparameter optimization. + + Args: + config_space: The configuration space to sample from. + metric: The metric to optimize. + param_bins: The parameter bins that limit the sub-network params in the search. + The configs from ask() are rejected if they fit into a bin that is full. + The bin size is increased if all bins are full. + mode: The optimization mode for the metric. + Defaults to "min". + start_point: Optional. The starting point for the search. + Defaults to None. + random_seed: Optional. The random seed for reproducibility. + Defaults to None. + points_to_evaluate: Optional. The initial configurations to evaluate. + Defaults to None. + **kwargs: Additional arguments for the scheduler. + """ + + def __init__( + self, + config_space: dict[str, Any], + metric: list[str], + param_bins: ParamBins, + mode: list[str] | str = "min", + start_point: dict[str, Any] | None = None, + random_seed: int | None = None, + points_to_evaluate: list[dict] | None = None, + **kwargs: Any, + ): + super().__init__( + config_space=config_space, + metric=metric, + mode=mode, + searcher=StratifiedRandomSearcher( + config_space=config_space, + metric=metric, + start_point=start_point, + mode=mode, + random_seed=random_seed, + points_to_evaluate=points_to_evaluate, + param_bins=param_bins, + ), + random_seed=random_seed, + **kwargs, + ) + + +class StratifiedRandomSearcher(RandomSearcher): + """ + Searcher which randomly samples configurations to try next. If a configuration + gets in a full bin (we already sampled enough configurations with a similar parameter count), + it is rejected and a new configuration is sampled. + """ + + def __init__( + self, + config_space: dict[str, Any], + metric: list[str] | str, + param_bins: ParamBins, + sample_patience: int = 10000, + **kwargs, + ): + """ + Args: + config_space: The configuration space to sample from. + metric: The metric to optimize. + param_bins: The parameter bins that limit the sub-network params in the search. + The configs from ask() are rejected if they fit into a bin that is full. + The bin size is increased if all bins are full. + sample_patience: The number of rejected samples to try before raising an error. + Defaults to 10000. + **kwargs: Additional arguments for the searcher. + """ + super().__init__( + config_space, + metric=metric, + **kwargs, + ) + self.param_bins = param_bins + self.sample_patience = sample_patience + + def _get_config(self, **kwargs) -> dict | None: + """Sample a new configuration at random. If it doesn't fit into bins of + already sampled configurations, continue sampling until a valid config is found. + + If ``allow_duplicates == False``, this is done without replacement, so + previously returned configs are not suggested again. + + :param trial_id: Optional. Used for ``debug_log`` + :return: New configuration, or None + """ + i = 0 + while True: + config = super()._get_config(**kwargs) + + # find a bin for the config, if not found, continue sampling + if self.param_bins.put_in_bin(config): + break + i += 1 + if i >= self.sample_patience: + raise ValueError( + f"Could not find a valid configuration after {self.sample_patience} samples. Try increasing the tolerance for parameter bins not filled to the max." + ) + + return config + + def clone_from_state(self, state: dict[str, Any]): + new_searcher = StratifiedRandomSearcher( + self.config_space, + metric=self._metric, + points_to_evaluate=[], + debug_log=self._debug_log, + allow_duplicates=self._allow_duplicates, + param_bins=self.param_bins, + ) + new_searcher._resource_attr = self._resource_attr + new_searcher._restore_from_state(state) + return new_searcher diff --git a/whittle/search_sub_networks.py b/whittle/search_sub_networks.py index 1379aa4..dce67dc 100644 --- a/whittle/search_sub_networks.py +++ b/whittle/search_sub_networks.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import os import time from pathlib import Path @@ -10,7 +11,8 @@ from lightning.fabric.strategies import FSDPStrategy from litgpt import Tokenizer from litgpt.args import EvalArgs, TrainArgs -from litgpt.data import Alpaca, DataModule +from litgpt.data import Alpaca, DataModule, TinyStories +from litgpt.finetune.lora import validate as finetune_validate from litgpt.model import Config from litgpt.pretrain import get_dataloaders, validate from litgpt.utils import ( @@ -18,21 +20,25 @@ check_nvlink_connectivity, check_valid_checkpoint_dir, choose_logger, + copy_config_files as copy_config_files_func, find_resume_path, get_default_supported_precision, init_out_dir, load_checkpoint, parse_devices, + save_config, ) from torch.utils.data import DataLoader -from whittle.args import SearchArgs -from whittle.metrics import compute_parameters +from whittle.args import ParamBinArgs, SearchArgs +from whittle.metrics import compute_flops, compute_latency, compute_parameters from whittle.models.gpt import GPT from whittle.models.gpt.blocks import Block from whittle.models.gpt.extract import extract_current_sub_network from whittle.pretrain_super_network import get_search_space +from whittle.sampling.param_bins import ParamBins, ParamsEstimator from whittle.search import multi_objective_search +from whittle.search.baselines import Methods def setup( @@ -54,6 +60,15 @@ def setup( logger_name: Literal["wandb", "tensorboard", "csv"] | None = "csv", seed: int | None = 1337, access_token: str | None = None, + param_bins: ParamBinArgs = ParamBinArgs(), + performance_metric: str | None = "val_loss", + efficiency_metric: str | None = "parameters", + log_objective_names: bool | None = True, + save_checkpoints: bool = True, + fine_tuned: bool = False, + copy_config_files: bool = True, + verbose: bool = True, + num_workers: int = 4, ) -> None: """ Multi-objective search to select Pareto optimal set of sub-networks from trained super-network. @@ -68,19 +83,55 @@ def setup( resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists. - data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. + data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyStories`` or ``litgpt.data.Alpaca`` for fine-tuned models. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. search: Search-related arguments. See ``whittle.args.SearchArgs`` for details. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. access_token: Optional API token to access models with restrictions. + param_bins: The parameter bins that limit the sub-network params in the search. + performance_metric: The name of the first objective to optimize (possible - val_loss, perplexity). Defaults to "val_loss". + efficiency_metric: The name of the second objective to optimize (possible - parameters, latency, flops). Defaults to "parameters". + log_objective_names: Whether to log the names of the objectives in the logger, or log as objective_1 and objective_2. Defaults to True. + save_checkpoints: Whether to save checkpoints of the sub-networks, or config + path to super-network. Defaults to True. + If False, `lit_model.pth` will have the following format: + `{'sub_network_config': sub_network_config, 'parent_dir': checkpoint_dir}` + fine_tuned: Whether the model is fine-tuned. Defaults to False. + This flag determines the dataset to use if `data` is not provided. Additionally, it changes the validation function to use for evaluation. + fine_tuned=True: litgpt.finetune.lora.validate, fine_tuned=False: litgpt.pretrain.validate. + copy_config_files: Whether to copy the config files from the super-network to the sub-networks. Defaults to True. + If set to False, we save `parent_dir` to `lit_model.pth`. If save_checkpoints is False, this argument is ignored. + verbose: Whether to print verbose output. Defaults to True. + num_workers: Number of workers to use for data loading. Defaults to 4. """ + assert performance_metric in [ + "val_loss", + "perplexity", + ], f"Invalid objective_1: {performance_metric}, must be 'val_loss' or 'perplexity'" + assert efficiency_metric in [ + "parameters", + "latency", + "flops", + ], ( + f"Invalid objective_2: {efficiency_metric}, must be 'parameters', 'latency' or 'flops'" + ) + checkpoint_dir = auto_download_checkpoint( model_name=checkpoint_dir, access_token=access_token ) - data = Alpaca() if data is None else data + if data is None: + # import sys + # sys.path.append('../do-not-touch/compressing_llms') + # from datasets_custom.llamamini import LLaMaMini + # data = LLaMaMini() if fine_tuned else TinyStories() + data = ( + Alpaca(num_workers=num_workers) + if fine_tuned + else TinyStories(num_workers=num_workers) + ) + num_devices = int(parse_devices(devices)) out_dir = init_out_dir(out_dir) @@ -131,6 +182,14 @@ def setup( train, eval, search, + param_bins if search.search_strategy == Methods.SRS else None, + performance_metric, + efficiency_metric, + log_objective_names, + save_checkpoints, + fine_tuned, + copy_config_files, + verbose, ) @@ -143,13 +202,32 @@ def _objective( val_dataloader: DataLoader, eval: EvalArgs, verbose: bool | None = True, + objective_1: str = "val_loss", + objective_2: str = "parameters", + fine_tuned: bool = False, ) -> tuple[float, float]: model.select_sub_network(config) - val_loss = validate( - fabric, model, val_dataloader, max_iters=eval.max_iters, verbose=verbose - ) - num_params = compute_parameters(model) - return float(val_loss), num_params + + if fine_tuned: + val_loss = finetune_validate(fabric, model, val_dataloader, eval, verbose=verbose) + else: + val_loss = validate( + fabric, model, val_dataloader, max_iters=eval.max_iters, verbose=verbose + ) + + if objective_1 == "perplexity": + val_loss = torch.exp(val_loss) + + if objective_2 == "parameters": + obj_2 = compute_parameters(model) + elif objective_2 == "latency": + obj_2 = compute_latency(model, device=model.lm_head.weight.device) + elif objective_2 == "flops": + obj_2 = compute_flops(model, previous_device=model.lm_head.weight.device) + else: + raise ValueError(f"Invalid objective_2: {objective_2}") + + return float(val_loss), obj_2 def main( @@ -164,6 +242,14 @@ def main( train: TrainArgs, eval: EvalArgs, search: SearchArgs, + param_bins: ParamBinArgs | None = None, + performance_metric: str = "val_loss", + efficiency_metric: str = "parameters", + log_objective_names: bool = True, + save_checkpoints: bool = True, + fine_tuned: bool = False, + copy_config_files: bool = True, + verbose: bool = True, ) -> None: fabric.seed_everything(seed) @@ -172,6 +258,7 @@ def main( train_dataloader, val_dataloader = get_dataloaders( fabric, data, tokenizer, train, train.max_seq_length ) + train_dataloader, val_dataloader = fabric.setup_dataloaders( train_dataloader, val_dataloader ) @@ -192,7 +279,7 @@ def main( fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) else: - load_checkpoint(fabric, state["model"], checkpoint_path) + load_checkpoint(fabric, model, checkpoint_path) train_time = time.perf_counter() @@ -203,6 +290,24 @@ def main( fabric.print("Start multi-objective search") + bins = None + if param_bins is not None: + from whittle.sampling.random_sampler import RandomSampler + + sampler = RandomSampler(search_space, seed=seed) + + # get bins limited by the smallest/largest config + params_estimator = ParamsEstimator(model) + bins = ParamBins( + sampler.get_smallest_sub_network(), + sampler.get_largest_sub_network(), + params_estimator, + num_bins=param_bins.num_bins, + log_bins=param_bins.log_bins, + start_bin_size=param_bins.start_bin_size, + ) + + # fabric.is_global_zero search_results = multi_objective_search( _objective, search_space, @@ -211,29 +316,60 @@ def main( "model": model, "val_dataloader": val_dataloader, "eval": eval, + "objective_1": performance_metric, + "objective_2": efficiency_metric, + "fine_tuned": fine_tuned, }, search_strategy=search.search_strategy, num_samples=search.iterations, seed=seed, logger=fabric.logger, + param_bins=bins, + objective_1_name=performance_metric if log_objective_names else "objective_1", + objective_2_name=efficiency_metric if log_objective_names else "objective_2", + verbose=verbose and fabric.is_global_zero, ) training_time = time.perf_counter() - train_time fabric.print(f"Total search time: {training_time:.02f}.") fabric.print( - f"Found {len(search_results['configs'])} Pareto optimal sub-networks. Save checkpoints to {out_dir}." + f"Found {len(search_results['configs'])} sub-networks ({sum(i for i in search_results['is_pareto_optimal'])} Pareto optimal). Save checkpoints to {out_dir}." ) + pareto_optimal_paths = [] for i, sub_network_dict in enumerate(search_results["configs"]): save_path = out_dir / f"sub_network_{i}" / "lit_model.pth" save_path.parent.mkdir(parents=True, exist_ok=True) - model.select_sub_network(sub_network_dict) - sub_network = extract_current_sub_network(model) - - model.reset_super_network() - - fabric.save(save_path, {"model": sub_network}) + if search_results["is_pareto_optimal"][i]: + pareto_optimal_paths.append(str(save_path.absolute())) + + # either save the extracted checkpoint, or the config + path to super-network + if save_checkpoints: + model.select_sub_network(sub_network_dict) + sub_network = extract_current_sub_network(model) + model.reset_super_network() + + # either save everything including config files, or only model_config.yaml and the weights + if copy_config_files: + copy_config_files_func(checkpoint_dir, save_path.parent) + fabric.save(save_path, {"model": sub_network}) + else: + fabric.save( + save_path, {"model": sub_network, "parent_dir": checkpoint_dir} + ) + # the new model_config.yaml is different from the original one, so we rewrite it + save_config(sub_network.config, save_path.parent) + else: + # minimalistic checkpoint - only sub-network config and path to super-network + fabric.save( + save_path, + {"sub_network_config": sub_network_dict, "parent_dir": checkpoint_dir}, + ) + + # save all paths to pareto optimal sub-networks + with open(out_dir / "pareto_optimal_paths.json", "w") as f: + json.dump(pareto_optimal_paths, f) if __name__ == "__main__":