diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 51f3f874b065..5c8d74ee0941 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -24,9 +24,21 @@ jobs: - name: Get torch-mlir uses: actions/checkout@v3 with: - submodules: 'true' + submodules: 'false' token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + - name: Get LLVM and StableHlo submodules + run: | + set -eo pipefail + cd ${GITHUB_WORKSPACE} + + # Fetching the submodules concurrently may cause problems, so we fetch + # them one after another. + rm -f .git/modules/externals/llvm-project/index.lock + rm -f .git/modules/externals/stablehlo/index.lock + git submodule update --init --recursive externals/llvm-project + git submodule update --init --recursive externals/stablehlo + - name: Setup ccache uses: ./.github/actions/setup-build with: @@ -71,15 +83,14 @@ jobs: echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV} echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} - - name: Build and test (in-tree), also update ODS and abstract interpretation library + - name: Build and test (out-of-tree), also update ODS and abstract interpretation library if: env.PT_HASH_CHANGED != '0' run: | cd ${GITHUB_WORKSPACE} - TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \ + TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \ TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \ - TM_PYTHON_VERSIONS="cp311-cp311" \ ./build_tools/python_deploy/build_linux_packages.sh - name: Post issue comment on build failure diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 1af748879e43..1a9ce3fb3ca9 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -115,7 +115,7 @@ jobs: cd $GITHUB_WORKSPACE TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TM_TORCH_VERSION="stable" ./build_tools/python_deploy/build_linux_packages.sh + TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TORCH_MLIR_ENABLE_LTC='0' ./build_tools/python_deploy/build_linux_packages.sh # If we were given a release_id, then upload the package we just built # to the github releases page. diff --git a/README.md b/README.md index 1d6d448cbdaf..06c314f4423e 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,17 @@ We have few paths to lower down to the Torch MLIR Dialect. - `#torch-mlir` channel on the LLVM [Discord](https://discord.gg/xS7Z362) - this is the most active communication channel - Github issues [here](https://github.com/llvm/torch-mlir/issues) - [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse -- Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information. -- Weekly op office hours on Thursdays 8:30-9:30AM PST. See [here](https://discourse.llvm.org/t/announcing-torch-mlir-office-hours/63973/2) for more information. + +### Meetings + +Community Meeting / Developer Hour: +- 1st and 3rd Monday of the month at 9 am PST +- 2nd and 4th Monday of the month at 5 pm PST + +Office Hours: +- Every Thursday at 8:30 am PST + +Meeting links can be found [here](https://discourse.llvm.org/t/new-community-meeting-developer-hour-schedule/73868). ## Install torch-mlir snapshot @@ -61,7 +70,7 @@ python -m pip install --upgrade pip Then, we can install torch-mlir with the corresponding torch and torchvision nightlies. ``` pip install --pre torch-mlir torchvision \ - -f https://llvm.github.io/torch-mlir/package-index/ + -f https://llvm.github.io/torch-mlir/package-index/ \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu ``` diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 5af371d56ef9..4444015805bd 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -467,7 +467,8 @@ def gen_fallback_code(*args, **kwargs): node_base="torch::lazy::TorchMlirNode", node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")), tensor_class=self.tensor_class, - tensor_class_hdr="torch/csrc/lazy/core/tensor.h", + tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h", + create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor", shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")), lazy_ir_generator=GenMlirLazyIr, ) diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index d57f693cc433..bfc4641640aa 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -3,12 +3,6 @@ blacklist: # It also doesn't have confusing `unsafe` argument. - _index_put_impl -# Ops with list of tensors output -- split.Tensor -- split_with_sizes -- unbind.int -- chunk - # Additional ops which autogen is supported for but don't compile yet - _convolution - detach @@ -18,42 +12,28 @@ blacklist: # Disabled for consistency with TS backend - lift_fresh_copy -- new_empty - rsub -- slice.Tensor # Disabled in favour of slice_copy.Tensor -- zeros -- ones -- arange -- arange.start -- arange.start_step -- fill.Scalar -- scalar_tensor # Disabled in favour of functionalized alternatives - _reshape_alias -- expand - permute - select.int -- squeeze - squeeze.dim -- t - transpose.int +- expand +- squeeze - unsqueeze - view +- slice.Tensor +- split.Tensor +- split_with_sizes +- unbind.int -whitelist: -# Enabled for consistency with TS backend -- arange.start_out # List of supported ops that we don't want to do the full codegen for supported: -# - bernoulli -# - bernoulli_ - _to_copy - clone -- empty.memory_format -- empty_strided -- fill_.Scalar - _unsafe_view - unbind_copy.int - split_copy.Tensor @@ -80,10 +60,10 @@ supported: - _trilinear - linalg_pinv.atol_rtol_tensor - logsumexp.out +- t # List of ops that will take in symints for the size instead of ints symint: -- empty.memory_format - new_empty_strided - expand_copy - narrow_copy @@ -91,7 +71,6 @@ symint: - slice_copy.Tensor - split_copy.Tensor - slice_scatter -- view - view_copy - as_strided_copy - as_strided_scatter diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 019764f34b92..6bda469e7db3 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -178,6 +178,12 @@ function run_in_docker() { out-of-tree) setup_venv "$python_version" "$TM_TORCH_VERSION" build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" "$TM_TORCH_VERSION" + if [ "${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" == "ON" ]; then + pushd /main_checkout/torch-mlir + TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_torch_ods.sh + TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_abstract_interp_lib.sh + popd + fi if [ "${TM_SKIP_TESTS}" == "OFF" ]; then test_out_of_tree fi diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 5aa0b89f636c..3a553b86ab7a 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -17,6 +17,9 @@ # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "UnflattenStaticModule_basic", + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -288,6 +291,12 @@ # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only "AtenEmbeddingBagStaticModule_basic", + + # Lowering not present for this case + "ElementwiseToDtypeI64ToUI8Module_basic", + + # torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8! + "ElementwiseAddScalarInt8Module_basic", } if torch_version_for_comparison() < version.parse("2.1.0.dev"): @@ -827,7 +836,6 @@ "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", - "RollModule_basic", "TestMultipleTensorReturn_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", @@ -1022,6 +1030,8 @@ "TypePromotionSameCategoryDifferentWidthModule_basic", "TypePromotionSameCategoryZeroRankWider_basic", "TypePromotionZeroRankHigherCategoryModule_basic", + "ElementwiseAddScalarInt8Module_basic", + "ElementwiseSubTensorInt8Module_basic", "ElementwiseMulScalarModule_basic", "ElementwiseMulScalarModule_float", "ElementwiseMulScalarModule_int", @@ -1039,6 +1049,8 @@ "BatchNorm1DStaticShapeModule_basic", "BroadcastListConstructWithMinusOneModule_basic", "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ToDtypeLayoutCPUModule_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", } @@ -1046,6 +1058,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", @@ -1175,6 +1189,7 @@ "BatchNorm3DModule_basic", "BatchNorm1DStaticShapeModule_basic", "FlattenStaticModule_basic", + "UnflattenStaticModule_basic", "FlattenRank0Module_basic", "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", @@ -1383,6 +1398,8 @@ "SoftmaxIntNegDimModule_basic", "_LogSoftmaxModule_basic", "_SoftmaxModule_basic", + "ElementwiseAddScalarInt8Module_basic", + "ElementwiseSubTensorInt8Module_basic", } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { @@ -1441,10 +1458,6 @@ "_ConvolutionDeprecated2DBenchmarkModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AddIntModule_basic", "AtenIntBoolOpModule_basic", "BernoulliTensorModule_basic", @@ -1480,7 +1493,6 @@ "NeFloatIntModule_basic", "NeIntModule_basic", "QuantizedMLP_basic", - "RollModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", "SliceEndSleStartModule_basic", @@ -1512,7 +1524,6 @@ "ConvolutionBackwardModule2DPadded_basic", "VarMeanCorrectionModule_basic", "VarMeanCorrectionNoneModule_basic", - "PrimsConvertElementTypeModule_basic", "ElementwisePreluModule_basic", "VarMeanBiasedModule_basic", "VarMeanUnbiasedModule_basic", @@ -1547,4 +1558,7 @@ "UniformStaticShapeModule_basic", "AtenEmbeddingBagStaticModule_basic", "EmptyStridedModule_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", } diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToLinalg/Elementwise.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToLinalg/Elementwise.cpp index dd495f932bd0..a476669625f6 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -72,11 +72,11 @@ createLinalgPayloadForElementwiseOp(Operation *op, auto minFloat = clampOp.getMinFloat(); auto maxFloat = clampOp.getMaxFloat(); if (minFloat) - result = b.create( + result = b.create( loc, result, b.create(loc, *minFloat, b.getF32Type())); if (maxFloat) - result = b.create( + result = b.create( loc, result, b.create(loc, *maxFloat, b.getF32Type())); } else if (elemType.isa()) { diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index ba7ed76c81cf..dcb2f4215891 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -233,7 +233,7 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, loc, init, [&](OpBuilder &b, Location loc, Value elem, Value acc) { Value x = b.create(loc, weight, localIVs); - Value max = b.create(loc, x, acc); + Value max = b.create(loc, x, acc); b.create(loc, max); }); }) diff --git a/externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/unary.mlir b/externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/unary.mlir index 14a50e7d2475..ab7a58eeb9ae 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/unary.mlir +++ b/externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/unary.mlir @@ -43,9 +43,9 @@ func.func @tanh(%arg0 : tensor) -> tensor { // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor) { // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): // CHECK: %[[CST0:.*]] = arith.constant 1.000000e-01 : f32 -// CHECK: %[[MAX:.*]] = arith.maxf %[[BBARG0]], %[[CST0]] : f32 +// CHECK: %[[MAX:.*]] = arith.maximumf %[[BBARG0]], %[[CST0]] : f32 // CHECK: %[[CST1:.*]] = arith.constant 1.024000e+03 : f32 -// CHECK: %[[MIN:.*]] = arith.minf %[[MAX]], %[[CST1]] : f32 +// CHECK: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[CST1]] : f32 // CHECK: linalg.yield %[[MIN]] : f32 // CHECK: } -> tensor // CHECK: return %[[GENERIC]] : tensor diff --git a/externals/llvm-project b/externals/llvm-project index 4acc3ffbb0af..28b27c1b10ae 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 4acc3ffbb0af5631bc7916aeff3570f448899647 +Subproject commit 28b27c1b10ae8d1f5b4fb9df691e8cf0da9be3f6 diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 8795974a395c..d561c2101173 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, - std::optional srcOriginalDtype = std::nullopt); + std::optional srcOriginalDtype = std::nullopt, + std::optional dstOriginalDtype = std::nullopt); Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f0d0a238a129..0530e3082418 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2527,6 +2527,51 @@ def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ }]; } +def Torch_AtenLog10Op : Torch_Op<"aten.log10", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::log10 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLog10Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLog10Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::log10_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLog10_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLog10_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ AllowsTypeRefinement, HasValueSemantics, @@ -2799,6 +2844,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ }]; } +def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -2893,6 +2985,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ }]; } +def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseRightShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseRightShiftTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseRightShift_TensorOp : Torch_Op<"aten.bitwise_right_shift_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_right_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseRightShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseRightShift_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ AllowsTypeRefinement, HasValueSemantics, @@ -4023,6 +4162,33 @@ def Torch_AtenViewAsRealOp : Torch_Op<"aten.view_as_real", [ }]; } +def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + Torch_FloatType:$rtol, + Torch_FloatType:$atol, + Torch_BoolType:$equal_nan + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIscloseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenIscloseOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ AllowsTypeRefinement, HasValueSemantics, @@ -4490,6 +4656,56 @@ def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [ }]; } +def Torch_AtenRandomOp : Torch_Op<"aten.random", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::random : (Tensor, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandomOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenRandomOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenRandomFromOp : Torch_Op<"aten.random.from", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$from, + AnyTorchOptionalIntType:$to, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandomFromOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenRandomFromOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ AllowsTypeRefinement, HasValueSemantics, @@ -5563,6 +5779,34 @@ def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_in }]; } +def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenAvgPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -5592,30 +5836,91 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ }]; } -def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [ +def Torch_AtenAvgPool2dBackwardOp : Torch_Op<"aten.avg_pool2d_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; let arguments = (ins + AnyTorchTensorType:$grad_output, AnyTorchTensorType:$self, AnyTorchListOfTorchIntType:$kernel_size, AnyTorchListOfTorchIntType:$stride, AnyTorchListOfTorchIntType:$padding, Torch_BoolType:$ceil_mode, - Torch_BoolType:$count_include_pad + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenAvgPool2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); } - void AtenAvgPool1dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenAvgPool2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenAvgPool3dOp : Torch_Op<"aten.avg_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenAvgPool3dBackwardOp : Torch_Op<"aten.avg_pool3d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool3dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenAvgPool3dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); } }]; } @@ -5846,6 +6151,30 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [ }]; } +def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -5870,12 +6199,12 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ }]; } -def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ +def Torch_Aten_AdaptiveAvgPool2dOp : Torch_Op<"aten._adaptive_avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchListOfTorchIntType:$output_size @@ -5885,10 +6214,106 @@ def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult Aten_AdaptiveAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) { + void Aten_AdaptiveAvgPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten_AdaptiveAvgPool2dBackwardOp : Torch_Op<"aten._adaptive_avg_pool2d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AdaptiveAvgPool2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_AdaptiveAvgPool2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenAdaptiveAvgPool3dOp : Torch_Op<"aten.adaptive_avg_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAdaptiveAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten_AdaptiveAvgPool3dOp : Torch_Op<"aten._adaptive_avg_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AdaptiveAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_AdaptiveAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AdaptiveAvgPool3dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_AdaptiveAvgPool3dBackwardOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; @@ -6042,6 +6467,31 @@ def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ }]; } +def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cumprod : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCumprodOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenCumprodOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [ AllowsTypeRefinement, HasValueSemantics, @@ -7114,6 +7564,30 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [ }]; } +def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchListOfTorchIntType:$sizes + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUnflattenIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenUnflattenIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenDimOp : Torch_Op<"aten.dim", [ AllowsTypeRefinement, HasValueSemantics, @@ -8725,6 +9199,31 @@ def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [ }]; } +def Torch_AtenResizeOp : Torch_Op<"aten.resize", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::resize : (Tensor, int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenResizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenResizeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ AllowsTypeRefinement ]> { @@ -10041,6 +10540,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [ }]; } +def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + Torch_IntType:$steps, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinspaceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenLinspaceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 37aaed9cd704..f913e70345f4 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -86,6 +86,24 @@ FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value input, Value dim); +// In Dynamo import paths, we can assume that dynamic dimensions are strictly +// quantities and are not ambiguous with '1' symbols that can be interpreted +// to signal an expansion in various broadcasting scenarios. In the +// torch.compile eager path, this precondition is assured by guards on 0/1 +// dimension values, and on the torch.export graph-capture path, the shape +// solver guarantees this. +// +// We let lowerings assume this on a per-scope basis if the +// torch.assume_strict_symbolic_shapes unit attribute is present on any parent +// of the block. +bool isAssumingStrictSymbolicShapes(Block *scope); + +// Helper that uses the block from an OpBuilder for determining whether we +// are assuming strict symbolic shapes. +inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) { + return isAssumingStrictSymbolicShapes(builder.getBlock()); +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 041581d2a18b..662a1379bb41 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -34,6 +34,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +static int64_t productReduce(ArrayRef a) { + return accumulate(a.begin(), a.end(), /*init=*/1, std::multiplies()); +} + template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, @@ -177,144 +181,136 @@ namespace { class ConvertAtenViewOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + // If one of the two dims arrays has size 1, a mapping is created from the one + // dimension of the size-1 array to all the dimensions of the other array. For + // example for inputs: xDims = [6], yDims = [2, 3] the result in the indices + // arrays will be: xIndices = [0], yIndices = [0, 1]. + // + // An error is returned if the dimension size of the size-1 array is not equal + // to the product of all the dimension sizes in the other array, or if neither + // of the arrays is size-1. + static LogicalResult mapAllDimsToSingleDim(ArrayRef xDims, + ArrayRef yDims, + SmallVector &xIndices, + SmallVector &yIndices) { + if (xDims.empty() || yDims.empty()) + return failure(); - // Helper for filling in remaining un-collapsed dims when the - // input/output dim is next to the next boundary dim. Additionally - // computes the size of a collapsed dynamic dim if necessary. - static LogicalResult - collapseToSingleDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter, - int64_t collapseDim, int64_t maxCollapseDim, - int64_t startExpandDim, int64_t maxExpandDim, - SmallVector &collapseShape, - const SmallVector &expandShape, - ReassociationIndices &expandIndices) { - int64_t collapseDimSize = 1; - for (auto i : llvm::seq(startExpandDim, maxExpandDim)) { - expandIndices.push_back(i); - if (collapseDimSize == kUnknownSize) - continue; - - int64_t expandedDimSize = expandShape[i]; - if (expandedDimSize == kUnknownSize) { - collapseDimSize = kUnknownSize; - continue; - } - collapseDimSize *= expandedDimSize; - } - int64_t rawCollapseDimSize = collapseShape[collapseDim]; - if (rawCollapseDimSize != kUnknownSize && collapseDimSize != kUnknownSize && - collapseDimSize != rawCollapseDimSize) { - return rewriter.notifyMatchFailure( - op, "desired size is not compatible with the input tensor size"); + auto isValidReduction = [](int64_t expectedReductionProduct, + ArrayRef arrayToReduce) -> bool { + if (llvm::count(arrayToReduce, kUnknownSize) > 0 || + expectedReductionProduct == kUnknownSize) + return true; + return productReduce(arrayToReduce) == expectedReductionProduct; + }; + + if (xDims.size() == 1) { + if (!isValidReduction(xDims[0], yDims)) + return failure(); + xIndices.assign({0}); + yIndices.assign(llvm::to_vector(llvm::seq(0, yDims.size()))); + return success(); + } else if (yDims.size() == 1) { + if (!isValidReduction(yDims[0], xDims)) + return failure(); + yIndices.assign({0}); + xIndices.assign(llvm::to_vector(llvm::seq(0, xDims.size()))); + return success(); } - collapseShape[collapseDim] = collapseDimSize; - return success(); + return failure(); } - // Helper to find the minimum set of dims to collapse with the - // same number of elements as that of collapseDim. This function assumes - // the size of the collapsed dim is never dynamic. - static LogicalResult minimallyCollapseDimHelper( - AtenViewOp op, ConversionPatternRewriter &rewriter, int64_t collapseDim, - int64_t maxCollapseDim, int64_t startExpandDim, int64_t maxExpandDim, - SmallVector &collapseShape, SmallVector &expandShape, - ReassociationIndices &collapseIndices, - ReassociationIndices &expandIndices) { - - int64_t collapseDimSize = collapseShape[collapseDim]; - - int64_t expandedSize = 1; - int64_t collapsedSize = collapseDimSize; - - int64_t expandIndex = startExpandDim; - int64_t collapseIndex = collapseDim + 1; - - if (collapseDimSize == kUnknownSize) { - if (llvm::all_of(collapseShape, - [](int64_t value) { return value == kUnknownSize; }) && - llvm::all_of(expandShape, - [](int64_t value) { return value == kUnknownSize; })) { - - for (size_t i = 0; i < collapseShape.size(); i++) { - collapseIndices.push_back(i); - } - - for (size_t i = 0; i < expandShape.size(); i++) { - expandIndices.push_back(i); - } - - return success(); + // Starting from the beginning of the dims arrays, this helper finds the + // smallest set of consecutive dims in each array such that the product of the + // dim sizes in the two subsets is equal. The indices arrays are populated + // with the indices of the dims arrays that correspond to the subsets found. + // + // An error is returned if two subsets of dims with total number of elements + // equal to each other is not found. + static LogicalResult mapStaticallyKnownDims(ArrayRef xDims, + ArrayRef yDims, + SmallVector &xIndices, + SmallVector &yIndices) { + if (xDims.empty() || yDims.empty()) + return failure(); + int64_t xTotalSize = xDims[0]; + int64_t yTotalSize = yDims[0]; + SmallVector xIndicesResult({0}); + SmallVector yIndicesResult({0}); + size_t nextXIndex = 1; + size_t nextYIndex = 1; + while (xTotalSize != yTotalSize) { + if (xTotalSize < yTotalSize) { + if (nextXIndex == xDims.size() || xDims[nextXIndex] == kUnknownSize) + return failure(); + xTotalSize *= xDims[nextXIndex]; + xIndicesResult.push_back(nextXIndex++); + } else { + if (nextYIndex == yDims.size() || yDims[nextYIndex] == kUnknownSize) + return failure(); + yTotalSize *= yDims[nextYIndex]; + yIndicesResult.push_back(nextYIndex++); } } - while (expandIndex != maxExpandDim || collapseIndex != maxCollapseDim) { - if (expandIndex != maxExpandDim && expandedSize <= collapsedSize) { - int64_t expandDimSize = expandShape[expandIndex]; - if (expandDimSize != kUnknownSize) { - expandedSize *= expandDimSize; - } - expandIndices.push_back(expandIndex); - expandIndex++; - - } else if (collapseIndex != maxCollapseDim && - collapsedSize < expandedSize) { - collapseDimSize = collapseShape[collapseIndex]; - if (collapseDimSize != kUnknownSize) { - collapsedSize *= collapseDimSize; - } - collapseIndices.push_back(collapseIndex); - collapseIndex++; - } - - if (expandedSize == collapsedSize) - return success(); - } - return rewriter.notifyMatchFailure( - op, "total number of elements mismatch in the expansion"); + xIndices.assign(std::move(xIndicesResult)); + yIndices.assign(std::move(yIndicesResult)); + return success(); } - static void solveDynamicSize(SmallVector &inputShape, - SmallVector &outputShape) { - int64_t inputProduct = 1; - int64_t outputProduct = 1; - - int64_t inputDynamicValues = 0; - int64_t outputDynamicValues = 0; - - for (int64_t value : inputShape) { - if (value == -1) { - ++inputDynamicValues; - } else { - inputProduct *= value; - } - } - for (int64_t value : outputShape) { - if (value == -1) { - ++outputDynamicValues; - } else { - outputProduct *= value; - } + // Calculates the size of a dynamic dimension if all other dimensions are + // statically known, and rewrites that dynamic dimension with the static size. + // + // Note: this function assumes that all the dimensions in `inputShape` map to + // all the dimensions in `outputShape`. + static void calculateSingleDynamicSize(MutableArrayRef inputShape, + MutableArrayRef outputShape) { + if (inputShape.empty() || outputShape.empty()) + return; + int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); + int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); + if (inputDynamicDimCount + outputDynamicDimCount != 1) + return; + + int64_t inputProduct = productReduce(inputShape); + int64_t outputProduct = productReduce(outputShape); + + if (inputDynamicDimCount == 1) { + inputProduct /= kUnknownSize; + *llvm::find(inputShape, kUnknownSize) = outputProduct / inputProduct; + } else { + outputProduct /= kUnknownSize; + *llvm::find(outputShape, kUnknownSize) = inputProduct / outputProduct; } + } - if (inputDynamicValues + outputDynamicValues == 1) { - if (inputDynamicValues) { - int64_t missingValue = outputProduct / inputProduct; - for (size_t i = 0; i < inputShape.size(); i++) { - if (inputShape[i] == -1) { - inputShape[i] = missingValue; - break; - } - } - } else { - int64_t missingValue = inputProduct / outputProduct; - for (size_t i = 0; i < outputShape.size(); i++) { - if (outputShape[i] == -1) { - outputShape[i] = missingValue; - break; - } + // Gets the shapes of the input and output tensors, making a best-effort + // attempt to extract static shape information given the inputs to + // `aten.view`. + static std::pair, SmallVector> + getInputAndOutputShape(Value inputTorchTensor, + SmallVector outputSizeTorchInt) { + SmallVector inputShape( + inputTorchTensor.getType().cast().getSizes()); + SmallVector outputShape(outputSizeTorchInt.size(), kUnknownSize); + for (auto [outputDim, outputDimSize] : + llvm::enumerate(outputSizeTorchInt)) { + int64_t inputDim; + int64_t outputDimSizeInt; + // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim + if (matchPattern(outputDimSize, + m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) { + outputShape[outputDim] = inputShape[inputDim]; + } else if (matchPattern(outputDimSize, + m_TorchConstantInt(&outputDimSizeInt))) { + if (outputDimSizeInt != -1) { + outputShape[outputDim] = outputDimSizeInt; } } } + + calculateSingleDynamicSize(inputShape, outputShape); + return std::make_pair(inputShape, outputShape); } LogicalResult @@ -325,8 +321,7 @@ class ConvertAtenViewOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); - SmallVector inputShape = - makeShapeTorchCompatible(inputType.getShape()); + SmallVector inputSize = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputType.getRank(); const TypeConverter *typeConverter = getTypeConverter(); auto resultType = @@ -349,6 +344,15 @@ class ConvertAtenViewOp : public OpConversionPattern { "unimplemented: the target size is " "not constructed from ListConstruct"); } + if (llvm::count_if(outputSizeTorchInt, [](Value size) -> bool { + int64_t sizeInt; + if (matchPattern(size, m_TorchConstantInt(&sizeInt))) + return sizeInt == -1; + return false; + }) > 1) { + return rewriter.notifyMatchFailure( + op, "at most one element in size list is allowed to be -1"); + } SmallVector outputSizeInt = getTypeConvertedValues( rewriter, loc, typeConverter, outputSizeTorchInt); if (resultRank != (int64_t)outputSizeInt.size()) { @@ -356,6 +360,9 @@ class ConvertAtenViewOp : public OpConversionPattern { op, "desired size list length mismatches with the result type rank"); } + auto [inputShape, outputShape] = + getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); + // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. @@ -364,90 +371,24 @@ class ConvertAtenViewOp : public OpConversionPattern { // [6] => [3, 2]. // Iterate through the view op size list to do the following: - // - // 1. Combine output size list and input tensor type info to get the most - // static outputShape. - // - // 2. Mark dims in unchangedDims for size list items where the output dim + // Mark dims in unchangedDims for size list items where the output dim // size comes from a `torch.aten.size.int(inputTensor, inputDim)`. We // naively assume this means the corresponding dimension is not expanded or // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption // is violated for the cases of dynamic dimensions. - SmallVector outputShape(resultRank, kUnknownSize); - SmallVector unchangedDims; - std::optional inferredDimension; - for (auto en : llvm::enumerate(outputSizeTorchInt)) { + SmallVector> unchangedDims; + for (auto [outputDim, outputDimSize] : + llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; - int64_t size; - int64_t outputDim = en.index(); // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim - if (matchPattern(en.value(), + if (matchPattern(outputDimSize, m_TorchTensorSizeInt(op.getSelf(), &inputDim))) { - unchangedDims.emplace_back(); - unchangedDims.back().push_back(inputDim); - unchangedDims.back().push_back(outputDim); - if (!inputType.isDynamicDim(inputDim)) { - outputShape[outputDim] = inputShape[inputDim]; - continue; - } - } else if (matchPattern(en.value(), m_TorchConstantInt(&size))) { - if (size != -1) { - outputShape[outputDim] = size; - continue; - } - - if (inferredDimension.has_value()) { - return rewriter.notifyMatchFailure( - op, "at most one element in size list is allowed to be -1"); - } - inferredDimension = outputDim; + unchangedDims.push_back(std::make_pair(inputDim, outputDim)); } } - // Mark the end of the input/output shapes - unchangedDims.emplace_back(); - unchangedDims.back().push_back(inputRank); - unchangedDims.back().push_back(resultRank); - - // Use static information of input tensor to determine size of inferred - // dimension in output shape. - // - // If there is an inferred dimension and that is the only dimension - // in the output shape (i.e. the tensor is getting fully flattened), - // then we don't need to analyze the static information of the input - // shape since the reassociation of dimensions only requires rank - // information. - if (inferredDimension.has_value() && outputShape.size() > 1) { - if (llvm::count(outputShape, kUnknownSize) != 1 || - llvm::count(inputShape, kUnknownSize) != 0) { - return rewriter.notifyMatchFailure( - op, - "unimplemented: an inferred dimension is only supported when there " - "is enough static shape information to determine its size, or when " - "the input tensor is being flattened to a single dimension"); - } - auto productReduceKnownSizes = [](const ArrayRef sizes) { - auto knownSizes = llvm::make_filter_range( - sizes, [](int64_t val) { return val != kUnknownSize; }); - return std::accumulate(knownSizes.begin(), knownSizes.end(), /*init=*/1, - std::multiplies()); - }; - - int64_t numOfElements = productReduceKnownSizes(inputShape); - int64_t outputKnownNumOfElements = productReduceKnownSizes(outputShape); - if (numOfElements % outputKnownNumOfElements != 0) { - return rewriter.notifyMatchFailure( - op, "number of elements in input tensor must be divisible by " - "product of non-inferred dimensions in size list"); - } - outputShape[*inferredDimension] = - numOfElements / outputKnownNumOfElements; - } - - SmallVector inputSize = getTensorSizes(rewriter, loc, input); - ArrayRef outputShapeInt = llvm::ArrayRef(outputSizeInt); - ArrayRef inputShapeInt = llvm::ArrayRef(inputSize); + unchangedDims.push_back(std::make_pair(inputRank, resultRank)); // Association indices for expand/collapse ops. These two vectors // are populated such that two entries at the same index corresponds @@ -463,10 +404,6 @@ class ConvertAtenViewOp : public OpConversionPattern { SmallVector inputAssociations; SmallVector outputAssociations; - SmallVector inputShapeVec = llvm::to_vector(inputShape); - - solveDynamicSize(inputShapeVec, outputShape); - // The for loop does the following: // 1. Attempt to match the indices from inputDim and outputDim to the next // boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or @@ -482,127 +419,103 @@ class ConvertAtenViewOp : public OpConversionPattern { // the dynamic dimension with the one across from it and give up if we can't // reason about how the dimensions are associated. // e.g. [-1, -1] -> [2, 3, 4] - // 3. Set inputShapeVec and outputShape following the requirements by - // tensor.expand_shape verification code: - // a. As long as one or more of the related dimensions in the expanded - // shape is dynamic the collapsed dimension is dynamic. - // b. If all of the related dimensions are static, the collapsed - // dimension must be static. In other words, if a collapsed dimension is - // dynamic, at least one of the related dimensions need to be dynamic. + // For more information, see description of helper functions used in the + // `if-else` cases inside the while loop. int64_t inputDim = 0, outputDim = 0; - for (auto boundary : unchangedDims) { - // We assume dims specified by AtenSizeInt ops are unchanged - int64_t nextUnchangedInput = boundary[0]; - int64_t nextUnchangedOutput = boundary[1]; - - bool hasDynamic = false; + for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) { + // Used for ensuring that we don't have an ambiguous expansion + bool assumedDynamicDimNotSplit = false; while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { - - inputAssociations.emplace_back(); - outputAssociations.emplace_back(); - - // outputDim is next to the boundary - if (outputDim == nextUnchangedOutput - 1) { - - if (hasDynamic && inputDim != nextUnchangedInput - 1) { - return rewriter.notifyMatchFailure( - op, "found ambiguous collapse of dynamic input sizes (e.g. " - "[-1, -1, -1] -> [-1, -1])"); - } - outputAssociations.back().push_back(outputDim); - if (failed(collapseToSingleDimHelper( - op, rewriter, outputDim, nextUnchangedOutput, inputDim, - nextUnchangedInput, outputShape, inputShapeVec, - inputAssociations.back()))) - return failure(); - outputDim = nextUnchangedOutput; - inputDim = nextUnchangedInput; - continue; - } - - // inputDim is next to the boundary - if (inputDim == nextUnchangedInput - 1) { - - if (hasDynamic && inputShape[inputDim] == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> " - "[-1, -1, -1])"); - } - inputAssociations.back().push_back(inputDim); - if (failed(collapseToSingleDimHelper( - op, rewriter, inputDim, nextUnchangedInput, outputDim, - nextUnchangedOutput, inputShapeVec, outputShape, - outputAssociations.back()))) - return failure(); - - outputDim = nextUnchangedOutput; - inputDim = nextUnchangedInput; - continue; - } - - int64_t inputMatchingDimSize = inputShapeVec[inputDim]; - int64_t outputMatchingDimSize = outputShape[outputDim]; - - // If the input is dynamic, first assume it is not split - if (inputMatchingDimSize == kUnknownSize) { - - checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim], - outputShapeInt[outputDim]); - outputShape[outputDim] = kUnknownSize; - inputAssociations.back().push_back(inputDim++); - outputAssociations.back().push_back(outputDim++); - hasDynamic = true; - continue; + auto inputShapeSlice = + MutableArrayRef(inputShape) + .slice(inputDim, nextUnchangedInput - inputDim); + auto outputShapeSlice = + MutableArrayRef(outputShape) + .slice(outputDim, nextUnchangedOutput - outputDim); + SmallVector inputSliceIndices; + SmallVector outputSliceIndices; + + // TODO: this can be removed by replacing it with a checkDimEqualHelper + // that takes into account the product of all the dimensions being + // reduced + if (assumedDynamicDimNotSplit && inputShapeSlice.size() == 1 && + outputShapeSlice.size() != 1 && + inputShapeSlice[0] == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "found ambiguous expand of dynamic input sizes " + "(e.g. [-1, -1] -> [-1, -1, -1])"); } - // inputDim size is larger; try to collapse onto it - if (inputMatchingDimSize >= outputMatchingDimSize) { - - inputAssociations.back().push_back(inputDim); - if (failed(minimallyCollapseDimHelper( - op, rewriter, inputDim, nextUnchangedInput, outputDim, - nextUnchangedOutput, inputShapeVec, outputShape, - inputAssociations.back(), outputAssociations.back()))) { - return failure(); + if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice, + inputSliceIndices, + outputSliceIndices))) { + calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice); + // Update shape to pass the tensor.expand_shape and + // tensor.collapse_shape verifiers. If one of the dimensions of the + // tensor being flattened is dynamic, the size of the flattened tensor + // must also be dynamic. + if (inputShapeSlice.size() == 1 && + llvm::count(outputShapeSlice, kUnknownSize) > 0) { + inputShapeSlice[0] = kUnknownSize; + } else if (outputShapeSlice.size() == 1 && + llvm::count(inputShapeSlice, kUnknownSize) > 0) { + outputShapeSlice[0] = kUnknownSize; } - hasDynamic = false; - outputDim = outputAssociations.back().back() + 1; - inputDim = inputAssociations.back().back() + 1; - continue; + } else if (succeeded(mapStaticallyKnownDims( + inputShapeSlice, outputShapeSlice, inputSliceIndices, + outputSliceIndices))) { + /// `mapStaticallyKnownDims` maps the smallest number of + /// input and output dimensions in the slice statically + /// known to have the same number of elements. + } else if (inputShapeSlice[0] == kUnknownSize) { + // If the input is dynamic, assume it is not split + checkDimEqualHelper(rewriter, loc, inputSize[inputDim], + outputSizeInt[outputDim]); + // If output dimension is not dynamic, improve static information of + // input + inputShape[inputDim] = outputShape[outputDim]; + inputSliceIndices.push_back(0); + outputSliceIndices.push_back(0); + assumedDynamicDimNotSplit = true; + } else { + return rewriter.notifyMatchFailure( + op, "unimplemented: found unhandled case of expansion/collapse " + "in `aten.view`"); } - // outputDim is larger; try to collapse onto it - outputAssociations.back().push_back(outputDim); - if (failed(minimallyCollapseDimHelper( - op, rewriter, outputDim, nextUnchangedOutput, inputDim, - nextUnchangedInput, outputShape, inputShapeVec, - outputAssociations.back(), inputAssociations.back()))) { - - return failure(); - } - hasDynamic = false; + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + for (int64_t inputSliceIndex : inputSliceIndices) + inputAssociations.back().push_back(inputSliceIndex + inputDim); + for (int64_t outputSliceIndex : outputSliceIndices) + outputAssociations.back().push_back(outputSliceIndex + outputDim); inputDim = inputAssociations.back().back() + 1; outputDim = outputAssociations.back().back() + 1; - continue; } - if (inputDim != nextUnchangedInput) { - hasDynamic = true; - if (inputAssociations.size() < 1) { - inputAssociations.emplace_back(); - outputAssociations.emplace_back(); - } - inputAssociations.back().push_back(inputDim++); - outputAssociations.back().push_back(outputDim++); - continue; - } - - // Append the associations for the dims matching `aten.size.int` - if (nextUnchangedInput != inputRank && - nextUnchangedOutput != resultRank) { + // Handle any leading or trailing size-1 dimensions and append the + // associations for the dims matching `aten.size.int`. + if (nextUnchangedInput != inputRank) { + assert(nextUnchangedOutput != resultRank && + "`nextUnchangedInput` and `nextUnchangedOutput` should equal " + "the respective input and output rank at the same time"); inputAssociations.emplace_back(); outputAssociations.emplace_back(); + } + while (inputDim <= nextUnchangedInput && inputDim < inputRank) { + if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only collapsing of static size-1 into " + "unchanged dim supported"); + } inputAssociations.back().push_back(inputDim++); + } + while (outputDim <= nextUnchangedOutput && outputDim < resultRank) { + if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only expanding of static size-1 out of " + "unchanged dim supported"); + } outputAssociations.back().push_back(outputDim++); } } @@ -624,7 +537,7 @@ class ConvertAtenViewOp : public OpConversionPattern { Type adjustedResultType = RankedTensorType::get( makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( - makeShapeLLVMCompatible(inputShapeVec), resultType.getElementType()); + makeShapeLLVMCompatible(inputShape), resultType.getElementType()); Value castedInput = rewriter.create(loc, adjustedInputType, input); std::optional expandedInput; @@ -728,20 +641,23 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { reassociation[0].push_back(headOnesCount++); } - // TODO: Add support for size-1 dynamic dimensions. Value one = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); int64_t j = -1; + bool elideDynamicBroadcastDimCheck = + isAssumingStrictSymbolicShapes(rewriter); for (auto i : llvm::seq(headOnesCount, inputRank)) { if (inputType.isDynamicDim(i)) { - // Make sure that size-1 dynamic dimension does not exist. - Value dimSize = getDimOp(rewriter, loc, input, i); - Value dimSizeNotOne = rewriter.create( - loc, arith::CmpIPredicate::ne, dimSize, one); - rewriter.create( - loc, dimSizeNotOne, - rewriter.getStringAttr( - "unimplemented: size 1 dynamic dimension is not supported")); + if (!elideDynamicBroadcastDimCheck) { + // Make sure that size-1 dynamic dimension does not exist. + Value dimSize = getDimOp(rewriter, loc, input, i); + Value dimSizeNotOne = rewriter.create( + loc, arith::CmpIPredicate::ne, dimSize, one); + rewriter.create( + loc, dimSizeNotOne, + rewriter.getStringAttr( + "unimplemented: size 1 dynamic dimension is not supported")); + } ++j; } else if (inputType.getDimSize(i) != 1) { ++j; @@ -1179,31 +1095,35 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { // which in this case is `inShapeConverted` because this shape will yield // us the dimension size of the output. SmallVector useBroadcastToShape; - for (auto x : inShape) { + int64_t inputRank = self.getType().cast().getRank(); + for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e; + ++i) { int64_t dim; - if (!matchPattern(x, m_TorchConstantInt(&dim))) { - Operation *defOp = x.getDefiningOp(); - if (isa(defOp)) - useBroadcastToShape.push_back(true); - else + if (matchPattern(inShape[i], m_TorchConstantInt(&dim))) { + if (dim < 0) { useBroadcastToShape.push_back(false); + } else { + useBroadcastToShape.push_back(true); + } } else { - useBroadcastToShape.push_back(false); + // Note: Dynamic -1 (inferred) broadcast shapes are unimplemented. + useBroadcastToShape.push_back(true); } } SmallVector inShapeConverted = getTypeConvertedValues( rewriter, op.getLoc(), getTypeConverter(), inShape); + auto newResultType = + getTypeConverter()->convertType(op.getType()).cast(); Value result; - if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self, - inShapeConverted, result, - useBroadcastToShape))) { + if (failed(torch_to_linalg::broadcastToGivenShape( + op, rewriter, self, inShapeConverted, newResultType, result, + useBroadcastToShape))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, result); + rewriter.replaceOp(op, result); return success(); } }; @@ -1261,7 +1181,7 @@ class ConvertAtenCopyOp : public OpConversionPattern { selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]); Value broadcastedSrc; if (failed(torch_to_linalg::broadcastToGivenShape( - op, rewriter, src, selfSizes, broadcastedSrc))) { + op, rewriter, src, selfSizes, selfType, broadcastedSrc))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index cfbac2632a28..0e89d822669f 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -644,14 +644,16 @@ class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern 1) { - Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, - rewriter.getIndexType()); - auto equalToRunning = rewriter.create( - loc, arith::CmpIPredicate::eq, cstStaticDimSize, - dynamicDims[0]); - rewriter.create(loc, equalToRunning, - "mismatched size for broadcast"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + if (staticDimSize > 1) { + Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, + rewriter.getIndexType()); + auto equalToRunning = rewriter.create( + loc, arith::CmpIPredicate::eq, cstStaticDimSize, + dynamicDims[0]); + rewriter.create(loc, equalToRunning, + "mismatched size for broadcast"); + } } broadcastedIndexShape.push_back(dynamicDims[0]); } else { diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 23528bb01f80..bbf53162d6a1 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -58,15 +58,18 @@ class ConvertAtenMmOp : public OpConversionPattern { } Value lhsDim0 = rewriter.create(loc, lhs, 0); - Value lhsDim1 = rewriter.create(loc, lhs, 1); - Value rhsDim0 = rewriter.create(loc, rhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); - Value contractingDimEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); - rewriter.create( - loc, contractingDimEqual, - rewriter.getStringAttr( - "mismatching contracting dimension for torch.aten.mm")); + + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value lhsDim1 = rewriter.create(loc, lhs, 1); + Value rhsDim0 = rewriter.create(loc, rhs, 0); + Value contractingDimEqual = rewriter.create( + loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); + rewriter.create( + loc, contractingDimEqual, + rewriter.getStringAttr( + "mismatching contracting dimension for torch.aten.mm")); + } Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); @@ -292,13 +295,24 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // Broadcast the batch dimensions of both the matrices. Value broadcastedLhs, broadcastedRhs; + // TODO: Improve usage of static shape information. + SmallVector lhsTargetShape(lhsBroadcastToShape.size(), + ShapedType::kDynamic); + auto lhsBroadcastType = + RankedTensorType::get(lhsTargetShape, lhsType.getElementType()); if (failed(torch_to_linalg::broadcastToGivenShape( - op, rewriter, lhs, lhsBroadcastToShape, broadcastedLhs))) { + op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType, + broadcastedLhs))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } + SmallVector rhsTargetShape(rhsBroadcastToShape.size(), + ShapedType::kDynamic); + auto rhsBroadcastType = + RankedTensorType::get(rhsTargetShape, rhsType.getElementType()); if (failed(torch_to_linalg::broadcastToGivenShape( - op, rewriter, rhs, rhsBroadcastToShape, broadcastedRhs))) { + op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType, + broadcastedRhs))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 4078fbaa342c..641f1ef8cc1c 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -176,8 +176,8 @@ class ConvertAtenMaxDimOp : public OpConversionPattern { Value resultMax, predicate; if (inElementType.isa()) { - resultMax = - rewriter.create(nestedLoc, newValue, oldValue); + resultMax = rewriter.create(nestedLoc, newValue, + oldValue); predicate = rewriter.create( nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else { @@ -280,7 +280,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) - return b.create(loc, self, result); + return b.create(loc, self, result); else if (resultElementType.isa()) { IntegerType intType = max.getSelf() .getType() @@ -297,7 +297,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) - return b.create(loc, self, result); + return b.create(loc, self, result); else if (resultElementType.isa()) { IntegerType intType = min.getSelf() .getType() diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 262d3cf62e54..a1e8e5fb72d9 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -42,7 +42,9 @@ class ConvertAtenSizeIntOp : public OpConversionPattern { Value inputRank = rewriter.create( loc, rewriter.getI64IntegerAttr(type.getRank())); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); - assertIsValidDim(rewriter, loc, dimPositive, inputRank); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + assertIsValidDim(rewriter, loc, dimPositive, inputRank); + } Value size = rewriter.create( loc, adaptor.getSelf(), castIntToIndex(rewriter, loc, dimPositive)); rewriter.replaceOp(op, castIndexToInt64(rewriter, loc, size)); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 1d25d22720d2..9c862e410994 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -235,6 +235,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -296,6 +300,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto bitwiseAndScalar = dyn_cast(op)) { + Type dtype = converter->convertType(bitwiseAndScalar.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + bitwiseAndScalar.emitError( + "bitwise_and.Scalar does not support non-integer input dtype."); + return nullptr; + } + Type resultElementType = + bitwiseAndScalar.getType().cast().getDtype(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value other = convertScalarToDtype(b, loc, operands[1], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + return b.create(loc, self, other); + } if (auto bitwiseOrTensor = dyn_cast(op)) { if (bitwiseOrTensor.getType() .cast() @@ -328,6 +351,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto bitwiseRightShiftTensor = + dyn_cast(op)) { + Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + bitwiseRightShiftTensor.emitError( + "Bitwise_Right_Shift op does not support non-integer input dtype."); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } if (isa(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); @@ -511,9 +548,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(sub.getType()) .cast() .getElementType(); - Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype); + Type resultElementType = sub.getType().cast().getDtype(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); @@ -544,9 +588,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(addScalar.getType()) .cast() .getElementType(); - Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value other = convertScalarToDtype(b, loc, operands[1], dtype); - Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); + Type resultElementType = + addScalar.getType().cast().getDtype(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value other = convertScalarToDtype(b, loc, operands[1], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value alpha = convertScalarToDtype(b, loc, operands[2], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); @@ -567,7 +619,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (dtype.isa()) { return b.create(loc, lhs, rhs); - } else if(dtype.isa()) { + } else if (dtype.isa()) { return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); @@ -984,7 +1036,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(atenToDtype.getType()) .cast() .getElementType(); - Value result = convertScalarToDtype(b, loc, input, dtype); + Type resultElementType; + int64_t dtypeInt; + if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) { + atenToDtype.emitError("unimplemented: dtype must be a constant integer"); + return nullptr; + } + FailureOr maybeResultElementType = getTypeForScalarType( + atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt, + IntegerType::Signless); + if (failed(maybeResultElementType)) { + atenToDtype.emitError("unable to convert `dtypeInt` to builtin type"); + return nullptr; + } + resultElementType = *maybeResultElementType; + Value result = convertScalarToDtype(b, loc, input, dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); return result; } if (auto divScalar = dyn_cast(op)) { @@ -1046,7 +1114,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value self = payloadArgs[0]; - Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); + Value threshold = + convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype); Value predicate; @@ -1068,7 +1137,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); + Value threshold = + convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value constantZero = b.create(loc, b.getZeroAttr(dtype)); Value predicate; @@ -1177,10 +1247,11 @@ class ConvertElementwiseOp : public ConversionPattern { AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, - AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, - AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, + AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, + AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, @@ -1477,10 +1548,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; - contractingDim0EqualsNumFeatures(weight); - contractingDim0EqualsNumFeatures(bias); - contractingDim0EqualsNumFeatures(runningMean); - contractingDim0EqualsNumFeatures(runningVar); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + contractingDim0EqualsNumFeatures(weight); + contractingDim0EqualsNumFeatures(bias); + contractingDim0EqualsNumFeatures(runningMean); + contractingDim0EqualsNumFeatures(runningVar); + } auto indexingMap = AffineMap::get( /*dimCount=*/inputRank, @@ -1677,7 +1750,8 @@ class ConvertAtenDetachOp : public OpConversionPattern { return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); return success(); } }; @@ -1712,17 +1786,18 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, - AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, - AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, - AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, - AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, - AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, - AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, - AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, - AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenRealOp, AtenImagOp>(); + AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, + AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, + AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, + AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, + AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, + AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, + AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, + AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp, + AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, + AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, + AtenImagOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 42c5d0b441cc..a666ca30b02f 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -231,7 +231,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // if this is the first tensor operand that didn't continue above: // take its dimension size as the size of the non-broadcasted // traversal along this dimension (this may include a dynamic size-1, - // **non-broadcasted** traversal!) + // **non-broadcasted** traversal unless if + // isAssumingStrictSymbolicShapes!) // emit error check "if the size does not match the non-broadcasted // traversal size along this dimension, error" // ``` @@ -251,6 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape(resultRank, c1); SmallVector indexingMaps; + bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b); for (Value tensorOperand : tensorOperands) { SmallVector exprs; auto type = tensorOperand.getType().cast(); @@ -294,11 +296,13 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // This is the check which protects against the undefined behavior of // the generated linalg op in the case of iterating two operands with // dimensions sizes that are expected to match. - auto equalToRunning = - b.create(loc, arith::CmpIPredicate::eq, - resultShape[resultDim], currentDimSize); - b.create(loc, equalToRunning, - "mismatched size for broadcast"); + if (!elideDynamicBroadcastCheck) { + auto equalToRunning = + b.create(loc, arith::CmpIPredicate::eq, + resultShape[resultDim], currentDimSize); + b.create(loc, equalToRunning, + "mismatched size for broadcast"); + } } indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); @@ -323,12 +327,15 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // Broadcasts input tensor based on the broadcastToShape. LogicalResult torch_to_linalg::broadcastToGivenShape( Operation *op, PatternRewriter &rewriter, Value input, - SmallVector broadcastToShape, Value &result, - SmallVector useBroadcastToShape) { + SmallVector broadcastToShape, RankedTensorType broadcastType, + Value &result, SmallVector useBroadcastToShape) { RankedTensorType inputType = input.getType().cast(); + int64_t inputRank = inputType.getRank(); + int64_t outputRank = broadcastToShape.size(); + ArrayRef outputShape = broadcastType.getShape(); SmallVector inputShape = makeShapeTorchCompatible(inputType.getShape()); - if (broadcastToShape.size() < inputShape.size()) { + if (outputRank < inputRank) { return rewriter.notifyMatchFailure( op, "invalid shape: broadcastToShape size must not be smaller than the " "size of the input shape"); @@ -336,7 +343,11 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Type elementType = inputType.getElementType(); Location loc = op->getLoc(); - SmallVector outShape; + SmallVector outShape; + bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter); + + // Vector indicating broadcasted status when assuming strict symbolic shapes. + SmallVector broadcastedStatus; // Create affine map and shapes for tensor initialization. SmallVector outExpr; @@ -346,17 +357,48 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( rewriter.create(loc, rewriter.getIndexAttr(0)); Value oneIndex = rewriter.create(loc, rewriter.getIndexAttr(1)); - size_t diff = broadcastToShape.size() - inputShape.size(); - for (size_t i = 0; i < broadcastToShape.size(); i++) { + size_t diff = outputRank - inputRank; + bool hasDynamicNumpyBroadcast = false; + for (size_t i = 0, e = outputRank; i < e; i++) { Value shapeValue = broadcastToShape[i]; size_t j = i - diff; + bool isDynamic = i >= diff && inputShape[j] == kUnknownSize; + + // Inherit static output shapes if present. + if (outputShape[i] != ShapedType::kDynamic) { + outShape.push_back(rewriter.getIndexAttr(outputShape[i])); + if (i < diff) { + if (outputShape[i] < 0) { + return rewriter.notifyMatchFailure( + op, "invalid shape: negative values not allowed in new broadcast " + "dimensions"); + } + continue; + } + if (isDynamic) { + hasDynamicNumpyBroadcast = true; + } else if (inputShape[j] != outputShape[i] && inputShape[j] != 1) { + return rewriter.notifyMatchFailure( + op, "invalid shape: static mismatch in input and output broadcast " + "shapes"); + } + + // If strict symbolic shapes are assumed and the input shape is dynamic, + // we can assume that dim is not broadcasted. + broadcastedStatus.push_back(inputShape[j] != outputShape[i] && + !isDynamic); + continue; + } + if (i < diff) { - Value isValid = rewriter.create( - loc, arith::CmpIPredicate::sge, shapeValue, zero); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "negative values not allowed in new dimensions")); + if (!elideDynamicBroadcastCheck) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "negative values not allowed in new dimensions")); + } outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); continue; } @@ -367,24 +409,80 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Value select = rewriter.create( loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue)); outShape.push_back(select); - } else { - // Case of dynamic input dimension wherein the shape to broadcast will - // yield us the dimension size of the output. - Value dim = getDimOp(rewriter, loc, input, j); - if (!useBroadcastToShape.empty()) { - if (useBroadcastToShape[i]) - dim = castIntToIndex(rewriter, loc, broadcastToShape[j]); + broadcastedStatus.push_back(true); + continue; + } + + // Case of dynamic input dimension wherein the shape to broadcast will + // yield us the dimension size of the output. + Value dim; + if (!useBroadcastToShape.empty() && useBroadcastToShape[j]) { + dim = castIntToIndex(rewriter, loc, broadcastToShape[i]); + if (isDynamic) { + hasDynamicNumpyBroadcast = true; + } + if (!elideDynamicBroadcastCheck) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "unimplemented: dynamic negative broadcast sizes")); } - outShape.push_back(dim); + } else { + dim = getDimOp(rewriter, loc, input, j); } + // We can safely assume this dimension is not broadcasted with strict + // symbols. + broadcastedStatus.push_back(false); + outShape.push_back(dim); } - Value outTensor = rewriter.create( - loc, getAsOpFoldResult(outShape), elementType); + Value outTensor = + rewriter.create(loc, outShape, elementType); + + // If we know there are no ? -> ? broadcasted dims, or we are assuming + // strict symbols, we can safely use standard linalg style broadcasting + // semantics. + if (!hasDynamicNumpyBroadcast || elideDynamicBroadcastCheck) { + // If no dims are broadcasted and the rank doesn't change, we can just fold + // the op away entirely. + if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) && + inputRank == outputRank) { + result = rewriter.create(loc, outTensor.getType(), input); + return success(); + } + + SmallVector inputExprs; + for (int64_t i = 0, e = inputRank; i < e; ++i) { + if (broadcastedStatus[i]) { + inputExprs.push_back(rewriter.getAffineConstantExpr(0)); + continue; + } + inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); + } + + SmallVector indexingMaps = { + AffineMap::get(outputRank, 0, inputExprs, rewriter.getContext()), + rewriter.getMultiDimIdentityMap(outputRank)}; + SmallVector iteratorTypes( + outputRank, utils::IteratorType::parallel); + result = rewriter + .create( + loc, outTensor.getType(), input, outTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + return success(); + } + // Fall back to numpy-style dynamic broadcasting in the form of a single + // linalg op. SmallVector indexingMaps = { - rewriter.getMultiDimIdentityMap(broadcastToShape.size())}; - SmallVector iteratorTypes(broadcastToShape.size(), + rewriter.getMultiDimIdentityMap(outputRank)}; + SmallVector iteratorTypes(outputRank, utils::IteratorType::parallel); result = rewriter .create( @@ -395,7 +493,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( // would be used to extract values from the input tensor // later on. SmallVector loopIndices; - for (size_t i = 0; i < broadcastToShape.size(); ++i) { + for (size_t i = 0, e = outputRank; i < e; ++i) { if (i < diff) continue; loopIndices.push_back(b.create(loc, i)); @@ -404,7 +502,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( // the i-th input dimension is not 1, else it contains a // zero index. SmallVector inputIndicesToExtract; - for (size_t i = 0, n = inputShape.size(); i < n; i++) { + for (size_t i = 0, n = inputRank; i < n; i++) { if (inputShape[i] == 1) { inputIndicesToExtract.push_back(zeroIndex); } else { diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index 354012028b01..3bee8d642533 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -73,10 +73,12 @@ Value createElementwiseLinalgGeneric( function_ref bodyBuild); // Broadcasts input tensor based on the broadcastToShape. -LogicalResult -broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input, - SmallVector broadcastToShape, Value &result, - SmallVector useBroadcastToShape = {}); +LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, + Value input, + SmallVector broadcastToShape, + RankedTensorType broadcastType, + Value &result, + SmallVector useBroadcastToShape = {}); // Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> // diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 146959151240..96e14f0fdd6e 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -237,17 +237,17 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { SmallVector regionArgTypes; SmallVector regionArgLocs; - for (Value value : scfForOp.getLoopBody().front().getArguments()) { + for (Value value : scfForOp.getRegion().front().getArguments()) { regionArgTypes.push_back(value.getType()); regionArgLocs.push_back(value.getLoc()); } // Populate the loop body region. - if (!scfForOp.getLoopBody().empty()) - rewriter.eraseBlock(&scfForOp.getLoopBody().back()); + if (!scfForOp.getRegion().empty()) + rewriter.eraseBlock(&scfForOp.getRegion().back()); - auto *block = rewriter.createBlock(&scfForOp.getLoopBody(), - scfForOp.getLoopBody().begin(), + auto *block = rewriter.createBlock(&scfForOp.getRegion(), + scfForOp.getRegion().begin(), regionArgTypes, regionArgLocs); // Rewrite uses of the torch loop block arguments to the new for-loop diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index ff27287649ad..d11a5524af7d 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1332,7 +1332,7 @@ class ConvertAtenScatterReduceTwoOp if (update.getType().isa()) { result = b.create(loc, update, current); } else if (update.getType().isa()) { - result = b.create(loc, update, current); + result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } @@ -1340,7 +1340,7 @@ class ConvertAtenScatterReduceTwoOp if (update.getType().isa()) { result = b.create(loc, update, current); } else if (update.getType().isa()) { - result = b.create(loc, update, current); + result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 31e8292452a9..c2c73708d79d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -121,8 +121,8 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, return (doubleValue == static_cast(static_cast(doubleValue))); } else { assert(isInt); - return (intValue >= std::numeric_limits::min()) && - (intValue <= std::numeric_limits::max()); + return (intValue >= static_cast(std::numeric_limits::min())) && + (intValue <= static_cast(std::numeric_limits::max())); } return true; } @@ -145,8 +145,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "Unable to extract the scalar constant"); if (dtype.isa()) { - tosaTensor = tosa::getConstTensor( - rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype) + tosaTensor = tosa::getConstTensor(rewriter, op, + (isFloat ? doubleValue : intValue), + dshape, dtype) .value(); } else if (auto intType = dtype.dyn_cast()) { auto w = intType.getWidth(); @@ -162,7 +163,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "of destination type"); } bool d = isFloat ? static_cast(doubleValue) - : static_cast(intValue); + : static_cast(intValue); tosaTensor = tosa::getConstTensor(rewriter, op, {d}, dshape).value(); } else if (w == 32) { @@ -616,7 +617,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Negative slope needs to be a scalar constant for conversion to " "TOSA LeakyReLU operation"); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()).value(); + auto zero = + tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()) + .value(); auto cond = rewriter.create( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), @@ -2253,11 +2256,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); - auto epsilonConst = - tosa::getConstTensor(rewriter, op.getOperation(), - {static_cast(eps)}, {}, - meanType.getElementType()) - .value(); + auto epsilonConst = tosa::getConstTensor(rewriter, op.getOperation(), + {static_cast(eps)}, {}, + meanType.getElementType()) + .value(); auto batchNorm = computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, @@ -2523,6 +2525,60 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnflattenIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a ranked tensor type + auto selfType = adaptor.getSelf().getType().dyn_cast(); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, + "Only ranked tensor types with static shapes are currently supported"); + + int64_t selfRank = selfType.getRank(); + int64_t dim; + + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); + + SmallVector sizes; + if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizes))) + return rewriter.notifyMatchFailure( + op, "Only constant sizes are currently supported"); + + if (selfRank > 0 && !isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + SmallVector newShape; + for (auto s : + llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) { + int64_t idx = s.index(); + if (idx < dim || idx > dim) { + newShape.push_back(s.value()); + } else { + auto sum = 1; + for (auto newDims : sizes) { + newShape.push_back(newDims); + sum *= newDims; + } + if (sum != s.value()) + return rewriter.notifyMatchFailure(op, + "sizes mismatch with original dim"); + } + } + + auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), + selfType.getElementType()); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(newType), adaptor.getSelf(), + rewriter.getDenseI64ArrayAttr(newShape)); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPermuteOp op, OpAdaptor adaptor, @@ -2571,7 +2627,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); - auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056}, + auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056f}, ln2Shape, selfType.getElementType()) .value(); auto rcpOp = @@ -2802,21 +2858,25 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}, dtype).value(); + auto a1 = + tosa::getConstTensor(rewriter, op, 0.278393f, {}, dtype).value(); auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}, dtype).value(); + auto a2 = + tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}, dtype).value(); + auto a3 = + tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}, dtype).value(); + auto a4 = + tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -2851,13 +2911,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Value xMinusMean = rewriter.create(loc, outType, x, mean); // rsqrt of 2 Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678, {}, dtype).value(); + tosa::getConstTensor(rewriter, op, 0.70710678f, {}, dtype).value(); Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); Value erf = approximateErfOp(rewriter, op, erfArg, dtype); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); + Value oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); @@ -2891,7 +2952,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); cdf = rewriter.createOrFold( - op->getLoc(), cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + op->getLoc(), + cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, @@ -2927,15 +2989,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto loc = op->getLoc(); - const double cstAlpha0 = 1.12837916709551257390; - const double cstAlpha1 = 0.70710678118654752440; - const double oneHalf = 0.5; - const double kAlpha = cstAlpha0 * cstAlpha1; + const float cstAlpha0 = 1.12837916709551257390f; + const float cstAlpha1 = 0.70710678118654752440f; + const float oneHalf = 0.5f; + const float kAlpha = cstAlpha0 * cstAlpha1; - Value kAlphaHalf = - tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value(); + Value kAlphaHalf = tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, + {}, selfElemTy) + .value(); Value negOneHalf = - tosa::getConstTensor(rewriter, op, -0.5, {}, selfElemTy).value(); + tosa::getConstTensor(rewriter, op, -0.5f, {}, selfElemTy).value(); Value inputSquared = rewriter.create( loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); Value negHalfInputSquared = rewriter.create( @@ -3006,7 +3069,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Only scalar constant is supported"); } - Value replace = tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); + Value replace = + tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); Type outType = getTypeConverter()->convertType(op.getType()); Value lesser = rewriter.create( @@ -3553,7 +3617,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // convert None to [0,0,0] auto indexNext = indexTensors[i + 1]; auto indexNextTorch = tensorsTorchType[i + 1]; - if (indexNextTorch.getType().isa()){ + if (indexNextTorch.getType().isa()) { return rewriter.notifyMatchFailure( op, "Multiple None index is not support for now."); } @@ -3620,12 +3684,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTfConcatTensors, lastDim); if (!indicesTf) { - return rewriter.notifyMatchFailure( - op, "Convert TorchIndex To TfIndices fail."); + return rewriter.notifyMatchFailure(op, + "Convert TorchIndex To TfIndices fail."); } - // do the tf scatterNd algorithm with tf style indices as input, algorithm mostly take from convertGatherNdOp. + // do the tf scatterNd algorithm with tf style indices as input, algorithm + // mostly take from convertGatherNdOp. auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, - indicesTf.getResult(), fillValues); + indicesTf.getResult(), fillValues); if (!result) { return rewriter.notifyMatchFailure( @@ -3855,6 +3920,59 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIscloseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // check args + double rtol, atol; + bool equalNan; + if (!matchPattern(op.getRtol(), m_TorchConstantFloat(&rtol))) + return rewriter.notifyMatchFailure(op, "rtol must be a scalar constant"); + if (!matchPattern(op.getAtol(), m_TorchConstantFloat(&atol))) + return rewriter.notifyMatchFailure(op, "atol must be a scalar constant"); + if (!matchPattern(op.getEqualNan(), m_TorchConstantBool(&equalNan))) + return rewriter.notifyMatchFailure( + op, "unimplemented: equal_nan is expected to be false"); + + // check tensor type. + auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto otherType = adaptor.getOther().getType().dyn_cast(); + if (!selfType || !otherType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + if (!selfType.hasStaticShape() || !otherType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + if (!selfType.getElementType().isa() || + !otherType.getElementType().isa()) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only FP element type is supported"); + } + + auto rhsSubOp = rewriter.create( + op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther()); + auto rhsAbsOp = + rewriter.create(op->getLoc(), selfType, rhsSubOp); + + auto lhsAbsOp = + rewriter.create(op->getLoc(), otherType, adaptor.getOther()); + auto rtolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(rtol)); + auto mulOp = rewriter.create(op->getLoc(), otherType, + rtolConstOp, lhsAbsOp, /*shift=*/0); + auto atolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(atol)); + auto addOp = + rewriter.create(op->getLoc(), otherType, atolConstOp, mulOp); + + auto outType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, outType, addOp, + rhsAbsOp); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampOp op, OpAdaptor adaptor, @@ -3866,19 +3984,37 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only tensor types input are currently supported"); - int64_t int_min, int_max; - if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_min` should be a torch constant int"); + IntegerAttr min_int, max_int; + FloatAttr min_fp, max_fp; + if (selfType.getElementType().isa()) { + double fp_min, fp_max; + if (!matchPattern(op.getMin(), m_TorchConstantFloat(&fp_min))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `fp_min` should be a torch constant float"); - if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_max` should be a torch constant int"); + if (!matchPattern(op.getMax(), m_TorchConstantFloat(&fp_max))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `fp_max` should be a torch constant float"); - IntegerAttr min_int = rewriter.getI64IntegerAttr(int_min); - IntegerAttr max_int = rewriter.getI64IntegerAttr(int_max); - FloatAttr min_fp = rewriter.getF32FloatAttr(float(int_min)); - FloatAttr max_fp = rewriter.getF32FloatAttr(float(int_max)); + min_int = rewriter.getI64IntegerAttr(static_cast(fp_min)); + max_int = rewriter.getI64IntegerAttr(static_cast(fp_max)); + min_fp = rewriter.getF32FloatAttr(static_cast(fp_min)); + max_fp = rewriter.getF32FloatAttr(static_cast(fp_max)); + } else { + int64_t int_min, int_max; + if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `int_min` should be a torch constant int"); + + if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `int_max` should be a torch constant int"); + + min_int = rewriter.getI64IntegerAttr(int_min); + max_int = rewriter.getI64IntegerAttr(int_max); + min_fp = rewriter.getF32FloatAttr(static_cast(int_min)); + max_fp = rewriter.getF32FloatAttr(static_cast(int_max)); + } auto outType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf(), @@ -5039,6 +5175,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(AtenLog2Op); INSERT_ATENOP_PATTERN(AtenThresholdOp); @@ -5068,6 +5205,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index c192ff33a25f..89b17a50be99 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -249,7 +249,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, - std::optional srcOriginalDtype) { + std::optional srcOriginalDtype, + std::optional dstOriginalDtype) { Type scalarType = scalar.getType(); if (scalarType == dtype) return scalar; @@ -261,14 +262,20 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return false; }; - // We only support conversion from Byte or Char scalarType not to Byte or Char - // dtype. + // We don't support conversion to Byte dtype. if (isByteOrChar(dtype)) { - mlir::emitError(loc) << "unsupported: conversion to byte or char type for " - "convertScalarToDtype " - << scalarType << "(scalar type) -> " << dtype - << "(dtype)"; - return nullptr; + if (!dstOriginalDtype.has_value()) { + mlir::emitError(loc) + << "unimplemented: for conversion to byte or char type " + "dstOriginalDtype has to be passed to convertScalarToDtype"; + return nullptr; + } + if (dstOriginalDtype->isUnsignedInteger()) { + mlir::emitError(loc) + << "unsupported: conversion to byte type for convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + return nullptr; + } } // If the dtype is i1, i.e., a boolean type. diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index aed453a62da4..bf930d68bc39 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2371,7 +2371,8 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size() || - !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) + (!isAssumingStrictSymbolicShapes((*this)->getBlock()) && + (!inType.areAllSizesKnown() || !outType.areAllSizesKnown()))) return nullptr; for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 697ad6bbd7ef..47d76219f199 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6322,6 +6322,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.log10\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.log1p\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7201,6 +7205,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.unflatten.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.slice.t %arg0, %none, %arg1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %1 = torch.aten.add.t %0, %arg2 : !torch.list, !torch.list -> !torch.list\n" +" %2 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.slice.t %arg0, %2, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %4 = torch.aten.add.t %1, %3 : !torch.list, !torch.list -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list, !torch.list, !torch.optional>) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7406,10 +7420,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7458,6 +7480,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isclose\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8291,6 +8317,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log10\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -8563,6 +8594,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unflatten.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -9062,6 +9097,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isclose\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -9192,6 +9231,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9208,6 +9256,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9547,94 +9603,98 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %8 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %4 -> () {\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %7 -> () {\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %4 -> () {\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %7 -> () {\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9937,39 +9997,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" +" %2 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" +" %3 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" +" %4 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %11 : !torch.int\n" +" %5 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %6 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %7 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 63ce4f837e85..0bdfca26ddc1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3380,10 +3380,15 @@ class DecomposeAtenToDtypeLayoutOp op, "unimplemented: pinMemory is expected to be false"); } - // TODO: Add support for non-None device arg. + // TODO: Add support for device arg other than cpu. if (!op.getDevice().getType().isa()) { - return rewriter.notifyMatchFailure( - op, "unimplemented: device arg must be None"); + std::string device; + if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) + return rewriter.notifyMatchFailure( + op, "unimplemented: device must be a constant str"); + else if (device != "cpu") + return rewriter.notifyMatchFailure( + op, "unimplemented: device is expected to be cpu"); } // TODO: Add support for non-strided layout. @@ -3479,11 +3484,13 @@ class DecomposeAtenAdaptiveAvgPool1dOp : rewriter.create( loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); } else { - Value cond = rewriter.create(loc, inputSize, outputSize); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create(loc, inputSize, outputSize); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + } kernelSize.push_back(constantOne); } @@ -3581,13 +3588,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp loc, rewriter.getI64IntegerAttr( inputShape[rank - 2 + i]))); } else { - Value cond = rewriter.create(loc, inputHW[i], - outputShapeSizesTorchInt[i]); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); - + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create( + loc, inputHW[i], outputShapeSizesTorchInt[i]); + rewriter.create(loc, cond, + "unimplemented: only support cases " + "where input and output size are " + "equal for non-unit output size"); + } Value outMinusOne = rewriter.create( loc, outputShapeSizesTorchInt[i], constantOne); kernelSize.push_back( @@ -3817,13 +3825,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, loc, rewriter.getF64FloatAttr(correction)); // The `correction` value should be less than or equal to `productDimSize + // 1`. - Value productDimSizePlusOne = rewriter.create( - loc, productDimSize.getType(), productDimSize, constantOne); - Value cond = - rewriter.create(loc, productDimSizePlusOne, cstCorrection); - rewriter.create( - loc, cond, - "correction value should be less than or equal to productDimSize + 1"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value productDimSizePlusOne = rewriter.create( + loc, productDimSize.getType(), productDimSize, constantOne); + Value cond = rewriter.create(loc, productDimSizePlusOne, + cstCorrection); + rewriter.create( + loc, cond, + "correction value should be less than or equal to productDimSize + 1"); + } Value productDimSizeSubCorrection = rewriter.create(loc, productDimSize, cstCorrection); Value result = rewriter.create(loc, newOutputType, squareSum, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 10c4bea67dc0..ddc95bd4b2fd 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -199,10 +199,10 @@ bool Torch::isViewLikeOp(Operation *op) { // that it does not return a view and treat those as having value // semantics. return isa Torch::unsqueezeTensor(PatternRewriter &rewriter, op->getLoc(), unsqueezedType, input, dim); return unsqueezed; } + +bool Torch::isAssumingStrictSymbolicShapes(Block *block) { + for (Operation *parentOp = block->getParentOp(); parentOp; + parentOp = parentOp->getParentOp()) { + if (parentOp->hasAttr("torch.assume_strict_symbolic_shapes")) + return true; + } + return false; +} diff --git a/lib/RefBackend/CMakeLists.txt b/lib/RefBackend/CMakeLists.txt index 2ef5dab3ae8d..a8ed0439d815 100644 --- a/lib/RefBackend/CMakeLists.txt +++ b/lib/RefBackend/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_library(TorchMLIRRefBackend MLIRIR MLIRTransforms MLIRMathTransforms + MLIRLinalgTransforms ) mlir_check_all_link_libraries(TorchMLIRRefBackend) diff --git a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt index ad8380612edd..81a8383949c7 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt +++ b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt @@ -69,6 +69,7 @@ add_library(torch_mlir_ltc_backend SHARED backend_impl.cpp dynamic_ir.cpp mlir_node.cpp + tensor.cpp ops/device_data.cpp ops/generic.cpp ops/index.cpp diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp index 540f02ae606c..d06ad5963919 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -30,6 +30,7 @@ #include #include +#include "generated/LazyIr.h" #include "generated/LazyNativeFunctions.h" #include "generated/shape_inference.h" #include "ops/to_copy.h" @@ -143,32 +144,6 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { } // namespace -// at::Tensor LazyNativeFunctions::bernoulli( -// const at::Tensor& self, c10::optional generator) { -// TORCH_LAZY_FN_COUNTER("lazy::"); -// if (generator.has_value() && generator->defined()) { -// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli has generator value"); -// } -// auto self_tensor = torch::lazy::TryGetLtcTensor(self); - -// UNIMPLEMENTED_FUNCTION_ERROR(); -// // return torch::lazy::CreateAtenFromLtcTensor( -// // torch::lazy::bernoulli(self_tensor)); -// } - -// at::Tensor& LazyNativeFunctions::bernoulli_( -// at::Tensor& self, double p, c10::optional generator) { -// TORCH_LAZY_FN_COUNTER("lazy::"); -// if (generator.has_value() && generator->defined()) { -// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli_ has generator value"); -// } -// auto self_tensor = torch::lazy::TryGetLtcTensor(self); - -// UNIMPLEMENTED_FUNCTION_ERROR(); -// // torch::lazy::bernoulli_(self_tensor, p); -// // return self; -// } - // clone is special in LT because we make it a no-op. // This should be safe to do, because every operator in the LT is functional. at::Tensor LazyNativeFunctions::clone( @@ -352,64 +327,17 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::empty_symint( - at::SymIntArrayRef sym_size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional memory_format) { - // TODO: support this directly - auto size = C10_AS_INTARRAYREF_SLOW(sym_size); - const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType(); - at::TensorOptions options = at::TensorOptions() - .device(c10::Device(device_type)) - .layout(layout) - .pinned_memory(pin_memory) - .dtype(dtype); - auto x_result = at::empty(size, options, memory_format); - auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device)); - // See Note [Lazy Tensor Functionalization] - if (c10::impl::tls_local_dispatch_key_set().excluded_.has( - c10::DispatchKey::Functionalize)) { - // Invariant: if the functionalization key is in the exclude set, then we're expected - // to return an ordinary tensor, which will be "lifted" into a functional wrapper later. - return tensor; - } else { - auto wrapped = at::functionalization::impl::to_functional_tensor(tensor); - return wrapped; - } -} - -at::Tensor LazyNativeFunctions::empty_strided( - at::IntArrayRef size, at::IntArrayRef stride, - c10::optional dtype, c10::optional layout, - c10::optional device, c10::optional pin_memory) { - TORCH_LAZY_FN_COUNTER("lazy::"); - at::Tensor t = empty_symint( - c10::fromIntArrayRefSlow(size), - dtype, layout, device, pin_memory, c10::nullopt); - return t.as_strided(size, stride, /*storage_offset=*/0); -} - -at::Tensor& -LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - - torch::lazy::Value constant = - torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( - value, self_tensor->shape(), self_tensor->GetDevice()); - self_tensor->SetInPlaceIrValue(std::move(constant)); - return self; -} - at::Tensor LazyNativeFunctions::_unsafe_view( const at::Tensor& self, at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size)); } +at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { + TORCH_LAZY_FN_COUNTER("lazy::"); + return at::functionalization::functionalize_aten_op::call(self); +} + std::vector LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); @@ -643,9 +571,18 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint( c10::optional layout, c10::optional device, c10::optional pin_memory) { - return at::functionalization:: - functionalize_aten_op_symint::call( - self, size, stride, dtype, layout, device, pin_memory); + if (!device || device->type() == c10::DeviceType::Lazy) { + return at::functionalization::functionalize_aten_op_symint< + ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout, + device, pin_memory); + } + // For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu") + // we need to avoid explicit functionalization. To do that we create regular cpu tensors. + at::Tensor t = at::empty_symint( + size, (dtype ? dtype : c10::optional(self.scalar_type())), + (layout ? layout : c10::optional(self.layout())), device, + pin_memory, c10::nullopt); + return t.as_strided_symint(size, stride, /*storage_offset=*/0); } at::Tensor LazyNativeFunctions::narrow_copy_symint( @@ -729,4 +666,4 @@ at::Tensor& LazyNativeFunctions::logsumexp_out( void InitializeAtenBindings() {} } // namespace lazy -} // namespace torch +} // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp index e4b75e5d53d1..39dc1ad0cd58 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp @@ -83,7 +83,7 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; } hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } -TorchMlirNode* TorchMlirNode::mlir_node(int index) { +TorchMlirNode* TorchMlirNode::mlir_node(int index) const { return dynamic_cast(operands_.at(index).get()); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h index dbf3117dbb13..4b5e196beb20 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h @@ -51,7 +51,7 @@ class TORCH_API TorchMlirNode : public torch::lazy::Node { hash_t shapeHash() const override; - TorchMlirNode* mlir_node(int index); + TorchMlirNode* mlir_node(int index) const; virtual TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index f8d03449877d..d5458f9c4ea6 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -265,6 +265,33 @@ std::vector compute_shape_eye( return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } +std::vector compute_shape_arange( + const at::Scalar& end, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + auto out_meta = + at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_arange( + const at::Scalar& start, const at::Scalar& end, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta), + pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_arange( + const at::Scalar& start, const at::Scalar& end, const at::Scalar& step, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::arange(start, end, step, dtype, layout, + c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + std::vector compute_shape_full( at::IntArrayRef size, const at::Scalar& fill_value, c10::optional dtype, c10::optional layout, @@ -273,6 +300,44 @@ std::vector compute_shape_full( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } +std::vector compute_shape_ones( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_zeros( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_empty( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory, + c10::optional memory_format) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_empty_strided( + at::IntArrayRef size, at::IntArrayRef stride, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_fill(const at::Tensor& self, + const at::Scalar& value) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + std::vector compute_shape_fill(const at::Tensor& self, const at::Tensor& value) { return {Shape(self.scalar_type(), self.sizes().vec())}; @@ -302,11 +367,36 @@ std::vector compute_shape_randint( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } +std::vector compute_shape_resize( + const at::Tensor & self, at::IntArrayRef size, + c10::optional memory_format) { + return {Shape(self.scalar_type(), size.vec())}; +} + std::vector compute_shape_bernoulli( const at::Tensor& self, const at::Tensor &p, c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_scalar_tensor( + const at::Scalar & s, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; +} + +std::vector compute_shape_roll( + const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { + auto out_meta = + at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + + } // namespace lazy } // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp b/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp new file mode 100644 index 000000000000..82ae6cc27f4a --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp @@ -0,0 +1,29 @@ +//===- tensor.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include + +#include "tensor.h" + +namespace torch { +namespace lazy { + +at::Tensor CreateFunctionalizedAtenFromLtcTensor( + const LazyTensorPtr& ltc_tensor) { + at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor); + if (!c10::impl::tls_is_dispatch_key_excluded( + c10::DispatchKey::Functionalize) && + !at::functionalization::impl::isFunctionalTensor(tensor)) { + return at::functionalization::impl::to_functional_tensor(tensor); + } + return tensor; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/tensor.h b/python/torch_mlir/csrc/base_lazy_backend/tensor.h new file mode 100644 index 000000000000..4e39dd095aa5 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/tensor.h @@ -0,0 +1,24 @@ +//===- tensor.h -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace torch { +namespace lazy { + +// Ops like torch.ones/zeros etc. which produce new tensor as an output +// should have explicit tensor functinoalization. Otherwise we can get +// unfanctionalized primitives or in the worst case if we apply inplace +// operations to unfunctionalized tensor it won't be captured in LTC graph. +TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp index 7131e9a66a2b..cdd97168031b 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp @@ -7,28 +7,66 @@ namespace torch { namespace lazy { +bool is_detach_copy(const torch::lazy::Node* node) { + return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); +} bool is_detach_copy(const torch::lazy::Value& value) { - return value->op() == torch::lazy::DetachCopy::ClassOpKind(); + return is_detach_copy(value.node.get()); } -torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { - if (!value) { +torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) { + if (!node) { return nullptr; } + + torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); + while(mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; +} + +const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) { + if (!node) { return nullptr; } + + const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); + while(mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; +} + + +torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) { + if (!node) { + return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; +} +const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) { + if (!node) { return nullptr; } - torch::lazy::TorchMlirNode* node = dynamic_cast(value.node.get()); - while(node) { - if (node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } - else if (node->op() == torch::lazy::DetachCopy::ClassOpKind()) { - node = node->mlir_node(0); - } - else { - break; - } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); } return nullptr; } +torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { + if (!value) { + return nullptr; + } + return device_data_cast(value.node.get()); +} torch::lazy::DeviceData* device_data_cast( const at::Tensor& tensor, c10::optional device diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h index 717173e9a8fc..745be78c35d2 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h @@ -8,10 +8,15 @@ namespace torch { namespace lazy { -TORCH_API bool is_detach_copy(const torch::lazy::Value& value); +TORCH_API bool is_detach_copy(const torch::lazy::Node*); +TORCH_API bool is_detach_copy(const torch::lazy::Value&); -TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); +TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*); +TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*); +TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*); +TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*); +TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); TORCH_API torch::lazy::DeviceData* device_data_cast( const at::Tensor& tensor, c10::optional device = c10::nullopt ); diff --git a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 3bc8465eafc1..1064a3d1e1ac 100644 --- a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -28,6 +28,11 @@ using namespace torch::lazy; namespace torch { namespace lazy { +/// Returns true if a string begins with another. +inline bool beginswith(const std::string& s, const std::string& t) { + return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; +} + struct ReferenceLazyBackendDeviceType : public BackendDeviceType { ReferenceLazyBackendDeviceType(c10::DeviceType device_type) : device_type_(device_type) {} @@ -104,7 +109,25 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { // // JIT Execution adopted from: // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp - torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), ""); + std::shared_ptr graph = mlir_computation->graph(); + for (auto* node : graph->nodes()) { + // Convert any lazy devices to cpu devices to ensure + // that the values are actually computed + if (node->outputs().size() == 1 && + node->output()->type()->kind() == + c10::TypeKind::DeviceObjType) { + auto value_sym = torch::jit::Symbol::attr("value"); + TORCH_CHECK(node->hasAttribute(value_sym), + "Expected node to have 'value' attribute."); + TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s, + "Expected 'value' attribute to be a string."); + if (beginswith(node->s(value_sym), "lazy")) { + node->s_(value_sym, "cpu"); + } + } + } + + torch::jit::GraphExecutor graph_executor(graph, ""); std::vector stack; for (const auto& argument : arguments) { const auto mlir_data = diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index d6f064f745ed..958df70d575a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -122,6 +122,9 @@ def aten〇detach〡shape(self: List[int]) -> List[int]: def aten〇log2〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇log10〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇log1p〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -620,6 +623,9 @@ def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]: return upstream_shape_functions.flatten(self, start_dim, end_dim) +def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) -> List[int]: + return self[:dim] + sizes + self[dim + 1:] + def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: return upstream_shape_functions.linear(input, weight, bias) @@ -793,9 +799,15 @@ def aten〇bitwise_or〇Tensor〡shape(self: List[int], other: List[int]) -> Lis def aten〇bitwise_and〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_and〇Scalar〡shape(self: List[int], other: float) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇bitwise_not〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -832,6 +844,9 @@ def aten〇lt〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇le〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇isclose〡shape(self: List[int], other: List[int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇unsqueeze〡shape(self: List[int], dim: int) -> List[int]: return upstream_shape_functions.unsqueeze(self, dim) @@ -1438,6 +1453,11 @@ def aten〇log2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇log10〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1642,6 +1662,11 @@ def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_ self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[1])) +def aten〇unflatten〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, sizes: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2149,6 +2174,10 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇isclose〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> int: + return torch.bool + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)])) def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int: _, query_dtype = query_rank_dtype @@ -2257,6 +2286,14 @@ def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_ dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇bitwise_and〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_two_tensor_op()) def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype @@ -2273,6 +2310,14 @@ def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_ dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + # Different width @@ -2461,7 +2506,7 @@ def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], s _check_tensors_with_the_same_dtype( tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], tensor_device="cpu", - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_kwargs) + + error_types={torch.bool, torch.complex64, torch.complex128}, **_convolution_kwargs) + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_kwargs), ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), @@ -2473,8 +2518,9 @@ def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], s def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + assert input_dtype == weight_dtype + assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool ranks: List[Optional[int]] = [input_rank, weight_rank] dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) @@ -2494,7 +2540,7 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d _check_tensors_with_the_same_dtype( tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], tensor_device="cpu", - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + + error_types={torch.bool, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_deprecated_kwargs), ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), @@ -2507,8 +2553,9 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + assert input_dtype == weight_dtype + assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool ranks: List[Optional[int]] = [input_rank, weight_rank] dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) @@ -2812,7 +2859,7 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U # TODO: This should be fixed by switching to FakeTensor instead of Meta tensor @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool, torch.float16}) + + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool}) + [ErrorInvocation(TensorOfShape( 1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int32, device="cpu")), ErrorInvocation( @@ -2824,8 +2871,8 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int: batch1_rank, batch1_dtype = batch1_rank_dtype batch2_rank, batch2_dtype = batch2_rank_dtype - assert batch1_dtype not in [torch.bool, torch.float16] - assert batch2_dtype not in [torch.bool, torch.float16] + assert batch1_dtype is not torch.bool + assert batch2_dtype is not torch.bool assert batch1_dtype == batch2_dtype ranks: List[Optional[int]] = [batch1_rank, batch2_rank] dtypes = [batch1_dtype, batch2_dtype] diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 9945db7e0768..e473603ff6da 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -294,14 +294,17 @@ def emit_with_mutating_variants(key, **kwargs): "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", "aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::log2 : (Tensor) -> (Tensor)", + "aten::log10 : (Tensor) -> (Tensor)", "aten::sqrt : (Tensor) -> (Tensor)", "aten::log1p : (Tensor) -> (Tensor)", "aten::rsqrt : (Tensor) -> (Tensor)", "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)", "aten::unsqueeze : (Tensor, int) -> (Tensor)", @@ -339,6 +342,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") emit("aten::view_as_real : (Tensor) -> (Tensor)") + emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)") # Ops with dynamic number of outputs emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") @@ -359,6 +363,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::random : (Tensor, Generator?) -> (Tensor)") + emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") @@ -426,11 +432,20 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) + emit( + "aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)" + ) emit( "aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" ) emit( - "aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)" + "aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + ) + emit( + "aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + ) + emit( + "aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" ) emit( "aten::softmax.int : (Tensor, int, int?) -> (Tensor)" @@ -444,14 +459,20 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") - emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") + emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)") emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") + emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)") emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") @@ -496,6 +517,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") + emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)") emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") @@ -557,6 +579,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::tile : (Tensor, int[]) -> (Tensor)") emit("aten::reshape : (Tensor, int[]) -> (Tensor)") emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") + emit("aten::resize : (Tensor, int[], int?) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)") emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) @@ -605,6 +628,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") + emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)") # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index 7add08a3ecee..dce83de277ad 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -14,6 +14,7 @@ "NativeGroupNormBackwardModule_basic", "QuantizedMLP_basic", "ReduceMaxAlongDimUnsignedInt_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", } # TODO: Delete once torch 2.1.0 is released diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 1b5f62715a30..d78253a58fc3 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -304,6 +304,28 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils): # ============================================================================== +class UnflattenStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 6, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.unflatten(x, 1, (2, 3)) + + +@register_test_case(module_factory=lambda: UnflattenStaticModule()) +def UnflattenStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 6, 4)) + + +# ============================================================================== + + class FlattenStaticModule(torch.nn.Module): def __init__(self): @@ -4558,3 +4580,48 @@ def forward(self, x): @register_test_case(module_factory=lambda: Add_Module()) def Add_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) + + +# ============================================================================== + + +class IscloseStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ([5, 5], torch.float32, True), + ]) + def forward(self, x, y): + return torch.isclose(x, y) + + +@register_test_case(module_factory=lambda: IscloseStaticModule()) +def IscloseStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 5), tu.rand(5, 5)) + + +# ============================================================================== + + +class IscloseStaticModuleTrue(torch.nn.Module): + + def __init__(self): + super().__init__() + self.register_buffer('tensor', torch.ones(1)) + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.isclose(x, self.tensor) + +@register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) +def IscloseStaticModuleTrue_basic(module, tu: TestUtils): + module.forward(torch.ones(5, 5)) diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index a2e3e8e29608..3b2997c3e482 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1642,6 +1642,44 @@ def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseToDtypeI64ToI8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.int64, True)]) + def forward(self, x): + return x.to(torch.int8) + + +@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToI8Module()) +def ElementwiseToDtypeI64ToI8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + + +# ============================================================================== + + +class ElementwiseToDtypeI64ToUI8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.int64, True)]) + def forward(self, x): + return x.to(torch.uint8) + + +@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToUI8Module()) +def ElementwiseToDtypeI64ToUI8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + + +# ============================================================================== + + class ElementwiseLog2Module(torch.nn.Module): def __init__(self): @@ -1683,6 +1721,48 @@ def ElementwiseLog2IntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) +# ============================================================================== + +class ElementwiseLog10Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.log10(a) + + +@register_test_case(module_factory=lambda: ElementwiseLog10Module()) +def ElementwiseLog10Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + +class ElementwiseLog10IntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.log10(a) + + +@register_test_case(module_factory=lambda: ElementwiseLog10IntModule()) +def ElementwiseLog10IntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + # ============================================================================== @@ -2236,6 +2316,31 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSubTensorInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ]) + def forward(self, x, y): + return torch.sub(x, y, alpha=2) + + +@register_test_case(module_factory=lambda: ElementwiseSubTensorInt8Module()) +def ElementwiseSubTensorInt8Module_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, high=10).to(dtype=torch.int8), + tu.randint(3, 4, high=10).to(dtype=torch.int8)) + + +# ============================================================================== + + class ElementwiseSubScalarIntModule(torch.nn.Module): def __init__(self): @@ -2392,6 +2497,28 @@ def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAddScalarInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + return torch.add(x, 3, 2) + + +@register_test_case(module_factory=lambda: ElementwiseAddScalarInt8Module()) +def ElementwiseAddScalarInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=10).to(torch.int8)) + + +# ============================================================================== + + class ElementwiseCloneModule(torch.nn.Module): def __init__(self): @@ -3435,3 +3562,126 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: TupleModule()) def TupleModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2), tu.rand(2, 2)) + + +# ============================================================================== + + +class ElementwiseBitwiseRightShiftInt64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt64Module()) +def ElementwiseBitwiseRightShiftInt64Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64)) + + +class ElementwiseBitwiseRightShiftInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 4], torch.int32, True), + ([-1, 1], torch.int32, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt32Module()) +def ElementwiseBitwiseRightShiftInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32)) + + +class ElementwiseBitwiseRightShiftInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt8Module()) +def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8)) + + +# ============================================================================== + + +class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.bitwise_and(x, 15) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt64Module()) +def ElementwiseBitwiseAndScalarInt64Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000)) + + +class ElementwiseBitwiseAndScalarInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.bitwise_and(x, 100) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt32Module()) +def ElementwiseBitwiseAndScalarInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32)) + + +class ElementwiseBitwiseAndScalarInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + return torch.bitwise_and(x, 100) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt8Module()) +def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8)) diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 1c2d810c13de..304d3025eb8d 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -672,6 +672,108 @@ def forward(self, a): def ViewNegativeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 128)) +class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0), 1, 1, 1) + +@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule()) +def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 1, 1, 1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0)) + +@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule()) +def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 1, 1, 1)) + +class ViewSizeDimLedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(1, 1, 1, a.size(0)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedByExpandedOnesModule()) +def ViewSizeDimLedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimLedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(3)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedByCollapsedOnesModule()) +def ViewSizeDimLedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 1, 128)) + +class ViewSizeDimLedAndFollowedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(1, 1, 1, a.size(0), 1, 1, 1) + +@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule()) +def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimLedAndFollowedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 1, -1, 1, 1, 1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(3)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule()) +def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 1, 128, 1, 1, 1)) + # ============================================================================== class ReshapeAliasExpandModule(torch.nn.Module): @@ -710,4 +812,4 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReshapeAliasCollapseModule()) def ReshapeAliasCollapseModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4)) \ No newline at end of file + module.forward(tu.rand(2, 4)) diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 5a4a19c50aba..6e04c5fa8700 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -169,6 +169,28 @@ def forward(self, x): def ToDtypeLayoutNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) +class ToDtypeLayoutCPUModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.to(x, + dtype=torch.float64, + layout=None, + device="cpu", + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None) + + +@register_test_case(module_factory=lambda: ToDtypeLayoutCPUModule()) +def ToDtypeLayoutCPUModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + class ToDtypeLayoutStridedModule(torch.nn.Module): diff --git a/pytorch-hash.txt b/pytorch-hash.txt index b45361a5173e..5a74e57f8cef 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -2baa4c49288efeded2fad677b2f28570b0ce858b +20217d1426d99d0caa70e1473d89e0c834b7f35e diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 76e15ba4dc62..6f00db5ca012 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230913 +torch==2.2.0.dev20231006 diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index d95b7e1d87cf..470962e2494d 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -8,11 +8,11 @@ // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor +// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[RHS_DIM_0:.*]] = tensor.dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor // CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index // CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm" // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty(%[[LHS_DIM_0]], %[[RHS_DIM_1]]) : tensor @@ -29,6 +29,17 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- +// CHECK-LABEL: func.func @torch.aten.mm$basic_strict( +// CHECK-NOT: assert +func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> + return %0 : !torch.vtensor<[?,2],f32> +} + +// ----- + // If the operands are missing dtype, we cannot lower it. func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { // expected-error@+1 {{failed to legalize}} @@ -264,4 +275,4 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> { %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> -} \ No newline at end of file +} diff --git a/test/Conversion/TorchToLinalg/broadcast.mlir b/test/Conversion/TorchToLinalg/broadcast.mlir new file mode 100644 index 000000000000..8841ba704328 --- /dev/null +++ b/test/Conversion/TorchToLinalg/broadcast.mlir @@ -0,0 +1,90 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$simple_static( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<3x4x2xf32> +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins({{.*}} : tensor<4x2xf32>) outs({{.*}} : tensor<3x4x2xf32>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<3x4x2xf32> +func.func @torch.aten.broadcast_to$simple_static(%arg0: !torch.vtensor<[4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[4,2],f32>, !torch.list -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$static_numpy_broadcast( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<1x4x2xf32> +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins({{.*}} : tensor<1x1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<1x4x2xf32> +func.func @torch.aten.broadcast_to$static_numpy_broadcast(%arg0: !torch.vtensor<[1,1,2],f32>) -> !torch.vtensor<[1,4,2],f32> { + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %list = torch.prim.ListConstruct %int1, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[1,1,2],f32>, !torch.list -> !torch.vtensor<[1,4,2],f32> + return %0 : !torch.vtensor<[1,4,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$empty_input( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>] +// CHECK-SAME: iterator_types = ["parallel"]} +// CHECK-SAME: ins({{.*}} : tensor) outs({{.*}} : tensor) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor +func.func @torch.aten.broadcast_to$empty_input(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> { + %list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$strict_dynamic_broadcast( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins({{.*}} : tensor) outs({{.*}} : tensor) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor +func.func @torch.aten.broadcast_to$strict_dynamic_broadcast(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> attributes {torch.assume_strict_symbolic_shapes} { + %list = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +/// Nothing we can do; verify we hit the fall back path. +// CHECK-LABEL: func.func @torch.aten.broadcast_to$pure_dynamic_broadcast( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs({{.*}} : tensor) { +// CHECK: ^bb0(%[[OUT:.+]]: f32): +// CHECK: tensor.extract +func.func @torch.aten.broadcast_to$pure_dynamic_broadcast(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> { + %list = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 49907a98a56d..180f48bcef2b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -41,16 +41,16 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // ----- // CHECK-LABEL: func.func @torch.aten.leaky_relu$basic( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e-01 -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_0]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]] {shift = 0 : i32} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_0]], %[[VAL_5]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e-01 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_6]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 1.000000e-01 @@ -155,13 +155,13 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // ----- // CHECK-LABEL: func.func @torch.aten.add$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -175,13 +175,13 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // ----- // CHECK-LABEL: func.func @torch.aten.sub$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -195,13 +195,14 @@ func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // ----- // CHECK-LABEL: func.func @torch.aten.mul$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]] {shift = 0 : i32} : (tensor, tensor) -> tensor -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: } func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32> @@ -210,14 +211,15 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // ----- // CHECK-LABEL: func.func @torch.aten.div$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RCP:.*]] = tosa.reciprocal %[[ARG1_BUILTIN]] : (tensor) -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[RCP]] {shift = 0 : i32} : (tensor, tensor) -> tensor -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +// CHECK: } func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32> @@ -394,13 +396,13 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) // ----- // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -415,13 +417,13 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // ----- // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -502,7 +504,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // ----- // CHECK-LABEL: func.func @torch.aten.native_batch_norm$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>}> : () -> tensor<4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32> @@ -518,8 +520,8 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_13:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_9]], %[[VAL_12]] : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> // CHECK: %[[VAL_15:.*]] = tosa.rsqrt %[[VAL_14]] : (tensor<4x1xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_17]], %[[VAL_11]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32> @@ -538,7 +540,7 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32 // ----- // CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 @@ -557,9 +559,31 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor // ----- // CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, -// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { +// CHECK: %[[VAL:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32> +// CHECK: %[[VAL_1:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL]] {new_shape = array} : (tensor<1x6x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,3,4],f32> +// CHECK: } +func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[1,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[1,2,3,4],f32> + return %1 : !torch.vtensor<[1,2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @forward( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> @@ -573,22 +597,22 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor // CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> // CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor // CHECK: %[[VAL_27:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_28:.*]] = tosa.add %[[VAL_23]], %[[VAL_26]] : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_29:.*]] = tosa.rsqrt %[[VAL_28]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_32:.*]] = tosa.add %[[VAL_31]], %[[VAL_25]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> // CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32> @@ -661,12 +685,12 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32> // ----- // CHECK-LABEL: func.func @torch.aten.log2$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i32} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -974,7 +998,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> // CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> // CHECK: %[[VAL_17:.*]] = tosa.gather %[[VAL_11]], %[[VAL_16]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> @@ -997,7 +1021,7 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> // CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> // CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<2x2xi32>) -> tensor<2x2xi64> // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> @@ -1017,7 +1041,7 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_3:.*]] = torch.constant.int 256 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> // CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> // CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> @@ -1048,6 +1072,23 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch return %0 : !torch.vtensor<[1,1,128,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.float( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],f32>) -> !torch.vtensor<[1,1,128,128],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],f32> -> tensor<1x1x128x128xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 6.432100e+00 : f32, max_int = 6 : i64, min_fp = 3.123400e+00 : f32, min_int = 3 : i64} : (tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xf32> -> !torch.vtensor<[1,1,128,128],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],f32> +// CHECK: } +func.func @torch.aten.clamp.float(%arg0: !torch.vtensor<[1,1,128,128],f32>) -> !torch.vtensor<[1,1,128,128],f32> { + %fp_min = torch.constant.float 3.123400e+00 + %fp_max = torch.constant.float 6.432100e+00 + %0 = torch.aten.clamp %arg0, %fp_min, %fp_max : !torch.vtensor<[1,1,128,128],f32>, !torch.float, !torch.float -> !torch.vtensor<[1,1,128,128],f32> + return %0 : !torch.vtensor<[1,1,128,128],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, @@ -1116,20 +1157,49 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to // ----- // CHECK-LABEL: func.func @torch.aten.remainder.Scalar( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> -// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5:.*]] : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3:.*]], %[[VAL_6:.*]] {shift = 0 : i32} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_5]], %[[VAL_8]] {shift = 0 : i32} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_10:.*]] = tosa.sub %[[VAL_3]], %[[VAL_9]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> -// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_1]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32> // CHECK: } func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { %int2 = torch.constant.int 2 %0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32> return %0 : !torch.vtensor<[2, 4],f32> } + +// ----- + +// CHECK-LABEL: func.func @forward( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.float 1.000000e-08 +// CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 +// CHECK: %[[VAL_6:.*]] = torch.constant.bool false +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_8:.*]] = tosa.abs %[[VAL_7]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.abs %[[VAL_3]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_9]] {shift = 0 : i8} : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.add %[[VAL_12]], %[[VAL_11]] : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_14:.*]] = tosa.greater_equal %[[VAL_13]], %[[VAL_8]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[5,5],i1> +// CHECK: } +func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { + %float1.000000e-08 = torch.constant.float 1.000000e-08 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %false = torch.constant.bool false + %0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1> + return %0 : !torch.vtensor<[5,5],i1> +} diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index c1d1d915b017..5813cd4351ac 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -1,9 +1,11 @@ // RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' -split-input-file -verify-diagnostics %s | FileCheck %s -// CHECK-LABEL: torch.aten.mul.Scalar$mixed_type -// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> -// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i32} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> +// CHECK-LABEL: func.func @torch.aten.mul.Scalar$mixed_type( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>) -> tensor<5xbf16> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> +// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> +// CHECK: return %[[VAL_2]] : tensor<5xbf16> +// CHECK: } func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> { %float2.000000e00 = torch.constant.float 2.000000e+00 %0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16> @@ -88,12 +90,14 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?], // ----- -// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp -// CHECK-SAME: %[[VAL_0:.*]]: tensor, -// CHECK-SAME: %[[VAL_1:.*]]: tensor -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK-LABEL: func.func @torch.aten.div.Tensor$mixed_type_fp( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_4]] : tensor +// CHECK: } func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32> diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 21e0500f4eb5..b66bd24e1bb3 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1983,6 +1983,15 @@ func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> ! return %0 : !torch.vtensor<[3,4,2],f32> } +// CHECK-LABEL: func.func @torch.aten.broadcast_to_strict$fold( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?],f32>, {{.*}}) -> !torch.vtensor<[?],f32> +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?],f32> +func.func @torch.aten.broadcast_to_strict$fold(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { + %list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> // CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32> diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 3de56bb10b07..4a54d8b07666 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230913 +torchvision==0.17.0.dev20231006