Skip to content

Commit

Permalink
Various optimizations for VarDCT frames (#81)
Browse files Browse the repository at this point in the history
* jxl-render: scale_f lookup table

* jxl-vardct: Keep transposed dequant matrices

* jxl-grid: Add CutGrid split methods

* jxl-render: Optimized small-sized DCT

* jxl-grid: Add docs for split methods
  • Loading branch information
tirr-c authored Sep 6, 2023
1 parent 7853f90 commit d55ccc4
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 40 deletions.
36 changes: 36 additions & 0 deletions crates/jxl-grid/src/simple_grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,42 @@ impl<'g, V: Copy> CutGrid<'g, V> {
}
}

impl<V: Copy> CutGrid<'_, V> {
/// Split the grid horizontally at an index.
///
/// # Panics
/// Panics if `x > self.width()`.
pub fn split_horizontal(&mut self, x: usize) -> (CutGrid<'_, V>, CutGrid<'_, V>) {
assert!(x <= self.width);

let left_ptr = self.ptr;
let right_ptr = NonNull::new(self.get_ptr(x, 0)).unwrap();
// SAFETY: two grids are contained in `self` and disjoint.
unsafe {
let left_grid = CutGrid::new(left_ptr, x, self.height, self.stride);
let right_grid = CutGrid::new(right_ptr, self.width - x, self.height, self.stride);
(left_grid, right_grid)
}
}

/// Split the grid vertically at an index.
///
/// # Panics
/// Panics if `y > self.height()`.
pub fn split_vertical(&mut self, y: usize) -> (CutGrid<'_, V>, CutGrid<'_, V>) {
assert!(y <= self.height);

let top_ptr = self.ptr;
let bottom_ptr = NonNull::new(self.get_ptr(0, y)).unwrap();
// SAFETY: two grids are contained in `self` and disjoint.
unsafe {
let top_grid = CutGrid::new(top_ptr, self.width, y, self.stride);
let bottom_grid = CutGrid::new(bottom_ptr, self.width, self.height - y, self.stride);
(top_grid, bottom_grid)
}
}
}

impl<'g> CutGrid<'g, f32> {
pub fn as_vectored<V: SimdVector>(&mut self) -> Option<CutGrid<'_, V>> {
let mask = V::SIZE - 1;
Expand Down
79 changes: 79 additions & 0 deletions crates/jxl-render/src/dct/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,86 @@ use super::{consts, DctDirection};
pub fn dct_2d(io: &mut CutGrid<'_>, direction: DctDirection) {
let width = io.width();
let height = io.height();
if width * height <= 1 {
return;
}

let mul = if direction == DctDirection::Forward { 0.5 } else { 1.0 };
if width == 2 && height == 1 {
let v0 = io.get(0, 0);
let v1 = io.get(1, 0);
*io.get_mut(0, 0) = (v0 + v1) * mul;
*io.get_mut(1, 0) = (v0 - v1) * mul;
return;
}
if width == 1 && height == 2 {
let v0 = io.get(0, 0);
let v1 = io.get(0, 1);
*io.get_mut(0, 0) = (v0 + v1) * mul;
*io.get_mut(0, 1) = (v0 - v1) * mul;
return;
}
if width == 2 && height == 2 {
let v00 = io.get(0, 0);
let v01 = io.get(1, 0);
let v10 = io.get(0, 1);
let v11 = io.get(1, 1);
*io.get_mut(0, 0) = (v00 + v01 + v10 + v11) * mul * mul;
*io.get_mut(1, 0) = (v00 - v01 + v10 - v11) * mul * mul;
*io.get_mut(0, 1) = (v00 + v01 - v10 - v11) * mul * mul;
*io.get_mut(1, 1) = (v00 - v01 - v10 + v11) * mul * mul;
return;
}

let mut buf = vec![0f32; width.max(height)];
if height == 1 {
dct(io.get_row_mut(0), &mut buf, direction);
return;
}
if width == 1 {
let mut row = vec![0f32; height];
for (y, v) in row.iter_mut().enumerate() {
*v = io.get(0, y);
}
dct(&mut row, &mut buf, direction);
for (y, v) in row.into_iter().enumerate() {
*io.get_mut(0, y) = v;
}
return;
}

if height == 2 {
let (mut row0, mut row1) = io.split_vertical(1);
let row0 = row0.get_row_mut(0);
let row1 = row1.get_row_mut(0);
for (v0, v1) in row0.iter_mut().zip(row1.iter_mut()) {
let tv0 = *v0;
let tv1 = *v1;
*v0 = (tv0 + tv1) * mul;
*v1 = (tv0 - tv1) * mul;
}

dct(row0, &mut buf, direction);
dct(row1, &mut buf, direction);
return;
}
if width == 2 {
let mut row = vec![0f32; height * 2];
let (row0, row1) = row.split_at_mut(height);
for y in 0..height {
let v0 = io.get(0, y);
let v1 = io.get(1, y);
row0[y] = (v0 + v1) * mul;
row1[y] = (v0 - v1) * mul;
}
dct(row0, &mut buf, direction);
dct(row1, &mut buf, direction);
for y in 0..height {
*io.get_mut(0, y) = row0[y];
*io.get_mut(1, y) = row1[y];
}
return;
}

let row = &mut buf[..width];
for y in 0..height {
Expand Down
54 changes: 32 additions & 22 deletions crates/jxl-render/src/vardct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,27 +150,20 @@ pub fn dequant_hf_varblock(
let need_transpose = dct_select.need_transpose();
let mul = 65536.0 / (quantizer.global_scale as i32 * hf_mul) as f32 * qm_scale[channel];

let mut new_matrix;
let mut matrix = dequant_matrices.get(channel, dct_select);
if need_transpose {
new_matrix = vec![0f32; matrix.len()];
for (idx, val) in new_matrix.iter_mut().enumerate() {
let mat_x = idx % width as usize;
let mat_y = idx / width as usize;
*val = matrix[mat_x * height as usize + mat_y];
}
matrix = &new_matrix;
}
let matrix = if need_transpose {
dequant_matrices.get_transposed(channel, dct_select)
} else {
dequant_matrices.get(channel, dct_select)
};

let mut coeff = CutGrid::from_buf(
&mut coeff_buf[offset..],
width as usize,
height as usize,
stride,
);
for y in 0..height {
let row = coeff.get_row_mut(y as usize);
let matrix_row = &matrix[(y * width) as usize..][..width as usize];
for (y, matrix_row) in matrix.chunks_exact(width as usize).enumerate() {
let row = coeff.get_row_mut(y);
for (q, &m) in row.iter_mut().zip(matrix_row) {
if q.abs() <= 1.0f32 {
*q *= quant_bias;
Expand Down Expand Up @@ -292,14 +285,6 @@ pub fn transform_with_lf(
) {
use TransformType::*;

fn scale_f(c: usize, b: usize) -> f32 {
let cb = c as f32 / b as f32;
let recip = (cb * std::f32::consts::FRAC_PI_2).cos() *
(cb * std::f32::consts::PI).cos() *
(cb * 2.0 * std::f32::consts::PI).cos();
recip.recip()
}

let lf_region = lf.region();
let coeff_region = coeff_out.region();
let lf = lf.buffer();
Expand Down Expand Up @@ -393,3 +378,28 @@ pub fn transform_with_lf(
}
}
}

fn scale_f(c: usize, b: usize) -> f32 {
// Precomputed for c = 0..32, b = 256
#[allow(clippy::excessive_precision)]
const SCALE_F: [f32; 32] = [
1.0000000000000000, 0.9996047255830407,
0.9984194528776054, 0.9964458326264695,
0.9936866130906366, 0.9901456355893141,
0.9858278282666936, 0.9807391980963174,
0.9748868211368796, 0.9682788310563117,
0.9609244059440204, 0.9528337534340876,
0.9440180941651672, 0.9344896436056892,
0.9242615922757944, 0.9133480844001980,
0.9017641950288744, 0.8895259056651056,
0.8766500784429904, 0.8631544288990163,
0.8490574973847023, 0.8343786191696513,
0.8191378932865928, 0.8033561501721485,
0.7870549181591013, 0.7702563888779096,
0.7529833816270532, 0.7352593067735488,
0.7171081282466044, 0.6985543251889097,
0.6796228528314652, 0.6603391026591464,
];
let c = c * (256 / b);
SCALE_F[c].recip()
}
31 changes: 15 additions & 16 deletions crates/jxl-render/src/vardct/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,38 @@ use jxl_vardct::TransformType;

use crate::dct::{dct_2d, DctDirection};

fn aux_idct2_in_place(block: &mut CutGrid<'_>, size: usize) {
debug_assert!(size.is_power_of_two());
fn aux_idct2_in_place<const SIZE: usize>(block: &mut CutGrid<'_>) {
debug_assert!(SIZE.is_power_of_two());

let num_2x2 = size / 2;
let mut scratch = vec![0.0f32; size * size];
let num_2x2 = SIZE / 2;
let mut scratch = [[0.0f32; SIZE]; SIZE];
for y in 0..num_2x2 {
for x in 0..num_2x2 {
let c00 = block.get(x, y);
let c01 = block.get(x + num_2x2, y);
let c10 = block.get(x, y + num_2x2);
let c11 = block.get(x + num_2x2, y + num_2x2);

let base_idx = 2 * (y * size + x);
scratch[base_idx] = c00 + c01 + c10 + c11;
scratch[base_idx + 1] = c00 + c01 - c10 - c11;
scratch[base_idx + size] = c00 - c01 + c10 - c11;
scratch[base_idx + size + 1] = c00 - c01 - c10 + c11;
scratch[2 * y][2 * x] = c00 + c01 + c10 + c11;
scratch[2 * y][2 * x + 1] = c00 + c01 - c10 - c11;
scratch[2 * y + 1][2 * x] = c00 - c01 + c10 - c11;
scratch[2 * y + 1][2 * x + 1] = c00 - c01 - c10 + c11;
}
}

for y in 0..size {
block.get_row_mut(y)[..size].copy_from_slice(&scratch[y * size..][..size]);
for (y, scratch_row) in scratch.into_iter().enumerate() {
block.get_row_mut(y)[..SIZE].copy_from_slice(&scratch_row);
}
}

fn transform_dct2(coeff: &mut CutGrid<'_>) {
aux_idct2_in_place(coeff, 2);
aux_idct2_in_place(coeff, 4);
aux_idct2_in_place(coeff, 8);
aux_idct2_in_place::<2>(coeff);
aux_idct2_in_place::<4>(coeff);
aux_idct2_in_place::<8>(coeff);
}

fn transform_dct4(coeff: &mut CutGrid<'_>) {
aux_idct2_in_place(coeff, 2);
aux_idct2_in_place::<2>(coeff);

let mut scratch = [0.0f32; 64];
for y in 0..2 {
Expand Down Expand Up @@ -63,7 +62,7 @@ fn transform_dct4(coeff: &mut CutGrid<'_>) {
}

fn transform_hornuss(coeff: &mut CutGrid<'_>) {
aux_idct2_in_place(coeff, 2);
aux_idct2_in_place::<2>(coeff);

let mut scratch = [0.0f32; 64];
for y in 0..2 {
Expand Down
50 changes: 48 additions & 2 deletions crates/jxl-vardct/src/dequant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ impl BundleDefault<TransformType> for DequantMatrixParams {
#[derive(Debug)]
pub struct DequantMatrixSet {
matrices: Vec<[Vec<f32>; 3]>,
matrices_tr: Vec<[Vec<f32>; 3]>,
}

impl Bundle<DequantMatrixSetParams<'_>> for DequantMatrixSet {
Expand Down Expand Up @@ -528,10 +529,27 @@ impl Bundle<DequantMatrixSetParams<'_>> for DequantMatrixSet {
}).collect::<Result<_>>()?
};

let matrices = param_list.into_iter()
let matrices: Vec<_> = param_list.into_iter()
.map(|params| params.into_matrix())
.collect();
Ok(Self { matrices })
let matrices_tr = matrices
.iter()
.zip(DCT_SELECT_LIST)
.map(|(matrix, dct_select)| {
std::array::from_fn(|idx| {
let matrix = &matrix[idx];
let (width, height) = dct_select.dequant_matrix_size();
let mut out = vec![0f32; matrix.len()];
for (idx, val) in out.iter_mut().enumerate() {
let mat_x = idx % height as usize;
let mat_y = idx / height as usize;
*val = matrix[mat_x * width as usize + mat_y];
}
out
})
})
.collect();
Ok(Self { matrices, matrices_tr })
}
}

Expand Down Expand Up @@ -563,4 +581,32 @@ impl DequantMatrixSet {
};
&self.matrices[idx][channel]
}

/// Returns the transposed dequantization matrix for the given channel and transform type.
///
/// The coefficients is in the raster order.
pub fn get_transposed(&self, channel: usize, dct_select: TransformType) -> &[f32] {
use TransformType::*;

let idx = match dct_select {
Dct8 => 0,
Hornuss => 1,
Dct2 => 2,
Dct4 => 3,
Dct16 => 4,
Dct32 => 5,
Dct8x16 | Dct16x8 => 6,
Dct8x32 | Dct32x8 => 7,
Dct16x32 | Dct32x16 => 8,
Dct4x8 | Dct8x4 => 9,
Afv0 | Afv1 | Afv2 | Afv3 => 10,
Dct64 => 11,
Dct32x64 | Dct64x32 => 12,
Dct128 => 13,
Dct64x128 | Dct128x64 => 14,
Dct256 => 15,
Dct128x256 | Dct256x128 => 16,
};
&self.matrices_tr[idx][channel]
}
}

0 comments on commit d55ccc4

Please sign in to comment.