Skip to content

Commit

Permalink
add einsum in pytorch frontend (apache#9651)
Browse files Browse the repository at this point in the history
* add einsum in pytorch frontend

* add einsum in pytorch frontend
  • Loading branch information
Meteorix authored and ylc committed Jan 7, 2022
1 parent e079a74 commit 78daccf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2819,6 +2819,10 @@ def slide_axes(inp, shape, ax):

return out

def einsum(self, inputs, input_types):
equation, data = inputs
return _op.einsum(data, equation)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3063,6 +3067,7 @@ def create_convert_map(self):
"aten::searchsorted": self.searchsorted,
"aten::bucketize": self.bucketize,
"aten::roll": self.roll,
"aten::einsum": self.einsum,
}

def update_convert_map(self, custom_map):
Expand Down
12 changes: 12 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4020,5 +4020,17 @@ def test_fn(shifts, dims):
verify_model(test_fn(shifts=(2, 1), dims=(0, 1)), [x])


@tvm.testing.uses_gpu
def test_einsum():
def test_fn(equation):
return lambda *x: torch.einsum(equation, *x)

x = torch.ones([2, 3])
y = torch.ones([3, 4])
z = torch.ones([4, 5])
verify_model(test_fn("ij,jk"), [x, y])
verify_model(test_fn("ij,jk,km->im"), [x, y, z])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 78daccf

Please sign in to comment.