diff --git a/zstd/decoder.go b/zstd/decoder.go index 30459cd3fb..2aeb953ca4 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -341,15 +341,8 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } return dst, err } - if frame.DictionaryID != nil { - dict, ok := d.dicts[*frame.DictionaryID] - if !ok { - return nil, ErrUnknownDictionary - } - if debugDecoder { - println("setting dict", frame.DictionaryID) - } - frame.history.setDict(&dict) + if err = d.setDict(frame); err != nil { + return nil, err } if frame.WindowSize > d.o.maxWindowSize { if debugDecoder { @@ -495,18 +488,12 @@ func (d *Decoder) nextBlockSync() (ok bool) { if !d.syncStream.inFrame { d.frame.history.reset() d.current.err = d.frame.reset(&d.syncStream.br) + if d.current.err == nil { + d.current.err = d.setDict(d.frame) + } if d.current.err != nil { return false } - if d.frame.DictionaryID != nil { - dict, ok := d.dicts[*d.frame.DictionaryID] - if !ok { - d.current.err = ErrUnknownDictionary - return false - } else { - d.frame.history.setDict(&dict) - } - } if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize { d.current.err = ErrDecoderSizeExceeded return false @@ -865,13 +852,8 @@ decodeStream: if debugDecoder && err != nil { println("Frame decoder returned", err) } - if err == nil && frame.DictionaryID != nil { - dict, ok := d.dicts[*frame.DictionaryID] - if !ok { - err = ErrUnknownDictionary - } else { - frame.history.setDict(&dict) - } + if err == nil { + err = d.setDict(frame) } if err == nil && d.frame.WindowSize > d.o.maxWindowSize { if debugDecoder { @@ -953,3 +935,20 @@ decodeStream: hist.reset() d.frame.history.b = frameHistCache } + +func (d *Decoder) setDict(frame *frameDec) (err error) { + dict, ok := d.dicts[frame.DictionaryID] + if ok { + if debugDecoder { + println("setting dict", frame.DictionaryID) + } + frame.history.setDict(&dict) + } else if frame.DictionaryID != 0 { + // A zero or missing dictionary id is ambiguous: + // either dictionary zero, or no dictionary. In particular, + // zstd --patch-from uses this id for the source file, + // so only return an error if the dictionary id is not zero. + err = ErrUnknownDictionary + } + return err +} diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index f42448e69c..c36853d660 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -6,6 +6,8 @@ package zstd import ( "errors" + "fmt" + "math/bits" "runtime" ) @@ -85,7 +87,13 @@ func WithDecoderMaxMemory(n uint64) DOption { } // WithDecoderDicts allows to register one or more dictionaries for the decoder. -// If several dictionaries with the same ID is provided the last one will be used. +// +// Each slice in dict must be in the [dictionary format] produced by +// "zstd --train" from the Zstandard reference implementation. +// +// If several dictionaries with the same ID are provided, the last one will be used. +// +// [dictionary format]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format func WithDecoderDicts(dicts ...[]byte) DOption { return func(o *decoderOptions) error { for _, b := range dicts { @@ -99,6 +107,18 @@ func WithDecoderDicts(dicts ...[]byte) DOption { } } +// WithEncoderDictRaw registers a dictionary that may be used by the decoder. +// The slice content can be arbitrary data. +func WithDecoderDictRaw(id uint32, content []byte) DOption { + return func(o *decoderOptions) error { + if bits.UintSize > 32 && uint(len(content)) > dictMaxLength { + return fmt.Errorf("dictionary of size %d > 2GiB too large", len(content)) + } + o.dicts = append(o.dicts, dict{id: id, content: content, offsets: [3]int{1, 4, 8}}) + return nil + } +} + // WithDecoderMaxWindow allows to set a maximum window size for decodes. // This allows rejecting packets that will cause big memory usage. // The Decoder will likely allocate more memory based on the WithDecoderLowmem setting. diff --git a/zstd/dict.go b/zstd/dict.go index b2725f77b5..66a95c18ef 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -21,6 +21,9 @@ type dict struct { const dictMagic = "\x37\xa4\x30\xec" +// Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB. +const dictMaxLength = 1 << 31 + // ID returns the dictionary id or 0 if d is nil. func (d *dict) ID() uint32 { if d == nil { diff --git a/zstd/dict_test.go b/zstd/dict_test.go index 872d594af3..28024d03a4 100644 --- a/zstd/dict_test.go +++ b/zstd/dict_test.go @@ -459,3 +459,38 @@ func readDicts(tb testing.TB, zr *zip.Reader) [][]byte { } return dicts } + +// Test decoding of zstd --patch-from output. +func TestDecoderRawDict(t *testing.T) { + t.Parallel() + + dict, err := os.ReadFile("testdata/delta/source.txt") + if err != nil { + t.Fatal(err) + } + + delta, err := os.Open("testdata/delta/target.txt.zst") + if err != nil { + t.Fatal(err) + } + defer delta.Close() + + dec, err := NewReader(delta, WithDecoderDictRaw(0, dict)) + if err != nil { + t.Fatal(err) + } + + out, err := io.ReadAll(dec) + if err != nil { + t.Fatal(err) + } + + ref, err := os.ReadFile("testdata/delta/target.txt") + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(out, ref) { + t.Errorf("mismatch: got %q, wanted %q", out, ref) + } +} diff --git a/zstd/encoder_options.go b/zstd/encoder_options.go index 6015f498af..8e15be2f7f 100644 --- a/zstd/encoder_options.go +++ b/zstd/encoder_options.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math" + "math/bits" "runtime" "strings" ) @@ -305,7 +306,13 @@ func WithLowerEncoderMem(b bool) EOption { } // WithEncoderDict allows to register a dictionary that will be used for the encode. +// +// The slice dict must be in the [dictionary format] produced by +// "zstd --train" from the Zstandard reference implementation. +// // The encoder *may* choose to use no dictionary instead for certain payloads. +// +// [dictionary format]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format func WithEncoderDict(dict []byte) EOption { return func(o *encoderOptions) error { d, err := loadDict(dict) @@ -316,3 +323,17 @@ func WithEncoderDict(dict []byte) EOption { return nil } } + +// WithEncoderDictRaw registers a dictionary that may be used by the encoder. +// +// The slice content may contain arbitrary data. It will be used as an initial +// history. +func WithEncoderDictRaw(id uint32, content []byte) EOption { + return func(o *encoderOptions) error { + if bits.UintSize > 32 && uint(len(content)) > dictMaxLength { + return fmt.Errorf("dictionary of size %d > 2GiB too large", len(content)) + } + o.dict = &dict{id: id, content: content, offsets: [3]int{1, 4, 8}} + return nil + } +} diff --git a/zstd/example_test.go b/zstd/example_test.go new file mode 100644 index 0000000000..13a13326e2 --- /dev/null +++ b/zstd/example_test.go @@ -0,0 +1,46 @@ +package zstd_test + +import ( + "bytes" + "fmt" + + "github.com/klauspost/compress/zstd" +) + +func ExampleWithEncoderDictRaw() { + // "Raw" dictionaries can be used for compressed delta encoding. + + source := []byte(` + This is the source file. Compression of the target file with + the source file as the dictionary will produce a compressed + delta encoding of the target file.`) + target := []byte(` + This is the target file. Decompression of the delta encoding with + the source file as the dictionary will produce this file.`) + + // The dictionary id is arbitrary. We use zero for compatibility + // with zstd --patch-from, but applications can use any id + // not in the range [32768, 1<<31). + const id = 0 + + bestLevel := zstd.WithEncoderLevel(zstd.SpeedBestCompression) + + w, _ := zstd.NewWriter(nil, bestLevel, + zstd.WithEncoderDictRaw(id, source)) + delta := w.EncodeAll(target, nil) + + r, _ := zstd.NewReader(nil, zstd.WithDecoderDictRaw(id, source)) + out, err := r.DecodeAll(delta, nil) + if err != nil || !bytes.Equal(out, target) { + panic("decoding error") + } + + // Ordinary compression, for reference. + w, _ = zstd.NewWriter(nil, bestLevel) + compressed := w.EncodeAll(target, nil) + + // Check that the delta is at most half as big as the compressed file. + fmt.Println(len(delta) < len(compressed)/2) + // Output: + // true +} diff --git a/zstd/framedec.go b/zstd/framedec.go index 65984bf07c..d8e8a05bd7 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -29,7 +29,7 @@ type frameDec struct { FrameContentSize uint64 - DictionaryID *uint32 + DictionaryID uint32 HasCheckSum bool SingleSegment bool } @@ -155,7 +155,7 @@ func (d *frameDec) reset(br byteBuffer) error { // Read Dictionary_ID // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id - d.DictionaryID = nil + d.DictionaryID = 0 if size := fhd & 3; size != 0 { if size == 3 { size = 4 @@ -178,11 +178,7 @@ func (d *frameDec) reset(br byteBuffer) error { if debugDecoder { println("Dict size", size, "ID:", id) } - if id > 0 { - // ID 0 means "sorry, no dictionary anyway". - // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format - d.DictionaryID = &id - } + d.DictionaryID = id } // Read Frame_Content_Size diff --git a/zstd/testdata/delta/source.txt b/zstd/testdata/delta/source.txt new file mode 100644 index 0000000000..d97357d4db --- /dev/null +++ b/zstd/testdata/delta/source.txt @@ -0,0 +1,5 @@ +0000000000000000 + +This file is to be used as the dictionary for compressing target.txt: + + zstd -19 --patch-from=source.txt target.txt diff --git a/zstd/testdata/delta/target.txt b/zstd/testdata/delta/target.txt new file mode 100644 index 0000000000..a7a6c91ec5 --- /dev/null +++ b/zstd/testdata/delta/target.txt @@ -0,0 +1,5 @@ +0000000000000000 + +This file is to be compressed with source.txt as the dictionary: + + zstd -19 --patch-from=source.txt target.txt diff --git a/zstd/testdata/delta/target.txt.zst b/zstd/testdata/delta/target.txt.zst new file mode 100644 index 0000000000..e2d12c56d8 Binary files /dev/null and b/zstd/testdata/delta/target.txt.zst differ diff --git a/zstd/zstd.go b/zstd/zstd.go index b1886f7c74..5ffa82f5ac 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -72,7 +72,6 @@ var ( ErrDecoderSizeExceeded = errors.New("decompressed size exceeds configured limit") // ErrUnknownDictionary is returned if the dictionary ID is unknown. - // For the time being dictionaries are not supported. ErrUnknownDictionary = errors.New("unknown dictionary") // ErrFrameSizeExceeded is returned if the stated frame size is exceeded.