Skip to content

Commit

Permalink
[sparse] fix double free due to incompatibility between buffer-deallo… (
Browse files Browse the repository at this point in the history
llvm#3303)

…cation and sparse tensors.

**NOTE**: This PR _doges_ the issue in buffer-deallocation pass instead
of resolving it. In the future, we need to fix the bug in
buffer-deallocation pass when handling code generated by sparse
compiler.
  • Loading branch information
Peiming Liu authored May 9, 2024
1 parent 5213557 commit cff144b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def invoke(*args):
"sparse-assembler{direct-out}",
"sparsification-and-bufferization",
"sparse-storage-specifier-to-llvm",
"inline", # inline sparse helper methods where useful
# Buffer deallocation pass does not know how to handle realloc.
"func.func(expand-realloc)",
# Bufferize.
"func.func(scf-bufferize)",
"func.func(tm-tensor-bufferize)",
Expand All @@ -167,6 +168,9 @@ def invoke(*args):
"func.func(tensor-bufferize)",
"func.func(finalizing-bufferize)",
"func.func(buffer-deallocation)",
# Buffer-deallocation does not work with the inlined code generated
# by sparse tensor dialect.
"inline", # inline sparse helper methods where useful
# Munge to make it ExecutionEngine compatible.
# Specifically, we rewrite calling convention boundaries to be in terms
# of unranked memref, and we rewrite the return to actually be a
Expand All @@ -180,7 +184,6 @@ def invoke(*args):
"func.func(tm-tensor-to-loops)",
"func.func(refback-munge-memref-copy)",
"func.func(convert-linalg-to-loops)",
"func.func(expand-realloc)",
"func.func(lower-affine)",
"convert-scf-to-cf",
"func.func(refback-expand-ops-for-llvm)",
Expand Down
31 changes: 19 additions & 12 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,26 +364,30 @@ def forward(self, x, y):
# CHECK-LABEL: test_sparse_eltwise
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> {
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> {
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>
# CHECK: }
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> {
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>
# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> {
# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]),
# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1,
# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]),
# CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]),
# CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]),
# CHECK: values=tensor({{\[}}[ -1., -2.],
# ...
# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32,
# CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8,
# CHECK: layout=torch.sparse_csr)
#
# CHECK: torch.mlir
# CHECK: [0 2 4 6 8]
# CHECK: [0 1 0 1 0 1 0 1]
# CHECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14.
# CHECK: -15. -16.]
# CHECK: torch.mlir.batch
#
def test_sparse_eltwise():
Expand All @@ -396,7 +400,7 @@ def forward(self, x):

net = EltNet()
dense_input = torch.reshape(
torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2)
torch.arange(1, 17, dtype=torch.float32), shape=(4, 2, 2)
)

# This yields a plain CSR with dense **sub**tensor
Expand All @@ -411,12 +415,15 @@ def forward(self, x):

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
# TODO: make these work
# res2 = sparse_jit(net, sparse_input)
# res3 = sparse_jit(net, batch_input)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2[0])
print(res2[1])
print(res2[2])
print("torch.mlir.batch")


Expand Down

0 comments on commit cff144b

Please sign in to comment.