Skip to content

Commit 628f7a3

Browse files
fix: make avx512 fp16 a runtime check (#1884)
Makes [avx512 fp16](https://networkbuilders.intel.com/solutionslibrary/intel-avx-512-fp16-instruction-set-for-intel-xeon-processor-based-products-technology-guide) support a runtime check. This will allow binaries compiled w/ the avx512fp16 feature to run hardware that doesn't support this feature (e.g. x86 before saphire rapids). Check does not add performance penalty: ``` albertlockett@albert-ubuntu-saphire:~/lance/rust/lance-linalg$ TARGET_TIME=55 cargo bench \ --bench dot \ -F avx512fp16 Compiling lance-linalg v0.9.9 (/home/albertlockett/lance/rust/lance-linalg) Finished bench [optimized + debuginfo] target(s) in 55.77s Running benches/dot.rs (/home/albertlockett/lance/rust/target/release/deps/dot-f42dee3ad61e0342) Gnuplot not found, using plotters backend Dot(half::binary16::f16, arrow_artiy) time: [2.5228 s 2.5230 s 2.5233 s] change: [-0.0915% -0.0641% -0.0381%] (p = 0.00 < 0.10) Change within noise threshold. Dot(half::binary16::f16, auto-vectorization) time: [167.90 ms 168.05 ms 168.34 ms] change: [-0.3945% -0.1097% +0.1731%] (p = 0.47 > 0.10) No change in performance detected. Dot(f16, SIMD) time: [167.03 ms 167.22 ms 167.50 ms] change: [-1.4038% -0.9215% -0.4951%] (p = 0.00 < 0.10) Change within noise threshold. ```
1 parent 2f67cf9 commit 628f7a3

File tree

7 files changed

+119
-42
lines changed

7 files changed

+119
-42
lines changed

rust/lance-core/src/utils.rs

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
pub mod address;
16+
pub mod cpu;
1617
pub mod deletion;
1718
pub mod mask;
1819
pub mod testing;

rust/lance-core/src/utils/cpu.rs

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright 2024 Lance Developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#[cfg(target_arch = "x86_64")]
16+
pub mod x86 {
17+
use core::arch::x86_64::__cpuid;
18+
19+
use lazy_static::lazy_static;
20+
21+
#[inline]
22+
fn check_flag(x: usize, position: u32) -> bool {
23+
x & (1 << position) != 0
24+
}
25+
26+
lazy_static! {
27+
pub static ref AVX512_F16_SUPPORTED: bool = {
28+
// this macro does many OS checks/etc. to determine if allowed to use AVX512
29+
if !is_x86_feature_detected!("avx512f") {
30+
return false;
31+
}
32+
33+
// EAX=7, ECX=0: Extended Features (includes AVX512)
34+
// More info on calling CPUID can be found here (section 1.4)
35+
// https://www.intel.com/content/dam/develop/external/us/en/documents/architecture-instruction-set-extensions-programming-reference.pdf
36+
let ext_cpuid_result = unsafe { __cpuid(7) };
37+
check_flag(ext_cpuid_result.edx as usize, 23)
38+
};
39+
}
40+
}

rust/lance-linalg/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ arrow-schema = { workspace = true }
1616
futures = { workspace = true }
1717
half = { workspace = true }
1818
lance-arrow = { workspace = true }
19+
lance-core = { workspace = true }
1920
log = { workspace = true }
2021
num_cpus = { workspace = true }
2122
num-traits = { workspace = true }

rust/lance-linalg/benches/dot.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
use std::iter::{repeat_with, Sum};
16+
use std::time::Duration;
1617

1718
use arrow_array::{
1819
types::{Float16Type, Float32Type, Float64Type},
@@ -131,18 +132,26 @@ fn bench_distance(c: &mut Criterion) {
131132
run_bench::<Float64Type>(c);
132133
}
133134

135+
fn bench_time() -> Duration {
136+
let secs: u64 = option_env!("TARGET_TIME").unwrap_or("5").parse().unwrap();
137+
Duration::from_secs(secs)
138+
}
139+
134140
#[cfg(target_os = "linux")]
135141
criterion_group!(
136142
name=benches;
137-
config = Criterion::default().significance_level(0.1).sample_size(10)
143+
config = Criterion::default()
144+
.significance_level(0.1)
145+
.sample_size(10)
146+
.measurement_time(bench_time())
138147
.with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
139148
targets = bench_distance);
140149

141150
// Non-linux version does not support pprof.
142151
#[cfg(not(target_os = "linux"))]
143152
criterion_group!(
144153
name=benches;
145-
config = Criterion::default().significance_level(0.1).sample_size(10);
154+
config = Criterion::default().significance_level(0.1).sample_size(10).measurement_time(bench_time());
146155
targets = bench_distance);
147156

148157
criterion_main!(benches);

rust/lance-linalg/src/distance/dot.rs

+12-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ use lance_arrow::{ArrowFloatType, FloatArray, FloatToArrayType};
2828
use num_traits::real::Real;
2929
use num_traits::AsPrimitive;
3030

31+
#[cfg(all(target_os = "linux", feature = "avx512fp16", target_arch = "x86_64"))]
32+
use lance_core::utils::cpu::x86::AVX512_F16_SUPPORTED;
33+
3134
use crate::simd::{
3235
f32::{f32x16, f32x8},
3336
SIMD,
@@ -112,13 +115,18 @@ mod kernel {
112115
impl Dot for Float16Type {
113116
#[inline]
114117
fn dot(x: &[f16], y: &[f16]) -> f32 {
115-
#[cfg(any(
116-
all(target_os = "macos", target_feature = "neon"),
117-
all(target_os = "linux", feature = "avx512fp16")
118-
))]
118+
#[cfg(all(target_os = "macos", target_feature = "neon"))]
119119
unsafe {
120120
kernel::dot_f16(x.as_ptr(), y.as_ptr(), x.len() as u32)
121121
}
122+
123+
#[cfg(all(target_os = "linux", feature = "avx512fp16", target_arch = "x86_64"))]
124+
if *AVX512_F16_SUPPORTED {
125+
unsafe { kernel::dot_f16(x.as_ptr(), y.as_ptr(), x.len() as u32) }
126+
} else {
127+
dot_scalar::<f16, 16>(x, y)
128+
}
129+
122130
#[cfg(not(any(
123131
all(target_os = "macos", target_feature = "neon"),
124132
all(target_os = "linux", feature = "avx512fp16")

rust/lance-linalg/src/distance/l2.rs

+10-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ use half::{bf16, f16};
2929
use lance_arrow::{bfloat16::BFloat16Type, ArrowFloatType, FloatArray, FloatToArrayType};
3030
use num_traits::{AsPrimitive, Float};
3131

32+
#[cfg(all(target_os = "linux", feature = "avx512fp16", target_arch = "x86_64"))]
33+
use lance_core::utils::cpu::x86::AVX512_F16_SUPPORTED;
34+
3235
use crate::simd::{
3336
f32::{f32x16, f32x8},
3437
SIMD,
@@ -119,13 +122,16 @@ mod kernel {
119122
impl L2 for Float16Type {
120123
#[inline]
121124
fn l2(x: &[f16], y: &[f16]) -> f32 {
122-
#[cfg(any(
123-
all(target_os = "macos", target_feature = "neon"),
124-
all(target_os = "linux", feature = "avx512fp16")
125-
))]
125+
#[cfg(all(target_os = "macos", target_feature = "neon"))]
126126
unsafe {
127127
kernel::l2_f16(x.as_ptr(), y.as_ptr(), x.len() as u32)
128128
}
129+
#[cfg(all(target_os = "linux", feature = "avx512fp16", target_arch = "x86_64"))]
130+
if *AVX512_F16_SUPPORTED {
131+
unsafe { kernel::l2_f16(x.as_ptr(), y.as_ptr(), x.len() as u32) }
132+
} else {
133+
l2_scalar::<f16, 16>(x, y)
134+
}
129135
#[cfg(not(any(
130136
all(target_os = "macos", target_feature = "neon"),
131137
all(target_os = "linux", feature = "avx512fp16")

rust/lance-linalg/src/distance/norm_l2.rs

+44-32
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ use std::iter::Sum;
1717
use half::{bf16, f16};
1818
use num_traits::{AsPrimitive, Float};
1919

20+
#[cfg(all(target_os = "linux", feature = "avx512fp16", target_arch = "x86_64"))]
21+
use lance_core::utils::cpu::x86::AVX512_F16_SUPPORTED;
22+
2023
use crate::simd::{
2124
f32::{f32x16, f32x8},
2225
SIMD,
@@ -45,47 +48,56 @@ mod kernel {
4548
impl Normalize<f16> for &[f16] {
4649
// #[inline]
4750
fn norm_l2(&self) -> f32 {
48-
#[cfg(any(
49-
all(target_os = "macos", target_feature = "neon"),
50-
feature = "avx512fp16"
51-
))]
51+
#[cfg(all(target_os = "macos", target_feature = "neon"))]
5252
unsafe {
5353
kernel::norm_l2_f16(self.as_ptr(), self.len() as u32)
5454
}
55+
56+
#[cfg(all(target_os = "linux", feature = "avx512fp16", target_arch = "x86_64"))]
57+
if *AVX512_F16_SUPPORTED {
58+
unsafe { kernel::norm_l2_f16(self.as_ptr(), self.len() as u32) }
59+
} else {
60+
norm_l2_f16_impl(self)
61+
}
62+
5563
#[cfg(not(any(
5664
all(target_os = "macos", target_feature = "neon"),
5765
feature = "avx512fp16"
5866
)))]
59-
{
60-
// Please run `cargo bench --bench norm_l2" on Apple Silicon when
61-
// change the following code.
62-
const LANES: usize = 16;
63-
let chunks = self.chunks_exact(LANES);
64-
let sum = if chunks.remainder().is_empty() {
65-
0.0
66-
} else {
67-
chunks
68-
.remainder()
69-
.iter()
70-
.map(|v| v.to_f32().powi(2))
71-
.sum::<f32>()
72-
};
73-
74-
let mut sums: [f32; LANES] = [0_f32; LANES];
75-
for chk in chunks {
76-
// Convert to f32
77-
let mut f32_vals: [f32; LANES] = [0_f32; LANES];
78-
for i in 0..LANES {
79-
f32_vals[i] = chk[i].to_f32();
80-
}
81-
// Vectorized multiply
82-
for i in 0..LANES {
83-
sums[i] += f32_vals[i].powi(2);
84-
}
85-
}
86-
(sums.iter().copied().sum::<f32>() + sum).sqrt()
67+
norm_l2_f16_impl(self)
68+
}
69+
}
70+
71+
#[inline]
72+
#[cfg(not(all(target_os = "macos", target_feature = "neon")))]
73+
fn norm_l2_f16_impl(arr: &[f16]) -> f32 {
74+
// Please run `cargo bench --bench norm_l2" on Apple Silicon when
75+
// change the following code.
76+
const LANES: usize = 16;
77+
let chunks = arr.chunks_exact(LANES);
78+
let sum = if chunks.remainder().is_empty() {
79+
0.0
80+
} else {
81+
chunks
82+
.remainder()
83+
.iter()
84+
.map(|v| v.to_f32().powi(2))
85+
.sum::<f32>()
86+
};
87+
88+
let mut sums: [f32; LANES] = [0_f32; LANES];
89+
for chk in chunks {
90+
// Convert to f32
91+
let mut f32_vals: [f32; LANES] = [0_f32; LANES];
92+
for i in 0..LANES {
93+
f32_vals[i] = chk[i].to_f32();
94+
}
95+
// Vectorized multiply
96+
for i in 0..LANES {
97+
sums[i] += f32_vals[i].powi(2);
8798
}
8899
}
100+
(sums.iter().copied().sum::<f32>() + sum).sqrt()
89101
}
90102

91103
impl Normalize<bf16> for &[bf16] {

0 commit comments

Comments
 (0)