diff --git a/src/ThreadsX.jl b/src/ThreadsX.jl index ded807ad..9a1c5c58 100644 --- a/src/ThreadsX.jl +++ b/src/ThreadsX.jl @@ -82,6 +82,7 @@ else foldxt(rf, xs; kw...) = reduce(rf, Map(identity), xs; kw...) end +include("debug.jl") include("utils.jl") include("basesizes.jl") include("reduce.jl") diff --git a/src/debug.jl b/src/debug.jl new file mode 100644 index 00000000..1e0843ff --- /dev/null +++ b/src/debug.jl @@ -0,0 +1,13 @@ +""" + @DBG expression_to_run_only_when_debugging +""" +macro DBG(ex) + quote + isdebugging() && $(esc(ex)) + nothing + end +end + +isdebugging() = false +enable_debug() = (@eval isdebugging() = true; nothing) +disable_debug() = (@eval isdebugging() = false; nothing) diff --git a/src/quicksort.jl b/src/quicksort.jl index 6e54c9fb..88b361cd 100644 --- a/src/quicksort.jl +++ b/src/quicksort.jl @@ -62,6 +62,7 @@ function _quicksort!( xs[end], ), ) + chunksize = alg.basesize # TODO: Calculate extrema during the first pass if it's possible # to use counting sort. @@ -69,33 +70,56 @@ function _quicksort!( # first pass. # Compute sizes of each partition for each chunks. - chunks = zip(_partition(xs, alg.basesize), _partition(cs, alg.basesize)) - results = maptasks(partition_sizes!(pivot, order), chunks) - nbelows::Vector{Int} = map(first, results) - nequals::Vector{Int} = map(last, results) - naboves::Vector{Int} = - [length(c) - (b + e) for (b, e, (c, _)) in zip(nbelows, nequals, chunks)] - @check length(chunks) == length(nbelows) == length(nequals) == length(naboves) - @check all(>=(0), naboves) - singleton_chunkid = map(nbelows, nequals, naboves) do nb, ne, na - if (nb > 0) + (ne > 0) + (na > 0) == 1 - return 1 * (nb > 0) + 2 * (ne > 0) + 3 * (na > 0) - else - return 0 - end + xs_chunk_list = _partition(xs, chunksize) + cs_chunk_list = _partition(cs, chunksize) + nchunks = cld(length(xs), chunksize) + nbelows = Vector{Int}(undef, nchunks) + nequals = Vector{Int}(undef, nchunks) + naboves = Vector{Int}(undef, nchunks) + @DBG begin + VERSION >= v"1.4" && + @check length(xs_chunk_list) == length(cs_chunk_list) == nchunks + fill!(nbelows, -1) + fill!(nequals, -1) + fill!(naboves, -1) + end + @sync for (nb, ne, na, xs_chunk, cs_chunk) in zip( + referenceable(nbelows), + referenceable(nequals), + referenceable(naboves), + xs_chunk_list, + cs_chunk_list, + ) + @spawn partition_sizes!(nb, ne, na, xs_chunk, cs_chunk, pivot, order) + end + @DBG begin + @check all(>=(0), nbelows) + @check all(>=(0), nequals) + @check all(>=(0), naboves) end - below_offsets = copy(nbelows) - equal_offsets = copy(nequals) - above_offsets = copy(naboves) + below_offsets = nbelows + equal_offsets = nequals + above_offsets = naboves acc = exclusive_cumsum!(below_offsets) acc = exclusive_cumsum!(equal_offsets, acc) acc = exclusive_cumsum!(above_offsets, acc) @check acc == length(xs) + @inline function singleton_chunkid(i) + nb = @inbounds get(below_offsets, i + 1, equal_offsets[1]) - below_offsets[i] + ne = @inbounds get(equal_offsets, i + 1, above_offsets[1]) - equal_offsets[i] + na = @inbounds get(above_offsets, i + 1, length(ys)) - above_offsets[i] + if (nb > 0) + (ne > 0) + (na > 0) == 1 + return 1 * (nb > 0) + 2 * (ne > 0) + 3 * (na > 0) + else + return 0 + end + end + @sync begin - for (i, (xs_chunk, cs_chunk)) in enumerate(chunks) - singleton_chunkid[i] > 0 && continue + for (i, (xs_chunk, cs_chunk)) in enumerate(zip(xs_chunk_list, cs_chunk_list)) + singleton_chunkid(i) > 0 && continue @spawn unsafe_quicksort_scatter!( ys, xs_chunk, @@ -105,13 +129,14 @@ function _quicksort!( above_offsets[i], ) end - for (i, (xs_chunk, _)) in enumerate(chunks) - singleton_chunkid[i] > 0 || continue + for (i, xs_chunk) in enumerate(xs_chunk_list) + sid = singleton_chunkid(i) + sid > 0 || continue idx = ( below_offsets[i]+1:get(below_offsets, i + 1, equal_offsets[1]), equal_offsets[i]+1:get(equal_offsets, i + 1, above_offsets[1]), above_offsets[i]+1:get(above_offsets, i + 1, length(ys)), - )[singleton_chunkid[i]] + )[sid] # There is only one partition. Short-circuit scattering. ys_chunk = view(ys, idx) copyto!(ys_chunk, xs_chunk) @@ -226,7 +251,13 @@ function _quicksort_serial!( return ys_is_result ? ys : xs end -partition_sizes!(pivot, order) = ((xs, cs),) -> partition_sizes!(xs, cs, pivot, order) +function partition_sizes!(nbelows, nequals, naboves, xs, cs, pivot, order) + (nb, ne) = partition_sizes!(xs, cs, pivot, order) + nbelows[] = nb + nequals[] = ne + naboves[] = length(xs) - (nb + ne) + return +end function partition_sizes!(xs, cs, pivot, order) nbelows = 0