diff --git a/zstd_stream.go b/zstd_stream.go index b351c27..5e5718f 100644 --- a/zstd_stream.go +++ b/zstd_stream.go @@ -421,29 +421,66 @@ func (r *reader) Read(p []byte) (int, error) { return 0, r.firstError } - // If we already have enough bytes, return - if r.decompSize-r.decompOff >= len(p) { - copy(p, r.decompressionBuffer[r.decompOff:]) - r.decompOff += len(p) - return len(p), nil + if len(p) == 0 { + return 0, nil + } + + // If we already have some uncompressed bytes, return without blocking + if r.decompSize > r.decompOff { + if r.decompSize-r.decompOff > len(p) { + copy(p, r.decompressionBuffer[r.decompOff:]) + r.decompOff += len(p) + return len(p), nil + } + // From https://golang.org/pkg/io/#Reader + // > Read conventionally returns what is available instead of waiting for more. + copy(p, r.decompressionBuffer[r.decompOff:r.decompSize]) + got := r.decompSize - r.decompOff + r.decompOff = r.decompSize + return got, nil } - copy(p, r.decompressionBuffer[r.decompOff:r.decompSize]) - got := r.decompSize - r.decompOff - r.decompSize = 0 - r.decompOff = 0 - - for got < len(p) { - // Populate src - src := r.compressionBuffer - reader := r.underlyingReader - n, err := TryReadFull(reader, src[r.compressionLeft:]) - if err != nil && err != errShortRead { // Handle underlying reader errors first - return 0, fmt.Errorf("failed to read from underlying reader: %s", err) - } else if n == 0 && r.compressionLeft == 0 { - return got, io.EOF + // Repeatedly read from the underlying reader until we get + // at least one zstd block, so that we don't block if the + // other end has flushed a block. + for { + // - If the last decompression didn't entirely fill the decompression buffer, + // zstd flushed all it could, and needs new data. In that case, do 1 Read. + // - If the last decompression did entirely fill the decompression buffer, + // it might have needed more room to decompress the input. In that case, + // don't do any unnecessary Read that might block. + needsData := r.decompSize < len(r.decompressionBuffer) + + var src []byte + if !needsData { + src = r.compressionBuffer[:r.compressionLeft] + } else { + src = r.compressionBuffer + var n int + var err error + // Read until data arrives or an error occurs. + for n == 0 && err == nil { + n, err = r.underlyingReader.Read(src[r.compressionLeft:]) + } + if err != nil && err != io.EOF { // Handle underlying reader errors first + return 0, fmt.Errorf("failed to read from underlying reader: %s", err) + } + if n == 0 { + // Ideally, we'd return with ErrUnexpectedEOF in all cases where the stream was unexpectedly EOF'd + // during a block or frame, i.e. when there are incomplete, pending compression data. + // However, it's hard to detect those cases with zstd. Namely, there is no way to know the size of + // the current buffered compression data in the zstd stream internal buffers. + // Best effort: throw ErrUnexpectedEOF if we still have some pending buffered compression data that + // zstd doesn't want to accept. + // If we don't have any buffered compression data but zstd still has some in its internal buffers, + // we will return with EOF instead. + if r.compressionLeft > 0 { + return 0, io.ErrUnexpectedEOF + } + return 0, io.EOF + } + src = src[:r.compressionLeft+n] } - src = src[:r.compressionLeft+n] // C code var srcPtr *byte // Do not point anywhere, if src is empty @@ -461,9 +498,9 @@ func (r *reader) Read(p []byte) (int, error) { ) retCode := int(r.resultBuffer.return_code) - // Keep src here eventhough we reuse later, the code might be deleted at some point + // Keep src here even though we reuse later, the code might be deleted at some point runtime.KeepAlive(src) - if err = getError(retCode); err != nil { + if err := getError(retCode); err != nil { return 0, fmt.Errorf("failed to decompress: %s", err) } @@ -473,10 +510,9 @@ func (r *reader) Read(p []byte) (int, error) { left := src[bytesConsumed:] copy(r.compressionBuffer, left) } - r.compressionLeft = len(src) - int(bytesConsumed) + r.compressionLeft = len(src) - bytesConsumed r.decompSize = int(r.resultBuffer.bytes_written) - r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize]) - got += r.decompOff + r.decompOff = copy(p, r.decompressionBuffer[:r.decompSize]) // Resize buffers nsize := retCode // Hint for next src buffer size @@ -488,25 +524,9 @@ func (r *reader) Read(p []byte) (int, error) { nsize = r.compressionLeft } r.compressionBuffer = resize(r.compressionBuffer, nsize) - } - return got, nil -} -// TryReadFull reads buffer just as ReadFull does -// Here we expect that buffer may end and we do not return ErrUnexpectedEOF as ReadAtLeast does. -// We return errShortRead instead to distinguish short reads and failures. -// We cannot use ReadFull/ReadAtLeast because it masks Reader errors, such as network failures -// and causes panic instead of error. -func TryReadFull(r io.Reader, buf []byte) (n int, err error) { - for n < len(buf) && err == nil { - var nn int - nn, err = r.Read(buf[n:]) - n += nn - } - if n == len(buf) && err == io.EOF { - err = nil // EOF at the end is somewhat expected - } else if err == io.EOF { - err = errShortRead + if r.decompOff > 0 { + return r.decompOff, nil + } } - return } diff --git a/zstd_stream_test.go b/zstd_stream_test.go index 79f412e..bad908a 100644 --- a/zstd_stream_test.go +++ b/zstd_stream_test.go @@ -39,7 +39,7 @@ func testCompressionDecompression(t *testing.T, dict []byte, payload []byte) { // Decompress r := NewReaderDict(rr, dict) dst := make([]byte, len(payload)) - n, err := r.Read(dst) + n, err := io.ReadFull(r, dst) if err != nil { failOnError(t, "Failed to read for decompression", err) } @@ -211,9 +211,16 @@ func TestStreamEmptyPayload(t *testing.T) { } func TestStreamFlush(t *testing.T) { - var w bytes.Buffer - writer := NewWriter(&w) - reader := NewReader(&w) + // use an actual os pipe so that + // - it's buffered and we don't get a 1-read = 1-write behaviour (io.Pipe) + // - reading doesn't send EOF when we're done reading the buffer (bytes.Buffer) + pr, pw, err := os.Pipe() + failOnError(t, "Failed creating pipe", err) + defer pw.Close() + defer pr.Close() + + writer := NewWriter(pw) + reader := NewReader(pr) payload := "cc" // keep the payload short to make sure it will not be automatically flushed by zstd buf := make([]byte, len(payload)) @@ -429,7 +436,7 @@ func BenchmarkStreamDecompression(b *testing.B) { for i := 0; i < b.N; i++ { rr := bytes.NewReader(compressed) r := NewReader(rr) - _, err := r.Read(dst) + _, err := io.ReadFull(r, dst) if err != nil { b.Fatalf("Failed to decompress: %s", err) }