Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jxl-color: Optimize color transform ops #158

Merged
merged 10 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions crates/jxl-color/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ path = "../jxl-coding"
version = "0.2.0"
path = "../jxl-grid"

[dependencies.jxl-threadpool]
version = "0.1.0"
path = "../jxl-threadpool"

[dependencies.tracing]
version = "0.1.37"
default_features = false
Expand Down
93 changes: 80 additions & 13 deletions crates/jxl-color/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
ToneMapping, TransferFunction,
};

mod gamut_map;
mod tone_map;

#[derive(Clone)]
Expand Down Expand Up @@ -393,10 +394,12 @@ impl ColorTransform {
});
}

Ok(Self {
let mut ret = Self {
begin_channels,
ops,
})
};
ret.optimize();
Ok(ret)
}

pub fn xyb_to_enum(
Expand All @@ -421,11 +424,84 @@ impl ColorTransform {

let mut num_channels = self.begin_channels;
for op in &self.ops {
tracing::trace!(?op);
num_channels = op.run(channels, num_channels, cms)?;
}
Ok(num_channels)
}

pub fn run_with_threads<Cms: ColorManagementSystem + Sync + ?Sized>(
&self,
channels: &mut [&mut [f32]],
cms: &Cms,
pool: &jxl_threadpool::JxlThreadPool,
) -> Result<usize> {
let _gurad = tracing::trace_span!("Run color transform ops").entered();

let mut chunks = Vec::new();
let mut it = channels
.iter_mut()
.map(|ch| ch.chunks_mut(65536))
.collect::<Vec<_>>();
loop {
let Some(chunk) = it
.iter_mut()
.map(|it| it.next())
.collect::<Option<Vec<_>>>()
else {
break;
};
chunks.push(chunk);
}

let ret = std::sync::Mutex::new(Ok(self.begin_channels));
pool.for_each_vec(chunks, |mut channels| {
let mut num_channels = self.begin_channels;
for op in &self.ops {
match op.run(&mut channels, num_channels, cms) {
Ok(x) => {
num_channels = x;
}
err => {
*ret.lock().unwrap() = err;
return;
}
}
}
*ret.lock().unwrap() = Ok(num_channels);
});
ret.into_inner().unwrap()
}

fn optimize(&mut self) {
let mut matrix_op_from = None;
let mut matrix = [0f32; 9];
let mut idx = 0usize;
let mut len = self.ops.len();
while idx < len {
let op = &self.ops[idx];
if let ColorTransformOp::Matrix(mat) = op {
if matrix_op_from.is_none() {
matrix_op_from = Some(idx);
matrix = *mat;
} else {
matrix = matmul3(mat, &matrix);
}
} else if let Some(from) = matrix_op_from {
self.ops[from] = ColorTransformOp::Matrix(matrix);
self.ops.drain((from + 1)..idx);
matrix_op_from = None;
idx = from + 1;
len = self.ops.len();
continue;
}
idx += 1;
}

if let Some(from) = matrix_op_from {
self.ops[from] = ColorTransformOp::Matrix(matrix);
self.ops.drain((from + 1)..len);
}
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -676,16 +752,7 @@ impl ColorTransformOp {
let [r, g, b, ..] = channels else {
unreachable!()
};
for ((r, g), b) in r.iter_mut().zip(&mut **g).zip(&mut **b) {
let mapped = crate::gamut::map_gamut_generic(
[*r, *g, *b],
*luminances,
*saturation_factor,
);
*r = mapped[0];
*g = mapped[1];
*b = mapped[2];
}
gamut_map::gamut_map(r, g, b, *luminances, *saturation_factor);
3
}
Self::IccToIcc { from, to, .. } => {
Expand Down
172 changes: 172 additions & 0 deletions crates/jxl-color/src/convert/gamut_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#[cfg(target_arch = "aarch64")]
use std::arch::is_aarch64_feature_detected;
#[cfg(target_arch = "x86_64")]
use std::arch::is_x86_feature_detected;

pub(super) fn gamut_map(
mut r: &mut [f32],
mut g: &mut [f32],
mut b: &mut [f32],
luminances: [f32; 3],
saturation_factor: f32,
) {
assert_eq!(r.len(), g.len());
assert_eq!(g.len(), b.len());

#[cfg(target_arch = "x86_64")]
if is_x86_feature_detected!("fma") && is_x86_feature_detected!("sse4.1") {
if is_x86_feature_detected!("avx2") {
// AVX2
// SAFETY: features are checked above.
unsafe {
(r, g, b) = gamut_map_x86_64_avx2(r, g, b, luminances, saturation_factor);
}
}

// SSE4.1 + FMA
// SAFETY: features are checked above.
unsafe {
(r, g, b) = gamut_map_x86_64_fma(r, g, b, luminances, saturation_factor);
}
} else {
// SAFETY: x86_64 implies SSE2.
unsafe {
let mut r_it = r.chunks_exact_mut(4);
let mut g_it = g.chunks_exact_mut(4);
let mut b_it = b.chunks_exact_mut(4);

for ((r, g), b) in (&mut r_it).zip(&mut g_it).zip(&mut b_it) {
let rgb = [
std::arch::x86_64::_mm_loadu_ps(r.as_ptr()),
std::arch::x86_64::_mm_loadu_ps(g.as_ptr()),
std::arch::x86_64::_mm_loadu_ps(b.as_ptr()),
];
let [vr, vg, vb] =
crate::gamut::map_gamut_x86_64_sse2(rgb, luminances, saturation_factor);
std::arch::x86_64::_mm_storeu_ps(r.as_mut_ptr(), vr);
std::arch::x86_64::_mm_storeu_ps(g.as_mut_ptr(), vg);
std::arch::x86_64::_mm_storeu_ps(b.as_mut_ptr(), vb);
}

r = r_it.into_remainder();
g = g_it.into_remainder();
b = b_it.into_remainder();
}
}

#[cfg(target_arch = "aarch64")]
if is_aarch64_feature_detected!("neon") {
// NEON
// SAFETY: features are checked above.
unsafe {
(r, g, b) = gamut_map_aarch64_neon(r, g, b, luminances, saturation_factor);
}
}

// generic
for ((r, g), b) in r.iter_mut().zip(g).zip(b) {
let mapped = crate::gamut::map_gamut_generic([*r, *g, *b], luminances, saturation_factor);
*r = mapped[0];
*g = mapped[1];
*b = mapped[2];
}
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
#[target_feature(enable = "sse4.1")]
pub(super) unsafe fn gamut_map_x86_64_avx2<'r, 'g, 'b>(
r: &'r mut [f32],
g: &'g mut [f32],
b: &'b mut [f32],
luminances: [f32; 3],
saturation_factor: f32,
) -> (&'r mut [f32], &'g mut [f32], &'b mut [f32]) {
let mut r_it = r.chunks_exact_mut(8);
let mut g_it = g.chunks_exact_mut(8);
let mut b_it = b.chunks_exact_mut(8);

for ((r, g), b) in (&mut r_it).zip(&mut g_it).zip(&mut b_it) {
let rgb = [
std::arch::x86_64::_mm256_loadu_ps(r.as_ptr()),
std::arch::x86_64::_mm256_loadu_ps(g.as_ptr()),
std::arch::x86_64::_mm256_loadu_ps(b.as_ptr()),
];
let [vr, vg, vb] = crate::gamut::map_gamut_x86_64_avx2(rgb, luminances, saturation_factor);
std::arch::x86_64::_mm256_storeu_ps(r.as_mut_ptr(), vr);
std::arch::x86_64::_mm256_storeu_ps(g.as_mut_ptr(), vg);
std::arch::x86_64::_mm256_storeu_ps(b.as_mut_ptr(), vb);
}

(
r_it.into_remainder(),
g_it.into_remainder(),
b_it.into_remainder(),
)
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "fma")]
#[target_feature(enable = "sse4.1")]
pub(super) unsafe fn gamut_map_x86_64_fma<'r, 'g, 'b>(
r: &'r mut [f32],
g: &'g mut [f32],
b: &'b mut [f32],
luminances: [f32; 3],
saturation_factor: f32,
) -> (&'r mut [f32], &'g mut [f32], &'b mut [f32]) {
let mut r_it = r.chunks_exact_mut(4);
let mut g_it = g.chunks_exact_mut(4);
let mut b_it = b.chunks_exact_mut(4);

for ((r, g), b) in (&mut r_it).zip(&mut g_it).zip(&mut b_it) {
let rgb = [
std::arch::x86_64::_mm_loadu_ps(r.as_ptr()),
std::arch::x86_64::_mm_loadu_ps(g.as_ptr()),
std::arch::x86_64::_mm_loadu_ps(b.as_ptr()),
];
let [vr, vg, vb] = crate::gamut::map_gamut_x86_64_fma(rgb, luminances, saturation_factor);
std::arch::x86_64::_mm_storeu_ps(r.as_mut_ptr(), vr);
std::arch::x86_64::_mm_storeu_ps(g.as_mut_ptr(), vg);
std::arch::x86_64::_mm_storeu_ps(b.as_mut_ptr(), vb);
}

(
r_it.into_remainder(),
g_it.into_remainder(),
b_it.into_remainder(),
)
}

#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub(super) unsafe fn gamut_map_aarch64_neon<'r, 'g, 'b>(
r: &'r mut [f32],
g: &'g mut [f32],
b: &'b mut [f32],
luminances: [f32; 3],
saturation_factor: f32,
) -> (&'r mut [f32], &'g mut [f32], &'b mut [f32]) {
let mut r_it = r.chunks_exact_mut(4);
let mut g_it = g.chunks_exact_mut(4);
let mut b_it = b.chunks_exact_mut(4);

for ((r, g), b) in (&mut r_it).zip(&mut g_it).zip(&mut b_it) {
let rgb = [
std::arch::aarch64::vld1q_f32(r.as_ptr()),
std::arch::aarch64::vld1q_f32(g.as_ptr()),
std::arch::aarch64::vld1q_f32(b.as_ptr()),
];
let [vr, vg, vb] = crate::gamut::map_gamut_aarch64_neon(rgb, luminances, saturation_factor);
std::arch::aarch64::vst1q_f32(r.as_mut_ptr(), vr);
std::arch::aarch64::vst1q_f32(g.as_mut_ptr(), vg);
std::arch::aarch64::vst1q_f32(b.as_mut_ptr(), vb);
}

(
r_it.into_remainder(),
g_it.into_remainder(),
b_it.into_remainder(),
)
}
Loading