From 4116d322ab040d85fdc1443dcd7487da0e7d196e Mon Sep 17 00:00:00 2001 From: michalursa Date: Wed, 4 May 2022 17:32:49 -0700 Subject: [PATCH] Fix for buffer overrun in bit_util --- cpp/src/arrow/compute/exec/util.cc | 71 +++++++++++++++++++++++++----- cpp/src/arrow/compute/exec/util.h | 2 + 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index ef56e6128a363..c38f7803fb937 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -29,6 +29,38 @@ using bit_util::CountTrailingZeros; namespace util { +inline uint64_t bit_util::SafeLoadUpTo8Bytes(const uint8_t* bytes, int num_bytes) { + // This will not be correct on big-endian architectures. +#if !ARROW_LITTLE_ENDIAN + ARROW_DCHECK(false); +#endif + ARROW_DCHECK(num_bytes >= 0 && num_bytes <= 8); + if (num_bytes == 8) { + return util::SafeLoad(reinterpret_cast(bytes)); + } else { + uint64_t word = 0; + for (int i = 0; i < num_bytes; ++i) { + word |= static_cast(bytes[i]) << (8 * i); + } + return word; + } +} + +inline void bit_util::SafeStoreUpTo8Bytes(uint8_t* bytes, int num_bytes, uint64_t value) { + // This will not be correct on big-endian architectures. +#if !ARROW_LITTLE_ENDIAN + ARROW_DCHECK(false); +#endif + ARROW_DCHECK(num_bytes >= 0 && num_bytes <= 8); + if (num_bytes == 8) { + util::SafeStore(reinterpret_cast(bytes), value); + } else { + for (int i = 0; i < num_bytes; ++i) { + bytes[i] = static_cast(value >> (8 * i)); + } + } +} + inline void bit_util::bits_to_indexes_helper(uint64_t word, uint16_t base_index, int* num_indexes, uint16_t* indexes) { int n = *num_indexes; @@ -86,8 +118,8 @@ void bit_util::bits_to_indexes_internal(int64_t hardware_flags, const int num_bi #endif // Optionally process the last partial word with masking out bits outside range if (tail) { - uint64_t word = - util::SafeLoad(&reinterpret_cast(bits)[num_bits / unroll]); + const uint8_t* bits_tail = bits + (num_bits - tail) / 8; + uint64_t word = SafeLoadUpTo8Bytes(bits_tail, (tail + 7) / 8); if (bit_to_search == 0) { word = ~word; } @@ -109,8 +141,7 @@ void bit_util::bits_to_indexes(int bit_to_search, int64_t hardware_flags, int nu *num_indexes = 0; uint16_t base_index = 0; if (bit_offset != 0) { - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + uint64_t bits_head = bits[0] >> bit_offset; int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); bits_to_indexes(bit_to_search, hardware_flags, bits_in_first_byte, reinterpret_cast(&bits_head), num_indexes, indexes); @@ -143,8 +174,7 @@ void bit_util::bits_filter_indexes(int bit_to_search, int64_t hardware_flags, bit_offset %= 8; if (bit_offset != 0) { int num_indexes_head = 0; - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + uint64_t bits_head = bits[0] >> bit_offset; int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); bits_filter_indexes(bit_to_search, hardware_flags, bits_in_first_byte, reinterpret_cast(&bits_head), input_indexes, @@ -185,8 +215,7 @@ void bit_util::bits_to_bytes(int64_t hardware_flags, const int num_bits, bits += bit_offset / 8; bit_offset %= 8; if (bit_offset != 0) { - uint64_t bits_head = - util::SafeLoad(reinterpret_cast(bits)) >> bit_offset; + uint64_t bits_head = bits[0] >> bit_offset; int bits_in_first_byte = std::min(num_bits, 8 - bit_offset); bits_to_bytes(hardware_flags, bits_in_first_byte, reinterpret_cast(&bits_head), bytes); @@ -207,7 +236,7 @@ void bit_util::bits_to_bytes(int64_t hardware_flags, const int num_bits, #endif // Processing 8 bits at a time constexpr int unroll = 8; - for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { + for (int i = num_processed / unroll; i < num_bits / unroll; ++i) { uint8_t bits_next = bits[i]; // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart // from the previous. @@ -219,6 +248,19 @@ void bit_util::bits_to_bytes(int64_t hardware_flags, const int num_bits, unpacked *= 255; util::SafeStore(&reinterpret_cast(bytes)[i], unpacked); } + int tail = num_bits % unroll; + if (tail) { + uint8_t bits_next = bits[(num_bits - tail) / unroll]; + // Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart + // from the previous. + uint64_t unpacked = static_cast(bits_next & 0xfe) * + ((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) | + (1ULL << 35) | (1ULL << 42) | (1ULL << 49)); + unpacked |= (bits_next & 1); + unpacked &= 0x0101010101010101ULL; + unpacked *= 255; + SafeStoreUpTo8Bytes(bytes + num_bits - tail, tail, unpacked); + } } void bit_util::bytes_to_bits(int64_t hardware_flags, const int num_bits, @@ -250,7 +292,7 @@ void bit_util::bytes_to_bits(int64_t hardware_flags, const int num_bits, #endif // Process 8 bits at a time constexpr int unroll = 8; - for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) { + for (int i = num_processed / unroll; i < num_bits / unroll; ++i) { uint64_t bytes_next = util::SafeLoad(&reinterpret_cast(bytes)[i]); bytes_next &= 0x0101010101010101ULL; bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes @@ -258,6 +300,15 @@ void bit_util::bytes_to_bits(int64_t hardware_flags, const int num_bits, bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte bits[i] = static_cast(bytes_next & 0xff); } + int tail = num_bits % unroll; + if (tail) { + uint64_t bytes_next = SafeLoadUpTo8Bytes(bytes + num_bits - tail, tail); + bytes_next &= 0x0101010101010101ULL; + bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes + bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte + bits[num_bits / 8] = static_cast(bytes_next & 0xff); + } } bool bit_util::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index b1a417e1c370d..31549c4ae827f 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -194,6 +194,8 @@ class bit_util { uint32_t num_bytes); private: + inline static uint64_t SafeLoadUpTo8Bytes(const uint8_t* bytes, int num_bytes); + inline static void SafeStoreUpTo8Bytes(uint8_t* bytes, int num_bytes, uint64_t value); inline static void bits_to_indexes_helper(uint64_t word, uint16_t base_index, int* num_indexes, uint16_t* indexes); inline static void bits_filter_indexes_helper(uint64_t word,