linux/lib/sort.c
Kuan-Wei Chiu 41ed780435 lib/sort: optimize heapsort for handling final 2 or 3 elements
After building the heap, the code continuously pops two elements from the
heap until only 2 or 3 elements remain, at which point it switches back to
a regular heapsort with one element popped at a time.  However, to handle
the final 2 or 3 elements, an additional else-if statement in the while
loop was introduced, potentially increasing branch misses.  Moreover, when
there are only 2 or 3 elements left, continuing with regular heapify
operations is unnecessary as these cases are simple enough to be handled
with a single comparison and 1 or 2 swaps outside the while loop.

Eliminating the additional else-if statement and directly managing cases
involving 2 or 3 elements outside the loop reduces unnecessary conditional
branches resulting from the numerous loops and conditionals in heapify.

This optimization maintains consistent numbers of comparisons and swaps
for arrays with even lengths while reducing swaps and comparisons for
arrays with odd lengths from 2.5 swaps and 1 comparison to 1.5 swaps and 1
comparison.

Link: https://lkml.kernel.org/r/20240527203011.1644280-4-visitorckw@gmail.com
Signed-off-by: Kuan-Wei Chiu <visitorckw@gmail.com>
Cc: Ching-Chun (Jim) Huang <jserv@ccns.ncku.edu.tw>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
2024-06-24 22:25:03 -07:00

303 lines
9.3 KiB
C

// SPDX-License-Identifier: GPL-2.0
/*
* A fast, small, non-recursive O(n log n) sort for the Linux kernel
*
* This performs n*log2(n) + 0.37*n + o(n) comparisons on average,
* and 1.5*n*log2(n) + O(n) in the (very contrived) worst case.
*
* Quicksort manages n*log2(n) - 1.26*n for random inputs (1.63*n
* better) at the expense of stack usage and much larger code to avoid
* quicksort's O(n^2) worst case.
*/
#include <linux/types.h>
#include <linux/export.h>
#include <linux/sort.h>
/**
* is_aligned - is this pointer & size okay for word-wide copying?
* @base: pointer to data
* @size: size of each element
* @align: required alignment (typically 4 or 8)
*
* Returns true if elements can be copied using word loads and stores.
* The size must be a multiple of the alignment, and the base address must
* be if we do not have CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS.
*
* For some reason, gcc doesn't know to optimize "if (a & mask || b & mask)"
* to "if ((a | b) & mask)", so we do that by hand.
*/
__attribute_const__ __always_inline
static bool is_aligned(const void *base, size_t size, unsigned char align)
{
unsigned char lsbits = (unsigned char)size;
(void)base;
#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS
lsbits |= (unsigned char)(uintptr_t)base;
#endif
return (lsbits & (align - 1)) == 0;
}
/**
* swap_words_32 - swap two elements in 32-bit chunks
* @a: pointer to the first element to swap
* @b: pointer to the second element to swap
* @n: element size (must be a multiple of 4)
*
* Exchange the two objects in memory. This exploits base+index addressing,
* which basically all CPUs have, to minimize loop overhead computations.
*
* For some reason, on x86 gcc 7.3.0 adds a redundant test of n at the
* bottom of the loop, even though the zero flag is still valid from the
* subtract (since the intervening mov instructions don't alter the flags).
* Gcc 8.1.0 doesn't have that problem.
*/
static void swap_words_32(void *a, void *b, size_t n)
{
do {
u32 t = *(u32 *)(a + (n -= 4));
*(u32 *)(a + n) = *(u32 *)(b + n);
*(u32 *)(b + n) = t;
} while (n);
}
/**
* swap_words_64 - swap two elements in 64-bit chunks
* @a: pointer to the first element to swap
* @b: pointer to the second element to swap
* @n: element size (must be a multiple of 8)
*
* Exchange the two objects in memory. This exploits base+index
* addressing, which basically all CPUs have, to minimize loop overhead
* computations.
*
* We'd like to use 64-bit loads if possible. If they're not, emulating
* one requires base+index+4 addressing which x86 has but most other
* processors do not. If CONFIG_64BIT, we definitely have 64-bit loads,
* but it's possible to have 64-bit loads without 64-bit pointers (e.g.
* x32 ABI). Are there any cases the kernel needs to worry about?
*/
static void swap_words_64(void *a, void *b, size_t n)
{
do {
#ifdef CONFIG_64BIT
u64 t = *(u64 *)(a + (n -= 8));
*(u64 *)(a + n) = *(u64 *)(b + n);
*(u64 *)(b + n) = t;
#else
/* Use two 32-bit transfers to avoid base+index+4 addressing */
u32 t = *(u32 *)(a + (n -= 4));
*(u32 *)(a + n) = *(u32 *)(b + n);
*(u32 *)(b + n) = t;
t = *(u32 *)(a + (n -= 4));
*(u32 *)(a + n) = *(u32 *)(b + n);
*(u32 *)(b + n) = t;
#endif
} while (n);
}
/**
* swap_bytes - swap two elements a byte at a time
* @a: pointer to the first element to swap
* @b: pointer to the second element to swap
* @n: element size
*
* This is the fallback if alignment doesn't allow using larger chunks.
*/
static void swap_bytes(void *a, void *b, size_t n)
{
do {
char t = ((char *)a)[--n];
((char *)a)[n] = ((char *)b)[n];
((char *)b)[n] = t;
} while (n);
}
/*
* The values are arbitrary as long as they can't be confused with
* a pointer, but small integers make for the smallest compare
* instructions.
*/
#define SWAP_WORDS_64 (swap_r_func_t)0
#define SWAP_WORDS_32 (swap_r_func_t)1
#define SWAP_BYTES (swap_r_func_t)2
#define SWAP_WRAPPER (swap_r_func_t)3
struct wrapper {
cmp_func_t cmp;
swap_func_t swap;
};
/*
* The function pointer is last to make tail calls most efficient if the
* compiler decides not to inline this function.
*/
static void do_swap(void *a, void *b, size_t size, swap_r_func_t swap_func, const void *priv)
{
if (swap_func == SWAP_WRAPPER) {
((const struct wrapper *)priv)->swap(a, b, (int)size);
return;
}
if (swap_func == SWAP_WORDS_64)
swap_words_64(a, b, size);
else if (swap_func == SWAP_WORDS_32)
swap_words_32(a, b, size);
else if (swap_func == SWAP_BYTES)
swap_bytes(a, b, size);
else
swap_func(a, b, (int)size, priv);
}
#define _CMP_WRAPPER ((cmp_r_func_t)0L)
static int do_cmp(const void *a, const void *b, cmp_r_func_t cmp, const void *priv)
{
if (cmp == _CMP_WRAPPER)
return ((const struct wrapper *)priv)->cmp(a, b);
return cmp(a, b, priv);
}
/**
* parent - given the offset of the child, find the offset of the parent.
* @i: the offset of the heap element whose parent is sought. Non-zero.
* @lsbit: a precomputed 1-bit mask, equal to "size & -size"
* @size: size of each element
*
* In terms of array indexes, the parent of element j = @i/@size is simply
* (j-1)/2. But when working in byte offsets, we can't use implicit
* truncation of integer divides.
*
* Fortunately, we only need one bit of the quotient, not the full divide.
* @size has a least significant bit. That bit will be clear if @i is
* an even multiple of @size, and set if it's an odd multiple.
*
* Logically, we're doing "if (i & lsbit) i -= size;", but since the
* branch is unpredictable, it's done with a bit of clever branch-free
* code instead.
*/
__attribute_const__ __always_inline
static size_t parent(size_t i, unsigned int lsbit, size_t size)
{
i -= size;
i -= size & -(i & lsbit);
return i / 2;
}
/**
* sort_r - sort an array of elements
* @base: pointer to data to sort
* @num: number of elements
* @size: size of each element
* @cmp_func: pointer to comparison function
* @swap_func: pointer to swap function or NULL
* @priv: third argument passed to comparison function
*
* This function does a heapsort on the given array. You may provide
* a swap_func function if you need to do something more than a memory
* copy (e.g. fix up pointers or auxiliary data), but the built-in swap
* avoids a slow retpoline and so is significantly faster.
*
* Sorting time is O(n log n) both on average and worst-case. While
* quicksort is slightly faster on average, it suffers from exploitable
* O(n*n) worst-case behavior and extra memory requirements that make
* it less suitable for kernel use.
*/
void sort_r(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func,
swap_r_func_t swap_func,
const void *priv)
{
/* pre-scale counters for performance */
size_t n = num * size, a = (num/2) * size;
const unsigned int lsbit = size & -size; /* Used to find parent */
size_t shift = 0;
if (!a) /* num < 2 || size == 0 */
return;
/* called from 'sort' without swap function, let's pick the default */
if (swap_func == SWAP_WRAPPER && !((struct wrapper *)priv)->swap)
swap_func = NULL;
if (!swap_func) {
if (is_aligned(base, size, 8))
swap_func = SWAP_WORDS_64;
else if (is_aligned(base, size, 4))
swap_func = SWAP_WORDS_32;
else
swap_func = SWAP_BYTES;
}
/*
* Loop invariants:
* 1. elements [a,n) satisfy the heap property (compare greater than
* all of their children),
* 2. elements [n,num*size) are sorted, and
* 3. a <= b <= c <= d <= n (whenever they are valid).
*/
for (;;) {
size_t b, c, d;
if (a) /* Building heap: sift down a */
a -= size << shift;
else if (n > 3 * size) { /* Sorting: Extract two largest elements */
n -= size;
do_swap(base, base + n, size, swap_func, priv);
shift = do_cmp(base + size, base + 2 * size, cmp_func, priv) <= 0;
a = size << shift;
n -= size;
do_swap(base + a, base + n, size, swap_func, priv);
} else { /* Sort complete */
break;
}
/*
* Sift element at "a" down into heap. This is the
* "bottom-up" variant, which significantly reduces
* calls to cmp_func(): we find the sift-down path all
* the way to the leaves (one compare per level), then
* backtrack to find where to insert the target element.
*
* Because elements tend to sift down close to the leaves,
* this uses fewer compares than doing two per level
* on the way down. (A bit more than half as many on
* average, 3/4 worst-case.)
*/
for (b = a; c = 2*b + size, (d = c + size) < n;)
b = do_cmp(base + c, base + d, cmp_func, priv) > 0 ? c : d;
if (d == n) /* Special case last leaf with no sibling */
b = c;
/* Now backtrack from "b" to the correct location for "a" */
while (b != a && do_cmp(base + a, base + b, cmp_func, priv) >= 0)
b = parent(b, lsbit, size);
c = b; /* Where "a" belongs */
while (b != a) { /* Shift it into place */
b = parent(b, lsbit, size);
do_swap(base + b, base + c, size, swap_func, priv);
}
}
n -= size;
do_swap(base, base + n, size, swap_func, priv);
if (n == size * 2 && do_cmp(base, base + size, cmp_func, priv) > 0)
do_swap(base, base + size, size, swap_func, priv);
}
EXPORT_SYMBOL(sort_r);
void sort(void *base, size_t num, size_t size,
cmp_func_t cmp_func,
swap_func_t swap_func)
{
struct wrapper w = {
.cmp = cmp_func,
.swap = swap_func,
};
return sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w);
}
EXPORT_SYMBOL(sort);