From 987701182538e10b62a5ade8b8f1dab410194c36 Mon Sep 17 00:00:00 2001 From: Alexander de Silva Date: Thu, 28 Dec 2023 19:56:35 +0000 Subject: [PATCH] OnnxToTorch support for onnx.InstanceNormalization op --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 59 ++++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 33 +++ .../TorchToLinalg/Uncategorized.cpp | 191 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 8 + .../build_tools/abstract_interp_lib_gen.py | 8 + .../build_tools/torch_ods_gen.py | 6 + .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 + 7 files changed, 314 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 12a2bf4a86e2a..99665ace647df 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5860,6 +5860,65 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + +def Torch_QuantizedInstanceNormOp : Torch_Op<"quantized.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `quantized::instance_norm : (Tensor, Tensor?, Tensor?, float, float, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps, + Torch_FloatType:$output_scale, + Torch_IntType:$output_zero_point + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult QuantizedInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void QuantizedInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c0a7473e4601b..3310a2ac02756 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -251,6 +251,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp("InstanceNormalization", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType) || + operands.size() != 3) { + return failure(); + } + + SmallString<64> name("torch.onnx."); + name.append("epsilon"); + + auto attr = binder.op->getAttr(name); + float eps; + if (attr) { + auto epsAttr = dyn_cast(attr); + eps = epsAttr.getValue().convertToFloat(); + } else { + eps = 1e-05f; + } + auto epsValue = rewriter.create(binder.getLoc(), + rewriter.getF64FloatAttr(eps)); + + auto outputScale = rewriter.create(binder.getLoc(), + rewriter.getF64FloatAttr(1.0f)); + auto outputZeroPoint = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operands[0], operands[1], operands[2], + epsValue, outputScale, outputZeroPoint); + return success(); + }); patterns.onOp("Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 749945dee6e24..70c0e8216b18e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1703,6 +1703,195 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertQuantizedInstanceNormOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(QuantizedInstanceNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + Value scale = adaptor.getWeight(); + Value bias = adaptor.getBias(); + Value eps = adaptor.getEps(); + + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + SmallVector ncExpr; + ncExpr.push_back(mlir::getAffineDimExpr(0, context)); + ncExpr.push_back(mlir::getAffineDimExpr(1, context)); + + auto ncIndexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, ncExpr, context); + + SmallVector cExpr; + cExpr.push_back(mlir::getAffineDimExpr(1, context)); + + auto cIndexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, cExpr, context); + + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // output + }; + + Type resultElementType = inputType.getElementType(); + auto inputSize = getTensorSizes(rewriter, loc, input); + SmallVector ncSize({inputSize[0], inputSize[1]}); + + Value meanTensor = + createZeroInitTensor(rewriter, loc, ncSize, resultElementType); + Value varTensor = + createZeroInitTensor(rewriter, loc, ncSize, resultElementType); + + SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); + + Value sumPool2d = + rewriter + .create( + loc, meanTensor.getType(), + ValueRange{input}, meanTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], sum = args[1]; + Value result = b.create(loc, input, sum); + b.create(loc, result); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(2), // sumPool2d + rewriter.getMultiDimIdentityMap(2), // output + }; + + iteratorTypes = {utils::IteratorType::parallel, utils::IteratorType::parallel}; + Value mean = + rewriter + .create( + loc, meanTensor.getType(), + ValueRange{sumPool2d}, meanTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + Value hw = + b.create(loc, + FloatAttr::get(resultElementType, inputType.getShape()[2] * + inputType.getShape()[3])); + Value result = b.create(loc, input, hw); + b.create(loc, result); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // mean + ncIndexingMap, // output + }; + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel,}; + // (input - mean) ^ 2 + Value varianceNumerator = + rewriter + .create( + loc, varTensor.getType(), + ValueRange{input, mean}, varTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], output = args[2]; + Value two = + b.create(loc, + FloatAttr::get(resultElementType, 2)); + Value inputSubMean = b.create(loc, input, mean); + Value squared = b.create(loc, inputSubMean, two); + Value sum = b.create(loc, squared, output); + b.create(loc, sum); + }) + .getResult(0); + + llvm::outs() << inputRank; + indexingMaps = { + rewriter.getMultiDimIdentityMap(2), // sumPool2d + rewriter.getMultiDimIdentityMap(2), // output + }; + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel,}; + + Value variance = + rewriter + .create( + loc, varTensor.getType(), + ValueRange{varianceNumerator}, varTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value numerator = args[0]; + Value hw = + b.create(loc, + FloatAttr::get(resultElementType, inputType.getShape()[2] * + inputType.getShape()[3])); + Value sum = b.create(loc, numerator, hw); + b.create(loc, sum); + }) + .getResult(0); + + llvm::outs() << inputRank; + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel,}; + indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // mean + ncIndexingMap, // variance + cIndexingMap, // scale + cIndexingMap, // bias + rewriter.getMultiDimIdentityMap(inputRank), // output + }; + + Value outTensor = + createZeroInitTensor(rewriter, loc, inputSize, resultElementType); + + Value instNorm = + rewriter + .create( + loc, outTensor.getType(), + ValueRange{input, mean, variance, scale, bias}, outTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], var = args[2], + scale = args[3], bias = args[4]; + Value inputSubMean = b.create(loc, input, mean); + Value truncatedEps = b.create(loc, var.getType(), eps); + Value varPlusEps = b.create(loc, var, truncatedEps); + Value rSTD = b.create(loc, varPlusEps); + Value temp = b.create(loc, inputSubMean, rSTD); + Value timesScale = b.create(loc, temp, scale); + Value plusBias = b.create(loc, timesScale, bias); + b.create(loc, plusBias); + }) + .getResult(0); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, instNorm); + + return success(); + + } +}; +} // namespace + namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { @@ -2114,6 +2303,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 3901cd34a4aaa..61bdbadbcd4a8 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8104,6 +8104,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" " return %3 : !torch.tuple, list, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.quantized.instance_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8915,6 +8919,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.quantized.instance_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 211023a9deec0..761336afd32be 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1151,6 +1151,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.unary(input), [N, group], [N, group] +def quantized〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float, output_scale: float, output_zero_point: int) -> List[int]: + return upstream_shape_functions.unary(input) + def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) @@ -1755,6 +1758,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r assert not is_integer_dtype(input_dtype) return input_dtype, input_dtype, input_dtype +# device is not supported hence unable to check the dtype function +def quantized〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], eps: float, output_scale: float, output_zero_point: int) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a9f9ed96dce2d..8c9bb11467e3f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -424,6 +424,12 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) + emit( + "quantized::instance_norm : (Tensor, Tensor?, Tensor?, float, float, int) -> (Tensor)" + ) emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c85659c25aa8f..750f750b1828d 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -441,6 +441,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_instancenorm + func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.quantized.instance_norm %arg0, %arg1, %arg2, %float9.999990e-06, %float1.000000e00, %int0 : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.float, !torch.int -> !torch.vtensor<[1,2,1,3],f32> + %0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> + return %0 : !torch.vtensor<[1,2,1,3],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>