├── .gitignore
├── LICENSE
├── PATENTS
├── README.md
├── classification-example.sh
├── classification-results.sh
├── eval.py
├── pom.xml
├── src
├── main
│ └── java
│ │ └── fasttext
│ │ ├── Args.java
│ │ ├── Dictionary.java
│ │ ├── FastText.java
│ │ ├── IOUtil.java
│ │ ├── Main.java
│ │ ├── Matrix.java
│ │ ├── Model.java
│ │ ├── Pair.java
│ │ ├── Utils.java
│ │ ├── Vector.java
│ │ └── io
│ │ ├── BufferedLineReader.java
│ │ ├── LineReader.java
│ │ └── MappedByteBufferLineReader.java
└── test
│ └── java
│ └── fasttext
│ └── TestDictionary.java
├── wikifil.pl
└── word-vector-example.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .*.swp
2 | *.o
3 | *.bin
4 | *.vec
5 | data
6 | fasttext
7 | result
8 | target
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD License
2 |
3 | For fastText software
4 |
5 | Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without modification,
8 | are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | * Neither the name Facebook nor the names of its contributors may be used to
18 | endorse or promote products derived from this software without specific
19 | prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/PATENTS:
--------------------------------------------------------------------------------
1 | Additional Grant of Patent Rights Version 2
2 |
3 | "Software" means the fastText software distributed by Facebook, Inc.
4 |
5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
7 | (subject to the termination provision below) license under any Necessary
8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise
9 | transfer the Software. For avoidance of doubt, no license is granted under
10 | Facebook’s rights in any patent claims that are infringed by (i) modifications
11 | to the Software made by you or any third party or (ii) the Software in
12 | combination with any software or other technology.
13 |
14 | The license granted hereunder will terminate, automatically and without notice,
15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate
16 | directly or indirectly, or take a direct financial interest in, any Patent
17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate
18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or
19 | in part from any software, technology, product or service of Facebook or any of
20 | its subsidiaries or corporate affiliates, or (iii) against any party relating
21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its
22 | subsidiaries or corporate affiliates files a lawsuit alleging patent
23 | infringement against you in the first instance, and you respond by filing a
24 | patent infringement counterclaim in that lawsuit against that party that is
25 | unrelated to the Software, the license granted hereunder will not terminate
26 | under section (i) of this paragraph due to such counterclaim.
27 |
28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is
29 | necessarily infringed by the Software standing alone.
30 |
31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
32 | or contributory infringement or inducement to infringe any patent, including a
33 | cross-claim or counterclaim.
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # fasttext_java
2 | Java port of c++ version of facebook fasttext [UPDATED 2017-01-29]
3 |
4 | Support Load/Save facebook fasttext binary model file
5 |
6 | ## Building fastText_java
7 | Requirements: Maven, Java 1.6 or onwards
8 |
9 | In order to build `fastText_java`, use the following:
10 |
11 | ```
12 | $ git clone https://github.com/ivanhk/fastText_java.git
13 | $ cd fastText_java
14 | $ mvn package
15 | ```
16 |
17 | ## Resources
18 |
19 | You can find more information and resources at https://github.com/facebookresearch/fastText
20 |
21 | ## License
22 |
23 | fastText is BSD-licensed. Facebook also provide an additional patent grant.
24 |
--------------------------------------------------------------------------------
/classification-example.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Copyright (c) 2016-present, Facebook, Inc.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree. An additional grant
8 | # of patent rights can be found in the PATENTS file in the same directory.
9 | #
10 |
11 | myshuf() {
12 | perl -MList::Util=shuffle -e 'print shuffle(<>);' "$@";
13 | }
14 |
15 | normalize_text() {
16 | tr '[:upper:]' '[:lower:]' | sed -e 's/^/__label__/g' | \
17 | sed -e "s/'/ ' /g" -e 's/"//g' -e 's/\./ \. /g' -e 's/
/ /g' \
18 | -e 's/,/ , /g' -e 's/(/ ( /g' -e 's/)/ ) /g' -e 's/\!/ \! /g' \
19 | -e 's/\?/ \? /g' -e 's/\;/ /g' -e 's/\:/ /g' | tr -s " " | myshuf
20 | }
21 |
22 | RESULTDIR=result
23 | DATADIR=data
24 |
25 | mkdir -p "${RESULTDIR}"
26 | mkdir -p "${DATADIR}"
27 |
28 | if [ ! -f "${DATADIR}/dbpedia.train" ]
29 | then
30 | wget -c "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz" -O "${DATADIR}/dbpedia_csv.tar.gz"
31 | tar -xzvf "${DATADIR}/dbpedia_csv.tar.gz" -C "${DATADIR}"
32 | cat "${DATADIR}/dbpedia_csv/train.csv" | normalize_text > "${DATADIR}/dbpedia.train"
33 | cat "${DATADIR}/dbpedia_csv/test.csv" | normalize_text > "${DATADIR}/dbpedia.test"
34 | fi
35 |
36 | mvn package
37 |
38 | JAR=./target/fasttext-0.0.1-SNAPSHOT-jar-with-dependencies.jar
39 |
40 | java -jar ${JAR} supervised -input "${DATADIR}/dbpedia.train" -output "${RESULTDIR}/dbpedia" -dim 10 -lr 0.1 -wordNgrams 2 -minCount 1 -bucket 10000000 -epoch 5 -thread 4
41 |
42 | java -jar ${JAR} test "${RESULTDIR}/dbpedia.bin" "${DATADIR}/dbpedia.test"
43 |
44 | java -jar ${JAR} predict "${RESULTDIR}/dbpedia.bin" "${DATADIR}/dbpedia.test" > "${RESULTDIR}/dbpedia.test.predict"
45 |
--------------------------------------------------------------------------------
/classification-results.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Copyright (c) 2016-present, Facebook, Inc.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree. An additional grant
8 | # of patent rights can be found in the PATENTS file in the same directory.
9 | #
10 |
11 | # This script produces the results from Table 1 in the following paper:
12 | # Bag of Tricks for Efficient Text Classification, arXiv 1607.01759, 2016
13 |
14 | myshuf() {
15 | perl -MList::Util=shuffle -e 'print shuffle(<>);' "$@";
16 | }
17 |
18 | normalize_text() {
19 | tr '[:upper:]' '[:lower:]' | sed -e 's/^/__label__/g' | \
20 | sed -e "s/'/ ' /g" -e 's/"//g' -e 's/\./ \. /g' -e 's/
/ /g' \
21 | -e 's/,/ , /g' -e 's/(/ ( /g' -e 's/)/ ) /g' -e 's/\!/ \! /g' \
22 | -e 's/\?/ \? /g' -e 's/\;/ /g' -e 's/\:/ /g' | tr -s " " | myshuf
23 | }
24 |
25 | DATASET=(
26 | ag_news
27 | sogou_news
28 | dbpedia
29 | yelp_review_polarity
30 | yelp_review_full
31 | yahoo_answers
32 | amazon_review_full
33 | amazon_review_polarity
34 | )
35 |
36 | ID=(
37 | 0Bz8a_Dbh9QhbUDNpeUdjb0wxRms # ag_news
38 | 0Bz8a_Dbh9QhbUkVqNEszd0pHaFE # sogou_news
39 | 0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k # dbpedia
40 | 0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg # yelp_review_polarity
41 | 0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0 # yelp_review_full
42 | 0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU # yahoo_answers
43 | 0Bz8a_Dbh9QhbZVhsUnRWRDhETzA # amazon_review_full
44 | 0Bz8a_Dbh9QhbaW12WVVZS2drcnM # amazon_review_polarity
45 | )
46 |
47 | # These learning rates were chosen by validation on a subset of the training set.
48 | LR=( 0.25 0.5 0.5 0.1 0.1 0.1 0.05 0.05 )
49 |
50 | RESULTDIR=result
51 | DATADIR=data
52 |
53 | mkdir -p "${RESULTDIR}"
54 | mkdir -p "${DATADIR}"
55 |
56 | for i in {0..7}
57 | do
58 | echo "Downloading dataset ${DATASET[i]}"
59 | if [ ! -f "${DATADIR}/${DATASET[i]}.train" ]
60 | then
61 | wget -c "https://googledrive.com/host/${ID[i]}" -O "${DATADIR}/${DATASET[i]}_csv.tar.gz"
62 | tar -xzvf "${DATADIR}/${DATASET[i]}_csv.tar.gz" -C "${DATADIR}"
63 | cat "${DATADIR}/${DATASET[i]}_csv/train.csv" | normalize_text > "${DATADIR}/${DATASET[i]}.train"
64 | cat "${DATADIR}/${DATASET[i]}_csv/test.csv" | normalize_text > "${DATADIR}/${DATASET[i]}.test"
65 | fi
66 | done
67 |
68 | mvn package
69 |
70 | JAR=./target/fasttext-0.0.1-SNAPSHOT-jar-with-dependencies.jar
71 |
72 | for i in {0..7}
73 | do
74 | echo "Working on dataset ${DATASET[i]}"
75 | java -jar ${JAR} supervised -input "${DATADIR}/${DATASET[i]}.train" \
76 | -output "${RESULTDIR}/${DATASET[i]}" -dim 10 -lr "${LR[i]}" -wordNgrams 2 \
77 | -minCount 1 -bucket 10000000 -epoch 5 -thread 4 > /dev/null
78 | java -jar ${JAR} test "${RESULTDIR}/${DATASET[i]}.bin" \
79 | "${DATADIR}/${DATASET[i]}.test"
80 | done
81 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (c) 2016-present, Facebook, Inc.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the BSD-style license found in the
8 | # LICENSE file in the root directory of this source tree. An additional grant
9 | # of patent rights can be found in the PATENTS file in the same directory.
10 | #
11 |
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 | from __future__ import unicode_literals
16 | import numpy as np
17 | from scipy import stats
18 | import sys
19 | import os
20 | import math
21 | import argparse
22 |
23 | def compat_splitting(line):
24 | return line.decode('utf8').split()
25 |
26 | def similarity(v1, v2):
27 | n1 = np.linalg.norm(v1)
28 | n2 = np.linalg.norm(v2)
29 | return np.dot(v1, v2) / n1 / n2
30 |
31 | parser = argparse.ArgumentParser(description='Process some integers.')
32 | parser.add_argument('--model', '-m', dest='modelPath', action='store', required=True, help='path to model')
33 | parser.add_argument('--data', '-d', dest='dataPath', action='store', required=True, help='path to data')
34 | args = parser.parse_args()
35 |
36 | vectors = {}
37 | fin = open(args.modelPath, 'rb')
38 | for i, line in enumerate(fin):
39 | try:
40 | tab = compat_splitting(line)
41 | vec = np.array(tab[1:], dtype=float)
42 | word = tab[0]
43 | if not word in vectors:
44 | vectors[word] = vec
45 | except ValueError:
46 | continue
47 | except UnicodeDecodeError:
48 | continue
49 | fin.close()
50 |
51 | mysim = []
52 | gold = []
53 | drop = 0.0
54 | nwords = 0.0
55 |
56 | fin = open(args.dataPath, 'rb')
57 | for line in fin:
58 | tline = compat_splitting(line)
59 | word1 = tline[0].lower()
60 | word2 = tline[1].lower()
61 | nwords = nwords + 1.0
62 |
63 | if (word1 in vectors) and (word2 in vectors):
64 | v1 = vectors[word1]
65 | v2 = vectors[word2]
66 | d = similarity(v1, v2)
67 | mysim.append(d)
68 | gold.append(float(tline[2]))
69 | else:
70 | drop = drop + 1.0
71 | fin.close()
72 |
73 | corr = stats.spearmanr(mysim, gold)
74 | dataset = os.path.basename(args.dataPath)
75 | print("{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)"
76 | .format(dataset, corr[0] * 100, math.ceil(drop / nwords * 100.0)))
77 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
3 | 4.0.0
4 |
5 | fasttext
6 | fasttext
7 | 0.0.1-SNAPSHOT
8 | jar
9 |
10 | fasttext
11 | http://maven.apache.org
12 |
13 |
14 | UTF-8
15 | 1.6
16 | 1.6
17 |
18 |
19 |
20 |
21 | junit
22 | junit
23 | 4.12
24 | test
25 |
26 |
27 |
28 |
29 |
30 |
31 | maven-compiler-plugin
32 |
33 | UTF-8
34 |
35 | lib
36 |
37 | 1.6
38 | 1.6
39 | 1.6
40 |
41 | 3.3
42 |
43 |
44 | org.apache.maven.plugins
45 | maven-resources-plugin
46 |
47 | UTF-8
48 |
49 | 3.0.1
50 |
51 |
52 | org.apache.maven.plugins
53 | maven-source-plugin
54 | 2.4
55 |
56 |
57 | attach-sources
58 |
59 | jar
60 |
61 |
62 |
63 |
64 |
65 |
66 | maven-assembly-plugin
67 | 2.6
68 |
69 |
70 | jar-with-dependencies
71 |
72 |
73 |
74 | fasttext.Main
75 |
76 |
77 |
78 |
79 |
80 | simple-command
81 | package
82 |
83 | attached
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/Args.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import java.io.IOException;
4 | import java.io.InputStream;
5 | import java.io.OutputStream;
6 |
7 | public class Args {
8 |
9 | public enum model_name {
10 | cbow(1), sg(2), sup(3);
11 |
12 | private int value;
13 |
14 | private model_name(int value) {
15 | this.value = value;
16 | }
17 |
18 | public int getValue() {
19 | return this.value;
20 | }
21 |
22 | public static model_name fromValue(int value) throws IllegalArgumentException {
23 | try {
24 | value -= 1;
25 | return model_name.values()[value];
26 | } catch (ArrayIndexOutOfBoundsException e) {
27 | throw new IllegalArgumentException("Unknown model_name enum value :" + value);
28 | }
29 | }
30 | }
31 |
32 | public enum loss_name {
33 | hs(1), ns(2), softmax(3);
34 | private int value;
35 |
36 | private loss_name(int value) {
37 | this.value = value;
38 | }
39 |
40 | public int getValue() {
41 | return this.value;
42 | }
43 |
44 | public static loss_name fromValue(int value) throws IllegalArgumentException {
45 | try {
46 | value -= 1;
47 | return loss_name.values()[value];
48 | } catch (ArrayIndexOutOfBoundsException e) {
49 | throw new IllegalArgumentException("Unknown loss_name enum value :" + value);
50 | }
51 | }
52 | }
53 |
54 | public String input;
55 | public String output;
56 | public String test;
57 | public double lr = 0.05;
58 | public int lrUpdateRate = 100;
59 | public int dim = 100;
60 | public int ws = 5;
61 | public int epoch = 5;
62 | public int minCount = 5;
63 | public int minCountLabel = 0;
64 | public int neg = 5;
65 | public int wordNgrams = 1;
66 | public loss_name loss = loss_name.ns;
67 | public model_name model = model_name.sg;
68 | public int bucket = 2000000;
69 | public int minn = 3;
70 | public int maxn = 6;
71 | public int thread = 1;
72 | public double t = 1e-4;
73 | public String label = "__label__";
74 | public int verbose = 2;
75 | public String pretrainedVectors = "";
76 |
77 | public void printHelp() {
78 | System.out.println("\n" + "The following arguments are mandatory:\n"
79 | + " -input training file path\n"
80 | + " -output output file path\n\n"
81 | + "The following arguments are optional:\n"
82 | + " -lr learning rate [" + lr + "]\n"
83 | + " -lrUpdateRate change the rate of updates for the learning rate [" + lrUpdateRate + "]\n"
84 | + " -dim size of word vectors [" + dim + "]\n"
85 | + " -ws size of the context window [" + ws + "]\n"
86 | + " -epoch number of epochs [" + epoch + "]\n"
87 | + " -minCount minimal number of word occurences [" + minCount + "]\n"
88 | + " -minCountLabel minimal number of label occurences [" + minCountLabel + "]\n"
89 | + " -neg number of negatives sampled [" + neg + "]\n"
90 | + " -wordNgrams max length of word ngram [" + wordNgrams + "]\n"
91 | + " -loss loss function {ns, hs, softmax} [ns]\n"
92 | + " -bucket number of buckets [" + bucket + "]\n"
93 | + " -minn min length of char ngram [" + minn + "]\n"
94 | + " -maxn max length of char ngram [" + maxn + "]\n"
95 | + " -thread number of threads [" + thread + "]\n"
96 | + " -t sampling threshold [" + t + "]\n"
97 | + " -label labels prefix [" + label + "]\n"
98 | + " -verbose verbosity level [" + verbose + "]\n"
99 | + " -pretrainedVectors pretrained word vectors for supervised learning []");
100 | }
101 |
102 | public void save(OutputStream ofs) throws IOException {
103 | IOUtil ioutil = new IOUtil();
104 | ofs.write(ioutil.intToByteArray(dim));
105 | ofs.write(ioutil.intToByteArray(ws));
106 | ofs.write(ioutil.intToByteArray(epoch));
107 | ofs.write(ioutil.intToByteArray(minCount));
108 | ofs.write(ioutil.intToByteArray(neg));
109 | ofs.write(ioutil.intToByteArray(wordNgrams));
110 | ofs.write(ioutil.intToByteArray(loss.value));
111 | ofs.write(ioutil.intToByteArray(model.value));
112 | ofs.write(ioutil.intToByteArray(bucket));
113 | ofs.write(ioutil.intToByteArray(minn));
114 | ofs.write(ioutil.intToByteArray(maxn));
115 | ofs.write(ioutil.intToByteArray(lrUpdateRate));
116 | ofs.write(ioutil.doubleToByteArray(t));
117 | }
118 |
119 | public void load(InputStream input) throws IOException {
120 | IOUtil ioutil = new IOUtil();
121 | dim = ioutil.readInt(input);
122 | ws = ioutil.readInt(input);
123 | epoch = ioutil.readInt(input);
124 | minCount = ioutil.readInt(input);
125 | neg = ioutil.readInt(input);
126 | wordNgrams = ioutil.readInt(input);
127 | loss = loss_name.fromValue(ioutil.readInt(input));
128 | model = model_name.fromValue(ioutil.readInt(input));
129 | bucket = ioutil.readInt(input);
130 | minn = ioutil.readInt(input);
131 | maxn = ioutil.readInt(input);
132 | lrUpdateRate = ioutil.readInt(input);
133 | t = ioutil.readDouble(input);
134 | }
135 |
136 | public void parseArgs(String[] args) {
137 | String command = args[0];
138 | if ("supervised".equalsIgnoreCase(command)) {
139 | model = model_name.sup;
140 | loss = loss_name.softmax;
141 | minCount = 1;
142 | minn = 0;
143 | maxn = 0;
144 | lr = 0.1;
145 | } else if ("cbow".equalsIgnoreCase(command)) {
146 | model = model_name.cbow;
147 | }
148 | int ai = 1;
149 | while (ai < args.length) {
150 | if (args[ai].charAt(0) != '-') {
151 | System.out.println("Provided argument without a dash! Usage:");
152 | printHelp();
153 | System.exit(1);
154 | }
155 | if ("-h".equals(args[ai])) {
156 | System.out.println("Here is the help! Usage:");
157 | printHelp();
158 | System.exit(1);
159 | } else if ("-input".equals(args[ai])) {
160 | input = args[ai + 1];
161 | } else if ("-test".equals(args[ai])) {
162 | test = args[ai + 1];
163 | } else if ("-output".equals(args[ai])) {
164 | output = args[ai + 1];
165 | } else if ("-lr".equals(args[ai])) {
166 | lr = Double.parseDouble(args[ai + 1]);
167 | } else if ("-lrUpdateRate".equals(args[ai])) {
168 | lrUpdateRate = Integer.parseInt(args[ai + 1]);
169 | } else if ("-dim".equals(args[ai])) {
170 | dim = Integer.parseInt(args[ai + 1]);
171 | } else if ("-ws".equals(args[ai])) {
172 | ws = Integer.parseInt(args[ai + 1]);
173 | } else if ("-epoch".equals(args[ai])) {
174 | epoch = Integer.parseInt(args[ai + 1]);
175 | } else if ("-minCount".equals(args[ai])) {
176 | minCount = Integer.parseInt(args[ai + 1]);
177 | } else if ("-minCountLabel".equals(args[ai])) {
178 | minCountLabel = Integer.parseInt(args[ai + 1]);
179 | } else if ("-neg".equals(args[ai])) {
180 | neg = Integer.parseInt(args[ai + 1]);
181 | } else if ("-wordNgrams".equals(args[ai])) {
182 | wordNgrams = Integer.parseInt(args[ai + 1]);
183 | } else if ("-loss".equals(args[ai])) {
184 | if ("hs".equalsIgnoreCase(args[ai + 1])) {
185 | loss = loss_name.hs;
186 | } else if ("ns".equalsIgnoreCase(args[ai + 1])) {
187 | loss = loss_name.ns;
188 | } else if ("softmax".equalsIgnoreCase(args[ai + 1])) {
189 | loss = loss_name.softmax;
190 | } else {
191 | System.out.println("Unknown loss: " + args[ai + 1]);
192 | printHelp();
193 | System.exit(1);
194 | }
195 | } else if ("-bucket".equals(args[ai])) {
196 | bucket = Integer.parseInt(args[ai + 1]);
197 | } else if ("-minn".equals(args[ai])) {
198 | minn = Integer.parseInt(args[ai + 1]);
199 | } else if ("-maxn".equals(args[ai])) {
200 | maxn = Integer.parseInt(args[ai + 1]);
201 | } else if ("-thread".equals(args[ai])) {
202 | thread = Integer.parseInt(args[ai + 1]);
203 | } else if ("-t".equals(args[ai])) {
204 | t = Double.parseDouble(args[ai + 1]);
205 | } else if ("-label".equals(args[ai])) {
206 | label = args[ai + 1];
207 | } else if ("-verbose".equals(args[ai])) {
208 | verbose = Integer.parseInt(args[ai + 1]);
209 | } else if ("-pretrainedVectors".equals(args[ai])) {
210 | pretrainedVectors = args[ai + 1];
211 | } else {
212 | System.out.println("Unknown argument: " + args[ai]);
213 | printHelp();
214 | System.exit(1);
215 | }
216 | ai += 2;
217 | }
218 | if (Utils.isEmpty(input) || Utils.isEmpty(output)) {
219 | System.out.println("Empty input or output path.");
220 | printHelp();
221 | System.exit(1);
222 | }
223 | if (wordNgrams <= 1 && maxn == 0) {
224 | bucket = 0;
225 | }
226 | }
227 |
228 | @Override
229 | public String toString() {
230 | StringBuilder builder = new StringBuilder();
231 | builder.append("Args [input=");
232 | builder.append(input);
233 | builder.append(", output=");
234 | builder.append(output);
235 | builder.append(", test=");
236 | builder.append(test);
237 | builder.append(", lr=");
238 | builder.append(lr);
239 | builder.append(", lrUpdateRate=");
240 | builder.append(lrUpdateRate);
241 | builder.append(", dim=");
242 | builder.append(dim);
243 | builder.append(", ws=");
244 | builder.append(ws);
245 | builder.append(", epoch=");
246 | builder.append(epoch);
247 | builder.append(", minCount=");
248 | builder.append(minCount);
249 | builder.append(", minCountLabel=");
250 | builder.append(minCountLabel);
251 | builder.append(", neg=");
252 | builder.append(neg);
253 | builder.append(", wordNgrams=");
254 | builder.append(wordNgrams);
255 | builder.append(", loss=");
256 | builder.append(loss);
257 | builder.append(", model=");
258 | builder.append(model);
259 | builder.append(", bucket=");
260 | builder.append(bucket);
261 | builder.append(", minn=");
262 | builder.append(minn);
263 | builder.append(", maxn=");
264 | builder.append(maxn);
265 | builder.append(", thread=");
266 | builder.append(thread);
267 | builder.append(", t=");
268 | builder.append(t);
269 | builder.append(", label=");
270 | builder.append(label);
271 | builder.append(", verbose=");
272 | builder.append(verbose);
273 | builder.append(", pretrainedVectors=");
274 | builder.append(pretrainedVectors);
275 | builder.append("]");
276 | return builder.toString();
277 | }
278 |
279 | }
280 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/Dictionary.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import java.io.IOException;
4 | import java.io.InputStream;
5 | import java.io.OutputStream;
6 | import java.util.ArrayList;
7 | import java.util.Collections;
8 | import java.util.Comparator;
9 | import java.util.HashMap;
10 | import java.util.Iterator;
11 | import java.util.List;
12 | import java.util.Map;
13 | import java.util.Random;
14 | import java.math.BigInteger;
15 |
16 | import fasttext.Args.model_name;
17 | import fasttext.io.BufferedLineReader;
18 | import fasttext.io.LineReader;
19 |
20 | public class Dictionary {
21 |
22 | private static final int MAX_VOCAB_SIZE = 30000000;
23 | private static final int MAX_LINE_SIZE = 1024;
24 | private static final Integer WORDID_DEFAULT = -1;
25 |
26 | private static final String EOS = "";
27 | private static final String BOW = "<";
28 | private static final String EOW = ">";
29 |
30 | public enum entry_type {
31 | word(0), label(1);
32 |
33 | private int value;
34 |
35 | private entry_type(int value) {
36 | this.value = value;
37 | }
38 |
39 | public int getValue() {
40 | return this.value;
41 | }
42 |
43 | public static entry_type fromValue(int value) throws IllegalArgumentException {
44 | try {
45 | return entry_type.values()[value];
46 | } catch (ArrayIndexOutOfBoundsException e) {
47 | throw new IllegalArgumentException("Unknown entry_type enum value :" + value);
48 | }
49 | }
50 |
51 | @Override
52 | public String toString() {
53 | return value == 0 ? "word" : value == 1 ? "label" : "unknown";
54 | }
55 | }
56 |
57 | public class entry {
58 | String word;
59 | long count;
60 | entry_type type;
61 | List subwords;
62 |
63 | @Override
64 | public String toString() {
65 | StringBuilder builder = new StringBuilder();
66 | builder.append("entry [word=");
67 | builder.append(word);
68 | builder.append(", count=");
69 | builder.append(count);
70 | builder.append(", type=");
71 | builder.append(type);
72 | builder.append(", subwords=");
73 | builder.append(subwords);
74 | builder.append("]");
75 | return builder.toString();
76 | }
77 |
78 | }
79 |
80 | private List words_;
81 | private List pdiscard_;
82 | private Map word2int_;
83 | private int size_;
84 | private int nwords_;
85 | private int nlabels_;
86 | private long ntokens_;
87 |
88 | private Args args_;
89 |
90 | private String charsetName_ = "UTF-8";
91 | private Class extends LineReader> lineReaderClass_ = BufferedLineReader.class;
92 |
93 | public Dictionary(Args args) {
94 | args_ = args;
95 | size_ = 0;
96 | nwords_ = 0;
97 | nlabels_ = 0;
98 | ntokens_ = 0;
99 | word2int_ = new HashMap(MAX_VOCAB_SIZE);
100 | words_ = new ArrayList(MAX_VOCAB_SIZE);
101 | }
102 |
103 | public long find(final String w) {
104 | long h = hash(w) % MAX_VOCAB_SIZE;
105 | entry e = null;
106 | while (Utils.mapGetOrDefault(word2int_, h, WORDID_DEFAULT) != WORDID_DEFAULT
107 | && ((e = words_.get(word2int_.get(h))) != null && !w.equals(e.word))) {
108 | h = (h + 1) % MAX_VOCAB_SIZE;
109 | }
110 | return h;
111 | }
112 |
113 | public void add(final String w) {
114 | long h = find(w);
115 | ntokens_++;
116 | if (Utils.mapGetOrDefault(word2int_, h, WORDID_DEFAULT) == WORDID_DEFAULT) {
117 | entry e = new entry();
118 | e.word = w;
119 | e.count = 1;
120 | e.type = w.startsWith(args_.label) ? entry_type.label : entry_type.word;
121 | words_.add(e);
122 | word2int_.put(h, size_++);
123 | } else {
124 | words_.get(word2int_.get(h)).count++;
125 | }
126 | }
127 |
128 | public int nwords() {
129 | return nwords_;
130 | }
131 |
132 | public int nlabels() {
133 | return nlabels_;
134 | }
135 |
136 | public long ntokens() {
137 | return ntokens_;
138 | }
139 |
140 | public final List getNgrams(int i) {
141 | Utils.checkArgument(i >= 0);
142 | Utils.checkArgument(i < nwords_);
143 | return words_.get(i).subwords;
144 | }
145 |
146 | public final List getNgrams(final String word) {
147 | List ngrams = new ArrayList();
148 | int i = getId(word);
149 | if (i >= 0) {
150 | ngrams = words_.get(i).subwords;
151 | } else {
152 | computeNgrams(BOW + word + EOW, ngrams);
153 | }
154 | return ngrams;
155 | }
156 |
157 | public boolean discard(int id, float rand) {
158 | Utils.checkArgument(id >= 0);
159 | Utils.checkArgument(id < nwords_);
160 | if (args_.model == model_name.sup)
161 | return false;
162 | return rand > pdiscard_.get(id);
163 | }
164 |
165 | public int getId(final String w) {
166 | long h = find(w);
167 | return Utils.mapGetOrDefault(word2int_, h, WORDID_DEFAULT);
168 | }
169 |
170 | public entry_type getType(int id) {
171 | Utils.checkArgument(id >= 0);
172 | Utils.checkArgument(id < size_);
173 | return words_.get(id).type;
174 | }
175 |
176 | public String getWord(int id) {
177 | Utils.checkArgument(id >= 0);
178 | Utils.checkArgument(id < size_);
179 | return words_.get(id).word;
180 | }
181 |
182 | /**
183 | * String FNV-1a Hash
184 | *
185 | * @param str
186 | * @return
187 | */
188 | public long hash(final String str) {
189 | int h = (int) 2166136261L;// 0xffffffc5;
190 | for (byte strByte : str.getBytes()) {
191 | h = (h ^ strByte) * 16777619; // FNV-1a
192 | // h = (h * 16777619) ^ strByte; //FNV-1
193 | }
194 | return h & 0xffffffffL;
195 | }
196 |
197 | public void computeNgrams(final String word, List ngrams) {
198 | for (int i = 0; i < word.length(); i++) {
199 | StringBuilder ngram = new StringBuilder();
200 | if (charMatches(word.charAt(i))) {
201 | continue;
202 | }
203 | for (int j = i, n = 1; j < word.length() && n <= args_.maxn; n++) {
204 | ngram.append(word.charAt(j++));
205 | while (j < word.length() && charMatches(word.charAt(j))) {
206 | ngram.append(word.charAt(j++));
207 | }
208 | if (n >= args_.minn && !(n == 1 && (i == 0 || j == word.length()))) {
209 | int h = (int) (nwords_ + (hash(ngram.toString()) % args_.bucket));
210 | if (h < 0) {
211 | System.err.println("computeNgrams h<0: " + h + " on word: " + word);
212 | }
213 | ngrams.add(h);
214 | }
215 | }
216 | }
217 | }
218 |
219 | private boolean charMatches(char ch) {
220 | if (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\f' || ch == '\r') {
221 | return true;
222 | }
223 | return false;
224 | }
225 |
226 | public void initNgrams() {
227 | for (int i = 0; i < size_; i++) {
228 | String word = BOW + words_.get(i).word + EOW;
229 | entry e = words_.get(i);
230 | if (e.subwords == null) {
231 | e.subwords = new ArrayList();
232 | }
233 | e.subwords.add(i);
234 | computeNgrams(word, e.subwords);
235 | }
236 | }
237 |
238 | public void readFromFile(String file) throws IOException, Exception {
239 | LineReader lineReader = null;
240 |
241 | try {
242 | lineReader = lineReaderClass_.getConstructor(String.class, String.class).newInstance(file, charsetName_);
243 | long minThreshold = 1;
244 | String[] lineTokens;
245 | while ((lineTokens = lineReader.readLineTokens()) != null) {
246 | for (int i = 0; i <= lineTokens.length; i++) {
247 | if (i == lineTokens.length) {
248 | add(EOS);
249 | } else {
250 | if (Utils.isEmpty(lineTokens[i])) {
251 | continue;
252 | }
253 | add(lineTokens[i]);
254 | }
255 | if (ntokens_ % 1000000 == 0 && args_.verbose > 1) {
256 | System.out.printf("\rRead %dM words", ntokens_ / 1000000);
257 | }
258 | if (size_ > 0.75 * MAX_VOCAB_SIZE) {
259 | minThreshold++;
260 | threshold(minThreshold, minThreshold);
261 | }
262 | }
263 | }
264 | } finally {
265 | if (lineReader != null) {
266 | lineReader.close();
267 | }
268 | }
269 | threshold(args_.minCount, args_.minCountLabel);
270 | initTableDiscard();
271 | if (model_name.cbow == args_.model || model_name.sg == args_.model) {
272 | initNgrams();
273 | }
274 | if (args_.verbose > 0) {
275 | System.out.printf("\rRead %dM words\n", ntokens_ / 1000000);
276 | System.out.println("Number of words: " + nwords_);
277 | System.out.println("Number of labels: " + nlabels_);
278 | }
279 | if (size_ == 0) {
280 | System.err.println("Empty vocabulary. Try a smaller -minCount value.");
281 | System.exit(1);
282 | }
283 | }
284 |
285 | public void threshold(long t, long tl) {
286 | Collections.sort(words_, entry_comparator);
287 | Iterator iterator = words_.iterator();
288 | while (iterator.hasNext()) {
289 | entry _entry = iterator.next();
290 | if ((entry_type.word == _entry.type && _entry.count < t)
291 | || (entry_type.label == _entry.type && _entry.count < tl)) {
292 | iterator.remove();
293 | }
294 | }
295 | ((ArrayList) words_).trimToSize();
296 | size_ = 0;
297 | nwords_ = 0;
298 | nlabels_ = 0;
299 | // word2int_.clear();
300 | word2int_ = new HashMap(words_.size());
301 | for (entry _entry : words_) {
302 | long h = find(_entry.word);
303 | word2int_.put(h, size_++);
304 | if (entry_type.word == _entry.type) {
305 | nwords_++;
306 | } else if (entry_type.label == _entry.type) {
307 | nlabels_++;
308 | }
309 | }
310 | }
311 |
312 | private transient Comparator entry_comparator = new Comparator() {
313 | @Override
314 | public int compare(entry o1, entry o2) {
315 | int cmp = (o1.type.value < o2.type.value) ? -1 : ((o1.type.value == o2.type.value) ? 0 : 1);
316 | if (cmp == 0) {
317 | cmp = (o2.count < o1.count) ? -1 : ((o2.count == o1.count) ? 0 : 1);
318 | }
319 | return cmp;
320 | }
321 | };
322 |
323 | public void initTableDiscard() {
324 | pdiscard_ = new ArrayList(size_);
325 | for (int i = 0; i < size_; i++) {
326 | float f = (float) (words_.get(i).count) / (float) ntokens_;
327 | pdiscard_.add((float) (Math.sqrt(args_.t / f) + args_.t / f));
328 | }
329 | }
330 |
331 | public List getCounts(entry_type type) {
332 | List counts = entry_type.label == type ? new ArrayList(nlabels()) : new ArrayList(nwords());
333 | for (entry w : words_) {
334 | if (w.type == type)
335 | counts.add(w.count);
336 | }
337 | return counts;
338 | }
339 |
340 | public void addNgrams(List line, int n) {
341 | if (n <= 1) {
342 | return;
343 | }
344 | int line_size = line.size();
345 | for (int i = 0; i < line_size; i++) {
346 | BigInteger h = BigInteger.valueOf(line.get(i));
347 | BigInteger r = BigInteger.valueOf(116049371l);
348 | BigInteger b = BigInteger.valueOf(args_.bucket);
349 |
350 | for (int j = i + 1; j < line_size && j < i + n; j++) {
351 | h = h.multiply(r).add(BigInteger.valueOf(line.get(j)));;
352 | line.add(nwords_ + h.remainder(b).intValue());
353 | }
354 | }
355 | }
356 |
357 | public int getLine(String[] tokens, List words, List labels, Random urd) {
358 | int ntokens = 0;
359 | words.clear();
360 | labels.clear();
361 | if (tokens != null) {
362 | for (int i = 0; i <= tokens.length; i++) {
363 | if (i < tokens.length && Utils.isEmpty(tokens[i])) {
364 | continue;
365 | }
366 | int wid = i == tokens.length ? getId(EOS) : getId(tokens[i]);
367 | if (wid < 0) {
368 | continue;
369 | }
370 | entry_type type = getType(wid);
371 | ntokens++;
372 | if (type == entry_type.word && !discard(wid, Utils.randomFloat(urd, 0, 1))) {
373 | words.add(wid);
374 | }
375 | if (type == entry_type.label) {
376 | labels.add(wid - nwords_);
377 | }
378 | if (words.size() > MAX_LINE_SIZE && args_.model != model_name.sup) {
379 | break;
380 | }
381 | // if (EOS == tokens[i]){
382 | // break;
383 | // }
384 | }
385 | }
386 | return ntokens;
387 | }
388 |
389 | public String getLabel(int lid) {
390 | Utils.checkArgument(lid >= 0);
391 | Utils.checkArgument(lid < nlabels_);
392 | return words_.get(lid + nwords_).word;
393 | }
394 |
395 | public void save(OutputStream ofs) throws IOException {
396 | IOUtil ioutil = new IOUtil();
397 | ofs.write(ioutil.intToByteArray(size_));
398 | ofs.write(ioutil.intToByteArray(nwords_));
399 | ofs.write(ioutil.intToByteArray(nlabels_));
400 | ofs.write(ioutil.longToByteArray(ntokens_));
401 | // Charset charset = Charset.forName("UTF-8");
402 | for (int i = 0; i < size_; i++) {
403 | entry e = words_.get(i);
404 | ofs.write(e.word.getBytes());
405 | ofs.write(0);
406 | ofs.write(ioutil.longToByteArray(e.count));
407 | ofs.write(ioutil.intToByte(e.type.value));
408 | }
409 | }
410 |
411 | public void load(InputStream ifs) throws IOException {
412 | // words_.clear();
413 | // word2int_.clear();
414 | IOUtil ioutil = new IOUtil();
415 | size_ = ioutil.readInt(ifs);
416 | nwords_ = ioutil.readInt(ifs);
417 | nlabels_ = ioutil.readInt(ifs);
418 | ntokens_ = ioutil.readLong(ifs);
419 | long pruneidx_size = ioutil.readLong(ifs);
420 |
421 | word2int_ = new HashMap(size_);
422 | words_ = new ArrayList(size_);
423 |
424 | for (int i = 0; i < size_; i++) {
425 | entry e = new entry();
426 | e.word = ioutil.readString(ifs);
427 | e.count = ioutil.readLong(ifs);
428 | e.type = entry_type.fromValue(ioutil.readByte(ifs));
429 | words_.add(e);
430 | word2int_.put(find(e.word), i);
431 | }
432 | initTableDiscard();
433 | //if (model_name.cbow == args_.model || model_name.sg == args_.model) {
434 | initNgrams();
435 | //}
436 | }
437 |
438 | @Override
439 | public String toString() {
440 | StringBuilder builder = new StringBuilder();
441 | builder.append("Dictionary [words_=");
442 | builder.append(words_);
443 | builder.append(", pdiscard_=");
444 | builder.append(pdiscard_);
445 | builder.append(", word2int_=");
446 | builder.append(word2int_);
447 | builder.append(", size_=");
448 | builder.append(size_);
449 | builder.append(", nwords_=");
450 | builder.append(nwords_);
451 | builder.append(", nlabels_=");
452 | builder.append(nlabels_);
453 | builder.append(", ntokens_=");
454 | builder.append(ntokens_);
455 | builder.append("]");
456 | return builder.toString();
457 | }
458 |
459 | public List getWords() {
460 | return words_;
461 | }
462 |
463 | public List getPdiscard() {
464 | return pdiscard_;
465 | }
466 |
467 | public Map getWord2int() {
468 | return word2int_;
469 | }
470 |
471 | public int getSize() {
472 | return size_;
473 | }
474 |
475 | public Args getArgs() {
476 | return args_;
477 | }
478 |
479 | public String getCharsetName() {
480 | return charsetName_;
481 | }
482 |
483 | public Class extends LineReader> getLineReaderClass() {
484 | return lineReaderClass_;
485 | }
486 |
487 | public void setCharsetName(String charsetName) {
488 | this.charsetName_ = charsetName;
489 | }
490 |
491 | public void setLineReaderClass(Class extends LineReader> lineReaderClass) {
492 | this.lineReaderClass_ = lineReaderClass;
493 | }
494 |
495 | }
496 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/FastText.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import java.io.BufferedInputStream;
4 | import java.io.BufferedOutputStream;
5 | import java.io.BufferedReader;
6 | import java.io.DataInputStream;
7 | import java.io.File;
8 | import java.io.FileInputStream;
9 | import java.io.FileOutputStream;
10 | import java.io.IOException;
11 | import java.io.InputStream;
12 | import java.io.InputStreamReader;
13 | import java.io.OutputStream;
14 | import java.io.OutputStreamWriter;
15 | import java.io.Writer;
16 | import java.text.DecimalFormat;
17 | import java.util.ArrayList;
18 | import java.util.List;
19 | import java.util.concurrent.atomic.AtomicLong;
20 |
21 | import fasttext.Args.model_name;
22 | import fasttext.Dictionary.entry;
23 | import fasttext.Dictionary.entry_type;
24 | import fasttext.io.BufferedLineReader;
25 | import fasttext.io.LineReader;
26 |
27 | /**
28 | * FastText class, can be used as a lib in other projects
29 | *
30 | * @author Ivan
31 | *
32 | */
33 | public strictfp class FastText {
34 |
35 | private Args args_;
36 | private Dictionary dict_;
37 | private Matrix input_;
38 | private Matrix output_;
39 | private Model model_;
40 |
41 | private AtomicLong tokenCount_;
42 | private long start_;
43 |
44 | private String charsetName_ = "UTF-8";
45 | private Class extends LineReader> lineReaderClass_ = BufferedLineReader.class;
46 |
47 | public String[] getWords() {
48 | List entries = dict_.getWords();
49 | int n = entries.size();
50 | String[] words = new String[n];
51 | for (int i = 0; i < n; i++) {
52 | words[i] = entries.get(i).word;
53 | }
54 | return words;
55 | }
56 |
57 | public void getVector(Vector vec, final String word) {
58 | final List ngrams = dict_.getNgrams(word);
59 | vec.zero();
60 | for (Integer it : ngrams) {
61 | vec.addRow(input_, it);
62 | }
63 | if (ngrams.size() > 0) {
64 | vec.mul(1.0f / ngrams.size());
65 | }
66 | }
67 |
68 | public void saveVectors() throws IOException {
69 | if (Utils.isEmpty(args_.output)) {
70 | if (args_.verbose > 1) {
71 | System.out.println("output is empty, skip save vector file");
72 | }
73 | return;
74 | }
75 |
76 | File file = new File(args_.output + ".vec");
77 | if (file.exists()) {
78 | file.delete();
79 | }
80 | if (file.getParentFile() != null) {
81 | file.getParentFile().mkdirs();
82 | }
83 | if (args_.verbose > 1) {
84 | System.out.println("Saving Vectors to " + file.getCanonicalPath().toString());
85 | }
86 | Vector vec = new Vector(args_.dim);
87 | DecimalFormat df = new DecimalFormat("0.#####");
88 | Writer writer = new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(file)), "UTF-8");
89 | try {
90 | writer.write(dict_.nwords() + " " + args_.dim + "\n");
91 | for (int i = 0; i < dict_.nwords(); i++) {
92 | String word = dict_.getWord(i);
93 | getVector(vec, word);
94 | writer.write(word);
95 | for (int j = 0; j < vec.m_; j++) {
96 | writer.write(" ");
97 | writer.write(df.format(vec.data_[j]));
98 | }
99 | writer.write("\n");
100 | }
101 | } finally {
102 | writer.flush();
103 | writer.close();
104 | }
105 | }
106 |
107 | public void saveModel() throws IOException {
108 | if (Utils.isEmpty(args_.output)) {
109 | if (args_.verbose > 1) {
110 | System.out.println("output is empty, skip save model file");
111 | }
112 | return;
113 | }
114 |
115 | File file = new File(args_.output + ".bin");
116 | if (file.exists()) {
117 | file.delete();
118 | }
119 | if (file.getParentFile() != null) {
120 | file.getParentFile().mkdirs();
121 | }
122 | if (args_.verbose > 1) {
123 | System.out.println("Saving model to " + file.getCanonicalPath().toString());
124 | }
125 | OutputStream ofs = new BufferedOutputStream(new FileOutputStream(file));
126 | try {
127 | args_.save(ofs);
128 | dict_.save(ofs);
129 | input_.save(ofs);
130 | output_.save(ofs);
131 | } finally {
132 | ofs.flush();
133 | ofs.close();
134 | }
135 | }
136 |
137 | /**
138 | * Load binary model file.
139 | *
140 | * @param filename
141 | * @throws IOException
142 | */
143 | public void loadModel(String filename) throws IOException {
144 | File file = new File(filename);
145 | if (!(file.exists() && file.isFile() && file.canRead())) {
146 | throw new IOException("Model file cannot be opened for loading! " + filename);
147 | }
148 | loadModel(new FileInputStream(file));
149 | }
150 |
151 | public void loadModel(InputStream is) throws IOException {
152 | DataInputStream dis = null;
153 | BufferedInputStream bis = null;
154 | try {
155 | bis = new BufferedInputStream(is);
156 | dis = new DataInputStream(bis);
157 |
158 | args_ = new Args();
159 | dict_ = new Dictionary(args_);
160 | input_ = new Matrix();
161 | output_ = new Matrix();
162 |
163 | // Read magic number and version
164 | IOUtil ioutil = new IOUtil();
165 | int magic_number = ioutil.readInt(dis);
166 | int version = ioutil.readInt(dis);
167 |
168 | args_.load(dis);
169 | dict_.load(dis);
170 | boolean quant = dis.readBoolean();
171 | input_.load(dis);
172 | quant = dis.readBoolean();
173 | output_.load(dis);
174 |
175 | model_ = new Model(input_, output_, args_, 0);
176 | if (args_.model == model_name.sup) {
177 | model_.setTargetCounts(dict_.getCounts(entry_type.label));
178 | } else {
179 | model_.setTargetCounts(dict_.getCounts(entry_type.word));
180 | }
181 | } finally {
182 | if (bis != null) {
183 | bis.close();
184 | }
185 | if (dis != null) {
186 | dis.close();
187 | }
188 | }
189 | }
190 |
191 | public void printInfo(float progress, float loss) {
192 | float t = (float) (System.currentTimeMillis() - start_) / 1000;
193 | float ws = (float) (tokenCount_.get()) / t;
194 | float wst = (float) (tokenCount_.get()) / t / args_.thread;
195 | float lr = (float) (args_.lr * (1.0f - progress));
196 | int eta = (int) (t / progress * (1 - progress));
197 | int etah = eta / 3600;
198 | int etam = (eta - etah * 3600) / 60;
199 | System.out.printf("\rProgress: %.1f%% words/sec: %d words/sec/thread: %d lr: %.6f loss: %.6f eta: %d h %d m",
200 | 100 * progress, (int) ws, (int) wst, lr, loss, etah, etam);
201 | }
202 |
203 | public void supervised(Model model, float lr, final List line, final List labels) {
204 | if (labels.size() == 0 || line.size() == 0)
205 | return;
206 | int i = Utils.randomInt(model.rng, 1, labels.size()) - 1;
207 | model.update(line, labels.get(i), lr);
208 | }
209 |
210 | public void cbow(Model model, float lr, final List line) {
211 | List bow = new ArrayList();
212 | for (int w = 0; w < line.size(); w++) {
213 | int boundary = Utils.randomInt(model.rng, 1, args_.ws);
214 | bow.clear();
215 | for (int c = -boundary; c <= boundary; c++) {
216 | if (c != 0 && w + c >= 0 && w + c < line.size()) {
217 | final List ngrams = dict_.getNgrams(line.get(w + c));
218 | bow.addAll(ngrams);
219 | }
220 | }
221 | model.update(bow, line.get(w), lr);
222 | }
223 | }
224 |
225 | public void skipgram(Model model, float lr, final List line) {
226 | for (int w = 0; w < line.size(); w++) {
227 | int boundary = Utils.randomInt(model.rng, 1, args_.ws);
228 | final List ngrams = dict_.getNgrams(line.get(w));
229 | for (int c = -boundary; c <= boundary; c++) {
230 | if (c != 0 && w + c >= 0 && w + c < line.size()) {
231 | model.update(ngrams, line.get(w + c), lr);
232 | }
233 | }
234 | }
235 | }
236 |
237 | public void test(InputStream in, int k) throws IOException, Exception {
238 | int nexamples = 0, nlabels = 0;
239 | double precision = 0.0f;
240 | List line = new ArrayList();
241 | List labels = new ArrayList();
242 |
243 | LineReader lineReader = null;
244 | try {
245 | lineReader = lineReaderClass_.getConstructor(InputStream.class, String.class).newInstance(in, charsetName_);
246 | String[] lineTokens;
247 | while ((lineTokens = lineReader.readLineTokens()) != null) {
248 | if (lineTokens.length == 1 && "quit".equals(lineTokens[0])) {
249 | break;
250 | }
251 | dict_.getLine(lineTokens, line, labels, model_.rng);
252 | dict_.addNgrams(line, args_.wordNgrams);
253 | if (labels.size() > 0 && line.size() > 0) {
254 | List> modelPredictions = new ArrayList>();
255 | model_.predict(line, k, modelPredictions);
256 | for (Pair pair : modelPredictions) {
257 | if (labels.contains(pair.getValue())) {
258 | precision += 1.0f;
259 | }
260 | }
261 | nexamples++;
262 | nlabels += labels.size();
263 | // } else {
264 | // System.out.println("FAIL Test line: " + lineTokens +
265 | // "labels: " + labels + " line: " + line);
266 | }
267 | }
268 | } finally {
269 | if (lineReader != null) {
270 | lineReader.close();
271 | }
272 | }
273 |
274 | System.out.printf("P@%d: %.3f%n", k, precision / (k * nexamples));
275 | System.out.printf("R@%d: %.3f%n", k, precision / nlabels);
276 | System.out.println("Number of examples: " + nexamples);
277 | }
278 |
279 | /**
280 | * Thread-safe predict api
281 | *
282 | * @param lineTokens
283 | * @param k
284 | * @return
285 | */
286 | public List> predict(String[] lineTokens, int k) {
287 | List words = new ArrayList();
288 | List labels = new ArrayList();
289 | dict_.getLine(lineTokens, words, labels, model_.rng);
290 | dict_.addNgrams(words, args_.wordNgrams);
291 |
292 | if (words.isEmpty()) {
293 | return null;
294 | }
295 |
296 | Vector hidden = new Vector(args_.dim);
297 | Vector output = new Vector(dict_.nlabels());
298 | List> modelPredictions = new ArrayList>(k + 1);
299 |
300 | model_.predict(words, k, modelPredictions, hidden, output);
301 |
302 | List> predictions = new ArrayList>(k);
303 | for (Pair pair : modelPredictions) {
304 | predictions.add(new Pair(Math.exp(pair.getKey().doubleValue()), dict_.getLabel(pair.getValue())));
305 | }
306 | return predictions;
307 | }
308 |
309 | public void predict(String[] lineTokens, int k, List> predictions) throws IOException {
310 | List words = new ArrayList();
311 | List labels = new ArrayList();
312 | dict_.getLine(lineTokens, words, labels, model_.rng);
313 | dict_.addNgrams(words, args_.wordNgrams);
314 |
315 | if (words.isEmpty()) {
316 | return;
317 | }
318 | List> modelPredictions = new ArrayList>(k + 1);
319 | model_.predict(words, k, modelPredictions);
320 | predictions.clear();
321 | for (Pair pair : modelPredictions) {
322 | predictions.add(new Pair(pair.getKey(), dict_.getLabel(pair.getValue())));
323 | }
324 | }
325 |
326 | public void predict(InputStream in, int k, boolean print_prob) throws IOException, Exception {
327 | List> predictions = new ArrayList>(k);
328 |
329 | LineReader lineReader = null;
330 |
331 | try {
332 | lineReader = lineReaderClass_.getConstructor(InputStream.class, String.class).newInstance(in, charsetName_);
333 | String[] lineTokens;
334 | while ((lineTokens = lineReader.readLineTokens()) != null) {
335 | if (lineTokens.length == 1 && "quit".equals(lineTokens[0])) {
336 | break;
337 | }
338 | predictions.clear();
339 | predict(lineTokens, k, predictions);
340 | if (predictions.isEmpty()) {
341 | System.out.println("n/a");
342 | continue;
343 | }
344 | for (Pair pair : predictions) {
345 | System.out.print(pair.getValue());
346 | if (print_prob) {
347 | System.out.printf(" %f", Math.exp(pair.getKey()));
348 | }
349 | }
350 | System.out.println();
351 | }
352 | } finally {
353 | if (lineReader != null) {
354 | lineReader.close();
355 | }
356 | }
357 | }
358 |
359 | public void wordVectors() {
360 | Vector vec = new Vector(args_.dim);
361 | LineReader lineReader = null;
362 | try {
363 | lineReader = lineReaderClass_.getConstructor(InputStream.class, String.class).newInstance(System.in,
364 | charsetName_);
365 | String word;
366 | while (!Utils.isEmpty((word = lineReader.readLine()))) {
367 | getVector(vec, word);
368 | System.out.println(word + " " + vec);
369 | }
370 | } catch (Exception e) {
371 | e.printStackTrace();
372 | } finally {
373 | if (lineReader != null) {
374 | try {
375 | lineReader.close();
376 | } catch (IOException e) {
377 | e.printStackTrace();
378 | }
379 | }
380 | }
381 | }
382 |
383 | public void textVectors() {
384 | List line = new ArrayList();
385 | List labels = new ArrayList();
386 | Vector vec = new Vector(args_.dim);
387 | LineReader lineReader = null;
388 | try {
389 | lineReader = lineReaderClass_.getConstructor(InputStream.class, String.class).newInstance(System.in,
390 | charsetName_);
391 | String[] lineTokens;
392 | while ((lineTokens = lineReader.readLineTokens()) != null) {
393 | if (lineTokens.length == 1 && "quit".equals(lineTokens[0])) {
394 | break;
395 | }
396 | dict_.getLine(lineTokens, line, labels, model_.rng);
397 | dict_.addNgrams(line, args_.wordNgrams);
398 | vec.zero();
399 | for (Integer it : line) {
400 | vec.addRow(input_, it);
401 | }
402 | if (!line.isEmpty()) {
403 | vec.mul(1.0f / line.size());
404 | }
405 | System.out.println(vec);
406 | }
407 | } catch (Exception e) {
408 | e.printStackTrace();
409 | } finally {
410 | if (lineReader != null) {
411 | try {
412 | lineReader.close();
413 | } catch (IOException e) {
414 | e.printStackTrace();
415 | }
416 | }
417 | }
418 | }
419 |
420 | public Vector sentenceVectors(String[] tokens) {
421 | Vector vec = new Vector(args_.dim);
422 | Vector svec = new Vector(args_.dim);
423 | svec.zero();
424 | int count = 0;
425 | for (int i=0; i < tokens.length; i++) {
426 | getVector(vec, tokens[i]);
427 | float norm = vec.norm();
428 | if (norm != 0.0f) {
429 | vec.mul((float) (1.0 / vec.norm()));
430 | svec.addVector(vec);
431 | count++;
432 | }
433 | }
434 | if (count != 0) svec.mul((float) (1.0 / count));
435 | else svec.zero();
436 | return svec;
437 | }
438 |
439 | public void printVectors() {
440 | if (args_.model == model_name.sup) {
441 | textVectors();
442 | } else {
443 | wordVectors();
444 | }
445 | }
446 |
447 | public class TrainThread extends Thread {
448 | final FastText ft;
449 | int threadId;
450 |
451 | public TrainThread(FastText ft, int threadId) {
452 | super("FT-TrainThread-" + threadId);
453 | this.ft = ft;
454 | this.threadId = threadId;
455 | }
456 |
457 | public void run() {
458 | if (args_.verbose > 2) {
459 | System.out.println("thread: " + threadId + " RUNNING!");
460 | }
461 | Exception catchedException = null;
462 | LineReader lineReader = null;
463 | try {
464 | lineReader = lineReaderClass_.getConstructor(String.class, String.class).newInstance(args_.input,
465 | charsetName_);
466 | lineReader.skipLine(threadId * threadFileSize / args_.thread);
467 | Model model = new Model(input_, output_, args_, threadId);
468 | if (args_.model == model_name.sup) {
469 | model.setTargetCounts(dict_.getCounts(entry_type.label));
470 | } else {
471 | model.setTargetCounts(dict_.getCounts(entry_type.word));
472 | }
473 |
474 | final long ntokens = dict_.ntokens();
475 | long localTokenCount = 0;
476 |
477 | List line = new ArrayList();
478 | List labels = new ArrayList();
479 |
480 | String[] lineTokens;
481 | while (tokenCount_.get() < args_.epoch * ntokens) {
482 | lineTokens = lineReader.readLineTokens();
483 | if (lineTokens == null) {
484 | try {
485 | lineReader.rewind();
486 | if (args_.verbose > 2) {
487 | System.out.println("Input file reloaded!");
488 | }
489 | } catch (Exception e) {
490 | e.printStackTrace();
491 | }
492 | lineTokens = lineReader.readLineTokens();
493 | }
494 |
495 | float progress = (float) (tokenCount_.get()) / (args_.epoch * ntokens);
496 | float lr = (float) (args_.lr * (1.0 - progress));
497 | localTokenCount += dict_.getLine(lineTokens, line, labels, model.rng);
498 | if (args_.model == model_name.sup) {
499 | dict_.addNgrams(line, args_.wordNgrams);
500 | if (labels.size() == 0 || line.size() == 0) {
501 | continue;
502 | }
503 | supervised(model, lr, line, labels);
504 | } else if (args_.model == model_name.cbow) {
505 | cbow(model, lr, line);
506 | } else if (args_.model == model_name.sg) {
507 | skipgram(model, lr, line);
508 | }
509 | if (localTokenCount > args_.lrUpdateRate) {
510 | tokenCount_.addAndGet(localTokenCount);
511 | localTokenCount = 0;
512 | if (threadId == 0 && args_.verbose > 1 && (System.currentTimeMillis() - start_) % 1000 == 0) {
513 | printInfo(progress, model.getLoss());
514 | }
515 | }
516 | }
517 |
518 | if (threadId == 0 && args_.verbose > 1) {
519 | printInfo(1.0f, model.getLoss());
520 | }
521 | } catch (Exception e) {
522 | catchedException = e;
523 | } finally {
524 | if (lineReader != null)
525 | try {
526 | lineReader.close();
527 | } catch (IOException e) {
528 | e.printStackTrace();
529 | }
530 | }
531 |
532 | // exit from thread
533 | synchronized (ft) {
534 | if (args_.verbose > 2) {
535 | System.out.println("\nthread: " + threadId + " EXIT!");
536 | }
537 | ft.threadCount--;
538 | ft.notify();
539 | if (catchedException != null) {
540 | throw new RuntimeException(catchedException);
541 | }
542 | }
543 | }
544 | }
545 |
546 | public void loadVectors(String filename) throws IOException {
547 | List words;
548 | Matrix mat; // temp. matrix for pretrained vectors
549 | int n, dim;
550 |
551 | BufferedReader dis = null;
552 | String line;
553 | String[] lineParts;
554 | try {
555 | dis = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "UTF-8"));
556 |
557 | line = dis.readLine();
558 | lineParts = line.split(" ");
559 | n = Integer.parseInt(lineParts[0]);
560 | dim = Integer.parseInt(lineParts[1]);
561 |
562 | words = new ArrayList(n);
563 |
564 | if (dim != args_.dim) {
565 | throw new IllegalArgumentException(
566 | "Dimension of pretrained vectors does not match args -dim option, pretrain dim is " + dim
567 | + ", args dim is " + args_.dim);
568 | }
569 |
570 | mat = new Matrix(n, dim);
571 | for (int i = 0; i < n; i++) {
572 | line = dis.readLine();
573 | lineParts = line.split(" ");
574 | String word = lineParts[0];
575 | for (int j = 1; j <= dim; j++) {
576 | mat.data_[i][j - 1] = Float.parseFloat(lineParts[j]);
577 | }
578 | words.add(word);
579 | dict_.add(word);
580 | }
581 |
582 | dict_.threshold(1, 0);
583 | input_ = new Matrix(dict_.nwords() + args_.bucket, args_.dim);
584 | input_.uniform(1.0f / args_.dim);
585 | for (int i = 0; i < n; i++) {
586 | int idx = dict_.getId(words.get(i));
587 | if (idx < 0 || idx >= dict_.nwords())
588 | continue;
589 | for (int j = 0; j < dim; j++) {
590 | input_.data_[idx][j] = mat.data_[i][j];
591 | }
592 | }
593 |
594 | } catch (IOException e) {
595 | throw new IOException("Pretrained vectors file cannot be opened!", e);
596 | } finally {
597 | try {
598 | if (dis != null) {
599 | dis.close();
600 | }
601 | } catch (IOException e) {
602 | e.printStackTrace();
603 | }
604 | }
605 | }
606 |
607 | int threadCount;
608 | long threadFileSize;
609 |
610 | public void train(Args args) throws IOException, Exception {
611 | args_ = args;
612 | dict_ = new Dictionary(args_);
613 | dict_.setCharsetName(charsetName_);
614 | dict_.setLineReaderClass(lineReaderClass_);
615 |
616 | if ("-".equals(args_.input)) {
617 | throw new IOException("Cannot use stdin for training!");
618 | }
619 |
620 | File file = new File(args_.input);
621 | if (!(file.exists() && file.isFile() && file.canRead())) {
622 | throw new IOException("Input file cannot be opened! " + args_.input);
623 | }
624 |
625 | dict_.readFromFile(args_.input);
626 | threadFileSize = Utils.sizeLine(args_.input);
627 |
628 | if (!Utils.isEmpty(args_.pretrainedVectors)) {
629 | loadVectors(args_.pretrainedVectors);
630 | } else {
631 | input_ = new Matrix(dict_.nwords() + args_.bucket, args_.dim);
632 | input_.uniform(1.0f / args_.dim);
633 | }
634 |
635 | if (args_.model == model_name.sup) {
636 | output_ = new Matrix(dict_.nlabels(), args_.dim);
637 | } else {
638 | output_ = new Matrix(dict_.nwords(), args_.dim);
639 | }
640 | output_.zero();
641 |
642 | start_ = System.currentTimeMillis();
643 | tokenCount_ = new AtomicLong(0);
644 | long t0 = System.currentTimeMillis();
645 | threadCount = args_.thread;
646 | for (int i = 0; i < args_.thread; i++) {
647 | Thread t = new TrainThread(this, i);
648 | t.setUncaughtExceptionHandler(trainThreadExcpetionHandler);
649 | t.start();
650 | }
651 |
652 | synchronized (this) {
653 | while (threadCount > 0) {
654 | try {
655 | wait();
656 | } catch (InterruptedException ignored) {
657 | }
658 | }
659 | }
660 |
661 | model_ = new Model(input_, output_, args_, 0);
662 |
663 | if (args.verbose > 1) {
664 | long trainTime = (System.currentTimeMillis() - t0) / 1000;
665 | System.out.printf("\nTrain time used: %d sec\n", trainTime);
666 | }
667 |
668 | saveModel();
669 | if (args_.model != model_name.sup) {
670 | saveVectors();
671 | }
672 | }
673 |
674 | protected Thread.UncaughtExceptionHandler trainThreadExcpetionHandler = new Thread.UncaughtExceptionHandler() {
675 | public void uncaughtException(Thread th, Throwable ex) {
676 | ex.printStackTrace();
677 | }
678 | };
679 |
680 | public Args getArgs() {
681 | return args_;
682 | }
683 |
684 | public Dictionary getDict() {
685 | return dict_;
686 | }
687 |
688 | public Matrix getInput() {
689 | return input_;
690 | }
691 |
692 | public Matrix getOutput() {
693 | return output_;
694 | }
695 |
696 | public Model getModel() {
697 | return model_;
698 | }
699 |
700 | public void setArgs(Args args) {
701 | this.args_ = args;
702 | }
703 |
704 | public void setDict(Dictionary dict) {
705 | this.dict_ = dict;
706 | }
707 |
708 | public void setInput(Matrix input) {
709 | this.input_ = input;
710 | }
711 |
712 | public void setOutput(Matrix output) {
713 | this.output_ = output;
714 | }
715 |
716 | public void setModel(Model model) {
717 | this.model_ = model;
718 | }
719 |
720 | public String getCharsetName() {
721 | return charsetName_;
722 | }
723 |
724 | public Class extends LineReader> getLineReaderClass() {
725 | return lineReaderClass_;
726 | }
727 |
728 | public void setCharsetName(String charsetName) {
729 | this.charsetName_ = charsetName;
730 | }
731 |
732 | public void setLineReaderClass(Class extends LineReader> lineReaderClass) {
733 | this.lineReaderClass_ = lineReaderClass;
734 | }
735 |
736 | }
737 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/IOUtil.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import java.io.IOException;
4 | import java.io.InputStream;
5 | import java.nio.ByteBuffer;
6 | import java.nio.ByteOrder;
7 |
8 | /**
9 | * Read/write cpp primitive type
10 | *
11 | * @author Ivan
12 | *
13 | */
14 | public class IOUtil {
15 |
16 | public IOUtil() {
17 | }
18 |
19 | private int string_buf_size_ = 128;
20 | private byte[] int_bytes_ = new byte[4];
21 | private byte[] long_bytes_ = new byte[8];
22 | private byte[] float_bytes_ = new byte[4];
23 | private byte[] double_bytes_ = new byte[8];
24 | private byte[] string_bytes_ = new byte[string_buf_size_];
25 | private StringBuilder stringBuilder_ = new StringBuilder();
26 | private ByteBuffer float_array_bytebuffer_ = null;
27 | private byte[] float_array_bytes_ = null;
28 |
29 | public void setStringBufferSize(int size) {
30 | string_buf_size_ = size;
31 | string_bytes_ = new byte[string_buf_size_];
32 | }
33 |
34 | public void setFloatArrayBufferSize(int itemSize) {
35 | float_array_bytebuffer_ = ByteBuffer.allocate(itemSize * 4).order(ByteOrder.LITTLE_ENDIAN);
36 | float_array_bytes_ = new byte[itemSize * 4];
37 | }
38 |
39 | public int readByte(InputStream is) throws IOException {
40 | return is.read() & 0xFF;
41 | }
42 |
43 | public int readInt(InputStream is) throws IOException {
44 | is.read(int_bytes_);
45 | return getInt(int_bytes_);
46 | }
47 |
48 | public int getInt(byte[] b) {
49 | return (b[0] & 0xFF) << 0 | (b[1] & 0xFF) << 8 | (b[2] & 0xFF) << 16 | (b[3] & 0xFF) << 24;
50 | }
51 |
52 | public long readLong(InputStream is) throws IOException {
53 | is.read(long_bytes_);
54 | return getLong(long_bytes_);
55 | }
56 |
57 | public long getLong(byte[] b) {
58 | return (b[0] & 0xFFL) << 0 | (b[1] & 0xFFL) << 8 | (b[2] & 0xFFL) << 16 | (b[3] & 0xFFL) << 24
59 | | (b[4] & 0xFFL) << 32 | (b[5] & 0xFFL) << 40 | (b[6] & 0xFFL) << 48 | (b[7] & 0xFFL) << 56;
60 | }
61 |
62 | public float readFloat(InputStream is) throws IOException {
63 | is.read(float_bytes_);
64 | return getFloat(float_bytes_);
65 | }
66 |
67 | public void readFloat(InputStream is, float[] data) throws IOException {
68 | is.read(float_array_bytes_);
69 | float_array_bytebuffer_.clear();
70 | ((ByteBuffer) float_array_bytebuffer_.put(float_array_bytes_).flip()).asFloatBuffer().get(data);
71 | }
72 |
73 | public float getFloat(byte[] b) {
74 | return Float
75 | .intBitsToFloat((b[0] & 0xFF) << 0 | (b[1] & 0xFF) << 8 | (b[2] & 0xFF) << 16 | (b[3] & 0xFF) << 24);
76 | }
77 |
78 | public double readDouble(InputStream is) throws IOException {
79 | is.read(double_bytes_);
80 | return getDouble(double_bytes_);
81 | }
82 |
83 | public double getDouble(byte[] b) {
84 | return Double.longBitsToDouble(getLong(b));
85 | }
86 |
87 | public String readString(InputStream is) throws IOException {
88 | int b = is.read();
89 | if (b < 0) {
90 | return null;
91 | }
92 | int i = -1;
93 | stringBuilder_.setLength(0);
94 | // ascii space, \n, \0
95 | while (b > -1 && b != 32 && b != 10 && b != 0) {
96 | string_bytes_[++i] = (byte) b;
97 | b = is.read();
98 | if (i == string_buf_size_ - 1) {
99 | stringBuilder_.append(new String(string_bytes_/*, "utf8"*/));
100 | i = -1;
101 | }
102 | }
103 | stringBuilder_.append(new String(string_bytes_, 0, i + 1/*, "utf8"*/));
104 | return stringBuilder_.toString();
105 | }
106 |
107 | public int intToByte(int i) {
108 | return (i & 0xFF);
109 | }
110 |
111 | public byte[] intToByteArray(int i) {
112 | int_bytes_[0] = (byte) ((i >> 0) & 0xff);
113 | int_bytes_[1] = (byte) ((i >> 8) & 0xff);
114 | int_bytes_[2] = (byte) ((i >> 16) & 0xff);
115 | int_bytes_[3] = (byte) ((i >> 24) & 0xff);
116 | return int_bytes_;
117 | }
118 |
119 | public byte[] longToByteArray(long l) {
120 | long_bytes_[0] = (byte) ((l >> 0) & 0xff);
121 | long_bytes_[1] = (byte) ((l >> 8) & 0xff);
122 | long_bytes_[2] = (byte) ((l >> 16) & 0xff);
123 | long_bytes_[3] = (byte) ((l >> 24) & 0xff);
124 | long_bytes_[4] = (byte) ((l >> 32) & 0xff);
125 | long_bytes_[5] = (byte) ((l >> 40) & 0xff);
126 | long_bytes_[6] = (byte) ((l >> 48) & 0xff);
127 | long_bytes_[7] = (byte) ((l >> 56) & 0xff);
128 |
129 | return long_bytes_;
130 | }
131 |
132 | public byte[] floatToByteArray(float f) {
133 | return intToByteArray(Float.floatToIntBits(f));
134 | }
135 |
136 | public byte[] floatToByteArray(float[] f) {
137 | float_array_bytebuffer_.clear();
138 | float_array_bytebuffer_.asFloatBuffer().put(f);
139 | return float_array_bytebuffer_.array();
140 | }
141 |
142 | public byte[] doubleToByteArray(double d) {
143 | return longToByteArray(Double.doubleToRawLongBits(d));
144 | }
145 |
146 | }
147 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/Main.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import java.io.File;
4 | import java.io.FileInputStream;
5 | import java.io.IOException;
6 |
7 | public class Main {
8 |
9 | public static void printUsage() {
10 | System.out.print("usage: java -jar fasttext.jar \n\n"
11 | + "The commands supported by fasttext are:\n\n"
12 | + " supervised train a supervised classifier\n"
13 | + " test evaluate a supervised classifier\n"
14 | + " predict predict most likely labels\n"
15 | + " predict-prob predict most likely labels with probabilities\n"
16 | + " skipgram train a skipgram model\n"
17 | + " cbow train a cbow model\n"
18 | + " print-vectors print vectors given a trained model\n");
19 | }
20 |
21 | public static void printTestUsage() {
22 | System.out.print("usage: java -jar fasttext.jar test []\n\n"
23 | + " model filename\n"
24 | + " test data filename (if -, read from stdin)\n"
25 | + " (optional; 1 by default) predict top k labels\n");
26 | }
27 |
28 | public static void printPredictUsage() {
29 | System.out.print("usage: java -jar fasttext.jar predict[-prob] []\n\n"
30 | + " model filename\n"
31 | + " test data filename (if -, read from stdin)\n"
32 | + " (optional; 1 by default) predict top k labels\n");
33 | }
34 |
35 | public static void printPrintVectorsUsage() {
36 | System.out.print("usage: java -jar fasttext.jar print-vectors \n\n"
37 | + " model filename\n");
38 | }
39 |
40 | public void test(String[] args) throws IOException, Exception {
41 | int k = 1;
42 | if (args.length == 3) {
43 | k = 1;
44 | } else if (args.length == 4) {
45 | k = Integer.parseInt(args[3]);
46 | } else {
47 | printTestUsage();
48 | System.exit(1);
49 | }
50 | FastText fasttext = new FastText();
51 | fasttext.loadModel(args[1]);
52 | String infile = args[2];
53 | if ("-".equals(infile)) {
54 | fasttext.test(System.in, k);
55 | } else {
56 | File file = new File(infile);
57 | if (!(file.exists() && file.isFile() && file.canRead())) {
58 | throw new IOException("Test file cannot be opened!");
59 | }
60 | fasttext.test(new FileInputStream(file), k);
61 | }
62 | }
63 |
64 | public void predict(String[] args) throws IOException, Exception {
65 | int k = 1;
66 | if (args.length == 3) {
67 | k = 1;
68 | } else if (args.length == 4) {
69 | k = Integer.parseInt(args[3]);
70 | } else {
71 | printPredictUsage();
72 | System.exit(1);
73 | }
74 | boolean print_prob = "predict-prob".equalsIgnoreCase(args[0]);
75 | FastText fasttext = new FastText();
76 | fasttext.loadModel(args[1]);
77 |
78 | String infile = args[2];
79 | if ("-".equals(infile)) {
80 | fasttext.predict(System.in, k, print_prob);
81 | } else {
82 | File file = new File(infile);
83 | if (!(file.exists() && file.isFile() && file.canRead())) {
84 | throw new IOException("Input file cannot be opened!");
85 | }
86 | fasttext.predict(new FileInputStream(file), k, print_prob);
87 | }
88 | }
89 |
90 | public void printVectors(String[] args) throws IOException {
91 | if (args.length != 2) {
92 | printPrintVectorsUsage();
93 | System.exit(1);
94 | }
95 | FastText fasttext = new FastText();
96 | fasttext.loadModel(args[1]);
97 | fasttext.printVectors();
98 | }
99 |
100 | public void train(String[] args) throws IOException, Exception {
101 | Args a = new Args();
102 | a.parseArgs(args);
103 | FastText fasttext = new FastText();
104 | fasttext.train(a);
105 | }
106 |
107 | public static void main(String[] args) {
108 | Main op = new Main();
109 |
110 | if (args.length == 0) {
111 | printUsage();
112 | System.exit(1);
113 | }
114 |
115 | try {
116 | String command = args[0];
117 | if ("skipgram".equalsIgnoreCase(command) || "cbow".equalsIgnoreCase(command)
118 | || "supervised".equalsIgnoreCase(command)) {
119 | op.train(args);
120 | } else if ("test".equalsIgnoreCase(command)) {
121 | op.test(args);
122 | } else if ("print-vectors".equalsIgnoreCase(command)) {
123 | op.printVectors(args);
124 | } else if ("predict".equalsIgnoreCase(command) || "predict-prob".equalsIgnoreCase(command)) {
125 | op.predict(args);
126 | } else {
127 | printUsage();
128 | System.exit(1);
129 | }
130 | } catch (Exception e) {
131 | e.printStackTrace();
132 | System.exit(1);
133 | }
134 |
135 | System.exit(0);
136 | }
137 |
138 | }
139 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/Matrix.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import java.io.IOException;
4 | import java.io.InputStream;
5 | import java.io.OutputStream;
6 | import java.util.Random;
7 |
8 | public strictfp class Matrix {
9 |
10 | public float[][] data_ = null;
11 | public int m_ = 0; // vocabSize
12 | public int n_ = 0; // layer1Size
13 |
14 | public Matrix() {
15 | }
16 |
17 | public Matrix(int m, int n) {
18 | m_ = m;
19 | n_ = n;
20 | data_ = new float[m][n];
21 | }
22 |
23 | public Matrix(final Matrix other) {
24 | m_ = other.m_;
25 | n_ = other.n_;
26 | data_ = new float[m_][n_];
27 | for (int i = 0; i < m_; i++) {
28 | for (int j = 0; j < n_; j++) {
29 | data_[i][j] = other.data_[i][j];
30 | }
31 | }
32 | }
33 |
34 | public void zero() {
35 | for (int i = 0; i < m_; i++) {
36 | for (int j = 0; j < n_; j++) {
37 | data_[i][j] = 0.0f;
38 | }
39 | }
40 | }
41 |
42 | public void uniform(float a) {
43 | Random random = new Random(1l);
44 | for (int i = 0; i < m_; i++) {
45 | for (int j = 0; j < n_; j++) {
46 | data_[i][j] = Utils.randomFloat(random, -a, a);
47 | }
48 | }
49 | }
50 |
51 | public void addRow(final Vector vec, int i, float a) {
52 | Utils.checkArgument(i >= 0);
53 | Utils.checkArgument(i < m_);
54 | Utils.checkArgument(vec.m_ == n_);
55 | for (int j = 0; j < n_; j++) {
56 | data_[i][j] += a * vec.data_[j];
57 | }
58 | }
59 |
60 | public float dotRow(final Vector vec, int i) {
61 | Utils.checkArgument(i >= 0);
62 | Utils.checkArgument(i < m_);
63 | Utils.checkArgument(vec.m_ == n_);
64 | float d = 0.0f;
65 | for (int j = 0; j < n_; j++) {
66 | d += data_[i][j] * vec.data_[j];
67 | }
68 | return d;
69 | }
70 |
71 | public void load(InputStream input) throws IOException {
72 | IOUtil ioutil = new IOUtil();
73 |
74 | m_ = (int) ioutil.readLong(input);
75 | n_ = (int) ioutil.readLong(input);
76 |
77 | ioutil.setFloatArrayBufferSize(n_);
78 | data_ = new float[m_][n_];
79 | for (int i = 0; i < m_; i++) {
80 | ioutil.readFloat(input, data_[i]);
81 | }
82 | }
83 |
84 | public void save(OutputStream ofs) throws IOException {
85 | IOUtil ioutil = new IOUtil();
86 | ioutil.setFloatArrayBufferSize(n_);
87 | ofs.write(ioutil.longToByteArray(m_));
88 | ofs.write(ioutil.longToByteArray(n_));
89 | for (int i = 0; i < m_; i++) {
90 | ofs.write(ioutil.floatToByteArray(data_[i]));
91 | }
92 | }
93 |
94 | @Override
95 | public String toString() {
96 | StringBuilder builder = new StringBuilder();
97 | builder.append("Matrix [data_=");
98 | if (data_ != null) {
99 | builder.append("[");
100 | for (int i = 0; i < m_ && i < 10; i++) {
101 | for (int j = 0; j < n_ && j < 10; j++) {
102 | builder.append(data_[i][j]).append(",");
103 | }
104 | }
105 | builder.setLength(builder.length() - 1);
106 | builder.append("]");
107 | } else {
108 | builder.append("null");
109 | }
110 | builder.append(", m_=");
111 | builder.append(m_);
112 | builder.append(", n_=");
113 | builder.append(n_);
114 | builder.append("]");
115 | return builder.toString();
116 | }
117 |
118 | }
119 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/Model.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collections;
5 | import java.util.Comparator;
6 | import java.util.List;
7 | import java.util.Random;
8 |
9 | import fasttext.Args.loss_name;
10 | import fasttext.Args.model_name;
11 |
12 | public strictfp class Model {
13 |
14 | static final int SIGMOID_TABLE_SIZE = 512;
15 | static final int MAX_SIGMOID = 8;
16 | static final int LOG_TABLE_SIZE = 512;
17 |
18 | static final int NEGATIVE_TABLE_SIZE = 10000000;
19 |
20 | public class Node {
21 | int parent;
22 | int left;
23 | int right;
24 | long count;
25 | boolean binary;
26 | }
27 |
28 | private Matrix wi_; // input
29 | private Matrix wo_; // output
30 | private Args args_;
31 | private Vector hidden_;
32 | private Vector output_;
33 | private Vector grad_;
34 | private int hsz_; // dim
35 | @SuppressWarnings("unused")
36 | private int isz_; // input vocabSize
37 | private int osz_; // output vocabSize
38 | private float loss_;
39 | private long nexamples_;
40 | private float[] t_sigmoid;
41 | private float[] t_log;
42 | // used for negative sampling:
43 | private List negatives;
44 | private int negpos;
45 | // used for hierarchical softmax:
46 | private List> paths;
47 | private List> codes;
48 | private List tree;
49 |
50 | public transient Random rng;
51 |
52 | public Model(Matrix wi, Matrix wo, Args args, int seed) {
53 | hidden_ = new Vector(args.dim);
54 | output_ = new Vector(wo.m_);
55 | grad_ = new Vector(args.dim);
56 | rng = new Random((long) seed);
57 |
58 | wi_ = wi;
59 | wo_ = wo;
60 | args_ = args;
61 | isz_ = wi.m_;
62 | osz_ = wo.m_;
63 | hsz_ = args.dim;
64 | negpos = 0;
65 | loss_ = 0.0f;
66 | nexamples_ = 1l;
67 | initSigmoid();
68 | initLog();
69 | }
70 |
71 | public float binaryLogistic(int target, boolean label, float lr) {
72 | float score = sigmoid(wo_.dotRow(hidden_, target));
73 | float alpha = lr * ((label ? 1.0f : 0.0f) - score);
74 | grad_.addRow(wo_, target, alpha);
75 | wo_.addRow(hidden_, target, alpha);
76 | if (label) {
77 | return -log(score);
78 | } else {
79 | return -log(1.0f - score);
80 | }
81 | }
82 |
83 | public float negativeSampling(int target, float lr) {
84 | float loss = 0.0f;
85 | grad_.zero();
86 | for (int n = 0; n <= args_.neg; n++) {
87 | if (n == 0) {
88 | loss += binaryLogistic(target, true, lr);
89 | } else {
90 | loss += binaryLogistic(getNegative(target), false, lr);
91 | }
92 | }
93 | return loss;
94 | }
95 |
96 | public float hierarchicalSoftmax(int target, float lr) {
97 | float loss = 0.0f;
98 | grad_.zero();
99 | final List binaryCode = codes.get(target);
100 | final List pathToRoot = paths.get(target);
101 | for (int i = 0; i < pathToRoot.size(); i++) {
102 | loss += binaryLogistic(pathToRoot.get(i), binaryCode.get(i), lr);
103 | }
104 | return loss;
105 | }
106 |
107 | public void computeOutputSoftmax(Vector hidden, Vector output) {
108 | output.mul(wo_, hidden);
109 | float max = output.get(0), z = 0.0f;
110 | for (int i = 1; i < osz_; i++) {
111 | max = Math.max(output.get(i), max);
112 | }
113 | for (int i = 0; i < osz_; i++) {
114 | output.set(i, (float) Math.exp(output.get(i) - max));
115 | z += output.get(i);
116 | }
117 | for (int i = 0; i < osz_; i++) {
118 | output.set(i, output.get(i) / z);
119 | }
120 | }
121 |
122 | public void computeOutputSoftmax() {
123 | computeOutputSoftmax(hidden_, output_);
124 | }
125 |
126 | public float softmax(int target, float lr) {
127 | grad_.zero();
128 | computeOutputSoftmax();
129 | for (int i = 0; i < osz_; i++) {
130 | float label = (i == target) ? 1.0f : 0.0f;
131 | float alpha = lr * (label - output_.get(i));
132 | grad_.addRow(wo_, i, alpha);
133 | wo_.addRow(hidden_, i, alpha);
134 | }
135 | return -log(output_.get(target));
136 | }
137 |
138 | public void computeHidden(final List input, Vector hidden) {
139 | Utils.checkArgument(hidden.size() == hsz_);
140 | hidden.zero();
141 | for (Integer it : input) {
142 | hidden.addRow(wi_, it);
143 | }
144 | hidden.mul(1.0f / input.size());
145 | }
146 |
147 | private Comparator> comparePairs = new Comparator>() {
148 |
149 | @Override
150 | public int compare(Pair o1, Pair o2) {
151 | return o2.getKey().compareTo(o1.getKey());
152 | }
153 | };
154 |
155 | public void predict(final List input, int k, List> heap, Vector hidden,
156 | Vector output) {
157 | Utils.checkArgument(k > 0);
158 | if (heap instanceof ArrayList) {
159 | ((ArrayList>) heap).ensureCapacity(k + 1);
160 | }
161 | computeHidden(input, hidden);
162 | if (args_.loss == loss_name.hs) {
163 | dfs(k, 2 * osz_ - 2, 0.0f, heap, hidden);
164 | } else {
165 | findKBest(k, heap, hidden, output);
166 | }
167 | Collections.sort(heap, comparePairs);
168 | }
169 |
170 | public void predict(final List input, int k, List> heap) {
171 | predict(input, k, heap, hidden_, output_);
172 | }
173 |
174 | public void findKBest(int k, List> heap, Vector hidden, Vector output) {
175 | computeOutputSoftmax(hidden, output);
176 | for (int i = 0; i < osz_; i++) {
177 | if (heap.size() == k && log(output.get(i)) < heap.get(heap.size() - 1).getKey()) {
178 | continue;
179 | }
180 | heap.add(new Pair(log(output.get(i)), i));
181 | Collections.sort(heap, comparePairs);
182 | if (heap.size() > k) {
183 | Collections.sort(heap, comparePairs);
184 | heap.remove(heap.size() - 1); // pop last
185 | }
186 | }
187 | }
188 |
189 | public void dfs(int k, int node, float score, List> heap, Vector hidden) {
190 | if (heap.size() == k && score < heap.get(heap.size() - 1).getKey()) {
191 | return;
192 | }
193 |
194 | if (tree.get(node).left == -1 && tree.get(node).right == -1) {
195 | heap.add(new Pair(score, node));
196 | Collections.sort(heap, comparePairs);
197 | if (heap.size() > k) {
198 | Collections.sort(heap, comparePairs);
199 | heap.remove(heap.size() - 1); // pop last
200 | }
201 | return;
202 | }
203 |
204 | float f = sigmoid(wo_.dotRow(hidden, node - osz_));
205 | dfs(k, tree.get(node).left, score + log(1.0f - f), heap, hidden);
206 | dfs(k, tree.get(node).right, score + log(f), heap, hidden);
207 | }
208 |
209 | public void update(final List input, int target, float lr) {
210 | Utils.checkArgument(target >= 0);
211 | Utils.checkArgument(target < osz_);
212 | if (input.size() == 0) {
213 | return;
214 | }
215 | computeHidden(input, hidden_);
216 |
217 | if (args_.loss == loss_name.ns) {
218 | loss_ += negativeSampling(target, lr);
219 | } else if (args_.loss == loss_name.hs) {
220 | loss_ += hierarchicalSoftmax(target, lr);
221 | } else {
222 | loss_ += softmax(target, lr);
223 | }
224 | nexamples_ += 1;
225 |
226 | if (args_.model == model_name.sup) {
227 | grad_.mul(1.0f / input.size());
228 | }
229 | for (Integer it : input) {
230 | wi_.addRow(grad_, it, 1.0f);
231 | }
232 | }
233 |
234 | public void setTargetCounts(final List counts) {
235 | Utils.checkArgument(counts.size() == osz_);
236 | if (args_.loss == loss_name.ns) {
237 | initTableNegatives(counts);
238 | }
239 | if (args_.loss == loss_name.hs) {
240 | buildTree(counts);
241 | }
242 | }
243 |
244 | public void initTableNegatives(final List counts) {
245 | negatives = new ArrayList(counts.size());
246 | float z = 0.0f;
247 | for (int i = 0; i < counts.size(); i++) {
248 | z += (float) Math.pow(counts.get(i), 0.5f);
249 | }
250 | for (int i = 0; i < counts.size(); i++) {
251 | float c = (float) Math.pow(counts.get(i), 0.5f);
252 | for (int j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) {
253 | negatives.add(i);
254 | }
255 | }
256 | Utils.shuffle(negatives, rng);
257 | }
258 |
259 | public int getNegative(int target) {
260 | int negative;
261 | do {
262 | negative = negatives.get(negpos);
263 | negpos = (negpos + 1) % negatives.size();
264 | } while (target == negative);
265 | return negative;
266 | }
267 |
268 | public void buildTree(final List counts) {
269 | paths = new ArrayList>(osz_);
270 | codes = new ArrayList>(osz_);
271 | tree = new ArrayList(2 * osz_ - 1);
272 |
273 | for (int i = 0; i < 2 * osz_ - 1; i++) {
274 | Node node = new Node();
275 | node.parent = -1;
276 | node.left = -1;
277 | node.right = -1;
278 | node.count = 1000000000000000L;// 1e15f;
279 | node.binary = false;
280 | tree.add(i, node);
281 | }
282 | for (int i = 0; i < osz_; i++) {
283 | tree.get(i).count = counts.get(i);
284 | }
285 | int leaf = osz_ - 1;
286 | int node = osz_;
287 | for (int i = osz_; i < 2 * osz_ - 1; i++) {
288 | int[] mini = new int[2];
289 | for (int j = 0; j < 2; j++) {
290 | if (leaf >= 0 && tree.get(leaf).count < tree.get(node).count) {
291 | mini[j] = leaf--;
292 | } else {
293 | mini[j] = node++;
294 | }
295 | }
296 | tree.get(i).left = mini[0];
297 | tree.get(i).right = mini[1];
298 | tree.get(i).count = tree.get(mini[0]).count + tree.get(mini[1]).count;
299 | tree.get(mini[0]).parent = i;
300 | tree.get(mini[1]).parent = i;
301 | tree.get(mini[1]).binary = true;
302 | }
303 | for (int i = 0; i < osz_; i++) {
304 | List path = new ArrayList();
305 | List code = new ArrayList();
306 | int j = i;
307 | while (tree.get(j).parent != -1) {
308 | path.add(tree.get(j).parent - osz_);
309 | code.add(tree.get(j).binary);
310 | j = tree.get(j).parent;
311 | }
312 | paths.add(path);
313 | codes.add(code);
314 | }
315 | }
316 |
317 | public float getLoss() {
318 | return loss_ / nexamples_;
319 | }
320 |
321 | private void initSigmoid() {
322 | t_sigmoid = new float[SIGMOID_TABLE_SIZE + 1];
323 | for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) {
324 | float x = (float) (i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID;
325 | t_sigmoid[i] = (float) (1.0f / (1.0f + Math.exp(-x)));
326 | }
327 | }
328 |
329 | private void initLog() {
330 | t_log = new float[LOG_TABLE_SIZE + 1];
331 | for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) {
332 | float x = (float) (((float) (i) + 1e-5f) / LOG_TABLE_SIZE);
333 | t_log[i] = (float) Math.log(x);
334 | }
335 | }
336 |
337 | public float log(float x) {
338 | if (x > 1.0f) {
339 | return 0.0f;
340 | }
341 | int i = (int) (x * LOG_TABLE_SIZE);
342 | return t_log[i];
343 | }
344 |
345 | public float sigmoid(float x) {
346 | if (x < -MAX_SIGMOID) {
347 | return 0.0f;
348 | } else if (x > MAX_SIGMOID) {
349 | return 1.0f;
350 | } else {
351 | int i = (int) ((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2);
352 | return t_sigmoid[i];
353 | }
354 | }
355 | }
356 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/Pair.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | public class Pair {
4 |
5 | private K key_;
6 | private V value_;
7 |
8 | public Pair(K key, V value) {
9 | this.key_ = key;
10 | this.value_ = value;
11 | }
12 |
13 | public K getKey() {
14 | return key_;
15 | }
16 |
17 | public V getValue() {
18 | return value_;
19 | }
20 |
21 | public void setKey(K key) {
22 | this.key_ = key;
23 | }
24 |
25 | public void setValue(V value) {
26 | this.value_ = value;
27 | }
28 |
29 | }
30 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/Utils.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import java.io.BufferedInputStream;
4 | import java.io.BufferedReader;
5 | import java.io.FileInputStream;
6 | import java.io.IOException;
7 | import java.io.InputStream;
8 | import java.util.List;
9 | import java.util.ListIterator;
10 | import java.util.Map;
11 | import java.util.Random;
12 | import java.util.RandomAccess;
13 |
14 | public strictfp class Utils {
15 |
16 | /**
17 | * Ensures the truth of an expression involving one or more parameters to
18 | * the calling method.
19 | *
20 | * @param expression
21 | * a boolean expression
22 | * @throws IllegalArgumentException
23 | * if {@code expression} is false
24 | */
25 | public static void checkArgument(boolean expression) {
26 | if (!expression) {
27 | throw new IllegalArgumentException();
28 | }
29 | }
30 |
31 | public static void checkArgument(boolean expression, String message) {
32 | if (!expression) {
33 | throw new IllegalArgumentException(message);
34 | }
35 | }
36 |
37 | public static boolean isEmpty(String str) {
38 | return (str == null || str.isEmpty());
39 | }
40 |
41 | public static V mapGetOrDefault(Map map, K key, V defaultValue) {
42 | return map.containsKey(key) ? map.get(key) : defaultValue;
43 | }
44 |
45 | public static int randomInt(Random rnd, int lower, int upper) {
46 | checkArgument(lower <= upper & lower > 0);
47 | if (lower == upper) {
48 | return lower;
49 | }
50 | return rnd.nextInt(upper - lower) + lower;
51 | }
52 |
53 | public static float randomFloat(Random rnd, float lower, float upper) {
54 | checkArgument(lower <= upper);
55 | if (lower == upper) {
56 | return lower;
57 | }
58 | return (rnd.nextFloat() * (upper - lower)) + lower;
59 | }
60 |
61 | public static long sizeLine(String filename) throws IOException {
62 | InputStream is = new BufferedInputStream(new FileInputStream(filename));
63 | try {
64 | byte[] c = new byte[1024];
65 | long count = 0;
66 | int readChars = 0;
67 | boolean endsWithoutNewLine = false;
68 | while ((readChars = is.read(c)) != -1) {
69 | for (int i = 0; i < readChars; ++i) {
70 | if (c[i] == '\n')
71 | ++count;
72 | }
73 | endsWithoutNewLine = (c[readChars - 1] != '\n');
74 | }
75 | if (endsWithoutNewLine) {
76 | ++count;
77 | }
78 | return count;
79 | } finally {
80 | is.close();
81 | }
82 | }
83 |
84 | /**
85 | *
86 | * @param br
87 | * @param pos
88 | * line numbers start from 1
89 | * @throws IOException
90 | */
91 | public static void seekLine(BufferedReader br, long pos) throws IOException {
92 | // br.reset();
93 | String line;
94 | int currentLine = 1;
95 | while (currentLine < pos && (line = br.readLine()) != null) {
96 | if (Utils.isEmpty(line) || line.startsWith("#")) {
97 | continue;
98 | }
99 | currentLine++;
100 | }
101 | }
102 |
103 | private static final int SHUFFLE_THRESHOLD = 5;
104 |
105 | @SuppressWarnings({ "rawtypes", "unchecked" })
106 | public static void shuffle(List> list, Random rnd) {
107 | int size = list.size();
108 | if (size < SHUFFLE_THRESHOLD || list instanceof RandomAccess) {
109 | for (int i = size; i > 1; i--)
110 | swap(list, i - 1, rnd.nextInt(i));
111 | } else {
112 | Object arr[] = list.toArray();
113 |
114 | // Shuffle array
115 | for (int i = size; i > 1; i--)
116 | swap(arr, i - 1, rnd.nextInt(i));
117 |
118 | // Dump array back into list
119 | // instead of using a raw type here, it's possible to capture
120 | // the wildcard but it will require a call to a supplementary
121 | // private method
122 | ListIterator it = list.listIterator();
123 | for (int i = 0; i < arr.length; i++) {
124 | it.next();
125 | it.set(arr[i]);
126 | }
127 | }
128 | }
129 |
130 | /**
131 | * Swaps the two specified elements in the specified array.
132 | */
133 | public static void swap(Object[] arr, int i, int j) {
134 | Object tmp = arr[i];
135 | arr[i] = arr[j];
136 | arr[j] = tmp;
137 | }
138 |
139 | @SuppressWarnings({ "rawtypes", "unchecked" })
140 | public static void swap(List> list, int i, int j) {
141 | // instead of using a raw type here, it's possible to capture
142 | // the wildcard but it will require a call to a supplementary
143 | // private method
144 | final List l = list;
145 | l.set(i, l.set(j, l.get(i)));
146 | }
147 |
148 | }
149 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/Vector.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | public strictfp class Vector {
4 |
5 | public int m_;
6 | public float[] data_;
7 |
8 | public void addVector(Vector source) {
9 | for (int i = 0; i < m_; i++) {
10 | data_[i] += source.data_[i];
11 | }
12 | }
13 |
14 | public float norm() {
15 | float sum = 0;
16 | for (int i = 0; i < m_; i++) {
17 | sum += data_[i] * data_[i];
18 | }
19 | return (float) Math.sqrt(sum);
20 | }
21 |
22 | public Vector(int size) {
23 | m_ = size;
24 | data_ = new float[size];
25 | }
26 |
27 | public int size() {
28 | return m_;
29 | }
30 |
31 | public void zero() {
32 | for (int i = 0; i < m_; i++) {
33 | data_[i] = 0.0f;
34 | }
35 | }
36 |
37 | public void mul(float a) {
38 | for (int i = 0; i < m_; i++) {
39 | data_[i] *= a;
40 | }
41 | }
42 |
43 | public void addRow(final Matrix A, int i) {
44 | Utils.checkArgument(i >= 0);
45 | Utils.checkArgument(i < A.m_);
46 | Utils.checkArgument(m_ == A.n_);
47 | for (int j = 0; j < A.n_; j++) { // layer size
48 | data_[j] += A.data_[i][j];
49 | }
50 | }
51 |
52 | public void addRow(final Matrix A, int i, float a) {
53 | Utils.checkArgument(i >= 0);
54 | Utils.checkArgument(i < A.m_);
55 | Utils.checkArgument(m_ == A.n_);
56 | for (int j = 0; j < A.n_; j++) {
57 | data_[j] += a * A.data_[i][j];
58 | }
59 | }
60 |
61 | public void mul(final Matrix A, final Vector vec) {
62 | Utils.checkArgument(A.m_ == m_);
63 | Utils.checkArgument(A.n_ == vec.m_);
64 | for (int i = 0; i < m_; i++) {
65 | data_[i] = 0.0f;
66 | for (int j = 0; j < A.n_; j++) {
67 | data_[i] += A.data_[i][j] * vec.data_[j];
68 | }
69 | }
70 | }
71 |
72 | public int argmax() {
73 | float max = data_[0];
74 | int argmax = 0;
75 | for (int i = 1; i < m_; i++) {
76 | if (data_[i] > max) {
77 | max = data_[i];
78 | argmax = i;
79 | }
80 | }
81 | return argmax;
82 | }
83 |
84 | public float get(int i) {
85 | return data_[i];
86 | }
87 |
88 | public void set(int i, float value) {
89 | data_[i] = value;
90 | }
91 |
92 | @Override
93 | public String toString() {
94 | StringBuilder builder = new StringBuilder();
95 | for (float data : data_) {
96 | builder.append(data).append(' ');
97 | }
98 | if (builder.length() > 1) {
99 | builder.setLength(builder.length() - 1);
100 | }
101 | return builder.toString();
102 | }
103 |
104 | }
105 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/io/BufferedLineReader.java:
--------------------------------------------------------------------------------
1 | package fasttext.io;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.FileInputStream;
5 | import java.io.IOException;
6 | import java.io.InputStream;
7 | import java.io.InputStreamReader;
8 | import java.io.UnsupportedEncodingException;
9 |
10 | public class BufferedLineReader extends LineReader {
11 |
12 | private String lineDelimitingRegex_ = " |\r|\t|\\v|\f|\0";
13 |
14 | private BufferedReader br_;
15 |
16 | public BufferedLineReader(String filename, String charsetName) throws IOException, UnsupportedEncodingException {
17 | super(filename, charsetName);
18 | FileInputStream fis = new FileInputStream(file_);
19 | br_ = new BufferedReader(new InputStreamReader(fis, charset_));
20 | }
21 |
22 | public BufferedLineReader(InputStream inputStream, String charsetName) throws UnsupportedEncodingException {
23 | super(inputStream, charsetName);
24 | br_ = new BufferedReader(new InputStreamReader(inputStream, charset_));
25 | }
26 |
27 | @Override
28 | public long skipLine(long n) throws IOException {
29 | if (n < 0L) {
30 | throw new IllegalArgumentException("skip value is negative");
31 | }
32 | String line;
33 | long currentLine = 0;
34 | long readLine = 0;
35 | synchronized (lock) {
36 | while (currentLine < n && (line = br_.readLine()) != null) {
37 | readLine++;
38 | if (line == null || line.isEmpty() || line.startsWith("#")) {
39 | continue;
40 | }
41 | currentLine++;
42 | }
43 | return readLine;
44 | }
45 | }
46 |
47 | @Override
48 | public String readLine() throws IOException {
49 | synchronized (lock) {
50 | String lineString = br_.readLine();
51 | while (lineString != null && (lineString.isEmpty() || lineString.startsWith("#"))) {
52 | lineString = br_.readLine();
53 | }
54 | return lineString;
55 | }
56 | }
57 |
58 | @Override
59 | public String[] readLineTokens() throws IOException {
60 | String line = readLine();
61 | if (line == null)
62 | return null;
63 | else
64 | return line.split(lineDelimitingRegex_, -1);
65 | }
66 |
67 | @Override
68 | public int read(char[] cbuf, int off, int len) throws IOException {
69 | synchronized (lock) {
70 | return br_.read(cbuf, off, len);
71 | }
72 | }
73 |
74 | @Override
75 | public void close() throws IOException {
76 | synchronized (lock) {
77 | if (br_ != null) {
78 | br_.close();
79 | }
80 | }
81 | }
82 |
83 | @Override
84 | public void rewind() throws IOException {
85 | synchronized (lock) {
86 | if (br_ != null) {
87 | br_.close();
88 | }
89 | if (file_ != null) {
90 | FileInputStream fis = new FileInputStream(file_);
91 | br_ = new BufferedReader(new InputStreamReader(fis, charset_));
92 | } else {
93 | // br = new BufferedReader(new InputStreamReader(inputStream,
94 | // charset));
95 | throw new UnsupportedOperationException("InputStream rewind not supported");
96 | }
97 | }
98 | }
99 |
100 | public String getLineDelimitingRege() {
101 | return lineDelimitingRegex_;
102 | }
103 |
104 | public void setLineDelimitingRegex(String lineDelimitingRegex) {
105 | this.lineDelimitingRegex_ = lineDelimitingRegex;
106 | }
107 |
108 | }
109 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/io/LineReader.java:
--------------------------------------------------------------------------------
1 | package fasttext.io;
2 |
3 | import java.io.File;
4 | import java.io.IOException;
5 | import java.io.InputStream;
6 | import java.io.Reader;
7 | import java.io.UnsupportedEncodingException;
8 | import java.nio.charset.Charset;
9 |
10 | public abstract class LineReader extends Reader {
11 |
12 | protected InputStream inputStream_ = null;
13 | protected File file_ = null;
14 | protected Charset charset_ = null;
15 |
16 | protected LineReader() {
17 | super();
18 | }
19 |
20 | protected LineReader(Object lock) {
21 | super(lock);
22 | }
23 |
24 | public LineReader(String filename, String charsetName) throws IOException, UnsupportedEncodingException {
25 | this();
26 | this.file_ = new File(filename);
27 | this.charset_ = Charset.forName(charsetName);
28 | }
29 |
30 | public LineReader(InputStream inputStream, String charsetName) throws UnsupportedEncodingException {
31 | this();
32 | this.inputStream_ = inputStream;
33 | this.charset_ = Charset.forName(charsetName);
34 | }
35 |
36 | /**
37 | * Skips lines.
38 | *
39 | * @param n
40 | * The number of lines to skip
41 | * @return The number of lines actually skipped
42 | * @exception IOException
43 | * If an I/O error occurs
44 | * @exception IllegalArgumentException
45 | * If n
is negative.
46 | */
47 | public abstract long skipLine(long n) throws IOException;
48 |
49 | public abstract String readLine() throws IOException;
50 |
51 | public abstract String[] readLineTokens() throws IOException;
52 |
53 | public abstract void rewind() throws IOException;
54 | }
55 |
--------------------------------------------------------------------------------
/src/main/java/fasttext/io/MappedByteBufferLineReader.java:
--------------------------------------------------------------------------------
1 | package fasttext.io;
2 |
3 | import java.io.BufferedInputStream;
4 | import java.io.IOException;
5 | import java.io.InputStream;
6 | import java.io.RandomAccessFile;
7 | import java.io.UnsupportedEncodingException;
8 | import java.nio.ByteBuffer;
9 | import java.nio.CharBuffer;
10 | import java.nio.channels.FileChannel;
11 | import java.util.ArrayList;
12 | import java.util.List;
13 |
14 | public class MappedByteBufferLineReader extends LineReader {
15 |
16 | private static int DEFAULT_BUFFER_SIZE = 1024;
17 |
18 | private volatile ByteBuffer byteBuffer_ = null; // MappedByteBuffer
19 | private RandomAccessFile raf_ = null;
20 | private FileChannel channel_ = null;
21 | private byte[] bytes_ = null;
22 |
23 | private int string_buf_size_ = DEFAULT_BUFFER_SIZE;
24 | private boolean fillLine_ = false;
25 |
26 | private StringBuilder sb_ = new StringBuilder();
27 | private List tokens_ = new ArrayList();
28 |
29 | public MappedByteBufferLineReader(String filename, String charsetName)
30 | throws IOException, UnsupportedEncodingException {
31 | super(filename, charsetName);
32 | raf_ = new RandomAccessFile(file_, "r");
33 | channel_ = raf_.getChannel();
34 | byteBuffer_ = channel_.map(FileChannel.MapMode.READ_ONLY, 0, channel_.size());
35 | bytes_ = new byte[string_buf_size_];
36 | }
37 |
38 | public MappedByteBufferLineReader(InputStream inputStream, String charsetName) throws UnsupportedEncodingException {
39 | this(inputStream, charsetName, DEFAULT_BUFFER_SIZE);
40 | }
41 |
42 | public MappedByteBufferLineReader(InputStream inputStream, String charsetName, int buf_size)
43 | throws UnsupportedEncodingException {
44 | super(inputStream instanceof BufferedInputStream ? inputStream : new BufferedInputStream(inputStream),
45 | charsetName);
46 | string_buf_size_ = buf_size;
47 | byteBuffer_ = ByteBuffer.allocateDirect(string_buf_size_); // ByteBuffer.allocate(string_buf_size_);
48 | bytes_ = new byte[string_buf_size_];
49 | if (inputStream == System.in) {
50 | fillLine_ = true;
51 | }
52 | }
53 |
54 | @Override
55 | public long skipLine(long n) throws IOException {
56 | if (n < 0L) {
57 | throw new IllegalArgumentException("skip value is negative");
58 | }
59 | String line;
60 | long currentLine = 0;
61 | long readLine = 0;
62 | synchronized (lock) {
63 | ensureOpen();
64 | while (currentLine < n && (line = getLine()) != null) {
65 | readLine++;
66 | if (line == null || line.isEmpty() || line.startsWith("#")) {
67 | continue;
68 | }
69 | currentLine++;
70 | }
71 | }
72 | return readLine;
73 | }
74 |
75 | @Override
76 | public String readLine() throws IOException {
77 | synchronized (lock) {
78 | ensureOpen();
79 | String lineString = getLine();
80 | while (lineString != null && (lineString.isEmpty() || lineString.startsWith("#"))) {
81 | lineString = getLine();
82 | }
83 | return lineString;
84 | }
85 | }
86 |
87 | @Override
88 | public String[] readLineTokens() throws IOException {
89 | synchronized (lock) {
90 | ensureOpen();
91 | String[] tokens = getLineTokens();
92 | while (tokens != null && ((tokens.length == 1 && tokens[0].isEmpty()) || tokens[0].startsWith("#"))) {
93 | tokens = getLineTokens();
94 | }
95 | return tokens;
96 | }
97 | }
98 |
99 | @Override
100 | public void rewind() throws IOException {
101 | synchronized (lock) {
102 | ensureOpen();
103 | if (raf_ != null) {
104 | raf_.seek(0);
105 | channel_.position(0);
106 | }
107 | byteBuffer_.position(0);
108 | }
109 | }
110 |
111 | @Override
112 | public int read(char[] cbuf, int off, int len) throws IOException {
113 | synchronized (lock) {
114 | ensureOpen();
115 | if ((off < 0) || (off > cbuf.length) || (len < 0) || ((off + len) > cbuf.length) || ((off + len) < 0)) {
116 | throw new IndexOutOfBoundsException();
117 | } else if (len == 0) {
118 | return 0;
119 | }
120 |
121 | CharBuffer charBuffer = byteBuffer_.asCharBuffer();
122 | int length = Math.min(len, charBuffer.remaining());
123 | charBuffer.get(cbuf, off, length);
124 |
125 | if (inputStream_ != null) {
126 | off += length;
127 |
128 | while (off < len) {
129 | fillByteBuffer();
130 | if (!byteBuffer_.hasRemaining()) {
131 | break;
132 | }
133 | charBuffer = byteBuffer_.asCharBuffer();
134 | length = Math.min(len, charBuffer.remaining());
135 | charBuffer.get(cbuf, off, length);
136 | off += length;
137 | }
138 | }
139 | return length == len ? len : -1;
140 | }
141 | }
142 |
143 | @Override
144 | public void close() throws IOException {
145 | synchronized (lock) {
146 | if (raf_ != null) {
147 | raf_.close();
148 | } else if (inputStream_ != null) {
149 | inputStream_.close();
150 | }
151 | channel_ = null;
152 | byteBuffer_ = null;
153 | }
154 | }
155 |
156 | /** Checks to make sure that the stream has not been closed */
157 | private void ensureOpen() throws IOException {
158 | if (byteBuffer_ == null)
159 | throw new IOException("Stream closed");
160 | }
161 |
162 | protected String getLine() throws IOException {
163 | fillByteBuffer();
164 | if (!byteBuffer_.hasRemaining()) {
165 | return null;
166 | }
167 | sb_.setLength(0);
168 | int b = -1;
169 | int i = -1;
170 | do {
171 | b = byteBuffer_.get();
172 | if ((b >= 10 && b <= 13) || b == 0) {
173 | break;
174 | }
175 | bytes_[++i] = (byte) b;
176 | if (i == string_buf_size_ - 1) {
177 | sb_.append(new String(bytes_, charset_));
178 | i = -1;
179 | }
180 | fillByteBuffer();
181 | } while (byteBuffer_.hasRemaining());
182 |
183 | sb_.append(new String(bytes_, 0, i + 1, charset_));
184 | return sb_.toString();
185 | }
186 |
187 | // " |\r|\t|\\v|\f|\0"
188 | // 32 ' ', 9 \t, 10 \n, 11 \\v, 12 \f, 13 \r, 0 \0
189 | protected String[] getLineTokens() throws IOException {
190 | fillByteBuffer();
191 | if (!byteBuffer_.hasRemaining()) {
192 | return null;
193 | }
194 | tokens_.clear();
195 | sb_.setLength(0);
196 |
197 | int b = -1;
198 | int i = -1;
199 | do {
200 | b = byteBuffer_.get();
201 |
202 | if ((b >= 10 && b <= 13) || b == 0) {
203 | break;
204 | } else if (b == 9 || b == 32) {
205 | sb_.append(new String(bytes_, 0, i + 1, charset_));
206 | tokens_.add(sb_.toString());
207 | sb_.setLength(0);
208 | i = -1;
209 | } else {
210 | bytes_[++i] = (byte) b;
211 | if (i == string_buf_size_ - 1) {
212 | sb_.append(new String(bytes_, charset_));
213 | i = -1;
214 | }
215 | }
216 | fillByteBuffer();
217 | } while (byteBuffer_.hasRemaining());
218 |
219 | sb_.append(new String(bytes_, 0, i + 1, charset_));
220 | tokens_.add(sb_.toString());
221 | return tokens_.toArray(new String[tokens_.size()]);
222 | }
223 |
224 | private void fillByteBuffer() throws IOException {
225 | if (inputStream_ == null || byteBuffer_.hasRemaining()) {
226 | return;
227 | }
228 |
229 | byteBuffer_.clear();
230 |
231 | int b;
232 | for (int i = 0; i < string_buf_size_; i++) {
233 | b = inputStream_.read();
234 | if (b < 0) { // END OF STREAM
235 | break;
236 | }
237 | byteBuffer_.put((byte) b);
238 | if (fillLine_) {
239 | if ((b >= 10 && b <= 13) || b == 0) {
240 | break;
241 | }
242 | }
243 | }
244 |
245 | byteBuffer_.flip();
246 | }
247 |
248 | }
249 |
--------------------------------------------------------------------------------
/src/test/java/fasttext/TestDictionary.java:
--------------------------------------------------------------------------------
1 | package fasttext;
2 |
3 | import static org.junit.Assert.*;
4 |
5 | import java.util.Map;
6 |
7 | import org.junit.Test;
8 |
9 | public class TestDictionary {
10 |
11 | Dictionary dictionary = new Dictionary(new Args());
12 |
13 | @Test
14 | public void testHash() {
15 | assertEquals(dictionary.hash(","), 688690635l);
16 | assertEquals(dictionary.hash("is"), 1312329493l);
17 | assertEquals(dictionary.hash(""), 3617362777l);
18 | }
19 |
20 | @Test
21 | public void testFind() {
22 | assertEquals(dictionary.find(","), 28690635l);
23 | assertEquals(dictionary.find("is"), 22329493l);
24 | assertEquals(dictionary.find(""), 17362777l);
25 | }
26 |
27 | @Test
28 | public void testAdd() {
29 | dictionary.add(",");
30 | dictionary.add("is");
31 | dictionary.add("is");
32 | String w = "";
33 | dictionary.add(w);
34 | dictionary.add(w);
35 | dictionary.add(w);
36 | Map word2int = dictionary.getWord2int();
37 | assertEquals(3, dictionary.getWords().get(word2int.get(dictionary.find(w))).count);
38 | assertEquals(2, dictionary.getWords().get(word2int.get(dictionary.find("is"))).count);
39 | assertEquals(1, dictionary.getWords().get(word2int.get(dictionary.find(","))).count);
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/wikifil.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl
2 |
3 | # Program to filter Wikipedia XML dumps to "clean" text consisting only of lowercase
4 | # letters (a-z, converted from A-Z), and spaces (never consecutive).
5 | # All other characters are converted to spaces. Only text which normally appears
6 | # in the web browser is displayed. Tables are removed. Image captions are
7 | # preserved. Links are converted to normal text. Digits are spelled out.
8 |
9 | # Written by Matt Mahoney, June 10, 2006. This program is released to the public domain.
10 |
11 | $/=">"; # input record separator
12 | while (<>) {
13 | if (/ ...
14 | if (/#redirect/i) {$text=0;} # remove #REDIRECT
15 | if ($text) {
16 |
17 | # Remove any text not normally visible
18 | if (/<\/text>/) {$text=0;}
19 | s/<.*>//; # remove xml tags
20 | s/&/&/g; # decode URL encoded chars
21 | s/<//g;
23 | s/[//g; # remove references ... ]
24 | s/<[^>]*>//g; # remove xhtml tags
25 | s/\[http:[^] ]*/[/g; # remove normal url, preserve visible text
26 | s/\|thumb//ig; # remove images links, preserve caption
27 | s/\|left//ig;
28 | s/\|right//ig;
29 | s/\|\d+px//ig;
30 | s/\[\[image:[^\[\]]*\|//ig;
31 | s/\[\[category:([^|\]]*)[^]]*\]\]/[[$1]]/ig; # show categories without markup
32 | s/\[\[[a-z\-]*:[^\]]*\]\]//g; # remove links to other languages
33 | s/\[\[[^\|\]]*\|/[[/g; # remove wiki url, preserve visible text
34 | s/{{[^}]*}}//g; # remove {{icons}} and {tables}
35 | s/{[^}]*}//g;
36 | s/\[//g; # remove [ and ]
37 | s/\]//g;
38 | s/&[^;]*;/ /g; # remove URL encoded chars
39 |
40 | # convert to lowercase letters and spaces, spell digits
41 | $_=" $_ ";
42 | tr/A-Z/a-z/;
43 | s/0/ zero /g;
44 | s/1/ one /g;
45 | s/2/ two /g;
46 | s/3/ three /g;
47 | s/4/ four /g;
48 | s/5/ five /g;
49 | s/6/ six /g;
50 | s/7/ seven /g;
51 | s/8/ eight /g;
52 | s/9/ nine /g;
53 | tr/a-z/ /cs;
54 | chop;
55 | print $_;
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/word-vector-example.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # Copyright (c) 2016-present, Facebook, Inc.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree. An additional grant
8 | # of patent rights can be found in the PATENTS file in the same directory.
9 | #
10 |
11 | RESULTDIR=result
12 | DATADIR=data
13 |
14 | mkdir -p "${RESULTDIR}"
15 | mkdir -p "${DATADIR}"
16 |
17 | if [ ! -f "${DATADIR}/text9" ]
18 | then
19 | wget -c http://mattmahoney.net/dc/enwik9.zip -P "${DATADIR}"
20 | unzip "${DATADIR}/enwik9.zip" -d "${DATADIR}"
21 | perl wikifil.pl "${DATADIR}/enwik9" > "${DATADIR}"/text9
22 | fi
23 |
24 | if [ ! -f "${DATADIR}/rw/rw.txt" ]
25 | then
26 | wget -c http://stanford.edu/~lmthang/morphoNLM/rw.zip -P "${DATADIR}"
27 | unzip "${DATADIR}/rw.zip" -d "${DATADIR}"
28 | fi
29 |
30 | mvn package
31 |
32 | JAR=./target/fasttext-0.0.1-SNAPSHOT-jar-with-dependencies.jar
33 |
34 | java -jar ${JAR} skipgram -input "${DATADIR}"/text9 -output "${RESULTDIR}"/text9 -lr 0.025 -dim 100 \
35 | -ws 5 -epoch 1 -minCount 5 -neg 5 -loss ns -bucket 2000000 \
36 | -minn 3 -maxn 6 -thread 4 -t 1e-4 -lrUpdateRate 100
37 |
38 | cut -f 1,2 "${DATADIR}"/rw/rw.txt | awk '{print tolower($0)}' | tr '\t' '\n' > "${DATADIR}"/queries.txt
39 |
40 | cat "${DATADIR}"/queries.txt | ./fasttext print-vectors "${RESULTDIR}"/text9.bin > "${RESULTDIR}"/vectors.txt
41 |
42 | python eval.py -m "${RESULTDIR}"/vectors.txt -d "${DATADIR}"/rw/rw.txt
43 |
--------------------------------------------------------------------------------