Skip to content

Commit

Permalink
[onnx] Support onnx.OneHot lowering to torch
Browse files Browse the repository at this point in the history
Leverage the `aten.onehot` implementation along with `aten.transpose` and
`aten.where.scalar`.
  • Loading branch information
Robert Suderman authored and Robert Suderman committed Apr 20, 2024
1 parent b01245c commit 167bf0d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 3 deletions.
93 changes: 93 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,99 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, input);
return success();
});
patterns.onOp(
"OneHot", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
llvm::SmallVector<Value> inputs;
Torch::ValueTensorType resultType;
if (binder.tensorOperandsList(inputs) ||
binder.tensorResultType(resultType))
return failure();

if (inputs.size() != 3)
return rewriter.notifyMatchFailure(binder.op, "expected 3 operands");

int64_t axis;
if (binder.s64IntegerAttr(axis, "axis", -1))
return rewriter.notifyMatchFailure(binder.op,
"`axis` attr not found");

auto loc = binder.getLoc();
Value indices = inputs[0];
Value depth = inputs[1];
Value values = inputs[2];

auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
auto valuesTy = cast<Torch::ValueTensorType>(values.getType());

axis = axis < 0 ? axis + indicesTy.getSizes().size() + 1 : axis;

depth = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), depth);

auto selectTy = rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{1}, valuesTy.getDtype());

Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));

Value off = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
values, zero, zero);
off = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), off);

Value on = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
values, zero, one);
on = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), on);

auto i32Ty = rewriter.getIntegerType(32, true);
llvm::SmallVector<int64_t> onehotShape(indicesTy.getSizes());
onehotShape.push_back(Torch::kUnknownSize);
auto onehotTy =
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);

Value onehot = rewriter.create<Torch::AtenOneHotOp>(
binder.getLoc(), onehotTy, indices, depth);

for (int i = valuesTy.getSizes().size(); i > axis; ++i) {
std::swap(onehotShape[i - 1], onehotShape[i]);
Value iv0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
Value iv1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i - 1));

onehotTy =
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, onehotTy,
onehot, iv1, iv0);
}

// Change one hot to an array of booleans to select value:
auto i1Ty = rewriter.getI1Type();
auto torchqTy = Torch::getScalarTypeForType(i1Ty);
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
static_cast<int64_t>(torchqTy)));

onehotTy = rewriter.getType<Torch::ValueTensorType>(onehotShape, i1Ty);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
onehot = rewriter.create<Torch::AtenToDtypeOp>(
loc, onehotTy, onehot, tyConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);

onehotTy = rewriter.getType<Torch::ValueTensorType>(
onehotShape, resultType.getDtype());
onehot = rewriter.create<Torch::AtenWhereScalarOp>(loc, onehotTy,
onehot, on, off);

rewriter.replaceOp(binder.op, onehot);
return success();
});
patterns.onOp("HardSwish", 14,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
4 changes: 1 addition & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,9 +2600,6 @@
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",

# Failure - onnx_lowering: onnx.OneHot
"OneHotModule_basic",

# Failure - onnx_lowering: onnx.RandomNormal
"RandnDtypeDeviceModule_basic",
"RandnGeneratorF64Module_basic",
Expand Down Expand Up @@ -2651,6 +2648,7 @@
"ScatterSrcStaticModule_basic",
"ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic",


# Failure - onnx_lowering: onnx.ScatterND
"IndexPut1DFloatAccumulateModule_basic",
Expand Down

0 comments on commit 167bf0d

Please sign in to comment.