diff --git a/zstd_bulk.go b/zstd_bulk.go index f0ecf26..fadfe5d 100644 --- a/zstd_bulk.go +++ b/zstd_bulk.go @@ -6,10 +6,26 @@ package zstd import "C" import ( "errors" + "runtime" "unsafe" ) -// BulkProcessor implements Bulk processing dictionary API +var ( + // ErrEmptyDictionary is returned when the given dictionary is empty + ErrEmptyDictionary = errors.New("Dictionary is empty") + // ErrBadDictionary is returned when cannot load the given dictionary + ErrBadDictionary = errors.New("Cannot load dictionary") + // ErrContentSize is returned when cannot determine the content size + ErrContentSize = errors.New("Cannot determine the content size") +) + +// BulkProcessor implements Bulk processing dictionary API. +// When compressing multiple messages or blocks using the same dictionary, +// it's recommended to digest the dictionary only once, since it's a costly operation. +// NewBulkProcessor() will create a state from digesting a dictionary. +// The resulting state can be used for future compression/decompression operations with very limited startup cost. +// BulkProcessor can be created once and shared by multiple threads concurrently, since its usage is read-only. +// The state will be freed when gc cleans up BulkProcessor. type BulkProcessor struct { cDict *C.struct_ZSTD_CDict_s dDict *C.struct_ZSTD_DDict_s @@ -17,26 +33,35 @@ type BulkProcessor struct { // NewBulkProcessor creates a new BulkProcessor with a pre-trained dictionary and compression level func NewBulkProcessor(dictionary []byte, compressionLevel int) (*BulkProcessor, error) { + if len(dictionary) < 1 { + return nil, ErrEmptyDictionary + } + p := &BulkProcessor{} + runtime.SetFinalizer(p, finalizeBulkProcessor) + p.cDict = C.ZSTD_createCDict( unsafe.Pointer(&dictionary[0]), C.size_t(len(dictionary)), C.int(compressionLevel), ) if p.cDict == nil { - return nil, errors.New("failed to create dictionary") + return nil, ErrBadDictionary } p.dDict = C.ZSTD_createDDict( unsafe.Pointer(&dictionary[0]), C.size_t(len(dictionary)), ) if p.dDict == nil { - return nil, errors.New("failed to create dictionary") + return nil, ErrBadDictionary } + return p, nil } -// Compress compresses the `src` with the dictionary +// Compress compresses `src` into `dst` with the dictionary given when creating the BulkProcessor. +// If you have a buffer to use, you can pass it to prevent allocation. +// If it is too small, or if nil is passed, a new buffer will be allocated and returned. func (p *BulkProcessor) Compress(dst, src []byte) ([]byte, error) { bound := CompressBound(len(src)) if cap(dst) >= bound { @@ -45,22 +70,31 @@ func (p *BulkProcessor) Compress(dst, src []byte) ([]byte, error) { dst = make([]byte, bound) } - var cSrc unsafe.Pointer + cctx := C.ZSTD_createCCtx() + // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics. + // This means we need to special case empty input. See: + // https://github.com/golang/go/issues/14210#issuecomment-346402945 + var cWritten C.size_t if len(src) == 0 { - cSrc = unsafe.Pointer(nil) + cWritten = C.ZSTD_compress_usingCDict( + cctx, + unsafe.Pointer(&dst[0]), + C.size_t(len(dst)), + unsafe.Pointer(nil), + C.size_t(len(src)), + p.cDict, + ) } else { - cSrc = unsafe.Pointer(&src[0]) + cWritten = C.ZSTD_compress_usingCDict( + cctx, + unsafe.Pointer(&dst[0]), + C.size_t(len(dst)), + unsafe.Pointer(&src[0]), + C.size_t(len(src)), + p.cDict, + ) } - cctx := C.ZSTD_createCCtx() - cWritten := C.ZSTD_compress_usingCDict( - cctx, - unsafe.Pointer(&dst[0]), - C.size_t(len(dst)), - cSrc, - C.size_t(len(src)), - p.cDict, - ) C.ZSTD_freeCCtx(cctx) written := int(cWritten) @@ -70,14 +104,16 @@ func (p *BulkProcessor) Compress(dst, src []byte) ([]byte, error) { return dst[:written], nil } -// Decompress compresses the `dst` with the dictionary +// Decompress decompresses `src` into `dst` with the dictionary given when creating the BulkProcessor. +// If you have a buffer to use, you can pass it to prevent allocation. +// If it is too small, or if nil is passed, a new buffer will be allocated and returned. func (p *BulkProcessor) Decompress(dst, src []byte) ([]byte, error) { if len(src) == 0 { - return []byte{}, ErrEmptySlice + return nil, ErrEmptySlice } contentSize := uint64(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))) if contentSize == C.ZSTD_CONTENTSIZE_ERROR || contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN { - return nil, errors.New("could not determine the content size") + return nil, ErrContentSize } if cap(dst) >= int(contentSize) { @@ -109,8 +145,12 @@ func (p *BulkProcessor) Decompress(dst, src []byte) ([]byte, error) { return dst[:written], nil } -// Cleanup frees compression and decompression dictionaries from memory -func (p *BulkProcessor) Cleanup() { - C.ZSTD_freeCDict(p.cDict) - C.ZSTD_freeDDict(p.dDict) +// finalizeBulkProcessor frees compression and decompression dictionaries from memory +func finalizeBulkProcessor(p *BulkProcessor) { + if p.cDict != nil { + C.ZSTD_freeCDict(p.cDict) + } + if p.dDict != nil { + C.ZSTD_freeDDict(p.dDict) + } } diff --git a/zstd_bullk_test.go b/zstd_bullk_test.go index 868e678..eeba156 100644 --- a/zstd_bullk_test.go +++ b/zstd_bullk_test.go @@ -15,6 +15,32 @@ var dictBase64 string = ` ZWxwIEpvaW4gZW5naW5lZXJzIGVuZ2luZWVycyBmdXR1cmUgbG92ZSB0aGF0IGFyZWlsZGluZyB1 c2UgaGVscCBoZWxwIHVzaGVyIEpvaW4gdXNlIGxvdmUgdXMgSm9pbiB1bmQgaW4gdXNoZXIgdXNo ZXIgYSBwbGF0Zm9ybSB1c2UgYW5kIGZ1dHVyZQ==` +var dict []byte +var compressedPayload []byte + +func init() { + var err error + dict, err = base64.StdEncoding.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, "")) + if err != nil { + panic("failed to create dictionary") + } + p, err := NewBulkProcessor(dict, BestSpeed) + if err != nil { + panic("failed to create bulk processor") + } + compressedPayload, err = p.Compress(nil, []byte("We're building a platform that engineers love to use. Join us, and help usher in the future.")) + if err != nil { + panic("failed to compress payload") + } +} + +func newBulkProcessor(t testing.TB, dict []byte, level int) *BulkProcessor { + p, err := NewBulkProcessor(dict, level) + if err != nil { + t.Fatal("failed to create a BulkProcessor") + } + return p +} func getRandomText() string { words := []string{"We", "are", "building", "a platform", "that", "engineers", "love", "to", "use", "Join", "us", "and", "help", "usher", "in", "the", "future"} @@ -27,51 +53,145 @@ func getRandomText() string { return strings.Join(result, " ") } -func TestCompressAndDecompress(t *testing.T) { - var b64 = base64.StdEncoding - dict, err := b64.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, "")) - if err != nil { - t.Fatalf("failed to decode the dictionary") - } - - p, err := NewBulkProcessor(dict, BestSpeed) - if err != nil { - t.Fatalf("failed to create a BulkProcessor") +func TestBulkDictionary(t *testing.T) { + if len(dict) < 1 { + t.Error("dictionary is empty") } +} +func TestBulkCompressAndDecompress(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) for i := 0; i < 100; i++ { payload := []byte(getRandomText()) compressed, err := p.Compress(nil, payload) if err != nil { - t.Fatalf("failed to compress") + t.Error("failed to compress") } uncompressed, err := p.Decompress(nil, compressed) if err != nil { - t.Fatalf("failed to decompress") + t.Error("failed to decompress") } if bytes.Compare(payload, uncompressed) != 0 { - t.Fatalf("uncompressed payload didn't match") + t.Error("uncompressed payload didn't match") } } +} - p.Cleanup() +func TestBulkEmptyOrNilDictionary(t *testing.T) { + p, err := NewBulkProcessor(nil, BestSpeed) + if p != nil { + t.Error("nil is expected") + } + if err != ErrEmptyDictionary { + t.Error("ErrEmptyDictionary is expected") + } + + p, err = NewBulkProcessor([]byte{}, BestSpeed) + if p != nil { + t.Error("nil is expected") + } + if err != ErrEmptyDictionary { + t.Error("ErrEmptyDictionary is expected") + } } -func TestCompressAndDecompressInReverseOrder(t *testing.T) { - var b64 = base64.StdEncoding - dict, err := b64.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, "")) +func TestBulkCompressEmptyOrNilContent(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + compressed, err := p.Compress(nil, nil) if err != nil { - t.Fatalf("failed to decode the dictionary") + t.Error("failed to compress") + } + if len(compressed) < 4 { + t.Error("magic number doesn't exist") } - p, err := NewBulkProcessor(dict, BestSpeed) + compressed, err = p.Compress(nil, []byte{}) if err != nil { - t.Fatalf("failed to create a BulkProcessor") + t.Error("failed to compress") } + if len(compressed) < 4 { + t.Error("magic number doesn't exist") + } +} +func TestBulkCompressIntoGivenDestination(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + dst := make([]byte, 100000) + compressed, err := p.Compress(dst, []byte(getRandomText())) + if err != nil { + t.Error("failed to compress") + } + if len(compressed) < 4 { + t.Error("magic number doesn't exist") + } + if &dst[0] != &compressed[0] { + t.Error("'dst' and 'compressed' are not the same object") + } +} + +func TestBulkCompressNotEnoughDestination(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + dst := make([]byte, 1) + compressed, err := p.Compress(dst, []byte(getRandomText())) + if err != nil { + t.Error("failed to compress") + } + if len(compressed) < 4 { + t.Error("magic number doesn't exist") + } + if &dst[0] == &compressed[0] { + t.Error("'dst' and 'compressed' are the same object") + } +} + +func TestBulkDecompressIntoGivenDestination(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + dst := make([]byte, 100000) + decompressed, err := p.Decompress(dst, compressedPayload) + if err != nil { + t.Error("failed to decompress") + } + if &dst[0] != &decompressed[0] { + t.Error("'dst' and 'decompressed' are not the same object") + } +} + +func TestBulkDecompressNotEnoughDestination(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + dst := make([]byte, 1) + decompressed, err := p.Decompress(dst, compressedPayload) + if err != nil { + t.Error("failed to decompress") + } + if &dst[0] == &decompressed[0] { + t.Error("'dst' and 'decompressed' are the same object") + } +} + +func TestBulkDecompressEmptyOrNilContent(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) + decompressed, err := p.Decompress(nil, nil) + if err != ErrEmptySlice { + t.Error("ErrEmptySlice is expected") + } + if decompressed != nil { + t.Error("nil is expected") + } + + decompressed, err = p.Decompress(nil, []byte{}) + if err != ErrEmptySlice { + t.Error("ErrEmptySlice is expected") + } + if decompressed != nil { + t.Error("nil is expected") + } +} + +func TestBulkCompressAndDecompressInReverseOrder(t *testing.T) { + p := newBulkProcessor(t, dict, BestSpeed) payloads := [][]byte{} compressedPayloads := [][]byte{} for i := 0; i < 100; i++ { @@ -79,7 +199,7 @@ func TestCompressAndDecompressInReverseOrder(t *testing.T) { compressed, err := p.Compress(nil, payloads[i]) if err != nil { - t.Fatalf("failed to compress") + t.Error("failed to compress") } compressedPayloads = append(compressedPayloads, compressed) } @@ -87,66 +207,38 @@ func TestCompressAndDecompressInReverseOrder(t *testing.T) { for i := 99; i >= 0; i-- { uncompressed, err := p.Decompress(nil, compressedPayloads[i]) if err != nil { - t.Fatalf("failed to decompress") + t.Error("failed to decompress") } if bytes.Compare(payloads[i], uncompressed) != 0 { - t.Fatalf("uncompressed payload didn't match") + t.Error("uncompressed payload didn't match") } } - - p.Cleanup() } -// BenchmarkCompress-8 715689 1550 ns/op 59.37 MB/s 208 B/op 5 allocs/op -func BenchmarkCompress(b *testing.B) { - var b64 = base64.StdEncoding - dict, err := b64.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, "")) - if err != nil { - b.Fatalf("failed to decode the dictionary") - } - - p, err := NewBulkProcessor(dict, BestSpeed) - if err != nil { - b.Fatalf("failed to create a BulkProcessor") - } +// BenchmarkBulkCompress-8 780148 1505 ns/op 61.14 MB/s 208 B/op 5 allocs/op +func BenchmarkBulkCompress(b *testing.B) { + p := newBulkProcessor(b, dict, BestSpeed) payload := []byte("We're building a platform that engineers love to use. Join us, and help usher in the future.") + b.SetBytes(int64(len(payload))) for n := 0; n < b.N; n++ { _, err := p.Compress(nil, payload) if err != nil { - b.Fatalf("failed to compress") + b.Error("failed to compress") } - b.SetBytes(int64(len(payload))) } - - p.Cleanup() } -// BenchmarkDecompress-8 664922 1544 ns/op 36.91 MB/s 192 B/op 7 allocs/op -func BenchmarkDecompress(b *testing.B) { - var b64 = base64.StdEncoding - dict, err := b64.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, "")) - if err != nil { - b.Fatalf("failed to decode the dictionary") - } +// BenchmarkBulkDecompress-8 817425 1412 ns/op 40.37 MB/s 192 B/op 7 allocs/op +func BenchmarkBulkDecompress(b *testing.B) { + p := newBulkProcessor(b, dict, BestSpeed) - p, err := NewBulkProcessor(dict, BestSpeed) - if err != nil { - b.Fatalf("failed to create a BulkProcessor") - } - - payload, err := p.Compress(nil, []byte("We're building a platform that engineers love to use. Join us, and help usher in the future.")) - if err != nil { - b.Fatalf("failed to compress") - } + b.SetBytes(int64(len(compressedPayload))) for n := 0; n < b.N; n++ { - _, err := p.Decompress(nil, payload) + _, err := p.Decompress(nil, compressedPayload) if err != nil { - b.Fatalf("failed to decompress") + b.Error("failed to decompress") } - b.SetBytes(int64(len(payload))) } - - p.Cleanup() }