Skip to content

Commit

Permalink
Changed test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Sep 20, 2024
1 parent 414d972 commit d3b2c04
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
1 change: 1 addition & 0 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refitable=True,
reuse_cached_engines=False,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
Expand Down
35 changes: 24 additions & 11 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
@pytest.mark.unit
def test_mapping():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
trt_input = [
torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format)
Expand All @@ -58,6 +58,7 @@ def test_mapping():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)
settings = trt_gm._run_on_acc_0.settings
runtime = trt.Runtime(TRT_LOGGER)
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_refit_one_engine_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -141,8 +143,8 @@ def test_refit_one_engine_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_no_map_with_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -160,6 +162,7 @@ def test_refit_one_engine_no_map_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)

trt_gm._run_on_acc_0.weight_name_map = None
Expand Down Expand Up @@ -192,8 +195,8 @@ def test_refit_one_engine_no_map_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_with_wrong_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -211,6 +214,7 @@ def test_refit_one_engine_with_wrong_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)
# Manually Deleted all batch norm layer. This suppose to fail the fast refit
trt_gm._run_on_acc_0.weight_name_map = {
Expand Down Expand Up @@ -268,6 +272,7 @@ def test_refit_one_engine_bert_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -302,8 +307,8 @@ def test_refit_one_engine_bert_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_inline_runtime__with_weightmap():
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -321,6 +326,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
trt_gm = torch.export.load(trt_ep_path)
Expand Down Expand Up @@ -348,8 +354,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_python_runtime_with_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -367,6 +373,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -438,6 +445,7 @@ def forward(self, x):
min_block_size=min_block_size,
make_refitable=True,
torch_executed_ops=torch_executed_ops,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -487,10 +495,11 @@ def test_refit_one_engine_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
compiled_module=trt_gm,
new_weight_module=exp_program2,
arg_inputs=inputs,
use_weight_map_cache=False,
Expand Down Expand Up @@ -538,6 +547,7 @@ def test_refit_one_engine_bert_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -591,6 +601,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
trt_gm = torch.export.load(trt_ep_path)
Expand Down Expand Up @@ -637,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -708,6 +720,7 @@ def forward(self, x):
min_block_size=min_block_size,
make_refitable=True,
torch_executed_ops=torch_executed_ops,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down

0 comments on commit d3b2c04

Please sign in to comment.