Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OnnxToTorch support for onnx.InstanceNormalization op #2710

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6203,6 +6203,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
}];
}

def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$use_input_stats,
Torch_FloatType:$momentum,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenInstanceNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}

def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
31 changes: 31 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"InstanceNormalization", 6,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
float eps;

if (binder.tensorOperands(operands, 3) ||
binder.tensorResultType(resultType) || operands.size() != 3 ||
binder.f32FloatAttr(eps, "epsilon", 1e-05f)) {
return failure();
}
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value boolTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value boolFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto epsValue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(eps));

auto momentum = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
rewriter.replaceOpWithNewOp<Torch::AtenInstanceNormOp>(
binder.op, resultType, /* input */ operands[0],
/* weight */ operands[1],
/* bias */ operands[2], /* running mean */ none,
/* running var */ none,
/* use input stats */ boolTrue, momentum, epsValue,
/* cudnn enabled */ boolFalse);
return success();
});
patterns.onOp(
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8784,6 +8784,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9643,6 +9647,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %3 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
146 changes: 146 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3962,6 +3962,151 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
};
} // namespace

namespace {
class DecomposeAtenInstanceNormOp
: public OpRewritePattern<AtenInstanceNormOp> {
using OpRewritePattern<AtenInstanceNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenInstanceNormOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
auto context = op.getContext();

auto inputTy = op.getInput().getType().cast<BaseTensorType>();
int64_t inputRank = inputTy.getSizes().size();
auto reduceDimInts =
llvm::SmallVector<int64_t>({inputRank - 2, inputRank - 1});

SmallVector<int64_t> reducedShape(inputTy.getSizes());
reducedShape[inputRank - 1] = 1;
reducedShape[inputRank - 2] = 1;

Type dtype = inputTy.getOptionalDtype();
Type reducedTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(reducedShape), dtype);

auto sizeListType = ListType::get(IntType::get(context));
SmallVector<Value> reduceDimVals;
reduceDimVals.reserve(reduceDimInts.size());
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
std::back_inserter(reduceDimVals), [&](int64_t d) {
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d));
});
Value reduceDimList =
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);

Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));

// mean(x)
Value inputMean = rewriter.create<AtenMeanDimOp>(
loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none);

// x - mean(x)
Value inputMeanExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.getInput());
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, inputTy, op.getInput(), inputMeanExpanded, one);
// (x - mean(x))^2
Value inputSubMeanSquare = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputSubMean, inputSubMean);

Value variancesum = rewriter.create<AtenSumDimIntListOp>(
loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue,
/*dtype=*/none);

Value hw = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] *
inputTy.getSizes()[inputRank - 2]));
Value inputVar =
rewriter.create<AtenDivScalarOp>(loc, reducedTy, variancesum, hw);

// rsqrt(var(x) + eps)
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
loc, reducedTy, inputVar, op.getEps(), one);
Value inputRsqrtVar =
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);

// (x - mean(x)) * rsqrt(var(x) + eps)
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
loc, inputTy, inputRsqrtVar, op.getInput());
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputSubMean, inputRsqrtVarExpanded);
Value out = rewriter.create<TensorStaticInfoCastOp>(
loc, op.getResult().getType(), inputNormalized);

Value weight = op.getWeight();
auto weightTy = weight.getType().cast<BaseTensorType>();
dtype = weightTy.getOptionalDtype();

SmallVector<int64_t> weightShape(weightTy.getSizes());
SmallVector<int64_t> newWeightShape;
newWeightShape.push_back(1);
newWeightShape.append(weightShape);

Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Type newWeightTy = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, zero);

Value two = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, two);

Value three = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, three);

Value weightExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());

Value bias = op.getBias();
auto biasTy = bias.getType().cast<BaseTensorType>();
dtype = biasTy.getOptionalDtype();

SmallVector<int64_t> biasShape(biasTy.getSizes());
SmallVector<int64_t> newBiasShape;
newBiasShape.push_back(1);
newBiasShape.append(biasShape);

Type newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, zero);

newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, two);

newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, three);

Value biasExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, bias, op.getInput());

out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out,
weightExpanded);
out = rewriter.create<AtenAddTensorOp>(loc, out.getType(), out,
biasExpanded, one);

rewriter.replaceOp(op, out);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenNativeLayerNormOp
: public OpRewritePattern<AtenNativeLayerNormOp> {
Expand Down Expand Up @@ -6733,6 +6878,7 @@ class DecomposeComplexOpsPass
DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenInstanceNormOp>(patterns);
aldesilv marked this conversation as resolved.
Show resolved Hide resolved
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
});
target.addIllegalOp<AtenAddcmulOp>();
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenInstanceNormOp>();
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();
target.addIllegalOp<AtenGroupNormOp>();
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'

# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"AtenInstanceNormModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
Expand Down Expand Up @@ -1012,6 +1013,7 @@
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
"AtenRoundIntModule_basic",
"AtenInstanceNormModule_basic",
"AtenToDeviceModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
Expand Down Expand Up @@ -1420,6 +1422,8 @@
"Conv2dNoPaddingModule_basic",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic",

"AtenInstanceNormModule_basic",
}

LTC_CRASHING_SET = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona
def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.unary(input), [N, group], [N, group]

def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
return upstream_shape_functions.unary(input)

def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step)

Expand Down Expand Up @@ -2048,6 +2051,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r
assert not is_integer_dtype(input_dtype)
return input_dtype, input_dtype, input_dtype

# device is not supported hence unable to check the dtype function
def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int:
input_rank, input_dtype = input_rank_dtype
return input_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
aldesilv marked this conversation as resolved.
Show resolved Hide resolved
)
emit(
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
)
Expand Down
18 changes: 18 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,21 @@ def forward(self, x):
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 3))

class AtenInstanceNormModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 2, 1, 3], torch.float32, True),
([2], torch.float32, True),
([2], torch.float32, True)
])
def forward(self, x, w, b):
return torch.ops.aten.instance_norm(x, w, b, None,
None, True, 0.0, 1e-05, False)

@register_test_case(module_factory=lambda: AtenInstanceNormModule())
def AtenInstanceNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2))
9 changes: 9 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3

// -----

// CHECK-LABEL: func.func @test_instancenorm
func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %true, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32>
%0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32>
return %0 : !torch.vtensor<[1,2,1,3],f32>
}

// -----

// CHECK-LABEL: func.func @test_not_2d
func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>
Expand Down
Loading