Skip to content

Commit

Permalink
allow tosa.cast to convert from f32 to f16 (llvm#2934)
Browse files Browse the repository at this point in the history
According to the [official TOSA
spec](https://www.mlplatform.org/tosa/tosa_spec.html#_cast), `tosa.cast`
allows a cast from `fp32` to `fp16`. We were not previously accounting
for this in the `TorchToTosa` lowering.

Also did a tiny bit of cleanup in the code to make it easier to spot
which conversions are currently allowed.

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
  • Loading branch information
srinathava and Srinath Avadhanula authored Feb 20, 2024
1 parent 534b266 commit 0f80e75
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
26 changes: 21 additions & 5 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,28 +266,44 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
}

static LogicalResult checkValidityOfCast(Type src, Type dest) {
if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) ||
// clang-format off
if ((src == dest) ||
// int64 -> *
(src.isInteger(64) && dest.isInteger(32)) ||
(src.isInteger(64) && dest.isInteger(8)) ||
(src.isInteger(64) && dest.isInteger(1)) ||
(src.isInteger(64) && dest.isF32()) ||
// int32 -> *
(src.isInteger(32) && dest.isInteger(64)) ||
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) ||
(src.isInteger(32) && dest.isBF16()) ||
// int16 -> *
(src.isInteger(16) && dest.isBF16()) ||
// int8 -> *
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isBF16()) ||
// int1 -> *
(src.isInteger(1) && dest.isInteger(64)) ||
(src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) ||
(src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) ||
(src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) ||
(src.isInteger(1) && dest.isF32()) ||
// f64 -> *
(src.isF64() && dest.isF32()) ||
(src.isF64() && dest.isBF16()) ||
// f32 -> *
(src.isF32() && dest.isF64()) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isInteger(64)) ||
(src.isF32() && dest.isInteger(1)) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isF32())) {
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isF32())) {
return success();
}
// clang-format on
return failure();
}

Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file

// CHECK: %{{.*}} = tosa.cast %{{.*}} : (tensor<1x32x220x220xf32>) -> tensor<1x32x220x220xf16>
func.func @forward(%arg0: !torch.vtensor<[1,32,220,220],f32>) -> !torch.vtensor<[1,32,220,220],f16> {
%int5 = torch.constant.int 5
%false = torch.constant.bool false
%none = torch.constant.none
%out = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,32,220,220],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,32,220,220],f16>
return %out : !torch.vtensor<[1,32,220,220],f16>
}


0 comments on commit 0f80e75

Please sign in to comment.