diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index c1b3708c92..f0bca786bc 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -247,10 +247,8 @@ def run_test( trt_inputs = inputs for num_input in range(num_inputs): input = inputs[num_input] - if input.dtype in (torch.int64, torch.float64): - dtype_32bit = ( - torch.int32 if (input.dtype == torch.int64) else torch.float32 - ) + if input.dtype is torch.float64: + dtype_32bit = torch.float32 # should we modify graph here to insert clone nodes? # ideally not required trt_inputs = (