diff --git a/crates/jxl-grid/src/simple_grid.rs b/crates/jxl-grid/src/simple_grid.rs index 1935e7ed..4512b036 100644 --- a/crates/jxl-grid/src/simple_grid.rs +++ b/crates/jxl-grid/src/simple_grid.rs @@ -212,6 +212,42 @@ impl<'g, V: Copy> CutGrid<'g, V> { } } +impl CutGrid<'_, V> { + /// Split the grid horizontally at an index. + /// + /// # Panics + /// Panics if `x > self.width()`. + pub fn split_horizontal(&mut self, x: usize) -> (CutGrid<'_, V>, CutGrid<'_, V>) { + assert!(x <= self.width); + + let left_ptr = self.ptr; + let right_ptr = NonNull::new(self.get_ptr(x, 0)).unwrap(); + // SAFETY: two grids are contained in `self` and disjoint. + unsafe { + let left_grid = CutGrid::new(left_ptr, x, self.height, self.stride); + let right_grid = CutGrid::new(right_ptr, self.width - x, self.height, self.stride); + (left_grid, right_grid) + } + } + + /// Split the grid vertically at an index. + /// + /// # Panics + /// Panics if `y > self.height()`. + pub fn split_vertical(&mut self, y: usize) -> (CutGrid<'_, V>, CutGrid<'_, V>) { + assert!(y <= self.height); + + let top_ptr = self.ptr; + let bottom_ptr = NonNull::new(self.get_ptr(0, y)).unwrap(); + // SAFETY: two grids are contained in `self` and disjoint. + unsafe { + let top_grid = CutGrid::new(top_ptr, self.width, y, self.stride); + let bottom_grid = CutGrid::new(bottom_ptr, self.width, self.height - y, self.stride); + (top_grid, bottom_grid) + } + } +} + impl<'g> CutGrid<'g, f32> { pub fn as_vectored(&mut self) -> Option> { let mask = V::SIZE - 1; diff --git a/crates/jxl-render/src/dct/generic.rs b/crates/jxl-render/src/dct/generic.rs index bcf3e027..ab5aa479 100644 --- a/crates/jxl-render/src/dct/generic.rs +++ b/crates/jxl-render/src/dct/generic.rs @@ -5,7 +5,86 @@ use super::{consts, DctDirection}; pub fn dct_2d(io: &mut CutGrid<'_>, direction: DctDirection) { let width = io.width(); let height = io.height(); + if width * height <= 1 { + return; + } + + let mul = if direction == DctDirection::Forward { 0.5 } else { 1.0 }; + if width == 2 && height == 1 { + let v0 = io.get(0, 0); + let v1 = io.get(1, 0); + *io.get_mut(0, 0) = (v0 + v1) * mul; + *io.get_mut(1, 0) = (v0 - v1) * mul; + return; + } + if width == 1 && height == 2 { + let v0 = io.get(0, 0); + let v1 = io.get(0, 1); + *io.get_mut(0, 0) = (v0 + v1) * mul; + *io.get_mut(0, 1) = (v0 - v1) * mul; + return; + } + if width == 2 && height == 2 { + let v00 = io.get(0, 0); + let v01 = io.get(1, 0); + let v10 = io.get(0, 1); + let v11 = io.get(1, 1); + *io.get_mut(0, 0) = (v00 + v01 + v10 + v11) * mul * mul; + *io.get_mut(1, 0) = (v00 - v01 + v10 - v11) * mul * mul; + *io.get_mut(0, 1) = (v00 + v01 - v10 - v11) * mul * mul; + *io.get_mut(1, 1) = (v00 - v01 - v10 + v11) * mul * mul; + return; + } + let mut buf = vec![0f32; width.max(height)]; + if height == 1 { + dct(io.get_row_mut(0), &mut buf, direction); + return; + } + if width == 1 { + let mut row = vec![0f32; height]; + for (y, v) in row.iter_mut().enumerate() { + *v = io.get(0, y); + } + dct(&mut row, &mut buf, direction); + for (y, v) in row.into_iter().enumerate() { + *io.get_mut(0, y) = v; + } + return; + } + + if height == 2 { + let (mut row0, mut row1) = io.split_vertical(1); + let row0 = row0.get_row_mut(0); + let row1 = row1.get_row_mut(0); + for (v0, v1) in row0.iter_mut().zip(row1.iter_mut()) { + let tv0 = *v0; + let tv1 = *v1; + *v0 = (tv0 + tv1) * mul; + *v1 = (tv0 - tv1) * mul; + } + + dct(row0, &mut buf, direction); + dct(row1, &mut buf, direction); + return; + } + if width == 2 { + let mut row = vec![0f32; height * 2]; + let (row0, row1) = row.split_at_mut(height); + for y in 0..height { + let v0 = io.get(0, y); + let v1 = io.get(1, y); + row0[y] = (v0 + v1) * mul; + row1[y] = (v0 - v1) * mul; + } + dct(row0, &mut buf, direction); + dct(row1, &mut buf, direction); + for y in 0..height { + *io.get_mut(0, y) = row0[y]; + *io.get_mut(1, y) = row1[y]; + } + return; + } let row = &mut buf[..width]; for y in 0..height { diff --git a/crates/jxl-render/src/vardct/mod.rs b/crates/jxl-render/src/vardct/mod.rs index 8135f97c..273764b0 100644 --- a/crates/jxl-render/src/vardct/mod.rs +++ b/crates/jxl-render/src/vardct/mod.rs @@ -150,17 +150,11 @@ pub fn dequant_hf_varblock( let need_transpose = dct_select.need_transpose(); let mul = 65536.0 / (quantizer.global_scale as i32 * hf_mul) as f32 * qm_scale[channel]; - let mut new_matrix; - let mut matrix = dequant_matrices.get(channel, dct_select); - if need_transpose { - new_matrix = vec![0f32; matrix.len()]; - for (idx, val) in new_matrix.iter_mut().enumerate() { - let mat_x = idx % width as usize; - let mat_y = idx / width as usize; - *val = matrix[mat_x * height as usize + mat_y]; - } - matrix = &new_matrix; - } + let matrix = if need_transpose { + dequant_matrices.get_transposed(channel, dct_select) + } else { + dequant_matrices.get(channel, dct_select) + }; let mut coeff = CutGrid::from_buf( &mut coeff_buf[offset..], @@ -168,9 +162,8 @@ pub fn dequant_hf_varblock( height as usize, stride, ); - for y in 0..height { - let row = coeff.get_row_mut(y as usize); - let matrix_row = &matrix[(y * width) as usize..][..width as usize]; + for (y, matrix_row) in matrix.chunks_exact(width as usize).enumerate() { + let row = coeff.get_row_mut(y); for (q, &m) in row.iter_mut().zip(matrix_row) { if q.abs() <= 1.0f32 { *q *= quant_bias; @@ -292,14 +285,6 @@ pub fn transform_with_lf( ) { use TransformType::*; - fn scale_f(c: usize, b: usize) -> f32 { - let cb = c as f32 / b as f32; - let recip = (cb * std::f32::consts::FRAC_PI_2).cos() * - (cb * std::f32::consts::PI).cos() * - (cb * 2.0 * std::f32::consts::PI).cos(); - recip.recip() - } - let lf_region = lf.region(); let coeff_region = coeff_out.region(); let lf = lf.buffer(); @@ -393,3 +378,28 @@ pub fn transform_with_lf( } } } + +fn scale_f(c: usize, b: usize) -> f32 { + // Precomputed for c = 0..32, b = 256 + #[allow(clippy::excessive_precision)] + const SCALE_F: [f32; 32] = [ + 1.0000000000000000, 0.9996047255830407, + 0.9984194528776054, 0.9964458326264695, + 0.9936866130906366, 0.9901456355893141, + 0.9858278282666936, 0.9807391980963174, + 0.9748868211368796, 0.9682788310563117, + 0.9609244059440204, 0.9528337534340876, + 0.9440180941651672, 0.9344896436056892, + 0.9242615922757944, 0.9133480844001980, + 0.9017641950288744, 0.8895259056651056, + 0.8766500784429904, 0.8631544288990163, + 0.8490574973847023, 0.8343786191696513, + 0.8191378932865928, 0.8033561501721485, + 0.7870549181591013, 0.7702563888779096, + 0.7529833816270532, 0.7352593067735488, + 0.7171081282466044, 0.6985543251889097, + 0.6796228528314652, 0.6603391026591464, + ]; + let c = c * (256 / b); + SCALE_F[c].recip() +} diff --git a/crates/jxl-render/src/vardct/transform.rs b/crates/jxl-render/src/vardct/transform.rs index 9c6e598d..aede772f 100644 --- a/crates/jxl-render/src/vardct/transform.rs +++ b/crates/jxl-render/src/vardct/transform.rs @@ -3,11 +3,11 @@ use jxl_vardct::TransformType; use crate::dct::{dct_2d, DctDirection}; -fn aux_idct2_in_place(block: &mut CutGrid<'_>, size: usize) { - debug_assert!(size.is_power_of_two()); +fn aux_idct2_in_place(block: &mut CutGrid<'_>) { + debug_assert!(SIZE.is_power_of_two()); - let num_2x2 = size / 2; - let mut scratch = vec![0.0f32; size * size]; + let num_2x2 = SIZE / 2; + let mut scratch = [[0.0f32; SIZE]; SIZE]; for y in 0..num_2x2 { for x in 0..num_2x2 { let c00 = block.get(x, y); @@ -15,27 +15,26 @@ fn aux_idct2_in_place(block: &mut CutGrid<'_>, size: usize) { let c10 = block.get(x, y + num_2x2); let c11 = block.get(x + num_2x2, y + num_2x2); - let base_idx = 2 * (y * size + x); - scratch[base_idx] = c00 + c01 + c10 + c11; - scratch[base_idx + 1] = c00 + c01 - c10 - c11; - scratch[base_idx + size] = c00 - c01 + c10 - c11; - scratch[base_idx + size + 1] = c00 - c01 - c10 + c11; + scratch[2 * y][2 * x] = c00 + c01 + c10 + c11; + scratch[2 * y][2 * x + 1] = c00 + c01 - c10 - c11; + scratch[2 * y + 1][2 * x] = c00 - c01 + c10 - c11; + scratch[2 * y + 1][2 * x + 1] = c00 - c01 - c10 + c11; } } - for y in 0..size { - block.get_row_mut(y)[..size].copy_from_slice(&scratch[y * size..][..size]); + for (y, scratch_row) in scratch.into_iter().enumerate() { + block.get_row_mut(y)[..SIZE].copy_from_slice(&scratch_row); } } fn transform_dct2(coeff: &mut CutGrid<'_>) { - aux_idct2_in_place(coeff, 2); - aux_idct2_in_place(coeff, 4); - aux_idct2_in_place(coeff, 8); + aux_idct2_in_place::<2>(coeff); + aux_idct2_in_place::<4>(coeff); + aux_idct2_in_place::<8>(coeff); } fn transform_dct4(coeff: &mut CutGrid<'_>) { - aux_idct2_in_place(coeff, 2); + aux_idct2_in_place::<2>(coeff); let mut scratch = [0.0f32; 64]; for y in 0..2 { @@ -63,7 +62,7 @@ fn transform_dct4(coeff: &mut CutGrid<'_>) { } fn transform_hornuss(coeff: &mut CutGrid<'_>) { - aux_idct2_in_place(coeff, 2); + aux_idct2_in_place::<2>(coeff); let mut scratch = [0.0f32; 64]; for y in 0..2 { diff --git a/crates/jxl-vardct/src/dequant.rs b/crates/jxl-vardct/src/dequant.rs index c7535cce..21d03fb8 100644 --- a/crates/jxl-vardct/src/dequant.rs +++ b/crates/jxl-vardct/src/dequant.rs @@ -488,6 +488,7 @@ impl BundleDefault for DequantMatrixParams { #[derive(Debug)] pub struct DequantMatrixSet { matrices: Vec<[Vec; 3]>, + matrices_tr: Vec<[Vec; 3]>, } impl Bundle> for DequantMatrixSet { @@ -528,10 +529,27 @@ impl Bundle> for DequantMatrixSet { }).collect::>()? }; - let matrices = param_list.into_iter() + let matrices: Vec<_> = param_list.into_iter() .map(|params| params.into_matrix()) .collect(); - Ok(Self { matrices }) + let matrices_tr = matrices + .iter() + .zip(DCT_SELECT_LIST) + .map(|(matrix, dct_select)| { + std::array::from_fn(|idx| { + let matrix = &matrix[idx]; + let (width, height) = dct_select.dequant_matrix_size(); + let mut out = vec![0f32; matrix.len()]; + for (idx, val) in out.iter_mut().enumerate() { + let mat_x = idx % height as usize; + let mat_y = idx / height as usize; + *val = matrix[mat_x * width as usize + mat_y]; + } + out + }) + }) + .collect(); + Ok(Self { matrices, matrices_tr }) } } @@ -563,4 +581,32 @@ impl DequantMatrixSet { }; &self.matrices[idx][channel] } + + /// Returns the transposed dequantization matrix for the given channel and transform type. + /// + /// The coefficients is in the raster order. + pub fn get_transposed(&self, channel: usize, dct_select: TransformType) -> &[f32] { + use TransformType::*; + + let idx = match dct_select { + Dct8 => 0, + Hornuss => 1, + Dct2 => 2, + Dct4 => 3, + Dct16 => 4, + Dct32 => 5, + Dct8x16 | Dct16x8 => 6, + Dct8x32 | Dct32x8 => 7, + Dct16x32 | Dct32x16 => 8, + Dct4x8 | Dct8x4 => 9, + Afv0 | Afv1 | Afv2 | Afv3 => 10, + Dct64 => 11, + Dct32x64 | Dct64x32 => 12, + Dct128 => 13, + Dct64x128 | Dct128x64 => 14, + Dct256 => 15, + Dct128x256 | Dct256x128 => 16, + }; + &self.matrices_tr[idx][channel] + } }