Skip to content

Commit

Permalink
Fix for buffer overrun in bit_util
Browse files Browse the repository at this point in the history
  • Loading branch information
michalursa committed May 5, 2022
1 parent d632536 commit 4116d32
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
71 changes: 61 additions & 10 deletions cpp/src/arrow/compute/exec/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,38 @@ using bit_util::CountTrailingZeros;

namespace util {

inline uint64_t bit_util::SafeLoadUpTo8Bytes(const uint8_t* bytes, int num_bytes) {
// This will not be correct on big-endian architectures.
#if !ARROW_LITTLE_ENDIAN
ARROW_DCHECK(false);
#endif
ARROW_DCHECK(num_bytes >= 0 && num_bytes <= 8);
if (num_bytes == 8) {
return util::SafeLoad(reinterpret_cast<const uint64_t*>(bytes));
} else {
uint64_t word = 0;
for (int i = 0; i < num_bytes; ++i) {
word |= static_cast<uint64_t>(bytes[i]) << (8 * i);
}
return word;
}
}

inline void bit_util::SafeStoreUpTo8Bytes(uint8_t* bytes, int num_bytes, uint64_t value) {
// This will not be correct on big-endian architectures.
#if !ARROW_LITTLE_ENDIAN
ARROW_DCHECK(false);
#endif
ARROW_DCHECK(num_bytes >= 0 && num_bytes <= 8);
if (num_bytes == 8) {
util::SafeStore(reinterpret_cast<uint64_t*>(bytes), value);
} else {
for (int i = 0; i < num_bytes; ++i) {
bytes[i] = static_cast<uint8_t>(value >> (8 * i));
}
}
}

inline void bit_util::bits_to_indexes_helper(uint64_t word, uint16_t base_index,
int* num_indexes, uint16_t* indexes) {
int n = *num_indexes;
Expand Down Expand Up @@ -86,8 +118,8 @@ void bit_util::bits_to_indexes_internal(int64_t hardware_flags, const int num_bi
#endif
// Optionally process the last partial word with masking out bits outside range
if (tail) {
uint64_t word =
util::SafeLoad(&reinterpret_cast<const uint64_t*>(bits)[num_bits / unroll]);
const uint8_t* bits_tail = bits + (num_bits - tail) / 8;
uint64_t word = SafeLoadUpTo8Bytes(bits_tail, (tail + 7) / 8);
if (bit_to_search == 0) {
word = ~word;
}
Expand All @@ -109,8 +141,7 @@ void bit_util::bits_to_indexes(int bit_to_search, int64_t hardware_flags, int nu
*num_indexes = 0;
uint16_t base_index = 0;
if (bit_offset != 0) {
uint64_t bits_head =
util::SafeLoad(reinterpret_cast<const uint64_t*>(bits)) >> bit_offset;
uint64_t bits_head = bits[0] >> bit_offset;
int bits_in_first_byte = std::min(num_bits, 8 - bit_offset);
bits_to_indexes(bit_to_search, hardware_flags, bits_in_first_byte,
reinterpret_cast<const uint8_t*>(&bits_head), num_indexes, indexes);
Expand Down Expand Up @@ -143,8 +174,7 @@ void bit_util::bits_filter_indexes(int bit_to_search, int64_t hardware_flags,
bit_offset %= 8;
if (bit_offset != 0) {
int num_indexes_head = 0;
uint64_t bits_head =
util::SafeLoad(reinterpret_cast<const uint64_t*>(bits)) >> bit_offset;
uint64_t bits_head = bits[0] >> bit_offset;
int bits_in_first_byte = std::min(num_bits, 8 - bit_offset);
bits_filter_indexes(bit_to_search, hardware_flags, bits_in_first_byte,
reinterpret_cast<const uint8_t*>(&bits_head), input_indexes,
Expand Down Expand Up @@ -185,8 +215,7 @@ void bit_util::bits_to_bytes(int64_t hardware_flags, const int num_bits,
bits += bit_offset / 8;
bit_offset %= 8;
if (bit_offset != 0) {
uint64_t bits_head =
util::SafeLoad(reinterpret_cast<const uint64_t*>(bits)) >> bit_offset;
uint64_t bits_head = bits[0] >> bit_offset;
int bits_in_first_byte = std::min(num_bits, 8 - bit_offset);
bits_to_bytes(hardware_flags, bits_in_first_byte,
reinterpret_cast<const uint8_t*>(&bits_head), bytes);
Expand All @@ -207,7 +236,7 @@ void bit_util::bits_to_bytes(int64_t hardware_flags, const int num_bits,
#endif
// Processing 8 bits at a time
constexpr int unroll = 8;
for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) {
for (int i = num_processed / unroll; i < num_bits / unroll; ++i) {
uint8_t bits_next = bits[i];
// Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart
// from the previous.
Expand All @@ -219,6 +248,19 @@ void bit_util::bits_to_bytes(int64_t hardware_flags, const int num_bits,
unpacked *= 255;
util::SafeStore(&reinterpret_cast<uint64_t*>(bytes)[i], unpacked);
}
int tail = num_bits % unroll;
if (tail) {
uint8_t bits_next = bits[(num_bits - tail) / unroll];
// Clear the lowest bit and then make 8 copies of remaining 7 bits, each 7 bits apart
// from the previous.
uint64_t unpacked = static_cast<uint64_t>(bits_next & 0xfe) *
((1ULL << 7) | (1ULL << 14) | (1ULL << 21) | (1ULL << 28) |
(1ULL << 35) | (1ULL << 42) | (1ULL << 49));
unpacked |= (bits_next & 1);
unpacked &= 0x0101010101010101ULL;
unpacked *= 255;
SafeStoreUpTo8Bytes(bytes + num_bits - tail, tail, unpacked);
}
}

void bit_util::bytes_to_bits(int64_t hardware_flags, const int num_bits,
Expand Down Expand Up @@ -250,14 +292,23 @@ void bit_util::bytes_to_bits(int64_t hardware_flags, const int num_bits,
#endif
// Process 8 bits at a time
constexpr int unroll = 8;
for (int i = num_processed / unroll; i < (num_bits + unroll - 1) / unroll; ++i) {
for (int i = num_processed / unroll; i < num_bits / unroll; ++i) {
uint64_t bytes_next = util::SafeLoad(&reinterpret_cast<const uint64_t*>(bytes)[i]);
bytes_next &= 0x0101010101010101ULL;
bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes
bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes
bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte
bits[i] = static_cast<uint8_t>(bytes_next & 0xff);
}
int tail = num_bits % unroll;
if (tail) {
uint64_t bytes_next = SafeLoadUpTo8Bytes(bytes + num_bits - tail, tail);
bytes_next &= 0x0101010101010101ULL;
bytes_next |= (bytes_next >> 7); // Pairs of adjacent output bits in individual bytes
bytes_next |= (bytes_next >> 14); // 4 adjacent output bits in individual bytes
bytes_next |= (bytes_next >> 28); // All 8 output bits in the lowest byte
bits[num_bits / 8] = static_cast<uint8_t>(bytes_next & 0xff);
}
}

bool bit_util::are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes,
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/exec/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class bit_util {
uint32_t num_bytes);

private:
inline static uint64_t SafeLoadUpTo8Bytes(const uint8_t* bytes, int num_bytes);
inline static void SafeStoreUpTo8Bytes(uint8_t* bytes, int num_bytes, uint64_t value);
inline static void bits_to_indexes_helper(uint64_t word, uint16_t base_index,
int* num_indexes, uint16_t* indexes);
inline static void bits_filter_indexes_helper(uint64_t word,
Expand Down

0 comments on commit 4116d32

Please sign in to comment.