From 62ba3abe7a7b8de9977a79f071966f258d7ec188 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 7 Feb 2025 12:33:43 -0400 Subject: [PATCH] Use GPUToolbox --- Project.toml | 2 + lib/utils/APIUtils.jl | 4 +- lib/utils/call.jl | 142 +------------------- src/CUDA.jl | 2 + src/device/intrinsics/cooperative_groups.jl | 2 + src/device/intrinsics/version.jl | 47 ------- 6 files changed, 10 insertions(+), 189 deletions(-) diff --git a/Project.toml b/Project.toml index 9656d4a6bd..2e63b643fb 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" +GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVMLoopInfo = "8b046642-f1f6-4319-8d3c-209ddc03c586" @@ -61,6 +62,7 @@ EnzymeCore = "0.8.2" ExprTools = "0.1" GPUArrays = "11.2.1" GPUCompiler = "0.24, 0.25, 0.26, 0.27, 1" +GPUToolbox = "0.1" KernelAbstractions = "0.9.2" LLVM = "9.1" LLVMLoopInfo = "1" diff --git a/lib/utils/APIUtils.jl b/lib/utils/APIUtils.jl index 420ddce8aa..58547d8feb 100644 --- a/lib/utils/APIUtils.jl +++ b/lib/utils/APIUtils.jl @@ -6,7 +6,9 @@ using LLVM using LLVM.Interop # helpers that facilitate working with CUDA APIs -include("call.jl") +using GPUToolbox: @checked, @debug_ccall, @gcsafe_ccall +export @checked, @debug_ccall, @gcsafe_ccall + include("enum.jl") include("threading.jl") include("cache.jl") diff --git a/lib/utils/call.jl b/lib/utils/call.jl index 8e7341451e..7571801994 100644 --- a/lib/utils/call.jl +++ b/lib/utils/call.jl @@ -1,50 +1,6 @@ # utilities for calling foreign functionality more conveniently -export @checked, with_workspace, with_workspaces, - @debug_ccall, @gcsafe_ccall - - -## function wrapper for checking the return value of a function - -""" - @checked function foo(...) - rv = ... - return rv - end - -Macro for wrapping a function definition returning a status code. Two versions of the -function will be generated: `foo`, with the function execution wrapped by an invocation of -the `check` function (to be implemented by the caller of this macro), and `unchecked_foo` -where no such invocation is present and the status code is returned to the caller. -""" -macro checked(ex) - # parse the function definition - @assert Meta.isexpr(ex, :function) - sig = ex.args[1] - @assert Meta.isexpr(sig, :call) - body = ex.args[2] - @assert Meta.isexpr(body, :block) - - # make sure these functions are inlined - pushfirst!(body.args, Expr(:meta, :inline)) - - # generate a "safe" version that performs a check - safe_body = quote - @inline - check() do - $body - end - end - safe_sig = Expr(:call, sig.args[1], sig.args[2:end]...) - safe_def = Expr(:function, safe_sig, safe_body) - - # generate a "unchecked" version that returns the error code instead - unchecked_sig = Expr(:call, Symbol("unchecked_", sig.args[1]), sig.args[2:end]...) - unchecked_def = Expr(:function, unchecked_sig, body) - - return esc(:($safe_def, $unchecked_def)) -end - +export with_workspace, with_workspaces ## wrapper for foreign functionality that requires a workspace buffer @@ -138,99 +94,3 @@ function with_workspaces(f::Base.Callable, end end end - - -## version of ccall that prints the ccall, its arguments and its return value - -macro debug_ccall(ex) - @assert Meta.isexpr(ex, :(::)) - call, ret = ex.args - @assert Meta.isexpr(call, :call) - target, argexprs... = call.args - args = map(argexprs) do argexpr - @assert Meta.isexpr(argexpr, :(::)) - argexpr.args[1] - end - - ex = Expr(:macrocall, Symbol("@ccall"), __source__, ex) - - # avoid task switches - io = :(Core.stdout) - - quote - print($io, $(string(target)), '(') - for (i, arg) in enumerate(($(map(esc, args)...),)) - i > 1 && print($io, ", ") - render_arg($io, arg) - end - print($io, ')') - - rv = $(esc(ex)) - - println($io, " = ", rv) - for (i, arg) in enumerate(($(map(esc, args)...),)) - if arg isa Base.RefValue - println($io, " $i: ", arg[]) - end - end - rv - end -end - -render_arg(io, arg) = print(io, arg) -render_arg(io, arg::AbstractArray) = summary(io, arg) -render_arg(io, arg::Base.RefValue{T}) where {T} = print(io, "Ref{", T, "}") - - -## version of ccall that calls jl_gc_safe_enter|leave around the inner ccall - -# TODO: replace with JuliaLang/julia#49933 once merged - -function ccall_macro_lower(func, rettype, types, args, nreq) - # instead of re-using ccall or Expr(:foreigncall) to perform argument conversion, - # we need to do so ourselves in order to insert a jl_gc_safe_enter|leave - # just around the inner ccall - - cconvert_exprs = [] - cconvert_args = [] - for (typ, arg) in zip(types, args) - var = gensym("$(func)_cconvert") - push!(cconvert_args, var) - push!(cconvert_exprs, :($var = Base.cconvert($(esc(typ)), $(esc(arg))))) - end - - unsafe_convert_exprs = [] - unsafe_convert_args = [] - for (typ, arg) in zip(types, cconvert_args) - var = gensym("$(func)_unsafe_convert") - push!(unsafe_convert_args, var) - push!(unsafe_convert_exprs, :($var = Base.unsafe_convert($(esc(typ)), $arg))) - end - - call = quote - $(unsafe_convert_exprs...) - - gc_state = @ccall(jl_gc_safe_enter()::Int8) - ret = ccall($(esc(func)), $(esc(rettype)), $(Expr(:tuple, map(esc, types)...)), - $(unsafe_convert_args...)) - @ccall(jl_gc_safe_leave(gc_state::Int8)::Cvoid) - ret - end - - quote - @inline - $(cconvert_exprs...) - GC.@preserve $(cconvert_args...) $(call) - end -end - -""" - @gcsafe_ccall ... - -Call a foreign function just like `@ccall`, but marking it safe for the GC to run. This is -useful for functions that may block, so that the GC isn't blocked from running, but may also -be required to prevent deadlocks (see JuliaGPU/CUDA.jl#2261). -""" -macro gcsafe_ccall(expr) - ccall_macro_lower(Base.ccall_macro_parse(expr)...) -end diff --git a/src/CUDA.jl b/src/CUDA.jl index a72e030575..29c840b8fc 100644 --- a/src/CUDA.jl +++ b/src/CUDA.jl @@ -4,6 +4,8 @@ using GPUCompiler using GPUArrays +using GPUToolbox: SimpleVersion, @sv_str + using LLVM using LLVM.Interop using Core: LLVMPtr diff --git a/src/device/intrinsics/cooperative_groups.jl b/src/device/intrinsics/cooperative_groups.jl index eb9d297c05..b1b5b215ec 100644 --- a/src/device/intrinsics/cooperative_groups.jl +++ b/src/device/intrinsics/cooperative_groups.jl @@ -32,6 +32,8 @@ using ..LLVMLoopInfo using Core: LLVMPtr +using GPUToolbox: @sv_str + const cg_debug = false if cg_debug cg_assert(x) = @cuassert x diff --git a/src/device/intrinsics/version.jl b/src/device/intrinsics/version.jl index ac74dff779..97c00515a6 100644 --- a/src/device/intrinsics/version.jl +++ b/src/device/intrinsics/version.jl @@ -1,52 +1,5 @@ # device intrinsics for querying the compute SimpleVersion and PTX ISA version - -## a GPU-compatible version number - -export SimpleVersion, @sv_str - -struct SimpleVersion - major::UInt32 - minor::UInt32 - - SimpleVersion(major, minor=0) = new(major, minor) -end - -function Base.tryparse(::Type{SimpleVersion}, v::AbstractString) - parts = split(v, ".") - 1 <= length(parts) <= 2 || return nothing - - int_parts = map(parts) do part - tryparse(Int, part) - end - any(isnothing, int_parts) && return nothing - - SimpleVersion(int_parts...) -end - -function Base.parse(::Type{SimpleVersion}, v::AbstractString) - ver = tryparse(SimpleVersion, v) - ver === nothing && throw(ArgumentError("invalid SimpleVersion string: '$v'")) - return ver -end - -SimpleVersion(v::AbstractString) = parse(SimpleVersion, v) - -@inline function Base.isless(a::SimpleVersion, b::SimpleVersion) - (a.major < b.major) && return true - (a.major > b.major) && return false - (a.minor < b.minor) && return true - (a.minor > b.minor) && return false - return false -end - -macro sv_str(str) - SimpleVersion(str) -end - - -## accessors for the compute SimpleVersion and PTX ISA version - export compute_capability, ptx_isa_version for var in ["sm_major", "sm_minor", "ptx_major", "ptx_minor"]