From 4eb677565c5192a9fec6afc511b334e997dfc524 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Thu, 19 Dec 2024 01:37:01 +0100 Subject: [PATCH 1/7] tensor donation --- .../transforms/jax-use-donated-arguments.mlir | 45 +++++++++++++ uv.lock | 2 +- xdsl/transforms/__init__.py | 6 ++ xdsl/transforms/jax_use_donated_arguments.py | 64 +++++++++++++++++++ 4 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 tests/filecheck/transforms/jax-use-donated-arguments.mlir create mode 100644 xdsl/transforms/jax_use_donated_arguments.py diff --git a/tests/filecheck/transforms/jax-use-donated-arguments.mlir b/tests/filecheck/transforms/jax-use-donated-arguments.mlir new file mode 100644 index 0000000000..50beb4610b --- /dev/null +++ b/tests/filecheck/transforms/jax-use-donated-arguments.mlir @@ -0,0 +1,45 @@ +// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s + +func.func public @one_donation(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { + %res = "test.op"() : () -> tensor<2x4xf32> + return %res : tensor<2x4xf32> + } + +// CHECK: builtin.module { +// CHECK-NEXT: func.func public @one_donation(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { +// CHECK-NEXT: %res = "test.op"() : () -> tensor<2x4xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: func.return %0 : tensor<2x4xf32> +// CHECK-NEXT: } + +func.func public @same_type_donation(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { + %res1 = "test.op"() : () -> tensor<2x3xf32> + %res2 = "test.op"() : () -> tensor<2x3xf32> + return %res1, %res2 : tensor<2x3xf32>, tensor<2x3xf32> + } + +// CHECK-NEXT: func.func public @same_type_donation(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { +// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res2 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: func.return %0, %1 : tensor<2x3xf32>, tensor<2x3xf32> +// CHECK-NEXT: } + +func.func public @non_trivial_donation(%arg0: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg2: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { + %res1 = "test.op"() : () -> tensor<2x3xf32> + %res2 = "test.op"() : () -> tensor<2x3xf32> + %res3 = "test.op"() : () -> tensor<4x5xf32> + return %res1, %res2, %res3 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> + } + +// CHECK-NEXT: func.func public @non_trivial_donation(%arg0 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg2 : tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { +// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res3 = "test.op"() : () -> tensor<4x5xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res3 in %arg0 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> +// CHECK-NEXT: func.return %0, %res2, %1 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> +// CHECK-NEXT: } + +// CHECK-NEXT: } diff --git a/uv.lock b/uv.lock index 4f1bc27255..077fa6d608 100644 --- a/uv.lock +++ b/uv.lock @@ -2457,7 +2457,7 @@ wheels = [ [[package]] name = "xdsl" -version = "0+dynamic" +version = "0+untagged.3077.gb394263.dirty" source = { editable = "." } dependencies = [ { name = "immutabledict" }, diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index ed0a62faef..81a12e06e1 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -173,6 +173,11 @@ def get_convert_varith_to_arith(): return varith_transformations.ConvertVarithToArithPass + def get_jax_use_donated_arguments(): + from xdsl.transforms import jax_use_donated_arguments + + return jax_use_donated_arguments.JaxUseDonatedArguments + def get_cse(): from xdsl.transforms import common_subexpression_elimination @@ -524,6 +529,7 @@ def get_varith_fuse_repeated_operands(): "convert-stencil-to-csl-stencil": get_convert_stencil_to_csl_stencil, "convert-stencil-to-ll-mlir": get_convert_stencil_to_ll_mlir, "convert-varith-to-arith": get_convert_varith_to_arith, + "jax-use-donated-arguments": get_jax_use_donated_arguments, "cse": get_cse, "csl-stencil-bufferize": get_csl_stencil_bufferize, "csl-stencil-handle-async-flow": get_csl_stencil_handle_async_flow, diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py new file mode 100644 index 0000000000..bb60d9b94b --- /dev/null +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass + +from xdsl.context import MLContext +from xdsl.dialects import builtin +from xdsl.dialects.bufferization import MaterializeInDestinationOp +from xdsl.dialects.builtin import TensorType +from xdsl.dialects.func import FuncOp, ReturnOp +from xdsl.ir import Operation, SSAValue +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) + + +@dataclass +class SubstituteDonatedTensors(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /): + func_op = op.parent_op() + assert isinstance(func_op, FuncOp) + + if func_op.arg_attrs is None: + return + + donated_inputs = [ + inp + for inp, attr in zip(func_op.args, func_op.arg_attrs, strict=True) + if isinstance(inp.type, TensorType) and "tf.aliasing_output" in attr.data + ] + + value_mapper: dict[SSAValue, SSAValue] = {} + new_ops: list[Operation] = [] + for output in op.arguments: + for i, arg in enumerate(donated_inputs): + if arg.type == output.type: + new_ops.append( + MaterializeInDestinationOp( + operands=[output, donated_inputs.pop(i)], + result_types=[output.type], + ) + ) + value_mapper[output] = new_ops[-1].results[0] + break + + new_ops.append(op.clone(value_mapper)) + rewriter.replace_matched_op(new_ops) + + +@dataclass(frozen=True) +class JaxUseDonatedArguments(ModulePass): + name = "jax-use-donated-arguments" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + the_one_pass = PatternRewriteWalker( + GreedyRewritePatternApplier([SubstituteDonatedTensors()]), + apply_recursively=False, + walk_reverse=True, + walk_regions_first=True, + ) + the_one_pass.rewrite_module(op) From 63831fe7fe407c928d314db9947138036963c2ee Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Thu, 19 Dec 2024 14:22:07 +0100 Subject: [PATCH 2/7] init --- .../frontend/jax-snitch-compilation.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tests/filecheck/frontend/jax-snitch-compilation.py diff --git a/tests/filecheck/frontend/jax-snitch-compilation.py b/tests/filecheck/frontend/jax-snitch-compilation.py new file mode 100644 index 0000000000..c0febf9907 --- /dev/null +++ b/tests/filecheck/frontend/jax-snitch-compilation.py @@ -0,0 +1,88 @@ +# RUN: python %s | filecheck %s + +import jax +import jax.numpy as jnp +from jax._src.interpreters import mlir +from jaxlib.mlir.dialects import stablehlo +from jaxlib.mlir.ir import Context, Module +from jaxlib.mlir.passmanager import PassManager + + +def get_linalg_str(func_jit, args): + lowered = func_jit.lower(*args) + module = lowered.compiler_ir(dialect="stablehlo") + module_str = str(module) + + with Context() as ctx: + ctx.append_dialect_registry(mlir.upstream_dialects) + stablehlo.register_dialect(ctx) + stablehlo.register_stablehlo_passes() + + module = Module.parse(module_str) + + pm = PassManager.parse( + "builtin.module(func.func(" + "shape-legalize-to-stablehlo," + "stablehlo-aggressive-folder," + "stablehlo-aggressive-simplification," + "stablehlo-legalize-to-linalg" + "))" + ) + + pm.run(module.operation) + + return str(module) + + +key = jax.random.key(42) + + +def scale(x: jnp.ndarray, alpha: float): + return x * alpha + + +# print(get_linalg_str(jax.jit(scale), (jax.random.uniform(key, [10]), 0.1))) + +original = """ +#map = affine_map<(d0) -> ()> +#map1 = affine_map<(d0) -> (d0)> +module @jit_scale attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<10xf32>, %arg1: tensor) -> (tensor<10xf32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<10xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg1 : tensor) outs(%0 : tensor<10xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<10xf32> + %2 = tensor.empty() : tensor<10xf32> + %3 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%arg0, %1 : tensor<10xf32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.mulf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<10xf32> + return %3 : tensor<10xf32> + } +} +""" + +changed = """ +#map = affine_map<(d0) -> ()> +#map1 = affine_map<(d0) -> (d0)> +module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<10xf32>, %arg1: tensor) -> tensor<10xf32> { + %0 = tensor.empty() : tensor<10xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg1 : tensor) outs(%0 : tensor<10xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<10xf32> + %2 = tensor.empty() : tensor<10xf32> + %3 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%arg0, %1 : tensor<10xf32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.mulf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<10xf32> + return %3 : tensor<10xf32> + } +} +""" + +print(changed) From a44ab3662a42324248317ae79f5fe24483935728 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 23 Dec 2024 01:45:47 +0100 Subject: [PATCH 3/7] kinda works --- .../frontend/jax-snitch-compilation.py | 57 ++++++------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/tests/filecheck/frontend/jax-snitch-compilation.py b/tests/filecheck/frontend/jax-snitch-compilation.py index c0febf9907..01f5978c3d 100644 --- a/tests/filecheck/frontend/jax-snitch-compilation.py +++ b/tests/filecheck/frontend/jax-snitch-compilation.py @@ -7,6 +7,8 @@ from jaxlib.mlir.ir import Context, Module from jaxlib.mlir.passmanager import PassManager +jax.config.update("jax_enable_x64", True) + def get_linalg_str(func_jit, args): lowered = func_jit.lower(*args) @@ -37,50 +39,23 @@ def get_linalg_str(func_jit, args): key = jax.random.key(42) -def scale(x: jnp.ndarray, alpha: float): - return x * alpha - - -# print(get_linalg_str(jax.jit(scale), (jax.random.uniform(key, [10]), 0.1))) - -original = """ -#map = affine_map<(d0) -> ()> -#map1 = affine_map<(d0) -> (d0)> -module @jit_scale attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<10xf32>, %arg1: tensor) -> (tensor<10xf32> {jax.result_info = ""}) { - %0 = tensor.empty() : tensor<10xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg1 : tensor) outs(%0 : tensor<10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<10xf32> - %2 = tensor.empty() : tensor<10xf32> - %3 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%arg0, %1 : tensor<10xf32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %4 = arith.mulf %in, %in_0 : f32 - linalg.yield %4 : f32 - } -> tensor<10xf32> - return %3 : tensor<10xf32> - } -} -""" +def matadd(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): + return A + B + + +# print(get_linalg_str(jax.jit(matadd, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64)))) changed = """ -#map = affine_map<(d0) -> ()> -#map1 = affine_map<(d0) -> (d0)> +#map = affine_map<(d0, d1) -> (d0, d1)> module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<10xf32>, %arg1: tensor) -> tensor<10xf32> { - %0 = tensor.empty() : tensor<10xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg1 : tensor) outs(%0 : tensor<10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<10xf32> - %2 = tensor.empty() : tensor<10xf32> - %3 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%arg0, %1 : tensor<10xf32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %4 = arith.mulf %in, %in_0 : f32 - linalg.yield %4 : f32 - } -> tensor<10xf32> - return %3 : tensor<10xf32> + func.func public @main(%arg0: tensor<8x16xf64>, %arg1: tensor<8x16xf64>, %arg2: tensor<8x16xf64> {tf.aliasing_output = 0 : i32}) -> tensor<8x16xf64> { + %0 = tensor.empty() : tensor<8x16xf64> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<8x16xf64>, tensor<8x16xf64>) outs(%0 : tensor<8x16xf64>) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %2 = arith.addf %in, %in_0 : f64 + linalg.yield %2 : f64 + } -> tensor<8x16xf64> + return %1 : tensor<8x16xf64> } } """ From dec5b773c6fa8fc862e390fa74dfe806fa1f1ea3 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 23 Dec 2024 21:00:48 +0100 Subject: [PATCH 4/7] more examples --- .../frontend/jax-snitch-compilation.py | 79 ++++++++++++++++--- 1 file changed, 68 insertions(+), 11 deletions(-) diff --git a/tests/filecheck/frontend/jax-snitch-compilation.py b/tests/filecheck/frontend/jax-snitch-compilation.py index 01f5978c3d..ff9b16c6c1 100644 --- a/tests/filecheck/frontend/jax-snitch-compilation.py +++ b/tests/filecheck/frontend/jax-snitch-compilation.py @@ -2,6 +2,8 @@ import jax import jax.numpy as jnp +import numpy as np +from jax import lax from jax._src.interpreters import mlir from jaxlib.mlir.dialects import stablehlo from jaxlib.mlir.ir import Context, Module @@ -26,8 +28,8 @@ def get_linalg_str(func_jit, args): "builtin.module(func.func(" "shape-legalize-to-stablehlo," "stablehlo-aggressive-folder," - "stablehlo-aggressive-simplification," - "stablehlo-legalize-to-linalg" + "stablehlo-aggressive-simplification" + # "stablehlo-legalize-to-linalg" "))" ) @@ -43,21 +45,76 @@ def matadd(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): return A + B +# breaks because of memref +def dot(x: jnp.ndarray, y: jnp.ndarray): + return jnp.dot(x, y) + + +def matmul(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): + return A @ B + + +def relu(A: jnp.ndarray, B: jnp.ndarray): + return jnp.maximum(A, 0) + + +# breaks because of memref +def fill(val: np.float64, A: jnp.ndarray): + return jnp.full(A.shape, val) + + +def conv(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): + return lax.conv_general_dilated(A, B, (1, 1), "VALID") + + # print(get_linalg_str(jax.jit(matadd, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64)))) +# print(get_linalg_str(jax.jit(dot), (jax.random.uniform(key, [100], dtype=np.float64), jax.random.uniform(key, [100], dtype=np.float64)))) + +# print(get_linalg_str(jax.jit(matmul, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 8], dtype=np.float64), jax.random.uniform(key, [8, 8], dtype=np.float64), jax.random.uniform(key, [8, 8], dtype=np.float64)))) + +# print(get_linalg_str(jax.jit(relu, donate_argnames="B", keep_unused=True), (jax.random.uniform(key, [16, 16], dtype=np.float64), jax.random.uniform(key, [16, 16], dtype=np.float64)))) + +# print(get_linalg_str(jax.jit(fill, donate_argnames="A", keep_unused=True), (150., jax.random.uniform(key, [16, 16], dtype=np.float64)))) + +print( + get_linalg_str( + jax.jit(conv, donate_argnames="C", keep_unused=True), + ( + jax.random.uniform(key, [1, 1, 10, 10], dtype=np.float64), + jax.random.uniform(key, [1, 1, 3, 3], dtype=np.float64), + jax.random.uniform(key, [1, 1, 8, 8], dtype=np.float64), + ), + ) +) + changed = """ -#map = affine_map<(d0, d1) -> (d0, d1)> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d6, d0, d2 + d3, d4 + d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d0, d3, d5)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d6, d1, d2, d4)> module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<8x16xf64>, %arg1: tensor<8x16xf64>, %arg2: tensor<8x16xf64> {tf.aliasing_output = 0 : i32}) -> tensor<8x16xf64> { - %0 = tensor.empty() : tensor<8x16xf64> - %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<8x16xf64>, tensor<8x16xf64>) outs(%0 : tensor<8x16xf64>) { + func.func public @main(%arg0: tensor<1x1x10x10xf64>, %arg1: tensor<1x1x3x3xf64>, %arg2: tensor<1x1x8x8xf64> {tf.aliasing_output = 0 : i32}) -> tensor<1x1x8x8xf64> { + %0 = tensor.empty() : tensor<1x1x3x3xf64> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<1x1x3x3xf64>) outs(%0 : tensor<1x1x3x3xf64>) { + ^bb0(%in: f64, %out: f64): + linalg.yield %in : f64 + } -> tensor<1x1x3x3xf64> + %2 = tensor.empty() : tensor<1x1x8x8xf64> + %cst = arith.constant 0.000000e+00 : f64 + %3 = linalg.fill ins(%cst : f64) outs(%2 : tensor<1x1x8x8xf64>) -> tensor<1x1x8x8xf64> + %4 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%arg0, %1 : tensor<1x1x10x10xf64>, tensor<1x1x3x3xf64>) outs(%3 : tensor<1x1x8x8xf64>) { ^bb0(%in: f64, %in_0: f64, %out: f64): - %2 = arith.addf %in, %in_0 : f64 - linalg.yield %2 : f64 - } -> tensor<8x16xf64> - return %1 : tensor<8x16xf64> + %5 = arith.mulf %in, %in_0 : f64 + %6 = arith.addf %out, %5 : f64 + linalg.yield %6 : f64 + } -> tensor<1x1x8x8xf64> + %collapsed = tensor.collapse_shape %4 [[0, 1, 2, 3]] : tensor<1x1x8x8xf64> into tensor<64xf64> + %cast = tensor.cast %collapsed : tensor<64xf64> to tensor<64xf64> + %expanded = tensor.expand_shape %cast [[0, 1, 2, 3]] output_shape [1, 1, 8, 8] : tensor<64xf64> into tensor<1x1x8x8xf64> + return %expanded : tensor<1x1x8x8xf64> } } """ -print(changed) +# print(changed) From 2d04f8f162c2431a6275ec876b5e549811ccee1e Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Tue, 24 Dec 2024 15:45:42 +0100 Subject: [PATCH 5/7] start --- .../frontend/jax-snitch-compilation.py | 78 +++++++------------ 1 file changed, 26 insertions(+), 52 deletions(-) diff --git a/tests/filecheck/frontend/jax-snitch-compilation.py b/tests/filecheck/frontend/jax-snitch-compilation.py index ff9b16c6c1..38ed33e910 100644 --- a/tests/filecheck/frontend/jax-snitch-compilation.py +++ b/tests/filecheck/frontend/jax-snitch-compilation.py @@ -28,8 +28,8 @@ def get_linalg_str(func_jit, args): "builtin.module(func.func(" "shape-legalize-to-stablehlo," "stablehlo-aggressive-folder," - "stablehlo-aggressive-simplification" - # "stablehlo-legalize-to-linalg" + "stablehlo-aggressive-simplification," + "stablehlo-legalize-to-linalg" "))" ) @@ -45,76 +45,50 @@ def matadd(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): return A + B -# breaks because of memref -def dot(x: jnp.ndarray, y: jnp.ndarray): - return jnp.dot(x, y) +# print(get_linalg_str(jax.jit(matadd, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64)))) def matmul(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): return A @ B +# print(get_linalg_str(jax.jit(matmul, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 8], dtype=np.float64), jax.random.uniform(key, [8, 8], dtype=np.float64), jax.random.uniform(key, [8, 8], dtype=np.float64)))) + + def relu(A: jnp.ndarray, B: jnp.ndarray): return jnp.maximum(A, 0) +# print(get_linalg_str(jax.jit(relu, donate_argnames="B", keep_unused=True), (jax.random.uniform(key, [16, 16], dtype=np.float64), jax.random.uniform(key, [16, 16], dtype=np.float64)))) + + +# breaks because of memref +def dot(x: jnp.ndarray, y: jnp.ndarray): + return jnp.dot(x, y) + + +# print(get_linalg_str(jax.jit(dot), (jax.random.uniform(key, [100], dtype=np.float64), jax.random.uniform(key, [100], dtype=np.float64)))) + + # breaks because of memref def fill(val: np.float64, A: jnp.ndarray): return jnp.full(A.shape, val) +# print(get_linalg_str(jax.jit(fill, donate_argnames="A", keep_unused=True), (150., jax.random.uniform(key, [16, 16], dtype=np.float64)))) + + +# a weird copy is inserted def conv(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): return lax.conv_general_dilated(A, B, (1, 1), "VALID") -# print(get_linalg_str(jax.jit(matadd, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64)))) - -# print(get_linalg_str(jax.jit(dot), (jax.random.uniform(key, [100], dtype=np.float64), jax.random.uniform(key, [100], dtype=np.float64)))) +# print(get_linalg_str(jax.jit(conv, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [1, 1, 10, 10], dtype=np.float64),jax.random.uniform(key, [1, 1, 3, 3], dtype=np.float64),jax.random.uniform(key, [1, 1, 8, 8], dtype=np.float64),),)) -# print(get_linalg_str(jax.jit(matmul, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 8], dtype=np.float64), jax.random.uniform(key, [8, 8], dtype=np.float64), jax.random.uniform(key, [8, 8], dtype=np.float64)))) -# print(get_linalg_str(jax.jit(relu, donate_argnames="B", keep_unused=True), (jax.random.uniform(key, [16, 16], dtype=np.float64), jax.random.uniform(key, [16, 16], dtype=np.float64)))) +# one of the reduction dimensions is f32 => it can't be streamed and it breaks +def max_pool(A: jnp.ndarray, B: jnp.ndarray): + return lax.reduce_window(A, -10000.0, lax.max, [1, 1, 3, 3], [1, 1, 2, 2], "VALID") -# print(get_linalg_str(jax.jit(fill, donate_argnames="A", keep_unused=True), (150., jax.random.uniform(key, [16, 16], dtype=np.float64)))) -print( - get_linalg_str( - jax.jit(conv, donate_argnames="C", keep_unused=True), - ( - jax.random.uniform(key, [1, 1, 10, 10], dtype=np.float64), - jax.random.uniform(key, [1, 1, 3, 3], dtype=np.float64), - jax.random.uniform(key, [1, 1, 8, 8], dtype=np.float64), - ), - ) -) - -changed = """ -#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d6, d0, d2 + d3, d4 + d5)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d0, d3, d5)> -#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d6, d1, d2, d4)> -module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<1x1x10x10xf64>, %arg1: tensor<1x1x3x3xf64>, %arg2: tensor<1x1x8x8xf64> {tf.aliasing_output = 0 : i32}) -> tensor<1x1x8x8xf64> { - %0 = tensor.empty() : tensor<1x1x3x3xf64> - %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<1x1x3x3xf64>) outs(%0 : tensor<1x1x3x3xf64>) { - ^bb0(%in: f64, %out: f64): - linalg.yield %in : f64 - } -> tensor<1x1x3x3xf64> - %2 = tensor.empty() : tensor<1x1x8x8xf64> - %cst = arith.constant 0.000000e+00 : f64 - %3 = linalg.fill ins(%cst : f64) outs(%2 : tensor<1x1x8x8xf64>) -> tensor<1x1x8x8xf64> - %4 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%arg0, %1 : tensor<1x1x10x10xf64>, tensor<1x1x3x3xf64>) outs(%3 : tensor<1x1x8x8xf64>) { - ^bb0(%in: f64, %in_0: f64, %out: f64): - %5 = arith.mulf %in, %in_0 : f64 - %6 = arith.addf %out, %5 : f64 - linalg.yield %6 : f64 - } -> tensor<1x1x8x8xf64> - %collapsed = tensor.collapse_shape %4 [[0, 1, 2, 3]] : tensor<1x1x8x8xf64> into tensor<64xf64> - %cast = tensor.cast %collapsed : tensor<64xf64> to tensor<64xf64> - %expanded = tensor.expand_shape %cast [[0, 1, 2, 3]] output_shape [1, 1, 8, 8] : tensor<64xf64> into tensor<1x1x8x8xf64> - return %expanded : tensor<1x1x8x8xf64> - } -} -""" - -# print(changed) +# print(get_linalg_str(jax.jit(max_pool, donate_argnames="B", keep_unused=True), (jax.random.uniform(key, [1, 1, 18, 18], dtype=np.float64), jax.random.uniform(key, [1, 1, 8, 8], dtype=np.float64)))) From d00ee3b46e129966682efa20f6943613d3afd460 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Tue, 24 Dec 2024 16:29:41 +0100 Subject: [PATCH 6/7] init test --- .../frontend/jax-snitch-compilation.py | 168 +++++++++++++++++- xdsl/dialects/builtin.py | 1 + xdsl/dialects/utils/format.py | 4 +- 3 files changed, 163 insertions(+), 10 deletions(-) diff --git a/tests/filecheck/frontend/jax-snitch-compilation.py b/tests/filecheck/frontend/jax-snitch-compilation.py index 38ed33e910..e6956cfd52 100644 --- a/tests/filecheck/frontend/jax-snitch-compilation.py +++ b/tests/filecheck/frontend/jax-snitch-compilation.py @@ -1,4 +1,6 @@ -# RUN: python %s | filecheck %s +# RUN: python %s | mlir-opt --split-input-file --allow-unregistered-dialect --linalg-generalize-named-ops | xdsl-opt --split-input-file -p jax-use-donated-arguments | \ +# RUN: mlir-opt --split-input-file --allow-unregistered-dialect --eliminate-empty-tensors --one-shot-bufferize="bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" --canonicalize --mlir-print-op-generic | \ +# RUN: xdsl-opt --split-input-file -p test-lower-linalg-to-snitch -t riscv-asm | filecheck %s import jax import jax.numpy as jnp @@ -41,25 +43,175 @@ def get_linalg_str(func_jit, args): key = jax.random.key(42) +# CHECK: .text +# CHECK-NEXT: .globl main +# CHECK-NEXT: .p2align 2 +# CHECK-NEXT: # Regalloc stats: {"preallocated_float": ["ft0", "ft1", "ft2"], "preallocated_int": ["a0", "a1", "a2", "zero"], "allocated_float": ["ft0", "ft1", "ft2"], "allocated_int": ["a0", "a1", "a2", "t0", "t1", "t2", "t3", "zero"]} +# CHECK-NEXT: main: +# CHECK-NEXT: mv t2, a0 +# CHECK-NEXT: mv t1, a1 +# CHECK-NEXT: mv t0, a2 +# CHECK-NEXT: li t3, 127 +# CHECK-NEXT: scfgwi t3, 95 # dm 31 dim 0 bound +# CHECK-NEXT: li t3, 8 +# CHECK-NEXT: scfgwi t3, 223 # dm 31 dim 0 stride +# CHECK-NEXT: scfgwi zero, 63 # dm 31 repeat +# CHECK-NEXT: scfgwi t2, 768 # dm 0 dim 0 source +# CHECK-NEXT: scfgwi t1, 769 # dm 1 dim 0 source +# CHECK-NEXT: scfgwi t0, 898 # dm 2 dim 0 destination +# CHECK-NEXT: csrrsi zero, 1984, 1 # SSR enable +# CHECK-NEXT: li t1, 127 +# CHECK-NEXT: frep.o t1, 1, 0, 0 +# CHECK-NEXT: fadd.d ft2, ft0, ft1 +# CHECK-NEXT: csrrci zero, 1984, 1 # SSR disable +# CHECK-NEXT: mv a0, t0 +# CHECK-NEXT: ret def matadd(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): return A + B -# print(get_linalg_str(jax.jit(matadd, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64), jax.random.uniform(key, [8, 16], dtype=np.float64)))) - - +print( + get_linalg_str( + jax.jit(matadd, donate_argnames="C", keep_unused=True), + ( + jax.random.uniform(key, [8, 16], dtype=np.float64), + jax.random.uniform(key, [8, 16], dtype=np.float64), + jax.random.uniform(key, [8, 16], dtype=np.float64), + ), + ) +) +print("// -----") + + +# CHECK: .text +# CHECK-NEXT: .globl main +# CHECK-NEXT: .p2align 2 +# CHECK-NEXT: # Regalloc stats: {{.*}} +# CHECK-NEXT: main: +# CHECK-NEXT: mv t2, a0 +# CHECK-NEXT: mv t3, a1 +# CHECK-NEXT: mv t0, a2 +# CHECK-NEXT: fcvt.d.w ft3, zero +# CHECK-NEXT: li t1, 7 +# CHECK-NEXT: scfgwi t1, 64 # dm 0 dim 0 bound +# CHECK-NEXT: li t1, 1 +# CHECK-NEXT: scfgwi t1, 96 # dm 0 dim 1 bound +# CHECK-NEXT: li t1, 7 +# CHECK-NEXT: scfgwi t1, 128 # dm 0 dim 2 bound +# CHECK-NEXT: li t1, 8 +# CHECK-NEXT: scfgwi t1, 192 # dm 0 dim 0 stride +# CHECK-NEXT: li t1, -56 +# CHECK-NEXT: scfgwi t1, 224 # dm 0 dim 1 stride +# CHECK-NEXT: li t1, 8 +# CHECK-NEXT: scfgwi t1, 256 # dm 0 dim 2 stride +# CHECK-NEXT: li t1, 3 +# CHECK-NEXT: scfgwi t1, 32 # dm 0 repeat +# CHECK-NEXT: li t1, 3 +# CHECK-NEXT: scfgwi t1, 65 # dm 1 dim 0 bound +# CHECK-NEXT: li t1, 7 +# CHECK-NEXT: scfgwi t1, 97 # dm 1 dim 1 bound +# CHECK-NEXT: li t1, 1 +# CHECK-NEXT: scfgwi t1, 129 # dm 1 dim 2 bound +# CHECK-NEXT: li t1, 7 +# CHECK-NEXT: scfgwi t1, 161 # dm 1 dim 3 bound +# CHECK-NEXT: li t1, 8 +# CHECK-NEXT: scfgwi t1, 193 # dm 1 dim 0 stride +# CHECK-NEXT: li t1, 40 +# CHECK-NEXT: scfgwi t1, 225 # dm 1 dim 1 stride +# CHECK-NEXT: li t1, -440 +# CHECK-NEXT: scfgwi t1, 257 # dm 1 dim 2 stride +# CHECK-NEXT: li t1, -504 +# CHECK-NEXT: scfgwi t1, 289 # dm 1 dim 3 stride +# CHECK-NEXT: scfgwi zero, 33 # dm 1 repeat +# CHECK-NEXT: li t1, 63 +# CHECK-NEXT: scfgwi t1, 66 # dm 2 dim 0 bound +# CHECK-NEXT: li t1, 8 +# CHECK-NEXT: scfgwi t1, 194 # dm 2 dim 0 stride +# CHECK-NEXT: scfgwi zero, 34 # dm 2 repeat +# CHECK-NEXT: scfgwi t2, 832 # dm 0 dim 2 source +# CHECK-NEXT: scfgwi t3, 865 # dm 1 dim 3 source +# CHECK-NEXT: scfgwi t0, 898 # dm 2 dim 0 destination +# CHECK-NEXT: csrrsi zero, 1984, 1 # SSR enable +# CHECK-NEXT: li t2, 16 +# CHECK-NEXT: mv t1, zero +# CHECK-NEXT: # Constant folded riscv_cf.bge +# CHECK-NEXT: scf_body_0_for: +# CHECK-NEXT: fmv.d ft7, ft3 +# CHECK-NEXT: fmv.d ft6, ft3 +# CHECK-NEXT: fmv.d ft5, ft3 +# CHECK-NEXT: fmv.d ft4, ft3 +# CHECK-NEXT: li t4, 7 +# CHECK-NEXT: frep.o t4, 8, 0, 0 +# CHECK-NEXT: fmul.d ft11, ft0, ft1 +# CHECK-NEXT: fmul.d ft10, ft0, ft1 +# CHECK-NEXT: fmul.d ft9, ft0, ft1 +# CHECK-NEXT: fmul.d ft8, ft0, ft1 +# CHECK-NEXT: fadd.d ft7, ft7, ft11 +# CHECK-NEXT: fadd.d ft6, ft6, ft10 +# CHECK-NEXT: fadd.d ft5, ft5, ft9 +# CHECK-NEXT: fadd.d ft4, ft4, ft8 +# CHECK-NEXT: fmv.d ft2, ft7 +# CHECK-NEXT: fmv.d ft2, ft6 +# CHECK-NEXT: fmv.d ft2, ft5 +# CHECK-NEXT: fmv.d ft2, ft4 +# CHECK-NEXT: addi t1, t1, 1 +# CHECK-NEXT: blt t1, t2, scf_body_0_for +# CHECK-NEXT: scf_body_end_0_for: +# CHECK-NEXT: csrrci zero, 1984, 1 # SSR disable +# CHECK-NEXT: mv a0, t0 +# CHECK-NEXT: ret def matmul(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): return A @ B -# print(get_linalg_str(jax.jit(matmul, donate_argnames="C", keep_unused=True), (jax.random.uniform(key, [8, 8], dtype=np.float64), jax.random.uniform(key, [8, 8], dtype=np.float64), jax.random.uniform(key, [8, 8], dtype=np.float64)))) - - +print( + get_linalg_str( + jax.jit(matmul, donate_argnames="C", keep_unused=True), + ( + jax.random.uniform(key, [8, 8], dtype=np.float64), + jax.random.uniform(key, [8, 8], dtype=np.float64), + jax.random.uniform(key, [8, 8], dtype=np.float64), + ), + ) +) +print("// -----") + + +# CHECK: .text +# CHECK-NEXT: .globl main +# CHECK-NEXT: .p2align 2 +# CHECK-NEXT: # Regalloc stats: {"preallocated_float": ["ft0", "ft1", "ft2"], "preallocated_int": ["a0", "a1", "zero"], "allocated_float": ["ft0", "ft1", "ft3"], "allocated_int": ["a0", "a1", "t0", "t1", "t2", "zero"]} +# CHECK-NEXT: main: +# CHECK-NEXT: mv t1, a0 +# CHECK-NEXT: mv t0, a1 +# CHECK-NEXT: fcvt.d.w ft3, zero +# CHECK-NEXT: li t2, 255 +# CHECK-NEXT: scfgwi t2, 95 # dm 31 dim 0 bound +# CHECK-NEXT: li t2, 8 +# CHECK-NEXT: scfgwi t2, 223 # dm 31 dim 0 stride +# CHECK-NEXT: scfgwi zero, 63 # dm 31 repeat +# CHECK-NEXT: scfgwi t1, 768 # dm 0 dim 0 source +# CHECK-NEXT: scfgwi t0, 897 # dm 1 dim 0 destination +# CHECK-NEXT: csrrsi zero, 1984, 1 # SSR enable +# CHECK-NEXT: li t1, 255 +# CHECK-NEXT: frep.o t1, 1, 0, 0 +# CHECK-NEXT: fmax.d ft1, ft0, ft3 +# CHECK-NEXT: csrrci zero, 1984, 1 # SSR disable +# CHECK-NEXT: mv a0, t0 +# CHECK-NEXT: ret def relu(A: jnp.ndarray, B: jnp.ndarray): return jnp.maximum(A, 0) -# print(get_linalg_str(jax.jit(relu, donate_argnames="B", keep_unused=True), (jax.random.uniform(key, [16, 16], dtype=np.float64), jax.random.uniform(key, [16, 16], dtype=np.float64)))) +print( + get_linalg_str( + jax.jit(relu, donate_argnames="B", keep_unused=True), + ( + jax.random.uniform(key, [16, 16], dtype=np.float64), + jax.random.uniform(key, [16, 16], dtype=np.float64), + ), + ) +) # breaks because of memref diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index e4bf295f56..23532a63c2 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1658,6 +1658,7 @@ def ops(self) -> BlockOps: @classmethod def parse(cls, parser: Parser) -> ModuleOp: + _ = parser.parse_optional_symbol_name() attributes = parser.parse_optional_attr_dict_with_keyword() if attributes is not None: attributes = attributes.data diff --git a/xdsl/dialects/utils/format.py b/xdsl/dialects/utils/format.py index 3b0a5c91c2..92364e5631 100644 --- a/xdsl/dialects/utils/format.py +++ b/xdsl/dialects/utils/format.py @@ -61,7 +61,7 @@ def print_func_op_like( printer.print(") ") if function_type.outputs: printer.print("-> ") - if len(function_type.outputs) > 1: + if len(function_type.outputs) > 1 or res_attrs is not None: printer.print("(") if res_attrs is not None: printer.print_list( @@ -72,7 +72,7 @@ def print_func_op_like( ) else: printer.print_list(function_type.outputs, printer.print_attribute) - if len(function_type.outputs) > 1: + if len(function_type.outputs) > 1 or res_attrs is not None: printer.print(")") printer.print(" ") else: From c682d058eb9ffbb255d29ded9bbf14b460488136 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Wed, 25 Dec 2024 20:58:28 +0100 Subject: [PATCH 7/7] remove outputs --- ...-use-donated-arguments-remove-outputs.mlir | 42 +++++++++++++++++++ xdsl/transforms/jax_use_donated_arguments.py | 37 +++++++++++++--- 2 files changed, 73 insertions(+), 6 deletions(-) create mode 100644 tests/filecheck/transforms/jax-use-donated-arguments-remove-outputs.mlir diff --git a/tests/filecheck/transforms/jax-use-donated-arguments-remove-outputs.mlir b/tests/filecheck/transforms/jax-use-donated-arguments-remove-outputs.mlir new file mode 100644 index 0000000000..71ae615892 --- /dev/null +++ b/tests/filecheck/transforms/jax-use-donated-arguments-remove-outputs.mlir @@ -0,0 +1,42 @@ +// RUN: xdsl-opt %s -p jax-use-donated-arguments{remove_matched_outputs=true} --split-input-file --verify-diagnostics | filecheck %s + +func.func public @one_donation(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { + %res = "test.op"() : () -> tensor<2x4xf32> + return %res : tensor<2x4xf32> +} + +// CHECK: func.func public @one_donation(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) { +// CHECK-NEXT: %res = "test.op"() : () -> tensor<2x4xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +func.func public @same_type_donation(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { + %res1 = "test.op"() : () -> tensor<2x3xf32> + %res2 = "test.op"() : () -> tensor<2x3xf32> + return %res1, %res2 : tensor<2x3xf32>, tensor<2x3xf32> +} + +// CHECK: func.func public @same_type_donation(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}) { +// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res2 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +func.func public @non_trivial_donation(%arg0: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg2: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { + %res1 = "test.op"() : () -> tensor<2x3xf32> + %res2 = "test.op"() : () -> tensor<2x3xf32> + %res3 = "test.op"() : () -> tensor<4x5xf32> + return %res1, %res2, %res3 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> +} + +// CHECK: func.func public @non_trivial_donation(%arg0 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg2 : tensor<2x3xf32>) -> tensor<2x3xf32> { +// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res3 = "test.op"() : () -> tensor<4x5xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res3 in %arg0 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> +// CHECK-NEXT: func.return %res2 : tensor<2x3xf32> +// CHECK-NEXT: } diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index bb60d9b94b..fba21bccc7 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -3,7 +3,7 @@ from xdsl.context import MLContext from xdsl.dialects import builtin from xdsl.dialects.bufferization import MaterializeInDestinationOp -from xdsl.dialects.builtin import TensorType +from xdsl.dialects.builtin import Attribute, FunctionType, TensorType from xdsl.dialects.func import FuncOp, ReturnOp from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass @@ -18,6 +18,8 @@ @dataclass class SubstituteDonatedTensors(RewritePattern): + remove_matched_outputs: bool = False + @op_type_rewrite_pattern def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /): func_op = op.parent_op() @@ -32,9 +34,12 @@ def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /): if isinstance(inp.type, TensorType) and "tf.aliasing_output" in attr.data ] - value_mapper: dict[SSAValue, SSAValue] = {} new_ops: list[Operation] = [] - for output in op.arguments: + new_outputs: list[SSAValue] = [] + matched_output_idxes: set[int] = set() + + for output_idx, output in enumerate(op.arguments): + final_output = output for i, arg in enumerate(donated_inputs): if arg.type == output.type: new_ops.append( @@ -43,10 +48,26 @@ def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /): result_types=[output.type], ) ) - value_mapper[output] = new_ops[-1].results[0] + final_output = new_ops[-1].results[0] + matched_output_idxes.add(output_idx) break + new_outputs.append(final_output) + + output_types = list(func_op.function_type.outputs.data) - new_ops.append(op.clone(value_mapper)) + if self.remove_matched_outputs: + new_outputs_trimmed: list[SSAValue] = [] + output_types_trimmed: list[Attribute] = [] + for i in range(len(new_outputs)): + if i not in matched_output_idxes: + new_outputs_trimmed.append(new_outputs[i]) + output_types_trimmed.append(output_types[i]) + new_outputs, output_types = new_outputs_trimmed, output_types_trimmed + + func_op.function_type = FunctionType.from_lists( + func_op.function_type.inputs.data, output_types + ) + new_ops.append(ReturnOp(*new_outputs)) rewriter.replace_matched_op(new_ops) @@ -54,9 +75,13 @@ def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /): class JaxUseDonatedArguments(ModulePass): name = "jax-use-donated-arguments" + remove_matched_outputs: bool = False + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: the_one_pass = PatternRewriteWalker( - GreedyRewritePatternApplier([SubstituteDonatedTensors()]), + GreedyRewritePatternApplier( + [SubstituteDonatedTensors(self.remove_matched_outputs)] + ), apply_recursively=False, walk_reverse=True, walk_regions_first=True,