├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── PATENTS ├── README.md ├── eval.py ├── src ├── args.cc ├── args.h ├── dictionary.cc ├── dictionary.h ├── fasttext.cc ├── matrix.cc ├── matrix.h ├── model.cc ├── model.h ├── real.h ├── utils.cc ├── utils.h ├── vector.cc └── vector.h └── wikifil.pl /.gitignore: -------------------------------------------------------------------------------- 1 | .*.swp 2 | *.o 3 | *.bin 4 | *.vec 5 | data 6 | fasttext 7 | result 8 | .DS_Store -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributing to fastText 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `master`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Facebook's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | ## License 27 | By contributing to fastText, you agree that your contributions will be licensed 28 | under its BSD license. 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fastText software 4 | 5 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. An additional grant 7 | # of patent rights can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | CXX = c++ 11 | CXXFLAGS = -pthread -std=c++0x 12 | OBJS = args.o dictionary.o matrix.o vector.o model.o utils.o 13 | INCLUDES = -I. 14 | 15 | opt: CXXFLAGS += -O3 -funroll-loops 16 | opt: fasttext 17 | 18 | debug: CXXFLAGS += -g -O0 -fno-inline 19 | debug: fasttext 20 | 21 | args.o: src/args.cc src/args.h 22 | $(CXX) $(CXXFLAGS) -c src/args.cc 23 | 24 | dictionary.o: src/dictionary.cc src/dictionary.h src/args.h 25 | $(CXX) $(CXXFLAGS) -c src/dictionary.cc 26 | 27 | matrix.o: src/matrix.cc src/matrix.h src/utils.h 28 | $(CXX) $(CXXFLAGS) -c src/matrix.cc 29 | 30 | vector.o: src/vector.cc src/vector.h src/utils.h 31 | $(CXX) $(CXXFLAGS) -c src/vector.cc 32 | 33 | model.o: src/model.cc src/model.h src/args.h 34 | $(CXX) $(CXXFLAGS) -c src/model.cc 35 | 36 | utils.o: src/utils.cc src/utils.h 37 | $(CXX) $(CXXFLAGS) -c src/utils.cc 38 | 39 | fasttext : $(OBJS) src/fasttext.cc 40 | $(CXX) $(CXXFLAGS) $(OBJS) src/fasttext.cc -o fasttext 41 | 42 | clean: 43 | rm -rf *.o fasttext 44 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the fastText software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fastText_doc2vec 2 | 3 | fastText_doc2vec is an extension of the [Facebook fastText](https://github.com/facebookresearch/fastText) for document embedding. 4 | 5 | ## Requirements 6 | 7 | **fastText_doc2vec** builds on similar environment of [Facebook fastText](https://github.com/facebookresearch/fastText#requirements). 8 | 9 | ## Building fastText 10 | 11 | In order to build `fastText_doc2vec`, use the following: 12 | 13 | ``` 14 | $ git clone https://github.com/Skarface-/fastText_doc2vec.git 15 | $ cd fastText_doc2vec 16 | $ make 17 | ``` 18 | 19 | This will produce object files for all the classes as well as the main binary `fasttext`.
If you do not plan on using the default system-wide compiler, update the two macros defined at the beginning of the Makefile (CC and INCLUDES). 20 | 21 | ## Example use cases 22 | 23 | This library has two use cases: document embedding using PV-DM and PV-DBOW model.
24 | These were described in the paper [1](#distributed-representations-of-sentences-and-documents). 25 | 26 | ### Document Embedding 27 | 28 | In order to embed document in vector space, as described in [1](#distributed-representations-of-sentences-and-documents), do: 29 | 30 | ``` 31 | $ ./fasttext pvdm -model model.bin -input data.txt -output docvecs 32 | or 33 | $ ./fasttext pvdbow -model model.bin -input data.txt -output docvecs 34 | ``` 35 | 36 | where `model.bin` is a previously trained model using [fasttext word representation learning](https://github.com/facebookresearch/fastText#word-representation-learning).
37 | As such, most options are inherited from [fasttext word representation learning](https://github.com/facebookresearch/fastText#word-representation-learning) without epoch, thread, [etc](#full-documentation)
38 | `data.txt` is a file containing `utf-8` encoded labeled documents. (\_\_label\_\_\, \)
39 | At the end of document embeding, the program will save a single file: `docvecs.vec`.
40 | `docvecs.vec` is a text file containing the labeled document vectors, one per line. 41 | 42 | ## Full documentation 43 | 44 | Invoke a command without arguments to list available arguments and their default values: 45 | 46 | ``` 47 | $ ./fasttext pvdm 48 | Empty model or input or output path. 49 | 50 | The following arguments are mandatory: 51 | -model (mandatory, only pvdm or pvdbow) model.bin file path for document embedding 52 | -input training file path 53 | -output output file path 54 | 55 | The following arguments are optional: 56 | -lr learning rate [0.05] 57 | -epoch number of epochs [5] 58 | -thread number of threads [12] 59 | -verbose how often to print to stdout [10000] 60 | -label labels prefix [__label__] 61 | ``` 62 | 63 | ## References 64 | 65 | Please cite [1](#distributed-representations-of-sentences-and-documents) and [2](#enriching-word-vectors-with-subword-information). 66 | 67 | ### Distributed Representations of Sentences and Documents 68 | 69 | [1] Quoc V. Le, T. Mikolov, [*Distributed Representations Of Sentences And Documents*](https://arxiv.org/abs/1405.4053v2) 70 | 71 | ``` 72 | @article{quoc2014distributed, 73 | title={Distributed Representations of Sentences and Documents}, 74 | author={Le, Quoc V. and Mikolov, Tomas}, 75 | journal={arXiv preprint arXiv:1405.4053v2}, 76 | year={2014} 77 | } 78 | ``` 79 | 80 | ### Enriching Word Vectors with Subword Information 81 | 82 | [2] P. Bojanowski\*, E. Grave\*, A. Joulin, T. Mikolov, [*Enriching Word Vectors with Subword Information*](https://arxiv.org/pdf/1607.04606v1.pdf) 83 | 84 | ``` 85 | @article{bojanowski2016enriching, 86 | title={Enriching Word Vectors with Subword Information}, 87 | author={Bojanowski, Piotr and Grave, Edouard and Joulin, Armand and Mikolov, Tomas}, 88 | journal={arXiv preprint arXiv:1607.04606}, 89 | year={2016} 90 | } 91 | ``` 92 | 93 | ## The fastText community 94 | 95 | * Facebook page: https://www.facebook.com/groups/1174547215919768 96 | * Google group: https://groups.google.com/forum/#!forum/fasttext-library 97 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (c) 2016-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the BSD-style license found in the 8 | # LICENSE file in the root directory of this source tree. An additional grant 9 | # of patent rights can be found in the PATENTS file in the same directory. 10 | # 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | from __future__ import unicode_literals 16 | import numpy as np 17 | from scipy import stats 18 | import sys 19 | import os 20 | import math 21 | import argparse 22 | 23 | def compat_splitting(line): 24 | if sys.version > "3": 25 | return line.split() 26 | else: # if version is 2 27 | return line.decode('utf8').split() 28 | 29 | def similarity(v1, v2): 30 | n1 = np.linalg.norm(v1) 31 | n2 = np.linalg.norm(v2) 32 | return np.dot(v1, v2) / n1 / n2 33 | 34 | parser = argparse.ArgumentParser(description='Process some integers.') 35 | parser.add_argument('--model', '-m', dest='modelPath', action='store', required=True, help='path to model') 36 | parser.add_argument('--data', '-d', dest='dataPath', action='store', required=True, help='path to data') 37 | args = parser.parse_args() 38 | 39 | vectors = {} 40 | fin = open(args.modelPath, 'r') 41 | for i, line in enumerate(fin): 42 | try: 43 | tab = compat_splitting(line) 44 | vec = np.array(tab[1:], dtype=float) 45 | word = tab[0] 46 | if not word in vectors: 47 | vectors[word] = vec 48 | except ValueError: 49 | continue 50 | except UnicodeDecodeError: 51 | continue 52 | fin.close() 53 | 54 | mysim = [] 55 | gold = [] 56 | drop = 0.0 57 | nwords = 0.0 58 | 59 | fin = open(args.dataPath, 'r') 60 | for line in fin: 61 | tline = compat_splitting(line) 62 | word1 = tline[0].lower() 63 | word2 = tline[1].lower() 64 | nwords = nwords + 1.0 65 | 66 | if (word1 in vectors) and (word2 in vectors): 67 | v1 = vectors[word1] 68 | v2 = vectors[word2] 69 | d = similarity(v1, v2) 70 | mysim.append(d) 71 | gold.append(float(tline[2])) 72 | else: 73 | drop = drop + 1.0 74 | fin.close() 75 | 76 | corr = stats.spearmanr(mysim, gold) 77 | dataset = os.path.basename(args.dataPath) 78 | print("{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)" 79 | .format(dataset, corr[0] * 100, math.ceil(drop / nwords * 100.0))) 80 | -------------------------------------------------------------------------------- /src/args.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "args.h" 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | Args::Args() { 19 | lr = 0.05; 20 | dim = 100; 21 | ws = 5; 22 | epoch = 5; 23 | minCount = 5; 24 | neg = 5; 25 | wordNgrams = 1; 26 | loss = loss_name::ns; 27 | model = model_name::sg; 28 | bucket = 2000000; 29 | minn = 3; 30 | maxn = 6; 31 | thread = 12; 32 | lrUpdateRate = 100; 33 | t = 1e-4; 34 | label = "__label__"; 35 | } 36 | 37 | void Args::parseArgs(int argc, char** argv) { 38 | std::string command(argv[1]); 39 | if (command == "supervised") { 40 | model = model_name::sup; 41 | loss = loss_name::softmax; 42 | minCount = 1; 43 | } else if (command == "cbow") { 44 | model = model_name::cbow; 45 | } else if (command == "pvdm") { 46 | model = model_name::pvdm; 47 | } else if (command == "pvdbow") { 48 | model = model_name::pvdbow; 49 | } 50 | 51 | int ai = 2; 52 | while (ai < argc) { 53 | if (argv[ai][0] != '-') { 54 | std::cout << "Provided argument without a dash! Usage:" << std::endl; 55 | printHelp(); 56 | exit(EXIT_FAILURE); 57 | } 58 | if (strcmp(argv[ai], "-h") == 0) { 59 | std::cout << "Here is the help! Usage:" << std::endl; 60 | printHelp(); 61 | exit(EXIT_FAILURE); 62 | } else if (strcmp(argv[ai], "-input") == 0) { 63 | input = std::string(argv[ai + 1]); 64 | } else if (strcmp(argv[ai], "-model") == 0) { 65 | modelInput = std::string(argv[ai + 1]); 66 | } else if (strcmp(argv[ai], "-test") == 0) { 67 | test = std::string(argv[ai + 1]); 68 | } else if (strcmp(argv[ai], "-output") == 0) { 69 | output = std::string(argv[ai + 1]); 70 | } else if (strcmp(argv[ai], "-lr") == 0) { 71 | lr = atof(argv[ai + 1]); 72 | } else if (strcmp(argv[ai], "-lrUpdateRate") == 0) { 73 | lrUpdateRate = atoi(argv[ai + 1]); 74 | } else if (strcmp(argv[ai], "-dim") == 0) { 75 | dim = atoi(argv[ai + 1]); 76 | } else if (strcmp(argv[ai], "-ws") == 0) { 77 | ws = atoi(argv[ai + 1]); 78 | } else if (strcmp(argv[ai], "-epoch") == 0) { 79 | epoch = atoi(argv[ai + 1]); 80 | } else if (strcmp(argv[ai], "-minCount") == 0) { 81 | minCount = atoi(argv[ai + 1]); 82 | } else if (strcmp(argv[ai], "-neg") == 0) { 83 | neg = atoi(argv[ai + 1]); 84 | } else if (strcmp(argv[ai], "-wordNgrams") == 0) { 85 | wordNgrams = atoi(argv[ai + 1]); 86 | } else if (strcmp(argv[ai], "-loss") == 0) { 87 | if (strcmp(argv[ai + 1], "hs") == 0) { 88 | loss = loss_name::hs; 89 | } else if (strcmp(argv[ai + 1], "ns") == 0) { 90 | loss = loss_name::ns; 91 | } else if (strcmp(argv[ai + 1], "softmax") == 0) { 92 | loss = loss_name::softmax; 93 | } else { 94 | std::cout << "Unknown loss: " << argv[ai + 1] << std::endl; 95 | printHelp(); 96 | exit(EXIT_FAILURE); 97 | } 98 | } else if (strcmp(argv[ai], "-bucket") == 0) { 99 | bucket = atoi(argv[ai + 1]); 100 | } else if (strcmp(argv[ai], "-minn") == 0) { 101 | minn = atoi(argv[ai + 1]); 102 | } else if (strcmp(argv[ai], "-maxn") == 0) { 103 | maxn = atoi(argv[ai + 1]); 104 | } else if (strcmp(argv[ai], "-thread") == 0) { 105 | thread = atoi(argv[ai + 1]); 106 | } else if (strcmp(argv[ai], "-t") == 0) { 107 | t = atof(argv[ai + 1]); 108 | } else if (strcmp(argv[ai], "-label") == 0) { 109 | label = std::string(argv[ai + 1]); 110 | } else { 111 | std::cout << "Unknown argument: " << argv[ai] << std::endl; 112 | printHelp(); 113 | exit(EXIT_FAILURE); 114 | } 115 | ai += 2; 116 | } 117 | 118 | if ((model == model_name::pvdm || model == model_name::pvdbow) && modelInput.empty()) { 119 | std::cout << "Empty model.bin path." << std::endl; 120 | printHelp(); 121 | exit(EXIT_FAILURE); 122 | } 123 | 124 | if (input.empty() || output.empty()) { 125 | std::cout << "Empty input or output path." << std::endl; 126 | printHelp(); 127 | exit(EXIT_FAILURE); 128 | } 129 | } 130 | 131 | void Args::printHelp() { 132 | std::cout 133 | << "\n" 134 | << "The following arguments are mandatory:\n" 135 | << " -model (mandatory, only pvdm or pvdbow) model.bin file path for document embedding\n" 136 | << " -input training file path\n" 137 | << " -output output file path\n\n" 138 | << "The following arguments are optional:\n" 139 | << " -lr learning rate [" << lr << "]\n" 140 | << " -lrUpdateRate change the rate of updates for the learning rate [" << lrUpdateRate << "]\n" 141 | << " -dim size of word vectors [" << dim << "]\n" 142 | << " -ws size of the context window [" << ws << "]\n" 143 | << " -epoch number of epochs [" << epoch << "]\n" 144 | << " -minCount minimal number of word occurences [" << minCount << "]\n" 145 | << " -neg number of negatives sampled [" << neg << "]\n" 146 | << " -wordNgrams max length of word ngram [" << wordNgrams << "]\n" 147 | << " -loss loss function {ns, hs, softmax} [ns]\n" 148 | << " -bucket number of buckets [" << bucket << "]\n" 149 | << " -minn min length of char ngram [" << minn << "]\n" 150 | << " -maxn max length of char ngram [" << maxn << "]\n" 151 | << " -thread number of threads [" << thread << "]\n" 152 | << " -t sampling threshold [" << t << "]\n" 153 | << " -label labels prefix [" << label << "]\n" 154 | << std::endl; 155 | } 156 | 157 | void Args::save(std::ofstream& ofs) { 158 | if (ofs.is_open()) { 159 | ofs.write((char*) &(dim), sizeof(int)); 160 | ofs.write((char*) &(ws), sizeof(int)); 161 | ofs.write((char*) &(epoch), sizeof(int)); 162 | ofs.write((char*) &(minCount), sizeof(int)); 163 | ofs.write((char*) &(neg), sizeof(int)); 164 | ofs.write((char*) &(wordNgrams), sizeof(int)); 165 | ofs.write((char*) &(loss), sizeof(loss_name)); 166 | ofs.write((char*) &(model), sizeof(model_name)); 167 | ofs.write((char*) &(bucket), sizeof(int)); 168 | ofs.write((char*) &(minn), sizeof(int)); 169 | ofs.write((char*) &(maxn), sizeof(int)); 170 | ofs.write((char*) &(lrUpdateRate), sizeof(int)); 171 | ofs.write((char*) &(t), sizeof(double)); 172 | } 173 | } 174 | 175 | void Args::load(std::ifstream& ifs) { 176 | if (ifs.is_open()) { 177 | ifs.read((char*) &(dim), sizeof(int)); 178 | ifs.read((char*) &(ws), sizeof(int)); 179 | ifs.read((char*) &(epoch), sizeof(int)); 180 | ifs.read((char*) &(minCount), sizeof(int)); 181 | ifs.read((char*) &(neg), sizeof(int)); 182 | ifs.read((char*) &(wordNgrams), sizeof(int)); 183 | ifs.read((char*) &(loss), sizeof(loss_name)); 184 | ifs.read((char*) &(model), sizeof(model_name)); 185 | ifs.read((char*) &(bucket), sizeof(int)); 186 | ifs.read((char*) &(minn), sizeof(int)); 187 | ifs.read((char*) &(maxn), sizeof(int)); 188 | ifs.read((char*) &(lrUpdateRate), sizeof(int)); 189 | ifs.read((char*) &(t), sizeof(double)); 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /src/args.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_ARGS_H 11 | #define FASTTEXT_ARGS_H 12 | 13 | #include 14 | 15 | enum class model_name : int {cbow=1, sg, sup, pvdm, pvdbow}; 16 | enum class loss_name : int {hs=1, ns, softmax}; 17 | 18 | class Args { 19 | public: 20 | Args(); 21 | std::string input; 22 | std::string modelInput; 23 | std::string test; 24 | std::string output; 25 | double lr; 26 | int lrUpdateRate; 27 | int dim; 28 | int ws; 29 | int epoch; 30 | int minCount; 31 | int neg; 32 | int wordNgrams; 33 | loss_name loss; 34 | model_name model; 35 | int bucket; 36 | int minn; 37 | int maxn; 38 | int thread; 39 | double t; 40 | std::string label; 41 | 42 | void parseArgs(int, char**); 43 | void printHelp(); 44 | void save(std::ofstream&); 45 | void load(std::ifstream&); 46 | }; 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /src/dictionary.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "dictionary.h" 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include "args.h" 21 | 22 | extern Args args; 23 | 24 | const std::string Dictionary::EOS = ""; 25 | const std::string Dictionary::BOW = "<"; 26 | const std::string Dictionary::EOW = ">"; 27 | 28 | Dictionary::Dictionary() { 29 | size_ = 0; 30 | nwords_ = 0; 31 | nlabels_ = 0; 32 | ntokens_ = 0; 33 | word2int_.resize(MAX_VOCAB_SIZE); 34 | for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) { 35 | word2int_[i] = -1; 36 | } 37 | } 38 | 39 | int32_t Dictionary::find(const std::string& w) { 40 | int32_t h = hash(w) % MAX_VOCAB_SIZE; 41 | while (word2int_[h] != -1 && words_[word2int_[h]].word != w) { 42 | h = (h + 1) % MAX_VOCAB_SIZE; 43 | } 44 | return h; 45 | } 46 | 47 | void Dictionary::add(const std::string& w) { 48 | int32_t h = find(w); 49 | ntokens_++; 50 | if (word2int_[h] == -1) { 51 | entry e; 52 | e.word = w; 53 | e.count = 1; 54 | e.type = (w.find(args.label) == 0) ? entry_type::label : entry_type::word; 55 | words_.push_back(e); 56 | word2int_[h] = size_++; 57 | } else { 58 | words_[word2int_[h]].count++; 59 | } 60 | } 61 | 62 | int32_t Dictionary::nwords() { 63 | return nwords_; 64 | } 65 | 66 | int32_t Dictionary::nlabels() { 67 | return nlabels_; 68 | } 69 | 70 | int32_t Dictionary::nsizes() { 71 | return size_; 72 | } 73 | 74 | int64_t Dictionary::ntokens() { 75 | return ntokens_; 76 | } 77 | 78 | const std::vector& Dictionary::getNgrams(int32_t i) { 79 | assert(i >= 0); 80 | assert(i < nwords_); 81 | return words_[i].subwords; 82 | } 83 | 84 | const std::vector Dictionary::getNgrams(const std::string& word) { 85 | std::vector ngrams; 86 | int32_t i = getId(word); 87 | if (i >= 0) { 88 | ngrams = words_[i].subwords; 89 | } else { 90 | computeNgrams(BOW + word + EOW, ngrams); 91 | } 92 | return ngrams; 93 | } 94 | 95 | bool Dictionary::discard(int32_t id, real rand) { 96 | assert(id >= 0); 97 | assert(id < nwords_); 98 | if (args.model == model_name::sup) return false; 99 | return rand > pdiscard_[id]; 100 | } 101 | 102 | int32_t Dictionary::getId(const std::string& w) { 103 | int32_t h = find(w); 104 | return word2int_[h]; 105 | } 106 | 107 | entry_type Dictionary::getType(int32_t id) { 108 | assert(id >= 0); 109 | assert(id < size_); 110 | return words_[id].type; 111 | } 112 | 113 | std::string Dictionary::getWord(int32_t id) { 114 | assert(id >= 0); 115 | assert(id < size_); 116 | return words_[id].word; 117 | } 118 | 119 | uint32_t Dictionary::hash(const std::string& str) { 120 | uint32_t h = 2166136261; 121 | for (size_t i = 0; i < str.size(); i++) { 122 | h = h ^ uint32_t(str[i]); 123 | h = h * 16777619; 124 | } 125 | return h; 126 | } 127 | 128 | void Dictionary::computeNgrams(const std::string& word, 129 | std::vector& ngrams) { 130 | for (size_t i = 0; i < word.size(); i++) { 131 | std::string ngram; 132 | 133 | if ((word[i] & 0xC0) == 0x80) continue; 134 | for (size_t j = i, n = 1; j < word.size() && n <= args.maxn; n++) { 135 | ngram.push_back(word[j++]); 136 | while (j < word.size() && (word[j] & 0xC0) == 0x80) { 137 | ngram.push_back(word[j++]); 138 | } 139 | if (n >= args.minn) { 140 | int32_t h = hash(ngram) % args.bucket; 141 | ngrams.push_back(nwords_ + h); 142 | } 143 | } 144 | } 145 | } 146 | 147 | void Dictionary::initNgrams() { 148 | for (size_t i = 0; i < size_; i++) { 149 | std::string word = BOW + words_[i].word + EOW; 150 | words_[i].subwords.push_back(i); 151 | computeNgrams(word, words_[i].subwords); 152 | } 153 | } 154 | 155 | std::string Dictionary::readWord(std::ifstream& fin) { 156 | char c; 157 | std::string word; 158 | while (fin.peek() != EOF) { 159 | fin.get(c); 160 | if (isspace(c) || c == 0) { 161 | if (word.empty()) { 162 | if (c == '\n') return EOS; 163 | continue; 164 | } else { 165 | if (c == '\n') fin.unget(); 166 | return word; 167 | } 168 | } 169 | word.push_back(c); 170 | } 171 | return word; 172 | } 173 | 174 | void Dictionary::readFromFile(std::ifstream& ifs) { 175 | std::string word; 176 | int64_t minThreshold = 1; 177 | while (!(word = readWord(ifs)).empty()) { 178 | add(word); 179 | if (ntokens_ % 1000000 == 0) { 180 | std::cout << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush; 181 | } 182 | if (size_ > 0.75 * MAX_VOCAB_SIZE) { 183 | threshold(minThreshold++); 184 | } 185 | } 186 | std::cout << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl; 187 | threshold(args.minCount); 188 | initTableDiscard(); 189 | initNgrams(); 190 | } 191 | 192 | void Dictionary::threshold(int64_t t) { 193 | sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) { 194 | if (e1.type != e2.type) return e1.type < e2.type; 195 | return e1.count > e2.count; 196 | }); 197 | words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) { 198 | return e.count < t; 199 | }), words_.end()); 200 | words_.shrink_to_fit(); 201 | size_ = 0; 202 | nwords_ = 0; 203 | nlabels_ = 0; 204 | for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) { 205 | word2int_[i] = -1; 206 | } 207 | for (auto it = words_.begin(); it != words_.end(); ++it) { 208 | int32_t h = find(it->word); 209 | word2int_[h] = size_++; 210 | if (it->type == entry_type::word) nwords_++; 211 | if (it->type == entry_type::label) nlabels_++; 212 | } 213 | } 214 | 215 | void Dictionary::initTableDiscard() { 216 | pdiscard_.resize(size_); 217 | for (size_t i = 0; i < size_; i++) { 218 | real f = real(words_[i].count) / real(ntokens_); 219 | pdiscard_[i] = sqrt(args.t / f) + args.t / f; 220 | } 221 | } 222 | 223 | std::vector Dictionary::getCounts(entry_type type) { 224 | std::vector counts; 225 | for (auto& w : words_) { 226 | if (w.type == type) counts.push_back(w.count); 227 | } 228 | return counts; 229 | } 230 | 231 | void Dictionary::addNgrams(std::vector& line, int32_t n) { 232 | int32_t line_size = line.size(); 233 | for (int32_t i = 0; i < line_size; i++) { 234 | uint64_t h = line[i]; 235 | for (int32_t j = i + 1; j < line_size && j < i + n; j++) { 236 | h = h * 116049371 + line[j]; 237 | line.push_back(nwords_ + (h % args.bucket)); 238 | } 239 | } 240 | } 241 | 242 | int32_t Dictionary::getLine(std::ifstream& ifs, 243 | std::vector& words, 244 | std::vector& labels, 245 | std::minstd_rand& rng) { 246 | std::uniform_real_distribution<> uniform(0, 1); 247 | std::string token; 248 | int32_t ntokens = 0; 249 | words.clear(); 250 | labels.clear(); 251 | if (ifs.eof()) { 252 | ifs.clear(); 253 | ifs.seekg(std::streampos(0)); 254 | } 255 | while (!(token = readWord(ifs)).empty()) { 256 | // read a one line. 257 | if (token == EOS) break; 258 | int32_t wid = getId(token); 259 | if (wid < 0) continue; 260 | entry_type type = getType(wid); 261 | ntokens++; 262 | if (type == entry_type::word && !discard(wid, uniform(rng))) { 263 | words.push_back(wid); 264 | } 265 | if (type == entry_type::label) { 266 | labels.push_back(wid-nwords_); 267 | } 268 | if (words.size() > MAX_LINE_SIZE && args.model != model_name::sup) break; 269 | } 270 | return ntokens; 271 | } 272 | 273 | std::string Dictionary::getLabel(int32_t lid) { 274 | assert(lid >= 0); 275 | assert(lid < nlabels_); 276 | return words_[lid + nwords_].word; 277 | } 278 | 279 | void Dictionary::save(std::ofstream& ofs) { 280 | ofs.write((char*) &size_, sizeof(int32_t)); 281 | ofs.write((char*) &nwords_, sizeof(int32_t)); 282 | ofs.write((char*) &nlabels_, sizeof(int32_t)); 283 | ofs.write((char*) &ntokens_, sizeof(int64_t)); 284 | for (int32_t i = 0; i < size_; i++) { 285 | entry e = words_[i]; 286 | ofs.write(e.word.data(), e.word.size() * sizeof(char)); 287 | ofs.put(0); 288 | ofs.write((char*) &(e.count), sizeof(int64_t)); 289 | ofs.write((char*) &(e.type), sizeof(entry_type)); 290 | } 291 | } 292 | 293 | void Dictionary::load(std::ifstream& ifs) { 294 | words_.clear(); 295 | for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) { 296 | word2int_[i] = -1; 297 | } 298 | ifs.read((char*) &size_, sizeof(int32_t)); 299 | ifs.read((char*) &nwords_, sizeof(int32_t)); 300 | ifs.read((char*) &nlabels_, sizeof(int32_t)); 301 | ifs.read((char*) &ntokens_, sizeof(int64_t)); 302 | for (int32_t i = 0; i < size_; i++) { 303 | char c; 304 | entry e; 305 | while ((c = ifs.get()) != 0) { 306 | e.word.push_back(c); 307 | } 308 | ifs.read((char*) &e.count, sizeof(int64_t)); 309 | ifs.read((char*) &e.type, sizeof(entry_type)); 310 | words_.push_back(e); 311 | word2int_[find(e.word)] = i; 312 | } 313 | initTableDiscard(); 314 | initNgrams(); 315 | } 316 | 317 | void Dictionary::load(std::ifstream& ifs, std::ifstream& modelIfs) { 318 | words_.clear(); 319 | for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) { 320 | word2int_[i] = -1; 321 | } 322 | modelIfs.read((char*) &size_, sizeof(int32_t)); 323 | modelIfs.read((char*) &nwords_, sizeof(int32_t)); 324 | modelIfs.read((char*) &nlabels_, sizeof(int32_t)); 325 | modelIfs.read((char*) &ntokens_, sizeof(int64_t)); 326 | for (int32_t i = 0; i < size_; i++) { 327 | char c; 328 | entry e; 329 | while ((c = modelIfs.get()) != 0) { 330 | e.word.push_back(c); 331 | } 332 | modelIfs.read((char*) &e.count, sizeof(int64_t)); 333 | modelIfs.read((char*) &e.type, sizeof(entry_type)); 334 | words_.push_back(e); 335 | word2int_[find(e.word)] = i; 336 | } 337 | 338 | std::string word; 339 | while (!(word = readWord(ifs)).empty()) { 340 | if (word.find(args.label) == 0) { 341 | add(word); 342 | nlabels_++; 343 | } 344 | } 345 | 346 | initTableDiscard(); 347 | initNgrams(); 348 | } 349 | -------------------------------------------------------------------------------- /src/dictionary.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_DICTIONARY_H 11 | #define FASTTEXT_DICTIONARY_H 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "real.h" 19 | 20 | typedef int32_t id_type; 21 | enum class entry_type : int8_t {word=0, label=1}; 22 | 23 | struct entry { 24 | std::string word; 25 | int64_t count; 26 | entry_type type; 27 | std::vector subwords; 28 | }; 29 | 30 | class Dictionary { 31 | private: 32 | static const int32_t MAX_VOCAB_SIZE = 30000000; 33 | static const int32_t MAX_LINE_SIZE = 1024; 34 | 35 | int32_t find(const std::string&); 36 | void initTableDiscard(); 37 | void initNgrams(); 38 | void threshold(int64_t); 39 | 40 | std::vector word2int_; 41 | std::vector words_; 42 | std::vector pdiscard_; 43 | int32_t size_; 44 | int32_t nwords_; 45 | int32_t nlabels_; 46 | int64_t ntokens_; 47 | 48 | public: 49 | static const std::string EOS; 50 | static const std::string BOW; 51 | static const std::string EOW; 52 | 53 | Dictionary(); 54 | int32_t nwords(); 55 | int32_t nsizes(); 56 | int32_t nlabels(); 57 | int64_t ntokens(); 58 | int32_t getId(const std::string&); 59 | entry_type getType(int32_t); 60 | bool discard(int32_t, real); 61 | std::string getWord(int32_t); 62 | const std::vector& getNgrams(int32_t); 63 | const std::vector getNgrams(const std::string&); 64 | void computeNgrams(const std::string&, std::vector&); 65 | uint32_t hash(const std::string& str); 66 | void add(const std::string&); 67 | std::string readWord(std::ifstream&); 68 | void readFromFile(std::ifstream&); 69 | std::string getLabel(int32_t); 70 | void save(std::ofstream&); 71 | void load(std::ifstream&); 72 | void load(std::ifstream&, std::ifstream&); 73 | std::vector getCounts(entry_type); 74 | void addNgrams(std::vector&, int32_t); 75 | int32_t getLine(std::ifstream&, std::vector&, 76 | std::vector&, std::minstd_rand&); 77 | }; 78 | 79 | #endif 80 | -------------------------------------------------------------------------------- /src/fasttext.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "matrix.h" 23 | #include "vector.h" 24 | #include "dictionary.h" 25 | #include "model.h" 26 | #include "utils.h" 27 | #include "real.h" 28 | #include "args.h" 29 | 30 | Args args; 31 | 32 | namespace info { 33 | clock_t start; 34 | std::atomic allWords(0); 35 | std::atomic allLabels(0); 36 | std::atomic allN(0); 37 | double allLoss(0.0); 38 | } 39 | 40 | void getVector(Dictionary& dict, Matrix& input, Vector& vec, std::string word) { 41 | const std::vector& ngrams = dict.getNgrams(word); 42 | vec.zero(); 43 | for (auto it = ngrams.begin(); it != ngrams.end(); ++it) { 44 | vec.addRow(input, *it); 45 | } 46 | if (ngrams.size() > 0) { 47 | vec.mul(1.0 / ngrams.size()); 48 | } 49 | } 50 | 51 | void saveVectors(Dictionary& dict, Matrix& input, Matrix& output) { 52 | std::ofstream ofs(args.output + ".vec"); 53 | if (!ofs.is_open()) { 54 | std::cout << "Error opening file for saving vectors." << std::endl; 55 | exit(EXIT_FAILURE); 56 | } 57 | ofs << dict.nwords() << " " << args.dim << std::endl; 58 | Vector vec(args.dim); 59 | for (int32_t i = 0; i < dict.nwords(); i++) { 60 | std::string word = dict.getWord(i); 61 | getVector(dict, input, vec, word); 62 | ofs << word << " " << vec << std::endl; 63 | } 64 | 65 | ofs.close(); 66 | } 67 | 68 | void saveDocVectors(Dictionary& dict, Matrix& input) { 69 | std::ofstream ofs(args.output + ".vec"); 70 | if (!ofs.is_open()) { 71 | std::cout << "Error opening file for saving document vectors." << std::endl; 72 | exit(EXIT_FAILURE); 73 | } 74 | ofs << dict.nlabels() << " " << args.dim << std::endl; 75 | Vector vec(args.dim); 76 | int32_t nwords = dict.nwords(); 77 | for (int32_t i = 0; i < dict.nlabels(); i++) { 78 | std::string label = dict.getLabel(i); 79 | vec.zero(); 80 | vec.addRow(input, i + nwords + args.bucket); 81 | ofs << label << "\t" << vec << std::endl; 82 | } 83 | 84 | ofs.close(); 85 | } 86 | 87 | void printVectors(Dictionary& dict, Matrix& input) { 88 | std::string word; 89 | Vector vec(args.dim); 90 | while (std::cin >> word) { 91 | getVector(dict, input, vec, word); 92 | std::cout << word << " " << vec << std::endl; 93 | } 94 | } 95 | 96 | void saveModel(Dictionary& dict, Matrix& input, Matrix& output) { 97 | std::ofstream ofs(args.output + ".bin"); 98 | if (!ofs.is_open()) { 99 | std::cerr << "Model file cannot be opened for saving!" << std::endl; 100 | exit(EXIT_FAILURE); 101 | } 102 | args.save(ofs); 103 | dict.save(ofs); 104 | input.save(ofs); 105 | output.save(ofs); 106 | ofs.close(); 107 | } 108 | 109 | void loadModel(std::string filename, Dictionary& dict, 110 | Matrix& input, Matrix& output) { 111 | std::ifstream ifs(filename); 112 | if (!ifs.is_open()) { 113 | std::cerr << "Model file cannot be opened for loading!" << std::endl; 114 | exit(EXIT_FAILURE); 115 | } 116 | args.load(ifs); 117 | dict.load(ifs); 118 | input.load(ifs); 119 | output.load(ifs); 120 | 121 | ifs.close(); 122 | } 123 | 124 | void loadModel(std::string documentsFilename, std::string modelFilename, Dictionary& dict, 125 | Matrix& input, Matrix& output) { 126 | std::ifstream ifs(documentsFilename); 127 | std::ifstream modelIfs(modelFilename); 128 | if (!ifs.is_open()) { 129 | std::cerr << "Model file cannot be opened for loading!" << std::endl; 130 | exit(EXIT_FAILURE); 131 | } 132 | args.load(modelIfs); 133 | dict.load(ifs, modelIfs); 134 | input.load(modelIfs, dict.nlabels()); 135 | output.load(modelIfs); 136 | 137 | ifs.close(); 138 | modelIfs.close(); 139 | } 140 | 141 | void printInfo(Model& model, real progress) { 142 | real loss = info::allLoss / info::allN; 143 | real t = real(clock() - info::start) / CLOCKS_PER_SEC; 144 | real wst = real(info::allWords) / t; 145 | int eta = int(t / progress * (1 - progress) / args.thread); 146 | int etah = eta / 3600; 147 | int etam = (eta - etah * 3600) / 60; 148 | std::cout << std::fixed; 149 | std::cout << "\rProgress: " << std::setprecision(1) << 100 * progress << "%"; 150 | std::cout << " words/sec/thread: " << std::setprecision(0) << wst; 151 | std::cout << " lr: " << std::setprecision(6) << model.getLearningRate(); 152 | std::cout << " allN: " << std::setprecision(6) << info::allN; 153 | std::cout << " allLoss: " << std::setprecision(6) << info::allLoss; 154 | std::cout << " loss: " << std::setprecision(6) << loss; 155 | std::cout << " eta: " << etah << "h" << etam << "m "; 156 | std::cout << std::flush; 157 | } 158 | 159 | void supervised(Model& model, 160 | const std::vector& line, 161 | const std::vector& labels, 162 | double& loss, int32_t& nexamples) { 163 | if (labels.size() == 0 || line.size() == 0) return; 164 | std::uniform_int_distribution<> uniform(0, labels.size() - 1); 165 | int32_t i = uniform(model.rng); 166 | loss += model.update(line, labels[i]); 167 | nexamples++; 168 | } 169 | 170 | void pvdbow(Dictionary& dict, Model& model, 171 | const std::vector& line, 172 | const std::vector& labels, 173 | double& loss, int32_t& nexamples) { 174 | std::uniform_int_distribution<> uniform(1, args.ws); 175 | for (int32_t w = 0; w < line.size(); w++) { 176 | int32_t boundary = uniform(model.rng); 177 | for (int32_t c = -boundary; c <= boundary; c++) { 178 | if (c != 0 && w + c >= 0 && w + c < line.size()) { 179 | loss += model.update(labels, line[w+c]); 180 | nexamples++; 181 | } 182 | } 183 | } 184 | } 185 | 186 | void pvdm(Dictionary& dict, Model& model, 187 | const std::vector& line, 188 | const std::vector& labels, 189 | double& loss, int32_t& nexamples) { 190 | std::vector bow; 191 | std::uniform_int_distribution<> uniform(1, args.ws); 192 | for (int32_t w = 0; w < line.size(); w++) { 193 | int32_t boundary = uniform(model.rng); 194 | bow.clear(); 195 | for (int32_t c = -boundary; c <= boundary; c++) { 196 | if (c != 0 && w + c >= 0 && w + c < line.size()) { 197 | const std::vector& ngrams = dict.getNgrams(line[w + c]); 198 | bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend()); 199 | } 200 | } 201 | bow.insert(bow.end(), labels.cbegin(), labels.cend()); 202 | loss += model.update(bow, line[w]); 203 | nexamples++; 204 | } 205 | } 206 | 207 | void cbow(Dictionary& dict, Model& model, 208 | const std::vector& line, 209 | double& loss, int32_t& nexamples) { 210 | std::vector bow; 211 | std::uniform_int_distribution<> uniform(1, args.ws); 212 | for (int32_t w = 0; w < line.size(); w++) { 213 | int32_t boundary = uniform(model.rng); 214 | bow.clear(); 215 | for (int32_t c = -boundary; c <= boundary; c++) { 216 | if (c != 0 && w + c >= 0 && w + c < line.size()) { 217 | const std::vector& ngrams = dict.getNgrams(line[w + c]); 218 | bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend()); 219 | } 220 | } 221 | loss += model.update(bow, line[w]); 222 | nexamples++; 223 | } 224 | } 225 | 226 | void skipgram(Dictionary& dict, Model& model, 227 | const std::vector& line, 228 | double& loss, int32_t& nexamples) { 229 | std::uniform_int_distribution<> uniform(1, args.ws); 230 | for (int32_t w = 0; w < line.size(); w++) { 231 | int32_t boundary = uniform(model.rng); 232 | const std::vector& ngrams = dict.getNgrams(line[w]); 233 | for (int32_t c = -boundary; c <= boundary; c++) { 234 | if (c != 0 && w + c >= 0 && w + c < line.size()) { 235 | loss += model.update(ngrams, line[w + c]); 236 | nexamples++; 237 | } 238 | } 239 | } 240 | } 241 | 242 | void test(Dictionary& dict, Model& model, std::string filename, int32_t k) { 243 | int32_t nexamples = 0, nlabels = 0; 244 | double precision = 0.0; 245 | std::vector line, labels; 246 | std::ifstream ifs(filename); 247 | if (!ifs.is_open()) { 248 | std::cerr << "Test file cannot be opened!" << std::endl; 249 | exit(EXIT_FAILURE); 250 | } 251 | while (ifs.peek() != EOF) { 252 | dict.getLine(ifs, line, labels, model.rng); 253 | dict.addNgrams(line, args.wordNgrams); 254 | if (labels.size() > 0 && line.size() > 0) { 255 | std::vector> predictions; 256 | model.predict(line, k, predictions); 257 | for (auto it = predictions.cbegin(); it != predictions.cend(); it++) { 258 | if (std::find(labels.begin(), labels.end(), it->second) != labels.end()) { 259 | precision += 1.0; 260 | } 261 | } 262 | nexamples++; 263 | nlabels += labels.size(); 264 | } 265 | } 266 | ifs.close(); 267 | std::cout << std::setprecision(3); 268 | std::cout << "P@" << k << ": " << precision / (k * nexamples) << std::endl; 269 | std::cout << "R@" << k << ": " << precision / nlabels << std::endl; 270 | std::cout << "Number of examples: " << nexamples << std::endl; 271 | } 272 | 273 | void predict(Dictionary& dict, Model& model, std::string filename, int32_t k) { 274 | std::vector line, labels; 275 | std::ifstream ifs(filename); 276 | if (!ifs.is_open()) { 277 | std::cerr << "Test file cannot be opened!" << std::endl; 278 | exit(EXIT_FAILURE); 279 | } 280 | while (ifs.peek() != EOF) { 281 | dict.getLine(ifs, line, labels, model.rng); 282 | dict.addNgrams(line, args.wordNgrams); 283 | if (line.empty()) { 284 | std::cout << "n/a" << std::endl; 285 | continue; 286 | } 287 | std::vector> predictions; 288 | model.predict(line, k, predictions); 289 | for (auto it = predictions.cbegin(); it != predictions.cend(); it++) { 290 | if (it != predictions.cbegin()) { 291 | std::cout << ' '; 292 | } 293 | std::cout << dict.getLabel(it->second); 294 | } 295 | std::cout << std::endl; 296 | } 297 | ifs.close(); 298 | } 299 | 300 | void printUsage() { 301 | std::cout 302 | << "usage: fasttext \n\n" 303 | << "The commands supported by fasttext are:\n\n" 304 | << " supervised train a supervised classifier\n" 305 | << " test evaluate a supervised classifier\n" 306 | << " predict predict most likely label\n" 307 | << " skipgram train a skipgram model\n" 308 | << " cbow train a cbow model\n" 309 | << " pvdm train a pvdm model\n" 310 | << " pvbow train a pvdbow model\n" 311 | << " print-vectors print vectors given a trained model\n" 312 | << std::endl; 313 | } 314 | 315 | void printTestUsage() { 316 | std::cout 317 | << "usage: fasttext test []\n\n" 318 | << " model filename\n" 319 | << " test data filename\n" 320 | << " (optional; 1 by default) predict top k labels\n" 321 | << std::endl; 322 | } 323 | 324 | void printPredictUsage() { 325 | std::cout 326 | << "usage: fasttext predict []\n\n" 327 | << " model filename\n" 328 | << " test data filename\n" 329 | << " (optional; 1 by default) predict top k labels\n" 330 | << std::endl; 331 | } 332 | 333 | void printPrintVectorsUsage() { 334 | std::cout 335 | << "usage: fasttext print-vectors \n\n" 336 | << " model filename\n" 337 | << std::endl; 338 | } 339 | 340 | void test(int argc, char** argv) { 341 | int32_t k; 342 | if (argc == 4) { 343 | k = 1; 344 | } else if (argc == 5) { 345 | k = atoi(argv[4]); 346 | } else { 347 | printTestUsage(); 348 | exit(EXIT_FAILURE); 349 | } 350 | Dictionary dict; 351 | Matrix input, output; 352 | loadModel(std::string(argv[2]), dict, input, output); 353 | Model model(input, output, args.dim, args.lr, 1); 354 | model.setTargetCounts(dict.getCounts(entry_type::label)); 355 | test(dict, model, std::string(argv[3]), k); 356 | exit(0); 357 | } 358 | 359 | void predict(int argc, char** argv) { 360 | int32_t k; 361 | if (argc == 4) { 362 | k = 1; 363 | } else if (argc == 5) { 364 | k = atoi(argv[4]); 365 | } else { 366 | printPredictUsage(); 367 | exit(EXIT_FAILURE); 368 | } 369 | Dictionary dict; 370 | Matrix input, output; 371 | loadModel(std::string(argv[2]), dict, input, output); 372 | Model model(input, output, args.dim, args.lr, 1); 373 | model.setTargetCounts(dict.getCounts(entry_type::label)); 374 | predict(dict, model, std::string(argv[3]), k); 375 | exit(0); 376 | } 377 | 378 | void printVectors(int argc, char** argv) { 379 | if (argc != 3) { 380 | printPrintVectorsUsage(); 381 | exit(EXIT_FAILURE); 382 | } 383 | Dictionary dict; 384 | Matrix input, output; 385 | loadModel(std::string(argv[2]), dict, input, output); 386 | printVectors(dict, input); 387 | exit(0); 388 | } 389 | 390 | void trainThread(Dictionary& dict, Matrix& input, Matrix& output, 391 | int32_t threadId) { 392 | std::ifstream ifs(args.input); 393 | utils::seek(ifs, threadId * utils::size(ifs) / args.thread); 394 | 395 | Model model(input, output, args.dim, args.lr, threadId); 396 | if (args.model == model_name::sup) { 397 | model.setTargetCounts(dict.getCounts(entry_type::label)); 398 | } else { 399 | model.setTargetCounts(dict.getCounts(entry_type::word)); 400 | } 401 | 402 | real progress; 403 | const int64_t ntokens = dict.ntokens(); 404 | int64_t tokenCount = 0, printCount = 0, deltaCount = 0; 405 | double loss = 0.0; 406 | int32_t nexamples = 0; 407 | std::vector line, labels; 408 | while (info::allWords < args.epoch * ntokens) { 409 | deltaCount = dict.getLine(ifs, line, labels, model.rng); 410 | tokenCount += deltaCount; 411 | printCount += deltaCount; 412 | if (args.model == model_name::sup) { 413 | dict.addNgrams(line, args.wordNgrams); 414 | supervised(model, line, labels, loss, nexamples); 415 | } else if (args.model == model_name::cbow) { 416 | cbow(dict, model, line, loss, nexamples); 417 | } else if (args.model == model_name::sg) { 418 | skipgram(dict, model, line, loss, nexamples); 419 | } 420 | if (tokenCount > args.lrUpdateRate) { 421 | info::allWords += tokenCount; 422 | info::allLoss += loss; 423 | info::allN += nexamples; 424 | tokenCount = 0; 425 | loss = 0.0; 426 | nexamples = 0; 427 | progress = real(info::allWords) / (args.epoch * ntokens); 428 | model.setLearningRate(args.lr * (1.0 - progress)); 429 | if (threadId == 0) { 430 | printInfo(model, progress); 431 | } 432 | } 433 | } 434 | if (threadId == 0) { 435 | printInfo(model, 1.0); 436 | std::cout << std::endl; 437 | } 438 | ifs.close(); 439 | } 440 | 441 | void train(int argc, char** argv) { 442 | args.parseArgs(argc, argv); 443 | 444 | Dictionary dict; 445 | std::ifstream ifs(args.input); 446 | if (!ifs.is_open()) { 447 | std::cerr << "Input file cannot be opened!" << std::endl; 448 | exit(EXIT_FAILURE); 449 | } 450 | dict.readFromFile(ifs); 451 | ifs.close(); 452 | 453 | Matrix input(dict.nwords() + args.bucket, args.dim); 454 | Matrix output; 455 | if (args.model == model_name::sup) { 456 | output = Matrix(dict.nlabels(), args.dim); 457 | } else { 458 | output = Matrix(dict.nwords(), args.dim); 459 | } 460 | input.uniform(1.0 / args.dim); 461 | output.zero(); 462 | 463 | info::start = clock(); 464 | time_t t0 = time(nullptr); 465 | std::vector threads; 466 | for (int32_t i = 0; i < args.thread; i++) { 467 | threads.push_back(std::thread(&trainThread, std::ref(dict), 468 | std::ref(input), std::ref(output), i)); 469 | } 470 | for (auto it = threads.begin(); it != threads.end(); ++it) { 471 | it->join(); 472 | } 473 | double trainTime = difftime(time(nullptr), t0); 474 | std::cout << "Train time: " << trainTime << " sec" << std::endl; 475 | 476 | if (args.output.size() != 0) { 477 | saveModel(dict, input, output); 478 | saveVectors(dict, input, output); 479 | } 480 | } 481 | 482 | void embeddingThread(Dictionary& dict, Matrix& input, Matrix& output, 483 | int32_t threadId) { 484 | std::ifstream ifs(args.input); 485 | utils::seek(ifs, threadId * utils::size(ifs) / args.thread); 486 | 487 | Model model(input, output, args.dim, args.lr, threadId); 488 | model.setTargetCounts(dict.getCounts(entry_type::word)); 489 | 490 | real progress; 491 | const int64_t ntokens = dict.ntokens(); 492 | const int64_t nwords = dict.nwords(); 493 | const int64_t nlabels = dict.nlabels(); 494 | int64_t tokenCount = 0, printCount = 0, deltaCount = 0; 495 | double loss = 0.0; 496 | int32_t nexamples = 0; 497 | std::vector line, labels; 498 | 499 | while (info::allLabels < args.epoch * nlabels) { 500 | deltaCount = dict.getLine(ifs, line, labels, model.rng); 501 | if (labels.size() == 0 || line.size() == 0) continue; 502 | tokenCount += deltaCount; 503 | printCount += deltaCount; 504 | 505 | labels[0] += nwords + args.bucket; 506 | if (args.model == model_name::pvdm) { 507 | pvdm(dict, model, line, labels, loss, nexamples); 508 | } else if (args.model == model_name::pvdbow) { 509 | pvdbow(dict, model, line, labels, loss, nexamples); 510 | } 511 | 512 | if (tokenCount > args.lrUpdateRate) { 513 | info::allLabels += 1; 514 | info::allWords += tokenCount; 515 | info::allLoss += loss; 516 | info::allN += nexamples; 517 | tokenCount = 0; 518 | loss = 0.0; 519 | nexamples = 0; 520 | progress = real(info::allLabels) / (args.epoch * nlabels); 521 | model.setLearningRate(args.lr * (1.0 - progress)); 522 | if (threadId == 0) { 523 | printInfo(model, progress); 524 | } 525 | } 526 | } 527 | if (threadId == 0) { 528 | printInfo(model, 1.0); 529 | std::cout << std::endl; 530 | } 531 | ifs.close(); 532 | } 533 | 534 | void embedding(int argc, char** argv) { 535 | args.parseArgs(argc, argv); 536 | 537 | model_name modelName = args.model; 538 | int epoch = args.epoch; 539 | 540 | Dictionary dict; 541 | Matrix input, output; 542 | time_t t0 = time(nullptr); 543 | loadModel(args.input, args.modelInput, dict, input, output); 544 | double trainTime = difftime(time(nullptr), t0); 545 | std::cout << "Model loading time: " << trainTime << " sec " << std::endl; 546 | std::cout << "size: " << dict.nsizes() << " nlabels: " << dict.nlabels() << " nwords: " << dict.nwords() << " ntokens: " << dict.ntokens() << std::endl; 547 | 548 | args.model = modelName; 549 | args.epoch = epoch; 550 | 551 | std::vector threads; 552 | t0 = time(nullptr); 553 | for (int32_t i = 0; i < args.thread; i++) { 554 | threads.push_back(std::thread(&embeddingThread, std::ref(dict), 555 | std::ref(input), std::ref(output), i)); 556 | } 557 | for (auto it = threads.begin(); it != threads.end(); ++it) { 558 | it->join(); 559 | } 560 | 561 | trainTime = difftime(time(nullptr), t0); 562 | std::cout << "Train time: " << trainTime << " sec" << std::endl; 563 | 564 | if (args.output.size() != 0) { 565 | saveDocVectors(dict, input); 566 | } 567 | 568 | exit(0); 569 | } 570 | 571 | int main(int argc, char** argv) { 572 | utils::initTables(); 573 | if (argc < 2) { 574 | printUsage(); 575 | exit(EXIT_FAILURE); 576 | } 577 | std::string command(argv[1]); 578 | if (command == "skipgram" || command == "cbow" || command == "supervised") { 579 | train(argc, argv); 580 | } else if (command == "pvdm" || command == "pvdbow") { 581 | embedding(argc, argv); 582 | } else if (command == "test") { 583 | test(argc, argv); 584 | } else if (command == "print-vectors") { 585 | printVectors(argc, argv); 586 | } else if (command == "predict") { 587 | predict(argc, argv); 588 | } else { 589 | printUsage(); 590 | exit(EXIT_FAILURE); 591 | } 592 | utils::freeTables(); 593 | return 0; 594 | } 595 | -------------------------------------------------------------------------------- /src/matrix.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "matrix.h" 11 | 12 | #include 13 | 14 | #include 15 | 16 | #include "utils.h" 17 | #include "vector.h" 18 | 19 | Matrix::Matrix() { 20 | m_ = 0; 21 | n_ = 0; 22 | data_ = nullptr; 23 | } 24 | 25 | Matrix::Matrix(int64_t m, int64_t n) { 26 | m_ = m; 27 | n_ = n; 28 | data_ = new real[m * n]; 29 | } 30 | 31 | Matrix::Matrix(const Matrix& other) { 32 | m_ = other.m_; 33 | n_ = other.n_; 34 | data_ = new real[m_ * n_]; 35 | for (int64_t i = 0; i < (m_ * n_); i++) { 36 | data_[i] = other.data_[i]; 37 | } 38 | } 39 | 40 | Matrix& Matrix::operator=(const Matrix& other) { 41 | Matrix temp(other); 42 | m_ = temp.m_; 43 | n_ = temp.n_; 44 | std::swap(data_, temp.data_); 45 | return *this; 46 | } 47 | 48 | Matrix::~Matrix() { 49 | delete[] data_; 50 | } 51 | 52 | void Matrix::zero() { 53 | for (int64_t i = 0; i < (m_ * n_); i++) { 54 | data_[i] = 0.0; 55 | } 56 | } 57 | 58 | void Matrix::uniform(real a) { 59 | std::minstd_rand rng(1); 60 | std::uniform_real_distribution<> uniform(-a, a); 61 | for (int64_t i = 0; i < (m_ * n_); i++) { 62 | data_[i] = uniform(rng); 63 | } 64 | } 65 | 66 | void Matrix::addRow(const Vector& vec, int64_t i, real a) { 67 | assert(i >= 0); 68 | assert(i < m_); 69 | assert(vec.m_ == n_); 70 | for (int64_t j = 0; j < n_; j++) { 71 | data_[i * n_ + j] += a * vec.data_[j]; 72 | } 73 | } 74 | 75 | real Matrix::dotRow(const Vector& vec, int64_t i) { 76 | assert(i >= 0); 77 | assert(i < m_); 78 | assert(vec.m_ == n_); 79 | real d = 0.0; 80 | for (int64_t j = 0; j < n_; j++) { 81 | d += data_[i * n_ + j] * vec.data_[j]; 82 | } 83 | return d; 84 | } 85 | 86 | void Matrix::save(std::ofstream& ofs) { 87 | ofs.write((char*) &m_, sizeof(int64_t)); 88 | ofs.write((char*) &n_, sizeof(int64_t)); 89 | ofs.write((char*) data_, m_ * n_ * sizeof(real)); 90 | } 91 | 92 | void Matrix::load(std::ifstream& ifs) { 93 | ifs.read((char*) &m_, sizeof(int64_t)); 94 | ifs.read((char*) &n_, sizeof(int64_t)); 95 | delete[] data_; 96 | data_ = new real[m_ * n_]; 97 | ifs.read((char*) data_, m_ * n_ * sizeof(real)); 98 | } 99 | 100 | void Matrix::load(std::ifstream& ifs, int64_t extraM_) { 101 | ifs.read((char*) &m_, sizeof(int64_t)); 102 | m_ += extraM_; 103 | ifs.read((char*) &n_, sizeof(int64_t)); 104 | 105 | delete[] data_; 106 | data_ = new real[(m_) * n_]; 107 | uniform(1.0 / n_); 108 | 109 | ifs.read((char*) data_, (m_-extraM_) * n_ * sizeof(real)); 110 | } 111 | -------------------------------------------------------------------------------- /src/matrix.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_MATRIX_H 11 | #define FASTTEXT_MATRIX_H 12 | 13 | #include 14 | #include 15 | 16 | #include "real.h" 17 | 18 | class Vector; 19 | 20 | class Matrix { 21 | 22 | public: 23 | real* data_; 24 | int64_t m_; 25 | int64_t n_; 26 | 27 | Matrix(); 28 | Matrix(int64_t, int64_t); 29 | Matrix(const Matrix&); 30 | Matrix& operator=(const Matrix&); 31 | ~Matrix(); 32 | 33 | void zero(); 34 | void uniform(real); 35 | real dotRow(const Vector&, int64_t); 36 | void addRow(const Vector&, int64_t, real); 37 | 38 | void save(std::ofstream&); 39 | void load(std::ifstream&); 40 | void load(std::ifstream&, int64_t); 41 | }; 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /src/model.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "model.h" 11 | #include 12 | #include 13 | #include "args.h" 14 | #include "utils.h" 15 | 16 | extern Args args; 17 | 18 | real Model::lr_ = MIN_LR; 19 | 20 | Model::Model(Matrix& wi, Matrix& wo, int32_t hsz, real lr, int32_t seed) 21 | : wi_(wi), wo_(wo), hidden_(hsz), output_(wo.m_), 22 | grad_(hsz), rng(seed) { 23 | isz_ = wi.m_; 24 | osz_ = wo.m_; 25 | hsz_ = hsz; 26 | lr_ = lr; 27 | negpos = 0; 28 | } 29 | 30 | void Model::setLearningRate(real lr) { 31 | lr_ = (lr < MIN_LR) ? MIN_LR : lr; 32 | } 33 | 34 | real Model::getLearningRate() { 35 | return lr_; 36 | } 37 | 38 | real Model::binaryLogistic(int32_t target, bool label) { 39 | real score = utils::sigmoid(wo_.dotRow(hidden_, target)); 40 | real alpha = lr_ * (real(label) - score); 41 | grad_.addRow(wo_, target, alpha); 42 | 43 | if (args.model != model_name::pvdm && args.model != model_name::pvdbow) { 44 | wo_.addRow(hidden_, target, alpha); 45 | } 46 | 47 | if (label) { 48 | return -utils::log(score); 49 | } else { 50 | return -utils::log(1.0 - score); 51 | } 52 | } 53 | 54 | real Model::negativeSampling(int32_t target) { 55 | real loss = 0.0; 56 | grad_.zero(); 57 | for (int32_t n = 0; n <= args.neg; n++) { 58 | if (n == 0) { 59 | loss += binaryLogistic(target, true); 60 | } else { 61 | loss += binaryLogistic(getNegative(target), false); 62 | } 63 | } 64 | 65 | return loss; 66 | } 67 | 68 | real Model::hierarchicalSoftmax(int32_t target) { 69 | real loss = 0.0; 70 | grad_.zero(); 71 | const std::vector& binaryCode = codes[target]; 72 | const std::vector& pathToRoot = paths[target]; 73 | for (int32_t i = 0; i < pathToRoot.size(); i++) { 74 | loss += binaryLogistic(pathToRoot[i], binaryCode[i]); 75 | } 76 | return loss; 77 | } 78 | 79 | real Model::softmax(int32_t target) { 80 | grad_.zero(); 81 | /** 82 | * multi context model 83 | * hidden_ = 1/C * wi_ * (w1 + ... + wC) 84 | * output_ : u 85 | * u = wo_j * hidden_i 86 | * * (matrix or vector).mul: 내부에서 원소 초기화 함. 87 | **/ 88 | output_.mul(wo_, hidden_); 89 | real max = output_[0], z = 0.0; 90 | for (int32_t i = 0; i < osz_; i++) { 91 | max = std::max(output_[i], max); 92 | } 93 | for (int32_t i = 0; i < osz_; i++) { 94 | output_[i] = exp(output_[i] - max); 95 | z += output_[i]; 96 | } 97 | for (int32_t i = 0; i < osz_; i++) { 98 | // label = iverson bracket output. 99 | real label = (i == target) ? 1.0 : 0.0; 100 | 101 | // log-linear classification obtain to multinomial distribution. 102 | output_[i] /= z; 103 | 104 | /** 105 | * label = output_[i]: update equation for output weight 106 | * 실제 라벨과 차이가 있으면 alpha는 음의 방향으로 크다. 107 | * 실제 라벨과 차이가 있다면 alpha에 의해서 grad_은 자연스럽게 local minima를 향해 수렴된다. 108 | * wo_도 마찬가지. 109 | **/ 110 | real alpha = lr_ * (label - output_[i]); 111 | grad_.addRow(wo_, i, alpha); 112 | // output matrix 113 | wo_.addRow(hidden_, i, alpha); 114 | } 115 | return -utils::log(output_[target]); 116 | } 117 | 118 | void Model::computeHidden(const std::vector& input) { 119 | hidden_.zero(); 120 | for (auto it = input.cbegin(); it != input.cend(); ++it) { 121 | hidden_.addRow(wi_, *it); 122 | } 123 | hidden_.mul(1.0 / input.size()); 124 | } 125 | 126 | bool Model::comparePairs(const std::pair &l, 127 | const std::pair &r) { 128 | return l.first > r.first; 129 | } 130 | 131 | void Model::predict(const std::vector& input, int32_t k, 132 | std::vector>& heap) { 133 | assert(k > 0); 134 | heap.reserve(k + 1); 135 | computeHidden(input); 136 | if (args.loss == loss_name::hs) { 137 | dfs(k, 2 * osz_ - 2, 0.0, heap); 138 | } else { 139 | output_.mul(wo_, hidden_); 140 | findKBest(k, heap); 141 | } 142 | std::sort_heap(heap.begin(), heap.end(), comparePairs); 143 | } 144 | 145 | void Model::findKBest(int32_t k, std::vector>& heap) { 146 | for (int32_t i = 0; i < osz_; i++) { 147 | if (heap.size() == k && output_[i] < heap.front().first) { 148 | continue; 149 | } 150 | heap.push_back(std::make_pair(output_[i], i)); 151 | std::push_heap(heap.begin(), heap.end(), comparePairs); 152 | if (heap.size() > k) { 153 | std::pop_heap(heap.begin(), heap.end(), comparePairs); 154 | heap.pop_back(); 155 | } 156 | } 157 | } 158 | 159 | void Model::dfs(int32_t k, int32_t node, real score, 160 | std::vector>& heap) { 161 | if (heap.size() == k && score < heap.front().first) { 162 | return; 163 | } 164 | 165 | if (tree[node].left == -1 && tree[node].right == -1) { 166 | heap.push_back(std::make_pair(score, node)); 167 | std::push_heap(heap.begin(), heap.end(), comparePairs); 168 | if (heap.size() > k) { 169 | std::pop_heap(heap.begin(), heap.end(), comparePairs); 170 | heap.pop_back(); 171 | } 172 | return; 173 | } 174 | 175 | real f = utils::sigmoid(wo_.dotRow(hidden_, node - osz_)); 176 | dfs(k, tree[node].left, score + utils::log(1.0 - f), heap); 177 | dfs(k, tree[node].right, score + utils::log(f), heap); 178 | } 179 | 180 | real Model::update(const std::vector& input, int32_t target) { 181 | assert(target >= 0); 182 | assert(target < osz_); 183 | if (input.size() == 0) return 0.0; 184 | 185 | hidden_.zero(); 186 | for (auto it = input.cbegin(); it != input.cend(); ++it) { 187 | hidden_.addRow(wi_, *it); 188 | } 189 | hidden_.mul(1.0 / input.size()); 190 | 191 | real loss; 192 | if (args.loss == loss_name::ns) { 193 | loss = negativeSampling(target); 194 | } else if (args.loss == loss_name::hs) { 195 | loss = hierarchicalSoftmax(target); 196 | } else { 197 | loss = softmax(target); 198 | } 199 | 200 | if (args.model == model_name::sup) { 201 | grad_.mul(1.0 / input.size()); 202 | } 203 | 204 | if (args.model == model_name::pvdm || args.model == model_name::pvdbow) { 205 | wi_.addRow(grad_, input.back(), 1.0); 206 | } else { 207 | for (auto it = input.cbegin(); it != input.cend(); ++it) { 208 | wi_.addRow(grad_, *it, 1.0); 209 | } 210 | } 211 | 212 | return loss; 213 | } 214 | 215 | void Model::setTargetCounts(const std::vector& counts) { 216 | assert(counts.size() == osz_); 217 | if (args.loss == loss_name::ns) { 218 | initTableNegatives(counts); 219 | } 220 | if (args.loss == loss_name::hs) { 221 | buildTree(counts); 222 | } 223 | } 224 | 225 | void Model::initTableNegatives(const std::vector& counts) { 226 | real z = 0.0; 227 | for (size_t i = 0; i < counts.size(); i++) { 228 | z += pow(counts[i], 0.5); 229 | } 230 | for (size_t i = 0; i < counts.size(); i++) { 231 | real c = pow(counts[i], 0.5); 232 | for (size_t j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) { 233 | negatives.push_back(i); 234 | } 235 | } 236 | std::shuffle(negatives.begin(), negatives.end(), rng); 237 | } 238 | 239 | int32_t Model::getNegative(int32_t target) { 240 | int32_t negative; 241 | do { 242 | negative = negatives[negpos]; 243 | negpos = (negpos + 1) % negatives.size(); 244 | } while (target == negative); 245 | return negative; 246 | } 247 | 248 | void Model::buildTree(const std::vector& counts) { 249 | tree.resize(2 * osz_ - 1); 250 | for (int32_t i = 0; i < 2 * osz_ - 1; i++) { 251 | tree[i].parent = -1; 252 | tree[i].left = -1; 253 | tree[i].right = -1; 254 | tree[i].count = 1e15; 255 | tree[i].binary = false; 256 | } 257 | for (int32_t i = 0; i < osz_; i++) { 258 | tree[i].count = counts[i]; 259 | } 260 | // word들은 leaf로 취급 261 | int32_t leaf = osz_ - 1; 262 | // 노드는 osz_만큼 생성하고 osz_ 이상부터 인덱스로 사용 263 | int32_t node = osz_; 264 | // leaf가 아닌 노드들 초기화 265 | for (int32_t i = osz_; i < 2 * osz_ - 1; i++) { 266 | int32_t mini[2]; 267 | for (int32_t j = 0; j < 2; j++) { 268 | if (leaf >= 0 && tree[leaf].count < tree[node].count) { 269 | mini[j] = leaf--; 270 | } else { 271 | mini[j] = node++; 272 | } 273 | } 274 | tree[i].left = mini[0]; 275 | tree[i].right = mini[1]; 276 | tree[i].count = tree[mini[0]].count + tree[mini[1]].count; 277 | tree[mini[0]].parent = i; 278 | tree[mini[1]].parent = i; 279 | tree[mini[1]].binary = true; 280 | } 281 | for (int32_t i = 0; i < osz_; i++) { 282 | std::vector path; 283 | std::vector code; 284 | int32_t j = i; 285 | while (tree[j].parent != -1) { 286 | path.push_back(tree[j].parent - osz_); 287 | code.push_back(tree[j].binary); 288 | j = tree[j].parent; 289 | } 290 | paths.push_back(path); 291 | codes.push_back(code); 292 | } 293 | } 294 | -------------------------------------------------------------------------------- /src/model.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_MODEL_H 11 | #define FASTTEXT_MODEL_H 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include "matrix.h" 18 | #include "vector.h" 19 | #include "real.h" 20 | 21 | struct Node { 22 | int32_t parent; 23 | int32_t left; 24 | int32_t right; 25 | int64_t count; 26 | bool binary; 27 | }; 28 | 29 | class Model { 30 | private: 31 | Matrix& wi_; 32 | Matrix& wo_; 33 | Vector hidden_; 34 | Vector output_; 35 | Vector grad_; 36 | int32_t hsz_; 37 | int32_t isz_; 38 | int32_t osz_; 39 | 40 | static real lr_; 41 | 42 | static bool comparePairs(const std::pair&, 43 | const std::pair&); 44 | 45 | std::vector negatives; 46 | size_t negpos; 47 | std::vector< std::vector > paths; 48 | std::vector< std::vector > codes; 49 | std::vector tree; 50 | 51 | static const int32_t NEGATIVE_TABLE_SIZE = 10000000; 52 | static constexpr real MIN_LR = 0.000001; 53 | 54 | public: 55 | Model(Matrix&, Matrix&, int32_t, real, int32_t); 56 | 57 | void setLearningRate(real); 58 | real getLearningRate(); 59 | 60 | real binaryLogistic(int32_t, bool); 61 | real negativeSampling(int32_t); 62 | real hierarchicalSoftmax(int32_t); 63 | real softmax(int32_t); 64 | 65 | void predict(const std::vector&, int32_t, 66 | std::vector>&); 67 | void dfs(int32_t, int32_t, real, std::vector>&); 68 | void findKBest(int32_t, std::vector>&); 69 | real update(const std::vector&, int32_t); 70 | void computeHidden(const std::vector&); 71 | 72 | void setTargetCounts(const std::vector&); 73 | void initTableNegatives(const std::vector&); 74 | int32_t getNegative(int32_t target); 75 | void buildTree(const std::vector&); 76 | 77 | std::minstd_rand rng; 78 | }; 79 | 80 | #endif 81 | -------------------------------------------------------------------------------- /src/real.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_REAL_H 11 | #define FASTTEXT_REAL_H 12 | 13 | typedef float real; 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /src/utils.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "utils.h" 11 | 12 | #include 13 | #include 14 | 15 | namespace utils { 16 | real* t_sigmoid; 17 | real* t_log; 18 | 19 | real log(real x) { 20 | if (x > 1.0) { 21 | return 0.0; 22 | } 23 | int i = int(x * LOG_TABLE_SIZE); 24 | return t_log[i]; 25 | } 26 | 27 | real sigmoid(real x) { 28 | if (x < -MAX_SIGMOID) { 29 | return 0.0; 30 | } else if (x > MAX_SIGMOID) { 31 | return 1.0; 32 | } else { 33 | int i = int((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2); 34 | return t_sigmoid[i]; 35 | } 36 | } 37 | 38 | void initTables() { 39 | initSigmoid(); 40 | initLog(); 41 | } 42 | 43 | void initSigmoid() { 44 | t_sigmoid = new real[SIGMOID_TABLE_SIZE + 1]; 45 | for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) { 46 | real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID; 47 | t_sigmoid[i] = 1.0 / (1.0 + std::exp(-x)); 48 | } 49 | } 50 | 51 | void initLog() { 52 | t_log = new real[LOG_TABLE_SIZE + 1]; 53 | for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) { 54 | real x = (real(i) + 1e-5) / LOG_TABLE_SIZE; 55 | t_log[i] = std::log(x); 56 | } 57 | } 58 | 59 | void freeTables() { 60 | delete[] t_sigmoid; 61 | delete[] t_log; 62 | t_sigmoid = nullptr; 63 | t_log = nullptr; 64 | } 65 | 66 | int64_t size(std::ifstream& ifs) { 67 | ifs.seekg(std::streamoff(0), std::ios::end); 68 | return ifs.tellg(); 69 | } 70 | 71 | void seek(std::ifstream& ifs, int64_t pos) { 72 | char c; 73 | ifs.clear(); 74 | ifs.seekg(std::streampos(pos)); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_UTILS_H 11 | #define FASTTEXT_UTILS_H 12 | 13 | #include 14 | 15 | #include "real.h" 16 | 17 | #define SIGMOID_TABLE_SIZE 512 18 | #define MAX_SIGMOID 8 19 | #define LOG_TABLE_SIZE 512 20 | 21 | namespace utils { 22 | 23 | real log(real); 24 | real sigmoid(real); 25 | 26 | void initTables(); 27 | void initSigmoid(); 28 | void initLog(); 29 | void freeTables(); 30 | 31 | int64_t size(std::ifstream&); 32 | void seek(std::ifstream&, int64_t); 33 | } 34 | 35 | #endif 36 | -------------------------------------------------------------------------------- /src/vector.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #include "vector.h" 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | #include "matrix.h" 18 | #include "utils.h" 19 | 20 | Vector::Vector(int64_t m) { 21 | m_ = m; 22 | data_ = new real[m]; 23 | } 24 | 25 | Vector::~Vector() { 26 | delete[] data_; 27 | } 28 | 29 | void Vector::zero() { 30 | for (int64_t i = 0; i < m_; i++) { 31 | data_[i] = 0.0; 32 | } 33 | } 34 | 35 | void Vector::mul(real a) { 36 | for (int64_t i = 0; i < m_; i++) { 37 | data_[i] *= a; 38 | } 39 | } 40 | 41 | void Vector::addRow(const Matrix& A, int64_t i) { 42 | assert(i >= 0); 43 | assert(i < A.m_); 44 | assert(m_ == A.n_); 45 | for (int64_t j = 0; j < A.n_; j++) { 46 | data_[j] += A.data_[i * A.n_ + j]; 47 | } 48 | } 49 | 50 | void Vector::addRow(const Matrix& A, int64_t i, real a) { 51 | assert(i >= 0); 52 | assert(i < A.m_); 53 | assert(m_ == A.n_); 54 | for (int64_t j = 0; j < A.n_; j++) { 55 | data_[j] += a * A.data_[i * A.n_ + j]; 56 | } 57 | } 58 | 59 | void Vector::mul(const Matrix& A, const Vector& vec) { 60 | assert(A.m_ == m_); 61 | assert(A.n_ == vec.m_); 62 | for (int64_t i = 0; i < m_; i++) { 63 | data_[i] = 0.0; 64 | for (int64_t j = 0; j < A.n_; j++) { 65 | data_[i] += A.data_[i * A.n_ + j] * vec.data_[j]; 66 | } 67 | } 68 | } 69 | 70 | int64_t Vector::argmax() { 71 | real max = data_[0]; 72 | int64_t argmax = 0; 73 | for (int64_t i = 1; i < m_; i++) { 74 | if (data_[i] > max) { 75 | max = data_[i]; 76 | argmax = i; 77 | } 78 | } 79 | return argmax; 80 | } 81 | 82 | real& Vector::operator[](int64_t i) { 83 | return data_[i]; 84 | } 85 | 86 | const real& Vector::operator[](int64_t i) const { 87 | return data_[i]; 88 | } 89 | 90 | std::ostream& operator<<(std::ostream& os, const Vector& v) 91 | { 92 | os << std::setprecision(5); 93 | os << v.data_[0]; 94 | for (int64_t j = 1; j < v.m_; j++) { 95 | os << ',' << v.data_[j]; 96 | } 97 | return os; 98 | } 99 | -------------------------------------------------------------------------------- /src/vector.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. An additional grant 7 | * of patent rights can be found in the PATENTS file in the same directory. 8 | */ 9 | 10 | #ifndef FASTTEXT_VECTOR_H 11 | #define FASTTEXT_VECTOR_H 12 | 13 | #include 14 | #include 15 | 16 | #include "real.h" 17 | 18 | class Matrix; 19 | 20 | class Vector { 21 | 22 | public: 23 | int64_t m_; 24 | real* data_; 25 | 26 | explicit Vector(int64_t); 27 | ~Vector(); 28 | 29 | real& operator[](int64_t); 30 | const real& operator[](int64_t) const; 31 | 32 | void zero(); 33 | void mul(real); 34 | void addRow(const Matrix&, int64_t); 35 | void addRow(const Matrix&, int64_t, real); 36 | void mul(const Matrix&, const Vector&); 37 | int64_t argmax(); 38 | }; 39 | 40 | std::ostream& operator<<(std::ostream&, const Vector&); 41 | 42 | #endif 43 | -------------------------------------------------------------------------------- /wikifil.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | 3 | # Program to filter Wikipedia XML dumps to "clean" text consisting only of lowercase 4 | # letters (a-z, converted from A-Z), and spaces (never consecutive). 5 | # All other characters are converted to spaces. Only text which normally appears 6 | # in the web browser is displayed. Tables are removed. Image captions are 7 | # preserved. Links are converted to normal text. Digits are spelled out. 8 | 9 | # Written by Matt Mahoney, June 10, 2006. This program is released to the public domain. 10 | 11 | $/=">"; # input record separator 12 | while (<>) { 13 | if (/ ... 14 | if (/#redirect/i) {$text=0;} # remove #REDIRECT 15 | if ($text) { 16 | 17 | # Remove any text not normally visible 18 | if (/<\/text>/) {$text=0;} 19 | s/<.*>//; # remove xml tags 20 | s/&/&/g; # decode URL encoded chars 21 | s/<//g; 23 | s///g; # remove references ... 24 | s/<[^>]*>//g; # remove xhtml tags 25 | s/\[http:[^] ]*/[/g; # remove normal url, preserve visible text 26 | s/\|thumb//ig; # remove images links, preserve caption 27 | s/\|left//ig; 28 | s/\|right//ig; 29 | s/\|\d+px//ig; 30 | s/\[\[image:[^\[\]]*\|//ig; 31 | s/\[\[category:([^|\]]*)[^]]*\]\]/[[$1]]/ig; # show categories without markup 32 | s/\[\[[a-z\-]*:[^\]]*\]\]//g; # remove links to other languages 33 | s/\[\[[^\|\]]*\|/[[/g; # remove wiki url, preserve visible text 34 | s/{{[^}]*}}//g; # remove {{icons}} and {tables} 35 | s/{[^}]*}//g; 36 | s/\[//g; # remove [ and ] 37 | s/\]//g; 38 | s/&[^;]*;/ /g; # remove URL encoded chars 39 | 40 | # convert to lowercase letters and spaces, spell digits 41 | $_=" $_ "; 42 | tr/A-Z/a-z/; 43 | s/0/ zero /g; 44 | s/1/ one /g; 45 | s/2/ two /g; 46 | s/3/ three /g; 47 | s/4/ four /g; 48 | s/5/ five /g; 49 | s/6/ six /g; 50 | s/7/ seven /g; 51 | s/8/ eight /g; 52 | s/9/ nine /g; 53 | tr/a-z/ /cs; 54 | chop; 55 | print $_; 56 | } 57 | } 58 | --------------------------------------------------------------------------------