Skip to content

Commit

Permalink
zstd: Detect short invalid signatures
Browse files Browse the repository at this point in the history
Detect short frame signatures.

Fixes #381
  • Loading branch information
klauspost committed May 28, 2021
1 parent 2199375 commit 0106520
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
25 changes: 25 additions & 0 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,31 @@ func TestDecoderRegression(t *testing.T) {
}
}

func TestShort(t *testing.T) {
for _, in := range []string{"f", "fo", "foo"} {
inb := []byte(in)
dec, err := NewReader(nil)
if err != nil {
t.Fatal(err)
}
defer dec.Close()

t.Run(fmt.Sprintf("DecodeAll-%d", len(in)), func(t *testing.T) {
_, err := dec.DecodeAll(inb, nil)
if err == nil {
t.Error("want error, got nil")
}
})
t.Run(fmt.Sprintf("Reader-%d", len(in)), func(t *testing.T) {
dec.Reset(bytes.NewReader(inb))
_, err := io.Copy(ioutil.Discard, dec)
if err == nil {
t.Error("want error, got nil")
}
})
}
}

func TestDecoder_Reset(t *testing.T) {
in, err := ioutil.ReadFile("testdata/z000028")
if err != nil {
Expand Down
43 changes: 32 additions & 11 deletions zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,43 @@ func newFrameDec(o decoderOptions) *frameDec {
func (d *frameDec) reset(br byteBuffer) error {
d.HasCheckSum = false
d.WindowSize = 0
var b []byte
var signature [4]byte
for {
var err error
b, err = br.readSmall(4)
// Check if we can read more...
b, err := br.readSmall(1)
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return io.EOF
default:
return err
case nil:
signature[0] = b[0]
}
// Read the rest, don't allow io.ErrUnexpectedEOF
b, err = br.readSmall(3)
switch err {
case io.EOF:
return io.EOF
default:
return err
case nil:
copy(signature[1:], b)
}
if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {

if !bytes.Equal(signature[1:4], skippableFrameMagic) || signature[0]&0xf0 != 0x50 {
if debugDecoder {
println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString(skippableFrameMagic))
}
// Break if not skippable frame.
break
}
// Read size to skip
b, err = br.readSmall(4)
if err != nil {
println("Reading Frame Size", err)
if debugDecoder {
println("Reading Frame Size", err)
}
return err
}
n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
Expand All @@ -112,15 +127,19 @@ func (d *frameDec) reset(br byteBuffer) error {
return err
}
}
if !bytes.Equal(b, frameMagic) {
println("Got magic numbers: ", b, "want:", frameMagic)
if !bytes.Equal(signature[:], frameMagic) {
if debugDecoder {
println("Got magic numbers: ", signature, "want:", frameMagic)
}
return ErrMagicMismatch
}

// Read Frame_Header_Descriptor
fhd, err := br.readByte()
if err != nil {
println("Reading Frame_Header_Descriptor", err)
if debugDecoder {
println("Reading Frame_Header_Descriptor", err)
}
return err
}
d.SingleSegment = fhd&(1<<5) != 0
Expand All @@ -135,7 +154,9 @@ func (d *frameDec) reset(br byteBuffer) error {
if !d.SingleSegment {
wd, err := br.readByte()
if err != nil {
println("Reading Window_Descriptor", err)
if debugDecoder {
println("Reading Window_Descriptor", err)
}
return err
}
printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
Expand All @@ -153,7 +174,7 @@ func (d *frameDec) reset(br byteBuffer) error {
size = 4
}

b, err = br.readSmall(int(size))
b, err := br.readSmall(int(size))
if err != nil {
println("Reading Dictionary_ID", err)
return err
Expand Down Expand Up @@ -191,7 +212,7 @@ func (d *frameDec) reset(br byteBuffer) error {
}
d.FrameContentSize = 0
if fcsSize > 0 {
b, err = br.readSmall(fcsSize)
b, err := br.readSmall(fcsSize)
if err != nil {
println("Reading Frame content", err)
return err
Expand Down

0 comments on commit 0106520

Please sign in to comment.