├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── data
├── download_text8.sh
├── google_analogies_test_set
│ └── questions-words.txt
└── wikifil.pl
├── images
├── visualize_nearest_man.png
└── visualize_nearest_science.png
└── src
├── compute-accuracy.c
└── word2bits.cpp
/.gitignore:
--------------------------------------------------------------------------------
1 | *.#*
2 | *~
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | CC_MIKOLOV=g++
2 | CFLAGS=-O3 -march=native -lm -pthread -Wno-unused-result
3 |
4 | word2bits:
5 | $(CC_MIKOLOV) $(CFLAGS) ./src/word2bits.cpp -o word2bits
6 | compute_accuracy:
7 | $(CC_MIKOLOV) $(CFLAGS) ./src/compute-accuracy.c -o compute_accuracy
8 | clean:
9 | rm -f word2bits
10 | rm -f compute_accuracy
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Word2Bits - Quantized Word Vectors
2 |
3 | Word2Bits extends the Word2Vec algorithm to output high quality
4 | quantized word vectors that take 8x-16x less storage than
5 | regular word vectors. Read the details at [https://arxiv.org/abs/1803.05651](https://arxiv.org/abs/1803.05651).
6 |
7 | ## What are Quantized Word Vectors?
8 |
9 | Quantized word vectors are word vectors where each parameter
10 | is one of `2^bitlevel` values.
11 |
12 | For example, the 1-bit quantized vector for "king" looks something
13 | like
14 |
15 | ```
16 | 0.33333334 0.33333334 0.33333334 -0.33333334 -0.33333334 -0.33333334 0.33333334 0.33333334 -0.33333334 0.33333334 0.33333334 ...
17 | ```
18 |
19 | Since parameters are limited to one of `2^bitlevel` values, each parameter
20 | takes only `bitlevel` bits to represent; this drastically reduces
21 | the amount of storage that word vectors take.
22 |
23 | ## Download Pretrained Word Vectors
24 |
25 | - All word vectors are in Glove/Fasttext format (format details [here](https://fasttext.cc/docs/en/english-vectors.html)). Files are compressed using gzip.
26 |
27 | | # Bits per parameter | Dimension | Trained on | Vocabulary size | File Size (Compressed) | Download Link |
28 | |:---------------------------:|:-------------:|:----------------------:|:----------------:|:----------------------:|:-------------:|
29 | | 1 | 800 | English Wikipedia 2017 | Top 400k | 86M | [w2b_bitlevel1_size800_vocab400K.tar.gz](https://drive.google.com/open?id=107guTTy93J-y7UCO2ZA2spxRIFpoqhjh) |
30 | | 1 | 1000 | English Wikipedia 2017 | Top 400k | 106M | [w2b_bitlevel1_size1000_vocab400K.tar.gz](https://drive.google.com/open?id=1boP7aFnABifVKRD9M-lxLqH-BM-C5-Z6) |
31 | | 1 | 1200 | English Wikipedia 2017 | Top 400k | 126M | [w2b_bitlevel1_size1200_vocab400K.tar.gz](https://drive.google.com/open?id=1zmoHFd9KqsCvuvYqpMl0Si21wn9IHMq2) |
32 | | 2 | 400 | English Wikipedia 2017 | Top 400k | 67M | [w2b_bitlevel2_size400_vocab400K.tar.gz](https://drive.google.com/open?id=1KHNDZW9dawwy9Ie73fdnMKAcGfadTI5J) |
33 | | 2 | 800 | English Wikipedia 2017 | Top 400k | 134M | [w2b_bitlevel2_size800_vocab400K.tar.gz](https://drive.google.com/open?id=1l3G4tyI8mU7bGsMG0TTPiM4fucmniJaR) |
34 | | 2 | 1000 | English Wikipedia 2017 | Top 400k | 168M | [w2b_bitlevel2_size1000_vocab400K.tar.gz](https://drive.google.com/open?id=1RX5z-jjpylAKTxpVazWqQmkDZ0XnumsB) |
35 | | 32 | 200 | English Wikipedia 2017 | Top 400k | 364M | [w2b_bitlevel0_size200_vocab400K.tar.gz](https://drive.google.com/open?id=1HKiDirbJ9oxJN1HXGdczvmjWTIazE0Gb) |
36 | | 32 | 400 | English Wikipedia 2017 | Top 400k | 724M | [w2b_bitlevel0_size400_vocab400K.tar.gz](https://drive.google.com/open?id=1ToIOpo0uhfGG48qsOZeDacPmt7Sh0Uup) |
37 | | 32 | 800 | English Wikipedia 2017 | Top 400k | 1.4G | [w2b_bitlevel0_size800_vocab400K.tar.gz](https://drive.google.com/open?id=1IMev4MIQKSx5CPgGhTxZo2EJ7nsVEVGT) |
38 | | 32 | 1000 | English Wikipedia 2017 | Top 400k | 1.8G | [w2b_bitlevel0_size1000_vocab400K.tar.gz](https://drive.google.com/open?id=1CtNjaQqK2Aw-iIeqRXdJOVNIqeALzOTi) |
39 | | 1 | 800 | English Wikipedia 2017 | 3.7M (Full) | 812M | [w2b_bitlevel1_size800_vocab3.7M.tar.gz](https://drive.google.com/open?id=1fisO5pl3KbP5DEGqb3-b8RxOsbqzquZE) |
40 | | 2 | 400 | English Wikipedia 2017 | 3.7M (Full) | 671M | [w2b_bitlevel2_size400_vocab3.7M.tar.gz](https://drive.google.com/open?id=139YwOxwhoIgKACXUJnfOdecxIueEkwf9) |
41 | | 32 | 400 | English Wikipedia 2017 | 3.7M (Full) | 6.7G | [w2b_bitlevel0_size400_vocab3.7M.tar.gz](https://drive.google.com/open?id=1zyizh_oJ3RHtdaHdT_V7eQUTwETuXKm6) |
42 |
43 | ## Visualizing Quantized Word Vectors
44 |
45 |
46 |
47 | (Note: every 5 word vectors are labelled; turquoise line boundary between nearest and furthest word vectors from target.)
48 |
49 | ## Using the Code
50 |
51 | ### Quickstart
52 |
53 | Compile with
54 | ```
55 | make word2bits
56 | ```
57 |
58 | Run with
59 | ```
60 | ./word2bits -train input -bitlevel 1 -size 200 -window 10 -negative 12 -threads 2 -iter 5 -min-count 5 -output 1bit_200d_vectors -binary 0
61 | ```
62 | Description of the most common flags:
63 | ```
64 | -train Input corpus text file
65 | -bitlevel Number of bits for each parameter. 0 is full precision (or 32 bits).
66 | -size Word vector dimension
67 | -window Window size
68 | -negative Negative sample size
69 | -threads Number of threads to use to train
70 | -iter Number of epochs to train
71 | -min-count Minimum count value. Words appearing less than value are removed from corpus.
72 | -output Path to write output word vectors
73 | -binary 0 to write in Glove format; 1 to write in binary format.
74 | ```
75 |
76 | ### Example: Word2Bits on text8
77 |
78 | 1. Download and preprocess text8 (make sure you're in the Word2Bits base directory).
79 | ```
80 | bash data/download_text8.sh
81 | ```
82 |
83 | 2. Compile Word2Bits and compute accuracy
84 | ```
85 | make word2bits
86 | ```
87 | ```
88 | make compute_accuracy
89 | ```
90 |
91 | 3. Train 1 bit 200 dimensional word vectors for 5 epochs using 4 threads (save in binary so that compute_accuracy can work with it)
92 | ```
93 | ./word2bits -bitlevel 1 -size 200 -window 8 -negative 24 -threads 4 -iter 5 -min-count 5 -train text8 -output 1b200d_vectors -binary 1
94 | ```
95 |
96 | (This will take several minutes. Run with more threads if you have more cores!)
97 |
98 | 4. Evaluate vectors on Google Analogy Task
99 | ```
100 | ./compute_accuracy ./1b200d_vectors < data/google_analogies_test_set/questions-words.txt
101 | ```
102 |
103 | You should see output like:
104 | ```
105 | Starting eval...
106 | capital-common-countries:
107 | ACCURACY TOP1: 19.76 % (100 / 506)
108 | Total accuracy: 19.76 % Semantic accuracy: 19.76 % Syntactic accuracy: -nan %
109 | capital-world:
110 | ACCURACY TOP1: 8.81 % (239 / 2713)
111 | Total accuracy: 10.53 % Semantic accuracy: 10.53 % Syntactic accuracy: -nan %
112 | ...
113 | gram8-plural:
114 | ACCURACY TOP1: 19.92 % (251 / 1260)
115 | Total accuracy: 11.48 % Semantic accuracy: 13.27 % Syntactic accuracy: 10.25 %
116 | gram9-plural-verbs:
117 | ACCURACY TOP1: 6.09 % (53 / 870)
118 | Total accuracy: 11.20 % Semantic accuracy: 13.27 % Syntactic accuracy: 9.88 %
119 | Questions seen / total: 16284 19544 83.32 %
120 | ```
121 |
122 | Inspecting the vector file in hex should show something like:
123 | ```
124 | $ od --format=x1 --read-bytes=160 1b200d_vectors
125 | 0000000 36 30 32 33 38 20 32 30 30 0a 3c 2f 73 3e 20 ab
126 | 0000020 aa aa 3e ab aa aa 3e ab aa aa be ab aa aa be ab
127 | 0000040 aa aa 3e ab aa aa 3e ab aa aa 3e ab aa aa be ab
128 | ...
129 | 0000160 aa aa be ab aa aa 3e ab aa aa be ab aa aa 3e ab
130 | 0000200 aa aa be ab aa aa 3e ab aa aa 3e ab aa aa 3e ab
131 | 0000220 aa aa be ab aa aa 3e ab aa aa be ab aa aa 3e ab
132 | ```
133 |
--------------------------------------------------------------------------------
/data/download_text8.sh:
--------------------------------------------------------------------------------
1 | wget http://mattmahoney.net/dc/enwik8.zip
2 | unzip enwik8.zip
3 | perl data/wikifil.pl enwik8 > text8
4 | rm -f enwik8.zip
5 | rm -f enwik8
6 |
--------------------------------------------------------------------------------
/data/wikifil.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl
2 |
3 | # Program to filter Wikipedia XML dumps to "clean" text consisting only of lowercase
4 | # letters (a-z, converted from A-Z), and spaces (never consecutive).
5 | # All other characters are converted to spaces. Only text which normally appears
6 | # in the web browser is displayed. Tables are removed. Image captions are
7 | # preserved. Links are converted to normal text. Digits are spelled out.
8 |
9 | # Written by Matt Mahoney, June 10, 2006. This program is released to the public domain.
10 |
11 | $/=">"; # input record separator
12 | while (<>) {
13 | if (/ ...
14 | if (/#redirect/i) {$text=0;} # remove #REDIRECT
15 | if ($text) {
16 |
17 | # Remove any text not normally visible
18 | if (/<\/text>/) {$text=0;}
19 | s/<.*>//; # remove xml tags
20 | s/&/&/g; # decode URL encoded chars
21 | s/<//g;
23 | s/[//g; # remove references ... ]
24 | s/<[^>]*>//g; # remove xhtml tags
25 | s/\[http:[^] ]*/[/g; # remove normal url, preserve visible text
26 | s/\|thumb//ig; # remove images links, preserve caption
27 | s/\|left//ig;
28 | s/\|right//ig;
29 | s/\|\d+px//ig;
30 | s/\[\[image:[^\[\]]*\|//ig;
31 | s/\[\[category:([^|\]]*)[^]]*\]\]/[[$1]]/ig; # show categories without markup
32 | s/\[\[[a-z\-]*:[^\]]*\]\]//g; # remove links to other languages
33 | s/\[\[[^\|\]]*\|/[[/g; # remove wiki url, preserve visible text
34 | s/\{\{[^\}]*\}\}//g; # remove {{icons}} and {tables}
35 | s/\{[^\}]*\}//g;
36 | s/\[//g; # remove [ and ]
37 | s/\]//g;
38 | s/&[^;]*;/ /g; # remove URL encoded chars
39 |
40 | # convert to lowercase letters and spaces, spell digits
41 | $_=" $_ ";
42 | tr/A-Z/a-z/;
43 | s/0/ zero /g;
44 | s/1/ one /g;
45 | s/2/ two /g;
46 | s/3/ three /g;
47 | s/4/ four /g;
48 | s/5/ five /g;
49 | s/6/ six /g;
50 | s/7/ seven /g;
51 | s/8/ eight /g;
52 | s/9/ nine /g;
53 | tr/a-z/ /cs;
54 | chop;
55 | print $_;
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/images/visualize_nearest_man.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/agnusmaximus/Word2Bits/d029cca3fac2c06e1928d2a71462530b2cbe54cd/images/visualize_nearest_man.png
--------------------------------------------------------------------------------
/images/visualize_nearest_science.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/agnusmaximus/Word2Bits/d029cca3fac2c06e1928d2a71462530b2cbe54cd/images/visualize_nearest_science.png
--------------------------------------------------------------------------------
/src/compute-accuracy.c:
--------------------------------------------------------------------------------
1 | // Copyright 2013 Google Inc. All Rights Reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 |
22 | const long long max_size = 2000; // max length of strings
23 | const long long N = 1; // number of closest words
24 | const long long max_w = 50; // max length of vocabulary entries
25 |
26 | float quantize(float num, int bitlevel) {
27 |
28 | if (bitlevel == 0) {
29 | // Special bitlevel 0 => full precision
30 | return num;
31 | }
32 |
33 | // Extract sign
34 | float retval = 0;
35 | float sign = num < 0 ? -1 : 1;
36 | num *= sign;
37 |
38 | // Boundaries: 0
39 | if (bitlevel == 1) {
40 | return sign / 3;
41 | }
42 |
43 | // Determine boundary and discrete activation value (2 bits)
44 | // Boundaries: 0, .5
45 | if (bitlevel == 2) {
46 | if (num >= 0 && num <= .5) retval = .25;
47 | else retval = .75;
48 | }
49 |
50 | // Determine boundary and discrete activation value (4 bits = 16 values)
51 | // Boundaries: 0, .1, .2, .3, .4, .5, .6, .7, .8
52 | //real boundaries[] = {0, .25, .5, .75, 1, 1.25, 1.5, 1.75};
53 | if (bitlevel >= 4) {
54 | int segmentation = pow(2, bitlevel-1);
55 | int casted = (num * segmentation) + (float).5;
56 | casted = casted > segmentation ? segmentation : casted;
57 | retval = casted / (float)segmentation;
58 | }
59 |
60 | return sign * retval;
61 | }
62 |
63 | int main(int argc, char **argv)
64 | {
65 | FILE *f;
66 | char st1[max_size], st2[max_size], st3[max_size], st4[max_size], bestw[N][max_size], file_name[max_size];
67 | float dist, len, bestd[N], vec[max_size];
68 | long long words, size, a, b, c, d, b1, b2, b3, threshold = 0;
69 | int bitlevel = 0;
70 | float *M;
71 | char *vocab;
72 | int TCN, CCN = 0, TACN = 0, CACN = 0, SECN = 0, SYCN = 0, SEAC = 0, SYAC = 0, QID = 0, TQ = 0, TQS = 0;
73 | if (argc < 2) {
74 | printf("Usage: ./compute-accuracy \nwhere FILE contains word projections, and threshold is used to reduce vocabulary of the model for fast approximate evaluation (0 = off, otherwise typical value is 30000)\n");
75 | return 0;
76 | }
77 | strcpy(file_name, argv[1]);
78 | if (argc > 2) bitlevel = atoi(argv[2]);
79 | if (argc > 3) threshold = atoi(argv[3]);
80 | f = fopen(file_name, "rb");
81 | if (f == NULL) {
82 | printf("Input file not found\n");
83 | return -1;
84 | }
85 | fscanf(f, "%lld", &words);
86 | if (threshold) if (words > threshold) words = threshold;
87 | fscanf(f, "%lld", &size);
88 | vocab = (char *)malloc(words * max_w * sizeof(char));
89 | M = (float *)malloc(words * size * sizeof(float));
90 | if (M == NULL) {
91 | printf("Cannot allocate memory: %lld MB\n", words * size * sizeof(float) / 1048576);
92 | return -1;
93 | }
94 | printf("Starting eval...\n");
95 | fflush(stdout);
96 | for (b = 0; b < words; b++) {
97 | a = 0;
98 | while (1) {
99 | vocab[b * max_w + a] = fgetc(f);
100 | if (feof(f) || (vocab[b * max_w + a] == ' ')) break;
101 | if ((a < max_w) && (vocab[b * max_w + a] != '\n')) a++;
102 | }
103 | vocab[b * max_w + a] = 0;
104 | for (a = 0; a < max_w; a++) vocab[b * max_w + a] = toupper(vocab[b * max_w + a]);
105 | for (a = 0; a < size; a++) fread(&M[a + b * size], sizeof(float), 1, f);
106 | for (a = 0; a < size; a++) M[a+b*size] = quantize(M[a+b*size], bitlevel);
107 | len = 0;
108 | for (a = 0; a < size; a++) len += M[a + b * size] * M[a + b * size];
109 | len = sqrt(len);
110 | for (a = 0; a < size; a++) M[a + b * size] /= len;
111 | }
112 | fclose(f);
113 | TCN = 0;
114 | while (1) {
115 | for (a = 0; a < N; a++) bestd[a] = 0;
116 | for (a = 0; a < N; a++) bestw[a][0] = 0;
117 | scanf("%s", st1);
118 | for (a = 0; a < strlen(st1); a++) st1[a] = toupper(st1[a]);
119 | if ((!strcmp(st1, ":")) || (!strcmp(st1, "EXIT")) || feof(stdin)) {
120 | if (TCN == 0) TCN = 1;
121 | if (QID != 0) {
122 | printf("ACCURACY TOP1: %.2f %% (%d / %d)\n", CCN / (float)TCN * 100, CCN, TCN);
123 | printf("Total accuracy: %.2f %% Semantic accuracy: %.2f %% Syntactic accuracy: %.2f %% \n", CACN / (float)TACN * 100, SEAC / (float)SECN * 100, SYAC / (float)SYCN * 100);
124 | }
125 | QID++;
126 | scanf("%s", st1);
127 | if (feof(stdin)) break;
128 | printf("%s:\n", st1);
129 | TCN = 0;
130 | CCN = 0;
131 | continue;
132 | }
133 | if (!strcmp(st1, "EXIT")) break;
134 | scanf("%s", st2);
135 | for (a = 0; a < strlen(st2); a++) st2[a] = toupper(st2[a]);
136 | scanf("%s", st3);
137 | for (a = 0; a bestd[a]) {
168 | for (d = N - 1; d > a; d--) {
169 | bestd[d] = bestd[d - 1];
170 | strcpy(bestw[d], bestw[d - 1]);
171 | }
172 | bestd[a] = dist;
173 | strcpy(bestw[a], &vocab[c * max_w]);
174 | break;
175 | }
176 | }
177 | }
178 | if (!strcmp(st4, bestw[0])) {
179 | CCN++;
180 | CACN++;
181 | if (QID <= 5) SEAC++; else SYAC++;
182 | }
183 | if (QID <= 5) SECN++; else SYCN++;
184 | TCN++;
185 | TACN++;
186 | }
187 | printf("Questions seen / total: %d %d %.2f %% \n", TQS, TQ, TQS/(float)TQ*100);
188 | return 0;
189 | }
190 |
--------------------------------------------------------------------------------
/src/word2bits.cpp:
--------------------------------------------------------------------------------
1 | // Copyright 2013 Google Inc. All Rights Reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 |
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 |
26 | using namespace std;
27 | typedef numeric_limits< double > dbl;
28 |
29 | #define MAX_STRING 4096
30 | #define EXP_TABLE_SIZE 1000
31 | #define MAX_EXP 6
32 | #define MAX_SENTENCE_LENGTH 1000
33 | #define MAX_CODE_LENGTH 40
34 |
35 | const int vocab_hash_size = 30000000; // Maximum 30 * 0.7 = 21M words in the vocabulary
36 |
37 | typedef float real; // Precision of float numbers
38 |
39 | struct vocab_word {
40 | long long cn;
41 | int *point;
42 | char *word, *code, codelen;
43 | };
44 |
45 | char train_file[MAX_STRING], output_file[MAX_STRING];
46 | char save_vocab_file[MAX_STRING], read_vocab_file[MAX_STRING];
47 | struct vocab_word *vocab;
48 | int binary = 0, debug_mode = 2, window = 5, min_count = 5, num_threads = 12, min_reduce = 1, bitlevel = 1;
49 | int *vocab_hash;
50 | long long vocab_max_size = 1000, vocab_size = 0, layer1_size = 100;
51 | long long train_words = 0, word_count_actual = 0, iter = 5, file_size = 0, classes = 0;
52 | bool save_every_epoch = 0;
53 | real alpha = 0.05, starting_alpha, sample = 1e-3;
54 | real reg = 0;
55 | double *thread_losses;
56 | real *u, *v, *expTable;
57 | clock_t start;
58 |
59 | int hs = 0, negative = 5;
60 | const int table_size = 1e8;
61 | int *table;
62 |
63 | ///////////////////////////////////////////
64 | // Word2Bits //
65 | ///////////////////////////////////////////
66 |
67 | real sigmoid(real val) {
68 | if (val > MAX_EXP) return 1;
69 | if (val < -MAX_EXP) return 1e-9;
70 | return 1 / (1 + (real)exp(-val));
71 | }
72 |
73 | real quantize(real num, int bitlevel) {
74 |
75 | if (bitlevel == 0) {
76 | // Special bitlevel 0 => full precision
77 | return num;
78 | }
79 |
80 | // Extract sign
81 | real retval = 0;
82 | real sign = num < 0 ? -1 : 1;
83 | num *= sign;
84 |
85 | // Boundaries: 0
86 | if (bitlevel == 1) {
87 | return sign / 3;
88 | }
89 |
90 | // Determine boundary and discrete activation value (2 bits)
91 | // Boundaries: 0, .5
92 | if (bitlevel == 2) {
93 | if (num >= 0 && num <= .5) retval = .25;
94 | else retval = .75;
95 | }
96 |
97 | // Determine boundary and discrete activation value (4 bits = 16 values)
98 | // Boundaries: 0, .1, .2, .3, .4, .5, .6, .7, .8
99 | //real boundaries[] = {0, .25, .5, .75, 1, 1.25, 1.5, 1.75};
100 | if (bitlevel >= 4) {
101 | int segmentation = pow(2, bitlevel-1);
102 | int casted = (num * segmentation) + (real).5;
103 | casted = casted > segmentation ? segmentation : casted;
104 | retval = casted / (real)segmentation;
105 | }
106 |
107 | return sign * retval;
108 | }
109 |
110 | /////////////////////////////////////////////////////////////
111 |
112 | void InitUnigramTable() {
113 | int a, i;
114 | double train_words_pow = 0;
115 | double d1, power = 0.75;
116 | table = (int *)malloc(table_size * sizeof(int));
117 | for (a = 0; a < vocab_size; a++) train_words_pow += pow(vocab[a].cn, power);
118 | i = 0;
119 | d1 = pow(vocab[i].cn, power) / train_words_pow;
120 | for (a = 0; a < table_size; a++) {
121 | table[a] = i;
122 | if (a / (double)table_size > d1) {
123 | i++;
124 | d1 += pow(vocab[i].cn, power) / train_words_pow;
125 | }
126 | if (i >= vocab_size) i = vocab_size - 1;
127 | }
128 | }
129 |
130 | // Reads a single word from a file, assuming space + tab + EOL to be word boundaries
131 | void ReadWord(char *word, FILE *fin, char *eof) {
132 | int a = 0, ch;
133 | while (1) {
134 | ch = fgetc(fin);
135 | if (ch == EOF) {
136 | *eof = 1;
137 | break;
138 | }
139 | if (ch == 13) continue;
140 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) {
141 | if (a > 0) {
142 | if (ch == '\n') ungetc(ch, fin);
143 | break;
144 | }
145 | if (ch == '\n') {
146 | strcpy(word, (char *)"");
147 | return;
148 | } else continue;
149 | }
150 | word[a] = ch;
151 | a++;
152 | if (a >= MAX_STRING - 1) a--; // Truncate too long words
153 | }
154 | word[a] = 0;
155 | }
156 |
157 | // Returns hash value of a word
158 | int GetWordHash(char *word) {
159 | unsigned long long a, hash = 0;
160 | for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a];
161 | hash = hash % vocab_hash_size;
162 | return hash;
163 | }
164 |
165 | // Returns position of a word in the vocabulary; if the word is not found, returns -1
166 | int SearchVocab(char *word) {
167 | unsigned int hash = GetWordHash(word);
168 | while (1) {
169 | if (vocab_hash[hash] == -1) return -1;
170 | if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash];
171 | hash = (hash + 1) % vocab_hash_size;
172 | }
173 | return -1;
174 | }
175 |
176 | // Reads a word and returns its index in the vocabulary
177 | int ReadWordIndex(FILE *fin, char *eof) {
178 | char word[MAX_STRING], eof_l = 0;
179 | ReadWord(word, fin, &eof_l);
180 | if (eof_l) {
181 | *eof = 1;
182 | return -1;
183 | }
184 | return SearchVocab(word);
185 | }
186 |
187 | // Adds a word to the vocabulary
188 | int AddWordToVocab(char *word) {
189 | unsigned int hash, length = strlen(word) + 1;
190 | if (length > MAX_STRING) length = MAX_STRING;
191 | vocab[vocab_size].word = (char *)calloc(length, sizeof(char));
192 | strcpy(vocab[vocab_size].word, word);
193 | vocab[vocab_size].cn = 0;
194 | vocab_size++;
195 | // Reallocate memory if needed
196 | if (vocab_size + 2 >= vocab_max_size) {
197 | vocab_max_size += 1000;
198 | vocab = (struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word));
199 | }
200 | hash = GetWordHash(word);
201 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
202 | vocab_hash[hash] = vocab_size - 1;
203 | return vocab_size - 1;
204 | }
205 |
206 | // Used later for sorting by word counts
207 | int VocabCompare(const void *a, const void *b) {
208 | long long l = ((struct vocab_word *)b)->cn - ((struct vocab_word *)a)->cn;
209 | if (l > 0) return 1;
210 | if (l < 0) return -1;
211 | return 0;
212 | }
213 |
214 | // Sorts the vocabulary by frequency using word counts
215 | void SortVocab() {
216 | int a, size;
217 | unsigned int hash;
218 | // Sort the vocabulary and keep at the first position
219 | qsort(&vocab[1], vocab_size - 1, sizeof(struct vocab_word), VocabCompare);
220 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
221 | size = vocab_size;
222 | train_words = 0;
223 | for (a = 0; a < size; a++) {
224 | // Words occuring less than min_count times will be discarded from the vocab
225 | if ((vocab[a].cn < min_count) && (a != 0)) {
226 | vocab_size--;
227 | free(vocab[a].word);
228 | } else {
229 | // Hash will be re-computed, as after the sorting it is not actual
230 | hash=GetWordHash(vocab[a].word);
231 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
232 | vocab_hash[hash] = a;
233 | train_words += vocab[a].cn;
234 | }
235 | }
236 | vocab = (struct vocab_word *)realloc(vocab, (vocab_size + 1) * sizeof(struct vocab_word));
237 | // Allocate memory for the binary tree construction
238 | for (a = 0; a < vocab_size; a++) {
239 | vocab[a].code = (char *)calloc(MAX_CODE_LENGTH, sizeof(char));
240 | vocab[a].point = (int *)calloc(MAX_CODE_LENGTH, sizeof(int));
241 | }
242 | }
243 |
244 | // Reduces the vocabulary by removing infrequent tokens
245 | void ReduceVocab() {
246 | int a, b = 0;
247 | unsigned int hash;
248 | for (a = 0; a < vocab_size; a++) if (vocab[a].cn > min_reduce) {
249 | vocab[b].cn = vocab[a].cn;
250 | vocab[b].word = vocab[a].word;
251 | b++;
252 | } else free(vocab[a].word);
253 | vocab_size = b;
254 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
255 | for (a = 0; a < vocab_size; a++) {
256 | // Hash will be re-computed, as it is not actual
257 | hash = GetWordHash(vocab[a].word);
258 | while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
259 | vocab_hash[hash] = a;
260 | }
261 | fflush(stdout);
262 | min_reduce++;
263 | }
264 |
265 | void LearnVocabFromTrainFile() {
266 | char word[MAX_STRING], eof = 0;
267 | FILE *fin;
268 | long long a, i, wc = 0;
269 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
270 | fin = fopen(train_file, "rb");
271 | if (fin == NULL) {
272 | printf("ERROR: training data file not found!\n");
273 | exit(1);
274 | }
275 | vocab_size = 0;
276 | AddWordToVocab((char *)"");
277 | while (1) {
278 | ReadWord(word, fin, &eof);
279 | if (eof) break;
280 | train_words++;
281 | wc++;
282 | if ((debug_mode > 1) && (wc >= 1000000)) {
283 | printf("%lldM%c", train_words / 1000000, 13);
284 | fflush(stdout);
285 | wc = 0;
286 | }
287 | i = SearchVocab(word);
288 | if (i == -1) {
289 | a = AddWordToVocab(word);
290 | vocab[a].cn = 1;
291 | } else vocab[i].cn++;
292 | if (vocab_size > vocab_hash_size * 0.7) ReduceVocab();
293 | }
294 | SortVocab();
295 | if (debug_mode > 0) {
296 | printf("Vocab size: %lld\n", vocab_size);
297 | printf("Words in train file: %lld\n", train_words);
298 | }
299 | file_size = ftell(fin);
300 | fclose(fin);
301 | }
302 |
303 | void SaveVocab() {
304 | long long i;
305 | FILE *fo = fopen(save_vocab_file, "wb");
306 | for (i = 0; i < vocab_size; i++) fprintf(fo, "%s %lld\n", vocab[i].word, vocab[i].cn);
307 | fclose(fo);
308 | }
309 |
310 | void ReadVocab() {
311 | long long a, i = 0;
312 | char c, eof = 0;
313 | char word[MAX_STRING];
314 | FILE *fin = fopen(read_vocab_file, "rb");
315 | if (fin == NULL) {
316 | printf("Vocabulary file not found\n");
317 | exit(1);
318 | }
319 | for (a = 0; a < vocab_hash_size; a++) vocab_hash[a] = -1;
320 | vocab_size = 0;
321 | while (1) {
322 | ReadWord(word, fin, &eof);
323 | if (eof) break;
324 | a = AddWordToVocab(word);
325 | if (fscanf(fin, "%lld%c", &vocab[a].cn, &c));
326 | i++;
327 | }
328 | SortVocab();
329 | if (debug_mode > 0) {
330 | printf("Vocab size: %lld\n", vocab_size);
331 | printf("Words in train file: %lld\n", train_words);
332 | }
333 | fin = fopen(train_file, "rb");
334 | if (fin == NULL) {
335 | printf("ERROR: training data file not found!\n");
336 | exit(1);
337 | }
338 | fseek(fin, 0, SEEK_END);
339 | file_size = ftell(fin);
340 | fclose(fin);
341 | }
342 |
343 | void InitNet() {
344 | long long a, b;
345 | unsigned long long next_random = 1;
346 | a = posix_memalign((void **)&u, 128, (long long)vocab_size * layer1_size * sizeof(real));
347 | if (u == NULL) {printf("Memory allocation failed\n"); exit(1);}
348 | a = posix_memalign((void **)&v, 128, (long long)vocab_size * layer1_size * sizeof(real));
349 | if (v == NULL) {printf("Memory allocation failed\n"); exit(1);}
350 | for (a = 0; a < vocab_size; a++) {
351 | for (b = 0; b < layer1_size; b++) {
352 | next_random = next_random * (unsigned long long)25214903917 + 11;
353 | v[a * layer1_size + b] = (((next_random & 0xFFFF) / (real)65536) - 0.5) ;
354 | }
355 | }
356 | for (a = 0; a < vocab_size; a++)
357 | for (b = 0; b < layer1_size; b++) {
358 | next_random = next_random * (unsigned long long)25214903917 + 11;
359 | u[a * layer1_size + b] = (((next_random & 0xFFFF) / (real)65536) - 0.5) ;
360 | }
361 | }
362 |
363 | void *TrainModelThread(void *id) {
364 | long long a, b, c, d, cw, word, last_word, sentence_length = 0, sentence_position = 0;
365 | long long word_count = 0, last_word_count = 0, sen[MAX_SENTENCE_LENGTH + 1];
366 | // local_iter = 1 to save the model every epoch
367 | long long l2, target, label, local_iter = 1;
368 | unsigned long long next_random = (long long)id;
369 | char eof = 0;
370 | int local_bitlevel = bitlevel;
371 | real f, g;
372 | clock_t now;
373 | real *context_avg = (real *)calloc(layer1_size, sizeof(real));
374 | real *context_avge = (real *)calloc(layer1_size, sizeof(real));
375 | double loss = 0, total_loss = 0;
376 | FILE *fi = fopen(train_file, "rb");
377 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET);
378 | while (1) {
379 | if (word_count - last_word_count > 10000) {
380 | word_count_actual += word_count - last_word_count;
381 | last_word_count = word_count;
382 | if ((debug_mode > 1)) {
383 | now=clock();
384 | printf("%cAlpha: %f Progress: %.2f%% Cost: %f Words/thread/sec: %.2fk ", 13, alpha,
385 | word_count_actual / (real)(iter * train_words + 1) * 100,
386 | loss,
387 | word_count_actual / ((real)(now - start + 1) / (real)CLOCKS_PER_SEC * 1000));
388 | loss = 0;
389 | fflush(stdout);
390 | }
391 | alpha = starting_alpha * (1 - word_count_actual / (real)(iter * train_words + 1));
392 | if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001;
393 | }
394 | if (sentence_length == 0) {
395 | while (1) {
396 | word = ReadWordIndex(fi, &eof);
397 | if (eof) break;
398 | if (word == -1) continue;
399 | word_count++;
400 | if (word == 0) break;
401 | // The subsampling randomly discards frequent words while keeping the ranking same
402 | if (sample > 0) {
403 | real ran = (sqrt(vocab[word].cn / (sample * train_words)) + 1) *
404 | (sample * train_words) / vocab[word].cn;
405 | next_random = next_random * (unsigned long long)25214903917 + 11;
406 | if (ran < (next_random & 0xFFFF) / (real)65536) continue;
407 | }
408 | sen[sentence_length] = word;
409 | sentence_length++;
410 | if (sentence_length >= MAX_SENTENCE_LENGTH) break;
411 | }
412 | sentence_position = 0;
413 | }
414 | if (eof || (word_count > train_words / num_threads)) {
415 | word_count_actual += word_count - last_word_count;
416 | local_iter--;
417 | if (local_iter == 0) break;
418 | word_count = 0;
419 | last_word_count = 0;
420 | sentence_length = 0;
421 | fseek(fi, file_size / (long long)num_threads * (long long)id, SEEK_SET);
422 | continue;
423 | }
424 | word = sen[sentence_position];
425 | if (word == -1) continue;
426 | for (c = 0; c < layer1_size; c++) context_avg[c] = 0;
427 | for (c = 0; c < layer1_size; c++) context_avge[c] = 0;
428 | next_random = next_random * (unsigned long long)25214903917 + 11;
429 | b = next_random % window;
430 | cw = 0;
431 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) {
432 | c = sentence_position - window + a;
433 | if (c < 0) continue;
434 | if (c >= sentence_length) continue;
435 | last_word = sen[c];
436 | if (last_word == -1) continue;
437 | real local_reg_loss = 0;
438 | for (c = 0; c < layer1_size; c++) {
439 | real cur_val = quantize(u[c + last_word * layer1_size], local_bitlevel);
440 | context_avg[c] += cur_val;
441 | local_reg_loss += cur_val * cur_val;
442 | }
443 | local_reg_loss = reg * local_reg_loss;
444 | loss += -local_reg_loss;
445 | total_loss += -local_reg_loss;
446 | cw++;
447 | }
448 | if (cw) {
449 | for (c = 0; c < layer1_size; c++) context_avg[c] /= cw;
450 | for (d = 0; d < negative + 1; d++) {
451 | if (d == 0) {
452 | target = word;
453 | label = 1;
454 | } else {
455 | next_random = next_random * (unsigned long long)25214903917 + 11;
456 | target = table[(next_random >> 16) % table_size];
457 | if (target == 0) target = next_random % (vocab_size - 1) + 1;
458 | if (target == word) continue;
459 | label = 0;
460 | }
461 | l2 = target * layer1_size;
462 | f = 0;
463 | real local_reg_loss = 0;
464 | for (c = 0; c < layer1_size; c++) {
465 | real cur_val = quantize(v[c + l2], local_bitlevel);
466 | f += context_avg[c] * cur_val;
467 |
468 | // Keep track of regularization loss
469 | local_reg_loss += cur_val * cur_val;
470 | }
471 | local_reg_loss = reg * local_reg_loss;
472 |
473 | if (f > MAX_EXP) g = (label - 1) * alpha;
474 | else if (f < -MAX_EXP) g = (label - 0) * alpha;
475 | else g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha;
476 |
477 | ////////////////////
478 | // Compute loss
479 | ////////////////////
480 | real dot_product = f * pow(-1, 1-label);
481 | real local_loss = log(sigmoid(dot_product));
482 | loss += local_loss - local_reg_loss;
483 | total_loss += local_loss - local_reg_loss;
484 | /////////////////////
485 |
486 | for (c = 0; c < layer1_size; c++) {
487 | context_avge[c] += g * quantize(v[c + l2], local_bitlevel);
488 | }
489 | for (c = 0; c < layer1_size; c++) {
490 | v[c + l2] += g * context_avg[c] - 2*alpha*reg*v[c + l2];
491 | }
492 | }
493 | // hidden -> in
494 | for (a = b; a < window * 2 + 1 - b; a++) if (a != window) {
495 | c = sentence_position - window + a;
496 | if (c < 0) continue;
497 | if (c >= sentence_length) continue;
498 | last_word = sen[c];
499 | if (last_word == -1) continue;
500 | for (c = 0; c < layer1_size; c++) {
501 | u[c + last_word * layer1_size] += context_avge[c] - 2*alpha*reg*u[c+last_word*layer1_size];
502 | }
503 | }
504 | }
505 | sentence_position++;
506 | if (sentence_position >= sentence_length) {
507 | sentence_length = 0;
508 | continue;
509 | }
510 | }
511 | thread_losses[(long long)id] = total_loss;
512 | fclose(fi);
513 | free(context_avg);
514 | free(context_avge);
515 | pthread_exit(NULL);
516 | }
517 |
518 | void TrainModel() {
519 | long a, b;
520 | FILE *fo;
521 | thread_losses = (double *)malloc(sizeof(double) * num_threads);
522 | pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t));
523 | printf("Starting training using file %s\n", train_file);
524 | starting_alpha = alpha;
525 | if (read_vocab_file[0] != 0) ReadVocab(); else LearnVocabFromTrainFile();
526 | if (save_vocab_file[0] != 0) SaveVocab();
527 | if (output_file[0] == 0) return;
528 | InitNet();
529 | if (negative > 0) InitUnigramTable();
530 |
531 | start = clock();
532 | for (int iteration = 0; iteration < iter; iteration++) {
533 | printf("Starting epoch: %d\n", iteration);
534 | memset(thread_losses, 0, sizeof(double) * num_threads);
535 | for (a = 0; a < num_threads; a++) pthread_create(&pt[a], NULL, TrainModelThread, (void *)a);
536 | for (a = 0; a < num_threads; a++) pthread_join(pt[a], NULL);
537 | double total_loss_epoch = 0;
538 | for (a = 0; a < num_threads; a++) total_loss_epoch += thread_losses[a];
539 | printf("Epoch Loss: %lf\n", total_loss_epoch);
540 | char output_file_cur_iter[MAX_STRING] = {0};
541 | sprintf(output_file_cur_iter, "%s_epoch%d", output_file, iteration);
542 | if (classes == 0 && save_every_epoch) {
543 | fo = fopen(output_file_cur_iter, "wb");
544 | // Save the word vectors
545 | fprintf(fo, "%lld %lld\n", vocab_size, layer1_size);
546 | for (a = 0; a < vocab_size; a++) {
547 | fprintf(fo, "%s ", vocab[a].word);
548 | for (b = 0; b < layer1_size; b++) {
549 | float avg = u[a*layer1_size+b] + v[a*layer1_size+b];
550 | avg = quantize(avg, bitlevel);
551 | if (binary) fwrite(&avg, sizeof(float), 1, fo);
552 | else fprintf(fo, "%lf ", avg);
553 | }
554 | fprintf(fo, "\n");
555 | }
556 | fclose(fo);
557 | }
558 | }
559 |
560 | // Write an extra file
561 | fo = fopen(output_file, "wb");
562 | if (classes == 0) {
563 | // Save the word vectors
564 | fprintf(fo, "%lld %lld\n", vocab_size, layer1_size);
565 | for (a = 0; a < vocab_size; a++) {
566 | fprintf(fo, "%s ", vocab[a].word);
567 | for (b = 0; b < layer1_size; b++) {
568 | float avg = u[a*layer1_size+b] + v[a*layer1_size+b];
569 | avg = quantize(avg, bitlevel);
570 | if (binary) fwrite(&avg, sizeof(float), 1, fo);
571 | else fprintf(fo, "%lf ", avg);
572 | }
573 | fprintf(fo, "\n");
574 | }
575 | }
576 | fclose(fo);
577 | }
578 |
579 | int ArgPos(char *str, int argc, char **argv) {
580 | int a;
581 | for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
582 | if (a == argc - 1) {
583 | printf("Argument missing for %s\n", str);
584 | exit(1);
585 | }
586 | return a;
587 | }
588 | return -1;
589 | }
590 |
591 | int main(int argc, char **argv) {
592 | int i;
593 | output_file[0] = 0;
594 | save_vocab_file[0] = 0;
595 | read_vocab_file[0] = 0;
596 | if ((i = ArgPos((char *)"-save-every-epoch", argc, argv)) > 0) save_every_epoch = atoi(argv[i + 1]);
597 | if ((i = ArgPos((char *)"-bitlevel", argc, argv)) > 0) bitlevel = atoi(argv[i + 1]);
598 | if ((i = ArgPos((char *)"-size", argc, argv)) > 0) layer1_size = atoi(argv[i + 1]);
599 | if ((i = ArgPos((char *)"-reg", argc, argv)) > 0) reg = atof(argv[i + 1]);
600 | if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(train_file, argv[i + 1]);
601 | if ((i = ArgPos((char *)"-debug", argc, argv)) > 0) debug_mode = atoi(argv[i + 1]);
602 | if ((i = ArgPos((char *)"-binary", argc, argv)) > 0) binary = atoi(argv[i + 1]);
603 | if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) alpha = atof(argv[i + 1]);
604 | if ((i = ArgPos((char *)"-output", argc, argv)) > 0) strcpy(output_file, argv[i + 1]);
605 | if ((i = ArgPos((char *)"-window", argc, argv)) > 0) window = atoi(argv[i + 1]);
606 | if ((i = ArgPos((char *)"-sample", argc, argv)) > 0) sample = atof(argv[i + 1]);
607 | if ((i = ArgPos((char *)"-negative", argc, argv)) > 0) negative = atoi(argv[i + 1]);
608 | if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) num_threads = atoi(argv[i + 1]);
609 | if ((i = ArgPos((char *)"-iter", argc, argv)) > 0) iter = atoi(argv[i + 1]);
610 | if ((i = ArgPos((char *)"-min-count", argc, argv)) > 0) min_count = atoi(argv[i + 1]);
611 | if ((i = ArgPos((char *)"-classes", argc, argv)) > 0) classes = atoi(argv[i + 1]);
612 | vocab = (struct vocab_word *)calloc(vocab_max_size, sizeof(struct vocab_word));
613 | vocab_hash = (int *)calloc(vocab_hash_size, sizeof(int));
614 | expTable = (real *)malloc((EXP_TABLE_SIZE + 1) * sizeof(real));
615 | for (i = 0; i < EXP_TABLE_SIZE; i++) {
616 | expTable[i] = exp((i / (real)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); // Precompute the exp() table
617 | expTable[i] = expTable[i] / (expTable[i] + 1); // Precompute f(x) = x / (x + 1)
618 | }
619 | TrainModel();
620 | return 0;
621 | }
622 |
--------------------------------------------------------------------------------