Skip to content

Commit

Permalink
[Linalg] Add conversion between bf16 and f16 (llvm#3963)
Browse files Browse the repository at this point in the history
To fix issue llvm#3962 :
'arith.extf' op operand type 'bf16' and result type 'f16' are cast
incompatible
  • Loading branch information
AmosLewis authored Jan 17, 2025
1 parent 5e1d68e commit f42c7e4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,

if (auto dtypeFloat = dyn_cast<mlir::FloatType>(dtype)) {
if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType)) {
if (scalarFloat.getWidth() == 16 && dtypeFloat.getWidth() == 16) {
auto scalarF32 = b.create<arith::ExtFOp>(loc, b.getF32Type(), scalar);
return b.create<arith::TruncFOp>(loc, dtype, scalarF32);
}
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
return b.create<arith::TruncFOp>(loc, dtype, scalar);
// Only scalarFloat width < dtypeFloat width can reach here.
Expand Down
16 changes: 16 additions & 0 deletions test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,19 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}

// -----

// CHECK-LABEL: func.func @elementwise_todtype_bf162f16(
// CHECK: linalg.generic
// CHECK: arith.extf
// CHECK-SAME: bf16 to f32
// CHECK: arith.truncf
// CHECK-SAME: f32 to f16
func.func @elementwise_todtype_bf162f16(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
%int5 = torch.constant.int 5
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16>
return %0 : !torch.vtensor<[1,?,32,128],f16>
}

0 comments on commit f42c7e4

Please sign in to comment.