Skip to content

Commit

Permalink
Update FFTW multi-threading
Browse files Browse the repository at this point in the history
  • Loading branch information
kamesy committed Jul 22, 2022
1 parent 88dd622 commit 357b313
Show file tree
Hide file tree
Showing 12 changed files with 72 additions and 63 deletions.
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["kamesy <ckames@physics.ubc.ca>"]
version = "0.2.0"

[deps]
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FastPow = "c0e83750-1142-43a8-81cf-6c956b72b4d1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -13,18 +14,21 @@ Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
PolyesterWeave = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
ThreadingUtilities = "8290d209-cae3-49c0-8002-c8c24d57dab5"
TiledIteration = "06e1c1a7-607b-532d-9fad-de7d9aa2abac"

[compat]
FFTW = "1.4.6"
CPUSummary = "0.1.23"
FFTW = "1.5.0"
FastPow = "0.1.0"
LinearMaps = "3.6.1"
NIfTI = "0.5.7"
Polyester = "0.6.8"
PolyesterWeave = "0.1.5"
Polyester = "0.6.14"
PolyesterWeave = "0.1.7"
SLEEFPirates = "0.6.31"
Static = "0.7.6"
StaticArrays = "1"
ThreadingUtilities = "0.5.0"
TiledIteration = "0.3.1"
Expand Down
63 changes: 37 additions & 26 deletions src/QSM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ module QSM


using Base.Threads: nthreads
using CPUSummary: num_cores
using FastPow: @fastpow
using LinearMaps: LinearMap
using NIfTI: NIVolume, niread, niwrite
using Polyester: @batch, num_cores
using Polyester: @batch
using PolyesterWeave: reset_workers!
using Printf: @printf
using SLEEFPirates: pow, sincos_fast
using Static: known
using StaticArrays: SVector
using ThreadingUtilities: initialize_task
using TiledIteration: EdgeIterator, TileIterator, padded_tilesize
Expand Down Expand Up @@ -38,9 +40,8 @@ include("inversion/inversion.jl")

function __init__()
@static if FFTW.fftw_provider == "fftw"
fftw_set_threading(:Polyester)
fftw_set_threading(:FFTW)
end
FFTW.set_num_threads(num_cores())
return nothing
end

Expand All @@ -60,6 +61,9 @@ end
#####
##### FFTW.jl
#####
const FFTW_NTHREADS = Ref{Int}(known(num_cores()))


@static if FFTW.fftw_provider == "fftw"
# modified `FFTW.spawnloop` to use Polyester for multi-threading
# https://github.com/JuliaMath/FFTW.jl/blob/v1.4.5/src/providers.jl#L49
Expand All @@ -70,32 +74,39 @@ end
return nothing
end

function fftw_set_threading(lib::Symbol = :Polyester)
if nthreads() > 1
if lib == :Polyester
cspawnloop = @cfunction(
_fftw_spawnloop,
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid})
)
elseif lib == :Threads
cspawnloop = @cfunction(
FFTW.spawnloop,
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid})
)
else
throw(ArgumentError("lib must be one of :Polyester or :Threads"))
end

ccall(
(:fftw_threads_set_callback, FFTW.libfftw3[]),
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL
)
function fftw_set_threading(lib::Symbol = :FFTW)
lib (:FFTW, :Polyester, :Threads) ||
throw(ArgumentError("lib must be one of :FFTW, :Polyester or :Threads, got :$(lib)"))

if lib (:Polyester, :Threads) && nthreads() < 2
@warn "Cannot use $lib with FFTW. Defaulting to FFTW multi-threading" Threads.nthreads()
lib = :FFTW
end

ccall(
(:fftwf_threads_set_callback, FFTW.libfftw3f[]),
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL
if lib == :Polyester
cspawnloop = @cfunction(
_fftw_spawnloop,
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid})
)
elseif lib == :Threads
cspawnloop = @cfunction(
FFTW.spawnloop,
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid})
)
else
cspawnloop = C_NULL
end

ccall(
(:fftw_threads_set_callback, FFTW.libfftw3[]),
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL
)

ccall(
(:fftwf_threads_set_callback, FFTW.libfftw3f[]),
Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL
)

return nothing
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/bgremove/ismv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function _ismv!(
S = Array{T}(undef, sz_)
= Array{complex(T)}(undef, sz_)

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(fp)
iP = inv(P)

Expand Down
2 changes: 1 addition & 1 deletion src/bgremove/pdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function _pdf!(
D = Array{T, 3}(undef, sz_)
= Array{complex(T), 3}(undef, sz_)

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(fp)
iP = inv(P)

Expand Down
6 changes: 3 additions & 3 deletions src/bgremove/sharp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function _sharp!(
S = Array{T}(undef, sz_)
= Array{complex(T)}(undef, sz_)

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(fp)
iP = inv(P)

Expand Down Expand Up @@ -212,7 +212,7 @@ function _sharp!(

m = tfill!(m, 0)

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(s)
iP = inv(P)

Expand Down Expand Up @@ -321,7 +321,7 @@ function _sharp!(

m = tfill!(m, 0)

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(s)
iP = inv(P)

Expand Down
2 changes: 1 addition & 1 deletion src/inversion/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ function _kdiv!(
D = Array{T, 3}(undef, sz_)
= Array{complex(T), 3}(undef, sz_)

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(fp)
iP = inv(P)

Expand Down
2 changes: 1 addition & 1 deletion src/inversion/nltv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ function _nltv!(
= Array{complex(T)}(undef, sz_) # in-place rfft var
= Array{complex(T)}(undef, sz_) # pre-computed rhs

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(xp)
iP = inv(P)

Expand Down
2 changes: 1 addition & 1 deletion src/inversion/rts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function _rts!(
M = Array{T}(undef, sz_) # abs(D) > δ
iA = Array{T}(undef, sz_) # 1 / (mu*M - rho*Δ)

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(xp)
iP = inv(P)

Expand Down
2 changes: 1 addition & 1 deletion src/inversion/tv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ function _tv!(
= Array{complex(T)}(undef, sz_) # in-place rfft var
= Array{complex(T)}(undef, sz_) # pre-computed rhs

FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = plan_rfft(xp)
iP = inv(P)

Expand Down
12 changes: 6 additions & 6 deletions src/utils/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function dipole_kernel(
return _dipole_kernel!(d, dsz, vsz, bdir, :i; shift=shift)

elseif transform == :fft
FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
D = Array{T, 3}(undef, sz)
= Array{complex(T), 3}(undef, sz)
P = plan_fft!(D̂)
Expand All @@ -89,7 +89,7 @@ function dipole_kernel(
return shift ? fftshift(D) : D

elseif transform == :rfft
FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
D = Array{T, 3}(undef, (sz[1]>>1 + 1, sz[2], sz[3]))
= Array{complex(T), 3}(undef, (sz[1]>>1 + 1, sz[2], sz[3]))
d = Array{T, 3}(undef, sz)
Expand Down Expand Up @@ -289,7 +289,7 @@ function smv_kernel(
return _smv_kernel!(s, vsz, r, shift=shift)

elseif transform == :fft
FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
S = Array{T, 3}(undef, sz)
= Array{complex(T), 3}(undef, sz)
P = plan_fft!(Ŝ)
Expand All @@ -298,7 +298,7 @@ function smv_kernel(
return shift ? fftshift(S) : S

elseif transform == :rfft
FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
S = Array{T, 3}(undef, (sz[1]>>1 + 1, sz[2], sz[3]))
= Array{complex(T), 3}(undef, (sz[1]>>1 + 1, sz[2], sz[3]))
s = Array{T, 3}(undef, sz)
Expand Down Expand Up @@ -465,7 +465,7 @@ function laplace_kernel(
return _laplace_kernel!(Δ, vsz, negative=negative, shift=shift)

elseif transform == :fft
FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
L = Array{T, 3}(undef, sz)
= Array{complex(T), 3}(undef, sz)
P = plan_fft!(L̂)
Expand All @@ -474,7 +474,7 @@ function laplace_kernel(
return shift ? fftshift(L) : L

elseif transform == :rfft
FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
L = Array{T, 3}(undef, (sz[1]>>1 + 1, sz[2], sz[3]))
= Array{complex(T), 3}(undef, (sz[1]>>1 + 1, sz[2], sz[3]))
Δ = Array{T, 3}(undef, sz)
Expand Down
30 changes: 12 additions & 18 deletions src/utils/poisson_solver/poisson_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function solve_poisson_dct(
dx::NTuple{3, Real}
) where {N}
N (3, 4) || throw(ArgumentError("arrays must be 3d or 4d, got $(N)d"))
return solve_poisson_dct!(tcopy(d2u), dx)
return solve_poisson_dct!(similar(d2u), d2u, dx)
end

function solve_poisson_dct!(
Expand All @@ -91,34 +91,26 @@ function solve_poisson_dct!(
) where {N}
N (3, 4) || throw(ArgumentError("arrays must be 3d or 4d, got $(N)d"))
size(u) == size(d2u) || throw(DimensionMismatch())
return solve_poisson_dct!(_tcopyto!(u, d2u), dx)
end

function solve_poisson_dct!(
d2u::AbstractArray{<:AbstractFloat, N},
dx::NTuple{3, Real}
) where {N}
N (3, 4) || throw(ArgumentError("arrays must be 3d or 4d, got $(N)d"))

nx, ny, nz = size(d2u)[1:3]
idx2 = inv(dx[1].*dx[1])
idy2 = inv(dx[2].*dx[2])
idz2 = inv(dx[3].*dx[3])

u = d2u

FFTW.set_num_threads(num_cores())
P = plan_dct!(u, 1:3)
# extreme slowdown for certain sizes with lots of threads
# even worse for in-place, ie dct!
FFTW.set_num_threads(max(1, FFTW_NTHREADS[]÷2))
P = plan_dct(u, 1:3)
iP = inv(P)

u = P*u
d2û = P*d2u

X = [2*(cospi(i)-1)*idx2 for i in range(0, step=1/nx, length=nx)]
Y = [2*(cospi(j)-1)*idy2 for j in range(0, step=1/ny, length=ny)]
Z = [2*(cospi(k)-1)*idz2 for k in range(0, step=1/nz, length=nz)]

@inbounds for t in axes(u, 4)
d2ût = @view(u[:,:,:,t])
@inbounds for t in axes(d2û, 4)
d2ût = @view(d2û[:,:,:,t])
@batch for k in 1:nz
for j in 1:ny
for i in 1:nx
Expand All @@ -131,7 +123,9 @@ function solve_poisson_dct!(
end

# inverse dct
return iP*u
u = mul!(u, iP, d2û)

return u
end


Expand All @@ -157,7 +151,7 @@ function solve_poisson_fft!(
idz2 = inv(dx[3].*dx[3])

_rfft = iseven(nx)
FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])

# FFTW's rfft is extremely slow with some odd lengths in the first dim
if _rfft
Expand Down
2 changes: 1 addition & 1 deletion src/utils/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ function psf2otf(
kp = circshift!(tzero(_kp), _kp, .-szk2)

# fft
FFTW.set_num_threads(num_cores())
FFTW.set_num_threads(FFTW_NTHREADS[])
P = T <: Real && rfft ? plan_rfft(kp) : plan_fft(kp)

K = P*kp
Expand Down

0 comments on commit 357b313

Please sign in to comment.