diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 680c180f52a5..d9b8f903b534 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2819,6 +2819,10 @@ def slide_axes(inp, shape, ax): return out + def einsum(self, inputs, input_types): + equation, data = inputs + return _op.einsum(data, equation) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -3063,6 +3067,7 @@ def create_convert_map(self): "aten::searchsorted": self.searchsorted, "aten::bucketize": self.bucketize, "aten::roll": self.roll, + "aten::einsum": self.einsum, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 5057f0d2b6b8..b30b0af20064 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4020,5 +4020,17 @@ def test_fn(shifts, dims): verify_model(test_fn(shifts=(2, 1), dims=(0, 1)), [x]) +@tvm.testing.uses_gpu +def test_einsum(): + def test_fn(equation): + return lambda *x: torch.einsum(equation, *x) + + x = torch.ones([2, 3]) + y = torch.ones([3, 4]) + z = torch.ones([4, 5]) + verify_model(test_fn("ij,jk"), [x, y]) + verify_model(test_fn("ij,jk,km->im"), [x, y, z]) + + if __name__ == "__main__": pytest.main([__file__])