Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY,TOPI] Add scatter_nd op #6854

Merged
merged 13 commits into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
}
};

struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
Array<Integer> out_shape;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter.");
}
};

struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
Integer axis;

Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def get_valid_implementations(op, attrs, inputs, out_type, target):
The list of all valid op implementations.
"""
fstrategy = op.get_attr("FTVMStrategy")
assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
assert fstrategy is not None, (
"%s doesn't have an FTVMStrategy registered. You can register "
"one in python with `tvm.relay.op.register_strategy`." % op.name
)
with target:
strategy = fstrategy(attrs, inputs, out_type, target)
analyzer = tvm.arith.Analyzer()
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
squeeze,
strided_set,
arange,
scatter_nd,
)


Expand Down Expand Up @@ -803,3 +804,9 @@ def arange_grad(orig, grad):
grad_step = cast_like(_sum(grad_step), step)

return [grad_start, grad_stop, grad_step]


@register_gradient("gather_nd")
def gather_nd_grad(orig, grad):
data, indices = orig.args
return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
9 changes: 9 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def compute_scatter_add(attrs, inputs, output_type):

_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)

# scatter
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)]


_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)

#####################
# Shape functions #
#####################
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,19 @@ def scatter_add_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_nd_strategy.register(["cuda", "gpu"])
def scatter_nd_cuda(attrs, inputs, out_type, target):
"""scatter_nd cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.cuda.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.cuda",
plevel=10,
)
return strategy


@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,28 @@ def scatter_add_strategy(attrs, outs, out_type, target):
return strategy


# scatter_nd
@override_native_generic_func("scatter_nd_strategy")
def scatter_nd_strategy(attrs, inputs, out_type, target):
"""scatter_nd generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.generic",
)
return strategy


def wrap_compute_scatter_nd(topi_compute):
"""Wrap scatter_nd topi compute"""

def _compute_scatter_nd(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], attrs.out_shape)]

return _compute_scatter_nd


# bitserial_conv2d
def wrap_compute_bitserial_conv2d(topi_compute):
"""wrap bitserial_conv2d topi compute"""
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,16 @@ def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target):
name="bitserial_dense.x86",
)
return strategy


@scatter_nd_strategy.register("cpu")
def scatter_nd_strategy_cpu(attrs, inputs, out_type, target):
"""scatter_nd x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.x86.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.x86",
plevel=10,
)
return strategy
24 changes: 24 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,30 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def scatter_nd(data, indices, out_shape):
"""Scatter values from an array.

See :py:func:`tvm.topi.scatter` for how data is scattered.
tkonolige marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
data : relay.Expr
The input data to the operator.

indices : relay.Expr
The index locations to update.

out_shape : relay.Expr
Output shape of the scatter.

Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter_nd(data, indices, out_shape)


def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
"""Reshapes the input tensor by the size of another tensor.
For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def check_grad(
break
grads = tmp

assert len(grads) > 0, "You must test at least one gradient."

# Get numeric gradients for each dimension of each param, using two-sided approximation.
approx_grads = []
for x in test_inputs:
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,11 @@ def extern(
if isinstance(body, tvm.tir.PrimExpr):
body = tvm.tir.Evaluate(body)
if not isinstance(body, tvm.tir.Stmt):
raise ValueError("Function '{}' should return PrimExpr or Stmt".format(fcompute.__name__))
raise ValueError(
"Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(
fcompute.__name__, type(body)
)
)

op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
Expand Down
106 changes: 106 additions & 0 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Scatter operator """
import tvm
from tvm import te
from ..scatter import _verify_scatter_nd_inputs


def ceil_div(a, b):
Expand Down Expand Up @@ -522,3 +523,108 @@ def update_func(dst_ptr, dst_index, update):
)

return out


def scatter_nd(data, indices, shape):
"""Scatter elements from a n-dimension array.

Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
(M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes

.. code-block::

output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M,
...,
x_{N-1}
] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]

all other entries in the output are 0. Repeated indices are summed.

Parameters
----------
data : tvm.te.Tensor
The source array.

indices : tvm.te.Tensor
The indices of the values to extract.

shape : Sequence[int]
The output shape. This must be specified because it cannot be inferred.

Returns
-------
ret : tvm.te.Tensor
"""
_verify_scatter_nd_inputs(data, indices, shape)

def gen_ir(data_ptr, indices_ptr, out_ptr):
ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
out = ib.buffer_ptr(out_ptr)

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the data dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i

fused_data_dimension = 1
for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_data_dimension *= i

fused_shape = 1
for i in shape:
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
# small.
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_data_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
bdim = ceil_div(fused_data_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

# zero data
# TODO(tkonolige): could we use topi.full to zero it instead?
with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
index = i * fused_data_dimension + bx * tdim + tx
with ib.if_scope(index < fused_shape):
out[index] = tvm.tir.Cast(data_ptr.dtype, 0)

with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_data_dimension):
offset = fused_data_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= shape[l]
out[index] += data[i * fused_data_dimension + j]

return ib.get()

out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
return te.extern(
[shape],
[data, indices],
lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_cuda",
tag="scatter_nd_cuda",
)
Loading