Skip to content

Commit

Permalink
Implements Bulk processing dictionary API
Browse files Browse the repository at this point in the history
  • Loading branch information
Barbayar committed Apr 2, 2022
1 parent 718c6ee commit 0114227
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 0 deletions.
116 changes: 116 additions & 0 deletions zstd_bulk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package zstd

/*
#include "zstd.h"
*/
import "C"
import (
"errors"
"unsafe"
)

// BulkProcessor implements Bulk processing dictionary API
type BulkProcessor struct {
cDict *C.struct_ZSTD_CDict_s
dDict *C.struct_ZSTD_DDict_s
}

// NewBulkProcessor creates a new BulkProcessor with a pre-trained dictionary and compression level
func NewBulkProcessor(dictionary []byte, compressionLevel int) (*BulkProcessor, error) {
p := &BulkProcessor{}
p.cDict = C.ZSTD_createCDict(
unsafe.Pointer(&dictionary[0]),
C.size_t(len(dictionary)),
C.int(compressionLevel),
)
if p.cDict == nil {
return nil, errors.New("failed to create dictionary")
}
p.dDict = C.ZSTD_createDDict(
unsafe.Pointer(&dictionary[0]),
C.size_t(len(dictionary)),
)
if p.dDict == nil {
return nil, errors.New("failed to create dictionary")
}
return p, nil
}

// Compress compresses the `src` with the dictionary
func (p *BulkProcessor) Compress(dst, src []byte) ([]byte, error) {
bound := CompressBound(len(src))
if cap(dst) >= bound {
dst = dst[0:bound]
} else {
dst = make([]byte, bound)
}

var cSrc unsafe.Pointer
if len(src) == 0 {
cSrc = unsafe.Pointer(nil)
} else {
cSrc = unsafe.Pointer(&src[0])
}

cctx := C.ZSTD_createCCtx()
cWritten := C.ZSTD_compress_usingCDict(
cctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
cSrc,
C.size_t(len(src)),
p.cDict,
)
C.ZSTD_freeCCtx(cctx)

written := int(cWritten)
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
}

// Decompress compresses the `dst` with the dictionary
func (p *BulkProcessor) Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
contentSize := uint64(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if contentSize == C.ZSTD_CONTENTSIZE_ERROR || contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN {
return nil, errors.New("could not determine the content size")
}

if cap(dst) >= int(contentSize) {
dst = dst[0:contentSize]
} else {
dst = make([]byte, contentSize)
}

if contentSize == 0 {
return dst, nil
}

dctx := C.ZSTD_createDCtx()
cWritten := C.ZSTD_decompress_usingDDict(
dctx,
unsafe.Pointer(&dst[0]),
C.size_t(contentSize),
unsafe.Pointer(&src[0]),
C.size_t(len(src)),
p.dDict,
)
C.ZSTD_freeDCtx(dctx)

written := int(cWritten)
if err := getError(written); err != nil {
return nil, err
}

return dst[:written], nil
}

// Cleanup frees compression and decompression dictionaries from memory
func (p *BulkProcessor) Cleanup() {
C.ZSTD_freeCDict(p.cDict)
C.ZSTD_freeDDict(p.dDict)
}
152 changes: 152 additions & 0 deletions zstd_bullk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package zstd

import (
"bytes"
"encoding/base64"
"math/rand"
"regexp"
"strings"
"testing"
)

var dictBase64 string = `
N6Qw7IsuFDIdENCSQjr//////4+QlekuNkmXbUBIkIDiVRX7H4AzAFCgQCFCO9oHAAAEQEuSikaK
Dg51OYghBYgBAAAAAAAAAAAAAAAAAAAAANQVpmRQGQAAAAAAAAAAAAAAAAABAAAABAAAAAgAAABo
ZWxwIEpvaW4gZW5naW5lZXJzIGVuZ2luZWVycyBmdXR1cmUgbG92ZSB0aGF0IGFyZWlsZGluZyB1
c2UgaGVscCBoZWxwIHVzaGVyIEpvaW4gdXNlIGxvdmUgdXMgSm9pbiB1bmQgaW4gdXNoZXIgdXNo
ZXIgYSBwbGF0Zm9ybSB1c2UgYW5kIGZ1dHVyZQ==`

func getRandomText() string {
words := []string{"We", "are", "building", "a platform", "that", "engineers", "love", "to", "use", "Join", "us", "and", "help", "usher", "in", "the", "future"}
wordCount := 10 + rand.Intn(100) // 10 - 109
result := []string{}
for i := 0; i < wordCount; i++ {
result = append(result, words[rand.Intn(len(words))])
}

return strings.Join(result, " ")
}

func TestCompressAndDecompress(t *testing.T) {
var b64 = base64.StdEncoding
dict, err := b64.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, ""))
if err != nil {
t.Fatalf("failed to decode the dictionary")
}

p, err := NewBulkProcessor(dict, BestSpeed)
if err != nil {
t.Fatalf("failed to create a BulkProcessor")
}

for i := 0; i < 100; i++ {
payload := []byte(getRandomText())

compressed, err := p.Compress(nil, payload)
if err != nil {
t.Fatalf("failed to compress")
}

uncompressed, err := p.Decompress(nil, compressed)
if err != nil {
t.Fatalf("failed to decompress")
}

if bytes.Compare(payload, uncompressed) != 0 {
t.Fatalf("uncompressed payload didn't match")
}
}

p.Cleanup()
}

func TestCompressAndDecompressInReverseOrder(t *testing.T) {
var b64 = base64.StdEncoding
dict, err := b64.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, ""))
if err != nil {
t.Fatalf("failed to decode the dictionary")
}

p, err := NewBulkProcessor(dict, BestSpeed)
if err != nil {
t.Fatalf("failed to create a BulkProcessor")
}

payloads := [][]byte{}
compressedPayloads := [][]byte{}
for i := 0; i < 100; i++ {
payloads = append(payloads, []byte(getRandomText()))

compressed, err := p.Compress(nil, payloads[i])
if err != nil {
t.Fatalf("failed to compress")
}
compressedPayloads = append(compressedPayloads, compressed)
}

for i := 99; i >= 0; i-- {
uncompressed, err := p.Decompress(nil, compressedPayloads[i])
if err != nil {
t.Fatalf("failed to decompress")
}

if bytes.Compare(payloads[i], uncompressed) != 0 {
t.Fatalf("uncompressed payload didn't match")
}
}

p.Cleanup()
}

// BenchmarkCompress-8 715689 1550 ns/op 59.37 MB/s 208 B/op 5 allocs/op
func BenchmarkCompress(b *testing.B) {
var b64 = base64.StdEncoding
dict, err := b64.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, ""))
if err != nil {
b.Fatalf("failed to decode the dictionary")
}

p, err := NewBulkProcessor(dict, BestSpeed)
if err != nil {
b.Fatalf("failed to create a BulkProcessor")
}

payload := []byte("We're building a platform that engineers love to use. Join us, and help usher in the future.")
for n := 0; n < b.N; n++ {
_, err := p.Compress(nil, payload)
if err != nil {
b.Fatalf("failed to compress")
}
b.SetBytes(int64(len(payload)))
}

p.Cleanup()
}

// BenchmarkDecompress-8 664922 1544 ns/op 36.91 MB/s 192 B/op 7 allocs/op
func BenchmarkDecompress(b *testing.B) {
var b64 = base64.StdEncoding
dict, err := b64.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, ""))
if err != nil {
b.Fatalf("failed to decode the dictionary")
}

p, err := NewBulkProcessor(dict, BestSpeed)
if err != nil {
b.Fatalf("failed to create a BulkProcessor")
}

payload, err := p.Compress(nil, []byte("We're building a platform that engineers love to use. Join us, and help usher in the future."))
if err != nil {
b.Fatalf("failed to compress")
}
for n := 0; n < b.N; n++ {
_, err := p.Decompress(nil, payload)
if err != nil {
b.Fatalf("failed to decompress")
}
b.SetBytes(int64(len(payload)))
}

p.Cleanup()
}

0 comments on commit 0114227

Please sign in to comment.