Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Kyber AVX2-serialization documentation. #293

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 155 additions & 29 deletions libcrux-ml-kem/src/vector/avx2/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@ use super::*;

#[inline(always)]
pub(crate) fn serialize_1(vector: Vec256) -> [u8; 2] {
// Suppose |vector| is laid out as follows (superscript number indicates the
// corresponding bit is duplicated that many times):
//
// 0¹⁵a₀ 0¹⁵b₀ 0¹⁵c₀ 0¹⁵d₀ | 0¹⁵e₀ 0¹⁵f₀ 0¹⁵g₀ 0¹⁵h₀ | ...
//
// We care only about the least significant bit in each lane,
// move it to the most significant position to make it easier to work with.
// |vector| now becomes:
//
// a₀0¹⁵ b₀0¹⁵ c₀0¹⁵ d₀0¹⁵ | e₀0¹⁵ f₀0¹⁵ g₀0¹⁵ h₀0¹⁵ | ↩
// i₀0¹⁵ j₀0¹⁵ k₀0¹⁵ l₀0¹⁵ | m₀0¹⁵ n₀0¹⁵ o₀0¹⁵ p₀0¹⁵
let lsb_to_msb = mm256_slli_epi16::<15>(vector);

// Get the first 8 16-bit elements ...
Expand All @@ -15,16 +24,26 @@ pub(crate) fn serialize_1(vector: Vec256) -> [u8; 2] {
// ... and then pack them into 8-bit values using signed saturation.
// This function packs all the |low_msbs|, and then the high ones.
//
// We shifted by 15 above to take advantage of signed saturation:
//
// low_msbs = a₀0¹⁵ b₀0¹⁵ c₀0¹⁵ d₀0¹⁵ | e₀0¹⁵ f₀0¹⁵ g₀0¹⁵ h₀0¹⁵
// high_msbs = i₀0¹⁵ j₀0¹⁵ k₀0¹⁵ l₀0¹⁵ | m₀0¹⁵ n₀0¹⁵ o₀0¹⁵ p₀0¹⁵
//
// We shifted by 15 above to take advantage of the signed saturation performed
// by mm_packs_epi16:
//
// - if the sign bit of the 16-bit element being packed is 1, the
// corresponding 8-bit element in |msbs| will be 0xFF.
// - if the sign bit of the 16-bit element being packed is 0, the
// corresponding 8-bit element in |msbs| will be 0.
//
// Thus, if, for example, a₀ = 1, e₀ = 1, and p₀ = 1, and every other bit
// is 0, after packing into 8 bit value, |msbs| will look like:
//
// 0xFF 0x00 0x00 0x00 | 0xFF 0x00 0x00 0x00 | 0x00 0x00 0x00 0x00 | 0x00 0x00 0x00 0xFF
let msbs = mm_packs_epi16(low_msbs, high_msbs);

// Now that we have all 16 bits we need conveniently placed in one vector,
// extract them into two bytes.
// Now that every element is either 0xFF or 0x00, we just extract the most
// significant bit from each element and collate them into two bytes.
let bits_packed = mm_movemask_epi8(msbs);

let mut serialized = [0u8; 2];
Expand All @@ -41,8 +60,8 @@ pub(crate) fn deserialize_1(bytes: &[u8]) -> Vec256 {
// duplicate them, and right-shift the 0th element by 0 bits,
// the first element by 1 bit, the second by 2 bits and so on before AND-ing
// with 0x1 to leave only the least signifinicant bit.
// But |_mm256_srlv_epi16| does not exist unfortunately, so we have to resort
// to a workaround.
// But since |_mm256_srlv_epi16| does not exist, so we have to resort to a
// workaround.
//
// Rather than shifting each element by a different amount, we'll multiply
// each element by a value such that the bit we're interested in becomes the most
Expand Down Expand Up @@ -161,28 +180,21 @@ pub(crate) fn serialize_4(vector: Vec256) -> [u8; 8] {

#[inline(always)]
pub(crate) fn deserialize_4(bytes: &[u8]) -> Vec256 {
let shift_lsbs_to_msbs = mm256_set_epi16(
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
);

// Every 4 bits from each byte of input should be put into its own 16-bit lane.
// Since |_mm256_srlv_epi16| does not exist, we have to resort to a workaround.
//
// Rather than shifting each element by a different amount, we'll multiply
// each element by a value such that the bits we're interested in become the most
// significant bits (of an 8-bit value).
let coefficients = mm256_set_epi16(
// In this lane, the 4 bits we need to put are already the most
// significant bits of |bytes[7]|.
bytes[7] as i16,
// In this lane, the 4 bits we need to put are the least significant bits,
// so we need to shift the 4 least-significant bits of |bytes[7]| to the
// most significant bits (of an 8-bit value).
bytes[7] as i16,
// and so on ...
bytes[6] as i16,
bytes[6] as i16,
bytes[5] as i16,
Expand All @@ -199,16 +211,53 @@ pub(crate) fn deserialize_4(bytes: &[u8]) -> Vec256 {
bytes[0] as i16,
);

let shift_lsbs_to_msbs = mm256_set_epi16(
// These constants are chosen to shift the bits of the values
// that we loaded into |coefficients|.
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
);

let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs);

// Once the 4-bit coefficients are in the most significant positions (of
// an 8-bit value), shift them all down by 4.
let coefficients_in_lsb = mm256_srli_epi16::<4>(coefficients_in_msb);

// Zero the remaining bits.
mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 4) - 1))
}

#[inline(always)]
pub(crate) fn serialize_5(vector: Vec256) -> [u8; 10] {
let mut serialized = [0u8; 32];

// If |vector| is laid out as follows (superscript number indicates the
// corresponding bit is duplicated that many times):
//
// 0¹¹a₄a₃a₂a₁a₀ 0¹¹b₄b₃b₂b₁b₀ 0¹¹c₄c₃c₂c₁c₀ 0¹¹d₄d₃d₂d₁d₀ | ↩
// 0¹¹e₄e₃e₂e₁e₀ 0¹¹f₄f₃f₂f₁f₀ 0¹¹g₄g₃g₂g₁g₀ 0¹¹h₄h₃h₂h₁h₀ | ↩
//
// |adjacent_2_combined| will be laid out as a series of 32-bit integers,
// as follows:
//
// 0²²b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩
// 0²²f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩
// ....
let adjacent_2_combined = mm256_madd_epi16(
vector,
mm256_set_epi16(
Expand All @@ -231,23 +280,60 @@ pub(crate) fn serialize_5(vector: Vec256) -> [u8; 10] {
),
);

// Recall that |adjacent_2_combined| is laid out as follows:
//
// 0²²b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩
// 0²²f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩
// ....
//
// This shift results in:
//
// b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀0²² 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩
// f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀0²² 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩
// ....
//
let adjacent_4_combined = mm256_sllv_epi32(
adjacent_2_combined,
mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22),
);

// |adjacent_4_combined|, when viewed as 64-bit lanes, is:
//
// 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀0²² | ↩
// 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀0²² | ↩
// ...
//
// so we just shift down by 22 bits to remove the least significant 0 bits
// that aren't part of the bits we need.
let adjacent_4_combined = mm256_srli_epi64::<22>(adjacent_4_combined);

// |adjacent_4_combined|, when viewed as a set of 32-bit values, looks like:
//
// 0:0¹²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 1:0³² 2:0¹²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 3:0³² | ↩
//
// To be able to read out the bytes in one go, we need to shifts the bits in
// position 2 to position 1 in each 128-bit lane.
let adjacent_8_combined = mm256_shuffle_epi32::<0b00_00_10_00>(adjacent_4_combined);

// |adjacent_8_combined|, when viewed as a set of 32-bit values, now looks like:
//
// 0¹²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0¹²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0³² 0³² | ↩
//
// Once again, we line these bits up by shifting the up values at indices
// 0 and 5 by 12, viewing the resulting register as a set of 64-bit values,
// and then shifting down the 64-bit values by 12 bits.
let adjacent_8_combined = mm256_sllv_epi32(
adjacent_8_combined,
mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12),
mm256_set_epi32(0, 0, 0, 12, 0, 0, 0, 12),
);
let adjacent_8_combined = mm256_srli_epi64::<12>(adjacent_8_combined);

// We now have 40 bits starting at position 0 in the lower 128-bit lane, ...
let lower_8 = mm256_castsi256_si128(adjacent_8_combined);
let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined);

mm_storeu_bytes_si128(&mut serialized[0..16], lower_8);

// ... and the second 40 bits at position 0 in the upper 128-bit lane
let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined);
mm_storeu_bytes_si128(&mut serialized[5..21], upper_8);

serialized[0..10].try_into().unwrap()
Expand Down Expand Up @@ -299,6 +385,19 @@ pub(crate) fn deserialize_5(bytes: &[u8]) -> Vec256 {
pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] {
let mut serialized = [0u8; 32];

// If |vector| is laid out as follows (superscript number indicates the
// corresponding bit is duplicated that many times):
//
// 0⁶a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ 0⁶b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀ 0⁶c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ 0⁶d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀ | ↩
// 0⁶e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ 0⁶f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀ 0⁶g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ 0⁶h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀ | ↩
// ...
//
// |adjacent_2_combined| will be laid out as a series of 32-bit integers,
// as follows:
//
// 0¹²b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ | ↩
// 0¹²f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ | ↩
// ....
let adjacent_2_combined = mm256_madd_epi16(
vector,
mm256_set_epi16(
Expand All @@ -321,12 +420,37 @@ pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] {
),
);

// Shifting up the values at the even indices by 12, we get:
//
// b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀0¹² 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ | ↩
// f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀0¹² 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ | ↩
// ...
let adjacent_4_combined = mm256_sllv_epi32(
adjacent_2_combined,
mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12),
);

// Viewing this as a set of 64-bit integers we get:
//
// 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀0¹² | ↩
// 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀0¹² | ↩
// ...
//
// Shifting down by 12 gives us:
//
// 0²⁴d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ | ↩
// 0²⁴h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ | ↩
// ...
let adjacent_4_combined = mm256_srli_epi64::<12>(adjacent_4_combined);

// |adjacent_4_combined|, when the bottom and top 128 bit-lanes are grouped
// into bytes, looks like:
//
// 0₇0₆0₅B₄B₃B₂B₁B₀ | ↩
// 0₁₅0₁₄0₁₃B₁₂B₁₁B₁₀B₉B₈ | ↩
//
// In each 128-bit lane, we want to put bytes 8, 9, 10, 11, 12 after
// bytes 0, 1, 2, 3 to allow for sequential reading.
let adjacent_8_combined = mm256_shuffle_epi8(
adjacent_4_combined,
mm256_set_epi8(
Expand All @@ -335,10 +459,12 @@ pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] {
),
);

// We now have 64 bits starting at position 0 in the lower 128-bit lane, ...
let lower_8 = mm256_castsi256_si128(adjacent_8_combined);
let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined);

mm_storeu_bytes_si128(&mut serialized[0..16], lower_8);

// and 64 bits starting at position 0 in the upper 128-bit lane.
let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined);
mm_storeu_bytes_si128(&mut serialized[10..26], upper_8);

serialized[0..20].try_into().unwrap()
Expand Down
Loading