diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 4cb372ceab57..95a859359ca4 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -123,6 +123,11 @@ def sparse_export( # Zero preserving elt-wise unary op. if node.name in {"abs", "neg", "relu", "sin"}: node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) + elif node.name == "_to_sparse": + dim = len(node.meta.get("val").shape) + node.meta["sparsity"] = SparsityMeta( + torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 + ) return prog @@ -458,3 +463,51 @@ def forward(self, x): print("torch.sparse") print(res1) print("torch.mlir") + + +@run +# CHECK-LABEL: test_sparse_activation +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> { +# CHECK: %[[N1:.*]] = torch.constant.none +# CHECK: %[[N2:.*]] = torch.constant.none +# CHECK: %[[N3:.*]] = torch.constant.none +# CHECK: %[[R:.*]] = torch.operator "torch.aten._to_sparse"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> +# CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor(indices=tensor({{\[}}[0, 0, 0, 0, 1, 1, 1, 1], +# CHECK: [0, 0, 1, 1, 0, 0, 1, 1], +# CHECK: [0, 1, 0, 1, 0, 1, 0, 1]{{\]}}), +# CHECK: values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]), +# CHECK: size=(2, 2, 2), nnz=8, layout=torch.sparse_coo) +# CHECK: torch.mlir +# CHECK: [0 8] +# CHECK: [0 0 0 0 1 1 1 1] +# CHECK: [0 0 1 1 0 0 1 1] +# CHECK: [0 1 0 1 0 1 0 1] +# CHECK: [1. 1. 1. 1. 1. 1. 1. 1.] +# +def test_sparse_activation(): + class SparseActivationCOO(torch.nn.Module): + def forward(self, x): + return x.to_sparse() + + net = SparseActivationCOO() + x = torch.ones(2, 2, 2) + m = export_and_import(net, x) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(x) + res2 = sparse_jit(net, x) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2[0]) + print(res2[1]) + print(res2[2]) + print(res2[3]) + print(res2[4])