Skip to content

Commit

Permalink
Also check write/flush after close.
Browse files Browse the repository at this point in the history
  • Loading branch information
klauspost committed Oct 8, 2024
1 parent a452cbb commit 05c93ea
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
24 changes: 19 additions & 5 deletions zstd/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
// and write CRC if requested.
func (e *Encoder) Write(p []byte) (n int, err error) {
s := &e.state
if s.eofWritten {
return 0, ErrEncoderClosed
}
for len(p) > 0 {
if len(p)+len(s.filling) < e.o.blockSize {
if e.o.crc {
Expand Down Expand Up @@ -289,6 +292,9 @@ func (e *Encoder) nextBlock(final bool) error {
s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
s.nInput += int64(len(s.current))
s.wg.Add(1)
if final {
s.eofWritten = true
}
go func(src []byte) {
if debugEncoder {
println("Adding block,", len(src), "bytes, final:", final)
Expand All @@ -304,9 +310,6 @@ func (e *Encoder) nextBlock(final bool) error {
blk := enc.Block()
enc.Encode(blk, src)
blk.last = final
if final {
s.eofWritten = true
}
// Wait for pending writes.
s.wWg.Wait()
if s.writeErr != nil {
Expand Down Expand Up @@ -402,12 +405,20 @@ func (e *Encoder) Flush() error {
if len(s.filling) > 0 {
err := e.nextBlock(false)
if err != nil {
// Ignore Flush after Close.
if errors.Is(s.err, ErrEncoderClosed) {
return nil
}
return err
}
}
s.wg.Wait()
s.wWg.Wait()
if s.err != nil {
// Ignore Flush after Close.
if errors.Is(s.err, ErrEncoderClosed) {
return nil
}
return s.err
}
return s.writeErr
Expand All @@ -418,11 +429,14 @@ func (e *Encoder) Flush() error {
// The Encoder can still be re-used after calling this.
func (e *Encoder) Close() error {
s := &e.state
if s.encoder == nil || errors.Is(s.err, ErrDecoderClosed) {
if s.encoder == nil {
return nil
}
err := e.nextBlock(true)
if err != nil {
if errors.Is(s.err, ErrEncoderClosed) {
return nil
}
return err
}
if s.frameContentSize > 0 {
Expand Down Expand Up @@ -461,7 +475,7 @@ func (e *Encoder) Close() error {
_, s.err = s.w.Write(frame)
}
if s.err == nil {
s.err = ErrDecoderClosed
s.err = ErrEncoderClosed
return nil
}

Expand Down
5 changes: 5 additions & 0 deletions zstd/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zstd

import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -282,6 +283,10 @@ func TestEncoderRegression(t *testing.T) {
if err != nil {
t.Error(err)
}
_, err = enc.Write([]byte{1, 2, 3, 4})
if !errors.Is(err, ErrEncoderClosed) {
t.Errorf("unexpected error: %v", err)
}
encoded = dst.Bytes()
if len(encoded) > enc.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
Expand Down

0 comments on commit 05c93ea

Please sign in to comment.