Skip to content

Commit

Permalink
OnnxToTorch support for onnx.InstanceNormalization op
Browse files Browse the repository at this point in the history
  • Loading branch information
aldesilv committed Feb 14, 2024
1 parent 4c55784 commit f8d3396
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 12 deletions.
31 changes: 31 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5973,6 +5973,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
}];
}

def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$use_input_stats,
Torch_FloatType:$momentum,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenInstanceNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}

def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
31 changes: 31 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"InstanceNormalization", 6,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
float eps;

if (binder.tensorOperands(operands, 3) ||
binder.tensorResultType(resultType) || operands.size() != 3 ||
binder.f32FloatAttr(eps, "epsilon", 1e-05f)) {
return failure();
}
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value boolTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value boolFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto epsValue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(eps));

auto momentum = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
rewriter.replaceOpWithNewOp<Torch::AtenInstanceNormOp>(
binder.op, resultType, /* input */ operands[0],
/* weight */ operands[1],
/* bias */ operands[2], /* running mean */ none,
/* running var */ none,
/* use input stats */ boolTrue, momentum, epsValue,
/* cudnn enabled */ boolFalse);
return success();
});
patterns.onOp(
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
28 changes: 16 additions & 12 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,11 +1342,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto valueTy = value.getType();
auto qtensor = op->getOperand(0);
auto qtensorTy = qtensor.getType().cast<ValueTensorType>().getDtype();
auto makeQTensor =
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
if (!makeQTensor) {
op->emitWarning(
"unimplemented: dequantizing tensor of unknown scale / zero-point");

Value zp, scale;
if (auto makeQTensor =
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
zp = makeQTensor.getZeroPoint();
scale = makeQTensor.getScale();
}

if (auto quant = qtensor.getDefiningOp<AtenQuantizePerTensorOp>()) {
zp = quant.getZeroPoint();
scale = quant.getScale();
}

if (!zp || !scale) {
return nullptr;
}

Expand All @@ -1362,10 +1371,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
}

Value zp = makeQTensor.getZeroPoint();
zp = converter->materializeTargetConversion(
b, loc, converter->convertType(zp.getType()),
makeQTensor.getZeroPoint());
b, loc, converter->convertType(zp.getType()), zp);
auto zpTy = zp.getType();

if (zpTy != outIntTy) {
Expand All @@ -1380,10 +1387,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
value = b.create<arith::SIToFPOp>(loc, outFpTy, value);
}

Value scale = makeQTensor.getScale();
scale = converter->materializeTargetConversion(
b, loc, converter->convertType(scale.getType()),
makeQTensor.getScale());
b, loc, converter->convertType(scale.getType()), scale);
if (scale.getType() != value.getType()) {
scale = b.create<arith::TruncFOp>(loc, value.getType(), scale);
}
Expand Down Expand Up @@ -2233,7 +2238,6 @@ class ConvertDequantizePerChannel
auto qoperand = op.getOperand();
auto make = qoperand.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>();
if (!make) {
llvm::errs() << "Did not find make per channel\n";
return rewriter.notifyMatchFailure(op, "did not find per channel qint");
}

Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8744,6 +8744,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9588,6 +9592,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %3 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
146 changes: 146 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3885,6 +3885,151 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
};
} // namespace

namespace {
class DecomposeAtenInstanceNormOp
: public OpRewritePattern<AtenInstanceNormOp> {
using OpRewritePattern<AtenInstanceNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenInstanceNormOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
auto context = op.getContext();

auto inputTy = op.getInput().getType().cast<BaseTensorType>();
int64_t inputRank = inputTy.getSizes().size();
auto reduceDimInts =
llvm::SmallVector<int64_t>({inputRank - 2, inputRank - 1});

SmallVector<int64_t> reducedShape(inputTy.getSizes());
reducedShape[inputRank - 1] = 1;
reducedShape[inputRank - 2] = 1;

Type dtype = inputTy.getOptionalDtype();
Type reducedTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(reducedShape), dtype);

auto sizeListType = ListType::get(IntType::get(context));
SmallVector<Value> reduceDimVals;
reduceDimVals.reserve(reduceDimInts.size());
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
std::back_inserter(reduceDimVals), [&](int64_t d) {
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d));
});
Value reduceDimList =
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);

Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));

// mean(x)
Value inputMean = rewriter.create<AtenMeanDimOp>(
loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none);

// x - mean(x)
Value inputMeanExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.getInput());
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, inputTy, op.getInput(), inputMeanExpanded, one);
// (x - mean(x))^2
Value inputSubMeanSquare = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputSubMean, inputSubMean);

Value variancesum = rewriter.create<AtenSumDimIntListOp>(
loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue,
/*dtype=*/none);

Value hw = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] *
inputTy.getSizes()[inputRank - 2]));
Value inputVar =
rewriter.create<AtenDivScalarOp>(loc, reducedTy, variancesum, hw);

// rsqrt(var(x) + eps)
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
loc, reducedTy, inputVar, op.getEps(), one);
Value inputRsqrtVar =
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);

// (x - mean(x)) * rsqrt(var(x) + eps)
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
loc, inputTy, inputRsqrtVar, op.getInput());
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputSubMean, inputRsqrtVarExpanded);
Value out = rewriter.create<TensorStaticInfoCastOp>(
loc, op.getResult().getType(), inputNormalized);

Value weight = op.getWeight();
auto weightTy = weight.getType().cast<BaseTensorType>();
dtype = weightTy.getOptionalDtype();

SmallVector<int64_t> weightShape(weightTy.getSizes());
SmallVector<int64_t> newWeightShape;
newWeightShape.push_back(1);
newWeightShape.append(weightShape);

Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Type newWeightTy = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, zero);

Value two = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, two);

Value three = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, three);

Value weightExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());

Value bias = op.getBias();
auto biasTy = bias.getType().cast<BaseTensorType>();
dtype = biasTy.getOptionalDtype();

SmallVector<int64_t> biasShape(biasTy.getSizes());
SmallVector<int64_t> newBiasShape;
newBiasShape.push_back(1);
newBiasShape.append(biasShape);

Type newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, zero);

newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, two);

newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, three);

Value biasExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, bias, op.getInput());

out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out,
weightExpanded);
out = rewriter.create<AtenAddTensorOp>(loc, out.getType(), out,
biasExpanded, one);

rewriter.replaceOp(op, out);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenNativeLayerNormOp
: public OpRewritePattern<AtenNativeLayerNormOp> {
Expand Down Expand Up @@ -6656,6 +6801,7 @@ class DecomposeComplexOpsPass
DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenInstanceNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
});
target.addIllegalOp<AtenAddcmulOp>();
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenInstanceNormOp>();
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();
target.addIllegalOp<AtenGroupNormOp>();
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'

# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"AtenInstanceNormModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
Expand Down Expand Up @@ -997,6 +998,7 @@
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
"AtenRoundIntModule_basic",
"AtenInstanceNormModule_basic",
"AtenToDeviceModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
Expand Down Expand Up @@ -1401,6 +1403,8 @@
"Conv2dNoPaddingModule_basic",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic",

"AtenInstanceNormModule_basic",
}

LTC_CRASHING_SET = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona
def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.unary(input), [N, group], [N, group]

def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
return upstream_shape_functions.unary(input)

def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step)

Expand Down Expand Up @@ -2006,6 +2009,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r
assert not is_integer_dtype(input_dtype)
return input_dtype, input_dtype, input_dtype

# device is not supported hence unable to check the dtype function
def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int:
input_rank, input_dtype = input_rank_dtype
return input_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
)
Expand Down
Loading

0 comments on commit f8d3396

Please sign in to comment.