Skip to content

Commit

Permalink
Fix small issue in pass_utils validation (#56)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: https://github.com/pytorch/fx2trt/pull/56

ATT

Reviewed By: khabinov

Differential Revision: D35649693

fbshipit-source-id: b3ef48142bdbeeca24cf12791f5946d8362d0564
  • Loading branch information
Shirong Wu authored and Wei Wei committed Jun 4, 2022
1 parent 2e3b265 commit 2cdaedb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def pass_with_validation(
) -> fx.GraphModule:
res0 = module(*input)
processed_module = pass_(module, input)
res1 = module(*input)
res1 = processed_module(*input)

tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
Expand All @@ -58,7 +58,7 @@ def pass_with_validation(
if not accuracy_check:
if suppress_accuracy_check_failure:
_LOGGER.error(f"pass {pass_} failed correctness check due to output {kk}, escape current pass.")
return module
return processed_module
else:
raise RuntimeError(f"pass {pass_} failed correctness check due to output {kk}")
return processed_module
Expand Down
3 changes: 2 additions & 1 deletion test/trt_lower/test_observer_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class Model(nn.Module):
def forward(self, x, y):
return x + y

mod = Model()
mod = Model().cuda()
inp = [torch.rand(1, 10), torch.rand(1, 10)]
inp = [i.cuda() for i in inp]
mod(*inp)

with execution_verifier() as verify_execution:
Expand Down

0 comments on commit 2cdaedb

Please sign in to comment.