Skip to content

Commit

Permalink
[torch-mlir][sparse] add JIT test for block sparse SpMV (#2955)
Browse files Browse the repository at this point in the history
This required adding a "decompose" pass to the torch lowering, since
torch.mv was not directly handled by lowering to linalg
  • Loading branch information
aartbik authored Feb 27, 2024
1 parent e30a083 commit 3021254
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def sparse_jit(f, *args, **kwargs):
module = export_and_import(f, *args, *kwargs)
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
(
"builtin.module("
"func.func(torch-decompose-complex-ops),"
"torch-backend-to-linalg-on-tensors-backend-pipeline)"
),
"Lowering TorchFX IR -> Linalg IR",
enable_ir_printing=False,
)
Expand Down Expand Up @@ -200,13 +204,13 @@ def __init__(self):
def forward(self, x):
return x.sum()

net = SumNet()
dense_input = torch.ones(64, 64)
sparse_input = dense_input.to_sparse_csr()
m = export_and_import(SumNet(), sparse_input)
m = export_and_import(net, sparse_input)
print(m)

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
net = SumNet()
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
print("torch.sparse =", res1)
Expand All @@ -222,6 +226,10 @@ def forward(self, x):
# CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
# CHECK: return %[[R]] : !torch.vtensor<[10],f32>
# CHECK: }
#
# CHECK: torch.sparse = tensor([55., 55., 55., 55., 55., 55., 55., 55., 55., 55.])
# CHECK: torch.mlir = [55. 55. 55. 55. 55. 55. 55. 55. 55. 55.]
#
def test_sparse_SpMV():
class SpMVNet(torch.nn.Module):
def __init__(self):
Expand All @@ -230,12 +238,19 @@ def __init__(self):
def forward(self, x, v):
return torch.mv(x, v)

dense_vector = torch.ones(10)
net = SpMVNet()
dense_vector = torch.arange(1, 11, dtype=torch.float32)
dense_input = torch.ones(10, 10)
sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2))
m = export_and_import(SpMVNet(), sparse_input, dense_vector)
m = export_and_import(net, sparse_input, dense_vector)
print(m)

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(sparse_input, dense_vector)
res2 = sparse_jit(net, sparse_input, dense_vector)
print("torch.sparse =", res1)
print("torch.mlir =", res2)


@run
# CHECK-LABEL: test_sparse_SpMM
Expand Down Expand Up @@ -264,15 +279,15 @@ def __init__(self):
def forward(self, x, y):
return torch.matmul(x, y)

net = MatMulNet()
dense_input = torch.ones(8, 8)
sparse_input = dense_input.to_sparse_coo()
m = export_and_import(MatMulNet(), sparse_input, dense_input)
m = export_and_import(net, sparse_input, dense_input)
print(m)

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
# TODO: run with COO, right now only CSR works
sparse_input = dense_input.to_sparse_csr()
net = MatMulNet()
res1 = net(sparse_input, dense_input)
res2 = sparse_jit(net, sparse_input, dense_input)
print("torch.sparse")
Expand Down Expand Up @@ -311,6 +326,7 @@ def forward(self, x, y):
# ...
# CHECK: [-61. -62.]
# CHECK: [-63. -64.]{{\]\]}}
#
def test_sparse_eltwise():
class EltNet(torch.nn.Module):
def __init__(self):
Expand All @@ -319,18 +335,19 @@ def __init__(self):
def forward(self, x):
return -x

net = EltNet()
dense_input = torch.reshape(
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2)
)

# This yields a **batched** CSR.
sparse_input = dense_input.to_sparse_csr(dense_dim=0)
m = export_and_import(EltNet(), sparse_input)
m = export_and_import(net, sparse_input)
print(m)

# This yields a plain CSR with dense **sub**tensor
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
m = export_and_import(EltNet(), sparse_input)
m = export_and_import(net, sparse_input)
print(m)

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
Expand All @@ -339,7 +356,6 @@ def forward(self, x):
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
# (2) for dense_dim=0, this will need a dense(batched) property
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
net = EltNet()
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
print("torch.sparse")
Expand Down

0 comments on commit 3021254

Please sign in to comment.