From 7a7be60dcfb5dbc4b8baf2b91726c879269cda2c Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 14 Sep 2023 10:23:44 -0700 Subject: [PATCH 01/41] Fix python package install instructions (#2464) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1d6d448cbdaf..c5fa561bcd15 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ python -m pip install --upgrade pip Then, we can install torch-mlir with the corresponding torch and torchvision nightlies. ``` pip install --pre torch-mlir torchvision \ - -f https://llvm.github.io/torch-mlir/package-index/ + -f https://llvm.github.io/torch-mlir/package-index/ \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu ``` From b03efdf2e47b0effbf66d84d9ed75cc90ab7ee2d Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 19 Sep 2023 13:18:27 +0000 Subject: [PATCH 02/41] build: manually update PyTorch version Set PyTorch and TorchVision version to nightly release 2023-09-18. Signed-Off By: Vivek Khandelwal --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index b45361a5173e..f665e35e6884 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -2baa4c49288efeded2fad677b2f28570b0ce858b +ba087c0903f1c59eb993614f46602d39fcec2dfd diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 76e15ba4dc62..0bf75dc90a30 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230913 +torch==2.2.0.dev20230918 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 3de56bb10b07..d09ae0e0eaf0 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230913 +torchvision==0.17.0.dev20230918 From 278c41e9388c6b20a92c8a7dc735578424b395a1 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 19 Sep 2023 10:50:53 -0700 Subject: [PATCH 03/41] Bump llvm-project to f66cd9e9556a53142a26a5c21a72e21f1579217c. (#2466) Picks up DenseResourceElementsAttr python support and fixes minf/maxf C++ rename. --- .gitmodules | 2 +- .../lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 2 +- externals/llvm-project | 2 +- lib/Conversion/TorchToLinalg/Reduction.cpp | 8 ++++---- lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8b46098d9615..3c50187b65a4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,4 @@ url = https://github.com/llvm/llvm-project.git [submodule "externals/stablehlo"] path = externals/stablehlo - url = https://github.com/openxla/stablehlo.git + url = https://github.com/shark-infra/stablehlo.git diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index ba7ed76c81cf..dcb2f4215891 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -233,7 +233,7 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, loc, init, [&](OpBuilder &b, Location loc, Value elem, Value acc) { Value x = b.create(loc, weight, localIVs); - Value max = b.create(loc, x, acc); + Value max = b.create(loc, x, acc); b.create(loc, max); }); }) diff --git a/externals/llvm-project b/externals/llvm-project index 4acc3ffbb0af..f66cd9e9556a 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 4acc3ffbb0af5631bc7916aeff3570f448899647 +Subproject commit f66cd9e9556a53142a26a5c21a72e21f1579217c diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 4078fbaa342c..641f1ef8cc1c 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -176,8 +176,8 @@ class ConvertAtenMaxDimOp : public OpConversionPattern { Value resultMax, predicate; if (inElementType.isa()) { - resultMax = - rewriter.create(nestedLoc, newValue, oldValue); + resultMax = rewriter.create(nestedLoc, newValue, + oldValue); predicate = rewriter.create( nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else { @@ -280,7 +280,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) - return b.create(loc, self, result); + return b.create(loc, self, result); else if (resultElementType.isa()) { IntegerType intType = max.getSelf() .getType() @@ -297,7 +297,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) - return b.create(loc, self, result); + return b.create(loc, self, result); else if (resultElementType.isa()) { IntegerType intType = min.getSelf() .getType() diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index ff27287649ad..d11a5524af7d 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1332,7 +1332,7 @@ class ConvertAtenScatterReduceTwoOp if (update.getType().isa()) { result = b.create(loc, update, current); } else if (update.getType().isa()) { - result = b.create(loc, update, current); + result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } @@ -1340,7 +1340,7 @@ class ConvertAtenScatterReduceTwoOp if (update.getType().isa()) { result = b.create(loc, update, current); } else if (update.getType().isa()) { - result = b.create(loc, update, current); + result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } From 20ea1c9e9159483cd14ca8141c4968845a23dea8 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 19 Sep 2023 23:05:52 -0700 Subject: [PATCH 04/41] Revert accidental change to submodule origin. (#2477) --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 3c50187b65a4..8b46098d9615 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,4 @@ url = https://github.com/llvm/llvm-project.git [submodule "externals/stablehlo"] path = externals/stablehlo - url = https://github.com/shark-infra/stablehlo.git + url = https://github.com/openxla/stablehlo.git From 023fc9007234c5b28a1fb28076e24ae44ae77c86 Mon Sep 17 00:00:00 2001 From: David Gens Date: Wed, 20 Sep 2023 10:47:08 -0700 Subject: [PATCH 05/41] [Torch Dialect] add avg_pool 2d and 3d op variants (#2473) Adds ODS for `avg_pool2d` and `avg_pool3d`, including their backward and `adaptive_` variants. --- e2e_testing/xfail_sets.py | 4 - .../Dialect/Torch/IR/GeneratedTorchOps.td | 231 +++++++++++++++++- .../jit_ir/build_tools/torch_ods_gen.py | 18 +- 3 files changed, 236 insertions(+), 17 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index eb2cb68b30d1..2257e83ff339 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1317,10 +1317,6 @@ "_ConvolutionDeprecated2DBenchmarkModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AddIntModule_basic", "AtenIntBoolOpModule_basic", "BernoulliTensorModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f0d0a238a129..efe896410e39 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5563,6 +5563,34 @@ def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_in }]; } +def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenAvgPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -5592,30 +5620,91 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ }]; } -def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [ +def Torch_AtenAvgPool2dBackwardOp : Torch_Op<"aten.avg_pool2d_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; let arguments = (ins + AnyTorchTensorType:$grad_output, AnyTorchTensorType:$self, AnyTorchListOfTorchIntType:$kernel_size, AnyTorchListOfTorchIntType:$stride, AnyTorchListOfTorchIntType:$padding, Torch_BoolType:$ceil_mode, - Torch_BoolType:$count_include_pad + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + ParseResult AtenAvgPool2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); } - void AtenAvgPool1dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + void AtenAvgPool2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + +def Torch_AtenAvgPool3dOp : Torch_Op<"aten.avg_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenAvgPool3dBackwardOp : Torch_Op<"aten.avg_pool3d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + Torch_BoolType:$ceil_mode, + Torch_BoolType:$count_include_pad, + AnyTorchOptionalIntType:$divisor_override + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAvgPool3dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void AtenAvgPool3dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); } }]; } @@ -5846,6 +5935,30 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [ }]; } +def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -5870,12 +5983,12 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ }]; } -def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ +def Torch_Aten_AdaptiveAvgPool2dOp : Torch_Op<"aten._adaptive_avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchListOfTorchIntType:$output_size @@ -5885,10 +5998,106 @@ def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult Aten_AdaptiveAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) { + void Aten_AdaptiveAvgPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten_AdaptiveAvgPool2dBackwardOp : Torch_Op<"aten._adaptive_avg_pool2d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AdaptiveAvgPool2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_AdaptiveAvgPool2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenAdaptiveAvgPool3dOp : Torch_Op<"aten.adaptive_avg_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAdaptiveAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten_AdaptiveAvgPool3dOp : Torch_Op<"aten._adaptive_avg_pool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AdaptiveAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_AdaptiveAvgPool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AdaptiveAvgPool3dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_AdaptiveAvgPool3dBackwardOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 9945db7e0768..a7d5bbba43e9 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -426,11 +426,20 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) + emit( + "aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)" + ) emit( "aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" ) emit( - "aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)" + "aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + ) + emit( + "aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" + ) + emit( + "aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" ) emit( "aten::softmax.int : (Tensor, int, int?) -> (Tensor)" @@ -444,8 +453,13 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") - emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") + emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") + emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)") From b9847b19043a70a7896b81f2fcd7d0251a71348f Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 20 Sep 2023 10:48:40 -0700 Subject: [PATCH 06/41] Fixing implicit double to float casts. (#2476) MSVC (and other compilers with implicit narrowing warnings) don't like this type mismatch. --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 75 +++++++++++++--------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 31e8292452a9..1e71f51b8598 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -121,8 +121,8 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, return (doubleValue == static_cast(static_cast(doubleValue))); } else { assert(isInt); - return (intValue >= std::numeric_limits::min()) && - (intValue <= std::numeric_limits::max()); + return (intValue >= static_cast(std::numeric_limits::min())) && + (intValue <= static_cast(std::numeric_limits::max())); } return true; } @@ -145,8 +145,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "Unable to extract the scalar constant"); if (dtype.isa()) { - tosaTensor = tosa::getConstTensor( - rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype) + tosaTensor = tosa::getConstTensor(rewriter, op, + (isFloat ? doubleValue : intValue), + dshape, dtype) .value(); } else if (auto intType = dtype.dyn_cast()) { auto w = intType.getWidth(); @@ -162,7 +163,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, "of destination type"); } bool d = isFloat ? static_cast(doubleValue) - : static_cast(intValue); + : static_cast(intValue); tosaTensor = tosa::getConstTensor(rewriter, op, {d}, dshape).value(); } else if (w == 32) { @@ -616,7 +617,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Negative slope needs to be a scalar constant for conversion to " "TOSA LeakyReLU operation"); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()).value(); + auto zero = + tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()) + .value(); auto cond = rewriter.create( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), @@ -2253,11 +2256,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); - auto epsilonConst = - tosa::getConstTensor(rewriter, op.getOperation(), - {static_cast(eps)}, {}, - meanType.getElementType()) - .value(); + auto epsilonConst = tosa::getConstTensor(rewriter, op.getOperation(), + {static_cast(eps)}, {}, + meanType.getElementType()) + .value(); auto batchNorm = computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, @@ -2571,7 +2573,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); - auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056}, + auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056f}, ln2Shape, selfType.getElementType()) .value(); auto rcpOp = @@ -2802,21 +2804,25 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}, dtype).value(); + auto a1 = + tosa::getConstTensor(rewriter, op, 0.278393f, {}, dtype).value(); auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}, dtype).value(); + auto a2 = + tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}, dtype).value(); + auto a3 = + tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}, dtype).value(); + auto a4 = + tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -2851,13 +2857,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Value xMinusMean = rewriter.create(loc, outType, x, mean); // rsqrt of 2 Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678, {}, dtype).value(); + tosa::getConstTensor(rewriter, op, 0.70710678f, {}, dtype).value(); Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); Value erf = approximateErfOp(rewriter, op, erfArg, dtype); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); + Value oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); @@ -2891,7 +2898,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); cdf = rewriter.createOrFold( - op->getLoc(), cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + op->getLoc(), + cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, @@ -2927,15 +2935,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto loc = op->getLoc(); - const double cstAlpha0 = 1.12837916709551257390; - const double cstAlpha1 = 0.70710678118654752440; - const double oneHalf = 0.5; - const double kAlpha = cstAlpha0 * cstAlpha1; + const float cstAlpha0 = 1.12837916709551257390f; + const float cstAlpha1 = 0.70710678118654752440f; + const float oneHalf = 0.5f; + const float kAlpha = cstAlpha0 * cstAlpha1; - Value kAlphaHalf = - tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value(); + Value kAlphaHalf = tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, + {}, selfElemTy) + .value(); Value negOneHalf = - tosa::getConstTensor(rewriter, op, -0.5, {}, selfElemTy).value(); + tosa::getConstTensor(rewriter, op, -0.5f, {}, selfElemTy).value(); Value inputSquared = rewriter.create( loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); Value negHalfInputSquared = rewriter.create( @@ -3006,7 +3015,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Only scalar constant is supported"); } - Value replace = tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); + Value replace = + tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); Type outType = getTypeConverter()->convertType(op.getType()); Value lesser = rewriter.create( @@ -3553,7 +3563,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // convert None to [0,0,0] auto indexNext = indexTensors[i + 1]; auto indexNextTorch = tensorsTorchType[i + 1]; - if (indexNextTorch.getType().isa()){ + if (indexNextTorch.getType().isa()) { return rewriter.notifyMatchFailure( op, "Multiple None index is not support for now."); } @@ -3620,12 +3630,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTfConcatTensors, lastDim); if (!indicesTf) { - return rewriter.notifyMatchFailure( - op, "Convert TorchIndex To TfIndices fail."); + return rewriter.notifyMatchFailure(op, + "Convert TorchIndex To TfIndices fail."); } - // do the tf scatterNd algorithm with tf style indices as input, algorithm mostly take from convertGatherNdOp. + // do the tf scatterNd algorithm with tf style indices as input, algorithm + // mostly take from convertGatherNdOp. auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, - indicesTf.getResult(), fillValues); + indicesTf.getResult(), fillValues); if (!result) { return rewriter.notifyMatchFailure( From 059041e0fe1d81948e87e2ef0928f13052e78a1c Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Thu, 21 Sep 2023 13:25:14 -0400 Subject: [PATCH 07/41] [LTC] Support torch.ones/zeros/arange ops (#2440) --- build_tools/autogen_ltc_backend.py | 3 +- build_tools/autogen_ltc_backend.yaml | 35 ++---- e2e_testing/xfail_sets.py | 1 - .../Dialect/Torch/IR/GeneratedTorchOps.td | 75 +++++++++++++ .../csrc/base_lazy_backend/CMakeLists.txt | 1 + .../mlir_native_functions.cpp | 101 ++++-------------- .../base_lazy_backend/shape_inference.cpp | 78 ++++++++++++++ .../csrc/base_lazy_backend/tensor.cpp | 29 +++++ .../csrc/base_lazy_backend/tensor.h | 24 +++++ .../reference_lazy_backend/backend_impl.cpp | 25 ++++- .../jit_ir/build_tools/torch_ods_gen.py | 3 + 11 files changed, 262 insertions(+), 113 deletions(-) create mode 100644 python/torch_mlir/csrc/base_lazy_backend/tensor.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/tensor.h diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 5af371d56ef9..4444015805bd 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -467,7 +467,8 @@ def gen_fallback_code(*args, **kwargs): node_base="torch::lazy::TorchMlirNode", node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")), tensor_class=self.tensor_class, - tensor_class_hdr="torch/csrc/lazy/core/tensor.h", + tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h", + create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor", shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")), lazy_ir_generator=GenMlirLazyIr, ) diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index d57f693cc433..bfc4641640aa 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -3,12 +3,6 @@ blacklist: # It also doesn't have confusing `unsafe` argument. - _index_put_impl -# Ops with list of tensors output -- split.Tensor -- split_with_sizes -- unbind.int -- chunk - # Additional ops which autogen is supported for but don't compile yet - _convolution - detach @@ -18,42 +12,28 @@ blacklist: # Disabled for consistency with TS backend - lift_fresh_copy -- new_empty - rsub -- slice.Tensor # Disabled in favour of slice_copy.Tensor -- zeros -- ones -- arange -- arange.start -- arange.start_step -- fill.Scalar -- scalar_tensor # Disabled in favour of functionalized alternatives - _reshape_alias -- expand - permute - select.int -- squeeze - squeeze.dim -- t - transpose.int +- expand +- squeeze - unsqueeze - view +- slice.Tensor +- split.Tensor +- split_with_sizes +- unbind.int -whitelist: -# Enabled for consistency with TS backend -- arange.start_out # List of supported ops that we don't want to do the full codegen for supported: -# - bernoulli -# - bernoulli_ - _to_copy - clone -- empty.memory_format -- empty_strided -- fill_.Scalar - _unsafe_view - unbind_copy.int - split_copy.Tensor @@ -80,10 +60,10 @@ supported: - _trilinear - linalg_pinv.atol_rtol_tensor - logsumexp.out +- t # List of ops that will take in symints for the size instead of ints symint: -- empty.memory_format - new_empty_strided - expand_copy - narrow_copy @@ -91,7 +71,6 @@ symint: - slice_copy.Tensor - split_copy.Tensor - slice_scatter -- view - view_copy - as_strided_copy - as_strided_scatter diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 2257e83ff339..6446a085ff67 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1384,7 +1384,6 @@ "ConvolutionBackwardModule2DPadded_basic", "VarMeanCorrectionModule_basic", "VarMeanCorrectionNoneModule_basic", - "PrimsConvertElementTypeModule_basic", "ElementwisePreluModule_basic", "VarMeanBiasedModule_basic", "VarMeanUnbiasedModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index efe896410e39..0c96a91b5b49 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4490,6 +4490,56 @@ def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [ }]; } +def Torch_AtenRandomOp : Torch_Op<"aten.random", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::random : (Tensor, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandomOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenRandomOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenRandomFromOp : Torch_Op<"aten.random.from", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$from, + AnyTorchOptionalIntType:$to, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandomFromOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenRandomFromOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ AllowsTypeRefinement, HasValueSemantics, @@ -8934,6 +8984,31 @@ def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [ }]; } +def Torch_AtenResizeOp : Torch_Op<"aten.resize", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::resize : (Tensor, int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenResizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenResizeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ AllowsTypeRefinement ]> { diff --git a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt index ad8380612edd..81a8383949c7 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt +++ b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt @@ -69,6 +69,7 @@ add_library(torch_mlir_ltc_backend SHARED backend_impl.cpp dynamic_ir.cpp mlir_node.cpp + tensor.cpp ops/device_data.cpp ops/generic.cpp ops/index.cpp diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp index 540f02ae606c..d06ad5963919 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -30,6 +30,7 @@ #include #include +#include "generated/LazyIr.h" #include "generated/LazyNativeFunctions.h" #include "generated/shape_inference.h" #include "ops/to_copy.h" @@ -143,32 +144,6 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { } // namespace -// at::Tensor LazyNativeFunctions::bernoulli( -// const at::Tensor& self, c10::optional generator) { -// TORCH_LAZY_FN_COUNTER("lazy::"); -// if (generator.has_value() && generator->defined()) { -// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli has generator value"); -// } -// auto self_tensor = torch::lazy::TryGetLtcTensor(self); - -// UNIMPLEMENTED_FUNCTION_ERROR(); -// // return torch::lazy::CreateAtenFromLtcTensor( -// // torch::lazy::bernoulli(self_tensor)); -// } - -// at::Tensor& LazyNativeFunctions::bernoulli_( -// at::Tensor& self, double p, c10::optional generator) { -// TORCH_LAZY_FN_COUNTER("lazy::"); -// if (generator.has_value() && generator->defined()) { -// UNSUPPORTED_ERROR("LazyNativeFunctions::bernoulli_ has generator value"); -// } -// auto self_tensor = torch::lazy::TryGetLtcTensor(self); - -// UNIMPLEMENTED_FUNCTION_ERROR(); -// // torch::lazy::bernoulli_(self_tensor, p); -// // return self; -// } - // clone is special in LT because we make it a no-op. // This should be safe to do, because every operator in the LT is functional. at::Tensor LazyNativeFunctions::clone( @@ -352,64 +327,17 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::empty_symint( - at::SymIntArrayRef sym_size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional memory_format) { - // TODO: support this directly - auto size = C10_AS_INTARRAYREF_SLOW(sym_size); - const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType(); - at::TensorOptions options = at::TensorOptions() - .device(c10::Device(device_type)) - .layout(layout) - .pinned_memory(pin_memory) - .dtype(dtype); - auto x_result = at::empty(size, options, memory_format); - auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device)); - // See Note [Lazy Tensor Functionalization] - if (c10::impl::tls_local_dispatch_key_set().excluded_.has( - c10::DispatchKey::Functionalize)) { - // Invariant: if the functionalization key is in the exclude set, then we're expected - // to return an ordinary tensor, which will be "lifted" into a functional wrapper later. - return tensor; - } else { - auto wrapped = at::functionalization::impl::to_functional_tensor(tensor); - return wrapped; - } -} - -at::Tensor LazyNativeFunctions::empty_strided( - at::IntArrayRef size, at::IntArrayRef stride, - c10::optional dtype, c10::optional layout, - c10::optional device, c10::optional pin_memory) { - TORCH_LAZY_FN_COUNTER("lazy::"); - at::Tensor t = empty_symint( - c10::fromIntArrayRefSlow(size), - dtype, layout, device, pin_memory, c10::nullopt); - return t.as_strided(size, stride, /*storage_offset=*/0); -} - -at::Tensor& -LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto self_tensor = torch::lazy::TryGetLtcTensor(self); - - torch::lazy::Value constant = - torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( - value, self_tensor->shape(), self_tensor->GetDevice()); - self_tensor->SetInPlaceIrValue(std::move(constant)); - return self; -} - at::Tensor LazyNativeFunctions::_unsafe_view( const at::Tensor& self, at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size)); } +at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { + TORCH_LAZY_FN_COUNTER("lazy::"); + return at::functionalization::functionalize_aten_op::call(self); +} + std::vector LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); @@ -643,9 +571,18 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint( c10::optional layout, c10::optional device, c10::optional pin_memory) { - return at::functionalization:: - functionalize_aten_op_symint::call( - self, size, stride, dtype, layout, device, pin_memory); + if (!device || device->type() == c10::DeviceType::Lazy) { + return at::functionalization::functionalize_aten_op_symint< + ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout, + device, pin_memory); + } + // For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu") + // we need to avoid explicit functionalization. To do that we create regular cpu tensors. + at::Tensor t = at::empty_symint( + size, (dtype ? dtype : c10::optional(self.scalar_type())), + (layout ? layout : c10::optional(self.layout())), device, + pin_memory, c10::nullopt); + return t.as_strided_symint(size, stride, /*storage_offset=*/0); } at::Tensor LazyNativeFunctions::narrow_copy_symint( @@ -729,4 +666,4 @@ at::Tensor& LazyNativeFunctions::logsumexp_out( void InitializeAtenBindings() {} } // namespace lazy -} // namespace torch +} // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index f8d03449877d..043094c67e0a 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -265,6 +265,33 @@ std::vector compute_shape_eye( return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } +std::vector compute_shape_arange( + const at::Scalar& end, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + auto out_meta = + at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_arange( + const at::Scalar& start, const at::Scalar& end, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta), + pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_arange( + const at::Scalar& start, const at::Scalar& end, const at::Scalar& step, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::arange(start, end, step, dtype, layout, + c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + std::vector compute_shape_full( at::IntArrayRef size, const at::Scalar& fill_value, c10::optional dtype, c10::optional layout, @@ -273,6 +300,44 @@ std::vector compute_shape_full( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } +std::vector compute_shape_ones( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_zeros( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_empty( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory, + c10::optional memory_format) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_empty_strided( + at::IntArrayRef size, at::IntArrayRef stride, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_fill(const at::Tensor& self, + const at::Scalar& value) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + std::vector compute_shape_fill(const at::Tensor& self, const at::Tensor& value) { return {Shape(self.scalar_type(), self.sizes().vec())}; @@ -302,11 +367,24 @@ std::vector compute_shape_randint( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } +std::vector compute_shape_resize( + const at::Tensor & self, at::IntArrayRef size, + c10::optional memory_format) { + return {Shape(self.scalar_type(), size.vec())}; +} + std::vector compute_shape_bernoulli( const at::Tensor& self, const at::Tensor &p, c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_scalar_tensor( + const at::Scalar & s, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; +} + } // namespace lazy } // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp b/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp new file mode 100644 index 000000000000..82ae6cc27f4a --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/tensor.cpp @@ -0,0 +1,29 @@ +//===- tensor.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include + +#include "tensor.h" + +namespace torch { +namespace lazy { + +at::Tensor CreateFunctionalizedAtenFromLtcTensor( + const LazyTensorPtr& ltc_tensor) { + at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor); + if (!c10::impl::tls_is_dispatch_key_excluded( + c10::DispatchKey::Functionalize) && + !at::functionalization::impl::isFunctionalTensor(tensor)) { + return at::functionalization::impl::to_functional_tensor(tensor); + } + return tensor; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/tensor.h b/python/torch_mlir/csrc/base_lazy_backend/tensor.h new file mode 100644 index 000000000000..4e39dd095aa5 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/tensor.h @@ -0,0 +1,24 @@ +//===- tensor.h -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace torch { +namespace lazy { + +// Ops like torch.ones/zeros etc. which produce new tensor as an output +// should have explicit tensor functinoalization. Otherwise we can get +// unfanctionalized primitives or in the worst case if we apply inplace +// operations to unfunctionalized tensor it won't be captured in LTC graph. +TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 3bc8465eafc1..1064a3d1e1ac 100644 --- a/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -28,6 +28,11 @@ using namespace torch::lazy; namespace torch { namespace lazy { +/// Returns true if a string begins with another. +inline bool beginswith(const std::string& s, const std::string& t) { + return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; +} + struct ReferenceLazyBackendDeviceType : public BackendDeviceType { ReferenceLazyBackendDeviceType(c10::DeviceType device_type) : device_type_(device_type) {} @@ -104,7 +109,25 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { // // JIT Execution adopted from: // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp - torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), ""); + std::shared_ptr graph = mlir_computation->graph(); + for (auto* node : graph->nodes()) { + // Convert any lazy devices to cpu devices to ensure + // that the values are actually computed + if (node->outputs().size() == 1 && + node->output()->type()->kind() == + c10::TypeKind::DeviceObjType) { + auto value_sym = torch::jit::Symbol::attr("value"); + TORCH_CHECK(node->hasAttribute(value_sym), + "Expected node to have 'value' attribute."); + TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s, + "Expected 'value' attribute to be a string."); + if (beginswith(node->s(value_sym), "lazy")) { + node->s_(value_sym, "cpu"); + } + } + } + + torch::jit::GraphExecutor graph_executor(graph, ""); std::vector stack; for (const auto& argument : arguments) { const auto mlir_data = diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a7d5bbba43e9..3dab8eabecb6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -359,6 +359,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::random : (Tensor, Generator?) -> (Tensor)") + emit("aten::random.from : (Tensor, int, int?, Generator?) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") @@ -571,6 +573,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::tile : (Tensor, int[]) -> (Tensor)") emit("aten::reshape : (Tensor, int[]) -> (Tensor)") emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") + emit("aten::resize : (Tensor, int[], int?) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)") emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) From 6699cbcc7484650c4a3554602713a7ec9f58e42b Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Sat, 23 Sep 2023 02:55:18 +0530 Subject: [PATCH 08/41] build: manually update PyTorch version (#2480) Set PyTorch and TorchVision version to nightly release 2023-09-22. Signed-Off By: Vivek Khandelwal --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index f665e35e6884..754078490fe0 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -ba087c0903f1c59eb993614f46602d39fcec2dfd +90c406a3a198b8f45682a9979b4c091ec5dc647e diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 0bf75dc90a30..4c3d409ecb4c 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230918 +torch==2.2.0.dev20230922 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index d09ae0e0eaf0..a63225b58911 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230918 +torchvision==0.17.0.dev20230922 From 5f772e8cb4abe0e134b1a974a6893e6452b9c656 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Sat, 23 Sep 2023 09:00:16 -0500 Subject: [PATCH 09/41] CI: reconcile differences between RollPyTorch and pre-merge checks (#2482) --- .github/workflows/RollPyTorch.yml | 19 +++++++++++++++---- .../python_deploy/build_linux_packages.sh | 6 ++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 51f3f874b065..5c8d74ee0941 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -24,9 +24,21 @@ jobs: - name: Get torch-mlir uses: actions/checkout@v3 with: - submodules: 'true' + submodules: 'false' token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + - name: Get LLVM and StableHlo submodules + run: | + set -eo pipefail + cd ${GITHUB_WORKSPACE} + + # Fetching the submodules concurrently may cause problems, so we fetch + # them one after another. + rm -f .git/modules/externals/llvm-project/index.lock + rm -f .git/modules/externals/stablehlo/index.lock + git submodule update --init --recursive externals/llvm-project + git submodule update --init --recursive externals/stablehlo + - name: Setup ccache uses: ./.github/actions/setup-build with: @@ -71,15 +83,14 @@ jobs: echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV} echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} - - name: Build and test (in-tree), also update ODS and abstract interpretation library + - name: Build and test (out-of-tree), also update ODS and abstract interpretation library if: env.PT_HASH_CHANGED != '0' run: | cd ${GITHUB_WORKSPACE} - TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \ + TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \ TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \ - TM_PYTHON_VERSIONS="cp311-cp311" \ ./build_tools/python_deploy/build_linux_packages.sh - name: Post issue comment on build failure diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 733c79fa3eab..c64119d348d1 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -178,6 +178,12 @@ function run_in_docker() { out-of-tree) setup_venv "$python_version" "$TM_TORCH_VERSION" build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" "$TM_TORCH_VERSION" + if [ "${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" == "ON" ]; then + pushd /main_checkout/torch-mlir + TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_torch_ods.sh + TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_abstract_interp_lib.sh + popd + fi if [ "${TM_SKIP_TESTS}" == "OFF" ]; then test_out_of_tree fi From a520d39f84f159838cc2f4f88a08c9cd611635fd Mon Sep 17 00:00:00 2001 From: Bruce Kim <92174982+brucekimrokcmu@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:00:19 -0400 Subject: [PATCH 10/41] [MLIR][TORCH] Add device "cpu" support for aten.to.dtype_layout op (#2481) This PR adds device="cpu" support for `aten.to_dtypeLayout` op and corresponding e2e test suit. (refer: PR https://github.com/llvm/torch-mlir/pull/812/) --- .../Torch/Transforms/DecomposeComplexOps.cpp | 11 +++++++--- .../test_suite/type_conversion.py | 22 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 63ce4f837e85..6136db09221d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3380,10 +3380,15 @@ class DecomposeAtenToDtypeLayoutOp op, "unimplemented: pinMemory is expected to be false"); } - // TODO: Add support for non-None device arg. + // TODO: Add support for device arg other than cpu. if (!op.getDevice().getType().isa()) { - return rewriter.notifyMatchFailure( - op, "unimplemented: device arg must be None"); + std::string device; + if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) + return rewriter.notifyMatchFailure( + op, "unimplemented: device must be a constant str"); + else if (device != "cpu") + return rewriter.notifyMatchFailure( + op, "unimplemented: device is expected to be cpu"); } // TODO: Add support for non-strided layout. diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 5a4a19c50aba..6e04c5fa8700 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -169,6 +169,28 @@ def forward(self, x): def ToDtypeLayoutNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) +class ToDtypeLayoutCPUModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.to(x, + dtype=torch.float64, + layout=None, + device="cpu", + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None) + + +@register_test_case(module_factory=lambda: ToDtypeLayoutCPUModule()) +def ToDtypeLayoutCPUModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + class ToDtypeLayoutStridedModule(torch.nn.Module): From c9fd78988e16fb97359b8cb10a27f6bf8d377495 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 26 Sep 2023 09:20:01 -0700 Subject: [PATCH 11/41] [NFC] Clean-up `ConvertAtenViewOp` in linalg backend (#2470) While trying to fix a bug in the `ConvertAtenViewOp` pattern in the linalg backend, I realized that the pattern had become quite complex and had accumulated some dead code, making it hard to reason about. This commit simplifies the pattern quite a bit. The main changes are: 1. All the static helper functions in the `ConvertAtenViewOp` class have been simplified, both in their signature and their body. Each one now performs simple calculations on arrays, and take the least number of arguments necessary. 2. The body of [the `while` loop](https://github.com/ramiro050/torch-mlir/blob/9fce566b0cb64ff2b198693d1f6ee9580b8fa01f/lib/Conversion/TorchToLinalg/DataMovement.cpp#L407) inside the main pattern has been changed to work on `MutableArrayRef` slices, to avoid having to keep track of `start` and `end` indices for the input and output shape arrays. 3. All the heuristics used to determine the mapping between the input and output dimensions are now in [this relatively short `if-else` section](https://github.com/ramiro050/torch-mlir/blob/9fce566b0cb64ff2b198693d1f6ee9580b8fa01f/lib/Conversion/TorchToLinalg/DataMovement.cpp#L428-L460), making it easy to see what is going on. 4. Dead code was eliminated + updates to some of the documentation comments This commit does not add any new functionality to the `ConvertAtenViewOp` pattern. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 507 +++++++----------- 1 file changed, 199 insertions(+), 308 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 041581d2a18b..9ec6a6006be7 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -34,6 +34,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +static int64_t productReduce(ArrayRef a) { + return accumulate(a.begin(), a.end(), /*init=*/1, std::multiplies()); +} + template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, @@ -177,144 +181,131 @@ namespace { class ConvertAtenViewOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + // If one of the two dims arrays has size 1, a mapping is created from the one + // dimension of the size-1 array to all the dimensions of the other array. For + // example for inputs: xDims = [6], yDims = [2, 3] the result in the indices + // arrays will be: xIndices = [0], yIndices = [0, 1]. + // + // An error is returned if the dimension size of the size-1 array is not equal + // to the product of all the dimension sizes in the other array, or if neither + // of the arrays is size-1. + static LogicalResult mapAllDimsToSingleDim(ArrayRef xDims, + ArrayRef yDims, + SmallVector &xIndices, + SmallVector &yIndices) { + auto isValidReduction = [](int64_t expectedReductionProduct, + ArrayRef arrayToReduce) -> bool { + if (llvm::count(arrayToReduce, kUnknownSize) > 0 || + expectedReductionProduct == kUnknownSize) + return true; + return productReduce(arrayToReduce) == expectedReductionProduct; + }; - // Helper for filling in remaining un-collapsed dims when the - // input/output dim is next to the next boundary dim. Additionally - // computes the size of a collapsed dynamic dim if necessary. - static LogicalResult - collapseToSingleDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter, - int64_t collapseDim, int64_t maxCollapseDim, - int64_t startExpandDim, int64_t maxExpandDim, - SmallVector &collapseShape, - const SmallVector &expandShape, - ReassociationIndices &expandIndices) { - int64_t collapseDimSize = 1; - for (auto i : llvm::seq(startExpandDim, maxExpandDim)) { - expandIndices.push_back(i); - if (collapseDimSize == kUnknownSize) - continue; - - int64_t expandedDimSize = expandShape[i]; - if (expandedDimSize == kUnknownSize) { - collapseDimSize = kUnknownSize; - continue; - } - collapseDimSize *= expandedDimSize; - } - int64_t rawCollapseDimSize = collapseShape[collapseDim]; - if (rawCollapseDimSize != kUnknownSize && collapseDimSize != kUnknownSize && - collapseDimSize != rawCollapseDimSize) { - return rewriter.notifyMatchFailure( - op, "desired size is not compatible with the input tensor size"); + if (xDims.size() == 1) { + if (!isValidReduction(xDims[0], yDims)) + return failure(); + xIndices.assign({0}); + yIndices.assign(llvm::to_vector(llvm::seq(0, yDims.size()))); + return success(); + } else if (yDims.size() == 1) { + if (!isValidReduction(yDims[0], xDims)) + return failure(); + yIndices.assign({0}); + xIndices.assign(llvm::to_vector(llvm::seq(0, xDims.size()))); + return success(); } - collapseShape[collapseDim] = collapseDimSize; - return success(); + return failure(); } - // Helper to find the minimum set of dims to collapse with the - // same number of elements as that of collapseDim. This function assumes - // the size of the collapsed dim is never dynamic. - static LogicalResult minimallyCollapseDimHelper( - AtenViewOp op, ConversionPatternRewriter &rewriter, int64_t collapseDim, - int64_t maxCollapseDim, int64_t startExpandDim, int64_t maxExpandDim, - SmallVector &collapseShape, SmallVector &expandShape, - ReassociationIndices &collapseIndices, - ReassociationIndices &expandIndices) { - - int64_t collapseDimSize = collapseShape[collapseDim]; - - int64_t expandedSize = 1; - int64_t collapsedSize = collapseDimSize; - - int64_t expandIndex = startExpandDim; - int64_t collapseIndex = collapseDim + 1; - - if (collapseDimSize == kUnknownSize) { - if (llvm::all_of(collapseShape, - [](int64_t value) { return value == kUnknownSize; }) && - llvm::all_of(expandShape, - [](int64_t value) { return value == kUnknownSize; })) { - - for (size_t i = 0; i < collapseShape.size(); i++) { - collapseIndices.push_back(i); - } - - for (size_t i = 0; i < expandShape.size(); i++) { - expandIndices.push_back(i); - } - - return success(); + // Starting from the beginning of the dims arrays, this helper finds the + // smallest set of consecutive dims in each array such that the product of the + // dim sizes in the two subsets is equal. The indices arrays are populated + // with the indices of the dims arrays that correspond to the subsets found. + // + // An error is returned if two subsets of dims with total number of elements + // equal to each other is not found. + static LogicalResult mapStaticallyKnownDims(ArrayRef xDims, + ArrayRef yDims, + SmallVector &xIndices, + SmallVector &yIndices) { + if (xDims.empty() || yDims.empty()) + return failure(); + int64_t xTotalSize = xDims[0]; + int64_t yTotalSize = yDims[0]; + SmallVector xIndicesResult({0}); + SmallVector yIndicesResult({0}); + size_t nextXIndex = 1; + size_t nextYIndex = 1; + while (xTotalSize != yTotalSize) { + if (xTotalSize < yTotalSize) { + if (nextXIndex == xDims.size() || xDims[nextXIndex] == kUnknownSize) + return failure(); + xTotalSize *= xDims[nextXIndex]; + xIndicesResult.push_back(nextXIndex++); + } else { + if (nextYIndex == yDims.size() || yDims[nextYIndex] == kUnknownSize) + return failure(); + yTotalSize *= yDims[nextYIndex]; + yIndicesResult.push_back(nextYIndex++); } } - while (expandIndex != maxExpandDim || collapseIndex != maxCollapseDim) { - if (expandIndex != maxExpandDim && expandedSize <= collapsedSize) { - int64_t expandDimSize = expandShape[expandIndex]; - if (expandDimSize != kUnknownSize) { - expandedSize *= expandDimSize; - } - expandIndices.push_back(expandIndex); - expandIndex++; - - } else if (collapseIndex != maxCollapseDim && - collapsedSize < expandedSize) { - collapseDimSize = collapseShape[collapseIndex]; - if (collapseDimSize != kUnknownSize) { - collapsedSize *= collapseDimSize; - } - collapseIndices.push_back(collapseIndex); - collapseIndex++; - } - - if (expandedSize == collapsedSize) - return success(); - } - return rewriter.notifyMatchFailure( - op, "total number of elements mismatch in the expansion"); + xIndices.assign(std::move(xIndicesResult)); + yIndices.assign(std::move(yIndicesResult)); + return success(); } - static void solveDynamicSize(SmallVector &inputShape, - SmallVector &outputShape) { - int64_t inputProduct = 1; - int64_t outputProduct = 1; - - int64_t inputDynamicValues = 0; - int64_t outputDynamicValues = 0; - - for (int64_t value : inputShape) { - if (value == -1) { - ++inputDynamicValues; - } else { - inputProduct *= value; - } - } - for (int64_t value : outputShape) { - if (value == -1) { - ++outputDynamicValues; - } else { - outputProduct *= value; - } + // Calculates the size of a dynamic dimension if all other dimensions are + // statically known, and rewrites that dynamic dimension with the static size. + // + // Note: this function assumes that all the dimensions in `inputShape` map to + // all the dimensions in `outputShape`. + static void calculateSingleDynamicSize(MutableArrayRef inputShape, + MutableArrayRef outputShape) { + int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); + int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); + if (inputDynamicDimCount + outputDynamicDimCount != 1) + return; + + int64_t inputProduct = productReduce(inputShape); + int64_t outputProduct = productReduce(outputShape); + + if (inputDynamicDimCount == 1) { + inputProduct /= kUnknownSize; + *llvm::find(inputShape, kUnknownSize) = outputProduct / inputProduct; + } else { + outputProduct /= kUnknownSize; + *llvm::find(outputShape, kUnknownSize) = inputProduct / outputProduct; } + } - if (inputDynamicValues + outputDynamicValues == 1) { - if (inputDynamicValues) { - int64_t missingValue = outputProduct / inputProduct; - for (size_t i = 0; i < inputShape.size(); i++) { - if (inputShape[i] == -1) { - inputShape[i] = missingValue; - break; - } - } - } else { - int64_t missingValue = inputProduct / outputProduct; - for (size_t i = 0; i < outputShape.size(); i++) { - if (outputShape[i] == -1) { - outputShape[i] = missingValue; - break; - } + // Gets the shapes of the input and output tensors, making a best-effort + // attempt to extract static shape information given the inputs to + // `aten.view`. + static std::pair, SmallVector> + getInputAndOutputShape(Value inputTorchTensor, + SmallVector outputSizeTorchInt) { + SmallVector inputShape( + inputTorchTensor.getType().cast().getSizes()); + SmallVector outputShape(outputSizeTorchInt.size(), kUnknownSize); + for (auto [outputDim, outputDimSize] : + llvm::enumerate(outputSizeTorchInt)) { + int64_t inputDim; + int64_t outputDimSizeInt; + // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim + if (matchPattern(outputDimSize, + m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) { + outputShape[outputDim] = inputShape[inputDim]; + } else if (matchPattern(outputDimSize, + m_TorchConstantInt(&outputDimSizeInt))) { + if (outputDimSizeInt != -1) { + outputShape[outputDim] = outputDimSizeInt; } } } + + calculateSingleDynamicSize(inputShape, outputShape); + return std::make_pair(inputShape, outputShape); } LogicalResult @@ -325,8 +316,7 @@ class ConvertAtenViewOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); - SmallVector inputShape = - makeShapeTorchCompatible(inputType.getShape()); + SmallVector inputSize = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputType.getRank(); const TypeConverter *typeConverter = getTypeConverter(); auto resultType = @@ -349,6 +339,15 @@ class ConvertAtenViewOp : public OpConversionPattern { "unimplemented: the target size is " "not constructed from ListConstruct"); } + if (llvm::count_if(outputSizeTorchInt, [](Value size) -> bool { + int64_t sizeInt; + if (matchPattern(size, m_TorchConstantInt(&sizeInt))) + return sizeInt == -1; + return false; + }) > 1) { + return rewriter.notifyMatchFailure( + op, "at most one element in size list is allowed to be -1"); + } SmallVector outputSizeInt = getTypeConvertedValues( rewriter, loc, typeConverter, outputSizeTorchInt); if (resultRank != (int64_t)outputSizeInt.size()) { @@ -356,6 +355,9 @@ class ConvertAtenViewOp : public OpConversionPattern { op, "desired size list length mismatches with the result type rank"); } + auto [inputShape, outputShape] = + getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); + // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. @@ -364,90 +366,24 @@ class ConvertAtenViewOp : public OpConversionPattern { // [6] => [3, 2]. // Iterate through the view op size list to do the following: - // - // 1. Combine output size list and input tensor type info to get the most - // static outputShape. - // - // 2. Mark dims in unchangedDims for size list items where the output dim + // Mark dims in unchangedDims for size list items where the output dim // size comes from a `torch.aten.size.int(inputTensor, inputDim)`. We // naively assume this means the corresponding dimension is not expanded or // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption // is violated for the cases of dynamic dimensions. - SmallVector outputShape(resultRank, kUnknownSize); - SmallVector unchangedDims; - std::optional inferredDimension; - for (auto en : llvm::enumerate(outputSizeTorchInt)) { + SmallVector> unchangedDims; + for (auto [outputDim, outputDimSize] : + llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; - int64_t size; - int64_t outputDim = en.index(); // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim - if (matchPattern(en.value(), + if (matchPattern(outputDimSize, m_TorchTensorSizeInt(op.getSelf(), &inputDim))) { - unchangedDims.emplace_back(); - unchangedDims.back().push_back(inputDim); - unchangedDims.back().push_back(outputDim); - if (!inputType.isDynamicDim(inputDim)) { - outputShape[outputDim] = inputShape[inputDim]; - continue; - } - } else if (matchPattern(en.value(), m_TorchConstantInt(&size))) { - if (size != -1) { - outputShape[outputDim] = size; - continue; - } - - if (inferredDimension.has_value()) { - return rewriter.notifyMatchFailure( - op, "at most one element in size list is allowed to be -1"); - } - inferredDimension = outputDim; + unchangedDims.push_back(std::make_pair(inputDim, outputDim)); } } - // Mark the end of the input/output shapes - unchangedDims.emplace_back(); - unchangedDims.back().push_back(inputRank); - unchangedDims.back().push_back(resultRank); - - // Use static information of input tensor to determine size of inferred - // dimension in output shape. - // - // If there is an inferred dimension and that is the only dimension - // in the output shape (i.e. the tensor is getting fully flattened), - // then we don't need to analyze the static information of the input - // shape since the reassociation of dimensions only requires rank - // information. - if (inferredDimension.has_value() && outputShape.size() > 1) { - if (llvm::count(outputShape, kUnknownSize) != 1 || - llvm::count(inputShape, kUnknownSize) != 0) { - return rewriter.notifyMatchFailure( - op, - "unimplemented: an inferred dimension is only supported when there " - "is enough static shape information to determine its size, or when " - "the input tensor is being flattened to a single dimension"); - } - auto productReduceKnownSizes = [](const ArrayRef sizes) { - auto knownSizes = llvm::make_filter_range( - sizes, [](int64_t val) { return val != kUnknownSize; }); - return std::accumulate(knownSizes.begin(), knownSizes.end(), /*init=*/1, - std::multiplies()); - }; - - int64_t numOfElements = productReduceKnownSizes(inputShape); - int64_t outputKnownNumOfElements = productReduceKnownSizes(outputShape); - if (numOfElements % outputKnownNumOfElements != 0) { - return rewriter.notifyMatchFailure( - op, "number of elements in input tensor must be divisible by " - "product of non-inferred dimensions in size list"); - } - outputShape[*inferredDimension] = - numOfElements / outputKnownNumOfElements; - } - - SmallVector inputSize = getTensorSizes(rewriter, loc, input); - ArrayRef outputShapeInt = llvm::ArrayRef(outputSizeInt); - ArrayRef inputShapeInt = llvm::ArrayRef(inputSize); + unchangedDims.push_back(std::make_pair(inputRank, resultRank)); // Association indices for expand/collapse ops. These two vectors // are populated such that two entries at the same index corresponds @@ -463,10 +399,6 @@ class ConvertAtenViewOp : public OpConversionPattern { SmallVector inputAssociations; SmallVector outputAssociations; - SmallVector inputShapeVec = llvm::to_vector(inputShape); - - solveDynamicSize(inputShapeVec, outputShape); - // The for loop does the following: // 1. Attempt to match the indices from inputDim and outputDim to the next // boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or @@ -482,119 +414,78 @@ class ConvertAtenViewOp : public OpConversionPattern { // the dynamic dimension with the one across from it and give up if we can't // reason about how the dimensions are associated. // e.g. [-1, -1] -> [2, 3, 4] - // 3. Set inputShapeVec and outputShape following the requirements by - // tensor.expand_shape verification code: - // a. As long as one or more of the related dimensions in the expanded - // shape is dynamic the collapsed dimension is dynamic. - // b. If all of the related dimensions are static, the collapsed - // dimension must be static. In other words, if a collapsed dimension is - // dynamic, at least one of the related dimensions need to be dynamic. + // For more information, see description of helper functions used in the + // `if-else` cases inside the while loop. int64_t inputDim = 0, outputDim = 0; - for (auto boundary : unchangedDims) { - // We assume dims specified by AtenSizeInt ops are unchanged - int64_t nextUnchangedInput = boundary[0]; - int64_t nextUnchangedOutput = boundary[1]; - - bool hasDynamic = false; + for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) { + // Used for ensuring that we don't have an ambiguous expansion + bool assumedDynamicDimNotSplit = false; while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { - - inputAssociations.emplace_back(); - outputAssociations.emplace_back(); - - // outputDim is next to the boundary - if (outputDim == nextUnchangedOutput - 1) { - - if (hasDynamic && inputDim != nextUnchangedInput - 1) { - return rewriter.notifyMatchFailure( - op, "found ambiguous collapse of dynamic input sizes (e.g. " - "[-1, -1, -1] -> [-1, -1])"); - } - outputAssociations.back().push_back(outputDim); - if (failed(collapseToSingleDimHelper( - op, rewriter, outputDim, nextUnchangedOutput, inputDim, - nextUnchangedInput, outputShape, inputShapeVec, - inputAssociations.back()))) - return failure(); - outputDim = nextUnchangedOutput; - inputDim = nextUnchangedInput; - continue; - } - - // inputDim is next to the boundary - if (inputDim == nextUnchangedInput - 1) { - - if (hasDynamic && inputShape[inputDim] == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> " - "[-1, -1, -1])"); - } - inputAssociations.back().push_back(inputDim); - if (failed(collapseToSingleDimHelper( - op, rewriter, inputDim, nextUnchangedInput, outputDim, - nextUnchangedOutput, inputShapeVec, outputShape, - outputAssociations.back()))) - return failure(); - - outputDim = nextUnchangedOutput; - inputDim = nextUnchangedInput; - continue; - } - - int64_t inputMatchingDimSize = inputShapeVec[inputDim]; - int64_t outputMatchingDimSize = outputShape[outputDim]; - - // If the input is dynamic, first assume it is not split - if (inputMatchingDimSize == kUnknownSize) { - - checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim], - outputShapeInt[outputDim]); - outputShape[outputDim] = kUnknownSize; - inputAssociations.back().push_back(inputDim++); - outputAssociations.back().push_back(outputDim++); - hasDynamic = true; - continue; + auto inputShapeSlice = + MutableArrayRef(inputShape) + .slice(inputDim, nextUnchangedInput - inputDim); + auto outputShapeSlice = + MutableArrayRef(outputShape) + .slice(outputDim, nextUnchangedOutput - outputDim); + SmallVector inputSliceIndices; + SmallVector outputSliceIndices; + + // TODO: this can be removed by replacing it with a checkDimEqualHelper + // that takes into account the product of all the dimensions being + // reduced + if (assumedDynamicDimNotSplit && inputShapeSlice.size() == 1 && + outputShapeSlice.size() != 1 && + inputShapeSlice[0] == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "found ambiguous expand of dynamic input sizes " + "(e.g. [-1, -1] -> [-1, -1, -1])"); } - // inputDim size is larger; try to collapse onto it - if (inputMatchingDimSize >= outputMatchingDimSize) { - - inputAssociations.back().push_back(inputDim); - if (failed(minimallyCollapseDimHelper( - op, rewriter, inputDim, nextUnchangedInput, outputDim, - nextUnchangedOutput, inputShapeVec, outputShape, - inputAssociations.back(), outputAssociations.back()))) { - return failure(); + if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice, + inputSliceIndices, + outputSliceIndices))) { + calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice); + // Update shape to pass the tensor.expand_shape and + // tensor.collapse_shape verifiers. If one of the dimensions of the + // tensor being flattened is dynamic, the size of the flattened tensor + // must also be dynamic. + if (inputShapeSlice.size() == 1 && + llvm::count(outputShapeSlice, kUnknownSize) > 0) { + inputShapeSlice[0] = kUnknownSize; + } else if (outputShapeSlice.size() == 1 && + llvm::count(inputShapeSlice, kUnknownSize) > 0) { + outputShapeSlice[0] = kUnknownSize; } - hasDynamic = false; - outputDim = outputAssociations.back().back() + 1; - inputDim = inputAssociations.back().back() + 1; - continue; + } else if (succeeded(mapStaticallyKnownDims( + inputShapeSlice, outputShapeSlice, inputSliceIndices, + outputSliceIndices))) { + /// `mapStaticallyKnownDims` maps the smallest number of + /// input and output dimensions in the slice statically + /// known to have the same number of elements. + } else if (inputShapeSlice[0] == kUnknownSize) { + // If the input is dynamic, assume it is not split + checkDimEqualHelper(rewriter, loc, inputSize[inputDim], + outputSizeInt[outputDim]); + // If output dimension is not dynamic, improve static information of + // input + inputShape[inputDim] = outputShape[outputDim]; + inputSliceIndices.push_back(0); + outputSliceIndices.push_back(0); + assumedDynamicDimNotSplit = true; + } else { + return rewriter.notifyMatchFailure( + op, "unimplemented: found unhandled case of expansion/collapse " + "in `aten.view`"); } - // outputDim is larger; try to collapse onto it - outputAssociations.back().push_back(outputDim); - if (failed(minimallyCollapseDimHelper( - op, rewriter, outputDim, nextUnchangedOutput, inputDim, - nextUnchangedInput, outputShape, inputShapeVec, - outputAssociations.back(), inputAssociations.back()))) { - - return failure(); - } - hasDynamic = false; + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + for (int64_t inputSliceIndex : inputSliceIndices) + inputAssociations.back().push_back(inputSliceIndex + inputDim); + for (int64_t outputSliceIndex : outputSliceIndices) + outputAssociations.back().push_back(outputSliceIndex + outputDim); inputDim = inputAssociations.back().back() + 1; outputDim = outputAssociations.back().back() + 1; - continue; - } - - if (inputDim != nextUnchangedInput) { - hasDynamic = true; - if (inputAssociations.size() < 1) { - inputAssociations.emplace_back(); - outputAssociations.emplace_back(); - } - inputAssociations.back().push_back(inputDim++); - outputAssociations.back().push_back(outputDim++); - continue; } // Append the associations for the dims matching `aten.size.int` @@ -624,7 +515,7 @@ class ConvertAtenViewOp : public OpConversionPattern { Type adjustedResultType = RankedTensorType::get( makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( - makeShapeLLVMCompatible(inputShapeVec), resultType.getElementType()); + makeShapeLLVMCompatible(inputShape), resultType.getElementType()); Value castedInput = rewriter.create(loc, adjustedInputType, input); std::optional expandedInput; From ff7f8b21dcc842a4f70209a6d255d54c4ef6e39b Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:15:55 -0500 Subject: [PATCH 12/41] update llvm-project to d13da154a7c7eff77df8686b2de1cfdfa7cc7029 (#2483) --- externals/llvm-project | 2 +- lib/Conversion/TorchToSCF/TorchToSCF.cpp | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index f66cd9e9556a..d13da154a7c7 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit f66cd9e9556a53142a26a5c21a72e21f1579217c +Subproject commit d13da154a7c7eff77df8686b2de1cfdfa7cc7029 diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 146959151240..96e14f0fdd6e 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -237,17 +237,17 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { SmallVector regionArgTypes; SmallVector regionArgLocs; - for (Value value : scfForOp.getLoopBody().front().getArguments()) { + for (Value value : scfForOp.getRegion().front().getArguments()) { regionArgTypes.push_back(value.getType()); regionArgLocs.push_back(value.getLoc()); } // Populate the loop body region. - if (!scfForOp.getLoopBody().empty()) - rewriter.eraseBlock(&scfForOp.getLoopBody().back()); + if (!scfForOp.getRegion().empty()) + rewriter.eraseBlock(&scfForOp.getRegion().back()); - auto *block = rewriter.createBlock(&scfForOp.getLoopBody(), - scfForOp.getLoopBody().begin(), + auto *block = rewriter.createBlock(&scfForOp.getRegion(), + scfForOp.getRegion().begin(), regionArgTypes, regionArgLocs); // Rewrite uses of the torch loop block arguments to the new for-loop From 7760bda8ee6244837ec76cedbee7e518127a4feb Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 27 Sep 2023 06:47:15 +0000 Subject: [PATCH 13/41] build: manually update PyTorch version Set PyTorch and TorchVision version to nightly release 2023-09-26. aten._convolution.deprecated changes done because upstream PyTorch has now added support for fp16 native convolution on CPU. Refer: https://github.com/pytorch/pytorch/commit/7c9052165a5358266a6c8fe614a203c70587cc49 Signed-Off By: Vivek Khandelwal --- .../Transforms/AbstractInterpLibrary.cpp | 96 ++++++++++--------- .../build_tools/abstract_interp_lib_gen.py | 14 +-- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 5 files changed, 61 insertions(+), 55 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 697ad6bbd7ef..de6d287e3443 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9547,94 +9547,98 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %8 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %4 -> () {\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %7 -> () {\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %4 -> () {\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %7 -> () {\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index d6f064f745ed..c2e24e93cf7c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2461,7 +2461,7 @@ def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], s _check_tensors_with_the_same_dtype( tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], tensor_device="cpu", - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_kwargs) + + error_types={torch.bool, torch.complex64, torch.complex128}, **_convolution_kwargs) + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_kwargs), ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), @@ -2473,8 +2473,9 @@ def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], s def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + assert input_dtype == weight_dtype + assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool ranks: List[Optional[int]] = [input_rank, weight_rank] dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) @@ -2494,7 +2495,7 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d _check_tensors_with_the_same_dtype( tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], tensor_device="cpu", - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + + error_types={torch.bool, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_deprecated_kwargs), ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), @@ -2507,8 +2508,9 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + assert input_dtype == weight_dtype + assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool ranks: List[Optional[int]] = [input_rank, weight_rank] dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 754078490fe0..a5e99e920ab6 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -90c406a3a198b8f45682a9979b4c091ec5dc647e +ab61acc20ccd35835b9cd7f587f6a909839cf57f diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 4c3d409ecb4c..583012d29dac 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230922 +torch==2.2.0.dev20230926 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index a63225b58911..d73b26f643d4 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230922 +torchvision==0.17.0.dev20230926 From e69266a936427184dd3c887e517c7098f7f1bf4e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 27 Sep 2023 08:45:35 -0700 Subject: [PATCH 14/41] update PyTorch version to 2.2.0.dev20230927 (#2489) torch version: 2.2.0.dev20230927 torch commit hash: d7520d8668dc08f7bed27a64f006c909006e653a torchvision version: 0.17.0.dev20230927 Co-authored-by: Roll PyTorch Action --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index a5e99e920ab6..fb206a10a32d 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -ab61acc20ccd35835b9cd7f587f6a909839cf57f +d7520d8668dc08f7bed27a64f006c909006e653a diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 583012d29dac..3dfe294db805 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230926 +torch==2.2.0.dev20230927 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index d73b26f643d4..ae783e410baa 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230926 +torchvision==0.17.0.dev20230927 From 7c6b9d2445288cea7140b94a8586ce7873b50ccb Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Wed, 27 Sep 2023 09:09:30 -0700 Subject: [PATCH 15/41] [linalg] Fix handling of trailing size-1 dimensions in aten.view (#2474) This commit adds to the lowering of `aten.view` handling for the following cases: - `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)` - `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))` Fixes: https://github.com/llvm/torch-mlir/issues/2448 --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 51 ++++++++++++++++--- .../test_suite/reshape_like.py | 36 ++++++++++++- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 9ec6a6006be7..1f54af1f6589 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -193,6 +193,9 @@ class ConvertAtenViewOp : public OpConversionPattern { ArrayRef yDims, SmallVector &xIndices, SmallVector &yIndices) { + if (xDims.empty() || yDims.empty()) + return failure(); + auto isValidReduction = [](int64_t expectedReductionProduct, ArrayRef arrayToReduce) -> bool { if (llvm::count(arrayToReduce, kUnknownSize) > 0 || @@ -255,6 +258,25 @@ class ConvertAtenViewOp : public OpConversionPattern { return success(); } + // If one of the two dims arrays has size 0 and the other array only + // has dims of size 1, a mapping is created from no dimensions to + // all the dimensions of the other array. + static LogicalResult mapTrailingSizeOneDims(ArrayRef xDims, + ArrayRef yDims, + SmallVector &xIndices, + SmallVector &yIndices) { + SmallVector ignoredIndices; + if (xDims.empty()) { + return mapAllDimsToSingleDim(ArrayRef({1}), yDims, + ignoredIndices, yIndices); + } else if (yDims.empty()) { + return mapAllDimsToSingleDim(xDims, ArrayRef({1}), xIndices, + ignoredIndices); + } else { + return failure(); + } + } + // Calculates the size of a dynamic dimension if all other dimensions are // statically known, and rewrites that dynamic dimension with the static size. // @@ -262,6 +284,8 @@ class ConvertAtenViewOp : public OpConversionPattern { // all the dimensions in `outputShape`. static void calculateSingleDynamicSize(MutableArrayRef inputShape, MutableArrayRef outputShape) { + if (inputShape.empty() || outputShape.empty()) + return; int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); if (inputDynamicDimCount + outputDynamicDimCount != 1) @@ -420,7 +444,7 @@ class ConvertAtenViewOp : public OpConversionPattern { for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) { // Used for ensuring that we don't have an ambiguous expansion bool assumedDynamicDimNotSplit = false; - while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { + while (inputDim < nextUnchangedInput || outputDim < nextUnchangedOutput) { auto inputShapeSlice = MutableArrayRef(inputShape) .slice(inputDim, nextUnchangedInput - inputDim); @@ -441,9 +465,15 @@ class ConvertAtenViewOp : public OpConversionPattern { "(e.g. [-1, -1] -> [-1, -1, -1])"); } - if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice, - inputSliceIndices, - outputSliceIndices))) { + if (succeeded(mapTrailingSizeOneDims(inputShapeSlice, outputShapeSlice, + inputSliceIndices, + outputSliceIndices))) { + } else if (outputShapeSlice.empty()) { + inputSliceIndices.assign( + llvm::to_vector(llvm::seq(0, inputShapeSlice.size()))); + } else if (succeeded(mapAllDimsToSingleDim( + inputShapeSlice, outputShapeSlice, inputSliceIndices, + outputSliceIndices))) { calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice); // Update shape to pass the tensor.expand_shape and // tensor.collapse_shape verifiers. If one of the dimensions of the @@ -462,7 +492,8 @@ class ConvertAtenViewOp : public OpConversionPattern { /// `mapStaticallyKnownDims` maps the smallest number of /// input and output dimensions in the slice statically /// known to have the same number of elements. - } else if (inputShapeSlice[0] == kUnknownSize) { + } else if (inputShapeSlice.size() > 0 && + inputShapeSlice[0] == kUnknownSize) { // If the input is dynamic, assume it is not split checkDimEqualHelper(rewriter, loc, inputSize[inputDim], outputSizeInt[outputDim]); @@ -478,8 +509,14 @@ class ConvertAtenViewOp : public OpConversionPattern { "in `aten.view`"); } - inputAssociations.emplace_back(); - outputAssociations.emplace_back(); + // If one of the slices is empty, this means we are handling + // the case of trailing dimensions, which does not require a + // new reassociation; the trailing dimensions get added to the + // last reassociation created. + if (inputShapeSlice.size() > 0 && outputShapeSlice.size() > 0) { + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + } for (int64_t inputSliceIndex : inputSliceIndices) inputAssociations.back().push_back(inputSliceIndex + inputDim); for (int64_t outputSliceIndex : outputSliceIndices) diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 1c2d810c13de..f91b90582d24 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -672,6 +672,40 @@ def forward(self, a): def ViewNegativeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 128)) +class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0), 1, 1, 1) + +@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule()) +def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 1, 1, 1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0)) + +@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule()) +def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 1, 1, 1)) + # ============================================================================== class ReshapeAliasExpandModule(torch.nn.Module): @@ -710,4 +744,4 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReshapeAliasCollapseModule()) def ReshapeAliasCollapseModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4)) \ No newline at end of file + module.forward(tu.rand(2, 4)) From 8abfa5b19613114bdf1841b6bc09db7ee4924a2f Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 27 Sep 2023 22:10:32 +0530 Subject: [PATCH 16/41] Use PyTorch nightly for Arm release build (#2488) The LTC backend has drifted from being able to pass tests on the stable PyTorch version, so pinning to nightly on ARM. Signed-Off By: Vivek Khandelwal --- .github/workflows/buildRelease.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 1af748879e43..1e732f4b3732 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -115,7 +115,7 @@ jobs: cd $GITHUB_WORKSPACE TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TM_TORCH_VERSION="stable" ./build_tools/python_deploy/build_linux_packages.sh + TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh # If we were given a release_id, then upload the package we just built # to the github releases page. From 4e1dd3bf10cc5d3b07251de9593b02f0673e4f1b Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Thu, 28 Sep 2023 10:17:03 -0700 Subject: [PATCH 17/41] add e2e support for torch.log10 (#2479) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 45 +++++++++++++++++++ .../TorchToLinalg/Uncategorized.cpp | 8 +++- .../Transforms/AbstractInterpLibrary.cpp | 9 ++++ .../build_tools/abstract_interp_lib_gen.py | 8 ++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 42 +++++++++++++++++ 6 files changed, 111 insertions(+), 2 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0c96a91b5b49..673fab897be0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2527,6 +2527,51 @@ def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ }]; } +def Torch_AtenLog10Op : Torch_Op<"aten.log10", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::log10 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLog10Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLog10Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::log10_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLog10_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLog10_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 1d25d22720d2..69634a610e35 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -235,6 +235,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -1177,7 +1181,7 @@ class ConvertElementwiseOp : public ConversionPattern { AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, - AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, + AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, @@ -1712,7 +1716,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, - AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, + AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index de6d287e3443..de99e529ca5d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6322,6 +6322,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.log10\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.log1p\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8291,6 +8295,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log10\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index c2e24e93cf7c..d142770b1810 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -122,6 +122,9 @@ def aten〇detach〡shape(self: List[int]) -> List[int]: def aten〇log2〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇log10〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇log1p〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -1438,6 +1441,11 @@ def aten〇log2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇log10〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 3dab8eabecb6..4c2b30d817db 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -294,6 +294,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", "aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::log2 : (Tensor) -> (Tensor)", + "aten::log10 : (Tensor) -> (Tensor)", "aten::sqrt : (Tensor) -> (Tensor)", "aten::log1p : (Tensor) -> (Tensor)", "aten::rsqrt : (Tensor) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index a2e3e8e29608..a2603bf8e7dd 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1683,6 +1683,48 @@ def ElementwiseLog2IntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) +# ============================================================================== + +class ElementwiseLog10Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.log10(a) + + +@register_test_case(module_factory=lambda: ElementwiseLog10Module()) +def ElementwiseLog10Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + +class ElementwiseLog10IntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.log10(a) + + +@register_test_case(module_factory=lambda: ElementwiseLog10IntModule()) +def ElementwiseLog10IntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + # ============================================================================== From 860be09a3908a169b4d37801eaba88ad3bf72a5b Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 29 Sep 2023 16:45:48 -0700 Subject: [PATCH 18/41] Elide dynamic broadcast checks when in strict symbolic shapes mode. (#2496) When importing dynamic shaped programs from Dynamo, via torch.compile or torch.export, we can assume that strict symbolic shape checks have been done prior to generating torch IR. Among other shape checking, this eliminates the case where an unknown dimension can be dynamically '1' in a way that signals a broadcast. Adds a `isAssumingStrictSymbolicShapes` utility which consults a `torch.assume_strict_symbolic_shapes` attribute on an enclosing scope and returns true if present. In the linalg pipeline, many runtime checks are elided when this returns true. --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 18 ++++++++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 21 +++++---- .../TorchToLinalg/IndirectDataMovement.cpp | 18 ++++---- lib/Conversion/TorchToLinalg/Linear.cpp | 19 ++++---- .../TorchToLinalg/TensorScalarInterop.cpp | 4 +- .../TorchToLinalg/Uncategorized.cpp | 10 +++-- lib/Conversion/TorchToLinalg/Utils.cpp | 31 +++++++------ .../Torch/Transforms/DecomposeComplexOps.cpp | 43 +++++++++++-------- lib/Dialect/Torch/Utils/Utils.cpp | 9 ++++ test/Conversion/TorchToLinalg/basic.mlir | 17 ++++++-- 10 files changed, 126 insertions(+), 64 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 37aaed9cd704..f913e70345f4 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -86,6 +86,24 @@ FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value input, Value dim); +// In Dynamo import paths, we can assume that dynamic dimensions are strictly +// quantities and are not ambiguous with '1' symbols that can be interpreted +// to signal an expansion in various broadcasting scenarios. In the +// torch.compile eager path, this precondition is assured by guards on 0/1 +// dimension values, and on the torch.export graph-capture path, the shape +// solver guarantees this. +// +// We let lowerings assume this on a per-scope basis if the +// torch.assume_strict_symbolic_shapes unit attribute is present on any parent +// of the block. +bool isAssumingStrictSymbolicShapes(Block *scope); + +// Helper that uses the block from an OpBuilder for determining whether we +// are assuming strict symbolic shapes. +inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) { + return isAssumingStrictSymbolicShapes(builder.getBlock()); +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 1f54af1f6589..74b7badd8253 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -656,20 +656,23 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { reassociation[0].push_back(headOnesCount++); } - // TODO: Add support for size-1 dynamic dimensions. Value one = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); int64_t j = -1; + bool elideDynamicBroadcastDimCheck = + isAssumingStrictSymbolicShapes(rewriter); for (auto i : llvm::seq(headOnesCount, inputRank)) { if (inputType.isDynamicDim(i)) { - // Make sure that size-1 dynamic dimension does not exist. - Value dimSize = getDimOp(rewriter, loc, input, i); - Value dimSizeNotOne = rewriter.create( - loc, arith::CmpIPredicate::ne, dimSize, one); - rewriter.create( - loc, dimSizeNotOne, - rewriter.getStringAttr( - "unimplemented: size 1 dynamic dimension is not supported")); + if (!elideDynamicBroadcastDimCheck) { + // Make sure that size-1 dynamic dimension does not exist. + Value dimSize = getDimOp(rewriter, loc, input, i); + Value dimSizeNotOne = rewriter.create( + loc, arith::CmpIPredicate::ne, dimSize, one); + rewriter.create( + loc, dimSizeNotOne, + rewriter.getStringAttr( + "unimplemented: size 1 dynamic dimension is not supported")); + } ++j; } else if (inputType.getDimSize(i) != 1) { ++j; diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index cfbac2632a28..0e89d822669f 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -644,14 +644,16 @@ class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern 1) { - Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, - rewriter.getIndexType()); - auto equalToRunning = rewriter.create( - loc, arith::CmpIPredicate::eq, cstStaticDimSize, - dynamicDims[0]); - rewriter.create(loc, equalToRunning, - "mismatched size for broadcast"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + if (staticDimSize > 1) { + Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, + rewriter.getIndexType()); + auto equalToRunning = rewriter.create( + loc, arith::CmpIPredicate::eq, cstStaticDimSize, + dynamicDims[0]); + rewriter.create(loc, equalToRunning, + "mismatched size for broadcast"); + } } broadcastedIndexShape.push_back(dynamicDims[0]); } else { diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 23528bb01f80..66380dea9a89 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -58,15 +58,18 @@ class ConvertAtenMmOp : public OpConversionPattern { } Value lhsDim0 = rewriter.create(loc, lhs, 0); - Value lhsDim1 = rewriter.create(loc, lhs, 1); - Value rhsDim0 = rewriter.create(loc, rhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); - Value contractingDimEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); - rewriter.create( - loc, contractingDimEqual, - rewriter.getStringAttr( - "mismatching contracting dimension for torch.aten.mm")); + + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value lhsDim1 = rewriter.create(loc, lhs, 1); + Value rhsDim0 = rewriter.create(loc, rhs, 0); + Value contractingDimEqual = rewriter.create( + loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); + rewriter.create( + loc, contractingDimEqual, + rewriter.getStringAttr( + "mismatching contracting dimension for torch.aten.mm")); + } Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 262d3cf62e54..a1e8e5fb72d9 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -42,7 +42,9 @@ class ConvertAtenSizeIntOp : public OpConversionPattern { Value inputRank = rewriter.create( loc, rewriter.getI64IntegerAttr(type.getRank())); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); - assertIsValidDim(rewriter, loc, dimPositive, inputRank); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + assertIsValidDim(rewriter, loc, dimPositive, inputRank); + } Value size = rewriter.create( loc, adaptor.getSelf(), castIntToIndex(rewriter, loc, dimPositive)); rewriter.replaceOp(op, castIndexToInt64(rewriter, loc, size)); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 69634a610e35..082680b3ecb5 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1481,10 +1481,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; - contractingDim0EqualsNumFeatures(weight); - contractingDim0EqualsNumFeatures(bias); - contractingDim0EqualsNumFeatures(runningMean); - contractingDim0EqualsNumFeatures(runningVar); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + contractingDim0EqualsNumFeatures(weight); + contractingDim0EqualsNumFeatures(bias); + contractingDim0EqualsNumFeatures(runningMean); + contractingDim0EqualsNumFeatures(runningVar); + } auto indexingMap = AffineMap::get( /*dimCount=*/inputRank, diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 42c5d0b441cc..99b86027b8e5 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -231,7 +231,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // if this is the first tensor operand that didn't continue above: // take its dimension size as the size of the non-broadcasted // traversal along this dimension (this may include a dynamic size-1, - // **non-broadcasted** traversal!) + // **non-broadcasted** traversal unless if + // isAssumingStrictSymbolicShapes!) // emit error check "if the size does not match the non-broadcasted // traversal size along this dimension, error" // ``` @@ -251,6 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape(resultRank, c1); SmallVector indexingMaps; + bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b); for (Value tensorOperand : tensorOperands) { SmallVector exprs; auto type = tensorOperand.getType().cast(); @@ -294,11 +296,13 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // This is the check which protects against the undefined behavior of // the generated linalg op in the case of iterating two operands with // dimensions sizes that are expected to match. - auto equalToRunning = - b.create(loc, arith::CmpIPredicate::eq, - resultShape[resultDim], currentDimSize); - b.create(loc, equalToRunning, - "mismatched size for broadcast"); + if (!elideDynamicBroadcastCheck) { + auto equalToRunning = + b.create(loc, arith::CmpIPredicate::eq, + resultShape[resultDim], currentDimSize); + b.create(loc, equalToRunning, + "mismatched size for broadcast"); + } } indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); @@ -337,6 +341,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Type elementType = inputType.getElementType(); Location loc = op->getLoc(); SmallVector outShape; + bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter); // Create affine map and shapes for tensor initialization. SmallVector outExpr; @@ -351,12 +356,14 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Value shapeValue = broadcastToShape[i]; size_t j = i - diff; if (i < diff) { - Value isValid = rewriter.create( - loc, arith::CmpIPredicate::sge, shapeValue, zero); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "negative values not allowed in new dimensions")); + if (!elideDynamicBroadcastCheck) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "negative values not allowed in new dimensions")); + } outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); continue; } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6136db09221d..0bdfca26ddc1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3484,11 +3484,13 @@ class DecomposeAtenAdaptiveAvgPool1dOp : rewriter.create( loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); } else { - Value cond = rewriter.create(loc, inputSize, outputSize); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create(loc, inputSize, outputSize); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + } kernelSize.push_back(constantOne); } @@ -3586,13 +3588,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp loc, rewriter.getI64IntegerAttr( inputShape[rank - 2 + i]))); } else { - Value cond = rewriter.create(loc, inputHW[i], - outputShapeSizesTorchInt[i]); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); - + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create( + loc, inputHW[i], outputShapeSizesTorchInt[i]); + rewriter.create(loc, cond, + "unimplemented: only support cases " + "where input and output size are " + "equal for non-unit output size"); + } Value outMinusOne = rewriter.create( loc, outputShapeSizesTorchInt[i], constantOne); kernelSize.push_back( @@ -3822,13 +3825,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, loc, rewriter.getF64FloatAttr(correction)); // The `correction` value should be less than or equal to `productDimSize + // 1`. - Value productDimSizePlusOne = rewriter.create( - loc, productDimSize.getType(), productDimSize, constantOne); - Value cond = - rewriter.create(loc, productDimSizePlusOne, cstCorrection); - rewriter.create( - loc, cond, - "correction value should be less than or equal to productDimSize + 1"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value productDimSizePlusOne = rewriter.create( + loc, productDimSize.getType(), productDimSize, constantOne); + Value cond = rewriter.create(loc, productDimSizePlusOne, + cstCorrection); + rewriter.create( + loc, cond, + "correction value should be less than or equal to productDimSize + 1"); + } Value productDimSizeSubCorrection = rewriter.create(loc, productDimSize, cstCorrection); Value result = rewriter.create(loc, newOutputType, squareSum, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 10c4bea67dc0..5de777763ea5 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -324,3 +324,12 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, op->getLoc(), unsqueezedType, input, dim); return unsqueezed; } + +bool Torch::isAssumingStrictSymbolicShapes(Block *block) { + for (Operation *parentOp = block->getParentOp(); parentOp; + parentOp = parentOp->getParentOp()) { + if (parentOp->hasAttr("torch.assume_strict_symbolic_shapes")) + return true; + } + return false; +} diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index d95b7e1d87cf..470962e2494d 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -8,11 +8,11 @@ // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor +// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[RHS_DIM_0:.*]] = tensor.dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor // CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index // CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm" // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty(%[[LHS_DIM_0]], %[[RHS_DIM_1]]) : tensor @@ -29,6 +29,17 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- +// CHECK-LABEL: func.func @torch.aten.mm$basic_strict( +// CHECK-NOT: assert +func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> + return %0 : !torch.vtensor<[?,2],f32> +} + +// ----- + // If the operands are missing dtype, we cannot lower it. func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { // expected-error@+1 {{failed to legalize}} @@ -264,4 +275,4 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> { %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> -} \ No newline at end of file +} From 71ac62f3a89f751a2750e922757789ff0cff489e Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 29 Sep 2023 05:09:31 +0000 Subject: [PATCH 19/41] build: manually update PyTorch version Set PyTorch and TorchVision version to nightly release 2023-09-28. aten.baddbmm changes done because upstream PyTorch has now added support for fp16 gemm on CPU. Refer: https://github.com/pytorch/pytorch/commit/9399e0b1ff743d2c968196a3be129307ac360823 --- .../Transforms/AbstractInterpLibrary.cpp | 25 ++++++++----------- .../build_tools/abstract_interp_lib_gen.py | 6 ++--- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 5 files changed, 16 insertions(+), 21 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index de99e529ca5d..2fffbd313927 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9950,39 +9950,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" +" %2 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" +" %3 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" +" %4 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %11 : !torch.int\n" +" %5 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %6 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %7 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index d142770b1810..7c507f53b70e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2822,7 +2822,7 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U # TODO: This should be fixed by switching to FakeTensor instead of Meta tensor @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool, torch.float16}) + + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool}) + [ErrorInvocation(TensorOfShape( 1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int32, device="cpu")), ErrorInvocation( @@ -2834,8 +2834,8 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int: batch1_rank, batch1_dtype = batch1_rank_dtype batch2_rank, batch2_dtype = batch2_rank_dtype - assert batch1_dtype not in [torch.bool, torch.float16] - assert batch2_dtype not in [torch.bool, torch.float16] + assert batch1_dtype is not torch.bool + assert batch2_dtype is not torch.bool assert batch1_dtype == batch2_dtype ranks: List[Optional[int]] = [batch1_rank, batch2_rank] dtypes = [batch1_dtype, batch2_dtype] diff --git a/pytorch-hash.txt b/pytorch-hash.txt index fb206a10a32d..c0a7e8b3e824 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -d7520d8668dc08f7bed27a64f006c909006e653a +fecde478ac83edf78e7d0e9d11ab73cb1580f6cf diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 3dfe294db805..4e3a9152f677 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230927 +torch==2.2.0.dev20230928 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index ae783e410baa..9c653bbcd577 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230927 +torchvision==0.17.0.dev20230928 From c434736ee9289cf4f0eb4f6a530eebfb56e9e16e Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 29 Sep 2023 12:19:18 +0000 Subject: [PATCH 20/41] [MLIR][TORCH] Add support for conversion to int8 dtype Signed-Off By: Vivek Khandelwal --- e2e_testing/xfail_sets.py | 3 ++ include/torch-mlir/Conversion/Utils/Utils.h | 3 +- .../TorchToLinalg/Uncategorized.cpp | 18 ++++++++- lib/Conversion/Utils/Utils.cpp | 23 +++++++---- .../test_suite/__init__.py | 1 + .../test_suite/elementwise.py | 38 +++++++++++++++++++ 6 files changed, 76 insertions(+), 10 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 6446a085ff67..8b237cf50b43 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -288,6 +288,9 @@ # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only "AtenEmbeddingBagStaticModule_basic", + + # Lowering not present for this case + "ElementwiseToDtypeI64ToUI8Module_basic", } if torch_version_for_comparison() < version.parse("2.1.0.dev"): diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 8795974a395c..d561c2101173 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, - std::optional srcOriginalDtype = std::nullopt); + std::optional srcOriginalDtype = std::nullopt, + std::optional dstOriginalDtype = std::nullopt); Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 082680b3ecb5..8a6366990b94 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -988,7 +988,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(atenToDtype.getType()) .cast() .getElementType(); - Value result = convertScalarToDtype(b, loc, input, dtype); + Type resultElementType; + int64_t dtypeInt; + if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) { + atenToDtype.emitError("unimplemented: dtype must be a constant integer"); + return nullptr; + } + FailureOr maybeResultElementType = getTypeForScalarType( + atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt, + IntegerType::Signless); + if (failed(maybeResultElementType)) { + atenToDtype.emitError("unable to convert `dtypeInt` to builtin type"); + return nullptr; + } + resultElementType = *maybeResultElementType; + Value result = convertScalarToDtype(b, loc, input, dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); return result; } if (auto divScalar = dyn_cast(op)) { diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index c192ff33a25f..89b17a50be99 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -249,7 +249,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, - std::optional srcOriginalDtype) { + std::optional srcOriginalDtype, + std::optional dstOriginalDtype) { Type scalarType = scalar.getType(); if (scalarType == dtype) return scalar; @@ -261,14 +262,20 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return false; }; - // We only support conversion from Byte or Char scalarType not to Byte or Char - // dtype. + // We don't support conversion to Byte dtype. if (isByteOrChar(dtype)) { - mlir::emitError(loc) << "unsupported: conversion to byte or char type for " - "convertScalarToDtype " - << scalarType << "(scalar type) -> " << dtype - << "(dtype)"; - return nullptr; + if (!dstOriginalDtype.has_value()) { + mlir::emitError(loc) + << "unimplemented: for conversion to byte or char type " + "dstOriginalDtype has to be passed to convertScalarToDtype"; + return nullptr; + } + if (dstOriginalDtype->isUnsignedInteger()) { + mlir::emitError(loc) + << "unsupported: conversion to byte type for convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + return nullptr; + } } // If the dtype is i1, i.e., a boolean type. diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index 7add08a3ecee..dce83de277ad 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -14,6 +14,7 @@ "NativeGroupNormBackwardModule_basic", "QuantizedMLP_basic", "ReduceMaxAlongDimUnsignedInt_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", } # TODO: Delete once torch 2.1.0 is released diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index a2603bf8e7dd..2df0a5513d4a 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1642,6 +1642,44 @@ def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseToDtypeI64ToI8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.int64, True)]) + def forward(self, x): + return x.to(torch.int8) + + +@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToI8Module()) +def ElementwiseToDtypeI64ToI8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + + +# ============================================================================== + + +class ElementwiseToDtypeI64ToUI8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.int64, True)]) + def forward(self, x): + return x.to(torch.uint8) + + +@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToUI8Module()) +def ElementwiseToDtypeI64ToUI8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + + +# ============================================================================== + + class ElementwiseLog2Module(torch.nn.Module): def __init__(self): From 9293326e1eb09a99cdec4ff7e08bfa1b47fbcd5f Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 28 Sep 2023 12:53:02 +0000 Subject: [PATCH 21/41] [MLIR][TORCH] Add support for bitwise_right_shit and bitwise_and.Scalar op Signed-Off By: Vivek Khandelwal --- e2e_testing/xfail_sets.py | 2 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 94 ++++++++++++++++ .../TorchToLinalg/Uncategorized.cpp | 68 +++++++++--- .../Transforms/AbstractInterpLibrary.cpp | 25 +++++ .../build_tools/abstract_interp_lib_gen.py | 22 ++++ .../jit_ir/build_tools/torch_ods_gen.py | 2 + .../test_suite/elementwise.py | 104 ++++++++++++++++++ 7 files changed, 299 insertions(+), 18 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 8b237cf50b43..15cf9bb79f91 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1421,4 +1421,6 @@ "UniformStaticShapeModule_basic", "AtenEmbeddingBagStaticModule_basic", "EmptyStridedModule_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 673fab897be0..4ecad92c662b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2844,6 +2844,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ }]; } +def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -2938,6 +2985,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ }]; } +def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseRightShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseRightShiftTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseRightShift_TensorOp : Torch_Op<"aten.bitwise_right_shift_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_right_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseRightShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseRightShift_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8a6366990b94..b47e13c8619e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -300,6 +300,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto bitwiseAndScalar = dyn_cast(op)) { + Type dtype = converter->convertType(bitwiseAndScalar.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + bitwiseAndScalar.emitError( + "bitwise_and.Scalar does not support non-integer input dtype."); + return nullptr; + } + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value other = convertScalarToDtype(b, loc, operands[1], dtype); + return b.create(loc, self, other); + } if (auto bitwiseOrTensor = dyn_cast(op)) { if (bitwiseOrTensor.getType() .cast() @@ -332,6 +345,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto bitwiseRightShiftTensor = + dyn_cast(op)) { + Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + bitwiseRightShiftTensor.emitError( + "Bitwise_Right_Shift op does not support non-integer input dtype."); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } if (isa(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); @@ -571,7 +598,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (dtype.isa()) { return b.create(loc, lhs, rhs); - } else if(dtype.isa()) { + } else if (dtype.isa()) { return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); @@ -1066,7 +1093,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value self = payloadArgs[0]; - Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); + Value threshold = + convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype); Value predicate; @@ -1088,7 +1116,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); + Value threshold = + convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value constantZero = b.create(loc, b.getZeroAttr(dtype)); Value predicate; @@ -1197,10 +1226,11 @@ class ConvertElementwiseOp : public ConversionPattern { AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, - AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, - AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, + AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, + AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, @@ -1699,7 +1729,8 @@ class ConvertAtenDetachOp : public OpConversionPattern { return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); return success(); } }; @@ -1735,16 +1766,17 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, - AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, - AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, - AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, - AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, - AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, - AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, - AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenRealOp, AtenImagOp>(); + AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, + AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, + AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, + AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, + AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, + AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, + AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp, + AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, + AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, + AtenImagOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2fffbd313927..e8f5aa568f59 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7410,10 +7410,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9201,6 +9209,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9217,6 +9234,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 7c507f53b70e..f74895d9dad6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -796,9 +796,15 @@ def aten〇bitwise_or〇Tensor〡shape(self: List[int], other: List[int]) -> Lis def aten〇bitwise_and〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_and〇Scalar〡shape(self: List[int], other: float) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇bitwise_not〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2265,6 +2271,14 @@ def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_ dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇bitwise_and〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_two_tensor_op()) def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype @@ -2281,6 +2295,14 @@ def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_ dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + # Different width diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 4c2b30d817db..56d18d3847d1 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -301,8 +301,10 @@ def emit_with_mutating_variants(key, **kwargs): "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)", "aten::unsqueeze : (Tensor, int) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 2df0a5513d4a..a0137f23e71b 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3515,3 +3515,107 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: TupleModule()) def TupleModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2), tu.rand(2, 2)) + + +# ============================================================================== + + +class ElementwiseBitwiseRightShiftInt64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt64Module()) +def ElementwiseBitwiseRightShiftInt64Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64)) + + +class ElementwiseBitwiseRightShiftInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 4], torch.int32, True), + ([-1, 1], torch.int32, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt32Module()) +def ElementwiseBitwiseRightShiftInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32)) + + +class ElementwiseBitwiseRightShiftInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ]) + def forward(self, lhs, rhs): + return torch.bitwise_right_shift(lhs, rhs) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt8Module()) +def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8)) + + +# ============================================================================== + + +class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.bitwise_and(x, 15) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt64Module()) +def ElementwiseBitwiseAndScalarInt64Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000)) + + +class ElementwiseBitwiseAndScalarInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.bitwise_and(x, 100) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt32Module()) +def ElementwiseBitwiseAndScalarInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32)) From b75c208f4e3343c997bc995a6284ffdd92476851 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 2 Oct 2023 08:02:15 -0700 Subject: [PATCH 22/41] update PyTorch version to 2.2.0.dev20231002 (#2497) torch version: 2.2.0.dev20231002 torch commit hash: 4dae8b49630d2784f6a5d8726db30923e2d1e077 torchvision version: 0.17.0.dev20231002 Co-authored-by: Roll PyTorch Action --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index c0a7e8b3e824..27152016b797 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -fecde478ac83edf78e7d0e9d11ab73cb1580f6cf +4dae8b49630d2784f6a5d8726db30923e2d1e077 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 4e3a9152f677..011435e57235 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230928 +torch==2.2.0.dev20231002 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 9c653bbcd577..069ca40553ea 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230928 +torchvision==0.17.0.dev20231002 From d10a86f51c1224200e9044a3f95be5b85c6f3a81 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 28 Sep 2023 13:43:14 +0000 Subject: [PATCH 23/41] Disable LTC for arm release Also, revert https://github.com/llvm/torch-mlir/pull/2488. Disabling LTC based on the discussion here: https://discord.com/channels/636084430946959380/742573221882364009/1156272667813494824 --- .github/workflows/buildRelease.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 1e732f4b3732..1a9ce3fb3ca9 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -115,7 +115,7 @@ jobs: cd $GITHUB_WORKSPACE TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh + TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TORCH_MLIR_ENABLE_LTC='0' ./build_tools/python_deploy/build_linux_packages.sh # If we were given a release_id, then upload the package we just built # to the github releases page. From 32d9b20bdef9ced97014499989a77b50ab8408ed Mon Sep 17 00:00:00 2001 From: "Jae Hoon (Antonio) Kim" <17433012+antoniojkim@users.noreply.github.com> Date: Tue, 3 Oct 2023 11:01:07 -0400 Subject: [PATCH 24/41] Add linspace/cumprod/roll ops (#2498) Add linspace/cumprod/roll ops to ODS and add shape inference functions to make it work with LTC. Also, add some tensor utils to LTC library for searching for non-detach copy nodes. --- e2e_testing/xfail_sets.py | 2 - .../Dialect/Torch/IR/GeneratedTorchOps.td | 54 +++++++++++++++ .../csrc/base_lazy_backend/mlir_node.cpp | 2 +- .../csrc/base_lazy_backend/mlir_node.h | 2 +- .../base_lazy_backend/shape_inference.cpp | 12 ++++ .../base_lazy_backend/utils/tensor_utils.cpp | 66 +++++++++++++++---- .../base_lazy_backend/utils/tensor_utils.h | 9 ++- .../jit_ir/build_tools/torch_ods_gen.py | 2 + 8 files changed, 129 insertions(+), 20 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 15cf9bb79f91..e46459f1a91b 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -830,7 +830,6 @@ "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", - "RollModule_basic", "TestMultipleTensorReturn_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", @@ -1355,7 +1354,6 @@ "NeFloatIntModule_basic", "NeIntModule_basic", "QuantizedMLP_basic", - "RollModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", "SliceEndSleStartModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4ecad92c662b..f1338142d197 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6440,6 +6440,31 @@ def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ }]; } +def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cumprod : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCumprodOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenCumprodOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [ AllowsTypeRefinement, HasValueSemantics, @@ -10464,6 +10489,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [ }]; } +def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + Torch_IntType:$steps, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinspaceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenLinspaceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp index e4b75e5d53d1..39dc1ad0cd58 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp @@ -83,7 +83,7 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; } hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } -TorchMlirNode* TorchMlirNode::mlir_node(int index) { +TorchMlirNode* TorchMlirNode::mlir_node(int index) const { return dynamic_cast(operands_.at(index).get()); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h index dbf3117dbb13..4b5e196beb20 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h @@ -51,7 +51,7 @@ class TORCH_API TorchMlirNode : public torch::lazy::Node { hash_t shapeHash() const override; - TorchMlirNode* mlir_node(int index); + TorchMlirNode* mlir_node(int index) const; virtual TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 043094c67e0a..d5458f9c4ea6 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -386,5 +386,17 @@ std::vector compute_shape_scalar_tensor( return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; } +std::vector compute_shape_roll( + const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { + auto out_meta = + at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + + } // namespace lazy } // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp index 7131e9a66a2b..cdd97168031b 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.cpp @@ -7,28 +7,66 @@ namespace torch { namespace lazy { +bool is_detach_copy(const torch::lazy::Node* node) { + return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); +} bool is_detach_copy(const torch::lazy::Value& value) { - return value->op() == torch::lazy::DetachCopy::ClassOpKind(); + return is_detach_copy(value.node.get()); } -torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { - if (!value) { +torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) { + if (!node) { return nullptr; } + + torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); + while(mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; +} + +const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) { + if (!node) { return nullptr; } + + const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); + while(mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; +} + + +torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) { + if (!node) { + return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; +} +const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) { + if (!node) { return nullptr; } - torch::lazy::TorchMlirNode* node = dynamic_cast(value.node.get()); - while(node) { - if (node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } - else if (node->op() == torch::lazy::DetachCopy::ClassOpKind()) { - node = node->mlir_node(0); - } - else { - break; - } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); } return nullptr; } +torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { + if (!value) { + return nullptr; + } + return device_data_cast(value.node.get()); +} torch::lazy::DeviceData* device_data_cast( const at::Tensor& tensor, c10::optional device diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h index 717173e9a8fc..745be78c35d2 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h +++ b/python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h @@ -8,10 +8,15 @@ namespace torch { namespace lazy { -TORCH_API bool is_detach_copy(const torch::lazy::Value& value); +TORCH_API bool is_detach_copy(const torch::lazy::Node*); +TORCH_API bool is_detach_copy(const torch::lazy::Value&); -TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); +TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*); +TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*); +TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*); +TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*); +TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); TORCH_API torch::lazy::DeviceData* device_data_cast( const at::Tensor& tensor, c10::optional device = c10::nullopt ); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 56d18d3847d1..f540a1ad2a7d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -471,6 +471,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") + emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)") emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") @@ -625,6 +626,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") + emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)") # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") From ca6ce8974f4ae6a5980aaa099b70b3fd16bb5d63 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 3 Oct 2023 11:59:56 +0000 Subject: [PATCH 25/41] [MLIR][TORCH] Add support for int8 dtype for sub, add, and bitwise_and op Signed-Off By: Vivek Khandelwal --- e2e_testing/xfail_sets.py | 6 ++ .../TorchToLinalg/Uncategorized.cpp | 37 ++++++++--- .../test_suite/elementwise.py | 66 +++++++++++++++++++ 3 files changed, 101 insertions(+), 8 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index e46459f1a91b..fd9827772547 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -291,6 +291,9 @@ # Lowering not present for this case "ElementwiseToDtypeI64ToUI8Module_basic", + + # torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8! + "ElementwiseAddScalarInt8Module_basic", } if torch_version_for_comparison() < version.parse("2.1.0.dev"): @@ -1261,6 +1264,8 @@ "SoftmaxIntNegDimModule_basic", "_LogSoftmaxModule_basic", "_SoftmaxModule_basic", + "ElementwiseAddScalarInt8Module_basic", + "ElementwiseSubTensorInt8Module_basic", } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { @@ -1421,4 +1426,5 @@ "EmptyStridedModule_basic", "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index b47e13c8619e..9c862e410994 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -309,8 +309,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp( "bitwise_and.Scalar does not support non-integer input dtype."); return nullptr; } - Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value other = convertScalarToDtype(b, loc, operands[1], dtype); + Type resultElementType = + bitwiseAndScalar.getType().cast().getDtype(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value other = convertScalarToDtype(b, loc, operands[1], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); return b.create(loc, self, other); } if (auto bitwiseOrTensor = dyn_cast(op)) { @@ -542,9 +548,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(sub.getType()) .cast() .getElementType(); - Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype); + Type resultElementType = sub.getType().cast().getDtype(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); @@ -575,9 +588,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(addScalar.getType()) .cast() .getElementType(); - Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value other = convertScalarToDtype(b, loc, operands[1], dtype); - Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); + Type resultElementType = + addScalar.getType().cast().getDtype(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value other = convertScalarToDtype(b, loc, operands[1], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value alpha = convertScalarToDtype(b, loc, operands[2], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index a0137f23e71b..3b2997c3e482 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2316,6 +2316,31 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSubTensorInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ([-1, -1], torch.int8, True), + ]) + def forward(self, x, y): + return torch.sub(x, y, alpha=2) + + +@register_test_case(module_factory=lambda: ElementwiseSubTensorInt8Module()) +def ElementwiseSubTensorInt8Module_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, high=10).to(dtype=torch.int8), + tu.randint(3, 4, high=10).to(dtype=torch.int8)) + + +# ============================================================================== + + class ElementwiseSubScalarIntModule(torch.nn.Module): def __init__(self): @@ -2472,6 +2497,28 @@ def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAddScalarInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + return torch.add(x, 3, 2) + + +@register_test_case(module_factory=lambda: ElementwiseAddScalarInt8Module()) +def ElementwiseAddScalarInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=10).to(torch.int8)) + + +# ============================================================================== + + class ElementwiseCloneModule(torch.nn.Module): def __init__(self): @@ -3619,3 +3666,22 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt32Module()) def ElementwiseBitwiseAndScalarInt32Module_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32)) + + +class ElementwiseBitwiseAndScalarInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + return torch.bitwise_and(x, 100) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt8Module()) +def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8)) From 4892ed433f4b396d5b6fccfcc0fc74596fa1034f Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 3 Oct 2023 10:02:55 -0700 Subject: [PATCH 26/41] update PyTorch version to 2.2.0.dev20231003 (#2500) torch version: 2.2.0.dev20231003 torch commit hash: 4e30fa82315208dcd38fa16a0ed9851fa8e98bc9 torchvision version: 0.17.0.dev20231003 Co-authored-by: Roll PyTorch Action --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 27152016b797..ccb5a1db2f3c 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -4dae8b49630d2784f6a5d8726db30923e2d1e077 +4e30fa82315208dcd38fa16a0ed9851fa8e98bc9 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 011435e57235..54b76cc34d30 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20231002 +torch==2.2.0.dev20231003 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 069ca40553ea..46790e4e95e3 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20231002 +torchvision==0.17.0.dev20231003 From 1c508af0ba011b667d24adbefc393b88af4a7f85 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 3 Oct 2023 18:49:41 +0000 Subject: [PATCH 27/41] Revert "[linalg] Fix handling of trailing size-1 dimensions in aten.view (#2474)" This reverts commit 7c6b9d2445288cea7140b94a8586ce7873b50ccb. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 51 +++---------------- .../test_suite/reshape_like.py | 36 +------------ 2 files changed, 8 insertions(+), 79 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 74b7badd8253..6ed9d369e8e5 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -193,9 +193,6 @@ class ConvertAtenViewOp : public OpConversionPattern { ArrayRef yDims, SmallVector &xIndices, SmallVector &yIndices) { - if (xDims.empty() || yDims.empty()) - return failure(); - auto isValidReduction = [](int64_t expectedReductionProduct, ArrayRef arrayToReduce) -> bool { if (llvm::count(arrayToReduce, kUnknownSize) > 0 || @@ -258,25 +255,6 @@ class ConvertAtenViewOp : public OpConversionPattern { return success(); } - // If one of the two dims arrays has size 0 and the other array only - // has dims of size 1, a mapping is created from no dimensions to - // all the dimensions of the other array. - static LogicalResult mapTrailingSizeOneDims(ArrayRef xDims, - ArrayRef yDims, - SmallVector &xIndices, - SmallVector &yIndices) { - SmallVector ignoredIndices; - if (xDims.empty()) { - return mapAllDimsToSingleDim(ArrayRef({1}), yDims, - ignoredIndices, yIndices); - } else if (yDims.empty()) { - return mapAllDimsToSingleDim(xDims, ArrayRef({1}), xIndices, - ignoredIndices); - } else { - return failure(); - } - } - // Calculates the size of a dynamic dimension if all other dimensions are // statically known, and rewrites that dynamic dimension with the static size. // @@ -284,8 +262,6 @@ class ConvertAtenViewOp : public OpConversionPattern { // all the dimensions in `outputShape`. static void calculateSingleDynamicSize(MutableArrayRef inputShape, MutableArrayRef outputShape) { - if (inputShape.empty() || outputShape.empty()) - return; int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); if (inputDynamicDimCount + outputDynamicDimCount != 1) @@ -444,7 +420,7 @@ class ConvertAtenViewOp : public OpConversionPattern { for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) { // Used for ensuring that we don't have an ambiguous expansion bool assumedDynamicDimNotSplit = false; - while (inputDim < nextUnchangedInput || outputDim < nextUnchangedOutput) { + while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { auto inputShapeSlice = MutableArrayRef(inputShape) .slice(inputDim, nextUnchangedInput - inputDim); @@ -465,15 +441,9 @@ class ConvertAtenViewOp : public OpConversionPattern { "(e.g. [-1, -1] -> [-1, -1, -1])"); } - if (succeeded(mapTrailingSizeOneDims(inputShapeSlice, outputShapeSlice, - inputSliceIndices, - outputSliceIndices))) { - } else if (outputShapeSlice.empty()) { - inputSliceIndices.assign( - llvm::to_vector(llvm::seq(0, inputShapeSlice.size()))); - } else if (succeeded(mapAllDimsToSingleDim( - inputShapeSlice, outputShapeSlice, inputSliceIndices, - outputSliceIndices))) { + if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice, + inputSliceIndices, + outputSliceIndices))) { calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice); // Update shape to pass the tensor.expand_shape and // tensor.collapse_shape verifiers. If one of the dimensions of the @@ -492,8 +462,7 @@ class ConvertAtenViewOp : public OpConversionPattern { /// `mapStaticallyKnownDims` maps the smallest number of /// input and output dimensions in the slice statically /// known to have the same number of elements. - } else if (inputShapeSlice.size() > 0 && - inputShapeSlice[0] == kUnknownSize) { + } else if (inputShapeSlice[0] == kUnknownSize) { // If the input is dynamic, assume it is not split checkDimEqualHelper(rewriter, loc, inputSize[inputDim], outputSizeInt[outputDim]); @@ -509,14 +478,8 @@ class ConvertAtenViewOp : public OpConversionPattern { "in `aten.view`"); } - // If one of the slices is empty, this means we are handling - // the case of trailing dimensions, which does not require a - // new reassociation; the trailing dimensions get added to the - // last reassociation created. - if (inputShapeSlice.size() > 0 && outputShapeSlice.size() > 0) { - inputAssociations.emplace_back(); - outputAssociations.emplace_back(); - } + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); for (int64_t inputSliceIndex : inputSliceIndices) inputAssociations.back().push_back(inputSliceIndex + inputDim); for (int64_t outputSliceIndex : outputSliceIndices) diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index f91b90582d24..1c2d810c13de 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -672,40 +672,6 @@ def forward(self, a): def ViewNegativeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 128)) -class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float32, True), - ]) - - def forward(self, a): - return a.view(a.size(0), 1, 1, 1) - -@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule()) -def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils): - module.forward(tu.rand(128)) - -class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, 1, 1, 1], torch.float32, True), - ]) - - def forward(self, a): - return a.view(a.size(0)) - -@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule()) -def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): - module.forward(tu.rand(128, 1, 1, 1)) - # ============================================================================== class ReshapeAliasExpandModule(torch.nn.Module): @@ -744,4 +710,4 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReshapeAliasCollapseModule()) def ReshapeAliasCollapseModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4)) + module.forward(tu.rand(2, 4)) \ No newline at end of file From 2e5d65064c6284d683d4c3d87ecf08647625bda1 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 3 Oct 2023 19:24:01 +0000 Subject: [PATCH 28/41] [linalg] Add handling for leadin and trailing size-1 dims in ViewOp This commit adds to the lowering of `aten.view` handling for the following cases: - `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)` - `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))` - `(a.size(i), ...)` -> `(1, ..., 1, a.size(i), ...)` - `(1, ..., 1, a.size(i), ...)` -> `(a.size(i), ...)` --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 28 ++++- .../test_suite/reshape_like.py | 104 +++++++++++++++++- 2 files changed, 128 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 6ed9d369e8e5..2897ff3423e9 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -193,6 +193,9 @@ class ConvertAtenViewOp : public OpConversionPattern { ArrayRef yDims, SmallVector &xIndices, SmallVector &yIndices) { + if (xDims.empty() || yDims.empty()) + return failure(); + auto isValidReduction = [](int64_t expectedReductionProduct, ArrayRef arrayToReduce) -> bool { if (llvm::count(arrayToReduce, kUnknownSize) > 0 || @@ -262,6 +265,8 @@ class ConvertAtenViewOp : public OpConversionPattern { // all the dimensions in `outputShape`. static void calculateSingleDynamicSize(MutableArrayRef inputShape, MutableArrayRef outputShape) { + if (inputShape.empty() || outputShape.empty()) + return; int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); if (inputDynamicDimCount + outputDynamicDimCount != 1) @@ -488,12 +493,29 @@ class ConvertAtenViewOp : public OpConversionPattern { outputDim = outputAssociations.back().back() + 1; } - // Append the associations for the dims matching `aten.size.int` - if (nextUnchangedInput != inputRank && - nextUnchangedOutput != resultRank) { + // Handle any leading or trailing size-1 dimensions and append the + // associations for the dims matching `aten.size.int`. + if (nextUnchangedInput != inputRank) { + assert(nextUnchangedOutput != resultRank && + "`nextUnchangedInput` and `nextUnchangedOutput` should equal " + "the respective input and output rank at the same time"); inputAssociations.emplace_back(); outputAssociations.emplace_back(); + } + while (inputDim <= nextUnchangedInput && inputDim < inputRank) { + if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only collapsing of static size-1 into " + "unchanged dim supported"); + } inputAssociations.back().push_back(inputDim++); + } + while (outputDim <= nextUnchangedOutput && outputDim < resultRank) { + if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only expanding of static size-1 out of " + "unchanged dim supported"); + } outputAssociations.back().push_back(outputDim++); } } diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 1c2d810c13de..304d3025eb8d 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -672,6 +672,108 @@ def forward(self, a): def ViewNegativeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 128)) +class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0), 1, 1, 1) + +@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule()) +def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 1, 1, 1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0)) + +@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule()) +def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 1, 1, 1)) + +class ViewSizeDimLedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(1, 1, 1, a.size(0)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedByExpandedOnesModule()) +def ViewSizeDimLedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimLedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(3)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedByCollapsedOnesModule()) +def ViewSizeDimLedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 1, 128)) + +class ViewSizeDimLedAndFollowedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(1, 1, 1, a.size(0), 1, 1, 1) + +@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule()) +def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimLedAndFollowedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 1, -1, 1, 1, 1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(3)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule()) +def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 1, 128, 1, 1, 1)) + # ============================================================================== class ReshapeAliasExpandModule(torch.nn.Module): @@ -710,4 +812,4 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReshapeAliasCollapseModule()) def ReshapeAliasCollapseModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4)) \ No newline at end of file + module.forward(tu.rand(2, 4)) From 14e6da8588ba28fb202c4adfbdb0980a87c74524 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 4 Oct 2023 07:55:21 -0700 Subject: [PATCH 29/41] update PyTorch version to 2.2.0.dev20231004 (#2502) torch version: 2.2.0.dev20231004 torch commit hash: 56af607c0437ed7321da4b96a4dbccdbd8b5a98b torchvision version: 0.17.0.dev20231004 Co-authored-by: Roll PyTorch Action --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index ccb5a1db2f3c..43dc87bfe85f 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -4e30fa82315208dcd38fa16a0ed9851fa8e98bc9 +56af607c0437ed7321da4b96a4dbccdbd8b5a98b diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 54b76cc34d30..ea65818e4460 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20231003 +torch==2.2.0.dev20231004 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 46790e4e95e3..58da02803beb 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20231003 +torchvision==0.17.0.dev20231004 From ae72eec224e57989f610c6f1d45308125a310f7f Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 5 Oct 2023 09:02:10 -0400 Subject: [PATCH 30/41] Improve aten.broadcast_to folder when in strict symbol mode (#2504) Strict symbolic shapes allow us to assume numpy-style dynamic broadcasts never occur. This allows us to strengthen the folder for broadcasts to cases where the rank is the same and all shapes match (including dynamic sentinel values). --- lib/Dialect/Torch/IR/TorchOps.cpp | 3 ++- test/Dialect/Torch/canonicalize.mlir | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index aed453a62da4..bf930d68bc39 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2371,7 +2371,8 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size() || - !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) + (!isAssumingStrictSymbolicShapes((*this)->getBlock()) && + (!inType.areAllSizesKnown() || !outType.areAllSizesKnown()))) return nullptr; for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 21e0500f4eb5..b66bd24e1bb3 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1983,6 +1983,15 @@ func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> ! return %0 : !torch.vtensor<[3,4,2],f32> } +// CHECK-LABEL: func.func @torch.aten.broadcast_to_strict$fold( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?],f32>, {{.*}}) -> !torch.vtensor<[?],f32> +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?],f32> +func.func @torch.aten.broadcast_to_strict$fold(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { + %list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> // CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32> From 42b6c0a14a13b1349b4e2e4fa9a94b62c9ac93cb Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 5 Oct 2023 09:45:53 -0700 Subject: [PATCH 31/41] update PyTorch version to 2.2.0.dev20231005 (#2506) torch version: 2.2.0.dev20231005 torch commit hash: 439cba92777ff61b49d24096edfaf128fbd742ea torchvision version: 0.17.0.dev20231005 Co-authored-by: Roll PyTorch Action --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 43dc87bfe85f..943397b0a254 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -56af607c0437ed7321da4b96a4dbccdbd8b5a98b +439cba92777ff61b49d24096edfaf128fbd742ea diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index ea65818e4460..ae5ae6af0a34 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20231004 +torch==2.2.0.dev20231005 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 58da02803beb..05791c8ba8bc 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20231004 +torchvision==0.17.0.dev20231005 From 6f81ad72938deb56c6d43bbc01388c1f8f1253c1 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 5 Oct 2023 15:15:26 -0400 Subject: [PATCH 32/41] [TorchToLinalg] Improve broadcast lowerings in strict symbolic modes (#2505) With strict symbolic shapes, we can assume numpy-style dynamic broadcasts never occur. This improves the lowering in the presence of this assumption. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 30 ++-- lib/Conversion/TorchToLinalg/Linear.cpp | 15 +- lib/Conversion/TorchToLinalg/Utils.cpp | 131 +++++++++++++++--- lib/Conversion/TorchToLinalg/Utils.h | 10 +- test/Conversion/TorchToLinalg/broadcast.mlir | 90 ++++++++++++ 5 files changed, 237 insertions(+), 39 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/broadcast.mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 2897ff3423e9..662a1379bb41 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1095,31 +1095,35 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { // which in this case is `inShapeConverted` because this shape will yield // us the dimension size of the output. SmallVector useBroadcastToShape; - for (auto x : inShape) { + int64_t inputRank = self.getType().cast().getRank(); + for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e; + ++i) { int64_t dim; - if (!matchPattern(x, m_TorchConstantInt(&dim))) { - Operation *defOp = x.getDefiningOp(); - if (isa(defOp)) - useBroadcastToShape.push_back(true); - else + if (matchPattern(inShape[i], m_TorchConstantInt(&dim))) { + if (dim < 0) { useBroadcastToShape.push_back(false); + } else { + useBroadcastToShape.push_back(true); + } } else { - useBroadcastToShape.push_back(false); + // Note: Dynamic -1 (inferred) broadcast shapes are unimplemented. + useBroadcastToShape.push_back(true); } } SmallVector inShapeConverted = getTypeConvertedValues( rewriter, op.getLoc(), getTypeConverter(), inShape); + auto newResultType = + getTypeConverter()->convertType(op.getType()).cast(); Value result; - if (failed(torch_to_linalg::broadcastToGivenShape(op, rewriter, self, - inShapeConverted, result, - useBroadcastToShape))) { + if (failed(torch_to_linalg::broadcastToGivenShape( + op, rewriter, self, inShapeConverted, newResultType, result, + useBroadcastToShape))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, result); + rewriter.replaceOp(op, result); return success(); } }; @@ -1177,7 +1181,7 @@ class ConvertAtenCopyOp : public OpConversionPattern { selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]); Value broadcastedSrc; if (failed(torch_to_linalg::broadcastToGivenShape( - op, rewriter, src, selfSizes, broadcastedSrc))) { + op, rewriter, src, selfSizes, selfType, broadcastedSrc))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 66380dea9a89..bbf53162d6a1 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -295,13 +295,24 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // Broadcast the batch dimensions of both the matrices. Value broadcastedLhs, broadcastedRhs; + // TODO: Improve usage of static shape information. + SmallVector lhsTargetShape(lhsBroadcastToShape.size(), + ShapedType::kDynamic); + auto lhsBroadcastType = + RankedTensorType::get(lhsTargetShape, lhsType.getElementType()); if (failed(torch_to_linalg::broadcastToGivenShape( - op, rewriter, lhs, lhsBroadcastToShape, broadcastedLhs))) { + op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType, + broadcastedLhs))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } + SmallVector rhsTargetShape(rhsBroadcastToShape.size(), + ShapedType::kDynamic); + auto rhsBroadcastType = + RankedTensorType::get(rhsTargetShape, rhsType.getElementType()); if (failed(torch_to_linalg::broadcastToGivenShape( - op, rewriter, rhs, rhsBroadcastToShape, broadcastedRhs))) { + op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType, + broadcastedRhs))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 99b86027b8e5..a666ca30b02f 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -327,12 +327,15 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // Broadcasts input tensor based on the broadcastToShape. LogicalResult torch_to_linalg::broadcastToGivenShape( Operation *op, PatternRewriter &rewriter, Value input, - SmallVector broadcastToShape, Value &result, - SmallVector useBroadcastToShape) { + SmallVector broadcastToShape, RankedTensorType broadcastType, + Value &result, SmallVector useBroadcastToShape) { RankedTensorType inputType = input.getType().cast(); + int64_t inputRank = inputType.getRank(); + int64_t outputRank = broadcastToShape.size(); + ArrayRef outputShape = broadcastType.getShape(); SmallVector inputShape = makeShapeTorchCompatible(inputType.getShape()); - if (broadcastToShape.size() < inputShape.size()) { + if (outputRank < inputRank) { return rewriter.notifyMatchFailure( op, "invalid shape: broadcastToShape size must not be smaller than the " "size of the input shape"); @@ -340,9 +343,12 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Type elementType = inputType.getElementType(); Location loc = op->getLoc(); - SmallVector outShape; + SmallVector outShape; bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter); + // Vector indicating broadcasted status when assuming strict symbolic shapes. + SmallVector broadcastedStatus; + // Create affine map and shapes for tensor initialization. SmallVector outExpr; Value zero = @@ -351,10 +357,39 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( rewriter.create(loc, rewriter.getIndexAttr(0)); Value oneIndex = rewriter.create(loc, rewriter.getIndexAttr(1)); - size_t diff = broadcastToShape.size() - inputShape.size(); - for (size_t i = 0; i < broadcastToShape.size(); i++) { + size_t diff = outputRank - inputRank; + bool hasDynamicNumpyBroadcast = false; + for (size_t i = 0, e = outputRank; i < e; i++) { Value shapeValue = broadcastToShape[i]; size_t j = i - diff; + bool isDynamic = i >= diff && inputShape[j] == kUnknownSize; + + // Inherit static output shapes if present. + if (outputShape[i] != ShapedType::kDynamic) { + outShape.push_back(rewriter.getIndexAttr(outputShape[i])); + if (i < diff) { + if (outputShape[i] < 0) { + return rewriter.notifyMatchFailure( + op, "invalid shape: negative values not allowed in new broadcast " + "dimensions"); + } + continue; + } + if (isDynamic) { + hasDynamicNumpyBroadcast = true; + } else if (inputShape[j] != outputShape[i] && inputShape[j] != 1) { + return rewriter.notifyMatchFailure( + op, "invalid shape: static mismatch in input and output broadcast " + "shapes"); + } + + // If strict symbolic shapes are assumed and the input shape is dynamic, + // we can assume that dim is not broadcasted. + broadcastedStatus.push_back(inputShape[j] != outputShape[i] && + !isDynamic); + continue; + } + if (i < diff) { if (!elideDynamicBroadcastCheck) { Value isValid = rewriter.create( @@ -374,24 +409,80 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Value select = rewriter.create( loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue)); outShape.push_back(select); - } else { - // Case of dynamic input dimension wherein the shape to broadcast will - // yield us the dimension size of the output. - Value dim = getDimOp(rewriter, loc, input, j); - if (!useBroadcastToShape.empty()) { - if (useBroadcastToShape[i]) - dim = castIntToIndex(rewriter, loc, broadcastToShape[j]); + broadcastedStatus.push_back(true); + continue; + } + + // Case of dynamic input dimension wherein the shape to broadcast will + // yield us the dimension size of the output. + Value dim; + if (!useBroadcastToShape.empty() && useBroadcastToShape[j]) { + dim = castIntToIndex(rewriter, loc, broadcastToShape[i]); + if (isDynamic) { + hasDynamicNumpyBroadcast = true; } - outShape.push_back(dim); + if (!elideDynamicBroadcastCheck) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "unimplemented: dynamic negative broadcast sizes")); + } + } else { + dim = getDimOp(rewriter, loc, input, j); } + // We can safely assume this dimension is not broadcasted with strict + // symbols. + broadcastedStatus.push_back(false); + outShape.push_back(dim); } - Value outTensor = rewriter.create( - loc, getAsOpFoldResult(outShape), elementType); + Value outTensor = + rewriter.create(loc, outShape, elementType); + + // If we know there are no ? -> ? broadcasted dims, or we are assuming + // strict symbols, we can safely use standard linalg style broadcasting + // semantics. + if (!hasDynamicNumpyBroadcast || elideDynamicBroadcastCheck) { + // If no dims are broadcasted and the rank doesn't change, we can just fold + // the op away entirely. + if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) && + inputRank == outputRank) { + result = rewriter.create(loc, outTensor.getType(), input); + return success(); + } + + SmallVector inputExprs; + for (int64_t i = 0, e = inputRank; i < e; ++i) { + if (broadcastedStatus[i]) { + inputExprs.push_back(rewriter.getAffineConstantExpr(0)); + continue; + } + inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); + } + + SmallVector indexingMaps = { + AffineMap::get(outputRank, 0, inputExprs, rewriter.getContext()), + rewriter.getMultiDimIdentityMap(outputRank)}; + SmallVector iteratorTypes( + outputRank, utils::IteratorType::parallel); + result = rewriter + .create( + loc, outTensor.getType(), input, outTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + return success(); + } + // Fall back to numpy-style dynamic broadcasting in the form of a single + // linalg op. SmallVector indexingMaps = { - rewriter.getMultiDimIdentityMap(broadcastToShape.size())}; - SmallVector iteratorTypes(broadcastToShape.size(), + rewriter.getMultiDimIdentityMap(outputRank)}; + SmallVector iteratorTypes(outputRank, utils::IteratorType::parallel); result = rewriter .create( @@ -402,7 +493,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( // would be used to extract values from the input tensor // later on. SmallVector loopIndices; - for (size_t i = 0; i < broadcastToShape.size(); ++i) { + for (size_t i = 0, e = outputRank; i < e; ++i) { if (i < diff) continue; loopIndices.push_back(b.create(loc, i)); @@ -411,7 +502,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( // the i-th input dimension is not 1, else it contains a // zero index. SmallVector inputIndicesToExtract; - for (size_t i = 0, n = inputShape.size(); i < n; i++) { + for (size_t i = 0, n = inputRank; i < n; i++) { if (inputShape[i] == 1) { inputIndicesToExtract.push_back(zeroIndex); } else { diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index 354012028b01..3bee8d642533 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -73,10 +73,12 @@ Value createElementwiseLinalgGeneric( function_ref bodyBuild); // Broadcasts input tensor based on the broadcastToShape. -LogicalResult -broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input, - SmallVector broadcastToShape, Value &result, - SmallVector useBroadcastToShape = {}); +LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, + Value input, + SmallVector broadcastToShape, + RankedTensorType broadcastType, + Value &result, + SmallVector useBroadcastToShape = {}); // Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> // diff --git a/test/Conversion/TorchToLinalg/broadcast.mlir b/test/Conversion/TorchToLinalg/broadcast.mlir new file mode 100644 index 000000000000..8841ba704328 --- /dev/null +++ b/test/Conversion/TorchToLinalg/broadcast.mlir @@ -0,0 +1,90 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$simple_static( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<3x4x2xf32> +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins({{.*}} : tensor<4x2xf32>) outs({{.*}} : tensor<3x4x2xf32>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<3x4x2xf32> +func.func @torch.aten.broadcast_to$simple_static(%arg0: !torch.vtensor<[4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[4,2],f32>, !torch.list -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$static_numpy_broadcast( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<1x4x2xf32> +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, 0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins({{.*}} : tensor<1x1x2xf32>) outs({{.*}} : tensor<1x4x2xf32>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<1x4x2xf32> +func.func @torch.aten.broadcast_to$static_numpy_broadcast(%arg0: !torch.vtensor<[1,1,2],f32>) -> !torch.vtensor<[1,4,2],f32> { + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %list = torch.prim.ListConstruct %int1, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[1,1,2],f32>, !torch.list -> !torch.vtensor<[1,4,2],f32> + return %0 : !torch.vtensor<[1,4,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$empty_input( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>] +// CHECK-SAME: iterator_types = ["parallel"]} +// CHECK-SAME: ins({{.*}} : tensor) outs({{.*}} : tensor) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor +func.func @torch.aten.broadcast_to$empty_input(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> { + %list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$strict_dynamic_broadcast( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins({{.*}} : tensor) outs({{.*}} : tensor) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor +func.func @torch.aten.broadcast_to$strict_dynamic_broadcast(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> attributes {torch.assume_strict_symbolic_shapes} { + %list = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +/// Nothing we can do; verify we hit the fall back path. +// CHECK-LABEL: func.func @torch.aten.broadcast_to$pure_dynamic_broadcast( +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty({{.*}}) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>] +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs({{.*}} : tensor) { +// CHECK: ^bb0(%[[OUT:.+]]: f32): +// CHECK: tensor.extract +func.func @torch.aten.broadcast_to$pure_dynamic_broadcast(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> { + %list = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} From 26ea13ddf544321d7a58e06ce0e58a78a85ca1f6 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 6 Oct 2023 07:27:45 -0700 Subject: [PATCH 33/41] update PyTorch version to 2.2.0.dev20231006 (#2507) torch version: 2.2.0.dev20231006 torch commit hash: 20217d1426d99d0caa70e1473d89e0c834b7f35e torchvision version: 0.17.0.dev20231006 Co-authored-by: Roll PyTorch Action --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 943397b0a254..5a74e57f8cef 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -439cba92777ff61b49d24096edfaf128fbd742ea +20217d1426d99d0caa70e1473d89e0c834b7f35e diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index ae5ae6af0a34..6f00db5ca012 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20231005 +torch==2.2.0.dev20231006 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 05791c8ba8bc..4a54d8b07666 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20231005 +torchvision==0.17.0.dev20231006 From 9b5a4afadd5df5b7f770e2524f9214741707f546 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 10 Oct 2023 11:54:54 -0500 Subject: [PATCH 34/41] Update README to include new meeting schedule (#2503) --- README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c5fa561bcd15..06c314f4423e 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,17 @@ We have few paths to lower down to the Torch MLIR Dialect. - `#torch-mlir` channel on the LLVM [Discord](https://discord.gg/xS7Z362) - this is the most active communication channel - Github issues [here](https://github.com/llvm/torch-mlir/issues) - [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse -- Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information. -- Weekly op office hours on Thursdays 8:30-9:30AM PST. See [here](https://discourse.llvm.org/t/announcing-torch-mlir-office-hours/63973/2) for more information. + +### Meetings + +Community Meeting / Developer Hour: +- 1st and 3rd Monday of the month at 9 am PST +- 2nd and 4th Monday of the month at 5 pm PST + +Office Hours: +- Every Thursday at 8:30 am PST + +Meeting links can be found [here](https://discourse.llvm.org/t/new-community-meeting-developer-hour-schedule/73868). ## Install torch-mlir snapshot From e649e06b7b7a74200354ce49c5863122f3287bdb Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Fri, 13 Oct 2023 18:39:41 -0700 Subject: [PATCH 35/41] Add aten.unflatten.int support and its torch-to-tosa lowering (#2509) Add aten.unflatten.int op Add its torch-to-tosa lowering Update the TorchToTosa/basic.mlir tests To test e2e tosa lowering: `python -m e2e_testing.main -v -c=tosa` --------- Co-authored-by: Ze Zhang --- e2e_testing/xfail_sets.py | 2 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++++++++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 55 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 14 +++++ lib/Dialect/Torch/Utils/Utils.cpp | 8 +-- .../build_tools/abstract_interp_lib_gen.py | 8 +++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 22 ++++++++ test/Conversion/TorchToTosa/basic.mlir | 22 ++++++++ 9 files changed, 152 insertions(+), 4 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index fd9827772547..cce4aa8a6350 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -17,6 +17,7 @@ # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "UnflattenStaticModule_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -1056,6 +1057,7 @@ "BatchNorm3DModule_basic", "BatchNorm1DStaticShapeModule_basic", "FlattenStaticModule_basic", + "UnflattenStaticModule_basic", "FlattenRank0Module_basic", "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f1338142d197..4f4fa561fbba 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7537,6 +7537,30 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [ }]; } +def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchListOfTorchIntType:$sizes + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUnflattenIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenUnflattenIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenDimOp : Torch_Op<"aten.dim", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1e71f51b8598..d2adefc4d3c7 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2525,6 +2525,60 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnflattenIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a ranked tensor type + auto selfType = adaptor.getSelf().getType().dyn_cast(); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, + "Only ranked tensor types with static shapes are currently supported"); + + int64_t selfRank = selfType.getRank(); + int64_t dim; + + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); + + SmallVector sizes; + if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizes))) + return rewriter.notifyMatchFailure( + op, "Only constant sizes are currently supported"); + + if (selfRank > 0 && !isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + SmallVector newShape; + for (auto s : + llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) { + int64_t idx = s.index(); + if (idx < dim || idx > dim) { + newShape.push_back(s.value()); + } else { + auto sum = 1; + for (auto newDims : sizes) { + newShape.push_back(newDims); + sum *= newDims; + } + if (sum != s.value()) + return rewriter.notifyMatchFailure(op, + "sizes mismatch with original dim"); + } + } + + auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), + selfType.getElementType()); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(newType), adaptor.getSelf(), + rewriter.getDenseI64ArrayAttr(newShape)); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPermuteOp op, OpAdaptor adaptor, @@ -5050,6 +5104,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(AtenLog2Op); INSERT_ATENOP_PATTERN(AtenThresholdOp); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index e8f5aa568f59..513d7b018d46 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7205,6 +7205,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.unflatten.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.slice.t %arg0, %none, %arg1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %1 = torch.aten.add.t %0, %arg2 : !torch.list, !torch.list -> !torch.list\n" +" %2 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.slice.t %arg0, %2, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %4 = torch.aten.add.t %1, %3 : !torch.list, !torch.list -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list, !torch.list, !torch.optional>) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8580,6 +8590,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unflatten.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 5de777763ea5..ddc95bd4b2fd 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -199,10 +199,10 @@ bool Torch::isViewLikeOp(Operation *op) { // that it does not return a view and treat those as having value // semantics. return isa List[int]: return upstream_shape_functions.flatten(self, start_dim, end_dim) +def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) -> List[int]: + return self[:dim] + sizes + self[dim + 1:] + def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: return upstream_shape_functions.linear(input, weight, bias) @@ -1656,6 +1659,11 @@ def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_ self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[1])) +def aten〇unflatten〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, sizes: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f540a1ad2a7d..3916f313620b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -516,6 +516,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") + emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)") emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 1b5f62715a30..e0269e68ce33 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -304,6 +304,28 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils): # ============================================================================== +class UnflattenStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 6, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.unflatten(x, 1, (2, 3)) + + +@register_test_case(module_factory=lambda: UnflattenStaticModule()) +def UnflattenStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 6, 4)) + + +# ============================================================================== + + class FlattenStaticModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 49907a98a56d..dc4e4793a67d 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -556,6 +556,28 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor // ----- +// CHECK-LABEL: func.func @forward( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { +// CHECK: %[[VAL:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32> +// CHECK: %[[VAL_1:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL]] {new_shape = array} : (tensor<1x6x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,3,4],f32> +// CHECK: } +func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[1,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[1,2,3,4],f32> + return %1 : !torch.vtensor<[1,2,3,4],f32> +} + +// ----- + // CHECK-LABEL: func.func @forward( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, From f2c53b8ca5389fc63c38a66892b0d393718c3db4 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Mon, 16 Oct 2023 09:44:53 -0700 Subject: [PATCH 36/41] Add aten.isclose support and its torch-to-tosa lowering (#2512) Add aten.isclose op Add its torch-to-tosa lowering Update the TorchToTosa/basic.mlir tests To test e2e tosa lowering: `python -m e2e_testing.main -v -c=tosa` --------- Co-authored-by: Ze Zhang --- e2e_testing/xfail_sets.py | 4 ++ .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 ++++++++++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 54 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 8 +++ .../build_tools/abstract_interp_lib_gen.py | 7 +++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 45 ++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 29 ++++++++++ 8 files changed, 175 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index cce4aa8a6350..d268a31ddf9e 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -18,6 +18,8 @@ # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "UnflattenStaticModule_basic", + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -928,6 +930,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4f4fa561fbba..0530e3082418 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4162,6 +4162,33 @@ def Torch_AtenViewAsRealOp : Torch_Op<"aten.view_as_real", [ }]; } +def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + Torch_FloatType:$rtol, + Torch_FloatType:$atol, + Torch_BoolType:$equal_nan + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIscloseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenIscloseOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index d2adefc4d3c7..970ef15d8d9a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3920,6 +3920,59 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIscloseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // check args + double rtol, atol; + bool equalNan; + if (!matchPattern(op.getRtol(), m_TorchConstantFloat(&rtol))) + return rewriter.notifyMatchFailure(op, "rtol must be a scalar constant"); + if (!matchPattern(op.getAtol(), m_TorchConstantFloat(&atol))) + return rewriter.notifyMatchFailure(op, "atol must be a scalar constant"); + if (!matchPattern(op.getEqualNan(), m_TorchConstantBool(&equalNan))) + return rewriter.notifyMatchFailure( + op, "unimplemented: equal_nan is expected to be false"); + + // check tensor type. + auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto otherType = adaptor.getOther().getType().dyn_cast(); + if (!selfType || !otherType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + if (!selfType.hasStaticShape() || !otherType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + if (!selfType.getElementType().isa() || + !otherType.getElementType().isa()) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only FP element type is supported"); + } + + auto rhsSubOp = rewriter.create( + op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther()); + auto rhsAbsOp = + rewriter.create(op->getLoc(), selfType, rhsSubOp); + + auto lhsAbsOp = + rewriter.create(op->getLoc(), otherType, adaptor.getOther()); + auto rtolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(rtol)); + auto mulOp = rewriter.create(op->getLoc(), otherType, + rtolConstOp, lhsAbsOp, /*shift=*/0); + auto atolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(atol)); + auto addOp = + rewriter.create(op->getLoc(), otherType, atolConstOp, mulOp); + + auto outType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, outType, addOp, + rhsAbsOp); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampOp op, OpAdaptor adaptor, @@ -5134,6 +5187,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 513d7b018d46..47d76219f199 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7480,6 +7480,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isclose\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9093,6 +9097,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isclose\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 00e752f01d95..958df70d575a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -844,6 +844,9 @@ def aten〇lt〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇le〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇isclose〡shape(self: List[int], other: List[int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇unsqueeze〡shape(self: List[int], dim: int) -> List[int]: return upstream_shape_functions.unsqueeze(self, dim) @@ -2171,6 +2174,10 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇isclose〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> int: + return torch.bool + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)])) def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int: _, query_dtype = query_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 3916f313620b..e473603ff6da 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -342,6 +342,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") emit("aten::view_as_real : (Tensor) -> (Tensor)") + emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)") # Ops with dynamic number of outputs emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index e0269e68ce33..d78253a58fc3 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4580,3 +4580,48 @@ def forward(self, x): @register_test_case(module_factory=lambda: Add_Module()) def Add_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) + + +# ============================================================================== + + +class IscloseStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ([5, 5], torch.float32, True), + ]) + def forward(self, x, y): + return torch.isclose(x, y) + + +@register_test_case(module_factory=lambda: IscloseStaticModule()) +def IscloseStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 5), tu.rand(5, 5)) + + +# ============================================================================== + + +class IscloseStaticModuleTrue(torch.nn.Module): + + def __init__(self): + super().__init__() + self.register_buffer('tensor', torch.ones(1)) + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.isclose(x, self.tensor) + +@register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) +def IscloseStaticModuleTrue_basic(module, tu: TestUtils): + module.forward(torch.ones(5, 5)) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index dc4e4793a67d..46023598c423 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1155,3 +1155,32 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to %0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32> return %0 : !torch.vtensor<[2, 4],f32> } + +// ----- + +// CHECK-LABEL: func.func @forward( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[5,5],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[ATOL:.*]] = torch.constant.float 1.000000e-08 +// CHECK: %[[RTOL:.*]] = torch.constant.float 1.000000e-05 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[VAL_2:.*]] = tosa.sub %[[VAL_0]], %[[VAL_1]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_3:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_4:.*]] = tosa.abs %[[VAL_1]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i32} : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_8]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[5,5],i1> +// CHECK: } +func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { + %float1.000000e-08 = torch.constant.float 1.000000e-08 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %false = torch.constant.bool false + %0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1> + return %0 : !torch.vtensor<[5,5],i1> +} From 14a4da923bae4db80db32e5a48a7c6cf501709cd Mon Sep 17 00:00:00 2001 From: Chi_Liu Date: Mon, 16 Oct 2023 19:29:48 -0700 Subject: [PATCH 37/41] Update llvm-project to b44b3494f60296db6aca38a14cab061d9b747a0a (#2511) The main purpose is to bring in the new mesh dialect change. https://github.com/llvm/llvm-project/pull/68007 --- externals/llvm-project | 2 +- lib/RefBackend/CMakeLists.txt | 1 + test/Conversion/TorchToTosa/basic.mlir | 162 +++++++++--------- ...orch-backend-to-tosa-backend-pipeline.mlir | 24 +-- 4 files changed, 98 insertions(+), 91 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index d13da154a7c7..b44b3494f602 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit d13da154a7c7eff77df8686b2de1cfdfa7cc7029 +Subproject commit b44b3494f60296db6aca38a14cab061d9b747a0a diff --git a/lib/RefBackend/CMakeLists.txt b/lib/RefBackend/CMakeLists.txt index 2ef5dab3ae8d..a8ed0439d815 100644 --- a/lib/RefBackend/CMakeLists.txt +++ b/lib/RefBackend/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_library(TorchMLIRRefBackend MLIRIR MLIRTransforms MLIRMathTransforms + MLIRLinalgTransforms ) mlir_check_all_link_libraries(TorchMLIRRefBackend) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 46023598c423..f04109873336 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -41,16 +41,16 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // ----- // CHECK-LABEL: func.func @torch.aten.leaky_relu$basic( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e-01 -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_0]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]] {shift = 0 : i32} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_0]], %[[VAL_5]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e-01 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_6]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 1.000000e-01 @@ -155,13 +155,13 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // ----- // CHECK-LABEL: func.func @torch.aten.add$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -175,13 +175,13 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // ----- // CHECK-LABEL: func.func @torch.aten.sub$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -195,13 +195,14 @@ func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // ----- // CHECK-LABEL: func.func @torch.aten.mul$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]] {shift = 0 : i32} : (tensor, tensor) -> tensor -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: } func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32> @@ -210,14 +211,15 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // ----- // CHECK-LABEL: func.func @torch.aten.div$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RCP:.*]] = tosa.reciprocal %[[ARG1_BUILTIN]] : (tensor) -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[RCP]] {shift = 0 : i32} : (tensor, tensor) -> tensor -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +// CHECK: } func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32> @@ -394,13 +396,13 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) // ----- // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -415,13 +417,13 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // ----- // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -502,7 +504,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // ----- // CHECK-LABEL: func.func @torch.aten.native_batch_norm$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>}> : () -> tensor<4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32> @@ -518,8 +520,8 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_13:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_9]], %[[VAL_12]] : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> // CHECK: %[[VAL_15:.*]] = tosa.rsqrt %[[VAL_14]] : (tensor<4x1xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_17]], %[[VAL_11]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32> @@ -538,7 +540,7 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32 // ----- // CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 @@ -579,9 +581,9 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3 // ----- // CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, -// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> @@ -595,22 +597,22 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3 // CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> // CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor // CHECK: %[[VAL_27:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_28:.*]] = tosa.add %[[VAL_23]], %[[VAL_26]] : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_29:.*]] = tosa.rsqrt %[[VAL_28]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_32:.*]] = tosa.add %[[VAL_31]], %[[VAL_25]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> // CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32> @@ -683,12 +685,12 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32> // ----- // CHECK-LABEL: func.func @torch.aten.log2$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i32} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -996,7 +998,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> // CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> // CHECK: %[[VAL_17:.*]] = tosa.gather %[[VAL_11]], %[[VAL_16]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> @@ -1019,7 +1021,7 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> // CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> // CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<2x2xi32>) -> tensor<2x2xi64> // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> @@ -1039,7 +1041,7 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_3:.*]] = torch.constant.int 256 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> // CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> // CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> @@ -1138,17 +1140,17 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to // ----- // CHECK-LABEL: func.func @torch.aten.remainder.Scalar( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> -// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5:.*]] : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3:.*]], %[[VAL_6:.*]] {shift = 0 : i32} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_5]], %[[VAL_8]] {shift = 0 : i32} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_10:.*]] = tosa.sub %[[VAL_3]], %[[VAL_9]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> -// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_1]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32> // CHECK: } func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { %int2 = torch.constant.int 2 @@ -1159,23 +1161,23 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to // ----- // CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[5,5],f32>, -// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> -// CHECK: %[[ATOL:.*]] = torch.constant.float 1.000000e-08 -// CHECK: %[[RTOL:.*]] = torch.constant.float 1.000000e-05 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[VAL_2:.*]] = tosa.sub %[[VAL_0]], %[[VAL_1]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_3:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_4:.*]] = tosa.abs %[[VAL_1]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i32} : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_8]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[5,5],i1> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.float 1.000000e-08 +// CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 +// CHECK: %[[VAL_6:.*]] = torch.constant.bool false +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_8:.*]] = tosa.abs %[[VAL_7]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.abs %[[VAL_3]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_9]] {shift = 0 : i8} : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.add %[[VAL_12]], %[[VAL_11]] : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_14:.*]] = tosa.greater_equal %[[VAL_13]], %[[VAL_8]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[5,5],i1> // CHECK: } func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { %float1.000000e-08 = torch.constant.float 1.000000e-08 diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index c1d1d915b017..5813cd4351ac 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -1,9 +1,11 @@ // RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' -split-input-file -verify-diagnostics %s | FileCheck %s -// CHECK-LABEL: torch.aten.mul.Scalar$mixed_type -// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> -// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i32} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> +// CHECK-LABEL: func.func @torch.aten.mul.Scalar$mixed_type( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>) -> tensor<5xbf16> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> +// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> +// CHECK: return %[[VAL_2]] : tensor<5xbf16> +// CHECK: } func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> { %float2.000000e00 = torch.constant.float 2.000000e+00 %0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16> @@ -88,12 +90,14 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?], // ----- -// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp -// CHECK-SAME: %[[VAL_0:.*]]: tensor, -// CHECK-SAME: %[[VAL_1:.*]]: tensor -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK-LABEL: func.func @torch.aten.div.Tensor$mixed_type_fp( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: return %[[VAL_4]] : tensor +// CHECK: } func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32> From 4279b750da10b4ded10ca6ccb1c120d7a4187a51 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Tue, 17 Oct 2023 14:49:47 -0700 Subject: [PATCH 38/41] update AtenClampOp in torch-to-tosa to handle fp inputs (#2516) As titled. --------- Co-authored-by: Ze Zhang --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 40 ++++++++++++++++------ test/Conversion/TorchToTosa/basic.mlir | 17 +++++++++ 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 970ef15d8d9a..c2c73708d79d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3984,19 +3984,37 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only tensor types input are currently supported"); - int64_t int_min, int_max; - if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_min` should be a torch constant int"); + IntegerAttr min_int, max_int; + FloatAttr min_fp, max_fp; + if (selfType.getElementType().isa()) { + double fp_min, fp_max; + if (!matchPattern(op.getMin(), m_TorchConstantFloat(&fp_min))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `fp_min` should be a torch constant float"); - if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_max` should be a torch constant int"); + if (!matchPattern(op.getMax(), m_TorchConstantFloat(&fp_max))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `fp_max` should be a torch constant float"); + + min_int = rewriter.getI64IntegerAttr(static_cast(fp_min)); + max_int = rewriter.getI64IntegerAttr(static_cast(fp_max)); + min_fp = rewriter.getF32FloatAttr(static_cast(fp_min)); + max_fp = rewriter.getF32FloatAttr(static_cast(fp_max)); + } else { + int64_t int_min, int_max; + if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `int_min` should be a torch constant int"); - IntegerAttr min_int = rewriter.getI64IntegerAttr(int_min); - IntegerAttr max_int = rewriter.getI64IntegerAttr(int_max); - FloatAttr min_fp = rewriter.getF32FloatAttr(float(int_min)); - FloatAttr max_fp = rewriter.getF32FloatAttr(float(int_max)); + if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) + return rewriter.notifyMatchFailure( + op, "unimplemented: value `int_max` should be a torch constant int"); + + min_int = rewriter.getI64IntegerAttr(int_min); + max_int = rewriter.getI64IntegerAttr(int_max); + min_fp = rewriter.getF32FloatAttr(static_cast(int_min)); + max_fp = rewriter.getF32FloatAttr(static_cast(int_max)); + } auto outType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf(), diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index f04109873336..180f48bcef2b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1072,6 +1072,23 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch return %0 : !torch.vtensor<[1,1,128,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.float( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],f32>) -> !torch.vtensor<[1,1,128,128],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],f32> -> tensor<1x1x128x128xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 6.432100e+00 : f32, max_int = 6 : i64, min_fp = 3.123400e+00 : f32, min_int = 3 : i64} : (tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xf32> -> !torch.vtensor<[1,1,128,128],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],f32> +// CHECK: } +func.func @torch.aten.clamp.float(%arg0: !torch.vtensor<[1,1,128,128],f32>) -> !torch.vtensor<[1,1,128,128],f32> { + %fp_min = torch.constant.float 3.123400e+00 + %fp_max = torch.constant.float 6.432100e+00 + %0 = torch.aten.clamp %arg0, %fp_min, %fp_max : !torch.vtensor<[1,1,128,128],f32>, !torch.float, !torch.float -> !torch.vtensor<[1,1,128,128],f32> + return %0 : !torch.vtensor<[1,1,128,128],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, From 52abae1526e51ae8c415ca98ce4a56b00782b68b Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 17 Oct 2023 22:00:26 -0700 Subject: [PATCH 39/41] Bump LLVM to get bazel fixes (#2517) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The last llvm bump in https://github.com/llvm/torch-mlir/pull/2511 pointed to https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a, however the bazel build upstream was not clean at this point: ``` ERROR: /root/.cache/bazel/_bazel_root/b89349c08f7224396763d14fe35cba11/external/llvm-project/mlir/BUILD.bazel:5837:18: TdGenerate external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/NVVMOpsInterface.h.inc failed: (Exit 1): mlir-tblgen failed: error executing command ... external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td:20:9: error: Could not find include file 'mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td' include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" ^ external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td:20:9: error: Unexpected token at top level include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" ^ ``` The bazel fixes followed in a subsequent commit at https://github.com/llvm/llvm-project/commit/28b27c1b10ae8d1f5b4fb9df691e8cf0da9be3f6. This PR bumps LLVM by a few more commits (to include the bazel fixes) which helps restore Torch-MLIR's bazel build back to 🟢 . GHA workflow to test bazel build: https://github.com/sjain-stanford/torch-mlir/actions/runs/6555101471/job/17803082508 --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index b44b3494f602..28b27c1b10ae 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit b44b3494f60296db6aca38a14cab061d9b747a0a +Subproject commit 28b27c1b10ae8d1f5b4fb9df691e8cf0da9be3f6 From b846437e40dfe9d678a072751b7ceff45c854013 Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Tue, 17 Oct 2023 22:42:14 -0700 Subject: [PATCH 40/41] Fix the names of arith MaximumF and MinimumF ops --- .../lib/Conversion/TcpToLinalg/Elementwise.cpp | 4 ++-- .../test/Conversion/TcpToLinalg/unary.mlir | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToLinalg/Elementwise.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToLinalg/Elementwise.cpp index dd495f932bd0..a476669625f6 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToLinalg/Elementwise.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToLinalg/Elementwise.cpp @@ -72,11 +72,11 @@ createLinalgPayloadForElementwiseOp(Operation *op, auto minFloat = clampOp.getMinFloat(); auto maxFloat = clampOp.getMaxFloat(); if (minFloat) - result = b.create( + result = b.create( loc, result, b.create(loc, *minFloat, b.getF32Type())); if (maxFloat) - result = b.create( + result = b.create( loc, result, b.create(loc, *maxFloat, b.getF32Type())); } else if (elemType.isa()) { diff --git a/externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/unary.mlir b/externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/unary.mlir index 14a50e7d2475..ab7a58eeb9ae 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/unary.mlir +++ b/externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/unary.mlir @@ -43,9 +43,9 @@ func.func @tanh(%arg0 : tensor) -> tensor { // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor) { // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): // CHECK: %[[CST0:.*]] = arith.constant 1.000000e-01 : f32 -// CHECK: %[[MAX:.*]] = arith.maxf %[[BBARG0]], %[[CST0]] : f32 +// CHECK: %[[MAX:.*]] = arith.maximumf %[[BBARG0]], %[[CST0]] : f32 // CHECK: %[[CST1:.*]] = arith.constant 1.024000e+03 : f32 -// CHECK: %[[MIN:.*]] = arith.minf %[[MAX]], %[[CST1]] : f32 +// CHECK: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[CST1]] : f32 // CHECK: linalg.yield %[[MIN]] : f32 // CHECK: } -> tensor // CHECK: return %[[GENERIC]] : tensor From 9624268bb51ce0a6227370f2932da040266265ef Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Tue, 17 Oct 2023 23:01:59 -0700 Subject: [PATCH 41/41] [Tcp] Add new e2e tests to pass list --- e2e_testing/xfail_sets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index eb575d94e515..3a553b86ab7a 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1030,6 +1030,8 @@ "TypePromotionSameCategoryDifferentWidthModule_basic", "TypePromotionSameCategoryZeroRankWider_basic", "TypePromotionZeroRankHigherCategoryModule_basic", + "ElementwiseAddScalarInt8Module_basic", + "ElementwiseSubTensorInt8Module_basic", "ElementwiseMulScalarModule_basic", "ElementwiseMulScalarModule_float", "ElementwiseMulScalarModule_int", @@ -1047,6 +1049,8 @@ "BatchNorm1DStaticShapeModule_basic", "BroadcastListConstructWithMinusOneModule_basic", "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ToDtypeLayoutCPUModule_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", }