From 88533b1968267169b654611f13550e5d42e288c9 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 11 Apr 2024 15:55:56 +0800 Subject: [PATCH] [Stablehlo] fix aten.arange's lowering to stablehlo (#3138) * promote to f64 to do division, avoid division on i64 (floor div) * refactor torch-to-stablehlo-pipeline --- lib/Conversion/TorchToStablehlo/Basic.cpp | 20 ++++++++++--------- .../TorchConversion/Transforms/Passes.cpp | 17 ++++++++++------ .../stablehlo_backends/linalg_on_tensors.py | 1 - 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 3a6c5396b3f8..4d6c8d194554 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1492,15 +1492,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Get length of the 1-d output tensor Value subOut = rewriter.create(loc, end, start); - Value divOut = rewriter.create(loc, subOut, step); - - Value resultLength = rewriter.create( - loc, RankedTensorType::get({1}, dtype), divOut); - if (dtype.isa()) { - resultLength = rewriter.create(loc, resultLength); - resultLength = rewriter.create( - loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); - } + // promote div to f64 + Type divType = RankedTensorType::get({}, rewriter.getF64Type()); + Value divOut = rewriter.create( + loc, rewriter.create(loc, divType, subOut), + rewriter.create(loc, divType, step)); + // ceil to i64 + Value resultLength = rewriter.create( + loc, RankedTensorType::get({}, rewriter.getI64Type()), + rewriter.create(loc, divOut)); + resultLength = rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); Value window = rewriter.create(loc, outType, resultLength, 0); diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 7a887abca67a..5209e6683db3 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -142,11 +142,6 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline( // Lowering Chlo ops to Stablehlo pm.addNestedPass( stablehlo::createChloLegalizeToStablehloPass()); - // Canonicalize Stablehlo dynamic ops to static ops - pm.addNestedPass( - stablehlo::createStablehloCanonicalizeDynamismPass()); - pm.addNestedPass(createCanonicalizerPass()); - // Lowering remained ops to Arith pm.addNestedPass(createConvertTorchToArithPass()); @@ -162,7 +157,17 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline( pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); - // Verify that we have lowered to Stablehlo and Chlo ops. + // Verify that we have lowered to Stablehlo ops. pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); + + // Canonicalize Stablehlo dynamic ops to static ops + pm.addNestedPass( + stablehlo::createStablehloCanonicalizeDynamismPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addPass(stablehlo::createStablehloRefineShapesPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass( + stablehlo::createStablehloCanonicalizeDynamismPass()); + pm.addNestedPass(createCanonicalizerPass()); } #endif diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 4899549a8969..d9627a352c51 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -18,7 +18,6 @@ # The pipeline of func.func passes that lower the STABLEHLO backend contract to the # Linalg-on-Tensors backend contract accepted by RefBackend. STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ - "canonicalize", "func.func(stablehlo-aggressive-simplification)", "stablehlo-legalize-to-linalg", "canonicalize"