diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c1bf69502ba8..1b86b120dfcc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -559,6 +559,59 @@ def split_with_sizes(self, inputs, input_types): return _op.split(data, indices, dim) + def tensor_split(self, inputs, input_types): + # Reference: https://pytorch.org/docs/stable/generated/torch.tensor_split.html + import torch + + if not isinstance(inputs[1], (int, list, tuple, torch.Tensor)): + msg = "indices_or_sections type %s could not be parsed in tensor_split op" % ( + type(inputs[1]) + ) + raise AssertionError(msg) + + if isinstance(inputs[1], torch.Tensor) and not ( + list(inputs[1].shape) == [] or list(inputs[1].shape) == 1 + ): + msg = "indices_or_sections must be a zero-dimensional or one-dimensional long tensor" + raise AssertionError(msg) + + if isinstance(inputs[1], int) or ( + isinstance(inputs[1], torch.Tensor) and list(inputs[1].shape) == [] + ): + data = inputs[0] + n = int(inputs[1]) + dim = int(inputs[2]) + + split_size = int(self.infer_shape(data)[dim] / n) + split_rest = int(self.infer_shape(data)[dim] % n) + + indices = [] + split_index = split_size + if split_rest == 0: + for i in range(n - 1): + indices.append(split_index) + split_index += split_size + else: + for i in range(split_rest): + indices.append(split_index + 1) + split_index = (i + 1) * (split_index + 1) + for i in range(n - split_rest - 1): + split_index += split_size + indices.append(split_index) + + return _op.split(data, indices, dim) + else: + data = inputs[0] + sections = inputs[1] + dim = int(inputs[2]) + + if isinstance(sections, tuple): + sections = list(sections) + elif isinstance(sections, torch.Tensor): + sections = sections.cpu().numpy().tolist() + + return _op.split(data, sections, dim) + def select(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) @@ -3484,6 +3537,7 @@ def create_convert_map(self): "aten::slice": self.slice, "aten::narrow": self.narrow, "aten::split": self.split, + "aten::tensor_split": self.tensor_split, "aten::split_with_sizes": self.split_with_sizes, "aten::select": self.select, "aten::take": self.take, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 33c70a4d74a4..3c8bd5efd80d 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -959,6 +959,28 @@ def forward(self, *args): verify_model(Split([2, 3, 5], 1).float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_tensor_split(): + """test_forward_tensor_split""" + torch.set_grad_enabled(False) + input_shape = [4, 10] + + class Tensor_Split(Module): + def __init__(self, split_size_or_sections, dim): + super().__init__() + self.split_size_or_sections = split_size_or_sections + self.dim = dim + + def forward(self, *args): + return torch.tensor_split(args[0], self.split_size_or_sections, self.dim) + + input_data = torch.rand(input_shape).float() + verify_model(Tensor_Split(2, 0).float().eval(), input_data=input_data) + verify_model(Tensor_Split(torch.tensor(3), 1).float().eval(), input_data=input_data) + verify_model(Tensor_Split([2, 3, 5], 1).float().eval(), input_data=input_data) + verify_model(Tensor_Split((2, 3, 5), 1).float().eval(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_avgpool1d(): """test_forward_avgpool1d"""