├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── data ├── download_text8.sh ├── google_analogies_test_set │ └── questions-words.txt └── wikifil.pl ├── images ├── visualize_nearest_man.png └── visualize_nearest_science.png └── src ├── compute-accuracy.c └── word2bits.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | *.#* 2 | *~ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CC_MIKOLOV=g++ 2 | CFLAGS=-O3 -march=native -lm -pthread -Wno-unused-result 3 | 4 | word2bits: 5 | $(CC_MIKOLOV) $(CFLAGS) ./src/word2bits.cpp -o word2bits 6 | compute_accuracy: 7 | $(CC_MIKOLOV) $(CFLAGS) ./src/compute-accuracy.c -o compute_accuracy 8 | clean: 9 | rm -f word2bits 10 | rm -f compute_accuracy 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Word2Bits - Quantized Word Vectors 2 | 3 | Word2Bits extends the Word2Vec algorithm to output high quality 4 | quantized word vectors that take 8x-16x less storage than 5 | regular word vectors. Read the details at [https://arxiv.org/abs/1803.05651](https://arxiv.org/abs/1803.05651). 6 | 7 | ## What are Quantized Word Vectors? 8 | 9 | Quantized word vectors are word vectors where each parameter 10 | is one of `2^bitlevel` values. 11 | 12 | For example, the 1-bit quantized vector for "king" looks something 13 | like 14 | 15 | ``` 16 | 0.33333334 0.33333334 0.33333334 -0.33333334 -0.33333334 -0.33333334 0.33333334 0.33333334 -0.33333334 0.33333334 0.33333334 ... 17 | ``` 18 | 19 | Since parameters are limited to one of `2^bitlevel` values, each parameter 20 | takes only `bitlevel` bits to represent; this drastically reduces 21 | the amount of storage that word vectors take. 22 | 23 | ## Download Pretrained Word Vectors 24 | 25 | - All word vectors are in Glove/Fasttext format (format details [here](https://fasttext.cc/docs/en/english-vectors.html)). Files are compressed using gzip. 26 | 27 | | # Bits per parameter | Dimension | Trained on | Vocabulary size | File Size (Compressed) | Download Link | 28 | |:---------------------------:|:-------------:|:----------------------:|:----------------:|:----------------------:|:-------------:| 29 | | 1 | 800 | English Wikipedia 2017 | Top 400k | 86M | [w2b_bitlevel1_size800_vocab400K.tar.gz](https://drive.google.com/open?id=107guTTy93J-y7UCO2ZA2spxRIFpoqhjh) | 30 | | 1 | 1000 | English Wikipedia 2017 | Top 400k | 106M | [w2b_bitlevel1_size1000_vocab400K.tar.gz](https://drive.google.com/open?id=1boP7aFnABifVKRD9M-lxLqH-BM-C5-Z6) | 31 | | 1 | 1200 | English Wikipedia 2017 | Top 400k | 126M | [w2b_bitlevel1_size1200_vocab400K.tar.gz](https://drive.google.com/open?id=1zmoHFd9KqsCvuvYqpMl0Si21wn9IHMq2) | 32 | | 2 | 400 | English Wikipedia 2017 | Top 400k | 67M | [w2b_bitlevel2_size400_vocab400K.tar.gz](https://drive.google.com/open?id=1KHNDZW9dawwy9Ie73fdnMKAcGfadTI5J) | 33 | | 2 | 800 | English Wikipedia 2017 | Top 400k | 134M | [w2b_bitlevel2_size800_vocab400K.tar.gz](https://drive.google.com/open?id=1l3G4tyI8mU7bGsMG0TTPiM4fucmniJaR) | 34 | | 2 | 1000 | English Wikipedia 2017 | Top 400k | 168M | [w2b_bitlevel2_size1000_vocab400K.tar.gz](https://drive.google.com/open?id=1RX5z-jjpylAKTxpVazWqQmkDZ0XnumsB) | 35 | | 32 | 200 | English Wikipedia 2017 | Top 400k | 364M | [w2b_bitlevel0_size200_vocab400K.tar.gz](https://drive.google.com/open?id=1HKiDirbJ9oxJN1HXGdczvmjWTIazE0Gb) | 36 | | 32 | 400 | English Wikipedia 2017 | Top 400k | 724M | [w2b_bitlevel0_size400_vocab400K.tar.gz](https://drive.google.com/open?id=1ToIOpo0uhfGG48qsOZeDacPmt7Sh0Uup) | 37 | | 32 | 800 | English Wikipedia 2017 | Top 400k | 1.4G | [w2b_bitlevel0_size800_vocab400K.tar.gz](https://drive.google.com/open?id=1IMev4MIQKSx5CPgGhTxZo2EJ7nsVEVGT) | 38 | | 32 | 1000 | English Wikipedia 2017 | Top 400k | 1.8G | [w2b_bitlevel0_size1000_vocab400K.tar.gz](https://drive.google.com/open?id=1CtNjaQqK2Aw-iIeqRXdJOVNIqeALzOTi) | 39 | | 1 | 800 | English Wikipedia 2017 | 3.7M (Full) | 812M | [w2b_bitlevel1_size800_vocab3.7M.tar.gz](https://drive.google.com/open?id=1fisO5pl3KbP5DEGqb3-b8RxOsbqzquZE) | 40 | | 2 | 400 | English Wikipedia 2017 | 3.7M (Full) | 671M | [w2b_bitlevel2_size400_vocab3.7M.tar.gz](https://drive.google.com/open?id=139YwOxwhoIgKACXUJnfOdecxIueEkwf9) | 41 | | 32 | 400 | English Wikipedia 2017 | 3.7M (Full) | 6.7G | [w2b_bitlevel0_size400_vocab3.7M.tar.gz](https://drive.google.com/open?id=1zyizh_oJ3RHtdaHdT_V7eQUTwETuXKm6) | 42 | 43 | ## Visualizing Quantized Word Vectors 44 | 45 | 46 | 47 | (Note: every 5 word vectors are labelled; turquoise line boundary between nearest and furthest word vectors from target.) 48 | 49 | ## Using the Code 50 | 51 | ### Quickstart 52 | 53 | Compile with 54 | ``` 55 | make word2bits 56 | ``` 57 | 58 | Run with 59 | ``` 60 | ./word2bits -train input -bitlevel 1 -size 200 -window 10 -negative 12 -threads 2 -iter 5 -min-count 5 -output 1bit_200d_vectors -binary 0 61 | ``` 62 | Description of the most common flags: 63 | ``` 64 | -train Input corpus text file 65 | -bitlevel Number of bits for each parameter. 0 is full precision (or 32 bits). 66 | -size Word vector dimension 67 | -window Window size 68 | -negative Negative sample size 69 | -threads Number of threads to use to train 70 | -iter Number of epochs to train 71 | -min-count Minimum count value. Words appearing less than value are removed from corpus. 72 | -output Path to write output word vectors 73 | -binary 0 to write in Glove format; 1 to write in binary format. 74 | ``` 75 | 76 | ### Example: Word2Bits on text8 77 | 78 | 1. Download and preprocess text8 (make sure you're in the Word2Bits base directory). 79 | ``` 80 | bash data/download_text8.sh 81 | ``` 82 | 83 | 2. Compile Word2Bits and compute accuracy 84 | ``` 85 | make word2bits 86 | ``` 87 | ``` 88 | make compute_accuracy 89 | ``` 90 | 91 | 3. Train 1 bit 200 dimensional word vectors for 5 epochs using 4 threads (save in binary so that compute_accuracy can work with it) 92 | ``` 93 | ./word2bits -bitlevel 1 -size 200 -window 8 -negative 24 -threads 4 -iter 5 -min-count 5 -train text8 -output 1b200d_vectors -binary 1 94 | ``` 95 | 96 | (This will take several minutes. Run with more threads if you have more cores!) 97 | 98 | 4. Evaluate vectors on Google Analogy Task 99 | ``` 100 | ./compute_accuracy ./1b200d_vectors < data/google_analogies_test_set/questions-words.txt 101 | ``` 102 | 103 | You should see output like: 104 | ``` 105 | Starting eval... 106 | capital-common-countries: 107 | ACCURACY TOP1: 19.76 % (100 / 506) 108 | Total accuracy: 19.76 % Semantic accuracy: 19.76 % Syntactic accuracy: -nan % 109 | capital-world: 110 | ACCURACY TOP1: 8.81 % (239 / 2713) 111 | Total accuracy: 10.53 % Semantic accuracy: 10.53 % Syntactic accuracy: -nan % 112 | ... 113 | gram8-plural: 114 | ACCURACY TOP1: 19.92 % (251 / 1260) 115 | Total accuracy: 11.48 % Semantic accuracy: 13.27 % Syntactic accuracy: 10.25 % 116 | gram9-plural-verbs: 117 | ACCURACY TOP1: 6.09 % (53 / 870) 118 | Total accuracy: 11.20 % Semantic accuracy: 13.27 % Syntactic accuracy: 9.88 % 119 | Questions seen / total: 16284 19544 83.32 % 120 | ``` 121 | 122 | Inspecting the vector file in hex should show something like: 123 | ``` 124 | $ od --format=x1 --read-bytes=160 1b200d_vectors 125 | 0000000 36 30 32 33 38 20 32 30 30 0a 3c 2f 73 3e 20 ab 126 | 0000020 aa aa 3e ab aa aa 3e ab aa aa be ab aa aa be ab 127 | 0000040 aa aa 3e ab aa aa 3e ab aa aa 3e ab aa aa be ab 128 | ... 129 | 0000160 aa aa be ab aa aa 3e ab aa aa be ab aa aa 3e ab 130 | 0000200 aa aa be ab aa aa 3e ab aa aa 3e ab aa aa 3e ab 131 | 0000220 aa aa be ab aa aa 3e ab aa aa be ab aa aa 3e ab 132 | ``` 133 | -------------------------------------------------------------------------------- /data/download_text8.sh: -------------------------------------------------------------------------------- 1 | wget http://mattmahoney.net/dc/enwik8.zip 2 | unzip enwik8.zip 3 | perl data/wikifil.pl enwik8 > text8 4 | rm -f enwik8.zip 5 | rm -f enwik8 6 | -------------------------------------------------------------------------------- /data/wikifil.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | 3 | # Program to filter Wikipedia XML dumps to "clean" text consisting only of lowercase 4 | # letters (a-z, converted from A-Z), and spaces (never consecutive). 5 | # All other characters are converted to spaces. Only text which normally appears 6 | # in the web browser is displayed. Tables are removed. Image captions are 7 | # preserved. Links are converted to normal text. Digits are spelled out. 8 | 9 | # Written by Matt Mahoney, June 10, 2006. This program is released to the public domain. 10 | 11 | $/=">"; # input record separator 12 | while (<>) { 13 | if (/ ... 14 | if (/#redirect/i) {$text=0;} # remove #REDIRECT 15 | if ($text) { 16 | 17 | # Remove any text not normally visible 18 | if (/<\/text>/) {$text=0;} 19 | s/<.*>//; # remove xml tags 20 | s/&/&/g; # decode URL encoded chars 21 | s/<//g; 23 | s///g; # remove references ... 24 | s/<[^>]*>//g; # remove xhtml tags 25 | s/\[http:[^] ]*/[/g; # remove normal url, preserve visible text 26 | s/\|thumb//ig; # remove images links, preserve caption 27 | s/\|left//ig; 28 | s/\|right//ig; 29 | s/\|\d+px//ig; 30 | s/\[\[image:[^\[\]]*\|//ig; 31 | s/\[\[category:([^|\]]*)[^]]*\]\]/[[$1]]/ig; # show categories without markup 32 | s/\[\[[a-z\-]*:[^\]]*\]\]//g; # remove links to other languages 33 | s/\[\[[^\|\]]*\|/[[/g; # remove wiki url, preserve visible text 34 | s/\{\{[^\}]*\}\}//g; # remove {{icons}} and {tables} 35 | s/\{[^\}]*\}//g; 36 | s/\[//g; # remove [ and ] 37 | s/\]//g; 38 | s/&[^;]*;/ /g; # remove URL encoded chars 39 | 40 | # convert to lowercase letters and spaces, spell digits 41 | $_=" $_ "; 42 | tr/A-Z/a-z/; 43 | s/0/ zero /g; 44 | s/1/ one /g; 45 | s/2/ two /g; 46 | s/3/ three /g; 47 | s/4/ four /g; 48 | s/5/ five /g; 49 | s/6/ six /g; 50 | s/7/ seven /g; 51 | s/8/ eight /g; 52 | s/9/ nine /g; 53 | tr/a-z/ /cs; 54 | chop; 55 | print $_; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /images/visualize_nearest_man.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agnusmaximus/Word2Bits/d029cca3fac2c06e1928d2a71462530b2cbe54cd/images/visualize_nearest_man.png -------------------------------------------------------------------------------- /images/visualize_nearest_science.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agnusmaximus/Word2Bits/d029cca3fac2c06e1928d2a71462530b2cbe54cd/images/visualize_nearest_science.png -------------------------------------------------------------------------------- /src/compute-accuracy.c: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Google Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | const long long max_size = 2000; // max length of strings 23 | const long long N = 1; // number of closest words 24 | const long long max_w = 50; // max length of vocabulary entries 25 | 26 | float quantize(float num, int bitlevel) { 27 | 28 | if (bitlevel == 0) { 29 | // Special bitlevel 0 => full precision 30 | return num; 31 | } 32 | 33 | // Extract sign 34 | float retval = 0; 35 | float sign = num < 0 ? -1 : 1; 36 | num *= sign; 37 | 38 | // Boundaries: 0 39 | if (bitlevel == 1) { 40 | return sign / 3; 41 | } 42 | 43 | // Determine boundary and discrete activation value (2 bits) 44 | // Boundaries: 0, .5 45 | if (bitlevel == 2) { 46 | if (num >= 0 && num <= .5) retval = .25; 47 | else retval = .75; 48 | } 49 | 50 | // Determine boundary and discrete activation value (4 bits = 16 values) 51 | // Boundaries: 0, .1, .2, .3, .4, .5, .6, .7, .8 52 | //real boundaries[] = {0, .25, .5, .75, 1, 1.25, 1.5, 1.75}; 53 | if (bitlevel >= 4) { 54 | int segmentation = pow(2, bitlevel-1); 55 | int casted = (num * segmentation) + (float).5; 56 | casted = casted > segmentation ? segmentation : casted; 57 | retval = casted / (float)segmentation; 58 | } 59 | 60 | return sign * retval; 61 | } 62 | 63 | int main(int argc, char **argv) 64 | { 65 | FILE *f; 66 | char st1[max_size], st2[max_size], st3[max_size], st4[max_size], bestw[N][max_size], file_name[max_size]; 67 | float dist, len, bestd[N], vec[max_size]; 68 | long long words, size, a, b, c, d, b1, b2, b3, threshold = 0; 69 | int bitlevel = 0; 70 | float *M; 71 | char *vocab; 72 | int TCN, CCN = 0, TACN = 0, CACN = 0, SECN = 0, SYCN = 0, SEAC = 0, SYAC = 0, QID = 0, TQ = 0, TQS = 0; 73 | if (argc < 2) { 74 | printf("Usage: ./compute-accuracy \nwhere FILE contains word projections, and threshold is used to reduce vocabulary of the model for fast approximate evaluation (0 = off, otherwise typical value is 30000)\n"); 75 | return 0; 76 | } 77 | strcpy(file_name, argv[1]); 78 | if (argc > 2) bitlevel = atoi(argv[2]); 79 | if (argc > 3) threshold = atoi(argv[3]); 80 | f = fopen(file_name, "rb"); 81 | if (f == NULL) { 82 | printf("Input file not found\n"); 83 | return -1; 84 | } 85 | fscanf(f, "%lld", &words); 86 | if (threshold) if (words > threshold) words = threshold; 87 | fscanf(f, "%lld", &size); 88 | vocab = (char *)malloc(words * max_w * sizeof(char)); 89 | M = (float *)malloc(words * size * sizeof(float)); 90 | if (M == NULL) { 91 | printf("Cannot allocate memory: %lld MB\n", words * size * sizeof(float) / 1048576); 92 | return -1; 93 | } 94 | printf("Starting eval...\n"); 95 | fflush(stdout); 96 | for (b = 0; b < words; b++) { 97 | a = 0; 98 | while (1) { 99 | vocab[b * max_w + a] = fgetc(f); 100 | if (feof(f) || (vocab[b * max_w + a] == ' ')) break; 101 | if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++; 102 | } 103 | vocab[b * max_w + a] = 0; 104 | for (a = 0; a < max_w; a++) vocab[b * max_w + a] = toupper(vocab[b * max_w + a]); 105 | for (a = 0; a < size; a++) fread(&M[a + b * size], sizeof(float), 1, f); 106 | for (a = 0; a < size; a++) M[a+b*size] = quantize(M[a+b*size], bitlevel); 107 | len = 0; 108 | for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size]; 109 | len = sqrt(len); 110 | for (a = 0; a < size; a++) M[a + b * size] /= len; 111 | } 112 | fclose(f); 113 | TCN = 0; 114 | while (1) { 115 | for (a = 0; a < N; a++) bestd[a] = 0; 116 | for (a = 0; a < N; a++) bestw[a][0] = 0; 117 | scanf("%s", st1); 118 | for (a = 0; a < strlen(st1); a++) st1[a] = toupper(st1[a]); 119 | if ((!strcmp(st1, ":")) || (!strcmp(st1, "EXIT")) || feof(stdin)) { 120 | if (TCN == 0) TCN = 1; 121 | if (QID != 0) { 122 | printf("ACCURACY TOP1: %.2f %% (%d / %d)\n", CCN / (float)TCN * 100, CCN, TCN); 123 | printf("Total accuracy: %.2f %% Semantic accuracy: %.2f %% Syntactic accuracy: %.2f %% \n", CACN / (float)TACN * 100, SEAC / (float)SECN * 100, SYAC / (float)SYCN * 100); 124 | } 125 | QID++; 126 | scanf("%s", st1); 127 | if (feof(stdin)) break; 128 | printf("%s:\n", st1); 129 | TCN = 0; 130 | CCN = 0; 131 | continue; 132 | } 133 | if (!strcmp(st1, "EXIT")) break; 134 | scanf("%s", st2); 135 | for (a = 0; a < strlen(st2); a++) st2[a] = toupper(st2[a]); 136 | scanf("%s", st3); 137 | for (a = 0; a bestd[a]) { 168 | for (d = N - 1; d > a; d--) { 169 | bestd[d] = bestd[d - 1]; 170 | strcpy(bestw[d], bestw[d - 1]); 171 | } 172 | bestd[a] = dist; 173 | strcpy(bestw[a], &vocab[c * max_w]); 174 | break; 175 | } 176 | } 177 | } 178 | if (!strcmp(st4, bestw[0])) { 179 | CCN++; 180 | CACN++; 181 | if (QID <= 5) SEAC++; else SYAC++; 182 | } 183 | if (QID <= 5) SECN++; else SYCN++; 184 | TCN++; 185 | TACN++; 186 | } 187 | printf("Questions seen / total: %d %d %.2f %% \n", TQS, TQ, TQS/(float)TQ*100); 188 | return 0; 189 | } 190 | -------------------------------------------------------------------------------- /src/word2bits.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2013 Google Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | using namespace std; 27 | typedef numeric_limits< double > dbl; 28 | 29 | #define MAX_STRING 4096 30 | #define EXP_TABLE_SIZE 1000 31 | #define MAX_EXP 6 32 | #define MAX_SENTENCE_LENGTH 1000 33 | #define MAX_CODE_LENGTH 40 34 | 35 | const int vocab_hash_size = 30000000; // Maximum 30 * 0.7 = 21M words in the vocabulary 36 | 37 | typedef float real; // Precision of float numbers 38 | 39 | struct vocab_word { 40 | long long cn; 41 | int *point; 42 | char *word, *code, codelen; 43 | }; 44 | 45 | char train_file[MAX_STRING], output_file[MAX_STRING]; 46 | char save_vocab_file[MAX_STRING], read_vocab_file[MAX_STRING]; 47 | struct vocab_word *vocab; 48 | int binary = 0, debug_mode = 2, window = 5, min_count = 5, num_threads = 12, min_reduce = 1, bitlevel = 1; 49 | int *vocab_hash; 50 | long long vocab_max_size = 1000, vocab_size = 0, layer1_size = 100; 51 | long long train_words = 0, word_count_actual = 0, iter = 5, file_size = 0, classes = 0; 52 | bool save_every_epoch = 0; 53 | real alpha = 0.05, starting_alpha, sample = 1e-3; 54 | real reg = 0; 55 | double *thread_losses; 56 | real *u, *v, *expTable; 57 | clock_t start; 58 | 59 | int hs = 0, negative = 5; 60 | const int table_size = 1e8; 61 | int *table; 62 | 63 | /////////////////////////////////////////// 64 | // Word2Bits // 65 | /////////////////////////////////////////// 66 | 67 | real sigmoid(real val) { 68 | if (val > MAX_EXP) return 1; 69 | if (val < -MAX_EXP) return 1e-9; 70 | return 1 / (1 + (real)exp(-val)); 71 | } 72 | 73 | real quantize(real num, int bitlevel) { 74 | 75 | if (bitlevel == 0) { 76 | // Special bitlevel 0 => full precision 77 | return num; 78 | } 79 | 80 | // Extract sign 81 | real retval = 0; 82 | real sign = num < 0 ? -1 : 1; 83 | num *= sign; 84 | 85 | // Boundaries: 0 86 | if (bitlevel == 1) { 87 | return sign / 3; 88 | } 89 | 90 | // Determine boundary and discrete activation value (2 bits) 91 | // Boundaries: 0, .5 92 | if (bitlevel == 2) { 93 | if (num >= 0 && num <= .5) retval = .25; 94 | else retval = .75; 95 | } 96 | 97 | // Determine boundary and discrete activation value (4 bits = 16 values) 98 | // Boundaries: 0, .1, .2, .3, .4, .5, .6, .7, .8 99 | //real boundaries[] = {0, .25, .5, .75, 1, 1.25, 1.5, 1.75}; 100 | if (bitlevel >= 4) { 101 | int segmentation = pow(2, bitlevel-1); 102 | int casted = (num * segmentation) + (real).5; 103 | casted = casted > segmentation ? segmentation : casted; 104 | retval = casted / (real)segmentation; 105 | } 106 | 107 | return sign * retval; 108 | } 109 | 110 | ///////////////////////////////////////////////////////////// 111 | 112 | void InitUnigramTable() { 113 | int a, i; 114 | double train_words_pow = 0; 115 | double d1, power = 0.75; 116 | table = (int *)malloc(table_size * sizeof(int)); 117 | for (a = 0; a < vocab_size; a++) train_words_pow += pow(vocab[a].cn, power); 118 | i = 0; 119 | d1 = pow(vocab[i].cn, power) / train_words_pow; 120 | for (a = 0; a < table_size; a++) { 121 | table[a] = i; 122 | if (a / (double)table_size > d1) { 123 | i++; 124 | d1 += pow(vocab[i].cn, power) / train_words_pow; 125 | } 126 | if (i >= vocab_size) i = vocab_size - 1; 127 | } 128 | } 129 | 130 | // Reads a single word from a file, assuming space + tab + EOL to be word boundaries 131 | void ReadWord(char *word, FILE *fin, char *eof) { 132 | int a = 0, ch; 133 | while (1) { 134 | ch = fgetc(fin); 135 | if (ch == EOF) { 136 | *eof = 1; 137 | break; 138 | } 139 | if (ch == 13) continue; 140 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) { 141 | if (a > 0) { 142 | if (ch == '\n') ungetc(ch, fin); 143 | break; 144 | } 145 | if (ch == '\n') { 146 | strcpy(word, (char *)""); 147 | return; 148 | } else continue; 149 | } 150 | word[a] = ch; 151 | a++; 152 | if (a >= MAX_STRING - 1) a--; // Truncate too long words 153 | } 154 | word[a] = 0; 155 | } 156 | 157 | // Returns hash value of a word 158 | int GetWordHash(char *word) { 159 | unsigned long long a, hash = 0; 160 | for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a]; 161 | hash = hash % vocab_hash_size; 162 | return hash; 163 | } 164 | 165 | // Returns position of a word in the vocabulary; if the word is not found, returns -1 166 | int SearchVocab(char *word) { 167 | unsigned int hash = GetWordHash(word); 168 | while (1) { 169 | if (vocab_hash[hash] == -1) return -1; 170 | if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash]; 171 | hash = (hash + 1) % vocab_hash_size; 172 | } 173 | return -1; 174 | } 175 | 176 | // Reads a word and returns its index in the vocabulary 177 | int ReadWordIndex(FILE *fin, char *eof) { 178 | char word[MAX_STRING], eof_l = 0; 179 | ReadWord(word, fin, &eof_l); 180 | if (eof_l) { 181 | *eof = 1; 182 | return -1; 183 | } 184 | return SearchVocab(word); 185 | } 186 | 187 | // Adds a word to the vocabulary 188 | int AddWordToVocab(char *word) { 189 | unsigned int hash, length = strlen(word) + 1; 190 | if (length > MAX_STRING) length = MAX_STRING; 191 | vocab[vocab_size].word = (char *)calloc(length, sizeof(char)); 192 | strcpy(vocab[vocab_size].word, word); 193 | vocab[vocab_size].cn = 0; 194 | vocab_size++; 195 | // Reallocate memory if needed 196 | if (vocab_size + 2 >= vocab_max_size) { 197 | vocab_max_size += 1000; 198 | vocab = (struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word)); 199 | } 200 | hash = GetWordHash(word); 201 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 202 | vocab_hash[hash] = vocab_size - 1; 203 | return vocab_size - 1; 204 | } 205 | 206 | // Used later for sorting by word counts 207 | int VocabCompare(const void *a, const void *b) { 208 | long long l = ((struct vocab_word *)b)->cn - ((struct vocab_word *)a)->cn; 209 | if (l > 0) return 1; 210 | if (l < 0) return -1; 211 | return 0; 212 | } 213 | 214 | // Sorts the vocabulary by frequency using word counts 215 | void SortVocab() { 216 | int a, size; 217 | unsigned int hash; 218 | // Sort the vocabulary and keep at the first position 219 | qsort(&vocab[1], vocab_size - 1, sizeof(struct vocab_word), VocabCompare); 220 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 221 | size = vocab_size; 222 | train_words = 0; 223 | for (a = 0; a < size; a++) { 224 | // Words occuring less than min_count times will be discarded from the vocab 225 | if ((vocab[a].cn < min_count) && (a != 0)) { 226 | vocab_size--; 227 | free(vocab[a].word); 228 | } else { 229 | // Hash will be re-computed, as after the sorting it is not actual 230 | hash=GetWordHash(vocab[a].word); 231 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 232 | vocab_hash[hash] = a; 233 | train_words += vocab[a].cn; 234 | } 235 | } 236 | vocab = (struct vocab_word *)realloc(vocab, (vocab_size + 1) * sizeof(struct vocab_word)); 237 | // Allocate memory for the binary tree construction 238 | for (a = 0; a < vocab_size; a++) { 239 | vocab[a].code = (char *)calloc(MAX_CODE_LENGTH, sizeof(char)); 240 | vocab[a].point = (int *)calloc(MAX_CODE_LENGTH, sizeof(int)); 241 | } 242 | } 243 | 244 | // Reduces the vocabulary by removing infrequent tokens 245 | void ReduceVocab() { 246 | int a, b = 0; 247 | unsigned int hash; 248 | for (a = 0; a < vocab_size; a++) if (vocab[a].cn > min_reduce) { 249 | vocab[b].cn = vocab[a].cn; 250 | vocab[b].word = vocab[a].word; 251 | b++; 252 | } else free(vocab[a].word); 253 | vocab_size = b; 254 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 255 | for (a = 0; a < vocab_size; a++) { 256 | // Hash will be re-computed, as it is not actual 257 | hash = GetWordHash(vocab[a].word); 258 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size; 259 | vocab_hash[hash] = a; 260 | } 261 | fflush(stdout); 262 | min_reduce++; 263 | } 264 | 265 | void LearnVocabFromTrainFile() { 266 | char word[MAX_STRING], eof = 0; 267 | FILE *fin; 268 | long long a, i, wc = 0; 269 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 270 | fin = fopen(train_file, "rb"); 271 | if (fin == NULL) { 272 | printf("ERROR: training data file not found!\n"); 273 | exit(1); 274 | } 275 | vocab_size = 0; 276 | AddWordToVocab((char *)""); 277 | while (1) { 278 | ReadWord(word, fin, &eof); 279 | if (eof) break; 280 | train_words++; 281 | wc++; 282 | if ((debug_mode > 1) && (wc >= 1000000)) { 283 | printf("%lldM%c", train_words / 1000000, 13); 284 | fflush(stdout); 285 | wc = 0; 286 | } 287 | i = SearchVocab(word); 288 | if (i == -1) { 289 | a = AddWordToVocab(word); 290 | vocab[a].cn = 1; 291 | } else vocab[i].cn++; 292 | if (vocab_size > vocab_hash_size * 0.7) ReduceVocab(); 293 | } 294 | SortVocab(); 295 | if (debug_mode > 0) { 296 | printf("Vocab size: %lld\n", vocab_size); 297 | printf("Words in train file: %lld\n", train_words); 298 | } 299 | file_size = ftell(fin); 300 | fclose(fin); 301 | } 302 | 303 | void SaveVocab() { 304 | long long i; 305 | FILE *fo = fopen(save_vocab_file, "wb"); 306 | for (i = 0; i < vocab_size; i++) fprintf(fo, "%s %lld\n", vocab[i].word, vocab[i].cn); 307 | fclose(fo); 308 | } 309 | 310 | void ReadVocab() { 311 | long long a, i = 0; 312 | char c, eof = 0; 313 | char word[MAX_STRING]; 314 | FILE *fin = fopen(read_vocab_file, "rb"); 315 | if (fin == NULL) { 316 | printf("Vocabulary file not found\n"); 317 | exit(1); 318 | } 319 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1; 320 | vocab_size = 0; 321 | while (1) { 322 | ReadWord(word, fin, &eof); 323 | if (eof) break; 324 | a = AddWordToVocab(word); 325 | if (fscanf(fin, "%lld%c", &vocab[a].cn, &c)); 326 | i++; 327 | } 328 | SortVocab(); 329 | if (debug_mode > 0) { 330 | printf("Vocab size: %lld\n", vocab_size); 331 | printf("Words in train file: %lld\n", train_words); 332 | } 333 | fin = fopen(train_file, "rb"); 334 | if (fin == NULL) { 335 | printf("ERROR: training data file not found!\n"); 336 | exit(1); 337 | } 338 | fseek(fin, 0, SEEK_END); 339 | file_size = ftell(fin); 340 | fclose(fin); 341 | } 342 | 343 | void InitNet() { 344 | long long a, b; 345 | unsigned long long next_random = 1; 346 | a = posix_memalign((void **)&u, 128, (long long)vocab_size * layer1_size * sizeof(real)); 347 | if (u == NULL) {printf("Memory allocation failed\n"); exit(1);} 348 | a = posix_memalign((void **)&v, 128, (long long)vocab_size * layer1_size * sizeof(real)); 349 | if (v == NULL) {printf("Memory allocation failed\n"); exit(1);} 350 | for (a = 0; a < vocab_size; a++) { 351 | for (b = 0; b < layer1_size; b++) { 352 | next_random = next_random * (unsigned long long)25214903917 + 11; 353 | v[a * layer1_size + b] = (((next_random & 0xFFFF) / (real)65536) - 0.5) ; 354 | } 355 | } 356 | for (a = 0; a < vocab_size; a++) 357 | for (b = 0; b < layer1_size; b++) { 358 | next_random = next_random * (unsigned long long)25214903917 + 11; 359 | u[a * layer1_size + b] = (((next_random & 0xFFFF) / (real)65536) - 0.5) ; 360 | } 361 | } 362 | 363 | void *TrainModelThread(void *id) { 364 | long long a, b, c, d, cw, word, last_word, sentence_length = 0, sentence_position = 0; 365 | long long word_count = 0, last_word_count = 0, sen[MAX_SENTENCE_LENGTH + 1]; 366 | // local_iter = 1 to save the model every epoch 367 | long long l2, target, label, local_iter = 1; 368 | unsigned long long next_random = (long long)id; 369 | char eof = 0; 370 | int local_bitlevel = bitlevel; 371 | real f, g; 372 | clock_t now; 373 | real *context_avg = (real *)calloc(layer1_size, sizeof(real)); 374 | real *context_avge = (real *)calloc(layer1_size, sizeof(real)); 375 | double loss = 0, total_loss = 0; 376 | FILE *fi = fopen(train_file, "rb"); 377 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET); 378 | while (1) { 379 | if (word_count - last_word_count > 10000) { 380 | word_count_actual += word_count - last_word_count; 381 | last_word_count = word_count; 382 | if ((debug_mode > 1)) { 383 | now=clock(); 384 | printf("%cAlpha: %f Progress: %.2f%% Cost: %f Words/thread/sec: %.2fk ", 13, alpha, 385 | word_count_actual / (real)(iter * train_words + 1) * 100, 386 | loss, 387 | word_count_actual / ((real)(now - start + 1) / (real)CLOCKS_PER_SEC * 1000)); 388 | loss = 0; 389 | fflush(stdout); 390 | } 391 | alpha = starting_alpha * (1 - word_count_actual / (real)(iter * train_words + 1)); 392 | if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001; 393 | } 394 | if (sentence_length == 0) { 395 | while (1) { 396 | word = ReadWordIndex(fi, &eof); 397 | if (eof) break; 398 | if (word == -1) continue; 399 | word_count++; 400 | if (word == 0) break; 401 | // The subsampling randomly discards frequent words while keeping the ranking same 402 | if (sample > 0) { 403 | real ran = (sqrt(vocab[word].cn / (sample * train_words)) + 1) * 404 | (sample * train_words) / vocab[word].cn; 405 | next_random = next_random * (unsigned long long)25214903917 + 11; 406 | if (ran < (next_random & 0xFFFF) / (real)65536) continue; 407 | } 408 | sen[sentence_length] = word; 409 | sentence_length++; 410 | if (sentence_length >= MAX_SENTENCE_LENGTH) break; 411 | } 412 | sentence_position = 0; 413 | } 414 | if (eof || (word_count > train_words / num_threads)) { 415 | word_count_actual += word_count - last_word_count; 416 | local_iter--; 417 | if (local_iter == 0) break; 418 | word_count = 0; 419 | last_word_count = 0; 420 | sentence_length = 0; 421 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET); 422 | continue; 423 | } 424 | word = sen[sentence_position]; 425 | if (word == -1) continue; 426 | for (c = 0; c < layer1_size; c++) context_avg[c] = 0; 427 | for (c = 0; c < layer1_size; c++) context_avge[c] = 0; 428 | next_random = next_random * (unsigned long long)25214903917 + 11; 429 | b = next_random % window; 430 | cw = 0; 431 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) { 432 | c = sentence_position - window + a; 433 | if (c < 0) continue; 434 | if (c >= sentence_length) continue; 435 | last_word = sen[c]; 436 | if (last_word == -1) continue; 437 | real local_reg_loss = 0; 438 | for (c = 0; c < layer1_size; c++) { 439 | real cur_val = quantize(u[c + last_word * layer1_size], local_bitlevel); 440 | context_avg[c] += cur_val; 441 | local_reg_loss += cur_val * cur_val; 442 | } 443 | local_reg_loss = reg * local_reg_loss; 444 | loss += -local_reg_loss; 445 | total_loss += -local_reg_loss; 446 | cw++; 447 | } 448 | if (cw) { 449 | for (c = 0; c < layer1_size; c++) context_avg[c] /= cw; 450 | for (d = 0; d < negative + 1; d++) { 451 | if (d == 0) { 452 | target = word; 453 | label = 1; 454 | } else { 455 | next_random = next_random * (unsigned long long)25214903917 + 11; 456 | target = table[(next_random >> 16) % table_size]; 457 | if (target == 0) target = next_random % (vocab_size - 1) + 1; 458 | if (target == word) continue; 459 | label = 0; 460 | } 461 | l2 = target * layer1_size; 462 | f = 0; 463 | real local_reg_loss = 0; 464 | for (c = 0; c < layer1_size; c++) { 465 | real cur_val = quantize(v[c + l2], local_bitlevel); 466 | f += context_avg[c] * cur_val; 467 | 468 | // Keep track of regularization loss 469 | local_reg_loss += cur_val * cur_val; 470 | } 471 | local_reg_loss = reg * local_reg_loss; 472 | 473 | if (f > MAX_EXP) g = (label - 1) * alpha; 474 | else if (f < -MAX_EXP) g = (label - 0) * alpha; 475 | else g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha; 476 | 477 | //////////////////// 478 | // Compute loss 479 | //////////////////// 480 | real dot_product = f * pow(-1, 1-label); 481 | real local_loss = log(sigmoid(dot_product)); 482 | loss += local_loss - local_reg_loss; 483 | total_loss += local_loss - local_reg_loss; 484 | ///////////////////// 485 | 486 | for (c = 0; c < layer1_size; c++) { 487 | context_avge[c] += g * quantize(v[c + l2], local_bitlevel); 488 | } 489 | for (c = 0; c < layer1_size; c++) { 490 | v[c + l2] += g * context_avg[c] - 2*alpha*reg*v[c + l2]; 491 | } 492 | } 493 | // hidden -> in 494 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) { 495 | c = sentence_position - window + a; 496 | if (c < 0) continue; 497 | if (c >= sentence_length) continue; 498 | last_word = sen[c]; 499 | if (last_word == -1) continue; 500 | for (c = 0; c < layer1_size; c++) { 501 | u[c + last_word * layer1_size] += context_avge[c] - 2*alpha*reg*u[c+last_word*layer1_size]; 502 | } 503 | } 504 | } 505 | sentence_position++; 506 | if (sentence_position >= sentence_length) { 507 | sentence_length = 0; 508 | continue; 509 | } 510 | } 511 | thread_losses[(long long)id] = total_loss; 512 | fclose(fi); 513 | free(context_avg); 514 | free(context_avge); 515 | pthread_exit(NULL); 516 | } 517 | 518 | void TrainModel() { 519 | long a, b; 520 | FILE *fo; 521 | thread_losses = (double *)malloc(sizeof(double) * num_threads); 522 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t)); 523 | printf("Starting training using file %s\n", train_file); 524 | starting_alpha = alpha; 525 | if (read_vocab_file[0] != 0) ReadVocab(); else LearnVocabFromTrainFile(); 526 | if (save_vocab_file[0] != 0) SaveVocab(); 527 | if (output_file[0] == 0) return; 528 | InitNet(); 529 | if (negative > 0) InitUnigramTable(); 530 | 531 | start = clock(); 532 | for (int iteration = 0; iteration < iter; iteration++) { 533 | printf("Starting epoch: %d\n", iteration); 534 | memset(thread_losses, 0, sizeof(double) * num_threads); 535 | for (a = 0; a < num_threads; a++) pthread_create(&pt[a], NULL, TrainModelThread, (void *)a); 536 | for (a = 0; a < num_threads; a++) pthread_join(pt[a], NULL); 537 | double total_loss_epoch = 0; 538 | for (a = 0; a < num_threads; a++) total_loss_epoch += thread_losses[a]; 539 | printf("Epoch Loss: %lf\n", total_loss_epoch); 540 | char output_file_cur_iter[MAX_STRING] = {0}; 541 | sprintf(output_file_cur_iter, "%s_epoch%d", output_file, iteration); 542 | if (classes == 0 && save_every_epoch) { 543 | fo = fopen(output_file_cur_iter, "wb"); 544 | // Save the word vectors 545 | fprintf(fo, "%lld %lld\n", vocab_size, layer1_size); 546 | for (a = 0; a < vocab_size; a++) { 547 | fprintf(fo, "%s ", vocab[a].word); 548 | for (b = 0; b < layer1_size; b++) { 549 | float avg = u[a*layer1_size+b] + v[a*layer1_size+b]; 550 | avg = quantize(avg, bitlevel); 551 | if (binary) fwrite(&avg, sizeof(float), 1, fo); 552 | else fprintf(fo, "%lf ", avg); 553 | } 554 | fprintf(fo, "\n"); 555 | } 556 | fclose(fo); 557 | } 558 | } 559 | 560 | // Write an extra file 561 | fo = fopen(output_file, "wb"); 562 | if (classes == 0) { 563 | // Save the word vectors 564 | fprintf(fo, "%lld %lld\n", vocab_size, layer1_size); 565 | for (a = 0; a < vocab_size; a++) { 566 | fprintf(fo, "%s ", vocab[a].word); 567 | for (b = 0; b < layer1_size; b++) { 568 | float avg = u[a*layer1_size+b] + v[a*layer1_size+b]; 569 | avg = quantize(avg, bitlevel); 570 | if (binary) fwrite(&avg, sizeof(float), 1, fo); 571 | else fprintf(fo, "%lf ", avg); 572 | } 573 | fprintf(fo, "\n"); 574 | } 575 | } 576 | fclose(fo); 577 | } 578 | 579 | int ArgPos(char *str, int argc, char **argv) { 580 | int a; 581 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) { 582 | if (a == argc - 1) { 583 | printf("Argument missing for %s\n", str); 584 | exit(1); 585 | } 586 | return a; 587 | } 588 | return -1; 589 | } 590 | 591 | int main(int argc, char **argv) { 592 | int i; 593 | output_file[0] = 0; 594 | save_vocab_file[0] = 0; 595 | read_vocab_file[0] = 0; 596 | if ((i = ArgPos((char *)"-save-every-epoch", argc, argv)) > 0) save_every_epoch = atoi(argv[i + 1]); 597 | if ((i = ArgPos((char *)"-bitlevel", argc, argv)) > 0) bitlevel = atoi(argv[i + 1]); 598 | if ((i = ArgPos((char *)"-size", argc, argv)) > 0) layer1_size = atoi(argv[i + 1]); 599 | if ((i = ArgPos((char *)"-reg", argc, argv)) > 0) reg = atof(argv[i + 1]); 600 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]); 601 | if ((i = ArgPos((char *)"-debug", argc, argv)) > 0) debug_mode = atoi(argv[i + 1]); 602 | if ((i = ArgPos((char *)"-binary", argc, argv)) > 0) binary = atoi(argv[i + 1]); 603 | if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) alpha = atof(argv[i + 1]); 604 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]); 605 | if ((i = ArgPos((char *)"-window", argc, argv)) > 0) window = atoi(argv[i + 1]); 606 | if ((i = ArgPos((char *)"-sample", argc, argv)) > 0) sample = atof(argv[i + 1]); 607 | if ((i = ArgPos((char *)"-negative", argc, argv)) > 0) negative = atoi(argv[i + 1]); 608 | if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) num_threads = atoi(argv[i + 1]); 609 | if ((i = ArgPos((char *)"-iter", argc, argv)) > 0) iter = atoi(argv[i + 1]); 610 | if ((i = ArgPos((char *)"-min-count", argc, argv)) > 0) min_count = atoi(argv[i + 1]); 611 | if ((i = ArgPos((char *)"-classes", argc, argv)) > 0) classes = atoi(argv[i + 1]); 612 | vocab = (struct vocab_word *)calloc(vocab_max_size, sizeof(struct vocab_word)); 613 | vocab_hash = (int *)calloc(vocab_hash_size, sizeof(int)); 614 | expTable = (real *)malloc((EXP_TABLE_SIZE + 1) * sizeof(real)); 615 | for (i = 0; i < EXP_TABLE_SIZE; i++) { 616 | expTable[i] = exp((i / (real)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); // Precompute the exp() table 617 | expTable[i] = expTable[i] / (expTable[i] + 1); // Precompute f(x) = x / (x + 1) 618 | } 619 | TrainModel(); 620 | return 0; 621 | } 622 | --------------------------------------------------------------------------------