Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed batchnorm bug #3170

Merged
merged 5 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refittable=True,
reuse_cached_engines=False,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
Expand Down
12 changes: 9 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,18 @@ def _save_weight_mapping(self) -> None:
# Retrieve each weight name(s) in state_dict
if layer_type == "CONSTANT":
if "embedding" in suffix:
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zewenli98 keep track of this, seems like there could be a lot of possible names we need to have might want to look at a generic solution later

sd_weight_name = f"{sd_weight_name}.weight"
elif "weight" in suffix or "mm_other" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
sd_weight_name = f"{sd_weight_name}.weight"
elif "running_mean" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.running_mean"
elif "running_var" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.running_var"
else:
sd_weight_name = f"{sd_weight_name}.{torch_attr[1]}"
sd_weight_name = f"{sd_weight_name}.bias"
elif layer_type == "SCALE":
# Batch norm needs all weights to calculate scale and shift
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]
Expand Down
29 changes: 21 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,27 @@ def batch_norm(
# Save the original output shape for later use
output_shape = input.shape

if weight is None:
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
if bias is None:
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
if running_mean is None:
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_var is None:
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
# We name the weight here according to the state_dict name
weight = (
get_trt_tensor(ctx, 1.0, f"{name}_weight")
if weight is None
else get_trt_tensor(ctx, weight, f"{name}_weight")
)
bias = (
get_trt_tensor(ctx, 0.0, f"{name}_bias")
if bias is None
else get_trt_tensor(ctx, bias, f"{name}_bias")
)
running_mean = (
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_mean is None
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
)
running_var = (
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
if running_var is None
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
)

# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
Expand Down
33 changes: 23 additions & 10 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
@pytest.mark.unit
def test_mapping():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
trt_input = [
torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format)
Expand All @@ -58,6 +58,7 @@ def test_mapping():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)
settings = trt_gm._run_on_acc_0.settings
runtime = trt.Runtime(TRT_LOGGER)
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_refit_one_engine_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -141,8 +143,8 @@ def test_refit_one_engine_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_no_map_with_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -160,6 +162,7 @@ def test_refit_one_engine_no_map_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

trt_gm._run_on_acc_0.weight_name_map = None
Expand Down Expand Up @@ -192,8 +195,8 @@ def test_refit_one_engine_no_map_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_with_wrong_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -211,6 +214,7 @@ def test_refit_one_engine_with_wrong_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)
# Manually Deleted all batch norm layer. This suppose to fail the fast refit
trt_gm._run_on_acc_0.weight_name_map = {
Expand Down Expand Up @@ -268,6 +272,7 @@ def test_refit_one_engine_bert_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -302,8 +307,8 @@ def test_refit_one_engine_bert_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_inline_runtime__with_weightmap():
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -321,6 +326,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
trt_gm = torch.export.load(trt_ep_path)
Expand Down Expand Up @@ -348,8 +354,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_python_runtime_with_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -367,6 +373,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -438,6 +445,7 @@ def forward(self, x):
min_block_size=min_block_size,
make_refittable=True,
torch_executed_ops=torch_executed_ops,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -487,6 +495,7 @@ def test_refit_one_engine_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -538,6 +547,7 @@ def test_refit_one_engine_bert_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -591,6 +601,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
trt_gm = torch.export.load(trt_ep_path)
Expand Down Expand Up @@ -637,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -708,6 +720,7 @@ def forward(self, x):
min_block_size=min_block_size,
make_refittable=True,
torch_executed_ops=torch_executed_ops,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down
Loading