Skip to content

Commit

Permalink
add simd strstr
Browse files Browse the repository at this point in the history
  • Loading branch information
skadilover committed Aug 27, 2024
1 parent 763c19c commit a5994bc
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 0 deletions.
137 changes: 137 additions & 0 deletions velox/common/base/SimdUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,141 @@ bool initializeSimdUtil() {

static bool FB_ANONYMOUS_VARIABLE(g_simdConstants) = initializeSimdUtil();

namespace detail {

#if XSIMD_WITH_SSE4_2
using CharVector = xsimd::batch<uint8_t, xsimd::sse4_2>;
#elif XSIMD_WITH_NEON
using CharVector = xsimd::batch<uint8_t, xsimd::neon>;
#endif

const int kPageSize = sysconf(_SC_PAGESIZE);
FOLLY_ALWAYS_INLINE bool pageSafe(const void* const ptr) {
return ((kPageSize - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <=
kPageSize - CharVector::size;
}

template <bool compiled, size_t compiledNeedleSize>
size_t FOLLY_ALWAYS_INLINE smidStrstrMemcmp(
const char* s,
size_t n,
const char* needle,
size_t needleSize) {
static_assert(compiledNeedleSize >= 2);
VELOX_CHECK_GT(needleSize, 1);
VELOX_CHECK_GT(n, 0);
auto first = CharVector::broadcast(needle[0]);
auto last = CharVector::broadcast(needle[needleSize - 1]);
size_t i = 0;
// Fast path for page-safe data.
// It`s safe to over-read CharVector if all-data are in same page.
// see: https://mudongliang.github.io/x86/html/file_module_x86_id_208.html
// While executing in 16-bit addressing mode, a linear address for a 128-bit
// data access that overlaps the end of a 16-bit segment is not allowed and is
// defined as reserved behavior. A specific processor implementation may or
// may not generate a general-protection exception (#GP) in this situation,
// and the address that spans the end of the segment may or may not wrap
// around to the beginning of the segment.
for (; i <= n - needleSize && pageSafe(s + i + needleSize - 1) &&
pageSafe(s + i);
i += CharVector::size) {
auto blockFirst = CharVector::load_unaligned(s + i);
auto blockLast = CharVector::load_unaligned(s + i + needleSize - 1);

const auto eqFirst = (first == blockFirst);
const auto eqLast = (last == blockLast);

auto mask = toBitMask(eqFirst && eqLast);

while (mask != 0) {
const auto bitpos = __builtin_ctz(mask);
if constexpr (compiled) {
if constexpr (compiledNeedleSize == 2) {
return i + bitpos;
}
if (memcmp(s + i + bitpos + 1, needle + 1, compiledNeedleSize - 2) ==
0) {
return i + bitpos;
}
} else {
if (memcmp(s + i + bitpos + 1, needle + 1, needleSize - 2) == 0) {
return i + bitpos;
}
}
mask = mask & (mask - 1);
}
}
// Fallback path for generic path.
for (; i <= n - needleSize; ++i) {
if constexpr (compiled) {
if (memcmp(s + i, needle, compiledNeedleSize) == 0) {
return i;
}
} else {
if (memcmp(s + i, needle, needleSize) == 0) {
return i;
}
}
}

return std::string::npos;
};

} // namespace detail

/// A faster implementation for c_strstr(), about 2x faster than string_view`s
/// find(), proved by TpchLikeBenchmark. Use xsmid-batch to compare first&&last
/// char first, use fixed-memcmp to compare left chars. Inline in header file
/// will be a little faster.
size_t simdStrstr(const char* s, size_t n, const char* needle, size_t k) {
size_t result = std::string::npos;

if (n < k) {
return result;
}

switch (k) {
case 0:
return 0;

case 1: {
const char* res = strchr(s, needle[0]);

return (res != nullptr) ? res - s : std::string::npos;
}
#define FIXED_MEM_STRSTR(size) \
case size: \
result = detail::smidStrstrMemcmp<true, size>(s, n, needle, size); \
break;
FIXED_MEM_STRSTR(2)
FIXED_MEM_STRSTR(3)
FIXED_MEM_STRSTR(4)
FIXED_MEM_STRSTR(5)
FIXED_MEM_STRSTR(6)
FIXED_MEM_STRSTR(7)
FIXED_MEM_STRSTR(8)
FIXED_MEM_STRSTR(9)
FIXED_MEM_STRSTR(10)
FIXED_MEM_STRSTR(11)
FIXED_MEM_STRSTR(12)
FIXED_MEM_STRSTR(13)
FIXED_MEM_STRSTR(14)
FIXED_MEM_STRSTR(15)
FIXED_MEM_STRSTR(16)
FIXED_MEM_STRSTR(17)
FIXED_MEM_STRSTR(18)
default:
result = detail::smidStrstrMemcmp<false, 2>(s, n, needle, k);
break;
}
#undef FIXED_MEM_STRSTR
// load_unaligned is used for better performance, so result maybe bigger than
// n-k.
if (result <= n - k) {
return result;
} else {
return std::string::npos;
}
}

} // namespace facebook::velox::simd
2 changes: 2 additions & 0 deletions velox/common/base/SimdUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ xsimd::batch<T, A> reinterpretBatch(xsimd::batch<U, A>, const A& = {});
template <typename A = xsimd::default_arch>
inline bool memEqualUnsafe(const void* x, const void* y, int32_t size);

size_t simdStrstr(const char* s, size_t n, const char* needle, size_t k);

} // namespace facebook::velox::simd

#include "velox/common/base/SimdUtil-inl.h"
21 changes: 21 additions & 0 deletions velox/common/base/tests/SimdUtilTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,4 +491,25 @@ TEST_F(SimdUtilTest, memcpyTime) {
LOG(INFO) << "simd=" << simd << " sys=" << sys;
}

TEST_F(SimdUtilTest, testSimdStrStr) {
// 48 chars.
std::string s1 = "aabbccddeeffgghhiijjkkllmmnnooppqqrrssttuuvvwwxxyyzz";
std::string s2 = "aabbccddeeffgghhiijjkkllmmnnooppqqrrssttuuvvwwxxyyzz";
std::string s3 = "xxx";
auto test = [](char* text, size_t size, char* needle, size_t k) {
ASSERT_EQ(
simd::simdStrstr(text, size, needle, k),
std::string_view(text, size).find(std::string_view(needle, k)));
};
// Match cases : substrings in s2 should be a substring in s1.
for (int i = 0; i < 20; i++) {
for (int k = 0; k < 28; k++) {
char* data = s2.data() + i;
test(s1.data(), s1.size(), data, k);
}
}
// Not match case : "xxx" not in s1.
test(s1.data(), s1.size(), s3.data(), s3.size());
}

} // namespace

0 comments on commit a5994bc

Please sign in to comment.