-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dialects: (varith) Add varith (variadic arithmetic) dialect
- Loading branch information
1 parent
4e3e7a9
commit 25e65c3
Showing
4 changed files
with
128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// RUN: xdsl-opt --parsing-diagnostics --verify-diagnostics --split-input-file | ||
|
||
|
||
%i, %f, %t1, %t2 = "test.op"() : () -> (i32, f32, tensor<10xf32>, tensor<5xf32>) | ||
varith.add %i, %f : i32 | ||
// CHECK: operand is used with type i32, but has been previously used or defined with type f32 | ||
|
||
|
||
// ----- | ||
// CHECK: ----- | ||
|
||
|
||
%i, %f, %t1, %t2 = "test.op"() : () -> (i32, f32, tensor<10xf32>, tensor<5xf32>) | ||
varith.add %t1, %t2 : tensor<10xf32> | ||
// CHECK: operand is used with type tensor<10xf32>, but has been previously used or defined with type tensor<5xf32> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// RUN: XDSL_ROUNDTRIP | ||
// RUN: XDSL_GENERIC_ROUNDTRIP | ||
|
||
%ia, %ib, %ic, %id = "test.op"() : () -> (i32, i32, i32, i32) | ||
%fa, %fb, %fc, %fd = "test.op"() : () -> (f32, f32, f32, f32) | ||
%ta, %tb, %tc, %td = "test.op"() : () -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) | ||
|
||
%x1 = "varith.add"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 | ||
// CHECK: %x1 = varith.add %ia, %ib, %ic, %id : i32 | ||
// CHECK-GENERIC: %x1 = "varith.add"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 | ||
|
||
%x2 = "varith.add"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 | ||
// CHECK: %x2 = varith.add %fa, %fb, %fc, %fd : f32 | ||
// CHECK-GENERIC: %x2 = "varith.add"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 | ||
|
||
%x3 = "varith.add"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> | ||
// CHECK: %x3 = varith.add %ta, %tb, %tc, %td : tensor<10xf32> | ||
// CHECK-GENERIC: %x3 = "varith.add"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> | ||
|
||
|
||
%x4 = "varith.mul"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 | ||
// CHECK: %x4 = varith.mul %ia, %ib, %ic, %id : i32 | ||
// CHECK-GENERIC: %x4 = "varith.mul"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 | ||
|
||
%x5 = "varith.mul"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 | ||
// CHECK: %x5 = varith.mul %fa, %fb, %fc, %fd : f32 | ||
// CHECK-GENERIC: %x5 = "varith.mul"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 | ||
|
||
%x6 = "varith.mul"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> | ||
// CHECK: %x6 = varith.mul %ta, %tb, %tc, %td : tensor<10xf32> | ||
// CHECK-GENERIC: %x6 = "varith.mul"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Annotated | ||
|
||
from xdsl.dialects.builtin import ( | ||
BFloat16Type, | ||
ContainerOf, | ||
Float16Type, | ||
Float32Type, | ||
Float64Type, | ||
Float80Type, | ||
Float128Type, | ||
IndexType, | ||
IntegerType, | ||
) | ||
from xdsl.ir import Attribute, Dialect, Operation, SSAValue | ||
from xdsl.irdl import ( | ||
AnyOf, | ||
ConstraintVar, | ||
IRDLOperation, | ||
irdl_op_definition, | ||
result_def, | ||
var_operand_def, | ||
) | ||
from xdsl.traits import Pure | ||
|
||
integerOrFloatLike: ContainerOf = ContainerOf( | ||
AnyOf( | ||
[ | ||
IntegerType, | ||
IndexType, | ||
BFloat16Type, | ||
Float16Type, | ||
Float32Type, | ||
Float64Type, | ||
Float80Type, | ||
Float128Type, | ||
] | ||
) | ||
) | ||
|
||
|
||
class VarithOp(IRDLOperation): | ||
""" | ||
Variadic arithmetic operation | ||
""" | ||
|
||
T = Annotated[Attribute, ConstraintVar("T"), integerOrFloatLike] | ||
|
||
args = var_operand_def(T) | ||
res = result_def(T) | ||
|
||
traits = frozenset((Pure(),)) | ||
|
||
assembly_format = "$args attr-dict `:` type($res)" | ||
|
||
def __init__(self, *args: SSAValue | Operation): | ||
assert len(args) > 0 | ||
super().__init__(operands=[args], result_types=[SSAValue.get(args[-1]).type]) | ||
|
||
|
||
@irdl_op_definition | ||
class VarithAddOp(VarithOp): | ||
name = "varith.add" | ||
|
||
|
||
@irdl_op_definition | ||
class VarithMulOp(VarithOp): | ||
name = "varith.mul" | ||
|
||
|
||
Varith = Dialect( | ||
"varith", | ||
[ | ||
VarithAddOp, | ||
VarithMulOp, | ||
], | ||
) |