├── Makefile ├── README.md ├── main.cpp └── miroslav.h /Makefile: -------------------------------------------------------------------------------- 1 | VEC_INFO=VecInfoAVX2 2 | ARCH=haswell 3 | 4 | CXXFLAGS=-std=c++14 -O3 -march=$(ARCH) -DVEC_INFO=$(VEC_INFO) -Wall -Wextra -Werror 5 | 6 | miroslav: main.cpp miroslav.h 7 | $(CXX) $(CXXFLAGS) -o miroslav main.cpp 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Miroslav 2 | ======== 3 | 4 | Miroslav is a fast regular expression matching algorithm using SIMD instructions to do a lossy NFA simulation on many input bytes in parallel. 5 | 6 | It's named after Miroslav Vitous, since his song [Infinite Search](https://www.youtube.com/watch?v=-OdIEbFwQEs) (from the album Infinite Search) came on shuffle as I was writing the initial implementation, and that seemed like an appropriate fit. The whole album is amazing by the way, some great early electric jazz from 1970. 7 | 8 | A full writeup on the algorithm should be coming soon-ish. 9 | 10 | This implementation for now only has a few features: 11 | 12 | * regular expressions can consist only of literal strings and the alternation operator `|`. 13 | * patterns can be loaded from a file with `-f [file]` 14 | * the count of lines matched can be printed with `-c`, otherwise all matching lines are printed 15 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "miroslav.h" 10 | 11 | // hacky macro for easier argument chomping 12 | #define EAT_ARG() do { argc--; argv++; } while (0); 13 | 14 | template 15 | void run(int argc, char **argv, NFAEdgeList &edges, MatchHandlerPrintLine &mh) { 16 | Miroslav m(edges, mh); 17 | 18 | while (argc > 0) { 19 | File f(argv[0]); 20 | EAT_ARG(); 21 | 22 | m.run(f); 23 | } 24 | } 25 | 26 | int main(int argc, char **argv) { 27 | if (argc < 3) { 28 | std::cerr << "Usage: " << argv[0] << " [-c] pattern path...\n"; 29 | std::cerr << " " << argv[0] << " [-c] -f pattern-file path...\n"; 30 | exit(1); 31 | } 32 | EAT_ARG(); 33 | 34 | // Option parsing 35 | bool print_count = false; 36 | if (!strcmp(argv[0], "-c")) { 37 | print_count = true; 38 | EAT_ARG(); 39 | } 40 | 41 | // Parse the input regex 42 | std::string pattern(argv[0]); 43 | if (!strcmp(argv[0], "-f")) { 44 | EAT_ARG(); 45 | std::ifstream f(argv[0]); 46 | EAT_ARG(); 47 | std::stringstream buffer; 48 | buffer << f.rdbuf(); 49 | pattern = buffer.str(); 50 | } else { 51 | pattern = std::string(argv[0]); 52 | EAT_ARG(); 53 | } 54 | 55 | NFAEdgeList edges; 56 | 57 | typedef std::pair edge_key; 58 | std::map edge_map{}; 59 | 60 | uint32_t last_state = START_STATE; 61 | uint32_t state = 2; 62 | bool can_be_duplicate = true; 63 | for (uint32_t i = 0; i < pattern.length(); i++) { 64 | auto c = pattern[i]; 65 | 66 | if (c == '|' || c == '\n') { 67 | last_state = START_STATE; 68 | can_be_duplicate = true; 69 | } else { 70 | // Add a new character to the pattern. Before we do, see if we can 71 | // branch off a previous state that matches the substring 72 | // up until now. 73 | if (can_be_duplicate) { 74 | edge_key k(c, last_state); 75 | if (edge_map.count(k)) { 76 | last_state = edge_map[k]; 77 | continue; 78 | } else 79 | can_be_duplicate = false; 80 | } 81 | if (i < pattern.length() - 1 && pattern[i + 1] != '|' && 82 | pattern[i + 1] != '\n') { 83 | edges.push_back(std::make_tuple(c, last_state, state)); 84 | edge_map[edge_key(c, last_state)] = state; 85 | last_state = state; 86 | state += 1; 87 | } else 88 | edges.push_back(std::make_tuple(c, last_state, END_STATE)); 89 | } 90 | } 91 | 92 | bool print_path = (argc > 1); 93 | MatchHandlerPrintLine mh(print_path, !print_count, print_count, true); 94 | 95 | if (state <= 32) 96 | run(argc, argv, edges, mh); 97 | else if (state <= 64) 98 | run(argc, argv, edges, mh); 99 | else if (state <= 128) 100 | run(argc, argv, edges, mh); 101 | // SIMD verifier: only allow 2^n-1 states since we need an extra for a sentinel 102 | else if (state <= 255) 103 | run_miroslav(argc, argv, edges, mh); 104 | else if (state <= 65535) 105 | run_miroslav(argc, argv, edges, mh); 106 | else 107 | run_miroslav(argc, argv, edges, mh); 108 | } 109 | -------------------------------------------------------------------------------- /miroslav.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #define UNUSED __attribute__((unused)) 13 | 14 | #define EXPECT(x, v) __builtin_expect((x), (v)) 15 | 16 | static inline uint32_t bsf64(uint64_t x) { 17 | return __builtin_ctzll(x); 18 | } 19 | 20 | static inline uint32_t bsf32(uint32_t x) { 21 | return __builtin_ctzl(x); 22 | } 23 | 24 | #define START_STATE (0) 25 | #define END_STATE (1) 26 | 27 | typedef std::vector> NFAEdgeList; 28 | 29 | // Wacky macro to make tuple unpacking a little less annoying 30 | #define FOR_EACH_EDGE(c, from, to, edges) \ 31 | for (auto edge_i = edges.begin(); \ 32 | edge_i != edges.end() ? (std::tie(c, from, to) = *edge_i), 1 : 0; \ 33 | edge_i++) 34 | 35 | //////////////////////////////////////////////////////////////////////////////// 36 | // Basic mmap wrapper for file handling 37 | //////////////////////////////////////////////////////////////////////////////// 38 | struct File { 39 | const char *path; 40 | const uint8_t *data; 41 | size_t size; 42 | 43 | File() : path(NULL), data(NULL), size(0) { } 44 | 45 | File(const char *path) : path(path) { 46 | FILE *f = fopen(path, "r"); 47 | 48 | fseek(f, 0, SEEK_END); 49 | this->size = ftell(f); 50 | 51 | this->data = (uint8_t *)mmap(NULL, this->size, PROT_READ, MAP_FILE|MAP_PRIVATE, fileno(f), 0); 52 | if (this->data == MAP_FAILED) { 53 | perror("mmap"); 54 | exit(1); 55 | } 56 | 57 | fclose(f); 58 | } 59 | 60 | ~File() { 61 | if (munmap((void *)this->data, this->size) == -1) { 62 | perror("munmap"); 63 | exit(1); 64 | } 65 | } 66 | }; 67 | 68 | //////////////////////////////////////////////////////////////////////////////// 69 | // Vector definition classes 70 | // These contain the basic properties of the underlying SIMD implementation, 71 | // as well as functions for any architecture-dependent operations. 72 | //////////////////////////////////////////////////////////////////////////////// 73 | 74 | // AVX2: we have 32-byte vectors, but can only shuffle within 16-byte halves of 75 | // these vectors, which requires special handling 76 | struct VecInfoAVX2 { 77 | static const uint32_t VL = 32; 78 | typedef __m256i V; 79 | typedef uint32_t vmask; 80 | typedef uint64_t double_vmask; 81 | 82 | // Lookup mask: since avx2 shuffles only work in two 16-byte chunks, we have 83 | // to use just four bits of each character as a table index, instead of the 84 | // five that we'd prefer 85 | static const uint32_t LMASK = 16 - 1; 86 | 87 | static inline V permute(V &table, V &index) { 88 | return _mm256_shuffle_epi8(table, index); 89 | } 90 | 91 | static inline V vec_and(V &a, V &b) { 92 | return _mm256_and_si256(a, b); 93 | } 94 | 95 | static inline V vec_shr(V &a, uint32_t shift) { 96 | return _mm256_srli_epi32(a, shift); 97 | } 98 | 99 | static inline V vec_lanes_shl_1(V &top, V &bottom) { 100 | // Move all the vector lanes in "top" to the left by one and fill 101 | // in the first lane with the last lane in "bottom". Since AVX2 102 | // generally works on two separate 16-byte vectors glued together, 103 | // this needs two steps. The permute takes [bottom_H, bottom_L] 104 | // and [top_H, top_L] and gives us [top_L, bottom_H]. The align then 105 | // takes [top_H, top_L] and gives us [top_H[1:], top_L[:1]], and 106 | // takes [top_L, bottom_H] and gives us [top_L[1:], bottom_H[:1]]. 107 | V shl_16 = _mm256_permute2x128_si256(top, bottom, 0x03); 108 | return _mm256_alignr_epi8(top, shl_16, 15); 109 | } 110 | 111 | static inline vmask test_high_bit(V &a) { 112 | return _mm256_movemask_epi8(a); 113 | } 114 | 115 | static inline vmask test_low_bit(V &a) { 116 | // Movemask tests the high bit, so the input has to be shifted up 117 | return _mm256_movemask_epi8(_mm256_slli_epi32(a, 7)); 118 | } 119 | 120 | static inline vmask test_nz(V &a) { 121 | return ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a, _mm256_setzero_si256())); 122 | } 123 | 124 | static void prepare_state_table(uint8_t state_bytes[VL]) { 125 | // HACK because AVX2 sucks and can only do 16-byte shuffles 126 | for (uint32_t i = 0; i < 16; i++) 127 | state_bytes[i + 16] = state_bytes[i]; 128 | } 129 | }; 130 | 131 | // Functions specialized on both vector size and element size. C++ doesn't 132 | // allow explicit specializations inside classes, so they're out here... 133 | 134 | // Broadcast 135 | template 136 | inline typename VI::V broadcast(element value); 137 | template<> 138 | inline VecInfoAVX2::V broadcast(uint8_t value) { 139 | return _mm256_set1_epi8(value); 140 | } 141 | template<> 142 | inline VecInfoAVX2::V broadcast(uint16_t value) { 143 | return _mm256_set1_epi16(value); 144 | } 145 | template<> 146 | inline VecInfoAVX2::V broadcast(uint32_t value) { 147 | return _mm256_set1_epi32(value); 148 | } 149 | 150 | // Test equal 151 | template 152 | inline typename VI::vmask test_eq(typename VI::V &a, typename VI::V &b); 153 | template<> 154 | inline VecInfoAVX2::vmask test_eq( 155 | VecInfoAVX2::V &a, VecInfoAVX2::V &b) { 156 | return _mm256_movemask_epi8(_mm256_cmpeq_epi8(a, b)); 157 | } 158 | template<> 159 | inline VecInfoAVX2::vmask test_eq( 160 | VecInfoAVX2::V &a, VecInfoAVX2::V &b) { 161 | // HACK: avx2 doesn't have a movemask_epi16 instruction. So we just use the 162 | // epi8 version, and in the one place test_eq is used now, we divide the 163 | // bitscan of this mask by 2. 164 | return _mm256_movemask_epi8(_mm256_cmpeq_epi16(a, b)); 165 | } 166 | template<> 167 | inline VecInfoAVX2::vmask test_eq( 168 | VecInfoAVX2::V &a, VecInfoAVX2::V &b) { 169 | return _mm256_movemask_ps(_mm256_cmpeq_epi32(a, b)); 170 | } 171 | 172 | // Lossy bitset. This works pretty much like a Bloom filter. For quickly 173 | // testing membership of a single index within the bitset, we look at 174 | // contiguous 6-bit chunks of the index value, and set a single bit within 175 | // the bitset based on that. For example, given the index 0xc68 (binary 176 | // 110001101000) and shifts of 0, 3, and 6, we get the following values: 177 | // 178 | // 1 1 0 0 0 1[1 0 1 0 0 0] --> 101000 (40) 179 | // 1 1 0[0 0 1 1 0 1]0 0 0 --> 001101 (13) 180 | // [1 1 0 0 0 1]1 0 1 0 0 0 --> 110001 (49) 181 | // 182 | // So we test bit 40 in the first bitset, 13 in the second, and 49 in the 183 | // third. Only if all of these bits are set do we return true from 184 | // might_contain. 185 | // 186 | // Also note that this implementation depends on 64-bit variable shifts 187 | // being implicitly modulo 64, as they are on Intel chips... 188 | template 189 | struct _LossyBitset { 190 | uint64_t bitsets[N_SHIFTS]; 191 | _LossyBitset() : bitsets{0} { 192 | } 193 | void add(uint32_t index) { 194 | for (uint32_t s = 0; s < N_SHIFTS; s++) { 195 | uint32_t sub_index = index >> SHIFTS[s]; 196 | bitsets[s] |= 1ull << sub_index; 197 | } 198 | } 199 | bool might_contain(uint32_t index) { 200 | for (uint32_t s = 0; s < N_SHIFTS; s++) { 201 | uint32_t sub_index = index >> SHIFTS[s]; 202 | if (!(bitsets[s] & 1ull << sub_index)) 203 | return false; 204 | } 205 | return true; 206 | } 207 | }; 208 | 209 | static constexpr uint32_t LBS_SHIFTS[] = {0, 4}; 210 | static constexpr uint32_t N_LBS_SHIFTS = sizeof(LBS_SHIFTS) / sizeof(LBS_SHIFTS[0]); 211 | typedef _LossyBitset LossyBitset; 212 | 213 | //////////////////////////////////////////////////////////////////////////////// 214 | // Match verifier 215 | // Since the vectorized matcher can give false positives, we have to run through 216 | // each potential match backwards from the end character. 217 | //////////////////////////////////////////////////////////////////////////////// 218 | 219 | // Always return true. Bad for correctness, good for testing speed of the core 220 | // Miroslav algorithm. 221 | struct FakeMatchVerifier { 222 | FakeMatchVerifier(UNUSED NFAEdgeList &edges) { } 223 | 224 | const uint8_t *verify(UNUSED const uint8_t *data, UNUSED const uint8_t *end) { 225 | return end; 226 | } 227 | }; 228 | 229 | // Basic bitset NFA simulator 230 | template 231 | class BasicMatchVerifier { 232 | typedef typename StateInfo::smask smask; 233 | static const uint32_t MAX_STATES = StateInfo::MAX_STATES; 234 | 235 | // This is a table of [input_byte][state] -> prev_states 236 | smask back_edges[256][MAX_STATES]; 237 | // All states that lead to a successful match from each input byte 238 | smask match_mask[256]; 239 | // Shortcut mask: for each input byte, keep a mask of states that have a 240 | // predecessor state so we don't try any unnecessary lookups 241 | smask next_mask[256]; 242 | 243 | public: 244 | BasicMatchVerifier(NFAEdgeList &edges) { 245 | for (uint32_t i = 0; i < 256; i++) { 246 | for (uint32_t j = 0; j < MAX_STATES; j++) 247 | back_edges[i][j] = 0; 248 | next_mask[i] = 0; 249 | match_mask[i] = 0; 250 | } 251 | 252 | // Initialize state mask tables 253 | uint8_t c; 254 | uint32_t from, to; 255 | FOR_EACH_EDGE(c, from, to, edges) { 256 | assert(from < MAX_STATES); 257 | assert(to < MAX_STATES); 258 | next_mask[c] |= (smask)1 << to; 259 | 260 | if (from == START_STATE) 261 | match_mask[c] |= (smask)1 << to; 262 | else 263 | back_edges[c][to] |= (smask)1 << from; 264 | } 265 | } 266 | 267 | const uint8_t *verify(const uint8_t *data, const uint8_t *end) { 268 | smask states = 1 << END_STATE; 269 | 270 | do { 271 | uint8_t c = *end; 272 | 273 | if (match_mask[c] & states) 274 | return end; 275 | 276 | states &= next_mask[c]; 277 | 278 | // Iterate through all current states and look up their next states 279 | smask next_states = 0; 280 | while (states) { 281 | uint32_t s = StateInfo::bsf(states); 282 | next_states |= back_edges[c][s]; 283 | states &= states - 1; 284 | } 285 | states = next_states; 286 | } while (states && --end >= data); 287 | 288 | return NULL; 289 | } 290 | }; 291 | 292 | // Structs with definitions for the BasicMatchVerifier. 293 | // smask should have MAX_STATES bits. 294 | // The backwards state table will be (max/8) * max * 256 bytes. 295 | // So 8 states is 2K, 16 => 8K, 32 => 32K, 64 => 128K, 128 => 512K 296 | 297 | struct StateInfo32 { 298 | static const uint32_t MAX_STATES = 32; 299 | typedef uint32_t smask; 300 | static inline uint32_t bsf(smask x) { return bsf32(x); } 301 | }; 302 | struct StateInfo64 { 303 | static const uint32_t MAX_STATES = 64; 304 | typedef uint64_t smask; 305 | static inline uint32_t bsf(smask x) { return bsf64(x); } 306 | }; 307 | struct StateInfo128 { 308 | static const uint32_t MAX_STATES = 128; 309 | typedef __uint128_t smask; 310 | static inline uint32_t bsf(smask x) { 311 | if ((uint64_t)x) 312 | return bsf64(((uint64_t)x)); 313 | return bsf64(x >> 64) + 64; 314 | } 315 | }; 316 | 317 | typedef BasicMatchVerifier BasicMatchVerifier32; 318 | typedef BasicMatchVerifier BasicMatchVerifier64; 319 | typedef BasicMatchVerifier BasicMatchVerifier128; 320 | 321 | // Find the next occurrence of 'chr' in the data stream before 'bound', and return 322 | // the pointer of the character after that. We can go in both directions. 323 | inline const uint8_t *skip_chr(const uint8_t *p, const uint8_t chr, bool forwards, 324 | const uint8_t *bound) { 325 | if (forwards) 326 | while (p < bound && *p != chr) 327 | p++; 328 | else 329 | while (p >= bound && *p != chr) 330 | p--; 331 | return p; 332 | } 333 | 334 | // Dumb empty class. Using Squamatus as a verifier doesn't need a handler, so we 335 | // use this type as a placeholder 336 | struct DummyMatchHandler { 337 | typedef uint32_t return_type; 338 | void handle_match(UNUSED File &f, UNUSED const uint8_t *start, 339 | UNUSED const uint8_t *end) { 340 | } 341 | }; 342 | 343 | struct MatcherRegexOpts { 344 | static const bool FORWARDS = true; 345 | static const bool CONTINUOUS = true; 346 | static const bool OVERLAPPING = false; 347 | static const bool ONE_PER_LINE = false; 348 | }; 349 | 350 | struct VerifierRegexOpts { 351 | static const bool FORWARDS = false; 352 | static const bool CONTINUOUS = false; 353 | static const bool OVERLAPPING = false; 354 | static const bool ONE_PER_LINE = false; 355 | }; 356 | 357 | template 358 | class _Squamatus { 359 | // Hacky substitute for "using" 360 | typedef typename VI::V V; 361 | typedef typename VI::vmask vmask; 362 | static const uint32_t VL = VI::VL; 363 | 364 | // Vector constants 365 | static const uint64_t N_VE = VI::VL / sizeof(VE); 366 | static const uint64_t MAX_STATES = ((uint64_t)1 << 8 * sizeof(VE)); 367 | static const uint64_t SENTINEL = MAX_STATES - 1; 368 | // HACK: we divide by 2 when using 16 bit compares since there's no movemask_epi16 369 | static const uint64_t MOVEMASK_HACK = sizeof(VE) == 2 ? 2 : 1; 370 | // Another HACKish thing: for a masking operation deep in the algorithm, we 371 | // need to mask out all the bits above the ones that might be set by a vcmpeq 372 | // instruction. This should be all ones (thus a no-op) unless sizeof(VE) > 2. 373 | static const vmask VCMP_BITS = ~(-1 << N_VE * MOVEMASK_HACK); 374 | 375 | // Associative array of state -> state transitions, with a different array per 376 | // possible input byte 377 | uint32_t key_count[256]; 378 | VE *edge_keys[256]; 379 | VE *edge_values[256]; 380 | 381 | // For keeping track of all current states and all next states 382 | VE *_state_buffer[2]; 383 | 384 | // For keeping track of where a given NFA match sequence started 385 | uint32_t *_start_buffer[2]; 386 | 387 | // Stuff to act like a full regex matcher 388 | MatchHandler &match_handler; 389 | File *input_file; 390 | 391 | static inline uint32_t INITIAL_STATE() { 392 | return Opts::FORWARDS ? START_STATE : END_STATE; 393 | } 394 | static inline uint32_t TARGET_STATE() { 395 | return Opts::FORWARDS ? END_STATE : START_STATE; 396 | } 397 | 398 | public: 399 | // Constructor wrapper using a cool NULL reference, for when using this class 400 | // as a verifier (the handler isn't touched) 401 | _Squamatus(NFAEdgeList &edges) : _Squamatus(edges, *(MatchHandler *)NULL) { 402 | assert(!Opts::CONTINUOUS); 403 | } 404 | 405 | _Squamatus(NFAEdgeList &edges, MatchHandler &handler) : match_handler(handler) { 406 | // Group NFA edges into vectors by character 407 | struct kv_pair { 408 | VE k, v; 409 | kv_pair(VE k, VE v) : k(k), v(v) { } 410 | }; 411 | std::vector edge_pairs[256]; 412 | uint8_t c; 413 | uint32_t from, to; 414 | FOR_EACH_EDGE(c, from, to, edges) { 415 | assert((uint64_t)from < SENTINEL); 416 | assert((uint64_t)to < SENTINEL); 417 | assert(from != END_STATE); 418 | assert(to != START_STATE); 419 | // Add the from/to states into the associative array. Which is 420 | // the key and which is the value depends on which direction 421 | // we're going. 422 | if (Opts::FORWARDS) 423 | edge_pairs[c].push_back(kv_pair(from, to)); 424 | else 425 | edge_pairs[c].push_back(kv_pair(to, from)); 426 | } 427 | 428 | uint32_t max_concurrent_states = 0; 429 | 430 | for (uint32_t i = 0; i < 256; i++) { 431 | // Sort the vector so all states that a given state/character combination 432 | // can lead to are all contiguous 433 | std::stable_sort(edge_pairs[i].begin(), edge_pairs[i].end(), 434 | [](const auto& a, const auto& b) { 435 | return a.k < b.k || (a.k == b.k && a.v < b.v); 436 | }); 437 | 438 | if (edge_pairs[i].size() == 0) { 439 | edge_keys[i] = edge_values[i] = NULL; 440 | key_count[i] = 0; 441 | continue; 442 | } 443 | 444 | // Calculate number of elements. We add N_VE and clear the low bits 445 | // (x & -N_VE rounds x down to the nearest multiple of N_VE) to get 446 | // the number of bytes we will store the keys/values in. We add 447 | // N_VE to round up, but also multiples of N_VE round up to the 448 | // next multiple, so we always get at least one extra slot. That 449 | // way we can just compare states in a loop while ignoring the 450 | // array length (since we're storing 255s at the end of the table 451 | // that will always compare false). 452 | key_count[i] = (edge_pairs[i].size() + N_VE) & -N_VE; 453 | edge_keys[i] = (VE *)malloc(key_count[i] * sizeof(VE)); 454 | edge_values[i] = (VE *)malloc(key_count[i] * sizeof(VE)); 455 | 456 | // Copy the sorted vector into the key/value lists, filling in the 457 | // rest of the values with SENTINEL 458 | uint32_t j; 459 | for (j = 0; j < edge_pairs[i].size(); j++) { 460 | edge_keys[i][j] = edge_pairs[i][j].k; 461 | edge_values[i][j] = edge_pairs[i][j].v; 462 | } 463 | for (; j < key_count[i]; j++) { 464 | edge_keys[i][j] = SENTINEL; 465 | edge_values[i][j] = SENTINEL; 466 | } 467 | 468 | // Update the # of max concurrent states. Since we fill in the state 469 | // buffer only with values from reading the edge_values[] array for 470 | // a single input character, the maximum length of one of these arrays 471 | // is an upper bound for the number of states that this NFA can 472 | // possibly be in at once. 473 | if (key_count[i] * N_VE > max_concurrent_states) 474 | max_concurrent_states = key_count[i] * N_VE; 475 | } 476 | 477 | // Allocate two buffers for storing the current states 478 | // +VL because we can write past the end of this array 479 | size_t buf_size = max_concurrent_states * sizeof(VE) + VL; 480 | _state_buffer[0] = (VE *)malloc(buf_size); 481 | _state_buffer[1] = (VE *)malloc(buf_size); 482 | 483 | // Allocate buffers for storing the start index of each 484 | // match as well. We have to allocate a bit more scratch 485 | // space at the end, since with this array we could potentially 486 | // write up to 4*VL bytes off the end 487 | size_t start_buf_size = max_concurrent_states * sizeof(uint32_t); 488 | start_buf_size += (VL * sizeof(uint32_t) / sizeof(VE)); 489 | _start_buffer[0] = (uint32_t *)malloc(start_buf_size); 490 | _start_buffer[1] = (uint32_t *)malloc(start_buf_size); 491 | } 492 | 493 | ~_Squamatus() { 494 | for (uint32_t i = 0; i < 256; i++) { 495 | if (key_count[i]) { 496 | free(edge_keys[i]); 497 | free(edge_values[i]); 498 | } 499 | } 500 | for (uint32_t i = 0; i < 2; i++) { 501 | free(_state_buffer[i]); 502 | free(_start_buffer[i]); 503 | } 504 | } 505 | 506 | const uint8_t *verify(const uint8_t *data, const uint8_t *end) { 507 | // We have two state lists, which we switch between a la double buffering. 508 | VE *states = _state_buffer[0], *next_states = _state_buffer[1]; 509 | uint32_t _n_states[2]; 510 | uint32_t *n_states = &_n_states[0], *next_n_states = &_n_states[1]; 511 | 512 | // Also keep an index for each state in the list, pointing to where 513 | // in the input stream the match started 514 | uint32_t *start_idx = _start_buffer[0], *next_start_idx = _start_buffer[1]; 515 | 516 | // For continuous operation, we will have the INITIAL_STATE always in 517 | // the current state set. To do this easily, we just always keep it 518 | // in the first position, and only write to entries after the first. 519 | if (Opts::CONTINUOUS) { 520 | states[0] = INITIAL_STATE(); 521 | next_states[0] = INITIAL_STATE(); 522 | } 523 | 524 | const uint8_t *input_p = Opts::FORWARDS ? data : end; 525 | 526 | *n_states = 1; 527 | states[0] = INITIAL_STATE(); 528 | start_idx[0] = input_p - data + (Opts::FORWARDS ? -1 : 1); 529 | 530 | do { 531 | start: 532 | uint8_t c = *input_p; 533 | 534 | // CONTINUOUS mode has an implied initial state always at the 535 | // beginning of the buffer 536 | if (Opts::CONTINUOUS) { 537 | *next_n_states = 1; 538 | next_start_idx[0] = input_p - data; 539 | } else 540 | *next_n_states = 0; 541 | 542 | // Add a (lossy) bitset of already examined states this iteration, 543 | // to skip duplicate states. We add the start state to the mask so 544 | // we don't have an extra branch on every state--we only have to 545 | // check for the start state inside the duplicate checking code. 546 | LossyBitset seen; 547 | seen.add(TARGET_STATE()); 548 | 549 | for (uint32_t s = 0; s < *n_states; s++) { 550 | VE state = states[s]; 551 | V state_vec = broadcast(state); 552 | V start_idx_vec = broadcast(start_idx[s]); 553 | 554 | // Check for whether we've already added this state 555 | if (seen.might_contain(state)) { 556 | // We found a match! We added the target state to the 557 | // seen bitset so we only have to do this test in the 558 | // slow path. 559 | if (state == TARGET_STATE()) { 560 | // Non-continuous case: we were just verifying that there was 561 | // a match. Oh hey, looks like there was one. 562 | if (!Opts::CONTINUOUS) 563 | return input_p; 564 | 565 | // Continuous case: pass the match off to the match 566 | // handler and keep going 567 | 568 | // Register the match. We have to do a bit of 569 | // index math to make the start/end points line 570 | // up correctly in both the forwards and backwards 571 | // cases. Why would we be streaming backwards? Who 572 | // the hell knows? 573 | if (Opts::FORWARDS) 574 | match_handler.handle_match(*input_file, 575 | data + start_idx[s] + 1, input_p - 1); 576 | else 577 | match_handler.handle_match(*input_file, 578 | input_p + 1, data + start_idx[s] - 1); 579 | 580 | // Skip to the next newline if we only care about 581 | // one match per line 582 | if (Opts::ONE_PER_LINE) { 583 | input_p = skip_chr(input_p, '\n', Opts::FORWARDS, 584 | Opts::FORWARDS ? end : data); 585 | 586 | next_start_idx[0] = input_p - data; 587 | 588 | // Reset all states for the next iteration except 589 | // the initial state (always at slot 0). Then 590 | // break to go on to the next character. 591 | *next_n_states = 1; 592 | break; 593 | } 594 | 595 | // If we allow overlapping matches, continue on to 596 | // the next state--all states that are in flight 597 | // are still valid and need to be processed. 598 | if (Opts::OVERLAPPING) 599 | continue; 600 | 601 | // No overlapping: any match that started before 602 | // this character can be thrown away. *But* we need 603 | // to process this character again from the start 604 | // state. So reset the states to just the initial 605 | // state, and go to the start 606 | // XXX this is ugly 607 | *n_states = 1; 608 | goto start; 609 | } 610 | 611 | // We potentially have a duplicate state. Since the state could 612 | // be anywhere in the state array before this state, we do a bulk 613 | // compare using vector instructions against the whole array. In 614 | // the loop we only compare up to the last part of the array that 615 | // fits entirely within a vector. The rest are compared with a 616 | // special case after. 617 | uint32_t s_rounded = s & -N_VE; 618 | for (uint32_t i = 0; i < s_rounded; i += N_VE) { 619 | V key = *(V *)&states[i]; 620 | vmask eq = test_eq(key, state_vec); 621 | if (eq) 622 | goto skip; 623 | } 624 | 625 | // Compare the last set of states before this one. We 626 | // compare all of them, but only test the bits below 627 | // the one corresponding to this state. 628 | V key = *(V *)&states[s_rounded]; 629 | vmask eq = test_eq(key, state_vec); 630 | if (eq & (1 << (s - s_rounded)) - 1) 631 | goto skip; 632 | } else 633 | seen.add(state); 634 | 635 | for (uint32_t i = 0; i < key_count[c]; i += N_VE) { 636 | // Load VL contiguous key bytes into one vector, and 637 | // compare the current state against all of them at once 638 | V key = *(V *)&edge_keys[c][i]; 639 | vmask eq = test_eq(key, state_vec); 640 | 641 | // If there was a match, there might be multiple 642 | // predecessor states from this state/input byte. Since 643 | // we sort the keys, we can just get the index of the 644 | // first index with a bitscan, find the last matches with 645 | // more vcmps and another bitscan, and copy all the bytes 646 | // in between at once. 647 | if (eq) { 648 | uint32_t start = bsf64(eq) / MOVEMASK_HACK + i; 649 | 650 | // Now that we have the start index, find the next 651 | // key that *isn't* equal to this state. We can 652 | // check for one within the same vector of keys that 653 | // the first key was in with a simple bitwise check: 654 | // all of the equal bits will be in one contiguous 655 | // group. If we add in the least significant bit of 656 | // the equality mask, we'll get a carry into the first 657 | // zero bit. If there aren't any unequal keys within 658 | // this vector after the equal keys, this addition 659 | // will overflow past the bits in the VCMP_BITS mask. 660 | vmask gt = eq + (eq & -eq); 661 | gt &= VCMP_BITS; 662 | 663 | // Loop through the rest of the keys until we find 664 | // the first unequal key. We always store at least one 665 | // sentinel at the end, so we don't need to check the 666 | // length here 667 | uint32_t j = i; 668 | for (; !gt; j += N_VE) { 669 | V key = *(V *)&edge_keys[c][j + N_VE]; 670 | gt = ~test_eq(key, state_vec); 671 | } 672 | 673 | uint32_t end = bsf64(gt) / MOVEMASK_HACK + j; 674 | 675 | // Copy an entire vectors' worth of values at a time, 676 | // possibly past the end of the array (we should have 677 | // enough space). We additionally copy over this state's 678 | // starting index, propagating it to all of its next states. 679 | VE *from = &edge_values[c][start]; 680 | VE *to = &next_states[*next_n_states]; 681 | uint32_t *to_idx = &next_start_idx[*next_n_states]; 682 | for (uint32_t x = start; x < end; 683 | x += N_VE, from += N_VE, to += N_VE) { 684 | *(V *)to = *(V *)from; 685 | 686 | // Weird loop thing: if our main NFA states are only 687 | // stored in one or two bytes, it takes more vector writes 688 | // to copy our start index over to all the next states 689 | for (uint32_t x_i = 0; x_i < sizeof(uint32_t) / sizeof(VE); 690 | x_i++, to_idx += (VL / sizeof(uint32_t))) 691 | *(V *)to_idx = start_idx_vec; 692 | } 693 | 694 | // Update the index to reflect only the valid values 695 | (*next_n_states) += end - start; 696 | break; 697 | } 698 | } 699 | skip: ; 700 | } 701 | 702 | // Flip the double buffers for the next iteration 703 | std::swap(states, next_states); 704 | std::swap(n_states, next_n_states); 705 | std::swap(start_idx, next_start_idx); 706 | } while (*n_states && 707 | (Opts::FORWARDS ? ++input_p <= end : --input_p >= data)); 708 | 709 | return NULL; 710 | } 711 | 712 | // Hacky method to run as a full matcher, not just verifier 713 | typename MatchHandler::return_type run(File &f) { 714 | match_handler.start(); 715 | 716 | assert(Opts::CONTINUOUS); 717 | 718 | input_file = &f; 719 | verify(f.data, f.data + f.size - 1); 720 | 721 | return match_handler.finish(f); 722 | } 723 | }; 724 | 725 | typedef _Squamatus SquamatusVerifier8; 726 | typedef _Squamatus SquamatusVerifier16; 727 | typedef _Squamatus SquamatusVerifier32; 728 | 729 | //////////////////////////////////////////////////////////////////////////////// 730 | // Match handlers 731 | // These classes take verified matches and perform an action (count it, print 732 | // it out, etc.) 733 | //////////////////////////////////////////////////////////////////////////////// 734 | 735 | // Count all occurrences of a match, ignoring newlines 736 | class MatchHandlerBasicCounter { 737 | uint32_t _match_count; 738 | 739 | public: 740 | typedef uint32_t return_type; 741 | 742 | MatchHandlerBasicCounter() { } 743 | 744 | void start() { 745 | _match_count = 0; 746 | } 747 | 748 | void handle_match(UNUSED File &f, UNUSED const uint8_t *start, UNUSED const uint8_t *end) { 749 | _match_count++; 750 | } 751 | 752 | return_type finish(UNUSED File &f) { 753 | return _match_count; 754 | } 755 | }; 756 | 757 | // Basic grep-style handling: print out the full line, ignore other matches 758 | // on the same line. We support counting matching lines and printing the path. 759 | class MatchHandlerPrintLine { 760 | bool print_path; 761 | bool print_matches; 762 | bool print_count; 763 | bool print_colors; 764 | uint32_t match_count; 765 | const uint8_t *current_match_end; 766 | const uint8_t *current_line_end; 767 | 768 | public: 769 | typedef uint32_t return_type; 770 | 771 | MatchHandlerPrintLine(bool print_path, bool print_matches, bool print_count, 772 | bool print_colors) : print_path(print_path), print_matches(print_matches), 773 | print_count(print_count), print_colors(print_colors) { } 774 | 775 | void start() { 776 | match_count = 0; 777 | current_match_end = NULL; 778 | current_line_end = NULL; 779 | } 780 | 781 | inline void flush_line() { 782 | if (print_matches && current_match_end) { 783 | std::string ending(current_match_end + 1, current_line_end); 784 | std::cout << ending << "\n"; 785 | current_match_end = NULL; 786 | } 787 | } 788 | 789 | inline void handle_match(File &f, const uint8_t *start, const uint8_t *end) { 790 | const uint8_t *pre_match; 791 | 792 | // Is this match on a new line? If so, we possibly need to flush the last 793 | // part of the last match, and then print the beginning of this line 794 | // XXX this doesn't work for the backwards matching case, which 795 | // REALLY REALLY MATTERS 796 | if (start >= current_line_end) { 797 | flush_line(); 798 | 799 | pre_match = skip_chr(start, '\n', false, f.data); 800 | current_line_end = skip_chr(end, '\n', true, f.data + f.size); 801 | if (print_path) 802 | std::cout << f.path << ":"; 803 | } 804 | // Otherwise, we print out everything after the end of the last match 805 | // first 806 | else { 807 | pre_match = current_match_end; 808 | // Weird overlapping case--only happens when we allow overlapping 809 | // matches 810 | if (pre_match >= start) 811 | start = pre_match + 1; 812 | } 813 | 814 | // Print out everything on this line (or after the last match) leading 815 | // up to this match, and then print this match, optionally with ANSI 816 | // color highlighting 817 | if (print_matches) { 818 | // These strings have additional copies that are kind of annoying 819 | std::string begin(pre_match + 1, start); 820 | std::string mid(start, end + 1); 821 | 822 | if (print_colors) 823 | std::cout << begin << "\033[31;40m" << mid 824 | << "\033[0m"; 825 | else 826 | std::cout << begin << mid; 827 | } 828 | 829 | match_count++; 830 | 831 | // Store where the last match ended. We will print the rest of the line 832 | // later, either with more matches or not, and need to know where we 833 | // stopped printing 834 | current_match_end = end; 835 | } 836 | 837 | return_type finish(File &f) { 838 | flush_line(); 839 | 840 | if (print_count) { 841 | if (print_path) 842 | std::cout << f.path << ":"; 843 | std::cout << match_count << "\n"; 844 | } 845 | return match_count; 846 | } 847 | }; 848 | 849 | //////////////////////////////////////////////////////////////////////////////// 850 | // Miroslav: the core SIMD string matching algorithm 851 | //////////////////////////////////////////////////////////////////////////////// 852 | template 854 | class _Miroslav { 855 | // Hacky substitute for "using" 856 | typedef typename VI::V V; 857 | typedef typename VI::vmask vmask; 858 | typedef typename VI::double_vmask double_vmask; 859 | static const uint32_t VL = VI::VL; 860 | static const uint32_t LMASK = VI::LMASK; 861 | 862 | static inline uint8_t state_mask(uint32_t state, uint32_t byte) { 863 | // If we only have one byte of state bits, start and end states 864 | // get the bottom and top bits, and all the other states rotate 865 | // through the other 6 bits 866 | if (N_BYTES == 1) { 867 | if (state == START_STATE) 868 | return 1 << 0; 869 | if (state == END_STATE) 870 | return 1 << 7; 871 | return 1 << (state % 6 + 1); 872 | } 873 | // For two or more state bytes, we use the top bit of the first 874 | // two bytes for the start and end states (for vpmovmskb convenience) 875 | // and all the other states rotate through the remaining bits 876 | else { 877 | if (state == START_STATE) 878 | return byte == 0 ? 1 << 7 : 0; 879 | if (state == END_STATE) 880 | return byte == 1 ? 1 << 7 : 0; 881 | state = state % (N_BYTES * 8 - 2); 882 | // Skip over bit 7 (start) and 15 (end) 883 | state += (state >= 7); 884 | state += (state >= 15); 885 | // Return the proper bit within this byte if the high bits of the 886 | // state id match the byte number 887 | return ((state >> 3) == byte) << (state & 7); 888 | } 889 | } 890 | 891 | static inline uint8_t char_mask(uint8_t c, uint32_t shift) { 892 | return (c >> shift) & LMASK; 893 | } 894 | 895 | MatchVerifier match_verifier; 896 | MatchHandler &match_handler; 897 | 898 | V from_states[N_BYTES][N_SHIFTS]; 899 | V to_states[N_BYTES][N_SHIFTS]; 900 | V v_char_mask; 901 | 902 | // For branchless testing of whether a pattern has a 1-character match. 903 | // Cuts down on false matches when the start and end characters of a 904 | // pattern hash to the same character class, which would cause every 905 | // occurrence of a character in that class to generate a match 906 | vmask has_1_char_match; 907 | 908 | public: 909 | _Miroslav(NFAEdgeList &edges, MatchHandler &handler) : 910 | match_verifier(edges), match_handler(handler) { 911 | uint8_t from_state_bytes[N_BYTES][N_SHIFTS][VL] = {{{0}}}; 912 | uint8_t to_state_bytes[N_BYTES][N_SHIFTS][VL] = {{{0}}}; 913 | 914 | has_1_char_match = 0; 915 | 916 | // Initialize from/to state masks tables 917 | uint8_t c; 918 | uint32_t from, to; 919 | FOR_EACH_EDGE(c, from, to, edges) { 920 | for (uint32_t b = 0; b < N_BYTES; b++) { 921 | uint8_t fm = state_mask(from, b); 922 | uint8_t tm = state_mask(to, b); 923 | for (uint32_t s = 0; s < N_SHIFTS; s++) { 924 | from_state_bytes[b][s][char_mask(c, SHIFTS[s])] |= fm; 925 | to_state_bytes[b][s][char_mask(c, SHIFTS[s])] |= tm; 926 | } 927 | } 928 | 929 | // Check for 1-character matches and update the mask 930 | if (from == START_STATE && to == END_STATE) 931 | has_1_char_match = (vmask)-1; 932 | } 933 | 934 | for (uint32_t b = 0; b < N_BYTES; b++) { 935 | for (uint32_t s = 0; s < N_SHIFTS; s++) { 936 | VI::prepare_state_table(from_state_bytes[b][s]); 937 | VI::prepare_state_table(to_state_bytes[b][s]); 938 | from_states[b][s] = *(V*)from_state_bytes[b][s]; 939 | to_states[b][s] = *(V*)to_state_bytes[b][s]; 940 | } 941 | } 942 | 943 | v_char_mask = broadcast(LMASK); 944 | } 945 | 946 | typename MatchHandler::return_type run(File &f) { 947 | match_handler.start(); 948 | 949 | double_vmask carry, last_carry = 0; 950 | 951 | // Fill a vector for each byte of state mask with the starting state. This 952 | // vector tracks the state between iterations of the main loop. Only the 953 | // last byte of each of these vectors is ever used. 954 | V last_to[N_BYTES]; 955 | for (uint32_t b = 0; b < N_BYTES; b++) 956 | last_to[b] = broadcast(state_mask(START_STATE, b)); 957 | 958 | const uint8_t *chunk; 959 | for (chunk = f.data; chunk + VL <= f.data + f.size; chunk += VL) { 960 | V input = *(V*)chunk; 961 | 962 | vmask start_m, end_m, seq_m = 0; 963 | 964 | for (uint32_t b = 0; b < N_BYTES; b++) { 965 | V from, to; 966 | // For each of the shifts defined in the template arguments, do 967 | // a vector table lookup by shifting each byte by the shift, masking, 968 | // and doing a permute on the appropriate table vector. We AND the 969 | // results together for each of these shifts, which should narrow down 970 | // the number of false positives from different input bytes 971 | // mapping to the same character class. 972 | for (uint32_t s = 0; s < N_SHIFTS; s++) { 973 | V masked_input = input; 974 | if (SHIFTS[s]) 975 | masked_input = VI::vec_shr(masked_input, SHIFTS[s]); 976 | masked_input = VI::vec_and(masked_input, v_char_mask); 977 | 978 | V f = VI::permute(from_states[b][s], masked_input); 979 | V t = VI::permute(to_states[b][s], masked_input); 980 | if (s == 0) { 981 | from = f; 982 | to = t; 983 | } else { 984 | from = VI::vec_and(from, f); 985 | to = VI::vec_and(to, t); 986 | } 987 | } 988 | 989 | // Get a vector of the to states, but shifted back in the data 990 | // stream by 1 byte. We fill the empty space in the first lane with 991 | // the last lane from last_to (which is initialized to the starting 992 | // state). 993 | V shl_to_1 = VI::vec_lanes_shl_1(to, last_to[b]); 994 | last_to[b] = to; 995 | 996 | // Test which input bytes can lead from the start state, and lead to 997 | // the end state. We handle the N_BYTES=1 and N cases differently, 998 | // because for 1, both bits are in the same byte, but otherwise, it's 999 | // a bit cheaper to have the start and end state bits in the top bit 1000 | // of the first two state bytes, since vpmovmskb only looks at the 1001 | // high bits of each byte in a vector. All this code should be 1002 | // unrolled/branchless, with start_m and end_m being set exactly once. 1003 | if (N_BYTES == 1) { 1004 | start_m = VI::test_low_bit(from); 1005 | end_m = VI::test_high_bit(to); 1006 | } else { 1007 | if (b == 0) 1008 | start_m = VI::test_high_bit(from); 1009 | else if (b == 1) 1010 | end_m = VI::test_high_bit(to); 1011 | } 1012 | 1013 | // Now find all input bytes that can come from some state that the 1014 | // previous input byte could lead to. 1015 | V seq = VI::vec_and(shl_to_1, from); 1016 | seq_m |= VI::test_nz(seq); 1017 | } 1018 | 1019 | // Test for potential matches. We use the ripple of carries through the 1020 | // "seq" mask to find sequences of input bytes that lead from the start 1021 | // state to the end state, while passing through valid state transitions. 1022 | // To be precise, since carries will ripple past the end mask, we find 1023 | // bits in the end mask that are cleared by a carry. This is slightly 1024 | // complicated by a couple factors: first, we have to keep track of 1025 | // carries across iterations (which is why we use the "double_vmask" 1026 | // type and shift carry right by VL each iteration), and second, extra 1027 | // bits in the start mask can make us think a carry didn't happen, so 1028 | // we clear out bits from start_m before testing the carries with 1029 | // the end mask. 1030 | carry = last_carry + ((double_vmask)start_m << 1) + seq_m; 1031 | vmask covered = end_m & seq_m; 1032 | vmask matches = ~(carry & ~start_m) & covered; 1033 | 1034 | // Check for 1-char matches, if they're possible 1035 | matches |= has_1_char_match & start_m & end_m; 1036 | 1037 | last_carry = carry >> VL; 1038 | 1039 | // Look through the bitset of all potential matches, and run a 1040 | // backwards verification step to weed out false positives. Any 1041 | // real matches we pass off to the match handler. 1042 | while (EXPECT(matches, 0)) { 1043 | const uint8_t *end = chunk + bsf64(matches); 1044 | matches &= matches - 1; 1045 | 1046 | const uint8_t *start = match_verifier.verify(f.data, end); 1047 | if (EXPECT(start != NULL, 0)) 1048 | match_handler.handle_match(f, start, end); 1049 | } 1050 | } 1051 | 1052 | // Run the slow backwards NFA on all the remainder bytes that don't fit 1053 | // in a vector register. We could probably read past the original 1054 | // buffer or do an unaligned read or something, but oh well. 1055 | for (const uint8_t *end = chunk; end < f.data + f.size; end++) { 1056 | const uint8_t *start = match_verifier.verify(f.data, end); 1057 | if (start) 1058 | match_handler.handle_match(f, start, end); 1059 | } 1060 | 1061 | return match_handler.finish(f); 1062 | } 1063 | }; 1064 | 1065 | // Template alias using VEC_INFO class defined in the Makefile 1066 | static constexpr uint32_t SHIFTS[] = {0, 3}; 1067 | static constexpr uint32_t N_SHIFTS = sizeof(SHIFTS) / sizeof(SHIFTS[0]); 1068 | template 1069 | using Miroslav = _Miroslav; 1070 | --------------------------------------------------------------------------------