Skip to content

Commit

Permalink
Optimize for special modular images (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
tirr-c authored Sep 7, 2023
1 parent d55ccc4 commit 48ca79c
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 68 deletions.
90 changes: 55 additions & 35 deletions crates/jxl-coding/src/ans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ use crate::{Error, Result};

#[derive(Debug)]
pub struct Histogram {
dist: Vec<u16>,
symbols: Vec<u16>,
offsets: Vec<u16>,
cutoffs: Vec<u16>,
buckets: Vec<Bucket>,
log_bucket_size: u32,
single_symbol: Option<u16>,
}

#[derive(Debug)]
struct Bucket {
dist: u16,
alias_symbol: u16,
alias_offset: u16,
alias_cutoff: u16,
}

impl Histogram {
Expand Down Expand Up @@ -151,60 +157,68 @@ impl Histogram {
}

if let Some(single_sym_idx) = dist.iter().position(|&d| d == 1 << 12) {
let symbols = vec![single_sym_idx as u16; table_size];
let offsets = (0..table_size as u16).map(|i| bucket_size * i).collect();
let cutoffs = vec![0u16; table_size];
let buckets = dist.into_iter()
.enumerate()
.map(|(i, dist)| Bucket {
dist,
alias_symbol: single_sym_idx as u16,
alias_offset: bucket_size * i as u16,
alias_cutoff: 0,
})
.collect();
return Ok(Self {
dist,
symbols,
offsets,
cutoffs,
buckets,
log_bucket_size,
single_symbol: Some(single_sym_idx as u16),
});
}

let mut cutoffs = dist.clone();
let mut symbols = (0..(alphabet_size as u16)).collect::<Vec<_>>();
symbols.resize(table_size, 0);
let mut offsets = vec![0u16; table_size];
let mut buckets: Vec<_> = dist
.into_iter()
.enumerate()
.map(|(i, dist)| Bucket {
dist,
alias_symbol: if i < alphabet_size { i as u16 } else { 0 },
alias_offset: 0,
alias_cutoff: dist,
})
.collect();

let mut underfull = Vec::new();
let mut overfull = Vec::new();
for (idx, d) in dist.iter().enumerate() {
match d.cmp(&bucket_size) {
for (idx, &Bucket { dist, .. }) in buckets.iter().enumerate() {
match dist.cmp(&bucket_size) {
std::cmp::Ordering::Less => underfull.push(idx),
std::cmp::Ordering::Equal => {},
std::cmp::Ordering::Greater => overfull.push(idx),
}
}
while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) {
let by = bucket_size - cutoffs[u];
cutoffs[o] -= by;
symbols[u] = o as u16;
offsets[u] = cutoffs[o];
match cutoffs[o].cmp(&bucket_size) {
let by = bucket_size - buckets[u].alias_cutoff;
buckets[o].alias_cutoff -= by;
buckets[u].alias_symbol = o as u16;
buckets[u].alias_offset = buckets[o].alias_cutoff;
match buckets[o].alias_cutoff.cmp(&bucket_size) {
std::cmp::Ordering::Less => underfull.push(o),
std::cmp::Ordering::Equal => {},
std::cmp::Ordering::Greater => overfull.push(o),
}
}

for idx in 0..table_size {
if cutoffs[idx] == bucket_size {
symbols[idx] = idx as u16;
offsets[idx] = 0;
cutoffs[idx] = 0;
for (idx, bucket) in buckets.iter_mut().enumerate() {
if bucket.alias_cutoff == bucket_size {
bucket.alias_symbol = idx as u16;
bucket.alias_offset = 0;
bucket.alias_cutoff = 0;
} else {
offsets[idx] -= cutoffs[idx];
bucket.alias_offset -= bucket.alias_cutoff;
}
}

Ok(Self {
dist,
symbols,
offsets,
cutoffs,
buckets,
log_bucket_size,
single_symbol: None,
})
}

Expand All @@ -222,8 +236,9 @@ impl Histogram {
fn map_alias(&self, idx: u16) -> (u16, u16) {
let i = (idx >> self.log_bucket_size) as usize;
let pos = idx & ((1 << self.log_bucket_size) - 1);
if pos >= self.cutoffs[i] {
(self.symbols[i], self.offsets[i] + pos)
let bucket = &self.buckets[i];
if pos >= bucket.alias_cutoff {
(bucket.alias_symbol, bucket.alias_offset + pos)
} else {
(i as u16, pos)
}
Expand All @@ -232,12 +247,17 @@ impl Histogram {
pub fn read_symbol<R: Read>(&self, bitstream: &mut Bitstream<R>, state: &mut u32) -> Result<u16> {
let idx = (*state & 0xfff) as u16;
let (symbol, offset) = self.map_alias(idx);
*state = (*state >> 12) * (self.dist[symbol as usize] as u32) + offset as u32;
*state = (*state >> 12) * (self.buckets[symbol as usize].dist as u32) + offset as u32;
if *state < (1 << 16) {
*state = (*state << 16) | bitstream.read_bits(16)?;
}
Ok(symbol)
}

#[inline]
pub fn single_symbol(&self) -> Option<u16> {
self.single_symbol
}
}

fn read_prefix<R: Read>(bitstream: &mut Bitstream<R>) -> Result<u16> {
Expand Down
66 changes: 66 additions & 0 deletions crates/jxl-coding/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,24 @@ impl Decoder {
})
}

pub fn as_rle(&mut self) -> Option<DecoderRleMode<'_>> {
let &Lz77::Enabled { ref state, min_symbol, min_length } = &self.lz77 else { return None; };
let lz_cluster = self.inner.lz_dist_cluster();
let lz_conf = &self.inner.configs[lz_cluster as usize];
let Some(sym) = self.inner.code.single_symbol(lz_cluster) else { return None; };
(sym == 1 && lz_conf.split_exponent == 0).then_some(DecoderRleMode {
inner: &mut self.inner,
min_symbol,
min_length,
len_config: state.lz_len_conf.clone(),
})
}

#[inline]
pub fn single_token(&self, cluster: u8) -> Option<u32> {
self.inner.single_token(cluster)
}

/// Explicitly start reading an entropy encoded stream.
///
/// This involves reading an initial state for the ANS stream. It's okay to skip this method,
Expand All @@ -155,6 +173,39 @@ impl Decoder {
}
}

/// An entropy decoder, in RLE mode.
#[derive(Debug)]
pub struct DecoderRleMode<'dec> {
inner: &'dec mut DecoderInner,
min_symbol: u32,
min_length: u32,
len_config: IntegerConfig,
}

#[derive(Debug, Copy, Clone)]
pub enum RleToken {
Value(u32),
Repeat(u32),
}

impl DecoderRleMode<'_> {
#[inline]
pub fn read_varint_clustered<R: std::io::Read>(
&mut self,
bitstream: &mut Bitstream<R>,
cluster: u8,
) -> Result<RleToken> {
let token = self.inner.code.read_symbol(bitstream, cluster)? as u32;
Ok(if token >= self.min_symbol {
let length = self.inner.read_uint(bitstream, &self.len_config, token - self.min_symbol)? + self.min_length;
RleToken::Repeat(length)
} else {
let value = self.inner.read_uint(bitstream, &self.inner.configs[cluster as usize], token)?;
RleToken::Value(value)
})
}
}

#[derive(Debug, Clone)]
enum Lz77 {
Disabled,
Expand Down Expand Up @@ -308,6 +359,13 @@ impl DecoderInner {
})
}

#[inline]
fn single_token(&self, cluster: u8) -> Option<u32> {
let single_symbol = self.code.single_symbol(cluster)? as u32;
let IntegerConfig { split, .. } = self.configs[cluster as usize];
(single_symbol < split).then_some(single_symbol)
}

fn read_uint<R: std::io::Read>(&self, bitstream: &mut Bitstream<R>, config: &IntegerConfig, token: u32) -> Result<u32> {
let &IntegerConfig { split_exponent, split, msb_in_token, lsb_in_token, .. } = config;
if token < split {
Expand Down Expand Up @@ -356,6 +414,14 @@ impl Coder {
}
}

#[inline]
fn single_symbol(&self, cluster: u8) -> Option<u16> {
match self {
Self::PrefixCode(dist) => dist[cluster as usize].single_symbol(),
Self::Ans { dist, .. } => dist[cluster as usize].single_symbol(),
}
}

fn begin<R: Read>(&mut self, bitstream: &mut Bitstream<R>) -> Result<()> {
match self {
Self::PrefixCode(_) => Ok(()),
Expand Down
6 changes: 6 additions & 0 deletions crates/jxl-coding/src/prefix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,10 @@ impl Histogram {
}
unreachable!()
}

#[inline]
pub fn single_symbol(&self) -> Option<u16> {
let &[symbol] = &*self.symbols else { return None; };
Some(symbol)
}
}
Loading

0 comments on commit 48ca79c

Please sign in to comment.