-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix galore lr display with schedulers #31710
Fix galore lr display with schedulers #31710
Conversation
Failing tests seem unrelated to me: TF and hub issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, overall this makes sense. Can you add a test in trainer_utils for this by chance? https://github.com/huggingface/transformers/blob/main/tests/trainer/test_trainer_utils.py
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@muellerzr Wouldn't it make more sense over here? transformers/tests/trainer/test_trainer.py Line 1452 in 3345ae7
I would add two tests:
Is that reasonable? I'm not sure when I'll have the time tho. |
# reach given learning rate peak and end with 0 lr | ||
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate) | ||
self.assertTrue(logs[-1]["learning_rate"] == 0) | ||
|
||
# increasing and decreasing pattern of lrs | ||
increasing_lrs = [ | ||
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"] | ||
for i in range(len(logs)) | ||
if i < num_warmup_steps - 2 | ||
] | ||
decreasing_lrs = [ | ||
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"] | ||
for i in range(len(logs) - 1) | ||
if i >= num_warmup_steps - 2 | ||
] | ||
|
||
self.assertTrue(all(increasing_lrs)) | ||
self.assertTrue(all(decreasing_lrs)) | ||
|
||
# warm up steps << total steps | ||
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking for the general patterns of the cosine scheduler. We could just hardcode the values, but I don't think that's necessary.
Moved the tests in the general trainer tests but could also be moved elsewhere. Thought it was more appropriate over here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job with the test!
cc @amyeroberts for final review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
Just a comment on the default LR
src/transformers/optimization.py
Outdated
@@ -519,7 +519,7 @@ def scheduler_hook(param): | |||
if param.requires_grad: | |||
param.register_post_accumulate_grad_hook(scheduler_hook) | |||
|
|||
return LayerWiseDummyScheduler() | |||
return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults.get("lr", 1e-3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does the 1e-3 come from here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is like a double fallback. Shouldn't be necessary since the dummy optimizer is guaranteed to have a value.
The 1e-3 itself comes from torch galore as their specific defaults.
src/transformers/trainer_pt_utils.py
Outdated
last_epoch = -1 | ||
verbose = False | ||
super().__init__(optimizer, last_epoch, verbose) | ||
|
||
def get_lr(self): | ||
return [group["lr"] for group in self.optimizer.param_groups] | ||
# default value | ||
lrs = [1e-3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move the 1e-3 value out to a constant which get_lr
and get_scheduler
so that we only need to update in one place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value is in the dummy optimizer, I'll just save them on the initial creation of the dummy scheduler. This way we won't have the hardcoded value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for iterating!
Rebasing on main should resolve any timeout issues on the CI runs
55f9d8f
to
230adf6
Compare
@amyeroberts One timeout didn't make it through. Is it just my luck? 😆 |
@vasqu Just bad luck - although we'll need to look into it on our side why these flaky failures are happening. Thankfully some re-runs worked. Thanks for your patience! |
What does this PR do?
See #31707 for a detailed rundown. Fixes #31707
Tl;dr: Galore still has issues displaying the correct lr due to the lr scheduler this time.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr @SunMarc @amyeroberts @Minami-su