Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AcceleratedKernels for sorting #688

Merged
merged 2 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "AMDGPU"
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>", "Valentin Churavy <v.churavy@gmail.com>", "Anton Smirnov <tonysmn97@gmail.com>"]
version = "1.0.4"
version = "1.0.5"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down Expand Up @@ -34,6 +35,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"

[compat]
AbstractFFTs = "1.0"
AcceleratedKernels = "0.1.0"
Adapt = "4"
Atomix = "0.1"
CEnum = "0.4, 0.5"
Expand Down
1 change: 1 addition & 0 deletions src/AMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using LLVM, LLVM.Interop
using Preferences
using Printf

import AcceleratedKernels as AK
import UnsafeAtomics
import UnsafeAtomicsLLVM
import Atomix
Expand Down
300 changes: 3 additions & 297 deletions src/kernels/sorting.jl
Original file line number Diff line number Diff line change
@@ -1,297 +1,3 @@
# Ported from CUDA.jl.
# Originally developed by @xaellison (Alex Ellison).

Base.sort!(x::AnyROCArray; kwargs...) = bitonic_sort!(x; kwargs...)

function Base.sortperm!(
ix::AnyROCArray, x::AnyROCArray;
initialized::Bool = false, kwargs...,
)
axes(ix) == axes(x) || throw(ArgumentError(
"Index array must have the same size as the source array, instead: " *
"$(size(ix)) vs $(size(x))."))

initialized || (ix .= LinearIndices(x);)
bitonic_sort!((x, ix); kwargs...)
return ix
end

function Base.sortperm(x::AnyROCArray; kwargs...)
sortperm!(ROCArray(1:length(x)), x; initialized=true, kwargs...)
end

# TODO dims
function bitonic_sort!(X; lt = isless, by = identity, rev::Bool = false)
_shmem(x::Tuple, groupsize) = prod(groupsize) * sum(sizeof.(eltype.(x)))
_shmem(x::AbstractArray, groupsize) = prod(groupsize) * sizeof(eltype(x))

len_x = typeof(X) <: Tuple ? length(X[1]) : length(X)
I = len_x ≤ typemax(Int32) ? Int32 : Int64
threads = min(256, prevpow(2, len_x))

# Compile kernels.
ker_1 = @roc launch=false cmp_small_kern!(
X, I(len_x), one(I), one(I), one(I), by, lt, Val(rev))
ker_2 = @roc launch=false cmp_ker!(
X, I(len_x), one(I), one(I), by, lt, Val(rev))

# Cutoff for when to use `ker_1` vs `ker_2`.
log_threads = Int(log2(threads))

k₀ = ceil(Int, log2(len_x))
for k in k₀:-1:1
j_end = k₀ - k + 1
for j in 1:j_end
if k₀ - k - j + 2 ≤ log_threads
pseudo_block_len = 1 << abs(j_end + 1 - j)
n_pseudo_blocks = nextpow(2, len_x) ÷ pseudo_block_len
pseudo_blocks_per_block = threads ÷ pseudo_block_len

gridsize = max(1, n_pseudo_blocks ÷ pseudo_blocks_per_block)
groupsize = (pseudo_block_len, threads ÷ pseudo_block_len)
ker_1(
X, I(len_x), I(k), I(j), I(j_end), by, lt, Val(rev);
gridsize, groupsize, shmem=_shmem(X, groupsize))
else
gridsize = cld(len_x, threads)
ker_2(
X, I(len_x), I(k), I(j), by, lt, Val(rev);
gridsize, groupsize=threads)
end
end
end
return X
end

function cmp_ker!(x, lenₓ::I, k::I, j::I, by, lt, rev) where I
idx::I = workgroupDim().x * (workgroupIdx().x - 0x1) + workitemIdx().x - 0x1
lo, n, dir = get_range(lenₓ, idx, k, j)

if !(lo < 0x0 || n < 0x0) && !(idx ≥ lenₓ)
m = gp2lt(n)
if lo ≤ idx < lo + n - m
i1, i2 = idx, idx + m
cmp!(x, i1, i2, dir, by, lt, rev)
end
end
return
end

function cmp_small_kern!(x, lenₓ::I, k::I, j₀::I, jₑ::I, by, lt, rev) where I
bidx::I = (workgroupIdx().x - 0x1) * workgroupDim().y + workitemIdx().y - 0x1
_lo, _n, dir = block_range(lenₓ, bidx, k, j₀)

idx = _lo + I(workitemIdx().x) - 0x1
in_range = workitemIdx().x ≤ _n && _lo ≥ 0x0
swap = init_shmem(x, idx, in_range)

lo, n = _lo, _n
for j in j₀:jₑ
if in_range && !(lo < 0x0 || n < 0x0)
m = gp2lt(n)
if lo ≤ idx < lo + n - m
i1, i2 = idx - _lo, idx - _lo + m
cmp_small!(swap, i1, i2, dir, by, lt, rev)
end
end
lo, n = bisect_range(idx, lo, n)
sync_workgroup()
end
finalize_shmem!(x, swap, idx, in_range)
return
end

function bisect_range(idx::I, lo::I, n::I) where I
n ≤ 0x1 && return -one(I), -one(I)

m = gp2lt(n)
if idx < lo + m
n = m
else
lo += m
n -= m
end
lo, n
end

function cmp!(
x::AbstractArray, i1::I, i2::I, dir::Bool, by, lt, rev,
) where I
i1, i2 = i1 + one(I), i2 + one(I)
@inbounds if dir != _lt_fn(by(x[i1]), by(x[i2]), lt, rev)
x[i1], x[i2] = x[i2], x[i1]
end
end

function cmp!(
X::Tuple, i1::I, i2::I, dir::Bool, by, lt, rev,
) where I
i1, i2 = i1 + one(I), i2 + one(I)
x, ix = X
cmp_res = _lt_fn(
(by(x[ix[i1]]), ix[i1]),
(by(x[ix[i2]]), ix[i2]), lt, rev)
@inbounds if dir != cmp_res
ix[i1], ix[i2] = ix[i2], ix[i1]
end
end

function cmp_small!(
swap::AbstractArray, i1::I, i2::I, dir::Bool, by, lt, rev,
) where I
i1, i2 = i1 + one(I), i2 + one(I)
@inbounds if dir != _lt_fn(by(swap[i1]), by(swap[i2]), lt, rev)
swap[i1], swap[i2] = swap[i2], swap[i1]
end
end

function cmp_small!(
swap::Tuple, i1::I, i2::I, dir::Bool, by, lt, rev,
) where I
i1, i2 = i1 + one(I), i2 + one(I)
x, ix = swap
cmp_res = _lt_fn(
(by(x[i1]), ix[i1]),
(by(x[i2]), ix[i2]), lt, rev)
@inbounds if dir != cmp_res
x[i1], x[i2] = x[i2], x[i1]
ix[i1], ix[i2] = ix[i2], ix[i1]
end
end

@inline function _lt_fn(a::T, b::T, lt, rev::Val{R}) where {T, R}
if R
lt(b, a)
else
lt(a, b)
end
end

@inline function _lt_fn(a::Tuple{T, J}, b::Tuple{T, J}, lt, rev::Val{R}) where {T, J, R}
if R
if a[1] == b[1]
return a[2] < b[2] # Compare indices.
else
return lt(b[1], a[1])
end
else
return lt(a, b)
end
end

function init_shmem(x::AbstractArray{T}, idx, in_range::Bool, offset=0) where T
swap = @ROCDynamicLocalArray(
T, (workgroupDim().x, workgroupDim().y), false, offset)
if in_range
@inbounds swap[workitemIdx().x, workitemIdx().y] = x[idx + 0x1]
end
sync_workgroup()
@inbounds @view(swap[:, workitemIdx().y])
end

function init_shmem(
X::Tuple{AbstractArray{T}, AbstractArray{J}}, idx, in_range::Bool,
) where {T, J}
x, ix = X
idx_swap = init_shmem(ix, idx, in_range)
offset = (workgroupDim().x * workgroupDim().y) * sizeof(J)
swap = init_shmem(x, idx_swap[workitemIdx().x] - 0x1, in_range, offset)
swap, idx_swap
end

"""
Copy `swap` back into global memory `x`.
"""
function finalize_shmem!(
x::AbstractArray, swap::AbstractArray, idx, in_range::Bool,
)
if in_range
@inbounds x[idx + 0x1] = swap[workitemIdx().x]
end
end

function finalize_shmem!(X::Tuple, swap::Tuple, idx, in_range::Bool)
x, ix = X
x_swap, idx_swap = swap
finalize_shmem!(ix, idx_swap, idx, in_range)
end

function get_range_part1(n::I, index::I, k::I) where I
lo = zero(I)
dir = true
for iter in one(I):(k - one(I))
if n ≤ one(I)
return -one(I), -one(I), false
end

if index < lo + n ÷ 0x2
n = n ÷ 0x2
dir = !dir
else
lo = lo + n ÷ 0x2
n = n - n ÷ 0x2
end
end
lo, n, dir
end

function get_range_part2(lo::I, n::I, index::I, j::I) where I
for iter in one(I):(j - one(I))
lo, n = bisect_range(index, lo, n)
end
lo, n
end

# Determine parameters for swapping.
function get_range(n, idx, k, j)
lo, n, dir = get_range_part1(n, idx, k)
lo, n = get_range_part2(lo, n, idx, j)
lo, n, dir
end

function block_range(n::I, bidx::I, k::I, j::I) where I
lo = zero(I)
dir = true
tmp = bidx * I(2)

# Part 1.
for i in one(I):(k - one(I))
tmp ÷= I(2)
n ≤ one(I) && return -one(I), -one(I), false

if tmp % I(2) == zero(I)
n ÷= I(2)
dir = !dir
else
lo += n ÷ I(2)
n -= n ÷ I(2)
end
end

# Part 2.
for i in one(I):(j - one(I))
tmp ÷= I(2)
n ≤ one(I) && return -one(I), -one(I), false

m = gp2lt(n)
if tmp % I(2) == zero(I)
n = m
else
lo += m
n -= m
end
end

(zero(I) ≤ n ≤ one(I)) && return -one(I), -one(I), false
lo, n, dir
end

@inline function gp2lt(x::I)::I where I
x -= one(I)
x |= x >> 0x1
x |= x >> 0x2
x |= x >> 0x4
x |= x >> 0x8
x |= x >> 0x102
x ⊻ (x >> 0x1)
end
Base.sort!(x::AnyROCArray; kwargs...) = (AK.sort!(x; kwargs...); return x)
Base.sortperm!(ix::AnyROCArray, x::AnyROCArray; kwargs...) = (AK.sortperm!(ix, x; kwargs...); return ix)
Base.sortperm(x::AnyROCArray; kwargs...) = sortperm!(ROCArray(1:length(x)), x; kwargs...)
2 changes: 0 additions & 2 deletions test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ end
end

include("codegen/codegen.jl")

include("rocarray/base.jl")
include("rocarray/broadcast.jl")

include("tls.jl")

end