-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Trainer.get_optimizer_cls_and_kwargs
to be overridden
#31875
Allow Trainer.get_optimizer_cls_and_kwargs
to be overridden
#31875
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @apoorvkh, thanks for opening a PR! Could you give some more details about how you'd like to use this method? I think it should be possible to override a staticmethod in a child class: In [1]: class Foo:
...: @staticmethod
...: def foo(a, b):
...: return a * b
...:
...: def bar(self, c, d):
...: return c ** d
...:
...: def baz(self, e, f):
...: return self.foo(e, f)
...:
...:
...: class Bar(Foo):
...: @staticmethod
...: def foo(a, b):
...: return a + b
...:
In [2]: Foo.foo(2, 3)
Out[2]: 6
In [3]: Bar.foo(2, 3)
Out[3]: 5
In [4]: Foo().bar(2, 3)
Out[4]: 8
In [5]: Bar().bar(2, 3)
Out[5]: 8
In [6]: Foo().baz(2, 3)
Out[6]: 6
In [7]: Bar().baz(2, 3)
Out[7]: 5 but I might be missing what you're trying to do |
Yes, can definitely elaborate: Say I want to use HF Trainer with an arbitrary PyTorch optimizer ( class CustomOptimizerTrainer(Trainer):
@staticmethod
def get_optimizer_cls_and_kwargs(args: HfTrainingArguments, model=None) -> tuple[type[torch.optim.Optimizer], dict[str, Any]]:
optimizer = torch.optim.AdamW
optimizer_kwargs = {
"lr": 4e-3,
"betas": (0.9, 0.999),
"weight_decay": 0.05,
}
return optimizer, optimizer_kwargs However, this won't take effect, because transformers/src/transformers/trainer.py Line 1076 in 6c1d0b0
This is not I also made Please let me know if that's clearer and if you agree! Thanks! |
@apoorvkh Thanks for taking the time for the detailed explanation! Yes, I think switching to I'd rather we didn't change the method to be an instance method - this is a breaking change which might affect many users downstream. |
Okay, sounds good then! That makes this a very simple PR. Made those changes and all tests pass :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this change and making our objects more flexible!
All LGTM - let's just get a seoncd 👍 from @muellerzr or @SunMarc to confirm this is all OK in trainer-land
Thanks! I am also considering making another (simple, no breaking changes) PR to support generic PyTorch optimizers via |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me as well, adds some nice flexibility :)
I was going to add support for something like TrainingArguments(
optim=torch.optim.AdamW,
optim_args={
"betas" : (0.9, 0.999),
"eps" : 1e-08,
"weight_decay" : 0.01
}
) (in addition to the existing functionality of But looks like TrainingArguments objects must be JSON serializable. I don't have another approach in mind that is as elegant. We could allow |
Big thanks for fixing this! This should save me a massive headache in Sentence Transformers.
|
What does this PR do?
Currently,
Trainer
builds an optimizer by loading the optimizer class and arguments fromTrainer.get_optimizer_cls_and_kwargs
inTrainer.create_optimizer
:transformers/src/transformers/trainer.py
Line 1076 in 6c1d0b0
However, this prevents the
get_optimizer_cls_and_kwargs()
function from being overridden. As a solution, I've changed it into an instance method (instead of a@staticmethod
) and fromTrainer.get_optimizer_cls_and_kwargs(args)
toself.get_optimizer_cls_and_kwargs()
in this PR. All existing functionality should remain as is, but this should now be extensible (if you subclassTrainer
).Note: I think this breaks the current tests, which expect
get_optimizer_cls_and_kwargs
to be a static method, e.g.transformers/tests/trainer/test_trainer.py
Line 4182 in ad35309
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr and @SunMarc