Skip to content

Commit

Permalink
tag=FseCode should take into account "variable bit-packing" (#1251)
Browse files Browse the repository at this point in the history
* fix: account for variable bit-packing in fse code section

* chore: range starts from 1 (ignore 0)
  • Loading branch information
roynalnaruto authored May 6, 2024
1 parent 60477bd commit 1011f5b
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 59 deletions.
172 changes: 113 additions & 59 deletions aggregator/src/aggregation/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,18 +443,11 @@ pub struct BitstreamDecoder {
bit_index_end_cmp_23: ComparatorConfig<Fr, 1>,
/// The value of the binary bitstring.
bitstring_value: Column<Advice>,
/// Helper gadget to know when the bitstring value is 0. This contributes to an edge-case in
/// decoding and reconstructing the FSE table from normalised distributions, where a value=0
/// implies prob=-1 ("less than 1" probability). In this case, the symbol is allocated a state
/// at the end of the FSE table, with baseline=0x00 and nb=AL, i.e. reset state.
bitstring_value_eq_0: IsEqualConfig<Fr>,
/// Helper gadget to know when the bitstring value is 1 or 3. This is useful in the case
/// of decoding/reconstruction of FSE table, where a value=1 implies a special case of
/// prob=0, where the symbol is instead followed by a 2-bit repeat flag. The repeat flag
/// bits themselves could be followed by another 2-bit repeat flag if the repeat flag's
/// value is 3.
bitstring_value_eq_1: IsEqualConfig<Fr>,
/// Helper config as per the above doc.
/// When we have encountered a symbol with value=1, i.e. prob=0, it is followed by 2-bits
/// repeat bits flag that tells us the number of symbols following the current one that also
/// have a probability of prob=0. If the repeat bits flag itself is [1, 1], i.e.
/// bitstring_value==3, then it is followed by another 2-bits repeat bits flag and so on. We
/// utilise this equality config to identify these cases.
bitstring_value_eq_3: IsEqualConfig<Fr>,
/// Boolean that is set for a special case:
/// - The bitstring that we have read in the current row is byte-aligned up to the next or the
Expand Down Expand Up @@ -504,18 +497,6 @@ impl BitstreamDecoder {
u8_table.into(),
),
bitstring_value,
bitstring_value_eq_0: IsEqualChip::configure(
meta,
|meta| not::expr(meta.query_advice(is_padding, Rotation::cur())),
|meta| meta.query_advice(bitstring_value, Rotation::cur()),
|_| 0.expr(),
),
bitstring_value_eq_1: IsEqualChip::configure(
meta,
|meta| not::expr(meta.query_advice(is_padding, Rotation::cur())),
|meta| meta.query_advice(bitstring_value, Rotation::cur()),
|_| 1.expr(),
),
bitstring_value_eq_3: IsEqualChip::configure(
meta,
|meta| not::expr(meta.query_advice(is_padding, Rotation::cur())),
Expand Down Expand Up @@ -552,25 +533,6 @@ impl BitstreamDecoder {
meta.query_advice(self.is_nb0, rotation)
}

/// If the bitstring value is 0.
fn is_prob_less_than1(
&self,
meta: &mut VirtualCells<Fr>,
rotation: Rotation,
) -> Expression<Fr> {
let bitstring_value = meta.query_advice(self.bitstring_value, rotation);
self.bitstring_value_eq_0
.expr_at(meta, rotation, bitstring_value, 1.expr())
}

/// While reconstructing the FSE table, indicates whether a value=1 was found, i.e. prob=0. In
/// this case, the symbol is followed by 2-bits repeat flag instead.
fn is_prob0(&self, meta: &mut VirtualCells<Fr>, rotation: Rotation) -> Expression<Fr> {
let bitstring_value = meta.query_advice(self.bitstring_value, rotation);
self.bitstring_value_eq_1
.expr_at(meta, rotation, bitstring_value, 1.expr())
}

/// Whether the 2-bits repeat flag was [1, 1]. In this case, the repeat flag is followed by
/// another repeat flag.
fn is_rb_flag3(&self, meta: &mut VirtualCells<Fr>, rotation: Rotation) -> Expression<Fr> {
Expand Down Expand Up @@ -702,24 +664,49 @@ pub struct FseDecoder {
table_size: Column<Advice>,
/// The incremental symbol for which probability is decoded.
symbol: Column<Advice>,
/// The value decoded as per variable bit-packing.
value_decoded: Column<Advice>,
/// An accumulator of the number of states allocated to each symbol as we decode the FSE table.
/// This is the normalised probability for the symbol.
probability_acc: Column<Advice>,
/// Whether we are in the repeat bits loop.
is_repeat_bits_loop: Column<Advice>,
/// Whether this row represents the 0-7 trailing bits that should be ignored.
is_trailing_bits: Column<Advice>,
/// Helper gadget to know when the decoded value is 0. This contributes to an edge-case in
/// decoding and reconstructing the FSE table from normalised distributions, where a value=0
/// implies prob=-1 ("less than 1" probability). In this case, the symbol is allocated a state
/// at the end of the FSE table, with baseline=0x00 and nb=AL, i.e. reset state.
value_decoded_eq_0: IsEqualConfig<Fr>,
/// Helper gadget to know when the decoded value is 1. This is useful in the edge-case in
/// decoding and reconstructing the FSE table, where a value=1 implies a special case of
/// prob=0, where the symbol is instead followed by a 2-bit repeat flag.
value_decoded_eq_1: IsEqualConfig<Fr>,
}

impl FseDecoder {
fn configure(meta: &mut ConstraintSystem<Fr>) -> Self {
fn configure(meta: &mut ConstraintSystem<Fr>, is_padding: Column<Advice>) -> Self {
let value_decoded = meta.advice_column();
Self {
table_kind: meta.advice_column(),
table_size: meta.advice_column(),
symbol: meta.advice_column(),
value_decoded,
probability_acc: meta.advice_column(),
is_repeat_bits_loop: meta.advice_column(),
is_trailing_bits: meta.advice_column(),
value_decoded_eq_0: IsEqualChip::configure(
meta,
|meta| not::expr(meta.query_advice(is_padding, Rotation::cur())),
|meta| meta.query_advice(value_decoded, Rotation::cur()),
|_| 0.expr(),
),
value_decoded_eq_1: IsEqualChip::configure(
meta,
|meta| not::expr(meta.query_advice(is_padding, Rotation::cur())),
|meta| meta.query_advice(value_decoded, Rotation::cur()),
|_| 1.expr(),
),
}
}
}
Expand All @@ -746,6 +733,25 @@ impl FseDecoder {
* (table_kind.expr() - FseTableKind::MLT.expr())
* invert_of_2
}

/// If the decoded value is 0.
fn is_prob_less_than1(
&self,
meta: &mut VirtualCells<Fr>,
rotation: Rotation,
) -> Expression<Fr> {
let value_decoded = meta.query_advice(self.value_decoded, rotation);
self.value_decoded_eq_0
.expr_at(meta, rotation, value_decoded, 1.expr())
}

/// While reconstructing the FSE table, indicates whether a value=1 was found, i.e. prob=0. In
/// this case, the symbol is followed by 2-bits repeat flag instead.
fn is_prob0(&self, meta: &mut VirtualCells<Fr>, rotation: Rotation) -> Expression<Fr> {
let value_decoded = meta.query_advice(self.value_decoded, rotation);
self.value_decoded_eq_1
.expr_at(meta, rotation, value_decoded, 1.expr())
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -964,7 +970,7 @@ impl DecoderConfig {
let sequences_header_decoder =
SequencesHeaderDecoder::configure(meta, byte, is_padding, u8_table);
let bitstream_decoder = BitstreamDecoder::configure(meta, is_padding, u8_table);
let fse_decoder = FseDecoder::configure(meta);
let fse_decoder = FseDecoder::configure(meta, is_padding);
let sequences_data_decoder = SequencesDataDecoder::configure(meta);

// TODO(enable):
Expand Down Expand Up @@ -2143,7 +2149,7 @@ impl DecoderConfig {
cb.condition(
and::expr([
not::expr(is_repeat_bits_loop.expr()),
config.bitstream_decoder.is_prob0(meta, Rotation::cur()),
config.fse_decoder.is_prob0(meta, Rotation::cur()),
]),
|cb| {
cb.require_equal(
Expand Down Expand Up @@ -2193,12 +2199,20 @@ impl DecoderConfig {
// updating and the FSE symbol itself.
//
// If no bitstring was read, even the symbol value is carried forward.
let (prob_acc_cur, prob_acc_prev, fse_symbol_cur, fse_symbol_prev, value) = (
let (
prob_acc_cur,
prob_acc_prev,
fse_symbol_cur,
fse_symbol_prev,
bitstring_value,
value_decoded,
) = (
meta.query_advice(config.fse_decoder.probability_acc, Rotation::cur()),
meta.query_advice(config.fse_decoder.probability_acc, Rotation::prev()),
meta.query_advice(config.fse_decoder.symbol, Rotation::cur()),
meta.query_advice(config.fse_decoder.symbol, Rotation::prev()),
meta.query_advice(config.bitstream_decoder.bitstring_value, Rotation::cur()),
meta.query_advice(config.fse_decoder.value_decoded, Rotation::cur()),
);
cb.condition(
config.bitstream_decoder.is_nil(meta, Rotation::cur()),
Expand Down Expand Up @@ -2231,11 +2245,9 @@ impl DecoderConfig {
"fse: probability_acc is updated correctly",
prob_acc_cur.expr(),
select::expr(
config
.bitstream_decoder
.is_prob_less_than1(meta, Rotation::cur()),
config.fse_decoder.is_prob_less_than1(meta, Rotation::cur()),
prob_acc_prev.expr() + 1.expr(),
prob_acc_prev.expr() + value.expr() - 1.expr(),
prob_acc_prev.expr() + value_decoded.expr() - 1.expr(),
),
);
cb.require_equal(
Expand Down Expand Up @@ -2268,7 +2280,7 @@ impl DecoderConfig {
cb.require_equal(
"fse: repeat-bits increases by the 2-bit value",
fse_symbol_cur,
fse_symbol_prev + value,
fse_symbol_prev + bitstring_value,
);
});

Expand Down Expand Up @@ -2379,6 +2391,49 @@ impl DecoderConfig {
},
);

meta.lookup_any(
"DecoderConfig: tag ZstdBlockSequenceFseCode (variable bit-packing)",
|meta| {
// At every row where a non-nil bitstring is read:
// - except the AL bits (is_change=true)
// - except when we are in repeat-bits loop
// - except the trailing bits (if they exist)
let condition = and::expr([
meta.query_advice(config.tag_config.is_fse_code, Rotation::cur()),
config.bitstream_decoder.is_not_nil(meta, Rotation::cur()),
not::expr(meta.query_advice(config.tag_config.is_change, Rotation::cur())),
not::expr(
meta.query_advice(config.fse_decoder.is_repeat_bits_loop, Rotation::cur()),
),
]);

let (table_size, probability_acc, value_read, value_decoded, num_bits) = (
meta.query_advice(config.fse_decoder.table_size, Rotation::cur()),
meta.query_advice(config.fse_decoder.probability_acc, Rotation::prev()),
meta.query_advice(config.bitstream_decoder.bitstring_value, Rotation::cur()),
meta.query_advice(config.fse_decoder.value_decoded, Rotation::cur()),
config
.bitstream_decoder
.bitstring_len_unchecked(meta, Rotation::cur()),
);

let range = table_size - probability_acc + 1.expr();
[
FixedLookupTag::VariableBitPacking.expr(),
range,
value_read,
value_decoded,
num_bits,
0.expr(),
0.expr(),
]
.into_iter()
.zip_eq(config.fixed_table.table_exprs(meta))
.map(|(arg, table)| (condition.expr() * arg, table))
.collect()
},
);

meta.lookup_any(
"DecoderConfig: tag ZstdBlockSequenceFseCode (normalised probability of symbol)",
|meta| {
Expand All @@ -2391,7 +2446,7 @@ impl DecoderConfig {
meta.query_advice(config.tag_config.is_fse_code, Rotation::cur()),
config.bitstream_decoder.is_not_nil(meta, Rotation::cur()),
not::expr(meta.query_advice(config.tag_config.is_change, Rotation::cur())),
not::expr(config.bitstream_decoder.is_prob0(meta, Rotation::cur())),
not::expr(config.fse_decoder.is_prob0(meta, Rotation::cur())),
not::expr(
meta.query_advice(config.fse_decoder.is_repeat_bits_loop, Rotation::cur()),
),
Expand All @@ -2400,20 +2455,19 @@ impl DecoderConfig {
),
]);

let (block_idx, fse_table_kind, fse_table_size, fse_symbol, bitstring_value) = (
let (block_idx, fse_table_kind, fse_table_size, fse_symbol, value_decoded) = (
meta.query_advice(config.block_config.block_idx, Rotation::cur()),
meta.query_advice(config.fse_decoder.table_kind, Rotation::cur()),
meta.query_advice(config.fse_decoder.table_size, Rotation::cur()),
meta.query_advice(config.fse_decoder.symbol, Rotation::cur()),
meta.query_advice(config.bitstream_decoder.bitstring_value, Rotation::cur()),
meta.query_advice(config.fse_decoder.value_decoded, Rotation::cur()),
);
let is_prob_less_than1 = config
.bitstream_decoder
.is_prob_less_than1(meta, Rotation::cur());
let is_prob_less_than1 =
config.fse_decoder.is_prob_less_than1(meta, Rotation::cur());
let norm_prob = select::expr(
is_prob_less_than1.expr(),
1.expr(),
bitstring_value - 1.expr(),
value_decoded - 1.expr(),
);

[
Expand Down
9 changes: 9 additions & 0 deletions aggregator/src/aggregation/decoder/tables/fixed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ use seq_tag_order::RomSeqTagOrder;
mod tag_transition;
use tag_transition::RomTagTransition;

mod variable_bit_packing;
use variable_bit_packing::RomVariableBitPacking;

pub trait FixedLookupValues {
fn values() -> Vec<[Value<Fr>; 7]>;
}
Expand All @@ -52,6 +55,11 @@ pub enum FixedLookupTag {
/// Represents the FSE table reconstructed from the default distributions, i.e. Predefined FSE
/// table.
PredefinedFse,
/// Represents read and decoded values for the variable bit-packing as specified in the [zstd
/// comopression format][doclink]:
///
/// doclink: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#fse-table-description
VariableBitPacking,
}

impl_expr!(FixedLookupTag);
Expand All @@ -65,6 +73,7 @@ impl FixedLookupTag {
Self::SeqCodeToValue => RomSeqCodeToValue::values(),
Self::FseTableTransition => RomFseTableTransition::values(),
Self::PredefinedFse => RomPredefinedFse::values(),
Self::VariableBitPacking => RomVariableBitPacking::values(),
}
}
}
Expand Down
Loading

0 comments on commit 1011f5b

Please sign in to comment.