Skip to content

Commit

Permalink
[torch] Add folder for prim.NumToTensor.Scalar (llvm#2921)
Browse files Browse the repository at this point in the history
Useful for `slice` lowerings that depend on tensors made form scalars.
  • Loading branch information
rsuderman authored Feb 19, 2024
1 parent e80054a commit 135c81a
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 99 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15070,6 +15070,7 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto exp = b.create<math::ExpOp>(loc, negate);
auto added = b.create<arith::AddFOp>(loc, exp, one);
auto div = b.create<arith::DivFOp>(loc, one, added);
outTy.dump();
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
}
if (auto relu = dyn_cast<AtenReluOp>(op)) {
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3641,6 +3641,30 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
}

//===----------------------------------------------------------------------===//
// PrimNumToTensorScalarOp
//===----------------------------------------------------------------------===//

OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
Attribute a = adaptor.getA();
auto resultTy = cast<BaseTensorType>(getType());
if (!a)
return {};
if (!resultTy.hasDtype() || !resultTy.hasSizes())
return {};

auto dty = resultTy.getDtype();
if (auto iattr = dyn_cast<IntegerAttr>(a)) {
a = IntegerAttr::get(dty, iattr.getInt());
} else if (auto fattr = dyn_cast<FloatAttr>(a)) {
a = FloatAttr::get(dty, fattr.getValueAsDouble());
}

auto mlirTensorType =
RankedTensorType::get(resultTy.getSizes(), resultTy.getDtype());
return SplatElementsAttr::get(mlirTensorType, a);
}

//===----------------------------------------------------------------------===//
// PrimMinSelfIntOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@
"AtenEyeModuleFloat2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic",
"AtenInstanceNormModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
"AtenIntBoolOpModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("prim::device : (Tensor) -> (Device)", has_canonicalizer=True)
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)", has_folder=True)
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
emit("prim::min.int : (int, int) -> (int)", has_folder=True)
emit("prim::max.self_int : (int[]) -> (int)")
Expand Down
10 changes: 3 additions & 7 deletions test/Conversion/TorchToStablehlo/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,9 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {

// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic(
// CHECK-SAME: ) -> !torch.vtensor<[],si64> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.*]] = stablehlo.constant dense<1> : tensor<i64>
// CHECK: %[[FROM:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[FROM]] : !torch.vtensor<[],si64>
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64>
Expand Down
Loading

0 comments on commit 135c81a

Please sign in to comment.