From 0f80e75c2eb6dfed00bf051644a5e3fb97207bb8 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 20 Feb 2024 17:22:38 -0500 Subject: [PATCH] allow tosa.cast to convert from f32 to f16 (#2934) According to the [official TOSA spec](https://www.mlplatform.org/tosa/tosa_spec.html#_cast), `tosa.cast` allows a cast from `fp32` to `fp16`. We were not previously accounting for this in the `TorchToTosa` lowering. Also did a tiny bit of cleanup in the code to make it easier to spot which conversions are currently allowed. --------- Co-authored-by: Srinath Avadhanula --- .../TorchToTosa/TosaLegalizeUtils.cpp | 26 +++++++++++++++---- .../TorchToTosa/cast_fp32_to_fp16.mlir | 12 +++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 781a5912d83c..9259fdacff24 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -266,28 +266,44 @@ std::optional getConstTensor(PatternRewriter &rewriter, } static LogicalResult checkValidityOfCast(Type src, Type dest) { - if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) || + // clang-format off + if ((src == dest) || + // int64 -> * + (src.isInteger(64) && dest.isInteger(32)) || (src.isInteger(64) && dest.isInteger(8)) || (src.isInteger(64) && dest.isInteger(1)) || (src.isInteger(64) && dest.isF32()) || + // int32 -> * (src.isInteger(32) && dest.isInteger(64)) || (src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isF32()) || (src.isInteger(32) && dest.isBF16()) || + // int16 -> * (src.isInteger(16) && dest.isBF16()) || + // int8 -> * (src.isInteger(8) && dest.isInteger(1)) || (src.isInteger(8) && dest.isBF16()) || + // int1 -> * (src.isInteger(1) && dest.isInteger(64)) || - (src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) || - (src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) || - (src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) || + (src.isInteger(1) && dest.isF32()) || + // f64 -> * + (src.isF64() && dest.isF32()) || + (src.isF64() && dest.isBF16()) || + // f32 -> * + (src.isF32() && dest.isF64()) || + (src.isF32() && dest.isBF16()) || + (src.isF32() && dest.isF16()) || + (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(64)) || (src.isF32() && dest.isInteger(1)) || + // bf16 -> * (src.isBF16() && dest.isInteger(8)) || (src.isBF16() && dest.isInteger(16)) || - (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isF32())) { + (src.isBF16() && dest.isInteger(32)) || + (src.isBF16() && dest.isF32())) { return success(); } + // clang-format on return failure(); } diff --git a/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir b/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir new file mode 100644 index 000000000000..5504ac0e4002 --- /dev/null +++ b/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir @@ -0,0 +1,12 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file + +// CHECK: %{{.*}} = tosa.cast %{{.*}} : (tensor<1x32x220x220xf32>) -> tensor<1x32x220x220xf16> +func.func @forward(%arg0: !torch.vtensor<[1,32,220,220],f32>) -> !torch.vtensor<[1,32,220,220],f16> { + %int5 = torch.constant.int 5 + %false = torch.constant.bool false + %none = torch.constant.none + %out = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,32,220,220],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,32,220,220],f16> + return %out : !torch.vtensor<[1,32,220,220],f16> +} + +