From 135c81a4165f9e4c9070d72c485efece887d64f8 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 19 Feb 2024 11:55:54 -0800 Subject: [PATCH] [torch] Add folder for `prim.NumToTensor.Scalar` (#2921) Useful for `slice` lowerings that depend on tensors made form scalars. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + .../TorchToLinalg/Uncategorized.cpp | 1 - lib/Dialect/Torch/IR/TorchOps.cpp | 24 +++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/torch_ods_gen.py | 2 +- test/Conversion/TorchToStablehlo/basic.mlir | 10 +- test/Dialect/Torch/canonicalize.mlir | 137 ++++++------------ 7 files changed, 77 insertions(+), 99 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5e4662369caf..c5fec66913b0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15070,6 +15070,7 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [ diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e8e671955835..08d69ca718b9 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -452,7 +452,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); auto div = b.create(loc, one, added); - outTy.dump(); return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy); } if (auto relu = dyn_cast(op)) { diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 18c8501df3d0..36e089fb28d3 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3641,6 +3641,30 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); } +//===----------------------------------------------------------------------===// +// PrimNumToTensorScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) { + Attribute a = adaptor.getA(); + auto resultTy = cast(getType()); + if (!a) + return {}; + if (!resultTy.hasDtype() || !resultTy.hasSizes()) + return {}; + + auto dty = resultTy.getDtype(); + if (auto iattr = dyn_cast(a)) { + a = IntegerAttr::get(dty, iattr.getInt()); + } else if (auto fattr = dyn_cast(a)) { + a = FloatAttr::get(dty, fattr.getValueAsDouble()); + } + + auto mlirTensorType = + RankedTensorType::get(resultTy.getSizes(), resultTy.getDtype()); + return SplatElementsAttr::get(mlirTensorType, a); +} + //===----------------------------------------------------------------------===// // PrimMinSelfIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7de8047fa98a..52e1ea3321b8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -419,6 +419,7 @@ "AtenEyeModuleFloat2D_basic", "AtenEyeModuleInt2D_basic", "AtenFloatScalarModule_basic", + "AtenInstanceNormModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index bfbebf86be0b..64f03add759e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -846,7 +846,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prim::device : (Tensor) -> (Device)", has_canonicalizer=True) emit("prim::dtype : (Tensor) -> (int)", has_folder=True) emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) - emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)") + emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)", has_folder=True) emit("prim::min.self_int : (int[]) -> (int)", has_folder=True) emit("prim::min.int : (int, int) -> (int)", has_folder=True) emit("prim::max.self_int : (int[]) -> (int)") diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index b502d3ffcce9..5f096205ea8c 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -27,13 +27,9 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { // CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic( // CHECK-SAME: ) -> !torch.vtensor<[],si64> { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor -// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],si64> -// CHECK: return %[[T4]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = stablehlo.constant dense<1> : tensor +// CHECK: %[[FROM:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],si64> +// CHECK: return %[[FROM]] : !torch.vtensor<[],si64> func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { %int1 = torch.constant.int 1 %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64> diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index bb57135075bd..4df52cfb174b 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1687,13 +1687,8 @@ func.func @torch.aten.Bool.int$fold_cst() -> !torch.bool { } // CHECK-LABEL: func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1705,11 +1700,8 @@ func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1760,11 +1752,8 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %str = torch.constant.str "floor" @@ -1775,13 +1764,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtenso } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %int2 = torch.constant.int 2 @@ -1793,11 +1777,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vt } // CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1808,9 +1789,8 @@ func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1820,13 +1800,8 @@ func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1838,11 +1813,8 @@ func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1854,11 +1826,8 @@ func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1869,9 +1838,8 @@ func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1891,9 +1859,8 @@ func.func @torch.aten.sub.float$fold() -> !torch.float { } // CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int3 = torch.constant.int 3 %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> @@ -1902,11 +1869,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1916,8 +1880,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> -// CHECK: return %[[INT6]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> @@ -1926,13 +1890,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1943,13 +1902,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %int2 = torch.constant.int 2 @@ -1961,11 +1915,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !to } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %str = torch.constant.str "trunc" @@ -2151,9 +2102,8 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %int-1 = torch.constant.int -1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -2163,11 +2113,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[] } // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %int-1 = torch.constant.int -1 -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[VAL_1:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 @@ -2179,7 +2126,6 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso // CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> // CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number // CHECK: return %[[VAL_1]] : !torch.number func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { @@ -2347,6 +2293,17 @@ func.func @fold_aten_where_true_attr() -> !torch.vtensor<[4],si64> { // ----- +// CHECK-LABEL: @fold_prim_numtotensor_scalar +func.func @fold_prim_numtotensor_scalar() -> !torch.vtensor<[1],si64> { + %int42 = torch.constant.int 42 + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[TENSOR]] + %0 = torch.prim.NumToTensor.Scalar %int42 : !torch.int -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// ----- + // CHECK-LABEL: @fold_aten_where_false_attr func.func @fold_aten_where_false_attr() -> !torch.vtensor<[4],si64> { // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>