├── LICENSE
├── README.md
├── Word2Vec
├── .classpath
├── .project
├── .settings
│ └── org.eclipse.jdt.core.prefs
├── bin
│ ├── Word2Vec$Builder.class
│ ├── Word2Vec$VocabWord.class
│ └── Word2Vec.class
├── lib
│ ├── commons-cli-1.3.1.jar
│ ├── commons-codec-1.9.jar
│ ├── commons-logging-1.2.jar
│ ├── slf4j-api-1.7.21.jar
│ └── slf4j-simple-1.7.21.jar
└── src
│ └── Word2Vec.java
└── assets
└── arguments.JPG
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 KimJunho
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Word2Vec In Java
2 | * https://code.google.com/archive/p/word2vec/source/default/source
3 | * Changed Word2vec c code to Java
4 |
5 | ## Usage
6 | * Put "Input.txt" in the folder containing the source code
7 | ```bash
8 | The contents of Input.txt are as follows
9 | There is one document per line
10 | All documents must be preprocessed
11 | ```
12 |
13 | * Preprocessing: documents should be separated by words using morphemes
14 | * In Eclipse, you mush give arguments (Run - Run Configurations...)
15 |
16 |
17 | 
18 |
19 |
20 | * a = input.txt, b = output.txt... That is, the name of the input output file.
21 | * but, in Code I hava set it (Line 34, 35)
22 |
23 | ## Contents of "Input.txt" after preprocessing
24 | * Document 1 : KimJunho is interested in machine learning and deep learning
25 | * Document 2 : KimJunho is interested in recruiting professional researches
26 | ```bash
27 | KimJunho isterested machine learning deep learning
28 | KimJunho recruiting professional researchers
29 | ```
30 |
31 | ## Main Variable Description
32 | * See Line 894 (public static class Builder)
33 | ```java
34 | 1. cbow = false
35 | Which of the cbow and skip-gram models to learn ?
36 | false : use skip gram
37 | true : use cbow model
38 |
39 | 2. startingAlpha = 0.025F
40 | This is a learningrate
41 | The smaller the value, the more accurate the learning, but the slower the learning speed
42 |
43 | 3. window = 5
44 | How many words to look around when learning
45 | The default value is 5, meaning that you see 5 words
46 |
47 | 4. negative = 0
48 | It can be used to improve the efficiency of calculation speed
49 | Methodology has Hierarchical Softmax and Negative Sampling
50 | If 0, Hierarchical Softmax
51 | else, Negative Sampling.. default value 5~10
52 |
53 | 5. minCount = 5
54 | Meaning that I will only see words from at least a few words in the document
55 | If you want to learn every word, minCount = 0
56 |
57 | 6. layerOneSize = 200
58 | Mean dimension of word vector
59 | default value is 200
60 | The higher the dimension, the more precise it is, but the learning speed is slower
61 | ```
62 |
--------------------------------------------------------------------------------
/Word2Vec/.classpath:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/Word2Vec/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | Word2Vec
4 |
5 |
6 |
7 |
8 |
9 | org.eclipse.jdt.core.javabuilder
10 |
11 |
12 |
13 |
14 |
15 | org.eclipse.jdt.core.javanature
16 |
17 |
18 |
--------------------------------------------------------------------------------
/Word2Vec/.settings/org.eclipse.jdt.core.prefs:
--------------------------------------------------------------------------------
1 | eclipse.preferences.version=1
2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8
4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
5 | org.eclipse.jdt.core.compiler.compliance=1.8
6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate
7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate
8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate
9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
11 | org.eclipse.jdt.core.compiler.source=1.8
12 |
--------------------------------------------------------------------------------
/Word2Vec/bin/Word2Vec$Builder.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/bin/Word2Vec$Builder.class
--------------------------------------------------------------------------------
/Word2Vec/bin/Word2Vec$VocabWord.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/bin/Word2Vec$VocabWord.class
--------------------------------------------------------------------------------
/Word2Vec/bin/Word2Vec.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/bin/Word2Vec.class
--------------------------------------------------------------------------------
/Word2Vec/lib/commons-cli-1.3.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/commons-cli-1.3.1.jar
--------------------------------------------------------------------------------
/Word2Vec/lib/commons-codec-1.9.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/commons-codec-1.9.jar
--------------------------------------------------------------------------------
/Word2Vec/lib/commons-logging-1.2.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/commons-logging-1.2.jar
--------------------------------------------------------------------------------
/Word2Vec/lib/slf4j-api-1.7.21.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/slf4j-api-1.7.21.jar
--------------------------------------------------------------------------------
/Word2Vec/lib/slf4j-simple-1.7.21.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/slf4j-simple-1.7.21.jar
--------------------------------------------------------------------------------
/Word2Vec/src/Word2Vec.java:
--------------------------------------------------------------------------------
1 |
2 |
3 | import java.io.DataInput;
4 | import java.io.DataInputStream;
5 | import java.io.DataOutputStream;
6 | import java.io.EOFException;
7 | import java.io.File;
8 | import java.io.FileInputStream;
9 | import java.io.FileOutputStream;
10 | import java.io.FileWriter;
11 | import java.io.IOException;
12 | import java.io.RandomAccessFile;
13 | import java.text.DecimalFormat;
14 | import java.text.NumberFormat;
15 | import java.util.ArrayList;
16 | import java.util.Arrays;
17 | import java.util.List;
18 |
19 | import org.apache.commons.cli.CommandLine;
20 | import org.apache.commons.cli.CommandLineParser;
21 | import org.apache.commons.cli.HelpFormatter;
22 | import org.apache.commons.cli.Option;
23 | import org.apache.commons.cli.OptionBuilder;
24 | import org.apache.commons.cli.Options;
25 | import org.apache.commons.cli.PosixParser;
26 | import org.slf4j.Logger;
27 | import org.slf4j.LoggerFactory;
28 |
29 |
30 |
31 | @SuppressWarnings("deprecation")
32 | public class Word2Vec {
33 | private static final Logger log = LoggerFactory.getLogger(Word2Vec.class);
34 | private static final String input = "Input.txt";
35 | private static final String output = "Output.txt";
36 |
37 | class VocabWord implements Comparable {
38 | VocabWord(String word) {
39 | this.word = word;
40 | }
41 | int cn = 0;
42 | int codelen;
43 | int[] point = new int[MAX_CODE_LENGTH];
44 | long[] code = new long[MAX_CODE_LENGTH];
45 | String word;
46 |
47 | @Override
48 | public int compareTo(VocabWord that) {
49 | if(that==null) {
50 | return 1;
51 | }
52 |
53 | return that.cn - this.cn;
54 | }
55 | @Override
56 | public String toString() {
57 | return this.cn + ": " + this.word;
58 | }
59 | }
60 |
61 | private static final int MAX_STRING = 100;
62 | private static final int EXP_TABLE_SIZE= 1000;
63 | private static final int MAX_EXP= 6;
64 | private static final int MAX_SENTENCE_LENGTH= 1000;
65 | private static final int MAX_CODE_LENGTH= 40;
66 | private static final int TABLE_SIZE = 100000000;
67 |
68 | // Maximum 30 * 0.7 = 21M words in the vocabulary
69 | private static final int VOCAB_HASH_SIZE = 30000000;
70 |
71 | private final int layerOneSize;
72 | private final File trainFile;
73 | private final File outputFile;
74 | private final File saveVocabFile;
75 | private final File readVocabFile;
76 | private final int window;
77 | private final int negative;
78 | private final int minCount;
79 | private final int numThreads;
80 | private final int classes;
81 | private final boolean binary;
82 | private final boolean cbow;
83 | private final boolean noHs;
84 | private final float startingAlpha;
85 | private final float sample;
86 | private final float[] expTable;
87 |
88 | private int minReduce = 1;
89 | private int vocabMaxSize = 1000;
90 | private VocabWord[] vocabWords = new VocabWord[vocabMaxSize];
91 | private int[] vocabHash = new int[VOCAB_HASH_SIZE];
92 | private Byte ungetc = null;
93 |
94 | private int vocabSize = 0;
95 | private long trainWords = 0;
96 | private long wordCountActual = 0;
97 | private int[] table;
98 |
99 | private float alpha;
100 |
101 | private float[] syn0;
102 | private float[] syn1;
103 | private float[] syn1neg;
104 |
105 | private long start;
106 |
107 | public Word2Vec(Builder b) {
108 | this.trainFile = b.trainFile;
109 | this.outputFile = b.outputFile;
110 | this.saveVocabFile = b.saveVocabFile;
111 | this.readVocabFile = b.readVocabFile;
112 | this.binary = b.binary;
113 | this.cbow = b.cbow;
114 | this.noHs = b.noHs;
115 | this.startingAlpha = b.startingAlpha;
116 | this.sample = b.sample;
117 | this.window = b.window;
118 | this.negative = b.negative;
119 | this.minCount = b.minCount;
120 | this.numThreads = b.numThreads;
121 | this.classes = b.classes;
122 | this.layerOneSize = b.layerOneSize;
123 |
124 | float[] tempExpTable = new float[EXP_TABLE_SIZE];
125 | for (int i = 0; i < tempExpTable.length; i++) {
126 | // Precompute the exp() table
127 | tempExpTable[i] = (float) Math.exp((i / (float) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP);
128 | // Precompute f(x) = x / (x + 1)
129 | tempExpTable[i] = tempExpTable[i] / (tempExpTable[i] + 1);
130 | }
131 | expTable = tempExpTable;
132 | }
133 | private void readVocab() throws IOException {
134 | vocabSize = 0;
135 | try(DataInputStream is = new DataInputStream(new FileInputStream(readVocabFile))) {
136 | String word;
137 | while((word = readWord(is)) != null) {
138 | int a = addWordToVocab(word);
139 | vocabWords[a].cn = is.readInt();
140 | is.readChar();
141 | }
142 | sortVocab();
143 | log.debug("Vocab size: {}", vocabSize);
144 | log.debug("Words in train file: {}", trainWords);
145 | } catch(IOException ioe) {
146 | throw ioe;
147 | }
148 | }
149 |
150 | private void learnVocabFromTrainFile() throws IOException {
151 | for (int a = 0; a < VOCAB_HASH_SIZE; a++) {
152 | vocabHash[a] = -1;
153 | }
154 | vocabSize = 0;
155 | addWordToVocab("");
156 | try (DataInputStream is = new DataInputStream(new FileInputStream(trainFile))) {
157 | while (true) {
158 | String word = readWord(is);
159 | if (word==null) {
160 | break;
161 | }
162 | trainWords++;
163 | if(log.isTraceEnabled() && trainWords % 100000 == 0) {
164 | log.trace("{}K training words processed.", (trainWords/1000));
165 | }
166 | int i = searchVocab(word);
167 | if (i == -1) {
168 | i = addWordToVocab(word);
169 | vocabWords[i].cn = 1;
170 | } else
171 | vocabWords[i].cn++;
172 | if (vocabSize > VOCAB_HASH_SIZE * 0.7) {
173 | reduceVocab();
174 | }
175 | }
176 | } catch (IOException ioe) {
177 | throw ioe;
178 | }
179 | sortVocab();
180 | log.debug("Vocab size: {}", vocabSize);
181 | log.debug("Words in train file: {}", trainWords);
182 | }
183 |
184 | private void saveVocab() throws IOException {
185 | saveVocabFile.delete();
186 | try (FileWriter fw = new FileWriter(saveVocabFile)) {
187 | //Don't output the , at element zero.
188 | for (int i = 1; i < vocabSize; i++) {
189 | fw.write(vocabWords[i].word);
190 | fw.write(" ");
191 | fw.write("" + vocabWords[i].cn);
192 | fw.write("\n");
193 | }
194 | }
195 | }
196 | private void initNet() {
197 | syn0 = new float[vocabSize * layerOneSize];
198 | if(!noHs) {
199 | syn1 = new float[vocabSize * layerOneSize];
200 | for(int b=0 ; b < layerOneSize ; b++) {
201 | for(int a=0 ; a0) {
207 | syn1neg = new float[vocabSize * layerOneSize];
208 | for(int b=0 ; b < layerOneSize ; b++) {
209 | for(int a=0 ; a= 0) {
244 | if (count[pos1] < count[pos2]) {
245 | min1i = pos1;
246 | pos1--;
247 | } else {
248 | min1i = pos2;
249 | pos2++;
250 | }
251 | } else {
252 | min1i = pos2;
253 | pos2++;
254 | }
255 | if (pos1 >= 0) {
256 | if (count[pos1] < count[pos2]) {
257 | min2i = pos1;
258 | pos1--;
259 | } else {
260 | min2i = pos2;
261 | pos2++;
262 | }
263 | } else {
264 | min2i = pos2;
265 | pos2++;
266 | }
267 | count[vocabSize + a] = count[min1i] + count[min2i];
268 | parentNode[min1i] = vocabSize + a;
269 | parentNode[min2i] = vocabSize + a;
270 | binary[min2i] = 1;
271 | }
272 |
273 | // Now assign binary code to each vocabulary word
274 | long[] code = new long[MAX_CODE_LENGTH];
275 | int[] point = new int[MAX_CODE_LENGTH];
276 | for (int a = 0; a < vocabSize; a++) {
277 | int b = a;
278 | int i = 0;
279 | while (true) {
280 | code[i] = binary[b];
281 | point[i] = b;
282 | i++;
283 | b = parentNode[b];
284 | if (b == vocabSize * 2 - 2)
285 | break;
286 | }
287 | vocabWords[a].codelen = i;
288 | vocabWords[a].point[0] = vocabSize - 2;
289 | for (b = 0; b < i; b++) {
290 | vocabWords[a].code[i - b - 1] = code[b];
291 | vocabWords[a].point[i - b] = point[b] - vocabSize;
292 | }
293 | }
294 | }
295 | private void initUnigramTable() {
296 | long trainWordsPow = 0;
297 | float power = 0.75F;
298 | for (int a = 0; a < vocabSize; a++) {
299 | trainWordsPow += Math.pow(vocabWords[a].cn, power);
300 | }
301 | int i = 0;
302 | float d1 = (float) Math.pow(vocabWords[i].cn, power) / (float) trainWordsPow;
303 | for (int a = 0; a < TABLE_SIZE; a++) {
304 | table[a] = i;
305 | if (a / (float) TABLE_SIZE > d1) {
306 | i++;
307 | d1 += Math.pow(vocabWords[i].cn, power) / (float) trainWordsPow;
308 | }
309 | if (i >= vocabSize) {
310 | i = vocabSize - 1;
311 | }
312 | }
313 | }
314 |
315 | //DataOutputStream#writeFloat writes the high byte first
316 | //but let's write the low byte first to give ourselves a better chance of
317 | //compatibility with the original c++ code
318 | private void writeFloat(float f, DataOutputStream out) throws IOException {
319 | int v = Float.floatToIntBits(f);
320 | out.write((v >>> 0) & 0xFF);
321 | out.write((v >>> 8) & 0xFF);
322 | out.write((v >>> 16) & 0xFF);
323 | out.write((v >>> 24) & 0xFF);
324 | }
325 |
326 | public void trainModel() {
327 | if(trainFile==null && readVocabFile==null) {
328 | throw new IllegalStateException("You must supply either a trainFile or a readVocabFile.");
329 | }
330 | alpha = startingAlpha;
331 | if(readVocabFile!=null) {
332 | try {
333 | log.info("Reading vocabulary from file {}.", readVocabFile);
334 | readVocab();
335 | } catch(IOException ioe) {
336 | log.error("There was a problem reading the vocabulary file.", ioe);
337 | return;
338 | }
339 | } else {
340 | log.info("Starting training using file {}.", trainFile);
341 | try {
342 | learnVocabFromTrainFile();
343 | } catch(IOException ioe) {
344 | log.error("There was a problem reading the training file.", ioe);
345 | return;
346 | }
347 | }
348 | if(saveVocabFile!=null) {
349 | try {
350 | saveVocab();
351 | } catch(IOException ioe) {
352 | log.error("There was a problem writing the vocabulary file.", ioe);
353 | return;
354 | }
355 | }
356 | if(outputFile==null) {
357 | return;
358 | }
359 | initNet();
360 | if(negative>0) {
361 | initUnigramTable();
362 | }
363 | start = System.nanoTime();
364 | //TODO: theads
365 | try {
366 | trainModelThread(0);
367 | } catch(IOException ioe) {
368 | log.error("There was a problem reading the training file.", ioe);
369 | return;
370 | }
371 | outputFile.delete();
372 | NumberFormat vectorTextFormat = new DecimalFormat("#.######");
373 | try(DataOutputStream os = new DataOutputStream(new FileOutputStream(outputFile))) {
374 | if(classes==0) {
375 | // Save the word vectors
376 | os.writeBytes("" + vocabSize + " " + layerOneSize + "\n");
377 | for (int a = 0; a < vocabSize; a++) {
378 | os.writeBytes(vocabWords[a].word);
379 | os.writeBytes(" ");
380 | if (binary) {
381 | for (int b = 0; b < layerOneSize; b++) {
382 | writeFloat(syn0[a * layerOneSize + b], os);
383 | }
384 | } else {
385 | for (int b = 0; b < layerOneSize; b++) {
386 | int index = a * layerOneSize + b;
387 | float value = syn0[index];
388 | os.writeBytes(vectorTextFormat.format(value) + " ");
389 | }
390 | }
391 | os.writeBytes("\n");
392 |
393 | }
394 | os.writeBytes("\n");
395 | } else {
396 | // Run K-means on the word vectors
397 | if(classes*layerOneSize > Integer.MAX_VALUE) {
398 | throw new RuntimeException("Number of classes times the size of Layer One cannot be greater than " + Integer.MAX_VALUE + " (" + classes + " * " + layerOneSize + ")");
399 | }
400 | int[] cl = new int[vocabSize];
401 | float[] cent = new float[classes * layerOneSize];
402 | int[] centcn = new int[classes];
403 | int numIterations = 10;
404 |
405 | for (int a = 0; a < vocabSize; a++) {
406 | cl[a] = a % classes;
407 | }
408 | for(int a = 0; a closev) {
441 | closev = x;
442 | closeid = d;
443 | }
444 | }
445 | cl[c] = closeid;
446 | }
447 | }
448 | // Save the K-means classes
449 | for(int a=0 ; a< vocabSize ; a++) {
450 | os.writeBytes(vocabWords[a].word);
451 | os.writeBytes(" ");
452 | os.writeInt(cl[a]);
453 | }
454 | }
455 | } catch(IOException ioe) {
456 | log.error("There was a problem writing the output file", ioe);
457 | return;
458 | }
459 | }
460 | private void trainModelThread(int id) throws IOException {
461 | try(RandomAccessFile raf = new RandomAccessFile(trainFile, "rw")) {
462 | if(id>0) {
463 | raf.seek(raf.length() / (numThreads * id));
464 | }
465 | long wordCount = 0;
466 | long lastWordCount = 0;
467 | int word = 0;
468 | int target = 0;
469 | int label = 0;
470 | int sentenceLength = 0;
471 | int sentencePosition = 0;
472 | int nextRandom = id;
473 | int[] sen = new int[MAX_SENTENCE_LENGTH + 1];
474 | float[] neu1 = new float[layerOneSize];
475 | float[] neu1e = new float[layerOneSize];
476 |
477 | NumberFormat alphaFormat = new DecimalFormat("0.000000");
478 | NumberFormat logPercentFormat = new DecimalFormat("#0.00%");
479 | NumberFormat wordsPerSecondFormat = new DecimalFormat("00.00k");
480 | long now = System.nanoTime();
481 | while(true) {
482 | if (wordCount - lastWordCount > 10000) {
483 | wordCountActual += wordCount - lastWordCount;
484 | lastWordCount = wordCount;
485 | if (log.isTraceEnabled()) {
486 | now = System.nanoTime();
487 | log.trace("Alpha: {}", alphaFormat.format(alpha));
488 | log.trace("Progress: {} ", logPercentFormat.format((float) wordCountActual / (trainWords + 1)));
489 | log.trace(
490 | "Words/thread/sec: {}\n",
491 | wordsPerSecondFormat.format((float) wordCountActual / (float) (now - start + 1)
492 | * 1000000));
493 | }
494 | alpha = startingAlpha * (1 - wordCountActual / (float) (trainWords + 1));
495 | if (alpha < startingAlpha * 0.0001F) {
496 | alpha = startingAlpha * 0.0001F;
497 | }
498 | }
499 | if(sentenceLength==0) {
500 | while(true) {
501 | word = readWordIndex(raf);
502 | if(word==-1) {
503 | break;
504 | }
505 | wordCount++;
506 | if(word==0) {
507 | break;
508 | }
509 | // The subsampling randomly discards frequent words while keeping the ranking same
510 | if (sample > 0) {
511 | float ran = (float) (Math.sqrt(vocabWords[word].cn / (sample * trainWords)) + 1) * (sample * trainWords) / vocabWords[word].cn;
512 | nextRandom = (int) (nextRandom * 25214903917L + 11);
513 | if (ran < ((nextRandom & 0xFFFF) / (float) 65536)) {
514 | continue;
515 | }
516 | }
517 | sen[sentenceLength] = word;
518 | sentenceLength++;
519 | if (sentenceLength >= MAX_SENTENCE_LENGTH) {
520 | break;
521 | }
522 | }
523 | sentencePosition = 0;
524 | }
525 | if(raf.getFilePointer()==raf.length()) {
526 | break;
527 | }
528 | if(wordCount > trainWords / numThreads) {
529 | break;
530 | }
531 | word = sen[sentencePosition];
532 | for(int c=0 ; c hidden
540 | for(int a = b ; b < window * 2 + 1 - b ; a++) {
541 | if(a != window) {
542 | int c = sentencePosition - window + a;
543 | if (c < 0) {
544 | continue;
545 | }
546 | if (c >= sentenceLength) {
547 | continue;
548 | }
549 | int lastWord = sen[c];
550 | for (c = 0; c < layerOneSize; c++) {
551 | neu1[c] += syn0[c + lastWord * layerOneSize];
552 | }
553 | }
554 | }
555 | if (!noHs) {
556 | for(int d=0 ; d< vocabWords[word].codelen ; d++) {
557 | float f=0;
558 | int l2 = vocabWords[word].point[d] * layerOneSize;
559 | // Propagate hidden -> output
560 | for (int c = 0; c < layerOneSize; c++) {
561 | f += neu1[c] * syn1[c + l2];
562 | }
563 | if (f <= -1 * MAX_EXP || f >= MAX_EXP) {
564 | continue;
565 | }
566 | f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
567 | // 'g' is the gradient multiplied by the learning rate
568 | float g = (1 - vocabWords[word].code[d] - f) * alpha;
569 | // Propagate errors output -> hidden
570 | for (int c = 0; c < layerOneSize; c++) {
571 | neu1e[c] += g * syn1[c + l2];
572 | }
573 | // Learn weights hidden -> output
574 | for (int c = 0; c < layerOneSize; c++) {
575 | syn1[c + l2] += g * neu1[c];
576 | }
577 | }
578 | }
579 | // NEGATIVE SAMPLING
580 | if (negative > 0) {
581 | for (int d = 0; d < negative + 1; d++) {
582 | if (d == 0) {
583 | target = word;
584 | label = 1;
585 | } else {
586 | nextRandom = (int) (nextRandom * 25214903917L + 11);
587 | target = table[(nextRandom >> 16) % TABLE_SIZE];
588 | if (target == 0) {
589 | target = nextRandom % (vocabSize - 1) + 1;
590 | }
591 | if (target == word) {
592 | continue;
593 | }
594 | label = 0;
595 | }
596 | int l2 = target * layerOneSize;
597 | int f = 0;
598 | for (int c = 0; c < layerOneSize; c++) {
599 | f += neu1[c] * syn1neg[c + l2];
600 | }
601 | float g;
602 | if (f > MAX_EXP) {
603 | g = (label - 1) * alpha;
604 | } else if (f < -MAX_EXP) {
605 | g = (label - 0) * alpha;
606 | } else {
607 | g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha;
608 | }
609 | for (int c = 0; c < layerOneSize; c++) {
610 | neu1e[c] += g * syn1neg[c + l2];
611 | }
612 | for (int c = 0; c < layerOneSize; c++) {
613 | syn1neg[c + l2] += g * neu1[c];
614 | }
615 | }
616 | }
617 | // hidden -> in
618 | for (int a = b; a < window * 2 + 1 - b; a++) {
619 | if (a != window) {
620 | int c = sentencePosition - window + a;
621 | if (c < 0 || c >= sentenceLength) {
622 | continue;
623 | }
624 | int lastWord = sen[c];
625 | for (c = 0; c < layerOneSize; c++) {
626 | syn0[c + lastWord * layerOneSize] += neu1e[c];
627 | }
628 | }
629 | }
630 | } else { //train skip-gram
631 | for (int a = b; a < window * 2 + 1 - b; a++) {
632 | if (a != window) {
633 | int lastWordIndex = sentencePosition - window + a;
634 | if (lastWordIndex < 0 || lastWordIndex >= sentenceLength) {
635 | continue;
636 | }
637 | int lastWord = sen[lastWordIndex];
638 | int l1 = lastWord * layerOneSize;
639 | for (int c = 0; c < layerOneSize; c++) {
640 | neu1e[c] = 0;
641 | }
642 | // HIERARCHICAL SOFTMAX
643 | if (!noHs) {
644 | for (int d = 0; d < vocabWords[word].codelen; d++) {
645 | float f = 0;
646 | int l2 = vocabWords[word].point[d] * layerOneSize;
647 | // Propagate hidden -> output
648 | for (int c = 0; c < layerOneSize; c++) {
649 | f += syn0[c + l1] * syn1[c + l2];
650 | }
651 | if (f <= -MAX_EXP || f >= MAX_EXP) {
652 | continue;
653 | }
654 | f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
655 | // 'g' is the gradient multiplied by the learning rate
656 | float g = (1 - vocabWords[word].code[d] - f) * alpha;
657 | // Propagate errors output -> hidden
658 | for (int c = 0; c < layerOneSize; c++) {
659 | neu1e[c] += g * syn1[c + l2];
660 | }
661 | // Learn weights hidden -> output
662 | for (int c = 0; c < layerOneSize; c++) {
663 | syn1[c + l2] += g * syn0[c + l1];
664 | }
665 | }
666 | }
667 | // NEGATIVE SAMPLING
668 | if (negative > 0) {
669 | for (int d = 0; d < negative + 1; d++) {
670 | if (d == 0) {
671 | target = word;
672 | label = 1;
673 | } else {
674 | nextRandom = (int) (nextRandom * 25214903917L + 11);
675 | target = table[(nextRandom >> 16) % TABLE_SIZE];
676 | if (target == 0) {
677 | target = nextRandom % (vocabSize - 1) + 1;
678 | }
679 | if (target == word) {
680 | continue;
681 | }
682 | label = 0;
683 | }
684 | int l2 = target * layerOneSize;
685 | int f = 0;
686 | for (int c = 0; c < layerOneSize; c++) {
687 | f += syn0[c + l1] * syn1neg[c + l2];
688 | }
689 | float g;
690 | if (f > MAX_EXP) {
691 | g = (label - 1) * alpha;
692 | } else if (f < -MAX_EXP) {
693 | g = (label - 0) * alpha;
694 | } else {
695 | g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha;
696 | }
697 | for (int c = 0; c < layerOneSize; c++) {
698 | neu1e[c] += g * syn1neg[c + l2];
699 | }
700 | for (int c = 0; c < layerOneSize; c++) {
701 | syn1neg[c + l2] += g * syn0[c + l1];
702 | }
703 | }
704 | }
705 | // Learn weights input -> hidden
706 | for (int c = 0; c < layerOneSize; c++) {
707 | syn0[c + l1] += neu1e[c];
708 | }
709 | }
710 | }
711 | }
712 | sentencePosition++;
713 | if (sentencePosition >= sentenceLength) {
714 | sentenceLength = 0;
715 | continue;
716 | }
717 | }
718 | } catch(IOException ioe) {
719 | throw ioe;
720 | }
721 | }
722 |
723 | // Reduces the vocabulary by removing infrequent tokens
724 | private void reduceVocab() {
725 | int b=0;
726 | for (int a = 0; a < vocabSize; a++) {
727 | if (vocabWords[a].cn > minReduce) {
728 | vocabWords[b].cn = vocabWords[a].cn;
729 | vocabWords[b].word = vocabWords[a].word;
730 | b++;
731 | }
732 | }
733 | vocabSize = b;
734 | for (int a = 0; a < VOCAB_HASH_SIZE; a++) {
735 | vocabHash[a] = -1;
736 | }
737 | for (int a = 0; a < vocabSize; a++) {
738 | // Hash will be re-computed, as it is not actual
739 | int hash = getWordHash(vocabWords[a].word);
740 | while (vocabHash[hash] != -1) {
741 | hash = (hash + 1) % VOCAB_HASH_SIZE;
742 | hash = Math.abs(hash);
743 | }
744 | vocabHash[hash] = a;
745 | }
746 | minReduce++;
747 | }
748 | // Sorts the vocabulary by frequency using word counts
749 | private void sortVocab() {
750 | // Sort the vocabulary and keep at the first position
751 | Arrays.sort(vocabWords, 1, vocabSize - 1);
752 | for (int a = 0; a < vocabHash.length; a++) {
753 | vocabHash[a] = -1;
754 | }
755 |
756 | trainWords = 0;
757 | int originalVocabSize = vocabSize;
758 | List wordList = new ArrayList(originalVocabSize);
759 | int aa=0;
760 | for (int a = 0; a < originalVocabSize; a++) {
761 | VocabWord vw = vocabWords[a];
762 | // Words occurring less than min_count times will be discarded from the vocab
763 | if (vw.cn < minCount && vw.cn > 0) {
764 | vocabSize--;
765 | } else {
766 | // Hash will be re-computed, as after the sorting it is not actual
767 | int hash = getWordHash(vw.word);
768 | while (vocabHash[hash] != -1) {
769 | hash = (hash + 1) % VOCAB_HASH_SIZE;
770 | hash = Math.abs(hash);
771 | }
772 | vocabHash[hash] = aa;
773 | trainWords += vw.cn;
774 | wordList.add(vw);
775 | aa++;
776 | }
777 | }
778 | vocabWords = wordList.toArray(new VocabWord[wordList.size()]);
779 | }
780 |
781 | private int addWordToVocab(String word) {
782 | int length = word.length() + 1;
783 | if(length > MAX_STRING) {
784 | length = MAX_STRING;
785 | }
786 | vocabWords[vocabSize] = new VocabWord(word);
787 | vocabSize++;
788 |
789 | // Reallocate memory if needed
790 | if (vocabSize + 2 >= vocabMaxSize) {
791 | vocabMaxSize += 1000;
792 | VocabWord[] vocabWords1 = new VocabWord[vocabMaxSize];
793 | System.arraycopy(vocabWords, 0, vocabWords1, 0, vocabWords.length);
794 | vocabWords = vocabWords1;
795 | }
796 | int hash = getWordHash(word);
797 | while (vocabHash[hash] != -1) {
798 | hash = (hash + 1) % VOCAB_HASH_SIZE;
799 | hash = Math.abs(hash);
800 | }
801 | vocabHash[hash] = vocabSize - 1;
802 | return vocabSize - 1;
803 | }
804 | private int getWordHash(String word) {
805 | int hash = 0;
806 | for (int a = 0; a < word.length(); a++) {
807 | hash = hash * 257 + word.charAt(a);
808 | }
809 | hash = hash % VOCAB_HASH_SIZE;
810 | return Math.abs(hash);
811 | }
812 |
813 | // Returns position of a word in the vocabulary; if the word is not found, returns -1
814 | private int searchVocab(String word) {
815 | int hash = getWordHash(word);
816 | while (true) {
817 | if (vocabHash[hash] == -1) {
818 | return -1;
819 | }
820 | if(word.equals(vocabWords[vocabHash[hash]].word)) {
821 | return vocabHash[hash];
822 | }
823 | hash = (hash + 1) % VOCAB_HASH_SIZE;
824 | hash = Math.abs(hash);
825 | }
826 | }
827 |
828 | private int readWordIndex(RandomAccessFile raf) throws IOException {
829 | String word = readWord(raf);
830 | if(raf.length()==raf.getFilePointer()) {
831 | return -1;
832 | }
833 | return searchVocab(word);
834 | }
835 |
836 | private String readWord(DataInput dataInput) throws IOException {
837 | StringBuilder sb = new StringBuilder();
838 | while(true) {
839 | byte ch;
840 | if(ungetc != null) {
841 | ch = ungetc;
842 | ungetc = null;
843 | } else {
844 | try {
845 | ch = dataInput.readByte();
846 | } catch(EOFException eofe) {
847 | break;
848 | }
849 | }
850 | if(ch=='\r') {
851 | continue;
852 | }
853 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) {
854 | if (sb.length()>0 ) {
855 | if(ch == '\n') {
856 | ungetc = ch;
857 | }
858 | break;
859 | }
860 | if (ch == '\n') {
861 | return "";
862 | } else {
863 | continue;
864 | }
865 | }
866 | sb.append((char) ch);
867 |
868 | // Truncate too long words
869 | if (sb.length() >= MAX_STRING - 1) {
870 | sb.deleteCharAt(sb.length()-1);
871 | }
872 | }
873 | String word = sb.length()==0 ? null : sb.toString();
874 | return word;
875 | }
876 | // 변수설명
877 | public static class Builder {
878 | private File trainFile = null;
879 | private File outputFile = null;
880 | private File saveVocabFile = null;
881 | private File readVocabFile = null;
882 | private boolean binary = false;
883 | private boolean cbow = false; // if false, use skip-gram model.. else, use cbow model
884 | private boolean noHs = false;
885 | private float startingAlpha = 0.025F; //0.025F, learning rate
886 | private float sample = 0.0F;
887 | private int window = 5; // Surrounding word
888 | private int negative = 0; // If negative = 0, use Hierarchical Softmax ...
889 | // default negative = 5 ~ 10
890 | private int minCount = 5; // mincount word, if you use all word then minCount = 0
891 | private int numThreads = 1;
892 | private int classes = 0;
893 | private int layerOneSize = 200; // vecter size
894 | public Builder trainFile(String trainFile) {
895 | this.trainFile = new File(trainFile);
896 | return this;
897 | }
898 | public Builder outputFile(String outputFile) {
899 | this.outputFile = new File(outputFile);
900 | return this;
901 | }
902 | public Builder saveVocabFile(String saveVocabFile) {
903 | this.saveVocabFile = new File(saveVocabFile);
904 | return this;
905 | }
906 | public Builder readVocabFile(String readVocabFile) {
907 | this.readVocabFile = new File(readVocabFile);
908 | return this;
909 | }
910 | public Builder binary() {
911 | this.binary = true;
912 | return this;
913 | }
914 | public Builder cbow() {
915 | this.cbow = true;
916 | return this;
917 | }
918 | public Builder noHs() {
919 | this.noHs = true;
920 | return this;
921 | }
922 | public Builder startingAlpha(float startingAlpha) {
923 | this.startingAlpha = startingAlpha;
924 | return this;
925 | }
926 | public Builder sample(float sample) {
927 | this.sample = sample;
928 | return this;
929 | }
930 | public Builder window(int window) {
931 | this.window = window;
932 | return this;
933 | }
934 | public Builder negative(int negative) {
935 | this.negative = negative;
936 | return this;
937 | }
938 | public Builder minCount(int minCount) {
939 | this.minCount = minCount;
940 | return this;
941 | }
942 | public Builder numThreads(int numThreads) {
943 | this.numThreads = numThreads;
944 | return this;
945 | }
946 | public Builder classes(int classes) {
947 | this.classes = classes;
948 | return this;
949 | }
950 | public Builder layerOneSize(int layerOneSize) {
951 | this.layerOneSize = layerOneSize;
952 | return this;
953 | }
954 | }
955 |
956 | @SuppressWarnings({ "static-access" })
957 | public static void main(String[] args) {
958 | Builder b = new Builder();
959 | Options options = new Options();
960 | options.addOption(OptionBuilder.hasArg().withArgName("file")
961 | .withDescription("Use text data from to train the model").create("train"));
962 | options.addOption(OptionBuilder.hasArg().withArgName("file")
963 | .withDescription("Use to save the resulting word vectors / word clusters").create("output"));
964 | options.addOption(OptionBuilder.hasArg().withArgName("int")
965 | .withDescription("Set size of word vectors; default is " + b.layerOneSize).create("size"));
966 | options.addOption(OptionBuilder.hasArg().withArgName("int")
967 | .withDescription("Set max skip length between words; default is " + b.window).create("window"));
968 | options.addOption(OptionBuilder
969 | .hasArg()
970 | .withArgName("int")
971 | .withDescription(
972 | "Set threshold for occurrence of words (0=off). Those that appear with higher frequency in the training data will be randomly down-sampled; default is "
973 | + b.sample + ", useful value is 1e-5").create("sample"));
974 | options.addOption(new Option("noHs", false, "Disable use of Hierarchical Softmax; " + (b.noHs ? "off" : "on")
975 | + " by default"));
976 | options.addOption(OptionBuilder
977 | .hasArg()
978 | .withArgName("int")
979 | .withDescription(
980 | "Number of negative examples; default is " + b.negative
981 | + ", common values are 5 - 10 (0 = not used)").create("negative"));
982 | options.addOption(OptionBuilder.hasArg().withArgName("int")
983 | .withDescription("Use threads (default " + b.numThreads + ")").create("threads"));
984 | options.addOption(OptionBuilder.hasArg().withArgName("int")
985 | .withDescription("This will discard words that appear less than times; default is " + b.minCount)
986 | .create("minCount"));
987 | options.addOption(OptionBuilder.hasArg().withArgName("float")
988 | .withDescription("Set the starting learning rate; default is " + b.startingAlpha).create("startingAlpha"));
989 | options.addOption(OptionBuilder
990 | .hasArg()
991 | .withArgName("int")
992 | .withDescription(
993 | "Number of word classes to output, or 0 to output word vectors; default is " + b.classes)
994 | .create("classes"));
995 | options.addOption(new Option("binary", false, "Save the resulting vectors in binary moded; "
996 | + (b.binary ? "on" : "off") + " by default"));
997 | options.addOption(OptionBuilder.hasArg().withArgName("file")
998 | .withDescription("The vocabulary will be saved to ").create("saveVocab"));
999 | options.addOption(OptionBuilder.hasArg().withArgName("file")
1000 | .withDescription("The vocabulary will be read from , not constructed from the training data")
1001 | .create("readVocab"));
1002 | options.addOption(new Option("cbow", false, "Use the continuous bag of words model; " + (b.cbow ? "on" : "off")
1003 | + " by default (skip-gram model)"));
1004 |
1005 | CommandLineParser parser = new PosixParser();
1006 | try {
1007 | CommandLine cl = parser.parse(options, args);
1008 | if (cl.getOptions().length == 0) {
1009 | new HelpFormatter().printHelp(Word2Vec.class.getSimpleName(), options);
1010 | System.exit(0);
1011 | }
1012 | if (cl.hasOption("size")) {
1013 | b.layerOneSize = Integer.parseInt(cl.getOptionValue("size"));
1014 | }
1015 | if (cl.hasOption("train")) {
1016 | //b.trainFile = new File(cl.getOptionValue("train"));
1017 | b.trainFile = new File(input);
1018 | }
1019 | if (cl.hasOption("saveVocab")) {
1020 | b.saveVocabFile = new File(cl.getOptionValue("saveVocab"));
1021 | }
1022 | if (cl.hasOption("readVocab")) {
1023 | b.readVocabFile = new File(cl.getOptionValue("readVocab"));
1024 | }
1025 | if (cl.hasOption("binary")) {
1026 | b.binary = true;
1027 | }
1028 | if (cl.hasOption("cbow")) {
1029 | b.cbow = true;
1030 | }
1031 | if (cl.hasOption("startingAlpha")) {
1032 | b.startingAlpha = Float.parseFloat(cl.getOptionValue("startingAlpha"));
1033 | }
1034 | if (cl.hasOption("output")) {
1035 | //b.outputFile = new File(cl.getOptionValue("output"));
1036 | b.outputFile = new File(output);
1037 | }
1038 | if (cl.hasOption("window")) {
1039 | b.window = Integer.parseInt(cl.getOptionValue("window"));
1040 | }
1041 | if (cl.hasOption("sample")) {
1042 | b.sample = Float.parseFloat(cl.getOptionValue("sample"));
1043 | }
1044 | if (cl.hasOption("noHs")) {
1045 | b.noHs = true;
1046 | }
1047 | if (cl.hasOption("negative")) {
1048 | b.negative = Integer.parseInt(cl.getOptionValue("negative"));
1049 | }
1050 | if (cl.hasOption("threads")) {
1051 | b.numThreads = Integer.parseInt(cl.getOptionValue("threads"));
1052 | }
1053 | if (cl.hasOption("minCount")) {
1054 | b.minCount = Integer.parseInt(cl.getOptionValue("minCount"));
1055 | }
1056 | if (cl.hasOption("classes")) {
1057 | b.classes = Integer.parseInt(cl.getOptionValue("classes"));
1058 | }
1059 | } catch (Exception e) {
1060 | System.err.println("Parsing command-line arguments failed. Reason: " + e.getMessage());
1061 | new HelpFormatter().printHelp("word2vec", options);
1062 | System.exit(1);
1063 | }
1064 | Word2Vec word2vec = new Word2Vec(b);
1065 | word2vec.trainModel();
1066 | System.out.println("train finish");
1067 | System.exit(0);
1068 | }
1069 | }
1070 |
--------------------------------------------------------------------------------
/assets/arguments.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/assets/arguments.JPG
--------------------------------------------------------------------------------