diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 89a3caa16843..aea76c621d46 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -216,14 +216,20 @@ @dataclass(frozen=True) class SparsityMeta: - """Class for keeping track of sparsity meta data.""" + """ + Class for keeping track of sparsity meta data. + + NOTE: this will be fully replaced by + torch.fx.passes.shape_prop.SparseTensorMetadata + """ layout: torch.layout batch_dim: int sparse_dim: int dense_dim: int - pos_width: int - crd_width: int + blocksize: Optional[tuple[int, int]] + pos_dtype: torch.dtype + crd_dtype: torch.dtype def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: @@ -240,21 +246,31 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: ) dim = batch_dim + sparse_dim + dense_dim assert dim == len(shape) + blocksize = sparsity.blocksize dims = ",".join(f"d{d}" for d in range(0, dim)) if sparsity.layout is torch.sparse_coo: - assert sparse_dim == 2 # TODO: deeper sparse dims + assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton" elif sparsity.layout is torch.sparse_csr: - assert sparse_dim == 2 + assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" elif sparsity.layout is torch.sparse_csc: - assert sparse_dim == 2 + assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" else: - # TODO: block format (derive block size!) - raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") + assert sparse_dim == 2 and blocksize is not None + if sparsity.layout is torch.sparse_bsr: + i, j = batch_dim, batch_dim + 1 + else: + assert sparsity.layout is torch.sparse_bsc + j, i = batch_dim, batch_dim + 1 + m, n = blocksize + lvls = ( + f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed," + f"d{i} mod {m}:dense,d{j} mod {n}:dense" + ) if batch_dim > 0: batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim)) @@ -264,7 +280,8 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) lvls = f"{lvls},{dense}" - posw, crdw = sparsity.pos_width, sparsity.crd_width + posw = torch.iinfo(sparsity.pos_dtype).bits + crdw = torch.iinfo(sparsity.crd_dtype).bits return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index e936e40cb039..87eecb2977d5 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -31,50 +31,49 @@ ] -def sparse_overhead_width(d: torch.dtype) -> int: - """Returns bit-width for admissible overhead type.""" - if d is torch.int64: - return 64 - if d is torch.int32: - return 32 - if d is torch.int16: - return 16 - if d is torch.int8: - return 8 - raise RuntimeError(f"Unsupported overhead type {d}") - - def sparse_metadata(a: torch.Tensor) -> SparsityMeta: - """Returns a meta data tuple for the given sparse tensor.""" + """ + Returns a meta data tuple for the given sparse tensor. + + NOTE: this will be fully replaced by fx graph SparseTensorMetadata + """ sparse_dim = a.sparse_dim() dense_dim = a.dense_dim() batch_dim = a.ndim - dense_dim - sparse_dim + blocksize = None if a.layout is torch.sparse_coo: return SparsityMeta( a.layout, batch_dim, sparse_dim, dense_dim, - sparse_overhead_width(a.indices().dtype), - sparse_overhead_width(a.indices().dtype), + blocksize, + a.indices().dtype, + a.indices().dtype, ) elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: + if a.layout is torch.sparse_bsr: + blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] return SparsityMeta( a.layout, batch_dim, sparse_dim, dense_dim, - sparse_overhead_width(a.crow_indices().dtype), - sparse_overhead_width(a.col_indices().dtype), + blocksize, + a.crow_indices().dtype, + a.col_indices().dtype, ) elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: + if a.layout is torch.sparse_bsc: + blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] return SparsityMeta( a.layout, batch_dim, sparse_dim, dense_dim, - sparse_overhead_width(a.ccol_indices().dtype), - sparse_overhead_width(a.row_indices().dtype), + blocksize, + a.ccol_indices().dtype, + a.row_indices().dtype, ) else: raise RuntimeError(f"Unsupported sparse layout for {a}") @@ -214,6 +213,30 @@ def forward(self, x): print("torch.mlir =", res2) +@run +# CHECK-LABEL: test_sparse_SpMV +# CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[10,10],f32,#[[$BSR]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> { +# CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> +# CHECK: return %[[R]] : !torch.vtensor<[10],f32> +# CHECK: } +def test_sparse_SpMV(): + class SpMVNet(torch.nn.Module): + def __init__(self): + super(SpMVNet, self).__init__() + + def forward(self, x, v): + return torch.mv(x, v) + + dense_vector = torch.ones(10) + dense_input = torch.ones(10, 10) + sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2)) + m = export_and_import(SpMVNet(), sparse_input, dense_vector) + print(m) + + @run # CHECK-LABEL: test_sparse_SpMM # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }>