Skip to content

Commit

Permalink
implement bitonic sorting network for SVectors
Browse files Browse the repository at this point in the history
  • Loading branch information
stev47 committed Mar 23, 2020
1 parent 95f2578 commit 482af35
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ include("abstractarray.jl")
include("indexing.jl")
include("broadcast.jl")
include("mapreduce.jl")
include("sort.jl")
include("arraymath.jl")
include("linalg.jl")
include("matrix_multiply.jl")
Expand Down
71 changes: 71 additions & 0 deletions src/sort.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import Base.@_inline_meta
import Base.Order: Ordering, Forward, ReverseOrdering, ord
import Base.Sort: Algorithm, defalg, lt, sort


struct BitonicSortAlg <: Algorithm end

const BitonicSort = BitonicSortAlg()

# BitonicSort has non-optimal asymptotic behaviour, so we define a cutoff length.
# This also prevents compilation time to skyrocket for larger vectors.
defalg(a::StaticVector) = isimmutable(a) && length(a) <= 20 ? BitonicSort : QuickSort

@inline function sort(a::StaticVector;
alg::Algorithm = defalg(a),
lt = isless,
by = identity,
rev::Union{Bool,Nothing} = nothing,
order::Ordering = Forward)
length(a) <= 1 && return a
ordr = ord(lt, by, rev, order)
return _sort(Size(a), alg, ordr, a)
end

@inline _sort(_, alg, order, a::StaticVector) = sort!(Base.copymutable(a); alg=alg, order=order)
@inline _sort(_, alg::BitonicSortAlg, order, a::StaticVector) = similar_type(a)(_sort(Tuple(a), alg, order))

# Implementation loosely following
# https://www.inf.hs-flensburg.de/lang/algorithmen/sortieren/bitonic/oddn.htm
@generated function _sort(a::NTuple{N}, ::BitonicSortAlg, order) where N
function swap_expr(i, j, rev)
ai = Symbol('a', i)
aj = Symbol('a', j)
order = rev ? :revorder : :order
return :( ($ai, $aj) = lt($order, $ai, $aj) ? ($ai, $aj) : ($aj, $ai) )
end

function merge_exprs(idx, rev)
exprs = Expr[]
length(idx) == 1 && return exprs

ci = 2^(ceil(Int, log2(length(idx))) - 1)
# TODO: generate simd code for these swaps
for i in first(idx):last(idx)-ci
push!(exprs, swap_expr(i, i+ci, rev))
end
append!(exprs, merge_exprs(idx[1:ci], rev))
append!(exprs, merge_exprs(idx[ci+1:end], rev))
return exprs
end

function sort_exprs(idx, rev=false)
exprs = Expr[]
length(idx) == 1 && return exprs

append!(exprs, sort_exprs(idx[1:end÷2], !rev))
append!(exprs, sort_exprs(idx[end÷2+1:end], rev))
append!(exprs, merge_exprs(idx, rev))
return exprs
end

idx = 1:N
symlist = (Symbol('a', i) for i in idx)
return quote
@_inline_meta
revorder = Base.Order.ReverseOrdering(order)
($(symlist...),) = a
($(sort_exprs(idx)...);)
return ($(symlist...),)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("abstractarray.jl")
include("indexing.jl")
include("initializers.jl")
Random.seed!(42); include("mapreduce.jl")
Random.seed!(42); include("sort.jl")
Random.seed!(42); include("accumulate.jl")
Random.seed!(42); include("arraymath.jl")
include("broadcast.jl")
Expand Down
20 changes: 20 additions & 0 deletions test/sort.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using StaticArrays, Test

@testset "sort" begin

@testset "basics" for T in (Int, Float64)
for N in (0, 1, 2, 3, 10, 20)
v = rand(SVector{N,T})
vs = sort!(Base.copymutable(v))

@test vs == @inferred sort(v)
@test 0 == @allocated sort(v)
end
end

@testset "fallbacks" begin
@test @inferred(sort(rand(SVector{3}), alg=QuickSort)) isa MVector
@test @inferred(sort(rand(SVector{21}))) isa MVector
end

end

0 comments on commit 482af35

Please sign in to comment.