├── Makefile ├── README.md ├── src ├── args.cc ├── args.h ├── asvoid.h ├── dictionary.cc ├── dictionary.h ├── fasttext.cc ├── fasttext.h ├── main.cc ├── matrix.cc ├── matrix.h ├── model.cc ├── model.h ├── productquantizer.cc ├── productquantizer.h ├── qmatrix.cc ├── qmatrix.h ├── real.cc ├── real.h ├── sent2vec.pyx ├── shmem_matrix.cc ├── shmem_matrix.h ├── utils.cc ├── utils.h ├── vector.cc └── vector.h └── vectors_by_lang.py /Makefile: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | CXX = c++ 11 | CXXFLAGS = -pthread -std=c++0x 12 | OBJS = args.o dictionary.o productquantizer.o matrix.o shmem_matrix.o qmatrix.o vector.o model.o utils.o fasttext.o 13 | INCLUDES = -I. 14 | ifneq ($(shell uname),Darwin) 15 | LINK_RT := -lrt 16 | endif 17 | 18 | opt: CXXFLAGS += -O3 -funroll-loops 19 | opt: fasttext 20 | 21 | debug: CXXFLAGS += -g -O0 -fno-inline 22 | debug: fasttext 23 | 24 | args.o: src/args.cc src/args.h 25 | $(CXX) $(CXXFLAGS) -c src/args.cc 26 | 27 | dictionary.o: src/dictionary.cc src/dictionary.h src/args.h 28 | $(CXX) $(CXXFLAGS) -c src/dictionary.cc 29 | 30 | productquantizer.o: src/productquantizer.cc src/productquantizer.h src/utils.h 31 | $(CXX) $(CXXFLAGS) -c src/productquantizer.cc 32 | 33 | matrix.o: src/matrix.cc src/matrix.h src/utils.h 34 | $(CXX) $(CXXFLAGS) -c src/matrix.cc 35 | 36 | shmem_matrix.o: src/shmem_matrix.cc src/shmem_matrix.h 37 | $(CXX) $(CXXFLAGS) -c src/shmem_matrix.cc 38 | 39 | qmatrix.o: src/qmatrix.cc src/qmatrix.h src/utils.h 40 | $(CXX) $(CXXFLAGS) -c src/qmatrix.cc 41 | 42 | vector.o: src/vector.cc src/vector.h src/utils.h 43 | $(CXX) $(CXXFLAGS) -c src/vector.cc 44 | 45 | model.o: src/model.cc src/model.h src/args.h 46 | $(CXX) $(CXXFLAGS) -c src/model.cc 47 | 48 | utils.o: src/utils.cc src/utils.h 49 | $(CXX) $(CXXFLAGS) -c src/utils.cc 50 | 51 | fasttext.o: src/fasttext.cc src/*.h 52 | $(CXX) $(CXXFLAGS) -c src/fasttext.cc 53 | 54 | fasttext: $(OBJS) src/fasttext.cc 55 | $(CXX) $(CXXFLAGS) $(OBJS) src/main.cc -o fasttext $(LINK_RT) 56 | 57 | clean: 58 | rm -rf *.o fasttext 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bi-Sent2Vec 2 | 3 | TLDR: This library provides cross-lingual numerical representations (features) for words, short texts, or sentences, which can be used as input to any machine learning task with applications geared towards cross-lingual word translation, cross-lingual sentence retrieval as well as cross-lingual downstream NLP tasks. The library is a cross-lingual extension of [Sent2Vec](https://github.com/epfml/sent2vec). 4 | 5 | Bi-Sent2Vec vectors are also well suited to monolingual tasks as indicated by a marked improvement in the monolingual quality of the word embeddings. (For more details, see [paper](https://arxiv.org/abs/1912.12481)) 6 | 7 | ### Table of Contents 8 | 9 | * [Setup and Requirements](#setup-and-requirements) 10 | * [Using the model](#using-the-model) 11 | - [Downloading Bi-Sent2Vec pre-trained vectors](#downloading-bi-sent2vec-pre-trained-vectors) 12 | - [Train a New Bi-Sent2Vec Model](#train-a-new-bi-sent2vec-model) 13 | * [Evaluation](#evaluation) 14 | * [References](#references) 15 | 16 | # Setup and Requirements 17 | 18 | Our code builds upon [Facebook's FastText library](https://github.com/facebookresearch/fastText). 19 | 20 | To compile the library, simply run the `make` command. 21 | 22 | # Using the model 23 | 24 | For the purpose of generating cross-lingual word and sentence representations, we introduce our Bi-Sent2vec method and provide code and models. 25 | 26 | The method uses a simple but efficient objective to train distributed representations of sentences. The algorithm outperforms the current state-of-the-art bag-of-words based models on most of the benchmark tasks, and is also competitive with deep models on some of the tasks, highlighting the robustness of the produced word and sentence embeddings, see [*the paper*](https://arxiv.org/abs/1912.12481) for more details. 27 | 28 | ## Downloading Bi-Sent2Vec pre-trained vectors 29 | 30 | Models trained and tested in the Bi-Sent2Vec paper can be downloaded from the following links. Users are encouraged to add more bi-lingual models to the list provided they have been benchmarked properly. 31 | 32 | ### Unigram 33 | 34 | [EN](https://drive.google.com/file/d/1schNkg0OLTrTqA_VSCcpJnczaZbEfUiW/view?usp=sharing)-[DE](https://drive.google.com/file/d/1S76Pf_UByF9vHfGHx3EAP5bB3Vvyi8_l/view?usp=sharing) 35 | • 36 | [EN](https://drive.google.com/file/d/1b_q6WCXdQEKz0Grx21mzxVaGqBV7Y5WY/view?usp=sharing)-[ES](https://drive.google.com/file/d/1pEusR2238oJwLmRzC0j7pduaKW6FLzOv/view?usp=sharing) 37 | • 38 | [EN](https://drive.google.com/file/d/1Omac6Cbkb7cmyOeTZpyacGOKGHy9ixo8/view?usp=sharing)-[FI](https://drive.google.com/file/d/1rr_ZhDPjp901vGKUuK4gjXEM9aDBDOOD/view?usp=sharing) 39 | • 40 | [EN](https://drive.google.com/file/d/1Ny7TDW_1jRZTH327OhrGbSpIPgsSr3LJ/view?usp=sharing)-[FR](https://drive.google.com/file/d/1WTsLmVcjG_M8vwgUvfvM_A1386q7Nv0H/view?usp=sharing) 41 | • 42 | [EN](https://drive.google.com/file/d/1dPmM270pUTW2ETl14SfcFFI0hscEeQXO/view?usp=sharing)-[HU](https://drive.google.com/file/d/1aLe8CsB2o0fjmMTonYozcj0V69pPY59_/view?usp=sharing) 43 | • 44 | [EN](https://drive.google.com/file/d/1C6e-6qkhsoYjWlJ0OcjWeStOSUvQUOuQ/view?usp=sharing)-[IT](https://drive.google.com/file/d/1_rO75UgpZpug7kzjtho-jkigwl_igFqm/view?usp=sharing) 45 | 46 | ### Bigram 47 | 48 | [EN](https://drive.google.com/file/d/1CI4sFR0Y6v6zHzdaFN17vn9Wo_m1ywno/view?usp=sharing)-[DE](https://drive.google.com/file/d/1HyKS0QpBHd_2pLp_0JxGFDAMASydR5Xe/view?usp=sharing) 49 | • 50 | [EN](https://drive.google.com/file/d/1XKk4Vw4ATMcYAhmDl_nUx1HnpIcfuEwX/view?usp=sharing)-[ES](https://drive.google.com/file/d/1oJ2LXUk0CZzwj02sWIVZtsXH6psICOHl/view?usp=sharing) 51 | • 52 | [EN](https://drive.google.com/file/d/1q9dn76Sau3ArOEJ-J2mYghAnnoPjb9us/view?usp=sharing)-[FI](https://drive.google.com/file/d/1cqen99e_BNZp13wWGBpJf7x5b-PmHEKl/view?usp=sharing) 53 | • 54 | [EN](https://drive.google.com/file/d/1ztCsll3YUUBVMDHTZBPDmNMcZPkQbEu-/view?usp=sharing)-[FR](https://drive.google.com/file/d/1KWuuFpNDOmEXoLvwWU2OastlK5MyIke6/view?usp=sharing) 55 | • 56 | [EN](https://drive.google.com/file/d/15sMJQNm3s6uWh80y2SkxY6hCWnMx5pmx/view?usp=sharing)-[HU](https://drive.google.com/file/d/1K88rEsVM7mrcHZlqwHPvtjWy5JPkIvJ6/view?usp=sharing) 57 | • 58 | [EN](https://drive.google.com/file/d/1Iv-vuPWw40mvkbzfRJ7c0EIvvO_VRjT6/view?usp=sharing)-[IT](https://drive.google.com/file/d/1XXDGKQFscr_snJFzGc-aafUVRy_q6r0g/view?usp=sharing) 59 | 60 | 61 | ## Train a New Bi-Sent2Vec Model 62 | ### Tokenizing and data format 63 | Bi-Sent2Vec requires parallel sentences (sentences which are translations of each other) for training. 64 | We use [spacy](https://spacy.io/) tokenizer to tokenize the text. 65 | 66 | The required data format is one sentence pair per line. The two parallel sentences are separated by a \<\\> token and each word has its language code attached to it as a prefix. For example, here is an example of a snapshot of a valid English-French dataset - 67 | ``` 68 | the_en train_en is_en arriving_en ._en <> le_fr train_fr arrive_fr ._fr 69 | france_en won_en the_en world_en cup_en ._en <> la_fr france a_fr gagné_fr la_fr coupe_fr du_fr monde_fr ._fr 70 | ``` 71 | 72 | ## Training 73 | 74 | Assuming en-fr_sentences.txt is the pre-processed training corpus, here is an example of a command to train a Bi-Sent2Vec model: 75 | 76 | ./fasttext bisent2vec -input en-fr_sentences.txt -output model-en-fr -dim 300 -lr 0.2 -neg 10 -bucket 2000000 -maxVocabSize 750000 -thread 30 -t 0.000005 -epoch 5 -minCount 8 -dropoutK 4 -loss ns -wordNgrams 2 -numCheckPoints 5 77 | 78 | Here is a description of all available arguments: 79 | 80 | ``` 81 | The following arguments are mandatory: 82 | -input training file path 83 | -output output file path (model is stored in the .bin file and the vectors in .vec file) 84 | 85 | The following arguments are optional: 86 | -lr learning rate [0.2] 87 | -lrUpdateRate change the rate of updates for the learning rate [100] 88 | -dim dimension of word and sentence vectors [100] 89 | -epoch number of epochs [5] 90 | -minCount minimal number of word occurences [5] 91 | -minCountLabel minimal number of label occurences [0] 92 | -neg number of negatives sampled [10] 93 | -wordNgrams max length of word ngram [2] 94 | -loss loss function {ns, hs, softmax} [ns] 95 | -bucket number of hash buckets for vocabulary [2000000] 96 | -thread number of threads [2] 97 | -t sampling threshold [0.0001] 98 | -dropoutK number of ngrams dropped when training a Bi-Sent2Vec model [2] 99 | -verbose verbosity level [2] 100 | -maxVocabSize vocabulary exceeding this size will be truncated [None] 101 | -numCheckPoints number of intermediary checkpoints to save when training [1] 102 | ``` 103 | ### Post Processing 104 | Use vectors_by_lang.py to separate the vectors for the two different languages. 105 | Example - 106 | ``` 107 | python vectors_by_lang.py model-en-fr.vec en fr 108 | ``` 109 | This code will create two files model-en-fr_en.vec and model-en-fr_fr.vec in word2vec format containing vectors for English and French respectively. 110 | 111 | # Evaluation 112 | Our models are evaluated using the standard evaluation tool in the [MUSE](https://github.com/facebookresearch/MUSE) repository by Facebook AI Research. 113 | 114 | # References 115 | When using this code or some of our pretrained vectors for your application, please cite the following paper: 116 | 117 | Ali Sabet, Prakhar Gupta, Jean-Baptiste Cordonnier, Robert West, Martin Jaggi [*Robust Cross-lingual Embeddings from Parallel Sentences*](https://arxiv.org/abs/1912.12481) 118 | 119 | ``` 120 | @article{Sabet2019RobustCE, 121 | title={Robust Cross-lingual Embeddings from Parallel Sentences}, 122 | author={Ali Sabet and Prakhar Gupta and Jean-Baptiste Cordonnier and Robert West and Martin Jaggi}, 123 | journal={ArXiv 1912.12481}, 124 | year={2020}, 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /src/args.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "args.h" 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | 17 | namespace fasttext { 18 | 19 | Args::Args() { 20 | lr = 0.05; 21 | boostNgrams = 1.0; 22 | dim = 100; 23 | ws = 5; 24 | dropoutK = 0; 25 | epoch = 5; 26 | minCount = 5; 27 | minCountLabel = 0; 28 | neg = 5; 29 | wordNgrams = 1; 30 | loss = loss_name::ns; 31 | model = model_name::transgram; 32 | bucket = 2000000; 33 | bucketChar = 1; 34 | minn = 0; 35 | maxn = 0; 36 | thread = 12; 37 | lrUpdateRate = 100; 38 | t = 1e-4; 39 | label = "__label__"; 40 | verbose = 2; 41 | pretrainedVectors = ""; 42 | saveOutput = 0; 43 | maxVocabSize = -1; 44 | numCheckPoints = 1; 45 | qout = false; 46 | retrain = false; 47 | qnorm = false; 48 | cutoff = 0; 49 | dsub = 2; 50 | } 51 | 52 | void Args::parseArgs(int argc, char** argv) { 53 | std::string command(argv[1]); 54 | if (command == "bisent2vec") { 55 | model = model_name::bisent2vec; 56 | loss = loss_name::ns; 57 | neg = 10; 58 | minCount = 5; 59 | minn = 3; 60 | maxn = 6; 61 | lr = 0.05; 62 | dropoutK = 2; 63 | } 64 | int ai = 2; 65 | while (ai < argc) { 66 | if (argv[ai][0] != '-') { 67 | std::cout << "Provided argument without a dash! Usage:" << std::endl; 68 | printHelp(); 69 | exit(EXIT_FAILURE); 70 | } 71 | if (strcmp(argv[ai], "-h") == 0) { 72 | std::cout << "Here is the help! Usage:" << std::endl; 73 | printHelp(); 74 | exit(EXIT_FAILURE); 75 | } else if (strcmp(argv[ai], "-input") == 0) { 76 | input = std::string(argv[ai + 1]); 77 | } else if (strcmp(argv[ai], "-test") == 0) { 78 | test = std::string(argv[ai + 1]); 79 | } else if (strcmp(argv[ai], "-output") == 0) { 80 | output = std::string(argv[ai + 1]); 81 | } else if (strcmp(argv[ai], "-dict") == 0) { 82 | dict = std::string(argv[ai + 1]); 83 | } else if (strcmp(argv[ai], "-lr") == 0) { 84 | lr = atof(argv[ai + 1]); 85 | } else if (strcmp(argv[ai], "-boostNgrams") == 0) { 86 | boostNgrams = atof(argv[ai + 1]); 87 | } else if (strcmp(argv[ai], "-lrUpdateRate") == 0) { 88 | lrUpdateRate = atoi(argv[ai + 1]); 89 | } else if (strcmp(argv[ai], "-dim") == 0) { 90 | dim = atoi(argv[ai + 1]); 91 | } else if (strcmp(argv[ai], "-ws") == 0) { 92 | ws = atoi(argv[ai + 1]); 93 | } else if (strcmp(argv[ai], "-epoch") == 0) { 94 | epoch = atoi(argv[ai + 1]); 95 | } else if (strcmp(argv[ai], "-minCount") == 0) { 96 | minCount = atoi(argv[ai + 1]); 97 | } else if (strcmp(argv[ai], "-minCountLabel") == 0) { 98 | minCountLabel = atoi(argv[ai + 1]); 99 | } else if (strcmp(argv[ai], "-neg") == 0) { 100 | neg = atoi(argv[ai + 1]); 101 | } else if (strcmp(argv[ai], "-numCheckPoints") == 0) { 102 | numCheckPoints = atoi(argv[ai + 1]); 103 | } else if (strcmp(argv[ai], "-dropoutK") == 0) { 104 | dropoutK = atoi(argv[ai + 1]); 105 | } else if (strcmp(argv[ai], "-wordNgrams") == 0) { 106 | wordNgrams = atoi(argv[ai + 1]); 107 | if (wordNgrams == 1) bucket = 1; 108 | } else if (strcmp(argv[ai], "-loss") == 0) { 109 | if (strcmp(argv[ai + 1], "hs") == 0) { 110 | loss = loss_name::hs; 111 | } else if (strcmp(argv[ai + 1], "ns") == 0) { 112 | loss = loss_name::ns; 113 | } else if (strcmp(argv[ai + 1], "softmax") == 0) { 114 | loss = loss_name::softmax; 115 | } else { 116 | std::cout << "Unknown loss: " << argv[ai + 1] << std::endl; 117 | printHelp(); 118 | exit(EXIT_FAILURE); 119 | } 120 | } else if (strcmp(argv[ai], "-bucket") == 0) { 121 | bucket = atoi(argv[ai + 1]); 122 | } else if (strcmp(argv[ai], "-bucketChar") == 0) { 123 | bucketChar = atoi(argv[ai + 1]); 124 | } else if (strcmp(argv[ai], "-minn") == 0) { 125 | minn = atoi(argv[ai + 1]); 126 | } else if (strcmp(argv[ai], "-maxn") == 0) { 127 | maxn = atoi(argv[ai + 1]); 128 | } else if (strcmp(argv[ai], "-thread") == 0) { 129 | thread = atoi(argv[ai + 1]); 130 | } else if (strcmp(argv[ai], "-t") == 0) { 131 | t = atof(argv[ai + 1]); 132 | } else if (strcmp(argv[ai], "-label") == 0) { 133 | label = std::string(argv[ai + 1]); 134 | } else if (strcmp(argv[ai], "-verbose") == 0) { 135 | verbose = atoi(argv[ai + 1]); 136 | } else if (strcmp(argv[ai], "-maxVocabSize") == 0) { 137 | maxVocabSize = atoi(argv[ai + 1]); 138 | } else if (strcmp(argv[ai], "-pretrainedVectors") == 0) { 139 | pretrainedVectors = std::string(argv[ai + 1]); 140 | } else if (strcmp(argv[ai], "-saveOutput") == 0) { 141 | saveOutput = atoi(argv[ai + 1]); 142 | } else if (strcmp(argv[ai], "-qnorm") == 0) { 143 | qnorm = true; ai--; 144 | } else if (strcmp(argv[ai], "-retrain") == 0) { 145 | retrain = true; ai--; 146 | } else if (strcmp(argv[ai], "-qout") == 0) { 147 | qout = true; ai--; 148 | } else if (strcmp(argv[ai], "-cutoff") == 0) { 149 | cutoff = atoi(argv[ai + 1]); 150 | } else if (strcmp(argv[ai], "-dsub") == 0) { 151 | dsub = atoi(argv[ai + 1]); 152 | } else { 153 | std::cout << "Unknown argument: " << argv[ai] << std::endl; 154 | printHelp(); 155 | exit(EXIT_FAILURE); 156 | } 157 | ai += 2; 158 | } 159 | if (input.empty() || output.empty()) { 160 | std::cout << "Empty input or output path." << std::endl; 161 | printHelp(); 162 | exit(EXIT_FAILURE); 163 | } 164 | if (wordNgrams <= 1 && maxn == 0) { 165 | bucket = 0; 166 | } 167 | } 168 | 169 | void Args::printHelp() { 170 | std::string lname = "ns"; 171 | if (loss == loss_name::hs) lname = "hs"; 172 | if (loss == loss_name::softmax) lname = "softmax"; 173 | std::cout 174 | << "\n" 175 | << "The following arguments are mandatory:\n" 176 | << " -input training file path\n" 177 | << " -output output file path\n\n" 178 | << "The following arguments are optional:\n" 179 | << " -lr learning rate [" << lr << "]\n" 180 | << " -lrUpdateRate change the rate of updates for the learning rate [" << lrUpdateRate << "]\n" 181 | << " -dim size of word vectors [" << dim << "]\n" 182 | << " -ws size of the context window [" << ws << "]\n" 183 | << " -epoch number of epochs [" << epoch << "]\n" 184 | << " -minCount minimal number of word occurences [" << minCount << "]\n" 185 | << " -minCountLabel minimal number of label occurences [" << minCountLabel << "]\n" 186 | << " -neg number of negatives sampled [" << neg << "]\n" 187 | << " -wordNgrams max length of word ngram [" << wordNgrams << "]\n" 188 | << " -loss loss function {ns, hs, softmax} [ns]\n" 189 | << " -bucket number of buckets [" << bucket << "]\n" 190 | << " -maxVocabSize vocabulary exceeding this size will be truncated [None]\n" 191 | << " -numCheckPoints number of intermediary checkpoints to save when training [" << numCheckPoints << "]\n" 192 | << " -minn min length of char ngram [" << minn << "]\n" 193 | << " -maxn max length of char ngram [" << maxn << "]\n" 194 | << " -thread number of threads [" << thread << "]\n" 195 | << " -t sampling threshold [" << t << "]\n" 196 | << " -label labels prefix [" << label << "]\n" 197 | << " -dropoutK number of ngrams dropped when training a sent2vec model [" << dropoutK << "]\n" 198 | << " -verbose verbosity level [" << verbose << "]\n" 199 | << " -pretrainedVectors pretrained word vectors for supervised learning []\n" 200 | << " -saveOutput whether output params should be saved [" << saveOutput << "]\n" 201 | << "\nThe following arguments for quantization are optional:\n" 202 | << " -cutoff number of words and ngrams to retain [" << cutoff << "]\n" 203 | << " -retrain finetune embeddings if a cutoff is applied [" << retrain << "]\n" 204 | << " -qnorm quantizing the norm separately [" << qnorm << "]\n" 205 | << " -qout quantizing the classifier [" << qout << "]\n" 206 | << " -dsub size of each sub-vector [" << dsub << "]\n" 207 | << std::endl; 208 | } 209 | 210 | void Args::save(std::ostream& out) { 211 | out.write((char*) &(dim), sizeof(int)); 212 | out.write((char*) &(ws), sizeof(int)); 213 | out.write((char*) &(epoch), sizeof(int)); 214 | out.write((char*) &(minCount), sizeof(int)); 215 | out.write((char*) &(neg), sizeof(int)); 216 | out.write((char*) &(wordNgrams), sizeof(int)); 217 | out.write((char*) &(loss), sizeof(loss_name)); 218 | out.write((char*) &(model), sizeof(model_name)); 219 | out.write((char*) &(bucket), sizeof(int)); 220 | out.write((char*) &(minn), sizeof(int)); 221 | out.write((char*) &(maxn), sizeof(int)); 222 | out.write((char*) &(lrUpdateRate), sizeof(int)); 223 | out.write((char*) &(t), sizeof(double)); 224 | } 225 | 226 | void Args::load(std::istream& in) { 227 | in.read((char*) &(dim), sizeof(int)); 228 | in.read((char*) &(ws), sizeof(int)); 229 | in.read((char*) &(epoch), sizeof(int)); 230 | in.read((char*) &(minCount), sizeof(int)); 231 | in.read((char*) &(neg), sizeof(int)); 232 | in.read((char*) &(wordNgrams), sizeof(int)); 233 | in.read((char*) &(loss), sizeof(loss_name)); 234 | in.read((char*) &(model), sizeof(model_name)); 235 | in.read((char*) &(bucket), sizeof(int)); 236 | in.read((char*) &(minn), sizeof(int)); 237 | in.read((char*) &(maxn), sizeof(int)); 238 | in.read((char*) &(lrUpdateRate), sizeof(int)); 239 | in.read((char*) &(t), sizeof(double)); 240 | } 241 | 242 | } 243 | -------------------------------------------------------------------------------- /src/args.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_ARGS_H 11 | #define FASTTEXT_ARGS_H 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | namespace fasttext { 18 | 19 | enum class model_name : int {cbow=1, transgram, supi, bisent2vec}; 20 | enum class loss_name : int {hs=1, ns, softmax}; 21 | 22 | class Args { 23 | public: 24 | Args(); 25 | std::string input; 26 | std::string test; 27 | std::string output; 28 | std::string dict; 29 | double lr; 30 | double boostNgrams; 31 | int lrUpdateRate; 32 | int dim; 33 | int ws; 34 | int dropoutK; 35 | int epoch; 36 | int maxVocabSize; 37 | int minCount; 38 | int minCountLabel; 39 | int neg; 40 | int wordNgrams; 41 | int numCheckPoints; 42 | loss_name loss; 43 | model_name model; 44 | int bucket; 45 | int bucketChar; 46 | int minn; 47 | int maxn; 48 | int thread; 49 | double t; 50 | std::string label; 51 | int verbose; 52 | std::string pretrainedVectors; 53 | int saveOutput; 54 | 55 | bool qout; 56 | bool retrain; 57 | bool qnorm; 58 | size_t cutoff; 59 | size_t dsub; 60 | 61 | void parseArgs(int, char**); 62 | void printHelp(); 63 | void save(std::ostream&); 64 | void load(std::istream&); 65 | }; 66 | 67 | } 68 | 69 | #endif 70 | -------------------------------------------------------------------------------- /src/asvoid.h: -------------------------------------------------------------------------------- 1 | template 2 | inline void *asvoid(std::vector *buf) 3 | { 4 | std::vector& tmp = *buf; 5 | return (void*)(&tmp[0]); 6 | } -------------------------------------------------------------------------------- /src/dictionary.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "dictionary.h" 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | namespace fasttext { 21 | 22 | const std::string Dictionary::EOS = ""; 23 | const std::string Dictionary::BOW = "<"; 24 | const std::string Dictionary::EOW = ">"; 25 | 26 | Dictionary::Dictionary(std::shared_ptr args) : args_(args), 27 | word2int_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0), 28 | ntokens_(0) {} 29 | 30 | int32_t Dictionary::find(const std::string& w) const { 31 | int32_t h = hash(w) % MAX_VOCAB_SIZE; 32 | while (word2int_[h] != -1 && words_[word2int_[h]].word != w) { 33 | h = (h + 1) % MAX_VOCAB_SIZE; 34 | } 35 | return h; 36 | } 37 | 38 | void Dictionary::add(const std::string& w) { 39 | int32_t h = find(w); 40 | ntokens_++; 41 | if (word2int_[h] == -1) { 42 | entry e; 43 | e.word = w; 44 | e.count = 1; 45 | e.type = getType(w); 46 | words_.push_back(e); 47 | word2int_[h] = size_++; 48 | } else { 49 | words_[word2int_[h]].count++; 50 | } 51 | } 52 | 53 | int32_t Dictionary::nwords() const { 54 | return nwords_; 55 | } 56 | 57 | int32_t Dictionary::nlabels() const { 58 | return nlabels_; 59 | } 60 | 61 | int64_t Dictionary::ntokens() const { 62 | return ntokens_; 63 | } 64 | 65 | const std::vector& Dictionary::getNgrams(int32_t i) const { 66 | assert(i >= 0); 67 | assert(i < nwords_); 68 | return words_[i].subwords; 69 | } 70 | 71 | const std::vector Dictionary::getNgrams(const std::string& word) const { 72 | int32_t i = getId(word); 73 | if (i >= 0) { 74 | return getNgrams(i); 75 | } 76 | std::vector ngrams; 77 | computeNgrams(BOW + word + EOW, ngrams); 78 | return ngrams; 79 | } 80 | 81 | void Dictionary::getNgrams(const std::string& word, 82 | std::vector& ngrams, 83 | std::vector& substrings) const { 84 | int32_t i = getId(word); 85 | ngrams.clear(); 86 | substrings.clear(); 87 | if (i >= 0) { 88 | ngrams.push_back(i); 89 | substrings.push_back(words_[i].word); 90 | } else { 91 | ngrams.push_back(-1); 92 | substrings.push_back(word); 93 | } 94 | computeNgrams(BOW + word + EOW, ngrams, substrings); 95 | } 96 | 97 | bool Dictionary::discard(int32_t id, real rand) const { 98 | assert(id >= 0); 99 | assert(id < nwords_); 100 | if (args_->model == model_name::bisent2vec) return false; 101 | return rand > pdiscard_[id]; 102 | } 103 | 104 | int32_t Dictionary::getId(const std::string& w) const { 105 | int32_t h = find(w); 106 | return word2int_[h]; 107 | } 108 | 109 | entry_type Dictionary::getType(int32_t id) const { 110 | assert(id >= 0); 111 | assert(id < size_); 112 | return words_[id].type; 113 | } 114 | 115 | entry_type Dictionary::getType(const std::string& w) const { 116 | return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word; 117 | } 118 | 119 | std::string Dictionary::getWord(int32_t id) const { 120 | assert(id >= 0); 121 | assert(id < size_); 122 | 123 | return words_[id].word; 124 | } 125 | 126 | std::vector Dictionary::getVocab() const { 127 | std::vector vocab; 128 | for (auto& w : words_) { 129 | if (w.type == entry_type::word) vocab.push_back(w.word); 130 | } 131 | return vocab; 132 | } 133 | 134 | int64_t Dictionary::getTokenCount(int32_t id) const { 135 | assert(id >= 0); 136 | assert(id < size_); 137 | return words_[id].count; 138 | } 139 | 140 | real Dictionary::getPDiscard(int32_t id) const { 141 | assert(id >= 0); 142 | assert(id < size_); 143 | return pdiscard_[id]; 144 | } 145 | 146 | uint32_t Dictionary::hash(const std::string& str) const { 147 | uint32_t h = 2166136261; 148 | for (size_t i = 0; i < str.size(); i++) { 149 | h = h ^ uint32_t(str[i]); 150 | h = h * 16777619; 151 | } 152 | return h; 153 | } 154 | 155 | void Dictionary::computeNgrams(const std::string& word, 156 | std::vector& ngrams, 157 | std::vector& substrings) const { 158 | for (size_t i = 0; i < word.size(); i++) { 159 | std::string ngram; 160 | if ((word[i] & 0xC0) == 0x80) continue; 161 | for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) { 162 | ngram.push_back(word[j++]); 163 | while (j < word.size() && (word[j] & 0xC0) == 0x80) { 164 | ngram.push_back(word[j++]); 165 | } 166 | if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) { 167 | int32_t h = hash(ngram) % args_->bucket; 168 | ngrams.push_back(nwords_ + h); 169 | substrings.push_back(ngram); 170 | } 171 | } 172 | } 173 | } 174 | 175 | void Dictionary::computeNgrams(const std::string& word, 176 | std::vector& ngrams) const { 177 | for (size_t i = 0; i < word.size(); i++) { 178 | std::string ngram; 179 | if ((word[i] & 0xC0) == 0x80) continue; 180 | for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) { 181 | ngram.push_back(word[j++]); 182 | while (j < word.size() && (word[j] & 0xC0) == 0x80) { 183 | ngram.push_back(word[j++]); 184 | } 185 | if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) { 186 | int32_t h = hash(ngram) % args_->bucket; 187 | ngrams.push_back(nwords_ + h); 188 | } 189 | } 190 | } 191 | } 192 | 193 | void Dictionary::initNgrams() { 194 | for (size_t i = 0; i < size_; i++) { 195 | std::string word = BOW + words_[i].word + EOW; 196 | words_[i].subwords.push_back(i); 197 | computeNgrams(word, words_[i].subwords); 198 | } 199 | } 200 | 201 | bool Dictionary::readWord(std::istream& in, std::string& word) const 202 | { 203 | char c; 204 | std::streambuf& sb = *in.rdbuf(); 205 | word.clear(); 206 | while ((c = sb.sbumpc()) != EOF) { 207 | if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' || 208 | c == '\f' || c == '\0') { 209 | if (word.empty()) { 210 | if (c == '\n') { 211 | word += EOS; 212 | return true; 213 | } 214 | continue; 215 | } else { 216 | if (c == '\n') 217 | sb.sungetc(); 218 | return true; 219 | } 220 | } 221 | word.push_back(c); 222 | } 223 | // trigger eofbit 224 | in.get(); 225 | return !word.empty(); 226 | } 227 | 228 | void Dictionary::readFromFile(std::istream& in) { 229 | std::string word; 230 | int64_t minThreshold = 1; 231 | while (readWord(in, word)) { 232 | add(word); 233 | if (ntokens_ % 1000000 == 0 && args_->verbose > 1) { 234 | std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush; 235 | } 236 | if (size_ > 0.75 * MAX_VOCAB_SIZE) { 237 | minThreshold++; 238 | threshold(minThreshold, minThreshold); 239 | } 240 | } 241 | if (args_->model == model_name::bisent2vec) { 242 | int32_t h = find(""); 243 | entry e; 244 | e.word = ""; 245 | e.count = 1e+18; 246 | e.type = entry_type::word; 247 | words_.push_back(e); 248 | word2int_[h] = size_++; 249 | } 250 | threshold(args_->minCount, args_->minCountLabel); 251 | if (args_->maxVocabSize > 0) { 252 | truncate(args_->maxVocabSize); 253 | } 254 | initTableDiscard(); 255 | initNgrams(); 256 | if (args_->model == model_name::bisent2vec) { 257 | assert(words_[0].word == ""); 258 | words_[0].count = 0; 259 | } 260 | if (args_->verbose > 0) { 261 | std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl; 262 | std::cerr << "Number of words: " << nwords_ << std::endl; 263 | std::cerr << "Number of labels: " << nlabels_ << std::endl; 264 | } 265 | if (size_ == 0) { 266 | std::cerr << "Empty vocabulary. Try a smaller -minCount value." 267 | << std::endl; 268 | exit(EXIT_FAILURE); 269 | } 270 | } 271 | 272 | void Dictionary::threshold(int64_t t, int64_t tl) { 273 | sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) { 274 | if (e1.type != e2.type) return e1.type < e2.type; 275 | return e1.count > e2.count; 276 | }); 277 | words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) { 278 | return (e.type == entry_type::word && e.count < t) || 279 | (e.type == entry_type::label && e.count < tl); 280 | }), words_.end()); 281 | words_.shrink_to_fit(); 282 | size_ = 0; 283 | nwords_ = 0; 284 | nlabels_ = 0; 285 | std::fill(word2int_.begin(), word2int_.end(), -1); 286 | for (auto it = words_.begin(); it != words_.end(); ++it) { 287 | int32_t h = find(it->word); 288 | word2int_[h] = size_++; 289 | if (it->type == entry_type::word) nwords_++; 290 | if (it->type == entry_type::label) nlabels_++; 291 | } 292 | } 293 | 294 | void Dictionary::truncate(int64_t maxVocabSize) { 295 | if (maxVocabSize >= words_.size()) 296 | return; 297 | sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) { 298 | if (e1.type != e2.type) return e1.type < e2.type; 299 | return e1.count > e2.count; 300 | }); 301 | words_.resize(maxVocabSize); 302 | words_.shrink_to_fit(); 303 | size_ = 0; 304 | nwords_ = 0; 305 | nlabels_ = 0; 306 | std::fill(word2int_.begin(), word2int_.end(), -1); 307 | for (auto it = words_.begin(); it != words_.end(); ++it) { 308 | int32_t h = find(it->word); 309 | word2int_[h] = size_++; 310 | if (it->type == entry_type::word) nwords_++; 311 | if (it->type == entry_type::label) nlabels_++; 312 | } 313 | } 314 | 315 | void Dictionary::initTableDiscard() { 316 | pdiscard_.resize(size_); 317 | for (size_t i = 0; i < size_; i++) { 318 | real f = real(words_[i].count) / real(ntokens_); 319 | pdiscard_[i] = std::sqrt(args_->t / f) + args_->t / f; 320 | } 321 | } 322 | 323 | std::vector Dictionary::getCounts(entry_type type) const { 324 | std::vector counts; 325 | for (auto& w : words_) { 326 | if (w.type == type) counts.push_back(w.count); 327 | } 328 | return counts; 329 | } 330 | 331 | void Dictionary::addNgrams(std::vector& line, 332 | const std::vector& hashes, 333 | int32_t n) const { 334 | if (pruneidx_size_ == 0) return; 335 | for (int32_t i = 0; i < hashes.size(); i++) { 336 | uint64_t h = hashes[i]; 337 | for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) { 338 | h = h * 116049371 + hashes[j]; 339 | int64_t id = h % args_->bucket; 340 | if (pruneidx_size_ > 0) { 341 | if (pruneidx_.count(id)) { 342 | id = pruneidx_.at(id); 343 | } else {continue;} 344 | } 345 | line.push_back(nwords_ + id); 346 | } 347 | } 348 | } 349 | 350 | void Dictionary::addNgrams(std::vector& out, std::vector& line, int32_t start, int32_t end, int32_t n, int32_t k, std::minstd_rand& rng) { 351 | int32_t num_discarded = 0; 352 | int32_t line_size = end - start; 353 | std::vector discard; 354 | discard.resize(line_size, false); 355 | std::uniform_int_distribution<> uniform(0, line_size - 1); 356 | while (num_discarded < k && line_size - num_discarded > 2) { 357 | int32_t token_to_discard = uniform(rng); 358 | if (!discard[token_to_discard]) { 359 | discard[token_to_discard] = true; 360 | num_discarded++; 361 | } 362 | } 363 | for (int32_t i = start; i <= end; i++) { 364 | if (discard[i - start]) continue; 365 | uint64_t h = line[i]; 366 | for (int32_t j = i + 1; j <= end && j < i + n; j++) { 367 | if (discard[j - start]) break; 368 | h = h * 116049371 + line[j]; 369 | out.push_back(nwords_ + (h % args_->bucket)); 370 | } 371 | } 372 | 373 | } 374 | 375 | void Dictionary::addNgrams(std::vector& line, int32_t n, int32_t k, std::minstd_rand& rng) const { 376 | int32_t num_discarded = 0; 377 | int32_t line_size = line.size(); 378 | std::vector discard; 379 | discard.resize(line_size, false); 380 | std::uniform_int_distribution<> uniform(0, line_size - 1); 381 | while (num_discarded < k && line_size - num_discarded > 2) { 382 | int32_t token_to_discard = uniform(rng); 383 | if (!discard[token_to_discard]) { 384 | discard[token_to_discard] = true; 385 | num_discarded++; 386 | } 387 | } 388 | for (int32_t i = 0; i < line_size; i++) { 389 | if (discard[i]) continue; 390 | uint64_t h = line[i]; 391 | for (int32_t j = i + 1; j < line_size && j < i + n; j++) { 392 | if (discard[j]) break; 393 | h = h * 116049371 + line[j]; 394 | line.push_back(nwords_ + (h % args_->bucket)); 395 | } 396 | } 397 | } 398 | 399 | void Dictionary::addNgrams(std::vector& line, int32_t n) const { 400 | int32_t line_size = line.size(); 401 | for (int32_t i = 0; i < line_size; i++) { 402 | uint64_t h = line[i]; 403 | for (int32_t j = i + 1; j < line_size && j < i + n; j++) { 404 | h = h * 116049371 + line[j]; 405 | line.push_back(nwords_ + (h % args_->bucket)); 406 | } 407 | } 408 | } 409 | 410 | int32_t Dictionary::getLine(std::istream& in, 411 | std::vector& words, 412 | std::vector& word_hashes, 413 | std::vector& labels, 414 | std::minstd_rand& rng) const { 415 | std::uniform_real_distribution<> uniform(0, 1); 416 | 417 | if (in.eof()) { 418 | in.clear(); 419 | in.seekg(std::streampos(0)); 420 | } 421 | 422 | words.clear(); 423 | labels.clear(); 424 | word_hashes.clear(); 425 | int32_t ntokens = 0; 426 | std::string token; 427 | while (readWord(in, token)) { 428 | if (token == EOS && (args_-> model == model_name::bisent2vec)){ 429 | break; 430 | } 431 | int32_t h = find(token); 432 | int32_t wid = word2int_[h]; 433 | if (wid < 0) { 434 | entry_type type = getType(token); 435 | if (type == entry_type::word) word_hashes.push_back(hash(token)); 436 | continue; 437 | } 438 | entry_type type = getType(wid); 439 | ntokens++; 440 | if (type == entry_type::word && !discard(wid, uniform(rng))) { 441 | words.push_back(wid); 442 | word_hashes.push_back(hash(token)); 443 | } 444 | if (type == entry_type::label) { 445 | labels.push_back(wid - nwords_); 446 | } 447 | if (token == EOS) break; 448 | if (ntokens > MAX_LINE_SIZE && args_->model != model_name::bisent2vec) break; 449 | } 450 | return ntokens; 451 | } 452 | 453 | int32_t Dictionary::getLine(std::istream& in, 454 | std::vector& words, 455 | std::vector& labels, 456 | std::minstd_rand& rng) const { 457 | std::vector word_hashes; 458 | int32_t ntokens = getLine(in, words, word_hashes, labels, rng); 459 | 460 | return ntokens; 461 | } 462 | 463 | std::string Dictionary::getLabel(int32_t lid) const { 464 | assert(lid >= 0); 465 | assert(lid < nlabels_); 466 | return words_[lid + nwords_].word; 467 | } 468 | 469 | void Dictionary::save(std::ostream& out) const { 470 | out.write((char*) &size_, sizeof(int32_t)); 471 | out.write((char*) &nwords_, sizeof(int32_t)); 472 | out.write((char*) &nlabels_, sizeof(int32_t)); 473 | out.write((char*) &ntokens_, sizeof(int64_t)); 474 | out.write((char*) &pruneidx_size_, sizeof(int64_t)); 475 | for (int32_t i = 0; i < size_; i++) { 476 | entry e = words_[i]; 477 | out.write(e.word.data(), e.word.size() * sizeof(char)); 478 | out.put(0); 479 | out.write((char*) &(e.count), sizeof(int64_t)); 480 | out.write((char*) &(e.type), sizeof(entry_type)); 481 | } 482 | for (const auto pair : pruneidx_) { 483 | out.write((char*) &(pair.first), sizeof(int32_t)); 484 | out.write((char*) &(pair.second), sizeof(int32_t)); 485 | } 486 | } 487 | 488 | void Dictionary::load(std::istream& in) { 489 | words_.clear(); 490 | std::fill(word2int_.begin(), word2int_.end(), -1); 491 | in.read((char*) &size_, sizeof(int32_t)); 492 | in.read((char*) &nwords_, sizeof(int32_t)); 493 | in.read((char*) &nlabels_, sizeof(int32_t)); 494 | in.read((char*) &ntokens_, sizeof(int64_t)); 495 | in.read((char*) &pruneidx_size_, sizeof(int64_t)); 496 | for (int32_t i = 0; i < size_; i++) { 497 | char c; 498 | entry e; 499 | while ((c = in.get()) != 0) { 500 | e.word.push_back(c); 501 | } 502 | in.read((char*) &e.count, sizeof(int64_t)); 503 | in.read((char*) &e.type, sizeof(entry_type)); 504 | words_.push_back(e); 505 | word2int_[find(e.word)] = i; 506 | } 507 | pruneidx_.clear(); 508 | for (int32_t i = 0; i < pruneidx_size_; i++) { 509 | int32_t first; 510 | int32_t second; 511 | in.read((char*) &first, sizeof(int32_t)); 512 | in.read((char*) &second, sizeof(int32_t)); 513 | pruneidx_[first] = second; 514 | } 515 | initTableDiscard(); 516 | initNgrams(); 517 | } 518 | 519 | void Dictionary::prune(std::vector& idx) { 520 | std::vector words, ngrams; 521 | for (auto it = idx.cbegin(); it != idx.cend(); ++it) { 522 | if (*it < nwords_) {words.push_back(*it);} 523 | else {ngrams.push_back(*it);} 524 | } 525 | std::sort(words.begin(), words.end()); 526 | idx = words; 527 | 528 | if (ngrams.size() != 0) { 529 | int32_t j = 0; 530 | for (const auto ngram : ngrams) { 531 | pruneidx_[ngram - nwords_] = j; 532 | j++; 533 | } 534 | idx.insert(idx.end(), ngrams.begin(), ngrams.end()); 535 | } 536 | pruneidx_size_ = pruneidx_.size(); 537 | 538 | std::fill(word2int_.begin(), word2int_.end(), -1); 539 | 540 | int32_t j = 0; 541 | for (int32_t i = 0; i < words_.size(); i++) { 542 | if (getType(i) == entry_type::label || (j < words.size() && words[j] == i)) { 543 | words_[j] = words_[i]; 544 | word2int_[find(words_[j].word)] = j; 545 | j++; 546 | } 547 | } 548 | nwords_ = words.size(); 549 | size_ = nwords_ + nlabels_; 550 | words_.erase(words_.begin() + size_, words_.end()); 551 | } 552 | 553 | } 554 | -------------------------------------------------------------------------------- /src/dictionary.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_DICTIONARY_H 11 | #define FASTTEXT_DICTIONARY_H 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include "args.h" 22 | #include "real.h" 23 | 24 | namespace fasttext { 25 | 26 | typedef int32_t id_type; 27 | enum class entry_type : int8_t {word=0, label=1}; 28 | 29 | struct entry { 30 | std::string word; 31 | int64_t count; 32 | entry_type type; 33 | std::vector subwords; 34 | }; 35 | 36 | class Dictionary { 37 | private: 38 | static const int32_t MAX_VOCAB_SIZE = 30000000; 39 | static const int32_t MAX_LINE_SIZE = 1024; 40 | 41 | int32_t find(const std::string&) const; 42 | void initTableDiscard(); 43 | void initNgrams(); 44 | 45 | std::shared_ptr args_; 46 | std::vector word2int_; 47 | std::vector words_; 48 | 49 | std::vector pdiscard_; 50 | int32_t size_; 51 | int32_t nwords_; 52 | int32_t nlabels_; 53 | int64_t ntokens_; 54 | 55 | int64_t pruneidx_size_ = -1; 56 | std::unordered_map pruneidx_; 57 | void addNgrams( 58 | std::vector& line, 59 | const std::vector& hashes, 60 | int32_t n) const; 61 | 62 | public: 63 | static const std::string EOS; 64 | static const std::string BOW; 65 | static const std::string EOW; 66 | 67 | explicit Dictionary(std::shared_ptr); 68 | int32_t nwords() const; 69 | int32_t nlabels() const; 70 | int64_t ntokens() const; 71 | real getPDiscard(int32_t) const; 72 | int32_t getId(const std::string&) const; 73 | int64_t getTokenCount(int32_t) const; 74 | entry_type getType(int32_t) const; 75 | entry_type getType(const std::string&) const; 76 | bool discard(int32_t, real) const; 77 | std::string getWord(int32_t) const; 78 | const std::vector& getNgrams(int32_t) const; 79 | const std::vector getNgrams(const std::string&) const; 80 | void getNgrams(const std::string&, std::vector&, 81 | std::vector&) const; 82 | void computeNgrams(const std::string&, std::vector&) const; 83 | void computeNgrams(const std::string&, std::vector&, 84 | std::vector&) const; 85 | uint32_t hash(const std::string& str) const; 86 | void add(const std::string&); 87 | bool readWord(std::istream&, std::string&) const; 88 | void readFromFile(std::istream&); 89 | std::string getLabel(int32_t) const; 90 | void save(std::ostream&) const; 91 | void load(std::istream&); 92 | std::vector getCounts(entry_type) const; 93 | std::vector getVocab() const; 94 | void addNgrams(std::vector&, int32_t, int32_t, std::minstd_rand&) const; 95 | void addNgrams(std::vector&, std::vector&e, int32_t, int32_t, int32_t, int32_t, std::minstd_rand&); 96 | void addNgrams(std::vector&, int32_t) const; 97 | int32_t getLine(std::istream&, std::vector&, std::vector&, 98 | std::vector&, std::minstd_rand&) const; 99 | int32_t getLine(std::istream&, std::vector&, 100 | std::vector&, std::minstd_rand&) const; 101 | void threshold(int64_t, int64_t); 102 | void truncate(int64_t); 103 | void prune(std::vector&); 104 | void convertNgrams(std::vector&); 105 | }; 106 | 107 | } 108 | 109 | #endif 110 | -------------------------------------------------------------------------------- /src/fasttext.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "fasttext.h" 11 | #include "shmem_matrix.h" 12 | 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | 26 | namespace fasttext { 27 | 28 | FastText::FastText() : quant_(false) {} 29 | 30 | std::vector FastText::getVocab() { 31 | return dict_->getVocab(); 32 | } 33 | 34 | std::vector FastText::getUnigramsCounts() { 35 | return dict_->getCounts(entry_type::word); 36 | } 37 | 38 | void FastText::getVector(Vector& vec, const std::string& word) const { 39 | const std::vector& ngrams = dict_->getNgrams(word); 40 | vec.zero(); 41 | for (auto it = ngrams.begin(); it != ngrams.end(); ++it) { 42 | vec.addRow(*input_, *it); 43 | } 44 | if (ngrams.size() > 0) { 45 | vec.mul(1.0 / ngrams.size()); 46 | } 47 | } 48 | 49 | void FastText::saveVectors() { 50 | std::ofstream ofs(args_->output + ".vec"); 51 | if (!ofs.is_open()) { 52 | std::cerr << "Error opening file for saving vectors." << std::endl; 53 | exit(EXIT_FAILURE); 54 | } 55 | ofs << dict_->nwords() << " " << args_->dim << std::endl; 56 | Vector vec(args_->dim); 57 | for (int32_t i = 0; i < dict_->nwords(); i++) { 58 | std::string word = dict_->getWord(i); 59 | getVector_ckpt(vec, i); 60 | ofs << word << " " << vec << std::endl; 61 | } 62 | ofs.close(); 63 | } 64 | 65 | void FastText::saveVectors(int32_t checkpoint) { 66 | std::ofstream ofs(args_->output + "_Chk" + std::to_string(checkpoint) + ".ckpt.vec"); 67 | if (!ofs.is_open()) { 68 | std::cerr << "Error opening file for saving vectors." << std::endl; 69 | exit(EXIT_FAILURE); 70 | } 71 | ofs << dict_->nwords() << " " << args_->dim << std::endl; 72 | Vector vec(args_->dim); 73 | for (int32_t i = 0; i < dict_->nwords(); i++) { 74 | std::string word = dict_->getWord(i); 75 | getVector_ckpt(vec, i); 76 | ofs << word << " " << vec << std::endl; 77 | } 78 | ofs.close(); 79 | } 80 | 81 | void FastText::saveOutput() { 82 | std::ofstream ofs(args_->output + ".output"); 83 | if (!ofs.is_open()) { 84 | std::cerr << "Error opening file for saving vectors." << std::endl; 85 | exit(EXIT_FAILURE); 86 | } 87 | ofs << dict_->nwords() << " " << args_->dim << std::endl; 88 | Vector vec(args_->dim); 89 | for (int32_t i = 0; i < dict_->nwords(); i++) { 90 | std::string word = dict_->getWord(i); 91 | vec.zero(); 92 | vec.addRow(*output_, i); 93 | ofs << word << " " << vec << std::endl; 94 | } 95 | ofs.close(); 96 | } 97 | 98 | bool FastText::checkModel(std::istream& in) { 99 | int32_t magic; 100 | int32_t version; 101 | in.read((char*)&(magic), sizeof(int32_t)); 102 | if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) { 103 | return false; 104 | } 105 | in.read((char*)&(version), sizeof(int32_t)); 106 | if (version != FASTTEXT_VERSION) { 107 | return false; 108 | } 109 | return true; 110 | } 111 | 112 | void FastText::signModel(std::ostream& out) { 113 | const int32_t magic = FASTTEXT_FILEFORMAT_MAGIC_INT32; 114 | const int32_t version = FASTTEXT_VERSION; 115 | out.write((char*)&(magic), sizeof(int32_t)); 116 | out.write((char*)&(version), sizeof(int32_t)); 117 | } 118 | 119 | void FastText::saveDict() { 120 | std::string fn(args_->output); 121 | std::ofstream ofs(fn, std::ofstream::binary); 122 | if (!ofs.is_open()) { 123 | std::cerr << "Model file cannot be opened for saving!" << std::endl; 124 | exit(EXIT_FAILURE); 125 | } 126 | args_->save(ofs); 127 | dict_->save(ofs); 128 | ofs.close(); 129 | } 130 | 131 | void FastText::loadDict(const std::string& filename) { 132 | std::ifstream ifs(filename, std::ifstream::binary); 133 | if (!ifs.is_open()) { 134 | std::cerr << "Model file cannot be opened for loading!" << std::endl; 135 | exit(EXIT_FAILURE); 136 | } 137 | if (!checkModel(ifs)) { 138 | std::cerr << "Model file has wrong file format!" << std::endl; 139 | exit(EXIT_FAILURE); 140 | } 141 | loadDict(ifs); 142 | ifs.close(); 143 | } 144 | 145 | void FastText::loadDict(std::istream& in) { 146 | args_ = std::make_shared(); 147 | dict_ = std::make_shared(args_); 148 | 149 | args_->load(in); 150 | dict_->load(in); 151 | } 152 | 153 | void FastText::saveModel() { 154 | std::string fn(args_->output); 155 | if (quant_) { 156 | fn += ".ftz"; 157 | } else { 158 | fn += ".bin"; 159 | } 160 | std::ofstream ofs(fn, std::ofstream::binary); 161 | if (!ofs.is_open()) { 162 | std::cerr << "Model file cannot be opened for saving!" << std::endl; 163 | exit(EXIT_FAILURE); 164 | } 165 | signModel(ofs); 166 | args_->save(ofs); 167 | dict_->save(ofs); 168 | 169 | ofs.write((char*)&(quant_), sizeof(bool)); 170 | if (quant_) { 171 | qinput_->save(ofs); 172 | } else { 173 | input_->save(ofs); 174 | } 175 | 176 | ofs.write((char*)&(args_->qout), sizeof(bool)); 177 | if (quant_ && args_->qout) { 178 | qoutput_->save(ofs); 179 | } else { 180 | output_->save(ofs); 181 | } 182 | 183 | ofs.close(); 184 | } 185 | 186 | void FastText::saveModel(int32_t checkpoint) { 187 | std::string fn(args_->output + "_Chk" + std::to_string(checkpoint) + ".ckpt"); 188 | if (quant_) { 189 | fn += ".ftz"; 190 | } else { 191 | fn += ".bin"; 192 | } 193 | std::ofstream ofs(fn, std::ofstream::binary); 194 | if (!ofs.is_open()) { 195 | std::cerr << "Model file cannot be opened for saving!" << std::endl; 196 | exit(EXIT_FAILURE); 197 | } 198 | signModel(ofs); 199 | args_->save(ofs); 200 | dict_->save(ofs); 201 | 202 | ofs.write((char*)&(quant_), sizeof(bool)); 203 | if (quant_) { 204 | qinput_->save(ofs); 205 | } else { 206 | input_->save(ofs); 207 | } 208 | 209 | ofs.write((char*)&(args_->qout), sizeof(bool)); 210 | if (quant_ && args_->qout) { 211 | qoutput_->save(ofs); 212 | } else { 213 | output_->save(ofs); 214 | } 215 | 216 | ofs.close(); 217 | } 218 | 219 | void FastText::loadModel(const std::string& filename, 220 | const bool inference_mode /* = false */, 221 | const int timeout_sec /* = -1 */) { 222 | std::ifstream ifs(filename, std::ifstream::binary); 223 | if (!ifs.is_open()) { 224 | std::cerr << "Model file cannot be opened for loading!" << std::endl; 225 | exit(EXIT_FAILURE); 226 | } 227 | if (!checkModel(ifs)) { 228 | std::cerr << "Model file has wrong file format!" << std::endl; 229 | exit(EXIT_FAILURE); 230 | } 231 | if (inference_mode) { 232 | loadModelForInference(ifs, filename, timeout_sec); 233 | } else { 234 | loadModel(ifs); 235 | } 236 | ifs.close(); 237 | } 238 | 239 | void FastText::loadModel(std::istream& in) { 240 | args_ = std::make_shared(); 241 | dict_ = std::make_shared(args_); 242 | input_ = std::make_shared(); 243 | output_ = std::make_shared(); 244 | qinput_ = std::make_shared(); 245 | qoutput_ = std::make_shared(); 246 | args_->load(in); 247 | 248 | dict_->load(in); 249 | 250 | bool quant_input; 251 | in.read((char*) &quant_input, sizeof(bool)); 252 | if (quant_input) { 253 | quant_ = true; 254 | qinput_->load(in); 255 | } else { 256 | input_->load(in); 257 | } 258 | 259 | in.read((char*) &args_->qout, sizeof(bool)); 260 | if (quant_ && args_->qout) { 261 | qoutput_->load(in); 262 | } else { 263 | output_->load(in); 264 | } 265 | 266 | model_ = std::make_shared(input_, output_, args_, 0); 267 | model_->quant_ = quant_; 268 | model_->setQuantizePointer(qinput_, qoutput_, args_->qout); 269 | 270 | model_->setTargetCounts(dict_->getCounts(entry_type::word)); 271 | 272 | } 273 | 274 | static std::string basename(const std::string& filename) { 275 | std::string s = filename; 276 | size_t separator_idx = s.find_last_of("\\/"); 277 | if (separator_idx != std::string::npos) { 278 | s.erase(0, separator_idx + 1); 279 | } 280 | size_t extension_idx = s.rfind('.'); 281 | if (extension_idx != std::string::npos) { 282 | s.erase(extension_idx); 283 | } 284 | return s; 285 | } 286 | 287 | void FastText::loadModelForInference(std::istream& in, 288 | const std::string& filename, 289 | const int timeout_sec) { 290 | std::string shmem_name = "s2v_" + basename(filename) + "_input_matrix"; 291 | 292 | args_ = std::make_shared(); 293 | args_->load(in); 294 | 295 | dict_ = std::make_shared(args_); 296 | dict_->load(in); 297 | 298 | in.read((char*) &quant_, sizeof(bool)); 299 | 300 | input_ = ShmemMatrix::load(in, shmem_name, timeout_sec); 301 | 302 | in.read((char*) &args_->qout, sizeof(bool)); 303 | 304 | output_ = std::make_shared(); 305 | in.read((char*) &(output_->m_), sizeof(int64_t)); 306 | in.read((char*) &(output_->n_), sizeof(int64_t)); 307 | 308 | model_ = std::make_shared(input_, output_, args_, 0); 309 | 310 | model_->setTargetCounts(dict_->getCounts(entry_type::word)); 311 | } 312 | 313 | void FastText::printInfo(real progress, real loss) { 314 | real t = real(clock() - start) / CLOCKS_PER_SEC; 315 | real wst = real(tokenCount) / t; 316 | real lr = args_->lr * (1.0 - progress); 317 | int eta = int(t / progress * (1 - progress) / args_->thread); 318 | int etah = eta / 3600; 319 | int etam = (eta - etah * 3600) / 60; 320 | std::cerr << std::fixed; 321 | std::cerr << "\rProgress: " << std::setprecision(1) << 100 * progress << "%"; 322 | std::cerr << " words/sec/thread: " << std::setprecision(0) << wst; 323 | std::cerr << " lr: " << std::setprecision(6) << lr; 324 | std::cerr << " loss: " << std::setprecision(6) << loss; 325 | std::cerr << " eta: " << etah << "h" << etam << "m "; 326 | std::cerr << std::flush; 327 | } 328 | 329 | std::vector FastText::selectEmbeddings(int32_t cutoff) const { 330 | Vector norms(input_->m_); 331 | input_->l2NormRow(norms); 332 | std::vector idx(input_->m_, 0); 333 | std::iota(idx.begin(), idx.end(), 0); 334 | auto eosid = dict_->getId(Dictionary::EOS); 335 | std::sort(idx.begin(), idx.end(), 336 | [&norms, eosid] (size_t i1, size_t i2) { 337 | return eosid ==i1 || (eosid != i2 && norms[i1] > norms[i2]); 338 | }); 339 | idx.erase(idx.begin() + cutoff, idx.end()); 340 | return idx; 341 | } 342 | 343 | void FastText::quantize(std::shared_ptr qargs) { 344 | if (qargs->output.empty()) { 345 | std::cerr<<"No model provided!"<output + ".bin"); 348 | 349 | args_->input = qargs->input; 350 | args_->qout = qargs->qout; 351 | args_->output = qargs->output; 352 | 353 | 354 | if (qargs->cutoff > 0 && qargs->cutoff < input_->m_) { 355 | auto idx = selectEmbeddings(qargs->cutoff); 356 | dict_->prune(idx); 357 | std::shared_ptr ninput = 358 | std::make_shared (idx.size(), args_->dim); 359 | for (auto i = 0; i < idx.size(); i++) { 360 | for (auto j = 0; j < args_->dim; j++) { 361 | ninput->at(i,j) = input_->at(idx[i], j); 362 | } 363 | } 364 | input_ = ninput; 365 | if (qargs->retrain) { 366 | args_->epoch = qargs->epoch; 367 | args_->lr = qargs->lr; 368 | args_->thread = qargs->thread; 369 | args_->verbose = qargs->verbose; 370 | tokenCount = 0; 371 | std::vector threads; 372 | for (int32_t i = 0; i < args_->thread; i++) { 373 | threads.push_back(std::thread([=]() { trainThread(i); })); 374 | } 375 | for (auto it = threads.begin(); it != threads.end(); ++it) { 376 | it->join(); 377 | } 378 | } 379 | } 380 | 381 | qinput_ = std::make_shared(*input_, qargs->dsub, qargs->qnorm); 382 | 383 | if (args_->qout) { 384 | qoutput_ = std::make_shared(*output_, 2, qargs->qnorm); 385 | } 386 | 387 | quant_ = true; 388 | saveModel(); 389 | } 390 | 391 | void FastText::supervised(Model& model, real lr, 392 | const std::vector& line, 393 | const std::vector& labels) { 394 | if (labels.size() == 0 || line.size() == 0) return; 395 | std::uniform_int_distribution<> uniform(0, labels.size() - 1); 396 | int32_t i = uniform(model.rng); 397 | model.update(line, labels[i], lr); 398 | } 399 | 400 | void FastText::transgram(Model& model, real lr, 401 | const std::vector& line) { 402 | if (line.size() <= 1) return; 403 | std::uniform_int_distribution<> uniform(1, args_->ws); 404 | std::vector line1; 405 | std::vector line2; 406 | int32_t flag = 0; 407 | for (int32_t i=0; igetId("<>")){ 409 | flag = 1; 410 | continue; 411 | } 412 | if (flag==0) 413 | line1.push_back(line[i]); 414 | else 415 | line2.push_back(line[i]); 416 | } 417 | for (int32_t w = 0; w < line1.size(); w++) { 418 | int32_t boundary = uniform(model.rng); 419 | std::vector word; 420 | word.push_back(line1[w]); 421 | for (int32_t c = -boundary; c <= boundary; c++) { 422 | if (c != 0 && w + c >= 0 && w + c < line1.size()) { 423 | model.update(word, line1[w + c], lr); 424 | } 425 | } 426 | for (int32_t i=0; i word; 433 | word.push_back(line2[w]); 434 | for (int32_t c = -boundary; c <= boundary; c++) { 435 | if (c != 0 && w + c >= 0 && w + c < line2.size()) { 436 | model.update(word, line2[w + c], lr); 437 | } 438 | } 439 | for (int32_t i=0; i& line) { 448 | std::vector bow; 449 | std::uniform_int_distribution<> uniform(1, args_->ws); 450 | for (int32_t w = 0; w < line.size(); w++) { 451 | int32_t boundary = uniform(model.rng); 452 | bow.clear(); 453 | for (int32_t c = -boundary; c <= boundary; c++) { 454 | if (c != 0 && w + c >= 0 && w + c < line.size()) { 455 | const std::vector& ngrams = dict_->getNgrams(line[w + c]); 456 | bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend()); 457 | } 458 | } 459 | model.update(bow, line[w], lr); 460 | } 461 | } 462 | 463 | void FastText::bisent2vec(Model& model, real lr, const std::vector& line){ 464 | if (line.size() <= 1) return; 465 | std::vector context; 466 | std::vector line1; 467 | std::vector line2; 468 | int32_t flag = 0; 469 | for (int32_t i=0; igetId("<>")){ 471 | flag = 1; 472 | continue; 473 | } 474 | if (flag==0) 475 | line1.push_back(line[i]); 476 | else 477 | line2.push_back(line[i]); 478 | } 479 | std::uniform_real_distribution<> uniform(0, 1); 480 | for (int32_t i=0; i dict_->getPDiscard(line1[i]) || dict_->getTokenCount(line1[i]) < args_->minCountLabel) 482 | continue; 483 | context = line1; 484 | context[i] = 0; 485 | dict_->addNgrams(context, args_->wordNgrams, args_->dropoutK, model.rng); 486 | model.update(context, line1[i], lr); 487 | context = line2; 488 | dict_->addNgrams(context, args_->wordNgrams, args_->dropoutK, model.rng); 489 | model.update(context, line1[i], lr); 490 | 491 | } 492 | for (int32_t i=0; i dict_->getPDiscard(line2[i]) || dict_->getTokenCount(line2[i]) < args_->minCountLabel) 494 | continue; 495 | context = line2; 496 | context[i] = 0; 497 | dict_->addNgrams(context, args_->wordNgrams, args_->dropoutK, model.rng); 498 | model.update(context, line2[i], lr); 499 | context = line1; 500 | dict_->addNgrams(context, args_->wordNgrams, args_->dropoutK, model.rng); 501 | model.update(context, line2[i], lr); 502 | } 503 | } 504 | 505 | void FastText::test(std::istream& in, int32_t k) { 506 | int32_t nexamples = 0, nlabels = 0; 507 | double precision = 0.0; 508 | std::vector line, labels; 509 | 510 | while (in.peek() != EOF) { 511 | dict_->getLine(in, line, labels, model_->rng); 512 | if (labels.size() > 0 && line.size() > 0) { 513 | std::vector> modelPredictions; 514 | model_->predict(line, k, modelPredictions); 515 | for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) { 516 | if (std::find(labels.begin(), labels.end(), it->second) != labels.end()) { 517 | precision += 1.0; 518 | } 519 | } 520 | nexamples++; 521 | nlabels += labels.size(); 522 | } 523 | } 524 | std::cout << "N" << "\t" << nexamples << std::endl; 525 | std::cout << std::setprecision(3); 526 | std::cout << "P@" << k << "\t" << precision / (k * nexamples) << std::endl; 527 | std::cout << "R@" << k << "\t" << precision / nlabels << std::endl; 528 | std::cerr << "Number of examples: " << nexamples << std::endl; 529 | } 530 | 531 | void FastText::predict(std::istream& in, int32_t k, 532 | std::vector>& predictions) const { 533 | std::vector words, labels; 534 | dict_->getLine(in, words, labels, model_->rng); 535 | if (words.empty()) return; 536 | Vector hidden(args_->dim); 537 | Vector output(dict_->nlabels()); 538 | std::vector> modelPredictions; 539 | model_->predict(words, k, modelPredictions, hidden, output); 540 | predictions.clear(); 541 | for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) { 542 | predictions.push_back(std::make_pair(it->first, dict_->getLabel(it->second))); 543 | } 544 | } 545 | 546 | void FastText::predict(std::istream& in, int32_t k, bool print_prob) { 547 | std::vector> predictions; 548 | while (in.peek() != EOF) { 549 | predict(in, k, predictions); 550 | if (predictions.empty()) { 551 | std::cout << std::endl; 552 | continue; 553 | } 554 | for (auto it = predictions.cbegin(); it != predictions.cend(); it++) { 555 | if (it != predictions.cbegin()) { 556 | std::cout << " "; 557 | } 558 | std::cout << it->second; 559 | if (print_prob) { 560 | std::cout << " " << exp(it->first); 561 | } 562 | } 563 | std::cout << std::endl; 564 | } 565 | } 566 | 567 | void FastText::wordVectors() { 568 | std::string word; 569 | Vector vec(args_->dim); 570 | while (std::cin >> word) { 571 | getVector(vec, word); 572 | std::cout << word << " " << vec << std::endl; 573 | } 574 | } 575 | 576 | void FastText::sentenceVectors() { 577 | Vector vec(args_->dim); 578 | std::string sentence; 579 | Vector svec(args_->dim); 580 | std::string word; 581 | while (std::getline(std::cin, sentence)) { 582 | std::istringstream iss(sentence); 583 | svec.zero(); 584 | int32_t count = 0; 585 | while(iss >> word) { 586 | getVector(vec, word); 587 | vec.mul(1.0 / vec.norm()); 588 | svec.addVector(vec); 589 | count++; 590 | } 591 | svec.mul(1.0 / count); 592 | std::cout << sentence << " " << svec << std::endl; 593 | } 594 | } 595 | 596 | void FastText::ngramVectors(std::string word) { 597 | std::vector ngrams; 598 | std::vector substrings; 599 | Vector vec(args_->dim); 600 | dict_->getNgrams(word, ngrams, substrings); 601 | for (int32_t i = 0; i < ngrams.size(); i++) { 602 | vec.zero(); 603 | if (ngrams[i] >= 0) { 604 | vec.addRow(*input_, ngrams[i]); 605 | } 606 | std::cout << substrings[i] << " " << vec << std::endl; 607 | } 608 | } 609 | 610 | void FastText::textVectors() { 611 | std::vector line, labels; 612 | Vector vec(args_->dim); 613 | while (std::cin.peek() != EOF) { 614 | dict_->getLine(std::cin, line, labels, model_->rng); 615 | vec.zero(); 616 | if (args_->model == model_name::bisent2vec){ 617 | dict_->addNgrams(line, args_->wordNgrams); 618 | } 619 | for (auto it = line.cbegin(); it != line.cend(); ++it) { 620 | vec.addRow(*input_, *it); 621 | } 622 | if (!line.empty()) { 623 | vec.mul(1.0 / line.size()); 624 | } 625 | std::cout << vec << std::endl; 626 | } 627 | } 628 | 629 | void FastText::textVectorThread(int thread_id, std::shared_ptr> sentences, std::shared_ptr emb, int num_threads) { 630 | std::vector line, labels; 631 | for (int sent_idx=thread_id; sent_idx < sentences->size(); sent_idx+=num_threads) { 632 | Vector vec(args_->dim); 633 | textVector(sentences->operator[](sent_idx), vec, line, labels); 634 | emb->addRow(vec, sent_idx, 1.); 635 | } 636 | } 637 | 638 | void FastText::textVectors(std::vector& sentences, int num_threads, std::vector& final) { 639 | std::shared_ptr emb; 640 | std::shared_ptr> sents; 641 | sents = std::make_shared>(sentences); 642 | emb = std::make_shared(sentences.size(), args_->dim); 643 | emb->zero(); 644 | std::vector threads; 645 | for (int32_t i = 0; i < num_threads; i++) { 646 | threads.push_back(std::thread([=]() { textVectorThread(i, sents, emb, num_threads); })); 647 | } 648 | for (auto it = threads.begin(); it != threads.end(); ++it) { 649 | it->join(); 650 | } 651 | memcpy(&final[0], &emb->data_[0], emb->m_*emb->n_ * sizeof(real)); 652 | } 653 | 654 | void FastText::getVector(Vector& vec, int32_t wordIdx) const { 655 | 656 | const std::vector& ngrams = dict_->getNgrams(wordIdx); 657 | vec.zero(); 658 | for (auto it = ngrams.begin(); it != ngrams.end(); ++it) { 659 | vec.addRow(*input_, *it); 660 | } 661 | if (ngrams.size() > 0) { 662 | vec.mul(1.0 / ngrams.size()); 663 | } 664 | } 665 | 666 | void FastText::getVector_ckpt(Vector& vec, int32_t wordIdx) const { 667 | const std::vector& ngrams = dict_->getNgrams(wordIdx); 668 | vec.zero(); 669 | vec.addRow(*input_, wordIdx); 670 | } 671 | 672 | void FastText::textVector(std::string text, Vector& vec, std::vector& line, std::vector& labels) { 673 | std::istringstream text_stream(text); 674 | dict_->getLine(text_stream, line, labels, model_->rng); 675 | vec.zero(); 676 | if (args_->model == model_name::bisent2vec){ 677 | dict_->addNgrams(line, args_->wordNgrams); 678 | } 679 | for (auto it = line.cbegin(); it != line.cend(); ++it) { 680 | vec.addRow(*input_, *it); 681 | } 682 | if (!line.empty()) { 683 | vec.mul(1.0 / line.size()); 684 | } 685 | } 686 | 687 | void FastText::printWordVectors() { 688 | wordVectors(); 689 | } 690 | 691 | void FastText::printVocabularyVectors(bool input_matrix) { 692 | std::vector words = getVocab(); 693 | std::shared_ptr matrix = input_matrix ? input_ : output_; 694 | Vector vec(args_->dim); 695 | std::string word; 696 | for(int i = 0; i < words.size(); i++) { 697 | word = words[i]; 698 | const std::vector &ngrams = dict_->getNgrams(word); 699 | vec.zero(); 700 | for (auto it = ngrams.begin(); it != ngrams.end(); ++it) { 701 | vec.addRow(*matrix, *it); 702 | } 703 | if (ngrams.size() > 0) { 704 | vec.mul(1.0 / ngrams.size()); 705 | } 706 | std::cout << word << " " << " " << vec << std::endl; 707 | } 708 | } 709 | 710 | void FastText::printSentenceVectors() { 711 | if (args_->model == model_name::bisent2vec) { 712 | textVectors(); 713 | } else { 714 | sentenceVectors(); 715 | } 716 | } 717 | 718 | void FastText::precomputeWordVectors(Matrix& wordVectors) { 719 | Vector vec(args_->dim); 720 | wordVectors.zero(); 721 | std::cerr << "Pre-computing word vectors..."; 722 | for (int32_t i = 0; i < dict_->nwords(); i++) { 723 | std::string word = dict_->getWord(i); 724 | getVector(vec, word); 725 | real norm = vec.norm(); 726 | wordVectors.addRow(vec, i, 1.0 / norm); 727 | } 728 | std::cerr << " done." << std::endl; 729 | } 730 | 731 | void FastText::precomputeSentenceVectors(Matrix& sentenceVectors,std::ifstream& in) { 732 | Vector vec(args_->dim); 733 | sentenceVectors.zero(); 734 | std::cerr << "Pre-computing sentence vectors..."; 735 | std::vector line; 736 | std::vector labels; 737 | int32_t i = 0; 738 | while (i < sentenceVectors.m_) { 739 | 740 | dict_->getLine(in, line, labels, model_->rng); 741 | dict_->addNgrams(line, args_->wordNgrams); 742 | 743 | vec.zero(); 744 | for (auto it = line.cbegin(); it != line.cend(); ++it) { 745 | vec.addRow(*input_, *it); 746 | } 747 | if (!line.empty()) { 748 | vec.mul(1.0 / line.size()); 749 | } 750 | real norm = vec.norm(); 751 | if(norm != 0) 752 | sentenceVectors.addRow(vec, i, 1.0 / norm); 753 | i++; 754 | } 755 | std::cerr << " done." << std::endl; 756 | } 757 | 758 | void FastText::findNN(const Matrix& wordVectors, const Vector& queryVec, 759 | int32_t k, const std::set& banSet) { 760 | real queryNorm = queryVec.norm(); 761 | if (std::abs(queryNorm) < 1e-8) { 762 | queryNorm = 1; 763 | } 764 | std::priority_queue> heap; 765 | Vector vec(args_->dim); 766 | for (int32_t i = 0; i < dict_->nwords(); i++) { 767 | std::string word = dict_->getWord(i); 768 | real dp = wordVectors.dotRow(queryVec, i); 769 | heap.push(std::make_pair(dp / queryNorm, word)); 770 | } 771 | int32_t i = 0; 772 | while (i < k && heap.size() > 0) { 773 | auto it = banSet.find(heap.top().second); 774 | if (it == banSet.end()) { 775 | std::cout << heap.top().second << " " << heap.top().first << std::endl; 776 | i++; 777 | } 778 | heap.pop(); 779 | } 780 | } 781 | 782 | void FastText::findNNSent(const Matrix& sentenceVectors, const Vector& queryVec, 783 | int32_t k, const std::set& banSet, int64_t numSent, 784 | const std::vector& sentences) { 785 | real queryNorm = queryVec.norm(); 786 | if (std::abs(queryNorm) < 1e-8) { 787 | queryNorm = 1; 788 | } 789 | std::priority_queue> heap; 790 | Vector vec(args_->dim); 791 | 792 | for (int32_t i = 0; i < numSent; i++) { 793 | std::string sentence = std::to_string(i) + " " + sentences[i]; 794 | real dp = sentenceVectors.dotRow(queryVec, i); 795 | heap.push(std::make_pair(dp / queryNorm, sentence)); 796 | } 797 | 798 | int32_t i = 0; 799 | while (i < k && heap.size() > 0) { 800 | auto it = banSet.find(heap.top().second); 801 | if (!std::isnan(heap.top().first)) { 802 | std::cout << heap.top().first << " " 803 | << heap.top().second << " " 804 | << std::endl; 805 | i++; 806 | } 807 | heap.pop(); 808 | } 809 | } 810 | 811 | 812 | void FastText::nn(int32_t k) { 813 | std::string queryWord; 814 | Vector queryVec(args_->dim); 815 | Matrix wordVectors(dict_->nwords(), args_->dim); 816 | precomputeWordVectors(wordVectors); 817 | std::set banSet; 818 | std::cerr << "Query word? " << std::endl; 819 | while (std::cin >> queryWord) { 820 | banSet.clear(); 821 | banSet.insert(queryWord); 822 | getVector(queryVec, queryWord); 823 | findNN(wordVectors, queryVec, k, banSet); 824 | std::cerr << "Query word? " << std::endl; 825 | } 826 | } 827 | 828 | void FastText::analogies(int32_t k) { 829 | std::string word; 830 | Vector buffer(args_->dim), query(args_->dim); 831 | Matrix wordVectors(dict_->nwords(), args_->dim); 832 | precomputeWordVectors(wordVectors); 833 | std::set banSet; 834 | std::cerr << "Query triplet (A - B + C)? " << std::endl; 835 | while (true) { 836 | banSet.clear(); 837 | query.zero(); 838 | std::cin >> word; 839 | banSet.insert(word); 840 | getVector(buffer, word); 841 | query.addVector(buffer, 1.0); 842 | std::cin >> word; 843 | banSet.insert(word); 844 | getVector(buffer, word); 845 | query.addVector(buffer, -1.0); 846 | std::cin >> word; 847 | banSet.insert(word); 848 | getVector(buffer, word); 849 | query.addVector(buffer, 1.0); 850 | 851 | findNN(wordVectors, query, k, banSet); 852 | std::cerr << "Query triplet (A - B + C)? " << std::endl; 853 | } 854 | } 855 | 856 | void FastText::nnSent(int32_t k, std::string filename) { 857 | std::string sentence; 858 | std::ifstream in1(filename); 859 | int64_t n = 0; 860 | 861 | Vector buffer(args_->dim), query(args_->dim); 862 | std::vector sentences; 863 | 864 | std::vector line, labels; 865 | std::ifstream in2(filename); 866 | 867 | while (in2.peek() != EOF) { 868 | std::getline(in2, sentence); 869 | sentences.push_back(sentence); 870 | n++; 871 | } 872 | std::cout << "Number of sentences in the corpus file is " << n << "." << std::endl ; 873 | Matrix sentenceVectors(n+1, args_->dim); 874 | 875 | precomputeSentenceVectors(sentenceVectors, in1); 876 | std::set banSet; 877 | 878 | std::cerr << "Query sentence? " << std::endl; 879 | while (std::cin.peek() != EOF) { 880 | query.zero(); 881 | dict_->getLine(std::cin, line, labels, model_->rng); 882 | dict_->addNgrams(line, args_->wordNgrams); 883 | buffer.zero(); 884 | for (auto it = line.cbegin(); it != line.cend(); ++it) { 885 | buffer.addRow(*input_, *it); 886 | } 887 | if (!line.empty()) { 888 | buffer.mul(1.0 / line.size()); 889 | } 890 | query.addVector(buffer, 1.0); 891 | 892 | findNNSent(sentenceVectors, query, k, banSet, n, sentences); 893 | std::cout << std::endl; 894 | std::cerr << "Query sentence? " << std::endl; 895 | } 896 | } 897 | 898 | 899 | void FastText::analogiesSent(int32_t k, std::string filename) { 900 | std::string sentence; 901 | std::ifstream in1(filename); 902 | int64_t n = 0; 903 | 904 | Vector buffer(args_->dim), query(args_->dim); 905 | std::vector sentences; 906 | 907 | std::vector line, labels; 908 | 909 | std::ifstream in2(filename); 910 | 911 | while (in2.peek() != EOF) { 912 | std::getline(in2, sentence); 913 | sentences.push_back(sentence); 914 | n++; 915 | } 916 | std::cout << "Number of sentences in the corpus file is " << n << "." << std::endl ; 917 | 918 | Matrix sentenceVectors(n+1, args_->dim); 919 | 920 | precomputeSentenceVectors(sentenceVectors, in1); 921 | std::set banSet; 922 | std::cerr << "Query triplet sentences (A - B + C)? " << std::endl; 923 | while (true) { 924 | banSet.clear(); 925 | query.zero(); 926 | dict_->getLine(std::cin, line, labels, model_->rng); 927 | dict_->addNgrams(line, args_->wordNgrams); 928 | buffer.zero(); 929 | for (auto it = line.cbegin(); it != line.cend(); ++it) { 930 | buffer.addRow(*input_, *it); 931 | } 932 | if (!line.empty()) { 933 | buffer.mul(1.0 / line.size()); 934 | } 935 | query.addVector(buffer, 1.0); 936 | 937 | dict_->getLine(std::cin, line, labels, model_->rng); 938 | dict_->addNgrams(line, args_->wordNgrams); 939 | buffer.zero(); 940 | for (auto it = line.cbegin(); it != line.cend(); ++it) { 941 | buffer.addRow(*input_, *it); 942 | } 943 | if (!line.empty()) { 944 | buffer.mul(1.0 / line.size()); 945 | } 946 | 947 | query.addVector(buffer, -1.0); 948 | 949 | dict_->getLine(std::cin, line, labels, model_->rng); 950 | dict_->addNgrams(line, args_->wordNgrams); 951 | buffer.zero(); 952 | for (auto it = line.cbegin(); it != line.cend(); ++it) { 953 | buffer.addRow(*input_, *it); 954 | } 955 | if (!line.empty()) { 956 | buffer.mul(1.0 / line.size()); 957 | } 958 | 959 | query.addVector(buffer, 1.0); 960 | 961 | findNNSent(sentenceVectors, query, k, banSet, n, sentences); 962 | std::cerr << "Query triplet sentences (A - B + C)? " << std::endl; 963 | } 964 | } 965 | 966 | 967 | void FastText::trainThread(int32_t threadId) { 968 | std::ifstream ifs(args_->input); 969 | utils::seek(ifs, threadId * utils::size(ifs) / args_->thread); 970 | int32_t currCheckPoint = 0; 971 | Model model(input_, output_, args_, threadId); 972 | model.setTargetCounts(dict_->getCounts(entry_type::word)); 973 | 974 | const int64_t ntokens = dict_->ntokens(); 975 | int64_t localTokenCount = 0; 976 | std::vector line, labels; 977 | while (tokenCount < args_->epoch * ntokens) { 978 | real progress = real(tokenCount) / (args_->epoch * ntokens); 979 | real lr = args_->lr * (1.0 - progress); 980 | localTokenCount += dict_->getLine(ifs, line, labels, model.rng); 981 | if (args_->model == model_name::bisent2vec) { 982 | bisent2vec(model, lr, line); 983 | } else if (args_->model == model_name::cbow) { 984 | cbow(model, lr, line); 985 | } else if (args_->model == model_name::transgram) { 986 | transgram(model, lr, line); 987 | } 988 | if (localTokenCount > args_->lrUpdateRate) { 989 | tokenCount += localTokenCount; 990 | localTokenCount = 0; 991 | if (threadId == 0 && args_->verbose > 1) { 992 | printInfo(progress, model.getLoss()); 993 | } 994 | } 995 | if (threadId == 0 && currCheckPoint != (int)(progress*args_->numCheckPoints)) { 996 | currCheckPoint++; 997 | printInfo(progress, model.getLoss()); 998 | std::cerr << std::endl; 999 | std::cerr<<"Saving Model ----- Checkpoint "<< currCheckPoint<< std::endl; 1000 | model_ = std::make_shared(input_, output_, args_, 0); 1001 | saveModel(currCheckPoint); 1002 | saveVectors(currCheckPoint); 1003 | } 1004 | } 1005 | if (threadId == 0 && args_->verbose > 0) { 1006 | printInfo(1.0, model.getLoss()); 1007 | std::cerr << std::endl; 1008 | } 1009 | ifs.close(); 1010 | } 1011 | 1012 | void FastText::loadVectors(std::string filename) { 1013 | std::ifstream in(filename); 1014 | std::vector words; 1015 | std::shared_ptr mat; // temp. matrix for pretrained vectors 1016 | int64_t n, dim; 1017 | real x; 1018 | if (!in.is_open()) { 1019 | std::cerr << "Pretrained vectors file cannot be opened!" << std::endl; 1020 | exit(EXIT_FAILURE); 1021 | } 1022 | in >> n >> dim; 1023 | if (dim != args_->dim) { 1024 | std::cerr << "Dimension of pretrained vectors does not match -dim option" 1025 | << std::endl; 1026 | exit(EXIT_FAILURE); 1027 | } 1028 | mat = std::make_shared(n, dim); 1029 | for (size_t i = 0; i < n; i++) { 1030 | std::string word; 1031 | in >> word; 1032 | if (dict_->getId(word) != -1){ 1033 | words.push_back(word); 1034 | dict_->add(word); 1035 | 1036 | for (size_t j = 0; j < dim; j++) { 1037 | in >> mat->data_[(words.size() -1) * dim + j]; 1038 | } 1039 | 1040 | } 1041 | else{ 1042 | for (size_t j = 0; j < dim; j++) { 1043 | in >> x; 1044 | } 1045 | } 1046 | } 1047 | in.close(); 1048 | 1049 | input_ = std::make_shared(dict_->nwords()+args_->bucket, args_->dim); 1050 | input_->uniform(1.0 / args_->dim); 1051 | 1052 | for (size_t i = 0; i < words.size(); i++) { 1053 | int32_t idx = dict_->getId(words[i]); 1054 | 1055 | if (idx < 0 || idx >= dict_->nwords()) continue; 1056 | for (size_t j = 0; j < dim; j++) { 1057 | input_->data_[idx * dim + j] = mat->data_[i * dim + j]; 1058 | } 1059 | } 1060 | } 1061 | 1062 | void FastText::loadOutputVectors(std::string filename) { 1063 | std::ifstream in(filename); 1064 | std::vector words; 1065 | std::shared_ptr mat; // temp. matrix for pretrained vectors 1066 | int64_t n, dim; 1067 | real x; 1068 | if (!in.is_open()) { 1069 | std::cerr << "Pretrained vectors file cannot be opened!" << std::endl; 1070 | exit(EXIT_FAILURE); 1071 | } 1072 | in >> n >> dim; 1073 | if (dim != args_->dim) { 1074 | std::cerr << "Dimension of pretrained vectors does not match -dim option" 1075 | << std::endl; 1076 | exit(EXIT_FAILURE); 1077 | } 1078 | mat = std::make_shared(n, dim); 1079 | for (size_t i = 0; i < n; i++) { 1080 | std::string word; 1081 | in >> word; 1082 | if (dict_->getId(word) != -1){ 1083 | words.push_back(word); 1084 | dict_->add(word); 1085 | 1086 | for (size_t j = 0; j < dim; j++) { 1087 | in >> mat->data_[(words.size() -1) * dim + j]; 1088 | } 1089 | 1090 | } 1091 | else{ 1092 | for (size_t j = 0; j < dim; j++) { 1093 | in >> x; 1094 | } 1095 | } 1096 | } 1097 | in.close(); 1098 | 1099 | output_ = std::make_shared(dict_->nwords(), args_->dim); 1100 | output_->uniform(0.0 / args_->dim); 1101 | 1102 | for (size_t i = 0; i < words.size(); i++) { 1103 | int32_t idx = dict_->getId(words[i]); 1104 | 1105 | if (idx < 0 || idx >= dict_->nwords()) continue; 1106 | for (size_t j = 0; j < dim; j++) { 1107 | output_->data_[idx * dim + j] = mat->data_[i * dim + j]; 1108 | } 1109 | } 1110 | 1111 | } 1112 | 1113 | void FastText::savedDictTrain(std::shared_ptr args) { 1114 | loadDict(args_->dict); 1115 | if (args_->input == "-") { 1116 | // manage expectations 1117 | std::cerr << "Cannot use stdin for training!" << std::endl; 1118 | exit(EXIT_FAILURE); 1119 | } 1120 | std::ifstream ifs(args_->input); 1121 | if (!ifs.is_open()) { 1122 | std::cerr << "Input file cannot be opened!" << std::endl; 1123 | exit(EXIT_FAILURE); 1124 | } 1125 | dict_->readFromFile(ifs); 1126 | ifs.close(); 1127 | 1128 | if (args_->pretrainedVectors.size() != 0) { 1129 | loadVectors(args_->pretrainedVectors); 1130 | } else { 1131 | input_ = std::make_shared(dict_->nwords()+args_->bucket, args_->dim); 1132 | input_->uniform(1.0 / args_->dim); 1133 | } 1134 | 1135 | output_ = std::make_shared(dict_->nwords(), args_->dim); 1136 | output_->zero(); 1137 | 1138 | start = clock(); 1139 | tokenCount = 0; 1140 | if (args_->thread > 1) { 1141 | std::vector threads; 1142 | for (int32_t i = 0; i < args_->thread; i++) { 1143 | threads.push_back(std::thread([=]() { trainThread(i); })); 1144 | } 1145 | for (auto it = threads.begin(); it != threads.end(); ++it) { 1146 | it->join(); 1147 | } 1148 | } else { 1149 | trainThread(0); 1150 | } 1151 | model_ = std::make_shared(input_, output_, args_, 0); 1152 | 1153 | saveModel(); 1154 | if (args_->model != model_name::bisent2vec) { 1155 | saveVectors(); 1156 | if (args_->saveOutput > 0) { 1157 | saveOutput(); 1158 | } 1159 | } 1160 | } 1161 | 1162 | void FastText::trainDict(std::shared_ptr args) { 1163 | args_ = args; 1164 | dict_ = std::make_shared(args_); 1165 | if (args_->input == "-") { 1166 | // manage expectations 1167 | std::cerr << "Cannot use stdin for training!" << std::endl; 1168 | exit(EXIT_FAILURE); 1169 | } 1170 | std::ifstream ifs(args_->input); 1171 | if (!ifs.is_open()) { 1172 | std::cerr << "Input file cannot be opened!" << std::endl; 1173 | exit(EXIT_FAILURE); 1174 | } 1175 | dict_->readFromFile(ifs); 1176 | saveDict(); 1177 | ifs.close(); 1178 | } 1179 | 1180 | void FastText::train(std::shared_ptr args) { 1181 | args_ = args; 1182 | dict_ = std::make_shared(args_); 1183 | if (args_->input == "-") { 1184 | // manage expectations 1185 | std::cerr << "Cannot use stdin for training!" << std::endl; 1186 | exit(EXIT_FAILURE); 1187 | } 1188 | std::ifstream ifs(args_->input); 1189 | if (!ifs.is_open()) { 1190 | std::cerr << "Input file cannot be opened!" << std::endl; 1191 | exit(EXIT_FAILURE); 1192 | } 1193 | dict_->readFromFile(ifs); 1194 | ifs.close(); 1195 | 1196 | if (args_->pretrainedVectors.size() != 0) { 1197 | loadVectors(args_->pretrainedVectors); 1198 | loadOutputVectors(args_->pretrainedVectors); 1199 | } else { 1200 | input_ = std::make_shared(dict_->nwords()+args_->bucket+args_->bucketChar, args_->dim); 1201 | input_->uniform(1.0 / args_->dim); 1202 | output_ = std::make_shared(dict_->nwords(), args_->dim); 1203 | output_->zero(); 1204 | } 1205 | 1206 | start = clock(); 1207 | tokenCount = 0; 1208 | if (args_->thread > 1) { 1209 | std::vector threads; 1210 | for (int32_t i = 0; i < args_->thread; i++) { 1211 | threads.push_back(std::thread([=]() { trainThread(i); })); 1212 | } 1213 | for (auto it = threads.begin(); it != threads.end(); ++it) { 1214 | it->join(); 1215 | } 1216 | } else { 1217 | trainThread(0); 1218 | } 1219 | model_ = std::make_shared(input_, output_, args_, 0); 1220 | 1221 | saveModel(); 1222 | saveVectors(); 1223 | if (args_->saveOutput > 0) { 1224 | saveOutput(); 1225 | } 1226 | 1227 | } 1228 | 1229 | int FastText::getDimension() const { 1230 | return args_->dim; 1231 | } 1232 | 1233 | } 1234 | -------------------------------------------------------------------------------- /src/fasttext.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_FASTTEXT_H 11 | #define FASTTEXT_FASTTEXT_H 12 | 13 | #define FASTTEXT_VERSION 11 /* Version 1a */ 14 | #define FASTTEXT_FILEFORMAT_MAGIC_INT32 793712314 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "args.h" 23 | #include "dictionary.h" 24 | #include "matrix.h" 25 | #include "qmatrix.h" 26 | #include "model.h" 27 | #include "real.h" 28 | #include "utils.h" 29 | #include "vector.h" 30 | 31 | namespace fasttext { 32 | 33 | class FastText { 34 | private: 35 | std::shared_ptr args_; 36 | std::shared_ptr dict_; 37 | 38 | std::shared_ptr input_; 39 | std::shared_ptr output_; 40 | 41 | std::shared_ptr qinput_; 42 | std::shared_ptr qoutput_; 43 | 44 | std::shared_ptr model_; 45 | 46 | std::atomic tokenCount; 47 | clock_t start; 48 | void signModel(std::ostream&); 49 | bool checkModel(std::istream&); 50 | 51 | bool quant_; 52 | 53 | public: 54 | FastText(); 55 | 56 | void getVector(Vector&, const std::string&) const; 57 | void getVector(Vector&, int32_t) const; 58 | void getVector_ckpt(Vector&, int32_t) const; 59 | void saveVectors(); 60 | void saveVectors(int32_t); 61 | void saveOutput(); 62 | void saveModel(); 63 | void saveModel(int32_t); 64 | void saveDict(); 65 | 66 | void loadModel(const std::string&, const bool inference_mode = false, const int timeout_sec = -1); 67 | void loadModel(std::istream&); 68 | void loadModelForInference(std::istream&, const std::string&, const int); 69 | void loadDict(const std::string&); 70 | void loadDict(std::istream&); 71 | void printInfo(real, real); 72 | 73 | void supervised(Model&, real, const std::vector&, 74 | const std::vector&); 75 | void cbow(Model&, real, const std::vector&); 76 | void bisent2vec(Model&, real, const std::vector&); 77 | void transgram(Model&, real, const std::vector&); 78 | std::vector selectEmbeddings(int32_t) const; 79 | void quantize(std::shared_ptr); 80 | void test(std::istream&, int32_t); 81 | void predict(std::istream&, int32_t, bool); 82 | void predict(std::istream&, int32_t, std::vector>&) const; 83 | void wordVectors(); 84 | void sentenceVectors(); 85 | void ngramVectors(std::string); 86 | void textVectors(); 87 | void textVectorThread(int, std::shared_ptr>, std::shared_ptr, int); 88 | void textVectors(std::vector&, int, std::vector&); 89 | void textVector(std::string, Vector&, std::vector&, std::vector&); 90 | void printWordVectors(); 91 | void printVocabularyVectors(bool); 92 | void printSentenceVectors(); 93 | std::vector getVocab(); 94 | std::vector getUnigramsCounts(); 95 | void trainThread(int32_t); 96 | 97 | void savedDictTrain(std::shared_ptr); 98 | void trainDict(std::shared_ptr); 99 | 100 | 101 | void train(std::shared_ptr); 102 | void precomputeWordVectors(Matrix&); 103 | void precomputeSentenceVectors(Matrix&,std::ifstream&); 104 | void findNN(const Matrix&, const Vector&, int32_t, 105 | const std::set&); 106 | void findNNSent(const Matrix&, const Vector&, int32_t, 107 | const std::set&, int64_t, const std::vector&); 108 | void nn(int32_t); 109 | void analogies(int32_t); 110 | void nnSent(int32_t, std::string ); 111 | void analogiesSent(int32_t, std::string ); 112 | 113 | void loadVectors(std::string); 114 | void loadOutputVectors(std::string); 115 | int getDimension() const; 116 | }; 117 | 118 | } 119 | 120 | #endif 121 | -------------------------------------------------------------------------------- /src/main.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include 11 | 12 | #include "fasttext.h" 13 | #include "args.h" 14 | 15 | using namespace fasttext; 16 | 17 | void printUsage() { 18 | std::cerr 19 | << "usage: fasttext \n\n" 20 | << "The commands supported by fasttext are:\n\n" 21 | << " bisent2vec train unsupervised sentence embeddings\n" 22 | << " quantize quantize a model to reduce the memory usage\n" 23 | << " test evaluate a supervised classifier\n" 24 | << " predict predict most likely labels\n" 25 | << " predict-prob predict most likely labels with probabilities\n" 26 | << " transgram train a skipgram model\n" 27 | << " print-word-vectors print word vectors given a trained model\n" 28 | << " print-sentence-vectors print sentence vectors given a trained model\n" 29 | << " print-vocabulary-vectors print unigram vectors in vocabulary\n" 30 | << " print-vocabulary print word in vocabulary with count\n" 31 | << " nn query for nearest neighbors\n" 32 | << " nnSent query for nearest neighbors for sentences\n" 33 | << " analogies query for analogies\n" 34 | << " analogiesSent query for analogies for Sentences\n" 35 | << std::endl; 36 | } 37 | 38 | void printQuantizeUsage() { 39 | std::cerr 40 | << "usage: fasttext quantize " 41 | << std::endl; 42 | } 43 | 44 | void printTestUsage() { 45 | std::cerr 46 | << "usage: fasttext test []\n\n" 47 | << " model filename\n" 48 | << " test data filename (if -, read from stdin)\n" 49 | << " (optional; 1 by default) predict top k labels\n" 50 | << std::endl; 51 | } 52 | 53 | void printPredictUsage() { 54 | std::cerr 55 | << "usage: fasttext predict[-prob] []\n\n" 56 | << " model filename\n" 57 | << " test data filename (if -, read from stdin)\n" 58 | << " (optional; 1 by default) predict top k labels\n" 59 | << std::endl; 60 | } 61 | 62 | void printPrintWordVectorsUsage() { 63 | std::cerr 64 | << "usage: fasttext print-word-vectors \n\n" 65 | << " model filename\n" 66 | << std::endl; 67 | } 68 | 69 | void printPrintSentenceVectorsUsage() { 70 | std::cerr 71 | << "usage: fasttext print-sentence-vectors \n\n" 72 | << " model filename\n" 73 | << std::endl; 74 | } 75 | 76 | void printPrintVocabularyVectorsUsage() { 77 | std::cerr 78 | << "usage: fasttext print-vocabulary-vectors \n\n" 79 | << " model filename\n" 80 | << " embedding matrix: input|output\n" 81 | << std::endl; 82 | } 83 | 84 | void printPrintVocabularyUsage() { 85 | std::cerr 86 | << "usage: fasttext print-vocabulary \n\n" 87 | << " model filename\n" 88 | << std::endl; 89 | } 90 | 91 | void printPrintNgramsUsage() { 92 | std::cerr 93 | << "usage: fasttext print-ngrams \n\n" 94 | << " model filename\n" 95 | << " word to print\n" 96 | << std::endl; 97 | } 98 | 99 | void quantize(int argc, char** argv) { 100 | std::shared_ptr a = std::make_shared(); 101 | if (argc < 3) { 102 | printQuantizeUsage(); 103 | a->printHelp(); 104 | exit(EXIT_FAILURE); 105 | } 106 | a->parseArgs(argc, argv); 107 | FastText fasttext; 108 | fasttext.quantize(a); 109 | exit(0); 110 | } 111 | 112 | void printNNUsage() { 113 | std::cout 114 | << "usage: fasttext nn \n\n" 115 | << " model filename\n" 116 | << " (optional; 10 by default) predict top k labels\n" 117 | << std::endl; 118 | } 119 | 120 | void printNNSentUsage() { 121 | std::cerr 122 | << "usage: fasttext nnSent \n\n" 123 | << " model filename\n" 124 | << " corpus filename \n" 125 | << " (optional; 10 by default) predict top k labels\n" 126 | << std::endl; 127 | std::cout<<"NOTE : A corpus file is required to find similar sentences."< \n\n" 133 | << " model filename\n" 134 | << " (optional; 10 by default) predict top k labels\n" 135 | << std::endl; 136 | } 137 | 138 | void printAnalogiesSentUsage() { 139 | std::cout 140 | << "usage: fasttext analogiesSent \n\n" 141 | << " model filename\n" 142 | << " corpus filename \n" 143 | << " (optional; 10 by default) predict top k labels\n" 144 | << std::endl; 145 | std::cout<<"NOTE : A corpus file is required to find similar sentences."< 5) { 150 | printTestUsage(); 151 | exit(EXIT_FAILURE); 152 | } 153 | int32_t k = 1; 154 | if (argc >= 5) { 155 | k = atoi(argv[4]); 156 | } 157 | 158 | FastText fasttext; 159 | fasttext.loadModel(std::string(argv[2])); 160 | 161 | std::string infile(argv[3]); 162 | if (infile == "-") { 163 | fasttext.test(std::cin, k); 164 | } else { 165 | std::ifstream ifs(infile); 166 | if (!ifs.is_open()) { 167 | std::cerr << "Test file cannot be opened!" << std::endl; 168 | exit(EXIT_FAILURE); 169 | } 170 | fasttext.test(ifs, k); 171 | ifs.close(); 172 | } 173 | exit(0); 174 | } 175 | 176 | void predict(int argc, char** argv) { 177 | if (argc < 4 || argc > 5) { 178 | printPredictUsage(); 179 | exit(EXIT_FAILURE); 180 | } 181 | int32_t k = 1; 182 | if (argc >= 5) { 183 | k = atoi(argv[4]); 184 | } 185 | 186 | bool print_prob = std::string(argv[1]) == "predict-prob"; 187 | FastText fasttext; 188 | fasttext.loadModel(std::string(argv[2])); 189 | 190 | std::string infile(argv[3]); 191 | if (infile == "-") { 192 | fasttext.predict(std::cin, k, print_prob); 193 | } else { 194 | std::ifstream ifs(infile); 195 | if (!ifs.is_open()) { 196 | std::cerr << "Input file cannot be opened!" << std::endl; 197 | exit(EXIT_FAILURE); 198 | } 199 | fasttext.predict(ifs, k, print_prob); 200 | ifs.close(); 201 | } 202 | 203 | exit(0); 204 | } 205 | 206 | void printWordVectors(int argc, char** argv) { 207 | if (argc != 3) { 208 | printPrintWordVectorsUsage(); 209 | exit(EXIT_FAILURE); 210 | } 211 | FastText fasttext; 212 | fasttext.loadModel(std::string(argv[2])); 213 | fasttext.printWordVectors(); 214 | exit(0); 215 | } 216 | 217 | void printSentenceVectors(int argc, char** argv) { 218 | if (argc != 3) { 219 | printPrintSentenceVectorsUsage(); 220 | exit(EXIT_FAILURE); 221 | } 222 | FastText fasttext; 223 | fasttext.loadModel(std::string(argv[2])); 224 | fasttext.printSentenceVectors(); 225 | exit(0); 226 | } 227 | 228 | void printVocabularyVectors(int argc, char** argv) { 229 | if (argc != 4) { 230 | printPrintVocabularyVectorsUsage(); 231 | exit(EXIT_FAILURE); 232 | } 233 | std::string matrix_type = std::string(argv[3]); 234 | if (matrix_type != "input" && matrix_type != "output") { 235 | printPrintVocabularyVectorsUsage(); 236 | exit(EXIT_FAILURE); 237 | } 238 | FastText fasttext; 239 | fasttext.loadModel(std::string(argv[2])); 240 | fasttext.printVocabularyVectors(matrix_type == "input"); 241 | exit(0); 242 | } 243 | 244 | void printVocabulary(int argc, char** argv) { 245 | if (argc != 3) { 246 | printPrintVocabularyUsage(); 247 | exit(EXIT_FAILURE); 248 | } 249 | 250 | FastText fasttext; 251 | fasttext.loadModel(std::string(argv[2])); 252 | std::vector vocabulary = fasttext.getVocab(); 253 | std::vector counts = fasttext.getUnigramsCounts(); 254 | for (int i = 0; i < vocabulary.size(); i++) { 255 | std::cout << vocabulary[i] << " " << counts[i] << std::endl; 256 | } 257 | exit(0); 258 | } 259 | 260 | void printNgrams(int argc, char** argv) { 261 | if (argc != 4) { 262 | printPrintNgramsUsage(); 263 | exit(EXIT_FAILURE); 264 | } 265 | FastText fasttext; 266 | fasttext.loadModel(std::string(argv[2])); 267 | fasttext.ngramVectors(std::string(argv[3])); 268 | exit(0); 269 | } 270 | 271 | void nn(int argc, char** argv) { 272 | int32_t k; 273 | if (argc == 3) { 274 | k = 10; 275 | } else if (argc == 4) { 276 | k = atoi(argv[3]); 277 | } else { 278 | printNNUsage(); 279 | exit(EXIT_FAILURE); 280 | } 281 | FastText fasttext; 282 | fasttext.loadModel(std::string(argv[2])); 283 | fasttext.nn(k); 284 | exit(0); 285 | } 286 | 287 | void nnSent(int argc, char** argv) { 288 | int32_t k; 289 | if (argc == 4) { 290 | k = 10; 291 | } else if (argc == 5) { 292 | k = atoi(argv[4]); 293 | } else { 294 | printNNSentUsage(); 295 | exit(EXIT_FAILURE); 296 | } 297 | FastText fasttext; 298 | fasttext.loadModel(std::string(argv[2])); 299 | fasttext.nnSent(k,std::string(argv[3])); 300 | exit(0); 301 | } 302 | 303 | 304 | void analogies(int argc, char** argv) { 305 | int32_t k; 306 | if (argc == 3) { 307 | k = 10; 308 | } else if (argc == 4) { 309 | k = atoi(argv[3]); 310 | } else { 311 | printAnalogiesUsage(); 312 | exit(EXIT_FAILURE); 313 | } 314 | FastText fasttext; 315 | fasttext.loadModel(std::string(argv[2])); 316 | fasttext.analogies(k); 317 | exit(0); 318 | } 319 | 320 | void analogiesSent(int argc, char** argv) { 321 | int32_t k; 322 | if (argc == 4) { 323 | k = 10; 324 | } else if (argc == 5) { 325 | k = atoi(argv[4]); 326 | } else { 327 | printAnalogiesSentUsage(); 328 | exit(EXIT_FAILURE); 329 | } 330 | FastText fasttext; 331 | fasttext.loadModel(std::string(argv[2])); 332 | fasttext.analogiesSent(k,std::string(argv[3])); 333 | exit(0); 334 | } 335 | 336 | void train(int argc, char** argv) { 337 | std::shared_ptr a = std::make_shared(); 338 | a->parseArgs(argc, argv); 339 | FastText fasttext; 340 | fasttext.train(a); 341 | } 342 | 343 | void saveDict(int argc, char** argv) { 344 | std::shared_ptr a = std::make_shared(); 345 | a->parseArgs(argc, argv); 346 | FastText fasttext; 347 | fasttext.trainDict(a); 348 | } 349 | 350 | void trainFromDict(int argc, char** argv) { 351 | std::shared_ptr a = std::make_shared(); 352 | a->parseArgs(argc, argv); 353 | FastText fasttext; 354 | fasttext.savedDictTrain(a); 355 | } 356 | 357 | int main(int argc, char** argv) { 358 | if (argc < 2) { 359 | printUsage(); 360 | exit(EXIT_FAILURE); 361 | } 362 | std::string command(argv[1]); 363 | if (command == "transgram" || command == "bisent2vec" ) { 364 | train(argc, argv); 365 | } else if (command == "test") { 366 | test(argc, argv); 367 | } else if (command == "saveDict") { 368 | saveDict(argc, argv); 369 | } else if (command == "sent2vecFromDict") { 370 | trainFromDict(argc, argv); 371 | } else if (command == "quantize") { 372 | quantize(argc, argv); 373 | } else if (command == "print-word-vectors") { 374 | printWordVectors(argc, argv); 375 | } else if (command == "print-sentence-vectors") { 376 | printSentenceVectors(argc, argv); 377 | } else if (command == "print-vocabulary-vectors") { 378 | printVocabularyVectors(argc, argv); 379 | } else if (command == "print-vocabulary") { 380 | printVocabulary(argc, argv); 381 | } else if (command == "print-ngrams") { 382 | printNgrams(argc, argv); 383 | } else if (command == "nn") { 384 | nn(argc, argv); 385 | } else if (command == "nnSent") { 386 | nnSent(argc, argv); 387 | } else if (command == "analogies") { 388 | analogies(argc, argv); 389 | } else if (command == "analogiesSent") { 390 | analogiesSent(argc, argv); 391 | } else if (command == "predict" || command == "predict-prob" ) { 392 | predict(argc, argv); 393 | } else { 394 | printUsage(); 395 | exit(EXIT_FAILURE); 396 | } 397 | return 0; 398 | } 399 | -------------------------------------------------------------------------------- /src/matrix.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "matrix.h" 11 | 12 | #include 13 | 14 | #include 15 | 16 | #include "utils.h" 17 | #include "vector.h" 18 | 19 | namespace fasttext { 20 | 21 | Matrix::Matrix() { 22 | m_ = 0; 23 | n_ = 0; 24 | data_ = nullptr; 25 | } 26 | 27 | Matrix::Matrix(int64_t m, int64_t n) { 28 | m_ = m; 29 | n_ = n; 30 | data_ = new real[m * n]; 31 | } 32 | 33 | Matrix::Matrix(const Matrix& other) { 34 | m_ = other.m_; 35 | n_ = other.n_; 36 | data_ = new real[m_ * n_]; 37 | for (int64_t i = 0; i < (m_ * n_); i++) { 38 | data_[i] = other.data_[i]; 39 | } 40 | } 41 | 42 | Matrix& Matrix::operator=(const Matrix& other) { 43 | Matrix temp(other); 44 | m_ = temp.m_; 45 | n_ = temp.n_; 46 | std::swap(data_, temp.data_); 47 | return *this; 48 | } 49 | 50 | Matrix::~Matrix() { 51 | delete[] data_; 52 | } 53 | 54 | void Matrix::zero() { 55 | for (int64_t i = 0; i < (m_ * n_); i++) { 56 | data_[i] = 0.0; 57 | } 58 | } 59 | 60 | void Matrix::uniform(real a) { 61 | std::minstd_rand rng(1); 62 | std::uniform_real_distribution<> uniform(-a, a); 63 | for (int64_t i = 0; i < (m_ * n_); i++) { 64 | data_[i] = uniform(rng); 65 | } 66 | } 67 | 68 | real Matrix::dotRow(const Vector& vec, int64_t i) const { 69 | assert(i >= 0); 70 | assert(i < m_); 71 | assert(vec.size() == n_); 72 | real d = 0.0; 73 | for (int64_t j = 0; j < n_; j++) { 74 | d += at(i, j) * vec.data_[j]; 75 | } 76 | return d; 77 | } 78 | 79 | void Matrix::addRow(const Vector& vec, int64_t i, real a) { 80 | assert(i >= 0); 81 | assert(i < m_); 82 | assert(vec.size() == n_); 83 | for (int64_t j = 0; j < n_; j++) { 84 | data_[i * n_ + j] += a * vec.data_[j]; 85 | } 86 | } 87 | 88 | void Matrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) { 89 | if (ie == -1) {ie = m_;} 90 | assert(ie <= nums.size()); 91 | for (auto i = ib; i < ie; i++) { 92 | real n = nums[i-ib]; 93 | if (n != 0) { 94 | for (auto j = 0; j < n_; j++) { 95 | at(i, j) *= n; 96 | } 97 | } 98 | } 99 | } 100 | 101 | void Matrix::divideRow(const Vector& denoms, int64_t ib, int64_t ie) { 102 | if (ie == -1) {ie = m_;} 103 | assert(ie <= denoms.size()); 104 | for (auto i = ib; i < ie; i++) { 105 | real n = denoms[i-ib]; 106 | if (n != 0) { 107 | for (auto j = 0; j < n_; j++) { 108 | at(i, j) /= n; 109 | } 110 | } 111 | } 112 | } 113 | 114 | real Matrix::l2NormRow(int64_t i) const { 115 | auto norm = 0.0; 116 | for (auto j = 0; j < n_; j++) { 117 | const real v = at(i,j); 118 | norm += v * v; 119 | } 120 | return std::sqrt(norm); 121 | } 122 | 123 | void Matrix::l2NormRow(Vector& norms) const { 124 | assert(norms.size() == m_); 125 | for (auto i = 0; i < m_; i++) { 126 | norms[i] = l2NormRow(i); 127 | } 128 | } 129 | 130 | void Matrix::save(std::ostream& out) { 131 | out.write((char*) &m_, sizeof(int64_t)); 132 | out.write((char*) &n_, sizeof(int64_t)); 133 | out.write((char*) data_, m_ * n_ * sizeof(real)); 134 | } 135 | 136 | void Matrix::load(std::istream& in) { 137 | in.read((char*) &m_, sizeof(int64_t)); 138 | in.read((char*) &n_, sizeof(int64_t)); 139 | delete[] data_; 140 | data_ = new real[m_ * n_]; 141 | in.read((char*) data_, m_ * n_ * sizeof(real)); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/matrix.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_MATRIX_H 11 | #define FASTTEXT_MATRIX_H 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include "real.h" 18 | 19 | namespace fasttext { 20 | 21 | class Vector; 22 | 23 | class Matrix { 24 | 25 | public: 26 | real* data_; 27 | int64_t m_; 28 | int64_t n_; 29 | 30 | Matrix(); 31 | Matrix(int64_t, int64_t); 32 | Matrix(const Matrix&); 33 | Matrix& operator=(const Matrix&); 34 | ~Matrix(); 35 | 36 | inline const real& at(int64_t i, int64_t j) const {return data_[i * n_ + j];}; 37 | inline real& at(int64_t i, int64_t j) {return data_[i * n_ + j];}; 38 | 39 | 40 | void zero(); 41 | void uniform(real); 42 | real dotRow(const Vector&, int64_t) const; 43 | void addRow(const Vector&, int64_t, real); 44 | 45 | void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1); 46 | void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1); 47 | 48 | real l2NormRow(int64_t i) const; 49 | void l2NormRow(Vector& norms) const; 50 | 51 | void save(std::ostream&); 52 | void load(std::istream&); 53 | }; 54 | 55 | } 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /src/model.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "model.h" 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | namespace fasttext { 17 | 18 | Model::Model(std::shared_ptr wi, 19 | std::shared_ptr wo, 20 | std::shared_ptr args, 21 | int32_t seed) 22 | : hidden_(args->dim), output_(wo->m_), 23 | grad_(args->dim), rng(seed), quant_(false) 24 | { 25 | wi_ = wi; 26 | wo_ = wo; 27 | args_ = args; 28 | osz_ = wo->m_; 29 | hsz_ = args->dim; 30 | negpos = 0; 31 | loss_ = 0.0; 32 | nexamples_ = 1; 33 | initSigmoid(); 34 | initLog(); 35 | } 36 | 37 | Model::~Model() { 38 | delete[] t_sigmoid; 39 | delete[] t_log; 40 | } 41 | 42 | void Model::setQuantizePointer(std::shared_ptr qwi, 43 | std::shared_ptr qwo, bool qout) { 44 | qwi_ = qwi; 45 | qwo_ = qwo; 46 | if (qout) { 47 | osz_ = qwo_->getM(); 48 | } 49 | } 50 | 51 | real Model::binaryLogistic(int32_t target, bool label, real lr) { 52 | real score = sigmoid(wo_->dotRow(hidden_, target)); 53 | real alpha = lr * (real(label) - score); 54 | grad_.addRow(*wo_, target, alpha); 55 | wo_->addRow(hidden_, target, alpha); 56 | if (label) { 57 | return -log(score); 58 | } else { 59 | return -log(1.0 - score); 60 | } 61 | } 62 | 63 | real Model::negativeSampling(int32_t target, real lr) { 64 | real loss = 0.0; 65 | grad_.zero(); 66 | for (int32_t n = 0; n <= args_->neg; n++) { 67 | if (n == 0) { 68 | loss += binaryLogistic(target, true, lr); 69 | } else { 70 | loss += binaryLogistic(getNegative(target), false, lr); 71 | } 72 | } 73 | return loss; 74 | } 75 | 76 | real Model::hierarchicalSoftmax(int32_t target, real lr) { 77 | real loss = 0.0; 78 | grad_.zero(); 79 | const std::vector& binaryCode = codes[target]; 80 | const std::vector& pathToRoot = paths[target]; 81 | for (int32_t i = 0; i < pathToRoot.size(); i++) { 82 | loss += binaryLogistic(pathToRoot[i], binaryCode[i], lr); 83 | } 84 | return loss; 85 | } 86 | 87 | void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const { 88 | if (quant_ && args_->qout) { 89 | output.mul(*qwo_, hidden); 90 | } else { 91 | output.mul(*wo_, hidden); 92 | } 93 | real max = output[0], z = 0.0; 94 | for (int32_t i = 0; i < osz_; i++) { 95 | max = std::max(output[i], max); 96 | } 97 | for (int32_t i = 0; i < osz_; i++) { 98 | output[i] = exp(output[i] - max); 99 | z += output[i]; 100 | } 101 | for (int32_t i = 0; i < osz_; i++) { 102 | output[i] /= z; 103 | } 104 | } 105 | 106 | void Model::computeOutputSoftmax() { 107 | computeOutputSoftmax(hidden_, output_); 108 | } 109 | 110 | real Model::softmax(int32_t target, real lr) { 111 | grad_.zero(); 112 | computeOutputSoftmax(); 113 | for (int32_t i = 0; i < osz_; i++) { 114 | real label = (i == target) ? 1.0 : 0.0; 115 | real alpha = lr * (label - output_[i]); 116 | grad_.addRow(*wo_, i, alpha); 117 | wo_->addRow(hidden_, i, alpha); 118 | } 119 | return -log(output_[target]); 120 | } 121 | 122 | void Model::computeHidden(const std::vector& input, Vector& hidden) const { 123 | assert(hidden.size() == hsz_); 124 | hidden.zero(); 125 | for (auto it = input.cbegin(); it != input.cend(); ++it) { 126 | if(quant_) { 127 | hidden.addRow(*qwi_, *it); 128 | } else { 129 | hidden.addRow(*wi_, *it); 130 | } 131 | } 132 | hidden.mul(1.0 / input.size()); 133 | } 134 | 135 | bool Model::comparePairs(const std::pair &l, 136 | const std::pair &r) { 137 | return l.first > r.first; 138 | } 139 | 140 | void Model::predict(const std::vector& input, int32_t k, 141 | std::vector>& heap, 142 | Vector& hidden, Vector& output) const { 143 | assert(k > 0); 144 | heap.reserve(k + 1); 145 | computeHidden(input, hidden); 146 | if (args_->loss == loss_name::hs) { 147 | dfs(k, 2 * osz_ - 2, 0.0, heap, hidden); 148 | } else { 149 | findKBest(k, heap, hidden, output); 150 | } 151 | std::sort_heap(heap.begin(), heap.end(), comparePairs); 152 | } 153 | 154 | void Model::predict(const std::vector& input, int32_t k, 155 | std::vector>& heap) { 156 | predict(input, k, heap, hidden_, output_); 157 | } 158 | 159 | void Model::findKBest(int32_t k, std::vector>& heap, 160 | Vector& hidden, Vector& output) const { 161 | computeOutputSoftmax(hidden, output); 162 | for (int32_t i = 0; i < osz_; i++) { 163 | if (heap.size() == k && log(output[i]) < heap.front().first) { 164 | continue; 165 | } 166 | heap.push_back(std::make_pair(log(output[i]), i)); 167 | std::push_heap(heap.begin(), heap.end(), comparePairs); 168 | if (heap.size() > k) { 169 | std::pop_heap(heap.begin(), heap.end(), comparePairs); 170 | heap.pop_back(); 171 | } 172 | } 173 | } 174 | 175 | void Model::dfs(int32_t k, int32_t node, real score, 176 | std::vector>& heap, 177 | Vector& hidden) const { 178 | if (heap.size() == k && score < heap.front().first) { 179 | return; 180 | } 181 | 182 | if (tree[node].left == -1 && tree[node].right == -1) { 183 | heap.push_back(std::make_pair(score, node)); 184 | std::push_heap(heap.begin(), heap.end(), comparePairs); 185 | if (heap.size() > k) { 186 | std::pop_heap(heap.begin(), heap.end(), comparePairs); 187 | heap.pop_back(); 188 | } 189 | return; 190 | } 191 | 192 | real f; 193 | if (quant_ && args_->qout) { 194 | f= sigmoid(qwo_->dotRow(hidden, node - osz_)); 195 | } else { 196 | f= sigmoid(wo_->dotRow(hidden, node - osz_)); 197 | } 198 | 199 | dfs(k, tree[node].left, score + log(1.0 - f), heap, hidden); 200 | dfs(k, tree[node].right, score + log(f), heap, hidden); 201 | } 202 | 203 | void Model::update(const std::vector& input, int32_t target, real lr) { 204 | assert(target >= 0); 205 | assert(target < osz_); 206 | if (input.size() == 0) return; 207 | computeHidden(input, hidden_); 208 | if (args_->loss == loss_name::ns) { 209 | loss_ += negativeSampling(target, lr); 210 | } else if (args_->loss == loss_name::hs) { 211 | loss_ += hierarchicalSoftmax(target, lr); 212 | } else { 213 | loss_ += softmax(target, lr); 214 | } 215 | nexamples_ += 1; 216 | 217 | if (args_->model == model_name::bisent2vec) { 218 | grad_.mul(1.0 / input.size()); 219 | } 220 | for (auto it = input.cbegin(); it != input.cend(); ++it) { 221 | wi_->addRow(grad_, *it, 1.0); 222 | } 223 | } 224 | 225 | void Model::update(const std::vector& input, int32_t target, real lr, real boostNgrams, int32_t lowIdx, int32_t hiIdx) { 226 | assert(target >= 0); 227 | assert(target < osz_); 228 | if (input.size() == 0) return; 229 | computeHidden(input, hidden_); 230 | if (args_->loss == loss_name::ns) { 231 | loss_ += negativeSampling(target, lr); 232 | } else if (args_->loss == loss_name::hs) { 233 | loss_ += hierarchicalSoftmax(target, lr); 234 | } else { 235 | loss_ += softmax(target, lr); 236 | } 237 | nexamples_ += 1; 238 | 239 | //if (args_->model == model_name::sup || args_->model == model_name::bisent2vec) { 240 | // grad_.mul(1.0 / input.size()); 241 | //} 242 | for (auto it = input.cbegin(); it != input.cend(); ++it) { 243 | if (*it > lowIdx && *it < hiIdx) { 244 | wi_->addRow(grad_, *it, boostNgrams); 245 | } else { 246 | wi_->addRow(grad_, *it, 1.0); 247 | } 248 | } 249 | } 250 | 251 | 252 | void Model::setTargetCounts(const std::vector& counts) { 253 | assert(counts.size() == osz_); 254 | if (args_->loss == loss_name::ns) { 255 | initTableNegatives(counts); 256 | } 257 | if (args_->loss == loss_name::hs) { 258 | buildTree(counts); 259 | } 260 | } 261 | 262 | void Model::initTableNegatives(const std::vector& counts) { 263 | real z = 0.0; 264 | for (size_t i = 0; i < counts.size(); i++) { 265 | z += pow(counts[i], 0.5); 266 | } 267 | for (size_t i = 0; i < counts.size(); i++) { 268 | real c = pow(counts[i], 0.5); 269 | for (size_t j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) { 270 | negatives.push_back(i); 271 | } 272 | } 273 | std::shuffle(negatives.begin(), negatives.end(), rng); 274 | } 275 | 276 | int32_t Model::getNegative(int32_t target) { 277 | int32_t negative; 278 | do { 279 | negative = negatives[negpos]; 280 | negpos = (negpos + 1) % negatives.size(); 281 | } while (target == negative); 282 | return negative; 283 | } 284 | 285 | void Model::buildTree(const std::vector& counts) { 286 | tree.resize(2 * osz_ - 1); 287 | for (int32_t i = 0; i < 2 * osz_ - 1; i++) { 288 | tree[i].parent = -1; 289 | tree[i].left = -1; 290 | tree[i].right = -1; 291 | tree[i].count = 1e15; 292 | tree[i].binary = false; 293 | } 294 | for (int32_t i = 0; i < osz_; i++) { 295 | tree[i].count = counts[i]; 296 | } 297 | int32_t leaf = osz_ - 1; 298 | int32_t node = osz_; 299 | for (int32_t i = osz_; i < 2 * osz_ - 1; i++) { 300 | int32_t mini[2]; 301 | for (int32_t j = 0; j < 2; j++) { 302 | if (leaf >= 0 && tree[leaf].count < tree[node].count) { 303 | mini[j] = leaf--; 304 | } else { 305 | mini[j] = node++; 306 | } 307 | } 308 | tree[i].left = mini[0]; 309 | tree[i].right = mini[1]; 310 | tree[i].count = tree[mini[0]].count + tree[mini[1]].count; 311 | tree[mini[0]].parent = i; 312 | tree[mini[1]].parent = i; 313 | tree[mini[1]].binary = true; 314 | } 315 | for (int32_t i = 0; i < osz_; i++) { 316 | std::vector path; 317 | std::vector code; 318 | int32_t j = i; 319 | while (tree[j].parent != -1) { 320 | path.push_back(tree[j].parent - osz_); 321 | code.push_back(tree[j].binary); 322 | j = tree[j].parent; 323 | } 324 | paths.push_back(path); 325 | codes.push_back(code); 326 | } 327 | } 328 | 329 | real Model::getLoss() const { 330 | return loss_ / nexamples_; 331 | } 332 | 333 | void Model::initSigmoid() { 334 | t_sigmoid = new real[SIGMOID_TABLE_SIZE + 1]; 335 | for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) { 336 | real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID; 337 | t_sigmoid[i] = 1.0 / (1.0 + std::exp(-x)); 338 | } 339 | } 340 | 341 | void Model::initLog() { 342 | t_log = new real[LOG_TABLE_SIZE + 1]; 343 | for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) { 344 | real x = (real(i) + 1e-5) / LOG_TABLE_SIZE; 345 | t_log[i] = std::log(x); 346 | } 347 | } 348 | 349 | real Model::log(real x) const { 350 | if (x > 1.0) { 351 | return 0.0; 352 | } 353 | int i = int(x * LOG_TABLE_SIZE); 354 | return t_log[i]; 355 | } 356 | 357 | real Model::sigmoid(real x) const { 358 | if (x < -MAX_SIGMOID) { 359 | return 0.0; 360 | } else if (x > MAX_SIGMOID) { 361 | return 1.0; 362 | } else { 363 | int i = int((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2); 364 | return t_sigmoid[i]; 365 | } 366 | } 367 | 368 | } 369 | -------------------------------------------------------------------------------- /src/model.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_MODEL_H 11 | #define FASTTEXT_MODEL_H 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "args.h" 19 | #include "matrix.h" 20 | #include "vector.h" 21 | #include "qmatrix.h" 22 | #include "real.h" 23 | 24 | #define SIGMOID_TABLE_SIZE 512 25 | #define MAX_SIGMOID 8 26 | #define LOG_TABLE_SIZE 512 27 | 28 | namespace fasttext { 29 | 30 | struct Node { 31 | int32_t parent; 32 | int32_t left; 33 | int32_t right; 34 | int64_t count; 35 | bool binary; 36 | }; 37 | 38 | class Model { 39 | private: 40 | std::shared_ptr wi_; 41 | std::shared_ptr wo_; 42 | std::shared_ptr qwi_; 43 | std::shared_ptr qwo_; 44 | std::shared_ptr args_; 45 | Vector hidden_; 46 | Vector output_; 47 | Vector grad_; 48 | int32_t hsz_; 49 | int32_t isz_; 50 | int32_t osz_; 51 | real loss_; 52 | int64_t nexamples_; 53 | real* t_sigmoid; 54 | real* t_log; 55 | // used for negative sampling: 56 | std::vector negatives; 57 | size_t negpos; 58 | // used for hierarchical softmax: 59 | std::vector< std::vector > paths; 60 | std::vector< std::vector > codes; 61 | std::vector tree; 62 | 63 | static bool comparePairs(const std::pair&, 64 | const std::pair&); 65 | 66 | int32_t getNegative(int32_t target); 67 | void initSigmoid(); 68 | void initLog(); 69 | 70 | static const int32_t NEGATIVE_TABLE_SIZE = 10000000; 71 | 72 | public: 73 | Model(std::shared_ptr, std::shared_ptr, 74 | std::shared_ptr, int32_t); 75 | ~Model(); 76 | 77 | real binaryLogistic(int32_t, bool, real); 78 | real negativeSampling(int32_t, real); 79 | real hierarchicalSoftmax(int32_t, real); 80 | real softmax(int32_t, real); 81 | 82 | void predict(const std::vector&, int32_t, 83 | std::vector>&, 84 | Vector&, Vector&) const; 85 | void predict(const std::vector&, int32_t, 86 | std::vector>&); 87 | void dfs(int32_t, int32_t, real, 88 | std::vector>&, 89 | Vector&) const; 90 | void findKBest(int32_t, std::vector>&, 91 | Vector&, Vector&) const; 92 | void update(const std::vector&, int32_t, real); 93 | void update(const std::vector&, int32_t, real, real, int32_t, int32_t); 94 | void computeHidden(const std::vector&, Vector&) const; 95 | void computeOutputSoftmax(Vector&, Vector&) const; 96 | void computeOutputSoftmax(); 97 | 98 | void setTargetCounts(const std::vector&); 99 | void initTableNegatives(const std::vector&); 100 | void buildTree(const std::vector&); 101 | real getLoss() const; 102 | real sigmoid(real) const; 103 | real log(real) const; 104 | 105 | std::minstd_rand rng; 106 | bool quant_; 107 | void setQuantizePointer(std::shared_ptr, std::shared_ptr, bool); 108 | }; 109 | 110 | } 111 | 112 | #endif 113 | -------------------------------------------------------------------------------- /src/productquantizer.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "productquantizer.h" 11 | 12 | #include 13 | #include 14 | 15 | namespace fasttext { 16 | 17 | real distL2(const real* x, const real* y, int32_t d) { 18 | real dist = 0; 19 | for (auto i = 0; i < d; i++) { 20 | auto tmp = x[i] - y[i]; 21 | dist += tmp * tmp; 22 | } 23 | return dist; 24 | } 25 | 26 | ProductQuantizer::ProductQuantizer(int32_t dim, int32_t dsub): dim_(dim), 27 | nsubq_(dim / dsub), dsub_(dsub), centroids_(dim * ksub_), rng(seed_) { 28 | lastdsub_ = dim_ % dsub; 29 | if (lastdsub_ == 0) {lastdsub_ = dsub_;} 30 | else {nsubq_++;} 31 | } 32 | 33 | const real* ProductQuantizer::get_centroids(int32_t m, uint8_t i) const { 34 | if (m == nsubq_ - 1) {return ¢roids_[m * ksub_ * dsub_ + i * lastdsub_];} 35 | return ¢roids_[(m * ksub_ + i) * dsub_]; 36 | } 37 | 38 | real* ProductQuantizer::get_centroids(int32_t m, uint8_t i) { 39 | if (m == nsubq_ - 1) {return ¢roids_[m * ksub_ * dsub_ + i * lastdsub_];} 40 | return ¢roids_[(m * ksub_ + i) * dsub_]; 41 | } 42 | 43 | real ProductQuantizer::assign_centroid(const real * x, const real* c0, 44 | uint8_t* code, int32_t d) const { 45 | const real* c = c0; 46 | real dis = distL2(x, c, d); 47 | code[0] = 0; 48 | for (auto j = 1; j < ksub_; j++) { 49 | c += d; 50 | real disij = distL2(x, c, d); 51 | if (disij < dis) { 52 | code[0] = (uint8_t) j; 53 | dis = disij; 54 | } 55 | } 56 | return dis; 57 | } 58 | 59 | void ProductQuantizer::Estep(const real* x, const real* centroids, 60 | uint8_t* codes, int32_t d, 61 | int32_t n) const { 62 | for (auto i = 0; i < n; i++) { 63 | assign_centroid(x + i * d, centroids, codes + i, d); 64 | } 65 | } 66 | 67 | void ProductQuantizer::MStep(const real* x0, real* centroids, 68 | const uint8_t* codes, 69 | int32_t d, int32_t n) { 70 | std::vector nelts(ksub_, 0); 71 | memset(centroids, 0, sizeof(real) * d * ksub_); 72 | const real* x = x0; 73 | for (auto i = 0; i < n; i++) { 74 | auto k = codes[i]; 75 | real* c = centroids + k * d; 76 | for (auto j = 0; j < d; j++) { 77 | c[j] += x[j]; 78 | } 79 | nelts[k]++; 80 | x += d; 81 | } 82 | 83 | real* c = centroids; 84 | for (auto k = 0; k < ksub_; k++) { 85 | real z = (real) nelts[k]; 86 | if (z != 0) { 87 | for (auto j = 0; j < d; j++) { 88 | c[j] /= z; 89 | } 90 | } 91 | c += d; 92 | } 93 | 94 | std::uniform_real_distribution<> runiform(0,1); 95 | for (auto k = 0; k < ksub_; k++) { 96 | if (nelts[k] == 0) { 97 | int32_t m = 0; 98 | while (runiform(rng) * (n - ksub_) >= nelts[m] - 1) { 99 | m = (m + 1) % ksub_; 100 | } 101 | memcpy(centroids + k * d, centroids + m * d, sizeof(real) * d); 102 | for (auto j = 0; j < d; j++) { 103 | int32_t sign = (j % 2) * 2 - 1; 104 | centroids[k * d + j] += sign * eps_; 105 | centroids[m * d + j] -= sign * eps_; 106 | } 107 | nelts[k] = nelts[m] / 2; 108 | nelts[m] -= nelts[k]; 109 | } 110 | } 111 | } 112 | 113 | void ProductQuantizer::kmeans(const real *x, real* c, int32_t n, int32_t d) { 114 | std::vector perm(n,0); 115 | std::iota(perm.begin(), perm.end(), 0); 116 | std::shuffle(perm.begin(), perm.end(), rng); 117 | for (auto i = 0; i < ksub_; i++) { 118 | memcpy (&c[i * d], x + perm[i] * d, d * sizeof(real)); 119 | } 120 | uint8_t* codes = new uint8_t[n]; 121 | for (auto i = 0; i < niter_; i++) { 122 | Estep(x, c, codes, d, n); 123 | MStep(x, c, codes, d, n); 124 | } 125 | delete [] codes; 126 | } 127 | 128 | void ProductQuantizer::train(int32_t n, const real * x) { 129 | if (n < ksub_) { 130 | std::cerr<<"Matrix too small for quantization, must have > 256 rows"< perm(n, 0); 134 | std::iota(perm.begin(), perm.end(), 0); 135 | auto d = dsub_; 136 | auto np = std::min(n, max_points_); 137 | real* xslice = new real[np * dsub_]; 138 | for (auto m = 0; m < nsubq_; m++) { 139 | if (m == nsubq_-1) {d = lastdsub_;} 140 | if (np != n) {std::shuffle(perm.begin(), perm.end(), rng);} 141 | for (auto j = 0; j < np; j++) { 142 | memcpy (xslice + j * d, x + perm[j] * dim_ + m * dsub_, d * sizeof(real)); 143 | } 144 | kmeans(xslice, get_centroids(m, 0), np, d); 145 | } 146 | delete [] xslice; 147 | } 148 | 149 | real ProductQuantizer::mulcode(const Vector& x, const uint8_t* codes, 150 | int32_t t, real alpha) const { 151 | real res = 0.0; 152 | auto d = dsub_; 153 | const uint8_t* code = codes + nsubq_ * t; 154 | for (auto m = 0; m < nsubq_; m++) { 155 | const real* c = get_centroids(m, code[m]); 156 | if (m == nsubq_ - 1) {d = lastdsub_;} 157 | for(auto n = 0; n < d; n++) { 158 | res += x[m * dsub_ + n] * c[n]; 159 | } 160 | } 161 | return res * alpha; 162 | } 163 | 164 | void ProductQuantizer::addcode(Vector& x, const uint8_t* codes, 165 | int32_t t, real alpha) const { 166 | auto d = dsub_; 167 | const uint8_t* code = codes + nsubq_ * t; 168 | for (auto m = 0; m < nsubq_; m++) { 169 | const real* c = get_centroids(m, code[m]); 170 | if (m == nsubq_ - 1) {d = lastdsub_;} 171 | for(auto n = 0; n < d; n++) { 172 | x[m * dsub_ + n] += alpha * c[n]; 173 | } 174 | } 175 | } 176 | 177 | void ProductQuantizer::compute_code(const real* x, uint8_t* code) const { 178 | auto d = dsub_; 179 | for (auto m = 0; m < nsubq_; m++) { 180 | if (m == nsubq_ - 1) {d = lastdsub_;} 181 | assign_centroid(x + m * dsub_, get_centroids(m, 0), code + m, d); 182 | } 183 | } 184 | 185 | void ProductQuantizer::compute_codes(const real* x, uint8_t* codes, 186 | int32_t n) const { 187 | for (auto i = 0; i < n; i++) { 188 | compute_code(x + i * dim_, codes + i * nsubq_); 189 | } 190 | } 191 | 192 | void ProductQuantizer::save(std::ostream& out) { 193 | out.write((char*) &dim_, sizeof(dim_)); 194 | out.write((char*) &nsubq_, sizeof(nsubq_)); 195 | out.write((char*) &dsub_, sizeof(dsub_)); 196 | out.write((char*) &lastdsub_, sizeof(lastdsub_)); 197 | out.write((char*) centroids_.data(), centroids_.size() * sizeof(real)); 198 | } 199 | 200 | void ProductQuantizer::load(std::istream& in) { 201 | in.read((char*) &dim_, sizeof(dim_)); 202 | in.read((char*) &nsubq_, sizeof(nsubq_)); 203 | in.read((char*) &dsub_, sizeof(dsub_)); 204 | in.read((char*) &lastdsub_, sizeof(lastdsub_)); 205 | centroids_.resize(dim_ * ksub_); 206 | for (auto i=0; i < centroids_.size(); i++) { 207 | in.read((char*) ¢roids_[i], sizeof(real)); 208 | } 209 | } 210 | 211 | } 212 | -------------------------------------------------------------------------------- /src/productquantizer.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_PRODUCT_QUANTIZER_H 11 | #define FASTTEXT_PRODUCT_QUANTIZER_H 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "real.h" 20 | #include "vector.h" 21 | 22 | namespace fasttext { 23 | 24 | class ProductQuantizer { 25 | private: 26 | const int32_t nbits_ = 8; 27 | const int32_t ksub_ = 1 << nbits_; 28 | const int32_t max_points_per_cluster_ = 256; 29 | const int32_t max_points_ = max_points_per_cluster_ * ksub_; 30 | const int32_t seed_ = 1234; 31 | const int32_t niter_ = 25; 32 | const real eps_ = 1e-7; 33 | 34 | int32_t dim_; 35 | int32_t nsubq_; 36 | int32_t dsub_; 37 | int32_t lastdsub_; 38 | 39 | std::vector centroids_; 40 | 41 | std::minstd_rand rng; 42 | 43 | public: 44 | ProductQuantizer() {} 45 | ProductQuantizer(int32_t, int32_t); 46 | 47 | real* get_centroids (int32_t, uint8_t); 48 | const real* get_centroids(int32_t, uint8_t) const; 49 | 50 | real assign_centroid(const real*, const real*, uint8_t*, int32_t) const; 51 | void Estep(const real*, const real*, uint8_t*, int32_t, int32_t) const; 52 | void MStep(const real*, real*, const uint8_t*, int32_t, int32_t); 53 | void kmeans(const real*, real*, int32_t, int32_t); 54 | void train(int, const real*); 55 | 56 | real mulcode(const Vector&, const uint8_t*, int32_t, real) const; 57 | void addcode(Vector&, const uint8_t*, int32_t, real) const; 58 | void compute_code(const real*, uint8_t*) const; 59 | void compute_codes(const real*, uint8_t*, int32_t) const; 60 | 61 | void save(std::ostream&); 62 | void load(std::istream&); 63 | }; 64 | 65 | } 66 | 67 | #endif 68 | -------------------------------------------------------------------------------- /src/qmatrix.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "qmatrix.h" 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | namespace fasttext { 17 | 18 | QMatrix::QMatrix() : qnorm_(false), 19 | m_(0), n_(0), codesize_(0) {} 20 | 21 | QMatrix::QMatrix(const Matrix& mat, int32_t dsub, bool qnorm) 22 | : qnorm_(qnorm), m_(mat.m_), n_(mat.n_), 23 | codesize_(m_ * std::ceil(n_ / dsub)) { 24 | codes_ = new uint8_t[codesize_]; 25 | pq_ = std::unique_ptr( new ProductQuantizer(n_, dsub)); 26 | if (qnorm_) { 27 | norm_codes_ = new uint8_t[m_]; 28 | npq_ = std::unique_ptr( new ProductQuantizer(1, 1)); 29 | } 30 | quantize(mat); 31 | } 32 | 33 | QMatrix::~QMatrix() { 34 | if (codesize_) { delete[] codes_; } 35 | if (qnorm_) { delete[] norm_codes_; } 36 | } 37 | 38 | void QMatrix::quantizeNorm(const Vector& norms) { 39 | assert(qnorm_); 40 | assert(norms.m_ == m_); 41 | auto dataptr = norms.data_; 42 | npq_->train(m_, dataptr); 43 | npq_->compute_codes(dataptr, norm_codes_, m_); 44 | } 45 | 46 | void QMatrix::quantize(const Matrix& matrix) { 47 | assert(n_ == matrix.n_); 48 | assert(m_ == matrix.m_); 49 | Matrix temp(matrix); 50 | if (qnorm_) { 51 | Vector norms(temp.m_); 52 | temp.l2NormRow(norms); 53 | temp.divideRow(norms); 54 | quantizeNorm(norms); 55 | } 56 | auto dataptr = temp.data_; 57 | pq_->train(m_, dataptr); 58 | pq_->compute_codes(dataptr, codes_, m_); 59 | } 60 | 61 | void QMatrix::addToVector(Vector& x, int32_t t) const { 62 | real norm = 1; 63 | if (qnorm_) { 64 | norm = npq_->get_centroids(0, norm_codes_[t])[0]; 65 | } 66 | pq_->addcode(x, codes_, t, norm); 67 | } 68 | 69 | real QMatrix::dotRow(const Vector& vec, int64_t i) const { 70 | assert(i >= 0); 71 | assert(i < m_); 72 | assert(vec.size() == n_); 73 | real norm = 1; 74 | if (qnorm_) { 75 | norm = npq_->get_centroids(0, norm_codes_[i])[0]; 76 | } 77 | return pq_->mulcode(vec, codes_, i, norm); 78 | } 79 | 80 | int64_t QMatrix::getM() const { 81 | return m_; 82 | } 83 | 84 | int64_t QMatrix::getN() const { 85 | return n_; 86 | } 87 | 88 | void QMatrix::save(std::ostream& out) { 89 | out.write((char*) &qnorm_, sizeof(qnorm_)); 90 | out.write((char*) &m_, sizeof(m_)); 91 | out.write((char*) &n_, sizeof(n_)); 92 | out.write((char*) &codesize_, sizeof(codesize_)); 93 | out.write((char*) codes_, codesize_ * sizeof(uint8_t)); 94 | pq_->save(out); 95 | if (qnorm_) { 96 | out.write((char*) norm_codes_, m_ * sizeof(uint8_t)); 97 | npq_->save(out); 98 | } 99 | } 100 | 101 | void QMatrix::load(std::istream& in) { 102 | in.read((char*) &qnorm_, sizeof(qnorm_)); 103 | in.read((char*) &m_, sizeof(m_)); 104 | in.read((char*) &n_, sizeof(n_)); 105 | in.read((char*) &codesize_, sizeof(codesize_)); 106 | codes_ = new uint8_t[codesize_]; 107 | in.read((char*) codes_, codesize_ * sizeof(uint8_t)); 108 | pq_ = std::unique_ptr( new ProductQuantizer()); 109 | pq_->load(in); 110 | if (qnorm_) { 111 | norm_codes_ = new uint8_t[m_]; 112 | in.read((char*) norm_codes_, m_ * sizeof(uint8_t)); 113 | npq_ = std::unique_ptr( new ProductQuantizer()); 114 | npq_->load(in); 115 | } 116 | } 117 | 118 | } 119 | -------------------------------------------------------------------------------- /src/qmatrix.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_QMATRIX_H 11 | #define FASTTEXT_QMATRIX_H 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | 20 | #include "real.h" 21 | 22 | #include "matrix.h" 23 | #include "vector.h" 24 | 25 | #include "productquantizer.h" 26 | 27 | namespace fasttext { 28 | 29 | class QMatrix { 30 | private: 31 | std::unique_ptr pq_; 32 | std::unique_ptr npq_; 33 | 34 | uint8_t* codes_; 35 | uint8_t* norm_codes_; 36 | 37 | bool qnorm_; 38 | 39 | int64_t m_; 40 | int64_t n_; 41 | 42 | int32_t codesize_; 43 | 44 | public: 45 | 46 | QMatrix(); 47 | QMatrix(const Matrix&, int32_t, bool); 48 | ~QMatrix(); 49 | 50 | int64_t getM() const; 51 | int64_t getN() const; 52 | 53 | void quantizeNorm(const Vector&); 54 | void quantize(const Matrix&); 55 | 56 | void addToVector(Vector& x, int32_t t) const; 57 | real dotRow(const Vector&, int64_t) const; 58 | 59 | void save(std::ostream&); 60 | void load(std::istream&); 61 | }; 62 | 63 | } 64 | 65 | #endif 66 | -------------------------------------------------------------------------------- /src/real.cc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/Bi-Sent2Vec/b3bf758f86a7006112a7a65bf9955d88ef3465d7/src/real.cc -------------------------------------------------------------------------------- /src/real.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_REAL_H 11 | #define FASTTEXT_REAL_H 12 | 13 | namespace fasttext { 14 | 15 | typedef float real; 16 | 17 | } 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /src/sent2vec.pyx: -------------------------------------------------------------------------------- 1 | import os.path 2 | import subprocess 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | cimport numpy as cnp 7 | 8 | from libcpp cimport bool 9 | from libcpp.string cimport string 10 | from libcpp.vector cimport vector 11 | from libc.stdint cimport int64_t 12 | 13 | #from libc.stdlib cimport free 14 | #from cpython cimport PyObject, Py_INCREF 15 | 16 | cnp.import_array() 17 | 18 | cdef extern from "fasttext.h" namespace "fasttext": 19 | 20 | cdef cppclass FastText: 21 | FastText() except + 22 | void loadModel(const string&, bool, int) 23 | void textVector(string, vector[float]&) 24 | void textVectors(vector[string]&, int, vector[float])#&) 25 | int getDimension() 26 | vector[string]& getVocab() 27 | vector[int64_t]& getUnigramsCounts() 28 | 29 | 30 | cdef extern from "asvoid.h": 31 | void *asvoid(vector[float] *buf) 32 | 33 | 34 | class stdvector_base: 35 | pass 36 | 37 | 38 | cdef class vector_wrapper: 39 | cdef: 40 | vector[float] *buf 41 | 42 | def __cinit__(vector_wrapper self, n): 43 | self.buf = NULL 44 | 45 | def __init__(vector_wrapper self, cnp.intp_t n): 46 | self.buf = new vector[float](n) 47 | 48 | def __dealloc__(vector_wrapper self): 49 | if self.buf != NULL: 50 | del self.buf 51 | 52 | def asarray(vector_wrapper self, cnp.intp_t n): 53 | """ 54 | Interpret the vector as np.ndarray without 55 | copying the data. 56 | """ 57 | base = stdvector_base() 58 | intbuf = asvoid(self.buf) 59 | dtype = np.dtype(np.float32) 60 | base.__array_interface__ = dict( 61 | data = (intbuf, False), 62 | descr = dtype.descr, 63 | shape = (n,), 64 | strides = (dtype.itemsize,), 65 | typestr = dtype.str, 66 | version = 3, 67 | ) 68 | base.vector_wrapper = self 69 | return np.asarray(base) 70 | 71 | 72 | cdef class Sent2vecModel: 73 | 74 | cdef FastText* _thisptr 75 | 76 | def __cinit__(self): 77 | self._thisptr = new FastText() 78 | 79 | def __dealloc__(self): 80 | del self._thisptr 81 | 82 | def __init__(self): 83 | pass 84 | 85 | def get_emb_size(self): 86 | return self._thisptr.getDimension() 87 | 88 | def load_model(self, model_path, inference_mode=False, timeout_sec=-1): 89 | cdef string cmodel_path = model_path.encode('utf-8', 'ignore'); 90 | cdef bool cinference_mode = inference_mode 91 | cdef int ctimeout_sec = timeout_sec 92 | self._thisptr.loadModel(cmodel_path, cinference_mode, ctimeout_sec) 93 | 94 | def embed_sentences(self, sentences, num_threads=1): 95 | if num_threads <= 0: 96 | num_threads = 1 97 | cdef vector[string] csentences 98 | cdef int cnum_threads = num_threads 99 | for s in sentences: 100 | csentences.push_back(s.encode('utf-8', 'ignore')); 101 | cdef vector_wrapper array 102 | w = vector_wrapper(len(sentences) * self.get_emb_size()) 103 | self._thisptr.textVectors(csentences, cnum_threads, w.buf[0]) 104 | final = w.asarray(len(sentences) * self.get_emb_size()) 105 | return final.reshape(len(sentences), self.get_emb_size()) 106 | 107 | def embed_sentence(self, sentence, num_threads=1): 108 | return self.embed_sentences([sentence], num_threads) 109 | 110 | def get_vocabulary(self): 111 | vocab = list(self._thisptr.getVocab()) 112 | vocab = [w.decode('utf-8', 'ignore') for w in vocab] 113 | freqs = list(self._thisptr.getUnigramsCounts()) 114 | assert len(vocab) == len(freqs) 115 | return OrderedDict(zip(vocab, freqs)) 116 | 117 | def get_unigram_embeddings(self): 118 | vocab = [w for w, c in self.get_vocabulary().items()] 119 | return self.embed_sentences(vocab), vocab 120 | 121 | def embed_unigrams(self, unigrams): 122 | assert all(len(w.split(' ')) == 1 for w in unigrams) 123 | return self.embed_sentences(unigrams) 124 | 125 | @staticmethod 126 | def release_shared_mem(model_path): 127 | model_basename = os.path.splitext(os.path.basename(model_path))[0] 128 | shm_path = ''.join(['/dev/shm/', 's2v_', model_basename, '_input_matrix']) 129 | subprocess.run(f'unlink {shm_path}.init', shell=True) 130 | subprocess.run(f'unlink {shm_path}', shell=True) 131 | -------------------------------------------------------------------------------- /src/shmem_matrix.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "shmem_matrix.h" 19 | 20 | namespace fasttext { 21 | 22 | /* Note: The shared memory segment is created when loading the input matrix for 23 | * the first time. It is not unlinked by the code, because it is hard to do 24 | * this from the code. It is safe to unlink the segment once every interested 25 | * process has opened it, but it is impossible to know this without some kind 26 | * of interprocess synchronization. Therefore, it is left to the user to unlink 27 | * it at some point when it is no longer needed. 28 | **/ 29 | 30 | ShmemMatrix::ShmemMatrix(const char* name, const int64_t m, const int64_t n, const int timeout_sec) { 31 | m_ = m; 32 | n_ = n; 33 | 34 | // Open an existing shared memory segment (keep retrying until timeout expires) 35 | int fd = -1; 36 | int waited_sec = 0; 37 | while (true) { 38 | fd = shm_open(name, O_RDONLY, 0444); 39 | if (fd != -1) break; 40 | if (errno == ENOENT) { 41 | if (timeout_sec != -1 && waited_sec >= timeout_sec) { 42 | fprintf(stderr, "ERROR ShmemMatrix::ShmemMatrix: timeout expired\n"); 43 | exit(-1); 44 | } else { 45 | sleep(1); 46 | waited_sec += 1; 47 | } 48 | } else { 49 | perror("ERROR ShmemMatrix::ShmemMatrix: shm_open failed"); 50 | exit(-1); 51 | } 52 | } 53 | 54 | // Map the shared memory segment 55 | size_t size = m_ * n_ * sizeof(real); 56 | void* ptr = mmap(NULL, size, PROT_READ, MAP_SHARED, fd, 0); 57 | if (ptr == (void*)-1) { 58 | perror("ERROR ShmemMatrix::ShmemMatrix: mmap failed"); 59 | exit(-1); 60 | } else { 61 | data_ = (real*)ptr; 62 | } 63 | 64 | // Close the file descriptor 65 | int ret = close(fd); 66 | if (ret == -1) { 67 | perror("ERROR ShmemMatrix::ShmemMatrix: close failed"); 68 | exit(-1); 69 | } 70 | } 71 | 72 | ShmemMatrix::~ShmemMatrix() { 73 | // Unmap the shared memory segment 74 | size_t size = m_ * n_ * sizeof(real); 75 | int ret = munmap((void*)data_, size); 76 | if (ret == -1) { 77 | perror("ERROR ShmemMatrix::~ShmemMatrix: munmap failed"); 78 | exit(-1); 79 | } 80 | data_ = nullptr; 81 | } 82 | 83 | std::shared_ptr ShmemMatrix::load(std::istream& in, 84 | const std::string& name, 85 | const int timeout_sec) { 86 | std::string init_name = name + ".init"; 87 | 88 | int64_t m, n; 89 | in.read((char*)&m, sizeof(int64_t)); 90 | in.read((char*)&n, sizeof(int64_t)); 91 | size_t size = m * n * sizeof(real); 92 | 93 | // Create a shared memory segment to be initialized with the input matrix 94 | bool new_segment = true; 95 | int fd = shm_open(init_name.c_str(), O_RDWR | O_CREAT | O_EXCL, 0444); 96 | if (fd == -1) { 97 | if (errno == EEXIST) { 98 | new_segment = false; 99 | } else { 100 | perror("ERROR ShmemMatrix::load: shm_open failed"); 101 | exit(-1); 102 | } 103 | } 104 | 105 | if (new_segment) { 106 | // Set the size for shared memory segment 107 | int ret = ftruncate(fd, size); 108 | if (ret == -1) { 109 | perror("ERROR ShmemMatrix::load: ftruncate failed"); 110 | exit(-1); 111 | } 112 | 113 | // Map the shared memory segment 114 | void* ptr = mmap(NULL, size, PROT_WRITE, MAP_SHARED, fd, 0); 115 | if (ptr == (void*)-1) { 116 | perror("ERROR ShmemMatrix::load: mmap failed"); 117 | exit(-1); 118 | } 119 | 120 | // Close the file descriptor 121 | ret = close(fd); 122 | if (ret == -1) { 123 | perror("ERROR ShmemMatrix::load: close failed"); 124 | exit(-1); 125 | } 126 | 127 | // Populate the shared memory segment 128 | in.read((char*)ptr, size); 129 | 130 | // Unmap the shared memory segment 131 | ret = munmap(ptr, size); 132 | if (ret == -1) { 133 | perror("ERROR ShmemMatrix::load: munmap failed"); 134 | exit(-1); 135 | } 136 | 137 | // Atomically link to the expected name 138 | std::string init_name_path = "/dev/shm/" + init_name; 139 | std::string name_path = "/dev/shm/" + name; 140 | ret = link(init_name_path.c_str(), name_path.c_str()); 141 | if (ret == -1) { 142 | perror("ERROR ShmemMatrix::load: link failed"); 143 | exit(-1); 144 | } 145 | } else { 146 | // Seek in the stream to skip the input matrix data 147 | in.seekg(size, in.cur); 148 | } 149 | 150 | return std::make_shared(name.c_str(), m, n, timeout_sec); 151 | } 152 | 153 | } 154 | -------------------------------------------------------------------------------- /src/shmem_matrix.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_SHMEM_MATRIX_H 11 | #define FASTTEXT_SHMEM_MATRIX_H 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "matrix.h" 20 | #include "real.h" 21 | 22 | namespace fasttext { 23 | 24 | class ShmemMatrix : public Matrix { 25 | public: 26 | ShmemMatrix(const char*, const int64_t, const int64_t, const int); 27 | ~ShmemMatrix(); 28 | 29 | Matrix& operator=(const Matrix&) = delete; 30 | void save(std::ostream&) = delete; 31 | void load(std::istream&) = delete; 32 | 33 | static std::shared_ptr load(std::istream&, const std::string&, const int); 34 | }; 35 | 36 | } 37 | 38 | #endif 39 | -------------------------------------------------------------------------------- /src/utils.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "utils.h" 11 | 12 | #include 13 | 14 | namespace fasttext { 15 | 16 | namespace utils { 17 | 18 | int64_t size(std::ifstream& ifs) { 19 | ifs.seekg(std::streamoff(0), std::ios::end); 20 | return ifs.tellg(); 21 | } 22 | 23 | void seek(std::ifstream& ifs, int64_t pos) { 24 | ifs.clear(); 25 | ifs.seekg(std::streampos(pos)); 26 | } 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_UTILS_H 11 | #define FASTTEXT_UTILS_H 12 | 13 | #include 14 | 15 | namespace fasttext { 16 | 17 | namespace utils { 18 | 19 | int64_t size(std::ifstream&); 20 | void seek(std::ifstream&, int64_t); 21 | } 22 | 23 | } 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /src/vector.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "vector.h" 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | #include "matrix.h" 18 | #include "qmatrix.h" 19 | 20 | namespace fasttext { 21 | 22 | Vector::Vector(int64_t m) { 23 | m_ = m; 24 | data_ = new real[m]; 25 | } 26 | 27 | Vector::~Vector() { 28 | delete[] data_; 29 | } 30 | 31 | int64_t Vector::size() const { 32 | return m_; 33 | } 34 | 35 | void Vector::zero() { 36 | for (int64_t i = 0; i < m_; i++) { 37 | data_[i] = 0.0; 38 | } 39 | } 40 | 41 | real Vector::norm() const { 42 | real sum = 0; 43 | for (int64_t i = 0; i < m_; i++) { 44 | sum += data_[i] * data_[i]; 45 | } 46 | return std::sqrt(sum); 47 | } 48 | 49 | void Vector::mul(real a) { 50 | for (int64_t i = 0; i < m_; i++) { 51 | data_[i] *= a; 52 | } 53 | } 54 | 55 | void Vector::addVector(const Vector& source) { 56 | assert(m_ == source.m_); 57 | for (int64_t i = 0; i < m_; i++) { 58 | data_[i] += source.data_[i]; 59 | } 60 | } 61 | 62 | void Vector::addVector(const Vector& source, real s) { 63 | assert(m_ == source.m_); 64 | for (int64_t i = 0; i < m_; i++) { 65 | data_[i] += s * source.data_[i]; 66 | } 67 | } 68 | 69 | void Vector::addRow(const Matrix& A, int64_t i) { 70 | assert(i >= 0); 71 | assert(i < A.m_); 72 | assert(m_ == A.n_); 73 | for (int64_t j = 0; j < A.n_; j++) { 74 | data_[j] += A.at(i, j); 75 | } 76 | } 77 | 78 | void Vector::addRow(const Matrix& A, int64_t i, real a) { 79 | assert(i >= 0); 80 | assert(i < A.m_); 81 | assert(m_ == A.n_); 82 | for (int64_t j = 0; j < A.n_; j++) { 83 | data_[j] += a * A.at(i, j); 84 | } 85 | } 86 | 87 | void Vector::addRow(const QMatrix& A, int64_t i) { 88 | assert(i >= 0); 89 | A.addToVector(*this, i); 90 | } 91 | 92 | void Vector::mul(const Matrix& A, const Vector& vec) { 93 | assert(A.m_ == m_); 94 | assert(A.n_ == vec.m_); 95 | for (int64_t i = 0; i < m_; i++) { 96 | data_[i] = A.dotRow(vec, i); 97 | } 98 | } 99 | 100 | void Vector::mul(const QMatrix& A, const Vector& vec) { 101 | assert(A.getM() == m_); 102 | assert(A.getN() == vec.m_); 103 | for (int64_t i = 0; i < m_; i++) { 104 | data_[i] = A.dotRow(vec, i); 105 | } 106 | } 107 | 108 | int64_t Vector::argmax() { 109 | real max = data_[0]; 110 | int64_t argmax = 0; 111 | for (int64_t i = 1; i < m_; i++) { 112 | if (data_[i] > max) { 113 | max = data_[i]; 114 | argmax = i; 115 | } 116 | } 117 | return argmax; 118 | } 119 | 120 | real& Vector::operator[](int64_t i) { 121 | return data_[i]; 122 | } 123 | 124 | const real& Vector::operator[](int64_t i) const { 125 | return data_[i]; 126 | } 127 | 128 | std::ostream& operator<<(std::ostream& os, const Vector& v) 129 | { 130 | os << std::setprecision(5); 131 | for (int64_t j = 0; j < v.m_; j++) { 132 | os << v.data_[j] << ' '; 133 | } 134 | return os; 135 | } 136 | 137 | } 138 | -------------------------------------------------------------------------------- /src/vector.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_VECTOR_H 11 | #define FASTTEXT_VECTOR_H 12 | 13 | #include 14 | #include 15 | 16 | #include "real.h" 17 | 18 | namespace fasttext { 19 | 20 | class Matrix; 21 | class QMatrix; 22 | 23 | class Vector { 24 | 25 | public: 26 | int64_t m_; 27 | real* data_; 28 | 29 | explicit Vector(int64_t); 30 | ~Vector(); 31 | 32 | real& operator[](int64_t); 33 | const real& operator[](int64_t) const; 34 | 35 | int64_t size() const; 36 | void zero(); 37 | void mul(real); 38 | real norm() const; 39 | void addVector(const Vector& source); 40 | void addVector(const Vector&, real); 41 | void addRow(const Matrix&, int64_t); 42 | void addRow(const QMatrix&, int64_t); 43 | void addRow(const Matrix&, int64_t, real); 44 | void mul(const QMatrix&, const Vector&); 45 | void mul(const Matrix&, const Vector&); 46 | int64_t argmax(); 47 | }; 48 | 49 | std::ostream& operator<<(std::ostream&, const Vector&); 50 | 51 | } 52 | 53 | #endif 54 | -------------------------------------------------------------------------------- /vectors_by_lang.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import codecs 3 | import sys, getopt 4 | import csv 5 | import os 6 | import gzip, pickle 7 | import sys 8 | from scipy import stats 9 | import argparse 10 | 11 | def divide_vectors(vec_file): 12 | file_contents=open(vec_file,'r',encoding='utf-8-sig') 13 | vectors = list() 14 | id2word = dict() 15 | num_vec = dict() 16 | word_count = 0 17 | for i,line in enumerate(file_contents): 18 | entries = line.split(" ",1) 19 | if i==0: 20 | num_words = int(entries[0]) 21 | dim = int(entries[1]) 22 | else: 23 | if "_" not in entries[0]: 24 | continue 25 | word = entries[0][:-3] 26 | lang = entries[0][-2:] 27 | id2word[len(vectors)] = (lang,word) 28 | if lang not in num_vec: 29 | num_vec[lang]=0 30 | num_vec[lang]+=1 31 | vectors.append(entries[1]) 32 | if i%100000==0: 33 | print(str(i) + " words loaded") 34 | file_contents.close() 35 | 36 | file_contents_output = dict() 37 | print("Writing vectors") 38 | 39 | for lang in num_vec: 40 | file_contents_output[lang] = open(vec_file[:-4] + "_" + lang + ".vec",'w',encoding='utf-8-sig') 41 | file_contents_output[lang].write(str(num_vec[lang]) + " " + str(dim) + "\n") 42 | 43 | for i,vector in enumerate(vectors): 44 | file_contents_output[id2word[i][0]].write(id2word[i][1] + " " + vector) 45 | 46 | for lang in num_vec: 47 | file_contents_output[lang].close() 48 | 49 | return 50 | 51 | if __name__ == "__main__": 52 | 53 | parser = argparse.ArgumentParser(description='Parameters for vector separation by language') 54 | 55 | parser.add_argument('--vector_file', action='store', type=str, 56 | help='vector file location') 57 | 58 | args = parser.parse_args() 59 | 60 | divide_vectors(args.vector_file) 61 | --------------------------------------------------------------------------------