Skip to content

Commit

Permalink
[torch-mlir][sparse] add block sparsity to mlir lowering (llvm#2942)
Browse files Browse the repository at this point in the history
Also note that we are in the process of proposing SparseTensorMetadata
to PyTorch FX graph export (see
pytorch/pytorch#117907). This will hopefully
eventually replace the current data structures in torch-mlir.
  • Loading branch information
aartbik authored Feb 23, 2024
1 parent 55dc8de commit 4147b28
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 29 deletions.
35 changes: 26 additions & 9 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,20 @@

@dataclass(frozen=True)
class SparsityMeta:
"""Class for keeping track of sparsity meta data."""
"""
Class for keeping track of sparsity meta data.
NOTE: this will be fully replaced by
torch.fx.passes.shape_prop.SparseTensorMetadata
"""

layout: torch.layout
batch_dim: int
sparse_dim: int
dense_dim: int
pos_width: int
crd_width: int
blocksize: Optional[tuple[int, int]]
pos_dtype: torch.dtype
crd_dtype: torch.dtype


def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
Expand All @@ -240,21 +246,31 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
)
dim = batch_dim + sparse_dim + dense_dim
assert dim == len(shape)
blocksize = sparsity.blocksize

dims = ",".join(f"d{d}" for d in range(0, dim))

if sparsity.layout is torch.sparse_coo:
assert sparse_dim == 2 # TODO: deeper sparse dims
assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton"
elif sparsity.layout is torch.sparse_csr:
assert sparse_dim == 2
assert sparse_dim == 2 and blocksize is None
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
elif sparsity.layout is torch.sparse_csc:
assert sparse_dim == 2
assert sparse_dim == 2 and blocksize is None
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
else:
# TODO: block format (derive block size!)
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
assert sparse_dim == 2 and blocksize is not None
if sparsity.layout is torch.sparse_bsr:
i, j = batch_dim, batch_dim + 1
else:
assert sparsity.layout is torch.sparse_bsc
j, i = batch_dim, batch_dim + 1
m, n = blocksize
lvls = (
f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed,"
f"d{i} mod {m}:dense,d{j} mod {n}:dense"
)

if batch_dim > 0:
batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim))
Expand All @@ -264,7 +280,8 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
lvls = f"{lvls},{dense}"

posw, crdw = sparsity.pos_width, sparsity.crd_width
posw = torch.iinfo(sparsity.pos_dtype).bits
crdw = torch.iinfo(sparsity.crd_dtype).bits
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"


Expand Down
63 changes: 43 additions & 20 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,50 +31,49 @@
]


def sparse_overhead_width(d: torch.dtype) -> int:
"""Returns bit-width for admissible overhead type."""
if d is torch.int64:
return 64
if d is torch.int32:
return 32
if d is torch.int16:
return 16
if d is torch.int8:
return 8
raise RuntimeError(f"Unsupported overhead type {d}")


def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
"""Returns a meta data tuple for the given sparse tensor."""
"""
Returns a meta data tuple for the given sparse tensor.
NOTE: this will be fully replaced by fx graph SparseTensorMetadata
"""
sparse_dim = a.sparse_dim()
dense_dim = a.dense_dim()
batch_dim = a.ndim - dense_dim - sparse_dim
blocksize = None
if a.layout is torch.sparse_coo:
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
sparse_overhead_width(a.indices().dtype),
sparse_overhead_width(a.indices().dtype),
blocksize,
a.indices().dtype,
a.indices().dtype,
)
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
if a.layout is torch.sparse_bsr:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
sparse_overhead_width(a.crow_indices().dtype),
sparse_overhead_width(a.col_indices().dtype),
blocksize,
a.crow_indices().dtype,
a.col_indices().dtype,
)
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
if a.layout is torch.sparse_bsc:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
sparse_overhead_width(a.ccol_indices().dtype),
sparse_overhead_width(a.row_indices().dtype),
blocksize,
a.ccol_indices().dtype,
a.row_indices().dtype,
)
else:
raise RuntimeError(f"Unsupported sparse layout for {a}")
Expand Down Expand Up @@ -214,6 +213,30 @@ def forward(self, x):
print("torch.mlir =", res2)


@run
# CHECK-LABEL: test_sparse_SpMV
# CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[10,10],f32,#[[$BSR]]>,
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> {
# CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
# CHECK: return %[[R]] : !torch.vtensor<[10],f32>
# CHECK: }
def test_sparse_SpMV():
class SpMVNet(torch.nn.Module):
def __init__(self):
super(SpMVNet, self).__init__()

def forward(self, x, v):
return torch.mv(x, v)

dense_vector = torch.ones(10)
dense_input = torch.ones(10, 10)
sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2))
m = export_and_import(SpMVNet(), sparse_input, dense_vector)
print(m)


@run
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }>
Expand Down

0 comments on commit 4147b28

Please sign in to comment.