Skip to content

Commit

Permalink
[StableHLO] Fix aten.clamp.Tensor in FxImporter2StableHLO (llvm#3190)
Browse files Browse the repository at this point in the history
The FX importer will pass static shapes to the Torch dialect, so it
needs to generate a StableHLO that satisfies shape inference.
  • Loading branch information
penguin-wwy authored Apr 19, 2024
1 parent 0a60734 commit 5a98c72
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 4 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,10 @@ LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
}
maxValue = *maxInfo;
}
if (inputType.hasStaticShape()) {
minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType);
maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType);
}
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
maxValue);
return success();
Expand Down
6 changes: 1 addition & 5 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,17 +642,12 @@
"ElementwiseBitwiseRightShiftInt32Module_basic",
"ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorIntModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Float_basic",
Expand Down Expand Up @@ -734,6 +729,7 @@
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexSelectRank0IdxModule_basic",
"IndexTensorNegativeIndexModule_basic",
"IntFloatModule_basic",
"IntImplicitModule_basic",
Expand Down

0 comments on commit 5a98c72

Please sign in to comment.