Skip to content

Commit

Permalink
FEAT / Optim: Add GaLore optimizer (#29588)
Browse files Browse the repository at this point in the history
* add galore v1

* add import

* add tests and doc

* fix doctest

* forward contrib credits from discussions

* forward contrib credits from discussions

* Apply suggestions from code review

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fix failing tests'

* switch to `optim_target_modules` and clarify docs

* more clarification

* enhance lookup logic

* update a test to add peak memory

* add regex, all-linear and single string support

* add layer-wise optimization through DummyOptimizers and LRSchedulers

* forward contrib credits from discussions and original idea

* add a section about DDP not supported in layerwise

* Update src/transformers/trainer.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fix self

* check only if layer_wise

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* oops

* make use of intervals

* clarify comment

* add matching tests

* GaLoRe -> GaLore

* move to `get_scheduler`

* add note on docs

* add a warning

* adapt a bit the docs

* update docstring

* support original API

* Update docs/source/en/trainer.md

* slightly refactor

* Update docs/source/en/trainer.md

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix args parsing and add tests

* remove warning for regex

* fix type hint

* add note about extra args

* make `is_regex` return optional

---------

Co-authored-by: Maxime <maximegmd @users.noreply.github.com>
Co-authored-by: Wing Lian <winglian @users.noreply.github.com>
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: hiyouga <hiyouga@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
  • Loading branch information
7 people authored Mar 19, 2024
1 parent 484e10f commit f6261d7
Show file tree
Hide file tree
Showing 10 changed files with 734 additions and 3 deletions.
130 changes: 130 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,136 @@ trainer = Trainer(..., args=training_args)

NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.

## GaLore

Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.

First make sure to install GaLore official repository:

```bash
pip install galore-torch
```

Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

To pass extra arguments supports by GaLore, you should pass correctly `optim_args`, for example:

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"],
optim_args="rank=64, update_proj_gap=100, scale=0.10",
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507).

Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner.

Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.

You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below:

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw_layerwise",
optim_target_modules=["attn", "mlp"]
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau

from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
from .trainer_utils import SchedulerType
from .utils import logging
from .utils.versions import require_version
Expand Down Expand Up @@ -362,6 +363,32 @@ def get_scheduler(
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]

# If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
# recursively call `get_scheduler` to get the proper schedulers on each parameter
if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict = {}

for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
name,
optimizer=optimizer_dict[param],
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)

def scheduler_hook(param):
# Since the optimizer hook has been already attached we only need to
# attach the scheduler hook
if param.grad is not None:
scheduler_dict[param].step()

for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)

return LayerWiseDummyScheduler()

if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)

Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -325,6 +326,14 @@ def require_bs4(test_case):
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)


def require_galore_torch(test_case):
"""
Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
https://github.com/jiaweizzhao/GaLore
"""
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
145 changes: 143 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LayerWiseDummyOptimizer,
LengthGroupedSampler,
SequentialDistributedSampler,
distributed_broadcast_scalars,
Expand Down Expand Up @@ -111,6 +112,7 @@
RemoveColumnsCollator,
TrainerMemoryTracker,
TrainOutput,
check_target_module_exists,
default_compute_objective,
denumpify_detensorize,
enable_full_determinism,
Expand Down Expand Up @@ -141,6 +143,7 @@
is_apex_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_peft_available,
Expand Down Expand Up @@ -1010,7 +1013,17 @@ def create_optimizer(self):
},
]

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)

# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")

# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")

self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
Expand All @@ -1033,7 +1046,9 @@ def create_optimizer(self):
return self.optimizer

@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
"""
Returns the optimizer class and optimizer parameters based on the training arguments.
Expand Down Expand Up @@ -1171,6 +1186,132 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_cls = torch.optim.Adagrad
elif args.optim == OptimizerNames.RMSPROP:
optimizer_cls = torch.optim.RMSprop
elif args.optim in [
OptimizerNames.GALORE_ADAMW,
OptimizerNames.GALORE_ADAMW_8BIT,
OptimizerNames.GALORE_ADAFACTOR,
OptimizerNames.GALORE_ADAMW_LAYERWISE,
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
]:
if not is_galore_torch_available():
raise ImportError(
"You need to install `galore_torch` in order to use GaLore optimizers"
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
)
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit

is_layerwise = args.optim.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")

optimizer_mapping = {
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
}

optimizer_cls = optimizer_mapping[args.optim]

if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
)

if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
)

if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")

logger.warning(
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
)

all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)

galore_params = []
galore_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)

if not isinstance(module, nn.Linear):
# Warn in case we match but it's not a linear layer
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
)

continue

if not target_module_exists and not all_linear:
continue

galore_params.append(module.weight)
galore_params_names.append(module_name + ".weight")

if len(galore_params) == 0:
raise ValueError(
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
)

non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]

galore_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 0.25)),
"proj_type": optim_args.pop("proj_type", "std"),
}

# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
param_groups = [
{"params": non_galore_params},
{"params": galore_params, **galore_optim_kwargs},
]

if is_layerwise:
# For layer-wise optimizers, the optimization step is done through post accumulation
# gradient hooks. The trick is to first attach these hooks to the model parameters then
# create a dummy optimizer that will perform no-ops in the Trainer.
# See the original implementation or the nice implementation from @hiyouga
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
if args.gradient_accumulation_steps != 1:
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")

optimizer_dict = {}
for param in non_galore_params:
param_groups = [{"params": [param]}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
for param in galore_params:
param_groups = [{"params": [param], **galore_optim_kwargs}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)

def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()

for param in model.parameters():
param.register_post_accumulate_grad_hook(optimizer_hook)

optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})

optimizer_kwargs.update({"params": param_groups})

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
Loading

0 comments on commit f6261d7

Please sign in to comment.