Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster sparse_dense on GPUs #6580

Merged
merged 14 commits into from
Oct 9, 2020
16 changes: 16 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ def compute_sparse_dense(attrs, inputs, out_type):
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_alter_op_layout("nn.sparse_dense")
def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type):
"""Alternate the layout of sparse_dense"""
return topi.nn.sparse_dense_alter_layout(attrs, inputs, tinfos, out_type)


@reg.register_compute("nn.internal.sparse_dense_padded")
def compute_sparse_dense_padded(attrs, inputs, out_type):
"""Compute definition of sparse_dense_padded"""
raise NotImplementedError("nn.internal.sparse_dense_padded is only available on cuda")


reg.register_strategy("nn.internal.sparse_dense_padded", strategy.sparse_dense_padded_strategy)
reg.register_pattern("nn.internal.sparse_dense_padded", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# sparse_transpose
@reg.register_compute("nn.sparse_transpose")
def compute_sparse_transpose(attrs, inputs, out_type):
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,15 +2016,17 @@ def sparse_dense(data, weight):
data : tvm.relay.Expr
The input data for the matrix multiplication

weight : namedtuple.
weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The sparse weight matrix for the matrix multiplication.

Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
if hasattr(weight, "indices"):
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
return _make.sparse_dense(data, weight[0], weight[1], weight[2])


def sparse_transpose(x):
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,19 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@sparse_dense_padded_strategy.register(["cuda", "gpu"])
def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_dense(topi.cuda.sparse_dense_padded),
wrap_topi_schedule(topi.cuda.schedule_sparse_dense_padded),
name="sparse_dense_padded.cuda",
plevel=10,
)
return strategy


@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,12 @@ def sparse_dense_strategy(attrs, inputs, out_type, target):
return strategy


@override_native_generic_func("sparse_dense_padded_strategy")
def sparse_dense_padded_strategy(attrs, inputs, out_type, target):
"""sparse dense padded generic strategy"""
raise NotImplementedError("sparse_dense_padded is only implemented for cuda")


# sparse_transpose
@generic_func
def schedule_sparse_transpose(attrs, outs, target):
Expand Down
39 changes: 34 additions & 5 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class BufferVar(ObjectGeneric):

Do not create it directly, create use IRBuilder.

BufferVars support array access either via a linear index, or, if given a
shape, via a multidimensional index.

Examples
--------
In the follow example, x is BufferVar.
Expand All @@ -55,16 +58,23 @@ class BufferVar(ObjectGeneric):
x = ib.pointer("float32")
x[0] = x[10] + 1

y = ib.allocate("float32", (32, 32))
# Array access using a linear index
y[(2*32) + 31] = 0.
# The same array access using a multidimensional index
y[2, 31] = 0.

See Also
--------
IRBuilder.pointer
IRBuilder.buffer_ptr
IRBuilder.allocate
"""

def __init__(self, builder, buffer_var, content_type):
def __init__(self, builder, buffer_var, shape, content_type):
tkonolige marked this conversation as resolved.
Show resolved Hide resolved
self._builder = builder
self._buffer_var = buffer_var
self._shape = shape
self._content_type = content_type

def asobject(self):
Expand All @@ -74,8 +84,23 @@ def asobject(self):
def dtype(self):
return self._content_type

def _linear_index(self, index):
if not isinstance(index, tuple) or self._shape is None:
return index
assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % (
len(index),
len(self._shape),
)
dim_size = 1
lidx = 0
for dim, idx in zip(reversed(self._shape), reversed(index)):
lidx += idx * dim_size
dim_size *= dim
return lidx

def __getitem__(self, index):
t = DataType(self._content_type)
index = self._linear_index(index)
tkonolige marked this conversation as resolved.
Show resolved Hide resolved
if t.lanes > 1:
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
Expand All @@ -87,6 +112,7 @@ def __setitem__(self, index, value):
raise ValueError(
"data type does not match content type %s vs %s" % (value.dtype, self._content_type)
)
index = self._linear_index(index)
t = DataType(self._content_type)
if t.lanes > 1:
base = index * t.lanes
Expand Down Expand Up @@ -341,7 +367,7 @@ def allocate(self, dtype, shape, name="buf", scope=None):
if scope:
self.scope_attr(buffer_var, "storage_scope", scope)
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
return BufferVar(self, buffer_var, dtype)
return BufferVar(self, buffer_var, shape, dtype)

def pointer(self, content_type, name="ptr"):
"""Create pointer variable with content type.
Expand All @@ -360,22 +386,25 @@ def pointer(self, content_type, name="ptr"):
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, dtype="handle")
return BufferVar(self, buffer_var, content_type)
return BufferVar(self, buffer_var, None, content_type)

def buffer_ptr(self, buf):
def buffer_ptr(self, buf, shape=None):
"""Create pointer variable corresponds to buffer ptr.

Parameters
----------
buf : Buffer
The buffer to be extracted.

shape : Tuple
Optional shape of the buffer. Overrides existing buffer shape.

Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
return BufferVar(self, buf.data, buf.dtype)
return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype)

def likely(self, expr):
"""Add likely tag for expression.
Expand Down
Loading