Skip to content

Commit

Permalink
jxl-color: Optimize color transform ops (#158)
Browse files Browse the repository at this point in the history
* jxl-color: SIMD PQ transfer function

* jxl-color: Add a note about libjxl

* jxl-color: NEON tone mapper

* jxl-color: x86_64 SIMD tone mapper

* jxl-color: NEON gamut mapper

* jxl-color: x86_64 SIMD gamut mapper

* jxl-color: Make sure SIMD operations are inlined

* jxl-color: Multithreaded color transformation

* jxl-color: Merge matrix ops

* jxl-color: Remove some trace logs
  • Loading branch information
tirr-c authored Dec 30, 2023
1 parent c8f0d41 commit d5fe259
Show file tree
Hide file tree
Showing 13 changed files with 2,201 additions and 443 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions crates/jxl-color/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ path = "../jxl-coding"
version = "0.2.0"
path = "../jxl-grid"

[dependencies.jxl-threadpool]
version = "0.1.0"
path = "../jxl-threadpool"

[dependencies.tracing]
version = "0.1.37"
default_features = false
Expand Down
93 changes: 80 additions & 13 deletions crates/jxl-color/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
ToneMapping, TransferFunction,
};

mod gamut_map;
mod tone_map;

#[derive(Clone)]
Expand Down Expand Up @@ -393,10 +394,12 @@ impl ColorTransform {
});
}

Ok(Self {
let mut ret = Self {
begin_channels,
ops,
})
};
ret.optimize();
Ok(ret)
}

pub fn xyb_to_enum(
Expand All @@ -421,11 +424,84 @@ impl ColorTransform {

let mut num_channels = self.begin_channels;
for op in &self.ops {
tracing::trace!(?op);
num_channels = op.run(channels, num_channels, cms)?;
}
Ok(num_channels)
}

pub fn run_with_threads<Cms: ColorManagementSystem + Sync + ?Sized>(
&self,
channels: &mut [&mut [f32]],
cms: &Cms,
pool: &jxl_threadpool::JxlThreadPool,
) -> Result<usize> {
let _gurad = tracing::trace_span!("Run color transform ops").entered();

let mut chunks = Vec::new();
let mut it = channels
.iter_mut()
.map(|ch| ch.chunks_mut(65536))
.collect::<Vec<_>>();
loop {
let Some(chunk) = it
.iter_mut()
.map(|it| it.next())
.collect::<Option<Vec<_>>>()
else {
break;
};
chunks.push(chunk);
}

let ret = std::sync::Mutex::new(Ok(self.begin_channels));
pool.for_each_vec(chunks, |mut channels| {
let mut num_channels = self.begin_channels;
for op in &self.ops {
match op.run(&mut channels, num_channels, cms) {
Ok(x) => {
num_channels = x;
}
err => {
*ret.lock().unwrap() = err;
return;
}
}
}
*ret.lock().unwrap() = Ok(num_channels);
});
ret.into_inner().unwrap()
}

fn optimize(&mut self) {
let mut matrix_op_from = None;
let mut matrix = [0f32; 9];
let mut idx = 0usize;
let mut len = self.ops.len();
while idx < len {
let op = &self.ops[idx];
if let ColorTransformOp::Matrix(mat) = op {
if matrix_op_from.is_none() {
matrix_op_from = Some(idx);
matrix = *mat;
} else {
matrix = matmul3(mat, &matrix);
}
} else if let Some(from) = matrix_op_from {
self.ops[from] = ColorTransformOp::Matrix(matrix);
self.ops.drain((from + 1)..idx);
matrix_op_from = None;
idx = from + 1;
len = self.ops.len();
continue;
}
idx += 1;
}

if let Some(from) = matrix_op_from {
self.ops[from] = ColorTransformOp::Matrix(matrix);
self.ops.drain((from + 1)..len);
}
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -676,16 +752,7 @@ impl ColorTransformOp {
let [r, g, b, ..] = channels else {
unreachable!()
};
for ((r, g), b) in r.iter_mut().zip(&mut **g).zip(&mut **b) {
let mapped = crate::gamut::map_gamut_generic(
[*r, *g, *b],
*luminances,
*saturation_factor,
);
*r = mapped[0];
*g = mapped[1];
*b = mapped[2];
}
gamut_map::gamut_map(r, g, b, *luminances, *saturation_factor);
3
}
Self::IccToIcc { from, to, .. } => {
Expand Down
172 changes: 172 additions & 0 deletions crates/jxl-color/src/convert/gamut_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#[cfg(target_arch = "aarch64")]
use std::arch::is_aarch64_feature_detected;
#[cfg(target_arch = "x86_64")]
use std::arch::is_x86_feature_detected;

pub(super) fn gamut_map(
mut r: &mut [f32],
mut g: &mut [f32],
mut b: &mut [f32],
luminances: [f32; 3],
saturation_factor: f32,
) {
assert_eq!(r.len(), g.len());
assert_eq!(g.len(), b.len());

#[cfg(target_arch = "x86_64")]
if is_x86_feature_detected!("fma") && is_x86_feature_detected!("sse4.1") {
if is_x86_feature_detected!("avx2") {
// AVX2
// SAFETY: features are checked above.
unsafe {
(r, g, b) = gamut_map_x86_64_avx2(r, g, b, luminances, saturation_factor);
}
}

// SSE4.1 + FMA
// SAFETY: features are checked above.
unsafe {
(r, g, b) = gamut_map_x86_64_fma(r, g, b, luminances, saturation_factor);
}
} else {
// SAFETY: x86_64 implies SSE2.
unsafe {
let mut r_it = r.chunks_exact_mut(4);
let mut g_it = g.chunks_exact_mut(4);
let mut b_it = b.chunks_exact_mut(4);

for ((r, g), b) in (&mut r_it).zip(&mut g_it).zip(&mut b_it) {
let rgb = [
std::arch::x86_64::_mm_loadu_ps(r.as_ptr()),
std::arch::x86_64::_mm_loadu_ps(g.as_ptr()),
std::arch::x86_64::_mm_loadu_ps(b.as_ptr()),
];
let [vr, vg, vb] =
crate::gamut::map_gamut_x86_64_sse2(rgb, luminances, saturation_factor);
std::arch::x86_64::_mm_storeu_ps(r.as_mut_ptr(), vr);
std::arch::x86_64::_mm_storeu_ps(g.as_mut_ptr(), vg);
std::arch::x86_64::_mm_storeu_ps(b.as_mut_ptr(), vb);
}

r = r_it.into_remainder();
g = g_it.into_remainder();
b = b_it.into_remainder();
}
}

#[cfg(target_arch = "aarch64")]
if is_aarch64_feature_detected!("neon") {
// NEON
// SAFETY: features are checked above.
unsafe {
(r, g, b) = gamut_map_aarch64_neon(r, g, b, luminances, saturation_factor);
}
}

// generic
for ((r, g), b) in r.iter_mut().zip(g).zip(b) {
let mapped = crate::gamut::map_gamut_generic([*r, *g, *b], luminances, saturation_factor);
*r = mapped[0];
*g = mapped[1];
*b = mapped[2];
}
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
#[target_feature(enable = "sse4.1")]
pub(super) unsafe fn gamut_map_x86_64_avx2<'r, 'g, 'b>(
r: &'r mut [f32],
g: &'g mut [f32],
b: &'b mut [f32],
luminances: [f32; 3],
saturation_factor: f32,
) -> (&'r mut [f32], &'g mut [f32], &'b mut [f32]) {
let mut r_it = r.chunks_exact_mut(8);
let mut g_it = g.chunks_exact_mut(8);
let mut b_it = b.chunks_exact_mut(8);

for ((r, g), b) in (&mut r_it).zip(&mut g_it).zip(&mut b_it) {
let rgb = [
std::arch::x86_64::_mm256_loadu_ps(r.as_ptr()),
std::arch::x86_64::_mm256_loadu_ps(g.as_ptr()),
std::arch::x86_64::_mm256_loadu_ps(b.as_ptr()),
];
let [vr, vg, vb] = crate::gamut::map_gamut_x86_64_avx2(rgb, luminances, saturation_factor);
std::arch::x86_64::_mm256_storeu_ps(r.as_mut_ptr(), vr);
std::arch::x86_64::_mm256_storeu_ps(g.as_mut_ptr(), vg);
std::arch::x86_64::_mm256_storeu_ps(b.as_mut_ptr(), vb);
}

(
r_it.into_remainder(),
g_it.into_remainder(),
b_it.into_remainder(),
)
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "fma")]
#[target_feature(enable = "sse4.1")]
pub(super) unsafe fn gamut_map_x86_64_fma<'r, 'g, 'b>(
r: &'r mut [f32],
g: &'g mut [f32],
b: &'b mut [f32],
luminances: [f32; 3],
saturation_factor: f32,
) -> (&'r mut [f32], &'g mut [f32], &'b mut [f32]) {
let mut r_it = r.chunks_exact_mut(4);
let mut g_it = g.chunks_exact_mut(4);
let mut b_it = b.chunks_exact_mut(4);

for ((r, g), b) in (&mut r_it).zip(&mut g_it).zip(&mut b_it) {
let rgb = [
std::arch::x86_64::_mm_loadu_ps(r.as_ptr()),
std::arch::x86_64::_mm_loadu_ps(g.as_ptr()),
std::arch::x86_64::_mm_loadu_ps(b.as_ptr()),
];
let [vr, vg, vb] = crate::gamut::map_gamut_x86_64_fma(rgb, luminances, saturation_factor);
std::arch::x86_64::_mm_storeu_ps(r.as_mut_ptr(), vr);
std::arch::x86_64::_mm_storeu_ps(g.as_mut_ptr(), vg);
std::arch::x86_64::_mm_storeu_ps(b.as_mut_ptr(), vb);
}

(
r_it.into_remainder(),
g_it.into_remainder(),
b_it.into_remainder(),
)
}

#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub(super) unsafe fn gamut_map_aarch64_neon<'r, 'g, 'b>(
r: &'r mut [f32],
g: &'g mut [f32],
b: &'b mut [f32],
luminances: [f32; 3],
saturation_factor: f32,
) -> (&'r mut [f32], &'g mut [f32], &'b mut [f32]) {
let mut r_it = r.chunks_exact_mut(4);
let mut g_it = g.chunks_exact_mut(4);
let mut b_it = b.chunks_exact_mut(4);

for ((r, g), b) in (&mut r_it).zip(&mut g_it).zip(&mut b_it) {
let rgb = [
std::arch::aarch64::vld1q_f32(r.as_ptr()),
std::arch::aarch64::vld1q_f32(g.as_ptr()),
std::arch::aarch64::vld1q_f32(b.as_ptr()),
];
let [vr, vg, vb] = crate::gamut::map_gamut_aarch64_neon(rgb, luminances, saturation_factor);
std::arch::aarch64::vst1q_f32(r.as_mut_ptr(), vr);
std::arch::aarch64::vst1q_f32(g.as_mut_ptr(), vg);
std::arch::aarch64::vst1q_f32(b.as_mut_ptr(), vb);
}

(
r_it.into_remainder(),
g_it.into_remainder(),
b_it.into_remainder(),
)
}
Loading

0 comments on commit d5fe259

Please sign in to comment.