diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c09900ce8eccf..8a1aa4724d59e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5973,6 +5973,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index df20a83515bf3..c1fb4c224d50b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -335,38 +335,66 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp( - "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - llvm::SmallVector operands; - if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || operands.size() == 0) { - return failure(); - } - Value result = operands[0]; - for (uint64_t i = 1; i < operands.size(); i++) { - result = rewriter.create( - binder.getLoc(), resultType, result, operands[i]); - } - rewriter.replaceOp(binder.op, result.getDefiningOp()); - return success(); - }); - patterns.onOp( - "Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - llvm::SmallVector operands; - if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || operands.size() == 0) { - return failure(); - } - Value result = operands[0]; - for (uint64_t i = 1; i < operands.size(); i++) { - result = rewriter.create( - binder.getLoc(), resultType, result, operands[i]); - } - rewriter.replaceOp(binder.op, result.getDefiningOp()); - return success(); - }); + patterns.onOp("InstanceNormalization", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + float eps; + + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType) || + operands.size() != 3 || + binder.f32FloatAttr(eps, "epsilon", 1e-05f)) { + return failure(); + } + Value none = rewriter.create(binder.getLoc()); + Value boolFalse = rewriter.create(binder.getLoc(), false); + auto epsValue = rewriter.create(binder.getLoc(), + rewriter.getF64FloatAttr(eps)); + + auto momentum = rewriter.create(binder.getLoc(), + rewriter.getF64FloatAttr(0.0f)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, /* input */ operands[0], /* weight */ operands[1], + /* bias */ operands[2], /* running mean */ none, /* running var */ none, + /* use input stats */ boolFalse, momentum, epsValue, /* cudnn enabled */ boolFalse); + return success(); + }); + patterns.onOp("Max", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (uint64_t i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp(binder.op, result.getDefiningOp()); + return success(); + }); + patterns.onOp("Min", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (uint64_t i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp( + binder.op, result.getDefiningOp()); + return success(); + }); patterns.onOp("Neg", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 54317979353d9..f0dfd3b939f7f 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1826,6 +1826,196 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenInstanceNormOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenInstanceNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + Value scale = adaptor.getWeight(); + Value bias = adaptor.getBias(); + Value eps = adaptor.getEps(); + + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + + SmallVector ncExpr; + ncExpr.push_back(mlir::getAffineDimExpr(0, context)); + ncExpr.push_back(mlir::getAffineDimExpr(1, context)); + + auto ncIndexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, ncExpr, context); + + SmallVector cExpr; + cExpr.push_back(mlir::getAffineDimExpr(1, context)); + + auto cIndexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, cExpr, context); + + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // output + }; + + Type resultElementType = inputType.getElementType(); + auto inputSize = getTensorSizes(rewriter, loc, input); + SmallVector ncSize({inputSize[0], inputSize[1]}); + + Value meanTensor = + createZeroInitTensor(rewriter, loc, ncSize, resultElementType); + Value varTensor = + createZeroInitTensor(rewriter, loc, ncSize, resultElementType); + + SmallVector iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::reduction, + utils::IteratorType::reduction}; + + Value sumPool2d = + rewriter + .create( + loc, meanTensor.getType(), + ValueRange{input}, meanTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], sum = args[1]; + Value result = b.create(loc, input, sum); + b.create(loc, result); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(2), // sumPool2d + rewriter.getMultiDimIdentityMap(2), // output + }; + + iteratorTypes = {utils::IteratorType::parallel, utils::IteratorType::parallel}; + Value mean = + rewriter + .create( + loc, meanTensor.getType(), + ValueRange{sumPool2d}, meanTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + Value hw = + b.create(loc, + FloatAttr::get(resultElementType, inputType.getShape()[2] * + inputType.getShape()[3])); + Value result = b.create(loc, input, hw); + b.create(loc, result); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // mean + ncIndexingMap, // output + }; + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::reduction, + utils::IteratorType::reduction,}; + // (input - mean) ^ 2 + Value varianceNumerator = + rewriter + .create( + loc, varTensor.getType(), + ValueRange{input, mean}, varTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], output = args[2]; + Value two = + b.create(loc, + FloatAttr::get(resultElementType, 2)); + Value inputSubMean = b.create(loc, input, mean); + Value squared = b.create(loc, inputSubMean, two); + Value sum = b.create(loc, squared, output); + b.create(loc, sum); + }) + .getResult(0); + + indexingMaps = { + rewriter.getMultiDimIdentityMap(2), // sumPool2d + rewriter.getMultiDimIdentityMap(2), // output + }; + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel,}; + + Value variance = + rewriter + .create( + loc, varTensor.getType(), + ValueRange{varianceNumerator}, varTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value numerator = args[0]; + Value hw = + b.create(loc, + FloatAttr::get(resultElementType, inputType.getShape()[2] * + inputType.getShape()[3])); + Value sum = b.create(loc, numerator, hw); + b.create(loc, sum); + }) + .getResult(0); + + iteratorTypes = {utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel, + utils::IteratorType::parallel,}; + indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + ncIndexingMap, // mean + ncIndexingMap, // variance + cIndexingMap, // scale + cIndexingMap, // bias + rewriter.getMultiDimIdentityMap(inputRank), // output + }; + + Value outTensor = + createZeroInitTensor(rewriter, loc, inputSize, resultElementType); + + Value instNorm = + rewriter + .create( + loc, outTensor.getType(), + ValueRange{input, mean, variance, scale, bias}, outTensor, + indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], mean = args[1], var = args[2], + scale = args[3], bias = args[4]; + Value inputSubMean = b.create(loc, input, mean); + Value truncatedEps = b.create(loc, var.getType(), eps); + Value varPlusEps = b.create(loc, var, truncatedEps); + Value rSTD = b.create(loc, varPlusEps); + Value temp = b.create(loc, inputSubMean, rSTD); + Value timesScale = b.create(loc, temp, scale); + Value plusBias = b.create(loc, timesScale, bias); + b.create(loc, plusBias); + }) + .getResult(0); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, instNorm); + + return success(); + + } +}; +} // namespace + namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { @@ -2367,6 +2557,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bb9717303e6ba..d338d672de9d9 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8744,6 +8744,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" " return %3 : !torch.tuple, list, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9588,6 +9592,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 91e98d99c9ffe..8ca8d8071ab3f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1388,6 +1388,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.unary(input), [N, group], [N, group] +def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: + return upstream_shape_functions.unary(input) + def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) @@ -2006,6 +2009,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r assert not is_integer_dtype(input_dtype) return input_dtype, input_dtype, input_dtype +# device is not supported hence unable to check the dtype function +def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3b930c20e79d5..a9506d5376572 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -432,6 +432,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index 3b17f516f9e57..2b2dfefb3990a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -489,3 +489,19 @@ def forward(self, x): def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 3)) +class InstanceNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.inorm = torch.nn.InstanceNorm2d(100) + + @export + @annotate_args([ + None, + ([20, 100, 35, 45], torch.float32, True), + ]) + def forward(self, x): + return self.inorm(x) + +@register_test_case(module_factory=lambda: InstanceNormModule()) +def InstanceNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(20, 100, 35, 45)) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 449b7e4feb32d..42858fc16b1b1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -555,6 +555,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_instancenorm + func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %false, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32> + %0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> + return %0 : !torch.vtensor<[1,2,1,3],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> diff --git a/test/Conversion/TorchToLinalg/instancenorm.mlir b/test/Conversion/TorchToLinalg/instancenorm.mlir new file mode 100644 index 0000000000000..e5dfbfee022a1 --- /dev/null +++ b/test/Conversion/TorchToLinalg/instancenorm.mlir @@ -0,0 +1,58 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @test_instancenorm + func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + // CHECK-SAME: ins({{.*}} : tensor<1x2x1x3xf32>) outs({{.*}} : tensor<1x2xf32>) { + // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[OUT]] : f32 + // CHECK: linalg.yield %[[ADD]] : f32 + // CHECK: } -> tensor<1x2xf32> + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel"]} + // CHECK-SAME: ins({{.*}} : tensor<1x2xf32>) outs({{.*}} : tensor<1x2xf32>) { + // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK: %[[CST_6:.*]] = arith.constant 3.000000e+00 : f32 + // CHECK: %[[DIV:.*]] = arith.divf %[[IN]], %[[CST_6]] : f32 + // CHECK: linalg.yield %[[DIV]] : f32 + // CHECK: } -> tensor<1x2xf32> + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + // CHECK-SAME: ins({{.*}}, {{.*}} : tensor<1x2x1x3xf32>, tensor<1x2xf32>) outs({{.*}} : tensor<1x2xf32>) { + // CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_6:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK: %[[CST_7:.*]] = arith.constant 2.000000e+00 : f32 + // CHECK: %[[SUB:.*]] = arith.subf %[[IN]], %[[IN_6]] : f32 + // CHECK: %[[POW:.*]] = math.powf %[[SUB]], %[[CST_7]] : f32 + // CHECK: %[[ADD:.*]] = arith.addf %[[POW]], %[[OUT]] : f32 + // CHECK: linalg.yield %[[ADD]] : f32 + // CHECK: } -> tensor<1x2xf32> + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel"]} + // CHECK-SAME: ins({{.*}} : tensor<1x2xf32>) outs({{.*}} : tensor<1x2xf32>) { + // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK: %[[CST_6:.*]] = arith.constant 3.000000e+00 : f32 + // CHECK: %[[DIV:.*]] = arith.divf %[[IN]], %[[CST_6]] : f32 + // CHECK: linalg.yield %[[DIV]] : f32 + // CHECK: } -> tensor<1x2xf32> + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: ins({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : tensor<1x2x1x3xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<2xf32>, tensor<2xf32>) outs(%13 : tensor<1x2x1x3xf32>) { + // CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_6:.*]]: f32, %[[IN_7:.*]]: f32, %[[IN_8:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK: %[[SUB:.*]] = arith.subf %[[IN]], %[[IN_6]] : f32 + // CHECK: %[[TRUNC:.*]] = arith.truncf %3 : f64 to f32 + // CHECK: %[[ADD_0:.*]] = arith.addf %[[IN_7]], %[[TRUNC]] : f32 + // CHECK: %[[RSQRT:.*]] = math.rsqrt %[[ADD_0]] : f32 + // CHECK: %[[MUL_0:.*]] = arith.mulf %[[SUB]], %[[RSQRT]] : f32 + // CHECK: %[[MUL_1:.*]] = arith.mulf %[[MUL_0]], %[[IN_8]] : f32 + // CHECK: %[[ADD_1:.*]] = arith.addf %[[MUL_1]], %[[IN_9]] : f32 + // CHECK: linalg.yield %[[ADD_1]] : f32 + // CHECK: } -> tensor<1x2x1x3xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %float9.999990e-06 = torch.constant.float 9.9999997473787516E-6 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %int0 = torch.constant.int 0 + %0 = torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %false, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32> + return %0 : !torch.vtensor<[1,2,1,3],f32> + }