diff --git a/flate/deflate.go b/flate/deflate.go index 944c90ea85..c3d3aabe24 100644 --- a/flate/deflate.go +++ b/flate/deflate.go @@ -425,9 +425,7 @@ Loop: } d.hash = newH } - d.index = newIndex - } else { // For matches this long, we don't bother inserting each individual // item into the table. @@ -480,7 +478,6 @@ func (d *compressor) deflateNoSkip() { d.hash = d.hasher(d.window[d.index:d.index+minMatchLength]) & hashMask } -Loop: for { if sanity && d.index > d.windowEnd { panic("index > windowEnd") @@ -488,7 +485,7 @@ Loop: lookahead := d.windowEnd - d.index if lookahead < minMatchLength+maxMatchLength { if !d.sync { - break Loop + return } if sanity && d.index > d.windowEnd { panic("index > windowEnd") @@ -507,7 +504,7 @@ Loop: } d.tokens.n = 0 } - break Loop + return } } if d.index < d.maxInsertIndex { @@ -538,6 +535,7 @@ Loop: // not better. Output the previous match. d.tokens.tokens[d.tokens.n] = matchToken(uint32(prevLength-3), uint32(prevOffset-minOffsetSize)) d.tokens.n++ + // Insert in the hash table all strings up to the end of the match. // index and index-1 are already inserted. If there is not enough // lookahead, the last two strings are not inserted into the hash @@ -573,7 +571,6 @@ Loop: } d.index = newIndex - d.byteAvailable = false d.length = minMatchLength - 1 if d.tokens.n == maxFlateBlockTokens { @@ -588,13 +585,13 @@ Loop: if d.length >= minMatchLength { d.ii = 0 } + // We have a byte waiting. Emit it. if d.byteAvailable { d.ii++ - i := d.index - 1 - d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[i])) + d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[d.index-1])) d.tokens.n++ if d.tokens.n == maxFlateBlockTokens { - if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil { + if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil { return } d.tokens.n = 0 @@ -604,29 +601,38 @@ Loop: // If we have a long run of no matches, skip additional bytes // Resets when d.ii overflows after 64KB. if d.ii > 31 { - n := int(d.ii >> 5) + n := int(d.ii >> 6) for j := 0; j < n; j++ { - i := d.index - 1 if d.index >= d.windowEnd-1 { break } - d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[i])) + d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[d.index-1])) d.tokens.n++ if d.tokens.n == maxFlateBlockTokens { - if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil { + if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil { return } d.tokens.n = 0 } d.index++ } + // Flush last byte + d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[d.index-1])) + d.tokens.n++ + d.byteAvailable = false + // d.length = minMatchLength - 1 // not needed, since d.ii is reset above, so it should never be > minMatchLength + if d.tokens.n == maxFlateBlockTokens { + if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil { + return + } + d.tokens.n = 0 + } } - } else { d.index++ + d.byteAvailable = true } - d.byteAvailable = true } } } @@ -659,6 +665,7 @@ func (d *compressor) storeHuff() { return } d.w.writeBlockHuff(false, d.window[:d.windowEnd]) + d.err = d.w.err d.windowEnd = 0 } @@ -672,22 +679,31 @@ func (d *compressor) storeSnappy() { } snappyEncode(&d.tokens, d.window[:d.windowEnd]) d.w.writeBlock(d.tokens, false, d.window[:d.windowEnd]) + d.err = d.w.err d.tokens.n = 0 d.windowEnd = 0 } func (d *compressor) write(b []byte) (n int, err error) { + if d.err != nil { + return 0, d.err + } n = len(b) - b = b[d.fill(d, b):] for len(b) > 0 { d.step(d) b = b[d.fill(d, b):] + if d.err != nil { + return 0, d.err + } } return n, d.err } func (d *compressor) syncFlush() error { d.sync = true + if d.err != nil { + return d.err + } d.step(d) if d.err == nil { d.w.writeStoredHeader(0, false) @@ -733,9 +749,7 @@ func (d *compressor) init(w io.Writer, level int) (err error) { return nil } -var zeroes [64]int var hzeroes [256]hashid -var bzeroes [256]byte func (d *compressor) reset(w io.Writer) { d.w.reset(w) @@ -769,6 +783,9 @@ func (d *compressor) reset(w io.Writer) { } func (d *compressor) close() error { + if d.err != nil { + return d.err + } d.sync = true d.step(d) if d.err != nil { diff --git a/flate/deflate_test.go b/flate/deflate_test.go index 318adfdb14..3741dea89d 100644 --- a/flate/deflate_test.go +++ b/flate/deflate_test.go @@ -540,6 +540,8 @@ func TestWriterReset(t *testing.T) { w.d.hasher, wref.d.hasher = nil, nil w.d.bulkHasher, wref.d.bulkHasher = nil, nil w.d.matcher, wref.d.matcher = nil, nil + // hashMatch is always overwritten when used. + copy(w.d.hashMatch[:], wref.d.hashMatch[:]) if w.d.tokens.n != 0 { t.Errorf("level %d Writer not reset after Reset. %d tokens were present", level, w.d.tokens.n) } @@ -601,3 +603,67 @@ func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) } t.Logf("got %d bytes", len(out1)) } + +// A writer that fails after N writes. +type errorWriter struct { + N int +} + +func (e *errorWriter) Write(b []byte) (int, error) { + if e.N <= 0 { + return 0, io.ErrClosedPipe + } + e.N-- + return len(b), nil +} + +// Test if errors from the underlying writer is passed upwards. +func TestWriteError(t *testing.T) { + buf := new(bytes.Buffer) + for i := 0; i < 1024*1024; i++ { + buf.WriteString(fmt.Sprintf("asdasfasf%d%dfghfgujyut%dyutyu\n", i, i, i)) + } + in := buf.Bytes() + for l := -2; l < 10; l++ { + for fail := 1; fail <= 512; fail *= 2 { + // Fail after 2 writes + ew := &errorWriter{N: fail} + w, err := NewWriter(ew, l) + if err != nil { + t.Errorf("NewWriter: level %d: %v", l, err) + } + n, err := io.Copy(w, bytes.NewBuffer(in)) + if err == nil { + t.Errorf("Level %d: Expected an error, writer was %#v", l, ew) + } + n2, err := w.Write([]byte{1, 2, 2, 3, 4, 5}) + if n2 != 0 { + t.Error("Level", l, "Expected 0 length write, got", n) + } + if err == nil { + t.Error("Level", l, "Expected an error") + } + err = w.Flush() + if err == nil { + t.Error("Level", l, "Expected an error on close") + } + err = w.Close() + if err == nil { + t.Error("Level", l, "Expected an error on close") + } + + w.Reset(ioutil.Discard) + n2, err = w.Write([]byte{1, 2, 3, 4, 5, 6}) + if err != nil { + t.Error("Level", l, "Got unexpected error after reset:", err) + } + if n2 == 0 { + t.Error("Level", l, "Got 0 length write, expected > 0") + } + if testing.Short() { + return + } + } + } + +}