diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 65bfa4176dd2..3d57220fe827 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -252,6 +252,136 @@ trainer = Trainer(..., args=training_args) NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior. +## GaLore + +Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA. + +First make sure to install GaLore official repository: + +```bash +pip install galore-torch +``` + +Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`): + +```python +import torch +import datasets +import trl + +from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM + +train_dataset = datasets.load_dataset('imdb', split='train') + +args = TrainingArguments( + output_dir="./test-galore", + max_steps=100, + per_device_train_batch_size=2, + optim="galore_adamw", + optim_target_modules=["attn", "mlp"] +) + +model_id = "google/gemma-2b" + +config = AutoConfig.from_pretrained(model_id) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_config(config).to(0) + +trainer = trl.SFTTrainer( + model=model, + args=args, + train_dataset=train_dataset, + dataset_text_field='text', + max_seq_length=512, +) + +trainer.train() +``` + +To pass extra arguments supports by GaLore, you should pass correctly `optim_args`, for example: + +```python +import torch +import datasets +import trl + +from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM + +train_dataset = datasets.load_dataset('imdb', split='train') + +args = TrainingArguments( + output_dir="./test-galore", + max_steps=100, + per_device_train_batch_size=2, + optim="galore_adamw", + optim_target_modules=["attn", "mlp"], + optim_args="rank=64, update_proj_gap=100, scale=0.10", +) + +model_id = "google/gemma-2b" + +config = AutoConfig.from_pretrained(model_id) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_config(config).to(0) + +trainer = trl.SFTTrainer( + model=model, + args=args, + train_dataset=train_dataset, + dataset_text_field='text', + max_seq_length=512, +) + +trainer.train() +``` + +You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507). + +Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner. + +Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards. + +You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below: + +```python +import torch +import datasets +import trl + +from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM + +train_dataset = datasets.load_dataset('imdb', split='train') + +args = TrainingArguments( + output_dir="./test-galore", + max_steps=100, + per_device_train_batch_size=2, + optim="galore_adamw_layerwise", + optim_target_modules=["attn", "mlp"] +) + +model_id = "google/gemma-2b" + +config = AutoConfig.from_pretrained(model_id) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_config(config).to(0) + +trainer = trl.SFTTrainer( + model=model, + args=args, + train_dataset=train_dataset, + dataset_text_field='text', + max_seq_length=512, +) + +trainer.train() +``` + +Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue. + ## Accelerate and Trainer The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 65a41d1b1a44..1c76ddda9f0b 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -24,6 +24,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau +from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler from .trainer_utils import SchedulerType from .utils import logging from .utils.versions import require_version @@ -362,6 +363,32 @@ def get_scheduler( """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and + # recursively call `get_scheduler` to get the proper schedulers on each parameter + if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): + optimizer_dict = optimizer.optimizer_dict + scheduler_dict = {} + + for param in optimizer_dict.keys(): + scheduler_dict[param] = get_scheduler( + name, + optimizer=optimizer_dict[param], + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + + def scheduler_hook(param): + # Since the optimizer hook has been already attached we only need to + # attach the scheduler hook + if param.grad is not None: + scheduler_dict[param].step() + + for param in optimizer_dict.keys(): + param.register_post_accumulate_grad_hook(scheduler_hook) + + return LayerWiseDummyScheduler() + if name == SchedulerType.CONSTANT: return schedule_func(optimizer) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 5caf23f6ddff..8b7814163739 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -70,6 +70,7 @@ is_fsdp_available, is_ftfy_available, is_g2p_en_available, + is_galore_torch_available, is_ipex_available, is_jieba_available, is_jinja_available, @@ -325,6 +326,14 @@ def require_bs4(test_case): return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case) +def require_galore_torch(test_case): + """ + Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed. + https://github.com/jiaweizzhao/GaLore + """ + return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case) + + def require_cv2(test_case): """ Decorator marking a test that requires OpenCV. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f40645d3ac5f..bef4b24c517c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -83,6 +83,7 @@ DistributedTensorGatherer, IterableDatasetShard, LabelSmoother, + LayerWiseDummyOptimizer, LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, @@ -111,6 +112,7 @@ RemoveColumnsCollator, TrainerMemoryTracker, TrainOutput, + check_target_module_exists, default_compute_objective, denumpify_detensorize, enable_full_determinism, @@ -141,6 +143,7 @@ is_apex_available, is_bitsandbytes_available, is_datasets_available, + is_galore_torch_available, is_in_notebook, is_ipex_available, is_peft_available, @@ -1010,7 +1013,17 @@ def create_optimizer(self): }, ] - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model) + + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") + + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": @@ -1033,7 +1046,9 @@ def create_optimizer(self): return self.optimizer @staticmethod - def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: + def get_optimizer_cls_and_kwargs( + args: TrainingArguments, model: Optional[PreTrainedModel] = None + ) -> Tuple[Any, Any]: """ Returns the optimizer class and optimizer parameters based on the training arguments. @@ -1171,6 +1186,132 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_cls = torch.optim.Adagrad elif args.optim == OptimizerNames.RMSPROP: optimizer_cls = torch.optim.RMSprop + elif args.optim in [ + OptimizerNames.GALORE_ADAMW, + OptimizerNames.GALORE_ADAMW_8BIT, + OptimizerNames.GALORE_ADAFACTOR, + OptimizerNames.GALORE_ADAMW_LAYERWISE, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE, + ]: + if not is_galore_torch_available(): + raise ImportError( + "You need to install `galore_torch` in order to use GaLore optimizers" + " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" + ) + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + + is_layerwise = args.optim.lower().endswith("layerwise") + if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: + raise NotImplementedError("Layer-wise GaLore does not support DDP at this time") + + optimizer_mapping = { + OptimizerNames.GALORE_ADAMW: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor, + OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor, + } + + optimizer_cls = optimizer_mapping[args.optim] + + if args.optim_target_modules is None: + raise ValueError( + "You need to define a `optim_target_modules` in order to properly use GaLore optimizers" + ) + + if not isinstance(args.optim_target_modules, (list, str)): + raise ValueError( + f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}" + ) + + if model is None: + raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") + + logger.warning( + "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !" + ) + + all_linear = ( + isinstance(args.optim_target_modules, str) + and args.optim_target_modules.replace("_", "-") == "all-linear" + ) + + galore_params = [] + galore_params_names = [] + for module_name, module in model.named_modules(): + target_module_exists, is_regex = check_target_module_exists( + args.optim_target_modules, module_name, return_is_regex=True + ) + + if not isinstance(module, nn.Linear): + # Warn in case we match but it's not a linear layer + if target_module_exists and not is_regex: + logger.warning( + f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!" + ) + + continue + + if not target_module_exists and not all_linear: + continue + + galore_params.append(module.weight) + galore_params_names.append(module_name + ".weight") + + if len(galore_params) == 0: + raise ValueError( + f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." + ) + + non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] + + galore_optim_kwargs = { + "rank": int(optim_args.pop("rank", 128)), + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), + "scale": float(optim_args.pop("scale", 0.25)), + "proj_type": optim_args.pop("proj_type", "std"), + } + + # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore + param_groups = [ + {"params": non_galore_params}, + {"params": galore_params, **galore_optim_kwargs}, + ] + + if is_layerwise: + # For layer-wise optimizers, the optimization step is done through post accumulation + # gradient hooks. The trick is to first attach these hooks to the model parameters then + # create a dummy optimizer that will perform no-ops in the Trainer. + # See the original implementation or the nice implementation from @hiyouga + # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba + if args.gradient_accumulation_steps != 1: + raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !") + + optimizer_dict = {} + for param in non_galore_params: + param_groups = [{"params": [param]}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + for param in galore_params: + param_groups = [{"params": [param], **galore_optim_kwargs}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + + def optimizer_hook(param): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + for param in model.parameters(): + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer_cls = LayerWiseDummyOptimizer + optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) + + optimizer_kwargs.update({"params": param_groups}) + + if args.optim == OptimizerNames.GALORE_ADAFACTOR: + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 34d2c8416b59..394d29411d4d 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -34,6 +34,7 @@ import torch import torch.distributed as dist from torch import nn +from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data.distributed import DistributedSampler @@ -1226,3 +1227,47 @@ def from_json_file(cls, json_file): def to_dict(self): return copy.deepcopy(self.__dict__) + + +class LayerWiseDummyOptimizer(torch.optim.Optimizer): + """ + For Layer-wise optimizers such as GaLoRE optimizer, the optimization + step is already done through the post gradient hooks. Therefore + the trick is to create a dummy optimizer that can take arbitrary + args and kwargs and return a no-op during training. + + Initial idea from @hiyouga in LLaMA-Factory: + https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba + """ + + def __init__(self, optimizer_dict=None, *args, **kwargs): + dummy_tensor = torch.randn(1, 1) + self.optimizer_dict = optimizer_dict + super().__init__([dummy_tensor], {"lr": 1e-03}) + + def zero_grad(self, set_to_none: bool = True) -> None: + pass + + def step(self, closure=None) -> Optional[float]: + pass + + +class LayerWiseDummyScheduler(LRScheduler): + """ + For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step + are already done through the post gradient hooks. Therefore + the trick is to create a dummy scheduler that can take arbitrary + args and kwargs and return a no-op during training. + """ + + def __init__(self, *args, **kwargs): + optimizer = LayerWiseDummyOptimizer() + last_epoch = -1 + verbose = False + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return self.base_lrs diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 5d528317e54f..0faf657387ba 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -785,3 +785,42 @@ def _remove_columns(self, feature: dict) -> dict: def __call__(self, features: List[dict]): features = [self._remove_columns(feature) for feature in features] return self.data_collator(features) + + +def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False): + """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules. + + Args: + optim_target_modules (`Union[str, List[str]]`): + A list of strings to try to match. Can be also a full string. + key (`str`): + A key to search any matches in optim_target_modules + return_is_regex (`bool`): + If set to `True`, the method will return whether the passed `optim_target_modules` + is a regex or not. + + Returns: + `bool` : True of match object if key matches any target modules from config, False or + None if no match found + `bool` : If the matched target module is a regex to silence out the warnings in Trainer + for extra modules being found (only if `target_module_found=True` for an array of regex). + """ + target_module_found = False + is_regex = False + + if isinstance(optim_target_modules, str): + target_module_found = bool(re.fullmatch(optim_target_modules, key)) + is_regex = True if not optim_target_modules == key else False + elif key in optim_target_modules: # from here, target_module_found must be a list of str + # this module is specified directly in target_modules + target_module_found = True + elif any(target_key in key for target_key in optim_target_modules): + target_module_found = True + elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules): + target_module_found = True + is_regex = True + + if return_is_regex: + return target_module_found, is_regex + + return target_module_found diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 54cd045b2005..a52a77e9a766 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -164,6 +164,12 @@ class OptimizerNames(ExplicitEnum): RMSPROP_BNB = "rmsprop_bnb" RMSPROP_8BIT = "rmsprop_bnb_8bit" RMSPROP_32BIT = "rmsprop_bnb_32bit" + GALORE_ADAMW = "galore_adamw" + GALORE_ADAMW_8BIT = "galore_adamw_8bit" + GALORE_ADAFACTOR = "galore_adafactor" + GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise" + GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise" + GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 @@ -696,6 +702,12 @@ class TrainingArguments: for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also `PeftModel` from peft. + optim_target_modules (`Union[str, List[str]]`, *optional*): + The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm + https://arxiv.org/abs/2403.03507 + See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe + optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules + only. """ framework = "pt" @@ -1354,6 +1366,13 @@ class TrainingArguments: }, ) + optim_target_modules: Union[None, str, List[str]] = field( + default=None, + metadata={ + "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment." + }, + ) + def __post_init__(self): # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index d6b170bcd87d..b8da221a8c91 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -125,6 +125,7 @@ is_fsdp_available, is_ftfy_available, is_g2p_en_available, + is_galore_torch_available, is_in_notebook, is_ipex_available, is_jieba_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 894ac00c6df2..3835831e88a4 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -95,6 +95,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _apex_available = _is_package_available("apex") _aqlm_available = _is_package_available("aqlm") _bitsandbytes_available = _is_package_available("bitsandbytes") +_galore_torch_available = _is_package_available("galore_torch") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -309,6 +310,10 @@ def is_torchvision_available(): return _torchvision_available +def is_galore_torch_available(): + return _galore_torch_available + + def is_pyctcdecode_available(): return _pyctcdecode_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index bd704bc8b59e..ebc628146b96 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -60,6 +60,7 @@ require_accelerate, require_bitsandbytes, require_deepspeed, + require_galore_torch, require_intel_extension_for_pytorch, require_optuna, require_peft, @@ -84,7 +85,7 @@ slow, torch_device, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists from transformers.training_args import OptimizerNames from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -114,6 +115,8 @@ GPT2Config, GPT2LMHeadModel, LineByLineTextDataset, + LlamaConfig, + LlamaForCausalLM, PreTrainedModel, Trainer, TrainerState, @@ -146,6 +149,31 @@ def __getitem__(self, i): return result +# Converting Bytes to Megabytes +def bytes2megabytes(x): + return int(x / 2**20) + + +# Copied from acclerate: https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py#L40C1-L73C68 +class TorchTracemalloc: + def __enter__(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + return self + + def __exit__(self, *exc): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + self.used = bytes2megabytes(self.end - self.begin) + self.peaked = bytes2megabytes(self.peak - self.begin) + + @dataclasses.dataclass class RegressionTrainingArguments(TrainingArguments): a: float = 0.0 @@ -1069,6 +1097,293 @@ def test_dataloader_without_dataset(self): trainer.train() trainer.evaluate() + def test_galore_matched_modules(self): + regex_patterns = [r".*.attn.*", r".*.mlp.*"] + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, True] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True) + self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertTrue(is_regex) + + exact_patterns = ["q_proj", "up_proj"] + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, True] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True) + self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertFalse(is_regex) + + simple_regex = r".*.attn.*" + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, False] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True) + self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertTrue(is_regex) + + simple_regex = "model.transformer.h.0.attn.q_proj" + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, False] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True) + self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertFalse(is_regex) + + target_modules = ["attn", "mlp"] + + module_names = [ + "model.transformer.h.0.ln_1", + "model.transformer.h.0.attn.q_proj", + "model.lm_head", + "model.transformer.h.0.mlp.up_proj", + ] + expected_values = [False, True, False, True] + + for expected_value, module_name in zip(expected_values, module_names): + is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True) + self.assertTrue(is_module_matched == expected_value) + if is_module_matched: + self.assertFalse(is_regex) + + @require_galore_torch + @require_torch_gpu + def test_galore(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_galore_torch + @require_torch_gpu + def test_galore_extra_args(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw", + optim_args="rank=64, update_proj_gap=100, scale=0.10", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_galore_torch + @require_torch_gpu + def test_galore_layerwise(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw_layerwise", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_galore_torch + @require_torch_gpu + def test_galore_layerwise_with_scheduler(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw_layerwise", + lr_scheduler_type="cosine", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_galore_torch + @require_torch_gpu + def test_galore_adamw_8bit(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adamw_8bit", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_galore_torch + @require_torch_gpu + def test_galore_adafactor(self): + # These are the intervals of the peak memory usage of training such a tiny model + # if the peak memory goes outside that range, then we know there might be a bug somewhere + upper_bound_pm = 700 + lower_bound_pm = 650 + + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adafactor", + optim_target_modules=[r".*attn.*", r".*mlp.*"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) + + self.assertTrue(galore_peak_memory < upper_bound_pm) + self.assertTrue(lower_bound_pm < galore_peak_memory) + + @require_galore_torch + @require_torch_gpu + def test_galore_adafactor_attention_only(self): + # These are the intervals of the peak memory usage of training such a tiny model + # if the peak memory goes outside that range, then we know there might be a bug somewhere + upper_bound_pm = 700 + lower_bound_pm = 650 + + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adafactor", + optim_target_modules=["q_proj", "k_proj", "v_proj"], + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) + self.assertTrue(galore_peak_memory < upper_bound_pm) + self.assertTrue(lower_bound_pm < galore_peak_memory) + + @require_galore_torch + @require_torch_gpu + def test_galore_adafactor_all_linear(self): + # These are the intervals of the peak memory usage of training such a tiny model + # if the peak memory goes outside that range, then we know there might be a bug somewhere + upper_bound_pm = 700 + lower_bound_pm = 650 + + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + optim="galore_adafactor", + optim_target_modules="all-linear", + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin) + self.assertTrue(galore_peak_memory < upper_bound_pm) + self.assertTrue(lower_bound_pm < galore_peak_memory) + @require_torch_multi_accelerator def test_data_is_not_parallelized_when_model_is_parallel(self): model = RegressionModel()