Skip to content

Commit

Permalink
Use GPUToolbox.jl (#2646)
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd authored Feb 17, 2025
1 parent 236643b commit 3d42ca2
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 188 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVMLoopInfo = "8b046642-f1f6-4319-8d3c-209ddc03c586"
Expand Down Expand Up @@ -61,6 +62,7 @@ EnzymeCore = "0.8.2"
ExprTools = "0.1"
GPUArrays = "11.2.1"
GPUCompiler = "0.24, 0.25, 0.26, 0.27, 1"
GPUToolbox = "0.1"
KernelAbstractions = "0.9.2"
LLVM = "9.1"
LLVMLoopInfo = "1"
Expand Down
3 changes: 3 additions & 0 deletions lib/utils/APIUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ using LLVM
using LLVM.Interop

# helpers that facilitate working with CUDA APIs
using GPUToolbox: @checked, @debug_ccall, @gcsafe_ccall
export @checked, @debug_ccall, @gcsafe_ccall

include("call.jl")
include("enum.jl")
include("threading.jl")
Expand Down
142 changes: 1 addition & 141 deletions lib/utils/call.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,6 @@
# utilities for calling foreign functionality more conveniently

export @checked, with_workspace, with_workspaces,
@debug_ccall, @gcsafe_ccall


## function wrapper for checking the return value of a function

"""
@checked function foo(...)
rv = ...
return rv
end
Macro for wrapping a function definition returning a status code. Two versions of the
function will be generated: `foo`, with the function execution wrapped by an invocation of
the `check` function (to be implemented by the caller of this macro), and `unchecked_foo`
where no such invocation is present and the status code is returned to the caller.
"""
macro checked(ex)
# parse the function definition
@assert Meta.isexpr(ex, :function)
sig = ex.args[1]
@assert Meta.isexpr(sig, :call)
body = ex.args[2]
@assert Meta.isexpr(body, :block)

# make sure these functions are inlined
pushfirst!(body.args, Expr(:meta, :inline))

# generate a "safe" version that performs a check
safe_body = quote
@inline
check() do
$body
end
end
safe_sig = Expr(:call, sig.args[1], sig.args[2:end]...)
safe_def = Expr(:function, safe_sig, safe_body)

# generate a "unchecked" version that returns the error code instead
unchecked_sig = Expr(:call, Symbol("unchecked_", sig.args[1]), sig.args[2:end]...)
unchecked_def = Expr(:function, unchecked_sig, body)

return esc(:($safe_def, $unchecked_def))
end

export with_workspace, with_workspaces

## wrapper for foreign functionality that requires a workspace buffer

Expand Down Expand Up @@ -138,99 +94,3 @@ function with_workspaces(f::Base.Callable,
end
end
end


## version of ccall that prints the ccall, its arguments and its return value

macro debug_ccall(ex)
@assert Meta.isexpr(ex, :(::))
call, ret = ex.args
@assert Meta.isexpr(call, :call)
target, argexprs... = call.args
args = map(argexprs) do argexpr
@assert Meta.isexpr(argexpr, :(::))
argexpr.args[1]
end

ex = Expr(:macrocall, Symbol("@ccall"), __source__, ex)

# avoid task switches
io = :(Core.stdout)

quote
print($io, $(string(target)), '(')
for (i, arg) in enumerate(($(map(esc, args)...),))
i > 1 && print($io, ", ")
render_arg($io, arg)
end
print($io, ')')

rv = $(esc(ex))

println($io, " = ", rv)
for (i, arg) in enumerate(($(map(esc, args)...),))
if arg isa Base.RefValue
println($io, " $i: ", arg[])
end
end
rv
end
end

render_arg(io, arg) = print(io, arg)
render_arg(io, arg::AbstractArray) = summary(io, arg)
render_arg(io, arg::Base.RefValue{T}) where {T} = print(io, "Ref{", T, "}")


## version of ccall that calls jl_gc_safe_enter|leave around the inner ccall

# TODO: replace with JuliaLang/julia#49933 once merged

function ccall_macro_lower(func, rettype, types, args, nreq)
# instead of re-using ccall or Expr(:foreigncall) to perform argument conversion,
# we need to do so ourselves in order to insert a jl_gc_safe_enter|leave
# just around the inner ccall

cconvert_exprs = []
cconvert_args = []
for (typ, arg) in zip(types, args)
var = gensym("$(func)_cconvert")
push!(cconvert_args, var)
push!(cconvert_exprs, :($var = Base.cconvert($(esc(typ)), $(esc(arg)))))
end

unsafe_convert_exprs = []
unsafe_convert_args = []
for (typ, arg) in zip(types, cconvert_args)
var = gensym("$(func)_unsafe_convert")
push!(unsafe_convert_args, var)
push!(unsafe_convert_exprs, :($var = Base.unsafe_convert($(esc(typ)), $arg)))
end

call = quote
$(unsafe_convert_exprs...)

gc_state = @ccall(jl_gc_safe_enter()::Int8)
ret = ccall($(esc(func)), $(esc(rettype)), $(Expr(:tuple, map(esc, types)...)),
$(unsafe_convert_args...))
@ccall(jl_gc_safe_leave(gc_state::Int8)::Cvoid)
ret
end

quote
@inline
$(cconvert_exprs...)
GC.@preserve $(cconvert_args...) $(call)
end
end

"""
@gcsafe_ccall ...
Call a foreign function just like `@ccall`, but marking it safe for the GC to run. This is
useful for functions that may block, so that the GC isn't blocked from running, but may also
be required to prevent deadlocks (see JuliaGPU/CUDA.jl#2261).
"""
macro gcsafe_ccall(expr)
ccall_macro_lower(Base.ccall_macro_parse(expr)...)
end
2 changes: 2 additions & 0 deletions src/CUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using GPUCompiler

using GPUArrays

using GPUToolbox: SimpleVersion, @sv_str

using LLVM
using LLVM.Interop
using Core: LLVMPtr
Expand Down
2 changes: 2 additions & 0 deletions src/device/intrinsics/cooperative_groups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ using ..LLVMLoopInfo

using Core: LLVMPtr

using GPUToolbox: @sv_str

const cg_debug = false
if cg_debug
cg_assert(x) = @cuassert x
Expand Down
47 changes: 0 additions & 47 deletions src/device/intrinsics/version.jl
Original file line number Diff line number Diff line change
@@ -1,52 +1,5 @@
# device intrinsics for querying the compute SimpleVersion and PTX ISA version


## a GPU-compatible version number

export SimpleVersion, @sv_str

struct SimpleVersion
major::UInt32
minor::UInt32

SimpleVersion(major, minor=0) = new(major, minor)
end

function Base.tryparse(::Type{SimpleVersion}, v::AbstractString)
parts = split(v, ".")
1 <= length(parts) <= 2 || return nothing

int_parts = map(parts) do part
tryparse(Int, part)
end
any(isnothing, int_parts) && return nothing

SimpleVersion(int_parts...)
end

function Base.parse(::Type{SimpleVersion}, v::AbstractString)
ver = tryparse(SimpleVersion, v)
ver === nothing && throw(ArgumentError("invalid SimpleVersion string: '$v'"))
return ver
end

SimpleVersion(v::AbstractString) = parse(SimpleVersion, v)

@inline function Base.isless(a::SimpleVersion, b::SimpleVersion)
(a.major < b.major) && return true
(a.major > b.major) && return false
(a.minor < b.minor) && return true
(a.minor > b.minor) && return false
return false
end

macro sv_str(str)
SimpleVersion(str)
end


## accessors for the compute SimpleVersion and PTX ISA version

export compute_capability, ptx_isa_version

for var in ["sm_major", "sm_minor", "ptx_major", "ptx_minor"]
Expand Down

0 comments on commit 3d42ca2

Please sign in to comment.