Skip to content

Commit

Permalink
opt like
Browse files Browse the repository at this point in the history
fix like

add arch indepent impl
  • Loading branch information
skadilover committed Aug 13, 2024
1 parent f340734 commit ee87495
Show file tree
Hide file tree
Showing 7 changed files with 436 additions and 22 deletions.
102 changes: 102 additions & 0 deletions velox/common/base/FixedMemCompare.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <stdint.h>
#include "folly/CPortability.h"

namespace facebook::velox::simd {

bool FOLLY_ALWAYS_INLINE alwaysTrue(const char*, const char*) {
return true;
}

bool FOLLY_ALWAYS_INLINE memcmp1(const char* a, const char* b) {
return a[0] == b[0];
}

bool FOLLY_ALWAYS_INLINE memcmp2(const char* a, const char* b) {
const uint16_t A = *reinterpret_cast<const uint16_t*>(a);
const uint16_t B = *reinterpret_cast<const uint16_t*>(b);
return A == B;
}

bool FOLLY_ALWAYS_INLINE memcmp3(const char* a, const char* b) {
const uint32_t A = *reinterpret_cast<const uint32_t*>(a);
const uint32_t B = *reinterpret_cast<const uint32_t*>(b);
return (A & 0x00ffffff) == (B & 0x00ffffff);
}

bool FOLLY_ALWAYS_INLINE memcmp4(const char* a, const char* b) {
const uint32_t A = *reinterpret_cast<const uint32_t*>(a);
const uint32_t B = *reinterpret_cast<const uint32_t*>(b);
return A == B;
}

bool FOLLY_ALWAYS_INLINE memcmp5(const char* a, const char* b) {
const uint64_t A = *reinterpret_cast<const uint64_t*>(a);
const uint64_t B = *reinterpret_cast<const uint64_t*>(b);
return ((A ^ B) & 0x000000fffffffffflu) == 0;
}

bool FOLLY_ALWAYS_INLINE memcmp6(const char* a, const char* b) {
const uint64_t A = *reinterpret_cast<const uint64_t*>(a);
const uint64_t B = *reinterpret_cast<const uint64_t*>(b);
return ((A ^ B) & 0x0000fffffffffffflu) == 0;
}

bool FOLLY_ALWAYS_INLINE memcmp7(const char* a, const char* b) {
const uint64_t A = *reinterpret_cast<const uint64_t*>(a);
const uint64_t B = *reinterpret_cast<const uint64_t*>(b);
return ((A ^ B) & 0x00fffffffffffffflu) == 0;
}

bool FOLLY_ALWAYS_INLINE memcmp8(const char* a, const char* b) {
const uint64_t A = *reinterpret_cast<const uint64_t*>(a);
const uint64_t B = *reinterpret_cast<const uint64_t*>(b);
return A == B;
}

bool FOLLY_ALWAYS_INLINE memcmp9(const char* a, const char* b) {
const uint64_t A = *reinterpret_cast<const uint64_t*>(a);
const uint64_t B = *reinterpret_cast<const uint64_t*>(b);
return (A == B) & (a[8] == b[8]);
}

bool FOLLY_ALWAYS_INLINE memcmp10(const char* a, const char* b) {
const uint64_t Aq = *reinterpret_cast<const uint64_t*>(a);
const uint64_t Bq = *reinterpret_cast<const uint64_t*>(b);
const uint16_t Aw = *reinterpret_cast<const uint16_t*>(a + 8);
const uint16_t Bw = *reinterpret_cast<const uint16_t*>(b + 8);
return (Aq == Bq) & (Aw == Bw);
}

bool FOLLY_ALWAYS_INLINE memcmp11(const char* a, const char* b) {
const uint64_t Aq = *reinterpret_cast<const uint64_t*>(a);
const uint64_t Bq = *reinterpret_cast<const uint64_t*>(b);
const uint32_t Ad = *reinterpret_cast<const uint32_t*>(a + 8);
const uint32_t Bd = *reinterpret_cast<const uint32_t*>(b + 8);
return (Aq == Bq) & ((Ad & 0x00ffffff) == (Bd & 0x00ffffff));
}

bool FOLLY_ALWAYS_INLINE memcmp12(const char* a, const char* b) {
const uint64_t Aq = *reinterpret_cast<const uint64_t*>(a);
const uint64_t Bq = *reinterpret_cast<const uint64_t*>(b);
const uint32_t Ad = *reinterpret_cast<const uint32_t*>(a + 8);
const uint32_t Bd = *reinterpret_cast<const uint32_t*>(b + 8);
return (Aq == Bq) & (Ad == Bd);
}

} // namespace facebook::velox::simd
185 changes: 185 additions & 0 deletions velox/common/base/SimdUtil-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1436,4 +1436,189 @@ inline bool memEqualUnsafe(const void* x, const void* y, int32_t size) {
return true;
}

namespace detail {
template <typename T>
T clearLeftmostSet(const T value) {
assert(value != 0);

return value & (value - 1);
}

template <typename T>
unsigned FOLLY_ALWAYS_INLINE getFirstBitSet(const T value) {
assert(value != 0);

return __builtin_ctz(value);
}

template <>
unsigned FOLLY_ALWAYS_INLINE getFirstBitSet<uint64_t>(const uint64_t value) {
assert(value != 0);

return __builtin_ctzl(value);
}

#if XSIMD_WITH_AVX2
// AVX2 is faster than sse2
#define SIMD_STRSTR
using CharVector = xsimd::batch<uint8_t, xsimd::avx2>;
#elif XSIMD_WITH_NEON
#define SIMD_STRSTR
using CharVector = xsimd::batch<uint8_t, xsimd::neon>;
#endif

#ifdef SIMD_STRSTR
size_t FOLLY_ALWAYS_INLINE
smidStrstrAnysize(const char* s, size_t n, const char* needle, size_t k) {
const auto first = CharVector::broadcast(needle[0]);
const auto last = CharVector::broadcast(needle[k - 1]);

for (size_t i = 0; i < n; i += CharVector::size) {
const auto block_first = CharVector::load_unaligned(s + i);
const auto block_last = CharVector::load_unaligned(s + i + k - 1);

const auto eq_first = (first == block_first);
const auto eq_last = (last == block_last);

auto mask = toBitMask(eq_first && eq_last);
;

while (mask != 0) {
const auto bitpos = detail::getFirstBitSet(mask);

if (memcmp(s + i + bitpos + 1, needle + 1, k - 2) == 0) {
return i + bitpos;
}

mask = detail::clearLeftmostSet(mask);
}
}

return std::string::npos;
}

template <size_t k, typename MEMCMP>
size_t FOLLY_ALWAYS_INLINE smidStrstrMemcmp(
const char* s,
size_t n,
const char* needle,
MEMCMP memcmp_fun) {
assert(k > 0);
assert(n > 0);

auto first = CharVector::broadcast(needle[0]);
auto last = CharVector::broadcast(needle[k - 1]);
for (size_t i = 0; i < n; i += CharVector::size) {
auto block_first = CharVector::load_unaligned(s + i);
auto block_last = CharVector::load_unaligned(s + i + k - 1);

const auto eq_first = (first == block_first);
const auto eq_last = (last == block_last);

auto mask = toBitMask(eq_first && eq_last);

while (mask != 0) {
const auto bitpos = detail::getFirstBitSet(mask);

if (memcmp_fun(s + i + bitpos + 1, needle + 1)) {
return i + bitpos;
}

mask = detail::clearLeftmostSet(mask);
}
}

return std::string::npos;
};
} // namespace detail
#endif

/// A faster implementation for c_strstr(), about 2x faster than string_view`s
/// find(), proved by TpchLikeBenchmark. Use xsmid-batch to compare first&&last
/// char first, use fixed-memcmp to compare left chars. Inline in header file
/// will be a little faster.
size_t FOLLY_ALWAYS_INLINE
simdStrstr(const char* s, size_t n, const char* needle, size_t k) {
#ifdef SIMD_STRSTR
size_t result = std::string::npos;

if (n < k) {
return result;
}

switch (k) {
case 0:
return 0;

case 1: {
const char* res = strchr(s, needle[0]);

return (res != nullptr) ? res - s : std::string::npos;
}

case 2:
result = detail::smidStrstrMemcmp<2>(s, n, needle, alwaysTrue);
break;

case 3:
result = detail::smidStrstrMemcmp<3>(s, n, needle, memcmp1);
break;

case 4:
result = detail::smidStrstrMemcmp<4>(s, n, needle, memcmp2);
break;

case 5:
// Note: use memcmp4 rather memcmp3 for align, as the last character
// of needle is already proven to be equal
result = detail::smidStrstrMemcmp<5>(s, n, needle, memcmp4);
break;

case 6:
result = detail::smidStrstrMemcmp<6>(s, n, needle, memcmp4);
break;

case 7:
result = detail::smidStrstrMemcmp<7>(s, n, needle, memcmp5);
break;

case 8:
result = detail::smidStrstrMemcmp<8>(s, n, needle, memcmp6);
break;

case 9:
// Note: use memcmp8 rather memcmp7 for the same reason as above.
result = detail::smidStrstrMemcmp<9>(s, n, needle, memcmp8);
break;

case 10:
result = detail::smidStrstrMemcmp<10>(s, n, needle, memcmp8);
break;

case 11:
result = detail::smidStrstrMemcmp<11>(s, n, needle, memcmp9);
break;

case 12:
result = detail::smidStrstrMemcmp<12>(s, n, needle, memcmp10);
break;

default:
result = detail::smidStrstrAnysize(s, n, needle, k);
break;
}

// load_unaligned is used for better performance, so result maybe bigger than
// n-k.
if (result <= n - k) {
return result;
} else {
return std::string::npos;
}
#else
// Generic path for string search.
return std::string_view(s, n).find(std::string_view(needle, k));
#endif
}

} // namespace facebook::velox::simd
4 changes: 3 additions & 1 deletion velox/common/base/SimdUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cstdint>
#include "velox/common/base/BitUtil.h"
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/FixedMemCompare.h"

#include <folly/Likely.h>
#include <xsimd/xsimd.hpp>
Expand Down Expand Up @@ -496,7 +497,8 @@ xsimd::batch<T, A> reinterpretBatch(xsimd::batch<U, A>, const A& = {});
// equal. May address up to SIMD width -1 past end of either 'x' or 'y'.
template <typename A = xsimd::default_arch>
inline bool memEqualUnsafe(const void* x, const void* y, int32_t size);

FOLLY_ALWAYS_INLINE size_t
simdStrstr(const char* s, size_t n, const char* needle, size_t k);
} // namespace facebook::velox::simd

#include "velox/common/base/SimdUtil-inl.h"
21 changes: 21 additions & 0 deletions velox/common/base/tests/SimdUtilTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,4 +491,25 @@ TEST_F(SimdUtilTest, memcpyTime) {
LOG(INFO) << "simd=" << simd << " sys=" << sys;
}

TEST_F(SimdUtilTest, testSimdStrStr) {
// 48 chars.
std::string s1 = "aabbccddeeffgghhiijjkkllmmnnooppqqrrssttuuvvwwxxyyzz";
std::string s2 = "aabbccddeeffgghhiijjkkllmmnnooppqqrrssttuuvvwwxxyyzz";
std::string s3 = "xxx";
auto test = [](char* text, size_t size, char* needle, size_t k) {
ASSERT_EQ(
simd::simdStrstr(text, size, needle, k),
std::string_view(text, size).find(std::string_view(needle, k)));
};
// Match cases : substrings in s2 should be a substring in s1.
for (int i = 0; i < 20; i++) {
for (int k = 0; k < 28; k++) {
char* data = s2.data() + i;
test(s1.data(), s1.size(), data, k);
}
}
// Not match case : "xxx" not in s1.
test(s1.data(), s1.size(), s3.data(), s3.size());
}

} // namespace
Loading

0 comments on commit ee87495

Please sign in to comment.