Skip to content

Commit

Permalink
[torch-mlir][sparse] recognize to_dense primitive (llvm#3308)
Browse files Browse the repository at this point in the history
also maps simply to sparse_tensor.convert
the sparsity types do the rest!
  • Loading branch information
aartbik authored May 9, 2024
1 parent 89bb740 commit a033bbf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2451,8 +2451,8 @@ class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
};
// Static initializer.
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
"torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc",
"torch.aten._to_bsr", "torch.aten._to_bsc",
"torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr",
"torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc",
};
} // namespace

Expand Down
34 changes: 33 additions & 1 deletion test/Conversion/TorchToLinalg/sparse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>)
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32>
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]> -> !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse> {
Expand All @@ -66,3 +66,35 @@ func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse>
return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse>
}

// -----

#sparse = #sparse_tensor.encoding<{
map = (d0, d1, d2, d3, d4) ->
(d0 : compressed(nonunique),
d1 : singleton(nonunique, soa),
d2 : singleton(nonunique, soa),
d3 : singleton(nonunique, soa),
d4 : singleton(soa)
),
posWidth = 64,
crdWidth = 64
}>

// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3, d4) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(nonunique, soa), d3 : singleton(nonunique, soa), d4 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
// CHECK-LABEL: func.func @deactivate(
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>)
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]> -> tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32, #[[$ST]]> to tensor<128x64x30x30x6xf32>
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32> -> !torch.vtensor<[128,64,30,30,6],f32>
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32>
func.func @deactivate(%arg0: !torch.vtensor<[128,64,30,30,6],f32,#sparse>)
-> !torch.vtensor<[128,64,30,30,6],f32> {
%none_0 = torch.constant.none
%none_1 = torch.constant.none
%none_2 = torch.constant.none
%result = torch.operator "torch.aten._to_dense"(%arg0, %none_0, %none_1, %none_2)
: (!torch.vtensor<[128,64,30,30,6],f32,#sparse>, !torch.none, !torch.none, !torch.none)
-> !torch.vtensor<[128,64,30,30,6],f32>
return %result : !torch.vtensor<[128,64,30,30,6],f32>
}

0 comments on commit a033bbf

Please sign in to comment.