Skip to content

Commit

Permalink
[Op][Topi] Gather, GatherND, Take can accept unsigned integers as ind…
Browse files Browse the repository at this point in the history
…ices (#10080)

* take rel

* gather and more tests

* gathernd case

* lint

* remove test which invalidates take preconditions

* re-add test

* fix dumb test failure oopsie
  • Loading branch information
AndrewZhaoLuo authored Jan 29, 2022
1 parent 21154c2 commit 0fb5ae2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 42 deletions.
4 changes: 2 additions & 2 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
ICHECK_GE(indices_dim_i, 1);
}
ICHECK(indices->dtype.is_int());
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());

Array<PrimExpr> out_shape;
for (size_t i = 0; i < ndim_i; ++i) {
Expand Down Expand Up @@ -1388,7 +1388,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim
}
for (size_t i = 0; i < indices_dim0; ++i) {
indices_position.Set(0, make_const(DataType::Int(32), i));
if (indices->dtype.is_int()) {
if (indices->dtype.is_int() || indices->dtype.is_uint()) {
real_indices.push_back(indices(indices_position));
} else {
real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,8 @@ bool TakeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (indices == nullptr) {
return false;
}
ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
<< "indices of take must be tensor of integer";
const auto param = attrs.as<TakeAttrs>();
ICHECK(param != nullptr);

Expand Down
75 changes: 44 additions & 31 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,15 @@

import numpy as np
import pytest

import tvm
import tvm.testing

from tvm import relay, te
from tvm.error import TVMError
from tvm.relay import create_executor, transform
from tvm.relay.testing import check_grad, run_infer_type

from utils import ref_funcs


executor_kind = tvm.testing.parameter("graph", "debug")


Expand Down Expand Up @@ -426,31 +423,36 @@ def test_take(self, dshape, indices_shape, oshape, axis):


class TestTake:
src_shape, indices_src, axis, mode = tvm.testing.parameters(
((4,), [1], None, "clip"),
((4,), [[0, 1, 2, 3]], None, "clip"),
((3, 3, 3), [[11, 25]], None, "clip"),
((4,), [[0, 1], [2, 3]], None, "clip"),
((4,), [1], 0, "clip"),
((2, 2), [[[1, 0], [0, 1]]], 0, "clip"),
((2, 2), [[[1, 0], [0, 1]]], 1, "clip"),
((4, 3, 5, 6), [[2, 1, 0, 0]], -2, "clip"),
((3, 4), [-5, 20], None, "clip"),
((3, 4), [-5, 20], None, "wrap"),
((3, 4), [-1, 2], 0, "clip"),
((3, 4), [-1, 2], 0, "wrap"),
((3, 4), [-1, 2], 1, "clip"),
((3, 4), [-1, 2], 1, "wrap"),
((3, 3, 3), [[11, 25]], None, "fast"),
((3, 4), [0, 2], 0, "fast"),
((3, 4), [0, 2], 1, "fast"),
src_shape, indices_src, axis, mode, indices_dtype = tvm.testing.parameters(
((4,), [1], None, "clip", "int32"),
((4,), [[0, 1, 2, 3]], None, "clip", "int32"),
((3, 3, 3), [[11, 25]], None, "clip", "int32"),
((4,), [[0, 1], [2, 3]], None, "clip", "int32"),
((4,), [1], 0, "clip", "int32"),
((2, 2), [[[1, 0], [0, 1]]], 0, "clip", "int32"),
((2, 2), [[[1, 0], [0, 1]]], 1, "clip", "int32"),
((4, 3, 5, 6), [[2, 1, 0, 0]], -2, "clip", "int32"),
((3, 4), [-5, 20], None, "clip", "int32"),
((3, 4), [-5, 20], None, "wrap", "int32"),
((3, 4), [-1, 2], 0, "clip", "int32"),
((3, 4), [-1, 2], 0, "wrap", "int32"),
((3, 4), [-1, 2], 1, "clip", "int32"),
((3, 4), [-1, 2], 1, "wrap", "int32"),
((3, 3, 3), [[11, 25]], None, "fast", "int32"),
((3, 4), [0, 2], 0, "fast", "int32"),
((3, 4), [0, 2], 1, "fast", "int32"),
((3, 4), [1, 2], 1, "clip", "uint32"),
((3, 4), [1, 2], 1, "wrap", "uint16"),
((3, 3, 3), [1, 2], None, "fast", "uint16"),
((3, 4), [0, 2], 0, "fast", "uint8"),
)

# Incorrect numeric output in some cases on vulkan
@tvm.testing.known_failing_targets("vulkan")
def test_take(self, target, dev, executor_kind, src_shape, indices_src, axis, mode):
def test_take(
self, target, dev, executor_kind, src_shape, indices_src, axis, mode, indices_dtype
):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
x = relay.var("x", relay.TensorType(src_shape, src_dtype))
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
Expand All @@ -459,11 +461,16 @@ def test_take(self, target, dev, executor_kind, src_shape, indices_src, axis, mo
func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
np_mode = "raise" if mode == "fast" else mode
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)

op_res = relay.create_executor(executor_kind, device=dev, target=target).evaluate(func)(
x_data, indices_src
)

# Old versions of numpy has take internally cast inside take which may violate
# safety rules. We have such version in i386 CI image.
indices_src = indices_src.astype("int32")
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)

tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)


Expand Down Expand Up @@ -1267,12 +1274,12 @@ def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype):
],
)
def test_gather(target, dev, executor_kind, data, axis, indices, ref_res):
def verify_gather(data, axis, indices, ref_res):
def verify_gather(data, axis, indices, ref_res, indices_dtype="int32"):
data = np.asarray(data, dtype="float32")
indices = np.asarray(indices, dtype="int32")
indices = np.asarray(indices, dtype=indices_dtype)
ref_res = np.asarray(ref_res)
d = relay.var("x", relay.TensorType(data.shape, "float32"))
i = relay.var("y", relay.TensorType(indices.shape, "int32"))
i = relay.var("y", relay.TensorType(indices.shape, indices_dtype))
z = relay.gather(d, axis, i)

func = relay.Function([d, i], z)
Expand All @@ -1283,22 +1290,25 @@ def verify_gather(data, axis, indices, ref_res):
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)

verify_gather(data, axis, indices, ref_res)
verify_gather(data, axis, indices, ref_res, indices_dtype="uint32")

verify_gather(data, axis, indices, ref_res)


def test_gather_nd(target, dev, executor_kind):
def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
def verify_gather_nd(xshape, yshape, y_data, batch_dims=0, indices_dtype="int32"):
x = relay.var("x", relay.TensorType(xshape, "float32"))
y = relay.var("y", relay.TensorType(yshape, "int32"))
y = relay.var("y", relay.TensorType(yshape, indices_dtype))
z = relay.gather_nd(x, y, batch_dims)

func = relay.Function([x, y], z)

x_data = np.random.uniform(size=xshape).astype("float32")

if y_data:
y_data = np.array(y_data, dtype="int32")
y_data = np.array(y_data, dtype=indices_dtype)
else:
y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")
y_data = np.random.randint(low=0, high=2, size=yshape, dtype=indices_dtype)

ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims)

Expand Down Expand Up @@ -1335,6 +1345,9 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2)

verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2, indices_dtype="uint8")
verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1, indices_dtype="uint32")


def _verify_infiniteness_ops(relay_op, ref_op):
for dtype in ["float32", "float16", "float16", "int32", "int16"]:
Expand Down
25 changes: 17 additions & 8 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
from tvm import relay
import tvm.testing
import tvm.topi.testing
from tvm import relay, te, topi
from tvm.contrib.nvcc import have_fp16

import tvm.testing


def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = te.placeholder(shape=in_shape, name="A")
Expand Down Expand Up @@ -356,9 +353,8 @@ def check_device(target, dev):
)


def verify_take(src_shape, indices_src, axis=None, mode="clip"):
def verify_take(src_shape, indices_src, axis=None, mode="clip", indices_dtype="int32"):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
A = te.placeholder(shape=src_shape, dtype=src_dtype, name="A")
indices = te.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
Expand Down Expand Up @@ -999,6 +995,9 @@ def test_take():
verify_take((3, 3, 3), [[11, 25]], mode="fast")
verify_take((3, 4), [0, 2], axis=0, mode="fast")
verify_take((3, 4), [0, 2], axis=1, mode="fast")
verify_take((3, 4), [1, 2], axis=1, indices_dtype="uint32")
verify_take((3, 4), [1, 2], axis=1, mode="wrap", indices_dtype="uint16")
verify_take((3, 3, 3), [[11, 20]], mode="fast", indices_dtype="uint8")


@tvm.testing.uses_gpu
Expand All @@ -1010,11 +1009,21 @@ def test_gather():
verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2)))
verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10)))
verify_gather(
np.random.randn(4, 7, 5),
2,
np.random.randint(low=0, high=5, size=(4, 7, 10)).astype("uint32"),
)
verify_gather(
np.random.randn(4, 7, 5),
2,
np.random.randint(low=0, high=5, size=(4, 7, 10)).astype("uint8"),
)


@tvm.testing.uses_gpu
def test_gather_nd():
for indices_dtype in ["int32", "float32"]:
for indices_dtype in ["int32", "float32", "uint8"]:
verify_gather_nd((4,), [[1.8]], indices_dtype)
verify_gather_nd((4,), [[1, 3, 2]], indices_dtype)
verify_gather_nd((2, 3), [[1]], indices_dtype)
Expand Down

0 comments on commit 0fb5ae2

Please sign in to comment.