-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert some SSE2 intrinsics to const generics #1021
Changes from 4 commits
0a2d3ce
9d38868
7228fea
34db275
f979a5a
66463d5
fb1798b
386e978
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -594,16 +594,10 @@ pub unsafe fn _mm_sll_epi64(a: __m128i, count: __m128i) -> __m128i { | |
#[inline] | ||
#[target_feature(enable = "sse2")] | ||
#[cfg_attr(test, assert_instr(psraw, imm8 = 1))] | ||
#[rustc_args_required_const(1)] | ||
#[rustc_legacy_const_generics(1)] | ||
#[stable(feature = "simd_x86", since = "1.27.0")] | ||
pub unsafe fn _mm_srai_epi16(a: __m128i, imm8: i32) -> __m128i { | ||
let a = a.as_i16x8(); | ||
macro_rules! call { | ||
($imm8:expr) => { | ||
transmute(psraiw(a, $imm8)) | ||
}; | ||
} | ||
constify_imm8!(imm8, call) | ||
pub unsafe fn _mm_srai_epi16<const imm8: i32>(a: __m128i) -> __m128i { | ||
transmute(psraiw(a.as_i16x8(), imm8)) | ||
} | ||
|
||
/// Shifts packed 16-bit integers in `a` right by `count` while shifting in sign | ||
|
@@ -625,16 +619,10 @@ pub unsafe fn _mm_sra_epi16(a: __m128i, count: __m128i) -> __m128i { | |
#[inline] | ||
#[target_feature(enable = "sse2")] | ||
#[cfg_attr(test, assert_instr(psrad, imm8 = 1))] | ||
#[rustc_args_required_const(1)] | ||
#[rustc_legacy_const_generics(1)] | ||
#[stable(feature = "simd_x86", since = "1.27.0")] | ||
pub unsafe fn _mm_srai_epi32(a: __m128i, imm8: i32) -> __m128i { | ||
let a = a.as_i32x4(); | ||
macro_rules! call { | ||
($imm8:expr) => { | ||
transmute(psraid(a, $imm8)) | ||
}; | ||
} | ||
constify_imm8!(imm8, call) | ||
pub unsafe fn _mm_srai_epi32<const imm8: i32>(a: __m128i) -> __m128i { | ||
transmute(psraid(a.as_i32x4(), imm8)) | ||
} | ||
|
||
/// Shifts packed 32-bit integers in `a` right by `count` while shifting in sign | ||
|
@@ -1461,60 +1449,21 @@ pub unsafe fn _mm_movemask_epi8(a: __m128i) -> i32 { | |
#[inline] | ||
#[target_feature(enable = "sse2")] | ||
#[cfg_attr(test, assert_instr(pshufd, imm8 = 9))] | ||
#[rustc_args_required_const(1)] | ||
#[rustc_legacy_const_generics(1)] | ||
#[stable(feature = "simd_x86", since = "1.27.0")] | ||
pub unsafe fn _mm_shuffle_epi32(a: __m128i, imm8: i32) -> __m128i { | ||
// simd_shuffleX requires that its selector parameter be made up of | ||
// constant values, but we can't enforce that here. In spirit, we need | ||
// to write a `match` on all possible values of a byte, and for each value, | ||
// hard-code the correct `simd_shuffleX` call using only constants. We | ||
// then hope for LLVM to do the rest. | ||
// | ||
// Of course, that's... awful. So we try to use macros to do it for us. | ||
let imm8 = (imm8 & 0xFF) as u8; | ||
pub unsafe fn _mm_shuffle_epi32<const imm8: i32>(a: __m128i) -> __m128i { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a breaking change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, same thing happened in a previous PR. We're kinda allowed to file it as "bugfix" though, because the api shouldn't accept a value outside 0..=255 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, similarly to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we could make it still accept any value and do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, whichever you all prefer. my understanding is that a crater run will ultimately help decide whether the bug should be preserved if need be ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see. If this is only going to break things where imm is not a constant, then I'm okay with that and agree it is a bug fix. I am only concerned about function params being changed. I guess i would also be concerned with how this renders in rustdoc, but whatever the case, that seems fixable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some initial discussion about this feature can be also be seen in this other zulip link, leading to the rustc PR linked above. Some more specific info about stdarch and the migration effort is in #1022. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does that mean this change cannot break anything? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is one specific situation where it can break: you can't use an expression derived from generic arguments in a const item (e.g. Example fn foo<const X: i32>(a: __m128) {
// This was previously accepted but will now fail to compile.
stdarch::_mm_shuffle_ps(a, a, X + 1);
// This still works.
stdarch::_mm_shuffle_ps(a, a, X);
} |
||
static_assert!(imm8: i32 where imm8 >= 0 && imm8 <= 255); | ||
let a = a.as_i32x4(); | ||
|
||
macro_rules! shuffle_done { | ||
($x01:expr, $x23:expr, $x45:expr, $x67:expr) => { | ||
simd_shuffle4(a, a, [$x01, $x23, $x45, $x67]) | ||
}; | ||
} | ||
macro_rules! shuffle_x67 { | ||
($x01:expr, $x23:expr, $x45:expr) => { | ||
match (imm8 >> 6) & 0b11 { | ||
0b00 => shuffle_done!($x01, $x23, $x45, 0), | ||
0b01 => shuffle_done!($x01, $x23, $x45, 1), | ||
0b10 => shuffle_done!($x01, $x23, $x45, 2), | ||
_ => shuffle_done!($x01, $x23, $x45, 3), | ||
} | ||
}; | ||
} | ||
macro_rules! shuffle_x45 { | ||
($x01:expr, $x23:expr) => { | ||
match (imm8 >> 4) & 0b11 { | ||
0b00 => shuffle_x67!($x01, $x23, 0), | ||
0b01 => shuffle_x67!($x01, $x23, 1), | ||
0b10 => shuffle_x67!($x01, $x23, 2), | ||
_ => shuffle_x67!($x01, $x23, 3), | ||
} | ||
}; | ||
} | ||
macro_rules! shuffle_x23 { | ||
($x01:expr) => { | ||
match (imm8 >> 2) & 0b11 { | ||
0b00 => shuffle_x45!($x01, 0), | ||
0b01 => shuffle_x45!($x01, 1), | ||
0b10 => shuffle_x45!($x01, 2), | ||
_ => shuffle_x45!($x01, 3), | ||
} | ||
}; | ||
} | ||
let x: i32x4 = match imm8 & 0b11 { | ||
0b00 => shuffle_x23!(0), | ||
0b01 => shuffle_x23!(1), | ||
0b10 => shuffle_x23!(2), | ||
_ => shuffle_x23!(3), | ||
}; | ||
let x: i32x4 = simd_shuffle4( | ||
a, | ||
a, | ||
[ | ||
imm8 as u32 & 0b11, | ||
(imm8 as u32 >> 2) & 0b11, | ||
(imm8 as u32 >> 4) & 0b11, | ||
(imm8 as u32 >> 6) & 0b11, | ||
], | ||
); | ||
transmute(x) | ||
} | ||
|
||
|
@@ -1528,53 +1477,25 @@ pub unsafe fn _mm_shuffle_epi32(a: __m128i, imm8: i32) -> __m128i { | |
#[inline] | ||
#[target_feature(enable = "sse2")] | ||
#[cfg_attr(test, assert_instr(pshufhw, imm8 = 9))] | ||
#[rustc_args_required_const(1)] | ||
#[rustc_legacy_const_generics(1)] | ||
#[stable(feature = "simd_x86", since = "1.27.0")] | ||
pub unsafe fn _mm_shufflehi_epi16(a: __m128i, imm8: i32) -> __m128i { | ||
// See _mm_shuffle_epi32. | ||
let imm8 = (imm8 & 0xFF) as u8; | ||
pub unsafe fn _mm_shufflehi_epi16<const imm8: i32>(a: __m128i) -> __m128i { | ||
static_assert!(imm8: i32 where imm8 >= 0 && imm8 <= 255); | ||
let a = a.as_i16x8(); | ||
macro_rules! shuffle_done { | ||
($x01:expr, $x23:expr, $x45:expr, $x67:expr) => { | ||
simd_shuffle8(a, a, [0, 1, 2, 3, $x01 + 4, $x23 + 4, $x45 + 4, $x67 + 4]) | ||
}; | ||
} | ||
macro_rules! shuffle_x67 { | ||
($x01:expr, $x23:expr, $x45:expr) => { | ||
match (imm8 >> 6) & 0b11 { | ||
0b00 => shuffle_done!($x01, $x23, $x45, 0), | ||
0b01 => shuffle_done!($x01, $x23, $x45, 1), | ||
0b10 => shuffle_done!($x01, $x23, $x45, 2), | ||
_ => shuffle_done!($x01, $x23, $x45, 3), | ||
} | ||
}; | ||
} | ||
macro_rules! shuffle_x45 { | ||
($x01:expr, $x23:expr) => { | ||
match (imm8 >> 4) & 0b11 { | ||
0b00 => shuffle_x67!($x01, $x23, 0), | ||
0b01 => shuffle_x67!($x01, $x23, 1), | ||
0b10 => shuffle_x67!($x01, $x23, 2), | ||
_ => shuffle_x67!($x01, $x23, 3), | ||
} | ||
}; | ||
} | ||
macro_rules! shuffle_x23 { | ||
($x01:expr) => { | ||
match (imm8 >> 2) & 0b11 { | ||
0b00 => shuffle_x45!($x01, 0), | ||
0b01 => shuffle_x45!($x01, 1), | ||
0b10 => shuffle_x45!($x01, 2), | ||
_ => shuffle_x45!($x01, 3), | ||
} | ||
}; | ||
} | ||
let x: i16x8 = match imm8 & 0b11 { | ||
0b00 => shuffle_x23!(0), | ||
0b01 => shuffle_x23!(1), | ||
0b10 => shuffle_x23!(2), | ||
_ => shuffle_x23!(3), | ||
}; | ||
let x: i16x8 = simd_shuffle8( | ||
a, | ||
a, | ||
[ | ||
0, | ||
1, | ||
2, | ||
3, | ||
(imm8 as u32 & 0b11) + 4, | ||
((imm8 as u32 >> 2) & 0b11) + 4, | ||
((imm8 as u32 >> 4) & 0b11) + 4, | ||
((imm8 as u32 >> 6) & 0b11) + 4, | ||
], | ||
); | ||
transmute(x) | ||
} | ||
|
||
|
@@ -1588,54 +1509,25 @@ pub unsafe fn _mm_shufflehi_epi16(a: __m128i, imm8: i32) -> __m128i { | |
#[inline] | ||
#[target_feature(enable = "sse2")] | ||
#[cfg_attr(test, assert_instr(pshuflw, imm8 = 9))] | ||
#[rustc_args_required_const(1)] | ||
#[rustc_legacy_const_generics(1)] | ||
#[stable(feature = "simd_x86", since = "1.27.0")] | ||
pub unsafe fn _mm_shufflelo_epi16(a: __m128i, imm8: i32) -> __m128i { | ||
// See _mm_shuffle_epi32. | ||
let imm8 = (imm8 & 0xFF) as u8; | ||
pub unsafe fn _mm_shufflelo_epi16<const imm8: i32>(a: __m128i) -> __m128i { | ||
static_assert!(imm8: i32 where imm8 >= 0 && imm8 <= 255); | ||
let a = a.as_i16x8(); | ||
|
||
macro_rules! shuffle_done { | ||
($x01:expr, $x23:expr, $x45:expr, $x67:expr) => { | ||
simd_shuffle8(a, a, [$x01, $x23, $x45, $x67, 4, 5, 6, 7]) | ||
}; | ||
} | ||
macro_rules! shuffle_x67 { | ||
($x01:expr, $x23:expr, $x45:expr) => { | ||
match (imm8 >> 6) & 0b11 { | ||
0b00 => shuffle_done!($x01, $x23, $x45, 0), | ||
0b01 => shuffle_done!($x01, $x23, $x45, 1), | ||
0b10 => shuffle_done!($x01, $x23, $x45, 2), | ||
_ => shuffle_done!($x01, $x23, $x45, 3), | ||
} | ||
}; | ||
} | ||
macro_rules! shuffle_x45 { | ||
($x01:expr, $x23:expr) => { | ||
match (imm8 >> 4) & 0b11 { | ||
0b00 => shuffle_x67!($x01, $x23, 0), | ||
0b01 => shuffle_x67!($x01, $x23, 1), | ||
0b10 => shuffle_x67!($x01, $x23, 2), | ||
_ => shuffle_x67!($x01, $x23, 3), | ||
} | ||
}; | ||
} | ||
macro_rules! shuffle_x23 { | ||
($x01:expr) => { | ||
match (imm8 >> 2) & 0b11 { | ||
0b00 => shuffle_x45!($x01, 0), | ||
0b01 => shuffle_x45!($x01, 1), | ||
0b10 => shuffle_x45!($x01, 2), | ||
_ => shuffle_x45!($x01, 3), | ||
} | ||
}; | ||
} | ||
let x: i16x8 = match imm8 & 0b11 { | ||
0b00 => shuffle_x23!(0), | ||
0b01 => shuffle_x23!(1), | ||
0b10 => shuffle_x23!(2), | ||
_ => shuffle_x23!(3), | ||
}; | ||
let x: i16x8 = simd_shuffle8( | ||
a, | ||
a, | ||
[ | ||
imm8 as u32 & 0b11, | ||
(imm8 as u32 >> 2) & 0b11, | ||
(imm8 as u32 >> 4) & 0b11, | ||
(imm8 as u32 >> 6) & 0b11, | ||
4, | ||
5, | ||
6, | ||
7, | ||
], | ||
); | ||
transmute(x) | ||
} | ||
|
||
|
@@ -3594,7 +3486,7 @@ mod tests { | |
|
||
#[simd_test(enable = "sse2")] | ||
unsafe fn test_mm_srai_epi16() { | ||
let r = _mm_srai_epi16(_mm_set1_epi16(-1), 1); | ||
let r = _mm_srai_epi16::<1>(_mm_set1_epi16(-1)); | ||
assert_eq_m128i(r, _mm_set1_epi16(-1)); | ||
} | ||
|
||
|
@@ -3608,7 +3500,7 @@ mod tests { | |
|
||
#[simd_test(enable = "sse2")] | ||
unsafe fn test_mm_srai_epi32() { | ||
let r = _mm_srai_epi32(_mm_set1_epi32(-1), 1); | ||
let r = _mm_srai_epi32::<1>(_mm_set1_epi32(-1)); | ||
assert_eq_m128i(r, _mm_set1_epi32(-1)); | ||
} | ||
|
||
|
@@ -4107,23 +3999,23 @@ mod tests { | |
#[simd_test(enable = "sse2")] | ||
unsafe fn test_mm_shuffle_epi32() { | ||
let a = _mm_setr_epi32(5, 10, 15, 20); | ||
let r = _mm_shuffle_epi32(a, 0b00_01_01_11); | ||
let r = _mm_shuffle_epi32::<0b00_01_01_11>(a); | ||
let e = _mm_setr_epi32(20, 10, 10, 5); | ||
assert_eq_m128i(r, e); | ||
} | ||
|
||
#[simd_test(enable = "sse2")] | ||
unsafe fn test_mm_shufflehi_epi16() { | ||
let a = _mm_setr_epi16(1, 2, 3, 4, 5, 10, 15, 20); | ||
let r = _mm_shufflehi_epi16(a, 0b00_01_01_11); | ||
let r = _mm_shufflehi_epi16::<0b00_01_01_11>(a); | ||
let e = _mm_setr_epi16(1, 2, 3, 4, 20, 10, 10, 5); | ||
assert_eq_m128i(r, e); | ||
} | ||
|
||
#[simd_test(enable = "sse2")] | ||
unsafe fn test_mm_shufflelo_epi16() { | ||
let a = _mm_setr_epi16(5, 10, 15, 20, 1, 2, 3, 4); | ||
let r = _mm_shufflelo_epi16(a, 0b00_01_01_11); | ||
let r = _mm_shufflelo_epi16::<0b00_01_01_11>(a); | ||
let e = _mm_setr_epi16(20, 10, 10, 5, 1, 2, 3, 4); | ||
assert_eq_m128i(r, e); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should have a
static_assert!
to ensure the immediate is between 0 and 255.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like
static_assert!(imm8: i32 where imm8 >= 0 && imm8 <= 255);
is going to be used quite a bit, we could have astatic_assert_imm8!
or in general static asserts analogues to the variousconstify_imm*
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a good idea. Also we should use a single shared
Validate
struct for those to reduce the MIR size (rather than instantiating a new one in each function).