├── .github └── workflows │ └── semgrep.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── ac.cxx ├── ac.h ├── ac_fast.cxx ├── ac_fast.hpp ├── ac_lua.cxx ├── ac_slow.cxx ├── ac_slow.hpp ├── ac_util.hpp ├── load_ac.lua ├── mytest.cxx └── tests ├── Makefile ├── ac_bench.cxx ├── ac_test_aggr.cxx ├── ac_test_simple.cxx ├── dict ├── README.txt └── dict1.txt ├── load_ac_test.lua ├── lua_test.lua ├── test_base.hpp ├── test_bigfile.cxx └── test_main.cxx /.github/workflows/semgrep.yml: -------------------------------------------------------------------------------- 1 | 2 | on: 3 | pull_request: {} 4 | workflow_dispatch: {} 5 | push: 6 | branches: 7 | - main 8 | - master 9 | name: Semgrep config 10 | jobs: 11 | semgrep: 12 | name: semgrep/ci 13 | runs-on: ubuntu-20.04 14 | env: 15 | SEMGREP_APP_TOKEN: ${{ secrets.SEMGREP_APP_TOKEN }} 16 | SEMGREP_URL: https://cloudflare.semgrep.dev 17 | SEMGREP_APP_URL: https://cloudflare.semgrep.dev 18 | SEMGREP_VERSION_CHECK_URL: https://cloudflare.semgrep.dev/api/check-version 19 | container: 20 | image: returntocorp/semgrep 21 | steps: 22 | - uses: actions/checkout@v3 23 | - run: semgrep ci 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.d 2 | *.o 3 | *.a 4 | *.so 5 | *_dep.txt 6 | tests/testinput 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 CloudFlare, Inc. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of CloudFlare, Inc. nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | OS := $(shell uname) 2 | 3 | ifeq ($(OS), Darwin) 4 | SO_EXT := dylib 5 | else 6 | SO_EXT := so 7 | endif 8 | 9 | ############################################################################# 10 | # 11 | # Binaries we are going to build 12 | # 13 | ############################################################################# 14 | # 15 | C_SO_NAME = libac.$(SO_EXT) 16 | LUA_SO_NAME = ahocorasick.$(SO_EXT) 17 | AR_NAME = libac.a 18 | 19 | ############################################################################# 20 | # 21 | # Compile and link flags 22 | # 23 | ############################################################################# 24 | PREFIX ?= /usr/local 25 | LUA_VERSION := 5.1 26 | LUA_INCLUDE_DIR := $(PREFIX)/include/lua$(LUA_VERSION) 27 | SO_TARGET_DIR := $(PREFIX)/lib/lua/$(LUA_VERSION) 28 | LUA_TARGET_DIR := $(PREFIX)/share/lua/$(LUA_VERSION) 29 | 30 | # Available directives: 31 | # -DDEBUG : Turn on debugging support 32 | # -DVERIFY : To verify if the slow-version and fast-version implementations 33 | # get exactly the same result. Note -DVERIFY implies -DDEBUG. 34 | # 35 | COMMON_FLAGS = -O3 #-g -DVERIFY -msse2 -msse3 -msse4.1 36 | COMMON_FLAGS += -fvisibility=hidden -Wall $(CXXFLAGS) $(MY_CXXFLAGS) $(CPPFLAGS) 37 | 38 | SO_CXXFLAGS = $(COMMON_FLAGS) -fPIC 39 | SO_LFLAGS = $(COMMON_FLAGS) $(LDFLAGS) 40 | AR_CXXFLAGS = $(COMMON_FLAGS) 41 | 42 | # -DVERIFY implies -DDEBUG 43 | ifneq ($(findstring -DVERIFY, $(COMMON_FLAGS)), ) 44 | ifeq ($(findstring -DDEBUG, $(COMMON_FLAGS)), ) 45 | COMMON_FLAGS += -DDEBUG 46 | endif 47 | endif 48 | 49 | AR = ar 50 | AR_FLAGS = cru 51 | 52 | ############################################################################# 53 | # 54 | # Divide source codes and objects into several categories 55 | # 56 | ############################################################################# 57 | # 58 | SRC_COMMON := ac_fast.cxx ac_slow.cxx 59 | LIBAC_SO_SRC := $(SRC_COMMON) ac.cxx # source for libac.so 60 | LUA_SO_SRC := $(SRC_COMMON) ac_lua.cxx # source for ahocorasick.so 61 | LIBAC_A_SRC := $(LIBAC_SO_SRC) # source for libac.a 62 | 63 | ############################################################################# 64 | # 65 | # Make rules 66 | # 67 | ############################################################################# 68 | # 69 | .PHONY = all clean test benchmark prepare 70 | all : $(C_SO_NAME) $(LUA_SO_NAME) $(AR_NAME) 71 | 72 | -include c_so_dep.txt 73 | -include lua_so_dep.txt 74 | -include ar_dep.txt 75 | 76 | BUILD_SO_DIR := build_so 77 | BUILD_AR_DIR := build_ar 78 | 79 | $(BUILD_SO_DIR) :; mkdir $@ 80 | $(BUILD_AR_DIR) :; mkdir $@ 81 | 82 | $(BUILD_SO_DIR)/%.o : %.cxx | $(BUILD_SO_DIR) 83 | $(CXX) $< -c $(SO_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@ 84 | 85 | $(BUILD_AR_DIR)/%.o : %.cxx | $(BUILD_AR_DIR) 86 | $(CXX) $< -c $(AR_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@ 87 | 88 | ifneq ($(OS), Darwin) 89 | $(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o}) 90 | $(CXX) $+ -shared -Wl,-soname=$(C_SO_NAME) $(SO_LFLAGS) -o $@ 91 | cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt 92 | 93 | $(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o}) 94 | $(CXX) $+ -shared -Wl,-soname=$(LUA_SO_NAME) $(SO_LFLAGS) -o $@ 95 | cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt 96 | 97 | else 98 | $(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o}) 99 | $(CXX) $+ -shared $(SO_LFLAGS) -o $@ 100 | cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt 101 | 102 | $(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o}) 103 | $(CXX) $+ -shared $(SO_LFLAGS) -o $@ -Wl,-undefined,dynamic_lookup 104 | cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt 105 | endif 106 | 107 | $(AR_NAME) : $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.o}) 108 | $(AR) $(AR_FLAGS) $@ $+ 109 | cat $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.d}) > lua_so_dep.txt 110 | 111 | ############################################################################# 112 | # 113 | # Misc 114 | # 115 | ############################################################################# 116 | # 117 | test : $(C_SO_NAME) 118 | $(MAKE) -C tests && \ 119 | luajit tests/lua_test.lua && \ 120 | luajit tests/load_ac_test.lua 121 | 122 | benchmark: $(C_SO_NAME) 123 | $(MAKE) benchmark -C tests 124 | 125 | clean : 126 | -rm -rf *.o *.d c_so_dep.txt lua_so_dep.txt ar_dep.txt $(TEST) \ 127 | $(C_SO_NAME) $(LUA_SO_NAME) $(TEST) $(BUILD_SO_DIR) $(BUILD_AR_DIR) \ 128 | $(AR_NAME) 129 | make clean -C tests 130 | 131 | install: 132 | install -D -m 755 $(C_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(C_SO_NAME) 133 | install -D -m 755 $(LUA_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(LUA_SO_NAME) 134 | install -D -m 664 load_ac.lua $(DESTDIR)/$(LUA_TARGET_DIR)/load_ac.lua 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | aho-corasick-lua 2 | ================ 3 | 4 | C++ and Lua Implementation of the Aho-Corasick (AC) string matching algorithm 5 | (http://dl.acm.org/citation.cfm?id=360855). 6 | 7 | We began with pure Lua implementation and realize the performance is not 8 | satisfactory. So we switch to C/C++ implementation. 9 | 10 | There are two shared objects provied by this package: libac.so and ahocorasick.so 11 | The former is a regular shared object which can be directly used by C/C++ 12 | application, or by Lua via FFI; and the later is a Lua module. An example usage 13 | is shown below: 14 | 15 | ```lua 16 | local ac = require "ahocorasick" 17 | local dict = {"string1", "string", "etc"} 18 | local acinst = ac.create(dict) 19 | local r = ac.match(acinst, "mystring") 20 | ``` 21 | 22 | For efficiency reasons, the implementation is slightly different from the 23 | standard AC algorithm in that it doesn't return a set of strings in the dictionary 24 | that match the given string, instead it only returns one of them in case the string 25 | matches. The functionality of our implementation can be (precisely) described by 26 | following pseudo-c snippet. 27 | 28 | ```C 29 | string foo(input-string, dictionary) { 30 | string ret = the-end-of-input-string; 31 | for each string s in dictionary { 32 | // find the first occurrence match sub-string. 33 | ret = min(ret, strstr(input-string, s); 34 | } 35 | return ret; 36 | } 37 | ``` 38 | 39 | It's pretty easy to get rid of this limitation, just to associate each state with 40 | a spare bit-vector depicting the set of strings recognized by that state. 41 | -------------------------------------------------------------------------------- /ac.cxx: -------------------------------------------------------------------------------- 1 | // Interface functions for libac.so 2 | // 3 | #include "ac_slow.hpp" 4 | #include "ac_fast.hpp" 5 | #include "ac.h" 6 | 7 | static inline ac_result_t 8 | _match(buf_header_t* ac, const char* str, unsigned int len) { 9 | AC_Buffer* buf = (AC_Buffer*)(void*)ac; 10 | ASSERT(ac->magic_num == AC_MAGIC_NUM); 11 | 12 | ac_result_t r = Match(buf, str, len); 13 | 14 | #ifdef VERIFY 15 | { 16 | Match_Result r2 = buf->slow_impl->Match(str, len); 17 | if (r.match_begin != r2.begin) { 18 | ASSERT(0); 19 | } else { 20 | ASSERT((r.match_begin < 0) || 21 | (r.match_end == r2.end && 22 | r.pattern_idx == r2.pattern_idx)); 23 | } 24 | } 25 | #endif 26 | return r; 27 | } 28 | 29 | extern "C" int 30 | ac_match2(ac_t* ac, const char* str, unsigned int len) { 31 | ac_result_t r = _match((buf_header_t*)(void*)ac, str, len); 32 | return r.match_begin; 33 | } 34 | 35 | extern "C" ac_result_t 36 | ac_match(ac_t* ac, const char* str, unsigned int len) { 37 | return _match((buf_header_t*)(void*)ac, str, len); 38 | } 39 | 40 | extern "C" ac_result_t 41 | ac_match_longest_l(ac_t* ac, const char* str, unsigned int len) { 42 | AC_Buffer* buf = (AC_Buffer*)(void*)ac; 43 | ASSERT(((buf_header_t*)ac)->magic_num == AC_MAGIC_NUM); 44 | 45 | ac_result_t r = Match_Longest_L(buf, str, len); 46 | return r; 47 | } 48 | 49 | class BufAlloc : public Buf_Allocator { 50 | public: 51 | virtual AC_Buffer* alloc(int sz) { 52 | return (AC_Buffer*)(new unsigned char[sz]); 53 | } 54 | 55 | // Do not de-allocate the buffer when the BufAlloc die. 56 | virtual void free() {} 57 | 58 | static void myfree(AC_Buffer* buf) { 59 | ASSERT(buf->hdr.magic_num == AC_MAGIC_NUM); 60 | const char* b = (const char*)buf; 61 | delete[] b; 62 | } 63 | }; 64 | 65 | extern "C" ac_t* 66 | ac_create(const char** strv, unsigned int* strlenv, unsigned int v_len) { 67 | if (v_len >= 65535) { 68 | // TODO: Currently we use 16-bit to encode pattern-index (see the 69 | // comment to AC_State::is_term), therefore we are not able to 70 | // handle pattern set with more than 65535 entries. 71 | return 0; 72 | } 73 | 74 | ACS_Constructor *acc; 75 | #ifdef VERIFY 76 | acc = new ACS_Constructor; 77 | #else 78 | ACS_Constructor tmp; 79 | acc = &tmp; 80 | #endif 81 | acc->Construct(strv, strlenv, v_len); 82 | 83 | BufAlloc ba; 84 | AC_Converter cvt(*acc, ba); 85 | AC_Buffer* buf = cvt.Convert(); 86 | 87 | #ifdef VERIFY 88 | buf->slow_impl = acc; 89 | #endif 90 | return (ac_t*)(void*)buf; 91 | } 92 | 93 | extern "C" void 94 | ac_free(void* ac) { 95 | AC_Buffer* buf = (AC_Buffer*)ac; 96 | #ifdef VERIFY 97 | delete buf->slow_impl; 98 | #endif 99 | 100 | BufAlloc::myfree(buf); 101 | } 102 | -------------------------------------------------------------------------------- /ac.h: -------------------------------------------------------------------------------- 1 | #ifndef AC_H 2 | #define AC_H 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | #define AC_EXPORT __attribute__ ((visibility ("default"))) 8 | 9 | /* If the subject-string doesn't match any of the given patterns, "match_begin" 10 | * should be a negative; otherwise the substring of the subject-string, 11 | * starting from offset "match_begin" to "match_end" incusively, 12 | * should exactly match the pattern specified by the 'pattern_idx' (i.e. 13 | * the pattern is "pattern_v[pattern_idx]" where the "pattern_v" is the 14 | * first actual argument passing to ac_create()) 15 | */ 16 | typedef struct { 17 | int match_begin; 18 | int match_end; 19 | int pattern_idx; 20 | } ac_result_t; 21 | 22 | struct ac_t; 23 | 24 | /* Create an AC instance. "pattern_v" is a vector of patterns, the length of 25 | * i-th pattern is specified by "pattern_len_v[i]"; the number of patterns 26 | * is specified by "vect_len". 27 | * 28 | * Return the instance on success, or NUL otherwise. 29 | */ 30 | ac_t* ac_create(const char** pattern_v, unsigned int* pattern_len_v, 31 | unsigned int vect_len) AC_EXPORT; 32 | 33 | ac_result_t ac_match(ac_t*, const char *str, unsigned int len) AC_EXPORT; 34 | 35 | ac_result_t ac_match_longest_l(ac_t*, const char *str, unsigned int len) AC_EXPORT; 36 | 37 | /* Similar to ac_match() except that it only returns match-begin. The rationale 38 | * for this interface is that luajit has hard time in dealing with strcture- 39 | * return-value. 40 | */ 41 | int ac_match2(ac_t*, const char *str, unsigned int len) AC_EXPORT; 42 | 43 | void ac_free(void*) AC_EXPORT; 44 | 45 | #ifdef __cplusplus 46 | } 47 | #endif 48 | 49 | #endif /* AC_H */ 50 | -------------------------------------------------------------------------------- /ac_fast.cxx: -------------------------------------------------------------------------------- 1 | #include // for std::sort 2 | #include "ac_slow.hpp" 3 | #include "ac_fast.hpp" 4 | 5 | uint32 6 | AC_Converter::Calc_State_Sz(const ACS_State* s) const { 7 | AC_State dummy; 8 | uint32 sz = offsetof(AC_State, input_vect); 9 | sz += s->Get_GotoNum() * sizeof(dummy.input_vect[0]); 10 | 11 | if (sz < sizeof(AC_State)) 12 | sz = sizeof(AC_State); 13 | 14 | uint32 align = __alignof__(dummy); 15 | sz = (sz + align - 1) & ~(align - 1); 16 | return sz; 17 | } 18 | 19 | AC_Buffer* 20 | AC_Converter::Alloc_Buffer() { 21 | const vector& all_states = _acs.Get_All_States(); 22 | const ACS_State* root_state = _acs.Get_Root_State(); 23 | uint32 root_fanout = root_state->Get_GotoNum(); 24 | 25 | // Step 1: Calculate the buffer size 26 | AC_Ofst root_goto_ofst, states_ofst_ofst, first_state_ofst; 27 | 28 | // part 1 : buffer header 29 | uint32 sz = root_goto_ofst = sizeof(AC_Buffer); 30 | 31 | // part 2: Root-node's goto function 32 | if (likely(root_fanout != 255)) 33 | sz += 256; 34 | else 35 | root_goto_ofst = 0; 36 | 37 | // part 3: mapping of state's relative position. 38 | unsigned align = __alignof__(AC_Ofst); 39 | sz = (sz + align - 1) & ~(align - 1); 40 | states_ofst_ofst = sz; 41 | 42 | sz += sizeof(AC_Ofst) * all_states.size(); 43 | 44 | // part 4: state's contents 45 | align = __alignof__(AC_State); 46 | sz = (sz + align - 1) & ~(align - 1); 47 | first_state_ofst = sz; 48 | 49 | uint32 state_sz = 0; 50 | for (vector::const_iterator i = all_states.begin(), 51 | e = all_states.end(); i != e; i++) { 52 | state_sz += Calc_State_Sz(*i); 53 | } 54 | state_sz -= Calc_State_Sz(root_state); 55 | 56 | sz += state_sz; 57 | 58 | // Step 2: Allocate buffer, and populate header. 59 | AC_Buffer* buf = _buf_alloc.alloc(sz); 60 | 61 | buf->hdr.magic_num = AC_MAGIC_NUM; 62 | buf->hdr.impl_variant = IMPL_FAST_VARIANT; 63 | buf->buf_len = sz; 64 | buf->root_goto_ofst = root_goto_ofst; 65 | buf->states_ofst_ofst = states_ofst_ofst; 66 | buf->first_state_ofst = first_state_ofst; 67 | buf->root_goto_num = root_fanout; 68 | buf->state_num = _acs.Get_State_Num(); 69 | return buf; 70 | } 71 | 72 | void 73 | AC_Converter::Populate_Root_Goto_Func(AC_Buffer* buf, 74 | GotoVect& goto_vect) { 75 | unsigned char *buf_base = (unsigned char*)(buf); 76 | InputTy* root_gotos = (InputTy*)(buf_base + buf->root_goto_ofst); 77 | const ACS_State* root_state = _acs.Get_Root_State(); 78 | 79 | root_state->Get_Sorted_Gotos(goto_vect); 80 | 81 | // Renumber the ID of root-node's immediate kids. 82 | uint32 new_id = 1; 83 | bool full_fantout = (goto_vect.size() == 255); 84 | if (likely(!full_fantout)) 85 | bzero(root_gotos, 256*sizeof(InputTy)); 86 | 87 | for (GotoVect::iterator i = goto_vect.begin(), e = goto_vect.end(); 88 | i != e; i++, new_id++) { 89 | InputTy c = i->first; 90 | ACS_State* s = i->second; 91 | _id_map[s->Get_ID()] = new_id; 92 | 93 | if (likely(!full_fantout)) 94 | root_gotos[c] = new_id; 95 | } 96 | } 97 | 98 | AC_Buffer* 99 | AC_Converter::Convert() { 100 | // Step 1: Some preparation stuff. 101 | GotoVect gotovect; 102 | 103 | _id_map.clear(); 104 | _ofst_map.clear(); 105 | _id_map.resize(_acs.Get_Next_Node_Id()); 106 | _ofst_map.resize(_acs.Get_Next_Node_Id()); 107 | 108 | // Step 2: allocate buffer to accommodate the entire AC graph. 109 | AC_Buffer* buf = Alloc_Buffer(); 110 | unsigned char* buf_base = (unsigned char*)buf; 111 | 112 | // Step 3: Root node need special care. 113 | Populate_Root_Goto_Func(buf, gotovect); 114 | buf->root_goto_num = gotovect.size(); 115 | _id_map[_acs.Get_Root_State()->Get_ID()] = 0; 116 | 117 | // Step 4: Converting the remaining states by BFSing the graph. 118 | // First of all, enter root's immediate kids to the working list. 119 | vector wl; 120 | State_ID id = 1; 121 | for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end(); 122 | i != e; i++, id++) { 123 | ACS_State* s = i->second; 124 | wl.push_back(s); 125 | _id_map[s->Get_ID()] = id; 126 | } 127 | 128 | AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst); 129 | AC_Ofst ofst = buf->first_state_ofst; 130 | for (uint32 idx = 0; idx < wl.size(); idx++) { 131 | const ACS_State* old_s = wl[idx]; 132 | AC_State* new_s = (AC_State*)(buf_base + ofst); 133 | 134 | // This property should hold as we: 135 | // - States are appended to worklist in the BFS order. 136 | // - sibling states are appended to worklist in the order of their 137 | // corresponding input. 138 | // 139 | State_ID state_id = idx + 1; 140 | ASSERT(_id_map[old_s->Get_ID()] == state_id); 141 | 142 | state_ofst_vect[state_id] = ofst; 143 | 144 | new_s->first_kid = wl.size() + 1; 145 | new_s->depth = old_s->Get_Depth(); 146 | new_s->is_term = old_s->is_Terminal() ? 147 | old_s->get_Pattern_Idx() + 1 : 0; 148 | 149 | uint32 gotonum = old_s->Get_GotoNum(); 150 | new_s->goto_num = gotonum; 151 | 152 | // Populate the "input" field 153 | old_s->Get_Sorted_Gotos(gotovect); 154 | uint32 input_idx = 0; 155 | uint32 id = wl.size() + 1; 156 | InputTy* input_vect = new_s->input_vect; 157 | for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end(); 158 | i != e; i++, id++, input_idx++) { 159 | input_vect[input_idx] = i->first; 160 | 161 | ACS_State* kid = i->second; 162 | _id_map[kid->Get_ID()] = id; 163 | wl.push_back(kid); 164 | } 165 | 166 | _ofst_map[old_s->Get_ID()] = ofst; 167 | ofst += Calc_State_Sz(old_s); 168 | } 169 | 170 | // This assertion might be useful to catch buffer overflow 171 | ASSERT(ofst == buf->buf_len); 172 | 173 | // Populate the fail-link field. 174 | for (vector::iterator i = wl.begin(), e = wl.end(); 175 | i != e; i++) { 176 | const ACS_State* slow_s = *i; 177 | State_ID fast_s_id = _id_map[slow_s->Get_ID()]; 178 | AC_State* fast_s = (AC_State*)(buf_base + state_ofst_vect[fast_s_id]); 179 | if (const ACS_State* fl = slow_s->Get_FailLink()) { 180 | State_ID id = _id_map[fl->Get_ID()]; 181 | fast_s->fail_link = id; 182 | } else 183 | fast_s->fail_link = 0; 184 | } 185 | #ifdef DEBUG 186 | //dump_buffer(buf, stderr); 187 | #endif 188 | return buf; 189 | } 190 | 191 | static inline AC_State* 192 | Get_State_Addr(unsigned char* buf_base, AC_Ofst* StateOfstVect, uint32 state_id) { 193 | ASSERT(state_id != 0 && "root node is handled in speical way"); 194 | ASSERT(state_id < ((AC_Buffer*)buf_base)->state_num); 195 | return (AC_State*)(buf_base + StateOfstVect[state_id]); 196 | } 197 | 198 | // The performance of the binary search is critical to this work. 199 | // 200 | // Here we provide two versions of binary-search functions. 201 | // The non-pristine version seems to consistently out-perform "pristine" one on 202 | // bunch of benchmarks we tested. With the benchmark under tests/testinput/ 203 | // 204 | // The speedup is following on my laptop (core i7, ubuntu): 205 | // 206 | // benchmark was is 207 | // ---------------------------------------- 208 | // image.bin 2.3s 2.0s 209 | // test.tar 6.7s 5.7s 210 | // 211 | // NOTE: As of I write this comment, we only measure the performance on about 212 | // 10+ benchmarks. It's still too early to say which one works better. 213 | // 214 | #if !defined(BS_MULTI_VER) 215 | static bool __attribute__((always_inline)) inline 216 | Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) { 217 | if (vect_len <= 8) { 218 | for (int i = 0; i < vect_len; i++) { 219 | if (input_vect[i] == input) { 220 | idx = i; 221 | return true; 222 | } 223 | } 224 | return false; 225 | } 226 | 227 | // The "low" and "high" must be signed integers, as they could become -1. 228 | // Also since they are signed integer, "(low + high)/2" is slightly more 229 | // expensive than (low+high)>>1 or ((unsigned)(low + high))/2. 230 | // 231 | int low = 0, high = vect_len - 1; 232 | while (low <= high) { 233 | int mid = (low + high) >> 1; 234 | InputTy mid_c = input_vect[mid]; 235 | 236 | if (input < mid_c) 237 | high = mid - 1; 238 | else if (input > mid_c) 239 | low = mid + 1; 240 | else { 241 | idx = mid; 242 | return true; 243 | } 244 | } 245 | return false; 246 | } 247 | 248 | #else 249 | 250 | /* Let us call this version "pristine" version. */ 251 | static inline bool 252 | Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) { 253 | int low = 0, high = vect_len - 1; 254 | while (low <= high) { 255 | int mid = (low + high) >> 1; 256 | InputTy mid_c = input_vect[mid]; 257 | 258 | if (input < mid_c) 259 | high = mid - 1; 260 | else if (input > mid_c) 261 | low = mid + 1; 262 | else { 263 | idx = mid; 264 | return true; 265 | } 266 | } 267 | return false; 268 | } 269 | #endif 270 | 271 | typedef enum { 272 | // Look for the first match. e.g. pattern set = {"ab", "abc", "def"}, 273 | // subject string "ababcdef". The first match would be "ab" at the 274 | // beginning of the subject string. 275 | MV_FIRST_MATCH, 276 | 277 | // Look for the left-most longest match. Follow above example; there are 278 | // two longest matches, "abc" and "def", and the left-most longest match 279 | // is "abc". 280 | MV_LEFT_LONGEST, 281 | 282 | // Similar to the left-most longest match, except that it returns the 283 | // *right* most longest match. Follow above example, the match would 284 | // be "def". NYI. 285 | MV_RIGHT_LONGEST, 286 | 287 | // Return all patterns that match that given subject string. NYI. 288 | MV_ALL_MATCHES, 289 | } MATCH_VARIANT; 290 | 291 | /* The Match_Tmpl is the template for vairants MV_FIRST_MATCH, MV_LEFT_LONGEST, 292 | * MV_RIGHT_LONGEST (If we really really need MV_RIGHT_LONGEST variant, we are 293 | * better off implementing it in a separate function). 294 | * 295 | * The Match_Tmpl supports three variants at once "symbolically", once it's 296 | * instanced to a particular variants, all the code irrelevant to the variants 297 | * will be statically removed. So don't worry about the code like 298 | * "if (variant == MV_XXXX)"; they will not incur any penalty. 299 | * 300 | * The drawback of using template is increased code size. Unfortunately, there 301 | * is no silver bullet. 302 | */ 303 | template static ac_result_t 304 | Match_Tmpl(AC_Buffer* buf, const char* str, uint32 len) { 305 | unsigned char* buf_base = (unsigned char*)(buf); 306 | unsigned char* root_goto = buf_base + buf->root_goto_ofst; 307 | AC_Ofst* states_ofst_vect = (AC_Ofst* )(buf_base + buf->states_ofst_ofst); 308 | 309 | AC_State* state = 0; 310 | uint32 idx = 0; 311 | 312 | // Skip leading chars that are not valid input of root-nodes. 313 | if (likely(buf->root_goto_num != 255)) { 314 | while(idx < len) { 315 | unsigned char c = str[idx++]; 316 | if (unsigned char kid_id = root_goto[c]) { 317 | state = Get_State_Addr(buf_base, states_ofst_vect, kid_id); 318 | break; 319 | } 320 | } 321 | } else { 322 | idx = 1; 323 | state = Get_State_Addr(buf_base, states_ofst_vect, *str); 324 | } 325 | 326 | ac_result_t r = {-1, -1}; 327 | if (likely(state != 0)) { 328 | if (unlikely(state->is_term)) { 329 | /* Dictionary may have string of length 1 */ 330 | r.match_begin = idx - state->depth; 331 | r.match_end = idx - 1; 332 | r.pattern_idx = state->is_term - 1; 333 | 334 | if (variant == MV_FIRST_MATCH) { 335 | return r; 336 | } 337 | } 338 | } 339 | 340 | while (idx < len) { 341 | unsigned char c = str[idx]; 342 | int res; 343 | bool found; 344 | found = Binary_Search_Input(state->input_vect, state->goto_num, c, res); 345 | if (found) { 346 | // The "t = goto(c, current_state)" is valid, advance to state "t". 347 | uint32 kid = state->first_kid + res; 348 | state = Get_State_Addr(buf_base, states_ofst_vect, kid); 349 | idx++; 350 | } else { 351 | // Follow the fail-link. 352 | State_ID fl = state->fail_link; 353 | if (fl == 0) { 354 | // fail-link is root-node, which implies the root-node doesn't 355 | // have 255 valid transitions (otherwise, the fail-link should 356 | // points to "goto(root, c)"), so we don't need speical handling 357 | // as we did before this while-loop is entered. 358 | // 359 | while(idx < len) { 360 | InputTy c = str[idx++]; 361 | if (unsigned char kid_id = root_goto[c]) { 362 | state = 363 | Get_State_Addr(buf_base, states_ofst_vect, kid_id); 364 | break; 365 | } 366 | } 367 | } else { 368 | state = Get_State_Addr(buf_base, states_ofst_vect, fl); 369 | } 370 | } 371 | 372 | // Check to see if the state is terminal state? 373 | if (state->is_term) { 374 | if (variant == MV_FIRST_MATCH) { 375 | ac_result_t r; 376 | r.match_begin = idx - state->depth; 377 | r.match_end = idx - 1; 378 | r.pattern_idx = state->is_term - 1; 379 | return r; 380 | } 381 | 382 | if (variant == MV_LEFT_LONGEST) { 383 | int match_begin = idx - state->depth; 384 | int match_end = idx - 1; 385 | 386 | if (r.match_begin == -1 || 387 | match_end - match_begin > r.match_end - r.match_begin) { 388 | r.match_begin = match_begin; 389 | r.match_end = match_end; 390 | r.pattern_idx = state->is_term - 1; 391 | } 392 | continue; 393 | } 394 | 395 | ASSERT(false && "NYI"); 396 | } 397 | } 398 | 399 | return r; 400 | } 401 | 402 | ac_result_t 403 | Match(AC_Buffer* buf, const char* str, uint32 len) { 404 | return Match_Tmpl(buf, str, len); 405 | } 406 | 407 | ac_result_t 408 | Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len) { 409 | return Match_Tmpl(buf, str, len); 410 | } 411 | 412 | #ifdef DEBUG 413 | void 414 | AC_Converter::dump_buffer(AC_Buffer* buf, FILE* f) { 415 | vector state_ofst; 416 | state_ofst.resize(_id_map.size()); 417 | 418 | fprintf(f, "Id maps between old/slow and new/fast graphs\n"); 419 | int old_id = 0; 420 | for (vector::iterator i = _id_map.begin(), e = _id_map.end(); 421 | i != e; i++, old_id++) { 422 | State_ID new_id = *i; 423 | if (new_id != 0) { 424 | fprintf(f, "%d -> %d, ", old_id, new_id); 425 | } 426 | } 427 | fprintf(f, "\n"); 428 | 429 | int idx = 0; 430 | for (vector::iterator i = _id_map.begin(), e = _id_map.end(); 431 | i != e; i++, idx++) { 432 | uint32 id = *i; 433 | if (id == 0) continue; 434 | state_ofst[id] = _ofst_map[idx]; 435 | } 436 | 437 | unsigned char* buf_base = (unsigned char*)buf; 438 | 439 | // dump root goto-function. 440 | fprintf(f, "root, fanout:%d goto {", buf->root_goto_num); 441 | if (buf->root_goto_num != 255) { 442 | unsigned char* root_goto = buf_base + buf->root_goto_ofst; 443 | for (uint32 i = 0; i < 255; i++) { 444 | if (root_goto[i] != 0) 445 | fprintf(f, "%c->S:%d, ", (unsigned char)i, root_goto[i]); 446 | } 447 | } else { 448 | fprintf(f, "full fanout\n"); 449 | } 450 | fprintf(f, "}\n"); 451 | 452 | // dump remaining states. 453 | AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst); 454 | for (uint32 i = 1, e = buf->state_num; i < e; i++) { 455 | AC_Ofst ofst = state_ofst_vect[i]; 456 | ASSERT(ofst == state_ofst[i]); 457 | fprintf(f, "S:%d, ofst:%d, goto={", i, ofst); 458 | 459 | AC_State* s = (AC_State*)(buf_base + ofst); 460 | State_ID kid = s->first_kid; 461 | for (uint32 k = 0, ke = s->goto_num; k < ke; k++, kid++) 462 | fprintf(f, "%c->S:%d, ", s->input_vect[k], kid); 463 | 464 | fprintf(f, "}, fail-link = S:%d, %s\n", s->fail_link, 465 | s->is_term ? "terminal" : ""); 466 | } 467 | } 468 | #endif 469 | -------------------------------------------------------------------------------- /ac_fast.hpp: -------------------------------------------------------------------------------- 1 | #ifndef AC_FAST_H 2 | #define AC_FAST_H 3 | 4 | #include 5 | #include "ac.h" 6 | #include "ac_slow.hpp" 7 | 8 | using namespace std; 9 | 10 | class ACS_Constructor; 11 | 12 | typedef uint32 AC_Ofst; 13 | typedef uint32 State_ID; 14 | 15 | // The entire "fast" AC graph is converted from its "slow" version, and store 16 | // in an consecutive trunk of memory or "buffer". Since the pointers in the 17 | // fast AC graph are represented as offset relative to the base address of 18 | // the buffer, this fast AC graph is position-independent, meaning cloning 19 | // the fast graph is just to memcpy the entire buffer. 20 | // 21 | // The buffer is laid-out as following: 22 | // 23 | // 1. The buffer header. (i.e. the AC_Buffer content) 24 | // 2. root-node's goto functions. It is represented as an array indiced by 25 | // root-node's valid inputs, and the element is the ID of the corresponding 26 | // transition state (aka kid). To save space, we used 8-bit to represent 27 | // the IDs. ID of root's kids starts with 1. 28 | // 29 | // Root may have 255 valid inputs. In this speical case, i-th element 30 | // stores value i -- i.e the i-th state. So, we don't need such array 31 | // at all. On the other hand, 8-bit is insufficient to encode kids' ID. 32 | // 33 | // 3. An array indiced by state's id, and the element is the offset 34 | // of corresponding state wrt the base address of the buffer. 35 | // 36 | // 4. the contents of states. 37 | // 38 | typedef struct { 39 | buf_header_t hdr; // The header exposed to the user using this lib. 40 | #ifdef VERIFY 41 | ACS_Constructor* slow_impl; 42 | #endif 43 | uint32 buf_len; 44 | AC_Ofst root_goto_ofst; // addr of root node's goto() function. 45 | AC_Ofst states_ofst_ofst; // addr of state pointer vector (indiced by id) 46 | AC_Ofst first_state_ofst; // addr of the first state in the buffer. 47 | uint16 root_goto_num; // fan-out of root-node. 48 | uint16 state_num; // number of states 49 | 50 | // Followed by the gut of the buffer: 51 | // 1. map: root's-valid-input -> kid's id 52 | // 2. map: state's ID -> offset of the state 53 | // 3. states' content. 54 | } AC_Buffer; 55 | 56 | // Depict the state of "fast" AC graph. 57 | typedef struct { 58 | // transition are sorted. For instance, state s1, has two transitions : 59 | // goto(b) -> S_b, goto(a)->S_a. The inputs are sorted in the ascending 60 | // order, and the target states are permuted accordingly. In this case, 61 | // the inputs are sorted as : a, b, and the target states are permuted 62 | // into S_a, S_b. So, S_a is the 1st kid, the ID of kids are consecutive, 63 | // so we don't need to save all the target kids. 64 | // 65 | State_ID first_kid; 66 | AC_Ofst fail_link; 67 | short depth; // How far away from root. 68 | unsigned short is_term; // Is terminal node. if is_term != 0, it encodes 69 | // the value of "1 + pattern-index". 70 | unsigned char goto_num; // The number of valid transition. 71 | InputTy input_vect[1]; // Vector of valid input. Must be last field! 72 | } AC_State; 73 | 74 | class Buf_Allocator { 75 | public: 76 | Buf_Allocator() : _buf(0) {} 77 | virtual ~Buf_Allocator() { free(); } 78 | 79 | virtual AC_Buffer* alloc(int sz) = 0; 80 | virtual void free() {}; 81 | protected: 82 | AC_Buffer* _buf; 83 | }; 84 | 85 | // Convert slow-AC-graph into fast one. 86 | class AC_Converter { 87 | public: 88 | AC_Converter(ACS_Constructor& acs, Buf_Allocator& ba) : 89 | _acs(acs), _buf_alloc(ba) {} 90 | AC_Buffer* Convert(); 91 | 92 | private: 93 | // Return the size in byte needed to to save the specified state. 94 | uint32 Calc_State_Sz(const ACS_State *) const; 95 | 96 | // In fast-AC-graph, the ID is bit trikcy. Given a state of slow-graph, 97 | // this function is to return the ID of its counterpart in the fast-graph. 98 | State_ID Get_Renumbered_Id(const ACS_State *s) const { 99 | const vector &m = _id_map; 100 | return m[s->Get_ID()]; 101 | } 102 | 103 | AC_Buffer* Alloc_Buffer(); 104 | void Populate_Root_Goto_Func(AC_Buffer *, GotoVect&); 105 | 106 | #ifdef DEBUG 107 | void dump_buffer(AC_Buffer*, FILE*); 108 | #endif 109 | 110 | private: 111 | ACS_Constructor& _acs; 112 | Buf_Allocator& _buf_alloc; 113 | 114 | // map: ID of state in slow-graph -> ID of counterpart in fast-graph. 115 | vector _id_map; 116 | 117 | // map: ID of state in slow-graph -> offset of counterpart in fast-graph. 118 | vector _ofst_map; 119 | }; 120 | 121 | ac_result_t Match(AC_Buffer* buf, const char* str, uint32 len); 122 | ac_result_t Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len); 123 | 124 | #endif // AC_FAST_H 125 | -------------------------------------------------------------------------------- /ac_lua.cxx: -------------------------------------------------------------------------------- 1 | // Interface functions for libac.so 2 | // 3 | #include 4 | #include 5 | #include "ac_slow.hpp" 6 | #include "ac_fast.hpp" 7 | #include "ac.h" // for the definition of ac_result_t 8 | #include "ac_util.hpp" 9 | 10 | extern "C" { 11 | #include 12 | #include 13 | } 14 | 15 | #if defined(USE_SLOW_VER) 16 | #error "Not going to implement it" 17 | #endif 18 | 19 | using namespace std; 20 | static const char* tname = "aho-corasick"; 21 | 22 | class BufAlloc : public Buf_Allocator { 23 | public: 24 | BufAlloc(lua_State* L) : _L(L) {} 25 | virtual AC_Buffer* alloc(int sz) { 26 | return (AC_Buffer*)lua_newuserdata (_L, sz); 27 | } 28 | 29 | // Let GC to take care. 30 | virtual void free() {} 31 | 32 | private: 33 | lua_State* _L; 34 | }; 35 | 36 | static bool 37 | _create_helper(lua_State* L, const vector& str_v, 38 | const vector& strlen_v) { 39 | ASSERT(str_v.size() == strlen_v.size()); 40 | 41 | ACS_Constructor acc; 42 | BufAlloc ba(L); 43 | 44 | // Step 1: construct the slow version. 45 | unsigned int strnum = str_v.size(); 46 | const char** str_vect = new const char*[strnum]; 47 | unsigned int* strlen_vect = new unsigned int[strnum]; 48 | 49 | int idx = 0; 50 | for (vector::const_iterator i = str_v.begin(), e = str_v.end(); 51 | i != e; i++) { 52 | str_vect[idx++] = *i; 53 | } 54 | 55 | idx = 0; 56 | for (vector::const_iterator i = strlen_v.begin(), 57 | e = strlen_v.end(); i != e; i++) { 58 | strlen_vect[idx++] = *i; 59 | } 60 | 61 | acc.Construct(str_vect, strlen_vect, idx); 62 | delete[] str_vect; 63 | delete[] strlen_vect; 64 | 65 | // Step 2: convert to fast version 66 | AC_Converter cvt(acc, ba); 67 | return cvt.Convert() != 0; 68 | } 69 | 70 | static ac_result_t 71 | _match_helper(buf_header_t* ac, const char *str, unsigned int len) { 72 | AC_Buffer* buf = (AC_Buffer*)(void*)ac; 73 | ASSERT(ac->magic_num == AC_MAGIC_NUM); 74 | 75 | ac_result_t r = Match(buf, str, len); 76 | return r; 77 | } 78 | 79 | // LUA semantic: 80 | // input: array of strings 81 | // output: userdata containing the AC-graph (i.e. the AC_Buffer). 82 | // 83 | static int 84 | lac_create(lua_State* L) { 85 | // The table of the array must be the 1st argument. 86 | int input_tab = 1; 87 | 88 | luaL_checktype(L, input_tab, LUA_TTABLE); 89 | 90 | // Init the "iteartor". 91 | lua_pushnil(L); 92 | 93 | vector str_v; 94 | vector strlen_v; 95 | 96 | // Loop over the elements 97 | while (lua_next(L, input_tab)) { 98 | size_t str_len; 99 | const char* s = luaL_checklstring(L, -1, &str_len); 100 | str_v.push_back(s); 101 | strlen_v.push_back(str_len); 102 | 103 | // remove the value, but keep the key as the iterator. 104 | lua_pop(L, 1); 105 | } 106 | 107 | // pop the nil value 108 | lua_pop(L, 1); 109 | 110 | if (_create_helper(L, str_v, strlen_v)) { 111 | // The AC graph, as a userdata is already pushed to the stack, hence 1. 112 | return 1; 113 | } 114 | 115 | return 0; 116 | } 117 | 118 | // LUA input: 119 | // arg1: the userdata, representing the AC graph, returned from l_create(). 120 | // arg2: the string to be matched. 121 | // 122 | // LUA return: 123 | // if match, return index range of the match; otherwise nil is returned. 124 | // 125 | static int 126 | lac_match(lua_State* L) { 127 | buf_header_t* ac = (buf_header_t*)lua_touserdata(L, 1); 128 | if (!ac) { 129 | luaL_checkudata(L, 1, tname); 130 | return 0; 131 | } 132 | 133 | size_t len; 134 | const char* str; 135 | #if LUA_VERSION_NUM >= 502 136 | str = luaL_tolstring(L, 2, &len); 137 | #else 138 | str = lua_tolstring(L, 2, &len); 139 | #endif 140 | if (!str) { 141 | luaL_checkstring(L, 2); 142 | return 0; 143 | } 144 | 145 | ac_result_t r = _match_helper(ac, str, len); 146 | if (r.match_begin != -1) { 147 | lua_pushinteger(L, r.match_begin); 148 | lua_pushinteger(L, r.match_end); 149 | return 2; 150 | } 151 | 152 | return 0; 153 | } 154 | 155 | static const struct luaL_Reg lib_funcs[] = { 156 | { "create", lac_create }, 157 | { "match", lac_match }, 158 | {0, 0} 159 | }; 160 | 161 | extern "C" int AC_EXPORT 162 | luaopen_ahocorasick(lua_State* L) { 163 | luaL_newmetatable(L, tname); 164 | 165 | #if LUA_VERSION_NUM == 501 166 | luaL_register(L, tname, lib_funcs); 167 | #elif LUA_VERSION_NUM >= 502 168 | luaL_newlib(L, lib_funcs); 169 | #else 170 | #error "Don't know how to do it right" 171 | #endif 172 | return 1; 173 | } 174 | -------------------------------------------------------------------------------- /ac_slow.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include // for bzero 3 | #include 4 | #include "ac_slow.hpp" 5 | #include "ac.h" 6 | 7 | ////////////////////////////////////////////////////////////////////////// 8 | // 9 | // Implementation of AhoCorasick_Slow 10 | // 11 | ////////////////////////////////////////////////////////////////////////// 12 | // 13 | ACS_Constructor::ACS_Constructor() : _next_node_id(1) { 14 | _root = new_state(); 15 | _root_char = new InputTy[256]; 16 | bzero((void*)_root_char, 256); 17 | 18 | #ifdef VERIFY 19 | _pattern_buf = 0; 20 | #endif 21 | } 22 | 23 | ACS_Constructor::~ACS_Constructor() { 24 | for (std::vector::iterator i = _all_states.begin(), 25 | e = _all_states.end(); i != e; i++) { 26 | delete *i; 27 | } 28 | _all_states.clear(); 29 | delete[] _root_char; 30 | 31 | #ifdef VERIFY 32 | delete[] _pattern_buf; 33 | #endif 34 | } 35 | 36 | ACS_State* 37 | ACS_Constructor::new_state() { 38 | ACS_State* t = new ACS_State(_next_node_id++); 39 | _all_states.push_back(t); 40 | return t; 41 | } 42 | 43 | void 44 | ACS_Constructor::Add_Pattern(const char* str, unsigned int str_len, 45 | int pattern_idx) { 46 | ACS_State* state = _root; 47 | for (unsigned int i = 0; i < str_len; i++) { 48 | const char c = str[i]; 49 | ACS_State* new_s = state->Get_Goto(c); 50 | if (!new_s) { 51 | new_s = new_state(); 52 | new_s->_depth = state->_depth + 1; 53 | state->Set_Goto(c, new_s); 54 | } 55 | state = new_s; 56 | } 57 | state->_is_terminal = true; 58 | state->set_Pattern_Idx(pattern_idx); 59 | } 60 | 61 | void 62 | ACS_Constructor::Propagate_faillink() { 63 | ACS_State* r = _root; 64 | std::vector wl; 65 | 66 | const ACS_Goto_Map& m = r->Get_Goto_Map(); 67 | for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); i != e; i++) { 68 | ACS_State* s = i->second; 69 | s->_fail_link = r; 70 | wl.push_back(s); 71 | } 72 | 73 | // For any input c, make sure "goto(root, c)" is valid, which make the 74 | // fail-link propagation lot easier. 75 | ACS_Goto_Map goto_save = r->_goto_map; 76 | for (uint32 i = 0; i <= 255; i++) { 77 | ACS_State* s = r->Get_Goto(i); 78 | if (!s) r->Set_Goto(i, r); 79 | } 80 | 81 | for (uint32 i = 0; i < wl.size(); i++) { 82 | ACS_State* s = wl[i]; 83 | ACS_State* fl = s->_fail_link; 84 | 85 | const ACS_Goto_Map& tran_map = s->Get_Goto_Map(); 86 | 87 | for (ACS_Goto_Map::const_iterator ii = tran_map.begin(), 88 | ee = tran_map.end(); ii != ee; ii++) { 89 | InputTy c = ii->first; 90 | ACS_State *tran = ii->second; 91 | 92 | ACS_State* tran_fl = 0; 93 | for (ACS_State* fl_walk = fl; ;) { 94 | if (ACS_State* t = fl_walk->Get_Goto(c)) { 95 | tran_fl = t; 96 | break; 97 | } else { 98 | fl_walk = fl_walk->Get_FailLink(); 99 | } 100 | } 101 | 102 | tran->_fail_link = tran_fl; 103 | wl.push_back(tran); 104 | } 105 | } 106 | 107 | // Remove "goto(root, c) == root" transitions 108 | r->_goto_map = goto_save; 109 | } 110 | 111 | void 112 | ACS_Constructor::Construct(const char** strv, unsigned int* strlenv, 113 | uint32 strnum) { 114 | Save_Patterns(strv, strlenv, strnum); 115 | 116 | for (uint32 i = 0; i < strnum; i++) { 117 | Add_Pattern(strv[i], strlenv[i], i); 118 | } 119 | 120 | Propagate_faillink(); 121 | unsigned char* p = _root_char; 122 | 123 | const ACS_Goto_Map& m = _root->Get_Goto_Map(); 124 | for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); 125 | i != e; i++) { 126 | p[i->first] = 1; 127 | } 128 | } 129 | 130 | Match_Result 131 | ACS_Constructor::MatchHelper(const char *str, uint32 len) const { 132 | const ACS_State* root = _root; 133 | const ACS_State* state = root; 134 | 135 | uint32 idx = 0; 136 | while (idx < len) { 137 | InputTy c = str[idx]; 138 | idx++; 139 | if (_root_char[c]) { 140 | state = root->Get_Goto(c); 141 | break; 142 | } 143 | } 144 | 145 | if (unlikely(state->is_Terminal())) { 146 | // This could happen if the one of the pattern has only one char! 147 | uint32 pos = idx - 1; 148 | Match_Result r(pos - state->Get_Depth() + 1, pos, 149 | state->get_Pattern_Idx()); 150 | return r; 151 | } 152 | 153 | while (idx < len) { 154 | InputTy c = str[idx]; 155 | ACS_State* gs = state->Get_Goto(c); 156 | 157 | if (!gs) { 158 | ACS_State* fl = state->Get_FailLink(); 159 | if (fl == root) { 160 | while (idx < len) { 161 | InputTy c = str[idx]; 162 | idx++; 163 | if (_root_char[c]) { 164 | state = root->Get_Goto(c); 165 | break; 166 | } 167 | } 168 | } else { 169 | state = fl; 170 | } 171 | } else { 172 | idx ++; 173 | state = gs; 174 | } 175 | 176 | if (state->is_Terminal()) { 177 | uint32 pos = idx - 1; 178 | Match_Result r = Match_Result(pos - state->Get_Depth() + 1, pos, 179 | state->get_Pattern_Idx()); 180 | return r; 181 | } 182 | } 183 | 184 | return Match_Result(-1, -1, -1); 185 | } 186 | 187 | #ifdef DEBUG 188 | void 189 | ACS_Constructor::dump_text(const char* txtfile) const { 190 | FILE* f = fopen(txtfile, "w+"); 191 | for (std::vector::const_iterator i = _all_states.begin(), 192 | e = _all_states.end(); i != e; i++) { 193 | ACS_State* s = *i; 194 | 195 | fprintf(f, "S%d goto:{", s->Get_ID()); 196 | const ACS_Goto_Map& goto_func = s->Get_Goto_Map(); 197 | 198 | for (ACS_Goto_Map::const_iterator i = goto_func.begin(), e = goto_func.end(); 199 | i != e; i++) { 200 | InputTy input = i->first; 201 | ACS_State* tran = i->second; 202 | if (isprint(input)) 203 | fprintf(f, "'%c' -> S:%d,", input, tran->Get_ID()); 204 | else 205 | fprintf(f, "%#x -> S:%d,", input, tran->Get_ID()); 206 | } 207 | fprintf(f, "} "); 208 | 209 | if (s->_fail_link) { 210 | fprintf(f, ", fail=S:%d", s->_fail_link->Get_ID()); 211 | } 212 | 213 | if (s->_is_terminal) { 214 | fprintf(f, ", terminal"); 215 | } 216 | 217 | fprintf(f, "\n"); 218 | } 219 | fclose(f); 220 | } 221 | 222 | void 223 | ACS_Constructor::dump_dot(const char *dotfile) const { 224 | FILE* f = fopen(dotfile, "w+"); 225 | const char* indent = " "; 226 | 227 | fprintf(f, "digraph G {\n"); 228 | 229 | // Emit node information 230 | fprintf(f, "%s%d [style=filled];\n", indent, _root->Get_ID()); 231 | for (std::vector::const_iterator i = _all_states.begin(), 232 | e = _all_states.end(); i != e; i++) { 233 | ACS_State *s = *i; 234 | if (s->_is_terminal) { 235 | fprintf(f, "%s%d [shape=doublecircle];\n", indent, s->Get_ID()); 236 | } 237 | } 238 | fprintf(f, "\n"); 239 | 240 | // Emit edge information 241 | for (std::vector::const_iterator i = _all_states.begin(), 242 | e = _all_states.end(); i != e; i++) { 243 | ACS_State* s = *i; 244 | uint32 id = s->Get_ID(); 245 | 246 | const ACS_Goto_Map& m = s->Get_Goto_Map(); 247 | for (ACS_Goto_Map::const_iterator ii = m.begin(), ee = m.end(); 248 | ii != ee; ii++) { 249 | InputTy input = ii->first; 250 | ACS_State* tran = ii->second; 251 | if (isalnum(input)) 252 | fprintf(f, "%s%d -> %d [label=%c];\n", 253 | indent, id, tran->Get_ID(), input); 254 | else 255 | fprintf(f, "%s%d -> %d [label=\"%#x\"];\n", 256 | indent, id, tran->Get_ID(), input); 257 | 258 | } 259 | 260 | // Emit fail-link 261 | ACS_State* fl = s->Get_FailLink(); 262 | if (fl && fl != _root) { 263 | fprintf(f, "%s%d -> %d [style=dotted, color=red]; \n", 264 | indent, id, fl->Get_ID()); 265 | } 266 | } 267 | fprintf(f, "}\n"); 268 | fclose(f); 269 | } 270 | #endif 271 | 272 | #ifdef VERIFY 273 | void 274 | ACS_Constructor::Verify_Result(const char* subject, const Match_Result* r) 275 | const { 276 | if (r->begin >= 0) { 277 | unsigned len = r->end - r->begin + 1; 278 | int ptn_idx = r->pattern_idx; 279 | 280 | ASSERT(ptn_idx >= 0 && 281 | len == get_ith_Pattern_Len(ptn_idx) && 282 | memcmp(subject + r->begin, get_ith_Pattern(ptn_idx), len) == 0); 283 | } 284 | } 285 | 286 | void 287 | ACS_Constructor::Save_Patterns(const char** strv, unsigned int* strlenv, 288 | int pattern_num) { 289 | // calculate the total size needed to save all patterns. 290 | // 291 | int buf_size = 0; 292 | for (int i = 0; i < pattern_num; i++) { buf_size += strlenv[i]; } 293 | 294 | // HINT: patterns are delimited by '\0' in order to ease debugging. 295 | buf_size += pattern_num; 296 | ASSERT(_pattern_buf == 0); 297 | _pattern_buf = new char[buf_size + 1]; 298 | #define MAGIC_NUM 0x5a 299 | _pattern_buf[buf_size] = MAGIC_NUM; 300 | 301 | int ofst = 0; 302 | _pattern_lens.resize(pattern_num); 303 | _pattern_vect.resize(pattern_num); 304 | for (int i = 0; i < pattern_num; i++) { 305 | int l = strlenv[i]; 306 | _pattern_lens[i] = l; 307 | _pattern_vect[i] = _pattern_buf + ofst; 308 | 309 | memcpy(_pattern_buf + ofst, strv[i], l); 310 | ofst += l; 311 | _pattern_buf[ofst++] = '\0'; 312 | } 313 | 314 | ASSERT(_pattern_buf[buf_size] == MAGIC_NUM); 315 | #undef MAGIC_NUM 316 | } 317 | 318 | #endif 319 | -------------------------------------------------------------------------------- /ac_slow.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MY_AC_H 2 | #define MY_AC_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include // for std::sort 9 | #include "ac_util.hpp" 10 | 11 | // Forward decl. the acronym "ACS" stands for "Aho-Corasick Slow implementation" 12 | class ACS_State; 13 | class ACS_Constructor; 14 | class AhoCorasick; 15 | 16 | using namespace std; 17 | 18 | typedef std::map ACS_Goto_Map; 19 | 20 | class Match_Result { 21 | public: 22 | int begin; 23 | int end; 24 | int pattern_idx; 25 | Match_Result(int b, int e, int p): begin(b), end(e), pattern_idx(p) {} 26 | }; 27 | 28 | typedef pair GotoPair; 29 | typedef vector GotoVect; 30 | 31 | // Sorting functor 32 | class GotoSort { 33 | public: 34 | bool operator() (const GotoPair& g1, const GotoPair& g2) { 35 | return g1.first < g2.first; 36 | } 37 | }; 38 | 39 | class ACS_State { 40 | friend class ACS_Constructor; 41 | 42 | public: 43 | ACS_State(uint32 id): _id(id), _pattern_idx(-1), _depth(0), 44 | _is_terminal(false), _fail_link(0){} 45 | ~ACS_State() {}; 46 | 47 | void Set_Goto(InputTy c, ACS_State* s) { _goto_map[c] = s; } 48 | ACS_State *Get_Goto(InputTy c) const { 49 | ACS_Goto_Map::const_iterator iter = _goto_map.find(c); 50 | return iter != _goto_map.end() ? (*iter).second : 0; 51 | } 52 | 53 | // Return all transitions sorted in the ascending order of their input. 54 | void Get_Sorted_Gotos(GotoVect& Gotos) const { 55 | const ACS_Goto_Map& m = _goto_map; 56 | Gotos.clear(); 57 | for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); 58 | i != e; i++) { 59 | Gotos.push_back(GotoPair(i->first, i->second)); 60 | } 61 | sort(Gotos.begin(), Gotos.end(), GotoSort()); 62 | } 63 | 64 | ACS_State* Get_FailLink() const { return _fail_link; } 65 | uint32 Get_GotoNum() const { return _goto_map.size(); } 66 | uint32 Get_ID() const { return _id; } 67 | uint32 Get_Depth() const { return _depth; } 68 | const ACS_Goto_Map& Get_Goto_Map(void) const { return _goto_map; } 69 | bool is_Terminal() const { return _is_terminal; } 70 | int get_Pattern_Idx() const { 71 | ASSERT(is_Terminal() && _pattern_idx >= 0); 72 | return _pattern_idx; 73 | } 74 | 75 | private: 76 | void set_Pattern_Idx(int idx) { 77 | ASSERT(is_Terminal()); 78 | _pattern_idx = idx; 79 | } 80 | 81 | private: 82 | uint32 _id; 83 | int _pattern_idx; 84 | short _depth; 85 | bool _is_terminal; 86 | ACS_Goto_Map _goto_map; 87 | ACS_State* _fail_link; 88 | }; 89 | 90 | class ACS_Constructor { 91 | public: 92 | ACS_Constructor(); 93 | ~ACS_Constructor(); 94 | 95 | void Construct(const char** strv, unsigned int* strlenv, 96 | unsigned int strnum); 97 | 98 | Match_Result Match(const char* s, uint32 len) const { 99 | Match_Result r = MatchHelper(s, len); 100 | Verify_Result(s, &r); 101 | return r; 102 | } 103 | 104 | Match_Result Match(const char* s) const { return Match(s, strlen(s)); } 105 | 106 | #ifdef DEBUG 107 | void dump_text(const char* = "ac.txt") const; 108 | void dump_dot(const char* = "ac.dot") const; 109 | #endif 110 | const ACS_State *Get_Root_State() const { return _root; } 111 | const vector& Get_All_States() const { 112 | return _all_states; 113 | } 114 | 115 | uint32 Get_Next_Node_Id() const { return _next_node_id; } 116 | uint32 Get_State_Num() const { return _next_node_id - 1; } 117 | 118 | private: 119 | void Add_Pattern(const char* str, unsigned int str_len, int pattern_idx); 120 | ACS_State* new_state(); 121 | void Propagate_faillink(); 122 | 123 | Match_Result MatchHelper(const char*, uint32 len) const; 124 | 125 | #ifdef VERIFY 126 | void Verify_Result(const char* subject, const Match_Result* r) const; 127 | void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len); 128 | const char* get_ith_Pattern(unsigned i) const { 129 | ASSERT(i < _pattern_vect.size()); 130 | return _pattern_vect.at(i); 131 | } 132 | unsigned get_ith_Pattern_Len(unsigned i) const { 133 | ASSERT(i < _pattern_lens.size()); 134 | return _pattern_lens.at(i); 135 | } 136 | #else 137 | void Verify_Result(const char* subject, const Match_Result* r) const { 138 | (void)subject; (void)r; 139 | } 140 | void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len) { 141 | (void)strv; (void)strlenv; 142 | } 143 | #endif 144 | 145 | private: 146 | ACS_State* _root; 147 | vector _all_states; 148 | unsigned char* _root_char; 149 | uint32 _next_node_id; 150 | 151 | #ifdef VERIFY 152 | char* _pattern_buf; 153 | vector _pattern_lens; 154 | vector _pattern_vect; 155 | #endif 156 | }; 157 | 158 | #endif 159 | -------------------------------------------------------------------------------- /ac_util.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2014 CloudFlare, Inc. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following disclaimer 12 | in the documentation and/or other materials provided with the 13 | distribution. 14 | * Neither the name of CloudFlare, Inc. nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | #ifndef AC_UTIL_H 31 | #define AC_UTIL_H 32 | 33 | #ifdef DEBUG 34 | #include // for fprintf 35 | #include // for abort 36 | #endif 37 | 38 | typedef unsigned short uint16; 39 | typedef unsigned int uint32; 40 | typedef unsigned long uint64; 41 | typedef unsigned char InputTy; 42 | 43 | #ifdef DEBUG 44 | // Usage examples: ASSERT(a > b), ASSERT(foo() && "Oops, foo() returned 0"); 45 | #define ASSERT(c) if (!(c))\ 46 | { fprintf(stderr, "%s:%d Assert: %s\n", __FILE__, __LINE__, #c); abort(); } 47 | #else 48 | #define ASSERT(c) ((void)0) 49 | #endif 50 | 51 | #define likely(x) __builtin_expect((x),1) 52 | #define unlikely(x) __builtin_expect((x),0) 53 | 54 | #ifndef offsetof 55 | #define offsetof(st, m) ((size_t)(&((st *)0)->m)) 56 | #endif 57 | 58 | typedef enum { 59 | IMPL_SLOW_VARIANT = 1, 60 | IMPL_FAST_VARIANT = 2, 61 | } impl_var_t; 62 | 63 | #define AC_MAGIC_NUM 0x5a 64 | typedef struct { 65 | unsigned char magic_num; 66 | unsigned char impl_variant; 67 | } buf_header_t; 68 | 69 | #endif //AC_UTIL_H 70 | -------------------------------------------------------------------------------- /load_ac.lua: -------------------------------------------------------------------------------- 1 | -- Helper wrappring script for loading shared object libac.so (FFI interface) 2 | -- from package.cpath instead of LD_LIBRARTY_PATH. 3 | -- 4 | 5 | local ffi = require 'ffi' 6 | ffi.cdef[[ 7 | void* ac_create(const char** str_v, unsigned int* strlen_v, 8 | unsigned int v_len); 9 | int ac_match2(void*, const char *str, int len); 10 | void ac_free(void*); 11 | ]] 12 | 13 | local _M = {} 14 | 15 | local string_gmatch = string.gmatch 16 | local string_match = string.match 17 | 18 | local ac_lib = nil 19 | local ac_create = nil 20 | local ac_match = nil 21 | local ac_free = nil 22 | 23 | --[[ Find shared object file package.cpath, obviating the need of setting 24 | LD_LIBRARY_PATH 25 | ]] 26 | local function find_shared_obj(cpath, so_name) 27 | for k, v in string_gmatch(cpath, "[^;]+") do 28 | local so_path = string_match(k, "(.*/)") 29 | if so_path then 30 | -- "so_path" could be nil. e.g, the dir path component is "." 31 | so_path = so_path .. so_name 32 | 33 | -- Don't get me wrong, the only way to know if a file exist is 34 | -- trying to open it. 35 | local f = io.open(so_path) 36 | if f ~= nil then 37 | io.close(f) 38 | return so_path 39 | end 40 | end 41 | end 42 | end 43 | 44 | function _M.load_ac_lib() 45 | if ac_lib ~= nil then 46 | return ac_lib 47 | else 48 | local so_path = find_shared_obj(package.cpath, "libac.so") 49 | if so_path ~= nil then 50 | ac_lib = ffi.load(so_path) 51 | ac_create = ac_lib.ac_create 52 | ac_match = ac_lib.ac_match2 53 | ac_free = ac_lib.ac_free 54 | return ac_lib 55 | end 56 | end 57 | end 58 | 59 | -- Create an Aho-Corasick instance, and return the instance if it was 60 | -- successful. 61 | function _M.create_ac(dict) 62 | local strnum = #dict 63 | if ac_lib == nil then 64 | _M.load_ac_lib() 65 | end 66 | 67 | local str_v = ffi.new("const char *[?]", strnum) 68 | local strlen_v = ffi.new("unsigned int [?]", strnum) 69 | 70 | for i = 1, strnum do 71 | local s = dict[i] 72 | str_v[i - 1] = s 73 | strlen_v[i - 1] = #s 74 | end 75 | 76 | local ac = ac_create(str_v, strlen_v, strnum); 77 | if ac ~= nil then 78 | return ffi.gc(ac, ac_free) 79 | end 80 | end 81 | 82 | -- Return nil if str doesn't match the dictionary, else return non-nil. 83 | function _M.match(ac, str) 84 | local r = ac_match(ac, str, #str); 85 | if r >= 0 then 86 | return r 87 | end 88 | end 89 | 90 | return _M 91 | -------------------------------------------------------------------------------- /mytest.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "ac.h" 5 | 6 | using namespace std; 7 | 8 | ///////////////////////////////////////////////////////////////////////// 9 | // 10 | // Test using strings from input files 11 | // 12 | ///////////////////////////////////////////////////////////////////////// 13 | // 14 | class BigFileTester { 15 | public: 16 | BigFileTester(const char* filepath); 17 | 18 | private: 19 | void Genector 20 | privaete: 21 | const char* _msg; 22 | int _msg_len; 23 | int _key_num; // number of strings in dictionary 24 | int _key_len_idx; 25 | }; 26 | 27 | ///////////////////////////////////////////////////////////////////////// 28 | // 29 | // Simple (yet maybe tricky) testings 30 | // 31 | ///////////////////////////////////////////////////////////////////////// 32 | // 33 | typedef struct { 34 | const char* str; 35 | const char* match; 36 | } StrPair; 37 | 38 | typedef struct { 39 | const char* name; 40 | const char** dict; 41 | StrPair* strpairs; 42 | int dict_len; 43 | int strpair_num; 44 | } TestingCase; 45 | 46 | class Tests { 47 | public: 48 | Tests(const char* name, 49 | const char* dict[], int dict_len, 50 | StrPair strpairs[], int strpair_num) { 51 | if (!_tests) 52 | _tests = new vector; 53 | 54 | TestingCase tc; 55 | tc.name = name; 56 | tc.dict = dict; 57 | tc.strpairs = strpairs; 58 | tc.dict_len = dict_len; 59 | tc.strpair_num = strpair_num; 60 | _tests->push_back(tc); 61 | } 62 | 63 | static vector* Get_Tests() { return _tests; } 64 | static void Erase_Tests() { delete _tests; _tests = 0; } 65 | 66 | private: 67 | static vector *_tests; 68 | }; 69 | 70 | vector* Tests::_tests = 0; 71 | 72 | static void 73 | simple_test(void) { 74 | int total = 0; 75 | int fail = 0; 76 | 77 | vector *tests = Tests::Get_Tests(); 78 | if (!tests) 79 | return 0; 80 | 81 | for (vector::iterator i = tests->begin(), e = tests->end(); 82 | i != e; i++) { 83 | TestingCase& t = *i; 84 | fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name); 85 | for (int i = 0, e = t.dict_len, need_break=0; i < e; i++) { 86 | fprintf(stdout, "%s, ", t.dict[i]); 87 | if (need_break++ == 16) { 88 | fputs("\n ", stdout); 89 | need_break = 0; 90 | } 91 | } 92 | fputs("]\n", stdout); 93 | 94 | /* Create the dictionary */ 95 | int dict_len = t.dict_len; 96 | ac_t* ac = ac_create(t.dict, dict_len); 97 | 98 | for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) { 99 | const StrPair& sp = t.strpairs[ii]; 100 | const char *str = sp.str; // the string to be matched 101 | const char *match = sp.match; 102 | 103 | fprintf(stdout, "[%3d] Testing '%s' : ", total, str); 104 | 105 | int len = strlen(str); 106 | ac_result_t r = ac_match(ac, str, len); 107 | int m_b = r.match_begin; 108 | int m_e = r.match_end; 109 | 110 | // The return value per se is insane. 111 | if (m_b > m_e || 112 | ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) { 113 | fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e); 114 | fail ++; 115 | continue; 116 | } 117 | 118 | // If the string is not supposed to match the dictionary. 119 | if (!match) { 120 | if (m_b != -1 || m_e != -1) { 121 | fail ++; 122 | fprintf(stdout, "Not Supposed to match (%d, %d) \n", 123 | m_b, m_e); 124 | } else 125 | fputs("Pass\n", stdout); 126 | continue; 127 | } 128 | 129 | // The string or its substring is match the dict. 130 | if (m_b >= len || m_b >= len) { 131 | fail ++; 132 | fprintf(stdout, 133 | "Return value >= the length of the string (%d, %d)\n", 134 | m_b, m_e); 135 | continue; 136 | } else { 137 | int mlen = strlen(match); 138 | if ((mlen != m_e - m_b + 1) || 139 | strncmp(str + m_b, match, mlen)) { 140 | fail ++; 141 | fprintf(stdout, "Fail\n"); 142 | } else 143 | fprintf(stdout, "Pass\n"); 144 | } 145 | } 146 | fputs("\n", stdout); 147 | ac_free(ac); 148 | } 149 | 150 | fprintf(stdout, "Total : %d, Fail %d\n", total, fail); 151 | 152 | return fail ? -1 : 0; 153 | } 154 | 155 | int 156 | main (int argc, char** argv) { 157 | int res = simple_test(); 158 | return res; 159 | }; 160 | 161 | /* test 1*/ 162 | const char *dict1[] = {"he", "she", "his", "her"}; 163 | StrPair strpair1[] = { 164 | {"he", "he"}, {"she", "she"}, {"his", "his"}, 165 | {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"}, 166 | {"shis2", "his"}, {"ahhe", "he"} 167 | }; 168 | Tests test1("test 1", 169 | dict1, sizeof(dict1)/sizeof(dict1[0]), 170 | strpair1, sizeof(strpair1)/sizeof(strpair1[0])); 171 | 172 | /* test 2*/ 173 | const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/ 174 | StrPair strpair2[] = {{"The pot had a handle", 0}}; 175 | Tests test2("test 2", dict2, 2, strpair2, 1); 176 | 177 | /* test 3*/ 178 | const char *dict3[] = {"The"}; 179 | StrPair strpair3[] = {{"The pot had a handle", "The"}}; 180 | Tests test3("test 3", dict3, 1, strpair3, 1); 181 | 182 | /* test 4*/ 183 | const char *dict4[] = {"pot"}; 184 | StrPair strpair4[] = {{"The pot had a handle", "pot"}}; 185 | Tests test4("test 4", dict4, 1, strpair4, 1); 186 | 187 | /* test 5*/ 188 | const char *dict5[] = {"pot "}; 189 | StrPair strpair5[] = {{"The pot had a handle", "pot "}}; 190 | Tests test5("test 5", dict5, 1, strpair5, 1); 191 | 192 | /* test 6*/ 193 | const char *dict6[] = {"ot h"}; 194 | StrPair strpair6[] = {{"The pot had a handle", "ot h"}}; 195 | Tests test6("test 6", dict6, 1, strpair6, 1); 196 | 197 | /* test 7*/ 198 | const char *dict7[] = {"andle"}; 199 | StrPair strpair7[] = {{"The pot had a handle", "andle"}}; 200 | Tests test7("test 7", dict7, 1, strpair7, 1); 201 | -------------------------------------------------------------------------------- /tests/Makefile: -------------------------------------------------------------------------------- 1 | OS := $(shell uname) 2 | ifeq ($(OS), Darwin) 3 | SO_EXT := dylib 4 | else 5 | SO_EXT := so 6 | endif 7 | 8 | .PHONY = all clean test runtest benchmark 9 | 10 | PROGRAM = ac_test 11 | BENCHMARK = ac_bench 12 | all: runtest 13 | 14 | CXXFLAGS = -O3 -g -march=native -Wall -DDEBUG 15 | MYCXXFLAGS = -MMD -I.. $(CXXFLAGS) 16 | %.o : %.cxx 17 | $(CXX) $< -c $(MYCXXFLAGS) 18 | 19 | -include dep.cxx 20 | SRC = test_main.cxx ac_test_simple.cxx ac_test_aggr.cxx test_bigfile.cxx 21 | 22 | OBJ = ${SRC:.cxx=.o} 23 | 24 | -include test_dep.txt 25 | -include bench_dep.txt 26 | 27 | $(PROGRAM) $(BENCHMARK) : testinput/text.tar testinput/image.bin 28 | $(PROGRAM) : $(OBJ) ../libac.$(SO_EXT) 29 | $(CXX) $(OBJ) -L.. -lac -o $@ 30 | -cat *.d > test_dep.txt 31 | 32 | $(BENCHMARK) : ac_bench.o ../libac.$(SO_EXT) 33 | $(CXX) ac_bench.o -L.. -lac -o $@ 34 | -cat *.d > bench_dep.txt 35 | 36 | ifneq ($(OS), Darwin) 37 | runtest:$(PROGRAM) 38 | LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/* 39 | 40 | benchmark:$(BENCHMARK) 41 | LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./ac_bench 42 | 43 | else 44 | runtest:$(PROGRAM) 45 | DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/* 46 | 47 | benchmark:$(BENCHMARK) 48 | DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./ac_bench 49 | 50 | endif 51 | 52 | testinput/text.tar: 53 | echo "download testing files (gcc tarball)..." 54 | if [ ! -d testinput ] ; then mkdir testinput; fi 55 | cd testinput && \ 56 | curl ftp://ftp.gnu.org/gnu/gcc/gcc-1.42.tar.gz -o text.tar.gz 2>/dev/null \ 57 | && gzip -d text.tar.gz 58 | 59 | testinput/image.bin: 60 | echo "download testing files.." 61 | if [ ! -d testinput ] ; then mkdir testinput; fi 62 | curl http://www.3dvisionlive.com/sites/default/files/Curiosity_render_hiresb.jpg -o $@ 2>/dev/null 63 | 64 | clean: 65 | -rm -f *.o *.d dep.txt $(PROGRAM) $(BENCHMARK) 66 | -------------------------------------------------------------------------------- /tests/ac_bench.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | #include "ac.h" 19 | #include "ac_util.hpp" 20 | 21 | using namespace std; 22 | 23 | static bool SomethingWrong = false; 24 | 25 | static int iteration = 300; 26 | static string dict_dir; 27 | static string obj_file_dir; 28 | static bool print_help = false; 29 | static int piece_size = 1024; 30 | 31 | class PatternSet { 32 | public: 33 | PatternSet(const char* filepath); 34 | ~PatternSet() { Cleanup(); } 35 | 36 | int getPatternNum() const { return _pat_num; } 37 | const char** getPatternVector() const { return _patterns; } 38 | unsigned int* getPatternLenVector() const { return _pat_len; } 39 | 40 | const char* getErrMessage() const { return _errmsg; } 41 | static bool isDictFile(const char* filepath) { 42 | if (strncmp(basename(const_cast(filepath)), "dict", 4)) 43 | return false; 44 | return true; 45 | } 46 | 47 | private: 48 | bool ExtractPattern(const char* filepath); 49 | void Cleanup(); 50 | 51 | const char** _patterns; 52 | unsigned int* _pat_len; 53 | char* _mmap; 54 | int _fd; 55 | size_t _mmap_size; 56 | int _pat_num; 57 | 58 | const char* _errmsg; 59 | }; 60 | 61 | bool 62 | PatternSet::ExtractPattern(const char* filepath) { 63 | if (!isDictFile(filepath)) 64 | return false; 65 | 66 | struct stat filestat; 67 | if (stat(filepath, &filestat)) { 68 | _errmsg = "fail to call stat()"; 69 | return false; 70 | } 71 | 72 | if (filestat.st_size > 4096 * 1024) { 73 | /* It doesn't seem to be a dictionary file*/ 74 | _errmsg = "file too big?"; 75 | return false; 76 | } 77 | 78 | _fd = open(filepath, 0); 79 | if (_fd == -1) { 80 | _errmsg = "fail to open dictionary file"; 81 | return false; 82 | } 83 | 84 | _mmap_size = filestat.st_size; 85 | _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE, 86 | MAP_PRIVATE, _fd, 0); 87 | if (_mmap == MAP_FAILED) { 88 | _errmsg = "fail to call mmap"; 89 | return false; 90 | } 91 | 92 | const char* pat = _mmap; 93 | vector pat_vect; 94 | vector pat_len_vect; 95 | 96 | for (size_t i = 0, e = filestat.st_size; i < e; i++) { 97 | if (_mmap[i] == '\r' || _mmap[i] == '\n') { 98 | _mmap[i] = '\0'; 99 | int len = _mmap + i - pat; 100 | if (len > 0) { 101 | pat_vect.push_back(pat); 102 | pat_len_vect.push_back(len); 103 | } 104 | pat = _mmap + i + 1; 105 | } 106 | } 107 | 108 | ASSERT(pat_vect.size() == pat_len_vect.size()); 109 | 110 | int pat_num = pat_vect.size(); 111 | if (pat_num > 0) { 112 | const char** p = _patterns = new const char*[pat_num]; 113 | int i = 0; 114 | for (vector::iterator iter = pat_vect.begin(), 115 | iter_e = pat_vect.end(); iter != iter_e; ++iter) { 116 | p[i++] = *iter; 117 | } 118 | 119 | i = 0; 120 | unsigned int* q = _pat_len = new unsigned int[pat_num]; 121 | for (vector::iterator iter = pat_len_vect.begin(), 122 | iter_e = pat_len_vect.end(); iter != iter_e; ++iter) { 123 | q[i++] = *iter; 124 | } 125 | } 126 | 127 | _pat_num = pat_num; 128 | if (pat_num <= 0) { 129 | _errmsg = "no pattern at all"; 130 | return false; 131 | } 132 | 133 | return true; 134 | } 135 | 136 | void 137 | PatternSet::Cleanup() { 138 | if (_mmap != MAP_FAILED) { 139 | munmap(_mmap, _mmap_size); 140 | _mmap = (char*)MAP_FAILED; 141 | _mmap_size = 0; 142 | } 143 | 144 | delete[] _patterns; 145 | delete[] _pat_len; 146 | if (_fd != -1) 147 | close(_fd); 148 | _pat_num = -1; 149 | } 150 | 151 | PatternSet::PatternSet(const char* filepath) { 152 | _patterns = 0; 153 | _pat_len = 0; 154 | _mmap = (char*)MAP_FAILED; 155 | _mmap_size = 0; 156 | _pat_num = -1; 157 | _errmsg = ""; 158 | 159 | if (!ExtractPattern(filepath)) 160 | Cleanup(); 161 | } 162 | 163 | bool 164 | getFilesUnderDir(vector& files, const char* path) { 165 | files.clear(); 166 | 167 | DIR* dir = opendir(path); 168 | if (!dir) 169 | return false; 170 | 171 | string path_dir = path; 172 | path_dir += "/"; 173 | 174 | for (;;) { 175 | struct dirent* entry = readdir(dir); 176 | if (entry) { 177 | string filepath = path_dir + entry->d_name; 178 | struct stat file_stat; 179 | if (stat(filepath.c_str(), &file_stat)) { 180 | closedir(dir); 181 | return false; 182 | } 183 | 184 | if (S_ISREG(file_stat.st_mode)) 185 | files.push_back(filepath); 186 | 187 | continue; 188 | } 189 | 190 | if (errno) { 191 | return false; 192 | } 193 | break; 194 | } 195 | closedir(dir); 196 | return true; 197 | } 198 | 199 | class Timer { 200 | public: 201 | Timer() { 202 | my_clock_gettime(&_start); 203 | _stop = _start; 204 | _acc.tv_sec = 0; 205 | _acc.tv_nsec = 0; 206 | } 207 | 208 | const Timer& operator += (const Timer& that) { 209 | time_t sec = _acc.tv_sec + that._acc.tv_sec; 210 | long nsec = _acc.tv_nsec + that._acc.tv_nsec; 211 | if (nsec > 1000000000) { 212 | nsec -= 1000000000; 213 | sec += 1; 214 | } 215 | _acc.tv_sec = sec; 216 | _acc.tv_nsec = nsec; 217 | return *this; 218 | } 219 | 220 | // return duration in us 221 | size_t getDuration() const { 222 | return _acc.tv_sec * (size_t)1000000 + _acc.tv_nsec/1000; 223 | } 224 | 225 | void Start(bool acc=true) { 226 | my_clock_gettime(&_start); 227 | } 228 | 229 | void Stop() { 230 | my_clock_gettime(&_stop); 231 | struct timespec t = CalcDuration(); 232 | _acc = add_duration(_acc, t); 233 | } 234 | 235 | private: 236 | int my_clock_gettime(struct timespec* t) { 237 | #ifdef __linux 238 | return clock_gettime(CLOCK_PROCESS_CPUTIME_ID, t); 239 | #else 240 | struct timeval tv; 241 | int rc = gettimeofday(&tv, 0); 242 | t->tv_sec = tv.tv_sec; 243 | t->tv_nsec = tv.tv_usec * 1000; 244 | return rc; 245 | #endif 246 | } 247 | 248 | struct timespec add_duration(const struct timespec& dur1, 249 | const struct timespec& dur2) { 250 | time_t sec = dur1.tv_sec + dur2.tv_sec; 251 | long nsec = dur1.tv_nsec + dur2.tv_nsec; 252 | if (nsec > 1000000000) { 253 | nsec -= 1000000000; 254 | sec += 1; 255 | } 256 | timespec t; 257 | t.tv_sec = sec; 258 | t.tv_nsec = nsec; 259 | 260 | return t; 261 | } 262 | 263 | struct timespec CalcDuration() const { 264 | timespec diff; 265 | if ((_stop.tv_nsec - _start.tv_nsec)<0) { 266 | diff.tv_sec = _stop.tv_sec - _start.tv_sec - 1; 267 | diff.tv_nsec = 1000000000 + _stop.tv_nsec - _start.tv_nsec; 268 | } else { 269 | diff.tv_sec = _stop.tv_sec - _start.tv_sec; 270 | diff.tv_nsec = _stop.tv_nsec - _start.tv_nsec; 271 | } 272 | return diff; 273 | } 274 | 275 | struct timespec _start; 276 | struct timespec _stop; 277 | struct timespec _acc; 278 | }; 279 | 280 | class Benchmark { 281 | public: 282 | Benchmark(const PatternSet& pat_set, const char* infile): 283 | _pat_set(pat_set), _infile(infile) { 284 | _mmap = (char*)MAP_FAILED; 285 | _file_sz = 0; 286 | _fd = -1; 287 | } 288 | 289 | ~Benchmark() { 290 | if (_mmap != MAP_FAILED) 291 | munmap(_mmap, _file_sz); 292 | if (_fd != -1) 293 | close(_fd); 294 | } 295 | 296 | bool Run(int iteration); 297 | const Timer& getTimer() const { return _timer; } 298 | 299 | private: 300 | const PatternSet& _pat_set; 301 | const char* _infile; 302 | char* _mmap; 303 | int _fd; 304 | size_t _file_sz; // input file size 305 | Timer _timer; 306 | }; 307 | 308 | bool 309 | Benchmark::Run(int iteration) { 310 | if (_pat_set.getPatternNum() <= 0) { 311 | SomethingWrong = true; 312 | return false; 313 | } 314 | 315 | if (_mmap == MAP_FAILED) { 316 | struct stat filestat; 317 | if (stat(_infile, &filestat)) { 318 | SomethingWrong = true; 319 | return false; 320 | } 321 | 322 | if (!S_ISREG(filestat.st_mode)) { 323 | SomethingWrong = true; 324 | return false; 325 | } 326 | 327 | _fd = open(_infile, 0); 328 | if (_fd == -1) 329 | return false; 330 | 331 | _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE, 332 | MAP_PRIVATE, _fd, 0); 333 | 334 | if (_mmap == MAP_FAILED) { 335 | SomethingWrong = true; 336 | return false; 337 | } 338 | 339 | _file_sz = filestat.st_size; 340 | } 341 | 342 | ac_t* ac = ac_create(_pat_set.getPatternVector(), 343 | _pat_set.getPatternLenVector(), 344 | _pat_set.getPatternNum()); 345 | if (!ac) { 346 | SomethingWrong = true; 347 | return false; 348 | } 349 | 350 | int piece_num = _file_sz/piece_size; 351 | 352 | _timer.Start(false); 353 | 354 | /* Stupid compiler may not be able to promote piece_size into register. 355 | * Do it manually. 356 | */ 357 | int piece_sz = piece_size; 358 | for (int i = 0; i < iteration; i++) { 359 | size_t match_ofst = 0; 360 | for (int piece_idx = 0; piece_idx < piece_num; piece_idx ++) { 361 | ac_match2(ac, _mmap + match_ofst, piece_sz); 362 | match_ofst += piece_sz; 363 | } 364 | if (match_ofst != _file_sz) 365 | ac_match2(ac, _mmap + match_ofst, _file_sz - match_ofst); 366 | } 367 | _timer.Stop(); 368 | return true; 369 | } 370 | 371 | const char* short_opt = "hd:f:i:p:"; 372 | const struct option long_opts[] = { 373 | {"help", no_argument, 0, 'h'}, 374 | {"iteration", required_argument, 0, 'i'}, 375 | {"dictionary-dir", required_argument, 0, 'd'}, 376 | {"obj-file-dir", required_argument, 0, 'f'}, 377 | {"piece-size", required_argument, 0, 'p'}, 378 | }; 379 | 380 | static void 381 | PrintHelp(const char* prog_name) { 382 | const char* msg = 383 | "Usage %s [OPTIONS]\n" 384 | " -d, --dictionary-dir : specify the dictionary directory (./dict by default)\n" 385 | " -f, --obj-file-dir : specify the object file directory\n" 386 | " (./testinput by default)\n" 387 | " -i, --iteration : Run this many iteration for each pattern match\n" 388 | " -p, --piece-size : The size of 'piece' in byte. The input file is\n" 389 | " divided into pieces, and match function is working\n" 390 | " on one piece at a time. The default size of piece\n" 391 | " is 1k byte.\n"; 392 | 393 | fprintf(stdout, msg, prog_name); 394 | } 395 | 396 | static bool 397 | getOptions(int argc, char** argv) { 398 | bool dict_dir_set = false; 399 | bool objfile_dir_set = false; 400 | int opt_index; 401 | 402 | while (1) { 403 | if (print_help) break; 404 | 405 | int c = getopt_long(argc, argv, short_opt, long_opts, &opt_index); 406 | 407 | if (c == -1) break; 408 | if (c == 0) { c = long_opts[opt_index].val; } 409 | 410 | switch(c) { 411 | case 'h': 412 | print_help = true; 413 | break; 414 | 415 | case 'i': 416 | iteration = atol(optarg); 417 | break; 418 | 419 | case 'd': 420 | dict_dir = optarg; 421 | dict_dir_set = true; 422 | break; 423 | 424 | case 'f': 425 | obj_file_dir = optarg; 426 | objfile_dir_set = true; 427 | break; 428 | 429 | case 'p': 430 | piece_size = atol(optarg); 431 | break; 432 | 433 | case '?': 434 | default: 435 | return false; 436 | } 437 | } 438 | 439 | if (print_help) 440 | return true; 441 | 442 | string basedir(dirname(argv[0])); 443 | if (!dict_dir_set) 444 | dict_dir = basedir + "/dict"; 445 | 446 | if (!objfile_dir_set) 447 | obj_file_dir = basedir + "/testinput"; 448 | 449 | return true; 450 | } 451 | 452 | int 453 | main(int argc, char** argv) { 454 | if (!getOptions(argc, argv)) 455 | return -1; 456 | 457 | if (print_help) { 458 | PrintHelp(argv[0]); 459 | return 0; 460 | } 461 | 462 | #ifndef __linux 463 | fprintf(stdout, "\n!!!WARNING: On this OS, the execution time is measured" 464 | " by gettimeofday(2) which is imprecise!!!\n\n"); 465 | #endif 466 | 467 | fprintf(stdout, "Test with iteration = %d, piece size = %d, and", 468 | iteration, piece_size); 469 | fprintf(stdout, "\n dictionary dir = %s\n object file dir = %s\n\n", 470 | dict_dir.c_str(), obj_file_dir.c_str()); 471 | 472 | vector dict_files; 473 | vector input_files; 474 | 475 | if (!getFilesUnderDir(dict_files, dict_dir.c_str())) { 476 | fprintf(stdout, "fail to find dictionary files\n"); 477 | return -1; 478 | } 479 | 480 | if (!getFilesUnderDir(input_files, obj_file_dir.c_str())) { 481 | fprintf(stdout, "fail to find test input files\n"); 482 | return -1; 483 | } 484 | 485 | for (vector::iterator diter = dict_files.begin(), 486 | diter_e = dict_files.end(); diter != diter_e; ++diter) { 487 | 488 | const char* dict_name = diter->c_str(); 489 | if (!PatternSet::isDictFile(dict_name)) 490 | continue; 491 | 492 | PatternSet ps(dict_name); 493 | if (ps.getPatternNum() <= 0) { 494 | fprintf(stdout, "fail to open dictionary file %s : %s\n", 495 | dict_name, ps.getErrMessage()); 496 | SomethingWrong = true; 497 | continue; 498 | } 499 | 500 | fprintf(stdout, "Using dictionary %s\n", dict_name); 501 | Timer timer; 502 | for (vector::iterator iter = input_files.begin(), 503 | iter_e = input_files.end(); iter != iter_e; ++iter) { 504 | fprintf(stdout, " testing %s ... ", iter->c_str()); 505 | fflush(stdout); 506 | Benchmark bm(ps, iter->c_str()); 507 | bm.Run(iteration); 508 | const Timer& t = bm.getTimer(); 509 | timer += bm.getTimer(); 510 | fprintf(stdout, "elapsed %.3f\n", t.getDuration() / 1000000.0); 511 | } 512 | 513 | fprintf(stdout, 514 | "\n==========================================================\n" 515 | " Total Elapse %.3f\n\n", timer.getDuration() / 1000000.0); 516 | } 517 | 518 | return SomethingWrong ? -1 : 0; 519 | } 520 | -------------------------------------------------------------------------------- /tests/ac_test_aggr.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "ac.h" 13 | #include "ac_util.hpp" 14 | #include "test_base.hpp" 15 | 16 | using namespace std; 17 | 18 | namespace { 19 | class ACBigFileTester : public BigFileTester { 20 | public: 21 | ACBigFileTester(const char* filepath) : BigFileTester(filepath){}; 22 | 23 | private: 24 | virtual buf_header_t* PM_Create(const char** strv, uint32* strlenv, 25 | uint32 vect_len) { 26 | return (buf_header_t*)ac_create(strv, strlenv, vect_len); 27 | } 28 | 29 | virtual void PM_Free(buf_header_t* PM) { ac_free(PM); } 30 | virtual bool Run_Helper(buf_header_t* PM); 31 | }; 32 | 33 | class ACTestAggressive: public ACTestBase { 34 | public: 35 | ACTestAggressive(const vector& files, const char* banner) 36 | : ACTestBase(banner), _files(files) {} 37 | virtual bool Run(); 38 | 39 | private: 40 | void PrintSummary(int total, int fail) { 41 | fprintf(stdout, "Test count : %d, fail: %d\n", total, fail); 42 | fflush(stdout); 43 | } 44 | vector _files; 45 | }; 46 | 47 | } // end of anonymous namespace 48 | 49 | bool 50 | ACBigFileTester::Run_Helper(buf_header_t* PM) { 51 | int fail = 0; 52 | // advance one chunk at a time. 53 | int len = _msg_len; 54 | int chunk_sz = _chunk_sz; 55 | 56 | vector c_style_keys; 57 | for (int i = 0, e = _keys.size(); i != e; i++) { 58 | const char* key = _keys[i].first; 59 | int len = _keys[i].second; 60 | char *t = new char[len+1]; 61 | memcpy(t, key, len); 62 | t[len] = '\0'; 63 | c_style_keys.push_back(t); 64 | } 65 | 66 | for (int ofst = 0, chunk_idx = 0, chunk_num = _chunk_num; 67 | chunk_idx < chunk_num; ofst += chunk_sz, chunk_idx++) { 68 | const char* substring = _msg + ofst; 69 | ac_result_t r = ac_match((ac_t*)(void*)PM, substring , len - ofst); 70 | int m_b = r.match_begin; 71 | int m_e = r.match_end; 72 | 73 | if (m_b < 0 || m_e < 0 || m_e <= m_b || m_e >= len) { 74 | fprintf(stdout, "fail to find match substring[%d:%d])\n", 75 | ofst, len - 1); 76 | fail ++; 77 | continue; 78 | } 79 | 80 | const char* match_str = _msg + len; 81 | int strstr_len = 0; 82 | int key_idx = -1; 83 | 84 | for (int i = 0, e = c_style_keys.size(); i != e; i++) { 85 | const char* key = c_style_keys[i]; 86 | if (const char *m = strstr(substring, key)) { 87 | if (m < match_str) { 88 | match_str = m; 89 | strstr_len = _keys[i].second; 90 | key_idx = i; 91 | } 92 | } 93 | } 94 | ASSERT(key_idx != -1); 95 | if ((match_str - substring != m_b)) { 96 | fprintf(stdout, 97 | "Fail to find match substring[%d:%d])," 98 | " expected to find match at offset %d instead of %d\n", 99 | ofst, len - 1, 100 | (int)(match_str - _msg), ofst + m_b); 101 | fprintf(stdout, "%d vs %d (key idx %d)\n", strstr_len, m_e - m_b + 1, key_idx); 102 | PrintStr(stdout, match_str, strstr_len); 103 | fprintf(stdout, "\n"); 104 | PrintStr(stdout, _msg + ofst + m_b, 105 | m_e - m_b + 1); 106 | fprintf(stdout, "\n"); 107 | fail ++; 108 | } 109 | } 110 | for (vector::iterator i = c_style_keys.begin(), 111 | e = c_style_keys.end(); i != e; i++) { 112 | delete[] *i; 113 | } 114 | 115 | return fail == 0; 116 | } 117 | 118 | bool 119 | ACTestAggressive::Run() { 120 | int fail = 0; 121 | for (vector::iterator i = _files.begin(), e = _files.end(); 122 | i != e; i++) { 123 | ACBigFileTester bft(*i); 124 | if (!bft.Run()) 125 | fail ++; 126 | } 127 | return fail == 0; 128 | } 129 | 130 | bool 131 | Run_AC_Aggressive_Test(const vector& files) { 132 | ACTestAggressive t(files, "AC Aggressive test"); 133 | t.PrintBanner(); 134 | return t.Run(); 135 | } 136 | -------------------------------------------------------------------------------- /tests/ac_test_simple.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "ac.h" 7 | #include "ac_util.hpp" 8 | #include "test_base.hpp" 9 | 10 | using namespace std; 11 | 12 | namespace { 13 | typedef struct { 14 | const char* str; 15 | const char* match; 16 | } StrPair; 17 | 18 | typedef enum { 19 | MV_FIRST_MATCH = 0, 20 | MV_LEFT_LONGEST = 1, 21 | } MatchVariant; 22 | 23 | typedef struct { 24 | const char* name; 25 | const char** dict; 26 | StrPair* strpairs; 27 | int dict_len; 28 | int strpair_num; 29 | MatchVariant match_variant; 30 | } TestingCase; 31 | 32 | class Tests { 33 | public: 34 | Tests(const char* name, 35 | const char* dict[], int dict_len, 36 | StrPair strpairs[], int strpair_num, 37 | MatchVariant mv = MV_FIRST_MATCH) { 38 | if (!_tests) 39 | _tests = new vector; 40 | 41 | TestingCase tc; 42 | tc.name = name; 43 | tc.dict = dict; 44 | tc.strpairs = strpairs; 45 | tc.dict_len = dict_len; 46 | tc.strpair_num = strpair_num; 47 | tc.match_variant = mv; 48 | _tests->push_back(tc); 49 | } 50 | 51 | static vector* Get_Tests() { return _tests; } 52 | static void Erase_Tests() { delete _tests; _tests = 0; } 53 | 54 | private: 55 | static vector *_tests; 56 | }; 57 | 58 | class LeftLongestTests : public Tests { 59 | public: 60 | LeftLongestTests (const char* name, const char* dict[], int dict_len, 61 | StrPair strpairs[], int strpair_num): 62 | Tests(name, dict, dict_len, strpairs, strpair_num, MV_LEFT_LONGEST) { 63 | } 64 | }; 65 | 66 | vector* Tests::_tests = 0; 67 | 68 | class ACTestSimple: public ACTestBase { 69 | public: 70 | ACTestSimple(const char* banner) : ACTestBase(banner) {} 71 | virtual bool Run(); 72 | 73 | private: 74 | void PrintSummary(int total, int fail) { 75 | fprintf(stdout, "Test count : %d, fail: %d\n", total, fail); 76 | fflush(stdout); 77 | } 78 | }; 79 | } 80 | 81 | bool 82 | ACTestSimple::Run() { 83 | int total = 0; 84 | int fail = 0; 85 | 86 | vector *tests = Tests::Get_Tests(); 87 | if (!tests) { 88 | PrintSummary(0, 0); 89 | return true; 90 | } 91 | 92 | for (vector::iterator i = tests->begin(), e = tests->end(); 93 | i != e; i++) { 94 | TestingCase& t = *i; 95 | int dict_len = t.dict_len; 96 | unsigned int* strlen_v = new unsigned int[dict_len]; 97 | 98 | fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name); 99 | for (int i = 0, need_break=0; i < dict_len; i++) { 100 | const char* s = t.dict[i]; 101 | fprintf(stdout, "%s, ", s); 102 | strlen_v[i] = strlen(s); 103 | if (need_break++ == 16) { 104 | fputs("\n ", stdout); 105 | need_break = 0; 106 | } 107 | } 108 | fputs("]\n", stdout); 109 | 110 | /* Create the dictionary */ 111 | ac_t* ac = ac_create(t.dict, strlen_v, dict_len); 112 | delete[] strlen_v; 113 | 114 | for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) { 115 | const StrPair& sp = t.strpairs[ii]; 116 | const char *str = sp.str; // the string to be matched 117 | const char *match = sp.match; 118 | 119 | fprintf(stdout, "[%3d] Testing '%s' : ", total, str); 120 | 121 | int len = strlen(str); 122 | ac_result_t r; 123 | if (t.match_variant == MV_FIRST_MATCH) 124 | r = ac_match(ac, str, len); 125 | else if (t.match_variant == MV_LEFT_LONGEST) 126 | r = ac_match_longest_l(ac, str, len); 127 | else { 128 | ASSERT(false && "Unknown variant"); 129 | } 130 | 131 | int m_b = r.match_begin; 132 | int m_e = r.match_end; 133 | 134 | // The return value per se is insane. 135 | if (m_b > m_e || 136 | ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) { 137 | fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e); 138 | fail ++; 139 | continue; 140 | } 141 | 142 | // If the string is not supposed to match the dictionary. 143 | if (!match) { 144 | if (m_b != -1 || m_e != -1) { 145 | fail ++; 146 | fprintf(stdout, "Not Supposed to match (%d, %d) \n", 147 | m_b, m_e); 148 | } else 149 | fputs("Pass\n", stdout); 150 | continue; 151 | } 152 | 153 | // The string or its substring is match the dict. 154 | if (m_b >= len || m_b >= len) { 155 | fail ++; 156 | fprintf(stdout, 157 | "Return value >= the length of the string (%d, %d)\n", 158 | m_b, m_e); 159 | continue; 160 | } else { 161 | int mlen = strlen(match); 162 | if ((mlen != m_e - m_b + 1) || 163 | strncmp(str + m_b, match, mlen)) { 164 | fail ++; 165 | fprintf(stdout, "Fail\n"); 166 | } else 167 | fprintf(stdout, "Pass\n"); 168 | } 169 | } 170 | fputs("\n", stdout); 171 | ac_free(ac); 172 | } 173 | 174 | PrintSummary(total, fail); 175 | return fail == 0; 176 | } 177 | 178 | bool 179 | Run_AC_Simple_Test() { 180 | ACTestSimple t("AC Simple test"); 181 | t.PrintBanner(); 182 | return t.Run(); 183 | } 184 | 185 | ////////////////////////////////////////////////////////////////////////////// 186 | // 187 | // Testing cases for first-match variant (i.e. test ac_match()) 188 | // 189 | ////////////////////////////////////////////////////////////////////////////// 190 | // 191 | 192 | /* test 1*/ 193 | const char *dict1[] = {"he", "she", "his", "her"}; 194 | StrPair strpair1[] = { 195 | {"he", "he"}, {"she", "she"}, {"his", "his"}, 196 | {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"}, 197 | {"shis2", "his"}, {"ahhe", "he"} 198 | }; 199 | Tests test1("test 1", 200 | dict1, sizeof(dict1)/sizeof(dict1[0]), 201 | strpair1, sizeof(strpair1)/sizeof(strpair1[0])); 202 | 203 | /* test 2*/ 204 | const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/ 205 | StrPair strpair2[] = {{"The pot had a handle", 0}}; 206 | Tests test2("test 2", dict2, 2, strpair2, 1); 207 | 208 | /* test 3*/ 209 | const char *dict3[] = {"The"}; 210 | StrPair strpair3[] = {{"The pot had a handle", "The"}}; 211 | Tests test3("test 3", dict3, 1, strpair3, 1); 212 | 213 | /* test 4*/ 214 | const char *dict4[] = {"pot"}; 215 | StrPair strpair4[] = {{"The pot had a handle", "pot"}}; 216 | Tests test4("test 4", dict4, 1, strpair4, 1); 217 | 218 | /* test 5*/ 219 | const char *dict5[] = {"pot "}; 220 | StrPair strpair5[] = {{"The pot had a handle", "pot "}}; 221 | Tests test5("test 5", dict5, 1, strpair5, 1); 222 | 223 | /* test 6*/ 224 | const char *dict6[] = {"ot h"}; 225 | StrPair strpair6[] = {{"The pot had a handle", "ot h"}}; 226 | Tests test6("test 6", dict6, 1, strpair6, 1); 227 | 228 | /* test 7*/ 229 | const char *dict7[] = {"andle"}; 230 | StrPair strpair7[] = {{"The pot had a handle", "andle"}}; 231 | Tests test7("test 7", dict7, 1, strpair7, 1); 232 | 233 | const char *dict8[] = {"aaab"}; 234 | StrPair strpair8[] = {{"aaaaaaab", "aaab"}}; 235 | Tests test8("test 8", dict8, 1, strpair8, 1); 236 | 237 | const char *dict9[] = {"haha", "z"}; 238 | StrPair strpair9[] = {{"aaaaz", "z"}, {"z", "z"}}; 239 | Tests test9("test 9", dict9, 2, strpair9, 2); 240 | 241 | /* test the case when input string doesn't contain even a single char 242 | * of the pattern in dictionary. 243 | */ 244 | const char *dict10[] = {"abc"}; 245 | StrPair strpair10[] = {{"cde", 0}}; 246 | Tests test10("test 10", dict10, 1, strpair10, 1); 247 | 248 | 249 | ////////////////////////////////////////////////////////////////////////////// 250 | // 251 | // Testing cases for first longest match variant (i.e. 252 | // test ac_match_longest_l()) 253 | // 254 | ////////////////////////////////////////////////////////////////////////////// 255 | // 256 | 257 | // This was actually first motivation for left-longest-match 258 | const char *dict100[] = {"Mozilla", "Mozilla Mobile"}; 259 | StrPair strpair100[] = {{"User Agent containing string Mozilla Mobile", "Mozilla Mobile"}}; 260 | LeftLongestTests test100("l_test 100", dict100, 2, strpair100, 1); 261 | 262 | // Dict with single char is tricky 263 | const char *dict101[] = {"a", "abc"}; 264 | StrPair strpair101[] = {{"abcdef", "abc"}}; 265 | LeftLongestTests test101("l_test 101", dict101, 2, strpair101, 1); 266 | 267 | // Testing case with partially overlapping patterns. The purpose is to 268 | // check if the fail-link leading from terminal state is correct. 269 | // 270 | // The fail-link leading from terminal-state does not matter in 271 | // match-first-occurrence variant, as it stop when a terminal is hit. 272 | // 273 | const char *dict102[] = {"abc", "bcdef"}; 274 | StrPair strpair102[] = {{"abcdef", "bcdef"}}; 275 | LeftLongestTests test102("l_test 102", dict102, 2, strpair102, 1); 276 | -------------------------------------------------------------------------------- /tests/dict/README.txt: -------------------------------------------------------------------------------- 1 | This directory contains pattern set of benchmark purpose. 2 | -------------------------------------------------------------------------------- /tests/dict/dict1.txt: -------------------------------------------------------------------------------- 1 | false_return@ 2 | forloop#haha 3 | wtfprogram 4 | mmaporunmap 5 | ThIs?Module!IsEssential 6 | struct rtlwtf 7 | gettIMEOfdayWrong 8 | edistribution_and_use_in_@source 9 | Copyright~#@ 10 | while {! 11 | !%SQLinje 12 | -------------------------------------------------------------------------------- /tests/load_ac_test.lua: -------------------------------------------------------------------------------- 1 | -- This script is to test load_ac.lua 2 | -- 3 | -- Some notes: 4 | -- 1. The purpose of this script is not to check if the libac.so work 5 | -- properly, it is to check if there are something stupid in load_ac.lua 6 | -- 7 | -- 2. There are bunch of collectgarbage() calls, the purpose is to make 8 | -- sure the shared lib is not unloaded after GC. 9 | 10 | -- load_ac.lua looks up libac.so via package.cpath rather than LD_LIBRARY_PATH, 11 | -- prepend (instead of appending) some insane paths here to see if it quit 12 | -- prematurely. 13 | -- 14 | package.cpath = ".;./?.so;" .. package.cpath 15 | 16 | local ac = require "load_ac" 17 | 18 | local ac_create = ac.create_ac 19 | local ac_match = ac.match 20 | local string_fmt = string.format 21 | local string_sub = string.sub 22 | 23 | local err_cnt = 0 24 | local function mytest(testname, dict, match, notmatch) 25 | print(">Testing ", testname) 26 | 27 | io.write(string_fmt("Dictionary: ")); 28 | for i=1, #dict do 29 | io.write(string_fmt("%s, ", dict[i])) 30 | end 31 | print "" 32 | 33 | local ac_inst = ac_create(dict); 34 | collectgarbage() 35 | for i=1, #match do 36 | local str = match[i] 37 | io.write(string_fmt("Matching %s, ", str)) 38 | local b = ac_match(ac_inst, str) 39 | if b then 40 | print "pass" 41 | else 42 | err_cnt = err_cnt + 1 43 | print "fail" 44 | end 45 | collectgarbage() 46 | end 47 | 48 | if notmatch == nil then 49 | return 50 | end 51 | 52 | collectgarbage() 53 | 54 | for i = 1, #notmatch do 55 | local str = notmatch[i] 56 | io.write(string_fmt("*Matching %s, ", str)) 57 | local r = ac_match(ac_inst, str) 58 | if r then 59 | err_cnt = err_cnt + 1 60 | print("fail") 61 | else 62 | print("succ") 63 | end 64 | collectgarbage() 65 | end 66 | ac_inst = nil 67 | collectgarbage() 68 | end 69 | 70 | print("") 71 | print("====== Test to see if load_ac.lua works properly ========") 72 | 73 | mytest("test1", 74 | {"he", "she", "his", "her", "str\0ing"}, 75 | -- matching cases 76 | { "he", "she", "his", "hers", "ahe", "shhe", "shis2", "ahhe", "str\0ing" }, 77 | 78 | -- not matching case 79 | {"str\0", "str"} 80 | ) 81 | 82 | os.exit((err_cnt == 0) and 0 or 1) 83 | -------------------------------------------------------------------------------- /tests/lua_test.lua: -------------------------------------------------------------------------------- 1 | -- This script is to test ahocorasick.so not libac.so 2 | -- 3 | local ac = require "ahocorasick" 4 | 5 | local ac_create = ac.create 6 | local ac_match = ac.match 7 | local string_fmt = string.format 8 | local string_sub = string.sub 9 | 10 | local err_cnt = 0 11 | local function mytest(testname, dict, match, notmatch) 12 | print(">Testing ", testname) 13 | 14 | io.write(string_fmt("Dictionary: ")); 15 | for i=1, #dict do 16 | io.write(string_fmt("%s, ", dict[i])) 17 | end 18 | print "" 19 | 20 | local ac_inst = ac_create(dict); 21 | for i=1, #match do 22 | local str = match[i][1] 23 | local substr = match[i][2] 24 | io.write(string_fmt("Matching %s, ", str)) 25 | local b, e = ac_match(ac_inst, str) 26 | if b and e and (string_sub(str, b+1, e+1) == substr) then 27 | print "pass" 28 | else 29 | err_cnt = err_cnt + 1 30 | print "fail" 31 | end 32 | --print("gc is called") 33 | collectgarbage() 34 | end 35 | 36 | if notmatch == nil then 37 | return 38 | end 39 | 40 | for i = 1, #notmatch do 41 | local str = notmatch[i] 42 | io.write(string_fmt("*Matching %s, ", str)) 43 | local r = ac_match(ac_inst, str) 44 | if r then 45 | err_cnt = err_cnt + 1 46 | print("fail") 47 | else 48 | print("succ") 49 | end 50 | collectgarbage() 51 | end 52 | end 53 | 54 | mytest("test1", 55 | {"he", "she", "his", "her", "str\0ing"}, 56 | -- matching cases 57 | { {"he", "he"}, {"she", "she"}, {"his", "his"}, {"hers", "he"}, 58 | {"ahe", "he"}, {"shhe", "he"}, {"shis2", "his"}, {"ahhe", "he"}, 59 | {"str\0ing", "str\0ing"} 60 | }, 61 | 62 | -- not matching case 63 | {"str\0", "str"} 64 | 65 | ) 66 | 67 | os.exit((err_cnt == 0) and 0 or 1) 68 | -------------------------------------------------------------------------------- /tests/test_base.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TEST_BASE_H 2 | #define TEST_BASE_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace std; 9 | class ACTestBase { 10 | public: 11 | ACTestBase(const char* name) :_banner(name) {} 12 | virtual void PrintBanner() { 13 | fprintf(stdout, "\n===== %s ====\n", _banner.c_str()); 14 | } 15 | 16 | virtual bool Run() = 0; 17 | private: 18 | string _banner; 19 | }; 20 | 21 | typedef std::pair StrInfo; 22 | class BigFileTester { 23 | public: 24 | BigFileTester(const char* filepath); 25 | virtual ~BigFileTester() { Cleanup(); } 26 | 27 | bool Run(); 28 | 29 | protected: 30 | virtual buf_header_t* PM_Create(const char** strv, uint32_t* strlenv, 31 | uint32_t vect_len) = 0; 32 | virtual void PM_Free(buf_header_t*) = 0; 33 | virtual bool Run_Helper(buf_header_t* PM) = 0; 34 | 35 | // Return true if the '\0' is valid char of a string. 36 | virtual bool Str_C_Style() { return true; } 37 | 38 | bool GenerateKeys(); 39 | void Cleanup(); 40 | void PrintStr(FILE*, const char* str, int len); 41 | 42 | protected: 43 | const char* _filepath; 44 | int _fd; 45 | vector _keys; 46 | char* _msg; 47 | int _msg_len; 48 | int _key_num; // number of strings in dictionary 49 | int _chunk_sz; 50 | int _chunk_num; 51 | 52 | int _max_key_num; 53 | int _key_min_len; 54 | int _key_max_len; 55 | }; 56 | 57 | extern bool Run_AC_Simple_Test(); 58 | extern bool Run_AC_Aggressive_Test(const vector& files); 59 | 60 | #endif 61 | -------------------------------------------------------------------------------- /tests/test_bigfile.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "ac.h" 13 | #include "ac_util.hpp" 14 | #include "test_base.hpp" 15 | 16 | /////////////////////////////////////////////////////////////////////////// 17 | // 18 | // Implementation of BigFileTester 19 | // 20 | /////////////////////////////////////////////////////////////////////////// 21 | // 22 | BigFileTester::BigFileTester(const char* filepath) { 23 | _filepath = filepath; 24 | _fd = -1; 25 | _msg = (char*)MAP_FAILED; 26 | _msg_len = 0; 27 | _key_num = 0; 28 | _chunk_sz = 0; 29 | _chunk_num = 0; 30 | 31 | _max_key_num = 100; 32 | _key_min_len = 20; 33 | _key_max_len = 80; 34 | } 35 | 36 | void 37 | BigFileTester::Cleanup() { 38 | if (_msg != MAP_FAILED) { 39 | munmap((void*)_msg, _msg_len); 40 | _msg = (char*)MAP_FAILED; 41 | _msg_len = 0; 42 | } 43 | 44 | if (_fd != -1) { 45 | close(_fd); 46 | _fd = -1; 47 | } 48 | } 49 | 50 | bool 51 | BigFileTester::GenerateKeys() { 52 | int chunk_sz = 4096; 53 | int max_key_num = _max_key_num; 54 | int key_min_len = _key_min_len; 55 | int key_max_len = _key_max_len; 56 | 57 | int t = _msg_len / chunk_sz; 58 | int keynum = t > max_key_num ? max_key_num : t; 59 | 60 | if (keynum <= 4) { 61 | // file is too small 62 | return false; 63 | } 64 | chunk_sz = _msg_len / keynum; 65 | _chunk_sz = chunk_sz; 66 | 67 | // For each chunck, "randomly" grab a sub-string searving 68 | // as key. 69 | int random_ofst[] = { 12, 30, 23, 15 }; 70 | int rofstsz = sizeof(random_ofst)/sizeof(random_ofst[0]); 71 | int ofst = 0; 72 | const char* msg = _msg; 73 | _chunk_num = keynum - 1; 74 | for (int idx = 0, e = _chunk_num; idx < e; idx++) { 75 | const char* key = msg + ofst + idx % rofstsz; 76 | int key_len = key_min_len + idx % (key_max_len - key_min_len); 77 | _keys.push_back(StrInfo(key, key_len)); 78 | ofst += chunk_sz; 79 | } 80 | return true; 81 | } 82 | 83 | bool 84 | BigFileTester::Run() { 85 | // Step 1: Bring the file into memory 86 | fprintf(stdout, "Testing using file '%s'...\n", _filepath); 87 | 88 | int fd = _fd = ::open(_filepath, O_RDONLY); 89 | if (fd == -1) { 90 | perror("open"); 91 | return false; 92 | } 93 | 94 | struct stat sb; 95 | if (fstat(fd, &sb) == -1) { 96 | perror("fstat"); 97 | return false; 98 | } 99 | 100 | if (!S_ISREG (sb.st_mode)) { 101 | fprintf(stderr, "%s is not regular file\n", _filepath); 102 | return false; 103 | } 104 | 105 | int ten_M = 1024 * 1024 * 10; 106 | int map_sz = _msg_len = sb.st_size > ten_M ? ten_M : sb.st_size; 107 | char* p = _msg = 108 | (char*)mmap (0, map_sz, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); 109 | if (p == MAP_FAILED) { 110 | perror("mmap"); 111 | return false; 112 | } 113 | 114 | // Get rid of '\0' if we are picky at it. 115 | if (Str_C_Style()) { 116 | for (int i = 0; i < map_sz; i++) { if (!p[i]) p[i] = 'a'; } 117 | p[map_sz - 1] = 0; 118 | } 119 | 120 | // Step 2: "Fabricate" some keys from the file. 121 | if (!GenerateKeys()) { 122 | close(fd); 123 | return false; 124 | } 125 | 126 | // Step 3: Create PM instance 127 | const char** keys = new const char*[_keys.size()]; 128 | unsigned int* keylens = new unsigned int[_keys.size()]; 129 | 130 | int i = 0; 131 | for (vector::iterator si = _keys.begin(), se = _keys.end(); 132 | si != se; si++, i++) { 133 | const StrInfo& strinfo = *si; 134 | keys[i] = strinfo.first; 135 | keylens[i] = strinfo.second; 136 | } 137 | 138 | buf_header_t* PM = PM_Create(keys, keylens, i); 139 | delete[] keys; 140 | delete[] keylens; 141 | 142 | // Step 4: Run testing 143 | bool res = Run_Helper(PM); 144 | PM_Free(PM); 145 | 146 | // Step 5: Clanup 147 | munmap(p, map_sz); 148 | _msg = (char*)MAP_FAILED; 149 | close(fd); 150 | _fd = -1; 151 | 152 | fprintf(stdout, "%s\n", res ? "succ" : "fail"); 153 | return res; 154 | } 155 | 156 | void 157 | BigFileTester::PrintStr(FILE* f, const char* str, int len) { 158 | fprintf(f, "{"); 159 | for (int i = 0; i < len; i++) { 160 | unsigned char c = str[i]; 161 | if (isprint(c)) 162 | fprintf(f, "'%c', ", c); 163 | else 164 | fprintf(f, "%#x, ", c); 165 | } 166 | fprintf(f, "}"); 167 | }; 168 | -------------------------------------------------------------------------------- /tests/test_main.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "ac.h" 12 | #include "ac_util.hpp" 13 | #include "test_base.hpp" 14 | 15 | using namespace std; 16 | 17 | 18 | ///////////////////////////////////////////////////////////////////////// 19 | // 20 | // Simple (yet maybe tricky) testings 21 | // 22 | ///////////////////////////////////////////////////////////////////////// 23 | // 24 | int 25 | main (int argc, char** argv) { 26 | bool succ = Run_AC_Simple_Test(); 27 | 28 | vector files; 29 | for (int i = 1; i < argc; i++) { files.push_back(argv[i]); } 30 | succ = Run_AC_Aggressive_Test(files) && succ; 31 | 32 | return succ ? 0 : -1; 33 | }; 34 | --------------------------------------------------------------------------------