Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] add sparse tensor computation support #1289

Merged
merged 39 commits into from
Sep 6, 2018
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
19af82a
[SPARSE] adjust implement regarding to suggestions;
liangfu Jun 20, 2018
6030a64
fix pylint;
liangfu Jun 20, 2018
7732946
derive from PlaceholderOp;
liangfu Jun 20, 2018
8b660e7
[Sparse] added CSRTensor and a placeholder for sparse tensors;
liangfu Jun 21, 2018
c98dc77
trying to add buffers to be binded with sparse placeholders;
liangfu Jun 21, 2018
033e446
avoid modifying original NDArray;
liangfu Jun 21, 2018
4952e63
enable sparse buffer;
liangfu Jun 21, 2018
12ea0bb
bug fix and unpack sparse tensor;
liangfu Jun 21, 2018
52f8e48
first successful `cs_scatter`;
liangfu Jun 22, 2018
13a40f5
bug fix;
liangfu Jun 22, 2018
0e6bb1d
implemented topi.sparse.dense;
liangfu Jun 22, 2018
9520196
bug fix;
liangfu Jun 22, 2018
f948622
first successful csrmv implement;
liangfu Jun 25, 2018
2b3a34a
test sparse tensor;
liangfu Jun 25, 2018
f6e5073
enable dynamic memory allocation for sparse tensor placeholder;
liangfu Jun 26, 2018
10cb79e
enable dynamic memory allocation for csrmv;
liangfu Jun 26, 2018
2368b89
bug fix;
liangfu Jun 26, 2018
5f9c139
improved code comment for documentation;
liangfu Jun 26, 2018
7afe978
improved reliability by initializing output ptr to zero;
liangfu Jun 27, 2018
1c75c68
implement csrmm with parallel for loop;
liangfu Jun 27, 2018
5e67aaa
enable tensorize to speedup computation;
liangfu Jun 28, 2018
cbdce65
trying implement sparse fully connected layer based on csr format;
liangfu Jun 28, 2018
6926ca7
first successful dense layer in csr format;
liangfu Jun 29, 2018
89c8835
support dense computation in csr format;
liangfu Jun 29, 2018
42eaaa3
put test functions at the bottom;
liangfu Jun 29, 2018
67a1e1f
convert to csr_matrix style input;
liangfu Jul 25, 2018
59d2ed3
satisfy lint;
liangfu Jul 25, 2018
27fd5ca
fix incorrect comment, and index type assignment problem;
liangfu Jul 26, 2018
df165b4
initial support for dense operator with sparse weights;
liangfu Jul 27, 2018
b1a1da5
bug fix in sparse-weight version of dense operator;\
liangfu Jul 27, 2018
a62050d
satisfy the linter;
liangfu Jul 27, 2018
915d3bd
update according to the comments;
liangfu Aug 16, 2018
c9a6e30
Merge remote-tracking branch 'upstream/master' into sparse
liangfu Aug 16, 2018
743558b
Update sparse.py
liangfu Aug 17, 2018
a3ea83b
Merge remote-tracking branch 'upstream/master' into sparse
liangfu Aug 21, 2018
98207fb
remove register_node declaration and path assignment in testing code;
liangfu Aug 21, 2018
6becfb9
satisfy the linter;
liangfu Aug 21, 2018
727b32b
update regarding to the comments;
liangfu Aug 24, 2018
20ead96
Merge branch 'sparse' of github.com:liangfu/tvm into sparse
liangfu Aug 24, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions python/tvm/contrib/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import numpy as _np
from .._ffi.node import register_node
from .. import expr as _expr
from .. import api as _api
from .. import tensor as _tensor
from .. import ndarray as _nd

float32 = "float32"
csr = "csr"
itype = 'int32'

@register_node
class CSRNDArray(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Integration with external functions:
DLPack only supports dense array blob. How would this class support invoking cusparse/mkl-sparse-blas functions? Or is it never possible if we don't enable DLPack to pack sparse arrays?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the pro and cons if it inherits NDArrayBase and throw exceptions on unimplemented method?

Copy link
Member Author

@liangfu liangfu Jul 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support invoking cusparse/mkl-sparse-blas functions?

I think we should keep these in mind. Let me give it a try and give you a definite answer later.

Will this bring changes to dlpack?

There is no need to change dlpack at the moment. @tqchen suggested that discussions are need before adding sparse matrix support into dlpack.

Inherit NDArrayBase and throw exceptions on unimplemented method?

Good suggestion. However, there would be too many changes at the moment, and csr_matrix is not really N-dimensional. Let's discuss what should be a proper way to make this change.

"""Sparse tensor object in CSR format."""
def __init__(self, arg1, ctx=None, shape=None):
"""Construct a sparse matrix in CSR format.

Parameters
----------
arg1 : numpy.ndarray or a tuple with (data, indices, indptr)
The corresponding a dense numpy array,
or a tuple for constructing a sparse matrix directly.

ctx: tvm.TVMContext
The corresponding context.

shape : tuple of int
The shape of the array
"""
if isinstance(arg1, tuple):
self.data, self.indices, self.indptr = arg1[0], arg1[1], arg1[2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to assert len(arg1) == 3 and do self.data, self.indices, self.indptr = arg1.
also have a better name for arg1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, i'll fix this.
a better name didn't come into my mind, the name arg1 was inspired by scipy.sparse.csr_matrix.

self.shape = shape
elif isinstance(arg1, _np.ndarray):
source_array = arg1
ridx, cidx = _np.nonzero(source_array)
data = source_array[ridx, cidx]
self.data = _nd.array(data, ctx)
indices = _np.nonzero(source_array)[1].astype(itype)
self.indices = _nd.array(indices, ctx)
indptr = [0]+_np.apply_along_axis(_np.count_nonzero, axis=1, arr=source_array).tolist()
indptr = _np.cumsum(_np.array(indptr, itype)).astype(itype)
self.indptr = _nd.array(indptr, ctx)
self.shape = source_array.shape
else:
raise RuntimeError("Construct CSRNDArray with either a tuple (data, indices, indptr) "
"or a numpy.array, can't handle type %s." % (type(arg1),))
self.stype = 'csr'
self.dtype = self.data.dtype
assert self.shape is not None
assert isinstance(self.data, _nd.NDArray)
assert isinstance(self.indices, _nd.NDArray)
assert str(self.indices.dtype) == 'int32' or \
str(self.indices.dtype) == 'int64', str(self.indices.dtype)
assert isinstance(self.indptr, _nd.NDArray)
assert str(self.indptr.dtype) == 'int32' or \
str(self.indptr.dtype) == 'int64', str(self.indptr.dtype)

def asnumpy(self):
"""Construct a full matrix and convert it to numpy array."""
full = _np.zeros(self.shape, self.dtype)
ridx = _np.diff(self.indptr.asnumpy())
ridx = _np.hstack((_np.ones((v,), itype)*i for i, v in enumerate(ridx)))
full[ridx, self.indices.asnumpy().astype(itype)] = self.data.asnumpy()
return full

def array(source_array, ctx=None, shape=None, stype='csr'):
"""Construct a sparse NDArray from numpy.ndarray"""
ret = None
if stype == 'csr':
ret = CSRNDArray(source_array, shape=shape, ctx=ctx)
else:
raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
return ret

@register_node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not need register_node for now if it is not part of node system

class SparsePlaceholderOp(object):
"""Placeholder class for sparse tensor representations."""
def __init__(self, shape, nonzeros, dtype, name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we store nonzeros?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left it unused intentionally.

# pylint: disable=unused-argument
"""Contructing a bare bone structure for a sparse matrix

Parameters
----------
shape: Tuple of Expr
The shape of the tensor

dtype: str, optional
The data type of the tensor

name: str, optional
The name hint of the tensor
"""
self.shape = shape
self.dtype = dtype
self.name = name
self.stype = 'unknown'

@register_node
class CSRPlaceholderOp(SparsePlaceholderOp):
"""Placeholder class for CSR based sparse tensor representation."""
def __init__(self, shape, nonzeros, dtype, name):
"""Contructing a bare bone structure for a csr_matrix

Parameters
----------
shape: Tuple of Expr
The shape of the tensor

dtype: str, optional
The data type of the tensor

name: str, optional
The name hint of the tensor
"""
SparsePlaceholderOp.__init__(self, shape, nonzeros, dtype, name)
self.stype = 'csr'
self.data = _api.placeholder((nonzeros,), dtype=dtype, name=self.name+'_data')
self.indices = _api.placeholder((nonzeros,), dtype=itype, name=self.name+'_indices')
self.indptr = _api.placeholder((self.shape[0]+1,), dtype=itype, name=self.name+'_indptr')
assert isinstance(self.data, _tensor.Tensor)
assert isinstance(self.indices, _tensor.Tensor)
assert isinstance(self.indptr, _tensor.Tensor)

def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None):
"""Construct an empty sparse tensor object.

Parameters
----------
shape: Tuple of Expr
The shape of the tensor

dtype: str, optional
The data type of the tensor

name: str, optional
The name hint of the tensor

stype: str, optional
The name storage type of the sparse tensor (e.g. csr, coo, ell)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing doc for nonzeros. so as that in SparsePlaceholderOp and CSRPlaceholderOp.


Returns
-------
tensor: SparsePlaceholderOp
The created sparse tensor placeholder
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
nonzeros = 0 if nonzeros is None else nonzeros
dtype = float32 if dtype is None else dtype
stype = csr if stype is None else stype
ret = None
if stype == 'csr':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'csr' -> csr. Or just remove the constant def, always use str, actually I prefer the later.

ret = CSRPlaceholderOp(shape=shape, nonzeros=nonzeros, dtype=dtype, name=name)
else:
raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
return ret
104 changes: 104 additions & 0 deletions tests/python/contrib/test_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os, sys
thisdir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(thisdir, '../../../python'))

import tvm
import tvm.contrib.sparse as tvmsp
import tvm.ndarray as _nd
import numpy as np
from collections import namedtuple

def test_static_tensor():
dtype = 'float32'
stype = 'csr'
target = 'llvm'
ctx = tvm.context(target, 0)
m = tvm.var('m')
n = tvm.var('n')
A = tvmsp.placeholder(shape=(m, n), name='A', dtype=dtype)
assert(A.stype == 'csr')
n = 3
a = np.maximum(np.random.uniform(size=(n,n)).astype(dtype)-.6, 0.)
a = tvmsp.array(a, ctx)
A.data = tvm.placeholder(a.data.shape, dtype, name='A_data')
Ab = tvm.decl_buffer(a.data.shape, dtype, name='A_data')
binds = {A.data: Ab}
C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
s = tvm.create_schedule(C.op)
f = tvm.build(s, [A.data, C], target, binds=binds)
c = tvmsp.array(np.zeros((n,n), dtype), ctx)
c.data = tvm.nd.empty(a.data.shape, dtype)
c.indices = a.indices
c.indptr = a.indptr
f(a.data, c.data)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)

def test_dynamic_tensor():
dtype = 'float32'
stype = 'csr'
target = 'llvm'
ctx = tvm.context(target, 0)
nr, nc, n = tvm.var('nr'), tvm.var('nc'), tvm.var('n')
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
assert(A.stype == 'csr')
C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
s = tvm.create_schedule(C.op)
_nr, _nc = 3, 5
a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.)
a = tvmsp.array(a, ctx)
assert a.data.dtype == a.dtype
Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr'])
Ab.data = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_data')
Ab.indices = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_indices')
binds = {A.data: Ab.data, A.indices: Ab.indices}
f = tvm.build(s, [nr, A.data, C], target, binds=binds)
c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx)
c.data = tvm.nd.empty(a.data.shape, dtype)
c.indices = a.indices
c.indptr = a.indptr
f(a.data.shape[0], a.data, c.data)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)

def test_sparse_array_tuple():
dtype, itype = 'float32', 'int32'
stype = 'csr'
target = 'llvm'
ctx = tvm.context(target, 0)
nr, nc, n = tvm.var('nr'), tvm.var('nc'), tvm.var('n')
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
assert(A.stype == 'csr')
C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
s = tvm.create_schedule(C.op)
_nr, _nc = 3, 5
a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.)
# convert to sparse array tuple
source_array = a
ridx, cidx = np.nonzero(source_array)
data = source_array[ridx, cidx]
a_data = _nd.array(data, ctx)
indices = np.nonzero(source_array)[1].astype(itype)
a_indices = _nd.array(indices, ctx)
indptr = [0]+np.apply_along_axis(np.count_nonzero, axis=1, arr=source_array).tolist()
indptr = np.cumsum(np.array(indptr, itype)).astype(itype)
a_indptr = _nd.array(indptr, ctx)
a_init = (a_data, a_indices, a_indptr)
# construct tvm sparse array with tuple
a = tvmsp.array(a_init, shape=source_array.shape, ctx=ctx)
assert a.data.dtype == a.dtype
Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr'])
Ab.data = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_data')
Ab.indices = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_indices')
binds = {A.data: Ab.data, A.indices: Ab.indices}
f = tvm.build(s, [nr, A.data, C], target, binds=binds)
c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx)
c.data = tvm.nd.empty(a.data.shape, dtype)
c.indices = a.indices
c.indptr = a.indptr
f(a.data.shape[0], a.data, c.data)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)

if __name__ == "__main__":
test_static_tensor()
test_dynamic_tensor()
test_sparse_array_tuple()

1 change: 1 addition & 0 deletions topi/python/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from . import rocm
from . import vision
from . import image
from . import sparse
from . import hls
# not import testing by default
# because testing can have extra deps that are not necessary
Expand Down
7 changes: 7 additions & 0 deletions topi/python/topi/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# pylint: disable=wildcard-import
"""Sparse operators"""
from __future__ import absolute_import as _abs

from .csrmv import csrmv
from .csrmm import csrmm
from .dense import dense
94 changes: 94 additions & 0 deletions topi/python/topi/sparse/csrmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""TVM operator compute SpMM in CSR format."""
from __future__ import absolute_import
import tvm
from .. import tag
from ..util import simplify

def csrmm_default(data, indices, indptr, weight, bias=None):
# pylint: disable=invalid-name
"""The default implementation of csrmm in topi.

Parameters
----------
data : tvm.Tensor
1-D with shape [nonzeros]

indices : tvm.Tensor
1-D with shape [nonzeros]

indptr : tvm.Tensor
1-D with shape [m+1]

weight : tvm.Tensor
2-D with shape [k, n]

bias : tvm.Tensor, optional
1-D with shape [m]

Returns
-------
output : tvm.Tensor
2-D with shape [m, n]
"""
assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
and len(weight.shape) == 2, "only support 2-dim csrmm"
assert isinstance(weight, tvm.tensor.Tensor), \
"weight matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(weight))
if bias is not None:
assert len(bias.shape) == 1
M = simplify(indptr.shape[0]-1)
_, N = weight.shape
def csrmm_default_ir(data, indices, indptr, weight, out):
"""define ir for csrmm"""
irb = tvm.ir_builder.create()
data_ptr = irb.buffer_ptr(data)
indices_ptr = irb.buffer_ptr(indices)
indptr_ptr = irb.buffer_ptr(indptr)
weight_ptr = irb.buffer_ptr(weight)
out_ptr = irb.buffer_ptr(out)
M = simplify(indptr.shape[0]-1)
_, N = weight.shape
with irb.for_range(0, N, for_type="vectorize", name='n') as n:
with irb.for_range(0, M, for_type="parallel", name='row') as row:
dot = irb.allocate('float32', (1,), name='dot', scope='local')
out_ptr[row*N+n] = 0.
dot[0] = 0.
row_start = indptr_ptr[row]
row_end = indptr_ptr[row+1]
row_elems = row_end-row_start
with irb.for_range(0, row_elems, name='idx') as idx:
elem = row_start+idx
dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]*N+n]
out_ptr[row*N+n] += dot[0]
return irb.get()
oshape = (M, N)
matmul = tvm.extern(oshape, [data, indices, indptr, weight],
lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="csrmm", dtype='float32', name='out')
if bias is not None:
matmul = tvm.compute(oshape, lambda i, j: matmul[i, j] + bias[i], \
tag=tag.BROADCAST)
return matmul


def csrmm(a, b, c=None):
"""The `csrmm` routine performs a matrix-matrix operation defined as :math:`C := A*B + C`,
where `B` and `C` are dense matrices, `A` is an m-by-k sparse matrix in the CSR format.

Parameters
----------
a : tvm.contrib.sparse.CSRNDArray
2-D sparse matrix with shape [m, k]

b : tvm.Tensor
2-D dense matrix with shape [k, n]

c : tvm.Tensor, optional
1-D dense vector with shape [n]

Returns
-------
output : tvm.Tensor
2-D with shape [m, n]
"""
return csrmm_default(a.data, a.indices, a.indptr, b, c)
Loading