Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUFFT] Preallocate a buffer for complex-to-real FFT #2578

Merged
merged 8 commits into from
Dec 14, 2024
112 changes: 71 additions & 41 deletions lib/cufft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,34 @@ Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale)

# N is the number of dimensions

mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N} <: Plan{S}
mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B} <: Plan{S}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dedicated typevar instead of simply Union{Nothing,CuArray} seems overkill, but I guess you want to eliminate the partial type information?

Copy link
Member Author

@amontoison amontoison Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and I also take this PR as an opportunity to have a concrete type for region in the structure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trust you have good reasons for this, but it's not always advantageous to do so. The CUTENSOR wrappers, for example, cause excessive specialization for little gain, since these calls are very coarse-grained anyway (yet leading to very long test times). In this case, though, the specialization is already there (since B doesn't introduced new information already contained in inplace and T).

Copy link
Member Author

@amontoison amontoison Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't changed the dispatch for the high-level FFT routines, so I didn't expect any slowdown. I simply made the low-level structure fully inferable.
I'll run a quick regression test for fft/ifft before and after this PR to check if there's any noticeable difference.

I didn't test it without the type B in the structure, but I observed a nice speed-up in my code when adding the buffer inside for rfft/irfft.
Julia also stopped crashing due to OOM issues.

# handle to Cuda low level plan. Note that this plan sometimes has lower dimensions
# to handle more transform cases such as individual directions
handle::cufftHandle
ctx::CuContext
stream::CuStream
input_size::NTuple{N,Int} # Julia size of input array
output_size::NTuple{N,Int} # Julia size of output array
region::Any
region::NTuple{R,Int}
buffer::B # buffer for out-of-place complex-to-real FFT
pinv::ScaledPlan{T} # required by AbstractFFTs API, will be defined by AbstractFFTs if needed

function CuFFTPlan{T,S,K,inplace,N}(handle::cufftHandle,
input_size::NTuple{N,Int}, output_size::NTuple{N,Int}, region
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N}
function CuFFTPlan{T,S,K,inplace,N,R,B}(handle::cufftHandle,
input_size::NTuple{N,Int}, output_size::NTuple{N,Int},
region::NTuple{R,Int}, buffer::B
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B}
abs(K) == 1 || throw(ArgumentError("FFT direction must be either -1 (forward) or +1 (inverse)"))
inplace isa Bool || throw(ArgumentError("FFT inplace argument must be a Bool"))
p = new{T,S,K,inplace,N}(handle, context(), stream(), input_size, output_size, region)
p = new{T,S,K,inplace,N,R,B}(handle, context(), stream(), input_size, output_size, region, buffer)
finalizer(unsafe_free!, p)
p
end
end

function CuFFTPlan{T,S,K,inplace,N}(handle::cufftHandle, X::DenseCuArray{S,N},
sizey::NTuple{N,Int}, region,
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N}
CuFFTPlan{T,S,K,inplace,N}(handle, size(X), sizey, region)
function CuFFTPlan{T,S,K,inplace,N,R,B}(handle::cufftHandle, X::DenseCuArray{S,N},
sizey::NTuple{N,Int}, region::NTuple{R,Int}, buffer::B
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B}
CuFFTPlan{T,S,K,inplace,N,R,B}(handle, size(X), sizey, region, buffer)
end

function CUDA.unsafe_free!(plan::CuFFTPlan)
Expand All @@ -60,6 +62,9 @@ function CUDA.unsafe_free!(plan::CuFFTPlan)
end
plan.handle = C_NULL
end
if !isnothing(plan.buffer)
CUDA.unsafe_free!(plan.buffer)
end
end

function showfftdims(io, sz, T)
Expand Down Expand Up @@ -151,103 +156,116 @@ end
function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
K = CUFFT_FORWARD
inplace = true
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)

CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end


function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
K = CUFFT_INVERSE
inplace = true
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)

CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

# out-of-place complex
function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
K = CUFFT_FORWARD
inplace = false
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)

CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
K = CUFFT_INVERSE
inplace = false
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)

CuFFTPlan{T,T,K,inplace,N}(handle, size(X), size(X), region)
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, size(X), size(X), region, nothing)
end

# out-of-place real-to-complex
function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
K = CUFFT_FORWARD
inplace = false
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
# X = front_view(X, md)
sizex = size(X)[1:md]

handle = cufftGetPlan(complex(T), T, sizex, region)

ydims = collect(size(X))
ydims[region[1]] = div(ydims[region[1]],2)+1
ydims[region[1]] = div(ydims[region[1]], 2) + 1

CuFFTPlan{complex(T),T,K,inplace,N}(handle, size(X), (ydims...,), region)
# The buffer is not needed for real-to-complex (`mul!`),
# but it’s required for complex-to-real (`ldiv!`).
buffer = CuArray{complex(T)}(undef, ydims...)
B = typeof(buffer)

CuFFTPlan{complex(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
end

function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::Any) where {T<:cufftComplexes,N}
# out-of-place complex-to-real
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
K = CUFFT_INVERSE
inplace = false
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

ydims = collect(size(X))
ydims[region[1]] = d

handle = cufftGetPlan(real(T), T, (ydims...,), region)

CuFFTPlan{real(T),T,K,inplace,N}(handle, size(X), (ydims...,), region)
buffer = CuArray{T}(undef, size(X))
B = typeof(buffer)

CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
end


# FIXME: plan_inv methods allocate needlessly (to provide type parameters)
# Perhaps use FakeArray types to avoid this.

function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N}
) where {T<:cufftNumber,S<:cufftNumber,N,inplace}
function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N,R,B}
) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B}
md_osz = plan_max_dims(p.region, p.output_size)
sz_X = p.output_size[1:md_osz]
handle = cufftGetPlan(S, T, sz_X, p.region)
ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N}(handle, p.output_size, p.input_size, p.region),
ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer),
normalization(real(T), p.output_size, p.region))
end

function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N}
) where {T<:cufftNumber,S<:cufftNumber,N,inplace}
function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N,R,B}
) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B}
md_isz = plan_max_dims(p.region, p.input_size)
sz_Y = p.input_size[1:md_isz]
handle = cufftGetPlan(S, T, sz_Y, p.region)
ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N}(handle, p.output_size, p.input_size, p.region),
ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer),
normalization(real(S), p.input_size, p.region))
end

Expand Down Expand Up @@ -309,10 +327,14 @@ function LinearAlgebra.mul!(y::DenseCuArray{T}, p::CuFFTPlan{T,S,K,inplace}, x::
) where {T,S,K,inplace}
assert_applicable(p, x, y)
if !inplace && T<:Real
# Out-of-place complex-to-real FFT will always overwrite input buffer.
x = copy(x)
# Out-of-place complex-to-real FFT will always overwrite input x.
# We copy the input x in an auxiliary buffer.
z = p.buffer
copyto!(z, x)
else
z = x
end
unsafe_execute_trailing!(p, x, y)
unsafe_execute_trailing!(p, z, y)
y
end

Expand All @@ -323,13 +345,21 @@ function Base.:(*)(p::CuFFTPlan{T,S,K,true}, x::DenseCuArray{S}) where {T,S,K}
end

function Base.:(*)(p::CuFFTPlan{T,S,K,false}, x::DenseCuArray{S1,M}) where {T,S,K,S1,M}
if S1 != S || T<:Real
# Convert to the expected input type. Also,
# Out-of-place complex-to-real FFT will always overwrite input buffer.
x = copy1(S, x)
if T<:Real
# Out-of-place complex-to-real FFT will always overwrite input x.
# We copy the input x in an auxiliary buffer.
z = p.buffer
copyto!(z, x)
else
if S1 != S
# Convert to the expected input type.
z = copy1(S, x)
else
z = x
end
end
assert_applicable(p, x)
assert_applicable(p, z)
y = CuArray{T,M}(undef, p.output_size)
unsafe_execute_trailing!(p, x, y)
unsafe_execute_trailing!(p, z, y)
y
end