Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Optim: Add GaLore optimizer #29588

Merged
merged 44 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b31ce79
add galore v1
younesbelkada Mar 11, 2024
58169f1
add import
younesbelkada Mar 11, 2024
9032635
add tests and doc
younesbelkada Mar 11, 2024
136f104
fix doctest
younesbelkada Mar 11, 2024
a5483b3
forward contrib credits from discussions
Mar 11, 2024
887d3ad
forward contrib credits from discussions
Mar 11, 2024
d6f119f
Apply suggestions from code review
younesbelkada Mar 11, 2024
3fae229
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada Mar 11, 2024
c8c50f8
fix failing tests'
younesbelkada Mar 11, 2024
2bdda68
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 13, 2024
630bd13
switch to `optim_target_modules` and clarify docs
younesbelkada Mar 13, 2024
a871b75
more clarification
younesbelkada Mar 13, 2024
cb6cd7e
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 13, 2024
51b7b29
enhance lookup logic
younesbelkada Mar 13, 2024
3da3b90
update a test to add peak memory
younesbelkada Mar 13, 2024
9115c94
add regex, all-linear and single string support
younesbelkada Mar 13, 2024
0b4ba83
add layer-wise optimization through DummyOptimizers and LRSchedulers
younesbelkada Mar 13, 2024
3e5930e
forward contrib credits from discussions and original idea
hiyouga Mar 13, 2024
a16d3a8
add a section about DDP not supported in layerwise
younesbelkada Mar 13, 2024
29e7e94
Update src/transformers/trainer.py
younesbelkada Mar 13, 2024
18ea144
fix self
younesbelkada Mar 13, 2024
7800bf1
check only if layer_wise
younesbelkada Mar 13, 2024
e022bdd
Update src/transformers/training_args.py
younesbelkada Mar 14, 2024
830c68d
oops
younesbelkada Mar 14, 2024
b640e98
make use of intervals
younesbelkada Mar 14, 2024
14a89b2
clarify comment
younesbelkada Mar 14, 2024
6f7102d
add matching tests
younesbelkada Mar 14, 2024
c11cb63
GaLoRe -> GaLore
younesbelkada Mar 14, 2024
3678201
move to `get_scheduler`
younesbelkada Mar 14, 2024
fdc4b2a
add note on docs
younesbelkada Mar 14, 2024
e7ce9b7
add a warning
younesbelkada Mar 14, 2024
91d6436
adapt a bit the docs
younesbelkada Mar 15, 2024
b9e338a
update docstring
younesbelkada Mar 15, 2024
6ff3762
support original API
younesbelkada Mar 17, 2024
0d0440a
Update docs/source/en/trainer.md
younesbelkada Mar 17, 2024
832f2be
slightly refactor
younesbelkada Mar 18, 2024
898a3c5
Update docs/source/en/trainer.md
younesbelkada Mar 18, 2024
ed3ad4a
Update src/transformers/training_args.py
younesbelkada Mar 19, 2024
57e7096
fix args parsing and add tests
younesbelkada Mar 19, 2024
64ccfa6
remove warning for regex
younesbelkada Mar 19, 2024
4413f07
Merge remote-tracking branch 'upstream/main' into add-galore-optimizer
younesbelkada Mar 19, 2024
73dcabb
fix type hint
younesbelkada Mar 19, 2024
1987b7a
add note about extra args
younesbelkada Mar 19, 2024
db2bf21
make `is_regex` return optional
younesbelkada Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,57 @@ trainer = Trainer(..., args=training_args)

NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.

## GaLore

Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.

First make sure to install GaLore official repository:

```bash
pip install git+https://github.com/jiaweizzhao/GaLore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GaLore has released an official package: pip install galore-torch

https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#install-galore-optimizer

```

Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `galore_target_modules`, which should be a list of strings, corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):

```python
import torch
import datasets
import trl

from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
galore_target_modules=["attn", "mlp"]
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()
```

You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507).

Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -324,6 +325,14 @@ def require_bs4(test_case):
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)


def require_galore_torch(test_case):
"""
Decorator marking a test that requires Galore. These tests are skipped when Galore isn't installed.
https://github.com/jiaweizzhao/GaLore
"""
return unittest.skipUnless(is_galore_torch_available(), "test requires Galore")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,9 @@ def get_optimizer_cls_and_kwargs(
]

optimizer_kwargs.update({"params": param_groups})

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
66 changes: 66 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
require_accelerate,
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch,
require_optuna,
require_peft,
Expand Down Expand Up @@ -114,6 +115,8 @@
GPT2Config,
GPT2LMHeadModel,
LineByLineTextDataset,
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel,
Trainer,
TrainerState,
Expand Down Expand Up @@ -1069,6 +1072,69 @@ def test_dataloader_without_dataset(self):
trainer.train()
trainer.evaluate()

@require_galore_torch
def test_galore(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw",
galore_target_modules=["attn", "mlp"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

@require_galore_torch
def test_galore_adamw_8bit(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_8bit",
galore_target_modules=["attn", "mlp"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

@require_galore_torch
def test_galore_adafactor(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
galore_target_modules=["attn", "mlp"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()
Expand Down
Loading