Skip to content

Commit

Permalink
[frontend][pytorch]Support aten::Tensor_split operator (#12871)
Browse files Browse the repository at this point in the history
Support aten::Tensor_split operator
  • Loading branch information
chengven027 authored Sep 26, 2022
1 parent cc6e01e commit 87085b0
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
54 changes: 54 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,59 @@ def split_with_sizes(self, inputs, input_types):

return _op.split(data, indices, dim)

def tensor_split(self, inputs, input_types):
# Reference: https://pytorch.org/docs/stable/generated/torch.tensor_split.html
import torch

if not isinstance(inputs[1], (int, list, tuple, torch.Tensor)):
msg = "indices_or_sections type %s could not be parsed in tensor_split op" % (
type(inputs[1])
)
raise AssertionError(msg)

if isinstance(inputs[1], torch.Tensor) and not (
list(inputs[1].shape) == [] or list(inputs[1].shape) == 1
):
msg = "indices_or_sections must be a zero-dimensional or one-dimensional long tensor"
raise AssertionError(msg)

if isinstance(inputs[1], int) or (
isinstance(inputs[1], torch.Tensor) and list(inputs[1].shape) == []
):
data = inputs[0]
n = int(inputs[1])
dim = int(inputs[2])

split_size = int(self.infer_shape(data)[dim] / n)
split_rest = int(self.infer_shape(data)[dim] % n)

indices = []
split_index = split_size
if split_rest == 0:
for i in range(n - 1):
indices.append(split_index)
split_index += split_size
else:
for i in range(split_rest):
indices.append(split_index + 1)
split_index = (i + 1) * (split_index + 1)
for i in range(n - split_rest - 1):
split_index += split_size
indices.append(split_index)

return _op.split(data, indices, dim)
else:
data = inputs[0]
sections = inputs[1]
dim = int(inputs[2])

if isinstance(sections, tuple):
sections = list(sections)
elif isinstance(sections, torch.Tensor):
sections = sections.cpu().numpy().tolist()

return _op.split(data, sections, dim)

def select(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
Expand Down Expand Up @@ -3484,6 +3537,7 @@ def create_convert_map(self):
"aten::slice": self.slice,
"aten::narrow": self.narrow,
"aten::split": self.split,
"aten::tensor_split": self.tensor_split,
"aten::split_with_sizes": self.split_with_sizes,
"aten::select": self.select,
"aten::take": self.take,
Expand Down
22 changes: 22 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,28 @@ def forward(self, *args):
verify_model(Split([2, 3, 5], 1).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_tensor_split():
"""test_forward_tensor_split"""
torch.set_grad_enabled(False)
input_shape = [4, 10]

class Tensor_Split(Module):
def __init__(self, split_size_or_sections, dim):
super().__init__()
self.split_size_or_sections = split_size_or_sections
self.dim = dim

def forward(self, *args):
return torch.tensor_split(args[0], self.split_size_or_sections, self.dim)

input_data = torch.rand(input_shape).float()
verify_model(Tensor_Split(2, 0).float().eval(), input_data=input_data)
verify_model(Tensor_Split(torch.tensor(3), 1).float().eval(), input_data=input_data)
verify_model(Tensor_Split([2, 3, 5], 1).float().eval(), input_data=input_data)
verify_model(Tensor_Split((2, 3, 5), 1).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_avgpool1d():
"""test_forward_avgpool1d"""
Expand Down

0 comments on commit 87085b0

Please sign in to comment.