diff --git a/plugins/attention-and-distributed-packing/.pylintrc b/plugins/attention-and-distributed-packing/.pylintrc index 31cb902c..abbb1c11 100644 --- a/plugins/attention-and-distributed-packing/.pylintrc +++ b/plugins/attention-and-distributed-packing/.pylintrc @@ -53,7 +53,7 @@ ignore=CVS,protobufs # format. Because '\\' represents the directory delimiter on Windows systems, # it can't be used as an escape character. # NOTE: do not lint code imported from unsloth -ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth* +ignore-paths=.*multipack_sampler.py # Files or directories matching the regular expression patterns are skipped. # The regex matches against base names, not paths. The default value ignores diff --git a/plugins/attention-and-distributed-packing/README.md b/plugins/attention-and-distributed-packing/README.md index cad6ec63..0ff6f2e4 100644 --- a/plugins/attention-and-distributed-packing/README.md +++ b/plugins/attention-and-distributed-packing/README.md @@ -3,19 +3,31 @@ This library contains plugins to accelerate finetuning with the following optimizations: 1. Padding-Free Flash Attention Computation +2. Multipack Distributed Sampling ## Plugins Plugin | Description | Depends | Loading | Augmentation | Callbacks --|--|--|--|--|-- -[padding_free](./src/fms_acceleration_ilab/framework_plugin_padding_free.py) | Padding-Free Flash Attention Computation | flash_attn | | ✅ | ✅ +[padding_free](./src/fms_acceleration_aadp/framework_plugin_padding_free.py) | Padding-Free Flash Attention Computation | flash_attn | | ✅ | +[multipack sampler](./src/fms_acceleration_aadp/framework_plugin_multipack.py) | Multipack Distributed Sampling | numba | | ✅ | ## Native Transformers Support from v4.44.0 Transformers natively supports padding-free from v4.44.0 [see here](https://github.com/huggingface/transformers/pull/31629). The padding-free plugin will use the transformers library if compatible, otherwise if `transformers < v4.44.0` the plugin will use an internal implementation instead. +## Running Benchmarks + +To reproduce the benchmarks, simply run the following commands, + +Reproduce [Padding Free on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_pf.csv) +`bash scripts/run_benchmarks.sh "1 2" "4 8" benchmark_outputs scenarios-orca.yaml "none"` + +Reproduce [MultiPack on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_mp.csv) +`bash scripts/run_benchmarks.sh "2 4 8" "16 32 64" benchmark_outputs scenarios-orca.yaml "padding-free"` + ## Known Issues ### Currently Only Supports Pre-Tokenized Dataset @@ -32,3 +44,9 @@ In the meantime, the plugin expects the user to provide a pretokenized dataset t - is tokenized - has template labels that are masked to exclude from loss computation - has eos token appended + +### Currenly Only Supports Multipack with Padding-Free + +The multipack plugin currently also requires the padding-free plugin to work. +This may change in the future if there is demand for multipack to work standalone without padding free. + diff --git a/plugins/attention-and-distributed-packing/configs/multipack.yaml b/plugins/attention-and-distributed-packing/configs/multipack.yaml new file mode 100644 index 00000000..6c4418c2 --- /dev/null +++ b/plugins/attention-and-distributed-packing/configs/multipack.yaml @@ -0,0 +1,11 @@ +# Configurations to accelerate data packing/padding in training +training: + + # dataloader configurations + dataloader: + + # multipack dataloader + multipack: + + # number of processes used to calculate dataset lengths + num_processes: 16 \ No newline at end of file diff --git a/plugins/attention-and-distributed-packing/configs/aadp.yaml b/plugins/attention-and-distributed-packing/configs/padding_free.yaml similarity index 100% rename from plugins/attention-and-distributed-packing/configs/aadp.yaml rename to plugins/attention-and-distributed-packing/configs/padding_free.yaml diff --git a/plugins/attention-and-distributed-packing/pyproject.toml b/plugins/attention-and-distributed-packing/pyproject.toml index 00f1a155..3675fa26 100644 --- a/plugins/attention-and-distributed-packing/pyproject.toml +++ b/plugins/attention-and-distributed-packing/pyproject.toml @@ -13,7 +13,7 @@ authors = [ license = {text = "Apache-2.0"} readme = "README.md" requires-python = "~=3.9" -keywords = ['fms-hf-tuning', 'acceleration', 'padding-free'] +keywords = ['fms-hf-tuning', 'acceleration', 'padding-free', 'multipack'] classifiers=[ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", @@ -23,6 +23,8 @@ classifiers=[ "Programming Language :: Python :: 3.11", ] +dependencies = ["numba"] + [tool.hatch.build.targets.wheel] only-include = ["src/fms_acceleration_aadp"] diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/__init__.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/__init__.py index 12e86c4a..ef335739 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/__init__.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. # Local +from .framework_plugin_multipack import MultipackDataloaderAccelerationPlugin from .framework_plugin_padding_free import PaddingFreeAccelerationPlugin diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py index 330bf5eb..10d6e93a 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Standard from dataclasses import dataclass import warnings + +# Third Party from transformers import DefaultDataCollator, default_data_collator +import numpy as np @dataclass @@ -52,3 +56,13 @@ def __call__(self, features, return_tensors=None): else: ret["labels"] += [-100] + feature["input_ids"][1:] return default_data_collator([ret], return_tensors) + + +def calculate_token_lengths(dataset, num_processes): + return np.array( + dataset.map( + lambda x: {"len": len(x["input_ids"])}, + num_proc=num_processes, + load_from_cache_file=True, + )["len"] + ) diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py index 782145cc..60834b9e 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import inspect +# Standard from functools import partial -import torch +from typing import Optional +import inspect +import os +# Third Party # pylint: disable=no-name-in-module from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal -from typing import Optional +import torch if is_flash_attn_2_available(): # pylint: disable=import-error - from flash_attn import ( - flash_attn_func, - flash_attn_varlen_func, - ) + # Third Party + from flash_attn import flash_attn_func, flash_attn_varlen_func _flash_supports_window_size = "window_size" in list( inspect.signature(flash_attn_func).parameters diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_multipack.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_multipack.py new file mode 100644 index 00000000..aa9134a6 --- /dev/null +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_multipack.py @@ -0,0 +1,183 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from types import MethodType +from typing import Dict, Tuple +import warnings + +# Third Party +from accelerate import Accelerator +from fms_acceleration import AccelerationPlugin +from peft import LoraConfig +from torch.utils.data import DataLoader +from transformers import TrainingArguments + +# from accelerate.data_loader import DataLoaderShard +import torch + + +class MultipackDataloaderAccelerationPlugin(AccelerationPlugin): + + require_packages = {"numba"} + + def __init__( + self, + configurations: Dict[str, Dict], + seed: int = 42, + ): + super().__init__(configurations) + + self.num_processes = self._check_config_and_maybe_check_values( + key="training.dataloader.multipack.num_processes", + ) + + # see about the collator + attention = self._check_config_and_maybe_check_values( + key="training.attention", + ) + + # internal flags + self._seed = seed + self._padding_free = False + self._pad_token_id = None + + if "padding_free" in attention: + # for padding free the multipack preparation will ignore the padding tokens + self._padding_free = True + else: + # NOTE: need to get this from somewhere + assert self._pad_token_id is not None, "need to get pad token id" + + @property + def requires_agumentation(self): + return True + + def augmentation( + self, + model, + train_args: TrainingArguments, + modifiable_args: Tuple[LoraConfig], + ): + + # guarded because multipack has numba dependencies + # Third Party + # pylint: disable=import-outside-toplevel + from fms_acceleration.accelerator_patcher import ( + AcceleratorPatcher, + AcceleratorPatcherComponent, + ) + + # Local + from .aadp_utils import ( # pylint: disable=import-outside-toplevel + calculate_token_lengths, + ) + from .multipack_sampler import ( # pylint: disable=import-outside-toplevel + MultipackDistributedBatchSampler, + ) + + rank, num_bins = 0, 1 + if torch.distributed.is_initialized(): + num_bins = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + # NOTE: or should we do a silent fallback + raise AssertionError( + "Multipack dataloader only works for distributed training." + ) + + # some checks + def _prereq(dataloader: DataLoader): + return hasattr(dataloader, "dataset") + + def _build_multipack_dataloader( + dataloader: DataLoader, accelerator: Accelerator + ): + + # NOTE: for now we disable support for deepspeed, but can be added in + # future if needed + assert ( + not accelerator.state.deepspeed_plugin + ), "Currently, multipack not supported for deepspeed" + + # get the dataset + dataset = dataloader.dataset + if torch.distributed.get_rank() > 0: + warnings.warn( + "Waiting for main process to perform the mapping." + "If the dataset is large, some processes might time out," + "You may need to increase the timeout limit or the number " + f"of workers processing the dataset > {self.num_processes}." + ) + torch.distributed.barrier() + + lengths = calculate_token_lengths(dataset, num_processes=self.num_processes) + + if torch.distributed.get_rank() == 0: + torch.distributed.barrier() + + self._max_number_tokens = ( + train_args.per_device_train_batch_size * lengths.mean() + ) + + # prepare the multipack distributed batch sampler + sampler = MultipackDistributedBatchSampler( + batch_max_length=self._max_number_tokens, + lengths=lengths, + num_replicas=num_bins, + rank=rank, + seed=self._seed, + padding=not self._padding_free, + ) + + # wanted to use this but its abit annoying, + # from accelerate.data_loader import DataLoaderShard + # - so will just patch for now, but lets have a better + # solution later + dataloader = DataLoader( + dataset, + batch_sampler=sampler, + num_workers=dataloader.num_workers, + collate_fn=dataloader.collate_fn, + ) + + # patch a set epoch function to delegate the call to the + # batch_sampler + def _set_epoch(self, epoch: int): + self.batch_sampler.set_epoch(epoch) + + dataloader.set_epoch = MethodType(_set_epoch, dataloader) + return dataloader + + AcceleratorPatcher.replace( + "multipack", + AcceleratorPatcherComponent.data_loader, + replacement_builder=_build_multipack_dataloader, + pre_requisite_check=_prereq, + skip_prepare=True, + ) + + # take a pointer to train args + self._train_args = train_args + return model, modifiable_args + + +# register +AccelerationPlugin.register_plugin( + MultipackDataloaderAccelerationPlugin, + configuration_and_paths=[ + "training.dataloader.multipack", # activate if multipack config + "training.attention", # currently we require multipack to work with padding free + ], +) diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py index 4680313e..dc60e5aa 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py @@ -19,14 +19,8 @@ # Third Party from fms_acceleration import AccelerationPlugin from peft import LoraConfig -from transformers import ( - TrainingArguments, - DataCollatorForSeq2Seq, -) -from accelerate import Accelerator +from transformers import DataCollatorForSeq2Seq, TrainingArguments import torch -from types import MethodType -from torch.utils.data import DataLoader class PaddingFreeAccelerationPlugin(AccelerationPlugin): @@ -56,34 +50,64 @@ def augmentation( modifiable_args: Tuple[LoraConfig], ): # guarded + # Standard + from functools import partial # pylint: disable=import-outside-toplevel + + # Third Party + # pylint: disable=import-outside-toplevel + from fms_acceleration.accelerator_patcher import ( + AcceleratorPatcher, + AcceleratorPatcherComponent, + ) from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel ModelPatcher, ModelPatcherRule, ModelPatcherTrigger, ) - from functools import partial # pylint: disable=import-outside-toplevel + + def _collator_check_seq2seq(collate_fn): + # "The padding-free plugin currently only works with a + # `DataCollatorForSeq2Seq` collate_fn, + # otherwise the collation can be unreliable" + return isinstance(collate_fn, DataCollatorForSeq2Seq) # This check is done here to only patch the attention forward # the PR was merged here # https://github.com/huggingface/transformers/pull/31629 + _native = False try: # if this is importable, it means the PR # has been merged, and there is no more need to # pylint: disable=import-outside-toplevel,no-name-in-module,unused-import - from transformers import ( - DataCollatorWithFlattening, - ) + # Third Party + from transformers import DataCollatorWithFlattening + + _native = True - # - if import successful this will print and return + except ImportError: + + # Otherwise, use the locally implemented DataCollatorWithFlattening + # pylint: disable=import-outside-toplevel + # Local + from .aadp_utils import DataCollatorWithFlattening + + # setup the collator + AcceleratorPatcher.replace( + "flattening-collator", + AcceleratorPatcherComponent.data_collator, + replacement=DataCollatorWithFlattening(), + pre_requisite_check=_collator_check_seq2seq, + ) + + if _native: + # - if natively supported, then no more need for patch the model + # - so print and return warnings.warn( "transformers version supports padding free natively in various models." ) return model, modifiable_args - except ImportError: - pass - # Otherwise patching is required: # 1. a custom forward has to be registered on the backbone # to intercept the position ids @@ -93,6 +117,7 @@ def _is_backbone(module: torch.nn.Module): # - patch backbone model_type = model.config.model_type # pylint: disable=import-outside-toplevel + # Local from .flash_attn import build_backbone_forward ModelPatcher.register( @@ -116,10 +141,12 @@ def _is_backbone(module: torch.nn.Module): # - this is an old version that does not have logic to handle the flattened batch # pylint: disable=import-outside-toplevel + # Third Party from transformers.modeling_flash_attention_utils import ( _flash_attention_forward, ) + # Local from .flash_attn import _flash_attention_forward_with_posids ModelPatcher.register( @@ -137,9 +164,10 @@ def _is_backbone(module: torch.nn.Module): # attached to the model classes # - for similar reasons as Case I, they need to be patched on the # FA2 modules - from .flash_attn import ( + # Local + from .flash_attn import ( # pylint: disable=import-outside-toplevel build_fa_forward, - ) # pylint: disable=import-outside-toplevel + ) def is_flash_attn_2(module): if module.__class__.__name__.endswith("FlashAttention2"): @@ -159,58 +187,6 @@ def is_flash_attn_2(module): return model, modifiable_args - def get_callbacks_and_ready_for_train( - self, model: torch.nn.Module = None, accelerator: Accelerator = None - ): - # patch the dataloader on the accelerator - self._patch_dataloader(accelerator) - return [] - - def _patch_dataloader( - self, - accelerator: Accelerator, - ): - """ - Hijacks the accelorator prepare inside `Trainer.train` - - If it is a single argument. it is assumed to be the prepare call on the dataloader - - we replace the collate function in the dataloader to flatten the batch into a long - sequence with special tokens to define the attention computation boundaries - """ - try: - # Check if transformers already supports a collator that flattens the batch - # pylint: disable=import-outside-toplevel,no-name-in-module - from transformers import ( - DataCollatorWithFlattening, - ) - except ImportError: - # Otherwise, use the locally implemented DataCollatorWithFlattening - # pylint: disable=import-outside-toplevel - from .aadp_utils import ( - DataCollatorWithFlattening, - ) - - # hijack the dataloader in accelerator.prepare to replace the collate_fn - _old_prepare = accelerator.prepare - - def prepare(self, *args, device_placement=None): - if len(args) > 1 or not isinstance(args[0], DataLoader): - return _old_prepare(*args, device_placement=device_placement) - dataloader = args[0] - - if not isinstance(dataloader.collate_fn, DataCollatorForSeq2Seq): - raise TypeError( - "The padding-free plugin currently only works with a \ - `DataCollatorForSeq2Seq` collate_fn, \ - otherwise the collation can be unreliable" - ) - - # Replace the collate_fn in dataloader - dataloader.collate_fn = DataCollatorWithFlattening() - - return _old_prepare(dataloader) - - accelerator.prepare = MethodType(prepare, accelerator) - # register AccelerationPlugin.register_plugin( diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/multipack_sampler.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/multipack_sampler.py new file mode 100644 index 00000000..481557e4 --- /dev/null +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/multipack_sampler.py @@ -0,0 +1,485 @@ +""" +MIT License + +Copyright (c) 2023 One + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +taken from https://github.com/imoneoi/multipack_sampler with some modifications +taken from https://github.com/instructlab/training/blob/main/src/instructlab/training/multipack_sampler.py +""" + +# Standard +from typing import List, Optional +import os + +# Third Party +from torch.utils.data import DataLoader, Sampler +import numba +import numpy as np +import torch.distributed as dist + + +def guess_starting_avg_padding(base_avg, goal, num_gpus, grad_accum, sorted_lengths): + """ + Return a starting middle point for the binary search + (to find optimal addition to packing_max_batch_len + to account for padding) + + Uses the largest initial bucket to approximate an + upper-bound for average padding, should overshoot. + """ + addition = 0 + packing_max_batch_len = int( + (base_avg + addition) * ((goal / num_gpus) / grad_accum) + ) + + bucket_zero = [] + max = sorted_lengths[0] + sum = 0 + for length in sorted_lengths: + if sum + max <= packing_max_batch_len: + sum += max + bucket_zero.append(length) + else: + break + + total_pad = 0 + for length in bucket_zero: + total_pad += max - length + addition = round(total_pad / len(bucket_zero)) + return addition + + +def simulate_buckets( + base_avg, + goal, + num_gpus, + grad_accum, + pad_id, + max_batch_len, + lengths, + seed, + dataset, + addition, +): + """ + Given an addition to packing_max_batch_len, simulate the + packing to find the updated average effective batch size. + """ + packing_max_batch_len = int( + (base_avg + addition) * ((goal / num_gpus) / grad_accum) + ) + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + sampler = MultipackDistributedBatchSampler( + batch_max_length=packing_max_batch_len, + lengths=lengths, + num_replicas=world_size, + rank=rank, + seed=seed, + padding=True, + ) + # NOTE: removed the collate_fn here, as it should not be + # required just to take the length + simulation_loader = DataLoader( + dataset, + batch_sampler=sampler, + num_workers=8, + ) + + avg_ebs = len(dataset) / len(simulation_loader) + return avg_ebs + + +def find_padding_max_batch_len_addition( + base_avg, goal, dataset, num_gpus, grad_accum, pad_id, max_batch_len, seed +): + """ + Do a modified binary search to find optimal padding addition for + packing_maximum_batch_len. Starts with an upper-bound guess, and + increases upper-bound until guess overshoots. Then perform standard + binary search until within a threshold for average effective batch + size. + """ + lengths = dataset.get_lengths() + sorted_lengths = list(lengths) + sorted_lengths.sort(reverse=True) + + # Use first default bucket avg padding as starting value for addition + addition = guess_starting_avg_padding( + base_avg, goal, num_gpus, grad_accum, sorted_lengths + ) + + # binary search correct addition value from starting value + first_over_hit = False + l = 0 + r = 2 * addition + while r - l > 1: + avg_ebs = simulate_buckets( + base_avg, + goal, + num_gpus, + grad_accum, + pad_id, + max_batch_len, + lengths, + seed, + dataset, + addition, + ) + + # check if simulation resulted in batch sizes close enough to goal and adjust if needed + if abs(avg_ebs - goal) <= max(10, round(goal * 0.02)): + break + + if avg_ebs > goal: + first_over_hit = True + r = addition + elif avg_ebs < goal: + if not first_over_hit: + # If the starting midpoint failed to overshoot, increase the bounds of the search + r = r * 2 + else: + l = addition + addition = l + ((r - l) // 2) + + return addition + + +def find_packing_max_batch_len_and_grad_accum( + num_gpus, + avg_sample_len, + effective_batch_size, + max_batch_len_per_gpu, + is_padding, + dataset, + pad_id, + seed, +): + """ + Calculate the minimum gradient accumulation steps required and the corresponding maximum batch length. + + This function determines the minimum number of gradient accumulation steps needed to process the + effective batch size within the constraints of the maximum batch length per GPU. It starts with + the assumption of a single step (no accumulation) and increases the number of steps until the + calculated batch length does not exceed the maximum allowed per GPU. The goal is to find the + lowest gradient accumulation that allows fitting the batch within GPU limits, ensuring efficient + utilization of computational resources. + + Parameters: + - num_gpus (int): The number of GPUs over which the batch is distributed. + - avg_sample_len (int): The average length of samples in the dataset, used to estimate batch length. + - effective_batch_size (int): The total batch size intended to be processed across all GPUs and + accumulation steps. + - max_batch_len_per_gpu (int): The maximum permissible number of tokens on each GPU to avoid memory overflow. + + Returns: + - Tuple[int, int]: A tuple where the first element is the maximum batch length that can be achieved + without exceeding the per-GPU limit, and the second element is the minimum number of gradient + accumulation steps required to maintain the effective batch size. + """ + + packing_max_batch_len = max_batch_len_per_gpu + 1 + grad_accum = 0 + while packing_max_batch_len > max_batch_len_per_gpu: + grad_accum += 1 + total_micro_batch = (effective_batch_size / grad_accum) / num_gpus + + # NOTE: remove this check for now + # if int(avg_sample_len * total_micro_batch) < dataset.get_lengths().max(): + # raise RuntimeError( + # f"Effective batch size is too low for multipack sampling, max sample length={dataset.get_lengths().max()} and min packing length={int(avg_sample_len * total_micro_batch)}. " + # "Switching to naive distributed sampling." + # ) + if is_padding: + addition = find_padding_max_batch_len_addition( + avg_sample_len, + effective_batch_size, + dataset, + num_gpus, + grad_accum, + pad_id, + max_batch_len_per_gpu, + seed, + ) + else: + addition = 0 + packing_max_batch_len = int((avg_sample_len + addition) * total_micro_batch) + + return packing_max_batch_len, grad_accum + + +@numba.njit +def ffd_check(a: np.ndarray, c: int, n: int): + # First-fit-decreasing bin packing + # Check if a[] could fit in n bins with capacity c + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + + a = np.sort(a)[::-1] + bins = np.full((n,), c, dtype=a.dtype) + for size in a: + not_found = True + for idx in range(n): + if bins[idx] >= size: + bins[idx] -= size + not_found = False + break + + if not_found: + return False + + return True + + +@numba.njit +def ffd_check_padding(a: np.ndarray, c: int, n: int): + # First-fit-decreasing bin packing + # Check if a[] could fit in n bins with capacity c + # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing + + a = np.sort(a)[::-1] + bins_max_lengths = np.zeros( + (n,), dtype=a.dtype + ) # Track the maximum length in each bin + bins_num_samples = np.zeros( + (n,), dtype=np.int_ + ) # Track the number of samples in each bin + + for size in a: + not_found = True + for idx in range(n): + # Calculate the new capacity if size is added to the bin + new_capacity = max(bins_max_lengths[idx], size) * ( + bins_num_samples[idx] + 1 + ) + if new_capacity <= c: + bins_max_lengths[idx] = max(bins_max_lengths[idx], size) + bins_num_samples[idx] += 1 + not_found = False + break + + if not_found: + return False + + return True + + +@numba.njit +def ffd_with_result(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) + + indices = np.argsort(a)[::-1] + a = a[indices] + + bins = [] + bins_result = [] + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins)): + if bins[idx] >= size: + bins[idx] -= size + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins.append(c - size) + bins_result.append([indices[a_id] + start_index]) + + return bins_result + + +@numba.njit +def ffd_with_result_padding(a: np.ndarray, c: int, start_index: int): + # First-fit-decreasing bin packing (with result return) + + indices = np.argsort(a)[::-1] + a = a[indices] + + bins_max_lengths = [] # Track the maximum length in each bin + bins_num_samples = [] # Track the number of samples in each bin + bins_result = [] # Track the indices of the samples in each bin + + for a_id, size in enumerate(a): + add_new = True + for idx in range(len(bins_max_lengths)): + # Calculate the new capacity if size is added to the bin + new_capacity = max(bins_max_lengths[idx], size) * ( + bins_num_samples[idx] + 1 + ) + if new_capacity <= c: + bins_max_lengths[idx] = max(bins_max_lengths[idx], size) + bins_num_samples[idx] += 1 + bins_result[idx].append(indices[a_id] + start_index) + add_new = False + break + + if add_new: + bins_max_lengths.append(size) + bins_num_samples.append(1) + bins_result.append([indices[a_id] + start_index]) + + return bins_result + + +@numba.njit +def allocate( + lengths: np.ndarray, + lengths_cumsum: np.ndarray, + rank: int, + c: int, + n: int, + padding: bool = True, +): + # Dynamic batch allocator, similar to Multifit + # https://en.wikipedia.org/wiki/Multifit_algorithm + # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + + s = 0 + start_index = 0 + result = [] + + while True: + # binary search [l, r) + l = 1 + r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while r - l > 1: + m = (l + r) // 2 + if padding: + check = ffd_check_padding(lengths[start_index : start_index + m], c, n) + else: + check = ffd_check(lengths[start_index : start_index + m], c, n) + if check: + l = m + else: + r = m + + # use length l + if padding: + batch = ffd_with_result_padding( + lengths[start_index : start_index + l], c, start_index + ) + else: + batch = ffd_with_result( + lengths[start_index : start_index + l], c, start_index + ) + assert len(batch) <= n + if len(batch) < n: + break + + start_index += l + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch[rank]) + + return result, s, len(result) * c * n + + +class MultipackDistributedBatchSampler(Sampler): + """Unpadded length sampling using Multipack. + Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard. + """ + + def __init__( + self, + batch_max_length: int, + lengths: List[int], + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + padding: bool = True, + ): + # Get rank + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + + self.num_replicas = num_replicas + self.rank = rank + self.seed = seed + + self.batch_max_length = batch_max_length + self.lengths = lengths + assert isinstance(self.lengths, np.ndarray) + + self.epoch = 0 + + # statistics + self.eff_total_used = 0 + self.eff_total_slots = 0 + self.padding = padding + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def generate_batches(self, set_stats=False): + indices = np.random.default_rng(seed=self.seed + self.epoch).permutation( + len(self.lengths) + ) + + # remove indices where the entries are longer than batch max length + indices = indices[self.lengths[indices] <= self.batch_max_length] + if len(indices) < len(self.lengths): + print( + f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m" + ) + + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=self.rank, + c=self.batch_max_length, + n=self.num_replicas, + padding=self.padding, + ) + + batches = [indices[batch] for batch in batches] + + # statistics + if set_stats: + self.eff_total_used += total_used + self.eff_total_slots += total_slots + + return batches + + def __iter__(self): + batches = self.generate_batches(set_stats=True) + return iter(batches) + + def __len__(self): + return self.num_batches() + + def num_batches(self): + batches = self.generate_batches() + return len(batches) + + def efficiency(self): + return self.eff_total_used / self.eff_total_slots diff --git a/plugins/attention-and-distributed-packing/tests/test_aadp_plugin.py b/plugins/attention-and-distributed-packing/tests/test_aadp_plugin.py index ea38158b..397bf95f 100644 --- a/plugins/attention-and-distributed-packing/tests/test_aadp_plugin.py +++ b/plugins/attention-and-distributed-packing/tests/test_aadp_plugin.py @@ -12,20 +12,129 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Standard import os -from fms_acceleration.utils import ( - instantiate_framework, - read_configuration, + +# Third Party +from datasets import Dataset # pylint: disable=import-error +from fms_acceleration.utils import instantiate_framework, read_configuration +import numpy as np + +# First Party +from fms_acceleration_aadp import ( + MultipackDataloaderAccelerationPlugin, + PaddingFreeAccelerationPlugin, ) -from fms_acceleration_aadp import PaddingFreeAccelerationPlugin +from fms_acceleration_aadp.aadp_utils import calculate_token_lengths +from fms_acceleration_aadp.multipack_sampler import MultipackDistributedBatchSampler # configuration DIRNAME = os.path.dirname(__file__) -CONFIG_PATH_ILAB = os.path.join(DIRNAME, "../configs/aadp.yaml") +CONFIG_PATH_PADDINGFREE = os.path.join(DIRNAME, "../configs/padding_free.yaml") +CONFIG_PATH_MULTIPACK = os.path.join(DIRNAME, "../configs/multipack.yaml") + def test_framework_installs_aadp_padding_free_plugin(): + """ + Test framework successfully installs paddingfree plugin + """ with instantiate_framework( - read_configuration(CONFIG_PATH_ILAB), require_packages_check=False + read_configuration(CONFIG_PATH_PADDINGFREE), require_packages_check=False ) as framework: for plugin in framework.active_plugins: assert isinstance(plugin[1], PaddingFreeAccelerationPlugin) + + +def test_framework_installs_aadp_multipack_and_paddingfree_plugins(): + """ + Test framework installs both multipack and paddingfree plugins + """ + pf_config = read_configuration(CONFIG_PATH_PADDINGFREE) + mp_config = read_configuration(CONFIG_PATH_MULTIPACK) + config = {"training": {**pf_config["training"], **mp_config["training"]}} + with instantiate_framework(config, require_packages_check=False) as framework: + assert len(framework.active_plugins) == 2 + for plugin in framework.active_plugins: + assert isinstance( + plugin[1], + (MultipackDataloaderAccelerationPlugin, PaddingFreeAccelerationPlugin), + ) + + +def test_multipack_sampler_assigns_balanced_tokens(): + """ + Ensure that the multipack sampler load balances the tokens amongst the GPUS + """ + num_gpus = 8 + batch_size_per_device = 32 + num_samples = 10000 + seed = 42 + num_processes = 4 + min_token = 0 + max_token = 1000 + min_seq_len = 256 + max_seq_len = 1024 + rng = np.random.default_rng(seed=seed) + + # 1. Build a test dataset + dataset = Dataset.from_list( + [ + { + "input_ids": rng.integers( + low=min_token, + high=max_token, + size=(rng.integers(min_seq_len, max_seq_len)), + ) + } + for _ in range(num_samples) + ] + ) + lengths = calculate_token_lengths(dataset, num_processes=num_processes) + + # 2. generate a multipack subset of indices + max_batch_len = batch_size_per_device * lengths.mean() + mean_tokens_per_rank_multipack = [] + for rank in range(num_gpus): + sampler = MultipackDistributedBatchSampler( + batch_max_length=max_batch_len, + lengths=lengths, + num_replicas=num_gpus, + rank=rank, + seed=seed, + padding=False, + ) + batches = sampler.generate_batches() + tokens_across_batches = [] + for batch in batches: + # count all the tokens in the batch + num_tokens_across_one_batch = sum(lengths[idx] for idx in batch) + tokens_across_batches.append(num_tokens_across_one_batch) + # take average number of tokens across the batches + average_tokens_across_batches = np.ceil(np.mean(tokens_across_batches)) + mean_tokens_per_rank_multipack.append(average_tokens_across_batches) + + # 3. generate a random sampled subset of indices + mean_tokens_per_rank_random = [] + perm_indices = rng.permutation(len(dataset)) + # bin indices to respective ranks + # this is a list of list where each item is a list of indices + # assigned to the respective rank + split_indices_to_ranks = np.array_split(perm_indices, num_gpus) + # bin indices in each rank to respective batches + # the result should be a List[List[List]] where + # dim = 0 corresponds to the number of ranks + # dim = 1 corresponds to the number of batches + # dim = 2 corresponds to the number of indices in a batch + split_indices_to_batches = [ + np.array_split(split, batch_size_per_device) for split in split_indices_to_ranks + ] + for indices_per_rank in split_indices_to_batches: + # count all the tokens in the batch + token_length_in_batch = [ + sum(lengths[idx] for idx in batch) for batch in indices_per_rank + ] + # take average number of tokens across the batches + mean_tokens_per_rank_random.append(np.ceil(np.mean(token_length_in_batch))) + + # expect std from multipack to be smaller + assert np.std(mean_tokens_per_rank_multipack) < np.std(mean_tokens_per_rank_random) diff --git a/plugins/attention-and-distributed-packing/tox.ini b/plugins/attention-and-distributed-packing/tox.ini index 7dfd370e..2eff9357 100644 --- a/plugins/attention-and-distributed-packing/tox.ini +++ b/plugins/attention-and-distributed-packing/tox.ini @@ -11,6 +11,8 @@ commands = # install the dependencies here to ensure # the order pip install -e {toxinidir}/../framework + + pip install datasets # for the multipack tests pytest {posargs:tests} [testenv:lint] diff --git a/plugins/framework/src/fms_acceleration/accelerator_patcher.py b/plugins/framework/src/fms_acceleration/accelerator_patcher.py new file mode 100644 index 00000000..6864e1a1 --- /dev/null +++ b/plugins/framework/src/fms_acceleration/accelerator_patcher.py @@ -0,0 +1,277 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass +from enum import Enum +from types import MethodType +from typing import Any, Callable, Dict, List + +# Third Party +from accelerate import Accelerator +from torch.utils.data import DataLoader + +# AcceleratorPatcher allows various modifications to the accelerator object: +# - includes replacements of various components, and other things (e.g., inserting) +# additional metrics in the outputs of a model forward +# - AcceleratorPatcherComponent regards the components that can be replaced +# - the AcceleratorRule abstracts logic for modifying AcceleratorPatcherComponent +# NOTE: currently only AcceleratorRuleReplace is implemented + +# ---------------------------------- CLASSES ----------------------------------- + + +# Components that can be modified / replaced via the patching of the accelerator +class AcceleratorPatcherComponent(Enum): + + # The dataloader can be replaced + data_loader = 1 + + # The data collator within the dataloader can be replaced + data_collator = 2 + + +# Components that are replaceable +# DataCollator is a typing.NewType and does not work well with isinstance +# - so we type data_collator as a Callable +REPLACEABLE_COMPONENTS = { + AcceleratorPatcherComponent.data_loader.value: DataLoader, + AcceleratorPatcherComponent.data_collator.value: Callable, +} + + +# History of all the patching that has been performed +@dataclass +class AcceleratorPatcherHistory: + + # component that is patched + component: AcceleratorPatcherComponent + + # type of rule (see RULE below) + kind: str + + # id of the rule that was applied + rule_id: str + + +# ---------------------------------- RULE ----------------------------------- + +RULE_KIND_REPLACEMENT = "replacement" + +# List of special kwargs that may affect behavior of specific rules +RULE_SPECIAL_KWARGS = {RULE_KIND_REPLACEMENT: {"skip_prepare"}} + + +@dataclass +class AcceleratorRuleReplace: + + # id, must be unique + rule_id: str + + # component that is patched + component: AcceleratorPatcherComponent + + # replacement: + replacement: Any = None + + # replacement builder + replacement_builder: Callable[[Any, Accelerator], Any] = None + + # pre-req check on the object to be replaced + pre_req: Callable = None + + # additional kwargs that can be used for special behaviors based on component + # e.g, skip_repare for dataloader will skip running the old prepare + kwargs: Dict = None + + def __post_init__(self): + + assert (self.replacement is None and self.replacement_builder is not None) or ( + self.replacement is not None and self.replacement_builder is None + ), "either replacement or replacement_builder should be specified" + + if self.kwargs is None: + self.kwargs = {} + + assert all( + k in RULE_SPECIAL_KWARGS[RULE_KIND_REPLACEMENT] for k in self.kwargs + ), f"Invalid special behavior kwargs in '{self.kwargs.keys()}'" + + def pre_req_check(self, to_be_replaced: Any): + if self.pre_req is None: + return + + assert self.pre_req( + to_be_replaced + ), f"Rule '{self.rule_id}' failed pre-requisite check for type '{type(to_be_replaced)}'." + + +# Sigleton AcceleratorPatcher +class AcceleratorPatcher: + + # singleton history of patches + history: List[AcceleratorPatcherHistory] = [] + + # singleton collection of replacements + replacement_rules: Dict[str, AcceleratorRuleReplace] = {} + + @staticmethod + def replace( + rule_id: str, + component: AcceleratorPatcherComponent, + replacement: Any = None, + replacement_builder: Callable = None, + pre_requisite_check: Callable = None, + **kwargs, + ): + """replace a component. Note that once this method is called, the replacement + is expected to occur, that is there is no fallback behavior + - if the pre_requisite_check fails will raise. + - if there are two replace calls on the same component will raise. + + replacement: the replacement object, if not specified, then replacement builder + must be specified. + replacement_builder (callable): the replacement builder object. If not specified, + then replacement must be specified. + pre_requisite_check (callable): the component to be replaced is expected to + pass this check, otherwise raises. + kwargs (dict): These control special behaviors of the replacement rules, see + RULE_SPECIAL_KWARGS above. + """ + + # - ensure that rule has not been added before + assert not any( + h.rule_id == rule_id for h in AcceleratorPatcher.history + ), f"Rule '{rule_id}' has already been added" + + assert ( + component.value not in AcceleratorPatcher.replacement_rules + ), f"replace has already been called once on component '{component.name}'" + + # handle the replacement + # - if replacement is not None, ensure replacement object is of the correct type + if replacement is not None: + comp_cls = REPLACEABLE_COMPONENTS.get(component.value) + if comp_cls: + assert isinstance(replacement, comp_cls), ( + f"Rule '{rule_id}' replacing component '{component}' with wrong ", + f"type '{type(replacement)}'", + ) + elif replacement_builder is not None: + # NOTE: there is no class check for the replacement builder pattern + pass + + # - register the replacement rule + AcceleratorPatcher.replacement_rules[component.value] = AcceleratorRuleReplace( + rule_id, + component, + replacement, + replacement_builder, + pre_requisite_check, + kwargs=kwargs, + ) + + # - record the history. This is done in advance for replacements even + # the pre-req check has not been run. + # - This advanced registration simplifies logic in the patch, and its ok + # because we will raise in the pre-req if fails, as the semantics for + # replace is that it is expected to occur once called. + AcceleratorPatcher.history.append( + AcceleratorPatcherHistory(component, RULE_KIND_REPLACEMENT, rule_id) + ) + + @staticmethod + def patch(accelerator: Accelerator): + + # some rules will require patching the prepare function + # - e.g., if replacements are required. + if any( + key + in ( + AcceleratorPatcherComponent.data_collator.value, + AcceleratorPatcherComponent.data_loader.value, + ) + for key in AcceleratorPatcher.replacement_rules + ): + AcceleratorPatcher._patch_prepare(accelerator) + + # function to patch the accelerator prepare + @staticmethod + def _patch_prepare(accelerator: Accelerator): + + # hijack the dataloader in accelerator.prepare to replace the collate_fn + _old_prepare = accelerator.prepare + + def prepare(self, *args, device_placement=None): + if len(args) > 1 or not isinstance(args[0], DataLoader): + return _old_prepare(*args, device_placement=device_placement) + + # if there is dataloader replacment + dataloader_replacement_rule = AcceleratorPatcher.replacement_rules.get( + AcceleratorPatcherComponent.data_loader.value + ) + # the original dataloader + dataloader = args[0] + + if dataloader_replacement_rule: + dataloader_replacement_rule.pre_req_check(dataloader) + if dataloader_replacement_rule.replacement is not None: + dataloader = dataloader_replacement_rule.replacement + else: + dataloader = dataloader_replacement_rule.replacement_builder( + dataloader, accelerator + ) + + # if there is dataloader replacment + collator_replacement_rule = AcceleratorPatcher.replacement_rules.get( + AcceleratorPatcherComponent.data_collator.value + ) + + if collator_replacement_rule: + # - first we run the check on the rule (if any) + # - then we replace + collator_replacement_rule.pre_req_check(dataloader.collate_fn) + + # FIXME: for now we just disable the replacement_builder + assert ( + collator_replacement_rule.replacement_builder is None + ), "Currently, replacement_builder not allowed for data collator" + + # Replace the collate_fn in dataloader + dataloader.collate_fn = collator_replacement_rule.replacement + + # - special behavior for dataloader replacements + # - need to know if we run the original prepare + if ( + dataloader_replacement_rule is not None + and dataloader_replacement_rule.kwargs.get("skip_prepare", False) + ): + return dataloader + + return _old_prepare(dataloader) + + accelerator.prepare = MethodType(prepare, accelerator) + + @staticmethod + def summary(): + result = [] + result.append("***************** Accelerator Patching *************") + for x in AcceleratorPatcher.history: + result.append( + "Rule: {0:25s} Kind: {1:10s} Component: {2:20s}".format( + x.rule_id, x.kind, x.component.name + ) + ) + + return "\n".join(result) diff --git a/plugins/framework/src/fms_acceleration/framework.py b/plugins/framework/src/fms_acceleration/framework.py index 1e6ecb44..3a393815 100644 --- a/plugins/framework/src/fms_acceleration/framework.py +++ b/plugins/framework/src/fms_acceleration/framework.py @@ -38,20 +38,15 @@ logger.setLevel(logging._get_default_logging_level()) logger.addHandler(logging._default_handler) + def log_patch_summary( + summary: List[str], logging_func: Callable = None, ): if logging_func is None: logging_func = print - # this is a guarded import, because the model rule registration - # does not need to be loaded unless patch_model is required - # Local - from .model_patcher import ( # pylint: disable=import-outside-toplevel - patch_model_summary, - ) - - for line in patch_model_summary().split("\n"): + for line in summary: logging_func(line) @@ -226,11 +221,23 @@ def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator: Accelerator = None ): - from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel + # Local + from .model_patcher import ( # pylint: disable=import-outside-toplevel + ModelPatcher, + ) + if model is not None: # Finally apply all registered patches to the model ModelPatcher.patch(model) + # do the accelerator patching + # Local + # pylint: disable=import-outside-toplevel + from .accelerator_patcher import AcceleratorPatcher + + if accelerator is not None: + AcceleratorPatcher.patch(accelerator) + # show the initialized message if accelerator is not None and accelerator.is_main_process: log_initialization_message( @@ -240,7 +247,13 @@ def get_callbacks_and_ready_for_train( ) # if patching is done, print patch summary to logger - log_patch_summary(logging_func=logger.info) + log_patch_summary( + ModelPatcher.summary(raw=False).split("\n"), logging_func=logger.info + ) + + log_patch_summary( + AcceleratorPatcher.summary().split("\n"), logging_func=logger.info + ) cbks = [] for _, plugin in self.active_plugins: diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index 0f24597e..2ea6f2ec 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -24,6 +24,7 @@ from transformers import TrainingArguments import torch + @dataclass class PluginRegistration: plugin: "AccelerationPlugin" diff --git a/plugins/framework/src/fms_acceleration/model_patcher.py b/plugins/framework/src/fms_acceleration/model_patcher.py index 5f0655fb..118c1026 100644 --- a/plugins/framework/src/fms_acceleration/model_patcher.py +++ b/plugins/framework/src/fms_acceleration/model_patcher.py @@ -118,21 +118,22 @@ def __post_init__(self): self.type = ModelPatcherTriggerType.module else: # ensure check conforms with self.type - assert self.type == ModelPatcherTriggerType.module, \ - "type argument passed but `check` argument does not match type specified" + assert ( + self.type == ModelPatcherTriggerType.module + ), "type argument passed but `check` argument does not match type specified" # if check is a callable elif callable(self.check): if self.type is None: self.type = ModelPatcherTriggerType.callable else: # ensure check conforms with self.type - assert self.type == ModelPatcherTriggerType.callable, \ - "type argument passed but `check` argument does not match type specified" + assert ( + self.type == ModelPatcherTriggerType.callable + ), "type argument passed but `check` argument does not match type specified" else: raise TypeError("check argument needs to be torch.nn.Module or Callable") - # type for model forward ModelForward = Callable @@ -175,11 +176,16 @@ class ModelPatcherRule: ] = None def __post_init__(self): - if sum([ - self.forward is not None, - self.forward_builder is not None, - self.import_and_maybe_reload is not None, - ]) != 1: + if ( + sum( + [ + self.forward is not None, + self.forward_builder is not None, + self.import_and_maybe_reload is not None, + ] + ) + != 1 + ): raise ValueError( f"Rule '{self.rule_id}' must only have only one of forward, " "foward builder, or import_and_maybe_reload, specified." @@ -197,6 +203,7 @@ def __post_init__(self): "forward_builder." ) + # helpful to keep a history of all patching that has been done @dataclass class ModelPatcherHistory: @@ -278,8 +285,10 @@ def did_rule_trigger(module: torch.nn.Module, module_name: str): # for simple forward patches. forward_builder args are handled # when they are decomposed into new simple forward rules elif rule.forward is not None: - warnings.warn(f"rule {rule.rule_id} is ignored on {module_name} as an \ - earlier rule {active_rule.rule_id} has been applied") + warnings.warn( + f"rule {rule.rule_id} is ignored on {module_name} as an \ + earlier rule {active_rule.rule_id} has been applied" + ) return active_rule_name, active_rule @@ -337,15 +346,16 @@ def _import_and_reload(model: torch.nn.Module): _with_reload = sorted( _with_reload, key=lambda _rule: len(_rule.import_and_maybe_reload[2]), - reverse=False + reverse=False, ) for rule_s in _with_reload: for rule_l in _with_reload[1:]: # if target paths in rule s is a prefix of rule l, raise an error _, _, _path_s = rule_s.import_and_maybe_reload _, _, _path_l = rule_l.import_and_maybe_reload - assert not _path_l.startswith(_path_s), \ - f"Attempting to reload same path `{_path_s}` multiple times in \ + assert not _path_l.startswith( + _path_s + ), f"Attempting to reload same path `{_path_s}` multiple times in \ {rule_s.rule_id} and {rule_l.rule_id}" # handle those with reload first @@ -502,6 +512,7 @@ def summary(raw: bool = False): # ------------------------ function ----------------------- + def patch_model(model: torch.nn.Module, **kwargs): ModelPatcher.patch(model, **kwargs) return model @@ -512,7 +523,10 @@ def patch_model_summary(): def combine_triggers(*triggers: ModelPatcherTrigger, logic: str = "OR"): - assert logic in ["AND", "OR"], "Only `AND`, `OR` logic implemented for combining triggers" + assert logic in [ + "AND", + "OR", + ], "Only `AND`, `OR` logic implemented for combining triggers" # NOTE: this can be probably simplified def _or_logic(*args, **kwargs): @@ -533,6 +547,7 @@ def _and_logic(*args, **kwargs): return ModelPatcherTrigger(check=_logic) + def combine_functions(*funcs: Callable, logic: str = "APPEND"): assert logic == "APPEND", "only APPEND logic implemented for combining functions" diff --git a/plugins/framework/src/fms_acceleration/utils/test_utils.py b/plugins/framework/src/fms_acceleration/utils/test_utils.py index 929c61e3..bc80ec8e 100644 --- a/plugins/framework/src/fms_acceleration/utils/test_utils.py +++ b/plugins/framework/src/fms_acceleration/utils/test_utils.py @@ -181,10 +181,34 @@ def dummy_custom_loader(self, model_name, **kwargs): "dummy custom loader returning dummy model" return create_noop_model_with_archs(archs=["DummyModel"]) # + @contextmanager def instantiate_model_patcher(): - from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel + # First Party + from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel + ModelPatcher, + ) + old_registrations = ModelPatcher.rules + old_history = ModelPatcher.history ModelPatcher.rules = {} + ModelPatcher.history = [] yield ModelPatcher.rules = old_registrations + ModelPatcher.history = old_history + + +@contextmanager +def instantiate_accel_patcher(): + # First Party + from fms_acceleration.accelerator_patcher import ( # pylint: disable=import-outside-toplevel + AcceleratorPatcher, + ) + + old_registrations = AcceleratorPatcher.replacement_rules + old_history = AcceleratorPatcher.history + AcceleratorPatcher.replacement_rules = {} + AcceleratorPatcher.history = [] + yield + AcceleratorPatcher.replacement_rules = old_registrations + AcceleratorPatcher.history = old_history diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py index 46e6ebb6..733cfbd8 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py +++ b/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .module3 import Module3Class +# Local from .module1_1 import Module1Class, mod_1_function +from .module3 import Module3Class diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py b/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py index 07a5b86a..d3404ddd 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Local from ..module2 import Module2Class + class Module1Class: def __init__(self) -> None: self.attribute = Module2Class() + def mod_1_function(): return "unpatched_mod_function" - \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py index eb882843..a663ab2c 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Local from .module3_1 import Module3Class diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py b/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py index 09108981..0f2e254a 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Local from ..module1_1 import mod_1_function + class Module3Class: def __init__(self) -> None: self.attribute = mod_1_function diff --git a/plugins/framework/tests/model_patcher_fixtures/module2.py b/plugins/framework/tests/model_patcher_fixtures/module2.py index d866ac51..cac907b9 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module2.py +++ b/plugins/framework/tests/model_patcher_fixtures/module2.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + class Module2Class: def __init__(self) -> None: pass diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py index d1a1e40c..176292d0 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py +++ b/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Third Party +import torch + +# Local from .module4_1 import mod_4_function from .module5 import Module5Class, mod_5_function -import torch + class Module4Class(torch.nn.Module): def __init__(self) -> None: diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py b/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py index 3b302c1f..5f7a7d29 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. + def mod_4_function(): return "unpatched_mod_function" - \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py index cf0bb8e2..d7288c1d 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Local from .module5_1 import Module5Class, mod_5_function diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py b/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py index fbfa408b..f62e812d 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Third Party import torch + class Module5Class(torch.nn.Module): def __init__(self) -> None: super().__init__() + def mod_5_function(): return "unpatched_mod_function" - \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_test_utils.py b/plugins/framework/tests/model_patcher_test_utils.py index 30ccb6cb..695c5771 100644 --- a/plugins/framework/tests/model_patcher_test_utils.py +++ b/plugins/framework/tests/model_patcher_test_utils.py @@ -12,32 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +# Standard +from contextlib import contextmanager +from typing import Any, Dict, Type +import importlib import os import sys -import importlib -from contextlib import contextmanager -from typing import Dict, Any, Type -ROOT = 'tests.model_patcher_fixtures' +# Third Party +import torch + +ROOT = "tests.model_patcher_fixtures" MODULE_PATHS = [] -for root, dirs, files in os.walk(ROOT.replace('.', os.path.sep)): +for root, dirs, files in os.walk(ROOT.replace(".", os.path.sep)): for f in files: filename, ext = os.path.splitext(f) if ext != ".py": continue - if filename != '__init__': + if filename != "__init__": p = os.path.join(root, filename) else: p = root - MODULE_PATHS.append(p.replace(os.path.sep, '.')) + MODULE_PATHS.append(p.replace(os.path.sep, ".")) + @contextmanager def isolate_test_module_fixtures(): - old_mod = { - k: sys.modules[k] for k in MODULE_PATHS if k in sys.modules - } + old_mod = {k: sys.modules[k] for k in MODULE_PATHS if k in sys.modules} yield # Reload only reloads the speicified module, but makes not attempt to reload @@ -58,7 +60,7 @@ def isolate_test_module_fixtures(): def create_module_class( class_name: str, namespaces: Dict[str, Any] = None, - parent_class: Type = torch.nn.Module + parent_class: Type = torch.nn.Module, ): if namespaces is None: namespaces = {} diff --git a/plugins/framework/tests/test_accel_patcher.py b/plugins/framework/tests/test_accel_patcher.py new file mode 100644 index 00000000..2a3e1521 --- /dev/null +++ b/plugins/framework/tests/test_accel_patcher.py @@ -0,0 +1,151 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Third Party +from accelerate import Accelerator +import pytest # pylint: disable=import-error +import torch + +# First Party +from fms_acceleration.accelerator_patcher import ( + AcceleratorPatcher, + AcceleratorPatcherComponent, + AcceleratorRuleReplace, +) +from fms_acceleration.utils.test_utils import instantiate_accel_patcher + + +def test_AP_rule_raises_correct_errors(): + # not specifying any replacement objects will throw an error + with pytest.raises( + AssertionError, + match="either replacement or replacement_builder should be specified", + ): + AcceleratorRuleReplace( + rule_id="bad-rule-empty-builders", + component=AcceleratorPatcherComponent.data_loader, + replacement=None, + replacement_builder=None, + ) + + # Ensure that rule registration throws error when attempting + # to specify an unknown flag for an unsupported behaviour + # handling of the component by AP + with pytest.raises( + AssertionError, + match=r"Invalid special behavior kwargs in '.*'", + ): + AcceleratorRuleReplace( + rule_id="invalid-special-kwargs", + component=AcceleratorPatcherComponent.data_loader, + replacement=torch.utils.data.DataLoader( + torch.utils.data.Dataset(), + ), + kwargs={"unsupported_kwarg": True}, + ) + + +def test_AP_failing_prereq_check_raises_error(): + # 1. register AP rule + # 2. instantiate accelerator + # 3. attempt to patch accelerator prepare function w a pre-req check + # 4. call accelerator prepare + # 5. ensure that pre-req check raises error when condition not satisfied + pre_req_error_message = "pre-requisite check failed" + + def pre_req_check(dataloader): + raise ValueError(pre_req_error_message) + + with pytest.raises( + ValueError, + match=pre_req_error_message, + ): + with instantiate_accel_patcher(): + dummy_dataloader = torch.utils.data.DataLoader(torch.utils.data.Dataset()) + + # register the replacement rule + AcceleratorPatcher.replace( + rule_id="pre-req-check-raises-error", + component=AcceleratorPatcherComponent.data_loader, + replacement=dummy_dataloader, + pre_requisite_check=pre_req_check, + ) + # instantiate an accelerator object + accelerator = Accelerator() + # patch the prepare function + AcceleratorPatcher.patch(accelerator) + # call accelerator prepare + accelerator.prepare(dummy_dataloader) + + +def test_AP_patch_correctly_with_simple_replacement(): + # 1. register rule to replace collate fn + # 2. patch the accelerator + # 3. call accelerator.prepare with a dataloader + # 4. verify that the dataloader's collate fn behaviour has updated + message = "replacement successful" + + def replaced_collater(): + return message + + with instantiate_accel_patcher(): + dataloader = torch.utils.data.DataLoader(torch.utils.data.Dataset()) + # register the replacement rule for new collate fn + AcceleratorPatcher.replace( + rule_id="simple-replacement-successful", + component=AcceleratorPatcherComponent.data_collator, + replacement=replaced_collater, + ) + # instantiate an accelerator object + accelerator = Accelerator() + # patch the prepare function + AcceleratorPatcher.patch(accelerator) + # call accelerator prepare + dataloader = accelerator.prepare(dataloader) + assert dataloader.collate_fn() == "replacement successful" + + +def test_AP_patch_correctly_with_replacement_builder(): + # 1. Create a builder function for a new dataloader class + # 2. Register a replacement rule to take in the builder function + # 3. Instantiate and patch accelerator + # 4. call accelerator.prepare on a standard dataloader + # 5. verify that the dataloader has been replaced + class NewDataLoader(torch.utils.data.DataLoader): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def build_new_dataloader( + dataloader: torch.utils.data.DataLoader, accelerator: Accelerator + ): + return NewDataLoader( + torch.utils.data.Dataset(), + ) + + with instantiate_accel_patcher(): + original_dataloader = torch.utils.data.DataLoader(torch.utils.data.Dataset()) + # register the replacement rule + AcceleratorPatcher.replace( + rule_id="replacement-builder-successful", + component=AcceleratorPatcherComponent.data_loader, + replacement_builder=build_new_dataloader, + skip_prepare=True, + ) + # instantiate an accelerator object + accelerator = Accelerator() + # patch the prepare function + AcceleratorPatcher.patch(accelerator) + # call accelerator prepare + replaced_dataloader = accelerator.prepare(original_dataloader) + assert isinstance(replaced_dataloader, NewDataLoader) diff --git a/plugins/framework/tests/test_model_patcher.py b/plugins/framework/tests/test_model_patcher.py index 13d01cf3..f9be7447 100644 --- a/plugins/framework/tests/test_model_patcher.py +++ b/plugins/framework/tests/test_model_patcher.py @@ -22,13 +22,15 @@ ModelPatcherTrigger, patch_target_module, ) +from fms_acceleration.utils.test_utils import instantiate_model_patcher -from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures +# Local from .model_patcher_fixtures import module4 -from fms_acceleration.utils.test_utils import instantiate_model_patcher +from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures from .test_model_patcher_helpers import DUMMY_RULE_ID -#Test patching of model attribute + +# Test patching of model attribute def test_simple_forward_rule_with_mp_replaces_old_forward(): """ Ensure that a child submodule forward function @@ -71,8 +73,8 @@ def patched_forward_function(X): with instantiate_model_patcher(): model = module4.Module4Class() SubModule1 = create_module_class( - "SubModule1", - namespaces={"forward": lambda self: "unpatched_forward_function"} + "SubModule1", + namespaces={"forward": lambda self: "unpatched_forward_function"}, ) model.add_module("submodule_1", SubModule1()) rule = ModelPatcherRule( @@ -85,6 +87,7 @@ def patched_forward_function(X): assert model.submodule_1.forward() == "patched_forward_function" + def test_import_and_maybe_reload_rule_with_mp_replaces_old_attribute(): """ Module4Class has an attribute from Module5Class, @@ -119,6 +122,7 @@ def test_import_and_maybe_reload_rule_with_mp_replaces_old_attribute(): ModelPatcher.patch(model) assert isinstance(module4.Module4Class().attribute, PatchedModuleClass) + def test_mp_throws_error_with_multiple_reloads_on_same_target(): """ Simulate a case where two rules attempt to reload on the same target prefix @@ -176,7 +180,7 @@ def patched_mod_function(): patch_target_module( "tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class", PatchedModuleClass, - "tests.model_patcher_fixtures.module4.module5" + "tests.model_patcher_fixtures.module4.module5", ) assert isinstance(module4.module5.Module5Class(), PatchedModuleClass) @@ -223,6 +227,7 @@ def patched_mod_function(): # longer target path ModelPatcher.patch(model) + def test_mp_throws_warning_with_multiple_patches(): """ Ensure for each module, only one forward patch is implemented on it. @@ -247,21 +252,21 @@ def test_mp_throws_warning_with_multiple_patches(): model = module4.Module4Class() SubModule1 = create_module_class( - "SubModule1", - namespaces={"forward": lambda self: "unpatched_forward_function"} + "SubModule1", + namespaces={"forward": lambda self: "unpatched_forward_function"}, ) model.add_module("submodule_1", SubModule1()) ModelPatcher.register( ModelPatcherRule( - rule_id=DUMMY_RULE_ID+".1", + rule_id=DUMMY_RULE_ID + ".1", trigger=ModelPatcherTrigger(check=SubModule1), forward=lambda self: "patched_forward_function", ) ) ModelPatcher.register( ModelPatcherRule( - rule_id=DUMMY_RULE_ID+".2", + rule_id=DUMMY_RULE_ID + ".2", trigger=ModelPatcherTrigger(check=SubModule1), forward=lambda self: "patched_forward_function_2", ) @@ -274,6 +279,7 @@ def test_forward_builder_rule_with_mp_replaces_old_forward(): Ensure that patching a model with a rule using forward_builder argument will replace the children module forwards """ + def is_module_type_B(module): if hasattr(module, "B"): return True @@ -296,7 +302,8 @@ def patched_forward_function(X): # 4. Ensure all submodule forwards are patched SubModule1 = create_module_class( - "SubModule1", namespaces={"forward": lambda X: "unpatched_forward_function"} + "SubModule1", + namespaces={"forward": lambda X: "unpatched_forward_function"}, ) SubModule1A = create_module_class( "SubModule1A", parent_class=SubModule1, namespaces={"A": "attributeA"} @@ -305,8 +312,11 @@ def patched_forward_function(X): "SubModule1B", parent_class=SubModule1, namespaces={"B": "attributeB"} ) SubModule2 = create_module_class( - "SubModule2", - namespaces={"C": "attributeC", "forward": lambda X: "unpatched_forward_function"} + "SubModule2", + namespaces={ + "C": "attributeC", + "forward": lambda X: "unpatched_forward_function", + }, ) model = module4.module5.Module5Class() @@ -320,8 +330,14 @@ def build_list_of_triggers( ): return [ (ModelPatcherTrigger(check=SubModule1A), patched_forward_function), - (ModelPatcherTrigger(check=is_module_type_B), patched_forward_function), - (ModelPatcherTrigger(check=is_module_type_C), patched_forward_function), + ( + ModelPatcherTrigger(check=is_module_type_B), + patched_forward_function, + ), + ( + ModelPatcherTrigger(check=is_module_type_C), + patched_forward_function, + ), ] ModelPatcher.register( @@ -329,7 +345,7 @@ def build_list_of_triggers( rule_id=DUMMY_RULE_ID, trigger=ModelPatcherTrigger(check=module4.module5.Module5Class), forward_builder=build_list_of_triggers, - ) + ) ) ModelPatcher.patch(model) diff --git a/plugins/framework/tests/test_model_patcher_helpers.py b/plugins/framework/tests/test_model_patcher_helpers.py index 6713aced..0660f577 100644 --- a/plugins/framework/tests/test_model_patcher_helpers.py +++ b/plugins/framework/tests/test_model_patcher_helpers.py @@ -23,36 +23,40 @@ from fms_acceleration.model_patcher import ( ModelPatcherRule, ModelPatcherTrigger, - patch_target_module, ModelPatcherTriggerType, combine_triggers, + patch_target_module, ) -from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures +# Local from .model_patcher_fixtures import module1 +from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures MOD_CLS_A = create_module_class("MOD_CLS_A") MOD_SUBCLS_A = create_module_class("MOD_SUBCLS_A", parent_class=MOD_CLS_A) MOD_CLS_B = create_module_class("MOD_CLS_B") + def returns_false(*args, **kwargs): "falsy function" return False + def returns_true(*args, **kwargs): "truthy function" return True + DUMMY_RULE_ID = "test_patch" # | ------------------ Test ModelPatcherTrigger ----------------------- | + def test_mp_trigger_constructs_with_check_arg_only(): "Test construction of trigger with check argument" # Test that error is raised when check is not of accepted type with pytest.raises( - TypeError, - match = "check argument needs to be torch.nn.Module or Callable" + TypeError, match="check argument needs to be torch.nn.Module or Callable" ): ModelPatcherTrigger(check=None) @@ -64,6 +68,7 @@ def test_mp_trigger_constructs_with_check_arg_only(): trigger = ModelPatcherTrigger(check=returns_true) assert trigger.type == ModelPatcherTriggerType.callable + def test_mp_trigger_constructs_with_check_and_trigger_type_args(): "Test construction of trigger with check and type arguments" # check that trigger constructs successfully as check conforms to specified type @@ -80,7 +85,7 @@ def test_mp_trigger_constructs_with_check_and_trigger_type_args(): # Ensure an error is raised when check is callable but type is module with pytest.raises( AssertionError, - match = "type argument passed but `check` argument does not match type specified", + match="type argument passed but `check` argument does not match type specified", ): ModelPatcherTrigger( check=returns_true, @@ -90,13 +95,14 @@ def test_mp_trigger_constructs_with_check_and_trigger_type_args(): # Ensure an error is raised when check is module but type is callable with pytest.raises( AssertionError, - match = "type argument passed but `check` argument does not match type specified", + match="type argument passed but `check` argument does not match type specified", ): ModelPatcherTrigger( check=torch.nn.Module, type=ModelPatcherTriggerType.callable, ) + def test_mp_trigger_correctly_triggers(): "Test for correctnness of trigger behaviour" @@ -126,13 +132,19 @@ def check_module(module): return True return False - assert ModelPatcherTrigger(check=check_module).is_triggered( - ModClassA(), - ) is True + assert ( + ModelPatcherTrigger(check=check_module).is_triggered( + ModClassA(), + ) + is True + ) - assert ModelPatcherTrigger(check=check_module).is_triggered( - ModClassB(), - ) is False + assert ( + ModelPatcherTrigger(check=check_module).is_triggered( + ModClassB(), + ) + is False + ) # Scenario 2: # Ensure return True, if is not an instance of ModelPatcherTrigger.check @@ -140,18 +152,27 @@ def check_module(module): # 2. create an instance of ModClassA and check is_triggered returns True # 3. create a subclass instance of ModClassA and check is_triggered returns True # 4. create an instance of ModClassB and check is_triggered returns False - assert ModelPatcherTrigger(check=ModClassA).is_triggered( - ModClassA(), - ) is True + assert ( + ModelPatcherTrigger(check=ModClassA).is_triggered( + ModClassA(), + ) + is True + ) - assert ModelPatcherTrigger(check=ModClassA).is_triggered( - ModSubClassA(), - ) is True + assert ( + ModelPatcherTrigger(check=ModClassA).is_triggered( + ModSubClassA(), + ) + is True + ) # Ensure returns False, if is not an instance of ModelPatcherTrigger.check - assert ModelPatcherTrigger(check=ModClassA).is_triggered( - ModClassB(), - ) is False + assert ( + ModelPatcherTrigger(check=ModClassA).is_triggered( + ModClassB(), + ) + is False + ) # Scenario 3: # Static check to ensure additional constraint is checked @@ -177,6 +198,7 @@ def check_module(module): # assert that is_triggered otherwise returns false assert trigger.is_triggered(module, name) is False + # Each test instance has # - target_module, # - tuple of trigger check arguments @@ -187,23 +209,27 @@ def check_module(module): # 3. if expected_result is a tuple, ensure an error is raised upon constructing the trigger # 4. Otherwise, ensure that the combined_trigger returns the expected result on the target module @pytest.mark.parametrize( - "target_module,trigger_checks,logic,expected_result", [ - [MOD_SUBCLS_A(), (returns_true, MOD_CLS_B), "OR", True], - [MOD_SUBCLS_A(), (MOD_CLS_B, returns_false), "OR", False], - [MOD_SUBCLS_A(), (MOD_CLS_A, returns_true), "OR", True], - [MOD_CLS_B(), (returns_false, MOD_CLS_A), "AND", False], - [MOD_CLS_B(), (MOD_CLS_B, returns_false), "AND", False], - [MOD_CLS_B(), (MOD_CLS_B, returns_true), "AND", True], + "target_module,trigger_checks,logic,expected_result", [ - MOD_SUBCLS_A(), (MOD_CLS_B, MOD_CLS_A), "NOR", - (AssertionError, "Only `AND`, `OR` logic implemented for combining triggers") + [MOD_SUBCLS_A(), (returns_true, MOD_CLS_B), "OR", True], + [MOD_SUBCLS_A(), (MOD_CLS_B, returns_false), "OR", False], + [MOD_SUBCLS_A(), (MOD_CLS_A, returns_true), "OR", True], + [MOD_CLS_B(), (returns_false, MOD_CLS_A), "AND", False], + [MOD_CLS_B(), (MOD_CLS_B, returns_false), "AND", False], + [MOD_CLS_B(), (MOD_CLS_B, returns_true), "AND", True], + [ + MOD_SUBCLS_A(), + (MOD_CLS_B, MOD_CLS_A), + "NOR", + ( + AssertionError, + "Only `AND`, `OR` logic implemented for combining triggers", + ), + ], ], -]) +) def test_combine_mp_triggers_produces_correct_output( - target_module, - trigger_checks, - logic, - expected_result + target_module, trigger_checks, logic, expected_result ): triggers = [ModelPatcherTrigger(check=check) for check in trigger_checks] @@ -217,11 +243,14 @@ def test_combine_mp_triggers_produces_correct_output( *triggers, logic=logic, ) - else: # otherwise ensure is_triggered output returns the expected_result - assert combine_triggers( - *triggers, - logic=logic, - ).is_triggered(target_module) is expected_result + else: # otherwise ensure is_triggered output returns the expected_result + assert ( + combine_triggers( + *triggers, + logic=logic, + ).is_triggered(target_module) + is expected_result + ) def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): @@ -229,8 +258,8 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): # Test mp rule construction raises with multiple arguments with pytest.raises( ValueError, - match="must only have only one of forward, " \ - "foward builder, or import_and_maybe_reload, specified." + match="must only have only one of forward, " + "foward builder, or import_and_maybe_reload, specified.", ): ModelPatcherRule( rule_id=DUMMY_RULE_ID, @@ -242,8 +271,7 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): # Test mp rule construction raises with trigger and import_and_reload with pytest.raises( ValueError, - match="has import_and_maybe_reload specified, " \ - "and trigger must be None." + match="has import_and_maybe_reload specified, " "and trigger must be None.", ): ModelPatcherRule( rule_id=DUMMY_RULE_ID, @@ -255,16 +283,13 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): # without a forward_builder, this can be the case when user passes in a # forward instead of forward_builder with pytest.raises( - ValueError, - match="has forward_builder_args but no " \ - "forward_builder." + ValueError, match="has forward_builder_args but no " "forward_builder." ): ModelPatcherRule( - rule_id=DUMMY_RULE_ID, - forward = lambda self, X: X, - forward_builder_args=[] + rule_id=DUMMY_RULE_ID, forward=lambda self, X: X, forward_builder_args=[] ) + def test_patch_target_module_replaces_module_or_function_correctly(): """ Test patching of standalone file functions @@ -319,7 +344,7 @@ def patched_mod_function(): patch_target_module( "tests.model_patcher_fixtures.module2.Module2Class", PatchedModuleClass, - "tests.model_patcher_fixtures.module1.module1_1" + "tests.model_patcher_fixtures.module1.module1_1", ) assert isinstance(module1.Module1Class().attribute, PatchedModuleClass) @@ -332,8 +357,8 @@ def patched_mod_function(): # - this test shows that a replacement only affects the EXACT PATH that was patched with isolate_test_module_fixtures(): patch_target_module( - "tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class", - PatchedModuleClass, + "tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class", + PatchedModuleClass, ) # - this is the exact module path that was patched, so it will reflect the patched class @@ -350,16 +375,18 @@ def patched_mod_function(): # for reload with isolate_test_module_fixtures(): patch_target_module( - "tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class", - PatchedModuleClass, - "tests.model_patcher_fixtures.module1", + "tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module1", ) # - the reload of the top level module path module1, will NOT replace module1.module3 # with the original version # - reloading top-level paths is tricky due to caching of the modules # - the reload of a top-level module does not cascade down to children modules. - assert not isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass) + assert not isinstance( + module1.module3.module3_1.Module3Class(), PatchedModuleClass + ) # S3.3 - module1.module3 is a submodule of module1 # 1. Replace module1.module3.module3_1.Module3Class with a new class @@ -368,14 +395,16 @@ def patched_mod_function(): # for reload with isolate_test_module_fixtures(): patch_target_module( - "tests.model_patcher_fixtures.module1.module3.Module3Class", - PatchedModuleClass, - "tests.model_patcher_fixtures.module1", + "tests.model_patcher_fixtures.module1.module3.Module3Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module1", ) # - the reload of the top level module path module1, will replace module1.module3 # with the original version - assert not isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass) + assert not isinstance( + module1.module3.module3_1.Module3Class(), PatchedModuleClass + ) # S4 - module1.module3 submodule has a dependency on # module1.module1_1.mod_1_function diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index fec756de..09301193 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -38,6 +38,11 @@ framework_configs: - attention-and-distributed-packing filename: aadp-padding-free-sample-configuration.yaml + - shortname: aadp-padding-free-multipack + plugins: + - attention-and-distributed-packing + filename: aadp-padding-free-multipack-sample-configuration.yaml + - shortname: accelerated-peft-bnb-padding-free plugins: - accelerated-peft diff --git a/sample-configurations/aadp-padding-free-multipack-sample-configuration.yaml b/sample-configurations/aadp-padding-free-multipack-sample-configuration.yaml new file mode 100644 index 00000000..a0a9f6c6 --- /dev/null +++ b/sample-configurations/aadp-padding-free-multipack-sample-configuration.yaml @@ -0,0 +1,22 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + dataloader: + + # multipack dataloader + multipack: + + # number of processes used to calculate dataset lengths + num_processes: 16 diff --git a/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml index 443cfca4..3ca13131 100644 --- a/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml @@ -3,29 +3,29 @@ # Each stanza incorporates various configurations for # different fine-tuning / training tasks. plugins: - # PEFT-related acceleration + # PEFT-related acceleration peft: - # quantization-releated acceleration - # e.g., kernels for quantized base weights + # quantization-releated acceleration + # e.g., kernels for quantized base weights quantization: - # AutoGPTQ quantized base weights. + # AutoGPTQ quantized base weights. auto_gptq: - # Kernel to be used for GPTQ linear laeyer - # NOTE: Not all kernels are suitable for PEFT training; need to use - # kernels that support autograd forward / backward. The best - # recommendation at the moment is "triton_v2". + # Kernel to be used for GPTQ linear laeyer + # NOTE: Not all kernels are suitable for PEFT training; need to use + # kernels that support autograd forward / backward. The best + # recommendation at the moment is "triton_v2". kernel: triton_v2 - # If true, then will already expect quantized checkpoint - # passed into TrainingArguments.model_name_or_path + # If true, then will already expect quantized checkpoint + # passed into TrainingArguments.model_name_or_path from_quantized: true - # Setting to false, will create GPTQ-LORA using the local autogptq package. - # if true, will create legacy implementation of GPTQ-LORA using external - # `auto_gptq`. Refer to README for installation instructions + # Setting to false, will create GPTQ-LORA using the local autogptq package. + # if true, will create legacy implementation of GPTQ-LORA using external + # `auto_gptq`. Refer to README for installation instructions use_external_lib: false fused_ops_and_kernels: diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml index fcb9bb14..f3f8741a 100644 --- a/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml @@ -3,24 +3,24 @@ # Each stanza incorporates various configurations for # different fine-tuning / training tasks. plugins: - # PEFT-related acceleration + # PEFT-related acceleration peft: - # quantization-releated acceleration - # e.g., kernels for quantized base weights + # quantization-releated acceleration + # e.g., kernels for quantized base weights quantization: - # For loading BitsAndBytes quantized layers - # to serve as 4bit base-weights for LoRA PEFT-tuning. - # NOTE: currently AutoGPTQ is not properly integrated into huggingface / - # bitsandbytes, thus recommended quant_type to be either "nf4" - # or "fp4". - # bitsandbytes: + # For loading BitsAndBytes quantized layers + # to serve as 4bit base-weights for LoRA PEFT-tuning. + # NOTE: currently AutoGPTQ is not properly integrated into huggingface / + # bitsandbytes, thus recommended quant_type to be either "nf4" + # or "fp4". + # bitsandbytes: bitsandbytes: quant_type: nf4 - # If True, then no get_peft_model and prepare_model_for_kbit_training - # will be called. + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. no_peft_model: false fused_ops_and_kernels: diff --git a/scripts/benchmarks/refs_orca/a100_80gb_mp.csv b/scripts/benchmarks/refs_orca/a100_80gb_mp.csv new file mode 100644 index 00000000..315d7220 --- /dev/null +++ b/scripts/benchmarks/refs_orca/a100_80gb_mp.csv @@ -0,0 +1,25 @@ +epoch,framework_config,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,num_gpus,per_device_train_batch_size,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +0.99,aadp-padding-free,57757.0,39178506752,28985203712,2,8,float16,0.3572933577722119,126.1795,15.85,0.491,2196.443 +0.99,aadp-padding-free,36992.0,28042799104,14501724672,4,4,float16,0.3559871873547954,92.9618,21.514,0.667,1521.163 +0.99,aadp-padding-free,27808.0,22495587328,7258869248,8,2,float16,0.369813451843877,80.6181,24.808,0.769,829.479 +0.98,aadp-padding-free,59847.0,41375433216,28984412672,2,16,float16,0.38225665976924283,106.7912,18.728,0.29,2550.876 +0.98,aadp-padding-free,40257.0,28841733632,14502653440,4,8,float16,0.36831535447028374,67.6621,29.559,0.458,2081.609 +0.98,aadp-padding-free,28417.75,22607035392,7258897920,8,4,float16,0.3666688684494265,50.3504,39.722,0.616,1418.042 +1.0,aadp-padding-free,68532.0,43333948928,28984432640,2,32,float16,0.429565966129303,97.3702,20.54,0.164,2764.615 +1.0,aadp-padding-free,42162.5,29314396160,14501763584,4,16,float16,0.40204794704914093,58.3195,34.294,0.274,2279.048 +1.0,aadp-padding-free,30864.75,22761271296,7260338176,8,8,float16,0.3985021114349365,37.2954,53.626,0.429,1804.189 +1.0,aadp-padding-free,73184.0,51346142720,28984764416,2,64,float16,0.46318376064300537,91.1072,21.952,0.088,3152.439 +1.0,aadp-padding-free,44459.5,32477526016,14501908480,4,32,float16,0.4389604926109314,50.8064,39.365,0.157,2707.945 +1.0,aadp-padding-free,33015.25,24019485696,7258975232,8,16,float16,0.42861950397491455,31.1059,64.296,0.257,2176.948 +1.0,aadp-padding-free-multipack,54208.0,38979052032,28985210880,2,8,float16,0.3598624064372136,123.3006,16.221,0.527,2277.158 +0.99,aadp-padding-free-multipack,36105.5,27772687360,14501707776,4,4,float16,0.36395706763634317,89.9515,22.234,0.723,1551.302 +0.99,aadp-padding-free-multipack,27499.25,22185776128,7259904512,8,2,float16,0.35203131789066755,73.3563,27.264,0.832,867.069 +1.0,aadp-padding-free-multipack,57615.0,39973031424,28984641536,2,16,float16,0.3937762454152107,104.3178,19.172,0.307,2675.209 +1.0,aadp-padding-free-multipack,36844.5,28088351232,14502502400,4,8,float16,0.3769803624600172,62.2141,32.147,0.514,2232.001 +1.0,aadp-padding-free-multipack,27506.0,22341642240,7259199488,8,4,float16,0.371551351621747,47.017,42.538,0.681,1463.427 +0.97,aadp-padding-free-multipack,60039.0,42691656704,28984248320,2,32,float16,0.4251274108886719,87.5607,22.841,0.171,3094.163 +0.97,aadp-padding-free-multipack,38446.5,29109392384,14502311936,4,16,float16,0.41165995597839355,49.5359,40.375,0.303,2733.896 +0.97,aadp-padding-free-multipack,28621.0,22654561792,7259273216,8,8,float16,0.4068267265955607,30.0739,66.503,0.499,2239.884 +1.0,aadp-padding-free-multipack,65043.0,48633901056,28984248320,2,64,float16,0.43941691517829895,88.0836,22.706,0.091,3177.945 +0.93,aadp-padding-free-multipack,39892.5,31826244608,14501666816,4,32,float16,0.4199257918766567,41.1059,48.655,0.17,3189.563 +0.93,aadp-padding-free-multipack,30494.0,23677434368,7258887168,8,16,float16,0.4178782190595354,23.6115,84.704,0.296,2774.236 diff --git a/scripts/benchmarks/refs_orca/benchmarks.csv b/scripts/benchmarks/refs_orca/a100_80gb_pf.csv similarity index 100% rename from scripts/benchmarks/refs_orca/benchmarks.csv rename to scripts/benchmarks/refs_orca/a100_80gb_pf.csv diff --git a/scripts/benchmarks/refs_orca/requirements.txt b/scripts/benchmarks/refs_orca/requirements.txt index b021ef7c..f3f04288 100644 --- a/scripts/benchmarks/refs_orca/requirements.txt +++ b/scripts/benchmarks/refs_orca/requirements.txt @@ -26,10 +26,10 @@ fastparquet==2024.5.0 filelock==3.15.4 fire==0.6.0 flash-attn==2.6.3 --e git+https://github.com/foundation-model-stack/fms-acceleration.git@0fe0867656a01c9e030d77d8007c70fa775e5668#egg=fms_acceleration&subdirectory=plugins/framework --e git+https://github.com/foundation-model-stack/fms-acceleration.git@0fe0867656a01c9e030d77d8007c70fa775e5668#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing --e git+https://github.com/foundation-model-stack/fms-acceleration.git@0fe0867656a01c9e030d77d8007c70fa775e5668#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels --e git+https://github.com/foundation-model-stack/fms-acceleration.git@0fe0867656a01c9e030d77d8007c70fa775e5668#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@8bc65cd3e0f83a7786e4716216fbe1bea702313b#egg=fms_acceleration&subdirectory=plugins/framework +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@8bc65cd3e0f83a7786e4716216fbe1bea702313b#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@8bc65cd3e0f83a7786e4716216fbe1bea702313b#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@8bc65cd3e0f83a7786e4716216fbe1bea702313b#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft -e git+https://github.com/foundation-model-stack/fms-hf-tuning.git@a8ab68ffaa0d3b49aeb6753bccfdf807672eba69#egg=fms_hf_tuning fonttools==4.53.1 frozenlist==1.4.1 @@ -46,6 +46,7 @@ joblib==1.4.2 jupyter_client==8.6.2 jupyter_core==5.7.2 kiwisolver==1.4.5 +llvmlite==0.43.0 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.1.post1 @@ -57,6 +58,7 @@ multiprocess==0.70.16 nest-asyncio==1.6.0 networkx==3.3 ninja==1.11.1.1 +numba==0.60.0 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 @@ -112,7 +114,7 @@ torch==2.4.0 tornado==6.4.1 tqdm==4.66.5 traitlets==5.14.3 -transformers==4.42.4 +transformers==4.44.0 triton==3.0.0 trl==0.9.6 typing_extensions==4.12.2 diff --git a/scripts/benchmarks/scenarios-orca.yaml b/scripts/benchmarks/scenarios-orca.yaml index 5c435852..df6ac097 100644 --- a/scripts/benchmarks/scenarios-orca.yaml +++ b/scripts/benchmarks/scenarios-orca.yaml @@ -51,6 +51,7 @@ scenarios: - name: padding-free framework_config: - aadp-padding-free + - aadp-padding-free-multipack arguments: learning_rate: 2e-5 torch_dtype: float16 diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index b778c3c7..c72c62eb 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -146,6 +146,7 @@ def read_configuration(path: str) -> Dict: KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak" KEY_BNB_NF4_FOAK = "bnb-nf4-foak" KEY_AADP_PADDING_FREE = "aadp-padding-free" +KEY_AADP_MULTIPACK = "aadp-multipack" CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -168,7 +169,8 @@ def read_configuration(path: str) -> Dict: "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", [("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")], ), - KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/aadp.yaml", + KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/padding_free.yaml", + KEY_AADP_MULTIPACK: "plugins/attention-and-distributed-packing/configs/multipack.yaml", } # list of (tag, combi) tuples @@ -187,6 +189,7 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-bnb-nf4-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4)), ("accelerated-peft-autogptq-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), ("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), + ("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)), ] diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh index 71c68fc0..f1d5d172 100644 --- a/scripts/run_benchmarks.sh +++ b/scripts/run_benchmarks.sh @@ -41,10 +41,11 @@ NO_OVERWRITE=${NO_OVERWRITE:-"false"} MEMORY_LOGGING=${MEMORY_LOGGING:-"all"} # inputs -NUM_GPUS_MATRIX=${1-"1 2"} -RESULT_DIR=${2:-"benchmark_outputs"} -SCENARIOS_CONFIG=${3:-$SCENARIOS_CONFIG} -SCENARIOS_FILTER=${4-$SCNTAG_PEFT_AUTOGPTQ} +NUM_GPUS_MATRIX=${1:-"1 2"} +EFFECTIVE_BS_MATRIX=${2:-"4 8"} +RESULT_DIR=${3:-"benchmark_outputs"} +SCENARIOS_CONFIG=${4:-$SCENARIOS_CONFIG} +SCENARIOS_FILTER=${5-$SCNTAG_PEFT_AUTOGPTQ} echo "NUM_GPUS_MATRIX: $NUM_GPUS_MATRIX" echo "RESULT_DIR: $RESULT_DIR" @@ -108,6 +109,7 @@ fi PYTHONPATH=. \ python $WORKING_DIR/benchmark.py \ --num_gpus $NUM_GPUS_MATRIX \ + --effective_batch_size_matrix $EFFECTIVE_BS_MATRIX \ --scenarios_config_path $SCENARIOS_CONFIG \ --accelerate_config $ACCELERATE_CONFIG \ --defaults_config_path $DEFAULTS_CONFIG \ diff --git a/tox.ini b/tox.ini index f512639c..52f9bdb3 100644 --- a/tox.ini +++ b/tox.ini @@ -41,7 +41,7 @@ commands = python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention_and_distributed_packing # run the benchmark script - bash scripts/run_benchmarks.sh {posargs:"1 2" benchmark_outputs} + bash scripts/run_benchmarks.sh {posargs:"1 2" "4 8" benchmark_outputs} allowlist_externals = bash