Skip to content

Commit

Permalink
[onnx] Fix onnx.Not for non-bool inputs (llvm#3187)
Browse files Browse the repository at this point in the history
Need to perform a bool cast to support `onnx.Not` on non-bool inputs.
  • Loading branch information
rsuderman authored Apr 19, 2024
1 parent 790a697 commit b01245c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
24 changes: 24 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType)) {
return failure();
}

auto loc = binder.getLoc();
auto operandTy =
cast<Torch::ValueTensorType>(operand.getType());
auto eTy = operandTy.getDtype();

if (!eTy.isInteger(1)) {
auto i1ty = rewriter.getI1Type();
auto ty = rewriter.getType<Torch::ValueTensorType>(
operandTy.getSizes(), i1ty);
auto torchqTy = Torch::getScalarTypeForType(i1ty);
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64),
static_cast<int64_t>(torchqTy)));
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(loc, false);
operand = rewriter.create<Torch::AtenToDtypeOp>(
loc, ty, operand, tyConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
binder.op, resultType, operand);
return success();
Expand Down
3 changes: 0 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2723,9 +2723,6 @@
"MaskedFillTensorFloatValueModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAllDimEmpty_basic",
"ReduceAllDimFloat_basic",
"ReduceAllDimInt_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
Expand Down

0 comments on commit b01245c

Please sign in to comment.