├── .gitignore ├── LICENSE ├── PATENTS ├── README.md ├── classification-example.sh ├── classification-results.sh ├── eval.py ├── pom.xml ├── src ├── main │ └── java │ │ └── fasttext │ │ ├── Args.java │ │ ├── Dictionary.java │ │ ├── FastText.java │ │ ├── IOUtil.java │ │ ├── Main.java │ │ ├── Matrix.java │ │ ├── Model.java │ │ ├── Pair.java │ │ ├── Utils.java │ │ ├── Vector.java │ │ └── io │ │ ├── BufferedLineReader.java │ │ ├── LineReader.java │ │ └── MappedByteBufferLineReader.java └── test │ └── java │ └── fasttext │ └── TestDictionary.java ├── wikifil.pl └── word-vector-example.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .*.swp 2 | *.o 3 | *.bin 4 | *.vec 5 | data 6 | fasttext 7 | result 8 | target -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fastText software 4 | 5 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the fastText software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fasttext_java 2 | Java port of c++ version of facebook fasttext [UPDATED 2017-01-29] 3 | 4 | Support Load/Save facebook fasttext binary model file 5 | 6 | ## Building fastText_java 7 | Requirements: Maven, Java 1.6 or onwards 8 | 9 | In order to build `fastText_java`, use the following: 10 | 11 | ``` 12 | $ git clone https://github.com/ivanhk/fastText_java.git 13 | $ cd fastText_java 14 | $ mvn package 15 | ``` 16 | 17 | ## Resources 18 | 19 | You can find more information and resources at https://github.com/facebookresearch/fastText 20 | 21 | ## License 22 | 23 | fastText is BSD-licensed. Facebook also provide an additional patent grant. 24 | -------------------------------------------------------------------------------- /classification-example.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Copyright (c) 2016-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. An additional grant 8 | # of patent rights can be found in the PATENTS file in the same directory. 9 | # 10 | 11 | myshuf() { 12 | perl -MList::Util=shuffle -e 'print shuffle(<>);' "$@"; 13 | } 14 | 15 | normalize_text() { 16 | tr '[:upper:]' '[:lower:]' | sed -e 's/^/__label__/g' | \ 17 | sed -e "s/'/ ' /g" -e 's/"//g' -e 's/\./ \. /g' -e 's/
/ /g' \ 18 | -e 's/,/ , /g' -e 's/(/ ( /g' -e 's/)/ ) /g' -e 's/\!/ \! /g' \ 19 | -e 's/\?/ \? /g' -e 's/\;/ /g' -e 's/\:/ /g' | tr -s " " | myshuf 20 | } 21 | 22 | RESULTDIR=result 23 | DATADIR=data 24 | 25 | mkdir -p "${RESULTDIR}" 26 | mkdir -p "${DATADIR}" 27 | 28 | if [ ! -f "${DATADIR}/dbpedia.train" ] 29 | then 30 | wget -c "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz" -O "${DATADIR}/dbpedia_csv.tar.gz" 31 | tar -xzvf "${DATADIR}/dbpedia_csv.tar.gz" -C "${DATADIR}" 32 | cat "${DATADIR}/dbpedia_csv/train.csv" | normalize_text > "${DATADIR}/dbpedia.train" 33 | cat "${DATADIR}/dbpedia_csv/test.csv" | normalize_text > "${DATADIR}/dbpedia.test" 34 | fi 35 | 36 | mvn package 37 | 38 | JAR=./target/fasttext-0.0.1-SNAPSHOT-jar-with-dependencies.jar 39 | 40 | java -jar ${JAR} supervised -input "${DATADIR}/dbpedia.train" -output "${RESULTDIR}/dbpedia" -dim 10 -lr 0.1 -wordNgrams 2 -minCount 1 -bucket 10000000 -epoch 5 -thread 4 41 | 42 | java -jar ${JAR} test "${RESULTDIR}/dbpedia.bin" "${DATADIR}/dbpedia.test" 43 | 44 | java -jar ${JAR} predict "${RESULTDIR}/dbpedia.bin" "${DATADIR}/dbpedia.test" > "${RESULTDIR}/dbpedia.test.predict" 45 | -------------------------------------------------------------------------------- /classification-results.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Copyright (c) 2016-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. An additional grant 8 | # of patent rights can be found in the PATENTS file in the same directory. 9 | # 10 | 11 | # This script produces the results from Table 1 in the following paper: 12 | # Bag of Tricks for Efficient Text Classification, arXiv 1607.01759, 2016 13 | 14 | myshuf() { 15 | perl -MList::Util=shuffle -e 'print shuffle(<>);' "$@"; 16 | } 17 | 18 | normalize_text() { 19 | tr '[:upper:]' '[:lower:]' | sed -e 's/^/__label__/g' | \ 20 | sed -e "s/'/ ' /g" -e 's/"//g' -e 's/\./ \. /g' -e 's/
/ /g' \ 21 | -e 's/,/ , /g' -e 's/(/ ( /g' -e 's/)/ ) /g' -e 's/\!/ \! /g' \ 22 | -e 's/\?/ \? /g' -e 's/\;/ /g' -e 's/\:/ /g' | tr -s " " | myshuf 23 | } 24 | 25 | DATASET=( 26 | ag_news 27 | sogou_news 28 | dbpedia 29 | yelp_review_polarity 30 | yelp_review_full 31 | yahoo_answers 32 | amazon_review_full 33 | amazon_review_polarity 34 | ) 35 | 36 | ID=( 37 | 0Bz8a_Dbh9QhbUDNpeUdjb0wxRms # ag_news 38 | 0Bz8a_Dbh9QhbUkVqNEszd0pHaFE # sogou_news 39 | 0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k # dbpedia 40 | 0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg # yelp_review_polarity 41 | 0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0 # yelp_review_full 42 | 0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU # yahoo_answers 43 | 0Bz8a_Dbh9QhbZVhsUnRWRDhETzA # amazon_review_full 44 | 0Bz8a_Dbh9QhbaW12WVVZS2drcnM # amazon_review_polarity 45 | ) 46 | 47 | # These learning rates were chosen by validation on a subset of the training set. 48 | LR=( 0.25 0.5 0.5 0.1 0.1 0.1 0.05 0.05 ) 49 | 50 | RESULTDIR=result 51 | DATADIR=data 52 | 53 | mkdir -p "${RESULTDIR}" 54 | mkdir -p "${DATADIR}" 55 | 56 | for i in {0..7} 57 | do 58 | echo "Downloading dataset ${DATASET[i]}" 59 | if [ ! -f "${DATADIR}/${DATASET[i]}.train" ] 60 | then 61 | wget -c "https://googledrive.com/host/${ID[i]}" -O "${DATADIR}/${DATASET[i]}_csv.tar.gz" 62 | tar -xzvf "${DATADIR}/${DATASET[i]}_csv.tar.gz" -C "${DATADIR}" 63 | cat "${DATADIR}/${DATASET[i]}_csv/train.csv" | normalize_text > "${DATADIR}/${DATASET[i]}.train" 64 | cat "${DATADIR}/${DATASET[i]}_csv/test.csv" | normalize_text > "${DATADIR}/${DATASET[i]}.test" 65 | fi 66 | done 67 | 68 | mvn package 69 | 70 | JAR=./target/fasttext-0.0.1-SNAPSHOT-jar-with-dependencies.jar 71 | 72 | for i in {0..7} 73 | do 74 | echo "Working on dataset ${DATASET[i]}" 75 | java -jar ${JAR} supervised -input "${DATADIR}/${DATASET[i]}.train" \ 76 | -output "${RESULTDIR}/${DATASET[i]}" -dim 10 -lr "${LR[i]}" -wordNgrams 2 \ 77 | -minCount 1 -bucket 10000000 -epoch 5 -thread 4 > /dev/null 78 | java -jar ${JAR} test "${RESULTDIR}/${DATASET[i]}.bin" \ 79 | "${DATADIR}/${DATASET[i]}.test" 80 | done 81 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (c) 2016-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the BSD-style license found in the 8 | # LICENSE file in the root directory of this source tree. An additional grant 9 | # of patent rights can be found in the PATENTS file in the same directory. 10 | # 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | from __future__ import unicode_literals 16 | import numpy as np 17 | from scipy import stats 18 | import sys 19 | import os 20 | import math 21 | import argparse 22 | 23 | def compat_splitting(line): 24 | return line.decode('utf8').split() 25 | 26 | def similarity(v1, v2): 27 | n1 = np.linalg.norm(v1) 28 | n2 = np.linalg.norm(v2) 29 | return np.dot(v1, v2) / n1 / n2 30 | 31 | parser = argparse.ArgumentParser(description='Process some integers.') 32 | parser.add_argument('--model', '-m', dest='modelPath', action='store', required=True, help='path to model') 33 | parser.add_argument('--data', '-d', dest='dataPath', action='store', required=True, help='path to data') 34 | args = parser.parse_args() 35 | 36 | vectors = {} 37 | fin = open(args.modelPath, 'rb') 38 | for i, line in enumerate(fin): 39 | try: 40 | tab = compat_splitting(line) 41 | vec = np.array(tab[1:], dtype=float) 42 | word = tab[0] 43 | if not word in vectors: 44 | vectors[word] = vec 45 | except ValueError: 46 | continue 47 | except UnicodeDecodeError: 48 | continue 49 | fin.close() 50 | 51 | mysim = [] 52 | gold = [] 53 | drop = 0.0 54 | nwords = 0.0 55 | 56 | fin = open(args.dataPath, 'rb') 57 | for line in fin: 58 | tline = compat_splitting(line) 59 | word1 = tline[0].lower() 60 | word2 = tline[1].lower() 61 | nwords = nwords + 1.0 62 | 63 | if (word1 in vectors) and (word2 in vectors): 64 | v1 = vectors[word1] 65 | v2 = vectors[word2] 66 | d = similarity(v1, v2) 67 | mysim.append(d) 68 | gold.append(float(tline[2])) 69 | else: 70 | drop = drop + 1.0 71 | fin.close() 72 | 73 | corr = stats.spearmanr(mysim, gold) 74 | dataset = os.path.basename(args.dataPath) 75 | print("{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)" 76 | .format(dataset, corr[0] * 100, math.ceil(drop / nwords * 100.0))) 77 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | fasttext 6 | fasttext 7 | 0.0.1-SNAPSHOT 8 | jar 9 | 10 | fasttext 11 | http://maven.apache.org 12 | 13 | 14 | UTF-8 15 | 1.6 16 | 1.6 17 | 18 | 19 | 20 | 21 | junit 22 | junit 23 | 4.12 24 | test 25 | 26 | 27 | 28 | 29 | 30 | 31 | maven-compiler-plugin 32 | 33 | UTF-8 34 | 35 | lib 36 | 37 | 1.6 38 | 1.6 39 | 1.6 40 | 41 | 3.3 42 | 43 | 44 | org.apache.maven.plugins 45 | maven-resources-plugin 46 | 47 | UTF-8 48 | 49 | 3.0.1 50 | 51 | 52 | org.apache.maven.plugins 53 | maven-source-plugin 54 | 2.4 55 | 56 | 57 | attach-sources 58 | 59 | jar 60 | 61 | 62 | 63 | 64 | 65 | 66 | maven-assembly-plugin 67 | 2.6 68 | 69 | 70 | jar-with-dependencies 71 | 72 | 73 | 74 | fasttext.Main 75 | 76 | 77 | 78 | 79 | 80 | simple-command 81 | package 82 | 83 | attached 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /src/main/java/fasttext/Args.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import java.io.IOException; 4 | import java.io.InputStream; 5 | import java.io.OutputStream; 6 | 7 | public class Args { 8 | 9 | public enum model_name { 10 | cbow(1), sg(2), sup(3); 11 | 12 | private int value; 13 | 14 | private model_name(int value) { 15 | this.value = value; 16 | } 17 | 18 | public int getValue() { 19 | return this.value; 20 | } 21 | 22 | public static model_name fromValue(int value) throws IllegalArgumentException { 23 | try { 24 | value -= 1; 25 | return model_name.values()[value]; 26 | } catch (ArrayIndexOutOfBoundsException e) { 27 | throw new IllegalArgumentException("Unknown model_name enum value :" + value); 28 | } 29 | } 30 | } 31 | 32 | public enum loss_name { 33 | hs(1), ns(2), softmax(3); 34 | private int value; 35 | 36 | private loss_name(int value) { 37 | this.value = value; 38 | } 39 | 40 | public int getValue() { 41 | return this.value; 42 | } 43 | 44 | public static loss_name fromValue(int value) throws IllegalArgumentException { 45 | try { 46 | value -= 1; 47 | return loss_name.values()[value]; 48 | } catch (ArrayIndexOutOfBoundsException e) { 49 | throw new IllegalArgumentException("Unknown loss_name enum value :" + value); 50 | } 51 | } 52 | } 53 | 54 | public String input; 55 | public String output; 56 | public String test; 57 | public double lr = 0.05; 58 | public int lrUpdateRate = 100; 59 | public int dim = 100; 60 | public int ws = 5; 61 | public int epoch = 5; 62 | public int minCount = 5; 63 | public int minCountLabel = 0; 64 | public int neg = 5; 65 | public int wordNgrams = 1; 66 | public loss_name loss = loss_name.ns; 67 | public model_name model = model_name.sg; 68 | public int bucket = 2000000; 69 | public int minn = 3; 70 | public int maxn = 6; 71 | public int thread = 1; 72 | public double t = 1e-4; 73 | public String label = "__label__"; 74 | public int verbose = 2; 75 | public String pretrainedVectors = ""; 76 | 77 | public void printHelp() { 78 | System.out.println("\n" + "The following arguments are mandatory:\n" 79 | + " -input training file path\n" 80 | + " -output output file path\n\n" 81 | + "The following arguments are optional:\n" 82 | + " -lr learning rate [" + lr + "]\n" 83 | + " -lrUpdateRate change the rate of updates for the learning rate [" + lrUpdateRate + "]\n" 84 | + " -dim size of word vectors [" + dim + "]\n" 85 | + " -ws size of the context window [" + ws + "]\n" 86 | + " -epoch number of epochs [" + epoch + "]\n" 87 | + " -minCount minimal number of word occurences [" + minCount + "]\n" 88 | + " -minCountLabel minimal number of label occurences [" + minCountLabel + "]\n" 89 | + " -neg number of negatives sampled [" + neg + "]\n" 90 | + " -wordNgrams max length of word ngram [" + wordNgrams + "]\n" 91 | + " -loss loss function {ns, hs, softmax} [ns]\n" 92 | + " -bucket number of buckets [" + bucket + "]\n" 93 | + " -minn min length of char ngram [" + minn + "]\n" 94 | + " -maxn max length of char ngram [" + maxn + "]\n" 95 | + " -thread number of threads [" + thread + "]\n" 96 | + " -t sampling threshold [" + t + "]\n" 97 | + " -label labels prefix [" + label + "]\n" 98 | + " -verbose verbosity level [" + verbose + "]\n" 99 | + " -pretrainedVectors pretrained word vectors for supervised learning []"); 100 | } 101 | 102 | public void save(OutputStream ofs) throws IOException { 103 | IOUtil ioutil = new IOUtil(); 104 | ofs.write(ioutil.intToByteArray(dim)); 105 | ofs.write(ioutil.intToByteArray(ws)); 106 | ofs.write(ioutil.intToByteArray(epoch)); 107 | ofs.write(ioutil.intToByteArray(minCount)); 108 | ofs.write(ioutil.intToByteArray(neg)); 109 | ofs.write(ioutil.intToByteArray(wordNgrams)); 110 | ofs.write(ioutil.intToByteArray(loss.value)); 111 | ofs.write(ioutil.intToByteArray(model.value)); 112 | ofs.write(ioutil.intToByteArray(bucket)); 113 | ofs.write(ioutil.intToByteArray(minn)); 114 | ofs.write(ioutil.intToByteArray(maxn)); 115 | ofs.write(ioutil.intToByteArray(lrUpdateRate)); 116 | ofs.write(ioutil.doubleToByteArray(t)); 117 | } 118 | 119 | public void load(InputStream input) throws IOException { 120 | IOUtil ioutil = new IOUtil(); 121 | dim = ioutil.readInt(input); 122 | ws = ioutil.readInt(input); 123 | epoch = ioutil.readInt(input); 124 | minCount = ioutil.readInt(input); 125 | neg = ioutil.readInt(input); 126 | wordNgrams = ioutil.readInt(input); 127 | loss = loss_name.fromValue(ioutil.readInt(input)); 128 | model = model_name.fromValue(ioutil.readInt(input)); 129 | bucket = ioutil.readInt(input); 130 | minn = ioutil.readInt(input); 131 | maxn = ioutil.readInt(input); 132 | lrUpdateRate = ioutil.readInt(input); 133 | t = ioutil.readDouble(input); 134 | } 135 | 136 | public void parseArgs(String[] args) { 137 | String command = args[0]; 138 | if ("supervised".equalsIgnoreCase(command)) { 139 | model = model_name.sup; 140 | loss = loss_name.softmax; 141 | minCount = 1; 142 | minn = 0; 143 | maxn = 0; 144 | lr = 0.1; 145 | } else if ("cbow".equalsIgnoreCase(command)) { 146 | model = model_name.cbow; 147 | } 148 | int ai = 1; 149 | while (ai < args.length) { 150 | if (args[ai].charAt(0) != '-') { 151 | System.out.println("Provided argument without a dash! Usage:"); 152 | printHelp(); 153 | System.exit(1); 154 | } 155 | if ("-h".equals(args[ai])) { 156 | System.out.println("Here is the help! Usage:"); 157 | printHelp(); 158 | System.exit(1); 159 | } else if ("-input".equals(args[ai])) { 160 | input = args[ai + 1]; 161 | } else if ("-test".equals(args[ai])) { 162 | test = args[ai + 1]; 163 | } else if ("-output".equals(args[ai])) { 164 | output = args[ai + 1]; 165 | } else if ("-lr".equals(args[ai])) { 166 | lr = Double.parseDouble(args[ai + 1]); 167 | } else if ("-lrUpdateRate".equals(args[ai])) { 168 | lrUpdateRate = Integer.parseInt(args[ai + 1]); 169 | } else if ("-dim".equals(args[ai])) { 170 | dim = Integer.parseInt(args[ai + 1]); 171 | } else if ("-ws".equals(args[ai])) { 172 | ws = Integer.parseInt(args[ai + 1]); 173 | } else if ("-epoch".equals(args[ai])) { 174 | epoch = Integer.parseInt(args[ai + 1]); 175 | } else if ("-minCount".equals(args[ai])) { 176 | minCount = Integer.parseInt(args[ai + 1]); 177 | } else if ("-minCountLabel".equals(args[ai])) { 178 | minCountLabel = Integer.parseInt(args[ai + 1]); 179 | } else if ("-neg".equals(args[ai])) { 180 | neg = Integer.parseInt(args[ai + 1]); 181 | } else if ("-wordNgrams".equals(args[ai])) { 182 | wordNgrams = Integer.parseInt(args[ai + 1]); 183 | } else if ("-loss".equals(args[ai])) { 184 | if ("hs".equalsIgnoreCase(args[ai + 1])) { 185 | loss = loss_name.hs; 186 | } else if ("ns".equalsIgnoreCase(args[ai + 1])) { 187 | loss = loss_name.ns; 188 | } else if ("softmax".equalsIgnoreCase(args[ai + 1])) { 189 | loss = loss_name.softmax; 190 | } else { 191 | System.out.println("Unknown loss: " + args[ai + 1]); 192 | printHelp(); 193 | System.exit(1); 194 | } 195 | } else if ("-bucket".equals(args[ai])) { 196 | bucket = Integer.parseInt(args[ai + 1]); 197 | } else if ("-minn".equals(args[ai])) { 198 | minn = Integer.parseInt(args[ai + 1]); 199 | } else if ("-maxn".equals(args[ai])) { 200 | maxn = Integer.parseInt(args[ai + 1]); 201 | } else if ("-thread".equals(args[ai])) { 202 | thread = Integer.parseInt(args[ai + 1]); 203 | } else if ("-t".equals(args[ai])) { 204 | t = Double.parseDouble(args[ai + 1]); 205 | } else if ("-label".equals(args[ai])) { 206 | label = args[ai + 1]; 207 | } else if ("-verbose".equals(args[ai])) { 208 | verbose = Integer.parseInt(args[ai + 1]); 209 | } else if ("-pretrainedVectors".equals(args[ai])) { 210 | pretrainedVectors = args[ai + 1]; 211 | } else { 212 | System.out.println("Unknown argument: " + args[ai]); 213 | printHelp(); 214 | System.exit(1); 215 | } 216 | ai += 2; 217 | } 218 | if (Utils.isEmpty(input) || Utils.isEmpty(output)) { 219 | System.out.println("Empty input or output path."); 220 | printHelp(); 221 | System.exit(1); 222 | } 223 | if (wordNgrams <= 1 && maxn == 0) { 224 | bucket = 0; 225 | } 226 | } 227 | 228 | @Override 229 | public String toString() { 230 | StringBuilder builder = new StringBuilder(); 231 | builder.append("Args [input="); 232 | builder.append(input); 233 | builder.append(", output="); 234 | builder.append(output); 235 | builder.append(", test="); 236 | builder.append(test); 237 | builder.append(", lr="); 238 | builder.append(lr); 239 | builder.append(", lrUpdateRate="); 240 | builder.append(lrUpdateRate); 241 | builder.append(", dim="); 242 | builder.append(dim); 243 | builder.append(", ws="); 244 | builder.append(ws); 245 | builder.append(", epoch="); 246 | builder.append(epoch); 247 | builder.append(", minCount="); 248 | builder.append(minCount); 249 | builder.append(", minCountLabel="); 250 | builder.append(minCountLabel); 251 | builder.append(", neg="); 252 | builder.append(neg); 253 | builder.append(", wordNgrams="); 254 | builder.append(wordNgrams); 255 | builder.append(", loss="); 256 | builder.append(loss); 257 | builder.append(", model="); 258 | builder.append(model); 259 | builder.append(", bucket="); 260 | builder.append(bucket); 261 | builder.append(", minn="); 262 | builder.append(minn); 263 | builder.append(", maxn="); 264 | builder.append(maxn); 265 | builder.append(", thread="); 266 | builder.append(thread); 267 | builder.append(", t="); 268 | builder.append(t); 269 | builder.append(", label="); 270 | builder.append(label); 271 | builder.append(", verbose="); 272 | builder.append(verbose); 273 | builder.append(", pretrainedVectors="); 274 | builder.append(pretrainedVectors); 275 | builder.append("]"); 276 | return builder.toString(); 277 | } 278 | 279 | } 280 | -------------------------------------------------------------------------------- /src/main/java/fasttext/Dictionary.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import java.io.IOException; 4 | import java.io.InputStream; 5 | import java.io.OutputStream; 6 | import java.util.ArrayList; 7 | import java.util.Collections; 8 | import java.util.Comparator; 9 | import java.util.HashMap; 10 | import java.util.Iterator; 11 | import java.util.List; 12 | import java.util.Map; 13 | import java.util.Random; 14 | import java.math.BigInteger; 15 | 16 | import fasttext.Args.model_name; 17 | import fasttext.io.BufferedLineReader; 18 | import fasttext.io.LineReader; 19 | 20 | public class Dictionary { 21 | 22 | private static final int MAX_VOCAB_SIZE = 30000000; 23 | private static final int MAX_LINE_SIZE = 1024; 24 | private static final Integer WORDID_DEFAULT = -1; 25 | 26 | private static final String EOS = ""; 27 | private static final String BOW = "<"; 28 | private static final String EOW = ">"; 29 | 30 | public enum entry_type { 31 | word(0), label(1); 32 | 33 | private int value; 34 | 35 | private entry_type(int value) { 36 | this.value = value; 37 | } 38 | 39 | public int getValue() { 40 | return this.value; 41 | } 42 | 43 | public static entry_type fromValue(int value) throws IllegalArgumentException { 44 | try { 45 | return entry_type.values()[value]; 46 | } catch (ArrayIndexOutOfBoundsException e) { 47 | throw new IllegalArgumentException("Unknown entry_type enum value :" + value); 48 | } 49 | } 50 | 51 | @Override 52 | public String toString() { 53 | return value == 0 ? "word" : value == 1 ? "label" : "unknown"; 54 | } 55 | } 56 | 57 | public class entry { 58 | String word; 59 | long count; 60 | entry_type type; 61 | List subwords; 62 | 63 | @Override 64 | public String toString() { 65 | StringBuilder builder = new StringBuilder(); 66 | builder.append("entry [word="); 67 | builder.append(word); 68 | builder.append(", count="); 69 | builder.append(count); 70 | builder.append(", type="); 71 | builder.append(type); 72 | builder.append(", subwords="); 73 | builder.append(subwords); 74 | builder.append("]"); 75 | return builder.toString(); 76 | } 77 | 78 | } 79 | 80 | private List words_; 81 | private List pdiscard_; 82 | private Map word2int_; 83 | private int size_; 84 | private int nwords_; 85 | private int nlabels_; 86 | private long ntokens_; 87 | 88 | private Args args_; 89 | 90 | private String charsetName_ = "UTF-8"; 91 | private Class lineReaderClass_ = BufferedLineReader.class; 92 | 93 | public Dictionary(Args args) { 94 | args_ = args; 95 | size_ = 0; 96 | nwords_ = 0; 97 | nlabels_ = 0; 98 | ntokens_ = 0; 99 | word2int_ = new HashMap(MAX_VOCAB_SIZE); 100 | words_ = new ArrayList(MAX_VOCAB_SIZE); 101 | } 102 | 103 | public long find(final String w) { 104 | long h = hash(w) % MAX_VOCAB_SIZE; 105 | entry e = null; 106 | while (Utils.mapGetOrDefault(word2int_, h, WORDID_DEFAULT) != WORDID_DEFAULT 107 | && ((e = words_.get(word2int_.get(h))) != null && !w.equals(e.word))) { 108 | h = (h + 1) % MAX_VOCAB_SIZE; 109 | } 110 | return h; 111 | } 112 | 113 | public void add(final String w) { 114 | long h = find(w); 115 | ntokens_++; 116 | if (Utils.mapGetOrDefault(word2int_, h, WORDID_DEFAULT) == WORDID_DEFAULT) { 117 | entry e = new entry(); 118 | e.word = w; 119 | e.count = 1; 120 | e.type = w.startsWith(args_.label) ? entry_type.label : entry_type.word; 121 | words_.add(e); 122 | word2int_.put(h, size_++); 123 | } else { 124 | words_.get(word2int_.get(h)).count++; 125 | } 126 | } 127 | 128 | public int nwords() { 129 | return nwords_; 130 | } 131 | 132 | public int nlabels() { 133 | return nlabels_; 134 | } 135 | 136 | public long ntokens() { 137 | return ntokens_; 138 | } 139 | 140 | public final List getNgrams(int i) { 141 | Utils.checkArgument(i >= 0); 142 | Utils.checkArgument(i < nwords_); 143 | return words_.get(i).subwords; 144 | } 145 | 146 | public final List getNgrams(final String word) { 147 | List ngrams = new ArrayList(); 148 | int i = getId(word); 149 | if (i >= 0) { 150 | ngrams = words_.get(i).subwords; 151 | } else { 152 | computeNgrams(BOW + word + EOW, ngrams); 153 | } 154 | return ngrams; 155 | } 156 | 157 | public boolean discard(int id, float rand) { 158 | Utils.checkArgument(id >= 0); 159 | Utils.checkArgument(id < nwords_); 160 | if (args_.model == model_name.sup) 161 | return false; 162 | return rand > pdiscard_.get(id); 163 | } 164 | 165 | public int getId(final String w) { 166 | long h = find(w); 167 | return Utils.mapGetOrDefault(word2int_, h, WORDID_DEFAULT); 168 | } 169 | 170 | public entry_type getType(int id) { 171 | Utils.checkArgument(id >= 0); 172 | Utils.checkArgument(id < size_); 173 | return words_.get(id).type; 174 | } 175 | 176 | public String getWord(int id) { 177 | Utils.checkArgument(id >= 0); 178 | Utils.checkArgument(id < size_); 179 | return words_.get(id).word; 180 | } 181 | 182 | /** 183 | * String FNV-1a Hash 184 | * 185 | * @param str 186 | * @return 187 | */ 188 | public long hash(final String str) { 189 | int h = (int) 2166136261L;// 0xffffffc5; 190 | for (byte strByte : str.getBytes()) { 191 | h = (h ^ strByte) * 16777619; // FNV-1a 192 | // h = (h * 16777619) ^ strByte; //FNV-1 193 | } 194 | return h & 0xffffffffL; 195 | } 196 | 197 | public void computeNgrams(final String word, List ngrams) { 198 | for (int i = 0; i < word.length(); i++) { 199 | StringBuilder ngram = new StringBuilder(); 200 | if (charMatches(word.charAt(i))) { 201 | continue; 202 | } 203 | for (int j = i, n = 1; j < word.length() && n <= args_.maxn; n++) { 204 | ngram.append(word.charAt(j++)); 205 | while (j < word.length() && charMatches(word.charAt(j))) { 206 | ngram.append(word.charAt(j++)); 207 | } 208 | if (n >= args_.minn && !(n == 1 && (i == 0 || j == word.length()))) { 209 | int h = (int) (nwords_ + (hash(ngram.toString()) % args_.bucket)); 210 | if (h < 0) { 211 | System.err.println("computeNgrams h<0: " + h + " on word: " + word); 212 | } 213 | ngrams.add(h); 214 | } 215 | } 216 | } 217 | } 218 | 219 | private boolean charMatches(char ch) { 220 | if (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\f' || ch == '\r') { 221 | return true; 222 | } 223 | return false; 224 | } 225 | 226 | public void initNgrams() { 227 | for (int i = 0; i < size_; i++) { 228 | String word = BOW + words_.get(i).word + EOW; 229 | entry e = words_.get(i); 230 | if (e.subwords == null) { 231 | e.subwords = new ArrayList(); 232 | } 233 | e.subwords.add(i); 234 | computeNgrams(word, e.subwords); 235 | } 236 | } 237 | 238 | public void readFromFile(String file) throws IOException, Exception { 239 | LineReader lineReader = null; 240 | 241 | try { 242 | lineReader = lineReaderClass_.getConstructor(String.class, String.class).newInstance(file, charsetName_); 243 | long minThreshold = 1; 244 | String[] lineTokens; 245 | while ((lineTokens = lineReader.readLineTokens()) != null) { 246 | for (int i = 0; i <= lineTokens.length; i++) { 247 | if (i == lineTokens.length) { 248 | add(EOS); 249 | } else { 250 | if (Utils.isEmpty(lineTokens[i])) { 251 | continue; 252 | } 253 | add(lineTokens[i]); 254 | } 255 | if (ntokens_ % 1000000 == 0 && args_.verbose > 1) { 256 | System.out.printf("\rRead %dM words", ntokens_ / 1000000); 257 | } 258 | if (size_ > 0.75 * MAX_VOCAB_SIZE) { 259 | minThreshold++; 260 | threshold(minThreshold, minThreshold); 261 | } 262 | } 263 | } 264 | } finally { 265 | if (lineReader != null) { 266 | lineReader.close(); 267 | } 268 | } 269 | threshold(args_.minCount, args_.minCountLabel); 270 | initTableDiscard(); 271 | if (model_name.cbow == args_.model || model_name.sg == args_.model) { 272 | initNgrams(); 273 | } 274 | if (args_.verbose > 0) { 275 | System.out.printf("\rRead %dM words\n", ntokens_ / 1000000); 276 | System.out.println("Number of words: " + nwords_); 277 | System.out.println("Number of labels: " + nlabels_); 278 | } 279 | if (size_ == 0) { 280 | System.err.println("Empty vocabulary. Try a smaller -minCount value."); 281 | System.exit(1); 282 | } 283 | } 284 | 285 | public void threshold(long t, long tl) { 286 | Collections.sort(words_, entry_comparator); 287 | Iterator iterator = words_.iterator(); 288 | while (iterator.hasNext()) { 289 | entry _entry = iterator.next(); 290 | if ((entry_type.word == _entry.type && _entry.count < t) 291 | || (entry_type.label == _entry.type && _entry.count < tl)) { 292 | iterator.remove(); 293 | } 294 | } 295 | ((ArrayList) words_).trimToSize(); 296 | size_ = 0; 297 | nwords_ = 0; 298 | nlabels_ = 0; 299 | // word2int_.clear(); 300 | word2int_ = new HashMap(words_.size()); 301 | for (entry _entry : words_) { 302 | long h = find(_entry.word); 303 | word2int_.put(h, size_++); 304 | if (entry_type.word == _entry.type) { 305 | nwords_++; 306 | } else if (entry_type.label == _entry.type) { 307 | nlabels_++; 308 | } 309 | } 310 | } 311 | 312 | private transient Comparator entry_comparator = new Comparator() { 313 | @Override 314 | public int compare(entry o1, entry o2) { 315 | int cmp = (o1.type.value < o2.type.value) ? -1 : ((o1.type.value == o2.type.value) ? 0 : 1); 316 | if (cmp == 0) { 317 | cmp = (o2.count < o1.count) ? -1 : ((o2.count == o1.count) ? 0 : 1); 318 | } 319 | return cmp; 320 | } 321 | }; 322 | 323 | public void initTableDiscard() { 324 | pdiscard_ = new ArrayList(size_); 325 | for (int i = 0; i < size_; i++) { 326 | float f = (float) (words_.get(i).count) / (float) ntokens_; 327 | pdiscard_.add((float) (Math.sqrt(args_.t / f) + args_.t / f)); 328 | } 329 | } 330 | 331 | public List getCounts(entry_type type) { 332 | List counts = entry_type.label == type ? new ArrayList(nlabels()) : new ArrayList(nwords()); 333 | for (entry w : words_) { 334 | if (w.type == type) 335 | counts.add(w.count); 336 | } 337 | return counts; 338 | } 339 | 340 | public void addNgrams(List line, int n) { 341 | if (n <= 1) { 342 | return; 343 | } 344 | int line_size = line.size(); 345 | for (int i = 0; i < line_size; i++) { 346 | BigInteger h = BigInteger.valueOf(line.get(i)); 347 | BigInteger r = BigInteger.valueOf(116049371l); 348 | BigInteger b = BigInteger.valueOf(args_.bucket); 349 | 350 | for (int j = i + 1; j < line_size && j < i + n; j++) { 351 | h = h.multiply(r).add(BigInteger.valueOf(line.get(j)));; 352 | line.add(nwords_ + h.remainder(b).intValue()); 353 | } 354 | } 355 | } 356 | 357 | public int getLine(String[] tokens, List words, List labels, Random urd) { 358 | int ntokens = 0; 359 | words.clear(); 360 | labels.clear(); 361 | if (tokens != null) { 362 | for (int i = 0; i <= tokens.length; i++) { 363 | if (i < tokens.length && Utils.isEmpty(tokens[i])) { 364 | continue; 365 | } 366 | int wid = i == tokens.length ? getId(EOS) : getId(tokens[i]); 367 | if (wid < 0) { 368 | continue; 369 | } 370 | entry_type type = getType(wid); 371 | ntokens++; 372 | if (type == entry_type.word && !discard(wid, Utils.randomFloat(urd, 0, 1))) { 373 | words.add(wid); 374 | } 375 | if (type == entry_type.label) { 376 | labels.add(wid - nwords_); 377 | } 378 | if (words.size() > MAX_LINE_SIZE && args_.model != model_name.sup) { 379 | break; 380 | } 381 | // if (EOS == tokens[i]){ 382 | // break; 383 | // } 384 | } 385 | } 386 | return ntokens; 387 | } 388 | 389 | public String getLabel(int lid) { 390 | Utils.checkArgument(lid >= 0); 391 | Utils.checkArgument(lid < nlabels_); 392 | return words_.get(lid + nwords_).word; 393 | } 394 | 395 | public void save(OutputStream ofs) throws IOException { 396 | IOUtil ioutil = new IOUtil(); 397 | ofs.write(ioutil.intToByteArray(size_)); 398 | ofs.write(ioutil.intToByteArray(nwords_)); 399 | ofs.write(ioutil.intToByteArray(nlabels_)); 400 | ofs.write(ioutil.longToByteArray(ntokens_)); 401 | // Charset charset = Charset.forName("UTF-8"); 402 | for (int i = 0; i < size_; i++) { 403 | entry e = words_.get(i); 404 | ofs.write(e.word.getBytes()); 405 | ofs.write(0); 406 | ofs.write(ioutil.longToByteArray(e.count)); 407 | ofs.write(ioutil.intToByte(e.type.value)); 408 | } 409 | } 410 | 411 | public void load(InputStream ifs) throws IOException { 412 | // words_.clear(); 413 | // word2int_.clear(); 414 | IOUtil ioutil = new IOUtil(); 415 | size_ = ioutil.readInt(ifs); 416 | nwords_ = ioutil.readInt(ifs); 417 | nlabels_ = ioutil.readInt(ifs); 418 | ntokens_ = ioutil.readLong(ifs); 419 | long pruneidx_size = ioutil.readLong(ifs); 420 | 421 | word2int_ = new HashMap(size_); 422 | words_ = new ArrayList(size_); 423 | 424 | for (int i = 0; i < size_; i++) { 425 | entry e = new entry(); 426 | e.word = ioutil.readString(ifs); 427 | e.count = ioutil.readLong(ifs); 428 | e.type = entry_type.fromValue(ioutil.readByte(ifs)); 429 | words_.add(e); 430 | word2int_.put(find(e.word), i); 431 | } 432 | initTableDiscard(); 433 | //if (model_name.cbow == args_.model || model_name.sg == args_.model) { 434 | initNgrams(); 435 | //} 436 | } 437 | 438 | @Override 439 | public String toString() { 440 | StringBuilder builder = new StringBuilder(); 441 | builder.append("Dictionary [words_="); 442 | builder.append(words_); 443 | builder.append(", pdiscard_="); 444 | builder.append(pdiscard_); 445 | builder.append(", word2int_="); 446 | builder.append(word2int_); 447 | builder.append(", size_="); 448 | builder.append(size_); 449 | builder.append(", nwords_="); 450 | builder.append(nwords_); 451 | builder.append(", nlabels_="); 452 | builder.append(nlabels_); 453 | builder.append(", ntokens_="); 454 | builder.append(ntokens_); 455 | builder.append("]"); 456 | return builder.toString(); 457 | } 458 | 459 | public List getWords() { 460 | return words_; 461 | } 462 | 463 | public List getPdiscard() { 464 | return pdiscard_; 465 | } 466 | 467 | public Map getWord2int() { 468 | return word2int_; 469 | } 470 | 471 | public int getSize() { 472 | return size_; 473 | } 474 | 475 | public Args getArgs() { 476 | return args_; 477 | } 478 | 479 | public String getCharsetName() { 480 | return charsetName_; 481 | } 482 | 483 | public Class getLineReaderClass() { 484 | return lineReaderClass_; 485 | } 486 | 487 | public void setCharsetName(String charsetName) { 488 | this.charsetName_ = charsetName; 489 | } 490 | 491 | public void setLineReaderClass(Class lineReaderClass) { 492 | this.lineReaderClass_ = lineReaderClass; 493 | } 494 | 495 | } 496 | -------------------------------------------------------------------------------- /src/main/java/fasttext/FastText.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import java.io.BufferedInputStream; 4 | import java.io.BufferedOutputStream; 5 | import java.io.BufferedReader; 6 | import java.io.DataInputStream; 7 | import java.io.File; 8 | import java.io.FileInputStream; 9 | import java.io.FileOutputStream; 10 | import java.io.IOException; 11 | import java.io.InputStream; 12 | import java.io.InputStreamReader; 13 | import java.io.OutputStream; 14 | import java.io.OutputStreamWriter; 15 | import java.io.Writer; 16 | import java.text.DecimalFormat; 17 | import java.util.ArrayList; 18 | import java.util.List; 19 | import java.util.concurrent.atomic.AtomicLong; 20 | 21 | import fasttext.Args.model_name; 22 | import fasttext.Dictionary.entry; 23 | import fasttext.Dictionary.entry_type; 24 | import fasttext.io.BufferedLineReader; 25 | import fasttext.io.LineReader; 26 | 27 | /** 28 | * FastText class, can be used as a lib in other projects 29 | * 30 | * @author Ivan 31 | * 32 | */ 33 | public strictfp class FastText { 34 | 35 | private Args args_; 36 | private Dictionary dict_; 37 | private Matrix input_; 38 | private Matrix output_; 39 | private Model model_; 40 | 41 | private AtomicLong tokenCount_; 42 | private long start_; 43 | 44 | private String charsetName_ = "UTF-8"; 45 | private Class lineReaderClass_ = BufferedLineReader.class; 46 | 47 | public String[] getWords() { 48 | List entries = dict_.getWords(); 49 | int n = entries.size(); 50 | String[] words = new String[n]; 51 | for (int i = 0; i < n; i++) { 52 | words[i] = entries.get(i).word; 53 | } 54 | return words; 55 | } 56 | 57 | public void getVector(Vector vec, final String word) { 58 | final List ngrams = dict_.getNgrams(word); 59 | vec.zero(); 60 | for (Integer it : ngrams) { 61 | vec.addRow(input_, it); 62 | } 63 | if (ngrams.size() > 0) { 64 | vec.mul(1.0f / ngrams.size()); 65 | } 66 | } 67 | 68 | public void saveVectors() throws IOException { 69 | if (Utils.isEmpty(args_.output)) { 70 | if (args_.verbose > 1) { 71 | System.out.println("output is empty, skip save vector file"); 72 | } 73 | return; 74 | } 75 | 76 | File file = new File(args_.output + ".vec"); 77 | if (file.exists()) { 78 | file.delete(); 79 | } 80 | if (file.getParentFile() != null) { 81 | file.getParentFile().mkdirs(); 82 | } 83 | if (args_.verbose > 1) { 84 | System.out.println("Saving Vectors to " + file.getCanonicalPath().toString()); 85 | } 86 | Vector vec = new Vector(args_.dim); 87 | DecimalFormat df = new DecimalFormat("0.#####"); 88 | Writer writer = new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(file)), "UTF-8"); 89 | try { 90 | writer.write(dict_.nwords() + " " + args_.dim + "\n"); 91 | for (int i = 0; i < dict_.nwords(); i++) { 92 | String word = dict_.getWord(i); 93 | getVector(vec, word); 94 | writer.write(word); 95 | for (int j = 0; j < vec.m_; j++) { 96 | writer.write(" "); 97 | writer.write(df.format(vec.data_[j])); 98 | } 99 | writer.write("\n"); 100 | } 101 | } finally { 102 | writer.flush(); 103 | writer.close(); 104 | } 105 | } 106 | 107 | public void saveModel() throws IOException { 108 | if (Utils.isEmpty(args_.output)) { 109 | if (args_.verbose > 1) { 110 | System.out.println("output is empty, skip save model file"); 111 | } 112 | return; 113 | } 114 | 115 | File file = new File(args_.output + ".bin"); 116 | if (file.exists()) { 117 | file.delete(); 118 | } 119 | if (file.getParentFile() != null) { 120 | file.getParentFile().mkdirs(); 121 | } 122 | if (args_.verbose > 1) { 123 | System.out.println("Saving model to " + file.getCanonicalPath().toString()); 124 | } 125 | OutputStream ofs = new BufferedOutputStream(new FileOutputStream(file)); 126 | try { 127 | args_.save(ofs); 128 | dict_.save(ofs); 129 | input_.save(ofs); 130 | output_.save(ofs); 131 | } finally { 132 | ofs.flush(); 133 | ofs.close(); 134 | } 135 | } 136 | 137 | /** 138 | * Load binary model file. 139 | * 140 | * @param filename 141 | * @throws IOException 142 | */ 143 | public void loadModel(String filename) throws IOException { 144 | File file = new File(filename); 145 | if (!(file.exists() && file.isFile() && file.canRead())) { 146 | throw new IOException("Model file cannot be opened for loading! " + filename); 147 | } 148 | loadModel(new FileInputStream(file)); 149 | } 150 | 151 | public void loadModel(InputStream is) throws IOException { 152 | DataInputStream dis = null; 153 | BufferedInputStream bis = null; 154 | try { 155 | bis = new BufferedInputStream(is); 156 | dis = new DataInputStream(bis); 157 | 158 | args_ = new Args(); 159 | dict_ = new Dictionary(args_); 160 | input_ = new Matrix(); 161 | output_ = new Matrix(); 162 | 163 | // Read magic number and version 164 | IOUtil ioutil = new IOUtil(); 165 | int magic_number = ioutil.readInt(dis); 166 | int version = ioutil.readInt(dis); 167 | 168 | args_.load(dis); 169 | dict_.load(dis); 170 | boolean quant = dis.readBoolean(); 171 | input_.load(dis); 172 | quant = dis.readBoolean(); 173 | output_.load(dis); 174 | 175 | model_ = new Model(input_, output_, args_, 0); 176 | if (args_.model == model_name.sup) { 177 | model_.setTargetCounts(dict_.getCounts(entry_type.label)); 178 | } else { 179 | model_.setTargetCounts(dict_.getCounts(entry_type.word)); 180 | } 181 | } finally { 182 | if (bis != null) { 183 | bis.close(); 184 | } 185 | if (dis != null) { 186 | dis.close(); 187 | } 188 | } 189 | } 190 | 191 | public void printInfo(float progress, float loss) { 192 | float t = (float) (System.currentTimeMillis() - start_) / 1000; 193 | float ws = (float) (tokenCount_.get()) / t; 194 | float wst = (float) (tokenCount_.get()) / t / args_.thread; 195 | float lr = (float) (args_.lr * (1.0f - progress)); 196 | int eta = (int) (t / progress * (1 - progress)); 197 | int etah = eta / 3600; 198 | int etam = (eta - etah * 3600) / 60; 199 | System.out.printf("\rProgress: %.1f%% words/sec: %d words/sec/thread: %d lr: %.6f loss: %.6f eta: %d h %d m", 200 | 100 * progress, (int) ws, (int) wst, lr, loss, etah, etam); 201 | } 202 | 203 | public void supervised(Model model, float lr, final List line, final List labels) { 204 | if (labels.size() == 0 || line.size() == 0) 205 | return; 206 | int i = Utils.randomInt(model.rng, 1, labels.size()) - 1; 207 | model.update(line, labels.get(i), lr); 208 | } 209 | 210 | public void cbow(Model model, float lr, final List line) { 211 | List bow = new ArrayList(); 212 | for (int w = 0; w < line.size(); w++) { 213 | int boundary = Utils.randomInt(model.rng, 1, args_.ws); 214 | bow.clear(); 215 | for (int c = -boundary; c <= boundary; c++) { 216 | if (c != 0 && w + c >= 0 && w + c < line.size()) { 217 | final List ngrams = dict_.getNgrams(line.get(w + c)); 218 | bow.addAll(ngrams); 219 | } 220 | } 221 | model.update(bow, line.get(w), lr); 222 | } 223 | } 224 | 225 | public void skipgram(Model model, float lr, final List line) { 226 | for (int w = 0; w < line.size(); w++) { 227 | int boundary = Utils.randomInt(model.rng, 1, args_.ws); 228 | final List ngrams = dict_.getNgrams(line.get(w)); 229 | for (int c = -boundary; c <= boundary; c++) { 230 | if (c != 0 && w + c >= 0 && w + c < line.size()) { 231 | model.update(ngrams, line.get(w + c), lr); 232 | } 233 | } 234 | } 235 | } 236 | 237 | public void test(InputStream in, int k) throws IOException, Exception { 238 | int nexamples = 0, nlabels = 0; 239 | double precision = 0.0f; 240 | List line = new ArrayList(); 241 | List labels = new ArrayList(); 242 | 243 | LineReader lineReader = null; 244 | try { 245 | lineReader = lineReaderClass_.getConstructor(InputStream.class, String.class).newInstance(in, charsetName_); 246 | String[] lineTokens; 247 | while ((lineTokens = lineReader.readLineTokens()) != null) { 248 | if (lineTokens.length == 1 && "quit".equals(lineTokens[0])) { 249 | break; 250 | } 251 | dict_.getLine(lineTokens, line, labels, model_.rng); 252 | dict_.addNgrams(line, args_.wordNgrams); 253 | if (labels.size() > 0 && line.size() > 0) { 254 | List> modelPredictions = new ArrayList>(); 255 | model_.predict(line, k, modelPredictions); 256 | for (Pair pair : modelPredictions) { 257 | if (labels.contains(pair.getValue())) { 258 | precision += 1.0f; 259 | } 260 | } 261 | nexamples++; 262 | nlabels += labels.size(); 263 | // } else { 264 | // System.out.println("FAIL Test line: " + lineTokens + 265 | // "labels: " + labels + " line: " + line); 266 | } 267 | } 268 | } finally { 269 | if (lineReader != null) { 270 | lineReader.close(); 271 | } 272 | } 273 | 274 | System.out.printf("P@%d: %.3f%n", k, precision / (k * nexamples)); 275 | System.out.printf("R@%d: %.3f%n", k, precision / nlabels); 276 | System.out.println("Number of examples: " + nexamples); 277 | } 278 | 279 | /** 280 | * Thread-safe predict api 281 | * 282 | * @param lineTokens 283 | * @param k 284 | * @return 285 | */ 286 | public List> predict(String[] lineTokens, int k) { 287 | List words = new ArrayList(); 288 | List labels = new ArrayList(); 289 | dict_.getLine(lineTokens, words, labels, model_.rng); 290 | dict_.addNgrams(words, args_.wordNgrams); 291 | 292 | if (words.isEmpty()) { 293 | return null; 294 | } 295 | 296 | Vector hidden = new Vector(args_.dim); 297 | Vector output = new Vector(dict_.nlabels()); 298 | List> modelPredictions = new ArrayList>(k + 1); 299 | 300 | model_.predict(words, k, modelPredictions, hidden, output); 301 | 302 | List> predictions = new ArrayList>(k); 303 | for (Pair pair : modelPredictions) { 304 | predictions.add(new Pair(Math.exp(pair.getKey().doubleValue()), dict_.getLabel(pair.getValue()))); 305 | } 306 | return predictions; 307 | } 308 | 309 | public void predict(String[] lineTokens, int k, List> predictions) throws IOException { 310 | List words = new ArrayList(); 311 | List labels = new ArrayList(); 312 | dict_.getLine(lineTokens, words, labels, model_.rng); 313 | dict_.addNgrams(words, args_.wordNgrams); 314 | 315 | if (words.isEmpty()) { 316 | return; 317 | } 318 | List> modelPredictions = new ArrayList>(k + 1); 319 | model_.predict(words, k, modelPredictions); 320 | predictions.clear(); 321 | for (Pair pair : modelPredictions) { 322 | predictions.add(new Pair(pair.getKey(), dict_.getLabel(pair.getValue()))); 323 | } 324 | } 325 | 326 | public void predict(InputStream in, int k, boolean print_prob) throws IOException, Exception { 327 | List> predictions = new ArrayList>(k); 328 | 329 | LineReader lineReader = null; 330 | 331 | try { 332 | lineReader = lineReaderClass_.getConstructor(InputStream.class, String.class).newInstance(in, charsetName_); 333 | String[] lineTokens; 334 | while ((lineTokens = lineReader.readLineTokens()) != null) { 335 | if (lineTokens.length == 1 && "quit".equals(lineTokens[0])) { 336 | break; 337 | } 338 | predictions.clear(); 339 | predict(lineTokens, k, predictions); 340 | if (predictions.isEmpty()) { 341 | System.out.println("n/a"); 342 | continue; 343 | } 344 | for (Pair pair : predictions) { 345 | System.out.print(pair.getValue()); 346 | if (print_prob) { 347 | System.out.printf(" %f", Math.exp(pair.getKey())); 348 | } 349 | } 350 | System.out.println(); 351 | } 352 | } finally { 353 | if (lineReader != null) { 354 | lineReader.close(); 355 | } 356 | } 357 | } 358 | 359 | public void wordVectors() { 360 | Vector vec = new Vector(args_.dim); 361 | LineReader lineReader = null; 362 | try { 363 | lineReader = lineReaderClass_.getConstructor(InputStream.class, String.class).newInstance(System.in, 364 | charsetName_); 365 | String word; 366 | while (!Utils.isEmpty((word = lineReader.readLine()))) { 367 | getVector(vec, word); 368 | System.out.println(word + " " + vec); 369 | } 370 | } catch (Exception e) { 371 | e.printStackTrace(); 372 | } finally { 373 | if (lineReader != null) { 374 | try { 375 | lineReader.close(); 376 | } catch (IOException e) { 377 | e.printStackTrace(); 378 | } 379 | } 380 | } 381 | } 382 | 383 | public void textVectors() { 384 | List line = new ArrayList(); 385 | List labels = new ArrayList(); 386 | Vector vec = new Vector(args_.dim); 387 | LineReader lineReader = null; 388 | try { 389 | lineReader = lineReaderClass_.getConstructor(InputStream.class, String.class).newInstance(System.in, 390 | charsetName_); 391 | String[] lineTokens; 392 | while ((lineTokens = lineReader.readLineTokens()) != null) { 393 | if (lineTokens.length == 1 && "quit".equals(lineTokens[0])) { 394 | break; 395 | } 396 | dict_.getLine(lineTokens, line, labels, model_.rng); 397 | dict_.addNgrams(line, args_.wordNgrams); 398 | vec.zero(); 399 | for (Integer it : line) { 400 | vec.addRow(input_, it); 401 | } 402 | if (!line.isEmpty()) { 403 | vec.mul(1.0f / line.size()); 404 | } 405 | System.out.println(vec); 406 | } 407 | } catch (Exception e) { 408 | e.printStackTrace(); 409 | } finally { 410 | if (lineReader != null) { 411 | try { 412 | lineReader.close(); 413 | } catch (IOException e) { 414 | e.printStackTrace(); 415 | } 416 | } 417 | } 418 | } 419 | 420 | public Vector sentenceVectors(String[] tokens) { 421 | Vector vec = new Vector(args_.dim); 422 | Vector svec = new Vector(args_.dim); 423 | svec.zero(); 424 | int count = 0; 425 | for (int i=0; i < tokens.length; i++) { 426 | getVector(vec, tokens[i]); 427 | float norm = vec.norm(); 428 | if (norm != 0.0f) { 429 | vec.mul((float) (1.0 / vec.norm())); 430 | svec.addVector(vec); 431 | count++; 432 | } 433 | } 434 | if (count != 0) svec.mul((float) (1.0 / count)); 435 | else svec.zero(); 436 | return svec; 437 | } 438 | 439 | public void printVectors() { 440 | if (args_.model == model_name.sup) { 441 | textVectors(); 442 | } else { 443 | wordVectors(); 444 | } 445 | } 446 | 447 | public class TrainThread extends Thread { 448 | final FastText ft; 449 | int threadId; 450 | 451 | public TrainThread(FastText ft, int threadId) { 452 | super("FT-TrainThread-" + threadId); 453 | this.ft = ft; 454 | this.threadId = threadId; 455 | } 456 | 457 | public void run() { 458 | if (args_.verbose > 2) { 459 | System.out.println("thread: " + threadId + " RUNNING!"); 460 | } 461 | Exception catchedException = null; 462 | LineReader lineReader = null; 463 | try { 464 | lineReader = lineReaderClass_.getConstructor(String.class, String.class).newInstance(args_.input, 465 | charsetName_); 466 | lineReader.skipLine(threadId * threadFileSize / args_.thread); 467 | Model model = new Model(input_, output_, args_, threadId); 468 | if (args_.model == model_name.sup) { 469 | model.setTargetCounts(dict_.getCounts(entry_type.label)); 470 | } else { 471 | model.setTargetCounts(dict_.getCounts(entry_type.word)); 472 | } 473 | 474 | final long ntokens = dict_.ntokens(); 475 | long localTokenCount = 0; 476 | 477 | List line = new ArrayList(); 478 | List labels = new ArrayList(); 479 | 480 | String[] lineTokens; 481 | while (tokenCount_.get() < args_.epoch * ntokens) { 482 | lineTokens = lineReader.readLineTokens(); 483 | if (lineTokens == null) { 484 | try { 485 | lineReader.rewind(); 486 | if (args_.verbose > 2) { 487 | System.out.println("Input file reloaded!"); 488 | } 489 | } catch (Exception e) { 490 | e.printStackTrace(); 491 | } 492 | lineTokens = lineReader.readLineTokens(); 493 | } 494 | 495 | float progress = (float) (tokenCount_.get()) / (args_.epoch * ntokens); 496 | float lr = (float) (args_.lr * (1.0 - progress)); 497 | localTokenCount += dict_.getLine(lineTokens, line, labels, model.rng); 498 | if (args_.model == model_name.sup) { 499 | dict_.addNgrams(line, args_.wordNgrams); 500 | if (labels.size() == 0 || line.size() == 0) { 501 | continue; 502 | } 503 | supervised(model, lr, line, labels); 504 | } else if (args_.model == model_name.cbow) { 505 | cbow(model, lr, line); 506 | } else if (args_.model == model_name.sg) { 507 | skipgram(model, lr, line); 508 | } 509 | if (localTokenCount > args_.lrUpdateRate) { 510 | tokenCount_.addAndGet(localTokenCount); 511 | localTokenCount = 0; 512 | if (threadId == 0 && args_.verbose > 1 && (System.currentTimeMillis() - start_) % 1000 == 0) { 513 | printInfo(progress, model.getLoss()); 514 | } 515 | } 516 | } 517 | 518 | if (threadId == 0 && args_.verbose > 1) { 519 | printInfo(1.0f, model.getLoss()); 520 | } 521 | } catch (Exception e) { 522 | catchedException = e; 523 | } finally { 524 | if (lineReader != null) 525 | try { 526 | lineReader.close(); 527 | } catch (IOException e) { 528 | e.printStackTrace(); 529 | } 530 | } 531 | 532 | // exit from thread 533 | synchronized (ft) { 534 | if (args_.verbose > 2) { 535 | System.out.println("\nthread: " + threadId + " EXIT!"); 536 | } 537 | ft.threadCount--; 538 | ft.notify(); 539 | if (catchedException != null) { 540 | throw new RuntimeException(catchedException); 541 | } 542 | } 543 | } 544 | } 545 | 546 | public void loadVectors(String filename) throws IOException { 547 | List words; 548 | Matrix mat; // temp. matrix for pretrained vectors 549 | int n, dim; 550 | 551 | BufferedReader dis = null; 552 | String line; 553 | String[] lineParts; 554 | try { 555 | dis = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "UTF-8")); 556 | 557 | line = dis.readLine(); 558 | lineParts = line.split(" "); 559 | n = Integer.parseInt(lineParts[0]); 560 | dim = Integer.parseInt(lineParts[1]); 561 | 562 | words = new ArrayList(n); 563 | 564 | if (dim != args_.dim) { 565 | throw new IllegalArgumentException( 566 | "Dimension of pretrained vectors does not match args -dim option, pretrain dim is " + dim 567 | + ", args dim is " + args_.dim); 568 | } 569 | 570 | mat = new Matrix(n, dim); 571 | for (int i = 0; i < n; i++) { 572 | line = dis.readLine(); 573 | lineParts = line.split(" "); 574 | String word = lineParts[0]; 575 | for (int j = 1; j <= dim; j++) { 576 | mat.data_[i][j - 1] = Float.parseFloat(lineParts[j]); 577 | } 578 | words.add(word); 579 | dict_.add(word); 580 | } 581 | 582 | dict_.threshold(1, 0); 583 | input_ = new Matrix(dict_.nwords() + args_.bucket, args_.dim); 584 | input_.uniform(1.0f / args_.dim); 585 | for (int i = 0; i < n; i++) { 586 | int idx = dict_.getId(words.get(i)); 587 | if (idx < 0 || idx >= dict_.nwords()) 588 | continue; 589 | for (int j = 0; j < dim; j++) { 590 | input_.data_[idx][j] = mat.data_[i][j]; 591 | } 592 | } 593 | 594 | } catch (IOException e) { 595 | throw new IOException("Pretrained vectors file cannot be opened!", e); 596 | } finally { 597 | try { 598 | if (dis != null) { 599 | dis.close(); 600 | } 601 | } catch (IOException e) { 602 | e.printStackTrace(); 603 | } 604 | } 605 | } 606 | 607 | int threadCount; 608 | long threadFileSize; 609 | 610 | public void train(Args args) throws IOException, Exception { 611 | args_ = args; 612 | dict_ = new Dictionary(args_); 613 | dict_.setCharsetName(charsetName_); 614 | dict_.setLineReaderClass(lineReaderClass_); 615 | 616 | if ("-".equals(args_.input)) { 617 | throw new IOException("Cannot use stdin for training!"); 618 | } 619 | 620 | File file = new File(args_.input); 621 | if (!(file.exists() && file.isFile() && file.canRead())) { 622 | throw new IOException("Input file cannot be opened! " + args_.input); 623 | } 624 | 625 | dict_.readFromFile(args_.input); 626 | threadFileSize = Utils.sizeLine(args_.input); 627 | 628 | if (!Utils.isEmpty(args_.pretrainedVectors)) { 629 | loadVectors(args_.pretrainedVectors); 630 | } else { 631 | input_ = new Matrix(dict_.nwords() + args_.bucket, args_.dim); 632 | input_.uniform(1.0f / args_.dim); 633 | } 634 | 635 | if (args_.model == model_name.sup) { 636 | output_ = new Matrix(dict_.nlabels(), args_.dim); 637 | } else { 638 | output_ = new Matrix(dict_.nwords(), args_.dim); 639 | } 640 | output_.zero(); 641 | 642 | start_ = System.currentTimeMillis(); 643 | tokenCount_ = new AtomicLong(0); 644 | long t0 = System.currentTimeMillis(); 645 | threadCount = args_.thread; 646 | for (int i = 0; i < args_.thread; i++) { 647 | Thread t = new TrainThread(this, i); 648 | t.setUncaughtExceptionHandler(trainThreadExcpetionHandler); 649 | t.start(); 650 | } 651 | 652 | synchronized (this) { 653 | while (threadCount > 0) { 654 | try { 655 | wait(); 656 | } catch (InterruptedException ignored) { 657 | } 658 | } 659 | } 660 | 661 | model_ = new Model(input_, output_, args_, 0); 662 | 663 | if (args.verbose > 1) { 664 | long trainTime = (System.currentTimeMillis() - t0) / 1000; 665 | System.out.printf("\nTrain time used: %d sec\n", trainTime); 666 | } 667 | 668 | saveModel(); 669 | if (args_.model != model_name.sup) { 670 | saveVectors(); 671 | } 672 | } 673 | 674 | protected Thread.UncaughtExceptionHandler trainThreadExcpetionHandler = new Thread.UncaughtExceptionHandler() { 675 | public void uncaughtException(Thread th, Throwable ex) { 676 | ex.printStackTrace(); 677 | } 678 | }; 679 | 680 | public Args getArgs() { 681 | return args_; 682 | } 683 | 684 | public Dictionary getDict() { 685 | return dict_; 686 | } 687 | 688 | public Matrix getInput() { 689 | return input_; 690 | } 691 | 692 | public Matrix getOutput() { 693 | return output_; 694 | } 695 | 696 | public Model getModel() { 697 | return model_; 698 | } 699 | 700 | public void setArgs(Args args) { 701 | this.args_ = args; 702 | } 703 | 704 | public void setDict(Dictionary dict) { 705 | this.dict_ = dict; 706 | } 707 | 708 | public void setInput(Matrix input) { 709 | this.input_ = input; 710 | } 711 | 712 | public void setOutput(Matrix output) { 713 | this.output_ = output; 714 | } 715 | 716 | public void setModel(Model model) { 717 | this.model_ = model; 718 | } 719 | 720 | public String getCharsetName() { 721 | return charsetName_; 722 | } 723 | 724 | public Class getLineReaderClass() { 725 | return lineReaderClass_; 726 | } 727 | 728 | public void setCharsetName(String charsetName) { 729 | this.charsetName_ = charsetName; 730 | } 731 | 732 | public void setLineReaderClass(Class lineReaderClass) { 733 | this.lineReaderClass_ = lineReaderClass; 734 | } 735 | 736 | } 737 | -------------------------------------------------------------------------------- /src/main/java/fasttext/IOUtil.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import java.io.IOException; 4 | import java.io.InputStream; 5 | import java.nio.ByteBuffer; 6 | import java.nio.ByteOrder; 7 | 8 | /** 9 | * Read/write cpp primitive type 10 | * 11 | * @author Ivan 12 | * 13 | */ 14 | public class IOUtil { 15 | 16 | public IOUtil() { 17 | } 18 | 19 | private int string_buf_size_ = 128; 20 | private byte[] int_bytes_ = new byte[4]; 21 | private byte[] long_bytes_ = new byte[8]; 22 | private byte[] float_bytes_ = new byte[4]; 23 | private byte[] double_bytes_ = new byte[8]; 24 | private byte[] string_bytes_ = new byte[string_buf_size_]; 25 | private StringBuilder stringBuilder_ = new StringBuilder(); 26 | private ByteBuffer float_array_bytebuffer_ = null; 27 | private byte[] float_array_bytes_ = null; 28 | 29 | public void setStringBufferSize(int size) { 30 | string_buf_size_ = size; 31 | string_bytes_ = new byte[string_buf_size_]; 32 | } 33 | 34 | public void setFloatArrayBufferSize(int itemSize) { 35 | float_array_bytebuffer_ = ByteBuffer.allocate(itemSize * 4).order(ByteOrder.LITTLE_ENDIAN); 36 | float_array_bytes_ = new byte[itemSize * 4]; 37 | } 38 | 39 | public int readByte(InputStream is) throws IOException { 40 | return is.read() & 0xFF; 41 | } 42 | 43 | public int readInt(InputStream is) throws IOException { 44 | is.read(int_bytes_); 45 | return getInt(int_bytes_); 46 | } 47 | 48 | public int getInt(byte[] b) { 49 | return (b[0] & 0xFF) << 0 | (b[1] & 0xFF) << 8 | (b[2] & 0xFF) << 16 | (b[3] & 0xFF) << 24; 50 | } 51 | 52 | public long readLong(InputStream is) throws IOException { 53 | is.read(long_bytes_); 54 | return getLong(long_bytes_); 55 | } 56 | 57 | public long getLong(byte[] b) { 58 | return (b[0] & 0xFFL) << 0 | (b[1] & 0xFFL) << 8 | (b[2] & 0xFFL) << 16 | (b[3] & 0xFFL) << 24 59 | | (b[4] & 0xFFL) << 32 | (b[5] & 0xFFL) << 40 | (b[6] & 0xFFL) << 48 | (b[7] & 0xFFL) << 56; 60 | } 61 | 62 | public float readFloat(InputStream is) throws IOException { 63 | is.read(float_bytes_); 64 | return getFloat(float_bytes_); 65 | } 66 | 67 | public void readFloat(InputStream is, float[] data) throws IOException { 68 | is.read(float_array_bytes_); 69 | float_array_bytebuffer_.clear(); 70 | ((ByteBuffer) float_array_bytebuffer_.put(float_array_bytes_).flip()).asFloatBuffer().get(data); 71 | } 72 | 73 | public float getFloat(byte[] b) { 74 | return Float 75 | .intBitsToFloat((b[0] & 0xFF) << 0 | (b[1] & 0xFF) << 8 | (b[2] & 0xFF) << 16 | (b[3] & 0xFF) << 24); 76 | } 77 | 78 | public double readDouble(InputStream is) throws IOException { 79 | is.read(double_bytes_); 80 | return getDouble(double_bytes_); 81 | } 82 | 83 | public double getDouble(byte[] b) { 84 | return Double.longBitsToDouble(getLong(b)); 85 | } 86 | 87 | public String readString(InputStream is) throws IOException { 88 | int b = is.read(); 89 | if (b < 0) { 90 | return null; 91 | } 92 | int i = -1; 93 | stringBuilder_.setLength(0); 94 | // ascii space, \n, \0 95 | while (b > -1 && b != 32 && b != 10 && b != 0) { 96 | string_bytes_[++i] = (byte) b; 97 | b = is.read(); 98 | if (i == string_buf_size_ - 1) { 99 | stringBuilder_.append(new String(string_bytes_/*, "utf8"*/)); 100 | i = -1; 101 | } 102 | } 103 | stringBuilder_.append(new String(string_bytes_, 0, i + 1/*, "utf8"*/)); 104 | return stringBuilder_.toString(); 105 | } 106 | 107 | public int intToByte(int i) { 108 | return (i & 0xFF); 109 | } 110 | 111 | public byte[] intToByteArray(int i) { 112 | int_bytes_[0] = (byte) ((i >> 0) & 0xff); 113 | int_bytes_[1] = (byte) ((i >> 8) & 0xff); 114 | int_bytes_[2] = (byte) ((i >> 16) & 0xff); 115 | int_bytes_[3] = (byte) ((i >> 24) & 0xff); 116 | return int_bytes_; 117 | } 118 | 119 | public byte[] longToByteArray(long l) { 120 | long_bytes_[0] = (byte) ((l >> 0) & 0xff); 121 | long_bytes_[1] = (byte) ((l >> 8) & 0xff); 122 | long_bytes_[2] = (byte) ((l >> 16) & 0xff); 123 | long_bytes_[3] = (byte) ((l >> 24) & 0xff); 124 | long_bytes_[4] = (byte) ((l >> 32) & 0xff); 125 | long_bytes_[5] = (byte) ((l >> 40) & 0xff); 126 | long_bytes_[6] = (byte) ((l >> 48) & 0xff); 127 | long_bytes_[7] = (byte) ((l >> 56) & 0xff); 128 | 129 | return long_bytes_; 130 | } 131 | 132 | public byte[] floatToByteArray(float f) { 133 | return intToByteArray(Float.floatToIntBits(f)); 134 | } 135 | 136 | public byte[] floatToByteArray(float[] f) { 137 | float_array_bytebuffer_.clear(); 138 | float_array_bytebuffer_.asFloatBuffer().put(f); 139 | return float_array_bytebuffer_.array(); 140 | } 141 | 142 | public byte[] doubleToByteArray(double d) { 143 | return longToByteArray(Double.doubleToRawLongBits(d)); 144 | } 145 | 146 | } 147 | -------------------------------------------------------------------------------- /src/main/java/fasttext/Main.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.IOException; 6 | 7 | public class Main { 8 | 9 | public static void printUsage() { 10 | System.out.print("usage: java -jar fasttext.jar \n\n" 11 | + "The commands supported by fasttext are:\n\n" 12 | + " supervised train a supervised classifier\n" 13 | + " test evaluate a supervised classifier\n" 14 | + " predict predict most likely labels\n" 15 | + " predict-prob predict most likely labels with probabilities\n" 16 | + " skipgram train a skipgram model\n" 17 | + " cbow train a cbow model\n" 18 | + " print-vectors print vectors given a trained model\n"); 19 | } 20 | 21 | public static void printTestUsage() { 22 | System.out.print("usage: java -jar fasttext.jar test []\n\n" 23 | + " model filename\n" 24 | + " test data filename (if -, read from stdin)\n" 25 | + " (optional; 1 by default) predict top k labels\n"); 26 | } 27 | 28 | public static void printPredictUsage() { 29 | System.out.print("usage: java -jar fasttext.jar predict[-prob] []\n\n" 30 | + " model filename\n" 31 | + " test data filename (if -, read from stdin)\n" 32 | + " (optional; 1 by default) predict top k labels\n"); 33 | } 34 | 35 | public static void printPrintVectorsUsage() { 36 | System.out.print("usage: java -jar fasttext.jar print-vectors \n\n" 37 | + " model filename\n"); 38 | } 39 | 40 | public void test(String[] args) throws IOException, Exception { 41 | int k = 1; 42 | if (args.length == 3) { 43 | k = 1; 44 | } else if (args.length == 4) { 45 | k = Integer.parseInt(args[3]); 46 | } else { 47 | printTestUsage(); 48 | System.exit(1); 49 | } 50 | FastText fasttext = new FastText(); 51 | fasttext.loadModel(args[1]); 52 | String infile = args[2]; 53 | if ("-".equals(infile)) { 54 | fasttext.test(System.in, k); 55 | } else { 56 | File file = new File(infile); 57 | if (!(file.exists() && file.isFile() && file.canRead())) { 58 | throw new IOException("Test file cannot be opened!"); 59 | } 60 | fasttext.test(new FileInputStream(file), k); 61 | } 62 | } 63 | 64 | public void predict(String[] args) throws IOException, Exception { 65 | int k = 1; 66 | if (args.length == 3) { 67 | k = 1; 68 | } else if (args.length == 4) { 69 | k = Integer.parseInt(args[3]); 70 | } else { 71 | printPredictUsage(); 72 | System.exit(1); 73 | } 74 | boolean print_prob = "predict-prob".equalsIgnoreCase(args[0]); 75 | FastText fasttext = new FastText(); 76 | fasttext.loadModel(args[1]); 77 | 78 | String infile = args[2]; 79 | if ("-".equals(infile)) { 80 | fasttext.predict(System.in, k, print_prob); 81 | } else { 82 | File file = new File(infile); 83 | if (!(file.exists() && file.isFile() && file.canRead())) { 84 | throw new IOException("Input file cannot be opened!"); 85 | } 86 | fasttext.predict(new FileInputStream(file), k, print_prob); 87 | } 88 | } 89 | 90 | public void printVectors(String[] args) throws IOException { 91 | if (args.length != 2) { 92 | printPrintVectorsUsage(); 93 | System.exit(1); 94 | } 95 | FastText fasttext = new FastText(); 96 | fasttext.loadModel(args[1]); 97 | fasttext.printVectors(); 98 | } 99 | 100 | public void train(String[] args) throws IOException, Exception { 101 | Args a = new Args(); 102 | a.parseArgs(args); 103 | FastText fasttext = new FastText(); 104 | fasttext.train(a); 105 | } 106 | 107 | public static void main(String[] args) { 108 | Main op = new Main(); 109 | 110 | if (args.length == 0) { 111 | printUsage(); 112 | System.exit(1); 113 | } 114 | 115 | try { 116 | String command = args[0]; 117 | if ("skipgram".equalsIgnoreCase(command) || "cbow".equalsIgnoreCase(command) 118 | || "supervised".equalsIgnoreCase(command)) { 119 | op.train(args); 120 | } else if ("test".equalsIgnoreCase(command)) { 121 | op.test(args); 122 | } else if ("print-vectors".equalsIgnoreCase(command)) { 123 | op.printVectors(args); 124 | } else if ("predict".equalsIgnoreCase(command) || "predict-prob".equalsIgnoreCase(command)) { 125 | op.predict(args); 126 | } else { 127 | printUsage(); 128 | System.exit(1); 129 | } 130 | } catch (Exception e) { 131 | e.printStackTrace(); 132 | System.exit(1); 133 | } 134 | 135 | System.exit(0); 136 | } 137 | 138 | } 139 | -------------------------------------------------------------------------------- /src/main/java/fasttext/Matrix.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import java.io.IOException; 4 | import java.io.InputStream; 5 | import java.io.OutputStream; 6 | import java.util.Random; 7 | 8 | public strictfp class Matrix { 9 | 10 | public float[][] data_ = null; 11 | public int m_ = 0; // vocabSize 12 | public int n_ = 0; // layer1Size 13 | 14 | public Matrix() { 15 | } 16 | 17 | public Matrix(int m, int n) { 18 | m_ = m; 19 | n_ = n; 20 | data_ = new float[m][n]; 21 | } 22 | 23 | public Matrix(final Matrix other) { 24 | m_ = other.m_; 25 | n_ = other.n_; 26 | data_ = new float[m_][n_]; 27 | for (int i = 0; i < m_; i++) { 28 | for (int j = 0; j < n_; j++) { 29 | data_[i][j] = other.data_[i][j]; 30 | } 31 | } 32 | } 33 | 34 | public void zero() { 35 | for (int i = 0; i < m_; i++) { 36 | for (int j = 0; j < n_; j++) { 37 | data_[i][j] = 0.0f; 38 | } 39 | } 40 | } 41 | 42 | public void uniform(float a) { 43 | Random random = new Random(1l); 44 | for (int i = 0; i < m_; i++) { 45 | for (int j = 0; j < n_; j++) { 46 | data_[i][j] = Utils.randomFloat(random, -a, a); 47 | } 48 | } 49 | } 50 | 51 | public void addRow(final Vector vec, int i, float a) { 52 | Utils.checkArgument(i >= 0); 53 | Utils.checkArgument(i < m_); 54 | Utils.checkArgument(vec.m_ == n_); 55 | for (int j = 0; j < n_; j++) { 56 | data_[i][j] += a * vec.data_[j]; 57 | } 58 | } 59 | 60 | public float dotRow(final Vector vec, int i) { 61 | Utils.checkArgument(i >= 0); 62 | Utils.checkArgument(i < m_); 63 | Utils.checkArgument(vec.m_ == n_); 64 | float d = 0.0f; 65 | for (int j = 0; j < n_; j++) { 66 | d += data_[i][j] * vec.data_[j]; 67 | } 68 | return d; 69 | } 70 | 71 | public void load(InputStream input) throws IOException { 72 | IOUtil ioutil = new IOUtil(); 73 | 74 | m_ = (int) ioutil.readLong(input); 75 | n_ = (int) ioutil.readLong(input); 76 | 77 | ioutil.setFloatArrayBufferSize(n_); 78 | data_ = new float[m_][n_]; 79 | for (int i = 0; i < m_; i++) { 80 | ioutil.readFloat(input, data_[i]); 81 | } 82 | } 83 | 84 | public void save(OutputStream ofs) throws IOException { 85 | IOUtil ioutil = new IOUtil(); 86 | ioutil.setFloatArrayBufferSize(n_); 87 | ofs.write(ioutil.longToByteArray(m_)); 88 | ofs.write(ioutil.longToByteArray(n_)); 89 | for (int i = 0; i < m_; i++) { 90 | ofs.write(ioutil.floatToByteArray(data_[i])); 91 | } 92 | } 93 | 94 | @Override 95 | public String toString() { 96 | StringBuilder builder = new StringBuilder(); 97 | builder.append("Matrix [data_="); 98 | if (data_ != null) { 99 | builder.append("["); 100 | for (int i = 0; i < m_ && i < 10; i++) { 101 | for (int j = 0; j < n_ && j < 10; j++) { 102 | builder.append(data_[i][j]).append(","); 103 | } 104 | } 105 | builder.setLength(builder.length() - 1); 106 | builder.append("]"); 107 | } else { 108 | builder.append("null"); 109 | } 110 | builder.append(", m_="); 111 | builder.append(m_); 112 | builder.append(", n_="); 113 | builder.append(n_); 114 | builder.append("]"); 115 | return builder.toString(); 116 | } 117 | 118 | } 119 | -------------------------------------------------------------------------------- /src/main/java/fasttext/Model.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.Comparator; 6 | import java.util.List; 7 | import java.util.Random; 8 | 9 | import fasttext.Args.loss_name; 10 | import fasttext.Args.model_name; 11 | 12 | public strictfp class Model { 13 | 14 | static final int SIGMOID_TABLE_SIZE = 512; 15 | static final int MAX_SIGMOID = 8; 16 | static final int LOG_TABLE_SIZE = 512; 17 | 18 | static final int NEGATIVE_TABLE_SIZE = 10000000; 19 | 20 | public class Node { 21 | int parent; 22 | int left; 23 | int right; 24 | long count; 25 | boolean binary; 26 | } 27 | 28 | private Matrix wi_; // input 29 | private Matrix wo_; // output 30 | private Args args_; 31 | private Vector hidden_; 32 | private Vector output_; 33 | private Vector grad_; 34 | private int hsz_; // dim 35 | @SuppressWarnings("unused") 36 | private int isz_; // input vocabSize 37 | private int osz_; // output vocabSize 38 | private float loss_; 39 | private long nexamples_; 40 | private float[] t_sigmoid; 41 | private float[] t_log; 42 | // used for negative sampling: 43 | private List negatives; 44 | private int negpos; 45 | // used for hierarchical softmax: 46 | private List> paths; 47 | private List> codes; 48 | private List tree; 49 | 50 | public transient Random rng; 51 | 52 | public Model(Matrix wi, Matrix wo, Args args, int seed) { 53 | hidden_ = new Vector(args.dim); 54 | output_ = new Vector(wo.m_); 55 | grad_ = new Vector(args.dim); 56 | rng = new Random((long) seed); 57 | 58 | wi_ = wi; 59 | wo_ = wo; 60 | args_ = args; 61 | isz_ = wi.m_; 62 | osz_ = wo.m_; 63 | hsz_ = args.dim; 64 | negpos = 0; 65 | loss_ = 0.0f; 66 | nexamples_ = 1l; 67 | initSigmoid(); 68 | initLog(); 69 | } 70 | 71 | public float binaryLogistic(int target, boolean label, float lr) { 72 | float score = sigmoid(wo_.dotRow(hidden_, target)); 73 | float alpha = lr * ((label ? 1.0f : 0.0f) - score); 74 | grad_.addRow(wo_, target, alpha); 75 | wo_.addRow(hidden_, target, alpha); 76 | if (label) { 77 | return -log(score); 78 | } else { 79 | return -log(1.0f - score); 80 | } 81 | } 82 | 83 | public float negativeSampling(int target, float lr) { 84 | float loss = 0.0f; 85 | grad_.zero(); 86 | for (int n = 0; n <= args_.neg; n++) { 87 | if (n == 0) { 88 | loss += binaryLogistic(target, true, lr); 89 | } else { 90 | loss += binaryLogistic(getNegative(target), false, lr); 91 | } 92 | } 93 | return loss; 94 | } 95 | 96 | public float hierarchicalSoftmax(int target, float lr) { 97 | float loss = 0.0f; 98 | grad_.zero(); 99 | final List binaryCode = codes.get(target); 100 | final List pathToRoot = paths.get(target); 101 | for (int i = 0; i < pathToRoot.size(); i++) { 102 | loss += binaryLogistic(pathToRoot.get(i), binaryCode.get(i), lr); 103 | } 104 | return loss; 105 | } 106 | 107 | public void computeOutputSoftmax(Vector hidden, Vector output) { 108 | output.mul(wo_, hidden); 109 | float max = output.get(0), z = 0.0f; 110 | for (int i = 1; i < osz_; i++) { 111 | max = Math.max(output.get(i), max); 112 | } 113 | for (int i = 0; i < osz_; i++) { 114 | output.set(i, (float) Math.exp(output.get(i) - max)); 115 | z += output.get(i); 116 | } 117 | for (int i = 0; i < osz_; i++) { 118 | output.set(i, output.get(i) / z); 119 | } 120 | } 121 | 122 | public void computeOutputSoftmax() { 123 | computeOutputSoftmax(hidden_, output_); 124 | } 125 | 126 | public float softmax(int target, float lr) { 127 | grad_.zero(); 128 | computeOutputSoftmax(); 129 | for (int i = 0; i < osz_; i++) { 130 | float label = (i == target) ? 1.0f : 0.0f; 131 | float alpha = lr * (label - output_.get(i)); 132 | grad_.addRow(wo_, i, alpha); 133 | wo_.addRow(hidden_, i, alpha); 134 | } 135 | return -log(output_.get(target)); 136 | } 137 | 138 | public void computeHidden(final List input, Vector hidden) { 139 | Utils.checkArgument(hidden.size() == hsz_); 140 | hidden.zero(); 141 | for (Integer it : input) { 142 | hidden.addRow(wi_, it); 143 | } 144 | hidden.mul(1.0f / input.size()); 145 | } 146 | 147 | private Comparator> comparePairs = new Comparator>() { 148 | 149 | @Override 150 | public int compare(Pair o1, Pair o2) { 151 | return o2.getKey().compareTo(o1.getKey()); 152 | } 153 | }; 154 | 155 | public void predict(final List input, int k, List> heap, Vector hidden, 156 | Vector output) { 157 | Utils.checkArgument(k > 0); 158 | if (heap instanceof ArrayList) { 159 | ((ArrayList>) heap).ensureCapacity(k + 1); 160 | } 161 | computeHidden(input, hidden); 162 | if (args_.loss == loss_name.hs) { 163 | dfs(k, 2 * osz_ - 2, 0.0f, heap, hidden); 164 | } else { 165 | findKBest(k, heap, hidden, output); 166 | } 167 | Collections.sort(heap, comparePairs); 168 | } 169 | 170 | public void predict(final List input, int k, List> heap) { 171 | predict(input, k, heap, hidden_, output_); 172 | } 173 | 174 | public void findKBest(int k, List> heap, Vector hidden, Vector output) { 175 | computeOutputSoftmax(hidden, output); 176 | for (int i = 0; i < osz_; i++) { 177 | if (heap.size() == k && log(output.get(i)) < heap.get(heap.size() - 1).getKey()) { 178 | continue; 179 | } 180 | heap.add(new Pair(log(output.get(i)), i)); 181 | Collections.sort(heap, comparePairs); 182 | if (heap.size() > k) { 183 | Collections.sort(heap, comparePairs); 184 | heap.remove(heap.size() - 1); // pop last 185 | } 186 | } 187 | } 188 | 189 | public void dfs(int k, int node, float score, List> heap, Vector hidden) { 190 | if (heap.size() == k && score < heap.get(heap.size() - 1).getKey()) { 191 | return; 192 | } 193 | 194 | if (tree.get(node).left == -1 && tree.get(node).right == -1) { 195 | heap.add(new Pair(score, node)); 196 | Collections.sort(heap, comparePairs); 197 | if (heap.size() > k) { 198 | Collections.sort(heap, comparePairs); 199 | heap.remove(heap.size() - 1); // pop last 200 | } 201 | return; 202 | } 203 | 204 | float f = sigmoid(wo_.dotRow(hidden, node - osz_)); 205 | dfs(k, tree.get(node).left, score + log(1.0f - f), heap, hidden); 206 | dfs(k, tree.get(node).right, score + log(f), heap, hidden); 207 | } 208 | 209 | public void update(final List input, int target, float lr) { 210 | Utils.checkArgument(target >= 0); 211 | Utils.checkArgument(target < osz_); 212 | if (input.size() == 0) { 213 | return; 214 | } 215 | computeHidden(input, hidden_); 216 | 217 | if (args_.loss == loss_name.ns) { 218 | loss_ += negativeSampling(target, lr); 219 | } else if (args_.loss == loss_name.hs) { 220 | loss_ += hierarchicalSoftmax(target, lr); 221 | } else { 222 | loss_ += softmax(target, lr); 223 | } 224 | nexamples_ += 1; 225 | 226 | if (args_.model == model_name.sup) { 227 | grad_.mul(1.0f / input.size()); 228 | } 229 | for (Integer it : input) { 230 | wi_.addRow(grad_, it, 1.0f); 231 | } 232 | } 233 | 234 | public void setTargetCounts(final List counts) { 235 | Utils.checkArgument(counts.size() == osz_); 236 | if (args_.loss == loss_name.ns) { 237 | initTableNegatives(counts); 238 | } 239 | if (args_.loss == loss_name.hs) { 240 | buildTree(counts); 241 | } 242 | } 243 | 244 | public void initTableNegatives(final List counts) { 245 | negatives = new ArrayList(counts.size()); 246 | float z = 0.0f; 247 | for (int i = 0; i < counts.size(); i++) { 248 | z += (float) Math.pow(counts.get(i), 0.5f); 249 | } 250 | for (int i = 0; i < counts.size(); i++) { 251 | float c = (float) Math.pow(counts.get(i), 0.5f); 252 | for (int j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) { 253 | negatives.add(i); 254 | } 255 | } 256 | Utils.shuffle(negatives, rng); 257 | } 258 | 259 | public int getNegative(int target) { 260 | int negative; 261 | do { 262 | negative = negatives.get(negpos); 263 | negpos = (negpos + 1) % negatives.size(); 264 | } while (target == negative); 265 | return negative; 266 | } 267 | 268 | public void buildTree(final List counts) { 269 | paths = new ArrayList>(osz_); 270 | codes = new ArrayList>(osz_); 271 | tree = new ArrayList(2 * osz_ - 1); 272 | 273 | for (int i = 0; i < 2 * osz_ - 1; i++) { 274 | Node node = new Node(); 275 | node.parent = -1; 276 | node.left = -1; 277 | node.right = -1; 278 | node.count = 1000000000000000L;// 1e15f; 279 | node.binary = false; 280 | tree.add(i, node); 281 | } 282 | for (int i = 0; i < osz_; i++) { 283 | tree.get(i).count = counts.get(i); 284 | } 285 | int leaf = osz_ - 1; 286 | int node = osz_; 287 | for (int i = osz_; i < 2 * osz_ - 1; i++) { 288 | int[] mini = new int[2]; 289 | for (int j = 0; j < 2; j++) { 290 | if (leaf >= 0 && tree.get(leaf).count < tree.get(node).count) { 291 | mini[j] = leaf--; 292 | } else { 293 | mini[j] = node++; 294 | } 295 | } 296 | tree.get(i).left = mini[0]; 297 | tree.get(i).right = mini[1]; 298 | tree.get(i).count = tree.get(mini[0]).count + tree.get(mini[1]).count; 299 | tree.get(mini[0]).parent = i; 300 | tree.get(mini[1]).parent = i; 301 | tree.get(mini[1]).binary = true; 302 | } 303 | for (int i = 0; i < osz_; i++) { 304 | List path = new ArrayList(); 305 | List code = new ArrayList(); 306 | int j = i; 307 | while (tree.get(j).parent != -1) { 308 | path.add(tree.get(j).parent - osz_); 309 | code.add(tree.get(j).binary); 310 | j = tree.get(j).parent; 311 | } 312 | paths.add(path); 313 | codes.add(code); 314 | } 315 | } 316 | 317 | public float getLoss() { 318 | return loss_ / nexamples_; 319 | } 320 | 321 | private void initSigmoid() { 322 | t_sigmoid = new float[SIGMOID_TABLE_SIZE + 1]; 323 | for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) { 324 | float x = (float) (i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID; 325 | t_sigmoid[i] = (float) (1.0f / (1.0f + Math.exp(-x))); 326 | } 327 | } 328 | 329 | private void initLog() { 330 | t_log = new float[LOG_TABLE_SIZE + 1]; 331 | for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) { 332 | float x = (float) (((float) (i) + 1e-5f) / LOG_TABLE_SIZE); 333 | t_log[i] = (float) Math.log(x); 334 | } 335 | } 336 | 337 | public float log(float x) { 338 | if (x > 1.0f) { 339 | return 0.0f; 340 | } 341 | int i = (int) (x * LOG_TABLE_SIZE); 342 | return t_log[i]; 343 | } 344 | 345 | public float sigmoid(float x) { 346 | if (x < -MAX_SIGMOID) { 347 | return 0.0f; 348 | } else if (x > MAX_SIGMOID) { 349 | return 1.0f; 350 | } else { 351 | int i = (int) ((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2); 352 | return t_sigmoid[i]; 353 | } 354 | } 355 | } 356 | -------------------------------------------------------------------------------- /src/main/java/fasttext/Pair.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | public class Pair { 4 | 5 | private K key_; 6 | private V value_; 7 | 8 | public Pair(K key, V value) { 9 | this.key_ = key; 10 | this.value_ = value; 11 | } 12 | 13 | public K getKey() { 14 | return key_; 15 | } 16 | 17 | public V getValue() { 18 | return value_; 19 | } 20 | 21 | public void setKey(K key) { 22 | this.key_ = key; 23 | } 24 | 25 | public void setValue(V value) { 26 | this.value_ = value; 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/fasttext/Utils.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import java.io.BufferedInputStream; 4 | import java.io.BufferedReader; 5 | import java.io.FileInputStream; 6 | import java.io.IOException; 7 | import java.io.InputStream; 8 | import java.util.List; 9 | import java.util.ListIterator; 10 | import java.util.Map; 11 | import java.util.Random; 12 | import java.util.RandomAccess; 13 | 14 | public strictfp class Utils { 15 | 16 | /** 17 | * Ensures the truth of an expression involving one or more parameters to 18 | * the calling method. 19 | * 20 | * @param expression 21 | * a boolean expression 22 | * @throws IllegalArgumentException 23 | * if {@code expression} is false 24 | */ 25 | public static void checkArgument(boolean expression) { 26 | if (!expression) { 27 | throw new IllegalArgumentException(); 28 | } 29 | } 30 | 31 | public static void checkArgument(boolean expression, String message) { 32 | if (!expression) { 33 | throw new IllegalArgumentException(message); 34 | } 35 | } 36 | 37 | public static boolean isEmpty(String str) { 38 | return (str == null || str.isEmpty()); 39 | } 40 | 41 | public static V mapGetOrDefault(Map map, K key, V defaultValue) { 42 | return map.containsKey(key) ? map.get(key) : defaultValue; 43 | } 44 | 45 | public static int randomInt(Random rnd, int lower, int upper) { 46 | checkArgument(lower <= upper & lower > 0); 47 | if (lower == upper) { 48 | return lower; 49 | } 50 | return rnd.nextInt(upper - lower) + lower; 51 | } 52 | 53 | public static float randomFloat(Random rnd, float lower, float upper) { 54 | checkArgument(lower <= upper); 55 | if (lower == upper) { 56 | return lower; 57 | } 58 | return (rnd.nextFloat() * (upper - lower)) + lower; 59 | } 60 | 61 | public static long sizeLine(String filename) throws IOException { 62 | InputStream is = new BufferedInputStream(new FileInputStream(filename)); 63 | try { 64 | byte[] c = new byte[1024]; 65 | long count = 0; 66 | int readChars = 0; 67 | boolean endsWithoutNewLine = false; 68 | while ((readChars = is.read(c)) != -1) { 69 | for (int i = 0; i < readChars; ++i) { 70 | if (c[i] == '\n') 71 | ++count; 72 | } 73 | endsWithoutNewLine = (c[readChars - 1] != '\n'); 74 | } 75 | if (endsWithoutNewLine) { 76 | ++count; 77 | } 78 | return count; 79 | } finally { 80 | is.close(); 81 | } 82 | } 83 | 84 | /** 85 | * 86 | * @param br 87 | * @param pos 88 | * line numbers start from 1 89 | * @throws IOException 90 | */ 91 | public static void seekLine(BufferedReader br, long pos) throws IOException { 92 | // br.reset(); 93 | String line; 94 | int currentLine = 1; 95 | while (currentLine < pos && (line = br.readLine()) != null) { 96 | if (Utils.isEmpty(line) || line.startsWith("#")) { 97 | continue; 98 | } 99 | currentLine++; 100 | } 101 | } 102 | 103 | private static final int SHUFFLE_THRESHOLD = 5; 104 | 105 | @SuppressWarnings({ "rawtypes", "unchecked" }) 106 | public static void shuffle(List list, Random rnd) { 107 | int size = list.size(); 108 | if (size < SHUFFLE_THRESHOLD || list instanceof RandomAccess) { 109 | for (int i = size; i > 1; i--) 110 | swap(list, i - 1, rnd.nextInt(i)); 111 | } else { 112 | Object arr[] = list.toArray(); 113 | 114 | // Shuffle array 115 | for (int i = size; i > 1; i--) 116 | swap(arr, i - 1, rnd.nextInt(i)); 117 | 118 | // Dump array back into list 119 | // instead of using a raw type here, it's possible to capture 120 | // the wildcard but it will require a call to a supplementary 121 | // private method 122 | ListIterator it = list.listIterator(); 123 | for (int i = 0; i < arr.length; i++) { 124 | it.next(); 125 | it.set(arr[i]); 126 | } 127 | } 128 | } 129 | 130 | /** 131 | * Swaps the two specified elements in the specified array. 132 | */ 133 | public static void swap(Object[] arr, int i, int j) { 134 | Object tmp = arr[i]; 135 | arr[i] = arr[j]; 136 | arr[j] = tmp; 137 | } 138 | 139 | @SuppressWarnings({ "rawtypes", "unchecked" }) 140 | public static void swap(List list, int i, int j) { 141 | // instead of using a raw type here, it's possible to capture 142 | // the wildcard but it will require a call to a supplementary 143 | // private method 144 | final List l = list; 145 | l.set(i, l.set(j, l.get(i))); 146 | } 147 | 148 | } 149 | -------------------------------------------------------------------------------- /src/main/java/fasttext/Vector.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | public strictfp class Vector { 4 | 5 | public int m_; 6 | public float[] data_; 7 | 8 | public void addVector(Vector source) { 9 | for (int i = 0; i < m_; i++) { 10 | data_[i] += source.data_[i]; 11 | } 12 | } 13 | 14 | public float norm() { 15 | float sum = 0; 16 | for (int i = 0; i < m_; i++) { 17 | sum += data_[i] * data_[i]; 18 | } 19 | return (float) Math.sqrt(sum); 20 | } 21 | 22 | public Vector(int size) { 23 | m_ = size; 24 | data_ = new float[size]; 25 | } 26 | 27 | public int size() { 28 | return m_; 29 | } 30 | 31 | public void zero() { 32 | for (int i = 0; i < m_; i++) { 33 | data_[i] = 0.0f; 34 | } 35 | } 36 | 37 | public void mul(float a) { 38 | for (int i = 0; i < m_; i++) { 39 | data_[i] *= a; 40 | } 41 | } 42 | 43 | public void addRow(final Matrix A, int i) { 44 | Utils.checkArgument(i >= 0); 45 | Utils.checkArgument(i < A.m_); 46 | Utils.checkArgument(m_ == A.n_); 47 | for (int j = 0; j < A.n_; j++) { // layer size 48 | data_[j] += A.data_[i][j]; 49 | } 50 | } 51 | 52 | public void addRow(final Matrix A, int i, float a) { 53 | Utils.checkArgument(i >= 0); 54 | Utils.checkArgument(i < A.m_); 55 | Utils.checkArgument(m_ == A.n_); 56 | for (int j = 0; j < A.n_; j++) { 57 | data_[j] += a * A.data_[i][j]; 58 | } 59 | } 60 | 61 | public void mul(final Matrix A, final Vector vec) { 62 | Utils.checkArgument(A.m_ == m_); 63 | Utils.checkArgument(A.n_ == vec.m_); 64 | for (int i = 0; i < m_; i++) { 65 | data_[i] = 0.0f; 66 | for (int j = 0; j < A.n_; j++) { 67 | data_[i] += A.data_[i][j] * vec.data_[j]; 68 | } 69 | } 70 | } 71 | 72 | public int argmax() { 73 | float max = data_[0]; 74 | int argmax = 0; 75 | for (int i = 1; i < m_; i++) { 76 | if (data_[i] > max) { 77 | max = data_[i]; 78 | argmax = i; 79 | } 80 | } 81 | return argmax; 82 | } 83 | 84 | public float get(int i) { 85 | return data_[i]; 86 | } 87 | 88 | public void set(int i, float value) { 89 | data_[i] = value; 90 | } 91 | 92 | @Override 93 | public String toString() { 94 | StringBuilder builder = new StringBuilder(); 95 | for (float data : data_) { 96 | builder.append(data).append(' '); 97 | } 98 | if (builder.length() > 1) { 99 | builder.setLength(builder.length() - 1); 100 | } 101 | return builder.toString(); 102 | } 103 | 104 | } 105 | -------------------------------------------------------------------------------- /src/main/java/fasttext/io/BufferedLineReader.java: -------------------------------------------------------------------------------- 1 | package fasttext.io; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.FileInputStream; 5 | import java.io.IOException; 6 | import java.io.InputStream; 7 | import java.io.InputStreamReader; 8 | import java.io.UnsupportedEncodingException; 9 | 10 | public class BufferedLineReader extends LineReader { 11 | 12 | private String lineDelimitingRegex_ = " |\r|\t|\\v|\f|\0"; 13 | 14 | private BufferedReader br_; 15 | 16 | public BufferedLineReader(String filename, String charsetName) throws IOException, UnsupportedEncodingException { 17 | super(filename, charsetName); 18 | FileInputStream fis = new FileInputStream(file_); 19 | br_ = new BufferedReader(new InputStreamReader(fis, charset_)); 20 | } 21 | 22 | public BufferedLineReader(InputStream inputStream, String charsetName) throws UnsupportedEncodingException { 23 | super(inputStream, charsetName); 24 | br_ = new BufferedReader(new InputStreamReader(inputStream, charset_)); 25 | } 26 | 27 | @Override 28 | public long skipLine(long n) throws IOException { 29 | if (n < 0L) { 30 | throw new IllegalArgumentException("skip value is negative"); 31 | } 32 | String line; 33 | long currentLine = 0; 34 | long readLine = 0; 35 | synchronized (lock) { 36 | while (currentLine < n && (line = br_.readLine()) != null) { 37 | readLine++; 38 | if (line == null || line.isEmpty() || line.startsWith("#")) { 39 | continue; 40 | } 41 | currentLine++; 42 | } 43 | return readLine; 44 | } 45 | } 46 | 47 | @Override 48 | public String readLine() throws IOException { 49 | synchronized (lock) { 50 | String lineString = br_.readLine(); 51 | while (lineString != null && (lineString.isEmpty() || lineString.startsWith("#"))) { 52 | lineString = br_.readLine(); 53 | } 54 | return lineString; 55 | } 56 | } 57 | 58 | @Override 59 | public String[] readLineTokens() throws IOException { 60 | String line = readLine(); 61 | if (line == null) 62 | return null; 63 | else 64 | return line.split(lineDelimitingRegex_, -1); 65 | } 66 | 67 | @Override 68 | public int read(char[] cbuf, int off, int len) throws IOException { 69 | synchronized (lock) { 70 | return br_.read(cbuf, off, len); 71 | } 72 | } 73 | 74 | @Override 75 | public void close() throws IOException { 76 | synchronized (lock) { 77 | if (br_ != null) { 78 | br_.close(); 79 | } 80 | } 81 | } 82 | 83 | @Override 84 | public void rewind() throws IOException { 85 | synchronized (lock) { 86 | if (br_ != null) { 87 | br_.close(); 88 | } 89 | if (file_ != null) { 90 | FileInputStream fis = new FileInputStream(file_); 91 | br_ = new BufferedReader(new InputStreamReader(fis, charset_)); 92 | } else { 93 | // br = new BufferedReader(new InputStreamReader(inputStream, 94 | // charset)); 95 | throw new UnsupportedOperationException("InputStream rewind not supported"); 96 | } 97 | } 98 | } 99 | 100 | public String getLineDelimitingRege() { 101 | return lineDelimitingRegex_; 102 | } 103 | 104 | public void setLineDelimitingRegex(String lineDelimitingRegex) { 105 | this.lineDelimitingRegex_ = lineDelimitingRegex; 106 | } 107 | 108 | } 109 | -------------------------------------------------------------------------------- /src/main/java/fasttext/io/LineReader.java: -------------------------------------------------------------------------------- 1 | package fasttext.io; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.io.InputStream; 6 | import java.io.Reader; 7 | import java.io.UnsupportedEncodingException; 8 | import java.nio.charset.Charset; 9 | 10 | public abstract class LineReader extends Reader { 11 | 12 | protected InputStream inputStream_ = null; 13 | protected File file_ = null; 14 | protected Charset charset_ = null; 15 | 16 | protected LineReader() { 17 | super(); 18 | } 19 | 20 | protected LineReader(Object lock) { 21 | super(lock); 22 | } 23 | 24 | public LineReader(String filename, String charsetName) throws IOException, UnsupportedEncodingException { 25 | this(); 26 | this.file_ = new File(filename); 27 | this.charset_ = Charset.forName(charsetName); 28 | } 29 | 30 | public LineReader(InputStream inputStream, String charsetName) throws UnsupportedEncodingException { 31 | this(); 32 | this.inputStream_ = inputStream; 33 | this.charset_ = Charset.forName(charsetName); 34 | } 35 | 36 | /** 37 | * Skips lines. 38 | * 39 | * @param n 40 | * The number of lines to skip 41 | * @return The number of lines actually skipped 42 | * @exception IOException 43 | * If an I/O error occurs 44 | * @exception IllegalArgumentException 45 | * If n is negative. 46 | */ 47 | public abstract long skipLine(long n) throws IOException; 48 | 49 | public abstract String readLine() throws IOException; 50 | 51 | public abstract String[] readLineTokens() throws IOException; 52 | 53 | public abstract void rewind() throws IOException; 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/fasttext/io/MappedByteBufferLineReader.java: -------------------------------------------------------------------------------- 1 | package fasttext.io; 2 | 3 | import java.io.BufferedInputStream; 4 | import java.io.IOException; 5 | import java.io.InputStream; 6 | import java.io.RandomAccessFile; 7 | import java.io.UnsupportedEncodingException; 8 | import java.nio.ByteBuffer; 9 | import java.nio.CharBuffer; 10 | import java.nio.channels.FileChannel; 11 | import java.util.ArrayList; 12 | import java.util.List; 13 | 14 | public class MappedByteBufferLineReader extends LineReader { 15 | 16 | private static int DEFAULT_BUFFER_SIZE = 1024; 17 | 18 | private volatile ByteBuffer byteBuffer_ = null; // MappedByteBuffer 19 | private RandomAccessFile raf_ = null; 20 | private FileChannel channel_ = null; 21 | private byte[] bytes_ = null; 22 | 23 | private int string_buf_size_ = DEFAULT_BUFFER_SIZE; 24 | private boolean fillLine_ = false; 25 | 26 | private StringBuilder sb_ = new StringBuilder(); 27 | private List tokens_ = new ArrayList(); 28 | 29 | public MappedByteBufferLineReader(String filename, String charsetName) 30 | throws IOException, UnsupportedEncodingException { 31 | super(filename, charsetName); 32 | raf_ = new RandomAccessFile(file_, "r"); 33 | channel_ = raf_.getChannel(); 34 | byteBuffer_ = channel_.map(FileChannel.MapMode.READ_ONLY, 0, channel_.size()); 35 | bytes_ = new byte[string_buf_size_]; 36 | } 37 | 38 | public MappedByteBufferLineReader(InputStream inputStream, String charsetName) throws UnsupportedEncodingException { 39 | this(inputStream, charsetName, DEFAULT_BUFFER_SIZE); 40 | } 41 | 42 | public MappedByteBufferLineReader(InputStream inputStream, String charsetName, int buf_size) 43 | throws UnsupportedEncodingException { 44 | super(inputStream instanceof BufferedInputStream ? inputStream : new BufferedInputStream(inputStream), 45 | charsetName); 46 | string_buf_size_ = buf_size; 47 | byteBuffer_ = ByteBuffer.allocateDirect(string_buf_size_); // ByteBuffer.allocate(string_buf_size_); 48 | bytes_ = new byte[string_buf_size_]; 49 | if (inputStream == System.in) { 50 | fillLine_ = true; 51 | } 52 | } 53 | 54 | @Override 55 | public long skipLine(long n) throws IOException { 56 | if (n < 0L) { 57 | throw new IllegalArgumentException("skip value is negative"); 58 | } 59 | String line; 60 | long currentLine = 0; 61 | long readLine = 0; 62 | synchronized (lock) { 63 | ensureOpen(); 64 | while (currentLine < n && (line = getLine()) != null) { 65 | readLine++; 66 | if (line == null || line.isEmpty() || line.startsWith("#")) { 67 | continue; 68 | } 69 | currentLine++; 70 | } 71 | } 72 | return readLine; 73 | } 74 | 75 | @Override 76 | public String readLine() throws IOException { 77 | synchronized (lock) { 78 | ensureOpen(); 79 | String lineString = getLine(); 80 | while (lineString != null && (lineString.isEmpty() || lineString.startsWith("#"))) { 81 | lineString = getLine(); 82 | } 83 | return lineString; 84 | } 85 | } 86 | 87 | @Override 88 | public String[] readLineTokens() throws IOException { 89 | synchronized (lock) { 90 | ensureOpen(); 91 | String[] tokens = getLineTokens(); 92 | while (tokens != null && ((tokens.length == 1 && tokens[0].isEmpty()) || tokens[0].startsWith("#"))) { 93 | tokens = getLineTokens(); 94 | } 95 | return tokens; 96 | } 97 | } 98 | 99 | @Override 100 | public void rewind() throws IOException { 101 | synchronized (lock) { 102 | ensureOpen(); 103 | if (raf_ != null) { 104 | raf_.seek(0); 105 | channel_.position(0); 106 | } 107 | byteBuffer_.position(0); 108 | } 109 | } 110 | 111 | @Override 112 | public int read(char[] cbuf, int off, int len) throws IOException { 113 | synchronized (lock) { 114 | ensureOpen(); 115 | if ((off < 0) || (off > cbuf.length) || (len < 0) || ((off + len) > cbuf.length) || ((off + len) < 0)) { 116 | throw new IndexOutOfBoundsException(); 117 | } else if (len == 0) { 118 | return 0; 119 | } 120 | 121 | CharBuffer charBuffer = byteBuffer_.asCharBuffer(); 122 | int length = Math.min(len, charBuffer.remaining()); 123 | charBuffer.get(cbuf, off, length); 124 | 125 | if (inputStream_ != null) { 126 | off += length; 127 | 128 | while (off < len) { 129 | fillByteBuffer(); 130 | if (!byteBuffer_.hasRemaining()) { 131 | break; 132 | } 133 | charBuffer = byteBuffer_.asCharBuffer(); 134 | length = Math.min(len, charBuffer.remaining()); 135 | charBuffer.get(cbuf, off, length); 136 | off += length; 137 | } 138 | } 139 | return length == len ? len : -1; 140 | } 141 | } 142 | 143 | @Override 144 | public void close() throws IOException { 145 | synchronized (lock) { 146 | if (raf_ != null) { 147 | raf_.close(); 148 | } else if (inputStream_ != null) { 149 | inputStream_.close(); 150 | } 151 | channel_ = null; 152 | byteBuffer_ = null; 153 | } 154 | } 155 | 156 | /** Checks to make sure that the stream has not been closed */ 157 | private void ensureOpen() throws IOException { 158 | if (byteBuffer_ == null) 159 | throw new IOException("Stream closed"); 160 | } 161 | 162 | protected String getLine() throws IOException { 163 | fillByteBuffer(); 164 | if (!byteBuffer_.hasRemaining()) { 165 | return null; 166 | } 167 | sb_.setLength(0); 168 | int b = -1; 169 | int i = -1; 170 | do { 171 | b = byteBuffer_.get(); 172 | if ((b >= 10 && b <= 13) || b == 0) { 173 | break; 174 | } 175 | bytes_[++i] = (byte) b; 176 | if (i == string_buf_size_ - 1) { 177 | sb_.append(new String(bytes_, charset_)); 178 | i = -1; 179 | } 180 | fillByteBuffer(); 181 | } while (byteBuffer_.hasRemaining()); 182 | 183 | sb_.append(new String(bytes_, 0, i + 1, charset_)); 184 | return sb_.toString(); 185 | } 186 | 187 | // " |\r|\t|\\v|\f|\0" 188 | // 32 ' ', 9 \t, 10 \n, 11 \\v, 12 \f, 13 \r, 0 \0 189 | protected String[] getLineTokens() throws IOException { 190 | fillByteBuffer(); 191 | if (!byteBuffer_.hasRemaining()) { 192 | return null; 193 | } 194 | tokens_.clear(); 195 | sb_.setLength(0); 196 | 197 | int b = -1; 198 | int i = -1; 199 | do { 200 | b = byteBuffer_.get(); 201 | 202 | if ((b >= 10 && b <= 13) || b == 0) { 203 | break; 204 | } else if (b == 9 || b == 32) { 205 | sb_.append(new String(bytes_, 0, i + 1, charset_)); 206 | tokens_.add(sb_.toString()); 207 | sb_.setLength(0); 208 | i = -1; 209 | } else { 210 | bytes_[++i] = (byte) b; 211 | if (i == string_buf_size_ - 1) { 212 | sb_.append(new String(bytes_, charset_)); 213 | i = -1; 214 | } 215 | } 216 | fillByteBuffer(); 217 | } while (byteBuffer_.hasRemaining()); 218 | 219 | sb_.append(new String(bytes_, 0, i + 1, charset_)); 220 | tokens_.add(sb_.toString()); 221 | return tokens_.toArray(new String[tokens_.size()]); 222 | } 223 | 224 | private void fillByteBuffer() throws IOException { 225 | if (inputStream_ == null || byteBuffer_.hasRemaining()) { 226 | return; 227 | } 228 | 229 | byteBuffer_.clear(); 230 | 231 | int b; 232 | for (int i = 0; i < string_buf_size_; i++) { 233 | b = inputStream_.read(); 234 | if (b < 0) { // END OF STREAM 235 | break; 236 | } 237 | byteBuffer_.put((byte) b); 238 | if (fillLine_) { 239 | if ((b >= 10 && b <= 13) || b == 0) { 240 | break; 241 | } 242 | } 243 | } 244 | 245 | byteBuffer_.flip(); 246 | } 247 | 248 | } 249 | -------------------------------------------------------------------------------- /src/test/java/fasttext/TestDictionary.java: -------------------------------------------------------------------------------- 1 | package fasttext; 2 | 3 | import static org.junit.Assert.*; 4 | 5 | import java.util.Map; 6 | 7 | import org.junit.Test; 8 | 9 | public class TestDictionary { 10 | 11 | Dictionary dictionary = new Dictionary(new Args()); 12 | 13 | @Test 14 | public void testHash() { 15 | assertEquals(dictionary.hash(","), 688690635l); 16 | assertEquals(dictionary.hash("is"), 1312329493l); 17 | assertEquals(dictionary.hash(""), 3617362777l); 18 | } 19 | 20 | @Test 21 | public void testFind() { 22 | assertEquals(dictionary.find(","), 28690635l); 23 | assertEquals(dictionary.find("is"), 22329493l); 24 | assertEquals(dictionary.find(""), 17362777l); 25 | } 26 | 27 | @Test 28 | public void testAdd() { 29 | dictionary.add(","); 30 | dictionary.add("is"); 31 | dictionary.add("is"); 32 | String w = ""; 33 | dictionary.add(w); 34 | dictionary.add(w); 35 | dictionary.add(w); 36 | Map word2int = dictionary.getWord2int(); 37 | assertEquals(3, dictionary.getWords().get(word2int.get(dictionary.find(w))).count); 38 | assertEquals(2, dictionary.getWords().get(word2int.get(dictionary.find("is"))).count); 39 | assertEquals(1, dictionary.getWords().get(word2int.get(dictionary.find(","))).count); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /wikifil.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | 3 | # Program to filter Wikipedia XML dumps to "clean" text consisting only of lowercase 4 | # letters (a-z, converted from A-Z), and spaces (never consecutive). 5 | # All other characters are converted to spaces. Only text which normally appears 6 | # in the web browser is displayed. Tables are removed. Image captions are 7 | # preserved. Links are converted to normal text. Digits are spelled out. 8 | 9 | # Written by Matt Mahoney, June 10, 2006. This program is released to the public domain. 10 | 11 | $/=">"; # input record separator 12 | while (<>) { 13 | if (/ ... 14 | if (/#redirect/i) {$text=0;} # remove #REDIRECT 15 | if ($text) { 16 | 17 | # Remove any text not normally visible 18 | if (/<\/text>/) {$text=0;} 19 | s/<.*>//; # remove xml tags 20 | s/&/&/g; # decode URL encoded chars 21 | s/<//g; 23 | s///g; # remove references ... 24 | s/<[^>]*>//g; # remove xhtml tags 25 | s/\[http:[^] ]*/[/g; # remove normal url, preserve visible text 26 | s/\|thumb//ig; # remove images links, preserve caption 27 | s/\|left//ig; 28 | s/\|right//ig; 29 | s/\|\d+px//ig; 30 | s/\[\[image:[^\[\]]*\|//ig; 31 | s/\[\[category:([^|\]]*)[^]]*\]\]/[[$1]]/ig; # show categories without markup 32 | s/\[\[[a-z\-]*:[^\]]*\]\]//g; # remove links to other languages 33 | s/\[\[[^\|\]]*\|/[[/g; # remove wiki url, preserve visible text 34 | s/{{[^}]*}}//g; # remove {{icons}} and {tables} 35 | s/{[^}]*}//g; 36 | s/\[//g; # remove [ and ] 37 | s/\]//g; 38 | s/&[^;]*;/ /g; # remove URL encoded chars 39 | 40 | # convert to lowercase letters and spaces, spell digits 41 | $_=" $_ "; 42 | tr/A-Z/a-z/; 43 | s/0/ zero /g; 44 | s/1/ one /g; 45 | s/2/ two /g; 46 | s/3/ three /g; 47 | s/4/ four /g; 48 | s/5/ five /g; 49 | s/6/ six /g; 50 | s/7/ seven /g; 51 | s/8/ eight /g; 52 | s/9/ nine /g; 53 | tr/a-z/ /cs; 54 | chop; 55 | print $_; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /word-vector-example.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Copyright (c) 2016-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. An additional grant 8 | # of patent rights can be found in the PATENTS file in the same directory. 9 | # 10 | 11 | RESULTDIR=result 12 | DATADIR=data 13 | 14 | mkdir -p "${RESULTDIR}" 15 | mkdir -p "${DATADIR}" 16 | 17 | if [ ! -f "${DATADIR}/text9" ] 18 | then 19 | wget -c http://mattmahoney.net/dc/enwik9.zip -P "${DATADIR}" 20 | unzip "${DATADIR}/enwik9.zip" -d "${DATADIR}" 21 | perl wikifil.pl "${DATADIR}/enwik9" > "${DATADIR}"/text9 22 | fi 23 | 24 | if [ ! -f "${DATADIR}/rw/rw.txt" ] 25 | then 26 | wget -c http://stanford.edu/~lmthang/morphoNLM/rw.zip -P "${DATADIR}" 27 | unzip "${DATADIR}/rw.zip" -d "${DATADIR}" 28 | fi 29 | 30 | mvn package 31 | 32 | JAR=./target/fasttext-0.0.1-SNAPSHOT-jar-with-dependencies.jar 33 | 34 | java -jar ${JAR} skipgram -input "${DATADIR}"/text9 -output "${RESULTDIR}"/text9 -lr 0.025 -dim 100 \ 35 | -ws 5 -epoch 1 -minCount 5 -neg 5 -loss ns -bucket 2000000 \ 36 | -minn 3 -maxn 6 -thread 4 -t 1e-4 -lrUpdateRate 100 37 | 38 | cut -f 1,2 "${DATADIR}"/rw/rw.txt | awk '{print tolower($0)}' | tr '\t' '\n' > "${DATADIR}"/queries.txt 39 | 40 | cat "${DATADIR}"/queries.txt | ./fasttext print-vectors "${RESULTDIR}"/text9.bin > "${RESULTDIR}"/vectors.txt 41 | 42 | python eval.py -m "${RESULTDIR}"/vectors.txt -d "${DATADIR}"/rw/rw.txt 43 | --------------------------------------------------------------------------------