diff --git a/crates/jxl-coding/src/ans.rs b/crates/jxl-coding/src/ans.rs index c3e2aa40..7b4d9fcb 100644 --- a/crates/jxl-coding/src/ans.rs +++ b/crates/jxl-coding/src/ans.rs @@ -6,11 +6,17 @@ use crate::{Error, Result}; #[derive(Debug)] pub struct Histogram { - dist: Vec, - symbols: Vec, - offsets: Vec, - cutoffs: Vec, + buckets: Vec, log_bucket_size: u32, + single_symbol: Option, +} + +#[derive(Debug)] +struct Bucket { + dist: u16, + alias_symbol: u16, + alias_offset: u16, + alias_cutoff: u16, } impl Histogram { @@ -151,60 +157,68 @@ impl Histogram { } if let Some(single_sym_idx) = dist.iter().position(|&d| d == 1 << 12) { - let symbols = vec![single_sym_idx as u16; table_size]; - let offsets = (0..table_size as u16).map(|i| bucket_size * i).collect(); - let cutoffs = vec![0u16; table_size]; + let buckets = dist.into_iter() + .enumerate() + .map(|(i, dist)| Bucket { + dist, + alias_symbol: single_sym_idx as u16, + alias_offset: bucket_size * i as u16, + alias_cutoff: 0, + }) + .collect(); return Ok(Self { - dist, - symbols, - offsets, - cutoffs, + buckets, log_bucket_size, + single_symbol: Some(single_sym_idx as u16), }); } - let mut cutoffs = dist.clone(); - let mut symbols = (0..(alphabet_size as u16)).collect::>(); - symbols.resize(table_size, 0); - let mut offsets = vec![0u16; table_size]; + let mut buckets: Vec<_> = dist + .into_iter() + .enumerate() + .map(|(i, dist)| Bucket { + dist, + alias_symbol: if i < alphabet_size { i as u16 } else { 0 }, + alias_offset: 0, + alias_cutoff: dist, + }) + .collect(); let mut underfull = Vec::new(); let mut overfull = Vec::new(); - for (idx, d) in dist.iter().enumerate() { - match d.cmp(&bucket_size) { + for (idx, &Bucket { dist, .. }) in buckets.iter().enumerate() { + match dist.cmp(&bucket_size) { std::cmp::Ordering::Less => underfull.push(idx), std::cmp::Ordering::Equal => {}, std::cmp::Ordering::Greater => overfull.push(idx), } } while let (Some(o), Some(u)) = (overfull.pop(), underfull.pop()) { - let by = bucket_size - cutoffs[u]; - cutoffs[o] -= by; - symbols[u] = o as u16; - offsets[u] = cutoffs[o]; - match cutoffs[o].cmp(&bucket_size) { + let by = bucket_size - buckets[u].alias_cutoff; + buckets[o].alias_cutoff -= by; + buckets[u].alias_symbol = o as u16; + buckets[u].alias_offset = buckets[o].alias_cutoff; + match buckets[o].alias_cutoff.cmp(&bucket_size) { std::cmp::Ordering::Less => underfull.push(o), std::cmp::Ordering::Equal => {}, std::cmp::Ordering::Greater => overfull.push(o), } } - for idx in 0..table_size { - if cutoffs[idx] == bucket_size { - symbols[idx] = idx as u16; - offsets[idx] = 0; - cutoffs[idx] = 0; + for (idx, bucket) in buckets.iter_mut().enumerate() { + if bucket.alias_cutoff == bucket_size { + bucket.alias_symbol = idx as u16; + bucket.alias_offset = 0; + bucket.alias_cutoff = 0; } else { - offsets[idx] -= cutoffs[idx]; + bucket.alias_offset -= bucket.alias_cutoff; } } Ok(Self { - dist, - symbols, - offsets, - cutoffs, + buckets, log_bucket_size, + single_symbol: None, }) } @@ -222,8 +236,9 @@ impl Histogram { fn map_alias(&self, idx: u16) -> (u16, u16) { let i = (idx >> self.log_bucket_size) as usize; let pos = idx & ((1 << self.log_bucket_size) - 1); - if pos >= self.cutoffs[i] { - (self.symbols[i], self.offsets[i] + pos) + let bucket = &self.buckets[i]; + if pos >= bucket.alias_cutoff { + (bucket.alias_symbol, bucket.alias_offset + pos) } else { (i as u16, pos) } @@ -232,12 +247,17 @@ impl Histogram { pub fn read_symbol(&self, bitstream: &mut Bitstream, state: &mut u32) -> Result { let idx = (*state & 0xfff) as u16; let (symbol, offset) = self.map_alias(idx); - *state = (*state >> 12) * (self.dist[symbol as usize] as u32) + offset as u32; + *state = (*state >> 12) * (self.buckets[symbol as usize].dist as u32) + offset as u32; if *state < (1 << 16) { *state = (*state << 16) | bitstream.read_bits(16)?; } Ok(symbol) } + + #[inline] + pub fn single_symbol(&self) -> Option { + self.single_symbol + } } fn read_prefix(bitstream: &mut Bitstream) -> Result { diff --git a/crates/jxl-coding/src/lib.rs b/crates/jxl-coding/src/lib.rs index c05100c7..1930d6aa 100644 --- a/crates/jxl-coding/src/lib.rs +++ b/crates/jxl-coding/src/lib.rs @@ -133,6 +133,24 @@ impl Decoder { }) } + pub fn as_rle(&mut self) -> Option> { + let &Lz77::Enabled { ref state, min_symbol, min_length } = &self.lz77 else { return None; }; + let lz_cluster = self.inner.lz_dist_cluster(); + let lz_conf = &self.inner.configs[lz_cluster as usize]; + let Some(sym) = self.inner.code.single_symbol(lz_cluster) else { return None; }; + (sym == 1 && lz_conf.split_exponent == 0).then_some(DecoderRleMode { + inner: &mut self.inner, + min_symbol, + min_length, + len_config: state.lz_len_conf.clone(), + }) + } + + #[inline] + pub fn single_token(&self, cluster: u8) -> Option { + self.inner.single_token(cluster) + } + /// Explicitly start reading an entropy encoded stream. /// /// This involves reading an initial state for the ANS stream. It's okay to skip this method, @@ -155,6 +173,39 @@ impl Decoder { } } +/// An entropy decoder, in RLE mode. +#[derive(Debug)] +pub struct DecoderRleMode<'dec> { + inner: &'dec mut DecoderInner, + min_symbol: u32, + min_length: u32, + len_config: IntegerConfig, +} + +#[derive(Debug, Copy, Clone)] +pub enum RleToken { + Value(u32), + Repeat(u32), +} + +impl DecoderRleMode<'_> { + #[inline] + pub fn read_varint_clustered( + &mut self, + bitstream: &mut Bitstream, + cluster: u8, + ) -> Result { + let token = self.inner.code.read_symbol(bitstream, cluster)? as u32; + Ok(if token >= self.min_symbol { + let length = self.inner.read_uint(bitstream, &self.len_config, token - self.min_symbol)? + self.min_length; + RleToken::Repeat(length) + } else { + let value = self.inner.read_uint(bitstream, &self.inner.configs[cluster as usize], token)?; + RleToken::Value(value) + }) + } +} + #[derive(Debug, Clone)] enum Lz77 { Disabled, @@ -308,6 +359,13 @@ impl DecoderInner { }) } + #[inline] + fn single_token(&self, cluster: u8) -> Option { + let single_symbol = self.code.single_symbol(cluster)? as u32; + let IntegerConfig { split, .. } = self.configs[cluster as usize]; + (single_symbol < split).then_some(single_symbol) + } + fn read_uint(&self, bitstream: &mut Bitstream, config: &IntegerConfig, token: u32) -> Result { let &IntegerConfig { split_exponent, split, msb_in_token, lsb_in_token, .. } = config; if token < split { @@ -356,6 +414,14 @@ impl Coder { } } + #[inline] + fn single_symbol(&self, cluster: u8) -> Option { + match self { + Self::PrefixCode(dist) => dist[cluster as usize].single_symbol(), + Self::Ans { dist, .. } => dist[cluster as usize].single_symbol(), + } + } + fn begin(&mut self, bitstream: &mut Bitstream) -> Result<()> { match self { Self::PrefixCode(_) => Ok(()), diff --git a/crates/jxl-coding/src/prefix.rs b/crates/jxl-coding/src/prefix.rs index e9b737f4..fb29ce0c 100644 --- a/crates/jxl-coding/src/prefix.rs +++ b/crates/jxl-coding/src/prefix.rs @@ -287,4 +287,10 @@ impl Histogram { } unreachable!() } + + #[inline] + pub fn single_symbol(&self) -> Option { + let &[symbol] = &*self.symbols else { return None; }; + Some(symbol) + } } diff --git a/crates/jxl-modular/src/image.rs b/crates/jxl-modular/src/image.rs index f87c5f3f..3c1bb7d4 100644 --- a/crates/jxl-modular/src/image.rs +++ b/crates/jxl-modular/src/image.rs @@ -1,13 +1,14 @@ use std::io::Read; -use jxl_bitstream::Bitstream; +use jxl_bitstream::{Bitstream, unpack_signed}; +use jxl_coding::{DecoderRleMode, RleToken}; use jxl_grid::Grid; use crate::{ ModularChannels, - predictor::{WpHeader, PredictorState}, + predictor::{WpHeader, PredictorState, Predictor}, Result, - SubimageChannelInfo, MaConfig, + SubimageChannelInfo, MaConfig, ma::{MaTreeLeafClustered, FlatMaTree}, ModularChannelInfo, }; /// Decoded Modular image. @@ -80,6 +81,11 @@ impl Image { let mut decoder = ma_ctx.decoder().clone(); decoder.begin(bitstream)?; + if let Some(rle_decoder) = decoder.as_rle() { + self.decode_image_rle(bitstream, stream_index, wp_header, ma_ctx, rle_decoder)?; + decoder.finalize()?; + return Ok(()); + } let dist_multiplier = self.channels.info.iter() .skip(self.channels.nb_meta_channels as usize) @@ -87,54 +93,202 @@ impl Image { .max() .unwrap_or(0); - let mut channels: Vec<_> = self.channels.info.iter() - .zip(self.data.iter_mut()) - .enumerate() - .filter(|(_, (_, grid))| grid.width() != 0 && grid.height() != 0) - .collect(); - let len = channels.len(); - for idx in 0..len { - let (prev, left) = channels.split_at_mut(idx); - let (i, (info, ref mut grid)) = left[0]; - let prev = prev + let mut prev: Vec<(&ModularChannelInfo, &mut Grid)> = Vec::new(); + for (i, (info, grid)) in self.channels.info.iter().zip(&mut self.data).enumerate() { + let filtered_prev = prev .iter() - .filter(|(_, (prev_info, _))| { + .filter(|&(prev_info, _)| { info.width == prev_info.width && info.height == prev_info.height && info.hshift == prev_info.hshift && info.vshift == prev_info.vshift }) + .map(|(_, g)| &**g) .collect::>(); + let width = grid.width(); + let height = grid.height(); + let ma_tree = ma_ctx.make_flat_tree(i as u32, stream_index); + if let Some(&MaTreeLeafClustered { cluster, predictor, offset, multiplier }) = ma_tree.single_node() { + tracing::trace!(cluster, ?predictor, "Single MA tree node"); + if predictor == Predictor::Zero { + if let Some(token) = decoder.single_token(cluster) { + tracing::trace!("Single token in cluster: hyper fast path"); + let value = unpack_signed(token) * multiplier as i32 + offset; + for y in 0..height { + for x in 0..width { + grid.set(x, y, value); + } + } + } else { + tracing::trace!("Fast path"); + for y in 0..height { + for x in 0..width { + let token = decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?; + let value = unpack_signed(token) * multiplier as i32 + offset; + grid.set(x, y, value); + } + } + } + continue; + } + if predictor == Predictor::Gradient && offset == 0 && multiplier == 1 { + tracing::trace!("Quite fast path"); + let mut prev_row = vec![0i32; width]; + for y in 0..height { + let mut w = prev_row[0] as i64; + let mut nw = w; + for (x, prev) in prev_row.iter_mut().enumerate() { + let n = if y == 0 { w } else { *prev as i64 }; + let pred = (n + w - nw).clamp(w.min(n), w.max(n)); + + let token = decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?; + let value = ((unpack_signed(token) as i64) + pred) as i32; + grid.set(x, y, value); + *prev = value; + nw = n; + w = value as i64; + } + } + continue; + } + } + let wp_header = ma_tree.need_self_correcting().then_some(wp_header); + let mut predictor = PredictorState::new(width as u32, i as u32, stream_index, filtered_prev.len(), wp_header); + let mut next = |cluster: u8| -> Result { + let token = decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?; + Ok(unpack_signed(token)) + }; + decode_channel_slow(&mut next, &ma_tree, &mut predictor, grid, &filtered_prev)?; + prev.push((info, grid)); + } + + decoder.finalize()?; + Ok(()) + } + + fn decode_image_rle( + &mut self, + bitstream: &mut Bitstream, + stream_index: u32, + wp_header: &WpHeader, + ma_ctx: &MaConfig, + mut decoder: DecoderRleMode<'_>, + ) -> Result<()> { + let mut rle_value = 0i32; + let mut rle_left = 0u32; + + let mut next = |cluster: u8| -> Result { + Ok(if rle_left > 0 { + rle_left -= 1; + rle_value + } else { + match decoder.read_varint_clustered(bitstream, cluster)? { + RleToken::Value(v) => { + rle_value = unpack_signed(v); + rle_value + }, + RleToken::Repeat(len) => { + rle_left = len - 1; + rle_value + }, + } + }) + }; + + let mut prev: Vec<(&ModularChannelInfo, &mut Grid)> = Vec::new(); + for (i, (info, grid)) in self.channels.info.iter().zip(&mut self.data).enumerate() { + let filtered_prev = prev + .iter() + .filter(|&(prev_info, _)| { + info.width == prev_info.width && + info.height == prev_info.height && + info.hshift == prev_info.hshift && + info.vshift == prev_info.vshift + }) + .map(|(_, g)| &**g) + .collect::>(); let width = grid.width(); let height = grid.height(); - let mut predictor = PredictorState::new(width as u32, i as u32, stream_index, prev.len(), wp_header); - let mut prev_channel_samples = vec![0i32; prev.len()]; - for y in 0..height { - for x in 0..width { - for ((_, (_, grid)), sample) in prev.iter().zip(&mut prev_channel_samples) { - *sample = *grid.get(x, y).unwrap(); + let ma_tree = ma_ctx.make_flat_tree(i as u32, stream_index); + if let Some(&MaTreeLeafClustered { cluster, predictor, offset, multiplier }) = ma_tree.single_node() { + tracing::trace!(cluster, ?predictor, "Single MA tree node"); + if predictor == Predictor::Zero { + tracing::trace!("Quite fast path"); + for y in 0..height { + for x in 0..width { + let token = next(cluster)?; + let value = (token * multiplier as i32) + offset; + grid.set(x, y, value); + } } + continue; + } + if predictor == Predictor::Gradient && offset == 0 && multiplier == 1 { + tracing::trace!("libjxl fast-lossless: quite fast path"); + let mut prev_row = vec![0i32; width]; + for y in 0..height { + let mut w = prev_row[0] as i64; + let mut nw = w; + for (x, prev) in prev_row.iter_mut().enumerate() { + let n = if y == 0 { w } else { *prev as i64 }; + let pred = (n + w - nw).clamp(w.min(n), w.max(n)); - let properties = predictor.properties(&prev_channel_samples); - let (diff, predictor) = ma_tree.decode_sample(bitstream, &mut decoder, &properties, dist_multiplier)?; - let sample_prediction = predictor.predict(&properties); - let true_value = (diff as i64 + sample_prediction) as i32; - grid.set(x, y, true_value); - properties.record(true_value); + let token = next(cluster)?; + let value = ((token as i64) + pred) as i32; + grid.set(x, y, value); + *prev = value; + nw = n; + w = value as i64; + } + } + continue; } } - } - decoder.finalize()?; + let wp_header = ma_tree.need_self_correcting().then_some(wp_header); + let mut predictor = PredictorState::new(width as u32, i as u32, stream_index, filtered_prev.len(), wp_header); + decode_channel_slow(&mut next, &ma_tree, &mut predictor, grid, &filtered_prev)?; + prev.push((info, grid)); + } Ok(()) } } +fn decode_channel_slow( + next: &mut impl FnMut(u8) -> Result, + ma_tree: &FlatMaTree, + predictor: &mut PredictorState, + grid: &mut Grid, + prev: &[&Grid], +) -> Result<()> { + let width = grid.width(); + let height = grid.height(); + + let mut prev_channel_samples = vec![0i32; prev.len()]; + + for y in 0..height { + for x in 0..width { + for (grid, sample) in prev.iter().zip(&mut prev_channel_samples) { + *sample = *grid.get(x, y).unwrap(); + } + + let properties = predictor.properties(&prev_channel_samples); + let (diff, predictor) = ma_tree.decode_sample_rle(next, &properties)?; + let sample_prediction = predictor.predict(&properties); + let true_value = (diff as i64 + sample_prediction) as i32; + grid.set(x, y, true_value); + properties.record(true_value); + } + } + + Ok(()) +} + impl Image { pub fn group_dim(&self) -> u32 { self.group_dim diff --git a/crates/jxl-modular/src/ma.rs b/crates/jxl-modular/src/ma.rs index 3ec08ac3..3732a22f 100644 --- a/crates/jxl-modular/src/ma.rs +++ b/crates/jxl-modular/src/ma.rs @@ -173,11 +173,11 @@ enum FlatMaTreeNode { } #[derive(Debug, Clone, PartialEq, Eq)] -struct MaTreeLeafClustered { - cluster: u8, - predictor: super::predictor::Predictor, - offset: i32, - multiplier: u32, +pub(crate) struct MaTreeLeafClustered { + pub(crate) cluster: u8, + pub(crate) predictor: super::predictor::Predictor, + pub(crate) offset: i32, + pub(crate) multiplier: u32, } impl FlatMaTree { @@ -227,6 +227,26 @@ impl FlatMaTree { let diff = unpack_signed(diff) * leaf.multiplier as i32 + leaf.offset; Ok((diff, leaf.predictor)) } + + #[inline] + pub(crate) fn decode_sample_rle( + &self, + next: &mut impl FnMut(u8) -> Result, + properties: &Properties, + ) -> Result<(i32, super::predictor::Predictor)> { + let leaf = self.get_leaf(properties)?; + let diff = next(leaf.cluster)?; + let diff = diff * leaf.multiplier as i32 + leaf.offset; + Ok((diff, leaf.predictor)) + } + + #[inline] + pub(crate) fn single_node(&self) -> Option<&MaTreeLeafClustered> { + match self.nodes.get(0) { + Some(FlatMaTreeNode::Leaf(node)) => Some(node), + _ => None, + } + } } #[derive(Debug)]