fork download
  1. /******************************************************************************
  2.  * Fast linear search (AVX-512 ➜ AVX2 ➜ scalar) *
  3.  * ‑> reports which implementation ran and how long it took. *
  4.  * *
  5.  * g++ -O3 -std=c++17 -mavx512f -mavx2 fast_find.cpp -pthread -o fast_find *
  6.  * ./fast_find # single-thread *
  7.  * ./fast_find --mt # multi-thread *
  8.  ******************************************************************************/
  9.  
  10. #include <algorithm>
  11. #include <atomic>
  12. #include <chrono>
  13. #include <cstddef>
  14. #include <cstdint>
  15. #include <immintrin.h>
  16. #include <iostream>
  17. #include <random>
  18. #include <string>
  19. #include <thread>
  20. #include <vector>
  21.  
  22. namespace fast_find {
  23.  
  24. // ---------------------------------------------------------------------------
  25. // Which implementation was actually used?
  26. // ---------------------------------------------------------------------------
  27. enum class Impl { Scalar, AVX2, AVX512 };
  28.  
  29. // ---------------------------------------------------------------------------
  30. // 1. Scalar fallback
  31. // ---------------------------------------------------------------------------
  32. template <typename T>
  33. inline int scalar(const T* a, std::size_t n, T key) noexcept {
  34. for (std::size_t i = 0; i < n; ++i)
  35. if (a[i] == key) return static_cast<int>(i);
  36. return -1;
  37. }
  38.  
  39. // ---------------------------------------------------------------------------
  40. // 2. AVX2 implementation
  41. // ---------------------------------------------------------------------------
  42. #ifdef __AVX2__
  43. inline int avx2(const int* a, std::size_t n, int key) noexcept {
  44. constexpr int W = 8;
  45. const __m256i NEEDLE = _mm256_set1_epi32(key);
  46.  
  47. std::size_t i = 0;
  48. const std::size_t limit = n & ~(W * 4 - 1);
  49.  
  50. for (; i < limit; i += W * 4) {
  51. __m256i v0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
  52. __m256i v1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + W));
  53. __m256i v2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + W * 2));
  54. __m256i v3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + W * 3));
  55.  
  56. int m0 = _mm256_movemask_epi8(_mm256_cmpeq_epi32(v0, NEEDLE));
  57. if (m0) return i + ((m0 & -m0) % 255) >> 2;
  58.  
  59. int m1 = _mm256_movemask_epi8(_mm256_cmpeq_epi32(v1, NEEDLE));
  60. if (m1) return i + W + ((m1 & -m1) % 255) >> 2;
  61.  
  62. int m2 = _mm256_movemask_epi8(_mm256_cmpeq_epi32(v2, NEEDLE));
  63. if (m2) return i + W * 2 + ((m2 & -m2) % 255) >> 2;
  64.  
  65. int m3 = _mm256_movemask_epi8(_mm256_cmpeq_epi32(v3, NEEDLE));
  66. if (m3) return i + W * 3 + ((m3 & -m3) % 255) >> 2;
  67. }
  68. for (; i < n; ++i)
  69. if (a[i] == key) return static_cast<int>(i);
  70. return -1;
  71. }
  72. #endif
  73.  
  74. // ---------------------------------------------------------------------------
  75. // 3. AVX-512 implementation
  76. // ---------------------------------------------------------------------------
  77. #ifdef __AVX512F__
  78. inline int avx512(const int* a, std::size_t n, int key) noexcept {
  79. constexpr int W = 16;
  80. const __m512i NEEDLE = _mm512_set1_epi32(key);
  81.  
  82. std::size_t i = 0;
  83. const std::size_t limit = n & ~(W * 4 - 1);
  84.  
  85. for (; i < limit; i += W * 4) {
  86. _mm_prefetch(reinterpret_cast<const char*>(a + i + 64), _MM_HINT_T0);
  87. _mm_prefetch(reinterpret_cast<const char*>(a + i + 128), _MM_HINT_T0);
  88.  
  89. __mmask16 m0 = _mm512_cmpeq_epi32_mask(_mm512_loadu_si512(a + i), NEEDLE);
  90. if (m0) return i + _tzcnt_u32(m0);
  91.  
  92. __mmask16 m1 = _mm512_cmpeq_epi32_mask(_mm512_loadu_si512(a + i + W), NEEDLE);
  93. if (m1) return i + W + _tzcnt_u32(m1);
  94.  
  95. __mmask16 m2 = _mm512_cmpeq_epi32_mask(_mm512_loadu_si512(a + i + W * 2), NEEDLE);
  96. if (m2) return i + W * 2 + _tzcnt_u32(m2);
  97.  
  98. __mmask16 m3 = _mm512_cmpeq_epi32_mask(_mm512_loadu_si512(a + i + W * 3), NEEDLE);
  99. if (m3) return i + W * 3 + _tzcnt_u32(m3);
  100. }
  101. for (; i < n; ++i)
  102. if (a[i] == key) return static_cast<int>(i);
  103. return -1;
  104. }
  105. #endif
  106.  
  107. // ---------------------------------------------------------------------------
  108. // 4. Single-thread façade (returns index + impl used)
  109. // ---------------------------------------------------------------------------
  110. inline int search(const int* data, std::size_t n, int value, Impl& used) noexcept {
  111. #ifdef __AVX512F__
  112. if (__builtin_cpu_supports("avx512f")) { used = Impl::AVX512; return avx512(data, n, value); }
  113. #endif
  114. #ifdef __AVX2__
  115. if (__builtin_cpu_supports("avx2")) { used = Impl::AVX2; return avx2 (data, n, value); }
  116. #endif
  117. used = Impl::Scalar;
  118. return scalar(data, n, value);
  119. }
  120.  
  121. // convenience wrapper when caller doesn't care about impl
  122. inline int search(const int* data, std::size_t n, int value) noexcept {
  123. Impl dummy;
  124. return search(data, n, value, dummy);
  125. }
  126.  
  127. // ---------------------------------------------------------------------------
  128. // 5. Multi-thread wrapper (returns index + impl used by *any* thread)
  129. // ---------------------------------------------------------------------------
  130. inline int search_mt(const int* data, std::size_t n, int value,
  131. unsigned nThreads,
  132. Impl& usedImpl)
  133. {
  134. if (nThreads == 0) nThreads = 1;
  135. if (nThreads == 1 || n < 16'384) // ST faster for small inputs
  136. return search(data, n, value, usedImpl);
  137.  
  138. const std::size_t chunk = (n + nThreads - 1) / nThreads;
  139. std::atomic<int> result{-1};
  140. std::atomic<Impl> implSeen{Impl::Scalar};
  141. std::vector<std::thread> pool;
  142.  
  143. for (unsigned t = 0; t < nThreads; ++t) {
  144. const std::size_t start = t * chunk;
  145. if (start >= n) break;
  146. const std::size_t end = std::min(start + chunk, n);
  147.  
  148. pool.emplace_back([&, start, end]() {
  149. Impl localImpl;
  150. int localIdx = search(data + start, end - start, value, localImpl);
  151. implSeen.store(localImpl, std::memory_order_relaxed);
  152.  
  153. if (localIdx != -1) {
  154. int global = static_cast<int>(start + localIdx);
  155. int expected = -1;
  156. result.compare_exchange_strong(expected, global,
  157. std::memory_order_relaxed);
  158. }
  159. });
  160. }
  161. for (auto& th : pool) th.join();
  162.  
  163. usedImpl = implSeen.load(std::memory_order_relaxed);
  164. return result.load();
  165. }
  166.  
  167. } // namespace fast_find
  168.  
  169. // ═══════════════════════════════════ Demo main ═════════════════════════════
  170. static std::string to_string(fast_find::Impl impl) {
  171. switch (impl) {
  172. case fast_find::Impl::Scalar: return "Scalar";
  173. case fast_find::Impl::AVX2: return "AVX2";
  174. case fast_find::Impl::AVX512: return "AVX-512";
  175. }
  176. return "Unknown";
  177. }
  178.  
  179. int main(int argc, char** argv) {
  180. constexpr std::size_t N = 10'000;
  181. std::vector<int> data(N);
  182. for (std::size_t i = 0; i < N; ++i) data[i] = (i * 77 + 123) & 0x7FFF;
  183.  
  184. // --------- Randomly pick a key from the data set
  185. std::random_device rd;
  186. std::mt19937 rng(rd());
  187. std::uniform_int_distribution<std::size_t> dist(0, N - 1);
  188. const std::size_t randIdx = dist(rng);
  189. const int key = data[randIdx];
  190.  
  191. const bool useMT = (argc > 1 && std::string(argv[1]) == "--mt");
  192.  
  193. const unsigned hwThreads =
  194. std::thread::hardware_concurrency() ? std::thread::hardware_concurrency() : 1;
  195.  
  196. fast_find::Impl implUsed;
  197. const auto t0 = std::chrono::high_resolution_clock::now();
  198. int idx = useMT
  199. ? fast_find::search_mt(data.data(), data.size(), key,
  200. hwThreads, implUsed)
  201. : fast_find::search (data.data(), data.size(), key,
  202. implUsed);
  203. const auto t1 = std::chrono::high_resolution_clock::now();
  204. const double micro =
  205. std::chrono::duration_cast<std::chrono::duration<double, std::micro>>(t1 - t0).count();
  206.  
  207. std::cout << (useMT ? "[MT] " : "[ST] ")
  208. << "Impl: " << to_string(implUsed)
  209. << " | Key: " << key
  210. << " | Index: " << idx
  211. << " | Time: " << micro << " µs"
  212. << " | Logical cores: " << hwThreads
  213. << '\n';
  214. }
Success #stdin #stdout 0.01s 5288KB
stdin
Standard input is empty
stdout
[ST] Impl: Scalar | Key: 13877 | Index: 6562 | Time: 4.546 µs | Logical cores: 8