From c5a1da1910f8e1a5dac748eb2806833bd4f1b0c2 Mon Sep 17 00:00:00 2001 From: ptrifunovic98 <156185835+ptrifunovic98@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:46:56 +0100 Subject: [PATCH 01/12] Implement lowering of torch.aten.norm.Scalar (#2899) Closes [nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++ lib/Conversion/TorchToLinalg/Reduction.cpp | 53 ++++++++++++++++--- lib/Dialect/Torch/IR/TorchOps.cpp | 36 +++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 32 +++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 18 +++++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 19 +++++++ 8 files changed, 177 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index cc8be7c6910b..dc1203de9471 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ }]; } +def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$p + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenNormScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index a21615ad84c4..e050764993e6 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -275,7 +275,8 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, elementType.getIntOrFloatBitWidth()))); } - if (isa(op) || isa(op)) + if (isa(op) || isa(op) || + isa(op)) return b.create(loc, b.getZeroAttr(elementType)); if (isa(op)) { @@ -341,6 +342,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (intType.isSigned()) return b.create(loc, self, result); } + } else if (isa(op)) { + // This creates payload for only the first of the two linalg.generic ops. + // TODO: Short-circuit operations if `p` is zero or one. + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + + // TODO: Fix this part to support complex elements. + if (elem.getType().isa()) { + op->emitError("lowering of complex input type for torch.aten.norm.Scalar " + "is currently unimplemented"); + return nullptr; + } + + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + + auto abs = b.create(loc, self); + AtenNormScalarOp::Adaptor adaptor(operands); + Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType); + auto pow = b.create(loc, abs, p); + return b.create(loc, pow, result); } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. // TODO: Short-circuit operations if `ord` is zero or one. @@ -433,7 +454,7 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa(op)) { + if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); @@ -484,10 +505,12 @@ class ConvertReductionOp : public ConversionPattern { return err ? Value{} : powOp; } - FailureOr createSecondReductionForVectorNormOp( - Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp, - Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo, - ConversionPatternRewriter &rewriter) const { + template + FailureOr + createSecondReductionForNormOp(Location loc, Type elemType, TOp op, + Value ordOp, Value firstReduction, + const torch_to_linalg::ReductionOpInfo &opInfo, + ConversionPatternRewriter &rewriter) const { // Cast `ord` to float so that we can readily pass it math.powf. Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType); @@ -544,13 +567,15 @@ class ConvertReductionOp : public ConversionPattern { LogicalResult validateReductionElementType(Operation *op, Type elemType, ConversionPatternRewriter &rewriter) const { - if ((isa(op) || isa(op)) && + if ((isa(op) || isa(op) || + isa(op)) && !elemType.isa()) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); if (isa(op) && elemType.isa() && elemType.getIntOrFloatBitWidth() == 8) return rewriter.notifyMatchFailure(op, "uint8 is not supported"); + // No checks for all other reduction operations return success(); } @@ -587,11 +612,22 @@ class ConvertReductionOp : public ConversionPattern { return rewriter.notifyMatchFailure( op, "failed to create linalg.generic operation for reduction"); + // If this is aten.norm.Scalar op, then we need to generate another + // linalg.generic op that references the first linalg.generic op. + if (isa(op)) { + AtenNormScalarOp::Adaptor adaptor(operands); + FailureOr secondReduceOp = createSecondReductionForNormOp( + loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter); + if (failed(secondReduceOp)) + return secondReduceOp; + reduceOp = *secondReduceOp; + } + // If this is aten.linalg_vector_norm op, then we need to generate another // linalg.generic op that references the first linalg.generic op. if (auto normOp = dyn_cast(op)) { AtenLinalgVectorNormOp::Adaptor adaptor(operands); - FailureOr secondReduceOp = createSecondReductionForVectorNormOp( + FailureOr secondReduceOp = createSecondReductionForNormOp( loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter); if (failed(secondReduceOp)) return secondReduceOp; @@ -627,6 +663,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index da6f71015942..ef3098eb1c12 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3767,6 +3767,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenNormScalarOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenNormScalarOp::verify() { + + // Verificaion of input type for torch.aten.norm.Scalar. + // Per PyTorch docs, only float and complex types are valid for norm + // operation. + + auto inTensor = getSelf().getType().cast(); + + // If no dtype is specified, it will default to a float one. + if (!inTensor.hasDtype()) { + return success(); + } + + auto inTensorDtype = inTensor.getDtype(); + + // Check if dtype is one of those supported by norm operation. + // ComplexType will match any torch complex types, but each float must be + // checked individually. + if (!inTensorDtype.isa()) { + return emitOpError( + "expected a float or complex type for input tensor, but got ") + << inTensorDtype; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// AtenPermuteOp +//===----------------------------------------------------------------------===// + LogicalResult AtenPermuteOp::verify() { // Verification of the permute op for input & output dimensions with diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bfc2fc6a1d0c..a8327b0e0da6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9339,6 +9339,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %1 = torch.derefine %none : !torch.none to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %false, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %0 = torch.derefine %arg2 : !torch.list to !torch.optional>\n" @@ -12038,6 +12046,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %int5 = torch.constant.int 5\n" +" %int8 = torch.constant.int 8\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e749b5834cc6..70f26fe421e0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1667,6 +1667,7 @@ "NllLossModule_ignore_index_out_of_bounds_basic", "NllLossModule_mean_basic", "NllLossModule_sum_basic", + "NormScalarModule_basic", "NormScalarOptDimKeepDimModule_basic", "NormScalarOptDimModule_basic", "NormalFunctionalModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 403d124ad927..99f4f2200d35 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1722,6 +1722,9 @@ def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Opti def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) +def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, None, False, None) + def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) @@ -3924,6 +3927,21 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni return dtype return aten〇std〡dtype(self_rank_dtype) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) +def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex] = 2) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + # The following check is added because aten〇std〡dtype + # does not handle complex32 transformation to float, + # so it is done manually (torch.half == torch.float16). + # Should possibly be added to aten〇std〡dtype. + if self_dtype == torch.complex32: + return torch.half + return aten〇std〡dtype(self_rank_dtype) + @check_dtype_function([Invocation(0.0), Invocation(0.0, dtype=torch.int32), Invocation(0.0, dtype=torch.float16), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 51c196421b78..cc41a99be228 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -449,6 +449,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True) emit( "aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 2c61524bd797..d0d6c2ea2dfa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1100,6 +1100,25 @@ def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils): # ============================================================================== +class NormScalarModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.p = 3.0 + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.norm(a, self.p) + +@register_test_case(module_factory=lambda: NormScalarModule()) +def NormScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class NormScalarOptDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() From 3cbe6c98ec9a67964ecb5947f7664e34e9ba4b5b Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 26 Feb 2024 10:08:14 -0800 Subject: [PATCH 02/12] Expose `func_name` to the main fx import API (#2949) As titled. --- python/torch_mlir/extras/fx_importer.py | 4 ++-- python/torch_mlir/fx.py | 5 +++-- test/python/fx_importer/basic_test.py | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 91f3c27ee263..e6d0f03deda4 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -623,7 +623,7 @@ def import_program( node_importer.return_node_values(loc, user_outputs) self.symbol_table.insert(func_op) - def import_frozen_program(self, prog: torch.export.ExportedProgram): + def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"): """Imports a consolidated torch.export.ExportedProgram instance. If using the new torch.export path (vs a lower level precursor), then this is @@ -702,7 +702,7 @@ def import_frozen_program(self, prog: torch.export.ExportedProgram): node.replace_all_uses_with(replacement) g.erase_node(node) - self.import_stateless_graph(g) + self.import_stateless_graph(g, func_name) def import_graph_module(self, gm: GraphModule): """Low-level import of a GraphModule assuming that it has been functionalized. diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 1f5aa8f74add..76cd91f82e0a 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -23,6 +23,7 @@ def export_and_import( constraints: Optional[torch.export.Constraint] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, + func_name: str = "main", **kwargs, ): context = ir.Context() @@ -36,8 +37,8 @@ def export_and_import( if experimental_support_mutation: if torch.__version__ < "2.3.0.dev20240207": warnings.warn("Mutable program import only supported on PyTorch 2.3+") - fx_importer.import_program(prog) + fx_importer.import_program(prog, func_name=func_name) else: - fx_importer.import_frozen_program(prog) + fx_importer.import_frozen_program(prog, func_name=func_name) return fx_importer.module_op diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 36c554862506..fc5b2030b648 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -56,3 +56,24 @@ def forward(self, x): m = fx.export_and_import(Basic(), torch.randn(3, 4)) print(m) + + +@run +# CHECK-LABEL: test_import_frozen_exported_program_with_func_name +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +def test_import_frozen_exported_program_with_func_name(): + @torch._dynamo.assume_constant_result + def get_a(): + return torch.randn(1, 4) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.b = torch.randn(3, 1) + self.p = nn.Parameter(torch.randn(1, 1)) + + def forward(self, x): + return torch.tanh(x) * get_a() * self.b * self.p + + m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net") + print(m) From 0ee752bc688c49d3b55995bb85db9262c8fdaad0 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 13 Jan 2024 08:44:23 +0530 Subject: [PATCH 03/12] ADDED SUPPORT FLOAT VALUE IN ARANGE --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 59 ++++-- projects/pt1/e2e_testing/xfail_sets.py | 214 +++++++++++++++++++++ 2 files changed, 260 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b49c9af8adce..0b5819631f52 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -26,6 +26,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include using namespace mlir; using namespace mlir::torch; @@ -4067,28 +4068,60 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - int64_t start, step, end; - if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + double start, step, end; + int64_t start_int, step_int, end_int; + bool is_all_inp_int; //Flag to check whether all inputs are integer + is_all_inp_int = op.getStart().getType().isa() && op.getEnd().getType().isa() && op.getStep().getType().isa(); + + if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) + { + start = (double)(start_int); + } + + else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int"); + op, "unimplemented: value `start` should be a torch constant int or float"); - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) + { + end = (double)(end_int); + } + else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int"); + op, "unimplemented: value `end` should be a torch constant int or float"); - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) + { + + step = (double)(step_int); + } + + else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int"); + op, "unimplemented: value `step` should be a torch constant int or float"); // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) - int64_t resultShape = ceil((float)(end - start) / (float)step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; - Value result = - tosa::getConstTensor(rewriter, op, values, resultShape).value(); + int64_t resultShape = ceil((end - start) / step); + Value result; + if (is_all_inp_int) + { + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += i * step; + + result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + } + + else + { + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += (i * step); + + result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + } rewriter.replaceOpWithNewOp(op, resultType, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 70f26fe421e0..c9768152da25 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -871,6 +871,212 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Convolution2DStridedModule_basic", + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "AliasModule_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseUnaryModule_basic", + "ElementwiseBinaryModule_basic", + "ElementwiseSigmoidModule_basic", + "ElementwiseExpModule_basic", + "ElementwiseReluModule_basic", + "ElementwiseLeakyReluModule_basic", + "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseFloorModule_basic", + "ElementwiseFloorIntModule_basic", + "ElementwiseLogModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", + "ElementwiseMinimumModule_basic", + "ElementwiseMinimumIntModule_basic", + "ElementwiseMinOtherIntModule_basic", + "ElementwiseMinOtherModule_basic", + "ElementwiseMaximumModule_basic", + "ElementwiseMaximumIntModule_basic", + "ElementwiseMaxOtherIntModule_basic", + "ElementwiseMaxOtherModule_basic", + "GluStaticModule_basic", + "ViewDoubleMergeStaticModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewFiveTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ViewExpandOnesMiddleOppModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "TanhBackward_basic", + "HardtanhBackward_basic", + "ElementwiseAddModule_basic", + "ReturnThreeTensorFloat32_basic", + "AddCMulModule_basic", + "AddCDivModule_basic", + "SqueezeModule_broadcast", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnTrueModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorHandleSignless_basic", + "ElementwiseRsqrtModule_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SqueezeModule_static", + "SqueezeModule_noUnitDim", + "SqueezeModule_allUnitDim", + "TModuleRank1_basic", + "TModuleRank0_basic", + "ElementwiseToDtypeIdentityModule_basic", + "AtenToDeviceModule_basic", + "View1DFoldModule_basic", + "UnsafeView1DFoldModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "SqueezeDimModule_static", + "SqueezeDimModule_identity", + "SqueezeDimModule_unitDim", + "ReturnTwoTensorF32I64_basic", + "ElementwiseSignModule_basic", + "ElementwisePowModule_basic", + "BmmFloatModule_basic", + "MmDagModule_basic", + "Matmul4dStatic_basic", + "Matmul_dot", + "Matmul_3d", + "RsubFloatModule_basic", + "RsubFloatModule_noalpha_basic", + "RsubInt0d_NumToTensor_Module_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", + "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseBitwiseOrModule_basic", + "ElementwiseBitwiseOrStaticShapeModule_basic", + "ElementwiseBitwiseXorModule_basic", + "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseGeFloatIntScalarModule_basic", + "ElementwiseGeFloatScalarModule_basic", + "ElementwiseGeIntScalarModule_basic", + "ElementwiseGeMixedIntScalarModule_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseGtFloatTensorModule_basic", + "ElementwiseGtIntTensorModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatTensorModule_basic", + "ElementwiseLtIntTensorModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqBoolScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatTensorModule_basic", + "ElementwiseEqIntTensorModule_basic", + "ElementwiseNeFloatScalarModule_basic", + "ElementwiseNeFloatTensorModule_basic", + "ElementwiseNeFloatTensorStaticModule_basic", + "ElementwiseNeIntTensorModule_basic", + "ElementwiseNeIntTensorStaticModule_basic", + "ElementwiseMulScalarModule_int", + "ElementwiseMulScalarModule_float", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "ElementwiseMulScalarModule_float", + "ElementwiseCeilModule_basic", + "ElementwiseReciprocalModule_basic", + "ElementwiseIsnanModule_basic", + "ElementwiseIsinfModule_basic", + "TypePromotionAlphaWiderModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", + "FlattenStaticModule_basic", + "UnflattenStaticModule_basic", + "FlattenRank0Module_basic", + "ElementwiseFlattenBroadcastModule_basic", + "SquareModule_basic", + "MaxPool2dStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "ResNet18StaticModule_basic", + "ReduceAmaxKeepDim_basic", + "NativeLayerNormModule4D_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "ElementwiseLog2Module_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dFloatModule_basic", + "Threshold2dFloatModule_basic", + "Threshold3dFloatModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseMulScalarModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleFalsePinMemory_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleInt_basic", + "OnesModuleFloat_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleCPUDevice_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleFalsePinMemory_basic", + "SiluModule_basic", + "DropoutEvalIntModule_basic", + "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChangeStaticModule_basic", + "UnsafeViewExpandModule_basic", + "ReshapeCollapseModule_basic", + "ReshapeAsModule_basic", + "ElementwiseGeluModule_basic", + "GeluBackwardModule_basic", + "ElementwiseNeIntScalarModule_basic", + "Convolution2DStaticModule_basic", + "ElementwiseNegModule_basic", + "TestMultipleTensorReturn_basic", + "TypeAsSameModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDivModule_basic", @@ -888,6 +1094,14 @@ "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartStepFloatModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", From 6b26100dc6710f71c222298d9956cd4cc6eb8b08 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 13 Jan 2024 08:58:07 +0530 Subject: [PATCH 04/12] got rid of extra tosa tests --- projects/pt1/e2e_testing/xfail_sets.py | 206 ------------------------- 1 file changed, 206 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c9768152da25..0e92f61b89c8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -871,212 +871,6 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "Convolution2DStridedModule_basic", - "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic", - "TileBigDimsSizeModule_basic", - "TileSmallDimsSizeModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", - "AliasModule_basic", - "MaxPool2dEmptyStrideStaticModule_basic", - "ConstantBoolParameterModule_basic", - "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", - "ElementwiseCloneModule_basic", - "ElementwiseUnaryModule_basic", - "ElementwiseBinaryModule_basic", - "ElementwiseSigmoidModule_basic", - "ElementwiseExpModule_basic", - "ElementwiseReluModule_basic", - "ElementwiseLeakyReluModule_basic", - "ElementwiseEluModule_basic", - "ElementwiseEluNonDefaultModule_basic", - "ElementwiseFloorModule_basic", - "ElementwiseFloorIntModule_basic", - "ElementwiseLogModule_basic", - "ElementwiseBinaryStaticShapeModule_basic", - "ElementwiseMinimumModule_basic", - "ElementwiseMinimumIntModule_basic", - "ElementwiseMinOtherIntModule_basic", - "ElementwiseMinOtherModule_basic", - "ElementwiseMaximumModule_basic", - "ElementwiseMaximumIntModule_basic", - "ElementwiseMaxOtherIntModule_basic", - "ElementwiseMaxOtherModule_basic", - "GluStaticModule_basic", - "ViewDoubleMergeStaticModule_basic", - "ViewCollapseOnesMiddleModule_basic", - "ViewFiveTestStaticModule_basic", - "ViewOffsetTestStaticModule_basic", - "ViewTwoFiveThreeStaticModule_basic", - "ViewTwoToThreeStaticModule_basic", - "ViewExpandOnesMiddleOppModule_basic", - "ViewOffsetBackwardTestStaticModule_basic", - "TanhBackward_basic", - "HardtanhBackward_basic", - "ElementwiseAddModule_basic", - "ReturnThreeTensorFloat32_basic", - "AddCMulModule_basic", - "AddCDivModule_basic", - "SqueezeModule_broadcast", - "BoolTensorReturnFalseModule_basic", - "BoolTensorReturnTrueModule_basic", - "BoolTensorReturnMixedModule_basic", - "BoolTensorHandleSignless_basic", - "ElementwiseRsqrtModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", - "SqueezeModule_static", - "SqueezeModule_noUnitDim", - "SqueezeModule_allUnitDim", - "TModuleRank1_basic", - "TModuleRank0_basic", - "ElementwiseToDtypeIdentityModule_basic", - "AtenToDeviceModule_basic", - "View1DFoldModule_basic", - "UnsafeView1DFoldModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "SqueezeDimModule_static", - "SqueezeDimModule_identity", - "SqueezeDimModule_unitDim", - "ReturnTwoTensorF32I64_basic", - "ElementwiseSignModule_basic", - "ElementwisePowModule_basic", - "BmmFloatModule_basic", - "MmDagModule_basic", - "Matmul4dStatic_basic", - "Matmul_dot", - "Matmul_3d", - "RsubFloatModule_basic", - "RsubFloatModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", - "ElementwiseBitwiseAndModule_basic", - "ElementwiseBitwiseAndStaticShapeModule_basic", - "ElementwiseBitwiseNotInt32Module_basic", - "ElementwiseBitwiseNotInt64Module_basic", - "ElementwiseOrTensorStaticShapeModule_basic", - "ElementwiseOrTensorModule_basic", - "ElementwiseBitwiseOrModule_basic", - "ElementwiseBitwiseOrStaticShapeModule_basic", - "ElementwiseBitwiseXorModule_basic", - "ElementwiseBitwiseXorStaticShapeModule_basic", - "ElementwiseGeFloatIntScalarModule_basic", - "ElementwiseGeFloatScalarModule_basic", - "ElementwiseGeIntScalarModule_basic", - "ElementwiseGeMixedIntScalarModule_basic", - "ElementwiseGtFloatScalarModule_basic", - "ElementwiseGtIntScalarModule_basic", - "ElementwiseGtMixed2ScalarModule_basic", - "ElementwiseGtFloatTensorModule_basic", - "ElementwiseGtIntTensorModule_basic", - "ElementwiseLtFloatScalarModule_basic", - "ElementwiseLtIntScalarModule_basic", - "ElementwiseLtDiffWidthScalarModule_basic", - "ElementwiseLtFloatTensorModule_basic", - "ElementwiseLtIntTensorModule_basic", - "ElementwiseEqFloatScalarModule_basic", - "ElementwiseEqIntScalarModule_basic", - "ElementwiseEqBoolScalarModule_basic", - "ElementwiseEqDiffWidthScalarModule_basic", - "ElementwiseEqFloatTensorModule_basic", - "ElementwiseEqIntTensorModule_basic", - "ElementwiseNeFloatScalarModule_basic", - "ElementwiseNeFloatTensorModule_basic", - "ElementwiseNeFloatTensorStaticModule_basic", - "ElementwiseNeIntTensorModule_basic", - "ElementwiseNeIntTensorStaticModule_basic", - "ElementwiseMulScalarModule_int", - "ElementwiseMulScalarModule_float", - "ElementwiseMulTensorIntModule_basic", - "ElementwiseDivScalarModule_basic", - "ElementwiseAtenDivIntScalarModule_basic", - "ElementwiseSubScalarFloatModule_basic", - "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", - "ElementwiseMulScalarModule_float", - "ElementwiseCeilModule_basic", - "ElementwiseReciprocalModule_basic", - "ElementwiseIsnanModule_basic", - "ElementwiseIsinfModule_basic", - "TypePromotionAlphaWiderModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", - "FlattenStaticModule_basic", - "UnflattenStaticModule_basic", - "FlattenRank0Module_basic", - "ElementwiseFlattenBroadcastModule_basic", - "SquareModule_basic", - "MaxPool2dStaticModule_basic", - "MaxPool2dStaticCeilModeTrueModule_basic", - "ResNet18StaticModule_basic", - "ReduceAmaxKeepDim_basic", - "NativeLayerNormModule4D_basic", - "LayerNormNormalizeOverAllDimsModule_basic", - "PermuteModule_basic", - "PermuteNegativeIndexModule_basic", - "ElementwiseLog2Module_basic", - "Threshold1dIntI32Module_basic", - "Threshold1dFloatModule_basic", - "Threshold2dFloatModule_basic", - "Threshold3dFloatModule_basic", - "ElementwiseSubScalarIntModule_basic", - "ElementwiseAddScalarIntModule_basic", - "ElementwiseMulScalarModule_basic", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleFalsePinMemory_basic", - "OnesModuleDefaultDtype_basic", - "OnesModuleInt_basic", - "OnesModuleFloat_basic", - "OnesModuleFalsePinMemory_basic", - "OnesModuleCPUDevice_basic", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleFalsePinMemory_basic", - "NewOnesModuleDefaultDtype_basic", - "NewOnesModuleInt2D_basic", - "NewOnesModuleInt3D_basic", - "NewOnesModuleFloat2D_basic", - "NewOnesModuleFloat3D_basic", - "NewOnesModuleFalsePinMemory_basic", - "SiluModule_basic", - "DropoutEvalIntModule_basic", - "DropoutEvalFloatModule_basic", - "ContiguousModule_basic", - "DropoutModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewExpandCollapseModule_basic", - "ViewExpandCollapseWithOnesModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", - "ViewNoChangeStaticModule_basic", - "UnsafeViewExpandModule_basic", - "ReshapeCollapseModule_basic", - "ReshapeAsModule_basic", - "ElementwiseGeluModule_basic", - "GeluBackwardModule_basic", - "ElementwiseNeIntScalarModule_basic", - "Convolution2DStaticModule_basic", - "ElementwiseNegModule_basic", - "TestMultipleTensorReturn_basic", - "TypeAsSameModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDivModule_basic", From ef559c5d97864faabde4bccbee2d72fe60378ffb Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Tue, 16 Jan 2024 08:55:37 +0530 Subject: [PATCH 05/12] git rid of iostream import --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0b5819631f52..a46fdc5f549c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -26,7 +26,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" -#include using namespace mlir; using namespace mlir::torch; From 08a289f326bb139914388b1e309cdabde411c4ae Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 3 Feb 2024 13:50:02 +0530 Subject: [PATCH 06/12] using int in result shape --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index a46fdc5f549c..001b9ad981a7 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4102,11 +4102,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) - int64_t resultShape = ceil((end - start) / step); + Value result; if (is_all_inp_int) { - SmallVector values(resultShape, start); + int64_t resultShape = ceil((float)(end_int - start_int) / (float)(step_int)); + SmallVector values(resultShape, start_int); for (unsigned i = 1; i < resultShape; i++) values[i] += i * step; @@ -4115,6 +4116,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( else { + int64_t resultShape = ceil((end - start) / step); SmallVector values(resultShape, start); for (unsigned i = 1; i < resultShape; i++) values[i] += (i * step); From 7f3caa87ab5b7047946eadb911ff2af138261d4b Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Mon, 5 Feb 2024 22:25:27 +0530 Subject: [PATCH 07/12] got rid of resultshape for int case --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 001b9ad981a7..dbab3d40d14c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4106,22 +4106,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value result; if (is_all_inp_int) { - int64_t resultShape = ceil((float)(end_int - start_int) / (float)(step_int)); - SmallVector values(resultShape, start_int); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; + SmallVector values(start_int); + for (int64_t i = start_int; i < end_int; i += step_int) + values.push_back(i); - result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } else { - int64_t resultShape = ceil((end - start) / step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) + int64_t resultSize = ceil((end - start) / step); + SmallVector values(resultSize, start); + for (unsigned i = 1; i < resultSize; i++) values[i] += (i * step); - result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); } rewriter.replaceOpWithNewOp(op, resultType, result); From 0f6ef1fc5f9cc91daf8d951d692a67391ac1ba6e Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Tue, 6 Feb 2024 22:35:54 +0530 Subject: [PATCH 08/12] got rid of result shape in all int case --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index dbab3d40d14c..cb2b7fe2b8e5 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4106,9 +4106,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value result; if (is_all_inp_int) { - SmallVector values(start_int); - for (int64_t i = start_int; i < end_int; i += step_int) - values.push_back(i); + SmallVector values; + if (step_int >= 0) + { + for (int64_t i = start_int; i < end_int; i += step_int) + values.push_back(i); + } + + else + { + for (int64_t i = start_int; i > end_int; i += step_int) + values.push_back(i); + } result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } From 8b57a512b597d5e5f01f265e3b93a2b246882cc8 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Thu, 15 Feb 2024 22:35:25 +0530 Subject: [PATCH 09/12] using static cast instead of dynamic cast --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index cb2b7fe2b8e5..24ad70235607 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4069,12 +4069,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( double start, step, end; int64_t start_int, step_int, end_int; - bool is_all_inp_int; //Flag to check whether all inputs are integer - is_all_inp_int = op.getStart().getType().isa() && op.getEnd().getType().isa() && op.getStep().getType().isa(); + auto isInteger = [=](Value v) { return v.getType().isa(); }; + //Flag to check whether all inputs are integer + bool integer_range = isInteger(op.getStart()) && isInteger(op.getEnd()) && isInteger(op.getStep()); if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) { - start = (double)(start_int); + start = static_cast(start_int); } else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) @@ -4083,7 +4084,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) { - end = (double)(end_int); + end = static_cast(end_int); } else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) return rewriter.notifyMatchFailure( @@ -4092,7 +4093,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) { - step = (double)(step_int); + step = static_cast(step_int); } else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) @@ -4104,7 +4105,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // ceil((end - start)/step) Value result; - if (is_all_inp_int) + if (integer_range) { SmallVector values; if (step_int >= 0) From 3140ab1a5990507b696c088d597f8c6c06c5fa80 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Mon, 19 Feb 2024 20:19:09 +0530 Subject: [PATCH 10/12] typecasting for int64type --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 24ad70235607..5eb22637b12a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4070,6 +4070,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( double start, step, end; int64_t start_int, step_int, end_int; auto isInteger = [=](Value v) { return v.getType().isa(); }; + bool isOutputInt64=false; + auto intType = resultType.getElementType().dyn_cast_or_null(); + + if(intType) + { + if(intType.getWidth() == 64) + { + isOutputInt64 = true; + } + } //Flag to check whether all inputs are integer bool integer_range = isInteger(op.getStart()) && isInteger(op.getEnd()) && isInteger(op.getStep()); @@ -4123,6 +4133,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } + //Since typecasting from float32 or float64 to int64 results in, seemingly + //garbage values. Therefore typecasting here itself. + else if(isOutputInt64) + { + int64_t resultSize = ceil((end - start) / step); + SmallVector values(resultSize, start); + for (unsigned i = 1; i < resultSize; i++) + values[i] += static_cast(i * step); + + result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); + } + else { int64_t resultSize = ceil((end - start) / step); From 4c185db1cfef941d458e178e079112db8447e05e Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 26 Feb 2024 11:49:43 -0800 Subject: [PATCH 11/12] git format, add some stylistic changes --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 181 ++++++++++++--------- 1 file changed, 103 insertions(+), 78 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5eb22637b12a..06aaf8965acd 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8,24 +8,22 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" -#include "torch-mlir/Conversion/Utils/Utils.h" - #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" +#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include using namespace mlir; using namespace mlir::torch; @@ -4067,93 +4065,120 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - double start, step, end; - int64_t start_int, step_int, end_int; - auto isInteger = [=](Value v) { return v.getType().isa(); }; - bool isOutputInt64=false; - auto intType = resultType.getElementType().dyn_cast_or_null(); - - if(intType) - { - if(intType.getWidth() == 64) - { - isOutputInt64 = true; + // Stores a range value (start / end / step) and whether it was initiated with + // a constant integer, an constant float or neither. + class ConstRangeValue { + public: + explicit ConstRangeValue(double v) + : vDouble(v), fromDouble(true), vInt(static_cast(v)), + fromInt(false) {} + + explicit ConstRangeValue(int64_t v) + : vDouble(static_cast(v)), fromDouble(false), vInt(v), + fromInt(true) {} + + ConstRangeValue() + : vDouble(0), fromDouble(false), vInt(0), fromInt(false) {} + + bool hasConstInt() const { return fromInt; } + bool hasConstDouble() const { return fromDouble; } + bool hasConst() const { return fromInt || fromDouble; } + double getDouble() const { return vDouble; } + int64_t getInt() const { return vInt; } + + private: + double vDouble; + bool fromDouble; + int64_t vInt; + bool fromInt; + }; + + auto setConstantIntOrFloat = [](Value v) -> ConstRangeValue { + int64_t intVal{0}; + double floatVal{0.0}; + if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { + return ConstRangeValue(floatVal); + } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { + return ConstRangeValue(intVal); } - } - //Flag to check whether all inputs are integer - bool integer_range = isInteger(op.getStart()) && isInteger(op.getEnd()) && isInteger(op.getStep()); - - if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) - { - start = static_cast(start_int); - } + return ConstRangeValue(); + }; - else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) + auto start = setConstantIntOrFloat(op.getStart()); + if (!start.hasConst()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int or float"); - - if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) - { - end = static_cast(end_int); + op, "unimplemented: case where `start` is not a constant int or float"); } - else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) + + auto end = setConstantIntOrFloat(op.getEnd()); + if (!end.hasConst()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int or float"); + op, + "unimplemented: case where value `end` is not a constant int or float"); + } - if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) - { - - step = static_cast(step_int); + auto step = setConstantIntOrFloat(op.getStep()); + if (!step.hasConst()) { + return rewriter.notifyMatchFailure(op, + "unimplemented: case where value `step` " + "is not a constant int or float"); } - else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int or float"); + auto getRange = [](auto start, auto end, auto step) { + // Initialize a small vector of the same type as start: + using T = decltype(start); + SmallVector values; - // The result will always be a 1-d tensor. - // The size of the result is calculated as follows: - // ceil((end - start)/step) - - Value result; - if (integer_range) - { - SmallVector values; - if (step_int >= 0) - { - for (int64_t i = start_int; i < end_int; i += step_int) - values.push_back(i); + uint64_t counter{0}; + if (start == end) { + return values; } - - else - { - for (int64_t i = start_int; i > end_int; i += step_int) - values.push_back(i); + assert(step != T(0)); + values.reserve( + 1 + static_cast(std::abs((end - start) / std::abs(step)))); + if (step > 0) { + while (start + T(counter) * step < end) { + values.push_back(start + counter * step); + counter++; + } + } else { + while (start + T(counter) * step > end) { + values.push_back(start + counter * step); + counter++; + } } + return values; + }; - result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); - } + const auto intType = + resultType.getElementType().dyn_cast_or_null(); - //Since typecasting from float32 or float64 to int64 results in, seemingly - //garbage values. Therefore typecasting here itself. - else if(isOutputInt64) - { - int64_t resultSize = ceil((end - start) / step); - SmallVector values(resultSize, start); - for (unsigned i = 1; i < resultSize; i++) - values[i] += static_cast(i * step); - - result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); - } + auto maybeResult = [&]() -> std::optional { + if (intType && start.hasConstInt() && end.hasConstInt() && + step.hasConstInt()) { + auto values = getRange(start.getInt(), end.getInt(), step.getInt()); + return tosa::getConstTensor(rewriter, op, values, values.size()); + } - else - { - int64_t resultSize = ceil((end - start) / step); - SmallVector values(resultSize, start); - for (unsigned i = 1; i < resultSize; i++) - values[i] += (i * step); - - result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); + auto values = + getRange(start.getDouble(), end.getDouble(), step.getDouble()); + if (intType) { + SmallVector values_i64; + values_i64.reserve(values.size()); + for (auto v : values) { + values_i64.push_back(static_cast(v)); + } + return tosa::getConstTensor(rewriter, op, values_i64, + values.size()); + } + return tosa::getConstTensor(rewriter, op, values, values.size()); + }(); + + if (!maybeResult.has_value()) { + return rewriter.notifyMatchFailure( + op, "failed to generate constant tensor for arange"); } + auto result = maybeResult.value(); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); From 9b4ae1e06805ce2453c10e4c9b517bea5a9a1476 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 26 Feb 2024 12:21:58 -0800 Subject: [PATCH 12/12] update --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 59 ++++++++++++++-------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 06aaf8965acd..ce0a1af2f834 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include @@ -4065,8 +4066,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - // Stores a range value (start / end / step) and whether it was initiated with - // a constant integer, an constant float or neither. + // Stores a range value (a start, end, or step value) and whether or not it + // was initiated with a constant integer, an constant float or neither. class ConstRangeValue { public: explicit ConstRangeValue(double v) @@ -4077,9 +4078,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( : vDouble(static_cast(v)), fromDouble(false), vInt(v), fromInt(true) {} + // Constructor for the case where there is no constant value to use. ConstRangeValue() : vDouble(0), fromDouble(false), vInt(0), fromInt(false) {} + static ConstRangeValue fromValue(Value v) { + int64_t intVal{0}; + double floatVal{0.0}; + if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { + return ConstRangeValue(floatVal); + } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { + return ConstRangeValue(intVal); + } + return ConstRangeValue(); + } + bool hasConstInt() const { return fromInt; } bool hasConstDouble() const { return fromDouble; } bool hasConst() const { return fromInt || fromDouble; } @@ -4093,31 +4106,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( bool fromInt; }; - auto setConstantIntOrFloat = [](Value v) -> ConstRangeValue { - int64_t intVal{0}; - double floatVal{0.0}; - if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { - return ConstRangeValue(floatVal); - } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { - return ConstRangeValue(intVal); - } - return ConstRangeValue(); - }; - - auto start = setConstantIntOrFloat(op.getStart()); + auto start = ConstRangeValue::fromValue(op.getStart()); if (!start.hasConst()) { return rewriter.notifyMatchFailure( op, "unimplemented: case where `start` is not a constant int or float"); } - auto end = setConstantIntOrFloat(op.getEnd()); + auto end = ConstRangeValue::fromValue(op.getEnd()); if (!end.hasConst()) { return rewriter.notifyMatchFailure( op, "unimplemented: case where value `end` is not a constant int or float"); } - auto step = setConstantIntOrFloat(op.getStep()); + auto step = ConstRangeValue::fromValue(op.getStep()); if (!step.hasConst()) { return rewriter.notifyMatchFailure(op, "unimplemented: case where value `step` " @@ -4150,19 +4152,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return values; }; - const auto intType = + const auto isIntType = resultType.getElementType().dyn_cast_or_null(); + const auto isDoubleType = + resultType.getElementType().dyn_cast_or_null(); + auto maybeResult = [&]() -> std::optional { - if (intType && start.hasConstInt() && end.hasConstInt() && + // Integer output type, and start / end / range are all integers. + if (isIntType && start.hasConstInt() && end.hasConstInt() && step.hasConstInt()) { auto values = getRange(start.getInt(), end.getInt(), step.getInt()); return tosa::getConstTensor(rewriter, op, values, values.size()); } + // Get a double range. auto values = getRange(start.getDouble(), end.getDouble(), step.getDouble()); - if (intType) { + if (isIntType) { SmallVector values_i64; values_i64.reserve(values.size()); for (auto v : values) { @@ -4171,7 +4178,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return tosa::getConstTensor(rewriter, op, values_i64, values.size()); } - return tosa::getConstTensor(rewriter, op, values, values.size()); + + if (!isDoubleType) { + return {}; + } + + SmallVector values_f32; + values_f32.reserve(values.size()); + for (auto v : values) { + values_f32.push_back(static_cast(v)); + } + auto vs = tosa::getConstTensor(rewriter, op, values_f32, + values_f32.size()); + return vs; }(); if (!maybeResult.has_value()) {