From fc71dc611a43af1f9001fe349c44c9d52db4ad39 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 27 Feb 2024 11:11:15 -0800 Subject: [PATCH] [torch-mlir][sparse] add JIT test for block sparse SpMV This required adding a "decompose" pass to the torch lowering, since torch.mv was not directly handled by lowering to linalg --- test/python/fx_importer/sparse_test.py | 36 +++++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 87eecb2977d5..138942b07092 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -139,7 +139,11 @@ def sparse_jit(f, *args, **kwargs): module = export_and_import(f, *args, *kwargs) run_pipeline_with_repro_report( module, - "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + ( + "builtin.module(" + "func.func(torch-decompose-complex-ops)," + "torch-backend-to-linalg-on-tensors-backend-pipeline)" + ), "Lowering TorchFX IR -> Linalg IR", enable_ir_printing=False, ) @@ -200,13 +204,13 @@ def __init__(self): def forward(self, x): return x.sum() + net = SumNet() dense_input = torch.ones(64, 64) sparse_input = dense_input.to_sparse_csr() - m = export_and_import(SumNet(), sparse_input) + m = export_and_import(net, sparse_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. - net = SumNet() res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) print("torch.sparse =", res1) @@ -222,6 +226,10 @@ def forward(self, x): # CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> # CHECK: return %[[R]] : !torch.vtensor<[10],f32> # CHECK: } +# +# CHECK: torch.sparse = tensor([55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]) +# CHECK: torch.mlir = [55. 55. 55. 55. 55. 55. 55. 55. 55. 55.] +# def test_sparse_SpMV(): class SpMVNet(torch.nn.Module): def __init__(self): @@ -230,12 +238,19 @@ def __init__(self): def forward(self, x, v): return torch.mv(x, v) - dense_vector = torch.ones(10) + net = SpMVNet() + dense_vector = torch.arange(1, 11, dtype=torch.float32) dense_input = torch.ones(10, 10) sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2)) - m = export_and_import(SpMVNet(), sparse_input, dense_vector) + m = export_and_import(net, sparse_input, dense_vector) print(m) + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(sparse_input, dense_vector) + res2 = sparse_jit(net, sparse_input, dense_vector) + print("torch.sparse =", res1) + print("torch.mlir =", res2) + @run # CHECK-LABEL: test_sparse_SpMM @@ -264,15 +279,15 @@ def __init__(self): def forward(self, x, y): return torch.matmul(x, y) + net = MatMulNet() dense_input = torch.ones(8, 8) sparse_input = dense_input.to_sparse_coo() - m = export_and_import(MatMulNet(), sparse_input, dense_input) + m = export_and_import(net, sparse_input, dense_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # TODO: run with COO, right now only CSR works sparse_input = dense_input.to_sparse_csr() - net = MatMulNet() res1 = net(sparse_input, dense_input) res2 = sparse_jit(net, sparse_input, dense_input) print("torch.sparse") @@ -311,6 +326,7 @@ def forward(self, x, y): # ... # CHECK: [-61. -62.] # CHECK: [-63. -64.]{{\]\]}} +# def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): @@ -319,18 +335,19 @@ def __init__(self): def forward(self, x): return -x + net = EltNet() dense_input = torch.reshape( torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2) ) # This yields a **batched** CSR. sparse_input = dense_input.to_sparse_csr(dense_dim=0) - m = export_and_import(EltNet(), sparse_input) + m = export_and_import(net, sparse_input) print(m) # This yields a plain CSR with dense **sub**tensor sparse_input = dense_input.to_sparse_csr(dense_dim=1) - m = export_and_import(EltNet(), sparse_input) + m = export_and_import(net, sparse_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. @@ -339,7 +356,6 @@ def forward(self, x): # (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result # (2) for dense_dim=0, this will need a dense(batched) property sparse_input = dense_input.to_sparse_csr(dense_dim=1) - net = EltNet() res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) print("torch.sparse")