From 167bf0d9b216c48388aa3255090b06351387b001 Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Fri, 19 Apr 2024 17:27:50 -0700 Subject: [PATCH] [onnx] Support `onnx.OneHot` lowering to `torch` Leverage the `aten.onehot` implementation along with `aten.transpose` and `aten.where.scalar`. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 93 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 4 +- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c7d0710791193..73c1f58cc7319 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1557,6 +1557,99 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, input); return success(); }); + patterns.onOp( + "OneHot", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + llvm::SmallVector inputs; + Torch::ValueTensorType resultType; + if (binder.tensorOperandsList(inputs) || + binder.tensorResultType(resultType)) + return failure(); + + if (inputs.size() != 3) + return rewriter.notifyMatchFailure(binder.op, "expected 3 operands"); + + int64_t axis; + if (binder.s64IntegerAttr(axis, "axis", -1)) + return rewriter.notifyMatchFailure(binder.op, + "`axis` attr not found"); + + auto loc = binder.getLoc(); + Value indices = inputs[0]; + Value depth = inputs[1]; + Value values = inputs[2]; + + auto indicesTy = cast(indices.getType()); + auto valuesTy = cast(values.getType()); + + axis = axis < 0 ? axis + indicesTy.getSizes().size() + 1 : axis; + + depth = rewriter.create( + loc, rewriter.getType(), depth); + + auto selectTy = rewriter.getType( + llvm::SmallVector{1}, valuesTy.getDtype()); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + Value off = rewriter.create(loc, selectTy, + values, zero, zero); + off = rewriter.create( + loc, rewriter.getType(), off); + + Value on = rewriter.create(loc, selectTy, + values, zero, one); + on = rewriter.create( + loc, rewriter.getType(), on); + + auto i32Ty = rewriter.getIntegerType(32, true); + llvm::SmallVector onehotShape(indicesTy.getSizes()); + onehotShape.push_back(Torch::kUnknownSize); + auto onehotTy = + rewriter.getType(onehotShape, i32Ty); + + Value onehot = rewriter.create( + binder.getLoc(), onehotTy, indices, depth); + + for (int i = valuesTy.getSizes().size(); i > axis; ++i) { + std::swap(onehotShape[i - 1], onehotShape[i]); + Value iv0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + Value iv1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(i - 1)); + + onehotTy = + rewriter.getType(onehotShape, i32Ty); + onehot = rewriter.create(loc, onehotTy, + onehot, iv1, iv0); + } + + // Change one hot to an array of booleans to select value: + auto i1Ty = rewriter.getI1Type(); + auto torchqTy = Torch::getScalarTypeForType(i1Ty); + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(torchqTy))); + + onehotTy = rewriter.getType(onehotShape, i1Ty); + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + onehot = rewriter.create( + loc, onehotTy, onehot, tyConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + + onehotTy = rewriter.getType( + onehotShape, resultType.getDtype()); + onehot = rewriter.create(loc, onehotTy, + onehot, on, off); + + rewriter.replaceOp(binder.op, onehot); + return success(); + }); patterns.onOp("HardSwish", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 037771b7494b6..4691a8ef76f98 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2600,9 +2600,6 @@ "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - # Failure - onnx_lowering: onnx.OneHot - "OneHotModule_basic", - # Failure - onnx_lowering: onnx.RandomNormal "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", @@ -2651,6 +2648,7 @@ "ScatterSrcStaticModule_basic", "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", + # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic",