Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Tensorflow] Add unique operator #7441

Merged
merged 16 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,14 @@ struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
}
};

/*! \brief Attributes used in unique operator */
struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
bool sorted;
TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs") {
TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted").set_default(true);
}
}; // struct UniqueAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,6 +2324,20 @@ def _impl(inputs, attr, params, mod):
return _impl


def _unique():
def _impl(inputs, attr, params, mod):
masahi marked this conversation as resolved.
Show resolved Hide resolved
assert len(inputs) == 1
data = inputs[0]
[unique, indices, num_uniq] = _op.unique(data, is_sorted=False)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices]),
2,
)

return _impl


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2502,6 +2516,7 @@ def _impl(inputs, attr, params, mod):
"TopKV2": _topk(),
"Transpose": _transpose(),
"TruncateMod": _elemwise("mod"),
"Unique": _unique(),
"Unpack": _unpack(),
"UnravelIndex": _unravel_index(),
"Where": _where(),
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ def compute_cumsum(attrs, inputs, output_type):
_reg.register_strategy("cumsum", strategy.cumsum_strategy)
_reg.register_shape_func("cumsum", False, elemwise_shape_func)


@_reg.register_compute("unique")
def compute_unique(attrs, inputs, output_type):
"""Compute definition of cumsum"""
return topi.unique(inputs[0], attrs.sorfted)


_reg.register_strategy("unique", strategy.unique_strategy)

#####################
# Shape functions #
#####################
Expand Down Expand Up @@ -946,3 +955,22 @@ def where_shape_func(attrs, inputs, _):
out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)

return [out_shape]


@script
def _unique_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
return (unique_shape, indices_shape, num_unique_shape)


@_reg.register_shape_func("unique", False)
def unique_shape_func(attrs, inputs, _):
"""
Shape func for unique operator.
"""
return _unique_shape(inputs[0])
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,3 +1432,24 @@ def cumsum_strategy(attrs, inputs, out_type, target):
name="cumsum.generic",
)
return strategy


def wrap_compute_unique(topi_compute):
"""Wrap unique topi compute"""

def _compute_unique(attrs, inputs, _):
return topi_compute(inputs[0], attrs.sorted)

return _compute_unique


@override_native_generic_func("unique_strategy")
def unique_strategy(attrs, inputs, out_type, target):
"""unique generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_unique(topi.unique),
wrap_topi_schedule(topi.generic.schedule_unique),
name="unique.generic",
)
return strategy
33 changes: 33 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,3 +1463,36 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
-> [1, 1, 2, 2, 3, 4, 4]
"""
return _make.cumsum(data, axis, dtype, exclusive)


def unique(data, is_sorted=True):
"""
Find the unique elements of a tensor
Parameters
----------
data : relay.Expr
A 1-D tensor of integers
sorted : bool
Whether to sort the unique elements in ascending order before returning as output
Returns
-------
output : relay.Expr
A 1-D tensor containing the unique elements of the input data tensor
indices : relay.Expr
A 1-D tensor containing the index of each data element in the output tensor
num_unique : relay.Expr
A 0-D tensor containing the number of unique elements in the input data tensor
Examples
--------
masahi marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: python
[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]

[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=True)
output = [1, 2, 3, 4, 5, ?, ?, ?]
indices = [3, 4, 0, 1, 2, 2, 3, 4]
num_unique = [5]
"""
return TupleWrapper(_make.unique(data, is_sorted), 3)
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .argwhere import *
from .cumsum import *
from .einsum import *
from .unique import *
from . import generic
from . import nn
from . import x86
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/topi/generic/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,19 @@ def schedule_scatter_add(outs):

def schedule_sparse_fill_empty_rows(outs):
return _default_schedule(outs, False)


def schedule_unique(outs):
"""Schedule for unique operator.

Parameters
----------
outs: Array of Tensor
The computation graph description of unique.

Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
118 changes: 118 additions & 0 deletions python/tvm/topi/unique.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Unique operator"""
from ..te import hybrid
from .cumsum import cumsum
from .sort import sort, argsort


@hybrid.script
def _calc_adjacent_diff(data):
output = output_tensor(data.shape, "int32")
output[0] = int32(0)
for i in range(1, data.shape[0]):
masahi marked this conversation as resolved.
Show resolved Hide resolved
output[i] = int32(1) if data[i] != data[i - 1] else int32(0)
return output


@hybrid.script
def _calc_num_unique(data):
output = output_tensor((1,), "int32")
output[0] = data[data.shape[0] - 1] + 1
return output


@hybrid.script
def _calc_unique_sorted(data, argsorted_indices, inc_scan):
unique_elements = output_tensor(data.shape, data.dtype)
indices = output_tensor(data.shape, "int32")
for i in range(data.shape[0]):
indices[argsorted_indices[i]] = inc_scan[i]
masahi marked this conversation as resolved.
Show resolved Hide resolved
unique_elements[inc_scan[i]] = data[argsorted_indices[i]]
return unique_elements, indices


@hybrid.script
def _calc_first_occurence(argsorted_indices, inc_scan):
first_occurence = output_tensor(argsorted_indices.shape, "int32")
for i in range(argsorted_indices.shape[0]):
first_occurence[i] = argsorted_indices.shape[0]
for i in range(argsorted_indices.shape[0]):
if i == 0 or inc_scan[i] != inc_scan[i - 1]:
first_occurence[inc_scan[i]] = argsorted_indices[i]
return first_occurence


@hybrid.script
def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter):
unique_elements = output_tensor(data.shape, data.dtype)
indices = output_tensor(data.shape, "int32")
for i in range(data.shape[0]):
new_unique_idx = index_converter[inc_scan[i]]
new_data_idx = argsorted_indices[i]
unique_elements[new_unique_idx] = data[new_data_idx]
indices[new_data_idx] = new_unique_idx
return unique_elements, indices


def unique(data, is_sorted=True):
"""
Find the unique elements of a tensor
Parameters
----------
data : relay.Expr
A 1-D tensor of integers
sorted : bool
Whether to sort the unique elements in ascending order before returning as output
Returns
-------
output : relay.Expr
A 1-D tensor containing the unique elements of the input data tensor
indices : relay.Expr
A 1-D tensor containing the index of each data element in the output tensor
num_unique : relay.Expr
A 0-D tensor containing the number of unique elements in the input data tensor
Examples
--------
.. code-block:: python
[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]

[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=True)
output = [1, 2, 3, 4, 5, ?, ?, ?]
indices = [3, 4, 0, 1, 2, 2, 3, 4]
num_unique = [5]
"""

sorted_data = sort(data)
argsorted_indices = argsort(data, dtype="int32")
adjacent_diff = _calc_adjacent_diff(sorted_data)
inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0)
num_unique_elements = _calc_num_unique(inc_scan)
if is_sorted:
unique_elements, inverse_indices = _calc_unique_sorted(data, argsorted_indices, inc_scan)
else:
first_occurence = _calc_first_occurence(argsorted_indices, inc_scan)
argsorted_first_occurence = argsort(first_occurence, dtype="int32")
index_converter = argsort(argsorted_first_occurence, dtype="int32")
unique_elements, inverse_indices = _calc_unique_unsorted(
data, argsorted_indices, inc_scan, index_converter
)
return [unique_elements, inverse_indices, num_unique_elements]
42 changes: 42 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3772,5 +3772,47 @@ RELAY_REGISTER_OP("cumsum")
.add_type_rel("Cumsum", CumsumRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque);

TVM_REGISTER_NODE_TYPE(UniqueAttrs);

bool UniqueRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
ICHECK_EQ(types.size(), 2) << "Unique: expect 2 types but " << types.size() << " provided";
ICHECK_EQ(num_inputs, 1) << "Unique: expect 1 inputs but " << num_inputs << " provided";
auto data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "Unique: expect input type to be TensorType but get " << types[0];
return false;
}
const int ndim = static_cast<int>(data->shape.size());
ICHECK_EQ(ndim, 1) << "Unique: input must be 1-D tensor";
ICHECK_EQ(data->dtype.is_int(), true) << "Unique: input must have int32 or int64 dtype";
std::vector<Type> fields;
fields.push_back(TensorType(data->shape, data->dtype)); // unique
fields.push_back(TensorType(data->shape, DataType::Int(32))); // indices
fields.push_back(TensorType(Array<PrimExpr>{1}, DataType::Int(32))); // num_unique
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
return true;
}

Expr MakeUnique(Expr data, bool sorted) {
auto attrs = make_object<UniqueAttrs>();
attrs->sorted = sorted;
static const Op& op = Op::Get("unique");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.unique").set_body_typed(MakeUnique);

RELAY_REGISTER_OP("unique")
.describe(
R"code(This operation returns the unique elements and the new index of each item in a given 1-D array.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
.add_type_rel("unique", UniqueRel)
.set_support_level(3)
.set_attr<TOpPattern>("TOpPattern", kOpaque);
} // namespace relay
} // namespace tvm
31 changes: 31 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4839,5 +4839,36 @@ def lstm_cell():
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)


#######################################################################
# Unique
# ------------


def _test_unique(n, dtype, is_dyn):
""" One iteration of a Stridedslice """

tf.reset_default_graph()
np_data = np.random.randint(100, size=n).astype(dtype)
with tf.Graph().as_default():
if is_dyn:
in_data = tf.placeholder(dtype, [n], name="in_data")
else:
in_data = tf.constant(np_data, dtype, name="in_data")
tf.unique(in_data)
if is_dyn:
compare_tf_with_tvm(np_data, "in_data:0", ["Unique:0", "Unique:1"], mode="vm")
else:
compare_tf_with_tvm(None, "", ["Unique:0", "Unique:1"])


def test_forward_unique():
"""test Unique"""

for dtype in ["int32", "int64"]:
for is_dyn in [False, True]:
_test_unique(50, dtype, is_dyn)
_test_unique(100, dtype, is_dyn)


if __name__ == "__main__":
pytest.main([__file__])
Loading