Skip to content

Commit

Permalink
[MLIR][TORCH] Add onnx.cast cases used by OPT-1.25M (#2787)
Browse files Browse the repository at this point in the history
  • Loading branch information
newling authored Jan 23, 2024
1 parent c9d8ffb commit dc056e5
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 67 deletions.
148 changes: 81 additions & 67 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,39 @@
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;

static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
int64_t dtypeIntTorch;
// TODO: Add complete mapping.
switch (dtypeIntOnnx) {
case 1:
dtypeIntTorch = 6; // float
break;
case 10:
dtypeIntTorch = 5; // half
break;
case 11:
dtypeIntTorch = 7; // double
break;
case 16:
dtypeIntTorch = 15; // bfloat16
break;
default:
dtypeIntTorch = -1; // No dtype
}
// Where are the ONNX and PyTorch dtype enums defined?
// ONNX:
// https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
// PyTorch:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88

int64_t dtypeIntTorch = [dtypeIntOnnx]() {
switch (dtypeIntOnnx) {
case 1:
return 6; // float
case 7:
return 5; // int64
case 9:
return 11; // bool
case 10:
return 5; // half
case 11:
return 7; // double
case 16:
return 15; // bfloat16
default:
return -1; // No dtype
}
}();

return dtypeIntTorch;
}

Expand Down Expand Up @@ -415,30 +424,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}
return success();
});
patterns.onOp(
"BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"BitwiseOr", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseOrTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseAnd", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseOr", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseOrTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseNot", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand All @@ -450,18 +459,18 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"BitwiseXor", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseXorTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseXor", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseXorTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand All @@ -474,9 +483,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (dtypeIntTorch == -1) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
auto message = llvm::formatv("unimplemented support for the given "
"dtype conversion (onnx 'type' = {0})",
dtypeIntOnnx);
llvm::errs() << message << "\n";
auto y = rewriter.notifyMatchFailure(binder.op, message);

return y;
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
Expand Down Expand Up @@ -864,7 +877,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
unsigned rank = *maybeRank;

SmallVector<int64_t> padding, strides, dilations, outputPadding;
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations,
defaultOutputPadding;
for (unsigned i = 0; i < rank - 2; i++) {
defaultPadding.push_back(0);
defaultStrides.push_back(1);
Expand Down Expand Up @@ -1018,30 +1032,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
Value rankVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
rank));
rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank));
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));

Value axisScalar = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), axisTensor);
Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axisScalar, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative);
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
binder.getLoc(), axisScalar, zero);
isNegative =
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, rankVal);
Value dim = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), axisScalar, finalOffset);

Torch::BaseTensorType resultTensorType = resultType.cast<Torch::BaseTensorType>();
Torch::BaseTensorType resultTensorType =
resultType.cast<Torch::BaseTensorType>();
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
}
// resultTensorType.print(llvm::outs());
Value resultDType =
Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype());
Value resultDType = Torch::getDtypeIntValueForType(
rewriter, loc, resultTensorType.getDtype());

rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(
binder.op, resultType, operand, dim, resultDType);
Expand Down
10 changes: 10 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !tor
return %0 : !torch.vtensor<[3,4],f64>
}

// CHECK-LABEL: @test_cast_FLOAT_to_BOOL
func.func @test_cast_FLOAT_to_BOOL(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 11
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],i1>
%0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1>
return %0 : !torch.vtensor<[3,4],i1>
}

// CHECK-LABEL: @test_cast_FLOAT16_to_FLOAT
func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 6
Expand Down

0 comments on commit dc056e5

Please sign in to comment.