From cdcac786a86ad43d31cdfe3655c2f82502b52aef Mon Sep 17 00:00:00 2001 From: ryohei-a-shimizu Date: Wed, 18 Dec 2024 09:45:34 +0900 Subject: [PATCH 1/4] add RAdamScheduleFree optimizer --- docs/source/en/trainer.md | 8 +++++-- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- src/transformers/trainer.py | 15 +++++++++--- src/transformers/training_args.py | 1 + tests/trainer/test_trainer.py | 23 +++++++++++++++++++ 6 files changed, 44 insertions(+), 7 deletions(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index e3a66f420424..d6f8269fe82a 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -550,7 +550,10 @@ This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMD The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682). Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule. -Supported optimizers for SFO are `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`. +Supported optimizers for SFO are `"schedule_free_radam"`, `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`. + +Schedule-Free family eliminates the need for learning rate schedules, so by default, we recommend setting `lr_scheduler_type="constant"` in the `TrainingArguments`. +Setting other `lr_scheduler_type` would also work, but combining Schedule-Free with other learning rate schedules is not well-studied both in research and in practice, as it may affect the optimizer's intended behavior and performance guarantees. Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision: @@ -566,7 +569,8 @@ args = TrainingArguments( output_dir="./test-schedulefree", max_steps=1000, per_device_train_batch_size=4, - optim="schedule_free_adamw", + optim="schedule_free_radam", + lr_scheduler_type="constant", gradient_checkpointing=True, logging_strategy="steps", logging_steps=1, diff --git a/setup.py b/setup.py index c2c0048d6913..04bcdb16bdcd 100644 --- a/setup.py +++ b/setup.py @@ -162,7 +162,7 @@ "sacremoses", "safetensors>=0.4.1", "sagemaker>=2.31.0", - "schedulefree>=1.2.6", + "schedulefree>=1.4", "scikit-learn", "scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`) "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 85345cc8e588..0273ac965890 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -68,7 +68,7 @@ "sacremoses": "sacremoses", "safetensors": "safetensors>=0.4.1", "sagemaker": "sagemaker>=2.31.0", - "schedulefree": "schedulefree>=1.2.6", + "schedulefree": "schedulefree>=1.4", "scikit-learn": "scikit-learn", "scipy": "scipy<1.13.0", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4d90c13df825..be6049f84f2d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1625,6 +1625,7 @@ def optimizer_hook(param): optimizer_cls = AdamW4bit optimizer_kwargs.update(adam_kwargs) elif args.optim in [ + OptimizerNames.SCHEDULE_FREE_RADAM, OptimizerNames.SCHEDULE_FREE_ADAMW, OptimizerNames.SCHEDULE_FREE_SGD, ]: @@ -1635,18 +1636,26 @@ def optimizer_hook(param): ) if not is_accelerate_available("0.30.0"): raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers") - from schedulefree import AdamWScheduleFree, SGDScheduleFree + from schedulefree import AdamWScheduleFree, RAdamScheduleFree, SGDScheduleFree additional_optim_kwargs = {} - if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW: + require_warmup = True + + if args.optim == OptimizerNames.SCHEDULE_FREE_RADAM: + optimizer_cls = RAdamScheduleFree + additional_optim_kwargs = adam_kwargs + require_warmup = False + elif args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW: optimizer_cls = AdamWScheduleFree additional_optim_kwargs = adam_kwargs elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD: optimizer_cls = SGDScheduleFree else: raise ValueError("Invalid schedulefree optimizer") + additional_optim_kwargs["weight_decay"] = args.weight_decay - additional_optim_kwargs["warmup_steps"] = args.warmup_steps + if require_warmup: + additional_optim_kwargs["warmup_steps"] = args.warmup_steps additional_optim_kwargs.update( { "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)), diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 6b141cff39e1..ddaa7af192dc 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -183,6 +183,7 @@ class OptimizerNames(ExplicitEnum): LOMO = "lomo" ADALOMO = "adalomo" GROKADAMW = "grokadamw" + SCHEDULE_FREE_RADAM = "schedule_free_radam" SCHEDULE_FREE_ADAMW = "schedule_free_adamw" SCHEDULE_FREE_SGD = "schedule_free_sgd" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d33be2789761..519682e591c6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1785,6 +1785,29 @@ def test_schedulefree_adam(self): learning_rate=1e-9, logging_steps=5, optim="schedule_free_adamw", + lr_scheduler_type="constant", + ) + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) + + # Check this works + _ = trainer.train() + + @require_schedulefree + @require_torch_gpu + def test_schedulefree_radam(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, + learning_rate=1e-9, + logging_steps=5, + lr_scheduler_type="constant", + optim="schedule_free_radam", ) trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) From d4a1830f2a4099e5f8db07bc7545d83914e9185d Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Wed, 25 Dec 2024 23:44:16 +0900 Subject: [PATCH 2/4] revert schedulefree version to the minimum requirement --- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index d1113ad14620..a78bb20dd0a4 100644 --- a/setup.py +++ b/setup.py @@ -163,7 +163,7 @@ "sacremoses", "safetensors>=0.4.1", "sagemaker>=2.31.0", - "schedulefree>=1.4", + "schedulefree>=1.2.6", "scikit-learn", "scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`) "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index ddf7582c8725..6a737b805a45 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -69,7 +69,7 @@ "sacremoses": "sacremoses", "safetensors": "safetensors>=0.4.1", "sagemaker": "sagemaker>=2.31.0", - "schedulefree": "schedulefree>=1.4", + "schedulefree": "schedulefree>=1.2.6", "scikit-learn": "scikit-learn", "scipy": "scipy<1.13.0", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", From 7b4b86ca40a19b50b9a66c1388cc4563a4c8f4da Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Wed, 25 Dec 2024 23:54:11 +0900 Subject: [PATCH 3/4] refine is_schedulefree_available so that it can take min_version --- src/transformers/trainer.py | 13 ++++++++++--- src/transformers/utils/import_utils.py | 7 ++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 04b6e42fa47e..fb082f4b9367 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1630,17 +1630,24 @@ def optimizer_hook(param): ]: if not is_schedulefree_available(): raise ImportError( - "You need to install `schedulefree` in order to use schedulefree optimizers" - " install it with `pip install schedulefree`" + "You need to install `schedulefree` in order to use schedulefree optimizers. " + "Install it with `pip install schedulefree.`" ) if not is_accelerate_available("0.30.0"): raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers") - from schedulefree import AdamWScheduleFree, RAdamScheduleFree, SGDScheduleFree + from schedulefree import AdamWScheduleFree, SGDScheduleFree additional_optim_kwargs = {} require_warmup = True if args.optim == OptimizerNames.SCHEDULE_FREE_RADAM: + if not is_schedulefree_available("1.4.0"): + raise ImportError( + "You need to install `schedulefree>=1.4.0` in order to use RAdamScheduleFree optimizer. " + "Install it with `pip install schedulefree.`" + ) + from schedulefree import RAdamScheduleFree + optimizer_cls = RAdamScheduleFree additional_optim_kwargs = adam_kwargs require_warmup = False diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index f880535dd6fe..809371df5f40 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -89,6 +89,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ TORCH_FX_REQUIRED_VERSION = version.parse("1.10") ACCELERATE_MIN_VERSION = "0.26.0" +SCHEDULEFREE_MIN_VERSION = "1.2.6" FSDP_MIN_VERSION = "1.12.0" GGUF_MIN_VERSION = "0.10.0" XLA_FSDPV2_MIN_VERSION = "2.2.0" @@ -107,7 +108,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") _grokadamw_available = _is_package_available("grokadamw") -_schedulefree_available = _is_package_available("schedulefree") +_schedulefree_available, _schedulefree_version = _is_package_available("schedulefree", return_version=True) # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -400,8 +401,8 @@ def is_grokadamw_available(): return _grokadamw_available -def is_schedulefree_available(): - return _schedulefree_available +def is_schedulefree_available(min_version: str = SCHEDULEFREE_MIN_VERSION): + return _schedulefree_available and version.parse(_schedulefree_version) >= version.parse(min_version) def is_pyctcdecode_available(): From 81b634d63fc9a73f364cf0baf6418f58be82d6e1 Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Thu, 26 Dec 2024 00:10:58 +0900 Subject: [PATCH 4/4] refine documents --- docs/source/en/trainer.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index d6f8269fe82a..016c7b434b55 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -544,16 +544,17 @@ trainer = Trainer( trainer.train() ``` -This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training. +This script demonstrates how to fine-tune the [google/gemma-2b](https://huggingface.co/google/gemma-2b) model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training. -### Schedule Free Optimizer +### Schedule-Free Optimizer + +The Schedule-Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682). +Supported optimizers for Schedule-Free are `schedule_free_radam`, `schedule_free_adamw` and `schedule_free_sgd`. First install schedulefree from pypi `pip install schedulefree`. -The Schedule Free optimizers have been introduced in [The Road Less Scheduled](https://hf.co/papers/2405.15682). Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule. -Supported optimizers for SFO are `"schedule_free_radam"`, `"schedule_free_adamw"` and `"schedule_free_sgd"`. First install schedulefree from pypi `pip install schedulefree`. +Additionally, neither `warmup_steps` nor `warmup_ratio` parameters are required when using `schedule_free_radam`. -Schedule-Free family eliminates the need for learning rate schedules, so by default, we recommend setting `lr_scheduler_type="constant"` in the `TrainingArguments`. -Setting other `lr_scheduler_type` would also work, but combining Schedule-Free with other learning rate schedules is not well-studied both in research and in practice, as it may affect the optimizer's intended behavior and performance guarantees. +By default, we recommend setting `lr_scheduler_type="constant"` in the `TrainingArguments`. Setting other `lr_scheduler_type` would also work, but combining Schedule-Free with other learning rate schedules is not well-studied both in research and in practice, as it may affect the optimizer's intended behavior and performance guarantees. Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision: