diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 781a5912d83c..9259fdacff24 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -266,28 +266,44 @@ std::optional getConstTensor(PatternRewriter &rewriter, } static LogicalResult checkValidityOfCast(Type src, Type dest) { - if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) || + // clang-format off + if ((src == dest) || + // int64 -> * + (src.isInteger(64) && dest.isInteger(32)) || (src.isInteger(64) && dest.isInteger(8)) || (src.isInteger(64) && dest.isInteger(1)) || (src.isInteger(64) && dest.isF32()) || + // int32 -> * (src.isInteger(32) && dest.isInteger(64)) || (src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isF32()) || (src.isInteger(32) && dest.isBF16()) || + // int16 -> * (src.isInteger(16) && dest.isBF16()) || + // int8 -> * (src.isInteger(8) && dest.isInteger(1)) || (src.isInteger(8) && dest.isBF16()) || + // int1 -> * (src.isInteger(1) && dest.isInteger(64)) || - (src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) || - (src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) || - (src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) || + (src.isInteger(1) && dest.isF32()) || + // f64 -> * + (src.isF64() && dest.isF32()) || + (src.isF64() && dest.isBF16()) || + // f32 -> * + (src.isF32() && dest.isF64()) || + (src.isF32() && dest.isBF16()) || + (src.isF32() && dest.isF16()) || + (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(64)) || (src.isF32() && dest.isInteger(1)) || + // bf16 -> * (src.isBF16() && dest.isInteger(8)) || (src.isBF16() && dest.isInteger(16)) || - (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isF32())) { + (src.isBF16() && dest.isInteger(32)) || + (src.isBF16() && dest.isF32())) { return success(); } + // clang-format on return failure(); } diff --git a/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir b/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir new file mode 100644 index 000000000000..5504ac0e4002 --- /dev/null +++ b/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir @@ -0,0 +1,12 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file + +// CHECK: %{{.*}} = tosa.cast %{{.*}} : (tensor<1x32x220x220xf32>) -> tensor<1x32x220x220xf16> +func.func @forward(%arg0: !torch.vtensor<[1,32,220,220],f32>) -> !torch.vtensor<[1,32,220,220],f16> { + %int5 = torch.constant.int 5 + %false = torch.constant.bool false + %none = torch.constant.none + %out = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,32,220,220],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,32,220,220],f16> + return %out : !torch.vtensor<[1,32,220,220],f16> +} + +