Skip to content

Commit

Permalink
OnnxToTorch support for onnx.InstanceNormalization op
Browse files Browse the repository at this point in the history
  • Loading branch information
aldesilv committed Feb 5, 2024
1 parent 4c55784 commit c509fe3
Show file tree
Hide file tree
Showing 9 changed files with 357 additions and 0 deletions.
31 changes: 31 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5973,6 +5973,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
}];
}

def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$use_input_stats,
Torch_FloatType:$momentum,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenInstanceNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}

def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
29 changes: 29 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"InstanceNormalization", 6,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
float eps;

if (binder.tensorOperands(operands, 3) ||
binder.tensorResultType(resultType) || operands.size() != 3 ||
binder.f32FloatAttr(eps, "epsilon", 1e-05f)) {
return failure();
}
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value boolFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto epsValue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(eps));

auto momentum = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
rewriter.replaceOpWithNewOp<Torch::AtenInstanceNormOp>(
binder.op, resultType, /* input */ operands[0],
/* weight */ operands[1],
/* bias */ operands[2], /* running mean */ none,
/* running var */ none,
/* use input stats */ boolFalse, momentum, epsValue,
/* cudnn enabled */ boolFalse);
return success();
});
patterns.onOp(
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
195 changes: 195 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,199 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
};
} // namespace

namespace {
class ConvertAtenInstanceNormOp
: public OpConversionPattern<AtenInstanceNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenInstanceNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.getInput();
Value scale = adaptor.getWeight();
Value bias = adaptor.getBias();
Value eps = adaptor.getEps();

auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();

SmallVector<AffineExpr, 2> ncExpr;
ncExpr.push_back(mlir::getAffineDimExpr(0, context));
ncExpr.push_back(mlir::getAffineDimExpr(1, context));

auto ncIndexingMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, ncExpr, context);

SmallVector<AffineExpr, 1> cExpr;
cExpr.push_back(mlir::getAffineDimExpr(1, context));

auto cIndexingMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, cExpr, context);

SmallVector<AffineMap, 2> indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
ncIndexingMap, // output
};

Type resultElementType = inputType.getElementType();
auto inputSize = getTensorSizes(rewriter, loc, input);
SmallVector<Value> ncSize({inputSize[0], inputSize[1]});

Value meanTensor =
createZeroInitTensor(rewriter, loc, ncSize, resultElementType);
Value varTensor =
createZeroInitTensor(rewriter, loc, ncSize, resultElementType);

SmallVector<utils::IteratorType> iteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel,
utils::IteratorType::reduction, utils::IteratorType::reduction};

Value sumPool2d =
rewriter
.create<linalg::GenericOp>(
loc, meanTensor.getType(), ValueRange{input}, meanTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], sum = args[1];
Value result = b.create<arith::AddFOp>(loc, input, sum);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);

indexingMaps = {
rewriter.getMultiDimIdentityMap(2), // sumPool2d
rewriter.getMultiDimIdentityMap(2), // output
};

iteratorTypes = {utils::IteratorType::parallel,
utils::IteratorType::parallel};
Value mean =
rewriter
.create<linalg::GenericOp>(
loc, meanTensor.getType(), ValueRange{sumPool2d}, meanTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0];
Value hw = b.create<arith::ConstantOp>(
loc, FloatAttr::get(resultElementType,
inputType.getShape()[2] *
inputType.getShape()[3]));
Value result = b.create<arith::DivFOp>(loc, input, hw);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);

indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
ncIndexingMap, // mean
ncIndexingMap, // output
};

iteratorTypes = {
utils::IteratorType::parallel,
utils::IteratorType::parallel,
utils::IteratorType::reduction,
utils::IteratorType::reduction,
};
// (input - mean) ^ 2
Value varianceNumerator =
rewriter
.create<linalg::GenericOp>(
loc, varTensor.getType(), ValueRange{input, mean}, varTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], output = args[2];
Value two = b.create<arith::ConstantOp>(
loc, FloatAttr::get(resultElementType, 2));
Value inputSubMean =
b.create<arith::SubFOp>(loc, input, mean);
Value squared =
b.create<math::PowFOp>(loc, inputSubMean, two);
Value sum = b.create<arith::AddFOp>(loc, squared, output);
b.create<linalg::YieldOp>(loc, sum);
})
.getResult(0);

indexingMaps = {
rewriter.getMultiDimIdentityMap(2), // sumPool2d
rewriter.getMultiDimIdentityMap(2), // output
};

iteratorTypes = {
utils::IteratorType::parallel,
utils::IteratorType::parallel,
};

Value variance =
rewriter
.create<linalg::GenericOp>(
loc, varTensor.getType(), ValueRange{varianceNumerator},
varTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value numerator = args[0];
Value hw = b.create<arith::ConstantOp>(
loc, FloatAttr::get(resultElementType,
inputType.getShape()[2] *
inputType.getShape()[3]));
Value sum = b.create<arith::DivFOp>(loc, numerator, hw);
b.create<linalg::YieldOp>(loc, sum);
})
.getResult(0);

iteratorTypes = {
utils::IteratorType::parallel,
utils::IteratorType::parallel,
utils::IteratorType::parallel,
utils::IteratorType::parallel,
};
indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
ncIndexingMap, // mean
ncIndexingMap, // variance
cIndexingMap, // scale
cIndexingMap, // bias
rewriter.getMultiDimIdentityMap(inputRank), // output
};

Value outTensor =
createZeroInitTensor(rewriter, loc, inputSize, resultElementType);

Value instNorm =
rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(),
ValueRange{input, mean, variance, scale, bias}, outTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], var = args[2],
scale = args[3], bias = args[4];
Value inputSubMean =
b.create<arith::SubFOp>(loc, input, mean);
Value truncatedEps =
b.create<arith::TruncFOp>(loc, var.getType(), eps);
Value varPlusEps =
b.create<arith::AddFOp>(loc, var, truncatedEps);
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD);
Value timesScale = b.create<arith::MulFOp>(loc, temp, scale);
Value plusBias =
b.create<arith::AddFOp>(loc, timesScale, bias);
b.create<linalg::YieldOp>(loc, plusBias);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, instNorm);

return success();
}
};
} // namespace

namespace {
class ConvertAtenNllLossBackwardOp
: public OpConversionPattern<AtenNllLossBackwardOp> {
Expand Down Expand Up @@ -2367,6 +2560,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<AtenLogitOp>();
patterns.add<ConvertLogitOp>(typeConverter, context);
target.addIllegalOp<AtenInstanceNormOp>();
patterns.add<ConvertAtenInstanceNormOp>(typeConverter, context);
target.addIllegalOp<PrimsCollapseOp>();
patterns.add<ConvertPrimsCollapseOp>(typeConverter, context);
target.addIllegalOp<PrimsSplitDimOp>();
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8744,6 +8744,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9588,6 +9592,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %3 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona
def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.unary(input), [N, group], [N, group]

def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
return upstream_shape_functions.unary(input)

def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step)

Expand Down Expand Up @@ -2006,6 +2009,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r
assert not is_integer_dtype(input_dtype)
return input_dtype, input_dtype, input_dtype

# device is not supported hence unable to check the dtype function
def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int:
input_rank, input_dtype = input_rank_dtype
return input_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
)
Expand Down
16 changes: 16 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,19 @@ def forward(self, x):
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 3))

class InstanceNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.inorm = torch.nn.InstanceNorm2d(100)

@export
@annotate_args([
None,
([20, 100, 35, 45], torch.float32, True),
])
def forward(self, x):
return self.inorm(x)

@register_test_case(module_factory=lambda: InstanceNormModule())
def InstanceNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(20, 100, 35, 45))
9 changes: 9 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3

// -----

// CHECK-LABEL: func.func @test_instancenorm
func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %false, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32>
%0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32>
return %0 : !torch.vtensor<[1,2,1,3],f32>
}

// -----

// CHECK-LABEL: func.func @test_not_2d
func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>
Expand Down
Loading

0 comments on commit c509fe3

Please sign in to comment.