Skip to content

Commit

Permalink
[TOPI] Minor perf improvement for GPU scatter (apache#7233)
Browse files Browse the repository at this point in the history
* improve scatter 4d init

* do not launch sorting based scatter for small input

* do not use hard coded num threads

* separate sort based implementation

* register scatter as autotvm task

* add missing import

* fix strategy

* add dedicated schedule and dummy flop

* add test tuning script

* try adding dummy knob

* skip random_fill when a tuning workload is from scatter

This reverts commit 1fed883.

* cleanup memcpy ir

* remove scatter tuning script

* make sure zero init arguments

* add comment on why skip random init for scatter

* restore ctx sync

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
2 people authored and Tushar Dey committed Jan 20, 2021
1 parent 2eba240 commit e457f3b
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 82 deletions.
9 changes: 6 additions & 3 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from random import getrandbits
from collections import namedtuple
import tempfile
import numpy as np

import tvm._ffi
import tvm.ir.transform
Expand Down Expand Up @@ -560,9 +561,11 @@ def run_through_rpc(
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
for arg in args:
random_fill(arg)
args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
random_fill(arg)
ctx.sync()

costs = time_f(*args).results
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,10 +783,23 @@ def scatter_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
wrap_topi_schedule(topi.generic.schedule_extern),
wrap_topi_schedule(topi.cuda.schedule_scatter),
name="scatter.cuda",
plevel=10,
)

rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
if target.kind.name == "cuda" and get_global_func(
"tvm.contrib.thrust.stable_sort_by_key", allow_missing=True
):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
name="scatter_via_sort.cuda",
plevel=9, # use the sequential version by default
)
return strategy


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def wrap_compute_scatter(topi_compute):
"""Wrap scatter topi compute"""

def _compute_scatter(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)]
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)]

return _compute_scatter

Expand Down
179 changes: 102 additions & 77 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,33 @@
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Scatter operator """
import tvm
from tvm import te
from tvm import te, autotvm
from ..scatter import _verify_scatter_nd_inputs
from ..generic import schedule_extern
from .nms import atomic_add
from .sort import stable_sort_by_key_thrust, is_thrust_available
from ..utils import prod


def ceil_div(a, b):
return (a + b - 1) // b


def _memcpy_ir(ib, out_ptr, data_ptr, shape):
fused = prod(shape)
with ib.new_scope():
num_thread = int(tvm.target.Target.current(allow_none=False).max_num_threads)
num_blocks = ceil_div(fused, num_thread)
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", num_blocks)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", num_thread)
tid = bx * num_thread + tx

with ib.if_scope(tid < fused):
out_ptr[tid] = data_ptr[tid]


def gen_ir_1d(data, indices, updates, axis, out, update_func):
"""Generate scatter ir for 1d inputs
Expand Down Expand Up @@ -63,10 +80,7 @@ def gen_ir_1d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
out_ptr[bx] = data_ptr[bx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -114,8 +128,6 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
ret : tir
The computational ir.
"""
warp_size = tvm.target.Target.current(False).thread_warp_size

n = data.shape[0]
c = data.shape[1]

Expand All @@ -124,16 +136,7 @@ def gen_ir_2d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", warp_size)
with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_:
j = j_ * warp_size + tx
with ib.if_scope(j < c):
idx = bx * c + j
out_ptr[idx] = data_ptr[idx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -205,18 +208,7 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func):
out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", n)
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", c)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", warp_size)
with ib.for_range(0, ceil_div(h, warp_size), name="k") as k_:
k = k_ * warp_size + tx
with ib.if_scope(k < h):
idx = (bx * c + by) * h + k
out_ptr[idx] = data_ptr[idx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -311,20 +303,7 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):

out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
ib.scope_attr(i, "thread_extent", n)
j = te.thread_axis("blockIdx.y")
ib.scope_attr(j, "thread_extent", c)
k = te.thread_axis("blockIdx.z")
ib.scope_attr(k, "thread_extent", h)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", warp_size)
with ib.for_range(0, ceil_div(w, warp_size), name="l") as l_:
l = l_ * warp_size + tx
with ib.if_scope(l < w):
idx = ((i * c + j) * h + k) * w + l
out_ptr[idx] = data_ptr[idx]
_memcpy_ir(ib, out_ptr, data_ptr, data.shape)

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)
Expand Down Expand Up @@ -417,7 +396,71 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
return ib.get()


def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
@autotvm.register_topi_compute("scatter.cuda")
def scatter(cfg, data, indices, updates, axis=0):
"""Update data at positions defined by indices with values in updates
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The index locations to update.
updates : relay.Expr
The values to update.
axis : int
The axis to scatter on
Returns
-------
ret : relay.Expr
The computed result.
"""
if axis < 0:
axis += len(data.shape)
assert axis >= 0
assert axis < len(data.shape)

rank = len(data.shape)
assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"

ir_funcs = {
1: gen_ir_1d,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")

cfg.add_flop(1) # A dummy value to satisfy AutoTVM

out = te.extern(
[out_shape],
[data, indices, updates],
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_gpu",
tag="scatter_gpu",
)

return out


@autotvm.register_topi_schedule("scatter.cuda")
def schedule_scatter(_, outs):
return schedule_extern(outs)


def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out):
"""Generate scatter ir for 1d inputs, using a sorting based approach.
By sorting indices and comparing neighboring two indices, we can tell which
of elements in the indices tensor can scatter its update value into the output.
Expand All @@ -438,9 +481,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
updates : tir.Tensor
The values to update, sorted by indices.
axis : int
The axis to scatter on. It must be 0 for this function.
out : tir.Tensor
The output tensor.
Expand All @@ -449,7 +489,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
ret : tir
The computational ir.
"""
assert axis == 0
n = data.shape[0]

ib = tvm.tir.ir_builder.create()
Expand Down Expand Up @@ -504,7 +543,8 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
return ib.get()


def scatter(data, indices, updates, axis=0):
@autotvm.register_topi_compute("scatter_via_sort.cuda")
def scatter_via_sort(cfg, data, indices, updates, axis=0):
"""Update data at positions defined by indices with values in updates
Parameters
Expand All @@ -528,49 +568,34 @@ def scatter(data, indices, updates, axis=0):
"""
if axis < 0:
axis += len(data.shape)
assert axis >= 0
assert axis < len(data.shape)
assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input"
assert is_thrust_available(), "Thrust is required for this op"

rank = len(data.shape)
assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions"

ir_funcs = {
1: gen_ir_1d,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update
cfg.add_flop(1) # A dummy value to satisfy AutoTVM

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")

in_bufs = [data]

if rank == 1 and is_thrust_available():
ir_funcs[1] = gen_scatter_1d_thrust
indices_sorted, updates_sorted = stable_sort_by_key_thrust(
indices, updates, for_scatter=True
)
in_bufs += [indices_sorted, updates_sorted]
else:
in_bufs += [indices, updates]
indices_sorted, updates_sorted = stable_sort_by_key_thrust(indices, updates, for_scatter=True)

out = te.extern(
[out_shape],
in_bufs,
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
[data, indices_sorted, updates_sorted],
lambda ins, outs: gen_scatter_1d_thrust(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_gpu",
tag="scatter_gpu",
name="scatter_via_sort_gpu",
tag="scatter_via_sort_gpu",
)

return out


@autotvm.register_topi_schedule("scatter_via_sort.cuda")
def schedule_scatter_via_sort(_, outs):
return schedule_extern(outs)


def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
"""Generate scatter add ir for 1d inputs, using atomic_add instruction
Expand Down

0 comments on commit e457f3b

Please sign in to comment.