-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed batchnorm bug #3170
Fixed batchnorm bug #3170
Conversation
be7eeb6
to
d3b2c04
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zewenli98 what are your thoughts here, seems like a simple enough change. @cehongwang after this is in would there be any other failure modes for fast refit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this PR still a draft or ready for review (rebase to main is needed)?
else get_trt_tensor(ctx, weight, f"{name}_weight") | ||
) | ||
bias = ( | ||
get_trt_tensor(ctx, 1.0, f"{name}_bias") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be 0?
else get_trt_tensor(ctx, bias, f"{name}_bias") | ||
) | ||
running_mean = ( | ||
get_trt_tensor(ctx, 1.0, f"{name}_running_mean") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be 0?
Since batchnorm was refactored previously, renaming maybe work for fast refit. |
I am setting up the naming trace for fast refit by naming the trt weight name same
as what it is in the state dict. This is still a draft I am having trouble
with installing the code lint on my current computer. The code is running
and passed the tests though. Will convert to a PR after fixing the lint.
…On Tue, Oct 1, 2024 at 3:15 PM Zewen (Evan) Li ***@***.***> wrote:
@zewenli98 <https://github.com/zewenli98> what are your thoughts here,
seems like a simple enough change. @cehongwang
<https://github.com/cehongwang> after this is in would there be any other
failure modes for fast refit?
Since batchnorm was refactored previously, renaming maybe work for fast
refit.
—
Reply to this email directly, view it on GitHub
<#3170 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/A5PD2UEYCT52QRQA3C43V4LZZMNH3AVCNFSM6AAAAABOS2MK56VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOBXGE4DAMBZGQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
I will do similar changes to other applicable converters to make the naming trace more robust. There should not be other failures after that because we have the value trace to guarantee the correctness. |
@cehongwang @narendasan Can you guys take a look at the issue #3200 which may be related to this PR |
d3b2c04
to
2734ebc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rebase and lgtm
@@ -69,7 +69,8 @@ | |||
debug=debug, | |||
min_block_size=min_block_size, | |||
torch_executed_ops=torch_executed_ops, | |||
make_refittable=True, | |||
make_refitable=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be 2 ts now, might want to rebase this branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rebased
@@ -477,12 +477,18 @@ def _save_weight_mapping(self) -> None: | |||
# Retrieve each weight name(s) in state_dict | |||
if layer_type == "CONSTANT": | |||
if "embedding" in suffix: | |||
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zewenli98 keep track of this, seems like there could be a lot of possible names we need to have might want to look at a generic solution later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Description
Batch norm value trace fails when the intial weights are identical. This pull request fixed batch norm fast refit case using name trace.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: