diff --git a/tests/filecheck/dialects/varith/invalid.mlir b/tests/filecheck/dialects/varith/invalid.mlir new file mode 100644 index 0000000000..568100e46f --- /dev/null +++ b/tests/filecheck/dialects/varith/invalid.mlir @@ -0,0 +1,15 @@ +// RUN: xdsl-opt --parsing-diagnostics --verify-diagnostics --split-input-file + + +%i, %f, %t1, %t2 = "test.op"() : () -> (i32, f32, tensor<10xf32>, tensor<5xf32>) +varith.add %i, %f : i32 +// CHECK: operand is used with type i32, but has been previously used or defined with type f32 + + +// ----- +// CHECK: ----- + + +%i, %f, %t1, %t2 = "test.op"() : () -> (i32, f32, tensor<10xf32>, tensor<5xf32>) +varith.add %t1, %t2 : tensor<10xf32> +// CHECK: operand is used with type tensor<10xf32>, but has been previously used or defined with type tensor<5xf32> diff --git a/tests/filecheck/dialects/varith/varith_ops.mlir b/tests/filecheck/dialects/varith/varith_ops.mlir new file mode 100644 index 0000000000..461ab861f6 --- /dev/null +++ b/tests/filecheck/dialects/varith/varith_ops.mlir @@ -0,0 +1,31 @@ +// RUN: XDSL_ROUNDTRIP +// RUN: XDSL_GENERIC_ROUNDTRIP + +%ia, %ib, %ic, %id = "test.op"() : () -> (i32, i32, i32, i32) +%fa, %fb, %fc, %fd = "test.op"() : () -> (f32, f32, f32, f32) +%ta, %tb, %tc, %td = "test.op"() : () -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) + +%x1 = "varith.add"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 +// CHECK: %x1 = varith.add %ia, %ib, %ic, %id : i32 +// CHECK-GENERIC: %x1 = "varith.add"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 + +%x2 = "varith.add"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 +// CHECK: %x2 = varith.add %fa, %fb, %fc, %fd : f32 +// CHECK-GENERIC: %x2 = "varith.add"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 + +%x3 = "varith.add"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> +// CHECK: %x3 = varith.add %ta, %tb, %tc, %td : tensor<10xf32> +// CHECK-GENERIC: %x3 = "varith.add"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + + +%x4 = "varith.mul"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 +// CHECK: %x4 = varith.mul %ia, %ib, %ic, %id : i32 +// CHECK-GENERIC: %x4 = "varith.mul"(%ia, %ib, %ic, %id) : (i32, i32, i32, i32) -> i32 + +%x5 = "varith.mul"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 +// CHECK: %x5 = varith.mul %fa, %fb, %fc, %fd : f32 +// CHECK-GENERIC: %x5 = "varith.mul"(%fa, %fb, %fc, %fd) : (f32, f32, f32, f32) -> f32 + +%x6 = "varith.mul"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> +// CHECK: %x6 = varith.mul %ta, %tb, %tc, %td : tensor<10xf32> +// CHECK-GENERIC: %x6 = "varith.mul"(%ta, %tb, %tc, %td) : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> diff --git a/xdsl/dialects/__init__.py b/xdsl/dialects/__init__.py index 1c19b6aacf..314be16645 100644 --- a/xdsl/dialects/__init__.py +++ b/xdsl/dialects/__init__.py @@ -293,6 +293,11 @@ def get_tosa(): return TOSA + def get_varith(): + from xdsl.dialects.varith import Varith + + return Varith + def get_vector(): from xdsl.dialects.vector import Vector @@ -371,6 +376,7 @@ def get_transform(): "tensor": get_tensor, "test": get_test, "tosa": get_tosa, + "varith": get_varith, "vector": get_vector, "wasm": get_wasm, "x86": get_x86, diff --git a/xdsl/dialects/varith.py b/xdsl/dialects/varith.py new file mode 100644 index 0000000000..a10dc28996 --- /dev/null +++ b/xdsl/dialects/varith.py @@ -0,0 +1,76 @@ +from typing import Annotated + +from xdsl.dialects.builtin import ( + BFloat16Type, + ContainerOf, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + Float128Type, + IndexType, + IntegerType, +) +from xdsl.ir import Attribute, Dialect, Operation, SSAValue +from xdsl.irdl import ( + AnyOf, + ConstraintVar, + IRDLOperation, + irdl_op_definition, + result_def, + var_operand_def, +) +from xdsl.traits import Pure + +integerOrFloatLike: ContainerOf = ContainerOf( + AnyOf( + [ + IntegerType, + IndexType, + BFloat16Type, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + Float128Type, + ] + ) +) + + +class VarithOp(IRDLOperation): + """ + Variadic arithmetic operation + """ + + T = Annotated[Attribute, ConstraintVar("T"), integerOrFloatLike] + + args = var_operand_def(T) + res = result_def(T) + + traits = frozenset((Pure(),)) + + assembly_format = "$args attr-dict `:` type($res)" + + def __init__(self, *args: SSAValue | Operation): + assert len(args) > 0 + super().__init__(operands=[args], result_types=[SSAValue.get(args[-1]).type]) + + +@irdl_op_definition +class VarithAddOp(VarithOp): + name = "varith.add" + + +@irdl_op_definition +class VarithMulOp(VarithOp): + name = "varith.mul" + + +Varith = Dialect( + "varith", + [ + VarithAddOp, + VarithMulOp, + ], +)