Skip to content

Commit

Permalink
[huf] Improve fast C & ASM performance on small data
Browse files Browse the repository at this point in the history
* Rename `ilimit` to `ilowest` and set it equal to `src` instead of
  `src + 6 + 8`. This is safe because the fast decoding loops guarantee
  to never read below `ilowest` already. This allows the fast decoder to
  run for at least two more iterations, because it consumes at most 7
  bytes per iteration.
* Continue the fast loop all the way until the number of safe iterations
 is 0. Initially, I thought that when it got towards the end, the
 computation of how many iterations of safe might become expensive. But
 it ends up being slower to have to decode each of the 4 streams
 individually, which makes sense.

This drastically speeds up the Huffman decoder on the `github` dataset
for the issue raised in facebook#3762, measured with `zstd -b1e1r github/`.

| Decoder  | Speed before | Speed after |
|----------|--------------|-------------|
| Fallback | 477 MB/s     | 477 MB/s    |
| Fast C   | 384 MB/s     | 492 MB/s    |
| Assembly | 385 MB/s     | 501 MB/s    |

We can also look at the speed delta for different block sizes of silesia
using `zstd -b1e1r silesia.tar -B#`.

| Decoder  | -B1K ∆ | -B2K ∆ | -B4K ∆ | -B8K ∆ | -B16K ∆ | -B32K ∆ | -B64K ∆ | -B128K ∆ |
|----------|--------|--------|--------|--------|---------|---------|---------|----------|
| Fast C   | +11.2% | +8.2%  | +6.1%  | +4.4%  | +2.7%   | +1.5%   | +0.6%   | +0.2%    |
| Assembly | +12.5% | +9.0%  | +6.2%  | +3.6%  | +1.5%   | +0.7%   | +0.2%   | +0.03%   |
  • Loading branch information
Nick Terrell authored and hswong3i committed Jan 5, 2025
1 parent dbe3ff5 commit 9328f85
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 57 deletions.
84 changes: 45 additions & 39 deletions lib/decompress/huf_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,18 @@ static size_t HUF_initFastDStream(BYTE const* ip) {
* op [in/out] - The output pointers, must be updated to reflect what is written.
* bits [in/out] - The bitstream containers, must be updated to reflect the current state.
* dt [in] - The decoding table.
* ilimit [in] - The input limit, stop when any input pointer is below ilimit.
* ilowest [in] - The beginning of the valid range of the input. Decoders may read
* down to this pointer. It may be below iend[0].
* oend [in] - The end of the output stream. op[3] must not cross oend.
* iend [in] - The end of each input stream. ip[i] may cross iend[i],
* as long as it is above ilimit, but that indicates corruption.
* as long as it is above ilowest, but that indicates corruption.
*/
typedef struct {
BYTE const* ip[4];
BYTE* op[4];
U64 bits[4];
void const* dt;
BYTE const* ilimit;
BYTE const* ilowest;
BYTE* oend;
BYTE const* iend[4];
} HUF_DecompressFastArgs;
Expand All @@ -192,7 +193,7 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds
void const* dt = DTable + 1;
U32 const dtLog = HUF_getDTableDesc(DTable).tableLog;

const BYTE* const ilimit = (const BYTE*)src + 6 + 8;
const BYTE* const istart = (const BYTE*)src;

BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize);

Expand All @@ -215,7 +216,6 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds

/* Read the jump table. */
{
const BYTE* const istart = (const BYTE*)src;
size_t const length1 = MEM_readLE16(istart);
size_t const length2 = MEM_readLE16(istart+2);
size_t const length3 = MEM_readLE16(istart+4);
Expand All @@ -227,10 +227,8 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds

/* HUF_initFastDStream() requires this, and this small of an input
* won't benefit from the ASM loop anyways.
* length1 must be >= 16 so that ip[0] >= ilimit before the loop
* starts.
*/
if (length1 < 16 || length2 < 8 || length3 < 8 || length4 < 8)
if (length1 < 8 || length2 < 8 || length3 < 8 || length4 < 8)
return 0;
if (length4 > srcSize) return ERROR(corruption_detected); /* overflow */
}
Expand Down Expand Up @@ -262,11 +260,12 @@ static size_t HUF_DecompressFastArgs_init(HUF_DecompressFastArgs* args, void* ds
args->bits[2] = HUF_initFastDStream(args->ip[2]);
args->bits[3] = HUF_initFastDStream(args->ip[3]);

/* If ip[] >= ilimit, it is guaranteed to be safe to
* reload bits[]. It may be beyond its section, but is
* guaranteed to be valid (>= istart).
*/
args->ilimit = ilimit;
/* The decoders must be sure to never read beyond ilowest.
* This is lower than iend[0], but allowing decoders to read
* down to ilowest can allow an extra iteration or two in the
* fast loop.
*/
args->ilowest = istart;

args->oend = oend;
args->dt = dt;
Expand All @@ -291,7 +290,7 @@ static size_t HUF_initRemainingDStream(BIT_DStream_t* bit, HUF_DecompressFastArg
assert(sizeof(size_t) == 8);
bit->bitContainer = MEM_readLEST(args->ip[stream]);
bit->bitsConsumed = ZSTD_countTrailingZeros64(args->bits[stream]);
bit->start = (const char*)args->iend[0];
bit->start = (const char*)args->ilowest;
bit->limitPtr = bit->start + sizeof(size_t);
bit->ptr = (const char*)args->ip[stream];

Expand Down Expand Up @@ -717,7 +716,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
BYTE* op[4];
U16 const* const dtable = (U16 const*)args->dt;
BYTE* const oend = args->oend;
BYTE const* const ilimit = args->ilimit;
BYTE const* const ilowest = args->ilowest;

/* Copy the arguments to local variables */
ZSTD_memcpy(&bits, &args->bits, sizeof(bits));
Expand All @@ -735,7 +734,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
#ifndef NDEBUG
for (stream = 0; stream < 4; ++stream) {
assert(op[stream] <= (stream == 3 ? oend : op[stream + 1]));
assert(ip[stream] >= ilimit);
assert(ip[stream] >= ilowest);
}
#endif
/* Compute olimit */
Expand All @@ -745,7 +744,7 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
/* Each iteration consumes up to 11 bits * 5 = 55 bits < 7 bytes
* per stream.
*/
size_t const iiters = (size_t)(ip[0] - ilimit) / 7;
size_t const iiters = (size_t)(ip[0] - ilowest) / 7;
/* We can safely run iters iterations before running bounds checks */
size_t const iters = MIN(oiters, iiters);
size_t const symbols = iters * 5;
Expand All @@ -756,8 +755,8 @@ void HUF_decompress4X1_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
*/
olimit = op[3] + symbols;

/* Exit fast decoding loop once we get close to the end. */
if (op[3] + 20 > olimit)
/* Exit fast decoding loop once we reach the end. */
if (op[3] == olimit)
break;

/* Exit the decoding loop if any input pointer has crossed the
Expand Down Expand Up @@ -836,7 +835,7 @@ HUF_decompress4X1_usingDTable_internal_fast(
HUF_DecompressFastLoopFn loopFn)
{
void const* dt = DTable + 1;
const BYTE* const iend = (const BYTE*)cSrc + 6;
BYTE const* const ilowest = (BYTE const*)cSrc;
BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize);
HUF_DecompressFastArgs args;
{ size_t const ret = HUF_DecompressFastArgs_init(&args, dst, dstSize, cSrc, cSrcSize, DTable);
Expand All @@ -845,18 +844,22 @@ HUF_decompress4X1_usingDTable_internal_fast(
return 0;
}

assert(args.ip[0] >= args.ilimit);
assert(args.ip[0] >= args.ilowest);
loopFn(&args);

/* Our loop guarantees that ip[] >= ilimit and that we haven't
/* Our loop guarantees that ip[] >= ilowest and that we haven't
* overwritten any op[].
*/
assert(args.ip[0] >= iend);
assert(args.ip[1] >= iend);
assert(args.ip[2] >= iend);
assert(args.ip[3] >= iend);
assert(args.ip[0] >= ilowest);
assert(args.ip[0] >= ilowest);
assert(args.ip[1] >= ilowest);
assert(args.ip[2] >= ilowest);
assert(args.ip[3] >= ilowest);
assert(args.op[3] <= oend);
(void)iend;

assert(ilowest == args.ilowest);
assert(ilowest + 6 == args.iend[0]);
(void)ilowest;

/* finish bit streams one by one. */
{ size_t const segmentSize = (dstSize+3) / 4;
Expand Down Expand Up @@ -1512,7 +1515,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
BYTE* op[4];
BYTE* oend[4];
HUF_DEltX2 const* const dtable = (HUF_DEltX2 const*)args->dt;
BYTE const* const ilimit = args->ilimit;
BYTE const* const ilowest = args->ilowest;

/* Copy the arguments to local registers. */
ZSTD_memcpy(&bits, &args->bits, sizeof(bits));
Expand All @@ -1535,7 +1538,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
#ifndef NDEBUG
for (stream = 0; stream < 4; ++stream) {
assert(op[stream] <= oend[stream]);
assert(ip[stream] >= ilimit);
assert(ip[stream] >= ilowest);
}
#endif
/* Compute olimit */
Expand All @@ -1548,7 +1551,7 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
* We also know that each input pointer is >= ip[0]. So we can run
* iters loops before running out of input.
*/
size_t iters = (size_t)(ip[0] - ilimit) / 7;
size_t iters = (size_t)(ip[0] - ilowest) / 7;
/* Each iteration can produce up to 10 bytes of output per stream.
* Each output stream my advance at different rates. So take the
* minimum number of safe iterations among all the output streams.
Expand All @@ -1566,8 +1569,8 @@ void HUF_decompress4X2_usingDTable_internal_fast_c_loop(HUF_DecompressFastArgs*
*/
olimit = op[3] + (iters * 5);

/* Exit the fast decoding loop if we are too close to the end. */
if (op[3] + 10 > olimit)
/* Exit the fast decoding loop once we reach the end. */
if (op[3] == olimit)
break;

/* Exit the decoding loop if any input pointer has crossed the
Expand Down Expand Up @@ -1652,7 +1655,7 @@ HUF_decompress4X2_usingDTable_internal_fast(
const HUF_DTable* DTable,
HUF_DecompressFastLoopFn loopFn) {
void const* dt = DTable + 1;
const BYTE* const iend = (const BYTE*)cSrc + 6;
const BYTE* const ilowest = (const BYTE*)cSrc;
BYTE* const oend = ZSTD_maybeNullPtrAdd((BYTE*)dst, dstSize);
HUF_DecompressFastArgs args;
{
Expand All @@ -1662,16 +1665,19 @@ HUF_decompress4X2_usingDTable_internal_fast(
return 0;
}

assert(args.ip[0] >= args.ilimit);
assert(args.ip[0] >= args.ilowest);
loopFn(&args);

/* note : op4 already verified within main loop */
assert(args.ip[0] >= iend);
assert(args.ip[1] >= iend);
assert(args.ip[2] >= iend);
assert(args.ip[3] >= iend);
assert(args.ip[0] >= ilowest);
assert(args.ip[1] >= ilowest);
assert(args.ip[2] >= ilowest);
assert(args.ip[3] >= ilowest);
assert(args.op[3] <= oend);
(void)iend;

assert(ilowest == args.ilowest);
assert(ilowest + 6 == args.iend[0]);
(void)ilowest;

/* finish bitStreams one by one */
{
Expand Down
34 changes: 16 additions & 18 deletions lib/decompress/huf_decompress_amd64.S
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop:
movq 88(%rax), %bits3
movq 96(%rax), %dtable
push %rax /* argument */
push 104(%rax) /* ilimit */
push 104(%rax) /* ilowest */
push 112(%rax) /* oend */
push %olimit /* olimit space */

Expand All @@ -156,11 +156,11 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop:
shrq $2, %r15

movq %ip0, %rax /* rax = ip0 */
movq 40(%rsp), %rdx /* rdx = ilimit */
subq %rdx, %rax /* rax = ip0 - ilimit */
movq %rax, %rbx /* rbx = ip0 - ilimit */
movq 40(%rsp), %rdx /* rdx = ilowest */
subq %rdx, %rax /* rax = ip0 - ilowest */
movq %rax, %rbx /* rbx = ip0 - ilowest */

/* rdx = (ip0 - ilimit) / 7 */
/* rdx = (ip0 - ilowest) / 7 */
movabsq $2635249153387078803, %rdx
mulq %rdx
subq %rdx, %rbx
Expand All @@ -183,9 +183,8 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop:

/* If (op3 + 20 > olimit) */
movq %op3, %rax /* rax = op3 */
addq $20, %rax /* rax = op3 + 20 */
cmpq %rax, %olimit /* op3 + 20 > olimit */
jb .L_4X1_exit
cmpq %rax, %olimit /* op3 == olimit */
je .L_4X1_exit

/* If (ip1 < ip0) go to exit */
cmpq %ip0, %ip1
Expand Down Expand Up @@ -316,7 +315,7 @@ HUF_decompress4X1_usingDTable_internal_fast_asm_loop:
/* Restore stack (oend & olimit) */
pop %rax /* olimit */
pop %rax /* oend */
pop %rax /* ilimit */
pop %rax /* ilowest */
pop %rax /* arg */

/* Save ip / op / bits */
Expand Down Expand Up @@ -387,7 +386,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:
movq 96(%rax), %dtable
push %rax /* argument */
push %rax /* olimit */
push 104(%rax) /* ilimit */
push 104(%rax) /* ilowest */

movq 112(%rax), %rax
push %rax /* oend3 */
Expand All @@ -414,9 +413,9 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:

/* We can consume up to 7 input bytes each iteration. */
movq %ip0, %rax /* rax = ip0 */
movq 40(%rsp), %rdx /* rdx = ilimit */
subq %rdx, %rax /* rax = ip0 - ilimit */
movq %rax, %r15 /* r15 = ip0 - ilimit */
movq 40(%rsp), %rdx /* rdx = ilowest */
subq %rdx, %rax /* rax = ip0 - ilowest */
movq %rax, %r15 /* r15 = ip0 - ilowest */

/* rdx = rax / 7 */
movabsq $2635249153387078803, %rdx
Expand All @@ -426,7 +425,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:
addq %r15, %rdx
shrq $2, %rdx

/* r15 = (ip0 - ilimit) / 7 */
/* r15 = (ip0 - ilowest) / 7 */
movq %rdx, %r15

/* r15 = min(r15, min(oend0 - op0, oend1 - op1, oend2 - op2, oend3 - op3) / 10) */
Expand Down Expand Up @@ -467,9 +466,8 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:

/* If (op3 + 10 > olimit) */
movq %op3, %rax /* rax = op3 */
addq $10, %rax /* rax = op3 + 10 */
cmpq %rax, %olimit /* op3 + 10 > olimit */
jb .L_4X2_exit
cmpq %rax, %olimit /* op3 == olimit */
je .L_4X2_exit

/* If (ip1 < ip0) go to exit */
cmpq %ip0, %ip1
Expand Down Expand Up @@ -537,7 +535,7 @@ HUF_decompress4X2_usingDTable_internal_fast_asm_loop:
pop %rax /* oend1 */
pop %rax /* oend2 */
pop %rax /* oend3 */
pop %rax /* ilimit */
pop %rax /* ilowest */
pop %rax /* olimit */
pop %rax /* arg */

Expand Down

0 comments on commit 9328f85

Please sign in to comment.