diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index c96b48d656a7..759d6235191b 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -460,6 +460,18 @@ struct CumsumAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in unique operator */ +struct UniqueAttrs : public tvm::AttrsNode { + bool sorted; + bool return_counts; + TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs") { + TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted").set_default(true); + TVM_ATTR_FIELD(return_counts) + .describe("Whether to return an additional tensor with counts of each unique elements") + .set_default(false); + } +}; // struct UniqueAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 931611274c20..679541051e75 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2157,6 +2157,24 @@ def is_floating_point(self, inputs, input_types): is_float = input_type in ["float32", "float64", "float16", "bfloat16"] return _expr.const(is_float) + def unique(self, inputs, input_types): + assert len(inputs) == 4 + [data, is_sorted, return_inverse, return_counts] = inputs + if not is_sorted: + logging.warning("TVM always assumes sorted=True for torch.unique") + is_sorted = True + if return_counts: + [unique, indices, num_uniq, counts] = _op.unique( + data, is_sorted=is_sorted, return_counts=True + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") + return (unique_sliced, indices, counts_sliced) + else: + [unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + return (unique_sliced, indices) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2363,6 +2381,7 @@ def create_convert_map(self): "aten::masked_select": self.masked_select, "aten::argsort": self.argsort, "aten::sort": self.sort, + "aten::_unique2": self.unique, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ab98cddd3835..65f18c029441 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2471,6 +2471,30 @@ def _impl(inputs, attr, params, mod): return _impl +def _unique(return_counts=True): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 1 + data = inputs[0] + if return_counts: + [unique, indices, num_uniq, counts] = _op.unique( + data, is_sorted=False, return_counts=True + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, indices, counts_sliced]), + 3, + ) + [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, indices]), + 2, + ) + + return _impl + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2650,6 +2674,8 @@ def _impl(inputs, attr, params, mod): "TopKV2": _topk(), "Transpose": _transpose(), "TruncateMod": _elemwise("mod"), + "Unique": _unique(False), + "UniqueWithCounts": _unique(True), "Unpack": _unpack(), "UnravelIndex": _unravel_index(), "Where": _where(), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index c41f956493b5..98797f06c7af 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -151,6 +151,15 @@ def compute_cumsum(attrs, inputs, output_type): _reg.register_strategy("cumsum", strategy.cumsum_strategy) _reg.register_shape_func("cumsum", False, elemwise_shape_func) + +@_reg.register_compute("unique") +def compute_unique(attrs, inputs, output_type): + """Compute definition of unique""" + return topi.unique(inputs[0], attrs.sorted, attrs.return_counts) + + +_reg.register_strategy("unique", strategy.unique_strategy) + ##################### # Shape functions # ##################### @@ -957,3 +966,38 @@ def where_shape_func(attrs, inputs, _): out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape) return [out_shape] + + +@script +def _unique_shape(data_shape): + unique_shape = output_tensor((1,), "int64") + indices_shape = output_tensor((1,), "int64") + num_unique_shape = output_tensor((1,), "int64") + unique_shape[0] = data_shape[0] + indices_shape[0] = data_shape[0] + num_unique_shape[0] = int64(1) + return (unique_shape, indices_shape, num_unique_shape) + + +@script +def _unique_with_counts_shape(data_shape): + unique_shape = output_tensor((1,), "int64") + indices_shape = output_tensor((1,), "int64") + num_unique_shape = output_tensor((1,), "int64") + counts_shape = output_tensor((1,), "int64") + unique_shape[0] = data_shape[0] + indices_shape[0] = data_shape[0] + num_unique_shape[0] = int64(1) + counts_shape[0] = data_shape[0] + return (unique_shape, indices_shape, num_unique_shape, counts_shape) + + +@_reg.register_shape_func("unique", False) +def unique_shape_func(attrs, inputs, _): + """ + Shape func for unique operator. + """ + if attrs.return_counts: + return _unique_with_counts_shape(inputs[0]) + else: + return _unique_shape(inputs[0]) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 20c5f03b9b0b..3abc9c42b659 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1009,3 +1009,15 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target): name="cumsum.cuda", ) return strategy + + +@unique_strategy.register(["cuda", "gpu"]) +def unique_strategy_cuda(attrs, inputs, out_type, target): + """unique cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_unique(topi.cuda.unique), + wrap_topi_schedule(topi.cuda.schedule_scan), + name="unique.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index da8107f961e8..a5773a4ada5f 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1440,3 +1440,24 @@ def cumsum_strategy(attrs, inputs, out_type, target): name="cumsum.generic", ) return strategy + + +def wrap_compute_unique(topi_compute): + """Wrap unique topi compute""" + + def _compute_unique(attrs, inputs, _): + return topi_compute(inputs[0], attrs.sorted, attrs.return_counts) + + return _compute_unique + + +@override_native_generic_func("unique_strategy") +def unique_strategy(attrs, inputs, out_type, target): + """unique generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_unique(topi.unique), + wrap_topi_schedule(topi.generic.schedule_unique), + name="unique.generic", + ) + return strategy diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 3583c5227db4..f34e062cd9e6 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1502,3 +1502,57 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): -> [1, 1, 2, 2, 3, 4, 4] """ return _make.cumsum(data, axis, dtype, exclusive) + + +def unique(data, is_sorted=True, return_counts=False): + """ + Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to + have the same length of `data` and element with index >= num_unique[0] has undefined value. + + Parameters + ---------- + data : relay.Expr + A 1-D tensor of integers. + + sorted : bool + Whether to sort the unique elements in ascending order before returning as output. + + return_counts : bool + Whether to return the count of each unique element. + + Returns + ------- + output : relay.Expr + A 1-D tensor containing the unique elements of the input data tensor. + + indices : relay.Expr + A 1-D tensor containing the index of each data element in the output tensor. + + num_unique : relay.Expr + A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. + + counts (optional) : relay.Expr + A 1-D tensor containing the count of each unique element in the output. + + Examples + -------- + .. code-block:: python + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + + [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] + + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] + """ + if return_counts: + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3) diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index f30e20c31281..b47facfafd2d 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -44,6 +44,7 @@ from .interpolate import * from .cumsum import * from .einsum import * +from .unique import * from . import generic from . import nn from . import x86 diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index bf3582c01d4f..df75c676fad3 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -58,3 +58,4 @@ from . import tensorcore_alter_op from .argwhere import * from .scan import * +from .unique import * diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py new file mode 100644 index 000000000000..02a5cf3bc592 --- /dev/null +++ b/python/tvm/topi/cuda/unique.py @@ -0,0 +1,396 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Unique operator""" +import tvm +from tvm import te, tir +from ...te import hybrid +from .scan import cumsum +from .sort import sort, argsort +from ..utils import ceil_div + + +def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): + """Low level IR to calculate adjacent difference in an 1-D array. + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + output: Buffer + A buffer to store adjacent difference, of the same shape as data. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). + + binop: function, optional + A binary associative op to use for calculating adjacent difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. + """ + ib = tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + output_ptr = ib.buffer_ptr(output) + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + output_ptr[tid] = 0 + with ib.else_scope(): + output_ptr[tid] = tir.Cast(output.dtype, binop(data_ptr[tid], data_ptr[tid - 1])) + return ib.get() + + +def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): + """Function calculate adjacent difference in an 1-D array. + + Parameters + ---------- + data : tvm.te.Tensor + Input 1-D tensor. + + output_dtype : str + The output tensor data type. + + binop: function, optional + A binary associative op to use for calculating difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor storing the adjacent difference of the input tensor. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). + """ + data_buf = tir.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) + output_buf = tir.decl_buffer(data.shape, out_dtype, "output_buf", data_alignment=8) + return te.extern( + [data.shape], + [data], + lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0], binop=binop), + dtype=[out_dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="_calc_adjacent_diff", + tag="_calc_adjacent_diff_gpu", + ) + + +@hybrid.script +def _calc_num_unique(inc_scan): + """Helper function to get the number of unique elements fron inc_scan tensor""" + output = output_tensor((1,), "int32") + for i in bind("threadIdx.x", 1): + output[i] = inc_scan[inc_scan.shape[0] - 1] + int32(1) + return output + + +def _calc_unique_ir( + data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts +): + """Low level IR to calculate unique elements, inverse indices, and counts (optional) of + unique elements of 1-D array. + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + argsorted_indices : Buffer + A buffer that stores the argsorted indices of the input data. + + inc_scan : Buffer + A buffer that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. + + index_converter (optional) : Buffer + An optional index converter that transforms the unique element index + such that new_idx = index_converter[old_idx]. + + unique_elements : Buffer + A buffer that stores the unique elements. + + indices : Buffer + A buffer that stores the the index of each input data element in the unique element array. + + counts (optional) : Buffer + A buffer that stores the count of each unique element. + """ + ib = tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + unique_elements_ptr = ib.buffer_ptr(unique_elements) + indices_ptr = ib.buffer_ptr(indices) + + index_converter_ptr = None + if isinstance(index_converter, tir.Buffer): + index_converter_ptr = ib.buffer_ptr(index_converter) + + if isinstance(counts, tir.Buffer): + counts_ptr = ib.buffer_ptr(counts) + # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] + unique_seq_indices_ptr = ib.buffer_ptr(indices) + + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + + # if need to return counts + if isinstance(counts, tir.Buffer): + num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1 + num_elements = data.shape[0] + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + unique_seq_indices_ptr[num_unique - 1] = num_elements + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + unique_seq_indices_ptr[inc_scan_ptr[tid] - 1] = tid + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < num_unique): + unique_idx = tid if not index_converter_ptr else index_converter_ptr[tid] + with ib.if_scope(tid == 0): + counts_ptr[unique_idx] = unique_seq_indices_ptr[tid] + with ib.else_scope(): + counts_ptr[unique_idx] = ( + unique_seq_indices_ptr[tid] - unique_seq_indices_ptr[tid - 1] + ) + # calculate unique elements and inverse indices + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + data_idx = argsorted_indices_ptr[tid] + unique_idx = ( + inc_scan_ptr[tid] + if not index_converter_ptr + else index_converter_ptr[inc_scan_ptr[tid]] + ) + indices_ptr[data_idx] = unique_idx + with ib.if_scope(tid == 0): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + return ib.get() + + +def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence): + """Low level IR to calculate the first occurence of each unique element in the input data. + + Parameters + ---------- + argsorted_indices : Buffer + A buffer that stores the argsorted indices of the input data. + + inc_scan : Buffer + A buffer that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. + + first_occurence : Buffer + A buffer that stores the first occurence of each unique element in the input data. + """ + ib = tir.ir_builder.create() + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + first_occurence_ptr = ib.buffer_ptr(first_occurence) + batch_size = argsorted_indices.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + first_occurence_ptr[tid] = batch_size + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid] + return ib.get() + + +def unique(data, is_sorted=True, return_counts=False): + """ + Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to + have the same length of `data` and element with index >= num_unique[0] has undefined value. + + Parameters + ---------- + data : tvm.te.Tensor + A 1-D tensor of integers. + + sorted : bool + Whether to sort the unique elements in ascending order before returning as output. + + return_counts : bool + Whether to return the count of each unique element. + + Returns + ------- + output : tvm.te.Tensor + A 1-D tensor containing the unique elements of the input data tensor. + + indices : tvm.te.Tensor + A 1-D tensor containing the index of each data element in the output tensor. + + num_unique : tvm.te.Tensor + A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. + + counts (optional) : tvm.te.Tensor + A 1-D tensor containing the count of each unique element in the output. + + Examples + -------- + .. code-block:: python + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + + [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] + + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] + """ + sorted_data = sort(data) + argsorted_indices = argsort(data, dtype="int32") + # adjacent difference + adjacent_diff = _calc_adjacent_diff(sorted_data, out_dtype="int32", binop=tir.NE) + # inclusive scan + inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) + # total number of unique elements + num_unique_elements = _calc_num_unique(inc_scan) + # buffers + data_buf = tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + argsorted_indices_buf = tir.decl_buffer( + data.shape, "int32", "argsorted_indices_buf", data_alignment=8 + ) + inc_scan_buf = tvm.tir.decl_buffer(data.shape, "int32", "inc_scan_buf", data_alignment=8) + unique_elements_buf = tir.decl_buffer( + data.shape, data.dtype, "unique_elements_buf", data_alignment=8 + ) + inverse_indices_buf = tvm.tir.decl_buffer( + data.shape, "int32", "inverse_indices_buf", data_alignment=8 + ) + # prepare outputs + if return_counts: + counts_buf = tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8) + out_data_shape = [data.shape] * 3 + out_buffers = [unique_elements_buf, inverse_indices_buf, counts_buf] + out_dtypes = [data.dtype, "int32", "int32"] + else: + out_data_shape = [data.shape] * 2 + out_buffers = [unique_elements_buf, inverse_indices_buf] + out_dtypes = [data.dtype, "int32"] + # prepare inputs and fcompute + if is_sorted: + in_data = [data, argsorted_indices, inc_scan] + in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf] + if return_counts: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs) + else: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None) + else: + # calculate the index converter if the unique elements should not be sorted + # calculate first occurence + first_occurence_buf = tir.decl_buffer( + data.shape, "int32", "first_occurence_buf", data_alignment=8 + ) + first_occurence = te.extern( + [data.shape], + [argsorted_indices, inc_scan], + lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]), + dtype=["int32"], + in_buffers=[argsorted_indices_buf, inc_scan_buf], + out_buffers=[first_occurence_buf], + name="_calc_first_occurence", + tag="_calc_first_occurence_gpu", + ) + # calculate index converter by sorting unique elements by their first occurence + argsorted_first_occurence = argsort(first_occurence, dtype="int32") + index_converter = argsort(argsorted_first_occurence, dtype="int32") + index_converter_buf = tir.decl_buffer( + data.shape, "int32", "index_converter_buf", data_alignment=8 + ) + in_data = [data, argsorted_indices, inc_scan, index_converter] + in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf, index_converter_buf] + if return_counts: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs) + else: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None) + outs = te.extern( + out_data_shape, + in_data, + fcompute, + dtype=out_dtypes, + in_buffers=in_buffers, + out_buffers=out_buffers, + name="_calc_unique", + tag="_calc_unique_gpu", + ) + if return_counts: + return [outs[0], outs[1], num_unique_elements, outs[2]] + return [*outs, num_unique_elements] diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index 603f79330ebb..4f88c3e3a2b8 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -86,3 +86,19 @@ def schedule_interpolate(outs): def schedule_sparse_fill_empty_rows(outs): return _default_schedule(outs, False) + + +def schedule_unique(outs): + """Schedule for unique operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of unique. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py new file mode 100644 index 000000000000..b4f27b38f65f --- /dev/null +++ b/python/tvm/topi/unique.py @@ -0,0 +1,297 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Unique operator""" +from tvm import te, tir +from ..te import hybrid +from .cumsum import cumsum +from .sort import sort, argsort + + +def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): + """Low level IR to calculate adjacent difference in an 1-D array. + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + output: Buffer + A buffer to store adjacent difference, of the same shape as data. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). + + binop: function, optional + A binary associative op to use for calculating adjacent difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. + """ + ib = tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + output_ptr = ib.buffer_ptr(output) + with ib.for_range(0, data.shape[0], kind="parallel") as i: + with ib.if_scope(i == 0): + output_ptr[0] = 0 + with ib.else_scope(): + output_ptr[i] = tir.Cast(output.dtype, binop(data_ptr[i], data_ptr[i - 1])) + return ib.get() + + +def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): + """Function calculate adjacent difference in an 1-D array. + + Parameters + ---------- + data : tvm.te.Tensor + Input 1-D tensor. + + output_dtype : str + The output tensor data type. + + binop: function, optional + A binary associative op to use for calculating difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor storing the adjacent difference of the input tensor. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). + """ + return te.extern( + [data.shape], + [data], + lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0], binop=binop), + dtype=[out_dtype], + name="_calc_adjacent_diff", + tag="_calc_adjacent_diff_cpu", + ) + + +@hybrid.script +def _calc_num_unique(inc_scan): + """Helper function to get the number of unique elements fron inc_scan tensor""" + output = output_tensor((1,), "int32") + output[0] = inc_scan[inc_scan.shape[0] - 1] + int32(1) + return output + + +def _calc_unique_ir( + data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts +): + """Low level IR to calculate unique elements, inverse indices, and counts (optional) of + unique elements of 1-D array. + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + argsorted_indices : Buffer + A buffer that stores the argsorted indices of the input data. + + inc_scan : Buffer + A buffer that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. + + index_converter (optional) : Buffer + An optional index converter that transforms the unique element index + such that new_idx = index_converter[old_idx]. + + unique_elements : Buffer + A buffer that stores the unique elements. + + indices : Buffer + A buffer that stores the the index of each input data element in the unique element array. + + counts (optional) : Buffer + A buffer that stores the count of each unique element. + """ + ib = tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + unique_elements_ptr = ib.buffer_ptr(unique_elements) + indices_ptr = ib.buffer_ptr(indices) + + index_converter_ptr = None + if isinstance(index_converter, tir.Buffer): + index_converter_ptr = ib.buffer_ptr(index_converter) + + if isinstance(counts, tir.Buffer): + counts_ptr = ib.buffer_ptr(counts) + # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] + unique_seq_indices_ptr = ib.buffer_ptr(indices) + + data_length = data.shape[0] + + # if need to return counts + if isinstance(counts, tir.Buffer): + num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1 + num_elements = data.shape[0] + unique_seq_indices_ptr[num_unique - 1] = num_elements + with ib.new_scope(): + with ib.for_range(0, data_length, kind="parallel") as i: + with ib.if_scope(i > 0): + with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]): + unique_seq_indices_ptr[inc_scan_ptr[i] - 1] = i + with ib.new_scope(): + with ib.for_range(0, num_unique, kind="parallel") as i: + unique_idx = i if not index_converter_ptr else index_converter_ptr[i] + with ib.if_scope(i == 0): + counts_ptr[unique_idx] = unique_seq_indices_ptr[i] + with ib.else_scope(): + counts_ptr[unique_idx] = ( + unique_seq_indices_ptr[i] - unique_seq_indices_ptr[i - 1] + ) + # calculate unique elements and inverse indices + with ib.new_scope(): + with ib.for_range(0, data_length, kind="parallel") as i: + data_idx = argsorted_indices_ptr[i] + unique_idx = ( + inc_scan_ptr[i] if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[i]] + ) + indices_ptr[data_idx] = unique_idx + with ib.if_scope(i == 0): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + return ib.get() + + +@hybrid.script +def _calc_first_occurence(argsorted_indices, inc_scan): + """Hybrid script to calculate the first occurence of each unique element in the input data. + + Parameters + ---------- + argsorted_indices : tvm.te.Tensor + A tensor that stores the argsorted indices of the input data. + + inc_scan : tvm.te.Tensor + A tensor that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. + + first_occurence : tvm.te.Tensor + A tensor that stores the first occurence of each unique element in the input data. + """ + first_occurence = output_tensor(argsorted_indices.shape, "int32") + for i in parallel(argsorted_indices.shape[0]): + first_occurence[i] = argsorted_indices.shape[0] + for i in parallel(argsorted_indices.shape[0]): + if i == 0 or inc_scan[i] != inc_scan[i - 1]: + first_occurence[inc_scan[i]] = argsorted_indices[i] + return first_occurence + + +def unique(data, is_sorted=True, return_counts=False): + """ + Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to + have the same length of `data` and element with index >= num_unique[0] has undefined value. + + Parameters + ---------- + data : tvm.te.Tensor + A 1-D tensor of integers. + + sorted : bool + Whether to sort the unique elements in ascending order before returning as output. + + return_counts : bool + Whether to return the count of each unique element. + + Returns + ------- + output : tvm.te.Tensor + A 1-D tensor containing the unique elements of the input data tensor. + + indices : tvm.te.Tensor + A 1-D tensor containing the index of each data element in the output tensor. + + num_unique : tvm.te.Tensor + A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. + + counts (optional) : tvm.te.Tensor + A 1-D tensor containing the count of each unique element in the output. + + Examples + -------- + .. code-block:: python + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + + [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] + + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] + """ + sorted_data = sort(data) + argsorted_indices = argsort(data, dtype="int32") + # adjacent difference + adjacent_diff = _calc_adjacent_diff(sorted_data, "int32", tir.NE) + # inclusive scan + inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) + # total number of unique elements + num_unique_elements = _calc_num_unique(inc_scan) + # prepare outputs + if return_counts: + out_data_shape = [data.shape] * 3 + out_dtypes = [data.dtype, "int32", "int32"] + else: + out_data_shape = [data.shape] * 2 + out_dtypes = [data.dtype, "int32"] + # prepare inputs and fcompute + if is_sorted: + in_data = [data, argsorted_indices, inc_scan] + if return_counts: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs) + else: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None) + else: + # calculate the index converter if the unique elements should not be sorted + # calculate first occurence + first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) + # calculate index converter by sorting unique elements by their first occurence + argsorted_first_occurence = argsort(first_occurence, dtype="int32") + index_converter = argsort(argsorted_first_occurence, dtype="int32") + in_data = [data, argsorted_indices, inc_scan, index_converter] + if return_counts: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs) + else: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None) + outs = te.extern( + out_data_shape, + in_data, + fcompute, + dtype=out_dtypes, + name="_calc_unique", + tag="_calc_unique_cpu", + ) + if return_counts: + return [outs[0], outs[1], num_unique_elements, outs[2]] + return [*outs, num_unique_elements] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4876598c342c..65e655d2e701 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3824,5 +3824,52 @@ RELAY_REGISTER_OP("cumsum") .add_type_rel("Cumsum", CumsumRel) .set_attr("TOpPattern", kOpaque); +TVM_REGISTER_NODE_TYPE(UniqueAttrs); + +bool UniqueRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + ICHECK_EQ(types.size(), 2) << "Unique: expect 2 types but " << types.size() << " provided"; + ICHECK_EQ(num_inputs, 1) << "Unique: expect 1 inputs but " << num_inputs << " provided"; + auto data = types[0].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "Unique: expect input type to be TensorType but get " << types[0]; + return false; + } + const int ndim = static_cast(data->shape.size()); + ICHECK_EQ(ndim, 1) << "Unique: input must be 1-D tensor"; + ICHECK_EQ(data->dtype.is_int(), true) << "Unique: input must have int32 or int64 dtype"; + std::vector fields; + fields.push_back(TensorType(data->shape, data->dtype)); // unique + fields.push_back(TensorType(data->shape, DataType::Int(32))); // indices + fields.push_back(TensorType(Array{1}, DataType::Int(32))); // num_unique + const auto* param = attrs.as(); + if (param->return_counts) { + fields.push_back(TensorType(data->shape, DataType::Int(32))); // counts + } + reporter->Assign(types[1], TupleType(Array(fields))); + return true; +} + +Expr MakeUnique(Expr data, bool sorted, bool return_counts) { + auto attrs = make_object(); + attrs->sorted = sorted; + attrs->return_counts = return_counts; + static const Op& op = Op::Get("unique"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.unique").set_body_typed(MakeUnique); + +RELAY_REGISTER_OP("unique") + .describe( + R"code(This operation returns the unique elements and the new index of each item in a given 1-D array. + )code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .add_type_rel("unique", UniqueRel) + .set_support_level(3) + .set_attr("TOpPattern", kOpaque); } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index aa42b0fb84e4..0cf4839c6ebb 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2064,7 +2064,12 @@ def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=["llv pt_result = input_model(*input_data) # Verify the accuracy - if not isinstance(pt_result, torch.Tensor): + if isinstance(pt_result, tuple): + # handle multiple outputs + for i in range(len(pt_result)): + tvm_res = vm_res[i].asnumpy() + tvm.testing.assert_allclose(tvm_res, pt_result[i].numpy(), rtol=1e-5, atol=1e-5) + elif not isinstance(pt_result, torch.Tensor): tvm_res = vm_res.asnumpy().item() assert pt_result == tvm_res else: @@ -3654,6 +3659,23 @@ def test_fn(x, mask): verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"]) +def test_unique(): + def test_fn(is_sorted, return_inverse, return_counts): + return lambda x: torch.unique(x, is_sorted, return_inverse, return_counts) + + in_data = torch.randint(0, 20, (10,), dtype=torch.int32) + targets = ["llvm", "cuda", "nvptx"] + verify_trace_model(test_fn(True, True, True), [in_data], targets) + verify_trace_model(test_fn(True, False, True), [in_data], targets) + verify_trace_model(test_fn(True, True, False), [in_data], targets) + verify_trace_model(test_fn(True, False, True), [in_data], targets) + in_data = torch.randint(0, 20, (20,), dtype=torch.int64) + verify_trace_model(test_fn(True, True, True), [in_data], targets) + verify_trace_model(test_fn(True, False, True), [in_data], targets) + verify_trace_model(test_fn(True, True, False), [in_data], targets) + verify_trace_model(test_fn(True, False, True), [in_data], targets) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3789,6 +3811,7 @@ def test_fn(x, mask): test_argsort() test_logical_and() test_masked_select() + test_unique() # Model tests test_resnet18() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index da8e26018f8c..58ba3561c9df 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -5004,5 +5004,70 @@ def lstm_cell(): tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) +####################################################################### +# Unique +# ------------ + + +def _test_unique(n, dtype, is_dyn): + tf.reset_default_graph() + np_data = np.random.randint(100, size=n).astype(dtype) + with tf.Graph().as_default(): + if is_dyn: + in_data = tf.placeholder(dtype, [n], name="in_data") + else: + in_data = tf.constant(np_data, dtype, name="in_data") + tf.unique(in_data) + if is_dyn: + compare_tf_with_tvm(np_data, "in_data:0", ["Unique:0", "Unique:1"], mode="vm") + else: + compare_tf_with_tvm(None, "", ["Unique:0", "Unique:1"]) + + +def test_forward_unique(): + """test Unique""" + + for dtype in ["int32", "int64"]: + for is_dyn in [False, True]: + _test_unique(50, dtype, is_dyn) + _test_unique(100, dtype, is_dyn) + + +####################################################################### +# Unique with counts +# ------------ + + +def _test_unique_with_counts(n, dtype, is_dyn): + tf.reset_default_graph() + np_data = np.random.randint(100, size=n).astype(dtype) + with tf.Graph().as_default(): + if is_dyn: + in_data = tf.placeholder(dtype, [n], name="in_data") + else: + in_data = tf.constant(np_data, dtype, name="in_data") + tf.unique_with_counts(in_data) + if is_dyn: + compare_tf_with_tvm( + np_data, + "in_data:0", + ["UniqueWithCounts:0", "UniqueWithCounts:1", "UniqueWithCounts:2"], + mode="vm", + ) + else: + compare_tf_with_tvm( + None, "", ["UniqueWithCounts:0", "UniqueWithCounts:1", "UniqueWithCounts:2"] + ) + + +def test_forward_unique_with_counts(): + """test UniqueWithCounts""" + + for dtype in ["int32", "int64"]: + for is_dyn in [False, True]: + _test_unique_with_counts(10, dtype, is_dyn) + _test_unique_with_counts(20, dtype, is_dyn) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index e5e4e2f8c77f..d78becfd190a 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1485,5 +1485,58 @@ def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, verify_scatter_nd_with_stack(data, indices, shape, out) +def test_unique(): + def calc_numpy_unique(data, is_sorted=False): + uniq, index, inverse, counts = np.unique( + data, return_index=True, return_inverse=True, return_counts=True + ) + num_uniq = np.array([len(uniq)]).astype("int32") + if not is_sorted: + order = np.argsort(index) + reverse_order = np.argsort(order) + uniq = uniq[order].astype(data.dtype) + inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") + counts = counts[order].astype("int32") + return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] + + def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): + if is_dyn: + x = relay.var("x", relay.TensorType([relay.Any()], dtype)) + else: + x = relay.var("x", relay.TensorType([n], dtype)) + outs = relay.unique(x, is_sorted, return_counts) + outs = outs.astuple() + func = relay.Function([x], outs) + x_data = np.random.randint(50, size=n).astype(dtype) + + if is_dyn: + backends = ["vm", "debug"] + else: + backends = ["graph", "debug"] + + for target, ctx in tvm.testing.enabled_targets(): + for kind in backends: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + tvm_res = intrp.evaluate()(x_data) + np_res = calc_numpy_unique(x_data, is_sorted) + num_unique = np_res[3][0] + assert num_unique == tvm_res[2].asnumpy()[0] + # unique + tvm.testing.assert_allclose(tvm_res[0].asnumpy()[:num_unique], np_res[0], rtol=1e-5) + # inverse_indices + tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5) + # counts + if return_counts: + tvm.testing.assert_allclose( + tvm_res[3].asnumpy()[:num_unique], np_res[2], rtol=1e-5 + ) + + for dtype in ["int32", "int64"]: + for i in range(8): + is_dyn, is_sorted, return_counts = bool(i & 1), bool(i & 2), bool(i & 4) + verify_unique(10, dtype, is_dyn, is_sorted, return_counts) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py new file mode 100644 index 000000000000..d7ee74282922 --- /dev/null +++ b/tests/python/topi/python/test_topi_unique.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +from tvm import topi +import tvm.topi.testing + + +@tvm.testing.parametrize_targets +def test_unique(ctx, target): + def calc_numpy_unique(data, is_sorted=False): + uniq, index, inverse, counts = np.unique( + data, return_index=True, return_inverse=True, return_counts=True + ) + num_uniq = np.array([len(uniq)]).astype("int32") + if not is_sorted: + order = np.argsort(index) + reverse_order = np.argsort(order) + uniq = uniq[order].astype(data.dtype) + inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") + counts = counts[order].astype("int32") + return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] + + def check_unique(data, is_sorted=False): + # numpy reference + np_unique, np_indices, np_counts, np_num_unique = calc_numpy_unique(data, is_sorted) + num_unique = np_num_unique[0] + + implementations = { + "generic": ( + lambda x, return_counts: topi.unique(x, is_sorted, return_counts), + topi.generic.schedule_unique, + ), + "cuda": ( + lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts), + topi.cuda.schedule_scan, + ), + "nvptx": ( + lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts), + topi.cuda.schedule_scan, + ), + } + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + tvm_data = tvm.nd.array(data, ctx=ctx) + tvm_unique = tvm.nd.array(np.zeros(data.shape).astype(data.dtype), ctx=ctx) + tvm_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), ctx=ctx) + tvm_num_unique = tvm.nd.array(np.zeros([1]).astype("int32"), ctx=ctx) + + # without counts + with tvm.target.Target(target): + te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) + outs = fcompute(te_input, False) + s = fschedule(outs) + func = tvm.build(s, [te_input, *outs]) + func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique) + + assert tvm_num_unique.asnumpy()[0] == np_num_unique + np.testing.assert_allclose( + tvm_unique.asnumpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5 + ) + np.testing.assert_allclose(tvm_indices.asnumpy(), np_indices, atol=1e-5, rtol=1e-5) + + # with counts + tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), ctx=ctx) + with tvm.target.Target(target): + te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) + outs = fcompute(te_input, True) + s = fschedule(outs) + func = tvm.build(s, [te_input, *outs]) + func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique, tvm_counts) + + np_unique, np_indices, _, np_num_unique = calc_numpy_unique(data, is_sorted) + num_unique = np_num_unique[0] + assert tvm_num_unique.asnumpy()[0] == np_num_unique + np.testing.assert_allclose( + tvm_unique.asnumpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5 + ) + np.testing.assert_allclose(tvm_indices.asnumpy(), np_indices, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose( + tvm_counts.asnumpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5 + ) + + for in_dtype in ["int32", "int64"]: + for is_sorted in [True, False]: + data = np.random.randint(0, 100, size=(1)).astype(in_dtype) + check_unique(data, is_sorted) + data = np.random.randint(0, 10, size=(10)).astype(in_dtype) + check_unique(data, is_sorted) + data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) + check_unique(data, is_sorted) + + +if __name__ == "__main__": + test_unique(tvm.context("cpu"), tvm.target.Target("llvm")) + test_unique(tvm.context("cuda"), tvm.target.Target("cuda")) + test_unique(tvm.context("nvptx"), tvm.target.Target("nvptx"))