-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformations: New test-add-timers-to-top-level-funcs pass #3407
Changes from 6 commits
54c270c
b482170
b17120a
09f604e
3f4f051
22bd044
cc38fb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// RUN: xdsl-opt %s -p test-add-timers-to-top-level-funcs --split-input-file | filecheck %s | ||
|
||
builtin.module { | ||
|
||
// CHECK: builtin.module { | ||
// CHECK-NEXT: func.func @has_timers(%arg0 : i32, %timers : !llvm.ptr) -> i32 { | ||
// CHECK-NEXT: %start = func.call @timer_start() : () -> f64 | ||
// CHECK-NEXT: "test.op"() : () -> () | ||
// CHECK-NEXT: %end = func.call @timer_end(%start) : (f64) -> f64 | ||
// CHECK-NEXT: "llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () | ||
// CHECK-NEXT: func.return %arg0 : i32 | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: func.func private @timer_start() -> f64 | ||
// CHECK-NEXT: func.func private @timer_end(f64) -> f64 | ||
// CHECK-NEXT: } | ||
|
||
func.func @has_timers(%arg0 : i32, %timers : !llvm.ptr) -> i32 { | ||
%start = func.call @timer_start() : () -> f64 | ||
"test.op"() : () -> () | ||
%end = func.call @timer_end(%start) : (f64) -> f64 | ||
"llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () | ||
func.return %arg0 : i32 | ||
} | ||
func.func private @timer_start() -> f64 | ||
func.func private @timer_end(f64) -> f64 | ||
} | ||
|
||
// ----- | ||
|
||
builtin.module { | ||
|
||
// CHECK: builtin.module { | ||
// CHECK-NEXT: func.func @has_no_timers(%arg0 : i32, %arg1 : i32, %timers : !llvm.ptr) -> i32 { | ||
// CHECK-NEXT: %timestamp = func.call @timer_start() : () -> f64 | ||
// CHECK-NEXT: %res = arith.addi %arg0, %arg1 : i32 | ||
// CHECK-NEXT: %timediff = func.call @timer_end(%timestamp) : (f64) -> f64 | ||
// CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () | ||
// CHECK-NEXT: func.return %res : i32 | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: func.func @also_has_no_timers(%timers : !llvm.ptr) { | ||
// CHECK-NEXT: %timestamp = func.call @timer_start() : () -> f64 | ||
// CHECK-NEXT: func.func @nested_should_not_get_timers() { | ||
// CHECK-NEXT: func.return | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: "test.op"() : () -> () | ||
// CHECK-NEXT: %timediff = func.call @timer_end(%timestamp) : (f64) -> f64 | ||
// CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () | ||
// CHECK-NEXT: func.return | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: func.func @timer_start() -> f64 | ||
// CHECK-NEXT: func.func @timer_end(f64) -> f64 | ||
// CHECK-NEXT: } | ||
|
||
func.func @has_no_timers(%arg0 : i32, %arg1 : i32) -> i32 { | ||
%res = arith.addi %arg0, %arg1 : i32 | ||
func.return %res : i32 | ||
} | ||
|
||
func.func @also_has_no_timers() { | ||
func.func @nested_should_not_get_timers() { | ||
func.return | ||
} | ||
"test.op"() : () -> () | ||
func.return | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,17 @@ | ||
from dataclasses import dataclass | ||
|
||
from xdsl.context import MLContext | ||
from xdsl.dialects import builtin, func | ||
from xdsl.dialects import builtin, func, llvm | ||
from xdsl.dialects.builtin import ArrayAttr, DictionaryAttr, StringAttr | ||
from xdsl.ir import Region | ||
from xdsl.passes import ModulePass | ||
from xdsl.pattern_rewriter import ( | ||
PatternRewriter, | ||
PatternRewriteWalker, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
from xdsl.rewriter import InsertPoint | ||
|
||
|
||
class ArgNamesToArgAttrsPass(RewritePattern): | ||
|
@@ -36,6 +40,70 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): | |
rewriter.has_done_action = True | ||
|
||
|
||
TIMER_START = "timer_start" | ||
TIMER_END = "timer_end" | ||
|
||
|
||
@dataclass | ||
class AddBenchTimersPattern(RewritePattern): | ||
start_func_t: func.FunctionType | ||
end_func_t: func.FunctionType | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): | ||
if ( | ||
not (top_level := op.parent_op()) | ||
or not isinstance(top_level, builtin.ModuleOp) | ||
or top_level.parent | ||
): | ||
return | ||
|
||
ptr = op.body.block.insert_arg(llvm.LLVMPointerType.opaque(), len(op.args)) | ||
start_call = func.Call(TIMER_START, [], tuple(self.start_func_t.outputs)) | ||
end_call = func.Call(TIMER_END, start_call.res, tuple(self.end_func_t.outputs)) | ||
store_time = llvm.StoreOp(end_call.res[0], ptr) | ||
|
||
ptr.name_hint = "timers" | ||
start_call.res[0].name_hint = "timestamp" | ||
end_call.res[0].name_hint = "timediff" | ||
|
||
assert op.body.block.last_op | ||
rewriter.insert_op(start_call, InsertPoint.at_start(op.body.block)) | ||
rewriter.insert_op( | ||
[end_call, store_time], InsertPoint.before(op.body.block.last_op) | ||
) | ||
op.update_function_type() | ||
|
||
|
||
class TestAddBenchTimersToTopLevelFunctions(ModulePass): | ||
""" | ||
Adds timers to top-level functions, by adding `timer_start() -> f64` and `timer_end(f64) -> f64` | ||
to the start and end of each module-level function. The time is stored in an `llvm.ptr` passed in | ||
as a function arg. | ||
""" | ||
|
||
name = "test-add-timers-to-top-level-funcs" | ||
|
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
all_funcs = [f for f in op.body.block.ops if isinstance(f, func.FuncOp)] | ||
func_names = [f.sym_name.data for f in all_funcs] | ||
if TIMER_START in func_names or TIMER_END in func_names: | ||
return | ||
|
||
start_func_t = func.FunctionType.from_lists([], [builtin.Float64Type()]) | ||
end_func_t = func.FunctionType.from_lists( | ||
[builtin.Float64Type()], [builtin.Float64Type()] | ||
) | ||
start_func = func.FuncOp(TIMER_START, start_func_t, Region([])) | ||
end_func = func.FuncOp(TIMER_END, end_func_t, Region([])) | ||
|
||
PatternRewriteWalker( | ||
AddBenchTimersPattern(start_func_t, end_func_t), apply_recursively=False | ||
).rewrite_module(op) | ||
|
||
op.body.block.add_ops((start_func, end_func)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would recommend using SymbolTable APIs to look for functions and insert them: xdsl/backend/riscv/lowering/convert_memref_to_riscv.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sweet, I like this |
||
|
||
|
||
class FunctionPersistArgNames(ModulePass): | ||
""" | ||
Persists func.func arg name hints to arg_attrs. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not all functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently don't have a use case for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're looking at annotating top-level functions that are the equivalent of a main function. Should prob add a check that it's not being called. Don't think the SymbolTable API supports this though by any chance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the name of the function to annotate be worth adding as a pass parameter? If it works with your flow it might be the easiest way to make it safe by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion, let me come back to that after the deadline if alright