Skip to content

Commit

Permalink
zstd: Improve block encoding speed (#456)
Browse files Browse the repository at this point in the history
* zstd: Improve block encoding speed
* Unify loops, avoid check.
  • Loading branch information
klauspost authored Dec 1, 2021
1 parent 6f71bfc commit 901aaf2
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 49 deletions.
22 changes: 21 additions & 1 deletion zstd/bitwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,34 @@ func (b *bitWriter) addBits16NC(value uint16, bits uint8) {
b.nBits += bits
}

// addBits32NC will add up to 32 bits.
// addBits32NC will add up to 31 bits.
// It will not check if there is space for them,
// so the caller must ensure that it has flushed recently.
func (b *bitWriter) addBits32NC(value uint32, bits uint8) {
b.bitContainer |= uint64(value&bitMask32[bits&31]) << (b.nBits & 63)
b.nBits += bits
}

// addBits64NC will add up to 64 bits.
// There must be space for 32 bits.
func (b *bitWriter) addBits64NC(value uint64, bits uint8) {
if bits <= 31 {
b.addBits32Clean(uint32(value), bits)
return
}
b.addBits32Clean(uint32(value), 32)
b.flush32()
b.addBits32Clean(uint32(value>>32), bits-32)
}

// addBits32Clean will add up to 32 bits.
// It will not check if there is space for them.
// The input must not contain more bits than specified.
func (b *bitWriter) addBits32Clean(value uint32, bits uint8) {
b.bitContainer |= uint64(value) << (b.nBits & 63)
b.nBits += bits
}

// addBits16Clean will add up to 16 bits. value may not contain more set bits than indicated.
// It will not check if there is space for them, so the caller must ensure that it has flushed recently.
func (b *bitWriter) addBits16Clean(value uint16, bits uint8) {
Expand Down
97 changes: 49 additions & 48 deletions zstd/blockenc.go
Original file line number Diff line number Diff line change
Expand Up @@ -722,52 +722,53 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error {
println("Encoded seq", seq, s, "codes:", s.llCode, s.mlCode, s.ofCode, "states:", ll.state, ml.state, of.state, "bits:", llB, mlB, ofB)
}
seq--
if llEnc.maxBits+mlEnc.maxBits+ofEnc.maxBits <= 32 {
// No need to flush (common)
for seq >= 0 {
s = b.sequences[seq]
wr.flush32()
llB, ofB, mlB := llTT[s.llCode], ofTT[s.ofCode], mlTT[s.mlCode]
// tabelog max is 8 for all.
of.encode(ofB)
ml.encode(mlB)
ll.encode(llB)
wr.flush32()

// We checked that all can stay within 32 bits
wr.addBits32NC(s.litLen, llB.outBits)
wr.addBits32NC(s.matchLen, mlB.outBits)
wr.addBits32NC(s.offset, ofB.outBits)

if debugSequences {
println("Encoded seq", seq, s)
}

seq--
}
} else {
for seq >= 0 {
s = b.sequences[seq]
wr.flush32()
llB, ofB, mlB := llTT[s.llCode], ofTT[s.ofCode], mlTT[s.mlCode]
// tabelog max is below 8 for each.
of.encode(ofB)
ml.encode(mlB)
ll.encode(llB)
wr.flush32()

// ml+ll = max 32 bits total
wr.addBits32NC(s.litLen, llB.outBits)
wr.addBits32NC(s.matchLen, mlB.outBits)
wr.flush32()
wr.addBits32NC(s.offset, ofB.outBits)

if debugSequences {
println("Encoded seq", seq, s)
}

seq--
}
// Store sequences in reverse...
for seq >= 0 {
s = b.sequences[seq]

ofB := ofTT[s.ofCode]
wr.flush32() // tablelog max is below 8 for each, so it will fill max 24 bits.
//of.encode(ofB)
nbBitsOut := (uint32(of.state) + ofB.deltaNbBits) >> 16
dstState := int32(of.state>>(nbBitsOut&15)) + int32(ofB.deltaFindState)
wr.addBits16NC(of.state, uint8(nbBitsOut))
of.state = of.stateTable[dstState]

// Accumulate extra bits.
outBits := ofB.outBits & 31
extraBits := uint64(s.offset & bitMask32[outBits])
extraBitsN := outBits

mlB := mlTT[s.mlCode]
//ml.encode(mlB)
nbBitsOut = (uint32(ml.state) + mlB.deltaNbBits) >> 16
dstState = int32(ml.state>>(nbBitsOut&15)) + int32(mlB.deltaFindState)
wr.addBits16NC(ml.state, uint8(nbBitsOut))
ml.state = ml.stateTable[dstState]

outBits = mlB.outBits & 31
extraBits = extraBits<<outBits | uint64(s.matchLen&bitMask32[outBits])
extraBitsN += outBits

llB := llTT[s.llCode]
//ll.encode(llB)
nbBitsOut = (uint32(ll.state) + llB.deltaNbBits) >> 16
dstState = int32(ll.state>>(nbBitsOut&15)) + int32(llB.deltaFindState)
wr.addBits16NC(ll.state, uint8(nbBitsOut))
ll.state = ll.stateTable[dstState]

outBits = llB.outBits & 31
extraBits = extraBits<<outBits | uint64(s.litLen&bitMask32[outBits])
extraBitsN += outBits

wr.flush32()
wr.addBits64NC(extraBits, extraBitsN)

if debugSequences {
println("Encoded seq", seq, s)
}

seq--
}
ml.flush(mlEnc.actualTableLog)
of.flush(ofEnc.actualTableLog)
Expand Down Expand Up @@ -820,7 +821,8 @@ func (b *blockEnc) genCodes() {
}

var llMax, ofMax, mlMax uint8
for i, seq := range b.sequences {
for i := range b.sequences {
seq := &b.sequences[i]
v := llCode(seq.litLen)
seq.llCode = v
llH[v]++
Expand All @@ -844,7 +846,6 @@ func (b *blockEnc) genCodes() {
panic(fmt.Errorf("mlMax > maxMatchLengthSymbol (%d), matchlen: %d", mlMax, seq.matchLen))
}
}
b.sequences[i] = seq
}
maxCount := func(a []uint32) int {
var max uint32
Expand Down

0 comments on commit 901aaf2

Please sign in to comment.