-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
transformations: New linalg-fuse-multiply-add pass (#3347)
Introduces a new pass to fuse `linalg.mul` and `linalg.add` into `linalg.generic` with the functionality of an FMA op. --------- Co-authored-by: n-io <n-io@users.noreply.github.com>
- Loading branch information
Showing
3 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// RUN: xdsl-opt %s -p linalg-fuse-multiply-add | filecheck %s | ||
// RUN: xdsl-opt %s -p linalg-fuse-multiply-add{require_scalar_factor=true} | filecheck %s --check-prefix=SCALAR | ||
// RUN: xdsl-opt %s -p linalg-fuse-multiply-add{require_erasable_mul=true} | filecheck %s --check-prefix=FOLD-MUL | ||
|
||
builtin.module { | ||
%t0, %t1, %t2, %t3 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) | ||
%c = arith.constant dense<2.997925e+08> : tensor<8xf32> | ||
%0 = linalg.mul ins(%t0, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%t0 : tensor<8xf32>) -> tensor<8xf32> | ||
%1 = linalg.mul ins(%c, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) -> tensor<8xf32> | ||
%2 = linalg.add ins(%0, %t2 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> | ||
%3 = linalg.add ins(%1, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%1 : tensor<8xf32>) -> tensor<8xf32> | ||
%4 = linalg.sub ins(%1, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%1 : tensor<8xf32>) -> tensor<8xf32> | ||
} | ||
|
||
// CHECK-NEXT: builtin.module { | ||
// CHECK-NEXT: %t0, %t1, %t2, %t3 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) | ||
// CHECK-NEXT: %c = arith.constant dense<2.997925e+08> : tensor<8xf32> | ||
// CHECK-NEXT: %0 = linalg.mul ins(%c, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) -> tensor<8xf32> | ||
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%t0, %t1, %t2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) outs(%t0 : tensor<8xf32>) { | ||
// CHECK-NEXT: ^0(%2 : f32, %3 : f32, %4 : f32, %5 : f32): | ||
// CHECK-NEXT: %6 = arith.mulf %2, %3 : f32 | ||
// CHECK-NEXT: %7 = arith.addf %6, %4 : f32 | ||
// CHECK-NEXT: linalg.yield %7 : f32 | ||
// CHECK-NEXT: } -> tensor<8xf32> | ||
// CHECK-NEXT: %8 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%c, %t1, %t3 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) { | ||
// CHECK-NEXT: ^1(%9 : f32, %10 : f32, %11 : f32, %12 : f32): | ||
// CHECK-NEXT: %13 = arith.mulf %9, %10 : f32 | ||
// CHECK-NEXT: %14 = arith.addf %13, %11 : f32 | ||
// CHECK-NEXT: linalg.yield %14 : f32 | ||
// CHECK-NEXT: } -> tensor<8xf32> | ||
// CHECK-NEXT: %15 = linalg.sub ins(%0, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> | ||
// CHECK-NEXT: } | ||
|
||
|
||
// SCALAR-NEXT: builtin.module { | ||
// SCALAR-NEXT: %t0, %t1, %t2, %t3 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) | ||
// SCALAR-NEXT: %c = arith.constant dense<2.997925e+08> : tensor<8xf32> | ||
// SCALAR-NEXT: %0 = linalg.mul ins(%t0, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%t0 : tensor<8xf32>) -> tensor<8xf32> | ||
// SCALAR-NEXT: %1 = linalg.mul ins(%c, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) -> tensor<8xf32> | ||
// SCALAR-NEXT: %2 = linalg.add ins(%0, %t2 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> | ||
// SCALAR-NEXT: %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%c, %t1, %t3 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) { | ||
// SCALAR-NEXT: ^0(%4 : f32, %5 : f32, %6 : f32, %7 : f32): | ||
// SCALAR-NEXT: %8 = arith.mulf %4, %5 : f32 | ||
// SCALAR-NEXT: %9 = arith.addf %8, %6 : f32 | ||
// SCALAR-NEXT: linalg.yield %9 : f32 | ||
// SCALAR-NEXT: } -> tensor<8xf32> | ||
// SCALAR-NEXT: %10 = linalg.sub ins(%1, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%1 : tensor<8xf32>) -> tensor<8xf32> | ||
// SCALAR-NEXT: } | ||
|
||
|
||
// FOLD-MUL-NEXT: builtin.module { | ||
// FOLD-MUL-NEXT: %t0, %t1, %t2, %t3 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) | ||
// FOLD-MUL-NEXT: %c = arith.constant dense<2.997925e+08> : tensor<8xf32> | ||
// FOLD-MUL-NEXT: %0 = linalg.mul ins(%c, %t1 : tensor<8xf32>, tensor<8xf32>) outs(%c : tensor<8xf32>) -> tensor<8xf32> | ||
// FOLD-MUL-NEXT: %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%t0, %t1, %t2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) outs(%t0 : tensor<8xf32>) { | ||
// FOLD-MUL-NEXT: ^0(%2 : f32, %3 : f32, %4 : f32, %5 : f32): | ||
// FOLD-MUL-NEXT: %6 = arith.mulf %2, %3 : f32 | ||
// FOLD-MUL-NEXT: %7 = arith.addf %6, %4 : f32 | ||
// FOLD-MUL-NEXT: linalg.yield %7 : f32 | ||
// FOLD-MUL-NEXT: } -> tensor<8xf32> | ||
// FOLD-MUL-NEXT: %8 = linalg.add ins(%0, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> | ||
// FOLD-MUL-NEXT: %9 = linalg.sub ins(%0, %t3 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) -> tensor<8xf32> | ||
// FOLD-MUL-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from dataclasses import dataclass | ||
|
||
from xdsl.builder import Builder | ||
from xdsl.context import MLContext | ||
from xdsl.dialects import arith, linalg | ||
from xdsl.dialects.builtin import AffineMapAttr, DenseIntOrFPElementsAttr, ModuleOp | ||
from xdsl.ir import BlockArgument, OpResult, SSAValue | ||
from xdsl.ir.affine import AffineMap | ||
from xdsl.passes import ModulePass | ||
from xdsl.pattern_rewriter import ( | ||
PatternRewriter, | ||
PatternRewriteWalker, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
|
||
|
||
def build_generic_fma( | ||
mul_op1: SSAValue, mul_op2: SSAValue, add_op: SSAValue, out: SSAValue | ||
) -> linalg.Generic: | ||
inputs = (mul_op1, mul_op2, add_op) | ||
outputs = (out,) | ||
|
||
arg_types = linalg.NamedOpBase.body_arg_types((*inputs, *outputs)) | ||
|
||
@Builder.implicit_region(arg_types) | ||
def body(args: tuple[BlockArgument, ...]) -> None: | ||
m = arith.Mulf(args[0], args[1]) | ||
a = arith.Addf(m, args[2]) | ||
linalg.YieldOp(a) | ||
|
||
return linalg.Generic( | ||
inputs, | ||
outputs, | ||
body, | ||
4 * [AffineMapAttr(AffineMap.from_callable(lambda i,: (i,)))], | ||
[linalg.IteratorTypeAttr.parallel()], | ||
[out.type], | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class FuseMultiplyAddPass(RewritePattern): | ||
require_scalar_factor: bool | ||
require_erasable_mul: bool | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, mul: linalg.MulOp, rewriter: PatternRewriter, /): | ||
if ( | ||
len(mul.res) != 1 | ||
or self.require_erasable_mul | ||
and len(set(use.operation for use in mul.res[0].uses)) != 1 | ||
): | ||
return | ||
|
||
for add in set( | ||
use.operation | ||
for use in mul.res[0].uses | ||
if isinstance(use.operation, linalg.AddOp) | ||
and mul.res[0] in use.operation.inputs | ||
): | ||
# if the `require_scalar_factor` flag is set, check if either operand of `mul` is a scalar | ||
if ( | ||
self.require_scalar_factor | ||
and not self.is_scalar_constant(mul.inputs[0]) | ||
and not self.is_scalar_constant(mul.inputs[1]) | ||
): | ||
return | ||
|
||
# the operand of `add` that is not the `mul` result | ||
add_operand = ( | ||
add.inputs[0] if mul.res[0] == add.inputs[1] else add.inputs[1] | ||
) | ||
|
||
# build fma op | ||
fma = build_generic_fma( | ||
mul.inputs[0], mul.inputs[1], add_operand, mul.outputs[0] | ||
) | ||
|
||
# replace in position of the add op | ||
rewriter.replace_op(add, fma) | ||
if len(mul.res[0].uses) == 0: | ||
rewriter.erase_matched_op() | ||
|
||
@staticmethod | ||
def is_scalar_constant(op: SSAValue) -> bool: | ||
""" | ||
Returns if the value is a scalar. This currently checks for scalar constants, and could | ||
in the future be extended to check for dynamically provided scalar values expanded via linalg.fill | ||
""" | ||
return ( | ||
isinstance(op, OpResult) | ||
and isinstance(op.op, arith.Constant) | ||
and ( | ||
not isinstance(v := op.op.value, DenseIntOrFPElementsAttr) | ||
or v.data.data.count(v.data.data[0]) == len(v.data.data) | ||
) | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class LinalgFuseMultiplyAddPass(ModulePass): | ||
""" | ||
Pass that fuses linalg multiply and add ops into a `generic` fma. | ||
""" | ||
|
||
name = "linalg-fuse-multiply-add" | ||
|
||
require_scalar_factor: bool = False | ||
"""Set to require one of the mul factors to be a scalar constant""" | ||
|
||
require_erasable_mul: bool = False | ||
"""Set to only fuse ops if the multiply has no other use and can be erased""" | ||
|
||
def apply(self, ctx: MLContext, op: ModuleOp) -> None: | ||
module_pass = PatternRewriteWalker( | ||
FuseMultiplyAddPass(self.require_scalar_factor, self.require_erasable_mul), | ||
apply_recursively=False, | ||
) | ||
module_pass.rewrite_module(op) |