Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Detect short invalid signatures #382

Merged
merged 1 commit into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,31 @@ func TestDecoderRegression(t *testing.T) {
}
}

func TestShort(t *testing.T) {
for _, in := range []string{"f", "fo", "foo"} {
inb := []byte(in)
dec, err := NewReader(nil)
if err != nil {
t.Fatal(err)
}
defer dec.Close()

t.Run(fmt.Sprintf("DecodeAll-%d", len(in)), func(t *testing.T) {
_, err := dec.DecodeAll(inb, nil)
if err == nil {
t.Error("want error, got nil")
}
})
t.Run(fmt.Sprintf("Reader-%d", len(in)), func(t *testing.T) {
dec.Reset(bytes.NewReader(inb))
_, err := io.Copy(ioutil.Discard, dec)
if err == nil {
t.Error("want error, got nil")
}
})
}
}

func TestDecoder_Reset(t *testing.T) {
in, err := ioutil.ReadFile("testdata/z000028")
if err != nil {
Expand Down
43 changes: 32 additions & 11 deletions zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,43 @@ func newFrameDec(o decoderOptions) *frameDec {
func (d *frameDec) reset(br byteBuffer) error {
d.HasCheckSum = false
d.WindowSize = 0
var b []byte
var signature [4]byte
for {
var err error
b, err = br.readSmall(4)
// Check if we can read more...
b, err := br.readSmall(1)
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return io.EOF
default:
return err
case nil:
signature[0] = b[0]
}
// Read the rest, don't allow io.ErrUnexpectedEOF
b, err = br.readSmall(3)
switch err {
case io.EOF:
return io.EOF
default:
return err
case nil:
copy(signature[1:], b)
}
if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {

if !bytes.Equal(signature[1:4], skippableFrameMagic) || signature[0]&0xf0 != 0x50 {
if debugDecoder {
println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString(skippableFrameMagic))
}
// Break if not skippable frame.
break
}
// Read size to skip
b, err = br.readSmall(4)
if err != nil {
println("Reading Frame Size", err)
if debugDecoder {
println("Reading Frame Size", err)
}
return err
}
n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
Expand All @@ -112,15 +127,19 @@ func (d *frameDec) reset(br byteBuffer) error {
return err
}
}
if !bytes.Equal(b, frameMagic) {
println("Got magic numbers: ", b, "want:", frameMagic)
if !bytes.Equal(signature[:], frameMagic) {
if debugDecoder {
println("Got magic numbers: ", signature, "want:", frameMagic)
}
return ErrMagicMismatch
}

// Read Frame_Header_Descriptor
fhd, err := br.readByte()
if err != nil {
println("Reading Frame_Header_Descriptor", err)
if debugDecoder {
println("Reading Frame_Header_Descriptor", err)
}
return err
}
d.SingleSegment = fhd&(1<<5) != 0
Expand All @@ -135,7 +154,9 @@ func (d *frameDec) reset(br byteBuffer) error {
if !d.SingleSegment {
wd, err := br.readByte()
if err != nil {
println("Reading Window_Descriptor", err)
if debugDecoder {
println("Reading Window_Descriptor", err)
}
return err
}
printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
Expand All @@ -153,7 +174,7 @@ func (d *frameDec) reset(br byteBuffer) error {
size = 4
}

b, err = br.readSmall(int(size))
b, err := br.readSmall(int(size))
if err != nil {
println("Reading Dictionary_ID", err)
return err
Expand Down Expand Up @@ -191,7 +212,7 @@ func (d *frameDec) reset(br byteBuffer) error {
}
d.FrameContentSize = 0
if fcsSize > 0 {
b, err = br.readSmall(fcsSize)
b, err := br.readSmall(fcsSize)
if err != nil {
println("Reading Frame content", err)
return err
Expand Down