From 010652059a70c53c1d81a6aecfe8d901c0abb5e8 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 28 May 2021 09:44:48 +0200 Subject: [PATCH] zstd: Detect short invalid signatures Detect short frame signatures. Fixes #381 --- zstd/decoder_test.go | 25 +++++++++++++++++++++++++ zstd/framedec.go | 43 ++++++++++++++++++++++++++++++++----------- 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 0476751667..fcc5dd98a8 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -575,6 +575,31 @@ func TestDecoderRegression(t *testing.T) { } } +func TestShort(t *testing.T) { + for _, in := range []string{"f", "fo", "foo"} { + inb := []byte(in) + dec, err := NewReader(nil) + if err != nil { + t.Fatal(err) + } + defer dec.Close() + + t.Run(fmt.Sprintf("DecodeAll-%d", len(in)), func(t *testing.T) { + _, err := dec.DecodeAll(inb, nil) + if err == nil { + t.Error("want error, got nil") + } + }) + t.Run(fmt.Sprintf("Reader-%d", len(in)), func(t *testing.T) { + dec.Reset(bytes.NewReader(inb)) + _, err := io.Copy(ioutil.Discard, dec) + if err == nil { + t.Error("want error, got nil") + } + }) + } +} + func TestDecoder_Reset(t *testing.T) { in, err := ioutil.ReadFile("testdata/z000028") if err != nil { diff --git a/zstd/framedec.go b/zstd/framedec.go index 52d035ee9c..e8cc9a2c22 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -78,20 +78,33 @@ func newFrameDec(o decoderOptions) *frameDec { func (d *frameDec) reset(br byteBuffer) error { d.HasCheckSum = false d.WindowSize = 0 - var b []byte + var signature [4]byte for { var err error - b, err = br.readSmall(4) + // Check if we can read more... + b, err := br.readSmall(1) switch err { case io.EOF, io.ErrUnexpectedEOF: return io.EOF default: return err case nil: + signature[0] = b[0] + } + // Read the rest, don't allow io.ErrUnexpectedEOF + b, err = br.readSmall(3) + switch err { + case io.EOF: + return io.EOF + default: + return err + case nil: + copy(signature[1:], b) } - if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 { + + if !bytes.Equal(signature[1:4], skippableFrameMagic) || signature[0]&0xf0 != 0x50 { if debugDecoder { - println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic)) + println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString(skippableFrameMagic)) } // Break if not skippable frame. break @@ -99,7 +112,9 @@ func (d *frameDec) reset(br byteBuffer) error { // Read size to skip b, err = br.readSmall(4) if err != nil { - println("Reading Frame Size", err) + if debugDecoder { + println("Reading Frame Size", err) + } return err } n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) @@ -112,15 +127,19 @@ func (d *frameDec) reset(br byteBuffer) error { return err } } - if !bytes.Equal(b, frameMagic) { - println("Got magic numbers: ", b, "want:", frameMagic) + if !bytes.Equal(signature[:], frameMagic) { + if debugDecoder { + println("Got magic numbers: ", signature, "want:", frameMagic) + } return ErrMagicMismatch } // Read Frame_Header_Descriptor fhd, err := br.readByte() if err != nil { - println("Reading Frame_Header_Descriptor", err) + if debugDecoder { + println("Reading Frame_Header_Descriptor", err) + } return err } d.SingleSegment = fhd&(1<<5) != 0 @@ -135,7 +154,9 @@ func (d *frameDec) reset(br byteBuffer) error { if !d.SingleSegment { wd, err := br.readByte() if err != nil { - println("Reading Window_Descriptor", err) + if debugDecoder { + println("Reading Window_Descriptor", err) + } return err } printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3) @@ -153,7 +174,7 @@ func (d *frameDec) reset(br byteBuffer) error { size = 4 } - b, err = br.readSmall(int(size)) + b, err := br.readSmall(int(size)) if err != nil { println("Reading Dictionary_ID", err) return err @@ -191,7 +212,7 @@ func (d *frameDec) reset(br byteBuffer) error { } d.FrameContentSize = 0 if fcsSize > 0 { - b, err = br.readSmall(fcsSize) + b, err := br.readSmall(fcsSize) if err != nil { println("Reading Frame content", err) return err