Skip to content

Commit

Permalink
[torch] Fix DecomposeAtenInstanceNorm decomposition (llvm#2960)
Browse files Browse the repository at this point in the history
The decomposition only suports a NCHW lowering however the operation can
support arbitrary spatial dimensions. Updated the lowering to better
support spatial dimensions.
  • Loading branch information
rsuderman authored Feb 28, 2024
1 parent dd673cf commit 73b6df9
Showing 1 changed file with 28 additions and 36 deletions.
64 changes: 28 additions & 36 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4025,25 +4025,20 @@ class DecomposeAtenInstanceNormOp

auto inputTy = op.getInput().getType().cast<BaseTensorType>();
int64_t inputRank = inputTy.getSizes().size();
auto reduceDimInts =
llvm::SmallVector<int64_t>({inputRank - 2, inputRank - 1});

SmallVector<int64_t> reducedShape(inputTy.getSizes());
reducedShape[inputRank - 1] = 1;
reducedShape[inputRank - 2] = 1;
SmallVector<int64_t> reduceDimInts;
SmallVector<Value> reduceDimVals;
for (int i = 2; i < inputRank; ++i) {
reducedShape[i] = 1;
reduceDimVals.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
}

Type dtype = inputTy.getOptionalDtype();
Type reducedTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(reducedShape), dtype);

auto sizeListType = ListType::get(IntType::get(context));
SmallVector<Value> reduceDimVals;
reduceDimVals.reserve(reduceDimInts.size());
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
std::back_inserter(reduceDimVals), [&](int64_t d) {
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d));
});
Value reduceDimList =
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Expand All @@ -4069,9 +4064,12 @@ class DecomposeAtenInstanceNormOp
loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue,
/*dtype=*/none);

int64_t elemCount = 1;
for (int i = 2; i < inputRank; ++i)
elemCount *= inputTy.getSizes()[i];

Value hw = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] *
inputTy.getSizes()[inputRank - 2]));
loc, rewriter.getI64IntegerAttr(elemCount));
Value inputVar =
rewriter.create<AtenDivScalarOp>(loc, reducedTy, variancesum, hw);

Expand Down Expand Up @@ -4104,19 +4102,14 @@ class DecomposeAtenInstanceNormOp
op.getContext(), llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, zero);

Value two = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, two);

Value three = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, three);
while (static_cast<int64_t>(newWeightShape.size()) < inputRank) {
Value i = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(newWeightShape.size()));
newWeightShape.push_back(1);
newWeightTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newWeightShape), dtype);
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, i);
}

Value weightExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
Expand All @@ -4134,15 +4127,14 @@ class DecomposeAtenInstanceNormOp
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, zero);

newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, two);

newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, three);
while (static_cast<int64_t>(newBiasShape.size()) < inputRank) {
Value i = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(newBiasShape.size()));
newBiasShape.push_back(1);
newBiasTy = ValueTensorType::get(op.getContext(),
llvm::ArrayRef(newBiasShape), dtype);
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, i);
}

Value biasExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, bias, op.getInput());
Expand Down

0 comments on commit 73b6df9

Please sign in to comment.