From 510dce676641b723b473960582a3540574215c1e Mon Sep 17 00:00:00 2001 From: Bruno Roustant <33934988+bruno-roustant@users.noreply.github.com> Date: Wed, 3 Nov 2021 09:50:16 +0100 Subject: [PATCH] GH-34: Optimize QuickSort inner loops. --- .../carrotsearch/hppc/sorting/QuickSort.java | 123 +++++++++--------- 1 file changed, 60 insertions(+), 63 deletions(-) diff --git a/hppc/src/main/java/com/carrotsearch/hppc/sorting/QuickSort.java b/hppc/src/main/java/com/carrotsearch/hppc/sorting/QuickSort.java index 6e8e0b51..ccfe7308 100644 --- a/hppc/src/main/java/com/carrotsearch/hppc/sorting/QuickSort.java +++ b/hppc/src/main/java/com/carrotsearch/hppc/sorting/QuickSort.java @@ -22,11 +22,8 @@ public final class QuickSort { /** Below this size threshold, the sub-range is sorted using Insertion sort. */ static final int INSERTION_SORT_THRESHOLD = 16; - /** - * Below this size threshold, the partition selection is simplified to taking the middle as the - * pivot. - */ - static final int MIDDLE_PIVOT_THRESHOLD = 40; + /** /** Below this size threshold, the partition selection is simplified to a single median. */ + static final int SINGLE_MEDIAN_THRESHOLD = 40; /** No instantiation. */ private QuickSort() { @@ -65,107 +62,107 @@ public static void sort(int[] array, int fromIndex, int toIndex, IntBinaryOperat */ public static void sort( int fromIndex, int toIndex, IntBinaryOperator comparator, IntBinaryOperator swapper) { - sortInner(fromIndex, toIndex - 1, comparator, swapper); - } - - /** - * @param start Start index, inclusive. - * @param end End index, inclusive. - * @param c Compares elements based on their indices. - * @param s Swaps the elements in the array at the given indices. - */ - private static void sortInner(int start, int end, IntBinaryOperator c, IntBinaryOperator s) { int size; - while ((size = end - start + 1) > INSERTION_SORT_THRESHOLD) { + while ((size = toIndex - fromIndex) > INSERTION_SORT_THRESHOLD) { // Pivot selection. - int middle = (start + end) >>> 1; + int last = toIndex - 1; + int middle = (fromIndex + last) >>> 1; int pivot; - if (size <= MIDDLE_PIVOT_THRESHOLD) { - // Select the middle as the pivot. - // If we select the median of [start, middle, end] as the pivot there is a performance - // degradation if the array is in descending order. - pivot = middle; + if (size <= SINGLE_MEDIAN_THRESHOLD) { + // Select the pivot with a single median around the middle element. + // Do not take the median between [from, mid, last] because it hurts performance + // if the order is descending. + int range = size >> 2; + pivot = median(middle - range, middle, middle + range, comparator); } else { // Select the pivot with the median of medians. int range = size >> 3; int doubleRange = range << 1; - int medianStart = median(start, start + range, start + doubleRange, c); - int medianMiddle = median(middle - range, middle, middle + range, c); - int medianEnd = median(end - doubleRange, end - range, end, c); - pivot = median(medianStart, medianMiddle, medianEnd, c); + int medianStart = median(fromIndex, fromIndex + range, fromIndex + doubleRange, comparator); + int medianMiddle = median(middle - range, middle, middle + range, comparator); + int medianEnd = median(last - doubleRange, last - range, last, comparator); + pivot = median(medianStart, medianMiddle, medianEnd, comparator); } - // 3-way partitioning. - swap(start, pivot, s); - int i = start; - int j = end + 1; - int p = start + 1; - int q = end; + // Bentley-McIlroy 3-way partitioning. + swap(fromIndex, pivot, swapper); + int i = fromIndex; + int j = toIndex; + int p = fromIndex + 1; + int q = last; while (true) { - while (++i < end && comp(i, start, c) < 0) ; - while (--j > start && comp(j, start, c) > 0) ; + int leftCmp, rightCmp; + while ((leftCmp = compare(++i, fromIndex, comparator)) < 0) ; + while ((rightCmp = compare(--j, fromIndex, comparator)) > 0) ; if (i >= j) { - if (i == j && comp(i, start, c) == 0) { - swap(i, p, s); + if (i == j && rightCmp == 0) { + swap(i, p, swapper); } break; } - swap(i, j, s); - if (comp(i, start, c) == 0) { - swap(i, p++, s); + swap(i, j, swapper); + if (rightCmp == 0) { + swap(i, p++, swapper); } - if (comp(j, start, c) == 0) { - swap(j, q--, s); + if (leftCmp == 0) { + swap(j, q--, swapper); } } i = j + 1; - for (int k = start; k < p; k++) { - swap(k, j--, s); + for (int k = fromIndex; k < p; ) { + swap(k++, j--, swapper); } - for (int k = end; k > q; k--) { - swap(k, i++, s); + for (int k = last; k > q; ) { + swap(k--, i++, swapper); } // Recursion on the smallest partition. // Replace the tail recursion by a loop. - if (j - start < end - i) { - sortInner(start, j, c, s); - start = i; + if (j - fromIndex < last - i) { + sort(fromIndex, j + 1, comparator, swapper); + fromIndex = i; } else { - sortInner(i, end, c, s); - end = j; + sort(i, toIndex, comparator, swapper); + toIndex = j + 1; } } - insertionSort(start, end, c, s); + insertionSort(fromIndex, toIndex, comparator, swapper); } - /** Sorts from start to end indices inclusive with insertion sort. */ - private static void insertionSort(int start, int end, IntBinaryOperator c, IntBinaryOperator s) { - for (int i = start + 1; i <= end; i++) { - for (int j = i; j > start && comp(j - 1, j, c) > 0; j--) { - swap(j - 1, j, s); + /** Sorts between from (inclusive) and to (exclusive) with insertion sort. */ + private static void insertionSort( + int fromIndex, int toIndex, IntBinaryOperator comparator, IntBinaryOperator swapper) { + for (int i = fromIndex + 1; i < toIndex; ) { + int current = i++; + int previous; + while (compare((previous = current - 1), current, comparator) > 0) { + swap(previous, current, swapper); + if (previous == fromIndex) { + break; + } + current = previous; } } } /** Returns the index of the median element among three elements at provided indices. */ - private static int median(int i, int j, int k, IntBinaryOperator c) { - if (comp(i, j, c) < 0) { - if (comp(j, k, c) <= 0) { + private static int median(int i, int j, int k, IntBinaryOperator comparator) { + if (compare(i, j, comparator) < 0) { + if (compare(j, k, comparator) <= 0) { return j; } - return comp(i, k, c) < 0 ? k : i; + return compare(i, k, comparator) < 0 ? k : i; } - if (comp(j, k, c) >= 0) { + if (compare(j, k, comparator) >= 0) { return j; } - return comp(i, k, c) < 0 ? i : k; + return compare(i, k, comparator) < 0 ? i : k; } /** Compares two elements at provided indices. */ - private static int comp(int i, int j, IntBinaryOperator comparator) { + private static int compare(int i, int j, IntBinaryOperator comparator) { return comparator.applyAsInt(i, j); }