diff --git a/fx/passes/pass_utils.py b/fx/passes/pass_utils.py index bba820fb3f..44c9d843a9 100644 --- a/fx/passes/pass_utils.py +++ b/fx/passes/pass_utils.py @@ -43,7 +43,7 @@ def pass_with_validation( ) -> fx.GraphModule: res0 = module(*input) processed_module = pass_(module, input) - res1 = module(*input) + res1 = processed_module(*input) tensor_res_0 = _collect_tensors(res0) tensor_res_1 = _collect_tensors(res1) @@ -58,7 +58,7 @@ def pass_with_validation( if not accuracy_check: if suppress_accuracy_check_failure: _LOGGER.error(f"pass {pass_} failed correctness check due to output {kk}, escape current pass.") - return module + return processed_module else: raise RuntimeError(f"pass {pass_} failed correctness check due to output {kk}") return processed_module diff --git a/test/trt_lower/test_observer_gpu.py b/test/trt_lower/test_observer_gpu.py index 40281484f5..266a7c23e0 100644 --- a/test/trt_lower/test_observer_gpu.py +++ b/test/trt_lower/test_observer_gpu.py @@ -24,8 +24,9 @@ class Model(nn.Module): def forward(self, x, y): return x + y - mod = Model() + mod = Model().cuda() inp = [torch.rand(1, 10), torch.rand(1, 10)] + inp = [i.cuda() for i in inp] mod(*inp) with execution_verifier() as verify_execution: