diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 51a710d940e9..736d66544e2d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4025,25 +4025,20 @@ class DecomposeAtenInstanceNormOp auto inputTy = op.getInput().getType().cast(); int64_t inputRank = inputTy.getSizes().size(); - auto reduceDimInts = - llvm::SmallVector({inputRank - 2, inputRank - 1}); - SmallVector reducedShape(inputTy.getSizes()); - reducedShape[inputRank - 1] = 1; - reducedShape[inputRank - 2] = 1; + SmallVector reduceDimInts; + SmallVector reduceDimVals; + for (int i = 2; i < inputRank; ++i) { + reducedShape[i] = 1; + reduceDimVals.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } Type dtype = inputTy.getOptionalDtype(); Type reducedTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(reducedShape), dtype); auto sizeListType = ListType::get(IntType::get(context)); - SmallVector reduceDimVals; - reduceDimVals.reserve(reduceDimInts.size()); - std::transform(reduceDimInts.begin(), reduceDimInts.end(), - std::back_inserter(reduceDimVals), [&](int64_t d) { - return rewriter.create( - loc, rewriter.getI64IntegerAttr(d)); - }); Value reduceDimList = rewriter.create(loc, sizeListType, reduceDimVals); Value cstTrue = rewriter.create(loc, true); @@ -4069,9 +4064,12 @@ class DecomposeAtenInstanceNormOp loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, /*dtype=*/none); + int64_t elemCount = 1; + for (int i = 2; i < inputRank; ++i) + elemCount *= inputTy.getSizes()[i]; + Value hw = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] * - inputTy.getSizes()[inputRank - 2])); + loc, rewriter.getI64IntegerAttr(elemCount)); Value inputVar = rewriter.create(loc, reducedTy, variancesum, hw); @@ -4104,19 +4102,14 @@ class DecomposeAtenInstanceNormOp op.getContext(), llvm::ArrayRef(newWeightShape), dtype); weight = rewriter.create(loc, newWeightTy, weight, zero); - Value two = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - newWeightShape.push_back(1); - newWeightTy = ValueTensorType::get(op.getContext(), - llvm::ArrayRef(newWeightShape), dtype); - weight = rewriter.create(loc, newWeightTy, weight, two); - - Value three = rewriter.create( - loc, rewriter.getI64IntegerAttr(3)); - newWeightShape.push_back(1); - newWeightTy = ValueTensorType::get(op.getContext(), - llvm::ArrayRef(newWeightShape), dtype); - weight = rewriter.create(loc, newWeightTy, weight, three); + while (static_cast(newWeightShape.size()) < inputRank) { + Value i = rewriter.create( + loc, rewriter.getI64IntegerAttr(newWeightShape.size())); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, i); + } Value weightExpanded = rewriter.create(loc, inputTy, weight, op.getInput()); @@ -4134,15 +4127,14 @@ class DecomposeAtenInstanceNormOp llvm::ArrayRef(newBiasShape), dtype); bias = rewriter.create(loc, newBiasTy, bias, zero); - newBiasShape.push_back(1); - newBiasTy = ValueTensorType::get(op.getContext(), - llvm::ArrayRef(newBiasShape), dtype); - bias = rewriter.create(loc, newBiasTy, bias, two); - - newBiasShape.push_back(1); - newBiasTy = ValueTensorType::get(op.getContext(), - llvm::ArrayRef(newBiasShape), dtype); - bias = rewriter.create(loc, newBiasTy, bias, three); + while (static_cast(newBiasShape.size()) < inputRank) { + Value i = rewriter.create( + loc, rewriter.getI64IntegerAttr(newBiasShape.size())); + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, i); + } Value biasExpanded = rewriter.create(loc, inputTy, bias, op.getInput());