From 54198fb08092a219c9447819de1ffc88af46e37e Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 24 Oct 2019 17:43:10 +0200 Subject: [PATCH] zstd: Fix reuse of huff0 when data hard to compress (#173) * zstd: Fix reuse of huff0 when data hard to compress zstd would reject huff0 compressed literals if the improvement was too small. However, this would update the huff0 state to contain a new table which could be reused. In that case a wrong table could be used for the next block. We move the rejection code to huff0, so the state can be properly maintained. Fixes #170 --- huff0/compress.go | 16 +++++++++----- huff0/compress_test.go | 49 ++++++++++++++++++++++++++++-------------- huff0/huff0.go | 6 ++++++ zstd/blockenc.go | 10 ++------- zstd/enc_dfast.go | 6 +++--- zstd/zstd.go | 1 + 6 files changed, 56 insertions(+), 32 deletions(-) diff --git a/huff0/compress.go b/huff0/compress.go index 61c6ede080..51e00aaeb2 100644 --- a/huff0/compress.go +++ b/huff0/compress.go @@ -54,6 +54,12 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error) canReuse = s.canUseTable(s.prevTable) } + // We want the output size to be less than this: + wantSize := len(in) + if s.WantLogLess > 0 { + wantSize -= wantSize >> s.WantLogLess + } + // Reset for next run. s.clearCount = true s.maxCount = 0 @@ -77,7 +83,7 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error) s.cTable = s.prevTable s.Out, err = compressor(in) s.cTable = keepTable - if err == nil && len(s.Out) < len(in) { + if err == nil && len(s.Out) < wantSize { s.OutData = s.Out return s.Out, true, nil } @@ -100,16 +106,16 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error) hSize := len(s.Out) oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen]) newSize := s.cTable.estimateSize(s.count[:s.symbolLen]) - if oldSize <= hSize+newSize || hSize+12 >= len(in) { + if oldSize <= hSize+newSize || hSize+12 >= wantSize { // Retain cTable even if we re-use. keepTable := s.cTable s.cTable = s.prevTable s.Out, err = compressor(in) + s.cTable = keepTable if err != nil { return nil, false, err } - s.cTable = keepTable - if len(s.Out) >= len(in) { + if len(s.Out) >= wantSize { return nil, false, ErrIncompressible } s.OutData = s.Out @@ -131,7 +137,7 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error) s.OutTable = nil return nil, false, err } - if len(s.Out) >= len(in) { + if len(s.Out) >= wantSize { s.OutTable = nil return nil, false, ErrIncompressible } diff --git a/huff0/compress_test.go b/huff0/compress_test.go index f8d6cfeaa7..3bf5fbf7c6 100644 --- a/huff0/compress_test.go +++ b/huff0/compress_test.go @@ -91,19 +91,24 @@ func init() { func TestCompressRegression(t *testing.T) { // Match the fuzz function var testInput = func(data []byte) int { - var sc Scratch - comp, _, err := Compress1X(data, &sc) + var enc Scratch + enc.WantLogLess = 5 + comp, _, err := Compress1X(data, &enc) if err == ErrIncompressible || err == ErrUseRLE || err == ErrTooBig { return 0 } if err != nil { panic(err) } - s, remain, err := ReadTable(comp, nil) + if len(comp) >= len(data)-len(data)>>enc.WantLogLess { + panic(fmt.Errorf("too large output provided. got %d, but should be < %d", len(comp), len(data)-len(data)>>enc.WantLogLess)) + } + + dec, remain, err := ReadTable(comp, nil) if err != nil { panic(err) } - out, err := s.Decompress1X(remain) + out, err := dec.Decompress1X(remain) if err != nil { panic(err) } @@ -111,22 +116,26 @@ func TestCompressRegression(t *testing.T) { panic("decompression 1x mismatch") } // Reuse as 4X - sc.Reuse = ReusePolicyAllow - comp, reUsed, err := Compress4X(data, &sc) + enc.Reuse = ReusePolicyAllow + comp, reUsed, err := Compress4X(data, &enc) if err == ErrIncompressible || err == ErrUseRLE || err == ErrTooBig { return 0 } if err != nil { panic(err) } + if len(comp) >= len(data)-len(data)>>enc.WantLogLess { + panic(fmt.Errorf("too large output provided. got %d, but should be < %d", len(comp), len(data)-len(data)>>enc.WantLogLess)) + } + remain = comp if !reUsed { - s, remain, err = ReadTable(comp, s) + dec, remain, err = ReadTable(comp, dec) if err != nil { panic(err) } } - out, err = s.Decompress4X(remain, len(data)) + out, err = dec.Decompress4X(remain, len(data)) if err != nil { panic(err) } @@ -134,8 +143,8 @@ func TestCompressRegression(t *testing.T) { panic("decompression 4x with reuse mismatch") } - s.Reuse = ReusePolicyNone - comp, reUsed, err = Compress4X(data, s) + enc.Reuse = ReusePolicyNone + comp, reUsed, err = Compress4X(data, &enc) if err == ErrIncompressible || err == ErrUseRLE || err == ErrTooBig { return 0 } @@ -145,11 +154,15 @@ func TestCompressRegression(t *testing.T) { if reUsed { panic("reused when asked not to") } - s, remain, err = ReadTable(comp, nil) + if len(comp) >= len(data)-len(data)>>enc.WantLogLess { + panic(fmt.Errorf("too large output provided. got %d, but should be < %d", len(comp), len(data)-len(data)>>enc.WantLogLess)) + } + + dec, remain, err = ReadTable(comp, dec) if err != nil { panic(err) } - out, err = s.Decompress4X(remain, len(data)) + out, err = dec.Decompress4X(remain, len(data)) if err != nil { panic(err) } @@ -158,22 +171,26 @@ func TestCompressRegression(t *testing.T) { } // Reuse as 1X - s.Reuse = ReusePolicyAllow - comp, reUsed, err = Compress1X(data, &sc) + dec.Reuse = ReusePolicyAllow + comp, reUsed, err = Compress1X(data, &enc) if err == ErrIncompressible || err == ErrUseRLE || err == ErrTooBig { return 0 } if err != nil { panic(err) } + if len(comp) >= len(data)-len(data)>>enc.WantLogLess { + panic(fmt.Errorf("too large output provided. got %d, but should be < %d", len(comp), len(data)-len(data)>>enc.WantLogLess)) + } + remain = comp if !reUsed { - s, remain, err = ReadTable(comp, s) + dec, remain, err = ReadTable(comp, dec) if err != nil { panic(err) } } - out, err = s.Decompress1X(remain) + out, err = dec.Decompress1X(remain) if err != nil { panic(err) } diff --git a/huff0/huff0.go b/huff0/huff0.go index 6f823f94d7..6bc23bbf00 100644 --- a/huff0/huff0.go +++ b/huff0/huff0.go @@ -89,6 +89,12 @@ type Scratch struct { // Reuse will specify the reuse policy Reuse ReusePolicy + // WantLogLess allows to specify a log 2 reduction that should at least be achieved, + // otherwise the block will be returned as incompressible. + // The reduction should then at least be (input size >> WantLogLess) + // If WantLogLess == 0 any improvement will do. + WantLogLess uint8 + // MaxDecodedSize will set the maximum allowed output size. // This value will automatically be set to BlockSizeMax if not set. // Decoders will return ErrMaxDecodedSizeExceeded is this limit is exceeded. diff --git a/zstd/blockenc.go b/zstd/blockenc.go index 9d9151a0ef..1dd8278022 100644 --- a/zstd/blockenc.go +++ b/zstd/blockenc.go @@ -51,7 +51,7 @@ func (b *blockEnc) init() { b.coders.llEnc = &fseEncoder{} b.coders.llPrev = &fseEncoder{} } - b.litEnc = &huff0.Scratch{} + b.litEnc = &huff0.Scratch{WantLogLess: 4} b.reset(nil) } @@ -415,16 +415,10 @@ func (b *blockEnc) encode() error { if len(b.literals) >= 1024 { // Use 4 Streams. out, reUsed, err = huff0.Compress4X(b.literals, b.litEnc) - if len(out) > len(b.literals)-len(b.literals)>>4 { - err = huff0.ErrIncompressible - } } else if len(b.literals) > 32 { // Use 1 stream single = true out, reUsed, err = huff0.Compress1X(b.literals, b.litEnc) - if len(out) > len(b.literals)-len(b.literals)>>4 { - err = huff0.ErrIncompressible - } } else { err = huff0.ErrIncompressible } @@ -711,7 +705,7 @@ func (b *blockEnc) encode() error { return nil } -var errIncompressible = errors.New("uncompressible") +var errIncompressible = errors.New("incompressible") func (b *blockEnc) genCodes() { if len(b.sequences) == 0 { diff --git a/zstd/enc_dfast.go b/zstd/enc_dfast.go index e120625d85..2f41bcd0d5 100644 --- a/zstd/enc_dfast.go +++ b/zstd/enc_dfast.go @@ -235,7 +235,7 @@ encodeLoop: if debug && s-t > e.maxMatchOff { panic("s - t >e.maxMatchOff") } - if debug { + if debugMatches { println("long match") } break @@ -259,7 +259,7 @@ encodeLoop: // but the likelihood of both the first 4 bytes and the hash matching should be enough. t = candidateL.offset - e.cur s += checkAt - if debug { + if debugMatches { println("long match (after short)") } break @@ -275,7 +275,7 @@ encodeLoop: if debug && t < 0 { panic("t<0") } - if debug { + if debugMatches { println("short match") } break diff --git a/zstd/zstd.go b/zstd/zstd.go index b975954c1c..57a8a2f5bb 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -11,6 +11,7 @@ import ( const debug = false const debugSequences = false +const debugMatches = false // force encoder to use predefined tables. const forcePreDef = false