Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add RAdamScheduleFree optimizer #35313

Merged
merged 6 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -544,13 +544,17 @@ trainer = Trainer(
trainer.train()
```

This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.
This script demonstrates how to fine-tune the [google/gemma-2b](https://huggingface.co/google/gemma-2b) model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.

### Schedule Free Optimizer
### Schedule-Free Optimizer

The Schedule-Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Supported optimizers for Schedule-Free are `schedule_free_radam`, `schedule_free_adamw` and `schedule_free_sgd`. First install schedulefree from pypi `pip install schedulefree`.

The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682).
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`.
Additionally, neither `warmup_steps` nor `warmup_ratio` parameters are required when using `schedule_free_radam`.

By default, we recommend setting `lr_scheduler_type="constant"` in the `TrainingArguments`. Setting other `lr_scheduler_type` would also work, but combining Schedule-Free with other learning rate schedules is not well-studied both in research and in practice, as it may affect the optimizer's intended behavior and performance guarantees.

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:

Expand All @@ -566,7 +570,8 @@ args = TrainingArguments(
output_dir="./test-schedulefree",
max_steps=1000,
per_device_train_batch_size=4,
optim="schedule_free_adamw",
optim="schedule_free_radam",
lr_scheduler_type="constant",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
Expand Down
24 changes: 20 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,28 +1624,44 @@ def optimizer_hook(param):
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
elif args.optim in [
OptimizerNames.SCHEDULE_FREE_RADAM,
OptimizerNames.SCHEDULE_FREE_ADAMW,
OptimizerNames.SCHEDULE_FREE_SGD,
]:
if not is_schedulefree_available():
raise ImportError(
"You need to install `schedulefree` in order to use schedulefree optimizers"
" install it with `pip install schedulefree`"
"You need to install `schedulefree` in order to use schedulefree optimizers. "
"Install it with `pip install schedulefree.`"
)
if not is_accelerate_available("0.30.0"):
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
from schedulefree import AdamWScheduleFree, SGDScheduleFree

additional_optim_kwargs = {}
if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
require_warmup = True

if args.optim == OptimizerNames.SCHEDULE_FREE_RADAM:
if not is_schedulefree_available("1.4.0"):
raise ImportError(
"You need to install `schedulefree>=1.4.0` in order to use RAdamScheduleFree optimizer. "
"Install it with `pip install schedulefree.`"
)
from schedulefree import RAdamScheduleFree

optimizer_cls = RAdamScheduleFree
additional_optim_kwargs = adam_kwargs
require_warmup = False
elif args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
optimizer_cls = AdamWScheduleFree
additional_optim_kwargs = adam_kwargs
elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
optimizer_cls = SGDScheduleFree
else:
raise ValueError("Invalid schedulefree optimizer")

additional_optim_kwargs["weight_decay"] = args.weight_decay
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
if require_warmup:
additional_optim_kwargs["warmup_steps"] = args.warmup_steps
additional_optim_kwargs.update(
{
"weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class OptimizerNames(ExplicitEnum):
LOMO = "lomo"
ADALOMO = "adalomo"
GROKADAMW = "grokadamw"
SCHEDULE_FREE_RADAM = "schedule_free_radam"
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
SCHEDULE_FREE_SGD = "schedule_free_sgd"

Expand Down
7 changes: 4 additions & 3 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")

ACCELERATE_MIN_VERSION = "0.26.0"
SCHEDULEFREE_MIN_VERSION = "1.2.6"
FSDP_MIN_VERSION = "1.12.0"
GGUF_MIN_VERSION = "0.10.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
Expand All @@ -107,7 +108,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
_schedulefree_available = _is_package_available("schedulefree")
_schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True)
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -400,8 +401,8 @@ def is_grokadamw_available():
return _grokadamw_available


def is_schedulefree_available():
return _schedulefree_available
def is_schedulefree_available(min_version: str = SCHEDULEFREE_MIN_VERSION):
return _schedulefree_available and version.parse(_schedulefree_version) >= version.parse(min_version)


def is_pyctcdecode_available():
Expand Down
23 changes: 23 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,29 @@ def test_schedulefree_adam(self):
learning_rate=1e-9,
logging_steps=5,
optim="schedule_free_adamw",
lr_scheduler_type="constant",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

@require_schedulefree
@require_torch_gpu
def test_schedulefree_radam(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
lr_scheduler_type="constant",
optim="schedule_free_radam",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

Expand Down