diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 16eb5565bedde..30fbec3107b6e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5813,6 +5813,65 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + +def Torch_QuantizedInstanceNormOp : Torch_Op<"quantized.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `quantized::instance_norm : (Tensor, Tensor?, Tensor?, float, float, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps, + Torch_FloatType:$output_scale, + Torch_IntType:$output_zero_point + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult QuantizedInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void QuantizedInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d154edb1ab750..32cd5d64f2dc8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -172,6 +172,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp("InstanceNormalization", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType) || + operands.size() != 3) { + return failure(); + } + + SmallString<64> name("torch.onnx."); + name.append("epsilon"); + + auto attr = binder.op->getAttr(name); + float eps; + if (attr) { + auto epsAttr = dyn_cast(attr); + eps = epsAttr.getValue().convertToFloat(); + } else { + eps = 1e-05f; + } + auto epsValue = rewriter.create(binder.getLoc(), + rewriter.getF64FloatAttr(eps)); + + auto outputScale = rewriter.create(binder.getLoc(), + rewriter.getF64FloatAttr(1.0f)); + auto outputZeroPoint = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operands[0], operands[1], operands[2], + epsValue, outputScale, outputZeroPoint); + return success(); + }); patterns.onOp("Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e947ae73ace0c..16e7a7c7b2f68 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1678,6 +1678,192 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertQuantizedInstanceNormOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(QuantizedInstanceNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + Value scale = adaptor.getWeight(); + Value bias = adaptor.getBias(); + Value eps = adaptor.getEps(); + + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + SmallVector ncExpr; + ncExpr.push_back(mlir::getAffineDimExpr(0, context)); + ncExpr.push_back(mlir::getAffineDimExpr(1, context)); + + auto ncIndexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, ncExpr, context); + + SmallVector cExpr; + cExpr.push_back(mlir::getAffineDimExpr(1, context)); + + auto cIndexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, cExpr, context); + + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // output + }; + + Type resultElementType = inputType.getElementType(); + auto inputSize = getTensorSizes(rewriter, loc, input); + SmallVector ncSize({inputSize[0], inputSize[1]}); + + Value meanTensor = + createZeroInitTensor(rewriter, loc, ncSize, resultElementType); + Value varTensor = + createZeroInitTensor(rewriter, loc, ncSize, resultElementType); + + SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); + + Value sumPool2d = + rewriter + .create( + loc, meanTensor.getType(), + ValueRange{input}, meanTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], sum = args[1]; + Value result = b.create(loc, input, sum); + b.create(loc, result); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(2), // sumPool2d + rewriter.getMultiDimIdentityMap(2), // output + }; + + iteratorTypes = {utils::IteratorType::parallel, utils::IteratorType::parallel}; + Value mean = + rewriter + .create( + loc, meanTensor.getType(), + ValueRange{sumPool2d}, meanTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + Value hw = + b.create(loc, + FloatAttr::get(resultElementType, inputType.getShape()[2] * + inputType.getShape()[3])); + Value result = b.create(loc, input, hw); + b.create(loc, result); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // mean + ncIndexingMap, // output + }; + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel,}; + // (input - mean) ^ 2 + Value varianceNumerator = + rewriter + .create( + loc, varTensor.getType(), + ValueRange{input, mean}, varTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], output = args[2]; + Value two = + b.create(loc, + FloatAttr::get(resultElementType, 2)); + Value inputSubMean = b.create(loc, input, mean); + Value squared = b.create(loc, inputSubMean, two); + Value sum = b.create(loc, squared, output); + b.create(loc, sum); + }) + .getResult(0); + + indexingMaps = { + ncIndexingMap, // numerator + ncIndexingMap, // output + }; + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel,}; + + Value variance = + rewriter + .create( + loc, varTensor.getType(), + ValueRange{varianceNumerator}, varTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value numerator = args[0]; + Value hw = + b.create(loc, + FloatAttr::get(resultElementType, inputType.getShape()[2] * + inputType.getShape()[3])); + Value sum = b.create(loc, numerator, hw); + b.create(loc, sum); + }) + .getResult(0); + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel,}; + indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // mean + ncIndexingMap, // variance + cIndexingMap, // scale + cIndexingMap, // bias + rewriter.getMultiDimIdentityMap(inputRank), // output + }; + + Value outTensor = + createZeroInitTensor(rewriter, loc, inputSize, resultElementType); + + Value instNorm = + rewriter + .create( + loc, outTensor.getType(), + ValueRange{input, mean, variance, scale, bias}, outTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], var = args[2], + scale = args[3], bias = args[4]; + Value inputSubMean = b.create(loc, input, mean); + Value varPlusEps = b.create(loc, var, eps); + Value rSTD = b.create(loc, varPlusEps); + Value temp = b.create(loc, inputSubMean, rSTD); + Value timesScale = b.create(loc, temp, scale); + Value plusBias = b.create(loc, timesScale, bias); + b.create(loc, plusBias); + }) + .getResult(0); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, instNorm); + + return success(); + + } +}; +} // namespace + namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { @@ -2002,6 +2188,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fb458f6a5d912..9eb7242fa2ea2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -423,6 +423,12 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) + emit( + "quantized::instance_norm : (Tensor, Tensor?, Tensor?, float, float, int) -> (Tensor)" + ) emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index e224ddfa2944c..6a25c9e7cf500 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -347,6 +347,13 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3,4,5],f32> } +// CHECK-LABEL: func.func @test_instancenorm + func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.quantized.instance_norm %arg0, %arg1, %arg2, %float9.999990e-06, %float1.000000e00, %int0 : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.float, !torch.int -> !torch.vtensor<[1,2,1,3],f32> + %0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> + return %0 : !torch.vtensor<[1,2,1,3],f32> + } + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>