diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f7538f0837c6..16dd0c447124 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1972,10 +1972,7 @@ def stack(self, inputs, input_types): return self.tensor_array_stack(inputs, input_types) def rsub(self, inputs, input_types): - data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2]) - - # TODO (t-vi): should this also be part of the type promotion? - alpha = _expr.const(float(inputs[2])) + data0, data1, alpha = self.pytorch_promote_types(inputs, input_types) # note: rsub means data0 and data1 swap places return get_relay_op("subtract")(data1, alpha * data0) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3fbef494f16d..a02701b5278a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2691,6 +2691,13 @@ def forward(self, *args): verify_model(Rsub2().float().eval(), input_data=[d1, d2]) verify_model(Rsub2().float().eval(), input_data=[d1, d3]) + d1 = torch.rand([1, 3]).half() + d2 = torch.rand([1, 3]).half() + verify_model(Rsub1().half().eval(), input_data=[d1, d2]) + verify_model(Rsub1().half().eval(), input_data=[d1, d3]) + verify_model(Rsub2().half().eval(), input_data=[d1, d2]) + verify_model(Rsub2().half().eval(), input_data=[d1, d3]) + @tvm.testing.uses_gpu def test_forward_embedding():