Skip to content

Commit

Permalink
[onnx] Support onnx.OneHot lowering to torch (llvm#3196)
Browse files Browse the repository at this point in the history
[onnx] Support `onnx.OneHot` lowering to `torch`

Leverage the `aten.onehot` implementation along with `aten.transpose`
and `aten.where.scalar`.
  • Loading branch information
rsuderman authored and archana-ramalingam committed May 8, 2024
1 parent 77dae58 commit 8f0a679
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 5 deletions.
101 changes: 101 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,107 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, input);
return success();
});
patterns.onOp(
"OneHot", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
llvm::SmallVector<Value> inputs;
Torch::ValueTensorType resultType;
if (binder.tensorOperandsList(inputs) ||
binder.tensorResultType(resultType))
return failure();

if (inputs.size() != 3)
return rewriter.notifyMatchFailure(binder.op, "expected 3 operands");

int64_t axis;
if (binder.s64IntegerAttr(axis, "axis", -1))
return rewriter.notifyMatchFailure(binder.op,
"`axis` attr not found");

auto loc = binder.getLoc();
Value indices = inputs[0];
Value depth = inputs[1];
Value values = inputs[2];

auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
auto valuesTy = cast<Torch::ValueTensorType>(values.getType());
auto depthTy = cast<Torch::ValueTensorType>(depth.getType());

axis = axis < 0 ? axis + indicesTy.getSizes().size() + 1 : axis;

bool depthIsInt = isa<IntegerType>(depthTy.getDtype());
Type intTy = rewriter.getType<Torch::IntType>();
Type floatTy = rewriter.getType<Torch::FloatType>();
Type depthETy = depthIsInt ? intTy : floatTy;
depth = rewriter.create<Torch::AtenItemOp>(loc, depthETy, depth);

if (!depthIsInt)
depth = rewriter.create<Torch::AtenIntScalarOp>(
loc, rewriter.getType<Torch::IntType>(), depth);

auto selectTy = rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{1}, valuesTy.getDtype());

Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));

Value off = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
values, zero, zero);
off = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), off);

Value on = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
values, zero, one);
on = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), on);

auto i32Ty = rewriter.getIntegerType(32, true);
llvm::SmallVector<int64_t> onehotShape(indicesTy.getSizes());
onehotShape.push_back(Torch::kUnknownSize);
auto onehotTy =
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);

Value onehot = rewriter.create<Torch::AtenOneHotOp>(
binder.getLoc(), onehotTy, indices, depth);

for (int i = valuesTy.getSizes().size(); i > axis; ++i) {
std::swap(onehotShape[i - 1], onehotShape[i]);
Value iv0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
Value iv1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i - 1));

onehotTy =
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, onehotTy,
onehot, iv1, iv0);
}

// Change one hot to an array of booleans to select value:
auto i1Ty = rewriter.getI1Type();
auto torchqTy = Torch::getScalarTypeForType(i1Ty);
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
static_cast<int64_t>(torchqTy)));

onehotTy = rewriter.getType<Torch::ValueTensorType>(onehotShape, i1Ty);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
onehot = rewriter.create<Torch::AtenToDtypeOp>(
loc, onehotTy, onehot, tyConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);

onehotTy = rewriter.getType<Torch::ValueTensorType>(
onehotShape, resultType.getDtype());
onehot = rewriter.create<Torch::AtenWhereScalarOp>(loc, onehotTy,
onehot, on, off);

rewriter.replaceOp(binder.op, onehot);
return success();
});
patterns.onOp("HardSwish", 14,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
7 changes: 2 additions & 5 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2644,9 +2644,6 @@
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",

# Failure - onnx_lowering: onnx.OneHot
"OneHotModule_basic",

# Failure - onnx_lowering: onnx.ReduceProd
"ReduceProdFloatModule_basic",
Expand All @@ -2655,7 +2652,7 @@
"ReduceProdUnsignedIntModule_basic",
"ReduceProdSignedIntModule_basic",
"ReduceProdDtypeIntModule_basic",

# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
"RandnDtypeDeviceModule_basic",
"RandnGeneratorF64Module_basic",
Expand All @@ -2679,7 +2676,7 @@
"ScatterReduceIntMaxModuleIncludeSelf",
"ScatterReduceIntMinModuleIncludeSelf",
"ScatterValueFloatModule_basic",

# Failure - onnx_lowering: onnx.ScatterND
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
Expand Down

0 comments on commit 8f0a679

Please sign in to comment.