├── .gitignore ├── .gitmodules ├── LICENSE ├── Makefile ├── README.md ├── build.sh ├── scripts ├── parameters_setting.txt ├── run.py └── sense_file.txt ├── src ├── data_block.cpp ├── data_block.h ├── dictionary.cpp ├── dictionary.h ├── huffman_encoder.cpp ├── huffman_encoder.h ├── main.cpp ├── multiverso_skipgram_mixture.cpp ├── multiverso_skipgram_mixture.h ├── multiverso_tablesid.h ├── param_loader.cpp ├── param_loader.h ├── reader.cpp ├── reader.h ├── skipgram_mixture_neural_network.cpp ├── skipgram_mixture_neural_network.h ├── trainer.cpp ├── trainer.h ├── util.cpp └── util.h └── windows └── distributed_skipgram_mixture ├── distributed_skipgram_mixture.sln └── distributed_skipgram_mixture.vcxproj /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | multiverso 3 | src/*.o 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "multiverso"] 2 | path = multiverso 3 | url = https://github.com/Microsoft/multiverso.git 4 | branch = multiverso-initial 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PROJECT := $(shell readlink $(dir $(lastword $(MAKEFILE_LIST))) -f) 2 | 3 | CXX = g++ 4 | CXXFLAGS = -O3 \ 5 | -std=c++11 \ 6 | -Wall \ 7 | -Wno-sign-compare \ 8 | -fno-omit-frame-pointer 9 | 10 | MULTIVERSO_DIR = $(PROJECT)/multiverso 11 | MULTIVERSO_INC = $(MULTIVERSO_DIR)/include/ 12 | MULTIVERSO_LIB = $(MULTIVERSO_DIR)/lib 13 | THIRD_PARTY_LIB = $(MULTIVERSO_DIR)/third_party/lib 14 | 15 | INC_FLAGS = -I$(MULTIVERSO_INC) 16 | LD_FLAGS = -L$(MULTIVERSO_LIB) -lmultiverso -lpthread 17 | LD_FLAGS += -L$(THIRD_PARTY_LIB) -lzmq -lmpich -lmpl 18 | 19 | WORD_EMBEDDING_HEADERS = $(shell find $(PROJECT)/src -type f -name "*.h") 20 | WORD_EMBEDDING_SRC = $(shell find $(PROJECT)/src -type f -name "*.cpp") 21 | WORD_EMBEDDING_OBJ = $(WORD_EMBEDDING_SRC:.cpp=.o) 22 | 23 | BIN_DIR = $(PROJECT)/bin 24 | WORD_EMBEDDING = $(BIN_DIR)/distributed_skipgram_mixture 25 | 26 | all: path \ 27 | multisense_word_embedding 28 | 29 | path: $(BIN_DIR) 30 | 31 | $(BIN_DIR): 32 | mkdir -p $@ 33 | 34 | $(WORD_EMBEDDING): $(WORD_EMBEDDING_OBJ) 35 | $(CXX) $(WORD_EMBEDDING_OBJ) $(CXXFLAGS) $(INC_FLAGS) $(LD_FLAGS) -o $@ 36 | 37 | $(WORD_EMBEDDING_OBJ): %.o: %.cpp $(WORD_EMBEDDING_HEADERS) $(MULTIVERSO_INC) 38 | $(CXX) $(CXXFLAGS) $(INC_FLAGS) -c $< -o $@ 39 | 40 | multisense_word_embedding: path $(WORD_EMBEDDING) 41 | 42 | clean: 43 | rm -rf $(BIN_DIR) $(WORD_EMBEDDING_OBJ) 44 | 45 | .PHONY: all path multisense_word_embedding clean 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Distributed Multisense Word Embedding 2 | ========== 3 | 4 | The Distributed Multisense Word Embedding(DMWE) tool is a parallelization of the Skip-Gram Mixture [1] algorithm on top of the DMTK parameter server. It provides an efficient "scaling to industry size" solution for multi sense word embedding. 5 | 6 | For more details, please view our website [http://www.dmtk.io](http://www.dmtk.io) 7 | 8 | Download 9 | ---------- 10 | $ git clone https://github.com/Microsoft/distributed_skipgram_mixture 11 | 12 | Build 13 | ---------- 14 | 15 | **Prerequisite** 16 | 17 | DMWE is built on top of the DMTK parameter sever, therefore please download and build DMTK first (https://github.com/Microsoft/multiverso). 18 | 19 | **For Windows** 20 | 21 | Open windows\distributed_skipgram_mixture\distributed_skipgram_mixture.sln using Visual Studio 2013. Add the necessary include path (for example, the path for DMTK multiverso) and lib path. Then build the solution. 22 | 23 | **For Ubuntu (Tested on Ubuntu 12.04)** 24 | 25 | Download and build by running ```$ sh build.sh```. Modify the include and lib path in Makefile. Then run ```$ make all -j4```. 26 | 27 | Run 28 | ---------- 29 | For parameter settings, see ```scripts/parameters_settings.txt```. For running it, see the example script ```scripts/run.py```. 30 | 31 | Reference 32 | ---------- 33 | [1] Tian, F., Dai, H., Bian, J., Gao, B., Zhang, R., Chen, E., & Liu, T. Y. (2014). [A probabilistic model for learning multi-prototype word embeddings](http://www.aclweb.org/anthology/C14-1016). In Proceedings of COLING (pp. 151-160). 34 | 35 | Microsoft Open Source Code of Conduct 36 | ------------ 37 | 38 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 39 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | # build word_embedding 2 | 3 | git clone git@github.com:Microsoft/multiverso.git 4 | 5 | cd multiverso 6 | cd third_party 7 | sh install.sh 8 | cd .. 9 | make -j4 all 10 | 11 | cd .. 12 | make -j4 13 | -------------------------------------------------------------------------------- /scripts/parameters_setting.txt: -------------------------------------------------------------------------------- 1 | The parameters for Skip-Gram Mixture Multi-Sense model: 2 | 3 | /************************For IO************************/ 4 | -train_file: STRING. The training corpus file, e.g. enwiki2014. 5 | -vocab_file: STRING. The file to read all the vocab counts info, you must extract the words count info into this file. 6 | -binary: INTEGER. 0, 1 or 2, indicates whether to write all the embeddings vectors into binary format. Setting to 2 means to output both binary and text formats. 7 | -read_sense: STRING. The file storing all the pre-defined multi sense words. Each line for each word. 8 | -binary_embedding_file: STRING. The output file to store the multi sense input embedding vectors in binary format. 9 | -text_embedding_file: STRING. The output file to store the multi sense input embedding vectors in text format. 10 | -huff_tree_file: STRING. The output file to store the huffman tree structure. 11 | -outputlayer_binary_file: STRING. The output file to store the huffman tree node embedding vectors in binary format. 12 | -outputlayer_text_file: STRING. The output file to store the huffman tree node embedding vectors in text format. 13 | -stopwords: INTEGER. 0 or 1, whether to avoid training stop words. 14 | -sw_file: STRING. The stop words file storing all the stop words, valid when -stopwords = 1. 15 | 16 | /************************For training configuration************************/ 17 | -size: INTEGER. Word embedding size, e.g. 50. 18 | -init_learning_rate: FLOAT. Initial learning rate, usually set to 0.025, then it will be linearly reduced to 1 during the training process. 19 | -window: INTEGER. The window size. 20 | -threads: INTEGER. The thread number to run in parallel in every single machine. 21 | -min_count: INTEGER. Words with lower frequency than min-count is removed from dictionary. 22 | -epoch: INTEGER. The epoch number. 23 | -sense_num_multi: INTEGER. How many senses for the multi-sense words. 24 | -momentum: FLOAT. The init momentum, must lie in (0, 1). Used to update the sense_priors by sense_priors_t = momentum * sense_priors_[t-1] + (1 - momentum) * \phi. It will be linearly increased to 1 during the training. 25 | -EM_iteration: INTEGER. The number of iterations in EM algorithm. Setting it to 1 is good and fast enough. 26 | -store_multinomial: BINARY. Ways to store the sense priors. 1 means store them as multinomial parameters (between 0 and 1 and their summation is 1), 0 means store them as the log of multinomial parameters. 27 | -top_n: INTEGER. Set the top_n frequent words as multi sense words. 28 | -top_ratio: FLOAT. Must lie in (0, 1), set the top_ratio most frequent words as multi sense words. 29 | 30 | /************************For Distributed Setting related with DMTK************************/ 31 | -lock_option: INTEGER. The lock option. See documents of DMTK. 32 | -num_lock: INTEGER. The number of locks. See documents of DMTK. 33 | -max_delay: INTEGER. The max delay. See documents of DMTK. Strongly recommend setting to 0. 34 | -is_pipline: BINARY. The pipeline setting. See documents of DMTK. Strongly recommend setting to 0. 35 | -data_block_size: INTEGER. The data block size. See documents of DMTK. 36 | -pre_load_data_blocks: INTEGER. The number of blocks to load before training. Set it to avoid memory malloc failure in case of very large corpus. -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import random 5 | import sys 6 | import shutil 7 | import subprocess 8 | 9 | from subprocess import STDOUT 10 | 11 | #Execute a subprocess with standard output 12 | def execute(command): 13 | popen = subprocess.Popen(command, stdout=subprocess.PIPE) 14 | lines_iterator = iter(popen.stdout.readline, b"") 15 | for line in lines_iterator: 16 | print(line) # yield line 17 | 18 | #Parameter w.r.t. MPI 19 | work_dir = 'Your Directory' 20 | port = 'Your port number for MPI' 21 | machinefile= 'Your host file for MPI' 22 | 23 | #Parameter w.r.t. SG-Mixture Training 24 | size = 50 25 | train = 'Your Training File' 26 | ''' 27 | Please visit http://ms-dmtk.azurewebsites.net/word2vec_multi.html#p5 for wiki2014 and clueweb. 28 | In running multi machine version, please separate the file and distribute the subfiles into different machines. 29 | ''' 30 | read_vocab = 'Your Vocab File' #each line of the file stores 'Word\tCount' info. 31 | binary = 2 32 | init_learning_rate = 0.025 33 | epoch = 1 34 | window = 5 35 | threads = 8 36 | mincount = 5 37 | EM_iteration = 1 38 | momentum = 0.05 39 | 40 | default_sense = 1 41 | #Default number of senses for multi sense words 42 | sense_num_multi = 5 43 | 44 | ''' 45 | Three ways of specifying multi sense words, each with sense_num_multi prototypes: 46 | 1)Set top_n frequent words 47 | 2)Set top_ratio (lie between 0 to 1) frequent words 48 | 3)Write all these words into sense_file 49 | ''' 50 | top_n = 0 51 | top_ratio = 0 52 | sense_file = 'Your Sense File, see sense_file.txt as an example' 53 | 54 | #Output files 55 | binary_embedding_file = 'emb.bin' 56 | text_embedding_file = 'emb.txt' 57 | huff_tree_file = 'huff.txt' 58 | outputlayer_binary_file = 'emb_out.bin' 59 | outputlayer_text_file = 'emb_out.txt' 60 | 61 | preload_cnt = 5 62 | 63 | #Number of sentences for each datablock. 64 | #Warning: for wiki2014, set it to 50000, for clueweb09, set it to 750000. Other values are not tested. 65 | data_block_size = 50000 66 | 67 | #Warning: enable pipeline in multiverso will lead to some performance drop 68 | pipline = 0 69 | 70 | #Whether to store the multinomial parameters in its original form. If false, will store their log values instead. 71 | multinomial = 0 72 | 73 | mpi_args = '-port {0} -wdir {1} -machinefile {2} '.format(port, work_dir, machinefile) 74 | sg_mixture_args = ' -train_file {0} -binary_embedding_file {1} -text_embedding_file {2} -threads {3} -size {4} -binary {5} -epoch {6} -init_learning_rate {7} -min_count {8} -window {9} -momentum {12} -EM_iteration {13} -top_n {14} -top_ratio {14} -default_sense {16} -sense_num_multi {17} -huff_tree_file {18} -vocab_file {19} -outputlayer_binary_file {20} -outputlayer_text_file {21} -read_sense {22} -data_block_size {23} -is_pipline {24} -store_multinomial {25} -max_preload_size {26}'.format(train, binary_embedding_file, text_embedding_file, threads, size, binary, epoch, init_learning_rate, mincount, window, momentum, EM_iteration, top_n, top_ratio, default_sense, sense_num_multi, huff_tree_file, read_vocab, outputlayer_binary_file, outputlayer_text_file, sense_file, data_block_size, pipline, multinomial, preload_cnt) 75 | 76 | print mpi_args 77 | print sg_mixture_args 78 | 79 | #Execute MPI 80 | proc = execute("mpiexec " + mpi_args + 'distributed_skipgram_mixture ' + sg_mixture_args) 81 | -------------------------------------------------------------------------------- /scripts/sense_file.txt: -------------------------------------------------------------------------------- 1 | love 2 | sex 3 | tiger 4 | cat 5 | book 6 | paper 7 | computer 8 | keyboard 9 | internet 10 | plane 11 | car 12 | train 13 | telephone 14 | communication 15 | television 16 | radio 17 | media 18 | drug 19 | abuse 20 | bread 21 | butter 22 | cucumber 23 | potato 24 | doctor 25 | nurse 26 | professor 27 | student 28 | smart 29 | stupid 30 | company 31 | stock 32 | market 33 | phone 34 | jaguar 35 | egg 36 | fertility 37 | live 38 | life 39 | library 40 | bank 41 | money 42 | wood 43 | forest 44 | cash 45 | king 46 | cabbage 47 | queen 48 | rook 49 | bishop 50 | rabbi 51 | holy 52 | fuck 53 | football 54 | soccer 55 | basketball 56 | tennis 57 | racket 58 | peace 59 | terror 60 | law 61 | lawyer 62 | movie 63 | star 64 | popcorn 65 | critic 66 | theater 67 | physics 68 | proton 69 | chemistry 70 | space 71 | alcohol 72 | vodka 73 | gin 74 | brandy 75 | drink 76 | ear 77 | mouth 78 | eat 79 | baby 80 | mother 81 | automobile 82 | gem 83 | jewel 84 | journey 85 | voyage 86 | boy 87 | lad 88 | coast 89 | shore 90 | asylum 91 | madhouse 92 | magician 93 | wizard 94 | midday 95 | noon 96 | furnace 97 | stove 98 | food 99 | fruit 100 | bird 101 | cock 102 | crane 103 | tool 104 | implement 105 | brother 106 | monk 107 | oracle 108 | cemetery 109 | woodland 110 | rooster 111 | hill 112 | graveyard 113 | slave 114 | chord 115 | smile 116 | glass 117 | string 118 | dollar 119 | currency 120 | wealth 121 | property 122 | possession 123 | deposit 124 | withdrawal 125 | laundering 126 | operation 127 | feline 128 | carnivore 129 | mammal 130 | animal 131 | organism 132 | fauna 133 | zoo 134 | psychology 135 | psychiatry 136 | anxiety 137 | fear 138 | depression 139 | clinic 140 | mind 141 | health 142 | science 143 | discipline 144 | cognition 145 | planet 146 | constellation 147 | moon 148 | sun 149 | galaxy 150 | astronomer 151 | precedent 152 | example 153 | information 154 | collection 155 | group 156 | antecedent 157 | cup 158 | coffee 159 | tableware 160 | article 161 | artifact 162 | object 163 | entity 164 | substance 165 | liquid 166 | energy 167 | secretary 168 | senate 169 | laboratory 170 | weapon 171 | secret 172 | fingerprint 173 | investigation 174 | effort 175 | water 176 | scientist 177 | news 178 | report 179 | canyon 180 | landscape 181 | image 182 | surface 183 | discovery 184 | seepage 185 | sign 186 | recess 187 | mile 188 | kilometer 189 | territory 190 | atmosphere 191 | president 192 | medal 193 | war 194 | troops 195 | record 196 | number 197 | skin 198 | eye 199 | history 200 | volunteer 201 | motto 202 | prejudice 203 | recognition 204 | decoration 205 | valor 206 | century 207 | year 208 | nation 209 | delay 210 | racism 211 | minister 212 | party 213 | plan 214 | minority 215 | attempt 216 | government 217 | crisis 218 | deployment 219 | departure 220 | announcement 221 | stroke 222 | hospital 223 | disability 224 | death 225 | victim 226 | emergency 227 | treatment 228 | recovery 229 | journal 230 | association 231 | personnel 232 | liability 233 | insurance 234 | school 235 | center 236 | reason 237 | hypertension 238 | criterion 239 | hundred 240 | percent 241 | infrastructure 242 | row 243 | inmate 244 | evidence 245 | term 246 | word 247 | similarity 248 | board 249 | recommendation 250 | governor 251 | interview 252 | country 253 | travel 254 | activity 255 | competition 256 | price 257 | consumer 258 | confidence 259 | problem 260 | airport 261 | flight 262 | credit 263 | card 264 | hotel 265 | reservation 266 | grocery 267 | registration 268 | arrangement 269 | accommodation 270 | month 271 | type 272 | kind 273 | arrival 274 | bed 275 | closet 276 | clothes 277 | situation 278 | conclusion 279 | isolation 280 | impartiality 281 | interest 282 | direction 283 | combination 284 | street 285 | place 286 | avenue 287 | block 288 | children 289 | listing 290 | proximity 291 | category 292 | cell 293 | production 294 | hike 295 | benchmark 296 | index 297 | trading 298 | gain 299 | dividend 300 | payment 301 | calculation 302 | computation 303 | oil 304 | warning 305 | profit 306 | loss 307 | yen 308 | buck 309 | software 310 | network 311 | hardware 312 | equipment 313 | maker 314 | luxury 315 | five 316 | investor 317 | earning 318 | baseball 319 | season 320 | game 321 | victory 322 | team 323 | marathon 324 | sprint 325 | series 326 | defeat 327 | seven 328 | seafood 329 | sea 330 | lobster 331 | wine 332 | preparation 333 | video 334 | archive 335 | start 336 | match 337 | round 338 | boxing 339 | championship 340 | tournament 341 | fighting 342 | defeating 343 | line 344 | day 345 | summer 346 | drought 347 | nature 348 | dawn 349 | environment 350 | ecology 351 | man 352 | woman 353 | murder 354 | manslaughter 355 | soap 356 | opera 357 | performance 358 | lesson 359 | focus 360 | crew 361 | film 362 | lover 363 | quarrel 364 | viewer 365 | serial 366 | possibility 367 | girl 368 | population 369 | development 370 | morality 371 | importance 372 | marriage 373 | gender 374 | equality 375 | change 376 | attitude 377 | family 378 | planning 379 | industry 380 | sugar 381 | approach 382 | practice 383 | institution 384 | ministry 385 | culture 386 | challenge 387 | size 388 | prominence 389 | citizen 390 | people 391 | issue 392 | experience 393 | music 394 | project 395 | metal 396 | aluminum 397 | chance 398 | credibility 399 | exhibit 400 | memorabilia 401 | concert 402 | virtuoso 403 | rock 404 | jazz 405 | museum 406 | observation 407 | architecture 408 | world 409 | preservation 410 | admission 411 | ticket 412 | shower 413 | thunderstorm 414 | flood 415 | weather 416 | forecast 417 | disaster 418 | area 419 | office 420 | brazil 421 | nut 422 | triple 423 | cd 424 | aglow 425 | harvard 426 | yale 427 | cambridge 428 | israel 429 | east 430 | israeli 431 | japanese 432 | american 433 | jerusalem 434 | wall 435 | mexico 436 | puebla 437 | opec 438 | saudi 439 | palestinian 440 | arab 441 | wednesday 442 | weekday 443 | haven 444 | ability 445 | know-how 446 | persecution 447 | accepted 448 | acknowledged 449 | believe 450 | welcome 451 | accommodate 452 | adjust 453 | settlement 454 | acronym 455 | form 456 | inaction 457 | confession 458 | matriculation 459 | advance 460 | headway 461 | propose 462 | advised 463 | inform 464 | affect 465 | tense 466 | aged 467 | develop 468 | old 469 | worn 470 | young 471 | aim 472 | cause 473 | hold 474 | thing 475 | airfield 476 | sterol 477 | component 478 | ambush 479 | surprise 480 | herbivore 481 | statement 482 | answering 483 | counter 484 | react 485 | worry 486 | access 487 | converge 488 | arch 489 | curve 490 | building 491 | urban 492 | shoulder 493 | flower 494 | weight 495 | ash 496 | bone 497 | assignment 498 | document 499 | assimilation 500 | americanization 501 | realtor 502 | astrophysicist 503 | bedlam 504 | shelter 505 | stp 506 | crime 507 | attempted 508 | initiate 509 | give 510 | attended 511 | guard 512 | respect 513 | window 514 | treat 515 | back 516 | lumbar 517 | out 518 | second 519 | side 520 | riverbank 521 | base 522 | air 523 | build 524 | structure 525 | ball 526 | tip-off 527 | bat 528 | bats 529 | placental 530 | turn 531 | beam 532 | signal 533 | plant 534 | platform 535 | plot 536 | point 537 | bent 538 | double 539 | inclination 540 | bias 541 | experimenter 542 | bigger 543 | large 544 | meat 545 | bitter 546 | ale 547 | resentful 548 | taste 549 | close 550 | blow 551 | blowing 552 | exhale 553 | insufflate 554 | blown 555 | swat 556 | deal 557 | directorate 558 | open 559 | phrase 560 | picture 561 | booster 562 | advertiser 563 | advocate 564 | bootleg 565 | covering 566 | produce 567 | whiskey 568 | bore 569 | cut 570 | flow 571 | stuff 572 | bound 573 | boundary 574 | skirt 575 | vault 576 | bow 577 | knot 578 | bowling 579 | wheel 580 | yield 581 | hit 582 | angstrom 583 | branching 584 | bifurcation 585 | grow 586 | cover 587 | bring 588 | return 589 | bronze 590 | nickel 591 | sculpture 592 | freemason 593 | bill 594 | horse 595 | buffer 596 | zone 597 | bug 598 | defect 599 | insect 600 | burning 601 | important 602 | burnt 603 | treated 604 | bury 605 | hide 606 | lay 607 | steal 608 | cake 609 | dish 610 | tablet 611 | interpolation 612 | camp 613 | gathering 614 | camping 615 | pitch 616 | capital 617 | assets 618 | primary 619 | seed 620 | playing 621 | ration 622 | see 623 | limousine 624 | predator 625 | carrier 626 | traveler 627 | vehicle 628 | cast 629 | shoot 630 | catching 631 | catch 632 | perceive 633 | compartment 634 | membrane 635 | site 636 | central 637 | fiscal 638 | refer 639 | think 640 | tract 641 | integer 642 | litigate 643 | contest 644 | fortune 645 | ring 646 | relation 647 | channels 648 | transmission 649 | charcoal 650 | fuel 651 | checked 652 | examine 653 | offspring 654 | cholera 655 | infectious 656 | arpeggio 657 | play 658 | circulation 659 | dissemination 660 | citation 661 | act 662 | award 663 | national 664 | clean 665 | dry 666 | jerk 667 | remove 668 | clear 669 | innocence 670 | make 671 | session 672 | closed 673 | state 674 | up 675 | storage 676 | toilet 677 | closing 678 | finale 679 | motion 680 | snap 681 | glide 682 | slope 683 | penis 684 | populace 685 | collect 686 | due 687 | archives 688 | data 689 | prayer 690 | scrape 691 | take 692 | colored 693 | black 694 | grey 695 | colors 696 | emblem 697 | ensign 698 | pigment 699 | scheme 700 | commit 701 | vow 702 | abstraction 703 | communications 704 | compact 705 | case 706 | pack 707 | packed 708 | short 709 | write 710 | compound 711 | enhance 712 | recombine 713 | whole 714 | conversion 715 | procedure 716 | assumption 717 | judgment 718 | connect 719 | natural 720 | topology 721 | content 722 | disapproval 723 | limit 724 | convinced 725 | disarm 726 | cool 727 | air-conditioned 728 | coldness 729 | answer 730 | negative 731 | ally 732 | course 733 | stream 734 | workshop 735 | blanket 736 | cloak 737 | submerge 738 | crabs 739 | crab 740 | crack 741 | check 742 | cracking 743 | noise 744 | sound 745 | ovation 746 | condition 747 | measurement 748 | cross 749 | marking 750 | meet 751 | gherkin 752 | melon 753 | counterculture 754 | subculture 755 | shape 756 | curb 757 | bit 758 | smother 759 | exchange 760 | prevalence 761 | cutting 762 | nip 763 | damned 764 | lost 765 | raise 766 | begin 767 | hour 768 | morrow 769 | deaf 770 | organic 771 | deceased 772 | born 773 | dead 774 | service 775 | set 776 | disappointment 777 | defeated 778 | upset 779 | ending 780 | overrun 781 | veto 782 | defense 783 | biological 784 | protection 785 | hesitate 786 | moratorium 787 | deliver 788 | serve 789 | demand 790 | call 791 | claim 792 | request 793 | supply 794 | demonic 795 | evil 796 | variation 797 | redeployment 798 | geological 799 | derived 800 | reap 801 | mental 802 | deepening 803 | elaboration 804 | qibla 805 | anorgasmia 806 | disabled 807 | unfit 808 | adversity 809 | wave 810 | system 811 | breakthrough 812 | disclosure 813 | display 814 | light 815 | model 816 | stick 817 | ditch 818 | abandon 819 | drain 820 | waterway 821 | numerator 822 | diving 823 | one-half 824 | theologian 825 | dominican 826 | dot 827 | disk 828 | seashore 829 | multiple 830 | draw 831 | entertainer 832 | gully 833 | drawing 834 | drag 835 | frame 836 | twitch 837 | dress 838 | formal 839 | morning 840 | neckline 841 | wear 842 | drive 843 | impulse 844 | mechanism 845 | operate 846 | ride 847 | drop 848 | descend 849 | measure 850 | dropping 851 | sink 852 | teardrop 853 | time 854 | anesthetic 855 | dwelling 856 | exist 857 | education 858 | effects 859 | backdate 860 | happen 861 | repercussion 862 | liberation 863 | white 864 | eliminate 865 | temp 866 | emphasized 867 | stress 868 | endeavor 869 | worst 870 | force 871 | heat 872 | habitat 873 | equal 874 | differ 875 | inadequate 876 | sameness 877 | tie 878 | person 879 | rescue 880 | establishment 881 | collectivization 882 | even 883 | identification 884 | notarize 885 | specimen 886 | parade 887 | possess 888 | show 889 | inexperience 890 | suffer 891 | extension 892 | expansion 893 | look 894 | fail 895 | choke 896 | shipwreck 897 | fair 898 | join 899 | midway 900 | moderate 901 | child 902 | famine 903 | irish 904 | lack 905 | fast 906 | abstain 907 | hunger 908 | sudden 909 | avifauna 910 | fearlessness 911 | panic 912 | sterility 913 | fiction 914 | canard 915 | fantasy 916 | field 917 | handle 918 | brush 919 | resist 920 | filled 921 | medium 922 | fine 923 | precise 924 | firm 925 | corporation 926 | hard 927 | forward 928 | fleet 929 | aircraft 930 | sortie 931 | debacle 932 | fill 933 | flush 934 | age 935 | down 936 | good 937 | rich 938 | fly 939 | blur 940 | concentration 941 | distinctness 942 | follow 943 | evaluate 944 | predict 945 | prognosis 946 | biome 947 | growth 948 | abbreviation 949 | fort 950 | presidio 951 | bold 952 | player 953 | position 954 | transport 955 | found 956 | wage 957 | freeing 958 | parole 959 | freelance 960 | work 961 | bear 962 | product 963 | intercourse 964 | fire 965 | full 966 | further 967 | far 968 | promote 969 | fused 970 | gauge 971 | united 972 | account 973 | obtain 974 | gallery 975 | audience 976 | parlay 977 | playoff 978 | gas 979 | sewer 980 | art 981 | crystalline 982 | asexual 983 | gen 984 | gig 985 | harpoon 986 | seidel 987 | gone 988 | gore 989 | blood 990 | pierce 991 | control 992 | graduate 993 | confer 994 | high 995 | instrument 996 | receive 997 | scholar 998 | greater 999 | shelf 1000 | vinyl 1001 | fractious 1002 | soft 1003 | weaponry 1004 | hay 1005 | fodder 1006 | healing 1007 | better 1008 | illness 1009 | heel 1010 | stack 1011 | height 1012 | degree 1013 | increase 1014 | tor 1015 | recital 1016 | topographical 1017 | home 1018 | away 1019 | honest 1020 | sincere 1021 | honey 1022 | hooks 1023 | hand 1024 | hot 1025 | calorific 1026 | violent 1027 | cardiovascular 1028 | ice 1029 | appearance 1030 | visualize 1031 | use 1032 | standing 1033 | increased 1034 | add 1035 | enhanced 1036 | maximize 1037 | list 1038 | mass 1039 | safety 1040 | enterprise 1041 | confirmation 1042 | datum 1043 | initial 1044 | first 1045 | letter 1046 | resident 1047 | instability 1048 | disorder 1049 | vicariate 1050 | instruction 1051 | recipe 1052 | contract 1053 | benefit 1054 | enthusiasm 1055 | intermediate 1056 | chemical 1057 | introduction 1058 | opening 1059 | usher 1060 | count 1061 | fishing 1062 | iron 1063 | alpha 1064 | robust 1065 | alienation 1066 | solitude 1067 | edit 1068 | edition 1069 | overriding 1070 | unblock 1071 | neck 1072 | writing 1073 | ship 1074 | keep 1075 | continue 1076 | stronghold 1077 | holder 1078 | typeset 1079 | rival 1080 | region 1081 | exclude 1082 | painting 1083 | scenery 1084 | language 1085 | soliloquy 1086 | last 1087 | end 1088 | past 1089 | populate 1090 | rank 1091 | run 1092 | senior 1093 | machine 1094 | wash 1095 | joke 1096 | mosaic 1097 | barrister 1098 | song 1099 | learned 1100 | discover 1101 | educated 1102 | scholarly 1103 | left 1104 | dad 1105 | admonition 1106 | lever 1107 | house 1108 | being 1109 | lift 1110 | aid 1111 | consequence 1112 | lower 1113 | move 1114 | lighter 1115 | heavy 1116 | headlight 1117 | lighten 1118 | smoke 1119 | limb 1120 | branch 1121 | crib 1122 | lie 1123 | matter 1124 | organization 1125 | post 1126 | sick 1127 | shopping 1128 | looking 1129 | search 1130 | sightseeing 1131 | sparkle 1132 | reducing 1133 | wastage 1134 | sleep 1135 | agape 1136 | care 1137 | like 1138 | lowering 1139 | decrease 1140 | devalue 1141 | movement 1142 | reef 1143 | escapologist 1144 | magus 1145 | manufacturer 1146 | homicide 1147 | manufacture 1148 | commercial 1149 | mark 1150 | broad 1151 | buoy 1152 | class 1153 | labor 1154 | shop 1155 | trade 1156 | spouse 1157 | married 1158 | final 1159 | matchstick 1160 | mature 1161 | distinguish 1162 | telecommunication 1163 | mention 1164 | allusion 1165 | comment 1166 | photograph 1167 | merit 1168 | demerit 1169 | worthiness 1170 | silver 1171 | tombac 1172 | middle 1173 | phase 1174 | put 1175 | milk 1176 | nutriment 1177 | recall 1178 | mine 1179 | strip 1180 | sulfur 1181 | clergyman 1182 | foreign 1183 | priesthood 1184 | minor 1185 | miss 1186 | failure 1187 | mock 1188 | derision 1189 | tease 1190 | ptolemaic 1191 | mole 1192 | spy 1193 | pile 1194 | religious 1195 | solar 1196 | week 1197 | inner 1198 | virtue 1199 | morocco 1200 | levant 1201 | moroccan 1202 | motel 1203 | motor 1204 | mother-in-law 1205 | yeast 1206 | agent 1207 | stepper 1208 | catchphrase 1209 | mourning 1210 | sadness 1211 | beak 1212 | tongue 1213 | ms 1214 | gulf 1215 | depository 1216 | narrow 1217 | determine 1218 | limited 1219 | strait 1220 | spanish 1221 | characteristic 1222 | nerves 1223 | psychological 1224 | newspaper 1225 | item 1226 | publisher 1227 | update 1228 | nobility 1229 | aristocrat 1230 | majority 1231 | practical 1232 | business 1233 | printing 1234 | offset 1235 | balance 1236 | compensation 1237 | rift 1238 | musical 1239 | action 1240 | prophecy 1241 | orchestrated 1242 | score 1243 | individual 1244 | oxygen 1245 | battalion 1246 | panorama 1247 | parent 1248 | part 1249 | interruption 1250 | dance 1251 | union 1252 | pass 1253 | accomplishment 1254 | passing 1255 | done 1256 | leave 1257 | patronage 1258 | blessing 1259 | pat 1260 | touch 1261 | repayment 1262 | amity 1263 | occupancy 1264 | perfect 1265 | polish 1266 | unbroken 1267 | magic 1268 | department 1269 | police 1270 | delegate 1271 | siphon 1272 | plain 1273 | knit 1274 | llano 1275 | obvious 1276 | simple 1277 | employee 1278 | seaplane 1279 | follower 1280 | schedule 1281 | plasma 1282 | gamma 1283 | curtain 1284 | played 1285 | die 1286 | foul 1287 | thousand 1288 | poll 1289 | homo 1290 | straw 1291 | vote 1292 | pop 1293 | father 1294 | overpopulation 1295 | criminal 1296 | liabilities 1297 | expectation 1298 | vine 1299 | pound 1300 | formalism 1301 | learn 1302 | prairie 1303 | grassland 1304 | civil 1305 | smell 1306 | self-preservation 1307 | pressing 1308 | squeeze 1309 | cost 1310 | principle 1311 | hellenism 1312 | yin 1313 | print 1314 | contact 1315 | difficulty 1316 | poser 1317 | growing 1318 | academician 1319 | advantage 1320 | bulge 1321 | projecting 1322 | communicate 1323 | moneymaker 1324 | limelight 1325 | promise 1326 | betrothal 1327 | declare 1328 | pledge 1329 | nucleon 1330 | vicinity 1331 | psychotherapy 1332 | argue 1333 | quiet 1334 | louden 1335 | order 1336 | leader 1337 | title 1338 | anti-semitism 1339 | profiling 1340 | broadcasting 1341 | raising 1342 | bump 1343 | rise 1344 | upbringing 1345 | reading 1346 | interpretation 1347 | read 1348 | explanation 1349 | generalize 1350 | present 1351 | rebel 1352 | soldier 1353 | designation 1354 | advice 1355 | puff 1356 | improvement 1357 | rally 1358 | reduced 1359 | abbreviate 1360 | bated 1361 | low 1362 | reduce 1363 | simplify 1364 | slash 1365 | reed 1366 | body 1367 | entrance 1368 | relieve 1369 | comfort 1370 | free 1371 | announce 1372 | blue 1373 | resistance 1374 | resolution 1375 | physical 1376 | retirement 1377 | status 1378 | reverse 1379 | gear 1380 | opposition 1381 | right 1382 | proper 1383 | chondrite 1384 | limestone 1385 | rod 1386 | role 1387 | hat 1388 | rolling 1389 | robbery 1390 | romance 1391 | intrigue 1392 | quality 1393 | chicken 1394 | rose 1395 | damask 1396 | soar 1397 | rough 1398 | crushed 1399 | golf 1400 | ammunition 1401 | dispute 1402 | terrace 1403 | rubber 1404 | crepe 1405 | running 1406 | dash 1407 | sacrifice 1408 | kill 1409 | sales 1410 | divestiture 1411 | income 1412 | sample 1413 | satellite 1414 | outer 1415 | scattered 1416 | break 1417 | disband 1418 | incoherent 1419 | alumnus 1420 | virtuosity 1421 | cosmographer 1422 | south 1423 | leap 1424 | ordinal 1425 | assistant 1426 | password 1427 | appreciate 1428 | diocese 1429 | seeing 1430 | exudation 1431 | sent 1432 | ordered 1433 | quarterly 1434 | broadcast 1435 | spot 1436 | foreplay 1437 | sheet 1438 | expanse 1439 | worksheet 1440 | shifting 1441 | shining 1442 | shooting 1443 | sucker 1444 | arrive 1445 | diamond 1446 | shot 1447 | colorful 1448 | grapeshot 1449 | plumbing 1450 | token 1451 | single 1452 | common 1453 | singles 1454 | badminton 1455 | singleton 1456 | sit 1457 | element 1458 | classify 1459 | coat 1460 | magnitude 1461 | slain 1462 | worker 1463 | slight 1464 | dismiss 1465 | insignificant 1466 | silent 1467 | slip 1468 | freudian 1469 | slow 1470 | gradual 1471 | smashing 1472 | blast 1473 | expression 1474 | smoking 1475 | emit 1476 | inhale 1477 | vaporization 1478 | bribe 1479 | saddle 1480 | compatible 1481 | sole 1482 | solo 1483 | perform 1484 | sounding 1485 | depth 1486 | location 1487 | spiral 1488 | coiled 1489 | helix 1490 | splash 1491 | spread 1492 | discharge 1493 | disparity 1494 | dispersion 1495 | distributed 1496 | square 1497 | angular 1498 | lawful 1499 | polygon 1500 | stalls 1501 | livery 1502 | have 1503 | path 1504 | beginning 1505 | starter 1506 | get 1507 | stated 1508 | still 1509 | silence 1510 | merchandise 1511 | stocks 1512 | framework 1513 | heater 1514 | striking 1515 | conspicuous 1516 | crash 1517 | impressive 1518 | arrange 1519 | fingerboard 1520 | strike 1521 | study 1522 | examination 1523 | learning 1524 | review 1525 | submarine 1526 | sandwich 1527 | subordinate 1528 | dog 1529 | insubordinate 1530 | under 1531 | succeeding 1532 | suffering 1533 | enjoy 1534 | pain 1535 | beet 1536 | solstice 1537 | sunburst 1538 | bubble 1539 | horizontal 1540 | suspension 1541 | lapse 1542 | sway 1543 | power 1544 | cutlery 1545 | color 1546 | telling 1547 | narration 1548 | referent 1549 | tenure 1550 | africa 1551 | box 1552 | things 1553 | expect 1554 | storm 1555 | tag 1556 | means 1557 | topping 1558 | tops 1559 | crown 1560 | touched 1561 | touching 1562 | universe 1563 | trace 1564 | support 1565 | traffic 1566 | commerce 1567 | gravitation 1568 | tour 1569 | walk 1570 | dealing 1571 | trench 1572 | trick 1573 | deceive 1574 | shift 1575 | daze 1576 | hostile 1577 | march 1578 | trouble 1579 | perturbation 1580 | twist 1581 | adult 1582 | identify 1583 | version 1584 | uniform 1585 | jump 1586 | upgrade 1587 | afflict 1588 | agitation 1589 | disturbance 1590 | troubled 1591 | application 1592 | consume 1593 | used 1594 | functional 1595 | usual 1596 | familiar 1597 | vanished 1598 | vector 1599 | radius 1600 | videotape 1601 | violate 1602 | musician 1603 | vision 1604 | imagination 1605 | screwdriver 1606 | void 1607 | invalid 1608 | nonexistence 1609 | validate 1610 | waking 1611 | sleeping 1612 | walking 1613 | accompany 1614 | locomotion 1615 | wade 1616 | wandering 1617 | about 1618 | stray 1619 | battle 1620 | hostility 1621 | strategic 1622 | shampoo 1623 | lake 1624 | perspiration 1625 | wet 1626 | ways 1627 | shipyard 1628 | weak 1629 | diluted 1630 | weakening 1631 | transformation 1632 | flimsy 1633 | abundance 1634 | persuasion 1635 | precipitation 1636 | welsh 1637 | brythonic 1638 | wetness 1639 | bleach 1640 | light-skinned 1641 | dark 1642 | delaware 1643 | withdraw 1644 | retrograde 1645 | female 1646 | silva 1647 | red 1648 | workings 1649 | excavation 1650 | chow 1651 | poor 1652 | result 1653 | wounded 1654 | wrong 1655 | false 1656 | improper 1657 | injury 1658 | y2k 1659 | yielding 1660 | assent 1661 | facility 1662 | undertaking 1663 | retail 1664 | credence 1665 | bundle 1666 | overlook 1667 | happening 1668 | room 1669 | relearn 1670 | express 1671 | anagram 1672 | insane 1673 | awake 1674 | lamb 1675 | mutiny 1676 | enclose 1677 | mute 1678 | foundation 1679 | nonprofessional 1680 | grazed 1681 | extort 1682 | efferent 1683 | brave 1684 | hurl 1685 | asian 1686 | longing 1687 | intolerable 1688 | grieve 1689 | hailstone 1690 | considered 1691 | wind 1692 | kick 1693 | formation 1694 | scan 1695 | protest 1696 | desk 1697 | indirect 1698 | compel 1699 | prepared 1700 | ask 1701 | stabilize 1702 | conference 1703 | impermanent 1704 | backward 1705 | solemnize 1706 | liquidate 1707 | damage 1708 | likeness 1709 | constitute 1710 | patterned 1711 | prepare 1712 | english 1713 | lip 1714 | nearness 1715 | nonexistent 1716 | imperative 1717 | mocha 1718 | sauce 1719 | living 1720 | calmness 1721 | male 1722 | flap 1723 | maradona 1724 | arafat 1725 | -------------------------------------------------------------------------------- /src/data_block.cpp: -------------------------------------------------------------------------------- 1 | #include "data_block.h" 2 | 3 | size_t DataBlock::Size() 4 | { 5 | return m_sentences.size(); 6 | } 7 | 8 | void DataBlock::Add(int *head, int sentence_length, int64_t word_count, uint64_t next_random) 9 | { 10 | Sentence sentence(head, sentence_length, word_count, next_random); 11 | m_sentences.push_back(sentence); 12 | } 13 | 14 | void DataBlock::UpdateNextRandom() 15 | { 16 | for (int i = 0; i < m_sentences.size(); ++i) 17 | m_sentences[i].next_random *= (uint64_t)rand(); 18 | } 19 | 20 | void DataBlock::Get(int index, int* &head, int &sentence_length, int64_t &word_count, uint64_t &next_random) 21 | { 22 | if (index >= 0 && index < m_sentences.size()) 23 | { 24 | m_sentences[index].Get(head, sentence_length, word_count, next_random); 25 | } 26 | else 27 | { 28 | head = nullptr; 29 | sentence_length = 0; 30 | word_count = 0; 31 | next_random = 0; 32 | } 33 | } 34 | 35 | void DataBlock::ReleaseSentences() 36 | { 37 | for (int i = 0; i < m_sentences.size(); ++i) 38 | delete m_sentences[i].head; 39 | m_sentences.clear(); 40 | } 41 | 42 | void DataBlock::AddTable(int table_id) 43 | { 44 | m_tables.push_back(table_id); 45 | } 46 | 47 | std::vector & DataBlock::GetTables() 48 | { 49 | return m_tables; 50 | } 51 | 52 | void DataBlock::SetEpochId(const int epoch_id) 53 | { 54 | m_epoch_id = epoch_id; 55 | } 56 | 57 | int DataBlock::GetEpochId() 58 | { 59 | return m_epoch_id; 60 | } 61 | -------------------------------------------------------------------------------- /src/data_block.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /*! 4 | * \file DataBlock.h 5 | * \brief Defines class DataBlock to store the necessary data for trainer and param_loader 6 | * \author 7 | * - v-fetia 8 | */ 9 | #include "util.h" 10 | #include 11 | #include "huffman_encoder.h" 12 | 13 | /*! 14 | * \brief The class DataBlock stores train for trainer and param_loader 15 | */ 16 | class DataBlock : public multiverso::DataBlockBase 17 | { 18 | public: 19 | /*! 20 | * \brief Get the number of sentences stored in DataBlock 21 | * \return the number of sentences 22 | */ 23 | size_t Size(); 24 | /*! 25 | * \brief Add a new sentence to the DataBlock 26 | * \param sentence the starting address of the sentence 27 | * \param sentence_length the length of the sentence 28 | * \param word_count the number of words when getting the sentence from train-file 29 | * \param next_random the seed for getting random number 30 | */ 31 | void Add(int *sentence, int sentence_length, int64_t word_count, uint64_t next_random); 32 | /*! 33 | * \brief Get the information of the index-th sentence 34 | * \param index the id of the sentence 35 | * \param sentence the starting address of the sentence 36 | * \param sentence_length the number of words in the sentence 37 | * \param word_count the number of words when getting the sentence from train-file 38 | * \param next_random the seed for getting random number 39 | */ 40 | void Get(int index, int* &sentence, int &sentence_length, int64_t &word_count, uint64_t &next_random); 41 | 42 | 43 | void UpdateNextRandom(); 44 | 45 | void AddTable(int table_id); 46 | 47 | std::vector & GetTables(); 48 | 49 | void ReleaseSentences(); 50 | 51 | int GetEpochId(); 52 | 53 | void SetEpochId(const int epoch_id); 54 | 55 | private: 56 | struct Sentence 57 | { 58 | int* head; 59 | int length; 60 | int64_t word_count; 61 | uint64_t next_random; 62 | Sentence(int *head, int length, int64_t word_count, uint64_t next_random) 63 | :head(head), length(length), word_count(word_count), next_random(next_random){} 64 | void Get(int* &local_head, int &sentence_length, int64_t &local_word_count, uint64_t &local_next_random) 65 | { 66 | local_head = head; 67 | sentence_length = length; 68 | local_word_count = word_count; 69 | local_next_random = next_random; 70 | } 71 | }; 72 | 73 | std::vector m_tables; 74 | std::vector m_sentences; 75 | int m_epoch_id; 76 | }; 77 | -------------------------------------------------------------------------------- /src/dictionary.cpp: -------------------------------------------------------------------------------- 1 | #include "dictionary.h" 2 | 3 | Dictionary::Dictionary() 4 | { 5 | combine =0; 6 | Clear(); 7 | } 8 | 9 | Dictionary::Dictionary(int i) 10 | { 11 | combine = i; 12 | Clear(); 13 | } 14 | 15 | void Dictionary::Clear() 16 | { 17 | m_word_idx_map.clear(); 18 | m_word_info.clear(); 19 | m_word_whitelist.clear(); 20 | } 21 | 22 | void Dictionary::SetWhiteList(const std::vector& whitelist) 23 | { 24 | for (unsigned int i = 0; i < whitelist.size(); ++i) 25 | m_word_whitelist.insert(whitelist[i]); 26 | } 27 | 28 | void Dictionary::MergeInfrequentWords(int64_t threshold) 29 | { 30 | m_word_idx_map.clear(); 31 | std::vector tmp_info; 32 | tmp_info.clear(); 33 | int infreq_idx = -1; 34 | 35 | for (auto& word_info : m_word_info) 36 | { 37 | if (word_info.freq >= threshold || word_info.freq == 0 || m_word_whitelist.count(word_info.word)) 38 | { 39 | m_word_idx_map[word_info.word] = static_cast(tmp_info.size()); 40 | tmp_info.push_back(word_info); 41 | } 42 | else { 43 | if (infreq_idx < 0) 44 | { 45 | WordInfo infreq_word_info; 46 | infreq_word_info.word = "WE_ARE_THE_INFREQUENT_WORDS"; 47 | infreq_word_info.freq = 0; 48 | m_word_idx_map[infreq_word_info.word] = static_cast(tmp_info.size()); 49 | infreq_idx = static_cast(tmp_info.size()); 50 | tmp_info.push_back(infreq_word_info); 51 | } 52 | m_word_idx_map[word_info.word] = infreq_idx; 53 | tmp_info[infreq_idx].freq += word_info.freq; 54 | } 55 | } 56 | m_word_info = tmp_info; 57 | } 58 | 59 | void Dictionary::RemoveWordsLessThan(int64_t min_count) 60 | { 61 | m_word_idx_map.clear(); 62 | std::vector tmp_info; 63 | tmp_info.clear(); 64 | for (auto& info : m_word_info) 65 | { 66 | if (info.freq >= min_count || info.freq == 0 || m_word_whitelist.count(info.word)) 67 | { 68 | m_word_idx_map[info.word] = static_cast(tmp_info.size()); 69 | tmp_info.push_back(info); 70 | } 71 | } 72 | m_word_info = tmp_info; 73 | } 74 | 75 | void Dictionary::Insert(const char* word, int64_t cnt) 76 | { 77 | const auto& it = m_word_idx_map.find(word); 78 | if (it != m_word_idx_map.end()) 79 | m_word_info[it->second].freq += cnt; 80 | else 81 | { 82 | m_word_idx_map[word] = static_cast(m_word_info.size()); 83 | m_word_info.push_back(WordInfo(word, cnt)); 84 | } 85 | } 86 | 87 | void Dictionary::LoadFromFile(const char* filename) 88 | { 89 | FILE* fid = fopen(filename, "r"); 90 | 91 | if(fid) 92 | { 93 | char sz_label[MAX_WORD_SIZE]; 94 | 95 | while (fscanf(fid, "%s", sz_label, MAX_WORD_SIZE) != EOF) 96 | { 97 | int freq; 98 | fscanf(fid, "%d", &freq); 99 | Insert(sz_label, freq); 100 | } 101 | fclose(fid); 102 | } 103 | } 104 | 105 | void Dictionary::LoadTriLetterFromFile(const char* filename, unsigned int min_cnt, unsigned int letter_count) 106 | { 107 | FILE* fid = fopen(filename, "r"); 108 | if(fid) 109 | { 110 | char sz_label[MAX_WORD_SIZE]; 111 | while (fscanf(fid, "%s", sz_label, MAX_WORD_SIZE) != EOF) 112 | { 113 | int freq; 114 | fscanf(fid, "%d", &freq); 115 | if (static_cast(freq) < min_cnt) continue; 116 | 117 | // Construct Tri-letter From word 118 | size_t len = strlen(sz_label); 119 | if (len > MAX_WORD_SIZE) 120 | { 121 | printf("ignore super long term"); 122 | continue; 123 | } 124 | 125 | char tri_letters[MAX_WORD_SIZE + 2]; 126 | tri_letters[0] = '#'; 127 | int i = 0; 128 | for (i = 0; i < strlen(sz_label); i++) 129 | { 130 | tri_letters[i+1] = sz_label[i]; 131 | } 132 | 133 | tri_letters[i+1] = '#'; 134 | tri_letters[i+2] = 0; 135 | if (combine) Insert(sz_label,freq); 136 | 137 | if (strlen(tri_letters) <= letter_count) { 138 | Insert(tri_letters, freq); 139 | } else { 140 | for (i = 0; i <= strlen(tri_letters) - letter_count; ++i) 141 | { 142 | char tri_word[MAX_WORD_SIZE]; 143 | unsigned int j = 0; 144 | for(j = 0; j < letter_count; j++) 145 | { 146 | tri_word[j] = tri_letters[i+j]; 147 | } 148 | tri_word[j] = 0; 149 | Insert(tri_word, freq); 150 | } 151 | } 152 | } 153 | fclose(fid); 154 | } 155 | } 156 | 157 | 158 | int Dictionary::GetWordIdx(const char* word) 159 | { 160 | const auto& it = m_word_idx_map.find(word); 161 | if (it != m_word_idx_map.end()) 162 | return it->second; 163 | return -1; 164 | } 165 | 166 | int Dictionary::Size() 167 | { 168 | return static_cast(m_word_info.size()); 169 | } 170 | 171 | const WordInfo* Dictionary::GetWordInfo(const char* word) 172 | { 173 | const auto& it = m_word_idx_map.find(word); 174 | if (it != m_word_idx_map.end()) 175 | return GetWordInfo(it->second); 176 | return NULL; 177 | } 178 | 179 | const WordInfo* Dictionary::GetWordInfo(int word_idx) 180 | { 181 | if (word_idx >= 0 && word_idx < m_word_info.size()) 182 | return &m_word_info[word_idx]; 183 | return NULL; 184 | } 185 | 186 | void Dictionary::StartIteration() 187 | { 188 | m_word_iterator = m_word_info.begin(); 189 | } 190 | 191 | bool Dictionary::HasMore() 192 | { 193 | return m_word_iterator != m_word_info.end(); 194 | } 195 | 196 | const WordInfo* Dictionary::Next() 197 | { 198 | const WordInfo* entry = &(*m_word_iterator); 199 | ++m_word_iterator; 200 | return entry; 201 | } 202 | 203 | std::vector::iterator Dictionary::Begin() 204 | { 205 | return m_word_info.begin(); 206 | } 207 | std::vector::iterator Dictionary::End() 208 | { 209 | return m_word_info.end(); 210 | } 211 | -------------------------------------------------------------------------------- /src/dictionary.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "util.h" 8 | 9 | const int MAX_WORD_SIZE = 901; 10 | 11 | struct WordInfo 12 | { 13 | std::string word; 14 | int64_t freq; 15 | WordInfo() 16 | { 17 | freq = 0; 18 | word.clear(); 19 | } 20 | WordInfo(const std::string& _word, int64_t _freq) 21 | { 22 | word = _word; 23 | freq = _freq; 24 | } 25 | }; 26 | 27 | class Dictionary 28 | { 29 | public: 30 | Dictionary(); 31 | Dictionary(int i); 32 | void Clear(); 33 | void SetWhiteList(const std::vector& whitelist); 34 | void RemoveWordsLessThan(int64_t min_count); 35 | void MergeInfrequentWords(int64_t threshold); 36 | void Insert(const char* word, int64_t cnt = 1); 37 | void LoadFromFile(const char* filename); 38 | void LoadTriLetterFromFile(const char* filename, unsigned int min_cnt = 1, unsigned int letter_count = 3); 39 | int GetWordIdx(const char* word); 40 | const WordInfo* GetWordInfo(const char* word); 41 | const WordInfo* GetWordInfo(int word_idx); 42 | int Size(); 43 | void StartIteration(); 44 | bool HasMore(); 45 | const WordInfo* Next(); 46 | std::vector::iterator Begin(); 47 | std::vector::iterator End(); 48 | 49 | private: 50 | int combine; 51 | std::vector m_word_info; 52 | std::vector::iterator m_word_iterator; 53 | std::unordered_map m_word_idx_map; 54 | std::unordered_set m_word_whitelist; 55 | }; 56 | -------------------------------------------------------------------------------- /src/huffman_encoder.cpp: -------------------------------------------------------------------------------- 1 | #include "huffman_encoder.h" 2 | #include 3 | #include 4 | 5 | HuffmanEncoder::HuffmanEncoder() 6 | { 7 | m_dict = NULL; 8 | } 9 | 10 | void HuffmanEncoder::Save2File(const char* filename) 11 | { 12 | FILE* fid = fopen(filename, "w"); 13 | if(fid) 14 | { 15 | fprintf(fid, "%lld\n", m_hufflabel_info.size()); 16 | 17 | for (unsigned i = 0; i < m_hufflabel_info.size(); ++i) 18 | { 19 | const auto& info = m_hufflabel_info[i]; 20 | const auto& word = m_dict->GetWordInfo(i); 21 | fprintf(fid, "%s %d", word->word.c_str(), info.codelen); 22 | 23 | for (int j = 0; j < info.codelen; ++j) 24 | fprintf(fid, " %d", info.code[j]); 25 | 26 | for (int j = 0; j < info.codelen; ++j) 27 | fprintf(fid, " %d", info.point[j]); 28 | 29 | fprintf(fid, "\n"); 30 | } 31 | 32 | fclose(fid); 33 | } 34 | else 35 | { 36 | printf("file open failed %s", filename); 37 | } 38 | } 39 | 40 | void HuffmanEncoder::RecoverFromFile(const char* filename) 41 | { 42 | m_dict = new Dictionary(); 43 | FILE* fid = fopen(filename, "r"); 44 | if(fid) 45 | { 46 | int vocab_size; 47 | fscanf(fid, "%lld", &vocab_size); 48 | m_hufflabel_info.reserve(vocab_size); 49 | m_hufflabel_info.clear(); 50 | 51 | int tmp; 52 | char sz_label[MAX_WORD_SIZE]; 53 | for (int i = 0; i < vocab_size; ++i) 54 | { 55 | HuffLabelInfo info; 56 | 57 | fscanf(fid, "%s", sz_label, MAX_WORD_SIZE); 58 | m_dict->Insert(sz_label); 59 | 60 | fscanf(fid, "%d", &info.codelen); 61 | 62 | info.code.clear(); 63 | info.point.clear(); 64 | 65 | for (int j = 0; j < info.codelen; ++j) 66 | { 67 | fscanf(fid, "%d", &tmp); 68 | info.code.push_back(tmp); 69 | } 70 | for (int j = 0; j < info.codelen; ++j) 71 | { 72 | fscanf(fid, "%d", &tmp); 73 | info.point.push_back(tmp); 74 | } 75 | 76 | m_hufflabel_info.push_back(info); 77 | } 78 | fclose(fid); 79 | } 80 | else 81 | { 82 | printf("file open failed %s", filename); 83 | } 84 | } 85 | 86 | bool compare(const std::pair& x, const std::pair& y) 87 | { 88 | if (x.second == 0) return true; 89 | if (y.second == 0) return false; 90 | return (x.second > y.second); 91 | } 92 | 93 | void HuffmanEncoder::BuildHuffmanTreeFromDict() 94 | { 95 | std::vector > ordered_words; 96 | ordered_words.reserve(m_dict->Size()); 97 | ordered_words.clear(); 98 | for (unsigned i = 0; i < static_cast(m_dict->Size()); ++i) 99 | ordered_words.push_back(std::pair(i, m_dict->GetWordInfo(i)->freq)); 100 | std::sort(ordered_words.begin(), ordered_words.end(), compare); 101 | 102 | unsigned vocab_size = (unsigned) ordered_words.size(); 103 | int64_t *count = new int64_t[vocab_size * 2 + 1]; //frequence 104 | unsigned *binary = new unsigned[vocab_size * 2 + 1]; //huffman code relative to parent node [1,0] of each node 105 | memset(binary, 0, sizeof(unsigned)* (vocab_size * 2 + 1)); 106 | 107 | unsigned *parent_node = new unsigned[vocab_size * 2 + 1]; // 108 | memset(parent_node, 0, sizeof(unsigned)* (vocab_size * 2 + 1)); 109 | unsigned code[MAX_CODE_LENGTH], point[MAX_CODE_LENGTH]; 110 | 111 | for (unsigned i = 0; i < vocab_size; ++i) 112 | count[i] = ordered_words[i].second; 113 | for (unsigned i = vocab_size; i < vocab_size * 2; i++) 114 | count[i] = static_cast(1e15); 115 | int pos1 = vocab_size - 1; 116 | int pos2 = vocab_size; 117 | int min1i, min2i; 118 | for (unsigned i = 0; i < vocab_size - 1; i++) 119 | { 120 | // First, find two smallest nodes 'min1, min2' 121 | assert(pos2 < vocab_size * 2 - 1); 122 | //find the samllest node 123 | if (pos1 >= 0) 124 | { 125 | if (count[pos1] < count[pos2]) 126 | { 127 | min1i = pos1; 128 | pos1--; 129 | } 130 | else 131 | { 132 | min1i = pos2; 133 | pos2++; 134 | } 135 | } 136 | else 137 | { 138 | min1i = pos2; 139 | pos2++; 140 | } 141 | 142 | //find the second samllest node 143 | if (pos1 >= 0) 144 | { 145 | if (count[pos1] < count[pos2]) 146 | { 147 | min2i = pos1; 148 | pos1--; 149 | } 150 | else 151 | { 152 | min2i = pos2; 153 | pos2++; 154 | } 155 | } 156 | else 157 | { 158 | min2i = pos2; 159 | pos2++; 160 | } 161 | 162 | count[vocab_size + i] = count[min1i] + count[min2i]; 163 | 164 | assert(min1i >= 0 && min1i < vocab_size * 2 - 1 && min2i >= 0 && min2i < vocab_size * 2 - 1); 165 | parent_node[min1i] = vocab_size + i; 166 | parent_node[min2i] = vocab_size + i; 167 | binary[min2i] = 1; 168 | } 169 | assert(pos1 < 0); 170 | 171 | //generate the huffman code for each leaf node 172 | m_hufflabel_info.clear(); 173 | for (unsigned a = 0; a < vocab_size; ++a) 174 | m_hufflabel_info.push_back(HuffLabelInfo()); 175 | for (unsigned a = 0; a < vocab_size; a++) 176 | { 177 | unsigned b = a, i = 0; 178 | while (1) 179 | { 180 | assert(i < MAX_CODE_LENGTH); 181 | code[i] = binary[b]; 182 | point[i] = b; 183 | i++; 184 | b = parent_node[b]; 185 | if (b == vocab_size * 2 - 2) break; 186 | } 187 | unsigned cur_word = ordered_words[a].first; 188 | 189 | m_hufflabel_info[cur_word].codelen = i; 190 | m_hufflabel_info[cur_word].point.push_back(vocab_size - 2); 191 | 192 | for (b = 0; b < i; b++) 193 | { 194 | m_hufflabel_info[cur_word].code.push_back(code[i - b - 1]); 195 | if (b) 196 | m_hufflabel_info[cur_word].point.push_back(point[i - b] - vocab_size); 197 | } 198 | } 199 | 200 | delete[] count; 201 | count = nullptr; 202 | delete[] binary; 203 | binary = nullptr; 204 | delete[] parent_node; 205 | parent_node = nullptr; 206 | } 207 | 208 | void HuffmanEncoder::BuildFromTermFrequency(const char* filename) 209 | { 210 | FILE* fid = fopen(filename, "r"); 211 | if(fid) 212 | { 213 | char sz_label[MAX_WORD_SIZE]; 214 | m_dict = new Dictionary(); 215 | 216 | while (fscanf(fid, "%s", sz_label, MAX_WORD_SIZE) != EOF) 217 | { 218 | HuffLabelInfo info; 219 | int freq; 220 | fscanf(fid, "%d", &freq); 221 | m_dict->Insert(sz_label, freq); 222 | } 223 | fclose(fid); 224 | 225 | BuildHuffmanTreeFromDict(); 226 | } 227 | else 228 | { 229 | printf("file open failed %s", filename); 230 | } 231 | } 232 | 233 | void HuffmanEncoder::BuildFromTermFrequency(Dictionary* dict) 234 | { 235 | m_dict = dict; 236 | BuildHuffmanTreeFromDict(); 237 | } 238 | 239 | int HuffmanEncoder::GetLabelSize() 240 | { 241 | return m_dict->Size(); 242 | } 243 | 244 | int HuffmanEncoder::GetLabelIdx(const char* label) 245 | { 246 | return m_dict->GetWordIdx(label); 247 | } 248 | 249 | HuffLabelInfo* HuffmanEncoder::GetLabelInfo(char* label) 250 | { 251 | int idx = GetLabelIdx(label); 252 | if (idx == -1) 253 | return NULL; 254 | return GetLabelInfo(idx); 255 | } 256 | 257 | HuffLabelInfo* HuffmanEncoder::GetLabelInfo(int label_idx) 258 | { 259 | if (label_idx == -1) return NULL; 260 | return &m_hufflabel_info[label_idx]; 261 | } 262 | 263 | Dictionary* HuffmanEncoder::GetDict() 264 | { 265 | return m_dict; 266 | } 267 | -------------------------------------------------------------------------------- /src/huffman_encoder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dictionary.h" 4 | 5 | const int MAX_CODE_LENGTH = 100; 6 | 7 | struct HuffLabelInfo 8 | { 9 | std::vector point; //internal node ids in the code path 10 | std::vector code; //huffman code 11 | int codelen; 12 | HuffLabelInfo() 13 | { 14 | codelen = 0; 15 | point.clear(); 16 | code.clear(); 17 | } 18 | }; 19 | 20 | class HuffmanEncoder 21 | { 22 | public: 23 | HuffmanEncoder(); 24 | 25 | void Save2File(const char* filename); 26 | void RecoverFromFile(const char* filename); 27 | void BuildFromTermFrequency(const char* filename); 28 | void BuildFromTermFrequency(Dictionary* dict); 29 | 30 | int GetLabelSize(); 31 | int GetLabelIdx(const char* label); 32 | HuffLabelInfo* GetLabelInfo(char* label); 33 | HuffLabelInfo* GetLabelInfo(int label_idx); 34 | Dictionary* GetDict(); 35 | 36 | private: 37 | void BuildHuffmanTreeFromDict(); 38 | std::vector m_hufflabel_info; 39 | Dictionary* m_dict; 40 | }; 41 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "dictionary.h" 13 | #include "huffman_encoder.h" 14 | #include "util.h" 15 | #include "reader.h" 16 | #include "multiverso_skipgram_mixture.h" 17 | #include "param_loader.h" 18 | #include "trainer.h" 19 | #include "skipgram_mixture_neural_network.h" 20 | 21 | bool ReadWord(char *word, FILE *fin) 22 | { 23 | int idx = 0; 24 | char ch; 25 | while (!feof(fin)) 26 | { 27 | ch = fgetc(fin); 28 | if (ch == 13) continue; 29 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) 30 | { 31 | if (idx > 0) 32 | { 33 | if (ch == '\n') 34 | ungetc(ch, fin); 35 | break; 36 | } 37 | 38 | if (ch == '\n') 39 | { 40 | strcpy(word, (char *)""); 41 | return true; 42 | } 43 | else 44 | { 45 | continue; 46 | } 47 | } 48 | 49 | word[idx++] = ch; 50 | if (idx >= MAX_STRING - 1) idx--; // Truncate too long words 51 | } 52 | 53 | word[idx] = 0; 54 | return idx > 0; 55 | } 56 | 57 | // Read the vocabulary file; create the dictionary and huffman_encoder according opt 58 | int64_t LoadVocab(Option *opt, Dictionary *dictionary, HuffmanEncoder *huffman_encoder) 59 | { 60 | int64_t total_words = 0; 61 | char word[MAX_STRING]; 62 | FILE* fid = nullptr; 63 | printf("vocab_file %s\n", opt->read_vocab_file); 64 | if (opt->read_vocab_file != nullptr && strlen(opt->read_vocab_file) > 0) 65 | { 66 | printf("Begin to load vocabulary file [%s] ...\n", opt->read_vocab_file); 67 | fid = fopen(opt->read_vocab_file, "r"); 68 | int word_freq; 69 | while (fscanf(fid, "%s %d", word, &word_freq) != EOF) 70 | { 71 | dictionary->Insert(word, word_freq); 72 | } 73 | } 74 | 75 | dictionary->RemoveWordsLessThan(opt->min_count); 76 | printf("Dictionary size: %d\n", dictionary->Size()); 77 | total_words = 0; 78 | for (int i = 0; i < dictionary->Size(); ++i) 79 | total_words += dictionary->GetWordInfo(i)->freq; 80 | printf("Words in corpus %I64d\n", total_words); 81 | huffman_encoder->BuildFromTermFrequency(dictionary); 82 | fclose(fid); 83 | 84 | return total_words; 85 | } 86 | 87 | 88 | int main(int argc, char *argv[]) 89 | { 90 | srand(static_cast(time(NULL))); 91 | Option *option = new Option(); 92 | Dictionary *dictionary = new Dictionary(); 93 | HuffmanEncoder *huffman_encoder = new HuffmanEncoder(); 94 | 95 | // Parse argument and store them in option 96 | option->ParseArgs(argc, argv); 97 | option->PrintArgs(); 98 | if (!option->CheckArgs()) 99 | { 100 | printf("Fatal error in arguments\n"); 101 | return -1; 102 | } 103 | // Read the vocabulary file; create the dictionary and huffman_encoder according opt 104 | printf("Loading vocabulary ...\n"); 105 | option->total_words = LoadVocab(option, dictionary, huffman_encoder); 106 | printf("Loaded vocabulary\n"); 107 | fflush(stdout); 108 | 109 | Reader *reader = new Reader(dictionary, option); 110 | 111 | MultiversoSkipGramMixture *multiverso_word2vector = new MultiversoSkipGramMixture(option, dictionary, huffman_encoder, reader); 112 | 113 | fflush(stdout); 114 | 115 | multiverso_word2vector->Train(argc, argv); 116 | 117 | delete multiverso_word2vector; 118 | delete reader; 119 | delete huffman_encoder; 120 | delete dictionary; 121 | delete option; 122 | 123 | return 0; 124 | } 125 | -------------------------------------------------------------------------------- /src/multiverso_skipgram_mixture.cpp: -------------------------------------------------------------------------------- 1 | #include "multiverso_skipgram_mixture.h" 2 | #include 3 | 4 | MultiversoSkipGramMixture::MultiversoSkipGramMixture(Option *option, Dictionary *dictionary, HuffmanEncoder *huffman_encoder, Reader *reader) 5 | { 6 | m_option = option; 7 | m_dictionary = dictionary; 8 | m_huffman_encoder = huffman_encoder; 9 | m_reader = reader; 10 | 11 | InitSenseCntInfo(); 12 | } 13 | 14 | void MultiversoSkipGramMixture::InitSenseCntInfo() 15 | { 16 | //First, determine #senses for words according to configuration parameters: top_N and top_ratio 17 | int threshold = (m_option->top_N ? std::min(m_option->top_N, m_dictionary->Size()) : m_dictionary->Size()); 18 | threshold = static_cast(std::min(static_cast(m_option->top_ratio) * m_dictionary->Size(), static_cast(threshold))); 19 | 20 | m_word_sense_info.total_senses_cnt = threshold * m_option->sense_num_multi + (m_dictionary->Size() - threshold); 21 | 22 | std::pair* wordlist = new std::pair[m_dictionary->Size() + 10]; 23 | for (int i = 0; i < m_dictionary->Size(); ++i) 24 | wordlist[i] = std::pair(i, m_dictionary->GetWordInfo(i)->freq); 25 | 26 | std::sort(wordlist, wordlist + m_dictionary->Size(), [](std::pair a, std::pair b) { 27 | return a.second > b.second; 28 | }); 29 | 30 | m_word_sense_info.word_sense_cnts_info.resize(m_dictionary->Size()); 31 | 32 | for (int i = 0; i < threshold; ++i) 33 | m_word_sense_info.word_sense_cnts_info[wordlist[i].first] = m_option->sense_num_multi; 34 | for (int i = threshold; i < m_dictionary->Size(); ++i) 35 | m_word_sense_info.word_sense_cnts_info[wordlist[i].first] = 1; 36 | 37 | //Then, read words #sense info from the sense file 38 | if (m_option->sense_file) 39 | { 40 | FILE* fid = fopen(m_option->sense_file, "r"); 41 | char word[1000]; 42 | while (fscanf(fid, "%s", word) != EOF) 43 | { 44 | int word_idx = m_dictionary->GetWordIdx(word); 45 | if (word_idx == -1) 46 | continue; 47 | if (m_word_sense_info.word_sense_cnts_info[word_idx] == 1) 48 | { 49 | m_word_sense_info.word_sense_cnts_info[word_idx] = m_option->sense_num_multi; 50 | m_word_sense_info.total_senses_cnt += (m_option->sense_num_multi - 1); 51 | } 52 | } 53 | fclose(fid); 54 | } 55 | 56 | //At last, point pointers to the right position 57 | m_word_sense_info.p_input_embedding.resize(m_dictionary->Size()); 58 | int cnt = 0; 59 | m_word_sense_info.multi_senses_words_cnt = 0; 60 | 61 | for (int i = 0; i < m_dictionary->Size(); ++i) 62 | { 63 | m_word_sense_info.p_input_embedding[i] = cnt; 64 | if (m_word_sense_info.word_sense_cnts_info[i] > 1) 65 | m_word_sense_info.p_wordidx2sense_idx[i] = m_word_sense_info.multi_senses_words_cnt++; 66 | cnt += m_word_sense_info.word_sense_cnts_info[i]; 67 | } 68 | 69 | printf("Total senses:%d, total multiple mearning words:%d\n", m_word_sense_info.total_senses_cnt, m_word_sense_info.multi_senses_words_cnt); 70 | 71 | } 72 | 73 | void MultiversoSkipGramMixture::Train(int argc, char *argv[]) 74 | { 75 | multiverso::Barrier* barrier = new multiverso::Barrier(m_option->thread_cnt); 76 | 77 | printf("Inited barrier\n"); 78 | 79 | SkipGramMixtureNeuralNetwork* word2vector_neural_networks[2] = { new SkipGramMixtureNeuralNetwork(m_option, m_huffman_encoder, &m_word_sense_info, m_dictionary, m_dictionary->Size()), 80 | new SkipGramMixtureNeuralNetwork(m_option, m_huffman_encoder, &m_word_sense_info, m_dictionary, m_dictionary->Size()) }; 81 | 82 | // Create Multiverso ParameterLoader and Trainers, 83 | // start Multiverso environment 84 | printf("Initializing Multiverso ...\n"); 85 | 86 | fflush(stdout); 87 | std::vector trainers; 88 | for (int i = 0; i < m_option->thread_cnt; ++i) 89 | { 90 | trainers.push_back(new Trainer(i, m_option, (void**)word2vector_neural_networks, barrier, m_dictionary, &m_word_sense_info, m_huffman_encoder)); 91 | } 92 | 93 | ParameterLoader *parameter_loader = new ParameterLoader(m_option, (void**)word2vector_neural_networks, &m_word_sense_info); 94 | multiverso::Config config; 95 | config.max_delay = m_option->max_delay; 96 | config.num_servers = m_option->num_servers; 97 | config.num_aggregator = m_option->num_aggregator; 98 | config.lock_option = static_cast(m_option->lock_option); 99 | config.num_lock = m_option->num_lock; 100 | config.is_pipeline = m_option->pipline; 101 | 102 | fflush(stdout); 103 | 104 | multiverso::Multiverso::Init(trainers, parameter_loader, config, &argc, &argv); 105 | 106 | fflush(stdout); 107 | multiverso::Log::ResetLogFile("log.txt"); 108 | m_process_id = multiverso::Multiverso::ProcessRank(); 109 | PrepareMultiversoParameterTables(m_option, m_dictionary); 110 | 111 | printf("Start to train ...\n"); 112 | TrainNeuralNetwork(); 113 | printf("Rank %d Finish training\n", m_process_id); 114 | 115 | delete barrier; 116 | delete word2vector_neural_networks[0]; 117 | delete word2vector_neural_networks[1]; 118 | for (auto &trainer : trainers) 119 | { 120 | delete trainer; 121 | } 122 | delete parameter_loader; 123 | multiverso::Multiverso::Close(); 124 | } 125 | 126 | void MultiversoSkipGramMixture::AddMultiversoParameterTable(multiverso::integer_t table_id, multiverso::integer_t rows, 127 | multiverso::integer_t cols, multiverso::Type type, multiverso::Format default_format) 128 | { 129 | multiverso::Multiverso::AddServerTable(table_id, rows, cols, type, default_format); 130 | multiverso::Multiverso::AddCacheTable(table_id, rows, cols, type, default_format, 0); 131 | multiverso::Multiverso::AddAggregatorTable(table_id, rows, cols, type, default_format, 0); 132 | } 133 | 134 | void MultiversoSkipGramMixture::PrepareMultiversoParameterTables(Option *opt, Dictionary *dictionary) 135 | { 136 | multiverso::Multiverso::BeginConfig(); 137 | int proc_count = multiverso::Multiverso::TotalProcessCount(); 138 | 139 | // create tables 140 | AddMultiversoParameterTable(kInputEmbeddingTableId, m_word_sense_info.total_senses_cnt, opt->embeding_size, multiverso::Type::Float, multiverso::Format::Dense); 141 | AddMultiversoParameterTable(kEmbeddingOutputTableId, dictionary->Size(), opt->embeding_size, multiverso::Type::Float, multiverso::Format::Dense); 142 | AddMultiversoParameterTable(kWordCountActualTableId, 1, 1, multiverso::Type::LongLong, multiverso::Format::Dense); 143 | AddMultiversoParameterTable(kWordSensePriorTableId, m_word_sense_info.multi_senses_words_cnt, m_option->sense_num_multi, multiverso::Type::Float, multiverso::Format::Dense); 144 | 145 | // initialize input embeddings 146 | for (int row = 0; row < m_word_sense_info.total_senses_cnt; ++row) 147 | { 148 | for (int col = 0; col < opt->embeding_size; ++col) 149 | { 150 | multiverso::Multiverso::AddToServer(kInputEmbeddingTableId, row, col, static_cast((static_cast(rand()) / RAND_MAX - 0.5) / opt->embeding_size / proc_count)); 151 | } 152 | } 153 | 154 | //initialize sense priors 155 | for (int row = 0; row < m_word_sense_info.multi_senses_words_cnt; ++row) 156 | { 157 | for (int col = 0; col < opt->sense_num_multi; ++col) 158 | { 159 | multiverso::Multiverso::AddToServer(kWordSensePriorTableId, row, col, 160 | static_cast(m_option->store_multinomial ? 1.0 / m_option->sense_num_multi : log(1.0 / m_option->sense_num_multi))); 161 | } 162 | } 163 | multiverso::Multiverso::EndConfig(); 164 | } 165 | 166 | //Load the sentences from train file, and store them in data_block 167 | void MultiversoSkipGramMixture::LoadData(DataBlock *data_block, Reader *reader, int64_t size) 168 | { 169 | data_block->ReleaseSentences(); 170 | while (data_block->Size() < m_option->data_block_size) 171 | { 172 | int64_t word_count = 0; 173 | int *sentence = new (std::nothrow)int[MAX_SENTENCE_LENGTH + 2]; 174 | assert(sentence != nullptr); 175 | int sentence_length = reader->GetSentence(sentence, word_count); 176 | if (sentence_length > 0) 177 | { 178 | data_block->Add(sentence, sentence_length, word_count, (uint64_t)rand() * 10000 + (uint64_t)rand()); 179 | } 180 | else 181 | { 182 | //Reader read eof 183 | delete[] sentence; 184 | return; 185 | } 186 | } 187 | } 188 | 189 | void MultiversoSkipGramMixture::PushDataBlock( 190 | std::queue &datablock_queue, DataBlock* data_block) 191 | { 192 | 193 | multiverso::Multiverso::PushDataBlock(data_block); 194 | 195 | datablock_queue.push(data_block); 196 | //limit the max size of total datablocks to avoid out of memory 197 | while (static_cast(datablock_queue.size()) > m_option->max_preload_blocks_cnt) 198 | { 199 | std::chrono::milliseconds dura(200); 200 | std::this_thread::sleep_for(dura); 201 | 202 | RemoveDoneDataBlock(datablock_queue); 203 | } 204 | } 205 | 206 | //Remove the datablock which has been delt by parameterloader and trainer 207 | void MultiversoSkipGramMixture::RemoveDoneDataBlock(std::queue &datablock_queue) 208 | { 209 | while (datablock_queue.empty() == false 210 | && datablock_queue.front()->IsDone()) 211 | { 212 | DataBlock *p_data_block = datablock_queue.front(); 213 | datablock_queue.pop(); 214 | delete p_data_block; 215 | } 216 | } 217 | 218 | void MultiversoSkipGramMixture::TrainNeuralNetwork() 219 | { 220 | std::queuedatablock_queue; 221 | int data_block_count = 0; 222 | 223 | multiverso::Multiverso::BeginTrain(); 224 | 225 | for (int curr_epoch = 0; curr_epoch < m_option->epoch; ++curr_epoch) 226 | { 227 | m_reader->Open(m_option->train_file); 228 | while (1) 229 | { 230 | ++data_block_count; 231 | DataBlock *data_block = new (std::nothrow)DataBlock(); 232 | assert(data_block != nullptr); 233 | clock_t start = clock(); 234 | LoadData(data_block, m_reader, m_option->data_block_size); 235 | if (data_block->Size() <= 0) 236 | { 237 | delete data_block; 238 | break; 239 | } 240 | multiverso::Log::Info("Rank%d Load%d^thDataBlockTime:%lfs\n", m_process_id, data_block_count, 241 | (clock() - start) / (double)CLOCKS_PER_SEC); 242 | multiverso::Multiverso::BeginClock(); 243 | PushDataBlock(datablock_queue, data_block); 244 | multiverso::Multiverso::EndClock(); 245 | } 246 | 247 | m_reader->Close(); 248 | 249 | multiverso::Multiverso::BeginClock(); 250 | 251 | DataBlock *output_data_block = new DataBlock(); //Add a special data_block for dumping model files 252 | output_data_block->AddTable(kInputEmbeddingTableId); 253 | output_data_block->AddTable(kEmbeddingOutputTableId); 254 | output_data_block->AddTable(kWordSensePriorTableId); 255 | output_data_block->SetEpochId(curr_epoch); 256 | 257 | ++data_block_count; 258 | multiverso::Multiverso::PushDataBlock(output_data_block); 259 | multiverso::Multiverso::EndClock(); 260 | } 261 | 262 | multiverso::Log::Info("Rank %d pushed %d blocks\n", multiverso::Multiverso::ProcessRank(), data_block_count); 263 | 264 | multiverso::Multiverso::EndTrain(); 265 | 266 | //After EndTrain, all the datablock are done, 267 | //we remove all the datablocks 268 | RemoveDoneDataBlock(datablock_queue); 269 | } 270 | 271 | 272 | -------------------------------------------------------------------------------- /src/multiverso_skipgram_mixture.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "util.h" 13 | #include "huffman_encoder.h" 14 | #include "data_block.h" 15 | #include "param_loader.h" 16 | #include "trainer.h" 17 | #include "reader.h" 18 | 19 | class MultiversoSkipGramMixture 20 | { 21 | public: 22 | MultiversoSkipGramMixture(Option *option, Dictionary *dictionary, HuffmanEncoder *huffman_encoder, Reader *reader); 23 | 24 | void Train(int argc, char *argv[]); 25 | 26 | private: 27 | int m_process_id; 28 | Option* m_option; 29 | Dictionary* m_dictionary; 30 | HuffmanEncoder* m_huffman_encoder; 31 | Reader* m_reader; 32 | 33 | WordSenseInfo m_word_sense_info; 34 | 35 | /*! 36 | * \brief Complete the train task with multiverso 37 | */ 38 | void TrainNeuralNetwork(); 39 | 40 | 41 | /*! 42 | * \brief Create a new table in the multiverso 43 | */ 44 | void AddMultiversoParameterTable(multiverso::integer_t table_id, multiverso::integer_t rows, 45 | multiverso::integer_t cols, multiverso::Type type, multiverso::Format default_format); 46 | 47 | /*! 48 | * \brief Prepare parameter table in the multiverso 49 | */ 50 | void PrepareMultiversoParameterTables(Option *opt, Dictionary *dictionary); 51 | 52 | 53 | /*! 54 | * \brief Load data from train_file to datablock 55 | * \param datablock the datablock which needs to be assigned 56 | * \param reader some useful function for calling 57 | * \param size datablock limit byte size 58 | */ 59 | void LoadData(DataBlock *data_block, Reader *reader, int64_t size); 60 | 61 | /*! 62 | * \brief Push the datablock into the multiverso and datablock_queue 63 | */ 64 | void PushDataBlock(std::queue &datablock_queue, DataBlock* data_block); 65 | 66 | /*! 67 | * \brief Remove datablock which is finished by multiverso thread 68 | * \param datablock_queue store the pushed datablocks 69 | */ 70 | void RemoveDoneDataBlock(std::queue &datablock_queue); 71 | 72 | /*! 73 | * \brief Init the sense count info for all words 74 | */ 75 | void InitSenseCntInfo(); 76 | }; 77 | 78 | 79 | -------------------------------------------------------------------------------- /src/multiverso_tablesid.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | /*! 5 | * \brief Defines the index of parameter tables. 6 | */ 7 | const multiverso::integer_t kInputEmbeddingTableId = 0; //Input embedding vector table 8 | const multiverso::integer_t kEmbeddingOutputTableId = 1; //Huffman tree node embedding vector table 9 | const multiverso::integer_t kWordCountActualTableId = 2; //Word count table 10 | const multiverso::integer_t kWordSensePriorTableId = 3; //Sense priors table 11 | -------------------------------------------------------------------------------- /src/param_loader.cpp: -------------------------------------------------------------------------------- 1 | #include "param_loader.h" 2 | 3 | template 4 | ParameterLoader::ParameterLoader(Option *option, void** word2vector_neural_networks, WordSenseInfo* word_sense_info) 5 | { 6 | m_option = option; 7 | m_parse_and_request_count = 0; 8 | m_sgmixture_neural_networks = word2vector_neural_networks; 9 | m_log_file = fopen("parameter_loader.log", "w"); 10 | m_words_sense_info = word_sense_info; 11 | } 12 | 13 | template 14 | void ParameterLoader::ParseAndRequest(multiverso::DataBlockBase *data_block) 15 | { 16 | if (m_parse_and_request_count == 0) 17 | { 18 | m_start_time = clock(); 19 | } 20 | 21 | fprintf(m_log_file, "%lf\n", (clock() - m_start_time) / (double)CLOCKS_PER_SEC); 22 | multiverso::Log::Info("Rank %d ParameterLoader begin %d\n", multiverso::Multiverso::ProcessRank(), m_parse_and_request_count); 23 | DataBlock *data = reinterpret_cast(data_block); 24 | 25 | SkipGramMixtureNeuralNetwork* sg_mixture_neural_network = reinterpret_cast*>(m_sgmixture_neural_networks[m_parse_and_request_count % 2]); 26 | ++m_parse_and_request_count; 27 | data->UpdateNextRandom(); 28 | sg_mixture_neural_network->PrepareParmeter(data); 29 | 30 | std::vector& input_layer_nodes = sg_mixture_neural_network->GetInputLayerNodes(); 31 | std::vector& output_layer_nodes = sg_mixture_neural_network->GetOutputLayerNodes(); 32 | assert(sg_mixture_neural_network->status == 0); 33 | sg_mixture_neural_network->status = 1; 34 | 35 | for (int i = 0; i < input_layer_nodes.size(); ++i) 36 | { 37 | int word_id = input_layer_nodes[i]; 38 | for (int j = 0; j < m_words_sense_info->word_sense_cnts_info[word_id]; ++j) 39 | RequestRow(kInputEmbeddingTableId, m_words_sense_info->p_input_embedding[word_id] + j); 40 | } 41 | 42 | for (int i = 0; i < output_layer_nodes.size(); ++i) 43 | RequestRow(kEmbeddingOutputTableId, output_layer_nodes[i]); 44 | 45 | RequestRow(kWordCountActualTableId, 0); 46 | 47 | for (int i = 0; i < input_layer_nodes.size(); ++i) 48 | { 49 | int word_id = input_layer_nodes[i]; 50 | if (m_words_sense_info->word_sense_cnts_info[word_id] > 1) 51 | RequestRow(kWordSensePriorTableId, m_words_sense_info->p_wordidx2sense_idx[word_id]); 52 | } 53 | 54 | std::vector & tables = data->GetTables(); 55 | for (int i = 0; i < tables.size(); ++i) 56 | RequestTable(tables[i]); 57 | 58 | multiverso::Log::Info("Rank %d ParameterLoader finish %d\n", multiverso::Multiverso::ProcessRank(), m_parse_and_request_count - 1); 59 | fprintf(m_log_file, "%lf\n", (clock() - m_start_time) / (double)CLOCKS_PER_SEC); 60 | assert(sg_mixture_neural_network->status == 1); 61 | sg_mixture_neural_network->status = 2; 62 | } 63 | 64 | template class ParameterLoader; 65 | template class ParameterLoader; -------------------------------------------------------------------------------- /src/param_loader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "data_block.h" 5 | #include "multiverso_tablesid.h" 6 | #include "util.h" 7 | #include "huffman_encoder.h" 8 | #include "skipgram_mixture_neural_network.h" 9 | 10 | 11 | /*! 12 | * \brief The class ParameterLoader preloads the parameters from multiverso server 13 | */ 14 | template 15 | class ParameterLoader : public multiverso::ParameterLoaderBase 16 | { 17 | public: 18 | ParameterLoader(Option *opt, void ** word2vector_neural_networks, WordSenseInfo* word_sense_info); 19 | /*! 20 | * \brief Request the parameters from multiverso server according to data_block 21 | * \param data_block stores the information of sentences 22 | */ 23 | void ParseAndRequest(multiverso::DataBlockBase* data_block) override; 24 | 25 | private: 26 | int m_parse_and_request_count; 27 | Option* m_option; 28 | clock_t m_start_time; 29 | WordSenseInfo* m_words_sense_info; 30 | void ** m_sgmixture_neural_networks; 31 | FILE* m_log_file; 32 | }; 33 | 34 | -------------------------------------------------------------------------------- /src/reader.cpp: -------------------------------------------------------------------------------- 1 | #include "reader.h" 2 | 3 | Reader::Reader(Dictionary *dictionary, Option *option) 4 | { 5 | m_dictionary = dictionary; 6 | m_option = option; 7 | 8 | m_stopwords_table.clear(); 9 | if (m_option->stopwords) 10 | { 11 | FILE* fid = fopen(m_option->sw_file, "r"); 12 | while (ReadWord(m_word, fid)) 13 | { 14 | m_stopwords_table.insert(m_word); 15 | if (m_dictionary->GetWordIdx(m_word) != -1) 16 | m_option->total_words -= m_dictionary->GetWordInfo(m_word)->freq; 17 | } 18 | 19 | fclose(fid); 20 | } 21 | } 22 | 23 | void Reader::Open(const char *input_file) 24 | { 25 | m_fin = fopen(input_file, "r"); 26 | } 27 | 28 | void Reader::Close() 29 | { 30 | fclose(m_fin); 31 | m_fin = nullptr; 32 | } 33 | 34 | int Reader::GetSentence(int *sentence, int64_t &word_count) 35 | { 36 | int length = 0, word_idx; 37 | word_count = 0; 38 | while (1) 39 | { 40 | if (!ReadWord(m_word, m_fin)) 41 | break; 42 | word_idx = m_dictionary->GetWordIdx(m_word); 43 | if (word_idx == -1) 44 | continue; 45 | word_count++; 46 | if (m_option->stopwords && m_stopwords_table.count(m_word)) 47 | continue; 48 | sentence[length++] = word_idx; 49 | if (length >= MAX_SENTENCE_LENGTH) 50 | break; 51 | } 52 | 53 | return length; 54 | } 55 | 56 | 57 | bool Reader::ReadWord(char *word, FILE *fin) 58 | { 59 | int idx = 0; 60 | char ch; 61 | while (!feof(fin)) 62 | { 63 | ch = fgetc(fin); 64 | if (ch == 13) continue; 65 | if ((ch == ' ') || (ch == '\t') || (ch == '\n')) 66 | { 67 | if (idx > 0) 68 | { 69 | if (ch == '\n') 70 | ungetc(ch, fin); 71 | break; 72 | } 73 | if (ch == '\n') 74 | { 75 | strcpy(word, (char *)""); 76 | return true; 77 | } 78 | else continue; 79 | } 80 | word[idx++] = ch; 81 | if (idx >= MAX_STRING - 1) idx--; // Truncate too long words 82 | } 83 | word[idx] = 0; 84 | return idx != 0; 85 | } 86 | -------------------------------------------------------------------------------- /src/reader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util.h" 4 | #include "dictionary.h" 5 | #include 6 | #include 7 | 8 | class Reader 9 | { 10 | public: 11 | Reader(Dictionary *dictionary, Option *option); 12 | void Open(const char *input_file); 13 | void Close(); 14 | int GetSentence(int *sentence, int64_t &word_count); 15 | 16 | private: 17 | Option* m_option; 18 | FILE* m_fin; 19 | char m_word[MAX_STRING + 1]; 20 | Dictionary *m_dictionary; 21 | std::unordered_set m_stopwords_table; 22 | 23 | bool ReadWord(char *word, FILE *fin); 24 | }; 25 | -------------------------------------------------------------------------------- /src/skipgram_mixture_neural_network.cpp: -------------------------------------------------------------------------------- 1 | #include "skipgram_mixture_neural_network.h" 2 | 3 | template 4 | SkipGramMixtureNeuralNetwork::SkipGramMixtureNeuralNetwork(Option* option, HuffmanEncoder* huffmanEncoder, WordSenseInfo* word_sense_info, Dictionary* dic, int dicSize) 5 | { 6 | status = 0; 7 | m_option = option; 8 | m_huffman_encoder = huffmanEncoder; 9 | m_word_sense_info = word_sense_info; 10 | m_dictionary_size = dicSize; 11 | m_dictionary = dic; 12 | 13 | m_input_embedding_weights_ptr = new T*[m_dictionary_size]; 14 | m_sense_priors_ptr = new T*[m_dictionary_size]; 15 | m_sense_priors_paras_ptr = new T*[m_dictionary_size]; 16 | 17 | m_output_embedding_weights_ptr = new T*[m_dictionary_size]; 18 | m_seleted_input_embedding_weights = new bool[m_dictionary_size]; 19 | m_selected_output_embedding_weights = new bool[m_dictionary_size]; 20 | assert(m_input_embedding_weights_ptr != nullptr); 21 | assert(m_output_embedding_weights_ptr != nullptr); 22 | assert(m_seleted_input_embedding_weights != nullptr); 23 | assert(m_selected_output_embedding_weights != nullptr); 24 | memset(m_seleted_input_embedding_weights, 0, sizeof(bool) * m_dictionary_size); 25 | memset(m_selected_output_embedding_weights, 0, sizeof(bool) * m_dictionary_size); 26 | } 27 | 28 | template 29 | SkipGramMixtureNeuralNetwork::~SkipGramMixtureNeuralNetwork() 30 | { 31 | delete m_input_embedding_weights_ptr; 32 | delete m_output_embedding_weights_ptr; 33 | delete m_sense_priors_ptr; 34 | delete m_sense_priors_paras_ptr; 35 | delete m_seleted_input_embedding_weights; 36 | delete m_selected_output_embedding_weights; 37 | } 38 | 39 | template 40 | void SkipGramMixtureNeuralNetwork::Train(int* sentence, int sentence_length, T* gamma, T* f_table, T* input_backup) 41 | { 42 | ParseSentence(sentence, sentence_length, gamma, f_table, input_backup, &SkipGramMixtureNeuralNetwork::TrainSample); 43 | } 44 | 45 | template 46 | //The E - step, estimate the posterior multinomial probabilities 47 | T SkipGramMixtureNeuralNetwork::EstimateGamma(int word_input, std::vector >& output_nodes, T* posterior_ll, T* estimation, T* sense_prior, T* f_m) 48 | { 49 | T* input_embedding = m_input_embedding_weights_ptr[word_input]; 50 | T f, log_likelihood = 0; 51 | for (int sense_idx = 0; sense_idx < m_word_sense_info->word_sense_cnts_info[word_input]; ++sense_idx, input_embedding += m_option->embeding_size) 52 | { 53 | posterior_ll[sense_idx] = sense_prior[sense_idx] < eps ? MIN_LOG : log(sense_prior[sense_idx]); //posterior likelihood for each sense 54 | 55 | int64_t fidx = sense_idx * MAX_CODE_LENGTH; 56 | 57 | for (int d = 0; d < output_nodes.size(); ++d, fidx++) 58 | { 59 | f = Util::InnerProduct(input_embedding, m_output_embedding_weights_ptr[output_nodes[d].first], m_option->embeding_size); 60 | f = Util::Sigmoid(f); 61 | f_m[fidx] = f; 62 | if (output_nodes[d].second) //huffman code, 0 or 1 63 | f = 1 - f; 64 | posterior_ll[sense_idx] += f < eps ? MIN_LOG : log(f); 65 | } 66 | log_likelihood += posterior_ll[sense_idx]; 67 | } 68 | if (m_word_sense_info->word_sense_cnts_info[word_input] == 1) 69 | { 70 | estimation[0] = 1; 71 | return log_likelihood; 72 | } 73 | 74 | Util::SoftMax(posterior_ll, estimation, m_word_sense_info->word_sense_cnts_info[word_input]); 75 | 76 | return log_likelihood; 77 | } 78 | 79 | template 80 | //The M Step: update the sense prior probabilities to maximize the Q function 81 | void SkipGramMixtureNeuralNetwork::MaximizeSensePriors(int word_input, T* log_likelihood) 82 | { 83 | if (m_word_sense_info->word_sense_cnts_info[word_input] == 1) 84 | { 85 | return; 86 | } 87 | 88 | for (int sense_idx = 0; sense_idx < m_word_sense_info->word_sense_cnts_info[word_input]; ++sense_idx) 89 | { 90 | T new_alpha = log_likelihood[sense_idx]; 91 | m_sense_priors_paras_ptr[word_input][sense_idx] = m_sense_priors_paras_ptr[word_input][sense_idx] * sense_prior_momentum + new_alpha * (1 - sense_prior_momentum); 92 | } 93 | 94 | if (!m_option->store_multinomial) 95 | Util::SoftMax(m_sense_priors_paras_ptr[word_input], m_sense_priors_ptr[word_input], m_option->sense_num_multi); //Update the multinomial parameters 96 | } 97 | 98 | template 99 | //The M step : update the embedding vectors to maximize the Q function 100 | void SkipGramMixtureNeuralNetwork::UpdateEmbeddings(int word_input, std::vector >& output_nodes, T* estimation, T* f_m, T* input_backup, UpdateDirection direction) 101 | { 102 | T g; 103 | T* output_embedding; 104 | T* input_embedding; 105 | if (direction == UpdateDirection::UPDATE_INPUT) 106 | input_embedding = m_input_embedding_weights_ptr[word_input]; 107 | else input_embedding = input_backup; 108 | for (int sense_idx = 0; sense_idx < m_word_sense_info->word_sense_cnts_info[word_input]; ++sense_idx, input_embedding += m_option->embeding_size) 109 | { 110 | int64_t fidx = sense_idx * MAX_CODE_LENGTH; 111 | for (int d = 0; d < output_nodes.size(); ++d, ++fidx) 112 | { 113 | output_embedding = m_output_embedding_weights_ptr[output_nodes[d].first]; 114 | g = estimation[sense_idx] * (1 - output_nodes[d].second - f_m[fidx]) * learning_rate; 115 | if (direction == UpdateDirection::UPDATE_INPUT) //Update Input 116 | { 117 | for (int j = 0; j < m_option->embeding_size; ++j) 118 | input_embedding[j] += g * output_embedding[j]; 119 | } 120 | else // Update Output 121 | { 122 | for (int j = 0; j < m_option->embeding_size; ++j) 123 | output_embedding[j] += g * input_embedding[j]; 124 | } 125 | } 126 | } 127 | } 128 | 129 | 130 | template 131 | //Train a window sample and update the input embedding & output embedding vectors 132 | void SkipGramMixtureNeuralNetwork::TrainSample(int input_node, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup) 133 | { 134 | T* gamma = (T*)v_gamma; //stores the posterior probabilities 135 | T* f_table = (T*)v_fTable; //stores the inner product values of input and output embeddings 136 | T* input_backup = (T*)v_input_backup; 137 | 138 | T posterior_ll[MAX_SENSE_CNT]; //stores the posterior log likelihood 139 | T senses[1] = { 1.0 }; //For those words with only one sense 140 | 141 | T* sense_prior = m_word_sense_info->word_sense_cnts_info[input_node] == 1 ? senses : (m_option->store_multinomial ? m_sense_priors_paras_ptr[input_node] : m_sense_priors_ptr[input_node]); 142 | 143 | T log_likelihood; 144 | 145 | for (int iter = 0; iter < m_option->EM_iteration; ++iter) 146 | { 147 | // backup input embeddings 148 | memcpy(input_backup, m_input_embedding_weights_ptr[input_node], m_option->embeding_size * m_word_sense_info->word_sense_cnts_info[input_node] * sizeof(T)); 149 | log_likelihood = 0; 150 | 151 | // E-Step 152 | log_likelihood += EstimateGamma(input_node, output_nodes, posterior_ll, gamma, sense_prior, f_table); 153 | 154 | // M-Step 155 | if (m_option->store_multinomial) 156 | MaximizeSensePriors(input_node, gamma); 157 | else 158 | MaximizeSensePriors(input_node, posterior_ll); 159 | 160 | UpdateEmbeddings(input_node, output_nodes, gamma, f_table, input_backup, UpdateDirection::UPDATE_INPUT); 161 | UpdateEmbeddings(input_node, output_nodes, gamma, f_table, input_backup, UpdateDirection::UPDATE_OUTPUT); 162 | 163 | } 164 | } 165 | 166 | template 167 | //Collect all the input words and output nodes in the data block 168 | void SkipGramMixtureNeuralNetwork::PrepareParmeter(DataBlock* data_block) 169 | { 170 | for (int i = 0; i < m_input_layer_nodes.size(); ++i) 171 | { 172 | m_input_embedding_weights_ptr[m_input_layer_nodes[i]] = nullptr; 173 | m_seleted_input_embedding_weights[m_input_layer_nodes[i]] = false; 174 | } 175 | 176 | for (int i = 0; i < m_output_layer_nodes.size(); ++i) 177 | { 178 | m_output_embedding_weights_ptr[m_output_layer_nodes[i]] = nullptr; 179 | m_selected_output_embedding_weights[m_output_layer_nodes[i]] = false; 180 | } 181 | 182 | m_input_layer_nodes.clear(); 183 | m_output_layer_nodes.clear(); 184 | 185 | int sentence_length; 186 | int64_t word_count_deta; 187 | int* sentence; 188 | uint64_t next_random; 189 | 190 | for (int i = 0; i < data_block->Size(); ++i) 191 | { 192 | data_block->Get(i, sentence, sentence_length, word_count_deta, next_random); 193 | ParseSentence(sentence, sentence_length, nullptr, nullptr, nullptr, &SkipGramMixtureNeuralNetwork::DealPrepareParameter); 194 | } 195 | } 196 | 197 | template 198 | //Copy the input_nodes&output_nodes to private set 199 | void SkipGramMixtureNeuralNetwork::DealPrepareParameter(int input_node, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup) 200 | { 201 | AddInputLayerNode(input_node); 202 | for (int i = 0; i < output_nodes.size(); ++i) 203 | AddOutputLayerNode(output_nodes[i].first); 204 | } 205 | 206 | template 207 | /* 208 | Parse a sentence and deepen into two branchs: 209 | one for TrainNN, the other one is for Parameter_parse&request 210 | */ 211 | void SkipGramMixtureNeuralNetwork::ParseSentence(int* sentence, int sentence_length, T* gamma, T* f_table, T* input_backup, FunctionType function) 212 | { 213 | if (sentence_length == 0) 214 | return; 215 | 216 | int feat[MAX_SENTENCE_LENGTH + 10]; 217 | int input_node; 218 | std::vector > output_nodes; 219 | for (int sentence_position = 0; sentence_position < sentence_length; ++sentence_position) 220 | { 221 | if (sentence[sentence_position] == -1) continue; 222 | int feat_size = 0; 223 | 224 | for (int i = 0; i < m_option->window_size * 2 + 1; ++i) 225 | if (i != m_option->window_size) 226 | { 227 | int c = sentence_position - m_option->window_size + i; 228 | if (c < 0 || c >= sentence_length || sentence[c] == -1) continue; 229 | feat[feat_size++] = sentence[c]; 230 | 231 | //Begin: Train SkipGram 232 | { 233 | input_node = feat[feat_size - 1]; 234 | output_nodes.clear(); 235 | Parse(input_node, sentence[sentence_position], output_nodes); 236 | (this->*function)(input_node, output_nodes, gamma, f_table, input_backup); 237 | } 238 | } 239 | } 240 | } 241 | 242 | template 243 | //Parse the needed parameter in a window 244 | void SkipGramMixtureNeuralNetwork::Parse(int feat, int out_word_idx, std::vector >& output_nodes) 245 | { 246 | const auto info = m_huffman_encoder->GetLabelInfo(out_word_idx); 247 | for (int d = 0; d < info->codelen; d++) 248 | output_nodes.push_back(std::make_pair(info->point[d], info->code[d])); 249 | 250 | } 251 | 252 | template 253 | void SkipGramMixtureNeuralNetwork::AddInputLayerNode(int node_id) 254 | { 255 | if (m_seleted_input_embedding_weights[node_id] == false) 256 | { 257 | m_seleted_input_embedding_weights[node_id] = true; 258 | m_input_layer_nodes.push_back(node_id); 259 | } 260 | } 261 | 262 | template 263 | void SkipGramMixtureNeuralNetwork::AddOutputLayerNode(int node_id) 264 | { 265 | if (m_selected_output_embedding_weights[node_id] == false) 266 | { 267 | m_selected_output_embedding_weights[node_id] = true; 268 | m_output_layer_nodes.push_back(node_id); 269 | } 270 | } 271 | 272 | template 273 | std::vector& SkipGramMixtureNeuralNetwork::GetInputLayerNodes() 274 | { 275 | return m_input_layer_nodes; 276 | } 277 | 278 | template 279 | std::vector& SkipGramMixtureNeuralNetwork::GetOutputLayerNodes() 280 | { 281 | return m_output_layer_nodes; 282 | } 283 | 284 | template 285 | void SkipGramMixtureNeuralNetwork::SetInputEmbeddingWeights(int input_node_id, T* ptr) 286 | { 287 | m_input_embedding_weights_ptr[input_node_id] = ptr; 288 | } 289 | 290 | template 291 | void SkipGramMixtureNeuralNetwork::SetOutputEmbeddingWeights(int output_node_id, T* ptr) 292 | { 293 | m_output_embedding_weights_ptr[output_node_id] = ptr; 294 | } 295 | 296 | template 297 | void SkipGramMixtureNeuralNetwork::SetSensePriorWeights(int input_node_id, T*ptr) 298 | { 299 | m_sense_priors_ptr[input_node_id] = ptr; 300 | } 301 | 302 | template 303 | void SkipGramMixtureNeuralNetwork::SetSensePriorParaWeights(int input_node_id, T* ptr) 304 | { 305 | m_sense_priors_paras_ptr[input_node_id] = ptr; 306 | } 307 | 308 | template 309 | T* SkipGramMixtureNeuralNetwork::GetInputEmbeddingWeights(int input_node_id) 310 | { 311 | return m_input_embedding_weights_ptr[input_node_id]; 312 | } 313 | 314 | template 315 | T* SkipGramMixtureNeuralNetwork::GetEmbeddingOutputWeights(int output_node_id) 316 | { 317 | return m_output_embedding_weights_ptr[output_node_id]; 318 | } 319 | 320 | template 321 | T* SkipGramMixtureNeuralNetwork::GetSensePriorWeights(int input_node_id) 322 | { 323 | return m_sense_priors_ptr[input_node_id]; 324 | } 325 | 326 | template 327 | T* SkipGramMixtureNeuralNetwork::GetSensePriorParaWeights(int input_node_id) 328 | { 329 | return m_sense_priors_paras_ptr[input_node_id]; 330 | } 331 | 332 | template class SkipGramMixtureNeuralNetwork; 333 | template class SkipGramMixtureNeuralNetwork; 334 | -------------------------------------------------------------------------------- /src/skipgram_mixture_neural_network.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "util.h" 6 | #include 7 | #include "huffman_encoder.h" 8 | #include "multiverso_skipgram_mixture.h" 9 | #include "cstring" 10 | 11 | enum class UpdateDirection 12 | { 13 | UPDATE_INPUT, 14 | UPDATE_OUTPUT 15 | }; 16 | 17 | template 18 | class SkipGramMixtureNeuralNetwork 19 | { 20 | public: 21 | T learning_rate; 22 | T sense_prior_momentum; 23 | 24 | int status; 25 | SkipGramMixtureNeuralNetwork(Option* option, HuffmanEncoder* huffmanEncoder, WordSenseInfo* word_sense_info, Dictionary* dic, int dicSize); 26 | ~SkipGramMixtureNeuralNetwork(); 27 | 28 | void Train(int* sentence, int sentence_length, T* gamma, T* fTable, T* input_backup); 29 | 30 | /*! 31 | * \brief Collect all the input words and output nodes in the data block 32 | */ 33 | void PrepareParmeter(DataBlock *data_block); 34 | 35 | std::vector& GetInputLayerNodes(); 36 | std::vector& GetOutputLayerNodes(); 37 | 38 | /*! 39 | * \brief Set the pointers to those local parameters 40 | */ 41 | void SetInputEmbeddingWeights(int input_node_id, T* ptr); 42 | void SetOutputEmbeddingWeights(int output_node_id, T* ptr); 43 | void SetSensePriorWeights(int input_node_id, T*ptr); 44 | void SetSensePriorParaWeights(int input_node_id, T* ptr); 45 | 46 | /*! 47 | * \brief Get the pointers to those locally updated parameters 48 | */ 49 | T* GetInputEmbeddingWeights(int input_node_id); 50 | T* GetEmbeddingOutputWeights(int output_node_id); 51 | T* GetSensePriorWeights(int input_node_id); 52 | T* GetSensePriorParaWeights(int input_node_id); 53 | 54 | private: 55 | Option *m_option; 56 | Dictionary *m_dictionary; 57 | HuffmanEncoder *m_huffman_encoder; 58 | int m_dictionary_size; 59 | 60 | WordSenseInfo* m_word_sense_info; 61 | 62 | T** m_input_embedding_weights_ptr; //Points to every word's input embedding vector 63 | bool *m_seleted_input_embedding_weights; 64 | T** m_output_embedding_weights_ptr; //Points to every huffman node's embedding vector 65 | bool *m_selected_output_embedding_weights; 66 | 67 | T** m_sense_priors_ptr; //Points to the multinomial parameters, if store_multinomial is set to zero. 68 | T** m_sense_priors_paras_ptr;//Points to sense prior parameters. If store_multinomial is zero, then it points to the log of multinomial, otherwise points to the multinomial parameters 69 | 70 | std::vector m_input_layer_nodes; 71 | std::vector m_output_layer_nodes; 72 | 73 | typedef void(SkipGramMixtureNeuralNetwork::*FunctionType)(int input_node, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup); 74 | 75 | /*! 76 | * \brief Parse the needed parameter in a window 77 | */ 78 | void Parse(int feat, int word_idx, std::vector >& output_nodes); 79 | 80 | /*! 81 | * \brief Parse a sentence and deepen into two branchs 82 | * \one for TrainNN,the other one is for Parameter_parse&request 83 | */ 84 | void ParseSentence(int* sentence, int sentence_length, T* gamma, T* fTable, T* input_backup, FunctionType function); 85 | 86 | /*! 87 | * \brief Copy the input_nodes&output_nodes to WordEmbedding private set 88 | */ 89 | void DealPrepareParameter(int input_nodes, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup); 90 | 91 | /*! 92 | * \brief Train a window sample and update the 93 | * \input-embedding&output-embedding vectors 94 | * \param word_input represent the input words 95 | * \param output_nodes represent the ouput nodes on huffman tree, including the node index and path label 96 | * \param v_gamma is the temp memory to store the posterior probabilities of each sense 97 | * \param v_fTable is the temp memory to store the sigmoid value of inner product of input and output embeddings 98 | * \param v_input_backup stores the input embedding vectors as backup 99 | */ 100 | void TrainSample(int word_input, std::vector >& output_nodes, void* v_gamma, void* v_fTable, void* v_input_backup); 101 | 102 | /*! 103 | * \brief The E-step, estimate the posterior multinomial probabilities 104 | * \param word_input represent the input words 105 | * \param output_nodes represent the ouput nodes on huffman tree, including the node index and path label 106 | * \param posterior represents the calculated posterior log likelihood 107 | * \param estimation represents the calculated gammas (see the paper), that is, the softmax terms of posterior 108 | * \param sense_prior represents the parameters of sense prior probablities for each polysemous words 109 | * \param f_m is the temp memory to store the sigmoid value of inner products of input and output embeddings 110 | */ 111 | T EstimateGamma(int word_input, std::vector >& output_nodes, T* posterior, T* estimation, T* sense_prior, T* f_m); 112 | 113 | /*! 114 | * \brief The M step: update the embedding vectors to maximize the Q function 115 | * \param word_input represent the input words 116 | * \param output_nodes represent the ouput nodes on huffman tree, including the node index and path label 117 | * \param estimation represents the calculated gammas (see the paper), that is, the softmax terms of posterior 118 | * \param f_m is the temp memory to store the sigmoid value of inner products of input and output embeddings 119 | * \param input_backup stores the input embedding vectors as backup 120 | * \param direction: update input vectors or output vectors 121 | */ 122 | void UpdateEmbeddings(int word_input, std::vector >& output_nodes, T* estimation, T* f_m, T* input_backup, UpdateDirection direction); 123 | 124 | /*! 125 | * \brief The M Step: update the sense prior probabilities to maximize the Q function 126 | * \param word_input represent the input words 127 | * \param curr_priors are the closed form values of the sense priors in this iteration 128 | */ 129 | void MaximizeSensePriors(int word_input, T* curr_priors); 130 | 131 | /* 132 | * \brief Record the input word so that parameter loader can be performed 133 | */ 134 | void AddInputLayerNode(int node_id); 135 | 136 | /* 137 | * \brief Record the huffman tree node so that parameter loader can be performed 138 | */ 139 | void AddOutputLayerNode(int node_id); 140 | }; 141 | -------------------------------------------------------------------------------- /src/trainer.cpp: -------------------------------------------------------------------------------- 1 | #include "trainer.h" 2 | 3 | template 4 | Trainer::Trainer(int trainer_id, Option *option, void** word2vector_neural_networks, multiverso::Barrier *barrier, Dictionary* dictionary, WordSenseInfo* word_sense_info, HuffmanEncoder* huff_encoder) 5 | { 6 | m_trainer_id = trainer_id; 7 | m_option = option; 8 | m_word_count = m_last_word_count = 0; 9 | m_sgmixture_neural_networks = word2vector_neural_networks; 10 | m_barrier = barrier; 11 | m_dictionary = dictionary; 12 | m_word_sense_info = word_sense_info; 13 | m_huffman_encoder = huff_encoder; 14 | 15 | gamma = (T*)calloc(m_option->window_size * MAX_SENSE_CNT, sizeof(T)); 16 | fTable = (T*)calloc(m_option->window_size * MAX_CODE_LENGTH * MAX_SENSE_CNT, sizeof(T)); 17 | input_backup = (T*)calloc(m_option->embeding_size * MAX_SENSE_CNT, sizeof(T)); 18 | 19 | m_start_time = 0; 20 | m_train_count = 0; 21 | m_executive_time = 0; 22 | if (m_trainer_id == 0) 23 | { 24 | m_log_file = fopen("trainer.log", "w"); 25 | } 26 | } 27 | 28 | template 29 | //Train one datablock 30 | void Trainer::TrainIteration(multiverso::DataBlockBase *data_block) 31 | { 32 | if (m_train_count == 0) 33 | { 34 | m_start_time = clock(); 35 | m_process_id = multiverso::Multiverso::ProcessRank(); 36 | } 37 | 38 | printf("Rank %d Begin TrainIteration...%d\n", m_process_id, m_train_count); 39 | clock_t train_interation_start = clock(); 40 | fflush(stdout); 41 | 42 | m_process_count = multiverso::Multiverso::TotalProcessCount(); 43 | 44 | DataBlock *data = reinterpret_cast(data_block); 45 | SkipGramMixtureNeuralNetwork* word2vector_neural_network = reinterpret_cast*>(m_sgmixture_neural_networks[m_train_count % 2]); 46 | ++m_train_count; 47 | std::vector& input_layer_nodes = word2vector_neural_network->GetInputLayerNodes(); 48 | std::vector& output_layer_nodes = word2vector_neural_network->GetOutputLayerNodes(); 49 | std::vector local_input_layer_nodes, local_output_layer_nodes; 50 | assert(word2vector_neural_network->status == 2); 51 | if (m_trainer_id == 0) 52 | { 53 | multiverso::Log::Info("Rank %d input_layer_size=%d, output_layer_size=%d\n", m_process_id, input_layer_nodes.size(), output_layer_nodes.size()); 54 | } 55 | 56 | for (int i = m_trainer_id; i < input_layer_nodes.size(); i += m_option->thread_cnt) 57 | { 58 | local_input_layer_nodes.push_back(input_layer_nodes[i]); 59 | } 60 | 61 | for (int i = m_trainer_id; i < output_layer_nodes.size(); i += m_option->thread_cnt) 62 | { 63 | local_output_layer_nodes.push_back(output_layer_nodes[i]); 64 | } 65 | 66 | CopyParameterFromMultiverso(local_input_layer_nodes, local_output_layer_nodes, word2vector_neural_network); 67 | 68 | multiverso::Row& word_count_actual_row = GetRow(kWordCountActualTableId, 0); 69 | T learning_rate = m_option->init_learning_rate * (1 - word_count_actual_row.At(0) / (T)(m_option->total_words * m_option->epoch + 1)); 70 | if (learning_rate < m_option->init_learning_rate * (real)0.0001) 71 | learning_rate = m_option->init_learning_rate * (real)0.0001; 72 | word2vector_neural_network->learning_rate = learning_rate; 73 | 74 | //Linearly increase the momentum from init_sense_prior_momentum to 1 75 | word2vector_neural_network->sense_prior_momentum = m_option->init_sense_prior_momentum + 76 | (1 - m_option->init_sense_prior_momentum) * word_count_actual_row.At(0) / (T)(m_option->total_words * m_option->epoch + 1); 77 | 78 | m_barrier->Wait(); 79 | 80 | for (int i = m_trainer_id; i < data->Size(); i += m_option->thread_cnt) //i iterates over all sentences 81 | { 82 | int sentence_length; 83 | int64_t word_count_deta; 84 | int *sentence; 85 | uint64_t next_random; 86 | data->Get(i, sentence, sentence_length, word_count_deta, next_random); 87 | 88 | word2vector_neural_network->Train(sentence, sentence_length, gamma, fTable, input_backup); 89 | 90 | m_word_count += word_count_deta; 91 | if (m_word_count - m_last_word_count > 10000) 92 | { 93 | multiverso::Row& word_count_actual_row = GetRow(kWordCountActualTableId, 0); 94 | Add(kWordCountActualTableId, 0, 0, m_word_count - m_last_word_count); 95 | m_last_word_count = m_word_count; 96 | m_now_time = clock(); 97 | 98 | if (m_trainer_id % 3 == 0) 99 | { 100 | multiverso::Log::Info("Rank %d Trainer %d lr: %.5f Mom: %.4f Progress: %.2f%% Words/thread/sec(total): %.2fk W/t/sec(executive): %.2fk\n", 101 | m_process_id, m_trainer_id, 102 | word2vector_neural_network->learning_rate, word2vector_neural_network->sense_prior_momentum, 103 | word_count_actual_row.At(0) / (real)(m_option->total_words * m_option->epoch + 1) * 100, 104 | m_last_word_count / ((real)(m_now_time - m_start_time + 1) / (real)CLOCKS_PER_SEC * 1000), 105 | m_last_word_count / ((real)(m_executive_time + clock() - train_interation_start + 1) / (real)CLOCKS_PER_SEC * 1000)); 106 | 107 | fflush(stdout); 108 | } 109 | 110 | T learning_rate = m_option->init_learning_rate * (1 - word_count_actual_row.At(0) / (T)(m_option->total_words * m_option->epoch + 1)); 111 | if (learning_rate < m_option->init_learning_rate * (real)0.0001) 112 | learning_rate = m_option->init_learning_rate * (real)0.0001; 113 | word2vector_neural_network->learning_rate = learning_rate; 114 | 115 | word2vector_neural_network->sense_prior_momentum = m_option->init_sense_prior_momentum + (1 - m_option->init_sense_prior_momentum) * word_count_actual_row.At(0) / (T)(m_option->total_words * m_option->epoch + 1); 116 | } 117 | } 118 | 119 | m_barrier->Wait(); 120 | AddParameterToMultiverso(local_input_layer_nodes, local_output_layer_nodes, word2vector_neural_network); 121 | 122 | m_executive_time += clock() - train_interation_start; 123 | 124 | multiverso::Log::Info("Rank %d Train %d end at %lfs, cost %lfs, total cost %lfs\n", 125 | m_process_id, 126 | m_trainer_id, clock() / (double)CLOCKS_PER_SEC, 127 | (clock() - train_interation_start) / (double)CLOCKS_PER_SEC, 128 | m_executive_time / (double)CLOCKS_PER_SEC); 129 | fflush(stdout); 130 | 131 | if (data->GetTables().size() > 0 && m_trainer_id == 0) //Dump model files 132 | { 133 | SaveMultiInputEmbedding(data->GetEpochId()); 134 | SaveOutputEmbedding(data->GetEpochId()); 135 | if (data->GetEpochId() == 0) 136 | SaveHuffEncoder(); 137 | 138 | fprintf(m_log_file, "%d %lf\t %lf\n", data->GetEpochId(), (clock() - m_start_time) / (double)CLOCKS_PER_SEC, m_executive_time / (double)CLOCKS_PER_SEC); 139 | } 140 | 141 | assert(word2vector_neural_network->status == 2); 142 | 143 | word2vector_neural_network->status = 0; 144 | 145 | multiverso::Log::Info("Rank %d Train %d are leaving training iter with nn status:%d\n", m_process_id, m_trainer_id, word2vector_neural_network->status); 146 | fflush(stdout); 147 | } 148 | 149 | template 150 | //Copy a size of memory from source row to dest 151 | void Trainer::CopyMemory(T* dest, multiverso::Row& source, int size) 152 | { 153 | for (int i = 0; i < size; ++i) 154 | dest[i] = source.At(i); 155 | } 156 | 157 | template 158 | //Copy the needed parameter from buffer to local blocks 159 | int Trainer::CopyParameterFromMultiverso(std::vector& input_layer_nodes, std::vector& output_layer_nodes, void* local_word2vector_neural_network) 160 | { 161 | SkipGramMixtureNeuralNetwork* word2vector_neural_network = (SkipGramMixtureNeuralNetwork*)local_word2vector_neural_network; 162 | 163 | //Copy input embedding 164 | for (int i = 0; i < input_layer_nodes.size(); ++i) 165 | { 166 | T* ptr = (T*)calloc(m_word_sense_info->word_sense_cnts_info[input_layer_nodes[i]] * m_option->embeding_size, sizeof(T)); 167 | int row_id_base = m_word_sense_info->p_input_embedding[input_layer_nodes[i]]; 168 | for (int j = 0, row_id = row_id_base; j < m_word_sense_info->word_sense_cnts_info[input_layer_nodes[i]]; ++j, ++row_id) 169 | CopyMemory(ptr + j * m_option->embeding_size, GetRow(kInputEmbeddingTableId, row_id), m_option->embeding_size); 170 | word2vector_neural_network->SetInputEmbeddingWeights(input_layer_nodes[i], ptr); 171 | } 172 | 173 | //Copy output embedding 174 | for (int i = 0; i < output_layer_nodes.size(); ++i) 175 | { 176 | T* ptr = (T*)calloc(m_option->embeding_size, sizeof(T)); 177 | CopyMemory(ptr, GetRow(kEmbeddingOutputTableId, output_layer_nodes[i]), m_option->embeding_size); 178 | for (int j = 0; j < m_option->embeding_size; j += 5) 179 | if (!Util::ValidF(static_cast(ptr[j]))) 180 | { 181 | printf("invalid number\n"); 182 | fflush(stdout); 183 | throw std::runtime_error("Invalid output embeddings"); 184 | } 185 | word2vector_neural_network->SetOutputEmbeddingWeights(output_layer_nodes[i], ptr); 186 | } 187 | 188 | //Copy sense prior 189 | for (int i = 0; i < input_layer_nodes.size(); ++i) 190 | { 191 | if (m_word_sense_info->word_sense_cnts_info[input_layer_nodes[i]] > 1) 192 | { 193 | T* ptr = (T*)calloc(m_option->sense_num_multi, sizeof(T)); 194 | T* para_ptr = (T*)calloc(m_option->sense_num_multi, sizeof(T)); 195 | 196 | CopyMemory(para_ptr, GetRow(kWordSensePriorTableId, m_word_sense_info->p_wordidx2sense_idx[input_layer_nodes[i]]), m_option->sense_num_multi); 197 | 198 | if (!m_option->store_multinomial)//softmax the para_ptr to obtain the multinomial parameters 199 | Util::SoftMax(para_ptr, ptr, m_option->sense_num_multi); 200 | word2vector_neural_network->SetSensePriorWeights(input_layer_nodes[i], ptr); 201 | word2vector_neural_network->SetSensePriorParaWeights(input_layer_nodes[i], para_ptr); 202 | } 203 | } 204 | 205 | return 0; 206 | } 207 | 208 | template 209 | //Add delta of a row of local parameters to the parameter stored in the buffer and send it to multiverso 210 | void Trainer::AddParameterRowToMultiverso(T* ptr, int table_id, int row_id, int size, real momentum) 211 | { 212 | multiverso::Row& row = GetRow(table_id, row_id); 213 | for (int i = 0; i < size; ++i) 214 | { 215 | T dest = ptr[i] * (1 - momentum) + row.At(i) * momentum; 216 | T delta = (dest - row.At(i)) / m_process_count; 217 | Add(table_id, row_id, i, delta); 218 | } 219 | } 220 | 221 | template 222 | //Add delta to the parameter stored in the buffer and send it to multiverso 223 | int Trainer::AddParameterToMultiverso(std::vector& input_layer_nodes, std::vector& output_layer_nodes, void* local_word2vector_neural_network) 224 | { 225 | SkipGramMixtureNeuralNetwork* word2vector_neural_network = (SkipGramMixtureNeuralNetwork*)local_word2vector_neural_network; 226 | std::vector blocks; //used to store locally malloced memorys 227 | 228 | //Add input embeddings 229 | for (int i = 0; i < input_layer_nodes.size(); ++i) 230 | { 231 | int table_id = kInputEmbeddingTableId; 232 | int row_id_base = m_word_sense_info->p_input_embedding[input_layer_nodes[i]]; 233 | T* ptr = word2vector_neural_network->GetInputEmbeddingWeights(input_layer_nodes[i]); 234 | 235 | for (int j = 0, row_id = row_id_base; j < m_word_sense_info->word_sense_cnts_info[input_layer_nodes[i]]; ++j, ++row_id) 236 | AddParameterRowToMultiverso(ptr + m_option->embeding_size * j, table_id, row_id, m_option->embeding_size); 237 | blocks.push_back(ptr); 238 | } 239 | 240 | //Add output embeddings 241 | for (int i = 0; i < output_layer_nodes.size(); ++i) 242 | { 243 | int table_id = kEmbeddingOutputTableId; 244 | int row_id = output_layer_nodes[i]; 245 | T* ptr = word2vector_neural_network->GetEmbeddingOutputWeights(row_id); 246 | AddParameterRowToMultiverso(ptr, table_id, row_id, m_option->embeding_size); 247 | blocks.push_back(ptr); 248 | } 249 | 250 | //Add sense priors 251 | for (int i = 0; i < input_layer_nodes.size(); ++i) 252 | { 253 | if (m_word_sense_info->word_sense_cnts_info[input_layer_nodes[i]] > 1) 254 | { 255 | int table_id = kWordSensePriorTableId; 256 | int row_id = m_word_sense_info->p_wordidx2sense_idx[input_layer_nodes[i]]; 257 | 258 | T* ptr = word2vector_neural_network->GetSensePriorWeights(input_layer_nodes[i]); 259 | T* para_ptr = word2vector_neural_network->GetSensePriorParaWeights(input_layer_nodes[i]); 260 | 261 | AddParameterRowToMultiverso(para_ptr, table_id, row_id, m_option->sense_num_multi, static_cast(word2vector_neural_network->sense_prior_momentum)); 262 | 263 | blocks.push_back(ptr); 264 | blocks.push_back(para_ptr); 265 | } 266 | 267 | } 268 | 269 | for (auto& x : blocks) 270 | free(x); 271 | 272 | return 0; 273 | } 274 | 275 | template 276 | void Trainer::SaveMultiInputEmbedding(const int epoch_id) 277 | { 278 | FILE* fid = nullptr; 279 | T* sense_priors_ptr = (T*)calloc(m_option->sense_num_multi, sizeof(real)); 280 | 281 | char outfile[2000]; 282 | if (m_option->output_binary) 283 | { 284 | sprintf(outfile, "%s%d", m_option->binary_embedding_file, epoch_id); 285 | 286 | fid = fopen(outfile, "wb"); 287 | 288 | fprintf(fid, "%d %d %d\n", m_dictionary->Size(), m_word_sense_info->total_senses_cnt, m_option->embeding_size); 289 | for (int i = 0; i < m_dictionary->Size(); ++i) 290 | { 291 | fprintf(fid, "%s %d ", m_dictionary->GetWordInfo(i)->word.c_str(), m_word_sense_info->word_sense_cnts_info[i]); 292 | int emb_row_id; 293 | real emb_tmp; 294 | 295 | if (m_word_sense_info->word_sense_cnts_info[i] > 1) 296 | { 297 | CopyMemory(sense_priors_ptr, GetRow(kWordSensePriorTableId, m_word_sense_info->p_wordidx2sense_idx[i]), m_option->sense_num_multi); 298 | if (!m_option->store_multinomial) 299 | Util::SoftMax(sense_priors_ptr, sense_priors_ptr, m_option->sense_num_multi); 300 | 301 | for (int j = 0; j < m_option->sense_num_multi; ++j) 302 | { 303 | fwrite(sense_priors_ptr + j, sizeof(real), 1, fid); 304 | emb_row_id = m_word_sense_info->p_input_embedding[i] + j; 305 | multiverso::Row& embedding = GetRow(kInputEmbeddingTableId, emb_row_id); 306 | for (int k = 0; k < m_option->embeding_size; ++k) 307 | { 308 | emb_tmp = embedding.At(k); 309 | fwrite(&emb_tmp, sizeof(real), 1, fid); 310 | } 311 | } 312 | fprintf(fid, "\n"); 313 | } 314 | else 315 | { 316 | real prob = 1.0; 317 | fwrite(&prob, sizeof(real), 1, fid); 318 | emb_row_id = m_word_sense_info->p_input_embedding[i]; 319 | multiverso::Row& embedding = GetRow(kInputEmbeddingTableId, emb_row_id); 320 | 321 | for (int k = 0; k < m_option->embeding_size; ++k) 322 | { 323 | emb_tmp = embedding.At(k); 324 | fwrite(&emb_tmp, sizeof(real), 1, fid); 325 | } 326 | fprintf(fid, "\n"); 327 | } 328 | } 329 | 330 | fclose(fid); 331 | } 332 | if (m_option->output_binary % 2 == 0) 333 | { 334 | sprintf(outfile, "%s%d", m_option->text_embedding_file, epoch_id); 335 | 336 | fid = fopen(outfile, "w"); 337 | fprintf(fid, "%d %d %d\n", m_dictionary->Size(), m_word_sense_info->total_senses_cnt, m_option->embeding_size); 338 | for (int i = 0; i < m_dictionary->Size(); ++i) 339 | { 340 | fprintf(fid, "%s %d\n", m_dictionary->GetWordInfo(i)->word.c_str(), m_word_sense_info->word_sense_cnts_info[i]); 341 | 342 | int emb_row_id; 343 | real emb_tmp; 344 | 345 | if (m_word_sense_info->word_sense_cnts_info[i] > 1) 346 | { 347 | CopyMemory(sense_priors_ptr, GetRow(kWordSensePriorTableId, m_word_sense_info->p_wordidx2sense_idx[i]), m_option->sense_num_multi); 348 | 349 | if (!m_option->store_multinomial) 350 | Util::SoftMax(sense_priors_ptr, sense_priors_ptr, m_option->sense_num_multi); 351 | 352 | for (int j = 0; j < m_option->sense_num_multi; ++j) 353 | { 354 | fprintf(fid, "%.4f", sense_priors_ptr[j]); 355 | 356 | emb_row_id = m_word_sense_info->p_input_embedding[i] + j; 357 | multiverso::Row& embedding = GetRow(kInputEmbeddingTableId, emb_row_id); 358 | for (int k = 0; k < m_option->embeding_size; ++k) 359 | { 360 | emb_tmp = embedding.At(k); 361 | fprintf(fid, " %.3f", emb_tmp); 362 | } 363 | fprintf(fid, "\n"); 364 | } 365 | } 366 | else 367 | { 368 | real prob = 1.0; 369 | fprintf(fid, "%.4f", 1.0); 370 | 371 | emb_row_id = m_word_sense_info->p_input_embedding[i]; 372 | multiverso::Row& embedding = GetRow(kInputEmbeddingTableId, emb_row_id); 373 | for (int k = 0; k < m_option->embeding_size; ++k) 374 | { 375 | emb_tmp = embedding.At(k); 376 | fprintf(fid, " %.3f", emb_tmp); 377 | } 378 | fprintf(fid, "\n"); 379 | } 380 | } 381 | 382 | fclose(fid); 383 | } 384 | } 385 | 386 | template 387 | void Trainer::SaveOutputEmbedding(const int epoch_id) 388 | { 389 | char outfile[2000]; 390 | if (m_option->output_binary) 391 | { 392 | sprintf(outfile, "%s%d", m_option->outputlayer_binary_file, epoch_id); 393 | 394 | FILE* fid = fopen(outfile, "wb"); 395 | fprintf(fid, "%d %d\n", m_dictionary->Size(), m_option->embeding_size); 396 | for (int i = 0; i < m_dictionary->Size(); ++i) 397 | { 398 | multiverso::Row& hs_embedding = GetRow(kEmbeddingOutputTableId, i); 399 | for (int j = 0; j < m_option->embeding_size; ++j) 400 | { 401 | real emb_tmp = hs_embedding.At(j); 402 | fwrite(&emb_tmp, sizeof(real), 1, fid); 403 | } 404 | } 405 | fclose(fid); 406 | } 407 | if (m_option->output_binary % 2 == 0) 408 | { 409 | sprintf(outfile, "%s%d", m_option->outputlayer_text_file, epoch_id); 410 | 411 | FILE* fid = fopen(outfile, "w"); 412 | fprintf(fid, "%d %d\n", m_dictionary->Size(), m_option->embeding_size); 413 | for (int i = 0; i < m_dictionary->Size(); ++i) 414 | { 415 | multiverso::Row& hs_embedding = GetRow(kEmbeddingOutputTableId, i); 416 | 417 | for (int j = 0; j < m_option->embeding_size; ++j) 418 | fprintf(fid, "%.2f ", hs_embedding.At(j)); 419 | fprintf(fid, "\n"); 420 | } 421 | fclose(fid); 422 | } 423 | } 424 | 425 | template 426 | void Trainer::SaveHuffEncoder() 427 | { 428 | FILE* fid = fopen(m_option->huff_tree_file, "w"); 429 | fprintf(fid, "%d\n", m_dictionary->Size()); 430 | for (int i = 0; i < m_dictionary->Size(); ++i) 431 | { 432 | fprintf(fid, "%s", m_dictionary->GetWordInfo(i)->word.c_str()); 433 | const auto info = m_huffman_encoder->GetLabelInfo(i); 434 | fprintf(fid, " %d", info->codelen); 435 | for (int j = 0; j < info->codelen; ++j) 436 | fprintf(fid, " %d", info->code[j]); 437 | for (int j = 0; j < info->codelen; ++j) 438 | fprintf(fid, " %d", info->point[j]); 439 | fprintf(fid, "\n"); 440 | } 441 | fclose(fid); 442 | } 443 | 444 | template class Trainer; 445 | template class Trainer; 446 | -------------------------------------------------------------------------------- /src/trainer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "data_block.h" 10 | #include "multiverso_tablesid.h" 11 | #include "util.h" 12 | #include "huffman_encoder.h" 13 | #include "skipgram_mixture_neural_network.h" 14 | 15 | 16 | template 17 | class Trainer : public multiverso::TrainerBase 18 | { 19 | public: 20 | Trainer(int trainer_id, Option *option, void** word2vector_neural_networks, multiverso::Barrier* barrier, Dictionary* dictionary, WordSenseInfo* word_sense_info, HuffmanEncoder* huff_encoder); 21 | 22 | /*! 23 | * /brief Train one datablock 24 | */ 25 | void TrainIteration(multiverso::DataBlockBase* data_block) override; 26 | 27 | private: 28 | int m_process_id; 29 | int m_trainer_id; 30 | int m_train_count; //threads count 31 | int m_process_count; //machines count 32 | 33 | Option *m_option; 34 | WordSenseInfo* m_word_sense_info; 35 | HuffmanEncoder* m_huffman_encoder; 36 | 37 | int64_t m_word_count, m_last_word_count; 38 | 39 | T* gamma, * fTable, *input_backup; //temp memories to store middle results in the EM algorithm 40 | 41 | clock_t m_start_time, m_now_time, m_executive_time; 42 | void ** m_sgmixture_neural_networks; 43 | multiverso::Barrier *m_barrier; 44 | Dictionary* m_dictionary; 45 | FILE* m_log_file; 46 | 47 | /*! 48 | * \brief Save the multi sense input-embedding vectors 49 | * \param epoch_id, the embedding vectors after epoch_id is dumped 50 | */ 51 | void SaveMultiInputEmbedding(const int epoch_id); 52 | 53 | /*! 54 | * \brief Save the outpue embedding vectors, i.e. the embeddings for huffman tree nodes 55 | * \param epoch_id, the embedding vectors after epoch_id is dumped 56 | */ 57 | void SaveOutputEmbedding(const int epoch_id); 58 | 59 | /*! 60 | * \brief Save the Huffman tree structure 61 | */ 62 | void SaveHuffEncoder(); 63 | 64 | /*! 65 | * \brief Copy the needed parameter from buffer to local blocks 66 | */ 67 | void CopyMemory(T* dest, multiverso::Row& source, int size); 68 | int CopyParameterFromMultiverso(std::vector& input_layer_nodes, std::vector& output_layer_nodes, void* word2vector_neural_networks); 69 | 70 | /*! 71 | * \brief Add delta to the parameter stored in the 72 | * \buffer and send it to multiverso 73 | */ 74 | int AddParameterToMultiverso(std::vector& input_layer_nodes, std::vector& output_layer_nodes, void* word2vector_neural_networks); 75 | /*! 76 | * \brief Add delta of a row of local parameters to the parameter stored in the 77 | * \buffer and send it to multiverso 78 | * \param momentum: new_value = old_value * momentum + current_value * (1 - momentum). Set to non zero when updating the sense_priors 79 | */ 80 | void AddParameterRowToMultiverso(T* ptr, int table_id, int row_id, int size, real momentum = 0); 81 | 82 | }; 83 | 84 | -------------------------------------------------------------------------------- /src/util.cpp: -------------------------------------------------------------------------------- 1 | #include "util.h" 2 | 3 | Option::Option() 4 | { 5 | train_file = NULL; 6 | read_vocab_file = NULL; 7 | binary_embedding_file = NULL; 8 | text_embedding_file = NULL; 9 | 10 | sw_file = NULL; 11 | output_binary = 2; 12 | embeding_size = 0; 13 | thread_cnt = 1; 14 | window_size = 5; 15 | min_count = 5; 16 | data_block_size = 100; 17 | init_learning_rate = static_cast(0.025); 18 | epoch = 1; 19 | stopwords = false; 20 | total_words = 0; 21 | 22 | //multisense config 23 | store_multinomial = false; 24 | EM_iteration = 1; 25 | top_N = 0; 26 | top_ratio = static_cast(0.1); 27 | sense_num_multi = 1; 28 | init_sense_prior_momentum = static_cast(0.1); 29 | sense_file = NULL; 30 | huff_tree_file = NULL; 31 | outputlayer_binary_file = NULL; 32 | outputlayer_text_file = NULL; 33 | 34 | // multiverso config 35 | num_servers = 0; 36 | num_aggregator = 1; 37 | lock_option = 1; 38 | num_lock = 100; 39 | max_delay = 0; 40 | } 41 | 42 | void Option::ParseArgs(int argc, char* argv[]) 43 | { 44 | for (int i = 1; i < argc; i += 2) 45 | { 46 | if (strcmp(argv[i], "-size") == 0) embeding_size = atoi(argv[i + 1]); 47 | if (strcmp(argv[i], "-train_file") == 0) train_file = argv[i + 1]; 48 | if (strcmp(argv[i], "-vocab_file") == 0) read_vocab_file = argv[i + 1]; 49 | if (strcmp(argv[i], "-binary") == 0) output_binary = atoi(argv[i + 1]); 50 | if (strcmp(argv[i], "-init_learning_rate") == 0) init_learning_rate = static_cast(atof(argv[i + 1])); 51 | if (strcmp(argv[i], "-binary_embedding_file") == 0) binary_embedding_file = argv[i + 1]; 52 | if (strcmp(argv[i], "-text_embedding_file") == 0) text_embedding_file = argv[i + 1]; 53 | if (strcmp(argv[i], "-window") == 0) window_size = atoi(argv[i + 1]); 54 | if (strcmp(argv[i], "-data_block_size") == 0) data_block_size = atoi(argv[i + 1]); 55 | if (strcmp(argv[i], "-threads") == 0) thread_cnt = atoi(argv[i + 1]); 56 | if (strcmp(argv[i], "-min_count") == 0) min_count = atoi(argv[i + 1]); 57 | if (strcmp(argv[i], "-epoch") == 0) epoch = atoi(argv[i + 1]); 58 | if (strcmp(argv[i], "-stopwords") == 0) stopwords = atoi(argv[i + 1]) != 0; 59 | if (strcmp(argv[i], "-sw_file") == 0) sw_file = argv[i + 1]; 60 | if (strcmp(argv[i], "-num_servers") == 0) num_servers = atoi(argv[i + 1]); 61 | if (strcmp(argv[i], "-num_aggregator") == 0) num_aggregator = atoi(argv[i + 1]); 62 | if (strcmp(argv[i], "-lock_option") == 0) lock_option = atoi(argv[i + 1]); 63 | if (strcmp(argv[i], "-num_lock") == 0) num_lock = atoi(argv[i + 1]); 64 | if (strcmp(argv[i], "-max_delay") == 0) max_delay = atoi(argv[i + 1]); 65 | if (strcmp(argv[i], "-max_preload_size") == 0) max_preload_blocks_cnt = atoi(argv[i + 1]); 66 | if (strcmp(argv[i], "-is_pipline") == 0) pipline = atoi(argv[i + 1]) != 0; 67 | 68 | if (strcmp(argv[i], "-sense_num_multi") == 0) sense_num_multi = atoi(argv[i + 1]); 69 | if (strcmp(argv[i], "-momentum") == 0) init_sense_prior_momentum = static_cast(atof(argv[i + 1])); 70 | if (strcmp(argv[i], "-EM_iteration") == 0) EM_iteration = atoi(argv[i + 1]); 71 | if (strcmp(argv[i], "-store_multinomial") == 0) store_multinomial = atoi(argv[i + 1]) != 0; 72 | if (strcmp(argv[i], "-top_n") == 0) top_N = atoi(argv[i + 1]); 73 | if (strcmp(argv[i], "-top_ratio") == 0) top_ratio = static_cast(atof(argv[i + 1])); 74 | if (strcmp(argv[i], "-read_sense") == 0) sense_file = argv[i + 1]; 75 | if (strcmp(argv[i], "-huff_tree_file") == 0) huff_tree_file = argv[i + 1]; 76 | if (strcmp(argv[i], "-outputlayer_binary_file") == 0) outputlayer_binary_file = argv[i + 1]; 77 | if (strcmp(argv[i], "-outputlayer_text_file") == 0) outputlayer_text_file = argv[i + 1]; 78 | } 79 | } 80 | 81 | void Option::PrintArgs() 82 | { 83 | printf("train_file: %s\n", train_file); 84 | printf("read_vocab_file: %s\n", read_vocab_file); 85 | printf("binary_embedding_file: %s\n", binary_embedding_file); 86 | printf("sw_file: %s\n", sw_file); 87 | printf("output_binary: %d\n", output_binary); 88 | printf("stopwords: %d\n", stopwords); 89 | printf("embeding_size: %d\n", embeding_size); 90 | printf("thread_cnt: %d\n", thread_cnt); 91 | printf("window_size: %d\n", window_size); 92 | printf("min_count: %d\n", min_count); 93 | printf("epoch: %d\n", epoch); 94 | printf("total_words: %lld\n", total_words); 95 | printf("init_learning_rate: %lf\n", init_learning_rate); 96 | printf("data_block_size: %d\n", data_block_size); 97 | printf("pre_load_data_blocks: %d\n", max_preload_blocks_cnt); 98 | printf("num_servers: %d\n", num_servers); 99 | printf("num_aggregator: %d\n", num_aggregator); 100 | printf("lock_option: %d\n", lock_option); 101 | printf("num_lock: %d\n", num_lock); 102 | printf("max_delay: %d\n", max_delay); 103 | printf("is_pipline:%d\n", pipline); 104 | printf("top_ratio: %lf\n", top_ratio); 105 | printf("top_N: %d\n", top_N); 106 | printf("store_multinomial: %d\n", store_multinomial); 107 | } 108 | 109 | //Check whether the user defined arguments are valid 110 | bool Option::CheckArgs() 111 | { 112 | if (!Util::IsFileExist(train_file)) 113 | { 114 | printf("Train corpus does not exist\n"); 115 | return false; 116 | } 117 | 118 | if (!Util::IsFileExist(read_vocab_file)) 119 | { 120 | printf("Vocab file does not exist\n"); 121 | return false; 122 | } 123 | 124 | if (output_binary && (binary_embedding_file == NULL || outputlayer_binary_file == NULL)) 125 | { 126 | printf("Binary output file name not specified\n"); 127 | return false; 128 | } 129 | 130 | if (output_binary % 2 == 0 && (text_embedding_file == NULL || outputlayer_text_file == NULL)) 131 | { 132 | printf("Text output file name not specified\n"); 133 | return false; 134 | } 135 | 136 | if (huff_tree_file == NULL) 137 | { 138 | printf("Huffman tree file name not speficied\n"); 139 | return false; 140 | } 141 | 142 | if (stopwords && !Util::IsFileExist(sw_file)) 143 | { 144 | printf("Stop words file does not exist\n"); 145 | return false; 146 | } 147 | 148 | if (init_sense_prior_momentum < -eps || init_sense_prior_momentum >= 1) 149 | { 150 | printf("Init momentum %.4f out of range, must lie between 0.0 and 1.0\n", init_sense_prior_momentum); 151 | return false; 152 | } 153 | 154 | if (top_ratio < -eps || top_ratio >= 1) 155 | { 156 | printf("Top ratio %.4f out of range, must lie between 0.0 and 1.0\n", init_sense_prior_momentum); 157 | return false; 158 | } 159 | 160 | if (sense_num_multi > MAX_SENSE_CNT) 161 | { 162 | printf("Sense number is too big, the maximum value is 50\n"); 163 | return false; 164 | } 165 | 166 | if (fabs(static_cast(max_delay)) > eps) 167 | { 168 | printf("Warning: better set max_delay to 0!\n"); 169 | } 170 | 171 | return true; 172 | } 173 | 174 | bool Util::ValidF(const real &f) 175 | { 176 | return f < 1 || f >= 1; 177 | } 178 | -------------------------------------------------------------------------------- /src/util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | typedef float real; 14 | 15 | #define MAX_STRING 100 16 | #define MAX_SENTENCE_LENGTH 2000 17 | #define MAX_EXP 6 18 | #define MAX_SENSE_CNT 50 19 | #define MIN_LOG -15 20 | 21 | const int table_size = (int)1e8; 22 | const real eps = (real)1e-8; 23 | 24 | struct WordSenseInfo 25 | { 26 | std::vector p_input_embedding; //Points to a word's row index in kInputEmbeddingTable 27 | std::unordered_map p_wordidx2sense_idx; //Map a word's idx to its row index in the table kWordSensePriorTableId 28 | 29 | std::vector word_sense_cnts_info; //Record every word's #sense count information 30 | int total_senses_cnt; 31 | int multi_senses_words_cnt; //Total number of words with multiple senses 32 | }; 33 | 34 | struct Option 35 | { 36 | const char* train_file; 37 | const char* read_vocab_file; 38 | const char* binary_embedding_file; 39 | const char* text_embedding_file; 40 | const char* sw_file; 41 | int output_binary, stopwords; 42 | int data_block_size; 43 | int embeding_size, thread_cnt, window_size, min_count, epoch; 44 | int64_t total_words; 45 | real init_learning_rate; 46 | int num_servers, num_aggregator, lock_option, num_lock, max_delay; 47 | bool pipline; 48 | int64_t max_preload_blocks_cnt; 49 | 50 | /*Multi sense config*/ 51 | int EM_iteration; 52 | int top_N; //The top top_N frequent words has multi senses, e.g. 500, 1000,... 53 | real top_ratio; // The top top_ratop frequent words has multi senses, e.g. 0.05, 0.1... 54 | int sense_num_multi; //Default number of senses for the multi_sense words 55 | real init_sense_prior_momentum; //Initial momentum, momentum is used in updating the sense priors 56 | bool store_multinomial; //Use multinomial parameters. If set to false, use the log of multinomial instead 57 | const char* sense_file; //The sense file storing (word, #sense) mapping 58 | const char* huff_tree_file; // The output file storing the huffman tree structure 59 | const char* outputlayer_binary_file; //The output binary file storing all the output embedding(i.e. the huffman node embedding) 60 | const char* outputlayer_text_file; //The output text file storing all the output embedding(i.e. the huffman node embedding) 61 | 62 | Option(); 63 | void ParseArgs(int argc, char* argv[]); 64 | void PrintArgs(); 65 | bool CheckArgs(); 66 | }; 67 | 68 | 69 | class Util 70 | { 71 | public: 72 | static void SaveVocab(); 73 | 74 | template 75 | static T InnerProduct(T* x, T* y, int length) 76 | { 77 | T result = 0; 78 | for (int i = 0; i < length; ++i) 79 | result += x[i] * y[i]; 80 | return result; 81 | } 82 | 83 | static bool ValidF(const real &f); 84 | 85 | template 86 | static T Sigmoid(T f) 87 | { 88 | if (f < -MAX_EXP) 89 | return 0; 90 | if (f > MAX_EXP) 91 | return 1; 92 | return 1 / (1 + exp(-f)); 93 | } 94 | 95 | template 96 | static void SoftMax(T* s, T* result, int size) 97 | { 98 | T sum = 0, max_v = s[0]; 99 | for (int j = 1; j < size; ++j) 100 | max_v = std::max(max_v, s[j]); 101 | for (int j = 0; j < size; ++j) 102 | sum += exp(s[j] - max_v); 103 | for (int j = 0; j < size; ++j) 104 | result[j] = exp(s[j] - max_v) / sum; 105 | } 106 | 107 | static bool IsFileExist(const char *fileName) 108 | { 109 | std::ifstream infile(fileName); 110 | return infile.good(); 111 | } 112 | 113 | }; 114 | 115 | -------------------------------------------------------------------------------- /windows/distributed_skipgram_mixture/distributed_skipgram_mixture.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 2013 4 | VisualStudioVersion = 12.0.40629.0 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "distributed_skipgram_mixture", "distributed_skipgram_mixture.vcxproj", "{05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Win32 = Debug|Win32 11 | Debug|x64 = Debug|x64 12 | Release|Win32 = Release|Win32 13 | Release|x64 = Release|x64 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}.Debug|Win32.ActiveCfg = Debug|Win32 17 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}.Debug|Win32.Build.0 = Debug|Win32 18 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}.Debug|x64.ActiveCfg = Debug|x64 19 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}.Debug|x64.Build.0 = Debug|x64 20 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}.Release|Win32.ActiveCfg = Release|Win32 21 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}.Release|Win32.Build.0 = Release|Win32 22 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}.Release|x64.ActiveCfg = Release|x64 23 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8}.Release|x64.Build.0 = Release|x64 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | EndGlobal 29 | -------------------------------------------------------------------------------- /windows/distributed_skipgram_mixture/distributed_skipgram_mixture.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Debug 10 | x64 11 | 12 | 13 | Release 14 | Win32 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {05DD6CF3-6096-44CE-A3C2-BBAE7A947DE8} 23 | Win32Proj 24 | distributed_skipgram_mixture 25 | 26 | 27 | 28 | Application 29 | true 30 | v120 31 | Unicode 32 | 33 | 34 | Application 35 | true 36 | v120 37 | Unicode 38 | 39 | 40 | Application 41 | false 42 | v120 43 | true 44 | Unicode 45 | 46 | 47 | Application 48 | false 49 | v120 50 | true 51 | Unicode 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | true 71 | 72 | 73 | true 74 | $(SolutionDir)/../../../multiverso/include;$(VC_IncludePath);$(WindowsSDK_IncludePath); 75 | $(SolutionDir)/../../../multiverso/windows/x64/Debug;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64); 76 | 77 | 78 | false 79 | 80 | 81 | false 82 | $(SolutionDir)/../../../multiverso/include;$(VC_IncludePath);$(WindowsSDK_IncludePath); 83 | $(SolutionDir)/../../../multiverso/windows/x64/Release;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64); 84 | 85 | 86 | 87 | 88 | 89 | Level3 90 | Disabled 91 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 92 | 93 | 94 | Console 95 | true 96 | 97 | 98 | 99 | 100 | 101 | 102 | Level3 103 | Disabled 104 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 105 | 106 | 107 | Console 108 | true 109 | multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) 110 | 111 | 112 | 113 | 114 | Level3 115 | 116 | 117 | MaxSpeed 118 | true 119 | true 120 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 121 | 122 | 123 | Console 124 | true 125 | true 126 | true 127 | 128 | 129 | 130 | 131 | Level3 132 | 133 | 134 | MaxSpeed 135 | true 136 | true 137 | _CRT_SECURE_NO_WARNINGS;WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 138 | 139 | 140 | Console 141 | true 142 | true 143 | true 144 | multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | --------------------------------------------------------------------------------