Skip to content

Commit

Permalink
[Stablehlo] fix aten.arange's lowering to stablehlo (llvm#3138)
Browse files Browse the repository at this point in the history
* promote to f64 to do division, avoid division on i64 (floor div)
* refactor torch-to-stablehlo-pipeline
  • Loading branch information
qingyunqu authored Apr 11, 2024
1 parent aa5e150 commit 88533b1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
20 changes: 11 additions & 9 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1492,15 +1492,17 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(

// Get length of the 1-d output tensor
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);

Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({1}, dtype), divOut);
if (dtype.isa<mlir::FloatType>()) {
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
resultLength = rewriter.create<stablehlo::ConvertOp>(
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
}
// promote div to f64
Type divType = RankedTensorType::get({}, rewriter.getF64Type());
Value divOut = rewriter.create<stablehlo::DivOp>(
loc, rewriter.create<stablehlo::ConvertOp>(loc, divType, subOut),
rewriter.create<stablehlo::ConvertOp>(loc, divType, step));
// ceil to i64
Value resultLength = rewriter.create<stablehlo::ConvertOp>(
loc, RankedTensorType::get({}, rewriter.getI64Type()),
rewriter.create<stablehlo::CeilOp>(loc, divOut));
resultLength = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);

Value window =
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
Expand Down
17 changes: 11 additions & 6 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,6 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
// Lowering Chlo ops to Stablehlo
pm.addNestedPass<func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
// Canonicalize Stablehlo dynamic ops to static ops
pm.addNestedPass<func::FuncOp>(
stablehlo::createStablehloCanonicalizeDynamismPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

// Lowering remained ops to Arith
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());

Expand All @@ -162,7 +157,17 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());

// Verify that we have lowered to Stablehlo and Chlo ops.
// Verify that we have lowered to Stablehlo ops.
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());

// Canonicalize Stablehlo dynamic ops to static ops
pm.addNestedPass<func::FuncOp>(
stablehlo::createStablehloCanonicalizeDynamismPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addPass(stablehlo::createStablehloRefineShapesPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
stablehlo::createStablehloCanonicalizeDynamismPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
# Linalg-on-Tensors backend contract accepted by RefBackend.
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([
"canonicalize",
"func.func(stablehlo-aggressive-simplification)",
"stablehlo-legalize-to-linalg",
"canonicalize"
Expand Down

0 comments on commit 88533b1

Please sign in to comment.