Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: New test-add-timers-to-top-level-funcs pass #3407

Merged
merged 7 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: xdsl-opt %s -p test-add-timers-to-top-level-funcs --split-input-file | filecheck %s

builtin.module {

// CHECK: builtin.module {
// CHECK-NEXT: func.func @has_timers(%arg0 : i32, %timers : !llvm.ptr) -> i32 {
// CHECK-NEXT: %start = func.call @timer_start() : () -> f64
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: %end = func.call @timer_end(%start) : (f64) -> f64
// CHECK-NEXT: "llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
// CHECK-NEXT: func.return %arg0 : i32
// CHECK-NEXT: }
// CHECK-NEXT: func.func private @timer_start() -> f64
// CHECK-NEXT: func.func private @timer_end(f64) -> f64
// CHECK-NEXT: }

func.func @has_timers(%arg0 : i32, %timers : !llvm.ptr) -> i32 {
%start = func.call @timer_start() : () -> f64
"test.op"() : () -> ()
%end = func.call @timer_end(%start) : (f64) -> f64
"llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
func.return %arg0 : i32
}
func.func private @timer_start() -> f64
func.func private @timer_end(f64) -> f64
}

// -----

builtin.module {

// CHECK: builtin.module {
// CHECK-NEXT: func.func @has_no_timers(%arg0 : i32, %arg1 : i32, %timers : !llvm.ptr) -> i32 {
// CHECK-NEXT: %timestamp = func.call @timer_start() : () -> f64
// CHECK-NEXT: %res = arith.addi %arg0, %arg1 : i32
// CHECK-NEXT: %timediff = func.call @timer_end(%timestamp) : (f64) -> f64
// CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
// CHECK-NEXT: func.return %res : i32
// CHECK-NEXT: }
// CHECK-NEXT: func.func @also_has_no_timers(%timers : !llvm.ptr) {
// CHECK-NEXT: %timestamp = func.call @timer_start() : () -> f64
// CHECK-NEXT: func.func @nested_should_not_get_timers() {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: %timediff = func.call @timer_end(%timestamp) : (f64) -> f64
// CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: func.func @timer_start() -> f64
// CHECK-NEXT: func.func @timer_end(f64) -> f64
// CHECK-NEXT: }

func.func @has_no_timers(%arg0 : i32, %arg1 : i32) -> i32 {
%res = arith.addi %arg0, %arg1 : i32
func.return %res : i32
}

func.func @also_has_no_timers() {
func.func @nested_should_not_get_timers() {
func.return
}
"test.op"() : () -> ()
func.return
}
}
8 changes: 8 additions & 0 deletions xdsl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,13 @@ def get_stencil_unroll():

return stencil_unroll.StencilUnrollPass

def get_test_add_timers_to_top_level_funcs():
from xdsl.transforms.function_transformations import (
TestAddBenchTimersToTopLevelFunctions,
)

return TestAddBenchTimersToTopLevelFunctions

def get_test_lower_linalg_to_snitch():
from xdsl.transforms import test_lower_linalg_to_snitch

Expand Down Expand Up @@ -542,6 +549,7 @@ def get_varith_fuse_repeated_operands():
"stencil-unroll": get_stencil_unroll,
"stencil-bufferize": get_stencil_bufferize,
"stencil-shape-minimize": get_stencil_shape_minimize,
"test-add-timers-to-top-level-funcs": get_test_add_timers_to_top_level_funcs,
"test-lower-linalg-to-snitch": get_test_lower_linalg_to_snitch,
"eqsat-create-eclasses": get_eqsat_create_eclasses,
"varith-fuse-repeated-operands": get_varith_fuse_repeated_operands,
Expand Down
16 changes: 8 additions & 8 deletions xdsl/transforms/csl_stencil_to_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
)
from xdsl.rewriter import InsertPoint
from xdsl.transforms import csl_stencil_bufferize
from xdsl.transforms.function_transformations import (
TIMER_END,
TIMER_START,
)
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr

_TIMER_START = "timer_start"
_TIMER_END = "timer_end"
_TIMER_FUNC_NAMES = [_TIMER_START, _TIMER_END]


def _get_module_wrapper(op: Operation) -> csl_wrapper.ModuleOp | None:
"""
Expand Down Expand Up @@ -64,7 +64,7 @@ class ConvertStencilFuncToModuleWrappedPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
# erase timer stubs
if op.is_declaration and op.sym_name.data in _TIMER_FUNC_NAMES:
if op.is_declaration and op.sym_name.data in [TIMER_START, TIMER_END]:
rewriter.erase_matched_op()
return
# find csl_stencil.apply ops, abort if there are none
Expand Down Expand Up @@ -250,7 +250,7 @@ def _translate_function_args(
isinstance(u.operation, llvm.StoreOp)
and isinstance(u.operation.value, OpResult)
and isinstance(u.operation.value.op, func.Call)
and u.operation.value.op.callee.string_value() == _TIMER_END
and u.operation.value.op.callee.string_value() == TIMER_END
for u in arg.uses
):
start_end_size = 3
Expand Down Expand Up @@ -394,9 +394,9 @@ class LowerTimerFuncCall(RewritePattern):
def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /):
if (
not isinstance(end_call := op.value.owner, func.Call)
or not end_call.callee.string_value() == _TIMER_END
or not end_call.callee.string_value() == TIMER_END
or not (isinstance(start_call := end_call.arguments[0].owner, func.Call))
or not start_call.callee.string_value() == _TIMER_START
or not start_call.callee.string_value() == TIMER_START
or not (wrapper := _get_module_wrapper(op))
or not isa(op.ptr.type, AnyMemRefType)
):
Expand Down
70 changes: 69 additions & 1 deletion xdsl/transforms/function_transformations.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import builtin, func
from xdsl.dialects import builtin, func, llvm
from xdsl.dialects.builtin import ArrayAttr, DictionaryAttr, StringAttr
from xdsl.ir import Region
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint


class ArgNamesToArgAttrsPass(RewritePattern):
Expand Down Expand Up @@ -36,6 +40,70 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
rewriter.has_done_action = True


TIMER_START = "timer_start"
TIMER_END = "timer_end"


@dataclass
class AddBenchTimersPattern(RewritePattern):
start_func_t: func.FunctionType
end_func_t: func.FunctionType

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
if (
not (top_level := op.parent_op())
or not isinstance(top_level, builtin.ModuleOp)
or top_level.parent
):
return
Comment on lines +55 to +60
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not all functions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently don't have a use case for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're looking at annotating top-level functions that are the equivalent of a main function. Should prob add a check that it's not being called. Don't think the SymbolTable API supports this though by any chance?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the name of the function to annotate be worth adding as a pass parameter? If it works with your flow it might be the easiest way to make it safe by default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, let me come back to that after the deadline if alright


ptr = op.body.block.insert_arg(llvm.LLVMPointerType.opaque(), len(op.args))
start_call = func.Call(TIMER_START, [], tuple(self.start_func_t.outputs))
end_call = func.Call(TIMER_END, start_call.res, tuple(self.end_func_t.outputs))
store_time = llvm.StoreOp(end_call.res[0], ptr)

ptr.name_hint = "timers"
start_call.res[0].name_hint = "timestamp"
end_call.res[0].name_hint = "timediff"

assert op.body.block.last_op
rewriter.insert_op(start_call, InsertPoint.at_start(op.body.block))
rewriter.insert_op(
[end_call, store_time], InsertPoint.before(op.body.block.last_op)
)
op.update_function_type()


class TestAddBenchTimersToTopLevelFunctions(ModulePass):
"""
Adds timers to top-level functions, by adding `timer_start() -> f64` and `timer_end(f64) -> f64`
to the start and end of each module-level function. The time is stored in an `llvm.ptr` passed in
as a function arg.
"""

name = "test-add-timers-to-top-level-funcs"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
all_funcs = [f for f in op.body.block.ops if isinstance(f, func.FuncOp)]
func_names = [f.sym_name.data for f in all_funcs]
if TIMER_START in func_names or TIMER_END in func_names:
return

start_func_t = func.FunctionType.from_lists([], [builtin.Float64Type()])
end_func_t = func.FunctionType.from_lists(
[builtin.Float64Type()], [builtin.Float64Type()]
)
start_func = func.FuncOp(TIMER_START, start_func_t, Region([]))
end_func = func.FuncOp(TIMER_END, end_func_t, Region([]))

PatternRewriteWalker(
AddBenchTimersPattern(start_func_t, end_func_t), apply_recursively=False
).rewrite_module(op)

op.body.block.add_ops((start_func, end_func))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend using SymbolTable APIs to look for functions and insert them:

xdsl/backend/riscv/lowering/convert_memref_to_riscv.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet, I like this



class FunctionPersistArgNames(ModulePass):
"""
Persists func.func arg name hints to arg_attrs.
Expand Down
Loading