Skip to content

Commit

Permalink
OnnxToTorch bicubic interpolation (#3802)
Browse files Browse the repository at this point in the history
(nod-ai/SHARK-TestSuite#391)
Repro (using SHARK TestSuite):
1. `python run.py --torchtolinalg -m cl-onnx-iree -t cubic_test`

---------

Co-authored-by: zjgarvey <zjgarvey@gmail.com>
  • Loading branch information
aldesilv and zjgarvey authored Nov 12, 2024
1 parent 17c1985 commit 889a836
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 31 deletions.
17 changes: 13 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2922,7 +2922,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<Value> operands;
std::string mode, nearest_mode, coordTfMode;
int64_t antialias, exclude_outside;
float extrapolation_value;
float extrapolation_value, cubic_coeff_a;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
Expand All @@ -2947,7 +2947,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
0.0) ||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
"round_prefer_floor"))
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();
if (antialias != 0) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -2976,6 +2977,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"except asymmetric and half_pixel");
}

if (mode == "cubic" && cubic_coeff_a != -0.75) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: cubic coeff must be -0.75");
}

unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
.getSizes()
.size();
Expand All @@ -2991,8 +2997,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;
if (mode == "cubic") {
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: bicubic mode");
std::string modeStr = "cubic";
if (coordTfMode != "half_pixel")
modeStr = modeStr + "_" + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}
// supported modes:
// bilinear (half_pixel), bilinear with align_corners,
Expand Down
255 changes: 230 additions & 25 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2683,7 +2683,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
};
} // namespace

static Value NearestInterpolate(OpBuilder &b, Location loc,
static Value nearestInterpolate(OpBuilder &b, Location loc,
SmallVector<Value> outputSizes, Value input,
SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
Expand Down Expand Up @@ -2771,12 +2771,12 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
return retVal;
}

static Value BilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
static SmallVector<Value> coordinateTransform(
OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc,
SmallVector<Value> outputSizes, Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues, std::string coordStr, bool alignCornersBool,
SmallVector<Value> indices, bool clip) {

unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
Expand All @@ -2785,15 +2785,7 @@ static Value BilinearInterpolate(OpBuilder &b,
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));

bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));

SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}

SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
SmallVector<Value> proj;
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
// length_original
Value inputFP =
Expand Down Expand Up @@ -2856,13 +2848,50 @@ static Value BilinearInterpolate(OpBuilder &b,
outputSizeFP, cstOneFloat);
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
}
// preClip is the fp position inside the input image to extract from.
// clip to [0,inf)
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
if (clip) {
// preClip is the fp position inside the input image to extract from.
// clip to [0,inf)
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
// clip to [0,length_original - 1].
// proj is properly within the input image.
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
} else {
proj.push_back(preClip);
}
}
return proj;
}

static Value bilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();

Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));

bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));

SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}

SmallVector<Value> proj, high, low, highFP, lowFP;
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
scaleValues, coordStr, alignCornersBool, indices,
true);
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
// length_original
Value inputFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
// clip to [0,length_original - 1].
// proj is properly within the input image.
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));

// for bilinear interpolation, we look for the nearest indices below and
// above proj
Expand Down Expand Up @@ -2926,6 +2955,176 @@ static Value BilinearInterpolate(OpBuilder &b,
return b.create<arith::AddFOp>(loc, left, right);
}

static Value bicubicInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();

Value inputFPH =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[0]);
Value inputFPW =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[1]);

Value a = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-0.75));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstTwoFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
Value cstThreeFloat =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(3.0));
Value cstFourFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(4.0));
Value cstFiveFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(5.0));
Value cstEightFloat =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(8.0));

// (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1)
auto WeightLessThanEqualOne = [&](Value xDistance) -> Value {
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
Value xDistanceCubed =
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);

Value lessEqualOne = b.create<arith::AddFOp>(loc, a, cstTwoFloat);
lessEqualOne = b.create<arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
Value aPlusThree = b.create<arith::AddFOp>(loc, a, cstThreeFloat);
aPlusThree = b.create<arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
lessEqualOne = b.create<arith::SubFOp>(loc, lessEqualOne, aPlusThree);
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, cstOneFloat);

return lessEqualOne;
};

// a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2)
auto WeightLessThanTwo = [&](Value xDistance) -> Value {
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
Value xDistanceCubed =
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
// a|x|^3
Value lessThanTwo = b.create<arith::MulFOp>(loc, xDistanceCubed, a);

Value fiveA = b.create<arith::MulFOp>(loc, xDistanceSquared, a);
fiveA = b.create<arith::MulFOp>(loc, fiveA, cstFiveFloat);
// a|x|^3 - 5a|x|^2
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fiveA);

Value eightA = b.create<arith::MulFOp>(loc, a, xDistance);
eightA = b.create<arith::MulFOp>(loc, eightA, cstEightFloat);
// a|x|^3 - 5a|x|^2 + 8a|x|
lessThanTwo = b.create<arith::AddFOp>(loc, eightA, lessThanTwo);

Value fourA = b.create<arith::MulFOp>(loc, a, cstFourFloat);
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fourA);
return lessThanTwo;
};

bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));

SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}

SmallVector<Value> proj;

proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
scaleValues, coordStr, alignCornersBool, indices,
false);

// get the nearest neighbors of proj
Value x1 = b.create<math::CeilOp>(loc, proj[1]);
Value x_1 = b.create<arith::SubFOp>(loc, x1, cstOneFloat);
Value x_2 = b.create<arith::SubFOp>(loc, x_1, cstOneFloat);
Value x2 = b.create<arith::AddFOp>(loc, x1, cstOneFloat);

Value y1 = b.create<math::CeilOp>(loc, proj[0]);
Value y_1 = b.create<arith::SubFOp>(loc, y1, cstOneFloat);
Value y_2 = b.create<arith::SubFOp>(loc, y_1, cstOneFloat);
Value y2 = b.create<arith::AddFOp>(loc, y1, cstOneFloat);

// calculate the distance of nearest neighbors x and y to proj
Value y2Distance = b.create<arith::SubFOp>(loc, proj[0], y2);
y2Distance = b.create<math::AbsFOp>(loc, y2Distance);
Value y1Distance = b.create<arith::SubFOp>(loc, proj[0], y1);
y1Distance = b.create<math::AbsFOp>(loc, y1Distance);
Value y_1Distance = b.create<arith::SubFOp>(loc, proj[0], y_1);
y_1Distance = b.create<math::AbsFOp>(loc, y_1Distance);
Value y_2Distance = b.create<arith::SubFOp>(loc, proj[0], y_2);
y_2Distance = b.create<math::AbsFOp>(loc, y_2Distance);

Value x2Distance = b.create<arith::SubFOp>(loc, proj[1], x2);
x2Distance = b.create<math::AbsFOp>(loc, x2Distance);
Value x1Distance = b.create<arith::SubFOp>(loc, proj[1], x1);
x1Distance = b.create<math::AbsFOp>(loc, x1Distance);
Value x_1Distance = b.create<arith::SubFOp>(loc, proj[1], x_1);
x_1Distance = b.create<math::AbsFOp>(loc, x_1Distance);
Value x_2Distance = b.create<arith::SubFOp>(loc, proj[1], x_2);
x_2Distance = b.create<math::AbsFOp>(loc, x_2Distance);

SmallVector<Value> y{y_2, y_1, y1, y2};
SmallVector<Value> x{x_2, x_1, x1, x2};

SmallVector<Value> wys{
WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance),
WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)};
SmallVector<Value> wxs{
WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance),
WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)};

// clip the nearest neighbors points to inside the original image
for (int k = 0; k < 4; k++) {
Value yClipped = b.create<arith::MaximumFOp>(loc, y[k], zero);
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputFPH, cstOneFloat);
yClipped = b.create<arith::MinimumFOp>(loc, yClipped, inputHSubOne);
Value yInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yClipped);
y[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yInt);

Value xClipped = b.create<arith::MaximumFOp>(loc, x[k], zero);
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputFPW, cstOneFloat);
xClipped = b.create<arith::MinimumFOp>(loc, xClipped, inputWSubOne);
Value xInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xClipped);
x[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xInt);
}
// 1. Compute x_original and y_original (proj)
// 2. Compute nearest x and y neighbors
// 3. Compute Wx Wy
// 4. Extract inputs at nearest neighbors (inputExtracts)
// 5. Compute weighted sum (yield this)

// 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original
// 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original
// Sum_x is over 4 nearest x neighbors (similar for Sum_y)
// f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y]
// * W(y_original - y)
Value fxy = zero;

for (int j = 0; j < 4; j++) {
Value wy = wys[j];
Value xInterpy = zero;

indices[dimOffset] = y[j];

for (int i = 0; i < 4; i++) {
Value wx = wxs[i];

indices[dimOffset + 1] = x[i];

Value p = b.create<tensor::ExtractOp>(loc, input, indices);

Value wxp = b.create<arith::MulFOp>(loc, wx, p);
xInterpy = b.create<arith::AddFOp>(loc, xInterpy, wxp);
}
Value wyXInterpy = b.create<arith::MulFOp>(loc, wy, xInterpy);
fxy = b.create<arith::AddFOp>(loc, fxy, wyXInterpy);
}

return fxy;
}

namespace {
class ConvertInterpolateOp
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
Expand All @@ -2941,7 +3140,8 @@ class ConvertInterpolateOp
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
// op with the non-standard mode="bilinear_asymmetric".
matchPattern(op.getMode(), m_TorchConstantStr(mode));
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" &&
mode.substr(0, 5) != "cubic") {
return failure();
}

Expand Down Expand Up @@ -3023,13 +3223,18 @@ class ConvertInterpolateOp
(mode.find(",") == std::string::npos)
? ""
: mode.substr(mode.find(",") + 1);
retVal = NearestInterpolate(
retVal = nearestInterpolate(
b, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, coordTfMode, nearestMode);
} else if (mode.substr(0, 8) == "bilinear") {
retVal = BilinearInterpolate(
retVal = bilinearInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(8));
} else if (mode.substr(0, 5) == "cubic") {

retVal = bicubicInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(5));
}
b.create<linalg::YieldOp>(loc, retVal);
})
Expand Down
Loading

0 comments on commit 889a836

Please sign in to comment.