Skip to content

Commit

Permalink
implement bitonic sorting network for SVectors
Browse files Browse the repository at this point in the history
  • Loading branch information
stev47 committed Mar 21, 2020
1 parent 95f2578 commit 2ca8a3e
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ include("abstractarray.jl")
include("indexing.jl")
include("broadcast.jl")
include("mapreduce.jl")
include("sort.jl")
include("arraymath.jl")
include("linalg.jl")
include("matrix_multiply.jl")
Expand Down
104 changes: 104 additions & 0 deletions src/sort.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import Base.@_inline_meta
import Base.Order: Ordering, Forward, ReverseOrdering, ord
import Base.Sort: Algorithm, defalg, lt, sort


struct BitonicSortAlg <: Algorithm end
struct MinSizeSortAlg <: Algorithm end
struct MinDepthSortAlg <: Algorithm end
const MinSortAlg = Union{MinSizeSortAlg,MinDepthSortAlg}

const BitonicSort = BitonicSortAlg()
const MinSizeSort = MinSizeSortAlg()
const MinDepthSort = MinDepthSortAlg()

defalg(::SVector) = BitonicSort

function sort(a::SVector;
alg::Algorithm = defalg(a),
lt = isless,
by = identity,
rev::Union{Bool,Nothing} = nothing,
order::Ordering = Forward)
ordr = ord(lt, by, rev, order)
_sort(Size(a), alg, ordr, a)
end

_sort(::Size{T}, alg, _, _) where T =
error("sorting algorithm $alg unimplemented for static array of size $T")


@inline _cmpswap(order, a, b) = lt(order, a, b) ? (a, b) : (b, a)

@inline _sort(::Size{(1,)}, _, _, a) = a
@inline _sort(::Size{(2,)}, _, order, (a1, a2)) = SVector(_cmpswap(order, a1, a2))

@inline _sort(::Size{(1,)}, ::BitonicSortAlg, _, a) = a
@inline _sort(s::Size{(2,)}, ::BitonicSortAlg, order, (a1, a2)) = SVector(_cmpswap(order, a1, a2))
@generated function _sort(::Size{S}, ::BitonicSortAlg, order, a) where {S}
function swap_expr(i, j, dir)
ai = Symbol('a', i)
aj = Symbol('a', j)
order = dir ? :revorder : :order
return :( ($ai, $aj) = _cmpswap($order, $ai, $aj) )
end

function merge_exprs(idx, dir)
exprs = Expr[]
length(idx) == 1 && return exprs

ci = 2^(ceil(Int, log2(length(idx))) - 1)
# TODO: generate simd code for these swaps
for i in first(idx):last(idx)-ci
push!(exprs, swap_expr(i, i+ci, dir))
end
append!(exprs, merge_exprs(idx[1:ci], dir))
append!(exprs, merge_exprs(idx[ci+1:end], dir))
return exprs
end

function sort_exprs(idx, dir)
exprs = Expr[]
length(idx) == 1 && return exprs

append!(exprs, sort_exprs(idx[1:end÷2], !dir))
append!(exprs, sort_exprs(idx[end÷2+1:end], dir))
append!(exprs, merge_exprs(idx, dir))
return exprs
end

idx = 1:prod(S)
symlist = (Symbol('a', i) for i in idx)
sym_exprs = (:( $ai = a[$i] ) for (i, ai) in enumerate(symlist))
return quote
@_inline_meta
revorder = Base.Order.ReverseOrdering(order)
@inbounds ($(sym_exprs...);)
($(sort_exprs(idx, false)...);)
return SVector(($(symlist...)))
end
end


## TODO: manually implementing minimal sorting networks for small lengths might
## be worthwhile
#
#macro _cmpswap(order, a, b)
# return esc(:( ($a, $b) = _cmpswap(order, $a, $b) ))
#end
#
#@inline function _sort(::Size{(3,)}, ::MinSortAlg, order, (a1, a2, a3))
# @_cmpswap order a1 a3
# @_cmpswap order a1 a2
# @_cmpswap order a2 a3
# return SVector(a1, a2, a3)
#end
#
#@inline function _sort(::Size{(4,)}, ::MinSortAlg, order, (a1, a2, a3, a4))
# @_cmpswap order a1 a3
# @_cmpswap order a2 a4
# @_cmpswap order a1 a2
# @_cmpswap order a3 a4
# @_cmpswap order a2 a3
# return SVector(a1, a2, a3, a4)
#end

0 comments on commit 2ca8a3e

Please sign in to comment.