Skip to content

Commit

Permalink
Merge pull request #117 from bsergean/patch-1
Browse files Browse the repository at this point in the history
Add SetNbWorkers api to the writer code (see #108)
  • Loading branch information
Viq111 authored Jun 6, 2022
2 parents d64f463 + c798238 commit fd035e5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
2 changes: 1 addition & 1 deletion zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package zstd
// support decoding of "legacy" zstd payloads from versions [0.4, 0.8], matching the
// default configuration of the zstd command line tool:
// https://github.com/facebook/zstd/blob/dev/programs/README.md
#cgo CFLAGS: -DZSTD_LEGACY_SUPPORT=4
#cgo CFLAGS: -DZSTD_LEGACY_SUPPORT=4 -DZSTD_MULTITHREAD=1
#include "zstd.h"
*/
Expand Down
23 changes: 23 additions & 0 deletions zstd_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import (

var errShortRead = errors.New("short read")
var errReaderClosed = errors.New("Reader is closed")
var ErrNoParallelSupport = errors.New("No parallel support")

// Writer is an io.WriteCloser that zstd-compresses its input.
type Writer struct {
Expand Down Expand Up @@ -302,6 +303,28 @@ func (w *Writer) Close() error {
return getError(int(C.ZSTD_freeCStream(w.ctx)))
}

// Set the number of workers to run the compression in parallel using multiple threads
// If > 1, the Write() call will become asynchronous. This means data will be buffered until processed.
// If you call Write() too fast, you might incur a memory buffer up to as large as your input.
// Consider calling Flush() periodically if you need to compress a very large file that would not fit all in memory.
// By default only one worker is used.
func (w *Writer) SetNbWorkers(n int) error {
if w.firstError != nil {
return w.firstError
}
if err := getError(int(C.ZSTD_CCtx_setParameter(w.ctx, C.ZSTD_c_nbWorkers, C.int(n)))); err != nil {
w.firstError = err
// First error case, a shared libary is used, and the library was compiled without parallel support
if err.Error() == "Unsupported parameter" {
return ErrNoParallelSupport
} else {
// This could happen if a very large number is passed in, and possibly zstd refuse to create as many threads, or the OS fails to do so
return err
}
}
return nil
}

// cSize is the recommended size of reader.compressionBuffer. This func and
// invocation allow for a one-time check for validity.
var cSize = func() int {
Expand Down
27 changes: 22 additions & 5 deletions zstd_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log"
"os"
"runtime/debug"
"strings"
"testing"
)

Expand All @@ -19,9 +20,16 @@ func failOnError(t *testing.T, msg string, err error) {
}
}

func testCompressionDecompression(t *testing.T, dict []byte, payload []byte) {
func testCompressionDecompression(t *testing.T, dict []byte, payload []byte, nbWorkers int) {
var w bytes.Buffer
writer := NewWriterLevelDict(&w, DefaultCompression, dict)

if nbWorkers > 1 {
if err := writer.SetNbWorkers(nbWorkers); err == ErrNoParallelSupport {
t.Skip()
}
}

_, err := writer.Write(payload)
failOnError(t, "Failed writing to compress object", err)
failOnError(t, "Failed to close compress object", writer.Close())
Expand Down Expand Up @@ -79,19 +87,19 @@ func TestResize(t *testing.T) {
}

func TestStreamSimpleCompressionDecompression(t *testing.T) {
testCompressionDecompression(t, nil, []byte("Hello world!"))
testCompressionDecompression(t, nil, []byte("Hello world!"), 1)
}

func TestStreamEmptySlice(t *testing.T) {
testCompressionDecompression(t, nil, []byte{})
testCompressionDecompression(t, nil, []byte{}, 1)
}

func TestZstdReaderLong(t *testing.T) {
var long bytes.Buffer
for i := 0; i < 10000; i++ {
long.Write([]byte("Hellow World!"))
}
testCompressionDecompression(t, nil, long.Bytes())
testCompressionDecompression(t, nil, long.Bytes(), 1)
}

func doStreamCompressionDecompression() error {
Expand Down Expand Up @@ -186,7 +194,7 @@ func TestStreamRealPayload(t *testing.T) {
if raw == nil {
t.Skip(ErrNoPayloadEnv)
}
testCompressionDecompression(t, nil, raw)
testCompressionDecompression(t, nil, raw, 1)
}

func TestStreamEmptyPayload(t *testing.T) {
Expand Down Expand Up @@ -398,12 +406,21 @@ func TestStreamWriteNoGoPointers(t *testing.T) {
})
}

func TestStreamSetNbWorkers(t *testing.T) {
// Build a big string first
s := strings.Repeat("foobaa", 1000*1000)

nbWorkers := 4
testCompressionDecompression(t, nil, []byte(s), nbWorkers)
}

func BenchmarkStreamCompression(b *testing.B) {
if raw == nil {
b.Fatal(ErrNoPayloadEnv)
}
var intermediate bytes.Buffer
w := NewWriter(&intermediate)
// w.SetNbWorkers(8)
defer w.Close()
b.SetBytes(int64(len(raw)))
b.ResetTimer()
Expand Down

0 comments on commit fd035e5

Please sign in to comment.