Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert some SSE2 intrinsics to const generics #1021

Merged
merged 8 commits into from
Feb 28, 2021
12 changes: 6 additions & 6 deletions crates/core_arch/src/x86/avx512bw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5858,7 +5858,7 @@ pub unsafe fn _mm256_maskz_srai_epi16(k: __mmask16, a: __m256i, imm8: u32) -> __
pub unsafe fn _mm_mask_srai_epi16(src: __m128i, k: __mmask8, a: __m128i, imm8: u32) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_srai_epi16(a, $imm8)
_mm_srai_epi16::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
Expand All @@ -5875,7 +5875,7 @@ pub unsafe fn _mm_mask_srai_epi16(src: __m128i, k: __mmask8, a: __m128i, imm8: u
pub unsafe fn _mm_maskz_srai_epi16(k: __mmask8, a: __m128i, imm8: u32) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_srai_epi16(a, $imm8)
_mm_srai_epi16::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
Expand Down Expand Up @@ -7414,7 +7414,7 @@ pub unsafe fn _mm_mask_shufflelo_epi16(
) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_shufflelo_epi16(a, $imm8)
_mm_shufflelo_epi16::<$imm8>(a)
};
}
let shuffle = constify_imm8_sae!(imm8, call);
Expand All @@ -7431,7 +7431,7 @@ pub unsafe fn _mm_mask_shufflelo_epi16(
pub unsafe fn _mm_maskz_shufflelo_epi16(k: __mmask8, a: __m128i, imm8: i32) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_shufflelo_epi16(a, $imm8)
_mm_shufflelo_epi16::<$imm8>(a)
};
}
let shuffle = constify_imm8_sae!(imm8, call);
Expand Down Expand Up @@ -7592,7 +7592,7 @@ pub unsafe fn _mm_mask_shufflehi_epi16(
) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_shufflehi_epi16(a, $imm8)
_mm_shufflehi_epi16::<$imm8>(a)
};
}
let shuffle = constify_imm8_sae!(imm8, call);
Expand All @@ -7609,7 +7609,7 @@ pub unsafe fn _mm_mask_shufflehi_epi16(
pub unsafe fn _mm_maskz_shufflehi_epi16(k: __mmask8, a: __m128i, imm8: i32) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_shufflehi_epi16(a, $imm8)
_mm_shufflehi_epi16::<$imm8>(a)
};
}
let shuffle = constify_imm8_sae!(imm8, call);
Expand Down
8 changes: 4 additions & 4 deletions crates/core_arch/src/x86/avx512f.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19238,7 +19238,7 @@ pub unsafe fn _mm256_maskz_srai_epi32(k: __mmask8, a: __m256i, imm8: u32) -> __m
pub unsafe fn _mm_mask_srai_epi32(src: __m128i, k: __mmask8, a: __m128i, imm8: u32) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_srai_epi32(a, $imm8)
_mm_srai_epi32::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
Expand All @@ -19255,7 +19255,7 @@ pub unsafe fn _mm_mask_srai_epi32(src: __m128i, k: __mmask8, a: __m128i, imm8: u
pub unsafe fn _mm_maskz_srai_epi32(k: __mmask8, a: __m128i, imm8: u32) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_srai_epi32(a, $imm8)
_mm_srai_epi32::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
Expand Down Expand Up @@ -22495,7 +22495,7 @@ pub unsafe fn _mm_mask_shuffle_epi32(
) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_shuffle_epi32(a, $imm8)
_mm_shuffle_epi32::<$imm8>(a)
};
}
let r = constify_imm8_sae!(imm8, call);
Expand All @@ -22512,7 +22512,7 @@ pub unsafe fn _mm_mask_shuffle_epi32(
pub unsafe fn _mm_maskz_shuffle_epi32(k: __mmask8, a: __m128i, imm8: _MM_PERM_ENUM) -> __m128i {
macro_rules! call {
($imm8:expr) => {
_mm_shuffle_epi32(a, $imm8)
_mm_shuffle_epi32::<$imm8>(a)
};
}
let r = constify_imm8_sae!(imm8, call);
Expand Down
224 changes: 58 additions & 166 deletions crates/core_arch/src/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,16 +594,10 @@ pub unsafe fn _mm_sll_epi64(a: __m128i, count: __m128i) -> __m128i {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(psraw, imm8 = 1))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_srai_epi16(a: __m128i, imm8: i32) -> __m128i {
let a = a.as_i16x8();
macro_rules! call {
($imm8:expr) => {
transmute(psraiw(a, $imm8))
};
}
constify_imm8!(imm8, call)
pub unsafe fn _mm_srai_epi16<const imm8: i32>(a: __m128i) -> __m128i {
transmute(psraiw(a.as_i16x8(), imm8))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a static_assert! to ensure the immediate is between 0 and 255.

Copy link
Member Author

@lqd lqd Feb 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like static_assert!(imm8: i32 where imm8 >= 0 && imm8 <= 255); is going to be used quite a bit, we could have a static_assert_imm8! or in general static asserts analogues to the various constify_imm* ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good idea. Also we should use a single shared Validate struct for those to reduce the MIR size (rather than instantiating a new one in each function).

}

/// Shifts packed 16-bit integers in `a` right by `count` while shifting in sign
Expand All @@ -625,16 +619,10 @@ pub unsafe fn _mm_sra_epi16(a: __m128i, count: __m128i) -> __m128i {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(psrad, imm8 = 1))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_srai_epi32(a: __m128i, imm8: i32) -> __m128i {
let a = a.as_i32x4();
macro_rules! call {
($imm8:expr) => {
transmute(psraid(a, $imm8))
};
}
constify_imm8!(imm8, call)
pub unsafe fn _mm_srai_epi32<const imm8: i32>(a: __m128i) -> __m128i {
transmute(psraid(a.as_i32x4(), imm8))
}

/// Shifts packed 32-bit integers in `a` right by `count` while shifting in sign
Expand Down Expand Up @@ -1461,60 +1449,21 @@ pub unsafe fn _mm_movemask_epi8(a: __m128i) -> i32 {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(pshufd, imm8 = 9))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_shuffle_epi32(a: __m128i, imm8: i32) -> __m128i {
// simd_shuffleX requires that its selector parameter be made up of
// constant values, but we can't enforce that here. In spirit, we need
// to write a `match` on all possible values of a byte, and for each value,
// hard-code the correct `simd_shuffleX` call using only constants. We
// then hope for LLVM to do the rest.
//
// Of course, that's... awful. So we try to use macros to do it for us.
let imm8 = (imm8 & 0xFF) as u8;
pub unsafe fn _mm_shuffle_epi32<const imm8: i32>(a: __m128i) -> __m128i {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a breaking change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, same thing happened in a previous PR. We're kinda allowed to file it as "bugfix" though, because the api shouldn't accept a value outside 0..=255

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, similarly to _mm_shuffle_ps discussed here #1018 (comment) which was acceptable as a bug fix if I understand correctly ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we could make it still accept any value and do &0xFF if we wanted to preserve the bug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, whichever you all prefer. my understanding is that a crater run will ultimately help decide whether the bug should be preserved if need be ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. If this is only going to break things where imm is not a constant, then I'm okay with that and agree it is a bug fix. I am only concerned about function params being changed.

I guess i would also be concerned with how this renders in rustdoc, but whatever the case, that seems fixable.

Copy link
Member Author

@lqd lqd Feb 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial discussion about this feature can be also be seen in this other zulip link, leading to the rustc PR linked above. Some more specific info about stdarch and the migration effort is in #1022.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that #[rustc_args_required_const] already required the value to be a constant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean this change cannot break anything?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one specific situation where it can break: you can't use an expression derived from generic arguments in a const item (e.g. const or a const generic parameter). But #[rustc_args_required_const] accepted this.

Example

fn foo<const X: i32>(a: __m128) {
	// This was previously accepted but will now fail to compile.
	stdarch::_mm_shuffle_ps(a, a, X + 1);

	// This still works.
	stdarch::_mm_shuffle_ps(a, a, X);
}

static_assert!(imm8: i32 where imm8 >= 0 && imm8 <= 255);
let a = a.as_i32x4();

macro_rules! shuffle_done {
($x01:expr, $x23:expr, $x45:expr, $x67:expr) => {
simd_shuffle4(a, a, [$x01, $x23, $x45, $x67])
};
}
macro_rules! shuffle_x67 {
($x01:expr, $x23:expr, $x45:expr) => {
match (imm8 >> 6) & 0b11 {
0b00 => shuffle_done!($x01, $x23, $x45, 0),
0b01 => shuffle_done!($x01, $x23, $x45, 1),
0b10 => shuffle_done!($x01, $x23, $x45, 2),
_ => shuffle_done!($x01, $x23, $x45, 3),
}
};
}
macro_rules! shuffle_x45 {
($x01:expr, $x23:expr) => {
match (imm8 >> 4) & 0b11 {
0b00 => shuffle_x67!($x01, $x23, 0),
0b01 => shuffle_x67!($x01, $x23, 1),
0b10 => shuffle_x67!($x01, $x23, 2),
_ => shuffle_x67!($x01, $x23, 3),
}
};
}
macro_rules! shuffle_x23 {
($x01:expr) => {
match (imm8 >> 2) & 0b11 {
0b00 => shuffle_x45!($x01, 0),
0b01 => shuffle_x45!($x01, 1),
0b10 => shuffle_x45!($x01, 2),
_ => shuffle_x45!($x01, 3),
}
};
}
let x: i32x4 = match imm8 & 0b11 {
0b00 => shuffle_x23!(0),
0b01 => shuffle_x23!(1),
0b10 => shuffle_x23!(2),
_ => shuffle_x23!(3),
};
let x: i32x4 = simd_shuffle4(
a,
a,
[
imm8 as u32 & 0b11,
(imm8 as u32 >> 2) & 0b11,
(imm8 as u32 >> 4) & 0b11,
(imm8 as u32 >> 6) & 0b11,
],
);
transmute(x)
}

Expand All @@ -1528,53 +1477,25 @@ pub unsafe fn _mm_shuffle_epi32(a: __m128i, imm8: i32) -> __m128i {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(pshufhw, imm8 = 9))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_shufflehi_epi16(a: __m128i, imm8: i32) -> __m128i {
// See _mm_shuffle_epi32.
let imm8 = (imm8 & 0xFF) as u8;
pub unsafe fn _mm_shufflehi_epi16<const imm8: i32>(a: __m128i) -> __m128i {
static_assert!(imm8: i32 where imm8 >= 0 && imm8 <= 255);
let a = a.as_i16x8();
macro_rules! shuffle_done {
($x01:expr, $x23:expr, $x45:expr, $x67:expr) => {
simd_shuffle8(a, a, [0, 1, 2, 3, $x01 + 4, $x23 + 4, $x45 + 4, $x67 + 4])
};
}
macro_rules! shuffle_x67 {
($x01:expr, $x23:expr, $x45:expr) => {
match (imm8 >> 6) & 0b11 {
0b00 => shuffle_done!($x01, $x23, $x45, 0),
0b01 => shuffle_done!($x01, $x23, $x45, 1),
0b10 => shuffle_done!($x01, $x23, $x45, 2),
_ => shuffle_done!($x01, $x23, $x45, 3),
}
};
}
macro_rules! shuffle_x45 {
($x01:expr, $x23:expr) => {
match (imm8 >> 4) & 0b11 {
0b00 => shuffle_x67!($x01, $x23, 0),
0b01 => shuffle_x67!($x01, $x23, 1),
0b10 => shuffle_x67!($x01, $x23, 2),
_ => shuffle_x67!($x01, $x23, 3),
}
};
}
macro_rules! shuffle_x23 {
($x01:expr) => {
match (imm8 >> 2) & 0b11 {
0b00 => shuffle_x45!($x01, 0),
0b01 => shuffle_x45!($x01, 1),
0b10 => shuffle_x45!($x01, 2),
_ => shuffle_x45!($x01, 3),
}
};
}
let x: i16x8 = match imm8 & 0b11 {
0b00 => shuffle_x23!(0),
0b01 => shuffle_x23!(1),
0b10 => shuffle_x23!(2),
_ => shuffle_x23!(3),
};
let x: i16x8 = simd_shuffle8(
a,
a,
[
0,
1,
2,
3,
(imm8 as u32 & 0b11) + 4,
((imm8 as u32 >> 2) & 0b11) + 4,
((imm8 as u32 >> 4) & 0b11) + 4,
((imm8 as u32 >> 6) & 0b11) + 4,
],
);
transmute(x)
}

Expand All @@ -1588,54 +1509,25 @@ pub unsafe fn _mm_shufflehi_epi16(a: __m128i, imm8: i32) -> __m128i {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(pshuflw, imm8 = 9))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_shufflelo_epi16(a: __m128i, imm8: i32) -> __m128i {
// See _mm_shuffle_epi32.
let imm8 = (imm8 & 0xFF) as u8;
pub unsafe fn _mm_shufflelo_epi16<const imm8: i32>(a: __m128i) -> __m128i {
static_assert!(imm8: i32 where imm8 >= 0 && imm8 <= 255);
let a = a.as_i16x8();

macro_rules! shuffle_done {
($x01:expr, $x23:expr, $x45:expr, $x67:expr) => {
simd_shuffle8(a, a, [$x01, $x23, $x45, $x67, 4, 5, 6, 7])
};
}
macro_rules! shuffle_x67 {
($x01:expr, $x23:expr, $x45:expr) => {
match (imm8 >> 6) & 0b11 {
0b00 => shuffle_done!($x01, $x23, $x45, 0),
0b01 => shuffle_done!($x01, $x23, $x45, 1),
0b10 => shuffle_done!($x01, $x23, $x45, 2),
_ => shuffle_done!($x01, $x23, $x45, 3),
}
};
}
macro_rules! shuffle_x45 {
($x01:expr, $x23:expr) => {
match (imm8 >> 4) & 0b11 {
0b00 => shuffle_x67!($x01, $x23, 0),
0b01 => shuffle_x67!($x01, $x23, 1),
0b10 => shuffle_x67!($x01, $x23, 2),
_ => shuffle_x67!($x01, $x23, 3),
}
};
}
macro_rules! shuffle_x23 {
($x01:expr) => {
match (imm8 >> 2) & 0b11 {
0b00 => shuffle_x45!($x01, 0),
0b01 => shuffle_x45!($x01, 1),
0b10 => shuffle_x45!($x01, 2),
_ => shuffle_x45!($x01, 3),
}
};
}
let x: i16x8 = match imm8 & 0b11 {
0b00 => shuffle_x23!(0),
0b01 => shuffle_x23!(1),
0b10 => shuffle_x23!(2),
_ => shuffle_x23!(3),
};
let x: i16x8 = simd_shuffle8(
a,
a,
[
imm8 as u32 & 0b11,
(imm8 as u32 >> 2) & 0b11,
(imm8 as u32 >> 4) & 0b11,
(imm8 as u32 >> 6) & 0b11,
4,
5,
6,
7,
],
);
transmute(x)
}

Expand Down Expand Up @@ -3594,7 +3486,7 @@ mod tests {

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srai_epi16() {
let r = _mm_srai_epi16(_mm_set1_epi16(-1), 1);
let r = _mm_srai_epi16::<1>(_mm_set1_epi16(-1));
assert_eq_m128i(r, _mm_set1_epi16(-1));
}

Expand All @@ -3608,7 +3500,7 @@ mod tests {

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srai_epi32() {
let r = _mm_srai_epi32(_mm_set1_epi32(-1), 1);
let r = _mm_srai_epi32::<1>(_mm_set1_epi32(-1));
assert_eq_m128i(r, _mm_set1_epi32(-1));
}

Expand Down Expand Up @@ -4107,23 +3999,23 @@ mod tests {
#[simd_test(enable = "sse2")]
unsafe fn test_mm_shuffle_epi32() {
let a = _mm_setr_epi32(5, 10, 15, 20);
let r = _mm_shuffle_epi32(a, 0b00_01_01_11);
let r = _mm_shuffle_epi32::<0b00_01_01_11>(a);
let e = _mm_setr_epi32(20, 10, 10, 5);
assert_eq_m128i(r, e);
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_shufflehi_epi16() {
let a = _mm_setr_epi16(1, 2, 3, 4, 5, 10, 15, 20);
let r = _mm_shufflehi_epi16(a, 0b00_01_01_11);
let r = _mm_shufflehi_epi16::<0b00_01_01_11>(a);
let e = _mm_setr_epi16(1, 2, 3, 4, 20, 10, 10, 5);
assert_eq_m128i(r, e);
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_shufflelo_epi16() {
let a = _mm_setr_epi16(5, 10, 15, 20, 1, 2, 3, 4);
let r = _mm_shufflelo_epi16(a, 0b00_01_01_11);
let r = _mm_shufflelo_epi16::<0b00_01_01_11>(a);
let e = _mm_setr_epi16(20, 10, 10, 5, 1, 2, 3, 4);
assert_eq_m128i(r, e);
}
Expand Down