From e5f616ba669f86e6e3100159707e0df480804ca5 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 3 Nov 2020 16:53:06 -0800 Subject: [PATCH 01/13] [RELAY,TOPI] Add scatter_nd op Scatter_nd is the inverse of gather_nd and also happens to be its gradient. The implementation here is not optimized. There are no cpu or gpu specific implementations. --- include/tvm/relay/attrs/transform.h | 8 ++ python/tvm/relay/backend/compile_engine.py | 3 +- python/tvm/relay/op/_tensor_grad.py | 14 +++ python/tvm/relay/op/_transform.py | 9 ++ python/tvm/relay/op/strategy/generic.py | 19 ++++ python/tvm/relay/op/transform.py | 24 +++++ python/tvm/te/operation.py | 2 +- python/tvm/testing.py | 29 ++++++ python/tvm/topi/scatter.py | 99 ++++++++++++++++++- src/relay/analysis/type_solver.cc | 9 +- src/relay/op/tensor/transform.cc | 68 +++++++++++++ tests/python/relay/test_op_grad_level3.py | 18 ++++ tests/python/topi/python/test_topi_scatter.py | 48 +++++++++ 13 files changed, 343 insertions(+), 7 deletions(-) create mode 100644 tests/python/topi/python/test_topi_scatter.py diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a7830cf61647..3ed6b8352845 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -129,6 +129,14 @@ struct ScatterAddAttrs : public tvm::AttrsNode { } }; +struct ScatterNDAttrs : public tvm::AttrsNode { + Array out_shape; + + TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") { + TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter."); + } +}; + struct GatherAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index d874732d6fa0..14e1f5d85d9f 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -121,7 +121,8 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): The list of all valid op implementations. """ fstrategy = op.get_attr("FTVMStrategy") - assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name + assert fstrategy is not None, "%s doesn't have an FTVMStrategy registered. You can register " \ + "one in python with `tvm.relay.op.register_strategy`." % op.name with target: strategy = fstrategy(attrs, inputs, out_type, target) analyzer = tvm.arith.Analyzer() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index b070d9f5b3ff..b200aa125270 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -62,6 +62,8 @@ squeeze, strided_set, arange, + gather_nd, + scatter_nd, ) @@ -803,3 +805,15 @@ def arange_grad(orig, grad): grad_step = cast_like(_sum(grad_step), step) return [grad_start, grad_stop, grad_step] + + +@register_gradient("gather_nd") +def gather_nd_grad(orig, grad): + data, indices = orig.args + return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)] + + +# @register_gradient("scatter_nd") +# def scatter_nd_grad(orig, grad): +# data, indices = orig.args +# return [gather_nd(grad, indices), zeros_like(indices)] diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e42b8bbae814..61e3dd7a10d4 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -115,6 +115,15 @@ def compute_scatter_add(attrs, inputs, output_type): _reg.register_strategy("scatter_add", strategy.scatter_add_strategy) +# scatter +@_reg.register_compute("scatter_nd") +def compute_scatter_nd(attrs, inputs, output_type): + """Compute definition of scatter_nd""" + return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)] + + +_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy) + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index e49135c4d1bf..54e1f9a9fe2b 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1062,6 +1062,25 @@ def scatter_add_strategy(attrs, outs, out_type, target): ) return strategy +# scatter_nd +@override_native_generic_func("scatter_nd_strategy") +def scatter_nd_strategy(attrs, inputs, out_type, target): + """scatter_nd generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_nd(topi.scatter_nd), wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_nd.generic", + ) + return strategy + +def wrap_compute_scatter_nd(topi_compute): + """Wrap scatter_nd topi compute""" + + def _compute_scatter_nd(attrs, inputs, _): + return [topi_compute(inputs[0], inputs[1], attrs.out_shape)] + + return _compute_scatter_nd + # bitserial_conv2d def wrap_compute_bitserial_conv2d(topi_compute): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index a3f97392e36e..7ee26e56c435 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -308,6 +308,30 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) +def scatter_nd(data, indices, out_shape): + """Scatter values from an array. + + See :py:func:`tvm.topi.scatter` for how data is scattered. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + out_shape : relay.Expr + Output shape of the scatter. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.scatter_nd(data, indices, out_shape) + + def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None): """Reshapes the input tensor by the size of another tensor. For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 30d0df382c27..a924c8b0c0db 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -317,7 +317,7 @@ def extern( if isinstance(body, tvm.tir.PrimExpr): body = tvm.tir.Evaluate(body) if not isinstance(body, tvm.tir.Stmt): - raise ValueError("Function '{}' should return PrimExpr or Stmt".format(fcompute.__name__)) + raise ValueError("Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(fcompute.__name__, type(body))) op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body) res = [op.output(i) for i in range(len(output_placeholders))] diff --git a/python/tvm/testing.py b/python/tvm/testing.py index e5b17f3d7b53..7286aa4bbbd9 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -714,4 +714,33 @@ def func(f): return wrap(args) +def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule): + """Compare a numpy inputs and output of a function to the results of the TVM version. + + Parameters + ---------- + inputs : Sequence[numpy.nd.array] + List of input numpy arrays to pass to the function. + output : numpy.nd.array + Verified correct function output. + target : tvm.target.Target + Target to run on. + ctx : tvm.TVMContext + Context to run on. + compute : callable + Topi compute function to test against. + schedule : callable + Topi scheduling function to test against. + """ + te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs] + te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx) + with tvm.target.Target(target): + out = compute(*te_inputs) + s = schedule([out]) + func = tvm.build(s, te_inputs + [out]) + arys = [tvm.nd.array(x, ctx=ctx) for x in inputs] + func(*(arys + [te_out])) + assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4) + + tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index f1c307a43a44..382a91e790c6 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,7 +16,9 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from tvm.te import hybrid +from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate +from ..te import extern, hybrid +from . import full @hybrid.script @@ -196,3 +198,98 @@ def scatter(data, indices, updates, axis=0): if len(data.shape) == 4: return _scatter_4d(data, indices, updates, axis) raise ValueError("scatter only support for 1-4 dimensions") + + +def scatter_nd(data, indices, shape): + """Scatter elements from a n-dimension array. + + Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + + .. code-block:: + + output[indices[0, y_0, ..., y_{K-1}], + ..., + indices[M-1, y_0, ..., y_{K-1}], + x_M, + ..., + x_{N-1} + ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + + all other entries in the output are 0. Repeated indices are summed. + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + shape : Sequence[int] + The output shape. This must be specified because it cannot be inferred. + + Returns + ------- + ret : tvm.te.Tensor + """ + assert indices.shape[0] <= len(shape), f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to the length of the shape of the output ({len(shape)})." + for i in range(len(indices.shape)-1): + assert indices.shape[i+1] == data.shape[i], f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of data[{i}] ({data.shape[i]})." + for i in range(int(indices.shape[0]), len(shape)): + assert data.shape[i] == out_shape[i], f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" + + assert "int" in indices.dtype, f"Indices must be a tensor of integers, but its elements are {indices.dtype}" + + + def gen_ir(data_ptr, indices_ptr, out_ptr): + ib = ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + out = ib.buffer_ptr(out_ptr) + + # zero data + # TODO(tkonolige): could we use topi.full to zero it instead? + fused_shape = 1 + for i in shape: + fused_shape *= i + with ib.for_range(0, fused_shape): + out[i] = Cast(data_ptr.dtype, 0) + + # We combine all the indices dimensions but the first one into a single + # dimension so we can iterate it in single loop instead of an arbitrary + # number of loops. We do the same thing for all the data dimensions. + fused_indices_dimension = 1 + for i in indices_ptr.shape[1:]: + fused_indices_dimension *= i + + fused_data_dimension = 1 + for i in data_ptr.shape[indices_ptr.shape[0].value :]: + fused_data_dimension *= i + + with ib.for_range(0, fused_indices_dimension) as i: + with ib.for_range(0, fused_data_dimension) as j: + offset = fused_data_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + ib.emit(AssertStmt(indices[i + l * fused_indices_dimension] < shape[l], StringImm("index out of bounds"), Evaluate(0))) + offset *= shape[l] + out[index] = data[i * fused_data_dimension + j] + + return ib.get() + + out_buf = decl_buffer(shape, data.dtype, "out_buf") + return extern( + [shape], + [data, indices], + lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_nd_generic", + tag="scatter_nd_generic", + ) diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 8f14b557dc54..64db13acbac0 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -246,7 +246,7 @@ class TypeSolver::Unifier : public TypeFunctor { for (size_t i = 0; i < tt1->shape.size(); i++) { auto dim = UnifyDim(tt1->shape[i], tt2->shape[i]); if (!dim.defined()) { - // NB: We push an arbitrary dimension here so we can continue error propogation. + // NB: We push an arbitrary dimension here so we can continue error propagation. shape.push_back(tt1->shape[i]); tvm::PrimExpr shape1 = tt1->shape[i]; tvm::PrimExpr shape2 = tt2->shape[i]; @@ -259,10 +259,11 @@ class TypeSolver::Unifier : public TypeFunctor { if (mismatches.size() != 0) { auto err = Diagnostic::Error(this->span); - err << "in particular "; + err << "The Relay type checker is unable to show the following types match.\n"; + err << "In particular "; for (auto mismatch : mismatches) { - err << "dimension " << std::get<0>(mismatch) << " conflicts " << std::get<1>(mismatch) - << " does not match " << std::get<2>(mismatch); + err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) + << " does not match " << std::get<2>(mismatch) << "."; } this->solver_->diag_ctx_.Emit(err); return Type(nullptr); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 71f88b2f258e..a7a8024999d9 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -977,6 +977,74 @@ RELAY_REGISTER_OP("scatter_add") .set_attr("TOpPattern", kOpaque) .set_support_level(10); +// scatter_nd operator +TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); + +bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, result] + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* indices = types[1].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "ScatterND: expect input data type to be TensorType but got " << types[0]; + return false; + } + if (indices == nullptr) { + ICHECK(types[1].as()) + << "ScatterND: expect indices type to be TensorType but got " << types[1]; + return false; + } + ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers."; + const auto out_shape = attrs.as()->out_shape; + const IntImmNode* mdim = indices->shape[0].as(); + const size_t kdim = indices->shape.size() - 1; + const size_t ndim = out_shape.size(); + ICHECK_LE(size_t(mdim->value), ndim) + << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " + "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N."; + // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. + for (size_t i = 0; i < kdim; i++) { + reporter->AssertEQ(indices->shape[i + 1], data->shape[i]); + } + + std::vector oshape; + for (auto& x : out_shape) { + oshape.push_back(x); + } + + // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} + for (size_t i = mdim->value; i < ndim; i++) { + reporter->AssertEQ(data->shape[i], oshape[i]); + } + + reporter->Assign(types[2], TensorType(oshape, data->dtype)); + return true; +} + +Expr MakeScatterND(Expr data, Expr indices, const Array out_shape) { + auto attrs = make_object(); + attrs->out_shape = out_shape; + static const Op& op = Op::Get("scatter_nd"); + return Call(op, {data, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND); + +RELAY_REGISTER_OP("scatter_nd") + .describe(R"code(Scatter elements or slices from data and store to a tensor +whose shape is defined by indices. + +Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape +(M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}). +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("ScatterND", ScatterNDRel) + .set_attr("TOpPattern", kInjective); + // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 9c27afd87205..358441738ece 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -117,5 +117,23 @@ def test_arange_grad(): check_grad(fwd_func, inputs=values) +def test_gather_nd_grad(): + data = relay.var("data", relay.TensorType((2, 3), "float64")) + indices = relay.var("indices", relay.TensorType((2, 4), "int64")) + fwd = relay.Function([data, indices], relay.gather_nd(data, indices)) + data_np = np.random.rand(2, 3) + indices_np = np.array([[0, 2, 1, 0], [0, 1, 2, 1]]) + check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np]) + + +# def test_scatter_nd_grad(): +# data = relay.var("data", relay.TensorType((2, 2), "float64")) +# indices = relay.var("indices", relay.TensorType((2, 2), "int64")) +# fwd = relay.Function([data, indices], relay.scatter_nd(data, indices, (2, 2))) +# data_np = np.array([[0, 1], [2, 3]]).astype("float64") +# indices_np = np.array([[1, 0], [1, 1]]) +# check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np]) + + if __name__ == "__main__": pytest.main() diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py new file mode 100644 index 000000000000..c8c5d0f40004 --- /dev/null +++ b/tests/python/topi/python/test_topi_scatter.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +from tvm import topi +import tvm.topi.testing + + +@tvm.testing.parametrize_targets +def test_scatter_nd(ctx, target): + def check_scatter_nd(data, indices, shape, out): + implementations = { + "generic": (lambda x,y: topi.scatter_nd(x,y,shape), topi.generic.schedule_extern), + } + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + tvm.testing.compare_numpy_tvm( + [data, indices], out, target, ctx, fcompute, fschedule + ) + + data = np.array([2, 3, 0]) + indices = np.array([[1, 1, 0], [0, 1, 0]]) + shape = (2, 2) + out = np.array([[0, 0], [2, 3]]) + check_scatter_nd(data, indices, shape, out) + + data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + indices = np.array([[0, 1], [1, 1]]) + shape = (2, 2, 2, 2) + out = np.array([[[[0, 0], [1, 2]], [[0, 0], [3, 4]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]) + check_scatter_nd(data, indices, shape, out) + +if __name__ == "__main__": + test_scatter_nd(tvm.context("cpu"), tvm.target.Target("llvm")) From 3aa5bdec90ddb93b94b15380d2afe497f84cccea Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 5 Nov 2020 09:57:07 -0800 Subject: [PATCH 02/13] formatting --- python/tvm/relay/backend/compile_engine.py | 6 ++-- python/tvm/relay/op/_tensor_grad.py | 7 ---- python/tvm/relay/op/strategy/generic.py | 5 ++- python/tvm/te/operation.py | 6 +++- python/tvm/testing.py | 29 ----------------- python/tvm/topi/scatter.py | 32 +++++++++++++------ python/tvm/topi/testing/common.py | 29 +++++++++++++++++ tests/python/relay/test_op_grad_level3.py | 9 ------ tests/python/topi/python/test_topi_scatter.py | 7 ++-- 9 files changed, 68 insertions(+), 62 deletions(-) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 14e1f5d85d9f..43643a7be745 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -121,8 +121,10 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): The list of all valid op implementations. """ fstrategy = op.get_attr("FTVMStrategy") - assert fstrategy is not None, "%s doesn't have an FTVMStrategy registered. You can register " \ - "one in python with `tvm.relay.op.register_strategy`." % op.name + assert fstrategy is not None, ( + "%s doesn't have an FTVMStrategy registered. You can register " + "one in python with `tvm.relay.op.register_strategy`." % op.name + ) with target: strategy = fstrategy(attrs, inputs, out_type, target) analyzer = tvm.arith.Analyzer() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index b200aa125270..9c84411352f2 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -62,7 +62,6 @@ squeeze, strided_set, arange, - gather_nd, scatter_nd, ) @@ -811,9 +810,3 @@ def arange_grad(orig, grad): def gather_nd_grad(orig, grad): data, indices = orig.args return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)] - - -# @register_gradient("scatter_nd") -# def scatter_nd_grad(orig, grad): -# data, indices = orig.args -# return [gather_nd(grad, indices), zeros_like(indices)] diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 54e1f9a9fe2b..ac9d3b157ec4 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1062,17 +1062,20 @@ def scatter_add_strategy(attrs, outs, out_type, target): ) return strategy + # scatter_nd @override_native_generic_func("scatter_nd_strategy") def scatter_nd_strategy(attrs, inputs, out_type, target): """scatter_nd generic strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_scatter_nd(topi.scatter_nd), wrap_topi_schedule(topi.generic.schedule_extern), + wrap_compute_scatter_nd(topi.scatter_nd), + wrap_topi_schedule(topi.generic.schedule_extern), name="scatter_nd.generic", ) return strategy + def wrap_compute_scatter_nd(topi_compute): """Wrap scatter_nd topi compute""" diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index a924c8b0c0db..0f3457af0f10 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -317,7 +317,11 @@ def extern( if isinstance(body, tvm.tir.PrimExpr): body = tvm.tir.Evaluate(body) if not isinstance(body, tvm.tir.Stmt): - raise ValueError("Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(fcompute.__name__, type(body))) + raise ValueError( + "Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format( + fcompute.__name__, type(body) + ) + ) op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body) res = [op.output(i) for i in range(len(output_placeholders))] diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 7286aa4bbbd9..e5b17f3d7b53 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -714,33 +714,4 @@ def func(f): return wrap(args) -def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule): - """Compare a numpy inputs and output of a function to the results of the TVM version. - - Parameters - ---------- - inputs : Sequence[numpy.nd.array] - List of input numpy arrays to pass to the function. - output : numpy.nd.array - Verified correct function output. - target : tvm.target.Target - Target to run on. - ctx : tvm.TVMContext - Context to run on. - compute : callable - Topi compute function to test against. - schedule : callable - Topi scheduling function to test against. - """ - te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs] - te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx) - with tvm.target.Target(target): - out = compute(*te_inputs) - s = schedule([out]) - func = tvm.build(s, te_inputs + [out]) - arys = [tvm.nd.array(x, ctx=ctx) for x in inputs] - func(*(arys + [te_out])) - assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4) - - tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 382a91e790c6..6c1d1ab39176 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -18,7 +18,6 @@ """Scatter operator""" from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate from ..te import extern, hybrid -from . import full @hybrid.script @@ -233,14 +232,23 @@ def scatter_nd(data, indices, shape): ------- ret : tvm.te.Tensor """ - assert indices.shape[0] <= len(shape), f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to the length of the shape of the output ({len(shape)})." - for i in range(len(indices.shape)-1): - assert indices.shape[i+1] == data.shape[i], f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of data[{i}] ({data.shape[i]})." + assert indices.shape[0] <= len(shape), ( + f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to " + f"the length of the shape of the output ({len(shape)})." + ) + for i in range(len(indices.shape) - 1): + assert indices.shape[i + 1] == data.shape[i], ( + f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " + f"data[{i}] ({data.shape[i]})." + ) for i in range(int(indices.shape[0]), len(shape)): - assert data.shape[i] == out_shape[i], f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" - - assert "int" in indices.dtype, f"Indices must be a tensor of integers, but its elements are {indices.dtype}" + assert ( + data.shape[i] == out_shape[i] + ), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" + assert ( + "int" in indices.dtype + ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}" def gen_ir(data_ptr, indices_ptr, out_ptr): ib = ir_builder.create() @@ -254,7 +262,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): fused_shape = 1 for i in shape: fused_shape *= i - with ib.for_range(0, fused_shape): + with ib.for_range(0, fused_shape) as i: out[i] = Cast(data_ptr.dtype, 0) # We combine all the indices dimensions but the first one into a single @@ -277,7 +285,13 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): for l in reversed(range(indices_ptr.shape[0].value)): # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] - ib.emit(AssertStmt(indices[i + l * fused_indices_dimension] < shape[l], StringImm("index out of bounds"), Evaluate(0))) + ib.emit( + AssertStmt( + indices[i + l * fused_indices_dimension] < shape[l], + StringImm("index out of bounds"), + Evaluate(0), + ) + ) offset *= shape[l] out[index] = data[i * fused_data_dimension + j] diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index 51ea19afe7ce..35a6040fa25a 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -77,3 +77,32 @@ def get_reduce_schedule(target): def get_conv2d_nchw_implement(target): return dispatch(target, _conv2d_nchw_implement) + + +def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule): + """Compare a numpy inputs and output of a function to the results of the TVM version. + + Parameters + ---------- + inputs : Sequence[numpy.nd.array] + List of input numpy arrays to pass to the function. + output : numpy.nd.array + Verified correct function output. + target : tvm.target.Target + Target to run on. + ctx : tvm.TVMContext + Context to run on. + compute : callable + Topi compute function to test against. + schedule : callable + Topi scheduling function to test against. + """ + te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs] + te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx) + with tvm.target.Target(target): + out = compute(*te_inputs) + s = schedule([out]) + func = tvm.build(s, te_inputs + [out]) + arys = [tvm.nd.array(x, ctx=ctx) for x in inputs] + func(*(arys + [te_out])) + assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 358441738ece..a5cb916da613 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -126,14 +126,5 @@ def test_gather_nd_grad(): check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np]) -# def test_scatter_nd_grad(): -# data = relay.var("data", relay.TensorType((2, 2), "float64")) -# indices = relay.var("indices", relay.TensorType((2, 2), "int64")) -# fwd = relay.Function([data, indices], relay.scatter_nd(data, indices, (2, 2))) -# data_np = np.array([[0, 1], [2, 3]]).astype("float64") -# indices_np = np.array([[1, 0], [1, 1]]) -# check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np]) - - if __name__ == "__main__": pytest.main() diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index c8c5d0f40004..ef8f94609471 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -25,12 +25,10 @@ def test_scatter_nd(ctx, target): def check_scatter_nd(data, indices, shape, out): implementations = { - "generic": (lambda x,y: topi.scatter_nd(x,y,shape), topi.generic.schedule_extern), + "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) - tvm.testing.compare_numpy_tvm( - [data, indices], out, target, ctx, fcompute, fschedule - ) + tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule) data = np.array([2, 3, 0]) indices = np.array([[1, 1, 0], [0, 1, 0]]) @@ -44,5 +42,6 @@ def check_scatter_nd(data, indices, shape, out): out = np.array([[[[0, 0], [1, 2]], [[0, 0], [3, 4]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]) check_scatter_nd(data, indices, shape, out) + if __name__ == "__main__": test_scatter_nd(tvm.context("cpu"), tvm.target.Target("llvm")) From dff36447bab8db356df91791d5ce8674299c32ef Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 5 Nov 2020 13:46:58 -0800 Subject: [PATCH 03/13] Fix tests --- python/tvm/topi/scatter.py | 5 +++-- python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/common.py | 3 +++ tests/python/relay/test_any.py | 5 ++++- tests/python/relay/test_op_grad_level3.py | 4 ++-- 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 6c1d1ab39176..e8c87729f946 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -241,9 +241,10 @@ def scatter_nd(data, indices, shape): f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " f"data[{i}] ({data.shape[i]})." ) - for i in range(int(indices.shape[0]), len(shape)): + mdim = int(indices.shape[0]) + for i in range(mdim, len(shape)): assert ( - data.shape[i] == out_shape[i] + data.shape[i-mdim] == shape[i] ), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" assert ( diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 5b23e8f4600e..18a46b17bb0a 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -57,6 +57,7 @@ from .space_to_depth import space_to_depth_python from .crop_and_resize_python import crop_and_resize_python from .common import ( + compare_numpy_tvm, get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index 35a6040fa25a..5639662d5a9d 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -19,6 +19,9 @@ import tvm from tvm import topi +from tvm.testing import assert_allclose + +import numpy as np _injective_schedule = { "generic": topi.generic.schedule_injective, diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 546973704fea..eec6aa21c69b 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -989,7 +989,10 @@ def _body(i, st): body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) with DiagnosticTesting() as diagnostics: - diagnostics.assert_message("in particular dimension 0 conflicts 2 does not match 1") + diagnostics.assert_message( + "The Relay type checker is unable to show the following types " + "match.\nIn particular dimension 0 conflicts: 2 does not match 1." + ) func = infer_type(func) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index a5cb916da613..a7443c65ac8c 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -122,8 +122,8 @@ def test_gather_nd_grad(): indices = relay.var("indices", relay.TensorType((2, 4), "int64")) fwd = relay.Function([data, indices], relay.gather_nd(data, indices)) data_np = np.random.rand(2, 3) - indices_np = np.array([[0, 2, 1, 0], [0, 1, 2, 1]]) - check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np]) + indices_np = np.array([[0, 1, 1, 0], [0, 1, 2, 1]]) + check_grad(fwd, inputs=[data_np, indices_np], test_inputs=indices_np) if __name__ == "__main__": From 4373ebab5bff07b36b324ec538e534537d1ae3d1 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 5 Nov 2020 13:54:41 -0800 Subject: [PATCH 04/13] formatting --- python/tvm/topi/scatter.py | 2 +- python/tvm/topi/testing/common.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index e8c87729f946..d39e98d3419a 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -244,7 +244,7 @@ def scatter_nd(data, indices, shape): mdim = int(indices.shape[0]) for i in range(mdim, len(shape)): assert ( - data.shape[i-mdim] == shape[i] + data.shape[i - mdim] == shape[i] ), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" assert ( diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index 5639662d5a9d..e97cd3c34428 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -17,12 +17,11 @@ # pylint: disable=invalid-name """Common utility for topi test""" +import numpy as np import tvm from tvm import topi from tvm.testing import assert_allclose -import numpy as np - _injective_schedule = { "generic": topi.generic.schedule_injective, "cpu": topi.x86.schedule_injective, From 1d9d1f9d6e3f6acf871cbb4deb32c045ea28367b Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 5 Nov 2020 15:37:14 -0800 Subject: [PATCH 05/13] specify types on test --- tests/python/relay/test_op_grad_level3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index a7443c65ac8c..182de0a9fea1 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -121,8 +121,8 @@ def test_gather_nd_grad(): data = relay.var("data", relay.TensorType((2, 3), "float64")) indices = relay.var("indices", relay.TensorType((2, 4), "int64")) fwd = relay.Function([data, indices], relay.gather_nd(data, indices)) - data_np = np.random.rand(2, 3) - indices_np = np.array([[0, 1, 1, 0], [0, 1, 2, 1]]) + data_np = np.random.rand(2, 3).astype("float64") + indices_np = np.array([[0, 1, 1, 0], [0, 1, 2, 1]]).astype("int64") check_grad(fwd, inputs=[data_np, indices_np], test_inputs=indices_np) From 288873f8ee60628f375792361748ad6155bb75d8 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 6 Nov 2020 11:52:30 -0800 Subject: [PATCH 06/13] Fix grad test --- python/tvm/relay/testing/__init__.py | 2 ++ python/tvm/topi/scatter.py | 2 +- tests/python/relay/test_op_grad_level3.py | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 9c87f2795e5c..93110e313642 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -143,6 +143,8 @@ def check_grad( break grads = tmp + assert len(grads) > 0, "You must test at least one gradient." + # Get numeric gradients for each dimension of each param, using two-sided approximation. approx_grads = [] for x in test_inputs: diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index d39e98d3419a..21439c9fb537 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -294,7 +294,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): ) ) offset *= shape[l] - out[index] = data[i * fused_data_dimension + j] + out[index] += data[i * fused_data_dimension + j] return ib.get() diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 182de0a9fea1..98ff62ed75d4 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -122,8 +122,8 @@ def test_gather_nd_grad(): indices = relay.var("indices", relay.TensorType((2, 4), "int64")) fwd = relay.Function([data, indices], relay.gather_nd(data, indices)) data_np = np.random.rand(2, 3).astype("float64") - indices_np = np.array([[0, 1, 1, 0], [0, 1, 2, 1]]).astype("int64") - check_grad(fwd, inputs=[data_np, indices_np], test_inputs=indices_np) + indices_np = np.array([[0, 1, 1, 0], [0, 1, 0, 0]], dtype="int64") + check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[data_np]) if __name__ == "__main__": From 60d94d73717a2f145c404c93ca6a070ca76964c7 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 6 Nov 2020 11:58:27 -0800 Subject: [PATCH 07/13] scatter_nd cuda impl --- python/tvm/relay/op/strategy/cuda.py | 13 ++ python/tvm/topi/cuda/scatter.py | 120 ++++++++++++++++++ tests/python/topi/python/test_topi_scatter.py | 1 + 3 files changed, 134 insertions(+) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1229a71569d0..4959bc88f4db 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -712,6 +712,19 @@ def scatter_add_cuda(attrs, inputs, out_type, target): return strategy +@scatter_nd_strategy.register(["cuda", "gpu"]) +def scatter_nd_cuda(attrs, inputs, out_type, target): + """scatter_nd cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_nd(topi.cuda.scatter_nd), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_nd.cuda", + plevel=10, + ) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 0a3e96f4be30..8dffdfce887f 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -522,3 +522,123 @@ def update_func(dst_ptr, dst_index, update): ) return out + + +def scatter_nd(data, indices, shape): + """Scatter elements from a n-dimension array. + + Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + + .. code-block:: + + output[indices[0, y_0, ..., y_{K-1}], + ..., + indices[M-1, y_0, ..., y_{K-1}], + x_M, + ..., + x_{N-1} + ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + + all other entries in the output are 0. Repeated indices are summed. + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + shape : Sequence[int] + The output shape. This must be specified because it cannot be inferred. + + Returns + ------- + ret : tvm.te.Tensor + """ + assert indices.shape[0] <= len(shape), ( + f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to " + f"the length of the shape of the output ({len(shape)})." + ) + for i in range(len(indices.shape) - 1): + assert indices.shape[i + 1] == data.shape[i], ( + f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " + f"data[{i}] ({data.shape[i]})." + ) + mdim = int(indices.shape[0]) + for i in range(mdim, len(shape)): + assert ( + data.shape[i - mdim] == shape[i] + ), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" + + assert ( + "int" in indices.dtype + ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}" + + def gen_ir(data_ptr, indices_ptr, out_ptr): + ib = ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + out = ib.buffer_ptr(out_ptr) + + # zero data + # TODO(tkonolige): could we use topi.full to zero it instead? + fused_shape = 1 + for i in shape: + fused_shape *= i + with ib.for_range(0, fused_shape) as i: + out[i] = Cast(data_ptr.dtype, 0) + + # We combine all the indices dimensions but the first one into a single + # dimension so we can iterate it in single loop instead of an arbitrary + # number of loops. We do the same thing for all the data dimensions. + fused_indices_dimension = 1 + for i in indices_ptr.shape[1:]: + fused_indices_dimension *= i + + fused_data_dimension = 1 + for i in data_ptr.shape[indices_ptr.shape[0].value :]: + fused_data_dimension *= i + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + tdim = min(max_threads, fused_data_dimension) + ib.scope_attr(tx, "thread_extent", tdim) + bdim = ceil_div(fused_data_dimension, tdim) + ib.scope_attr(bx, "thread_extent", bdim) + + + with ib.for_range(0, fused_indices_dimension) as i: + j = bx * tdim + tx + offset = fused_data_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + ib.emit( + AssertStmt( + indices[i + l * fused_indices_dimension] < shape[l], + StringImm("index out of bounds"), + Evaluate(0), + ) + ) + offset *= shape[l] + out[index] += data[i * fused_data_dimension + j] + + return ib.get() + + out_buf = decl_buffer(shape, data.dtype, "out_buf") + return extern( + [shape], + [data, indices], + lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_nd_cuda", + tag="scatter_nd_cuda", + ) diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index ef8f94609471..ae26e1bb0f6b 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -26,6 +26,7 @@ def test_scatter_nd(ctx, target): def check_scatter_nd(data, indices, shape, out): implementations = { "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), + "cuda": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule) From 557b915d9def68971558fcf0782e55678a024421 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 6 Nov 2020 13:46:06 -0800 Subject: [PATCH 08/13] cuda impl --- python/tvm/topi/cuda/scatter.py | 76 ++++++++----------- python/tvm/topi/scatter.py | 41 +++++----- tests/python/topi/python/test_topi_scatter.py | 9 +++ 3 files changed, 63 insertions(+), 63 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 8dffdfce887f..ac6bf830da3e 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -18,6 +18,7 @@ """Scatter operator """ import tvm from tvm import te +from ..scatter import _verify_scatter_nd_inputs def ceil_div(a, b): @@ -557,40 +558,15 @@ def scatter_nd(data, indices, shape): ------- ret : tvm.te.Tensor """ - assert indices.shape[0] <= len(shape), ( - f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to " - f"the length of the shape of the output ({len(shape)})." - ) - for i in range(len(indices.shape) - 1): - assert indices.shape[i + 1] == data.shape[i], ( - f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " - f"data[{i}] ({data.shape[i]})." - ) - mdim = int(indices.shape[0]) - for i in range(mdim, len(shape)): - assert ( - data.shape[i - mdim] == shape[i] - ), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" - - assert ( - "int" in indices.dtype - ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}" + _verify_scatter_nd_inputs(data, indices, shape) def gen_ir(data_ptr, indices_ptr, out_ptr): - ib = ir_builder.create() + ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) out = ib.buffer_ptr(out_ptr) - # zero data - # TODO(tkonolige): could we use topi.full to zero it instead? - fused_shape = 1 - for i in shape: - fused_shape *= i - with ib.for_range(0, fused_shape) as i: - out[i] = Cast(data_ptr.dtype, 0) - # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary # number of loops. We do the same thing for all the data dimensions. @@ -602,6 +578,16 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): for i in data_ptr.shape[indices_ptr.shape[0].value :]: fused_data_dimension *= i + fused_shape = 1 + for i in shape: + fused_shape *= i + + # For now we avoid parallizing over dimensions indexed by `indices` as + # there may be repeated indices and hadling parallel accumulation can + # be hard. So we parallelize over X_M .. X_{N-1} instead. This will + # work well when these dimensions are large enough to saturate memory + # bandwidth, but performance will be bad when these dimensions are + # small. bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) @@ -610,30 +596,30 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): bdim = ceil_div(fused_data_dimension, tdim) ib.scope_attr(bx, "thread_extent", bdim) + # zero data + # TODO(tkonolige): could we use topi.full to zero it instead? + with ib.for_range(0, ceil_div(fused_shape, bdim)) as i: + index = i * fused_data_dimension + bx * tdim + tx + with ib.if_scope(index < fused_shape): + out[index] = tvm.tir.Cast(data_ptr.dtype, 0) with ib.for_range(0, fused_indices_dimension) as i: j = bx * tdim + tx - offset = fused_data_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. - for l in reversed(range(indices_ptr.shape[0].value)): - # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[i + l * fused_indices_dimension] - ib.emit( - AssertStmt( - indices[i + l * fused_indices_dimension] < shape[l], - StringImm("index out of bounds"), - Evaluate(0), - ) - ) - offset *= shape[l] - out[index] += data[i * fused_data_dimension + j] + with ib.if_scope(j < fused_data_dimension): + offset = fused_data_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= shape[l] + out[index] += data[i * fused_data_dimension + j] return ib.get() - out_buf = decl_buffer(shape, data.dtype, "out_buf") - return extern( + out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf") + return te.extern( [shape], [data, indices], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 21439c9fb537..4f708645bcff 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -199,6 +199,28 @@ def scatter(data, indices, updates, axis=0): raise ValueError("scatter only support for 1-4 dimensions") +def _verify_scatter_nd_inputs(data, indices, shape): + mdim = int(indices.shape[0]) + assert mdim <= len(shape), ( + f"The first dimension of the indices ({mdim}) must be less than or equal to " + f"the length of the shape of the output ({len(shape)})." + ) + for i in range(len(indices.shape) - 1): + assert indices.shape[i + 1] == data.shape[i], ( + f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " + f"data[{i}] ({data.shape[i]})." + ) + for i in range(mdim, len(shape)): + data_ind = i - mdim + len(indices.shape) - 1 + assert ( + data.shape[data_ind] == shape[i] + ), f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension of out_shape[{i}] ({shape[i]})." + + assert ( + "int" in indices.dtype + ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}." + + def scatter_nd(data, indices, shape): """Scatter elements from a n-dimension array. @@ -232,24 +254,7 @@ def scatter_nd(data, indices, shape): ------- ret : tvm.te.Tensor """ - assert indices.shape[0] <= len(shape), ( - f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to " - f"the length of the shape of the output ({len(shape)})." - ) - for i in range(len(indices.shape) - 1): - assert indices.shape[i + 1] == data.shape[i], ( - f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " - f"data[{i}] ({data.shape[i]})." - ) - mdim = int(indices.shape[0]) - for i in range(mdim, len(shape)): - assert ( - data.shape[i - mdim] == shape[i] - ), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" - - assert ( - "int" in indices.dtype - ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}" + _verify_scatter_nd_inputs(data, indices, shape) def gen_ir(data_ptr, indices_ptr, out_ptr): ib = ir_builder.create() diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index ae26e1bb0f6b..5845cc63965d 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -43,6 +43,15 @@ def check_scatter_nd(data, indices, shape, out): out = np.array([[[[0, 0], [1, 2]], [[0, 0], [3, 4]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]) check_scatter_nd(data, indices, shape, out) + data = np.reshape(np.arange(1560*3), (3, 1560)).astype("float32") + indices = np.array([[1, 0, 0]]) + shape = (2, 1560) + out = np.zeros(shape).astype("float32") + out[1, :] += data[0, :] + out[0, :] += data[1, :] + out[0, :] += data[2, :] + check_scatter_nd(data, indices, shape, out) + if __name__ == "__main__": test_scatter_nd(tvm.context("cpu"), tvm.target.Target("llvm")) From 2431807ae32819a57001094cd6ceaed542d2572b Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 6 Nov 2020 14:43:50 -0800 Subject: [PATCH 09/13] x86 impl --- python/tvm/relay/op/strategy/x86.py | 12 ++ python/tvm/topi/x86/__init__.py | 1 + python/tvm/topi/x86/scatter.py | 108 ++++++++++++++++++ tests/python/topi/python/test_topi_scatter.py | 1 + 4 files changed, 122 insertions(+) create mode 100644 python/tvm/topi/x86/scatter.py diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3c5735b17aa5..3f48e5d35b9c 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -446,3 +446,15 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target): name="bitserial_dense.x86", ) return strategy + +@scatter_nd_strategy.register("cpu") +def scatter_nd_strategy_cpu(attrs, inputs, out_type, target): + """scatter_nd x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_nd(topi.x86.scatter_nd), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_nd.x86", + plevel=10, + ) + return strategy diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index 659668cbbe4c..154511010a1c 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -39,3 +39,4 @@ from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * +from .scatter import * diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py new file mode 100644 index 000000000000..4c4e63a26567 --- /dev/null +++ b/python/tvm/topi/x86/scatter.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Scatter operators for x86""" +import tvm +from tvm import te +from ..scatter import _verify_scatter_nd_inputs + + +def scatter_nd(data, indices, shape): + """Scatter elements from a n-dimension array. + + Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + + .. code-block:: + + output[indices[0, y_0, ..., y_{K-1}], + ..., + indices[M-1, y_0, ..., y_{K-1}], + x_M, + ..., + x_{N-1} + ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + + all other entries in the output are 0. Repeated indices are summed. + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + shape : Sequence[int] + The output shape. This must be specified because it cannot be inferred. + + Returns + ------- + ret : tvm.te.Tensor + """ + _verify_scatter_nd_inputs(data, indices, shape) + + def gen_ir(data_ptr, indices_ptr, out_ptr): + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + out = ib.buffer_ptr(out_ptr) + + # We combine all the indices dimensions but the first one into a single + # dimension so we can iterate it in single loop instead of an arbitrary + # number of loops. We do the same thing for all the data dimensions. + fused_indices_dimension = 1 + for i in indices_ptr.shape[1:]: + fused_indices_dimension *= i + + fused_data_dimension = 1 + for i in data_ptr.shape[indices_ptr.shape[0].value :]: + fused_data_dimension *= i + + fused_shape = 1 + for i in shape: + fused_shape *= i + + # zero data + # TODO(tkonolige): could we use topi.full to zero it instead? + with ib.for_range(0, fused_shape) as i: + out[i] = tvm.tir.Cast(data_ptr.dtype, 0) + + with ib.for_range(0, fused_indices_dimension) as i: + with ib.for_range(0, fused_data_dimension, for_type="parallel") as j: + offset = fused_data_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= shape[l] + out[index] += data[i * fused_data_dimension + j] + + return ib.get() + + out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf") + return te.extern( + [shape], + [data, indices], + lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_nd_x86", + tag="scatter_nd_x86", + ) diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index 5845cc63965d..3f64644d1b46 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -27,6 +27,7 @@ def check_scatter_nd(data, indices, shape, out): implementations = { "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), "cuda": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern), + "llvm": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule) From 0593227560dd744f8c1e3fdacea0df043dec3705 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 6 Nov 2020 16:08:19 -0800 Subject: [PATCH 10/13] formatting --- python/tvm/relay/op/strategy/x86.py | 1 + python/tvm/topi/scatter.py | 7 ++++--- python/tvm/topi/x86/scatter.py | 1 + tests/python/topi/python/test_topi_scatter.py | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3f48e5d35b9c..3f129c471faf 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -447,6 +447,7 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target): ) return strategy + @scatter_nd_strategy.register("cpu") def scatter_nd_strategy_cpu(attrs, inputs, out_type, target): """scatter_nd x86 strategy""" diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 4f708645bcff..848bdff1d470 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -212,9 +212,10 @@ def _verify_scatter_nd_inputs(data, indices, shape): ) for i in range(mdim, len(shape)): data_ind = i - mdim + len(indices.shape) - 1 - assert ( - data.shape[data_ind] == shape[i] - ), f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension of out_shape[{i}] ({shape[i]})." + assert data.shape[data_ind] == shape[i], ( + f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension " + f"of out_shape[{i}] ({shape[i]})." + ) assert ( "int" in indices.dtype diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py index 4c4e63a26567..dcd875f68a4a 100644 --- a/python/tvm/topi/x86/scatter.py +++ b/python/tvm/topi/x86/scatter.py @@ -56,6 +56,7 @@ def scatter_nd(data, indices, shape): _verify_scatter_nd_inputs(data, indices, shape) def gen_ir(data_ptr, indices_ptr, out_ptr): + # pylint: disable=invalid-name ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index 3f64644d1b46..cf6b3fc23b36 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -44,7 +44,7 @@ def check_scatter_nd(data, indices, shape, out): out = np.array([[[[0, 0], [1, 2]], [[0, 0], [3, 4]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]) check_scatter_nd(data, indices, shape, out) - data = np.reshape(np.arange(1560*3), (3, 1560)).astype("float32") + data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") indices = np.array([[1, 0, 0]]) shape = (2, 1560) out = np.zeros(shape).astype("float32") From 8cf300f0349b0e3c4360f1b7876cf4b8ce2669fc Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 11 Nov 2020 10:30:02 -0800 Subject: [PATCH 11/13] fix shape rel --- src/relay/op/tensor/transform.cc | 2 +- tests/python/topi/python/test_topi_scatter.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a7a8024999d9..09f0a35365d5 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1016,7 +1016,7 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} for (size_t i = mdim->value; i < ndim; i++) { - reporter->AssertEQ(data->shape[i], oshape[i]); + reporter->AssertEQ(data->shape[i - mdim + kdim], oshape[i]); } reporter->Assign(types[2], TensorType(oshape, data->dtype)); diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index cf6b3fc23b36..bb1b039c19b2 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -53,6 +53,14 @@ def check_scatter_nd(data, indices, shape, out): out[0, :] += data[2, :] check_scatter_nd(data, indices, shape, out) + data = np.random.rand((40, 768)) + indices = np.stack((np.random.randint(40, size=40), np.random.randint(768, size=40))) + shape = (8, 50, 768) + out = np.zeros(shape).astype("float32") + for i in range(40): + out[indices[0, i], indices[1, i], :] += data[i, :] + check_scatter_nd(data, indices, shape, out) + if __name__ == "__main__": test_scatter_nd(tvm.context("cpu"), tvm.target.Target("llvm")) From 9ec74c63cd08be4c0c215e351a82280ad68e082c Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 11 Nov 2020 11:58:17 -0800 Subject: [PATCH 12/13] fix tests --- python/tvm/topi/cuda/scatter.py | 2 +- python/tvm/topi/scatter.py | 6 +++--- python/tvm/topi/testing/common.py | 2 +- python/tvm/topi/x86/scatter.py | 2 +- src/relay/op/tensor/transform.cc | 2 +- tests/python/topi/python/test_topi_scatter.py | 19 ++++++++++--------- 6 files changed, 17 insertions(+), 16 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index ac6bf830da3e..4bbbc1a7f919 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -575,7 +575,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): fused_indices_dimension *= i fused_data_dimension = 1 - for i in data_ptr.shape[indices_ptr.shape[0].value :]: + for i in data_ptr.shape[len(indices_ptr.shape)-1 :]: fused_data_dimension *= i fused_shape = 1 diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 848bdff1d470..dcb4477fea66 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -280,11 +280,11 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): fused_indices_dimension *= i fused_data_dimension = 1 - for i in data_ptr.shape[indices_ptr.shape[0].value :]: + for i in data_ptr.shape[len(indices_ptr.shape)-1 :]: fused_data_dimension *= i - with ib.for_range(0, fused_indices_dimension) as i: - with ib.for_range(0, fused_data_dimension) as j: + with ib.for_range(0, fused_indices_dimension, name='i') as i: + with ib.for_range(0, fused_data_dimension, name='j') as j: offset = fused_data_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index e97cd3c34428..e4e5e811ab18 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -107,4 +107,4 @@ def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule): func = tvm.build(s, te_inputs + [out]) arys = [tvm.nd.array(x, ctx=ctx) for x in inputs] func(*(arys + [te_out])) - assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4) + assert_allclose(te_out.asnumpy(), output, atol=1e-4, rtol=1e-4) diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py index dcd875f68a4a..4d2bef5369d7 100644 --- a/python/tvm/topi/x86/scatter.py +++ b/python/tvm/topi/x86/scatter.py @@ -71,7 +71,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): fused_indices_dimension *= i fused_data_dimension = 1 - for i in data_ptr.shape[indices_ptr.shape[0].value :]: + for i in data_ptr.shape[len(indices_ptr.shape)-1 :]: fused_data_dimension *= i fused_shape = 1 diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 09f0a35365d5..4b374d881947 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1016,7 +1016,7 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} for (size_t i = mdim->value; i < ndim; i++) { - reporter->AssertEQ(data->shape[i - mdim + kdim], oshape[i]); + reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); } reporter->Assign(types[2], TensorType(oshape, data->dtype)); diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index bb1b039c19b2..2e701e2903d9 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -26,8 +26,8 @@ def test_scatter_nd(ctx, target): def check_scatter_nd(data, indices, shape, out): implementations = { "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), - "cuda": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern), - "llvm": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern), + "gpu": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern), + "cpu": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule) @@ -41,7 +41,7 @@ def check_scatter_nd(data, indices, shape, out): data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) indices = np.array([[0, 1], [1, 1]]) shape = (2, 2, 2, 2) - out = np.array([[[[0, 0], [1, 2]], [[0, 0], [3, 4]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]) + out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]]) check_scatter_nd(data, indices, shape, out) data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") @@ -53,12 +53,13 @@ def check_scatter_nd(data, indices, shape, out): out[0, :] += data[2, :] check_scatter_nd(data, indices, shape, out) - data = np.random.rand((40, 768)) - indices = np.stack((np.random.randint(40, size=40), np.random.randint(768, size=40))) - shape = (8, 50, 768) - out = np.zeros(shape).astype("float32") - for i in range(40): - out[indices[0, i], indices[1, i], :] += data[i, :] + data = np.ones((5, 3)).astype("float64") + indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64") + shape = (2, 7, 3) + out = np.zeros(shape).astype("float64") + for i in range(indices.shape[1]): + for j in range(data.shape[1]): + out[indices[0, i], indices[1, i], j] += data[i, j] check_scatter_nd(data, indices, shape, out) From b212a70b49044055fa73e58df51868ea05ba1d34 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 11 Nov 2020 12:02:30 -0800 Subject: [PATCH 13/13] formatting --- python/tvm/topi/cuda/scatter.py | 2 +- python/tvm/topi/scatter.py | 6 +++--- python/tvm/topi/x86/scatter.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 4bbbc1a7f919..5e03fafcfb58 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -575,7 +575,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): fused_indices_dimension *= i fused_data_dimension = 1 - for i in data_ptr.shape[len(indices_ptr.shape)-1 :]: + for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: fused_data_dimension *= i fused_shape = 1 diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index dcb4477fea66..a376963aa55a 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -280,11 +280,11 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): fused_indices_dimension *= i fused_data_dimension = 1 - for i in data_ptr.shape[len(indices_ptr.shape)-1 :]: + for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: fused_data_dimension *= i - with ib.for_range(0, fused_indices_dimension, name='i') as i: - with ib.for_range(0, fused_data_dimension, name='j') as j: + with ib.for_range(0, fused_indices_dimension, name="i") as i: + with ib.for_range(0, fused_data_dimension, name="j") as j: offset = fused_data_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py index 4d2bef5369d7..8147d3a00135 100644 --- a/python/tvm/topi/x86/scatter.py +++ b/python/tvm/topi/x86/scatter.py @@ -71,7 +71,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): fused_indices_dimension *= i fused_data_dimension = 1 - for i in data_ptr.shape[len(indices_ptr.shape)-1 :]: + for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: fused_data_dimension *= i fused_shape = 1