diff --git a/zstd.go b/zstd.go index 82af3c1..2cf5c61 100644 --- a/zstd.go +++ b/zstd.go @@ -36,7 +36,7 @@ const ( // decompressed, err := zstd.Decompress(dst, src) decompressSizeBufferLimit = 1000 * 1000 - zstdFrameHeaderSizeMax = 18 // From zstd.h. Since it's experimental API, hardcoding it + zstdFrameHeaderSizeMin = 2 // From zstd.h. Since it's experimental API, hardcoding it ) // CompressBound returns the worst case size needed for a destination buffer, @@ -67,11 +67,14 @@ func decompressSizeHint(src []byte) int { } hint := upperBound - if len(src) >= zstdFrameHeaderSizeMax { + if len(src) >= zstdFrameHeaderSizeMin { hint = int(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))) if hint < 0 { // On error, just use upperBound hint = upperBound } + if hint == 0 { // When compressing the empty slice, we need an output of at least 1 to pass down to the C lib + hint = 1 + } } // Take the minimum of both diff --git a/zstd_test.go b/zstd_test.go index b7299cc..0253537 100644 --- a/zstd_test.go +++ b/zstd_test.go @@ -293,6 +293,25 @@ func TestBadPayloadZipBomb(t *testing.T) { } } +func TestSmallPayload(t *testing.T) { + // Test that we can compress really small payloads and this doesn't generate a huge output buffer + compressed, err := Compress(nil, []byte("a")) + if err != nil { + t.Fatalf("failed to compress: %s", err) + } + + preAllocated := make([]byte, 1, 64) // Don't use more than that + decompressed, err := Decompress(preAllocated, compressed) + if err != nil { + t.Fatalf("failed to compress: %s", err) + } + + if &(preAllocated[0]) != &(decompressed[0]) { // They should point to the same spot (no realloc) + t.Fatal("Compression buffer was changed") + } + +} + func BenchmarkCompression(b *testing.B) { if raw == nil { b.Fatal(ErrNoPayloadEnv)