Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loop pragmas #31376

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ include("asyncmap.jl")
# experimental API's
include("experimental.jl")

# various loop pragmas
include("pragma.jl")

# deprecated functions
include("deprecated.jl")

Expand Down
93 changes: 93 additions & 0 deletions base/pragma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
module Pragma

export @unroll

##
# Uses the loopinfo expr node to attach LLVM loopinfo to loops
# the full list of supported metadata nodes is available at
# https://llvm.org/docs/LangRef.html#llvm-loop
# TODO:
# - Figure out how to deal with compile-time constants in `@unroll(N, expr)`
# so constants that come from `Val{N}` but are not parse time constant.
# - Difference between `unroll_enable` and `unroll_full`
# - ? Expose `unroll_disable`
# - ? Expose `jam_disable`
##

module MD
disable_nonforced() = (Symbol("llvm.loop.disable_nonforced"),)
interleave(n) = (Symbol("llvm.loop.interleave.count"), convert(Int, n))
vectorize_enable(flag) = (Symbol("llvm.loop.vectorize.enable"), convert(Bool, flag))
vectorize_width(n) = (Symbol("llvm.loop.vectorize.width"), convert(Int, n))
# ‘llvm.loop.vectorize.followup_vectorized’
# ‘llvm.loop.vectorize.followup_epilogue’
# ‘llvm.loop.vectorize.followup_all’
unroll_count(n) = (Symbol("llvm.loop.unroll.count"), convert(Int, n))
unroll_disable() = (Symbol("llvm.loop.unroll.disable"),)
unroll_enable() = (Symbol("llvm.loop.unroll.enable"),)
unroll_full() = (Symbol("llvm.loop.unroll.full"),)
# ‘llvm.loop.unroll.followup’
# ‘llvm.loop.unroll.followup_remainder’
jam_count(n) = (Symbol("llvm.loop.unroll_and_jam.count"), convert(Int, n))
jam_disable() = (Symbol("llvm.loop.unroll_and_jam.disable"),)
jam_enable() = (Symbol("llvm.loop.unroll_and_jam.enable"),)
# ‘llvm.loop.unroll_and_jam.followup_outer’
# ‘llvm.loop.unroll_and_jam.followup_inner’
# ‘llvm.loop.unroll_and_jam.followup_remainder_outer’
# ‘llvm.loop.unroll_and_jam.followup_remainder_inner’
# ‘llvm.loop.unroll_and_jam.followup_all’
# ‘llvm.loop.licm_versioning.disable’
# ‘llvm.loop.distribute.enable’
# ‘llvm.loop.distribute.followup_coincident’
# ‘llvm.loop.distribute.followup_sequential’
# ‘llvm.loop.distribute.followup_fallback’
# ‘llvm.loop.distribute.followup_all’
end

function loopinfo(name, expr, nodes...)
if expr.head != :for
error("Syntax error: pragma $name needs a for loop")
end
push!(expr.args[2].args, Expr(:loopinfo, nodes...))
return expr
end

"""
@unroll expr

Takes a for loop as `expr` and informs the LLVM unroller to fully unroll it, if
it is safe to do so and the loop count is known.
"""
macro unroll(expr)
expr = loopinfo("@unroll", expr, MD.unroll_full())
return esc(expr)
end

"""
@unroll N expr

Takes a for loop as `expr` and informs the LLVM unroller to unroll it `N` times,
if it is safe to do so.
"""
macro unroll(N, expr)
if !(N isa Integer)
error("Syntax error: `@unroll N expr` needs a constant integer N")
end
expr = loopinfo("@unroll", expr, MD.unroll_count(N))
return esc(expr)
end

macro jam(N, expr)
if !(N isa Integer)
error("Syntax error: `@jam N expr` needs a constant integer N")
end
expr = loopinfo("@jam", expr, MD.jam_count(N))
return esc(expr)
end

macro jam(expr)
expr = loopinfo("@jam", expr, MD.jam_enable())
return esc(expr)
end

end #module
8 changes: 4 additions & 4 deletions test/llvmpasses/loopinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# RUN: julia --startup-file=no %s %t -O && llvm-link -S %t/* -o %t/module.ll
# RUN: cat %t/module.ll | FileCheck %s -check-prefix=FINAL

using Base.Pragma

## Notes:
# This script uses the `emit` function (defined llvmpasses.jl) to emit either
# optimized or unoptimized LLVM IR. Each function is emitted individually and
Expand Down Expand Up @@ -57,9 +59,8 @@ end
# LOWER-LABEL: @julia_loop_unroll
# FINAL-LABEL: @julia_loop_unroll
@eval function loop_unroll(N)
for i in 1:N
@unroll 3 for i in 1:N
iteration(i)
$(Expr(:loopinfo, (Symbol("llvm.loop.unroll.count"), 3)))
# CHECK: call void @julia.loopinfo_marker(), {{.*}}, !julia.loopinfo [[LOOPINFO3:![0-9]+]]
# LOWER-NOT: call void @julia.loopinfo_marker()
# LOWER: br {{.*}}, !llvm.loop [[LOOPID3:![0-9]+]]
Expand All @@ -79,13 +80,12 @@ end
# LOWER-LABEL: @julia_loop_unroll2
# FINAL-LABEL: @julia_loop_unroll2
@eval function loop_unroll2(J, I)
for i in 1:10
@unroll for i in 1:10
for j in J
1 <= j <= I && continue
@show (i,j)
iteration(i)
end
$(Expr(:loopinfo, (Symbol("llvm.loop.unroll.full"),)))
# CHECK: call void @julia.loopinfo_marker(), {{.*}}, !julia.loopinfo [[LOOPINFO4:![0-9]+]]
# LOWER-NOT: call void @julia.loopinfo_marker()
# LOWER: br {{.*}}, !llvm.loop [[LOOPID4:![0-9]+]]
Expand Down