From a5994bc3d1120ebffece2ee74faee715a78ac975 Mon Sep 17 00:00:00 2001 From: "hengjiang.ly" Date: Tue, 6 Aug 2024 14:35:18 +0800 Subject: [PATCH] add simd strstr --- velox/common/base/SimdUtil.cpp | 137 +++++++++++++++++++++++ velox/common/base/SimdUtil.h | 2 + velox/common/base/tests/SimdUtilTest.cpp | 21 ++++ 3 files changed, 160 insertions(+) diff --git a/velox/common/base/SimdUtil.cpp b/velox/common/base/SimdUtil.cpp index 03576ac31ec43..3f7d0de91d5a6 100644 --- a/velox/common/base/SimdUtil.cpp +++ b/velox/common/base/SimdUtil.cpp @@ -112,4 +112,141 @@ bool initializeSimdUtil() { static bool FB_ANONYMOUS_VARIABLE(g_simdConstants) = initializeSimdUtil(); +namespace detail { + +#if XSIMD_WITH_SSE4_2 +using CharVector = xsimd::batch; +#elif XSIMD_WITH_NEON +using CharVector = xsimd::batch; +#endif + +const int kPageSize = sysconf(_SC_PAGESIZE); +FOLLY_ALWAYS_INLINE bool pageSafe(const void* const ptr) { + return ((kPageSize - 1) & reinterpret_cast(ptr)) <= + kPageSize - CharVector::size; +} + +template +size_t FOLLY_ALWAYS_INLINE smidStrstrMemcmp( + const char* s, + size_t n, + const char* needle, + size_t needleSize) { + static_assert(compiledNeedleSize >= 2); + VELOX_CHECK_GT(needleSize, 1); + VELOX_CHECK_GT(n, 0); + auto first = CharVector::broadcast(needle[0]); + auto last = CharVector::broadcast(needle[needleSize - 1]); + size_t i = 0; + // Fast path for page-safe data. + // It`s safe to over-read CharVector if all-data are in same page. + // see: https://mudongliang.github.io/x86/html/file_module_x86_id_208.html + // While executing in 16-bit addressing mode, a linear address for a 128-bit + // data access that overlaps the end of a 16-bit segment is not allowed and is + // defined as reserved behavior. A specific processor implementation may or + // may not generate a general-protection exception (#GP) in this situation, + // and the address that spans the end of the segment may or may not wrap + // around to the beginning of the segment. + for (; i <= n - needleSize && pageSafe(s + i + needleSize - 1) && + pageSafe(s + i); + i += CharVector::size) { + auto blockFirst = CharVector::load_unaligned(s + i); + auto blockLast = CharVector::load_unaligned(s + i + needleSize - 1); + + const auto eqFirst = (first == blockFirst); + const auto eqLast = (last == blockLast); + + auto mask = toBitMask(eqFirst && eqLast); + + while (mask != 0) { + const auto bitpos = __builtin_ctz(mask); + if constexpr (compiled) { + if constexpr (compiledNeedleSize == 2) { + return i + bitpos; + } + if (memcmp(s + i + bitpos + 1, needle + 1, compiledNeedleSize - 2) == + 0) { + return i + bitpos; + } + } else { + if (memcmp(s + i + bitpos + 1, needle + 1, needleSize - 2) == 0) { + return i + bitpos; + } + } + mask = mask & (mask - 1); + } + } + // Fallback path for generic path. + for (; i <= n - needleSize; ++i) { + if constexpr (compiled) { + if (memcmp(s + i, needle, compiledNeedleSize) == 0) { + return i; + } + } else { + if (memcmp(s + i, needle, needleSize) == 0) { + return i; + } + } + } + + return std::string::npos; +}; + +} // namespace detail + +/// A faster implementation for c_strstr(), about 2x faster than string_view`s +/// find(), proved by TpchLikeBenchmark. Use xsmid-batch to compare first&&last +/// char first, use fixed-memcmp to compare left chars. Inline in header file +/// will be a little faster. +size_t simdStrstr(const char* s, size_t n, const char* needle, size_t k) { + size_t result = std::string::npos; + + if (n < k) { + return result; + } + + switch (k) { + case 0: + return 0; + + case 1: { + const char* res = strchr(s, needle[0]); + + return (res != nullptr) ? res - s : std::string::npos; + } +#define FIXED_MEM_STRSTR(size) \ + case size: \ + result = detail::smidStrstrMemcmp(s, n, needle, size); \ + break; + FIXED_MEM_STRSTR(2) + FIXED_MEM_STRSTR(3) + FIXED_MEM_STRSTR(4) + FIXED_MEM_STRSTR(5) + FIXED_MEM_STRSTR(6) + FIXED_MEM_STRSTR(7) + FIXED_MEM_STRSTR(8) + FIXED_MEM_STRSTR(9) + FIXED_MEM_STRSTR(10) + FIXED_MEM_STRSTR(11) + FIXED_MEM_STRSTR(12) + FIXED_MEM_STRSTR(13) + FIXED_MEM_STRSTR(14) + FIXED_MEM_STRSTR(15) + FIXED_MEM_STRSTR(16) + FIXED_MEM_STRSTR(17) + FIXED_MEM_STRSTR(18) + default: + result = detail::smidStrstrMemcmp(s, n, needle, k); + break; + } +#undef FIXED_MEM_STRSTR + // load_unaligned is used for better performance, so result maybe bigger than + // n-k. + if (result <= n - k) { + return result; + } else { + return std::string::npos; + } +} + } // namespace facebook::velox::simd diff --git a/velox/common/base/SimdUtil.h b/velox/common/base/SimdUtil.h index 9a6ad0c374253..ba63d3c1d2374 100644 --- a/velox/common/base/SimdUtil.h +++ b/velox/common/base/SimdUtil.h @@ -497,6 +497,8 @@ xsimd::batch reinterpretBatch(xsimd::batch, const A& = {}); template inline bool memEqualUnsafe(const void* x, const void* y, int32_t size); +size_t simdStrstr(const char* s, size_t n, const char* needle, size_t k); + } // namespace facebook::velox::simd #include "velox/common/base/SimdUtil-inl.h" diff --git a/velox/common/base/tests/SimdUtilTest.cpp b/velox/common/base/tests/SimdUtilTest.cpp index ba389780b1cba..9dbebc060fb32 100644 --- a/velox/common/base/tests/SimdUtilTest.cpp +++ b/velox/common/base/tests/SimdUtilTest.cpp @@ -491,4 +491,25 @@ TEST_F(SimdUtilTest, memcpyTime) { LOG(INFO) << "simd=" << simd << " sys=" << sys; } +TEST_F(SimdUtilTest, testSimdStrStr) { + // 48 chars. + std::string s1 = "aabbccddeeffgghhiijjkkllmmnnooppqqrrssttuuvvwwxxyyzz"; + std::string s2 = "aabbccddeeffgghhiijjkkllmmnnooppqqrrssttuuvvwwxxyyzz"; + std::string s3 = "xxx"; + auto test = [](char* text, size_t size, char* needle, size_t k) { + ASSERT_EQ( + simd::simdStrstr(text, size, needle, k), + std::string_view(text, size).find(std::string_view(needle, k))); + }; + // Match cases : substrings in s2 should be a substring in s1. + for (int i = 0; i < 20; i++) { + for (int k = 0; k < 28; k++) { + char* data = s2.data() + i; + test(s1.data(), s1.size(), data, k); + } + } + // Not match case : "xxx" not in s1. + test(s1.data(), s1.size(), s3.data(), s3.size()); +} + } // namespace