├── LICENSE ├── Makefile ├── README.md ├── beamsearch ├── dump ├── eval.py ├── model.c ├── model.h ├── omnisearch ├── parse ├── qcomp.c ├── qcomp.py ├── requirements.txt ├── run_eval.bash ├── stocsearch ├── train ├── trielookup └── triesearch /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Po-Wei Wang, Huan Zhang, Vijai Mohan, Inderjit S. Dhillon, J. Zico Kolter 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CC=gcc 2 | MKLROOT=/opt/intel/mkl 3 | 4 | CFLAGS=-Wall -g -O3 -std=gnu11 -m64 -I${MKLROOT}/include -fopenmp -DTRANSPARENT_TRIE 5 | LDFLAGS=-L${MKLROOT}/lib/intel64 -Wl,--no-as-needed -Wl,-rpath,${MKLROOT}/lib/intel64 -lmkl_intel_ilp64 -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl -lrt 6 | 7 | ifeq ($(UNAME_S),Linux) 8 | LDFLAGS += -lrt 9 | endif 10 | 11 | all:qcomp 12 | qcomp: qcomp.o model.o 13 | 14 | .PHONY:clean 15 | clean: 16 | rm model.o qcomp.o qcomp 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Realtime query completion via deep language models 2 | 3 | This is a CPU-based implementation for our paper `Realtime query completion via deep language models`, 4 | which is capable of generating 10 query completion in 16 ms. 5 | 6 | To build it, please first install the following 7 | * Intel Math Kernel Library (MKL) (https://software.intel.com/en-us/mkl) 8 | * Please update the MKLROOT in our Makefile if it is not installed in (/opt/intel/mkl) 9 | 10 | ## The CPU-based query completion (qcomp.c) 11 | To play with the query completion, we pre-trained a model from the AOL dataset (model.c, model.h). 12 | The completion program can be compiled by 13 | ``` 14 | $ make 15 | ``` 16 | The generated `qcomp` program is soft-linked to different entries (stocsearch, beamsearch, omnisearch, trielookup). 17 | To play with our omni-completion model, please use 18 | ``` 19 | $ omnisearch 20 | ``` 21 | and type in any prefix, and press enter. 22 | 23 | ## Train the model with AOL data (qcomp.py) 24 | 25 | To train the model, please first install the following python dependencies: 26 | * Keras/Theanos/Numpy 27 | * The dependencies can be obtained by $pip install -r requirements.txt 28 | 29 | You will also need to download the AOL data frm the Internet and save it in `aol_raw.txt`. 30 | Our program qcomp.py is again soft-linked to different entries (parse, train, dump). 31 | It's quite short so please take a look before training. 32 | 33 | First, create the parsed data by 34 | ``` 35 | ./parse aol_raw.txt > aol_parsed.txt 36 | ``` 37 | 38 | Note that the aol_parsed.txt has the following format 39 | ``` 40 | TIMESTAMP QUERY PREFIX MD5_OF_PREFIX 41 | ``` 42 | 43 | Then, we will sort our data by different input assumption 44 | ``` 45 | # sort by md5 of prefix, or the timestamp 46 | sort --key 4 -t$'\t' --parallel=8 aol_parsed.txt > sorted.txt 47 | # sort --key 1 -t$'\t' -g --parallel=8 aol_parsed.txt > sorted.txt 48 | ``` 49 | 50 | The last 1% of sorted.txt will be used in testing. 51 | Now, we can run the training and evaluation using 52 | ``` 53 | bash ./run_eval.bash 54 | ``` 55 | -------------------------------------------------------------------------------- /beamsearch: -------------------------------------------------------------------------------- 1 | qcomp -------------------------------------------------------------------------------- /dump: -------------------------------------------------------------------------------- 1 | qcomp.py -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from math import log 3 | 4 | # occurance of prefix in dataset 5 | lines1 = open(sys.argv[1]).readlines() 6 | lines2 = open(sys.argv[2]).readlines() 7 | 8 | # drop partial lines 9 | lines2 = lines2[: (len(lines2) / 16) * 16] 10 | lines1 = lines1[:len(lines2) / 16] 11 | prefix = map(float, lines1) 12 | 13 | # occurance of predicted completeion in dataset 14 | method = map(float, lines2) 15 | 16 | # group 16 numbers together, because we predict top-16 queries (some of them might be 0) 17 | group = lambda x: [x[i*16:(i+1)*16] for i in range(len(x)/16)] 18 | 19 | # sum the occurance of each group 20 | method = map(sum, group(method)) 21 | 22 | # some prefix will appear 0 times in dataset (for example, the query is too long and truncated) 23 | # in that case both x and y are 0. Avoid dividing by 0. 24 | prob = [x*1./(y + 1e-5) for (x,y) in zip(method, prefix)] 25 | prob = filter( lambda x: x<=1, prob) 26 | 27 | print max(prob) 28 | mean = lambda x: sum(x)*1./len(x) 29 | print 'mean prob =', mean(prob), 'mean hit =', mean(method)/16 30 | -------------------------------------------------------------------------------- /model.h: -------------------------------------------------------------------------------- 1 | int LSTM_1_BIAS_SHAPE_0 = 1024; 2 | extern const float LSTM_1_BIAS[]; 3 | 4 | int LSTM_1_KERNEL_SHAPE_0 = 47; 5 | int LSTM_1_KERNEL_SHAPE_1 = 1024; 6 | extern const float LSTM_1_KERNEL[]; 7 | 8 | int LSTM_1_RECURRENT_KERNEL_SHAPE_0 = 256; 9 | int LSTM_1_RECURRENT_KERNEL_SHAPE_1 = 1024; 10 | extern const float LSTM_1_RECURRENT_KERNEL[]; 11 | 12 | int LSTM_2_BIAS_SHAPE_0 = 1024; 13 | extern const float LSTM_2_BIAS[]; 14 | 15 | int LSTM_2_KERNEL_SHAPE_0 = 256; 16 | int LSTM_2_KERNEL_SHAPE_1 = 1024; 17 | extern const float LSTM_2_KERNEL[]; 18 | 19 | int LSTM_2_RECURRENT_KERNEL_SHAPE_0 = 256; 20 | int LSTM_2_RECURRENT_KERNEL_SHAPE_1 = 1024; 21 | extern const float LSTM_2_RECURRENT_KERNEL[]; 22 | 23 | int DENSE_1_BIAS_SHAPE_0 = 47; 24 | extern const float DENSE_1_BIAS[]; 25 | 26 | int DENSE_1_KERNEL_SHAPE_0 = 256; 27 | int DENSE_1_KERNEL_SHAPE_1 = 47; 28 | extern const float DENSE_1_KERNEL[]; 29 | 30 | -------------------------------------------------------------------------------- /omnisearch: -------------------------------------------------------------------------------- 1 | qcomp -------------------------------------------------------------------------------- /parse: -------------------------------------------------------------------------------- 1 | qcomp.py -------------------------------------------------------------------------------- /qcomp.c: -------------------------------------------------------------------------------- 1 | // vim: noai:ts=4:sw=4 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #ifndef __unix__ 13 | #include 14 | #endif 15 | #include 16 | 17 | #include "mkl.h" 18 | #include "model.h" 19 | 20 | const char *charset = "abcdefghijklmnopqrstuvwxyz0123456789 .-%_:/\\$"; 21 | int char_index[256]; 22 | 23 | // some ancient compiler does not support C11, so we have to supply this function ourselves 24 | void *aligned_alloc(size_t alignment, size_t size) { 25 | void *ptr = NULL; 26 | posix_memalign(&ptr, alignment, size); 27 | return ptr; 28 | } 29 | 30 | // The LSTM struct implementing Kera's LSTM 31 | typedef struct lstm_t { 32 | float *kernel; // concat of Kera's kernel and recurrent kernel 33 | float *bias; 34 | 35 | float **c, **h; 36 | float **z, **t; 37 | float **cbuf, **hbuf; // c and h buffer for beam search 38 | 39 | int n_in, n_hid, max_batch; // # of (input, hidden) units, and max batch size 40 | } lstm_t; 41 | 42 | // A softmax layer with linear input 43 | typedef struct softmax_t { 44 | float *W, *bias; 45 | float **out; 46 | 47 | int n_in, n_out, max_batch; 48 | } softmax_t; 49 | 50 | // The NN model 51 | typedef struct model_t { 52 | float **in; // The input 53 | lstm_t lstm_1, lstm_2; // The two layer LSTM 54 | softmax_t softmax; // THe output softmas layer (w/ linear unit) 55 | 56 | float dropout; // The fraction of dropout 57 | int n_in, n_hid, max_batch; 58 | } model_t; 59 | 60 | // Struct for completion distance 61 | typedef struct dist_t { 62 | int *seq; 63 | float *dist, *dist_new; 64 | int pos, len, extend; 65 | } dist_t; 66 | 67 | // Struct for omnicompletion 68 | typedef struct qcomp_t { 69 | int **cand, **result, **buf; 70 | float *cand_score, *result_score, **new_score; 71 | int *rank, *next_char; 72 | dist_t *dist, *dist_buf; 73 | 74 | int max_batch, max_len; 75 | } qcomp_t; 76 | 77 | // The Trie 78 | typedef struct list_t { 79 | struct list_t *next, *prev; 80 | struct list_t *child, *parent; 81 | struct list_t **top; // pointers to 16 most frequent strings in subtries 82 | int key, val; 83 | double weight; 84 | } list_t; 85 | 86 | 87 | /*********** time utils **********/ 88 | 89 | #define NS_PER_SEC 1000000000 90 | int64_t wall_clock_ns() 91 | { 92 | #ifdef __unix__ 93 | struct timespec tspec; 94 | int r = clock_gettime(CLOCK_MONOTONIC, &tspec); 95 | assert(r==0); 96 | return tspec.tv_sec*NS_PER_SEC + tspec.tv_nsec; 97 | #else 98 | struct timeval tv; 99 | int r = gettimeofday( &tv, NULL ); 100 | assert(r==0); 101 | return tv.tv_sec*NS_PER_SEC + tv.tv_usec*1000; 102 | #endif 103 | } 104 | 105 | double wall_time_diff(int64_t ed, int64_t st) 106 | { 107 | return (double)(ed-st)/(double)NS_PER_SEC; 108 | } 109 | 110 | /************* BLAS utils *************/ 111 | 112 | void mysgemm(int m, int n, int k, float alpha, const float *restrict A, const float *restrict B, float beta, float *restrict C) 113 | { cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, alpha, A, k, B, k, beta, C, n); } 114 | 115 | void mysaxpy(int n, float a, const float *restrict x , float *restrict y) 116 | { cblas_saxpy(n, a, x, 1, y, 1); } 117 | 118 | void myszero(int n, float *x) 119 | { memset(x, 0, n*sizeof(*x)); } 120 | 121 | void myscopy(int n, const float *restrict x, float *restrict y) 122 | { memcpy(y, x, n*sizeof(*x)); } 123 | 124 | void myclip(int n, float *x) 125 | { 126 | x = __builtin_assume_aligned(x, 8*sizeof(float)); 127 | 128 | const __m128 mzero = _mm_set1_ps(0); 129 | const __m128 mone = _mm_set1_ps(1); 130 | for(int i=0; i= 3*self.n_hid) i -= self.n_hid; 190 | else if(ii >= 2*self.n_hid) i += self.n_hid; 191 | 192 | for(int j=0; j sub) cost = sub; 391 | if(cost > del) cost = del; 392 | self.dist_new[j] = cost; 393 | } 394 | return self; 395 | } 396 | 397 | // Commit dist_new to dist 398 | dist_t dist_commit(dist_t self) 399 | { 400 | myscopy(self.len+1, self.dist_new, self.dist); 401 | self.pos++; 402 | 403 | return self; 404 | } 405 | 406 | // Copying (branching) the distance matirx 407 | void dist_copy(dist_t *src, dist_t *dst) 408 | { 409 | for(int i=0; i<=src->len; i++){ 410 | dst->seq[i] = src->seq[i]; 411 | dst->dist[i] = src->dist[i]; 412 | dst->dist_new[i] = src->dist_new[i]; 413 | } 414 | dst->len = src->len; 415 | dst->pos = src->pos; 416 | dst->extend = src->extend; 417 | } 418 | 419 | // The difference between dist_new and dist 420 | float dist_diff(dist_t self) 421 | { 422 | return self.dist_new[self.len] - self.dist[self.len]; 423 | } 424 | 425 | /************* qcomp_t ***********/ 426 | qcomp_t qcomp_init(int max_batch, int max_len) 427 | { 428 | int voc_size = strlen(charset)+2; 429 | qcomp_t self = { 430 | .cand = imat2d_init(max_batch, max_len), 431 | .result = imat2d_init(max_batch, max_len), 432 | .buf = imat2d_init(max_batch, max_len), 433 | .cand_score = calloc(max_batch, sizeof(float)), 434 | .result_score = calloc(max_batch, sizeof(float)), 435 | .new_score = smat2d_init(max_batch, voc_size), 436 | .rank = calloc(max_batch*voc_size, sizeof(int)), 437 | .next_char = calloc(max_batch, sizeof(int)), 438 | .dist = calloc(max_batch, sizeof(dist_t)), 439 | .dist_buf = calloc(max_batch, sizeof(dist_t)), 440 | .max_batch = max_batch, 441 | .max_len = max_len, 442 | }; 443 | 444 | for(int i=0; i1 && i < 60; i++) str[i] = seq[i] - 2; 472 | else 473 | for(i=0; seq[i]>1 && i < 60; i++) str[i] = charset[seq[i]-2]; 474 | str[i] = '\0'; 475 | } 476 | 477 | // Comparator for two index by SCORE[index] 478 | float *SCORE; 479 | int argcmp(const void *x, const void *y) 480 | { 481 | int ix = *((int*)x), iy = *((int*)y); 482 | float fx = SCORE[ix], fy = SCORE[iy]; 483 | if(fxfy) return -1; 485 | else return 0; 486 | } 487 | 488 | // Argsort, returning index matrix sorted by score[index] 489 | void argsort(float *score, int *rank, int n) 490 | { 491 | for(int i=0; i t) break; 516 | 517 | free(prob); 518 | return itval-1; 519 | } 520 | 521 | // Stochastic search 522 | void qcomp_stocsearch(qcomp_t self, model_t model, int *prefix) 523 | { 524 | int voc_size = model.n_in; 525 | for(int j=0; j1 && i < self.max_len; i++){ 534 | int sample = boltzman_sampling(voc_size, model.softmax.out[0], 1); 535 | self.result[j][i] = self.next_char[0] = sample; 536 | score += model.softmax.out[0][sample]; 537 | model_forward(model, 1, self.next_char); 538 | } 539 | self.result_score[j] = score; 540 | } 541 | } 542 | 543 | // Beam search 544 | // inspired by https://gist.github.com/udibr/67be473cf053d8c38730 545 | int qcomp_beamsearch(qcomp_t self, model_t model, int *prefix) 546 | { 547 | int n_cand = 1, n_result = 0; 548 | int voc_size = model.n_in; 549 | 550 | seq_copy(prefix, self.cand[0], self.max_len); 551 | self.cand_score[0] = 0; 552 | 553 | model_reset(model); 554 | int pos=0; 555 | for(int i=0; prefix[i]; i++, pos++){ 556 | self.cand[0][i] = self.next_char[0] = prefix[i]; 557 | model_forward(model, 1, self.next_char); 558 | } 559 | 560 | for(; n_cand && pos < self.max_len; pos++){ 561 | for(int i=0; i60) len = 60, buf[len] = '\0'; 688 | if(buf[len-1] == '\n') buf[--len] = '\0'; 689 | buf[len++] = END; 690 | buf[len] = '\0'; 691 | 692 | // Setup weights 693 | int count = 0; 694 | int pos = 0; 695 | // Counting # of tabs 696 | for (int i = 0; i < len; ++i) { 697 | if(buf[i] == '\t') { 698 | pos = i; 699 | count++; 700 | } 701 | } 702 | // If we find two tabs, after the second tab it is the weight 703 | // Otherwise the weight is 1 704 | float weight = 1.0f; 705 | if (count == 2) { 706 | buf[pos] = END; 707 | weight = atof(buf+pos+1); 708 | } 709 | buf[pos+1] = '\0'; 710 | 711 | // Going down the Trie 712 | list_t *p = root; 713 | int link; 714 | char *s = buf; 715 | for(; *s; s++){ 716 | for(; p->next != NULL && p->key != *s; p = p->next) 717 | ; 718 | 719 | if(p->key != *s){ 720 | link = HORIZONTAL; 721 | break; 722 | } 723 | p->val++; 724 | p->weight += weight; 725 | if(p->child == NULL){ 726 | link = VERTICAL; 727 | s++; 728 | break; 729 | } 730 | p = p->child; 731 | } 732 | 733 | // Create nodes if not matched 734 | for(; *s; s++){ 735 | list_t *old = p; 736 | p = calloc(1, sizeof(list_t)); 737 | trie_size++; 738 | p->key = *s; 739 | p->val = 1; 740 | p->weight = weight; 741 | if(link==VERTICAL) 742 | old->child = p, p->parent = old; 743 | else 744 | old->next = p, p->prev = old; 745 | link = VERTICAL; 746 | } 747 | } 748 | 749 | fprintf(stderr, "trie size = %d\n", trie_size); 750 | 751 | free(buf); 752 | 753 | return root; 754 | } 755 | 756 | int min(int a, int b) 757 | { 758 | if(aval); 768 | int size_b = min(16, b->val); 769 | 770 | int i=0, k=0; 771 | for(int j=0; itop[i], *pb = b->top[j]; 773 | if(!pa || (pb && pa->weight < pb->weight)) MERGE_BUF[k] = pb, j++; 774 | else MERGE_BUF[k] = pa, i++; 775 | } 776 | for(; ktop[i]; 777 | 778 | for(int i=0; itop[i] = MERGE_BUF[i]; 780 | } 781 | } 782 | 783 | // Build the top list by recursively calling merge 784 | void trie_build(list_t *root) 785 | { 786 | int top_size = min(16, root->val); 787 | root->top = calloc(top_size, sizeof(*root->top)); 788 | 789 | // leave node 790 | if(root->key == END){ 791 | root->top[0] = root; 792 | } 793 | 794 | if (root->child) { 795 | trie_build(root->child); 796 | for(list_t *p=root->child; p; p = p->next){ 797 | trie_merge_top(root, p); 798 | } 799 | } 800 | if(root->next) trie_build(root->next); 801 | } 802 | 803 | // Lookup s in the trie, 804 | // returning trie node if found and null otherwise 805 | list_t *trie_lookup(list_t *root, char *s) 806 | { 807 | list_t *p = root; 808 | while(*s){ 809 | for(; p && p->key != *s; p = p->next) 810 | ; 811 | if(!p) break; 812 | s++; 813 | if(!*s) break; 814 | p = p->child; 815 | } 816 | return p; 817 | } 818 | 819 | // Trie search 820 | int qcomp_triesearch(qcomp_t qcomp, list_t *root, int *seq) 821 | { 822 | char prefix[258], r[258], buf[258]; 823 | #ifdef TRANSPARENT_TRIE 824 | qcomp_decode(qcomp, seq, prefix, 1); 825 | #else 826 | qcomp_decode(qcomp, seq, prefix, 0); 827 | #endif 828 | list_t *q = trie_lookup(root, prefix); 829 | if(!q){ 830 | for(int i=0; i=q->val || !q->top[i]){ 836 | qcomp.result_score[i] = 0; 837 | qcomp.result[i][0] = 0; 838 | continue; 839 | } 840 | qcomp.result_score[i] = log(q->top[i]->weight)-log(q->weight); 841 | 842 | int len=0; 843 | for(list_t *p = q->top[i]; p != q; ){ 844 | r[len++] = p->key; 845 | while(!p->parent) 846 | p = p->prev; 847 | p = p->parent; 848 | } 849 | for(int j=0, k=len-1, t; jval; 861 | } 862 | 863 | enum{STOCSEARCH, BEAMSEARCH, OMNISEARCH, TRIESEARCH, TRIELOOKUP}; 864 | /*****************************/ 865 | int main(int argc, char **argv) 866 | { 867 | srand((unsigned)0); 868 | 869 | omp_set_num_threads(8); 870 | int n_beam = 16, max_len = 60; 871 | int voc_size = LSTM_1_KERNEL_SHAPE_0; 872 | model_t model = model_init(n_beam, voc_size, 256); 873 | model_load(model); 874 | qcomp_t qcomp = qcomp_init(n_beam, max_len); 875 | 876 | char *prefix = malloc(2048), prev[62] = {0}; 877 | char str[62]; 878 | int seq[62], rank[16]; 879 | 880 | const char *prog_name = argv[0]; 881 | const char *base_name = prog_name+strlen(prog_name)-1; 882 | for(; base_name!=prog_name && isalnum(*base_name); base_name--) 883 | ; 884 | base_name += 1; 885 | 886 | int mode = 0; 887 | int transparent = 0; 888 | // Different program names would invoke different functionalities 889 | if(strcmp(base_name, "stocsearch") == 0) 890 | mode = STOCSEARCH; 891 | else if(strcmp(base_name, "beamsearch") == 0) 892 | mode = BEAMSEARCH; 893 | else if(strcmp(base_name, "omnisearch") == 0) 894 | mode = OMNISEARCH; 895 | else if(strcmp(base_name, "triesearch") == 0) { 896 | mode = TRIESEARCH; 897 | #ifdef TRANSPARENT_TRIE 898 | transparent = 1; 899 | #endif 900 | } 901 | else if(strcmp(base_name, "trielookup") == 0) { 902 | mode = TRIELOOKUP; 903 | #ifdef TRANSPARENT_TRIE 904 | transparent = 1; 905 | #endif 906 | } 907 | else 908 | fprintf(stderr, "ERROR CMD\n"), exit(1); 909 | fprintf(stderr, "%s mode %d\n", base_name, mode); 910 | 911 | list_t *root = NULL; 912 | if(mode == TRIESEARCH || mode == TRIELOOKUP){ 913 | root = trie_init(argv[1]); 914 | trie_build(root); 915 | } 916 | int to_end = 0; 917 | if(argc >= 3 && strcmp(argv[2], "1") == 0) to_end = 1; 918 | signal(SIGPIPE, SIG_IGN); 919 | fprintf(stderr, "ready\n"); 920 | while(NULL != fgets(prefix, 2000, stdin)){ 921 | int len = strlen(prefix); 922 | if(len>60) len = 60, prefix[len] = '\0'; 923 | if(prefix[len-1] == '\n') prefix[len-1] = '\0'; 924 | qcomp_encode(qcomp, prefix, seq, transparent); 925 | 926 | 927 | int64_t st = wall_clock_ns(); 928 | // if(strcmp(prefix, prev) == 0 && mode != TRIELOOKUP) 929 | // fprintf(stderr, "repeat\n"); 930 | if(mode==STOCSEARCH) 931 | qcomp_stocsearch(qcomp, model, seq); 932 | else if(mode==BEAMSEARCH) 933 | qcomp_beamsearch(qcomp, model, seq); 934 | else if(mode==OMNISEARCH) 935 | qcomp_omnisearch(qcomp, model, seq); 936 | else if(mode==TRIESEARCH) { 937 | int count = qcomp_triesearch(qcomp, root, seq); 938 | // also count and print to stderr 939 | fprintf(stderr, "%d\n", count); 940 | fflush(stderr); 941 | } 942 | else if(mode==TRIELOOKUP){ 943 | int len = strlen(prefix); 944 | if(to_end) prefix[len++] = END; 945 | prefix[len] = '\0'; 946 | 947 | list_t *p = trie_lookup(root, prefix); 948 | if(!p) printf("0\n"); 949 | else printf("%d %f\n", p->val, p->weight); 950 | fflush(stdout); 951 | continue; 952 | } 953 | 954 | argsort(qcomp.result_score, rank, n_beam); 955 | 956 | for(int jj=0; jj 3 52 | r = 2+random.randint(0,max(0,len(query)-3)) 53 | prefix = query[:r] 54 | md5 = hashlib.md5(prefix).hexdigest() 55 | line = '\t'.join([str(int(timestamp)), query, prefix, md5]) 56 | print(line) 57 | 58 | class Sequencer(object): 59 | PAD, END = 0, 1 60 | def __init__(self): 61 | self.token_to_indice = dict([(c,i+2) for (i,c) in enumerate(charset)]) 62 | self.vocabs = ['PAD', 'END']+list(charset) 63 | 64 | def encode(self, line, ending=True): 65 | seq = map(self.token_to_indice.__getitem__, line) 66 | if ending: 67 | seq.append(self.END) 68 | return seq 69 | 70 | def decode(self, seq): 71 | if not seq: 72 | return '' 73 | if seq[-1] == self.END: 74 | seq = seq[:-1] 75 | line = ''.join(map(self.vocabs.__getitem__, seq)) 76 | return line 77 | 78 | def padding(seq, maxlen): 79 | return pad_sequences(seq, maxlen, padding='post', value=0) 80 | 81 | class WeightsSaver(Callback): 82 | def __init__(self, model, N): 83 | self.model = model 84 | self.N = N 85 | self.batch = 0 86 | 87 | def on_batch_end(self, batch, logs={}): 88 | if self.batch % self.N == 0: 89 | name = 'weights/weights%08d.hdf5' % self.batch 90 | self.model.save_weights(name) 91 | self.batch += 1 92 | 93 | class LanguageModel(object): 94 | def __init__(self): 95 | self.sqn = Sequencer() 96 | 97 | def save(self): 98 | pass 99 | 100 | def load(self): 101 | pass 102 | 103 | def build(self, hid_size, n_hid_layers, drp_rate, batch_size): 104 | cin = Input(batch_shape=(None, None)) 105 | voc_size = len(self.sqn.vocabs) 106 | # A trick to map categories to onehot encoding 107 | emb = Embedding(voc_size, voc_size, trainable=False, weights=[np.identity(voc_size)])(cin) 108 | prev = emb 109 | for i in range(n_hid_layers): 110 | lstm = LSTM(hid_size, return_sequences=True, implementation=2)(prev) 111 | dropout = Dropout(drp_rate)(lstm) 112 | prev = dropout 113 | cout = Dense(voc_size, activation='softmax')(prev) 114 | 115 | self.model = Model(inputs=cin, outputs=cout) 116 | self.model.summary() 117 | 118 | self.batch_size = batch_size 119 | 120 | def train(self, fname, maxlen, lr=1e-3): 121 | ref = [] 122 | 123 | for line in open(fname): 124 | line = line.strip() 125 | seq = self.sqn.encode(line) 126 | ref.append(seq) 127 | ref = np.array(ref) 128 | ref = padding(ref, maxlen+1) 129 | X, Y = ref[:, :-1], ref[:, 1:] 130 | Y = np.expand_dims(Y, -1) 131 | M = X>self.sqn.END 132 | M[:,0] = 0 133 | 134 | self.model.compile( 135 | loss='sparse_categorical_crossentropy', 136 | sample_weight_mode='temporal', 137 | optimizer=Adam(lr=lr) 138 | ) 139 | self.model.fit(X, Y, batch_size=self.batch_size, sample_weight=M, 140 | callbacks=[WeightsSaver(self.model, 500)], 141 | validation_split=0.01, 142 | epochs=3 143 | ) 144 | 145 | def array_str(arr): 146 | s = ', '.join(['%.8e' % x for x in arr]) 147 | return s+',\n' 148 | 149 | 150 | def sanitize_for_tf(name): 151 | #HACK for make the variable names consistent between THEANO and TENSORFLOW models 152 | return name.replace("KERNEL:0","KERNEL").replace("BIAS:0","BIAS") 153 | 154 | # Dumping the HDF5 weights to a model.c file 155 | # and specifies the dimension in model.h 156 | def dump(fname): 157 | f = h5py.File(fname) 158 | fheader = open('model.h', 'w') 159 | fctx = open('model.c', 'w') 160 | for name in f.attrs['layer_names']: 161 | if name.startswith('lstm') or name.startswith('dense'): 162 | layer = f[name][name] 163 | for elem in layer: 164 | shape = layer[elem].shape 165 | for i,n in enumerate(shape): 166 | current_row='int '+(name+'_%s_shape_%d = %d;\n'%(elem, i, n)).upper() 167 | current_row = sanitize_for_tf(current_row) 168 | fheader.write(current_row) 169 | elem_decl = 'const float '+(name+'_'+elem).upper()+'[]' 170 | 171 | elem_decl = sanitize_for_tf(elem_decl) 172 | 173 | fheader.write('extern '+elem_decl+';\n\n') 174 | 175 | fctx.write(elem_decl+' = {\n') 176 | mat = np.array(layer[elem]) 177 | if len(shape) == 2: 178 | for i in range(shape[0]): 179 | fctx.write(array_str(mat[i])) 180 | else: 181 | fctx.write(array_str(mat)) 182 | 183 | fctx.write('};\n\n') 184 | 185 | 186 | if __name__ == '__main__': 187 | prog_name = os.path.basename(sys.argv[0]) 188 | if prog_name == 'train': 189 | q = LanguageModel() 190 | q.build(256, 2, 0.5, 256) 191 | q.train(sys.argv[1], 60) 192 | elif prog_name == 'parse': 193 | parse(sys.argv[1]) 194 | elif prog_name == 'dump': 195 | dump(sys.argv[1]) 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | keras==2.0.4 2 | theano==0.9.0 3 | h5py 4 | numpy 5 | -------------------------------------------------------------------------------- /run_eval.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | DATA=sorted.txt 6 | HEADER_CHAR=false 7 | 8 | # Calculate line numbers 9 | LINES=`wc -l < ${DATA}` 10 | TEST_LINES=$((${LINES}/100)) 11 | TRAIN_LINES=$((${LINES}-${TEST_LINES})) 12 | echo "Using ${TRAIN_LINES} for training, ${TEST_LINES} lines for testing" 13 | 14 | # Do the splitting 15 | if ${HEADER_CHAR}; then 16 | head -n ${TRAIN_LINES} ${DATA} | cut -f 1,2 | shuf > ${DATA}.train 17 | tail -n ${TEST_LINES} ${DATA} | cut -f 1,2 > ${DATA}.test 18 | else 19 | head -n ${TRAIN_LINES} ${DATA} | cut -f 2 | shuf > ${DATA}.train 20 | tail -n ${TEST_LINES} ${DATA} | cut -f 2 > ${DATA}.test 21 | fi 22 | cat ${DATA}.train ${DATA}.test > ${DATA}.whole 23 | # cut -f3 ${DATA}.test > ${DATA}.prefix 24 | tail -n ${TEST_LINES} ${DATA} | cut -f 3 > ${DATA}.prefix 25 | 26 | # train the model file 27 | mkdir -p weights 28 | ./train ${DATA}.whole # note that we pass in the whole file to get the testing loss 29 | mv weights/`ls -Art weights | tail -n 1` weights.hdf5 30 | ./dump weights.hdf5 31 | make clean 32 | make 33 | 34 | # do the prediction 35 | cat ${DATA}.prefix | ./stocsearch > result_stoc.txt 36 | cat ${DATA}.prefix | ./beamsearch > result_beam.txt 37 | cat ${DATA}.prefix | ./omnisearch > result_omni.txt 38 | # I think trie search should only use training set? 39 | cat ${DATA}.prefix | ./triesearch ${DATA}.whole > result_trie.txt 40 | 41 | # count the appearance of results 42 | cut -f2 result_beam.txt | ./trielookup ${DATA}.whole 1 > result_beam_freq.txt 43 | cut -f2 result_omni.txt | ./trielookup ${DATA}.whole 1 > result_omni_freq.txt 44 | cut -f2 result_trie.txt | ./trielookup ${DATA}.whole 1 > result_trie_freq.txt 45 | 46 | # count the appearance of prefix 47 | cat ${DATA}.prefix | ./trielookup ${DATA}.whole 0 > prefix_freq.txt 48 | 49 | # evaluate the metrics 50 | echo 'beamsearch' 51 | python eval.py prefix_freq.txt result_beam_freq.txt 52 | echo 'omnisearch' 53 | python eval.py prefix_freq.txt result_omni_freq.txt 54 | echo 'triesearch' 55 | python eval.py prefix_freq.txt result_trie_freq.txt 56 | -------------------------------------------------------------------------------- /stocsearch: -------------------------------------------------------------------------------- 1 | qcomp -------------------------------------------------------------------------------- /train: -------------------------------------------------------------------------------- 1 | qcomp.py -------------------------------------------------------------------------------- /trielookup: -------------------------------------------------------------------------------- 1 | qcomp -------------------------------------------------------------------------------- /triesearch: -------------------------------------------------------------------------------- 1 | qcomp --------------------------------------------------------------------------------