From a033bbfe6c2c3a7bc66102a03ae760136d0112aa Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Wed, 8 May 2024 22:50:17 -0700 Subject: [PATCH] [torch-mlir][sparse] recognize to_dense primitive (#3308) also maps simply to sparse_tensor.convert the sparsity types do the rest! --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 4 +-- test/Conversion/TorchToLinalg/sparse.mlir | 34 ++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index d034a8293463..67d13c5fb644 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2451,8 +2451,8 @@ class ConvertSparseOperatorOp : public OpConversionPattern { }; // Static initializer. SmallVector ConvertSparseOperatorOp::legalizedNames = { - "torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc", - "torch.aten._to_bsr", "torch.aten._to_bsc", + "torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr", + "torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc", }; } // namespace diff --git a/test/Conversion/TorchToLinalg/sparse.mlir b/test/Conversion/TorchToLinalg/sparse.mlir index 4dc580ea3164..f343aedf5545 100644 --- a/test/Conversion/TorchToLinalg/sparse.mlir +++ b/test/Conversion/TorchToLinalg/sparse.mlir @@ -54,7 +54,7 @@ func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>, // CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>) // CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32> // CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]> -// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]> +// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]> -> !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]> // CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]> func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>) -> !torch.vtensor<[128,64,30,30,6],f32,#sparse> { @@ -66,3 +66,35 @@ func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>) -> !torch.vtensor<[128,64,30,30,6],f32,#sparse> return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse> } + +// ----- + +#sparse = #sparse_tensor.encoding<{ + map = (d0, d1, d2, d3, d4) -> + (d0 : compressed(nonunique), + d1 : singleton(nonunique, soa), + d2 : singleton(nonunique, soa), + d3 : singleton(nonunique, soa), + d4 : singleton(soa) + ), + posWidth = 64, + crdWidth = 64 +}> + +// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3, d4) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(nonunique, soa), d3 : singleton(nonunique, soa), d4 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +// CHECK-LABEL: func.func @deactivate( +// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>) +// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]> -> tensor<128x64x30x30x6xf32, #[[$ST]]> +// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32, #[[$ST]]> to tensor<128x64x30x30x6xf32> +// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32> -> !torch.vtensor<[128,64,30,30,6],f32> +// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32> +func.func @deactivate(%arg0: !torch.vtensor<[128,64,30,30,6],f32,#sparse>) + -> !torch.vtensor<[128,64,30,30,6],f32> { + %none_0 = torch.constant.none + %none_1 = torch.constant.none + %none_2 = torch.constant.none + %result = torch.operator "torch.aten._to_dense"(%arg0, %none_0, %none_1, %none_2) + : (!torch.vtensor<[128,64,30,30,6],f32,#sparse>, !torch.none, !torch.none, !torch.none) + -> !torch.vtensor<[128,64,30,30,6],f32> + return %result : !torch.vtensor<[128,64,30,30,6],f32> +}