├── LICENSE ├── README.md ├── Word2Vec ├── .classpath ├── .project ├── .settings │ └── org.eclipse.jdt.core.prefs ├── bin │ ├── Word2Vec$Builder.class │ ├── Word2Vec$VocabWord.class │ └── Word2Vec.class ├── lib │ ├── commons-cli-1.3.1.jar │ ├── commons-codec-1.9.jar │ ├── commons-logging-1.2.jar │ ├── slf4j-api-1.7.21.jar │ └── slf4j-simple-1.7.21.jar └── src │ └── Word2Vec.java └── assets └── arguments.JPG /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 KimJunho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Word2Vec In Java 2 | * https://code.google.com/archive/p/word2vec/source/default/source 3 | * Changed Word2vec c code to Java 4 | 5 | ## Usage 6 | * Put "Input.txt" in the folder containing the source code 7 | ```bash 8 | The contents of Input.txt are as follows 9 | There is one document per line 10 | All documents must be preprocessed 11 | ``` 12 | 13 | * Preprocessing: documents should be separated by words using morphemes 14 | * In Eclipse, you mush give arguments (Run - Run Configurations...) 15 | 16 | 17 | ![Arguments](./assets/arguments.JPG) 18 | 19 | 20 | * a = input.txt, b = output.txt... That is, the name of the input output file. 21 | * but, in Code I hava set it (Line 34, 35) 22 | 23 | ## Contents of "Input.txt" after preprocessing 24 | * Document 1 : KimJunho is interested in machine learning and deep learning 25 | * Document 2 : KimJunho is interested in recruiting professional researches 26 | ```bash 27 | KimJunho isterested machine learning deep learning 28 | KimJunho recruiting professional researchers 29 | ``` 30 | 31 | ## Main Variable Description 32 | * See Line 894 (public static class Builder) 33 | ```java 34 | 1. cbow = false 35 |   Which of the cbow and skip-gram models to learn ? 36 | false : use skip gram 37 | true : use cbow model 38 | 39 | 2. startingAlpha = 0.025F 40 | This is a learningrate 41 | The smaller the value, the more accurate the learning, but the slower the learning speed 42 | 43 | 3. window = 5 44 | How many words to look around when learning 45 | The default value is 5, meaning that you see 5 words 46 | 47 | 4. negative = 0 48 | It can be used to improve the efficiency of calculation speed 49 | Methodology has Hierarchical Softmax and Negative Sampling 50 | If 0, Hierarchical Softmax 51 | else, Negative Sampling.. default value 5~10 52 | 53 | 5. minCount = 5 54 | Meaning that I will only see words from at least a few words in the document 55 | If you want to learn every word, minCount = 0 56 | 57 | 6. layerOneSize = 200 58 | Mean dimension of word vector 59 | default value is 200 60 | The higher the dimension, the more precise it is, but the learning speed is slower 61 | ``` 62 | -------------------------------------------------------------------------------- /Word2Vec/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /Word2Vec/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | Word2Vec 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | 15 | org.eclipse.jdt.core.javanature 16 | 17 | 18 | -------------------------------------------------------------------------------- /Word2Vec/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=1.8 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 11 | org.eclipse.jdt.core.compiler.source=1.8 12 | -------------------------------------------------------------------------------- /Word2Vec/bin/Word2Vec$Builder.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/bin/Word2Vec$Builder.class -------------------------------------------------------------------------------- /Word2Vec/bin/Word2Vec$VocabWord.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/bin/Word2Vec$VocabWord.class -------------------------------------------------------------------------------- /Word2Vec/bin/Word2Vec.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/bin/Word2Vec.class -------------------------------------------------------------------------------- /Word2Vec/lib/commons-cli-1.3.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/commons-cli-1.3.1.jar -------------------------------------------------------------------------------- /Word2Vec/lib/commons-codec-1.9.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/commons-codec-1.9.jar -------------------------------------------------------------------------------- /Word2Vec/lib/commons-logging-1.2.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/commons-logging-1.2.jar -------------------------------------------------------------------------------- /Word2Vec/lib/slf4j-api-1.7.21.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/slf4j-api-1.7.21.jar -------------------------------------------------------------------------------- /Word2Vec/lib/slf4j-simple-1.7.21.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/Word2Vec/lib/slf4j-simple-1.7.21.jar -------------------------------------------------------------------------------- /Word2Vec/src/Word2Vec.java: -------------------------------------------------------------------------------- 1 | 2 | 3 | import java.io.DataInput; 4 | import java.io.DataInputStream; 5 | import java.io.DataOutputStream; 6 | import java.io.EOFException; 7 | import java.io.File; 8 | import java.io.FileInputStream; 9 | import java.io.FileOutputStream; 10 | import java.io.FileWriter; 11 | import java.io.IOException; 12 | import java.io.RandomAccessFile; 13 | import java.text.DecimalFormat; 14 | import java.text.NumberFormat; 15 | import java.util.ArrayList; 16 | import java.util.Arrays; 17 | import java.util.List; 18 | 19 | import org.apache.commons.cli.CommandLine; 20 | import org.apache.commons.cli.CommandLineParser; 21 | import org.apache.commons.cli.HelpFormatter; 22 | import org.apache.commons.cli.Option; 23 | import org.apache.commons.cli.OptionBuilder; 24 | import org.apache.commons.cli.Options; 25 | import org.apache.commons.cli.PosixParser; 26 | import org.slf4j.Logger; 27 | import org.slf4j.LoggerFactory; 28 | 29 | 30 | 31 | @SuppressWarnings("deprecation") 32 | public class Word2Vec { 33 | private static final Logger log = LoggerFactory.getLogger(Word2Vec.class); 34 | private static final String input = "Input.txt"; 35 | private static final String output = "Output.txt"; 36 | 37 | class VocabWord implements Comparable { 38 | VocabWord(String word) { 39 | this.word = word; 40 | } 41 | int cn = 0; 42 | int codelen; 43 | int[] point = new int[MAX_CODE_LENGTH]; 44 | long[] code = new long[MAX_CODE_LENGTH]; 45 | String word; 46 | 47 | @Override 48 | public int compareTo(VocabWord that) { 49 | if(that==null) { 50 | return 1; 51 | } 52 | 53 | return that.cn - this.cn; 54 | } 55 | @Override 56 | public String toString() { 57 | return this.cn + ": " + this.word; 58 | } 59 | } 60 | 61 | private static final int MAX_STRING = 100; 62 | private static final int EXP_TABLE_SIZE= 1000; 63 | private static final int MAX_EXP= 6; 64 | private static final int MAX_SENTENCE_LENGTH= 1000; 65 | private static final int MAX_CODE_LENGTH= 40; 66 | private static final int TABLE_SIZE = 100000000; 67 | 68 | // Maximum 30 * 0.7 = 21M words in the vocabulary 69 | private static final int VOCAB_HASH_SIZE = 30000000; 70 | 71 | private final int layerOneSize; 72 | private final File trainFile; 73 | private final File outputFile; 74 | private final File saveVocabFile; 75 | private final File readVocabFile; 76 | private final int window; 77 | private final int negative; 78 | private final int minCount; 79 | private final int numThreads; 80 | private final int classes; 81 | private final boolean binary; 82 | private final boolean cbow; 83 | private final boolean noHs; 84 | private final float startingAlpha; 85 | private final float sample; 86 | private final float[] expTable; 87 | 88 | private int minReduce = 1; 89 | private int vocabMaxSize = 1000; 90 | private VocabWord[] vocabWords = new VocabWord[vocabMaxSize]; 91 | private int[] vocabHash = new int[VOCAB_HASH_SIZE]; 92 | private Byte ungetc = null; 93 | 94 | private int vocabSize = 0; 95 | private long trainWords = 0; 96 | private long wordCountActual = 0; 97 | private int[] table; 98 | 99 | private float alpha; 100 | 101 | private float[] syn0; 102 | private float[] syn1; 103 | private float[] syn1neg; 104 | 105 | private long start; 106 | 107 | public Word2Vec(Builder b) { 108 | this.trainFile = b.trainFile; 109 | this.outputFile = b.outputFile; 110 | this.saveVocabFile = b.saveVocabFile; 111 | this.readVocabFile = b.readVocabFile; 112 | this.binary = b.binary; 113 | this.cbow = b.cbow; 114 | this.noHs = b.noHs; 115 | this.startingAlpha = b.startingAlpha; 116 | this.sample = b.sample; 117 | this.window = b.window; 118 | this.negative = b.negative; 119 | this.minCount = b.minCount; 120 | this.numThreads = b.numThreads; 121 | this.classes = b.classes; 122 | this.layerOneSize = b.layerOneSize; 123 | 124 | float[] tempExpTable = new float[EXP_TABLE_SIZE]; 125 | for (int i = 0; i < tempExpTable.length; i++) { 126 | // Precompute the exp() table 127 | tempExpTable[i] = (float) Math.exp((i / (float) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP); 128 | // Precompute f(x) = x / (x + 1) 129 | tempExpTable[i] = tempExpTable[i] / (tempExpTable[i] + 1); 130 | } 131 | expTable = tempExpTable; 132 | } 133 | private void readVocab() throws IOException { 134 | vocabSize = 0; 135 | try(DataInputStream is = new DataInputStream(new FileInputStream(readVocabFile))) { 136 | String word; 137 | while((word = readWord(is)) != null) { 138 | int a = addWordToVocab(word); 139 | vocabWords[a].cn = is.readInt(); 140 | is.readChar(); 141 | } 142 | sortVocab(); 143 | log.debug("Vocab size: {}", vocabSize); 144 | log.debug("Words in train file: {}", trainWords); 145 | } catch(IOException ioe) { 146 | throw ioe; 147 | } 148 | } 149 | 150 | private void learnVocabFromTrainFile() throws IOException { 151 | for (int a = 0; a < VOCAB_HASH_SIZE; a++) { 152 | vocabHash[a] = -1; 153 | } 154 | vocabSize = 0; 155 | addWordToVocab(""); 156 | try (DataInputStream is = new DataInputStream(new FileInputStream(trainFile))) { 157 | while (true) { 158 | String word = readWord(is); 159 | if (word==null) { 160 | break; 161 | } 162 | trainWords++; 163 | if(log.isTraceEnabled() && trainWords % 100000 == 0) { 164 | log.trace("{}K training words processed.", (trainWords/1000)); 165 | } 166 | int i = searchVocab(word); 167 | if (i == -1) { 168 | i = addWordToVocab(word); 169 | vocabWords[i].cn = 1; 170 | } else 171 | vocabWords[i].cn++; 172 | if (vocabSize > VOCAB_HASH_SIZE * 0.7) { 173 | reduceVocab(); 174 | } 175 | } 176 | } catch (IOException ioe) { 177 | throw ioe; 178 | } 179 | sortVocab(); 180 | log.debug("Vocab size: {}", vocabSize); 181 | log.debug("Words in train file: {}", trainWords); 182 | } 183 | 184 | private void saveVocab() throws IOException { 185 | saveVocabFile.delete(); 186 | try (FileWriter fw = new FileWriter(saveVocabFile)) { 187 | //Don't output the , at element zero. 188 | for (int i = 1; i < vocabSize; i++) { 189 | fw.write(vocabWords[i].word); 190 | fw.write(" "); 191 | fw.write("" + vocabWords[i].cn); 192 | fw.write("\n"); 193 | } 194 | } 195 | } 196 | private void initNet() { 197 | syn0 = new float[vocabSize * layerOneSize]; 198 | if(!noHs) { 199 | syn1 = new float[vocabSize * layerOneSize]; 200 | for(int b=0 ; b < layerOneSize ; b++) { 201 | for(int a=0 ; a0) { 207 | syn1neg = new float[vocabSize * layerOneSize]; 208 | for(int b=0 ; b < layerOneSize ; b++) { 209 | for(int a=0 ; a= 0) { 244 | if (count[pos1] < count[pos2]) { 245 | min1i = pos1; 246 | pos1--; 247 | } else { 248 | min1i = pos2; 249 | pos2++; 250 | } 251 | } else { 252 | min1i = pos2; 253 | pos2++; 254 | } 255 | if (pos1 >= 0) { 256 | if (count[pos1] < count[pos2]) { 257 | min2i = pos1; 258 | pos1--; 259 | } else { 260 | min2i = pos2; 261 | pos2++; 262 | } 263 | } else { 264 | min2i = pos2; 265 | pos2++; 266 | } 267 | count[vocabSize + a] = count[min1i] + count[min2i]; 268 | parentNode[min1i] = vocabSize + a; 269 | parentNode[min2i] = vocabSize + a; 270 | binary[min2i] = 1; 271 | } 272 | 273 | // Now assign binary code to each vocabulary word 274 | long[] code = new long[MAX_CODE_LENGTH]; 275 | int[] point = new int[MAX_CODE_LENGTH]; 276 | for (int a = 0; a < vocabSize; a++) { 277 | int b = a; 278 | int i = 0; 279 | while (true) { 280 | code[i] = binary[b]; 281 | point[i] = b; 282 | i++; 283 | b = parentNode[b]; 284 | if (b == vocabSize * 2 - 2) 285 | break; 286 | } 287 | vocabWords[a].codelen = i; 288 | vocabWords[a].point[0] = vocabSize - 2; 289 | for (b = 0; b < i; b++) { 290 | vocabWords[a].code[i - b - 1] = code[b]; 291 | vocabWords[a].point[i - b] = point[b] - vocabSize; 292 | } 293 | } 294 | } 295 | private void initUnigramTable() { 296 | long trainWordsPow = 0; 297 | float power = 0.75F; 298 | for (int a = 0; a < vocabSize; a++) { 299 | trainWordsPow += Math.pow(vocabWords[a].cn, power); 300 | } 301 | int i = 0; 302 | float d1 = (float) Math.pow(vocabWords[i].cn, power) / (float) trainWordsPow; 303 | for (int a = 0; a < TABLE_SIZE; a++) { 304 | table[a] = i; 305 | if (a / (float) TABLE_SIZE > d1) { 306 | i++; 307 | d1 += Math.pow(vocabWords[i].cn, power) / (float) trainWordsPow; 308 | } 309 | if (i >= vocabSize) { 310 | i = vocabSize - 1; 311 | } 312 | } 313 | } 314 | 315 | //DataOutputStream#writeFloat writes the high byte first 316 | //but let's write the low byte first to give ourselves a better chance of 317 | //compatibility with the original c++ code 318 | private void writeFloat(float f, DataOutputStream out) throws IOException { 319 | int v = Float.floatToIntBits(f); 320 | out.write((v >>> 0) & 0xFF); 321 | out.write((v >>> 8) & 0xFF); 322 | out.write((v >>> 16) & 0xFF); 323 | out.write((v >>> 24) & 0xFF); 324 | } 325 | 326 | public void trainModel() { 327 | if(trainFile==null && readVocabFile==null) { 328 | throw new IllegalStateException("You must supply either a trainFile or a readVocabFile."); 329 | } 330 | alpha = startingAlpha; 331 | if(readVocabFile!=null) { 332 | try { 333 | log.info("Reading vocabulary from file {}.", readVocabFile); 334 | readVocab(); 335 | } catch(IOException ioe) { 336 | log.error("There was a problem reading the vocabulary file.", ioe); 337 | return; 338 | } 339 | } else { 340 | log.info("Starting training using file {}.", trainFile); 341 | try { 342 | learnVocabFromTrainFile(); 343 | } catch(IOException ioe) { 344 | log.error("There was a problem reading the training file.", ioe); 345 | return; 346 | } 347 | } 348 | if(saveVocabFile!=null) { 349 | try { 350 | saveVocab(); 351 | } catch(IOException ioe) { 352 | log.error("There was a problem writing the vocabulary file.", ioe); 353 | return; 354 | } 355 | } 356 | if(outputFile==null) { 357 | return; 358 | } 359 | initNet(); 360 | if(negative>0) { 361 | initUnigramTable(); 362 | } 363 | start = System.nanoTime(); 364 | //TODO: theads 365 | try { 366 | trainModelThread(0); 367 | } catch(IOException ioe) { 368 | log.error("There was a problem reading the training file.", ioe); 369 | return; 370 | } 371 | outputFile.delete(); 372 | NumberFormat vectorTextFormat = new DecimalFormat("#.######"); 373 | try(DataOutputStream os = new DataOutputStream(new FileOutputStream(outputFile))) { 374 | if(classes==0) { 375 | // Save the word vectors 376 | os.writeBytes("" + vocabSize + " " + layerOneSize + "\n"); 377 | for (int a = 0; a < vocabSize; a++) { 378 | os.writeBytes(vocabWords[a].word); 379 | os.writeBytes(" "); 380 | if (binary) { 381 | for (int b = 0; b < layerOneSize; b++) { 382 | writeFloat(syn0[a * layerOneSize + b], os); 383 | } 384 | } else { 385 | for (int b = 0; b < layerOneSize; b++) { 386 | int index = a * layerOneSize + b; 387 | float value = syn0[index]; 388 | os.writeBytes(vectorTextFormat.format(value) + " "); 389 | } 390 | } 391 | os.writeBytes("\n"); 392 | 393 | } 394 | os.writeBytes("\n"); 395 | } else { 396 | // Run K-means on the word vectors 397 | if(classes*layerOneSize > Integer.MAX_VALUE) { 398 | throw new RuntimeException("Number of classes times the size of Layer One cannot be greater than " + Integer.MAX_VALUE + " (" + classes + " * " + layerOneSize + ")"); 399 | } 400 | int[] cl = new int[vocabSize]; 401 | float[] cent = new float[classes * layerOneSize]; 402 | int[] centcn = new int[classes]; 403 | int numIterations = 10; 404 | 405 | for (int a = 0; a < vocabSize; a++) { 406 | cl[a] = a % classes; 407 | } 408 | for(int a = 0; a closev) { 441 | closev = x; 442 | closeid = d; 443 | } 444 | } 445 | cl[c] = closeid; 446 | } 447 | } 448 | // Save the K-means classes 449 | for(int a=0 ; a< vocabSize ; a++) { 450 | os.writeBytes(vocabWords[a].word); 451 | os.writeBytes(" "); 452 | os.writeInt(cl[a]); 453 | } 454 | } 455 | } catch(IOException ioe) { 456 | log.error("There was a problem writing the output file", ioe); 457 | return; 458 | } 459 | } 460 | private void trainModelThread(int id) throws IOException { 461 | try(RandomAccessFile raf = new RandomAccessFile(trainFile, "rw")) { 462 | if(id>0) { 463 | raf.seek(raf.length() / (numThreads * id)); 464 | } 465 | long wordCount = 0; 466 | long lastWordCount = 0; 467 | int word = 0; 468 | int target = 0; 469 | int label = 0; 470 | int sentenceLength = 0; 471 | int sentencePosition = 0; 472 | int nextRandom = id; 473 | int[] sen = new int[MAX_SENTENCE_LENGTH + 1]; 474 | float[] neu1 = new float[layerOneSize]; 475 | float[] neu1e = new float[layerOneSize]; 476 | 477 | NumberFormat alphaFormat = new DecimalFormat("0.000000"); 478 | NumberFormat logPercentFormat = new DecimalFormat("#0.00%"); 479 | NumberFormat wordsPerSecondFormat = new DecimalFormat("00.00k"); 480 | long now = System.nanoTime(); 481 | while(true) { 482 | if (wordCount - lastWordCount > 10000) { 483 | wordCountActual += wordCount - lastWordCount; 484 | lastWordCount = wordCount; 485 | if (log.isTraceEnabled()) { 486 | now = System.nanoTime(); 487 | log.trace("Alpha: {}", alphaFormat.format(alpha)); 488 | log.trace("Progress: {} ", logPercentFormat.format((float) wordCountActual / (trainWords + 1))); 489 | log.trace( 490 | "Words/thread/sec: {}\n", 491 | wordsPerSecondFormat.format((float) wordCountActual / (float) (now - start + 1) 492 | * 1000000)); 493 | } 494 | alpha = startingAlpha * (1 - wordCountActual / (float) (trainWords + 1)); 495 | if (alpha < startingAlpha * 0.0001F) { 496 | alpha = startingAlpha * 0.0001F; 497 | } 498 | } 499 | if(sentenceLength==0) { 500 | while(true) { 501 | word = readWordIndex(raf); 502 | if(word==-1) { 503 | break; 504 | } 505 | wordCount++; 506 | if(word==0) { 507 | break; 508 | } 509 | // The subsampling randomly discards frequent words while keeping the ranking same 510 | if (sample > 0) { 511 | float ran = (float) (Math.sqrt(vocabWords[word].cn / (sample * trainWords)) + 1) * (sample * trainWords) / vocabWords[word].cn; 512 | nextRandom = (int) (nextRandom * 25214903917L + 11); 513 | if (ran < ((nextRandom & 0xFFFF) / (float) 65536)) { 514 | continue; 515 | } 516 | } 517 | sen[sentenceLength] = word; 518 | sentenceLength++; 519 | if (sentenceLength >= MAX_SENTENCE_LENGTH) { 520 | break; 521 | } 522 | } 523 | sentencePosition = 0; 524 | } 525 | if(raf.getFilePointer()==raf.length()) { 526 | break; 527 | } 528 | if(wordCount > trainWords / numThreads) { 529 | break; 530 | } 531 | word = sen[sentencePosition]; 532 | for(int c=0 ; c hidden 540 | for(int a = b ; b < window * 2 + 1 - b ; a++) { 541 | if(a != window) { 542 | int c = sentencePosition - window + a; 543 | if (c < 0) { 544 | continue; 545 | } 546 | if (c >= sentenceLength) { 547 | continue; 548 | } 549 | int lastWord = sen[c]; 550 | for (c = 0; c < layerOneSize; c++) { 551 | neu1[c] += syn0[c + lastWord * layerOneSize]; 552 | } 553 | } 554 | } 555 | if (!noHs) { 556 | for(int d=0 ; d< vocabWords[word].codelen ; d++) { 557 | float f=0; 558 | int l2 = vocabWords[word].point[d] * layerOneSize; 559 | // Propagate hidden -> output 560 | for (int c = 0; c < layerOneSize; c++) { 561 | f += neu1[c] * syn1[c + l2]; 562 | } 563 | if (f <= -1 * MAX_EXP || f >= MAX_EXP) { 564 | continue; 565 | } 566 | f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 567 | // 'g' is the gradient multiplied by the learning rate 568 | float g = (1 - vocabWords[word].code[d] - f) * alpha; 569 | // Propagate errors output -> hidden 570 | for (int c = 0; c < layerOneSize; c++) { 571 | neu1e[c] += g * syn1[c + l2]; 572 | } 573 | // Learn weights hidden -> output 574 | for (int c = 0; c < layerOneSize; c++) { 575 | syn1[c + l2] += g * neu1[c]; 576 | } 577 | } 578 | } 579 | // NEGATIVE SAMPLING 580 | if (negative > 0) { 581 | for (int d = 0; d < negative + 1; d++) { 582 | if (d == 0) { 583 | target = word; 584 | label = 1; 585 | } else { 586 | nextRandom = (int) (nextRandom * 25214903917L + 11); 587 | target = table[(nextRandom >> 16) % TABLE_SIZE]; 588 | if (target == 0) { 589 | target = nextRandom % (vocabSize - 1) + 1; 590 | } 591 | if (target == word) { 592 | continue; 593 | } 594 | label = 0; 595 | } 596 | int l2 = target * layerOneSize; 597 | int f = 0; 598 | for (int c = 0; c < layerOneSize; c++) { 599 | f += neu1[c] * syn1neg[c + l2]; 600 | } 601 | float g; 602 | if (f > MAX_EXP) { 603 | g = (label - 1) * alpha; 604 | } else if (f < -MAX_EXP) { 605 | g = (label - 0) * alpha; 606 | } else { 607 | g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha; 608 | } 609 | for (int c = 0; c < layerOneSize; c++) { 610 | neu1e[c] += g * syn1neg[c + l2]; 611 | } 612 | for (int c = 0; c < layerOneSize; c++) { 613 | syn1neg[c + l2] += g * neu1[c]; 614 | } 615 | } 616 | } 617 | // hidden -> in 618 | for (int a = b; a < window * 2 + 1 - b; a++) { 619 | if (a != window) { 620 | int c = sentencePosition - window + a; 621 | if (c < 0 || c >= sentenceLength) { 622 | continue; 623 | } 624 | int lastWord = sen[c]; 625 | for (c = 0; c < layerOneSize; c++) { 626 | syn0[c + lastWord * layerOneSize] += neu1e[c]; 627 | } 628 | } 629 | } 630 | } else { //train skip-gram 631 | for (int a = b; a < window * 2 + 1 - b; a++) { 632 | if (a != window) { 633 | int lastWordIndex = sentencePosition - window + a; 634 | if (lastWordIndex < 0 || lastWordIndex >= sentenceLength) { 635 | continue; 636 | } 637 | int lastWord = sen[lastWordIndex]; 638 | int l1 = lastWord * layerOneSize; 639 | for (int c = 0; c < layerOneSize; c++) { 640 | neu1e[c] = 0; 641 | } 642 | // HIERARCHICAL SOFTMAX 643 | if (!noHs) { 644 | for (int d = 0; d < vocabWords[word].codelen; d++) { 645 | float f = 0; 646 | int l2 = vocabWords[word].point[d] * layerOneSize; 647 | // Propagate hidden -> output 648 | for (int c = 0; c < layerOneSize; c++) { 649 | f += syn0[c + l1] * syn1[c + l2]; 650 | } 651 | if (f <= -MAX_EXP || f >= MAX_EXP) { 652 | continue; 653 | } 654 | f = expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 655 | // 'g' is the gradient multiplied by the learning rate 656 | float g = (1 - vocabWords[word].code[d] - f) * alpha; 657 | // Propagate errors output -> hidden 658 | for (int c = 0; c < layerOneSize; c++) { 659 | neu1e[c] += g * syn1[c + l2]; 660 | } 661 | // Learn weights hidden -> output 662 | for (int c = 0; c < layerOneSize; c++) { 663 | syn1[c + l2] += g * syn0[c + l1]; 664 | } 665 | } 666 | } 667 | // NEGATIVE SAMPLING 668 | if (negative > 0) { 669 | for (int d = 0; d < negative + 1; d++) { 670 | if (d == 0) { 671 | target = word; 672 | label = 1; 673 | } else { 674 | nextRandom = (int) (nextRandom * 25214903917L + 11); 675 | target = table[(nextRandom >> 16) % TABLE_SIZE]; 676 | if (target == 0) { 677 | target = nextRandom % (vocabSize - 1) + 1; 678 | } 679 | if (target == word) { 680 | continue; 681 | } 682 | label = 0; 683 | } 684 | int l2 = target * layerOneSize; 685 | int f = 0; 686 | for (int c = 0; c < layerOneSize; c++) { 687 | f += syn0[c + l1] * syn1neg[c + l2]; 688 | } 689 | float g; 690 | if (f > MAX_EXP) { 691 | g = (label - 1) * alpha; 692 | } else if (f < -MAX_EXP) { 693 | g = (label - 0) * alpha; 694 | } else { 695 | g = (label - expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]) * alpha; 696 | } 697 | for (int c = 0; c < layerOneSize; c++) { 698 | neu1e[c] += g * syn1neg[c + l2]; 699 | } 700 | for (int c = 0; c < layerOneSize; c++) { 701 | syn1neg[c + l2] += g * syn0[c + l1]; 702 | } 703 | } 704 | } 705 | // Learn weights input -> hidden 706 | for (int c = 0; c < layerOneSize; c++) { 707 | syn0[c + l1] += neu1e[c]; 708 | } 709 | } 710 | } 711 | } 712 | sentencePosition++; 713 | if (sentencePosition >= sentenceLength) { 714 | sentenceLength = 0; 715 | continue; 716 | } 717 | } 718 | } catch(IOException ioe) { 719 | throw ioe; 720 | } 721 | } 722 | 723 | // Reduces the vocabulary by removing infrequent tokens 724 | private void reduceVocab() { 725 | int b=0; 726 | for (int a = 0; a < vocabSize; a++) { 727 | if (vocabWords[a].cn > minReduce) { 728 | vocabWords[b].cn = vocabWords[a].cn; 729 | vocabWords[b].word = vocabWords[a].word; 730 | b++; 731 | } 732 | } 733 | vocabSize = b; 734 | for (int a = 0; a < VOCAB_HASH_SIZE; a++) { 735 | vocabHash[a] = -1; 736 | } 737 | for (int a = 0; a < vocabSize; a++) { 738 | // Hash will be re-computed, as it is not actual 739 | int hash = getWordHash(vocabWords[a].word); 740 | while (vocabHash[hash] != -1) { 741 | hash = (hash + 1) % VOCAB_HASH_SIZE; 742 | hash = Math.abs(hash); 743 | } 744 | vocabHash[hash] = a; 745 | } 746 | minReduce++; 747 | } 748 | // Sorts the vocabulary by frequency using word counts 749 | private void sortVocab() { 750 | // Sort the vocabulary and keep at the first position 751 | Arrays.sort(vocabWords, 1, vocabSize - 1); 752 | for (int a = 0; a < vocabHash.length; a++) { 753 | vocabHash[a] = -1; 754 | } 755 | 756 | trainWords = 0; 757 | int originalVocabSize = vocabSize; 758 | List wordList = new ArrayList(originalVocabSize); 759 | int aa=0; 760 | for (int a = 0; a < originalVocabSize; a++) { 761 | VocabWord vw = vocabWords[a]; 762 | // Words occurring less than min_count times will be discarded from the vocab 763 | if (vw.cn < minCount && vw.cn > 0) { 764 | vocabSize--; 765 | } else { 766 | // Hash will be re-computed, as after the sorting it is not actual 767 | int hash = getWordHash(vw.word); 768 | while (vocabHash[hash] != -1) { 769 | hash = (hash + 1) % VOCAB_HASH_SIZE; 770 | hash = Math.abs(hash); 771 | } 772 | vocabHash[hash] = aa; 773 | trainWords += vw.cn; 774 | wordList.add(vw); 775 | aa++; 776 | } 777 | } 778 | vocabWords = wordList.toArray(new VocabWord[wordList.size()]); 779 | } 780 | 781 | private int addWordToVocab(String word) { 782 | int length = word.length() + 1; 783 | if(length > MAX_STRING) { 784 | length = MAX_STRING; 785 | } 786 | vocabWords[vocabSize] = new VocabWord(word); 787 | vocabSize++; 788 | 789 | // Reallocate memory if needed 790 | if (vocabSize + 2 >= vocabMaxSize) { 791 | vocabMaxSize += 1000; 792 | VocabWord[] vocabWords1 = new VocabWord[vocabMaxSize]; 793 | System.arraycopy(vocabWords, 0, vocabWords1, 0, vocabWords.length); 794 | vocabWords = vocabWords1; 795 | } 796 | int hash = getWordHash(word); 797 | while (vocabHash[hash] != -1) { 798 | hash = (hash + 1) % VOCAB_HASH_SIZE; 799 | hash = Math.abs(hash); 800 | } 801 | vocabHash[hash] = vocabSize - 1; 802 | return vocabSize - 1; 803 | } 804 | private int getWordHash(String word) { 805 | int hash = 0; 806 | for (int a = 0; a < word.length(); a++) { 807 | hash = hash * 257 + word.charAt(a); 808 | } 809 | hash = hash % VOCAB_HASH_SIZE; 810 | return Math.abs(hash); 811 | } 812 | 813 | // Returns position of a word in the vocabulary; if the word is not found, returns -1 814 | private int searchVocab(String word) { 815 | int hash = getWordHash(word); 816 | while (true) { 817 | if (vocabHash[hash] == -1) { 818 | return -1; 819 | } 820 | if(word.equals(vocabWords[vocabHash[hash]].word)) { 821 | return vocabHash[hash]; 822 | } 823 | hash = (hash + 1) % VOCAB_HASH_SIZE; 824 | hash = Math.abs(hash); 825 | } 826 | } 827 | 828 | private int readWordIndex(RandomAccessFile raf) throws IOException { 829 | String word = readWord(raf); 830 | if(raf.length()==raf.getFilePointer()) { 831 | return -1; 832 | } 833 | return searchVocab(word); 834 | } 835 | 836 | private String readWord(DataInput dataInput) throws IOException { 837 | StringBuilder sb = new StringBuilder(); 838 | while(true) { 839 | byte ch; 840 | if(ungetc != null) { 841 | ch = ungetc; 842 | ungetc = null; 843 | } else { 844 | try { 845 | ch = dataInput.readByte(); 846 | } catch(EOFException eofe) { 847 | break; 848 | } 849 | } 850 | if(ch=='\r') { 851 | continue; 852 | } 853 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) { 854 | if (sb.length()>0 ) { 855 | if(ch == '\n') { 856 | ungetc = ch; 857 | } 858 | break; 859 | } 860 | if (ch == '\n') { 861 | return ""; 862 | } else { 863 | continue; 864 | } 865 | } 866 | sb.append((char) ch); 867 | 868 | // Truncate too long words 869 | if (sb.length() >= MAX_STRING - 1) { 870 | sb.deleteCharAt(sb.length()-1); 871 | } 872 | } 873 | String word = sb.length()==0 ? null : sb.toString(); 874 | return word; 875 | } 876 | // 변수설명 877 | public static class Builder { 878 | private File trainFile = null; 879 | private File outputFile = null; 880 | private File saveVocabFile = null; 881 | private File readVocabFile = null; 882 | private boolean binary = false; 883 | private boolean cbow = false; // if false, use skip-gram model.. else, use cbow model 884 | private boolean noHs = false; 885 | private float startingAlpha = 0.025F; //0.025F, learning rate 886 | private float sample = 0.0F; 887 | private int window = 5; // Surrounding word 888 | private int negative = 0; // If negative = 0, use Hierarchical Softmax ... 889 | // default negative = 5 ~ 10 890 | private int minCount = 5; // mincount word, if you use all word then minCount = 0 891 | private int numThreads = 1; 892 | private int classes = 0; 893 | private int layerOneSize = 200; // vecter size 894 | public Builder trainFile(String trainFile) { 895 | this.trainFile = new File(trainFile); 896 | return this; 897 | } 898 | public Builder outputFile(String outputFile) { 899 | this.outputFile = new File(outputFile); 900 | return this; 901 | } 902 | public Builder saveVocabFile(String saveVocabFile) { 903 | this.saveVocabFile = new File(saveVocabFile); 904 | return this; 905 | } 906 | public Builder readVocabFile(String readVocabFile) { 907 | this.readVocabFile = new File(readVocabFile); 908 | return this; 909 | } 910 | public Builder binary() { 911 | this.binary = true; 912 | return this; 913 | } 914 | public Builder cbow() { 915 | this.cbow = true; 916 | return this; 917 | } 918 | public Builder noHs() { 919 | this.noHs = true; 920 | return this; 921 | } 922 | public Builder startingAlpha(float startingAlpha) { 923 | this.startingAlpha = startingAlpha; 924 | return this; 925 | } 926 | public Builder sample(float sample) { 927 | this.sample = sample; 928 | return this; 929 | } 930 | public Builder window(int window) { 931 | this.window = window; 932 | return this; 933 | } 934 | public Builder negative(int negative) { 935 | this.negative = negative; 936 | return this; 937 | } 938 | public Builder minCount(int minCount) { 939 | this.minCount = minCount; 940 | return this; 941 | } 942 | public Builder numThreads(int numThreads) { 943 | this.numThreads = numThreads; 944 | return this; 945 | } 946 | public Builder classes(int classes) { 947 | this.classes = classes; 948 | return this; 949 | } 950 | public Builder layerOneSize(int layerOneSize) { 951 | this.layerOneSize = layerOneSize; 952 | return this; 953 | } 954 | } 955 | 956 | @SuppressWarnings({ "static-access" }) 957 | public static void main(String[] args) { 958 | Builder b = new Builder(); 959 | Options options = new Options(); 960 | options.addOption(OptionBuilder.hasArg().withArgName("file") 961 | .withDescription("Use text data from to train the model").create("train")); 962 | options.addOption(OptionBuilder.hasArg().withArgName("file") 963 | .withDescription("Use to save the resulting word vectors / word clusters").create("output")); 964 | options.addOption(OptionBuilder.hasArg().withArgName("int") 965 | .withDescription("Set size of word vectors; default is " + b.layerOneSize).create("size")); 966 | options.addOption(OptionBuilder.hasArg().withArgName("int") 967 | .withDescription("Set max skip length between words; default is " + b.window).create("window")); 968 | options.addOption(OptionBuilder 969 | .hasArg() 970 | .withArgName("int") 971 | .withDescription( 972 | "Set threshold for occurrence of words (0=off). Those that appear with higher frequency in the training data will be randomly down-sampled; default is " 973 | + b.sample + ", useful value is 1e-5").create("sample")); 974 | options.addOption(new Option("noHs", false, "Disable use of Hierarchical Softmax; " + (b.noHs ? "off" : "on") 975 | + " by default")); 976 | options.addOption(OptionBuilder 977 | .hasArg() 978 | .withArgName("int") 979 | .withDescription( 980 | "Number of negative examples; default is " + b.negative 981 | + ", common values are 5 - 10 (0 = not used)").create("negative")); 982 | options.addOption(OptionBuilder.hasArg().withArgName("int") 983 | .withDescription("Use threads (default " + b.numThreads + ")").create("threads")); 984 | options.addOption(OptionBuilder.hasArg().withArgName("int") 985 | .withDescription("This will discard words that appear less than times; default is " + b.minCount) 986 | .create("minCount")); 987 | options.addOption(OptionBuilder.hasArg().withArgName("float") 988 | .withDescription("Set the starting learning rate; default is " + b.startingAlpha).create("startingAlpha")); 989 | options.addOption(OptionBuilder 990 | .hasArg() 991 | .withArgName("int") 992 | .withDescription( 993 | "Number of word classes to output, or 0 to output word vectors; default is " + b.classes) 994 | .create("classes")); 995 | options.addOption(new Option("binary", false, "Save the resulting vectors in binary moded; " 996 | + (b.binary ? "on" : "off") + " by default")); 997 | options.addOption(OptionBuilder.hasArg().withArgName("file") 998 | .withDescription("The vocabulary will be saved to ").create("saveVocab")); 999 | options.addOption(OptionBuilder.hasArg().withArgName("file") 1000 | .withDescription("The vocabulary will be read from , not constructed from the training data") 1001 | .create("readVocab")); 1002 | options.addOption(new Option("cbow", false, "Use the continuous bag of words model; " + (b.cbow ? "on" : "off") 1003 | + " by default (skip-gram model)")); 1004 | 1005 | CommandLineParser parser = new PosixParser(); 1006 | try { 1007 | CommandLine cl = parser.parse(options, args); 1008 | if (cl.getOptions().length == 0) { 1009 | new HelpFormatter().printHelp(Word2Vec.class.getSimpleName(), options); 1010 | System.exit(0); 1011 | } 1012 | if (cl.hasOption("size")) { 1013 | b.layerOneSize = Integer.parseInt(cl.getOptionValue("size")); 1014 | } 1015 | if (cl.hasOption("train")) { 1016 | //b.trainFile = new File(cl.getOptionValue("train")); 1017 | b.trainFile = new File(input); 1018 | } 1019 | if (cl.hasOption("saveVocab")) { 1020 | b.saveVocabFile = new File(cl.getOptionValue("saveVocab")); 1021 | } 1022 | if (cl.hasOption("readVocab")) { 1023 | b.readVocabFile = new File(cl.getOptionValue("readVocab")); 1024 | } 1025 | if (cl.hasOption("binary")) { 1026 | b.binary = true; 1027 | } 1028 | if (cl.hasOption("cbow")) { 1029 | b.cbow = true; 1030 | } 1031 | if (cl.hasOption("startingAlpha")) { 1032 | b.startingAlpha = Float.parseFloat(cl.getOptionValue("startingAlpha")); 1033 | } 1034 | if (cl.hasOption("output")) { 1035 | //b.outputFile = new File(cl.getOptionValue("output")); 1036 | b.outputFile = new File(output); 1037 | } 1038 | if (cl.hasOption("window")) { 1039 | b.window = Integer.parseInt(cl.getOptionValue("window")); 1040 | } 1041 | if (cl.hasOption("sample")) { 1042 | b.sample = Float.parseFloat(cl.getOptionValue("sample")); 1043 | } 1044 | if (cl.hasOption("noHs")) { 1045 | b.noHs = true; 1046 | } 1047 | if (cl.hasOption("negative")) { 1048 | b.negative = Integer.parseInt(cl.getOptionValue("negative")); 1049 | } 1050 | if (cl.hasOption("threads")) { 1051 | b.numThreads = Integer.parseInt(cl.getOptionValue("threads")); 1052 | } 1053 | if (cl.hasOption("minCount")) { 1054 | b.minCount = Integer.parseInt(cl.getOptionValue("minCount")); 1055 | } 1056 | if (cl.hasOption("classes")) { 1057 | b.classes = Integer.parseInt(cl.getOptionValue("classes")); 1058 | } 1059 | } catch (Exception e) { 1060 | System.err.println("Parsing command-line arguments failed. Reason: " + e.getMessage()); 1061 | new HelpFormatter().printHelp("word2vec", options); 1062 | System.exit(1); 1063 | } 1064 | Word2Vec word2vec = new Word2Vec(b); 1065 | word2vec.trainModel(); 1066 | System.out.println("train finish"); 1067 | System.exit(0); 1068 | } 1069 | } 1070 | -------------------------------------------------------------------------------- /assets/arguments.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Word2VecJava/221720ce62fc90f9c64fb16a4441468e4ad23271/assets/arguments.JPG --------------------------------------------------------------------------------