From b1422fa468d3b13d3250dba91c55ff7468c58a87 Mon Sep 17 00:00:00 2001 From: Alexander de Silva Date: Thu, 28 Dec 2023 19:56:35 +0000 Subject: [PATCH] OnnxToTorch support for onnx.InstanceNormalization op --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 31 +++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 31 +++ .../TorchToLinalg/Uncategorized.cpp | 195 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 8 + .../Torch/Transforms/DecomposeComplexOps.cpp | 146 +++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 5 + .../build_tools/abstract_interp_lib_gen.py | 8 + .../build_tools/torch_ods_gen.py | 3 + .../test_suite/norm_like.py | 37 ++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 + 10 files changed, 473 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c09900ce8eccf..8a1aa4724d59e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5973,6 +5973,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index df20a83515bf3..7c47c375d0d3f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -335,6 +335,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "InstanceNormalization", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + float eps; + + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType) || operands.size() != 3 || + binder.f32FloatAttr(eps, "epsilon", 1e-05f)) { + return failure(); + } + Value none = rewriter.create(binder.getLoc()); + Value boolTrue = + rewriter.create(binder.getLoc(), true); + Value boolFalse = + rewriter.create(binder.getLoc(), false); + auto epsValue = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(eps)); + + auto momentum = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, /* input */ operands[0], + /* weight */ operands[1], + /* bias */ operands[2], /* running mean */ none, + /* running var */ none, + /* use input stats */ boolTrue, momentum, epsValue, + /* cudnn enabled */ boolFalse); + return success(); + }); patterns.onOp( "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 54317979353d9..0f325981471c2 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1826,6 +1826,199 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenInstanceNormOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenInstanceNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + Value scale = adaptor.getWeight(); + Value bias = adaptor.getBias(); + Value eps = adaptor.getEps(); + + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + SmallVector ncExpr; + ncExpr.push_back(mlir::getAffineDimExpr(0, context)); + ncExpr.push_back(mlir::getAffineDimExpr(1, context)); + + auto ncIndexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, ncExpr, context); + + SmallVector cExpr; + cExpr.push_back(mlir::getAffineDimExpr(1, context)); + + auto cIndexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, cExpr, context); + + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // output + }; + + Type resultElementType = inputType.getElementType(); + auto inputSize = getTensorSizes(rewriter, loc, input); + SmallVector ncSize({inputSize[0], inputSize[1]}); + + Value meanTensor = + createZeroInitTensor(rewriter, loc, ncSize, resultElementType); + Value varTensor = + createZeroInitTensor(rewriter, loc, ncSize, resultElementType); + + SmallVector iteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::reduction, utils::IteratorType::reduction}; + + Value sumPool2d = + rewriter + .create( + loc, meanTensor.getType(), ValueRange{input}, meanTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], sum = args[1]; + Value result = b.create(loc, input, sum); + b.create(loc, result); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(2), // sumPool2d + rewriter.getMultiDimIdentityMap(2), // output + }; + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel}; + Value mean = + rewriter + .create( + loc, meanTensor.getType(), ValueRange{sumPool2d}, meanTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + Value hw = b.create( + loc, FloatAttr::get(resultElementType, + inputType.getShape()[2] * + inputType.getShape()[3])); + Value result = b.create(loc, input, hw); + b.create(loc, result); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // mean + ncIndexingMap, // output + }; + + iteratorTypes = { + utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::reduction, + utils::IteratorType::reduction, + }; + // (input - mean) ^ 2 + Value varianceNumerator = + rewriter + .create( + loc, varTensor.getType(), ValueRange{input, mean}, varTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], output = args[2]; + Value two = b.create( + loc, FloatAttr::get(resultElementType, 2)); + Value inputSubMean = + b.create(loc, input, mean); + Value squared = + b.create(loc, inputSubMean, two); + Value sum = b.create(loc, squared, output); + b.create(loc, sum); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(2), // sumPool2d + rewriter.getMultiDimIdentityMap(2), // output + }; + + iteratorTypes = { + utils::IteratorType::parallel, + utils::IteratorType::parallel, + }; + + Value variance = + rewriter + .create( + loc, varTensor.getType(), ValueRange{varianceNumerator}, + varTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value numerator = args[0]; + Value hw = b.create( + loc, FloatAttr::get(resultElementType, + inputType.getShape()[2] * + inputType.getShape()[3])); + Value sum = b.create(loc, numerator, hw); + b.create(loc, sum); + }) + .getResult(0); + + iteratorTypes = { + utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel, + }; + indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // mean + ncIndexingMap, // variance + cIndexingMap, // scale + cIndexingMap, // bias + rewriter.getMultiDimIdentityMap(inputRank), // output + }; + + Value outTensor = + createZeroInitTensor(rewriter, loc, inputSize, resultElementType); + + Value instNorm = + rewriter + .create( + loc, outTensor.getType(), + ValueRange{input, mean, variance, scale, bias}, outTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], var = args[2], + scale = args[3], bias = args[4]; + Value inputSubMean = + b.create(loc, input, mean); + Value truncatedEps = + b.create(loc, var.getType(), eps); + Value varPlusEps = + b.create(loc, var, truncatedEps); + Value rSTD = b.create(loc, varPlusEps); + Value temp = b.create(loc, inputSubMean, rSTD); + Value timesScale = b.create(loc, temp, scale); + Value plusBias = + b.create(loc, timesScale, bias); + b.create(loc, plusBias); + }) + .getResult(0); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, instNorm); + + return success(); + } +}; +} // namespace + namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { @@ -2367,6 +2560,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bb9717303e6ba..d338d672de9d9 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8744,6 +8744,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" " return %3 : !torch.tuple, list, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9588,6 +9592,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d1794de930b4e..4925b3fcc4687 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3885,6 +3885,151 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenInstanceNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenInstanceNormOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.getInput().getType().cast(); + int64_t inputRank = inputTy.getSizes().size(); + auto reduceDimInts = + llvm::SmallVector({inputRank - 2, inputRank - 1}); + + SmallVector reducedShape(inputTy.getSizes()); + reducedShape[inputRank - 1] = 1; + reducedShape[inputRank - 2] = 1; + + Type dtype = inputTy.getOptionalDtype(); + Type reducedTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(reducedShape), dtype); + + auto sizeListType = ListType::get(IntType::get(context)); + SmallVector reduceDimVals; + reduceDimVals.reserve(reduceDimInts.size()); + std::transform(reduceDimInts.begin(), reduceDimInts.end(), + std::back_inserter(reduceDimVals), [&](int64_t d) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(d)); + }); + Value reduceDimList = + rewriter.create(loc, sizeListType, reduceDimVals); + Value cstTrue = rewriter.create(loc, true); + Value none = rewriter.create(loc); + + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + // mean(x) + Value inputMean = rewriter.create( + loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); + + // x - mean(x) + Value inputMeanExpanded = + rewriter.create(loc, inputTy, inputMean, op.getInput()); + Value inputSubMean = rewriter.create( + loc, inputTy, op.getInput(), inputMeanExpanded, one); + // (x - mean(x))^2 + Value inputSubMeanSquare = rewriter.create( + loc, inputTy, inputSubMean, inputSubMean); + + Value variancesum = rewriter.create( + loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, + /*dtype=*/none); + + Value hw = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] * + inputTy.getSizes()[inputRank - 2])); + Value inputVar = + rewriter.create(loc, reducedTy, variancesum, hw); + + // rsqrt(var(x) + eps) + Value inputVarPlusEps = rewriter.create( + loc, reducedTy, inputVar, op.getEps(), one); + Value inputRsqrtVar = + rewriter.create(loc, reducedTy, inputVarPlusEps); + + // (x - mean(x)) * rsqrt(var(x) + eps) + Value inputRsqrtVarExpanded = rewriter.create( + loc, inputTy, inputRsqrtVar, op.getInput()); + Value inputNormalized = rewriter.create( + loc, inputTy, inputSubMean, inputRsqrtVarExpanded); + Value out = rewriter.create( + loc, op.getResult().getType(), inputNormalized); + + Value weight = op.getWeight(); + auto weightTy = weight.getType().cast(); + dtype = weightTy.getOptionalDtype(); + + SmallVector weightShape(weightTy.getSizes()); + SmallVector newWeightShape; + newWeightShape.push_back(1); + newWeightShape.append(weightShape); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Type newWeightTy = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, zero); + + Value two = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, two); + + Value three = rewriter.create( + loc, rewriter.getI64IntegerAttr(3)); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, three); + + Value weightExpanded = + rewriter.create(loc, inputTy, weight, op.getInput()); + + Value bias = op.getBias(); + auto biasTy = bias.getType().cast(); + dtype = biasTy.getOptionalDtype(); + + SmallVector biasShape(biasTy.getSizes()); + SmallVector newBiasShape; + newBiasShape.push_back(1); + newBiasShape.append(biasShape); + + Type newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, zero); + + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, two); + + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, three); + + Value biasExpanded = + rewriter.create(loc, inputTy, bias, op.getInput()); + + out = rewriter.create(loc, out.getType(), out, + weightExpanded); + out = rewriter.create(loc, out.getType(), out, + biasExpanded, one); + + rewriter.replaceOp(op, out); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeLayerNormOp : public OpRewritePattern { @@ -6656,6 +6801,7 @@ class DecomposeComplexOpsPass DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAddCLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f43c325069ceb..cba859a040771 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -997,6 +997,7 @@ "AtenEyeModuleFalsePinMemory_basic", "AtenEyeModuleFloat2D_basic", "AtenRoundIntModule_basic", + "AtenInstanceNormModule_basic", "AtenToDeviceModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", @@ -1188,6 +1189,7 @@ "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", "IndexTensorStaticModule_basic", + "InstanceNormModule_basic", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", "LayerNormNormalizeOverAllDimsModule_basic", @@ -1401,6 +1403,9 @@ "Conv2dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", + + "AtenInstanceNormModule_basic", + "InstanceNormModule_basic", } LTC_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 91e98d99c9ffe..8ca8d8071ab3f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1388,6 +1388,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.unary(input), [N, group], [N, group] +def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: + return upstream_shape_functions.unary(input) + def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) @@ -2006,6 +2009,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r assert not is_integer_dtype(input_dtype) return input_dtype, input_dtype, input_dtype +# device is not supported hence unable to check the dtype function +def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3b930c20e79d5..a9506d5376572 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -432,6 +432,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index 3b17f516f9e57..2c58ad7f639e3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -489,3 +489,40 @@ def forward(self, x): def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 3)) +class InstanceNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.inorm = torch.nn.InstanceNorm2d(2) + self.inorm.weight = torch.nn.Parameter(torch.tensor([1.0, 1.5])) + self.inorm.bias = torch.nn.Parameter(torch.tensor([0.0, 1.0])) + + @export + @annotate_args([ + None, + ([1, 2, 1, 3], torch.float32, True), + ]) + def forward(self, x): + return self.inorm(x) + +@register_test_case(module_factory=lambda: InstanceNormModule()) +def InstanceNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 1, 3)) + +class AtenInstanceNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 1, 3], torch.float32, True), + ([2], torch.float32, True), + ([2], torch.float32, True) + ]) + def forward(self, x, w, b): + return torch.ops.aten.instance_norm(x, w, b, None, + None, True, 0.0, 1e-05, False) + +@register_test_case(module_factory=lambda: AtenInstanceNormModule()) +def AtenInstanceNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2)) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 449b7e4feb32d..866bfc0c3c819 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -555,6 +555,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_instancenorm + func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %true, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32> + %0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> + return %0 : !torch.vtensor<[1,2,1,3],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>