├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md └── src └── com └── ansj └── vec ├── Learn.java ├── Word2VEC.java ├── domain ├── HiddenNeuron.java ├── Neuron.java ├── WordEntry.java └── WordNeuron.java └── util ├── Haffman.java └── MapCount.java /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | 3 | # Package Files # 4 | *.jar 5 | *.war 6 | *.ear 7 | 8 | # Eclipse project files 9 | .classpath 10 | .project 11 | .settings/ 12 | 13 | # Intellij project files 14 | *.iml 15 | .idea/ 16 | 17 | # Others 18 | target/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Yaopeng Liu Hao Peng Yangqiu Song Jianxin Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # incremental-word2vec 2 | ======== 3 | Train word2vec based on Hierarchical Softmax Function Incrementally. Inspired by Tomas Mikolov article "Efficient Estimation of Word Representations in 4 | Vector Space", [link](https://arxiv.org/pdf/1301.3781v3.pdf). 5 | 6 | The incremental word2vec is provided only for research purposes and without any warranty. 7 | Any commercial use is prohibited. 8 | When using the incremental word2vec code in your research work, you can refer to the following papers, respectively: 9 | 10 | H.Peng JX.Li Yq.Song Yp.Liu 11 | Incremental Learning the Hierarchical Softmax Function for Neural Language Models  12 | (AAAI), February 2017. 13 | -------------------------------------------------------------------------------- /src/com/ansj/vec/Learn.java: -------------------------------------------------------------------------------- 1 | package com.ansj.vec; 2 | 3 | import com.ansj.vec.domain.HiddenNeuron; 4 | import com.ansj.vec.domain.Neuron; 5 | import com.ansj.vec.domain.WordNeuron; 6 | import com.ansj.vec.util.Haffman; 7 | import com.ansj.vec.util.MapCount; 8 | 9 | import java.io.*; 10 | import java.util.*; 11 | import java.util.Map.Entry; 12 | import java.util.concurrent.ArrayBlockingQueue; 13 | import java.util.concurrent.ThreadPoolExecutor; 14 | import java.util.concurrent.TimeUnit; 15 | 16 | public class Learn { 17 | 18 | private Map wordMap = new HashMap<>(); 19 | /** 20 | * Training feature number 21 | */ 22 | private int layerSize = 50; 23 | 24 | /** 25 | * Context window size 26 | */ 27 | private int window = 5; 28 | 29 | private double sample = 1e-4; 30 | private double alpha = 0.025; 31 | private double startingAlpha = alpha; 32 | 33 | private int EXP_TABLE_SIZE = 1000; 34 | 35 | private Boolean isCbow = false; 36 | 37 | private double[] expTable = new double[EXP_TABLE_SIZE]; 38 | 39 | private int trainWordsCount = 0; 40 | 41 | private int MAX_EXP = 6; 42 | 43 | private LinkedList taskList = new LinkedList<>(); 44 | 45 | private int MAX_SIZE = 5000; 46 | 47 | private int threadSize = 20; 48 | 49 | public Learn(Boolean isCbow, Integer layerSize, Integer window, Double alpha, Double sample) { 50 | createExpTable(); 51 | if (isCbow != null) { 52 | this.isCbow = isCbow; 53 | } 54 | if (layerSize != null) 55 | this.layerSize = layerSize; 56 | if (window != null) 57 | this.window = window; 58 | if (alpha != null) 59 | this.alpha = alpha; 60 | if (sample != null) 61 | this.sample = sample; 62 | } 63 | 64 | public Learn() { 65 | createExpTable(); 66 | } 67 | 68 | /** 69 | * trainModel Globally 70 | * @throws java.io.IOException 71 | */ 72 | private void trainModel(File file) throws IOException { 73 | ThreadPoolExecutor executor = new ThreadPoolExecutor(threadSize, threadSize, threadSize, TimeUnit.MILLISECONDS, 74 | new ArrayBlockingQueue(threadSize)); 75 | for(int i = 0;i < threadSize;++i) { 76 | MyTask myTask = new MyTask(); 77 | executor.execute(myTask); 78 | } 79 | 80 | try (BufferedReader br = new BufferedReader( 81 | new InputStreamReader(new FileInputStream(file)))) { 82 | String temp = null; 83 | long nextRandom = 5; 84 | int wordCount = 0; 85 | int lastWordCount = 0; 86 | int wordCountActual = 0; 87 | synchronized (taskList) { 88 | while ((temp = br.readLine()) != null) { 89 | if (wordCount - lastWordCount > 10000) { 90 | // System.out.println("alpha:" + alpha + "\tProgress: " 91 | // + (int) (wordCountActual / (double) (trainWordsCount + 1) * 100) 92 | // + "%"); 93 | wordCountActual += wordCount - lastWordCount; 94 | lastWordCount = wordCount; 95 | alpha = startingAlpha * (1 - wordCountActual / (double) (trainWordsCount + 1)); 96 | if (alpha < startingAlpha * 0.0001) { 97 | alpha = startingAlpha * 0.0001; 98 | } 99 | } 100 | String[] strs = temp.split(" "); 101 | wordCount += strs.length; 102 | List sentence = new ArrayList(); 103 | for (int i = 0; i < strs.length; i++) { 104 | Neuron entry = wordMap.get(strs[i]); 105 | if (entry == null) { 106 | continue; 107 | } 108 | // The subsampling randomly discards frequent words while keeping the ranking same 109 | if (sample > 0) { 110 | double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1) 111 | * (sample * trainWordsCount) / entry.freq; 112 | nextRandom = nextRandom * 25214903917L + 11; 113 | if (ran < (nextRandom & 0xFFFF) / (double) 65536) { 114 | continue; 115 | } 116 | } 117 | sentence.add((WordNeuron) entry); 118 | } 119 | 120 | while (taskList.size() > MAX_SIZE) 121 | { 122 | try 123 | { 124 | taskList.wait(); 125 | } 126 | catch (InterruptedException e) 127 | { 128 | e.printStackTrace(); 129 | } 130 | } 131 | taskList.add(new Tri(nextRandom,sentence,(short)2)); 132 | taskList.notifyAll(); 133 | // for (int index = 0; index < sentence.size(); index++) { 134 | // nextRandom = nextRandom * 25214903917L + 11; 135 | // if (isCbow) { 136 | // cbowGram(index, sentence, (int) nextRandom % window); 137 | // } else { 138 | // skipGram(index, sentence, (int) nextRandom % window); 139 | // } 140 | // } 141 | } 142 | // System.out.println("Vocab size: " + wordMap.size()); 143 | // System.out.println("Words in train file: " + trainWordsCount); 144 | // System.out.println("sucess train over!"); 145 | } 146 | } 147 | executor.shutdown(); 148 | } 149 | 150 | private void trainModelBlindly(File file, File fileAdded) throws IOException { 151 | String temp = null; 152 | long nextRandom = 5; 153 | int wordCount = 0; 154 | int lastWordCount = 0; 155 | int wordCountActual = 0; 156 | try (BufferedReader br = new BufferedReader( 157 | new InputStreamReader(new FileInputStream(file)))) { 158 | while ((temp = br.readLine()) != null) { 159 | if (wordCount - lastWordCount > 10000) { 160 | wordCountActual += wordCount - lastWordCount; 161 | lastWordCount = wordCount; 162 | alpha = startingAlpha * (1 - wordCountActual / (double) (trainWordsCount + 1)); 163 | if (alpha < startingAlpha * 0.0001) { 164 | alpha = startingAlpha * 0.0001; 165 | } 166 | } 167 | String[] strs = temp.split(" "); 168 | wordCount += strs.length; 169 | List sentence = new ArrayList(); 170 | for (int i = 0; i < strs.length; i++) { 171 | Neuron entry = wordMap.get(strs[i]); 172 | if (entry == null) { 173 | continue; 174 | } 175 | // The subsampling randomly discards frequent words while keeping the ranking same 176 | if (sample > 0) { 177 | double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1) 178 | * (sample * trainWordsCount) / entry.freq; 179 | nextRandom = nextRandom * 25214903917L + 11; 180 | if (ran < (nextRandom & 0xFFFF) / (double) 65536) { 181 | continue; 182 | } 183 | } 184 | sentence.add((WordNeuron) entry); 185 | } 186 | 187 | for (int index = 0; index < sentence.size(); index++) { 188 | nextRandom = nextRandom * 25214903917L + 11; 189 | if (isCbow) { 190 | cbowGram(index, sentence, (int) nextRandom % window); 191 | } else { 192 | skipGram(index, sentence, (int) nextRandom % window); 193 | } 194 | } 195 | 196 | } 197 | } 198 | long start = System.currentTimeMillis(); 199 | int xx = 0; 200 | try (BufferedReader br = new BufferedReader( 201 | new InputStreamReader(new FileInputStream(fileAdded)))) { 202 | while ((temp = br.readLine()) != null) { 203 | if (wordCount - lastWordCount > 10000) { 204 | // System.out.println("alpha:" + alpha + "\tProgress: " 205 | // + (int) (wordCountActual / (double) (trainWordsCount + 1) * 100) 206 | // + "%"); 207 | wordCountActual += wordCount - lastWordCount; 208 | lastWordCount = wordCount; 209 | alpha = startingAlpha * (1 - wordCountActual / (double) (trainWordsCount + 1)); 210 | if (alpha < startingAlpha * 0.0001) { 211 | alpha = startingAlpha * 0.0001; 212 | } 213 | } 214 | String[] strs = temp.split(" "); 215 | wordCount += strs.length; 216 | List sentence = new ArrayList(); 217 | for (int i = 0; i < strs.length; i++) { 218 | Neuron entry = wordMap.get(strs[i]); 219 | if (entry == null) { 220 | continue; 221 | } 222 | // The subsampling randomly discards frequent words while keeping the ranking same 223 | if (sample > 0) { 224 | double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1) 225 | * (sample * trainWordsCount) / entry.freq; 226 | nextRandom = nextRandom * 25214903917L + 11; 227 | if (ran < (nextRandom & 0xFFFF) / (double) 65536) { 228 | continue; 229 | } 230 | } 231 | sentence.add((WordNeuron) entry); 232 | } 233 | 234 | for (int index = 0; index < sentence.size(); index++) { 235 | nextRandom = nextRandom * 25214903917L + 11; 236 | if (isCbow) { 237 | cbowGram(index, sentence, (int) nextRandom % window); 238 | } else { 239 | skipGram(index, sentence, (int) nextRandom % window); 240 | } 241 | } 242 | xx++; 243 | } 244 | } 245 | System.out.println((System.currentTimeMillis() - start)); 246 | } 247 | 248 | private void trainModelThirdType(File file, File fileAdded) throws IOException { 249 | String temp = null; 250 | long nextRandom = 5; 251 | int wordCount = 0; 252 | int lastWordCount = 0; 253 | int wordCountActual = 0; 254 | 255 | ThreadPoolExecutor executor = new ThreadPoolExecutor(threadSize, threadSize, threadSize, TimeUnit.MILLISECONDS, 256 | new ArrayBlockingQueue(threadSize)); 257 | for(int i = 0;i < threadSize;++i) { 258 | MyTaskThirdType myTask = new MyTaskThirdType(); 259 | executor.execute(myTask); 260 | } 261 | 262 | try (BufferedReader br = new BufferedReader( 263 | new InputStreamReader(new FileInputStream(file)))) { 264 | synchronized (taskList) { 265 | while ((temp = br.readLine()) != null) { 266 | if (wordCount - lastWordCount > 10000) { 267 | // System.out.println("alpha:" + alpha + "\tProgress: " 268 | // + (int) (wordCountActual / (double) (trainWordsCount + 1) * 100) 269 | // + "%"); 270 | wordCountActual += wordCount - lastWordCount; 271 | lastWordCount = wordCount; 272 | alpha = startingAlpha * (1 - wordCountActual / (double) (trainWordsCount + 1)); 273 | if (alpha < startingAlpha * 0.0001) { 274 | alpha = startingAlpha * 0.0001; 275 | } 276 | } 277 | 278 | String[] strs = temp.split(" "); 279 | // wordCount += strs.length; 280 | List sentence = new ArrayList(); 281 | for (int i = 0; i < strs.length; i++) { 282 | Neuron entry = wordMap.get(strs[i]); 283 | if (entry == null) { 284 | continue; 285 | } 286 | // The subsampling randomly discards frequent words while keeping the ranking same 287 | if (sample > 0) { 288 | double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1) 289 | * (sample * trainWordsCount) / entry.freq; 290 | nextRandom = nextRandom * 25214903917L + 11; 291 | if (ran < (nextRandom & 0xFFFF) / (double) 65536) { 292 | continue; 293 | } 294 | } 295 | sentence.add((WordNeuron) entry); 296 | } 297 | wordCount += sentence.size(); 298 | if(sentence.isEmpty())continue; 299 | while (taskList.size() > MAX_SIZE) 300 | { 301 | try 302 | { 303 | taskList.wait(); 304 | } 305 | catch (InterruptedException e) 306 | { 307 | e.printStackTrace(); 308 | } 309 | } 310 | taskList.add(new Tri(nextRandom,sentence,(short)0)); 311 | taskList.notifyAll(); 312 | } 313 | } 314 | } 315 | 316 | int xx = 0; 317 | try (BufferedReader br = new BufferedReader( 318 | new InputStreamReader(new FileInputStream(fileAdded)))) { 319 | synchronized (taskList) { 320 | while ((temp = br.readLine()) != null) { 321 | if (wordCount - lastWordCount > 10000) { 322 | wordCountActual += wordCount - lastWordCount; 323 | lastWordCount = wordCount; 324 | alpha = startingAlpha * (1 - wordCountActual / (double) (trainWordsCount + 1)); 325 | if (alpha < startingAlpha * 0.0001) { 326 | alpha = startingAlpha * 0.0001; 327 | } 328 | } 329 | String[] strs = temp.split(" "); 330 | wordCount += strs.length; 331 | List sentence = new ArrayList(); 332 | for (int i = 0; i < strs.length; i++) { 333 | Neuron entry = wordMap.get(strs[i]); 334 | if (entry == null) { 335 | continue; 336 | } 337 | // The subsampling randomly discards frequent words while keeping the ranking same 338 | if (sample > 0) { 339 | double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1) 340 | * (sample * trainWordsCount) / entry.freq; 341 | nextRandom = nextRandom * 25214903917L + 11; 342 | if (ran < (nextRandom & 0xFFFF) / (double) 65536) { 343 | continue; 344 | } 345 | } 346 | sentence.add((WordNeuron) entry); 347 | } 348 | 349 | while (taskList.size() > MAX_SIZE) { 350 | try { 351 | taskList.wait(); 352 | } catch (InterruptedException e) { 353 | e.printStackTrace(); 354 | } 355 | } 356 | taskList.add(new Tri(nextRandom, sentence, (short) 2)); 357 | taskList.notifyAll(); 358 | xx++; 359 | } 360 | } 361 | } 362 | executor.shutdown(); 363 | } 364 | 365 | /** 366 | * skip gram model training 367 | * @param sentence 368 | */ 369 | private void skipGram(int index, List sentence, int b) { 370 | // TODO Auto-generated method stub 371 | WordNeuron word = sentence.get(index); 372 | // TODO Auto-generated method stub 373 | int a, c = 0; 374 | for (a = b; a < window * 2 + 1 - b; a++) { 375 | if (a == window) { 376 | continue; 377 | } 378 | c = index - window + a; 379 | if (c < 0 || c >= sentence.size()) { 380 | continue; 381 | } 382 | 383 | double[] neu1e = new double[layerSize];//误差项 384 | //HIERARCHICAL SOFTMAX 385 | List neurons = word.neurons; 386 | WordNeuron we = sentence.get(c); 387 | for (int i = 0; i < neurons.size(); i++) { 388 | HiddenNeuron out = (HiddenNeuron) neurons.get(i); 389 | if(out.syn1 == null) { 390 | // System.out.print(we.name + " out "); 391 | out.syn1 = new double[layerSize]; 392 | } 393 | double f = 0; 394 | // Propagate hidden -> output 395 | for (int j = 0; j < layerSize; j++) { 396 | f += we.syn0[j] * out.syn1[j]; 397 | } 398 | if (f <= -MAX_EXP || f >= MAX_EXP) { 399 | continue; 400 | } else { 401 | f = (f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2); 402 | f = expTable[(int) f]; 403 | } 404 | // 'g' is the gradient multiplied by the learning rate 405 | double g = (1 - word.codeArr[i] - f) * alpha; 406 | // Propagate errors output -> hidden 407 | for (c = 0; c < layerSize; c++) { 408 | neu1e[c] += g * out.syn1[c]; 409 | // we.splitSyn0.get(i)[c] = g * out.syn1[c]; 410 | } 411 | // Learn weights hidden -> output 412 | for (c = 0; c < layerSize; c++) { 413 | out.syn1[c] += g * we.syn0[c]; 414 | } 415 | } 416 | 417 | // Learn weights input -> hidden 418 | for (int j = 0; j < layerSize; j++) { 419 | we.syn0[j] += neu1e[j]; 420 | } 421 | } 422 | 423 | } 424 | 425 | /** 426 | * Bag of words Model 427 | * @param index 428 | * @param sentence 429 | * @param b 430 | */ 431 | private void cbowGram(int index, List sentence, int b) { 432 | WordNeuron word = sentence.get(index); 433 | int a, c = 0; 434 | 435 | List neurons = word.neurons; 436 | double[] neu1e = new double[layerSize];//Error term 437 | double[] neu1 = new double[layerSize];//Error term 438 | WordNeuron last_word; 439 | 440 | for (a = b; a < window * 2 + 1 - b; a++) 441 | if (a != window) { 442 | c = index - window + a; 443 | if (c < 0) 444 | continue; 445 | if (c >= sentence.size()) 446 | continue; 447 | last_word = sentence.get(c); 448 | if (last_word == null) 449 | continue; 450 | for (c = 0; c < layerSize; c++) 451 | neu1[c] += last_word.syn0[c]; 452 | } 453 | 454 | //HIERARCHICAL SOFTMAX 455 | for (int d = 0; d < neurons.size(); d++) { 456 | HiddenNeuron out = (HiddenNeuron) neurons.get(d); 457 | if(out.syn1 == null) { 458 | // System.out.print(we.name + " out "); 459 | out.syn1 = new double[layerSize]; 460 | } 461 | double f = 0; 462 | // Propagate hidden -> output 463 | for (c = 0; c < layerSize; c++) 464 | f += neu1[c] * out.syn1[c]; 465 | if (f <= -MAX_EXP) 466 | continue; 467 | else if (f >= MAX_EXP) 468 | continue; 469 | else 470 | f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 471 | // 'g' is the gradient multiplied by the learning rate 472 | // double g = (1 - word.codeArr[d] - f) * alpha; 473 | // double g = f*(1-f)*( word.codeArr[i] - f) * alpha; 474 | double g = f * (1 - f) * (word.codeArr[d] - f) * alpha; 475 | // 476 | for (c = 0; c < layerSize; c++) { 477 | neu1e[c] += g * out.syn1[c]; 478 | //we.splitSyn0.get(i)[c] = g * out.syn1[c]; 479 | } 480 | // Learn weights hidden -> output 481 | for (c = 0; c < layerSize; c++) { 482 | out.syn1[c] += g * neu1[c]; 483 | } 484 | } 485 | for (a = b; a < window * 2 + 1 - b; a++) { 486 | if (a != window) { 487 | c = index - window + a; 488 | if (c < 0) 489 | continue; 490 | if (c >= sentence.size()) 491 | continue; 492 | last_word = sentence.get(c); 493 | if (last_word == null) 494 | continue; 495 | for (c = 0; c < layerSize; c++) 496 | last_word.syn0[c] += neu1e[c]; 497 | } 498 | 499 | } 500 | } 501 | 502 | private void skipGram_Incrementally(int index, List sentence, int b) { 503 | // TODO Auto-generated method stubS 504 | WordNeuron word = sentence.get(index); 505 | if(!word.alter)return; 506 | int a, c = 0; 507 | for (a = b; a < window * 2 + 1 - b; a++) { 508 | if (a == window) { 509 | continue; 510 | } 511 | c = index - window + a; 512 | if (c < 0 || c >= sentence.size()) { 513 | continue; 514 | } 515 | 516 | //HIERARCHICAL SOFTMAX 517 | List neurons = word.neurons; 518 | List oldNeurons = word.oldNeurons; 519 | WordNeuron we = sentence.get(c); 520 | 521 | for (int i = word.lowestPublicNode + 1; ; i++) { 522 | if(i >= neurons.size()) 523 | break; 524 | HiddenNeuron out = (HiddenNeuron) neurons.get(i); 525 | // if(!out.alter)continue; 526 | double f = 0; 527 | // Propagate hidden -> output 528 | for (int j = 0; j < layerSize; j++) { 529 | f += we.syn0[j] * out.syn1[j]; 530 | } 531 | if (f <= -MAX_EXP || f >= MAX_EXP) { 532 | continue; 533 | } else { 534 | f = (f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2); 535 | f = expTable[(int) f]; 536 | } 537 | // 'g' is the gradient multiplied by the learning rate 538 | double g = (1 - word.codeArr[i] - f) * alpha; 539 | // Learn weights hidden -> output 540 | for (c = 0; c < layerSize; c++) { 541 | out.syn1[c] += g * we.syn0[c]; 542 | } 543 | // Undo Learn weights hidden -> output 544 | if(i < oldNeurons.size()){ 545 | HiddenNeuron oldOut = (HiddenNeuron) oldNeurons.get(i); 546 | for (c = 0; c < layerSize; c++) { 547 | oldOut.syn1[c] -= g * we.syn0[c]; 548 | } 549 | } 550 | } 551 | } 552 | 553 | } 554 | 555 | private void cbowGram_Incrementally(int index, List sentence, int b) { 556 | WordNeuron word = sentence.get(index); 557 | if(!word.alter) 558 | return; 559 | int a, c = 0; 560 | 561 | List neurons = word.neurons; 562 | List oldNeurons = word.oldNeurons; 563 | double[] neu1 = new double[layerSize];//Error term 564 | WordNeuron last_word; 565 | 566 | for (a = b; a < window * 2 + 1 - b; a++) 567 | if (a != window) { 568 | c = index - window + a; 569 | if (c < 0) 570 | continue; 571 | if (c >= sentence.size()) 572 | continue; 573 | last_word = sentence.get(c); 574 | if (last_word == null) 575 | continue; 576 | for (c = 0; c < layerSize; c++) 577 | neu1[c] += last_word.syn0[c]; 578 | } 579 | 580 | //HIERARCHICAL SOFTMAX 581 | for (int d = word.lowestPublicNode+1; d < neurons.size(); d++) { 582 | HiddenNeuron out = (HiddenNeuron) neurons.get(d); 583 | // if (!out.alter) continue; 584 | double f = 0; 585 | // Propagate hidden -> output 586 | for (c = 0; c < layerSize; c++) 587 | f += neu1[c] * out.syn1[c]; 588 | if (f <= -MAX_EXP) 589 | continue; 590 | else if (f >= MAX_EXP) 591 | continue; 592 | else 593 | f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 594 | // 'g' is the gradient multiplied by the learning rate 595 | // double g = (1 - word.codeArr[d] - f) * alpha; 596 | // double g = f*(1-f)*( word.codeArr[i] - f) * alpha; 597 | double g = f * (1 - f) * (word.codeArr[d] - f) * alpha; 598 | 599 | // Learn weights hidden -> output 600 | for (c = 0; c < layerSize; c++) { 601 | out.syn1[c] += g * neu1[c]; 602 | } 603 | if(d < oldNeurons.size()){ 604 | HiddenNeuron oldOut = (HiddenNeuron) oldNeurons.get(d); 605 | for (c = 0; c < layerSize; c++) { 606 | oldOut.syn1[c] -= g * neu1[c]; 607 | } 608 | } 609 | } 610 | } 611 | 612 | /** 613 | * Count frequency 614 | * @param file 615 | * @throws java.io.IOException 616 | */ 617 | private void readVocab(File file) throws IOException { 618 | MapCount mc = new MapCount<>(); 619 | try (BufferedReader br = new BufferedReader( 620 | new InputStreamReader(new FileInputStream(file)))) { 621 | String temp = null; 622 | while ((temp = br.readLine()) != null) { 623 | String[] split = temp.split(" "); 624 | trainWordsCount += split.length; 625 | for (String string : split) { 626 | mc.add(string); 627 | } 628 | } 629 | } 630 | for (Entry element : mc.get().entrySet()) { 631 | wordMap.put(element.getKey(), new WordNeuron(element.getKey(), element.getValue(), 632 | layerSize)); 633 | } 634 | } 635 | 636 | /** 637 | * Adjust word frequency 638 | * @param file 639 | * @throws java.io.IOException 640 | */ 641 | private void addVocab(File file) throws IOException { 642 | MapCount mc = new MapCount<>();int xx = 0;int cntBytes = 0; 643 | try (BufferedReader br = new BufferedReader( 644 | new InputStreamReader(new FileInputStream(file)))) { 645 | String temp = null; 646 | while ((temp = br.readLine()) != null) { 647 | cntBytes += temp.getBytes().length; 648 | 649 | String[] split = temp.split(" "); 650 | trainWordsCount += split.length; 651 | for (String string : split) { 652 | mc.add(string); 653 | } 654 | xx++; 655 | } 656 | } 657 | for (Entry element : mc.get().entrySet()) { 658 | if(wordMap.containsKey(element.getKey())){ 659 | // System.out.print(element.getKey() + ":" + wordMap.get(element.getKey()).freq + "+" + element.getValue()); 660 | wordMap.get(element.getKey()).freq += element.getValue(); 661 | // System.out.println("="+wordMap.get(element.getKey()).freq); 662 | } 663 | else { 664 | // System.out.println(element.getKey()+":"+element.getValue()+"(new word)"); 665 | wordMap.put(element.getKey(), new WordNeuron(element.getKey(), element.getValue(), 666 | layerSize)); 667 | } 668 | } 669 | } 670 | 671 | /** 672 | * Count frequency & Preset word vector 673 | * @param modelFile 674 | * @throws java.io.IOException 675 | */ 676 | private void readVocabFromModelPlus(File modelFile) throws IOException { 677 | try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(modelFile)))) { 678 | int size = dis.readInt(); 679 | this.layerSize = dis.readInt(); 680 | this.trainWordsCount = dis.readInt(); 681 | String key = null; 682 | int val = 0; 683 | for (int i = 0;i < size;++i) { 684 | key = dis.readUTF(); 685 | val = dis.readInt(); 686 | wordMap.put(key, new WordNeuron(key, val,layerSize,dis)); 687 | } 688 | } 689 | 690 | } 691 | 692 | /** 693 | * Count frequency & Reset word vector 694 | * @param modelFile 695 | * @throws java.io.IOException 696 | */ 697 | private void readVocabFromModelAndReset(File modelFile) throws IOException { 698 | try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(modelFile)))) { 699 | int size = dis.readInt(); 700 | this.layerSize = dis.readInt(); 701 | this.trainWordsCount = dis.readInt(); 702 | String key = null; 703 | int val = 0; 704 | for (int i = 0;i < size;++i) { 705 | key = dis.readUTF(); 706 | val = dis.readInt(); 707 | wordMap.put(key, new WordNeuron(key, val,layerSize)); 708 | for(int j = 0;j < layerSize;++j) 709 | dis.readFloat(); 710 | } 711 | } 712 | } 713 | /** 714 | * Precompute the exp() table 715 | * f(x) = x / (x + 1) 716 | */ 717 | private void createExpTable() { 718 | for (int i = 0; i < EXP_TABLE_SIZE; i++) { 719 | // exp(6 * ((i-500)/500)) => e^-6~e^6 720 | expTable[i] = Math.exp(((i / (double) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP)); 721 | // e^6/(e^6 + 1) = 1/(1+e^-6) 722 | expTable[i] = expTable[i] / (expTable[i] + 1); 723 | } 724 | } 725 | 726 | /** 727 | * Learn by file 728 | * @param file 729 | * @throws java.io.IOException 730 | */ 731 | public void learnFile(File file) throws IOException { 732 | readVocab(file); 733 | long start = System.currentTimeMillis(); 734 | new Haffman(layerSize).make(wordMap.values()); 735 | 736 | //Find each neuron 737 | for (Neuron neuron : wordMap.values()) { 738 | ((WordNeuron)neuron).makeNeurons(layerSize) ; 739 | } 740 | System.out.println("wordMap.size= "+wordMap.size()); 741 | trainModel(file); 742 | } 743 | 744 | /** 745 | * Build the tree by file,learn by file & fileAdded 746 | * @param file 747 | * @throws java.io.IOException 748 | */ 749 | public void learnFileBlindly(File file,File fileAdded) throws IOException { 750 | readVocab(file); 751 | long start = System.currentTimeMillis(); 752 | new Haffman(layerSize).make(wordMap.values()); 753 | 754 | //查找每个神经元 755 | for (Neuron neuron : wordMap.values()) { 756 | ((WordNeuron)neuron).makeNeurons(layerSize) ; 757 | } 758 | System.out.print((System.currentTimeMillis() - start) + " "); 759 | trainModelBlindly(file, fileAdded); 760 | } 761 | 762 | /** 763 | * learn by file & fileAdded 764 | * @param file,fileAdded,treeFile,leaveFile 765 | * @throws java.io.IOException 766 | */ 767 | public void learnFile_Incrementally(File file, File fileAdded, File treeFile, File modelFile) throws IOException { 768 | //Restore Word Vectors & Binary Trees 769 | readVocabFromModelPlus(modelFile); 770 | new Haffman(layerSize).make(wordMap.values(),treeFile); 771 | //Find each neuron 772 | for (Neuron neuron : wordMap.values()) { 773 | ((WordNeuron)neuron).makeNeurons2(layerSize); 774 | } 775 | 776 | long cnt = 0; 777 | //Increment Constructs a binary tree 778 | addVocab(fileAdded); 779 | Neuron root = new Haffman(layerSize).makeWithRoot(wordMap.values()); 780 | //Compared each neuron 781 | for (Neuron neuron : wordMap.values()) { 782 | cnt += ((WordNeuron)neuron).inheritNeurons(layerSize,root) ; 783 | } 784 | System.out.println(":"+cnt + "/" + wordMap.size()+"="+((double) cnt / wordMap.size())); 785 | long start = System.currentTimeMillis(); 786 | trainModelThirdType(file, fileAdded); 787 | System.out.println((System.currentTimeMillis() - start)); 788 | } 789 | 790 | public void learnFile_Incrementally_Count(File file,File fileAdded,File treeFile,File modelFile) throws IOException { 791 | //Restore Word Vectors & Binary Trees 792 | readVocabFromModelPlus(modelFile); 793 | new Haffman(layerSize).make(wordMap.values(),treeFile); 794 | //Find each neuron 795 | for (Neuron neuron : wordMap.values()) { 796 | ((WordNeuron)neuron).makeNeurons(layerSize) ; 797 | } 798 | 799 | int cnt = 0,cnt2 = 0; 800 | //Increment Constructs a binary tree 801 | addVocab(fileAdded); 802 | Neuron root = new Haffman(layerSize).makeWithRoot(wordMap.values()); 803 | //Compared each neuron 804 | for (Neuron neuron : wordMap.values()) { 805 | cnt2 += ((WordNeuron)neuron).inheritNeuronsCount(layerSize,root) ; 806 | // cnt += ((WordNeuron)neuron).inheritNeurons(layerSize,root) ; 807 | } 808 | cnt = new Haffman(layerSize).getInheritNeurons(wordMap.values()); 809 | cnt2 = new Haffman(layerSize).getInheritNeurons2(wordMap.values()); 810 | // int cnt2 = 0; 811 | // for (Neuron neuron : wordMap.values()) { 812 | // if(neuron.alter) 813 | // cnt2++; 814 | // } 815 | System.out.println("Longer nodes: "+cnt + "/" + wordMap.size() + "=" + ((double) cnt / wordMap.size())); 816 | System.out.println("Shorter nodes: "+cnt2 + "/" + wordMap.size() + "=" + ((double) cnt2 / wordMap.size())); 817 | // __out.write("Non-leaf nodes: " + cnt + "/" + wordMap.size() + "=" + ((double) cnt / wordMap.size()) + "\r\n"); 818 | // __out.write("Leaf nodes: "+cnt2 + "/" + wordMap.size() + "=" + ((double) cnt2 / wordMap.size()) +"\r\n"); 819 | // __out.write(((double) cnt / wordMap.size()) + " " + ((double) cnt2 / wordMap.size()) + " "); 820 | // System.out.print(((double) cnt / wordMap.size()) + " " + ((double) cnt2 / wordMap.size()) + " "); 821 | // System.exit(-1); 822 | } 823 | 824 | /** 825 | * Save model 826 | */ 827 | public void saveModel(File file) { 828 | // TODO Auto-generated method stub 829 | 830 | try (DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream( 831 | new FileOutputStream(file)))) { 832 | dataOutputStream.writeInt(wordMap.size()); 833 | dataOutputStream.writeInt(layerSize); 834 | double[] syn0 = null; 835 | for (Entry element : wordMap.entrySet()) { 836 | dataOutputStream.writeUTF(element.getKey()); 837 | syn0 = ((WordNeuron) element.getValue()).syn0; 838 | for (double d : syn0) { 839 | dataOutputStream.writeFloat(((Double) d).floatValue()); 840 | } 841 | } 842 | } catch (IOException e) { 843 | // TODO Auto-generated catch block 844 | e.printStackTrace(); 845 | } 846 | } 847 | /** 848 | * Save model & word frequency 849 | */ 850 | public void saveModelPlus(File file) { 851 | // TODO Auto-generated method stub 852 | 853 | try (DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream( 854 | new FileOutputStream(file)))) { 855 | dataOutputStream.writeInt(wordMap.size()); 856 | dataOutputStream.writeInt(layerSize); 857 | dataOutputStream.writeInt(trainWordsCount); 858 | double[] syn0 = null; 859 | for (Entry element : wordMap.entrySet()) { 860 | dataOutputStream.writeUTF(element.getKey()); 861 | dataOutputStream.writeInt(element.getValue().freq); 862 | syn0 = ((WordNeuron) element.getValue()).syn0; 863 | for (double d : syn0) { 864 | dataOutputStream.writeFloat(((Double) d).floatValue()); 865 | } 866 | } 867 | } catch (IOException e) { 868 | // TODO Auto-generated catch block 869 | e.printStackTrace(); 870 | } 871 | } 872 | /* 873 | Save the binary tree non-leaf nodes 874 | */ 875 | public void saveTreeNodes(File file) { 876 | // TODO Auto-generated method stub 877 | 878 | try (DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream( 879 | new FileOutputStream(file)))) { 880 | // dataOutputStream.writeInt(wordMap.size()); 881 | // dataOutputStream.writeInt(layerSize); 882 | new Haffman(layerSize).get(wordMap.values(), dataOutputStream); 883 | } catch (IOException e) { 884 | // TODO Auto-generated catch block 885 | e.printStackTrace(); 886 | } 887 | } 888 | /* 889 | Save the binary tree leaf nodes , which is useless 890 | */ 891 | public void saveLeaveNodes(File file){ 892 | // TODO Auto-generated method stub 893 | 894 | try (DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream( 895 | new FileOutputStream(file)))) { 896 | // dataOutputStream.writeInt(wordMap.size()); 897 | // dataOutputStream.writeInt(layerSize); 898 | int[] codeArr = null; 899 | for (Entry element : wordMap.entrySet()) { 900 | dataOutputStream.writeUTF(element.getKey()); 901 | codeArr = ((WordNeuron) element.getValue()).codeArr; 902 | dataOutputStream.writeInt(codeArr.length); 903 | for (int d : codeArr) { 904 | dataOutputStream.writeInt(d); 905 | } 906 | } 907 | } catch (IOException e) { 908 | // TODO Auto-generated catch block 909 | e.printStackTrace(); 910 | } 911 | } 912 | 913 | /* 914 | Save the bottom non-leaf nodes , which is useless 915 | */ 916 | public void saveTheta(File file) { 917 | // TODO Auto-generated method stub 918 | 919 | try (DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream( 920 | new FileOutputStream(file)))) { 921 | dataOutputStream.writeInt(wordMap.size()); 922 | dataOutputStream.writeInt(layerSize); 923 | double[] syn1 = null; 924 | for (Entry element : wordMap.entrySet()) { 925 | dataOutputStream.writeUTF(element.getKey()); 926 | WordNeuron wordNeuron = (WordNeuron)element.getValue(); 927 | HiddenNeuron hiddenNeuron = (HiddenNeuron) wordNeuron.neurons.get(wordNeuron.neurons.size() - 1); 928 | syn1 = hiddenNeuron.syn1; 929 | if(syn1 == null)syn1 = new double[layerSize]; 930 | for (double d : syn1) { 931 | dataOutputStream.writeFloat(((Double) d).floatValue()); 932 | } 933 | } 934 | } catch (IOException e) { 935 | // TODO Auto-generated catch block 936 | e.printStackTrace(); 937 | } 938 | } 939 | 940 | public int getLayerSize() { 941 | return layerSize; 942 | } 943 | 944 | public void setLayerSize(int layerSize) { 945 | this.layerSize = layerSize; 946 | } 947 | 948 | public int getWindow() { 949 | return window; 950 | } 951 | 952 | public void setWindow(int window) { 953 | this.window = window; 954 | } 955 | 956 | public double getSample() { 957 | return sample; 958 | } 959 | 960 | public void setSample(double sample) { 961 | this.sample = sample; 962 | } 963 | 964 | public double getAlpha() { 965 | return alpha; 966 | } 967 | 968 | public void setAlpha(double alpha) { 969 | this.alpha = alpha; 970 | this.startingAlpha = alpha; 971 | } 972 | 973 | public Boolean getIsCbow() { 974 | return isCbow; 975 | } 976 | 977 | public void setIsCbow(Boolean isCbow) { 978 | this.isCbow = isCbow; 979 | } 980 | 981 | public static void main(String[] args) throws IOException { 982 | Learn learn = new Learn(); 983 | long start = System.currentTimeMillis() ; 984 | learn.learnFile(new File("InputFiles/xh.txt")); 985 | System.out.println("use time "+(System.currentTimeMillis()-start)); 986 | learn.saveModel(new File("InputFiles/javaVector")); 987 | 988 | } 989 | class MyTask implements Runnable { 990 | private long nextRandom,nextRandom2,nextRandom3,nextRandom4; 991 | private List sentence = null,sentence2 = null,sentence3 = null,sentence4 = null; 992 | private short step = 0,step2 = 0,step3 = 0,step4 = 0; 993 | 994 | public void setArguments(Tri tri,Tri tri2,Tri tri3,Tri tri4){ 995 | this.nextRandom = tri.nextRandom; 996 | this.sentence = tri.sentence; 997 | this.step = tri.step; 998 | this.nextRandom2 = tri2.nextRandom; 999 | this.sentence2 = tri2.sentence; 1000 | this.step2 = tri.step; 1001 | this.nextRandom3 = tri3.nextRandom; 1002 | this.sentence3 = tri3.sentence; 1003 | this.step3 = tri.step; 1004 | this.nextRandom4 = tri4.nextRandom; 1005 | this.sentence4 = tri4.sentence; 1006 | this.step4 = tri.step; 1007 | } 1008 | @Override 1009 | public void run() { 1010 | while (true) { 1011 | 1012 | synchronized (taskList) { 1013 | while (taskList.size() < 4) { 1014 | try { 1015 | taskList.wait(); 1016 | } catch (InterruptedException e) { 1017 | e.printStackTrace(); 1018 | } 1019 | } 1020 | 1021 | setArguments(taskList.poll(),taskList.poll(),taskList.poll(),taskList.poll()); 1022 | 1023 | taskList.notifyAll(); 1024 | } 1025 | 1026 | LearnByPiece( step, sentence, nextRandom); 1027 | LearnByPiece( step2, sentence2, nextRandom2); 1028 | LearnByPiece( step3, sentence3, nextRandom3); 1029 | LearnByPiece( step4, sentence4, nextRandom4); 1030 | sentence = null; 1031 | sentence2 = null; 1032 | sentence3 = null; 1033 | sentence4 = null; 1034 | } 1035 | } 1036 | } 1037 | private void LearnByPiece(short step, List sentence, long nextRandom){ 1038 | for (int index = 0; index < sentence.size(); index++) { 1039 | nextRandom = nextRandom * 25214903917L + 11; 1040 | if (isCbow) { 1041 | cbowGram(index, sentence, (int) nextRandom % window); 1042 | } else { 1043 | skipGram(index, sentence, (int) nextRandom % window); 1044 | } 1045 | } 1046 | } 1047 | class Tri { 1048 | private long nextRandom; 1049 | private List sentence = null; 1050 | private short step; 1051 | 1052 | public Tri(long nextRandom, List sentence, short step) { 1053 | this.nextRandom = nextRandom; 1054 | this.sentence = sentence; 1055 | this.step = step; 1056 | } 1057 | } 1058 | 1059 | class MyTaskThirdType implements Runnable { 1060 | private long nextRandom,nextRandom2,nextRandom3,nextRandom4; 1061 | private List sentence = null,sentence2 = null,sentence3 = null,sentence4 = null; 1062 | private short step = 0,step2 = 0,step3 = 0,step4 = 0; 1063 | 1064 | public void setArguments(Tri tri,Tri tri2,Tri tri3,Tri tri4){ 1065 | this.nextRandom = tri.nextRandom; 1066 | this.sentence = tri.sentence; 1067 | this.step = tri.step; 1068 | this.nextRandom2 = tri2.nextRandom; 1069 | this.sentence2 = tri2.sentence; 1070 | this.step2 = tri.step; 1071 | this.nextRandom3 = tri3.nextRandom; 1072 | this.sentence3 = tri3.sentence; 1073 | this.step3 = tri.step; 1074 | this.nextRandom4 = tri4.nextRandom; 1075 | this.sentence4 = tri4.sentence; 1076 | this.step4 = tri.step; 1077 | } 1078 | @Override 1079 | public void run() { 1080 | while (true) { 1081 | 1082 | synchronized (taskList) { 1083 | while (taskList.size() < 4) { 1084 | try { 1085 | taskList.wait(); 1086 | } catch (InterruptedException e) { 1087 | e.printStackTrace(); 1088 | } 1089 | } 1090 | 1091 | setArguments(taskList.poll(),taskList.poll(),taskList.poll(),taskList.poll()); 1092 | 1093 | taskList.notifyAll(); 1094 | } 1095 | 1096 | increLearnByPiece_Incrementally(step, sentence, nextRandom); 1097 | increLearnByPiece_Incrementally(step2, sentence2, nextRandom2); 1098 | increLearnByPiece_Incrementally(step3, sentence3, nextRandom3); 1099 | increLearnByPiece_Incrementally(step4, sentence4, nextRandom4); 1100 | sentence = null; 1101 | sentence2 = null; 1102 | sentence3 = null; 1103 | sentence4 = null; 1104 | } 1105 | } 1106 | } 1107 | private void increLearnByPiece_Incrementally(short step, List sentence, long nextRandom){ 1108 | if(step == 0) { 1109 | for (int index = 0; index < sentence.size(); index++) { 1110 | nextRandom = nextRandom * 25214903917L + 11; 1111 | if (isCbow) { 1112 | cbowGram_Incrementally(index, sentence, (int) nextRandom % window); 1113 | } else { 1114 | skipGram_Incrementally(index, sentence, (int) nextRandom % window); 1115 | } 1116 | } 1117 | }else{ 1118 | for (int index = 0; index < sentence.size(); index++) { 1119 | nextRandom = nextRandom * 25214903917L + 11; 1120 | if (isCbow) { 1121 | cbowGram(index, sentence, (int) nextRandom % window); 1122 | } else { 1123 | skipGram(index, sentence, (int) nextRandom % window); 1124 | } 1125 | } 1126 | } 1127 | } 1128 | } 1129 | -------------------------------------------------------------------------------- /src/com/ansj/vec/Word2VEC.java: -------------------------------------------------------------------------------- 1 | package com.ansj.vec; 2 | 3 | import com.ansj.vec.domain.WordEntry; 4 | 5 | import java.io.*; 6 | import java.util.*; 7 | import java.util.Map.Entry; 8 | 9 | public class Word2VEC { 10 | private static boolean getVec = false; 11 | private static boolean global = false; 12 | public static void main(String[] args) throws IOException { 13 | if(args.length >= 1){ 14 | global = Boolean.valueOf(args[0]); 15 | } 16 | long start = 0; 17 | Learn learn; 18 | String[] strs = {"man", "woman", "war", "murder", "fog", "disease", "bribe", "obama", "fire", "issue"}; 19 | try { 20 | if(global) { 21 | start = System.currentTimeMillis(); 22 | learn = new Learn(); 23 | learn.learnFile(new File("InputFiles/wiki.enLemmatize.4.text")); 24 | learn.saveModel(new File("InputFiles/javaSkip300Test7G")); 25 | learn.saveModelPlus(new File("InputFiles/javaSkip300Plus7G")); 26 | learn.saveTreeNodes(new File("InputFiles/javaSkip300Tree7G")); 27 | System.out.println( (System.currentTimeMillis() - start) ); 28 | 29 | testVec("InputFiles/javaSkip300Test7G",getVec,strs); 30 | }else { 31 | 32 | learn = new Learn(); 33 | learn.learnFile_Incrementally(new File("InputFiles/wiki.enLemmatize.3.text"), new File("InputFiles/xh-added-5G.txt"), 34 | new File("InputFiles/javaSkip300Tree2G"), new File("InputFiles/javaSkip300Plus2G")); 35 | learn.saveModel(new File("InputFiles/javaSkip300Test2G+5G")); 36 | learn.saveModelPlus(new File("InputFiles/javaSkip300Plus2G+5G")); 37 | learn.saveTreeNodes(new File("InputFiles/javaSkip300Tree2G+5G")); 38 | 39 | testVec("InputFiles/javaSkip300Test2G+5G", getVec, strs); 40 | } 41 | }catch (Exception e) { 42 | e.printStackTrace(); 43 | } 44 | } 45 | 46 | private HashMap wordMap = new HashMap(); 47 | 48 | private int words; 49 | private int size; 50 | private int topNSize = 30; 51 | 52 | /** 53 | * @param path 54 | * @throws java.io.IOException 55 | */ 56 | public static void testVec(String path,boolean getVec,String[] strs) throws IOException { 57 | Word2VEC vec = new Word2VEC(); 58 | vec.loadJavaModel(path); 59 | for (String str : strs) { 60 | if (getVec) { 61 | float[] tmp = vec.getWordVector(str); 62 | System.out.print(str + ":[");//distance(str)); 63 | for (float each : tmp) { 64 | System.out.print(each + ","); 65 | } 66 | System.out.println("]"); 67 | } 68 | else 69 | System.out.println(str + ":" + vec.distance(str)); 70 | } 71 | //System.out.println("distance of \"男人\" & \"女人\" is " + vec.distanceOfWord("男人", "女人")); 72 | System.exit(-1); 73 | } 74 | 75 | /** 76 | * @param path 77 | * @throws java.io.IOException 78 | */ 79 | public void loadGoogleModel(String path) throws IOException { 80 | DataInputStream dis = null; 81 | BufferedInputStream bis = null; 82 | double len = 0; 83 | float vector = 0; 84 | try { 85 | bis = new BufferedInputStream(new FileInputStream(path)); 86 | dis = new DataInputStream(bis); 87 | // //读取词数 88 | words = Integer.parseInt(readString(dis)); 89 | // //大小 90 | size = Integer.parseInt(readString(dis)); 91 | String word; 92 | float[] vectors = null; 93 | for (int i = 0; i < words; i++) { 94 | word = readString(dis); 95 | vectors = new float[size]; 96 | len = 0; 97 | for (int j = 0; j < size; j++) { 98 | vector = readFloat(dis); 99 | len += vector * vector; 100 | vectors[j] = (float) vector; 101 | } 102 | len = Math.sqrt(len); 103 | 104 | for (int j = 0; j < size; j++) { 105 | vectors[j] /= len; 106 | } 107 | 108 | wordMap.put(word, vectors); 109 | dis.read(); 110 | } 111 | } finally { 112 | bis.close(); 113 | dis.close(); 114 | } 115 | } 116 | 117 | /** 118 | * @param path 119 | * @throws java.io.IOException 120 | */ 121 | public void loadJavaModel(String path) throws IOException { 122 | try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)))) { 123 | words = dis.readInt(); 124 | size = dis.readInt(); 125 | 126 | float vector = 0; 127 | 128 | String key = null; 129 | float[] value = null; 130 | for (int i = 0; i < words; i++) { 131 | double len = 0; 132 | key = dis.readUTF(); 133 | value = new float[size]; 134 | for (int j = 0; j < size; j++) { 135 | vector = dis.readFloat(); 136 | len += vector * vector; 137 | value[j] = vector; 138 | } 139 | 140 | len = Math.sqrt(len); 141 | //Normalized 142 | for (int j = 0; j < size; j++) { 143 | value[j] /= len; 144 | } 145 | wordMap.put(key, value); 146 | } 147 | 148 | } 149 | } 150 | 151 | private static final int MAX_SIZE = 50; 152 | 153 | /** 154 | * simliar word 155 | * 156 | * @return 157 | */ 158 | public TreeSet analogy(String word0, String word1, String word2) { 159 | float[] wv0 = getWordVector(word0); 160 | float[] wv1 = getWordVector(word1); 161 | float[] wv2 = getWordVector(word2); 162 | 163 | if (wv1 == null || wv2 == null || wv0 == null) { 164 | return null; 165 | } 166 | float[] wordVector = new float[size]; 167 | for (int i = 0; i < size; i++) { 168 | wordVector[i] = wv1[i] - wv0[i] + wv2[i]; 169 | } 170 | float[] tempVector; 171 | String name; 172 | List wordEntrys = new ArrayList(topNSize); 173 | for (Entry entry : wordMap.entrySet()) { 174 | name = entry.getKey(); 175 | if (name.equals(word0) || name.equals(word1) || name.equals(word2)) { 176 | continue; 177 | } 178 | float dist = 0; 179 | tempVector = entry.getValue(); 180 | for (int i = 0; i < wordVector.length; i++) { 181 | dist += wordVector[i] * tempVector[i]; 182 | } 183 | insertTopN(name, dist, wordEntrys); 184 | } 185 | return new TreeSet(wordEntrys); 186 | } 187 | 188 | private void insertTopN(String name, float score, List wordsEntrys) { 189 | // TODO Auto-generated method stub 190 | if (wordsEntrys.size() < topNSize) { 191 | wordsEntrys.add(new WordEntry(name, score)); 192 | return; 193 | } 194 | float min = Float.MAX_VALUE; 195 | int minOffe = 0; 196 | for (int i = 0; i < topNSize; i++) { 197 | WordEntry wordEntry = wordsEntrys.get(i); 198 | if (min > wordEntry.score) { 199 | min = wordEntry.score; 200 | minOffe = i; 201 | } 202 | } 203 | 204 | if (score > min) { 205 | wordsEntrys.set(minOffe, new WordEntry(name, score)); 206 | } 207 | 208 | } 209 | 210 | public float distanceOfWord(String wordA , String wordB ){ 211 | float dist = 0; 212 | if(wordA.compareTo(wordB) == 0) 213 | return 1; 214 | float[] vectorA = getWordVector(wordA); 215 | float[] vectorB = getWordVector(wordB); 216 | if(vectorA == null || vectorB == null) 217 | return 0; 218 | for (int i = 0; i < vectorA.length; i++) { 219 | dist += vectorB[i] * vectorA[i]; 220 | } 221 | return dist; 222 | } 223 | 224 | public int getMapsize(){ 225 | return wordMap.size(); 226 | } 227 | public Set distance(String queryWord) { 228 | 229 | float[] center = wordMap.get(queryWord); 230 | if (center == null) { 231 | return Collections.emptySet(); 232 | } 233 | 234 | int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize; 235 | TreeSet result = new TreeSet(); 236 | 237 | double min = Float.MIN_VALUE; 238 | for (Entry entry : wordMap.entrySet()) { 239 | float[] vector = entry.getValue(); 240 | float dist = 0; 241 | for (int i = 0; i < vector.length; i++) { 242 | dist += center[i] * vector[i]; 243 | } 244 | 245 | if (dist > min) { 246 | result.add(new WordEntry(entry.getKey(), dist)); 247 | if (resultSize < result.size()) { 248 | result.pollLast(); 249 | } 250 | min = result.last().score; 251 | } 252 | } 253 | result.pollFirst(); 254 | 255 | return result; 256 | } 257 | 258 | public Set distance(List words) { 259 | 260 | float[] center = null; 261 | for (String word : words) { 262 | center = sum(center, wordMap.get(word)); 263 | } 264 | 265 | if (center == null) { 266 | return Collections.emptySet(); 267 | } 268 | 269 | int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize; 270 | TreeSet result = new TreeSet(); 271 | 272 | double min = Float.MIN_VALUE; 273 | for (Entry entry : wordMap.entrySet()) { 274 | float[] vector = entry.getValue(); 275 | float dist = 0; 276 | for (int i = 0; i < vector.length; i++) { 277 | dist += center[i] * vector[i]; 278 | } 279 | 280 | if (dist > min) { 281 | result.add(new WordEntry(entry.getKey(), dist)); 282 | if (resultSize < result.size()) { 283 | result.pollLast(); 284 | } 285 | min = result.last().score; 286 | } 287 | } 288 | result.pollFirst(); 289 | 290 | return result; 291 | } 292 | 293 | private float[] sum(float[] center, float[] fs) { 294 | // TODO Auto-generated method stub 295 | 296 | if (center == null && fs == null) { 297 | return null; 298 | } 299 | 300 | if (fs == null) { 301 | return center; 302 | } 303 | 304 | if (center == null) { 305 | return fs; 306 | } 307 | 308 | for (int i = 0; i < fs.length; i++) { 309 | center[i] += fs[i]; 310 | } 311 | 312 | return center; 313 | } 314 | 315 | /** 316 | * get word vector 317 | * 318 | * @param word 319 | * @return 320 | */ 321 | public float[] getWordVector(String word) { 322 | return wordMap.get(word); 323 | } 324 | 325 | public static float readFloat(InputStream is) throws IOException { 326 | byte[] bytes = new byte[4]; 327 | is.read(bytes); 328 | return getFloat(bytes); 329 | } 330 | 331 | /** 332 | * read a float 333 | * 334 | * @param b 335 | * @return 336 | */ 337 | public static float getFloat(byte[] b) { 338 | int accum = 0; 339 | accum = accum | (b[0] & 0xff) << 0; 340 | accum = accum | (b[1] & 0xff) << 8; 341 | accum = accum | (b[2] & 0xff) << 16; 342 | accum = accum | (b[3] & 0xff) << 24; 343 | return Float.intBitsToFloat(accum); 344 | } 345 | 346 | /** 347 | * read a string 348 | * 349 | * @param dis 350 | * @return 351 | * @throws java.io.IOException 352 | */ 353 | private static String readString(DataInputStream dis) throws IOException { 354 | // TODO Auto-generated method stub 355 | byte[] bytes = new byte[MAX_SIZE]; 356 | byte b = dis.readByte(); 357 | int i = -1; 358 | StringBuilder sb = new StringBuilder(); 359 | while (b != 32 && b != 10) { 360 | i++; 361 | bytes[i] = b; 362 | b = dis.readByte(); 363 | if (i == 49) { 364 | sb.append(new String(bytes)); 365 | i = -1; 366 | bytes = new byte[MAX_SIZE]; 367 | } 368 | } 369 | sb.append(new String(bytes, 0, i + 1)); 370 | return sb.toString(); 371 | } 372 | 373 | public int getTopNSize() { 374 | return topNSize; 375 | } 376 | 377 | public void setTopNSize(int topNSize) { 378 | this.topNSize = topNSize; 379 | } 380 | 381 | public HashMap getWordMap() { 382 | return wordMap; 383 | } 384 | 385 | public int getWords() { 386 | return words; 387 | } 388 | 389 | public int getSize() { 390 | return size; 391 | } 392 | 393 | } 394 | -------------------------------------------------------------------------------- /src/com/ansj/vec/domain/HiddenNeuron.java: -------------------------------------------------------------------------------- 1 | package com.ansj.vec.domain; 2 | 3 | public class HiddenNeuron extends Neuron{ 4 | 5 | public double[] syn1 ; //hidden->out 6 | public Neuron leftNode = null,rightNode = null;//left --> 0 right --> 1 7 | 8 | public HiddenNeuron(int layerSize){ 9 | syn1 = new double[layerSize] ; 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/com/ansj/vec/domain/Neuron.java: -------------------------------------------------------------------------------- 1 | package com.ansj.vec.domain; 2 | 3 | public abstract class Neuron implements Comparable { 4 | public int freq;public int freq2; 5 | public Neuron parent; 6 | public int code; 7 | public boolean alter,cnt2Alter,cntAlter,cengAlter; 8 | 9 | @Override 10 | public int compareTo(Neuron o) { 11 | // TODO Auto-generated method stub 12 | if (this.freq > o.freq) { 13 | return 1; 14 | } else { 15 | return -1; 16 | } 17 | } 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/com/ansj/vec/domain/WordEntry.java: -------------------------------------------------------------------------------- 1 | package com.ansj.vec.domain; 2 | 3 | 4 | public class WordEntry implements Comparable { 5 | public String name; 6 | public float score; 7 | 8 | public WordEntry(String name, float score) { 9 | this.name = name; 10 | this.score = score; 11 | } 12 | 13 | @Override 14 | public String toString() { 15 | // TODO Auto-generated method stub 16 | return this.name + "\t" + score; 17 | } 18 | 19 | @Override 20 | public int compareTo(WordEntry o) { 21 | // TODO Auto-generated method stub 22 | if (this.score < o.score) { 23 | return 1; 24 | } else { 25 | return -1; 26 | } 27 | } 28 | 29 | } -------------------------------------------------------------------------------- /src/com/ansj/vec/domain/WordNeuron.java: -------------------------------------------------------------------------------- 1 | package com.ansj.vec.domain; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.IOException; 5 | import java.util.Collections; 6 | import java.util.LinkedList; 7 | import java.util.List; 8 | import java.util.Random; 9 | 10 | public class WordNeuron extends Neuron { 11 | public String name; 12 | public double[] syn0 = null; //input->hidden 13 | public List neurons = null;//neurons on path 14 | public int[] codeArr = null; 15 | public List oldNeurons = null; 16 | public int lowestPublicNode = -1; 17 | 18 | public List makeNeurons(int layerSize) { 19 | if (neurons != null) { 20 | return neurons; 21 | } 22 | Neuron neuron = this; 23 | neurons = new LinkedList<>(); 24 | while ((neuron = neuron.parent) != null) { 25 | neurons.add(neuron); 26 | } 27 | Collections.reverse(neurons); 28 | codeArr = new int[neurons.size()]; 29 | 30 | for (int i = 1; i < neurons.size(); i++) { 31 | codeArr[i - 1] = neurons.get(i).code; 32 | } 33 | codeArr[codeArr.length - 1] = this.code; 34 | 35 | return neurons; 36 | } 37 | 38 | public int makeNeurons2(int layerSize) { 39 | if (neurons != null) { 40 | return neurons.size(); 41 | } 42 | Neuron neuron = this; 43 | neurons = new LinkedList<>(); 44 | while ((neuron = neuron.parent) != null) { 45 | neurons.add(neuron); 46 | } 47 | Collections.reverse(neurons); 48 | codeArr = new int[neurons.size()]; 49 | 50 | for (int i = 1; i < neurons.size(); i++) { 51 | codeArr[i - 1] = neurons.get(i).code; 52 | } 53 | codeArr[codeArr.length - 1] = this.code; 54 | 55 | return neurons.size(); 56 | } 57 | 58 | public int inheritNeurons(int layerSize , Neuron root) { 59 | Neuron neuron = this; 60 | List inputNeurons = new LinkedList<>(); 61 | while ((neuron = neuron.parent) != null) { 62 | inputNeurons.add(neuron); 63 | } 64 | Collections.reverse(inputNeurons); 65 | int[] inputCodeArr = new int[inputNeurons.size()]; 66 | 67 | for (int i = 1; i < inputNeurons.size(); i++) { 68 | inputCodeArr[i - 1] = inputNeurons.get(i).code; 69 | } 70 | inputCodeArr[inputCodeArr.length - 1] = this.code; 71 | 72 | if(neurons == null) { 73 | //TO DO 74 | neurons = inputNeurons; 75 | codeArr = inputCodeArr; 76 | return 0; 77 | } 78 | int cnt = 0; 79 | 80 | oldNeurons = new LinkedList<>(); 81 | HiddenNeuron hn = null,oldHn = null; 82 | boolean flag = true; 83 | int len = 0; 84 | 85 | for(int i = 0;i < codeArr.length;++i){ 86 | if(root instanceof WordNeuron) { 87 | break; 88 | } 89 | if(flag && i < inputCodeArr.length && inputCodeArr[i] == codeArr[i]) 90 | lowestPublicNode = i; 91 | else { 92 | flag = false; 93 | this.alter = true; 94 | } 95 | if(!root.alter){ 96 | oldHn = (HiddenNeuron) neurons.get(i); 97 | hn = (HiddenNeuron)root; 98 | for (int j = 0; j < layerSize; ++j) { 99 | if (!Double.isNaN(oldHn.syn1[j])) 100 | hn.syn1[j] = oldHn.syn1[j]; 101 | } 102 | oldHn.syn1 = null; 103 | cnt++; 104 | root.alter = true; 105 | } 106 | 107 | oldNeurons.add(root); 108 | 109 | if(codeArr[i] == 0) 110 | root = ((HiddenNeuron)root).leftNode; 111 | else 112 | root = ((HiddenNeuron)root).rightNode; 113 | } 114 | for(int i = 0;i < inputNeurons.size();++i){ 115 | Neuron x = inputNeurons.get(i); 116 | if(!x.cengAlter){ 117 | x.cengAlter = true; 118 | } 119 | } 120 | neurons = inputNeurons; 121 | codeArr = inputCodeArr; 122 | return cnt; 123 | } 124 | 125 | public int inheritNeuronsCount(int layerSize , Neuron root) { 126 | Neuron neuron = this; 127 | List inputNeurons = new LinkedList<>(); 128 | while ((neuron = neuron.parent) != null) { 129 | inputNeurons.add(neuron); 130 | } 131 | Collections.reverse(inputNeurons); 132 | int[] inputCodeArr = new int[inputNeurons.size()]; 133 | 134 | for (int i = 1; i < inputNeurons.size(); i++) { 135 | inputCodeArr[i - 1] = inputNeurons.get(i).code; 136 | } 137 | inputCodeArr[inputCodeArr.length - 1] = this.code; 138 | 139 | if(neurons == null) { 140 | //TO DO 141 | neurons = inputNeurons; 142 | codeArr = inputCodeArr; 143 | return 0; 144 | } 145 | int cnt = 0; 146 | 147 | oldNeurons = new LinkedList<>(); 148 | HiddenNeuron hn = null,oldHn = null; 149 | boolean flag = true; 150 | int longer = codeArr.length - inputCodeArr.length; 151 | for(int i = 0;i < codeArr.length;++i){ 152 | if(root instanceof WordNeuron) { 153 | break; 154 | } 155 | if(flag && i < inputCodeArr.length && inputCodeArr[i] == codeArr[i]) 156 | lowestPublicNode = i; 157 | else { 158 | flag = false; 159 | this.alter = true; 160 | } 161 | if(i >= inputCodeArr.length){ 162 | root.cntAlter = true; 163 | } 164 | if(!root.alter){ 165 | oldHn = (HiddenNeuron) neurons.get(i); 166 | hn = (HiddenNeuron)root; 167 | for (int j = 0; j < layerSize; ++j) { 168 | if (!Double.isNaN(oldHn.syn1[j])) 169 | hn.syn1[j] = oldHn.syn1[j]; 170 | } 171 | oldHn.syn1 = null; 172 | // cnt++; 173 | root.alter = true; 174 | } 175 | oldNeurons.add(root); 176 | 177 | if(codeArr[i] == 0) 178 | root = ((HiddenNeuron)root).leftNode; 179 | else 180 | root = ((HiddenNeuron)root).rightNode; 181 | } 182 | if(longer < 0){ 183 | for(int i = inputCodeArr.length + longer ; i < inputCodeArr.length ;i++){ 184 | if(inputNeurons.get(i) instanceof WordNeuron) { 185 | break; 186 | } 187 | hn = (HiddenNeuron)inputNeurons.get(i); 188 | hn.cnt2Alter = true; 189 | } 190 | } 191 | neurons = inputNeurons; 192 | codeArr = inputCodeArr; 193 | return cnt; 194 | } 195 | public WordNeuron(String name, int freq, int layerSize) { 196 | this.name = name; 197 | this.freq = freq; 198 | this.syn0 = new double[layerSize]; 199 | Random random = new Random(); 200 | for (int i = 0; i < syn0.length; i++) { 201 | syn0[i] = (random.nextDouble() - 0.5) / layerSize; 202 | } 203 | } 204 | 205 | public WordNeuron(String name, int freq, int layerSize, DataInputStream dis) throws IOException { 206 | this.name = name; 207 | this.freq = freq; 208 | this.syn0 = new double[layerSize]; 209 | for (int i = 0; i < syn0.length; i++) { 210 | syn0[i] = dis.readFloat(); 211 | } 212 | } 213 | } -------------------------------------------------------------------------------- /src/com/ansj/vec/util/Haffman.java: -------------------------------------------------------------------------------- 1 | package com.ansj.vec.util; 2 | 3 | import com.ansj.vec.domain.HiddenNeuron; 4 | import com.ansj.vec.domain.Neuron; 5 | 6 | import java.io.*; 7 | import java.util.Collection; 8 | import java.util.TreeSet; 9 | 10 | /** 11 | * Construct Haffman Tree 12 | * @author ansj 13 | * 14 | */ 15 | public class Haffman { 16 | private int layerSize; 17 | 18 | public Haffman(int layerSize) { 19 | this.layerSize = layerSize; 20 | } 21 | 22 | private TreeSet set = new TreeSet<>(); 23 | 24 | public void make(Collection neurons) { 25 | set.addAll(neurons); 26 | while (set.size() > 1) { 27 | merger(); 28 | } 29 | } 30 | 31 | 32 | private void merger() { 33 | // TODO Auto-generated method stub 34 | HiddenNeuron hn = new HiddenNeuron(layerSize); 35 | Neuron min1 = set.pollFirst(); 36 | Neuron min2 = set.pollFirst(); 37 | hn.freq = min1.freq + min2.freq; 38 | min1.parent = hn; 39 | min2.parent = hn; 40 | min1.code = 0; 41 | min2.code = 1; 42 | set.add(hn); 43 | } 44 | 45 | public Neuron makeWithRoot(Collection neurons) { 46 | set.addAll(neurons); 47 | while (set.size() > 1) { 48 | mergerWithFatherToSon(); 49 | } 50 | return set.pollFirst(); 51 | } 52 | 53 | private void mergerWithFatherToSon() { 54 | // TODO Auto-generated method stub 55 | HiddenNeuron hn = new HiddenNeuron(layerSize); 56 | Neuron min1 = set.pollFirst(); 57 | Neuron min2 = set.pollFirst(); 58 | hn.freq = min1.freq + min2.freq; 59 | min1.parent = hn; 60 | min2.parent = hn; 61 | min1.code = 0; 62 | min2.code = 1; 63 | hn.leftNode = min1; 64 | hn.rightNode = min2; 65 | set.add(hn); 66 | } 67 | 68 | public void make(Collection neurons,File treeFile) throws IOException { 69 | set.addAll(neurons); 70 | try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(treeFile)))){ 71 | while (set.size() > 1) { 72 | merger(dis); 73 | } 74 | } 75 | } 76 | 77 | private void merger(DataInputStream dis) throws IOException { 78 | // TODO Auto-generated method stub 79 | HiddenNeuron hn = new HiddenNeuron(layerSize); 80 | Neuron min1 = set.pollFirst(); 81 | Neuron min2 = set.pollFirst(); 82 | hn.freq = min1.freq + min2.freq; 83 | min1.parent = hn; 84 | min2.parent = hn; 85 | min1.code = 0; 86 | min2.code = 1; 87 | if(!set.isEmpty()) { 88 | for (int j = 0; j < layerSize; j++) { 89 | hn.syn1[j] = dis.readFloat();//System.out.print(hn.syn1[j] + " "); 90 | } 91 | } 92 | set.add(hn); 93 | } 94 | 95 | public void get(Collection neurons, DataOutputStream dataOutputStream) throws IOException { 96 | set.addAll(neurons); 97 | while (set.size() > 1) { 98 | getParentNode(dataOutputStream); 99 | } 100 | } 101 | 102 | private void getParentNode(DataOutputStream dataOutputStream) throws IOException { 103 | Neuron min1 = set.pollFirst(); 104 | Neuron min2 = set.pollFirst(); 105 | if(min1 instanceof HiddenNeuron){ 106 | for (double d : ((HiddenNeuron) min1).syn1) { 107 | dataOutputStream.writeFloat(((Double) d).floatValue()); 108 | } 109 | } 110 | if(min2 instanceof HiddenNeuron){ 111 | for (double d : ((HiddenNeuron) min2).syn1) { 112 | dataOutputStream.writeFloat(((Double) d).floatValue()); 113 | } 114 | } 115 | set.add(min1.parent); 116 | } 117 | 118 | public int getInheritNeurons(Collection neurons) throws IOException { 119 | set.addAll(neurons); 120 | int cnt = 0; 121 | while (set.size() > 1) { 122 | cnt += getInheritParentNode(); 123 | } 124 | return cnt; 125 | } 126 | 127 | private int getInheritParentNode() throws IOException { 128 | Neuron min1 = set.pollFirst(); 129 | Neuron min2 = set.pollFirst(); 130 | int cnt = 0; 131 | if(min1 instanceof HiddenNeuron){ 132 | if(min1.cntAlter && min1.alter) 133 | cnt++; 134 | } 135 | if(min2 instanceof HiddenNeuron){ 136 | if(min2.cntAlter && min2.alter) 137 | cnt++; 138 | } 139 | set.add(min1.parent); 140 | return cnt; 141 | } 142 | 143 | public int getInheritNeurons2(Collection neurons) throws IOException { 144 | set.addAll(neurons); 145 | int cnt = 0; 146 | while (set.size() > 1) { 147 | cnt += getInheritParentNode2(); 148 | } 149 | return cnt; 150 | } 151 | 152 | private int getInheritParentNode2() throws IOException { 153 | Neuron min1 = set.pollFirst(); 154 | Neuron min2 = set.pollFirst(); 155 | int cnt = 0; 156 | if(min1 instanceof HiddenNeuron){ 157 | if(min1.cnt2Alter && min1.alter) 158 | cnt++; 159 | } 160 | if(min2 instanceof HiddenNeuron){ 161 | if(min2.cnt2Alter && min2.alter) 162 | cnt++; 163 | } 164 | set.add(min1.parent); 165 | return cnt; 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /src/com/ansj/vec/util/MapCount.java: -------------------------------------------------------------------------------- 1 | // 2 | // Source code recreated from a .class file by IntelliJ IDEA 3 | // (powered by Fernflower decompiler) 4 | // 5 | 6 | package com.ansj.vec.util; 7 | 8 | import java.util.HashMap; 9 | import java.util.Iterator; 10 | import java.util.Map.Entry; 11 | 12 | public class MapCount { 13 | private HashMap hm = null; 14 | 15 | public MapCount() { 16 | this.hm = new HashMap(); 17 | } 18 | 19 | public MapCount(int initialCapacity) { 20 | this.hm = new HashMap(initialCapacity); 21 | } 22 | 23 | public void add(T t, int n) { 24 | Integer integer = null; 25 | if((integer = (Integer)this.hm.get(t)) != null) { 26 | this.hm.put(t, Integer.valueOf(integer.intValue() + n)); 27 | } else { 28 | this.hm.put(t, Integer.valueOf(n)); 29 | } 30 | 31 | } 32 | 33 | public void add(T t) { 34 | this.add(t, 1); 35 | } 36 | 37 | public int size() { 38 | return this.hm.size(); 39 | } 40 | 41 | public void remove(T t) { 42 | this.hm.remove(t); 43 | } 44 | 45 | public HashMap get() { 46 | return this.hm; 47 | } 48 | 49 | public String getDic() { 50 | Iterator iterator = this.hm.entrySet().iterator(); 51 | StringBuilder sb = new StringBuilder(); 52 | Entry next = null; 53 | 54 | while(iterator.hasNext()) { 55 | next = (Entry)iterator.next(); 56 | sb.append(next.getKey()); 57 | sb.append("\t"); 58 | sb.append(next.getValue()); 59 | sb.append("\n"); 60 | } 61 | 62 | return sb.toString(); 63 | } 64 | 65 | public static void main(String[] args) { 66 | System.out.println(9223372036854775807L); 67 | } 68 | } 69 | --------------------------------------------------------------------------------