├── .gitignore
├── CODEOWNERS
├── Dockerfile
├── LICENSE
├── OpenNMT-py
├── .gitignore
├── LICENSE.md
├── README.md
├── cache_embeddings.py
├── corenlp_tokenize.py
├── get_embed_for_dict.py
├── iwslt_xml2txt.py
├── multi30k_corenlp_tokenizer.sh
├── onmt
│ ├── Beam.py
│ ├── Constants.py
│ ├── Dataset.py
│ ├── Dict.py
│ ├── Models.py
│ ├── Optim.py
│ ├── Translator.py
│ ├── __init__.py
│ └── modules
│ │ ├── GlobalAttention.py
│ │ └── __init__.py
├── preprocess.py
├── train.py
├── translate.py
├── wmt_clean.py
└── wmt_sgm2txt.py
├── README.md
├── cove
├── __init__.py
└── encoder.py
├── get_data.sh
├── requirements.txt
├── setup.py
├── test
└── example.py
└── wmt_clean.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2 | #ECCN:Open Source
3 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # docker build --no-cache multitasking .
2 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
3 |
4 | RUN apt-get update \
5 | && apt-get install -y --no-install-recommends \
6 | git \
7 | ssh \
8 | build-essential \
9 | locales \
10 | ca-certificates \
11 | curl \
12 | unzip
13 |
14 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
15 | chmod +x ~/miniconda.sh && \
16 | ~/miniconda.sh -b -p /opt/conda && \
17 | rm ~/miniconda.sh && \
18 | /opt/conda/bin/conda install conda-build python=3.6.3 numpy pyyaml mkl&& \
19 | /opt/conda/bin/conda clean -ya
20 | ENV PATH /opt/conda/bin:$PATH
21 |
22 | # Default to utf-8 encodings in python
23 | # Can verify in container with:
24 | # python -c 'import locale; print(locale.getpreferredencoding(False))'
25 | RUN locale-gen en_US.UTF-8
26 | ENV LANG en_US.UTF-8
27 | ENV LANGUAGE en_US:en
28 | ENV LC_ALL en_US.UTF-8
29 |
30 | RUN conda install -c pytorch pytorch cuda90
31 |
32 | RUN pip install tqdm
33 | RUN pip install requests
34 | RUN pip install git+https://github.com/pytorch/text.git
35 | ADD ./README.md /README.md
36 | ADD ./cove/ /cove/
37 | ADD ./setup.py /setup.py
38 | RUN python setup.py develop
39 |
40 | ADD ./test/ /test/
41 |
42 | CMD python /test/example.py
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2017, Salesforce.com, Inc.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/OpenNMT-py/.gitignore:
--------------------------------------------------------------------------------
1 | pred.txt
2 | multi-bleu.perl
3 | *.pt
4 | *.pyc
5 |
--------------------------------------------------------------------------------
/OpenNMT-py/LICENSE.md:
--------------------------------------------------------------------------------
1 | This software is derived from the OpenNMT project at
2 | https://github.com/OpenNMT/OpenNMT.
3 |
4 | The MIT License (MIT)
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/OpenNMT-py/README.md:
--------------------------------------------------------------------------------
1 | # OpenNMT: Open-Source Neural Machine Translation
2 |
3 | This is a [Pytorch](https://github.com/pytorch/pytorch)
4 | port of [OpenNMT](https://github.com/OpenNMT/OpenNMT),
5 | an open-source (MIT) neural machine translation system.
6 |
7 |
8 |
9 | # Requirements
10 |
11 | =======
12 | ## Some useful tools:
13 |
14 | The example below uses the Moses tokenizer (http://www.statmt.org/moses/) to prepare the data and the Moses BLEU script for evaluation.
15 |
16 | ```bash
17 | ```bash
18 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl
19 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/lowercase.perl
20 | sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl
21 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de
22 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en
23 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl
24 | ```
25 | ## WMT'16 Multimodal Translation: Multi30k (de-en)
26 |
27 | An example of training for the WMT'16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html).
28 |
29 | ### 0) Download the data.
30 |
31 | ```bash
32 | mkdir -p data/multi30k
33 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz && tar -xf training.tar.gz -C data/multi30k && rm training.tar.gz
34 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz && tar -xf validation.tar.gz -C data/multi30k && rm validation.tar.gz
35 | wget https://staff.fnwi.uva.nl/d.elliott/wmt16/mmt16_task1_test.tgz && tar -xf mmt16_task1_test.tgz -C data/multi30k && rm mmt16_task1_test.tgz
36 | for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done
37 | ```
38 |
39 | The last line of the train and validation files is blank, so the last line of the bash above removes the non-compliant lines.
40 |
41 | ### 1) Preprocess the data.
42 |
43 | Moses tokenization without html escaping (add the -a option after -no-escape for aggressive hypen splitting)
44 |
45 | ```bash
46 | for l in en de; do for f in data/multi30k/*.$l; do perl tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done
47 | ```
48 |
49 | Typically, we lowercase this dataset, as the important comparisons are in uncased BLEU:
50 |
51 | ```bash
52 | for f in data/multi30k/*.tok; do perl lowercase.perl < $f > $f.low; done # if you ran Moses
53 | ```
54 |
55 | If you would like to use the Moses tokenization for source and target, prepare the data for the model as so:
56 |
57 | ```bash
58 | python preprocess.py -train_src data/multi30k/train.en.tok.low -train_tgt data/multi30k/train.de.tok.low -valid_src data/multi30k/val.en.tok.low -valid_tgt data/multi30k/val.de.tok.low -save_data data/multi30k.tok.low -lower
59 | ```
60 |
61 | ```bash
62 | ```
63 |
64 | The extra lower option in the line above will ensure that the vocabulary object converts all words to lowercase before lookup.
65 |
66 | If you would like to use GloVe vectors and character embeddings, now's the time:
67 |
68 | ```bash
69 | python get_embed_for_dict.py data/multi30k.tok.low.src.dict -glove -chargram -d_hid 400
70 | python get_embed_for_dict.py data/multi30k.tok.low.src.dict -glove -d_hid 300
71 | ```
72 |
73 | ### 2) Train the model.
74 |
75 | ```bash
76 | python train.py -data data/multi30k.tok.low.train.pt -save_model snapshots/multi30k.tok.low.600h.400d.2dp.brnn.2l.fixed_glove_char.model -brnn -pre_word_vecs_enc data/multi30k.tok.low.src.dict.glove.chargram -fix_embed
77 |
78 | python train.py -data data/multi30k.tok.low.train.pt -save_model snapshots/multi30k.tok.low.600h.300d.2dp.brnn.2l.fixed_glove.model -brnn -rnn_size 600 -word_vec_size 300 -pre_word_vecs_enc data/multi30k.tok.low.src.dict.glove -fix_embed
79 | ```
80 |
81 | ### 3) Translate sentences.
82 |
83 | ```bash
84 | python translate.py -gpu 0 -model model_name -src data/multi30k/test.en.tok.low -tgt data/multi30k/test.de.tok.low -replace_unk -verbose -output multi30k.tok.low.test.pred
85 |
86 | ```
87 |
88 | ### 4) Evaluate.
89 |
90 | ```bash
91 | perl multi-bleu.perl data/multi30k/test.de.tok.low < multi30k.tok.low.test.pred
92 | ```
93 |
94 | ## IWSLT'16 (de-en)
95 |
96 | ### 0) Download the data.
97 |
98 | ```bash
99 | mkdir -p data/iwslt16
100 | wget https://wit3.fbk.eu/archive/2016-01//texts/de/en/de-en.tgz && tar -xf de-en.tgz -C data
101 | ```
102 |
103 | ### 1) Preprocess the data.
104 |
105 | ```bash
106 | python iwslt_xml2txt.py data/de-en
107 | python iwslt_xml2txt.py data/de-en -a
108 |
109 | python preprocess.py -train_src data/de-en/train.de-en.en.tok -train_tgt data/de-en/train.de-en.de.tok -valid_src data/de-en/IWSLT16.TED.tst2013.de-en.en.tok -valid_tgt data/de-en/IWSLT16.TED.tst2013.de-en.de.tok -save_data data/iwslt16.tok.low -lower -src_vocab_size 22822 -tgt_vocab_size 32009
110 |
111 | #Glove Vectors + CharNgrams
112 | python get_embed_for_dict.py data/iwslt16.tok.low.src.dict -glove -chargrams
113 | python get_embed_for_dict.py data/iwslt16.tok.low.src.dict -glove
114 | ```
115 |
116 | ### 2) Train the model.
117 |
118 | ```bash
119 | python train.py -data data/iwslt16.tok.low.train.pt -save_model snapshots/iwslt16.tok.low.600h.400d.2dp.brnn.2l.fixed_glove_char.model -gpus 0 -brnn -rnn_size 600 -fix_embed -pre_word_vecs_enc data/iwslt16.tok.low.src.dict.glove.chargram > iwslt16.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.log
120 |
121 | python train.py -data data/iwslt16.tok.low.train.pt -save_model snapshots/iwslt16.tok.low.600h.300d.2dp.brnn.2l.fixed_glove_char.model -gpus 0 -brnn -rnn_size 600 -word_vec_size 300 -fix_embed -pre_word_vecs_enc data/iwslt16.tok.low.src.dict.glove > iwslt16.tok.low.600h.300d.2dp.brnn.2l.fixed_glove.log
122 | ```
123 |
124 | ### 3) Translate sentences.
125 |
126 | ```bash
127 | python translate.py -gpu 0 -model model_name -src data/de-en/IWSLT16.TED.tst2014.de-en.en.tok -tgt data/de-en/IWSLT16.TED.tst2014.de-en.de.tok -replace_unk -verbose -output iwslt.ted.tst2014.de-en.tok.low.pred
128 | ```
129 |
130 | ### 4) Evaluate.
131 |
132 | ```bash
133 | perl multi-bleu.perl data/de-en/IWSLT16.TED.tst2014.de-en.de.tok < iwslt.ted.tst2014.de-en.tok.low.pred
134 | ```
135 |
136 | ## WMt'17 (de-en)
137 |
138 | ### 0) Download the data.
139 |
140 | ```bash
141 | mkdir -p data/wmt17
142 | cd data/wmt17
143 | wget http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz
144 | wget http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz
145 | wget http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz
146 | wget http://data.statmt.org/wmt17/translation-task/rapid2016.tgz
147 | wget http://data.statmt.org/wmt17/translation-task/dev.tgz
148 | tar -xzf training-parallel-europarl-v7.tgz
149 | tar -xzf training-parallel-commoncrawl.tgz
150 | tar -xzf training-parallel-nc-v12.tgz
151 | tar -xzf rapid2016.tgz
152 | tar -xzf dev.tgz
153 | mkdir de-en
154 | mv *de-en* de-en
155 | mv training/*de-en* de-en
156 | mv dev/*deen* de-en
157 | mv dev/*ende* de-en
158 | mv dev/*.de de-en
159 | mv dev/*.en de-en
160 | mv dev/newstest2009*.en*
161 | mv dev/news-test2008*.en*
162 |
163 | python ../../wmt_clean.py de-en
164 | for l in de; do for f in de-en/*.clean.$l; do perl ../../tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done
165 | for l in en; do for f in de-en/*.clean.$l; do perl ../../tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done
166 | for l in en de; do for f in de-en/*.clean.$l.tok; do perl ../../lowercase.perl < $f > $f.low; done; done
167 | for l in en de; do perl ../../tokenizer.perl -no-escape -l $l -q < de-en/newstest2013.$l > de-en/newstest2013.$l.tok; done
168 | for l in en de; do perl ../../lowercase.perl < de-en/newstest2013.$l.tok > de-en/newstest2013.$l.tok.low; done
169 | for l in en de; do cat de-en/commoncraw*clean.$l.tok.low de-en/europarl*.clean.$l.tok.low de-en/news-commentary*.clean.$l.tok.low de-en/rapid*.clean.$l.tok.low > de-en/train.clean.$l.tok.low; done
170 | ```
171 |
172 | ### 1) Preprocess the data.
173 |
174 | ```bash
175 | # News Commentary
176 | python preprocess.py -train_src data/wmt17/de-en/news-commentary-v12.de-en.clean.en.tok.low -train_tgt data/wmt17/de-en/news-commentary-v12.de-en.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/news-commentary.clean.tok.low -lower -seq_length 75
177 | python get_embed_for_dict.py data/news-commentary.clean.tok.low.src.dict -glove -d_hid 300
178 | python get_embed_for_dict.py data/news-commentary.clean.tok.low.src.dict -glove -chargrams -d_hid 400
179 |
180 | # Rapid Fire
181 | python preprocess.py -train_src data/wmt17/de-en/rapid*.clean.en.tok.low -train_tgt data/wmt17/de-en/rapid*.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/rapid.clean.tok.low -lower -seq_length 75
182 | python get_embed_for_dict.py data/rapid.clean.tok.low.src.dict -glove -d_hid 300
183 | python get_embed_for_dict.py data/rapid.clean.tok.low.src.dict -glove -chargrams -d_hid 400
184 |
185 | # Europarl
186 | python preprocess.py -train_src data/wmt17/de-en/europarl*.clean.en.tok.low -train_tgt data/wmt17/de-en/europarl*.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/europarl.clean.tok.low -lower -seq_length 75
187 | python get_embed_for_dict.py data/europarl.clean.tok.low.src.dict -glove -d_hid 300
188 | python get_embed_for_dict.py data/europarl.clean.tok.low.src.dict -glove -chargrams -d_hid 400
189 |
190 | # Common Crawl
191 | python preprocess.py -train_src data/wmt17/de-en/commoncrawl*.clean.en.tok.low -train_tgt data/wmt17/de-en/commoncrawl*.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/commoncrawl.clean.tok.low -lower -seq_length 75
192 | python get_embed_for_dict.py data/commoncrawl.clean.tok.low.src.dict -glove -d_hid 300
193 | python get_embed_for_dict.py data/commoncrawl.clean.tok.low.src.dict -glove -chargrams -d_hid 400
194 |
195 | # WMT'17
196 | python preprocess.py -train_src data/wmt17/de-en/train.clean.en.tok.low -train_tgt data/wmt17/de-en/train.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/wmt17.clean.tok.low -lower -seq_length 75
197 | python get_embed_for_dict.py data/wmt17.clean.tok.low.src.dict -glove -d_hid 300
198 | python get_embed_for_dict.py data/wmt17.clean.tok.low.src.dict -glove -chargrams -d_hid 400
199 | ```
200 |
201 | ### 2) Train the model
202 |
203 | ```bash
204 | # Train fixed glove+char models
205 | for corpus in wmt17
206 | do
207 | python train.py -data data/${corpus}.clean.tok.low.train.pt -save_model snapshots/${corpus}.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.model -gpus 0 -brnn -word_vec_size 400 -pre_word_vecs_enc data/${corpus}.clean.tok.low.src.dict.glove.chargram -fix_embed > logs/${corpus}.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.log
208 | done
209 |
210 | # Train fixed glove models
211 | for corpus in wmt17
212 | do
213 | python train.py -data data/${corpus}.clean.tok.low.train.pt -save_model snapshots/${corpus}.clean.tok.low.600h.300d.2l.brnn.2dp.fixed_glove.model -gpus 0 -brnn -word_vec_size 300 -pre_word_vecs_enc data/${corpus}.clean.tok.low.src.dict.glove -fix_embed > logs/${corpus}.clean.tok.low.600h.300d.2l.brnn.2dp.fixed_glove.log
214 | done
215 |
216 | # Train fixed glove+char models
217 | for corpus in news-commentary rapid europarl commoncrawl
218 | do
219 | python train.py -data data/${corpus}.clean.tok.low.train.pt -save_model snapshots/${corpus}.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.model -gpus 0 -brnn -word_vec_size 400 -pre_word_vecs_enc data/${corpus}.clean.tok.low.src.dict.glove.chargram -fix_embed > logs/${corpus}.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.log
220 | done
221 |
222 | # Train fixed glove models
223 | for corpus in news-commentary rapid europarl commoncrawl
224 | do
225 | python train.py -data data/${corpus}.clean.tok.low.train.pt -save_model snapshots/${corpus}.clean.tok.low.600h.300d.2l.brnn.2dp.fixed_glove.model -gpus 0 -brnn -word_vec_size 300 -pre_word_vecs_enc data/${corpus}.clean.tok.low.src.dict.glove -fix_embed > logs/${corpus}.clean.tok.low.600h.300d.2l.brnn.2dp.fixed_glove.log
226 | done
227 | ```
228 |
229 | ### 3) Translate sentences.
230 |
231 | ### 4) Evaluate.
232 |
--------------------------------------------------------------------------------
/OpenNMT-py/cache_embeddings.py:
--------------------------------------------------------------------------------
1 | from os import path, environ, makedirs
2 | import logging
3 | import requests
4 | from argparse import ArgumentParser
5 |
6 | import torch
7 |
8 |
9 |
10 | def cache_glove(glove_prefix):
11 | stoi = {}
12 | itos = []
13 | vectors = []
14 | fname = glove_prefix+'.txt'
15 |
16 | with open(fname, 'rb') as f:
17 | for l in f:
18 | l = l.strip().split(b' ')
19 | word, vector = l[0], l[1:]
20 | try:
21 | word = word.decode()
22 | except:
23 | print('non-UTF8 token', repr(word), 'ignored')
24 | continue
25 | stoi[word] = len(itos)
26 | itos.append(word)
27 | vectors.append([float(x) for x in vector])
28 | d = {'stoi': stoi, 'itos': itos, 'vectors': torch.FloatTensor(vectors)}
29 | torch.save(d, glove_prefix+'.pt')
30 |
31 | def cache_chargrams():
32 | stoi = {}
33 | itos = []
34 | vectors = []
35 | fname = 'kazuma1.emb'
36 |
37 | with open(fname, 'rb') as f:
38 | for l in f:
39 | l = l.strip().split(b' ')
40 | word = l[0]
41 | vector = [float(n) for n in l[1:]]
42 |
43 | try:
44 | word = word.decode()
45 | except:
46 | print('non-UTF8 token', repr(word), 'ignored')
47 | continue
48 |
49 | stoi[word] = len(itos)
50 | itos.append(word)
51 | vectors.append(vector)
52 |
53 | d = {'stoi': stoi, 'itos': itos, 'vectors': torch.FloatTensor(vectors)}
54 | torch.save(d, 'kazuma.100d.pt')
55 |
56 |
57 | if __name__ == '__main__':
58 | parser = ArgumentParser()
59 | parser.add_argument('embeddings')
60 | parser.add_argument('-glove_prefix', default='glove.840B.300d', type=str)
61 | args = parser.parse_args()
62 |
63 | if args.embeddings == 'glove':
64 | cache_glove(args.glove_prefix)
65 | else:
66 | cache_chargrams()
67 |
68 |
--------------------------------------------------------------------------------
/OpenNMT-py/corenlp_tokenize.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import time
3 | sys.path.append(os.getcwd())
4 |
5 | from stanza.nlp.corenlp import CoreNLPClient
6 | import json
7 | from tqdm import tqdm
8 | from collections import defaultdict
9 | import re
10 | import gzip
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser(description='preprocess.py')
14 |
15 | ##
16 | ## **Preprocess Options**
17 | ##
18 |
19 | parser.add_argument('-config', help="Read options from this file")
20 |
21 | parser.add_argument('-input-fn', required=True,
22 | help="Path to the input english data")
23 | parser.add_argument('-output-fn', required=True,
24 | help="Path to the output english data")
25 |
26 | parser.add_argument('-src_vocab_size', type=int, default=50000,
27 | help="Size of the source vocabulary")
28 | parser.add_argument('-tgt_vocab_size', type=int, default=50000,
29 | help="Size of the target vocabulary")
30 | parser.add_argument('-src_vocab',
31 | help="Path to an existing source vocabulary")
32 | parser.add_argument('-tgt_vocab',
33 | help="Path to an existing target vocabulary")
34 |
35 |
36 | parser.add_argument('-seq_length', type=int, default=50,
37 | help="Maximum sequence length")
38 | parser.add_argument('-shuffle', type=int, default=1,
39 | help="Shuffle data")
40 | parser.add_argument('-seed', type=int, default=3435,
41 | help="Random seed")
42 |
43 | parser.add_argument('-lower', action='store_true', help='lowercase data')
44 |
45 | parser.add_argument('-report_every', type=int, default=100000,
46 | help="Report status every this many sentences")
47 |
48 | opt = parser.parse_args()
49 |
50 | corenlp = CoreNLPClient(default_annotators=['tokenize', 'ssplit'])
51 |
52 | def annotate_sentence(corenlp, gloss):
53 | try:
54 | parse = corenlp.annotate(gloss)
55 | except:
56 | time.sleep(10)
57 | parse = corenlp.annotate(gloss)
58 | token_str = ' '.join([token['word'] for sentence in parse.json['sentence'] for token in sentence['token'] ])
59 | #return parse.json['sentence'][0]['token']
60 | return token_str
61 |
62 | with open(opt.input_fn, 'r') as f:
63 | with open(opt.output_fn, 'w') as f2:
64 | for sent in f.readlines():
65 | f2.write(annotate_sentence(corenlp, sent))
66 | f2.write('\n')
67 |
--------------------------------------------------------------------------------
/OpenNMT-py/get_embed_for_dict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 |
4 | parser = ArgumentParser()
5 | parser.add_argument('path')
6 | parser.add_argument('-glove', action='store_true', dest='glove')
7 | parser.add_argument('-small-glove', action='store_true', dest='small_glove')
8 | parser.add_argument('-chargram', action='store_true', dest='chargram')
9 | parser.add_argument('-d_hid', default=400, type=int)
10 |
11 | args = parser.parse_args()
12 |
13 |
14 | def ngrams(sentence, n):
15 | return [sentence[i:i+n] for i in range(len(sentence)-n+1)]
16 |
17 |
18 | def charemb(w):
19 | chars = ['#BEGIN#'] + list(w) + ['#END#']
20 | match = {}
21 | for i in [2, 3, 4]:
22 | grams = ngrams(chars, i)
23 | for g in grams:
24 | g = '{}gram-{}'.format(i, ''.join(g))
25 | e = None
26 | if g in kazuma['stoi']:
27 | e = kazuma['vectors'][kazuma['stoi'][g]]
28 | if e is not None:
29 | match[g] = e
30 | if match:
31 | emb = sum(match.values()) / len(match)
32 | else:
33 | emb = torch.FloatTensor(100).uniform_(-0.1, 0.1)
34 | return emb
35 |
36 |
37 | with open(args.path, 'rb') as f:
38 | vocab = [l.strip().split(b' ')[0] for l in f]
39 |
40 | if args.glove:
41 | glove = torch.load('glove.840B.300d.pt')
42 | if args.chargram:
43 | kazuma = torch.load('kazuma.100d.pt')
44 |
45 | vectors = []
46 | for word in vocab:
47 | vector = torch.FloatTensor(args.d_hid).uniform_(-0.1, 0.1)
48 | try:
49 | word = word.decode()
50 | glove_dim = args.d_hid - 100 if args.chargram else args.d_hid
51 | if args.glove and word in glove['stoi']:
52 | vector[:glove_dim] = glove['vectors'][glove['stoi'][word]]
53 | if args.chargram:
54 | vector[glove_dim:] = charemb(word)
55 | except:
56 | import pdb; pdb.set_trace()
57 | print('non-UTF-8 token', repr(word), 'ignored')
58 | vectors.append(vector)
59 |
60 | ext = '.glove' if args.glove else ''
61 | ext += '.chargram' if args.chargram else ''
62 | torch.save(torch.stack(vectors), args.path + ext)
63 |
--------------------------------------------------------------------------------
/OpenNMT-py/iwslt_xml2txt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import subprocess
4 | import xml.etree.ElementTree as ET
5 | from argparse import ArgumentParser
6 |
7 | parser = ArgumentParser(description='xml2txt')
8 | parser.add_argument('path')
9 | parser.add_argument('-t', '--tags', nargs='+', default=['seg'])
10 | parser.add_argument('-th', '--threads', default=8, type=int)
11 | parser.add_argument('-a', '--aggressive', action='store_true')
12 | parser.add_argument('-corenlp', action='store_true')
13 | args = parser.parse_args()
14 |
15 | def tokenize(f_txt):
16 | lang = os.path.splitext(f_txt)[1][1:]
17 | f_tok = f_txt
18 | if args.aggressive:
19 | f_tok += '.atok'
20 | elif args.corenlp and lang == 'en':
21 | f_tok += '.corenlp'
22 | else:
23 | f_tok += '.tok'
24 | with open(f_tok, 'w') as fout, open(f_txt) as fin:
25 | if args.aggressive:
26 | pipe = subprocess.call(['perl', 'tokenizer.perl', '-a', '-q', '-threads', str(args.threads), '-no-escape', '-l', lang], stdin=fin, stdout=fout)
27 | elif args.corenlp and lang=='en':
28 | pipe = subprocess.call(['python', 'corenlp_tokenize.py', '-input-fn', f_txt, '-output-fn', f_tok])
29 | else:
30 | pipe = subprocess.call(['perl', 'tokenizer.perl', '-q', '-threads', str(args.threads), '-no-escape', '-l', lang], stdin=fin, stdout=fout)
31 |
32 | for f_xml in glob.iglob(os.path.join(args.path, '*.xml')):
33 | print(f_xml)
34 | f_txt = os.path.splitext(f_xml)[0]
35 | with open(f_txt, 'w') as fd_txt:
36 | root = ET.parse(f_xml).getroot()[0]
37 | for doc in root.findall('doc'):
38 | for tag in args.tags:
39 | for e in doc.findall(tag):
40 | fd_txt.write(e.text.strip() + '\n')
41 | tokenize(f_txt)
42 |
43 | xml_tags = [' 0:
63 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
64 | else:
65 | beamLk = wordLk[0]
66 |
67 | flatBeamLk = beamLk.view(-1)
68 |
69 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
70 | self.scores = bestScores
71 |
72 | # bestScoresId is flattened beam x word array, so calculate which
73 | # word and beam each score came from
74 | prevK = bestScoresId / numWords
75 | self.prevKs.append(prevK)
76 | self.nextYs.append(bestScoresId - prevK * numWords)
77 | self.attn.append(attnOut.index_select(0, prevK))
78 |
79 | # End condition is when top-of-beam is EOS.
80 | if self.nextYs[-1][0] == onmt.Constants.EOS:
81 | self.done = True
82 |
83 | return self.done
84 |
85 | def sortBest(self):
86 | return torch.sort(self.scores, 0, True)
87 |
88 | # Get the score of the best in the beam.
89 | def getBest(self):
90 | scores, ids = self.sortBest()
91 | return scores[1], ids[1]
92 |
93 | # Walk back to construct the full hypothesis.
94 | #
95 | # Parameters.
96 | #
97 | # * `k` - the position in the beam to construct.
98 | #
99 | # Returns.
100 | #
101 | # 1. The hypothesis
102 | # 2. The attention at each time step.
103 | def getHyp(self, k):
104 | hyp, attn = [], []
105 | # print(len(self.prevKs), len(self.nextYs), len(self.attn))
106 | for j in range(len(self.prevKs) - 1, -1, -1):
107 | hyp.append(self.nextYs[j+1][k])
108 | attn.append(self.attn[j][k])
109 | k = self.prevKs[j][k]
110 |
111 | return hyp[::-1], torch.stack(attn[::-1])
112 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/Constants.py:
--------------------------------------------------------------------------------
1 |
2 | PAD = 0
3 | UNK = 1
4 | BOS = 2
5 | EOS = 3
6 |
7 | PAD_WORD = ''
8 | UNK_WORD = ''
9 | BOS_WORD = ''
10 | EOS_WORD = ''
11 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/Dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import math
4 | import random
5 |
6 | import torch
7 | from torch.autograd import Variable
8 |
9 | import onmt
10 |
11 |
12 | class Dataset(object):
13 |
14 | def __init__(self, srcData, tgtData, batchSize, cuda, volatile=False):
15 | self.src = srcData
16 | if tgtData:
17 | self.tgt = tgtData
18 | assert(len(self.src) == len(self.tgt))
19 | else:
20 | self.tgt = None
21 | self.cuda = cuda
22 |
23 | self.batchSize = batchSize
24 | self.numBatches = math.ceil(len(self.src)/batchSize)
25 | self.volatile = volatile
26 |
27 | def _batchify(self, data, align_right=False, include_lengths=False):
28 | lengths = [x.size(0) for x in data]
29 | max_length = max(lengths)
30 | out = data[0].new(len(data), max_length).fill_(onmt.Constants.PAD)
31 | for i in range(len(data)):
32 | data_length = data[i].size(0)
33 | offset = max_length - data_length if align_right else 0
34 | out[i].narrow(0, offset, data_length).copy_(data[i])
35 |
36 | if include_lengths:
37 | return out, lengths
38 | else:
39 | return out
40 |
41 | def __getitem__(self, index):
42 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches)
43 | srcBatch, lengths = self._batchify(
44 | self.src[index*self.batchSize:(index+1)*self.batchSize],
45 | align_right=False, include_lengths=True)
46 |
47 | if self.tgt:
48 | tgtBatch = self._batchify(
49 | self.tgt[index*self.batchSize:(index+1)*self.batchSize])
50 | else:
51 | tgtBatch = None
52 |
53 | # within batch sorting by decreasing length for variable length rnns
54 | indices = range(len(srcBatch))
55 | batch = zip(indices, srcBatch) if tgtBatch is None else zip(indices, srcBatch, tgtBatch)
56 | batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1]))
57 | if tgtBatch is None:
58 | indices, srcBatch = zip(*batch)
59 | else:
60 | indices, srcBatch, tgtBatch = zip(*batch)
61 |
62 | def wrap(b):
63 | if b is None:
64 | return b
65 | b = torch.stack(b, 0).t().contiguous()
66 | if self.cuda:
67 | b = b.cuda()
68 | b = Variable(b, volatile=self.volatile)
69 | return b
70 |
71 | return (wrap(srcBatch), lengths), wrap(tgtBatch), indices
72 |
73 | def __len__(self):
74 | return self.numBatches
75 |
76 |
77 | def shuffle(self):
78 | data = list(zip(self.src, self.tgt))
79 | self.src, self.tgt = zip(*[data[i] for i in torch.randperm(len(data))])
80 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/Dict.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Dict(object):
5 | def __init__(self, data=None, lower=False):
6 | self.idxToLabel = {}
7 | self.labelToIdx = {}
8 | self.frequencies = {}
9 | self.lower = lower
10 |
11 | # Special entries will not be pruned.
12 | self.special = []
13 |
14 | if data is not None:
15 | if type(data) == str:
16 | self.loadFile(data)
17 | else:
18 | self.addSpecials(data)
19 |
20 | def size(self):
21 | return len(self.idxToLabel)
22 |
23 | # Load entries from a file.
24 | def loadFile(self, filename):
25 | for line in open(filename):
26 | fields = line.split()
27 | label = fields[0]
28 | idx = int(fields[1])
29 | self.add(label, idx)
30 |
31 | # Write entries to a file.
32 | def writeFile(self, filename):
33 | with open(filename, 'w') as file:
34 | for i in range(self.size()):
35 | label = self.idxToLabel[i]
36 | file.write('%s %d\n' % (label, i))
37 |
38 | file.close()
39 |
40 | def lookup(self, key, default=None):
41 | key = key.lower() if self.lower else key
42 | try:
43 | return self.labelToIdx[key]
44 | except KeyError:
45 | return default
46 |
47 | def getLabel(self, idx, default=None):
48 | try:
49 | return self.idxToLabel[idx]
50 | except KeyError:
51 | return default
52 |
53 | # Mark this `label` and `idx` as special (i.e. will not be pruned).
54 | def addSpecial(self, label, idx=None):
55 | idx = self.add(label, idx)
56 | self.special += [idx]
57 |
58 | # Mark all labels in `labels` as specials (i.e. will not be pruned).
59 | def addSpecials(self, labels):
60 | for label in labels:
61 | self.addSpecial(label)
62 |
63 | # Add `label` in the dictionary. Use `idx` as its index if given.
64 | def add(self, label, idx=None):
65 | label = label.lower() if self.lower else label
66 | if idx is not None:
67 | self.idxToLabel[idx] = label
68 | self.labelToIdx[label] = idx
69 | else:
70 | if label in self.labelToIdx:
71 | idx = self.labelToIdx[label]
72 | else:
73 | idx = len(self.idxToLabel)
74 | self.idxToLabel[idx] = label
75 | self.labelToIdx[label] = idx
76 |
77 | if idx not in self.frequencies:
78 | self.frequencies[idx] = 1
79 | else:
80 | self.frequencies[idx] += 1
81 |
82 | return idx
83 |
84 | # Return a new dictionary with the `size` most frequent entries.
85 | def prune(self, size):
86 | if size >= self.size():
87 | return self
88 |
89 | # Only keep the `size` most frequent entries.
90 | freq = torch.Tensor(
91 | [self.frequencies[i] for i in range(len(self.frequencies))])
92 | _, idx = torch.sort(freq, 0, True)
93 |
94 | newDict = Dict()
95 | newDict.lower = self.lower
96 |
97 | # Add special entries in all cases.
98 | for i in self.special:
99 | newDict.addSpecial(self.idxToLabel[i])
100 |
101 | for i in idx[:size]:
102 | newDict.add(self.idxToLabel[i])
103 |
104 | return newDict
105 |
106 | # Convert `labels` to indices. Use `unkWord` if not found.
107 | # Optionally insert `bosWord` at the beginning and `eosWord` at the .
108 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None):
109 | vec = []
110 |
111 | if bosWord is not None:
112 | vec += [self.lookup(bosWord)]
113 |
114 | unk = self.lookup(unkWord)
115 | vec += [self.lookup(label, default=unk) for label in labels]
116 |
117 | if eosWord is not None:
118 | vec += [self.lookup(eosWord)]
119 |
120 | return torch.LongTensor(vec)
121 |
122 | # Convert `idx` to labels. If index `stop` is reached, convert it and return.
123 | def convertToLabels(self, idx, stop):
124 | labels = []
125 |
126 | for i in idx:
127 | labels += [self.getLabel(i)]
128 | if i == stop:
129 | break
130 |
131 | return labels
132 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/Models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import onmt.modules
5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack
6 | from torch.nn.utils.rnn import pack_padded_sequence as pack
7 |
8 | class Encoder(nn.Module):
9 |
10 | def __init__(self, opt, dicts):
11 | super(Encoder, self).__init__()
12 | self.detach_embed = opt.detach_embed if hasattr(opt, 'detach_embed') else 0
13 | self.fix_embed = opt.fix_embed
14 | self.count = 0
15 | self.layers = opt.layers
16 | self.dropout = nn.Dropout(opt.dropout)
17 | self.num_directions = 2 if opt.brnn else 1
18 | assert opt.rnn_size % self.num_directions == 0
19 | self.hidden_size = opt.rnn_size // self.num_directions
20 | input_size = opt.word_vec_size
21 |
22 | self.word_lut = nn.Embedding(dicts.size(),
23 | opt.word_vec_size,
24 | padding_idx=onmt.Constants.PAD)
25 | self.rnn = nn.LSTM(input_size, self.hidden_size,
26 | num_layers=opt.layers,
27 | dropout=opt.dropout,
28 | bidirectional=opt.brnn)
29 |
30 | def load_pretrained_vectors(self, opt):
31 | if opt.pre_word_vecs_enc is not None:
32 | pretrained = torch.load(opt.pre_word_vecs_enc)
33 | self.word_lut.weight.data.copy_(pretrained)
34 |
35 | def forward(self, input, hidden=None):
36 | if isinstance(input, tuple):
37 | emb = self.word_lut(input[0])
38 | else:
39 | emb = self.word_lut(input)
40 | if self.fix_embed or self.count < self.detach_embed:
41 | emb = emb.detach()
42 | emb = self.dropout(emb)
43 | if isinstance(input, tuple):
44 | emb = pack(emb, input[1])
45 | outputs, hidden_t = self.rnn(emb, hidden)
46 | if isinstance(input, tuple):
47 | outputs = self.dropout(unpack(outputs)[0])
48 | self.count += 1
49 | return hidden_t, outputs
50 |
51 |
52 | class StackedLSTM(nn.Module):
53 | def __init__(self, num_layers, input_size, rnn_size, dropout):
54 | super(StackedLSTM, self).__init__()
55 | self.dropout = nn.Dropout(dropout)
56 | self.num_layers = num_layers
57 | self.layers = nn.ModuleList()
58 |
59 | for i in range(num_layers):
60 | self.layers.append(nn.LSTMCell(input_size, rnn_size))
61 | input_size = rnn_size
62 |
63 | def forward(self, input, hidden):
64 | h_0, c_0 = hidden
65 | h_1, c_1 = [], []
66 | for i, layer in enumerate(self.layers):
67 | input = self.dropout(input)
68 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
69 | input = h_1_i
70 | h_1 += [h_1_i]
71 | c_1 += [c_1_i]
72 |
73 | h_1 = torch.stack(h_1)
74 | c_1 = torch.stack(c_1)
75 |
76 | return input, (h_1, c_1)
77 |
78 |
79 | class Decoder(nn.Module):
80 |
81 | def __init__(self, opt, dicts):
82 | super(Decoder, self).__init__()
83 | self.layers = opt.layers
84 | self.input_feed = opt.input_feed
85 | input_size = opt.word_vec_size
86 | self.dropout = nn.Dropout(opt.dropout)
87 | if self.input_feed:
88 | input_size += opt.rnn_size
89 |
90 | self.word_lut = nn.Embedding(dicts.size(),
91 | opt.word_vec_size,
92 | padding_idx=onmt.Constants.PAD)
93 | self.rnn = StackedLSTM(opt.layers, input_size, opt.rnn_size, opt.dropout)
94 | self.attn = onmt.modules.GlobalAttention(opt.rnn_size, opt.dot)
95 | self.dropout = nn.Dropout(opt.dropout)
96 |
97 | self.hidden_size = opt.rnn_size
98 |
99 | def load_pretrained_vectors(self, opt):
100 | if opt.pre_word_vecs_dec is not None:
101 | pretrained = torch.load(opt.pre_word_vecs_dec)
102 | self.word_lut.weight.data.copy_(pretrained)
103 |
104 | def forward(self, input, hidden, context, init_output):
105 | emb = self.word_lut(input)
106 |
107 | # n.b. you can increase performance if you compute W_ih * x for all
108 | # iterations in parallel, but that's only possible if
109 | # self.input_feed=False
110 | outputs = []
111 | output = init_output
112 | for emb_t in emb.split(1):
113 | emb_t = self.dropout(emb_t)
114 | output = self.dropout(output)
115 | emb_t = emb_t.squeeze(0)
116 | if self.input_feed:
117 | emb_t = torch.cat([emb_t, output], 1)
118 |
119 | output, hidden = self.rnn(emb_t, hidden)
120 | output, attn = self.attn(output, context.t())
121 | output = self.dropout(output)
122 | outputs += [output]
123 |
124 | outputs = torch.stack(outputs)
125 | return outputs, hidden, attn
126 |
127 |
128 | class NMTModel(nn.Module):
129 |
130 | def __init__(self, encoder, decoder):
131 | super(NMTModel, self).__init__()
132 | self.encoder = encoder
133 | self.decoder = decoder
134 |
135 | def make_init_decoder_output(self, context):
136 | batch_size = context.size(1)
137 | h_size = (batch_size, self.decoder.hidden_size)
138 | return Variable(context.data.new(*h_size).zero_(), requires_grad=False)
139 |
140 | def _fix_enc_hidden(self, h):
141 | # the encoder hidden is (layers*directions) x batch x dim
142 | # we need to convert it to layers x batch x (directions*dim)
143 | if self.encoder.num_directions == 2:
144 | return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
145 | .transpose(1, 2).contiguous() \
146 | .view(h.size(0) // 2, h.size(1), h.size(2) * 2)
147 | else:
148 | return h
149 |
150 | def forward(self, input):
151 | src = input[0]
152 | tgt = input[1][:-1] # exclude last target from inputs
153 | enc_hidden, context = self.encoder(src)
154 | init_output = self.make_init_decoder_output(context)
155 |
156 | enc_hidden = (self._fix_enc_hidden(enc_hidden[0]),
157 | self._fix_enc_hidden(enc_hidden[1]))
158 |
159 | out, dec_hidden, _attn = self.decoder(tgt, enc_hidden, context, init_output)
160 |
161 | return out
162 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/Optim.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.optim as optim
3 | import torch.nn as nn
4 | from torch.nn.utils import clip_grad_norm
5 |
6 | class Optim(object):
7 |
8 | def set_parameters(self, params):
9 | self.params = list(params) # careful: params may be a generator
10 | if self.method == 'sgd':
11 | self.optimizer = optim.SGD(self.params, lr=self.lr)
12 | elif self.method == 'adagrad':
13 | self.optimizer = optim.Adagrad(self.params, lr=self.lr)
14 | elif self.method == 'adadelta':
15 | self.optimizer = optim.Adadelta(self.params, lr=self.lr)
16 | elif self.method == 'adam':
17 | self.optimizer = optim.Adam(self.params, lr=self.lr)
18 | else:
19 | raise RuntimeError("Invalid optim method: " + self.method)
20 |
21 | def __init__(self, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None):
22 | self.last_ppl = None
23 | self.lr = lr
24 | self.max_grad_norm = max_grad_norm
25 | self.method = method
26 | self.lr_decay = lr_decay
27 | self.start_decay_at = start_decay_at
28 |
29 | def step(self):
30 | # Compute gradients norm.
31 | if self.max_grad_norm:
32 | clip_grad_norm(self.params, self.max_grad_norm)
33 | self.optimizer.step()
34 |
35 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit
36 | def updateLearningRate(self, ppl, epoch):
37 | start_decay = False
38 | if self.start_decay_at is not None and epoch >= self.start_decay_at:
39 | start_decay = True
40 | if self.last_ppl is not None and ppl > self.last_ppl:
41 | start_decay = True
42 |
43 | if start_decay:
44 | self.lr = self.lr * self.lr_decay
45 | print("Decaying learning rate to %g" % self.lr)
46 |
47 | self.last_ppl = ppl
48 | self.optimizer.param_groups[0]['lr'] = self.lr
49 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/Translator.py:
--------------------------------------------------------------------------------
1 | import onmt
2 | import torch.nn as nn
3 | import torch
4 | from torch.autograd import Variable
5 |
6 |
7 | class Translator(object):
8 | def __init__(self, opt):
9 | self.opt = opt
10 | self.tt = torch.cuda if opt.cuda else torch
11 |
12 | checkpoint = torch.load(opt.model)
13 |
14 | model_opt = checkpoint['opt']
15 | self.src_dict = checkpoint['dicts']['src']
16 | self.tgt_dict = checkpoint['dicts']['tgt']
17 |
18 | encoder = onmt.Models.Encoder(model_opt, self.src_dict)
19 | decoder = onmt.Models.Decoder(model_opt, self.tgt_dict)
20 | model = onmt.Models.NMTModel(encoder, decoder)
21 |
22 | generator = nn.Sequential(
23 | nn.Linear(model_opt.rnn_size, self.tgt_dict.size()),
24 | nn.LogSoftmax())
25 |
26 | model.load_state_dict(checkpoint['model'])
27 | generator.load_state_dict(checkpoint['generator'])
28 |
29 | if opt.cuda:
30 | model.cuda()
31 | generator.cuda()
32 | else:
33 | model.cpu()
34 | generator.cpu()
35 |
36 | model.generator = generator
37 |
38 | self.model = model
39 | self.model.eval()
40 |
41 |
42 | def buildData(self, srcBatch, goldBatch):
43 | srcData = [self.src_dict.convertToIdx(b,
44 | onmt.Constants.UNK_WORD) for b in srcBatch]
45 | tgtData = None
46 | if goldBatch:
47 | tgtData = [self.tgt_dict.convertToIdx(b,
48 | onmt.Constants.UNK_WORD,
49 | onmt.Constants.BOS_WORD,
50 | onmt.Constants.EOS_WORD) for b in goldBatch]
51 |
52 | return onmt.Dataset(srcData, tgtData,
53 | self.opt.batch_size, self.opt.cuda, volatile=True)
54 |
55 | def buildTargetTokens(self, pred, src, attn):
56 | tokens = self.tgt_dict.convertToLabels(pred, onmt.Constants.EOS)
57 | tokens = tokens[:-1] # EOS
58 | if self.opt.replace_unk:
59 | for i in range(len(tokens)):
60 | if tokens[i] == onmt.Constants.UNK_WORD:
61 | _, maxIndex = attn[i].max(0)
62 | tokens[i] = src[maxIndex[0]]
63 | return tokens
64 |
65 | def translateBatch(self, srcBatch, tgtBatch):
66 | batchSize = srcBatch[0].size(1)
67 | beamSize = self.opt.beam_size
68 |
69 | # (1) run the encoder on the src
70 | encStates, context = self.model.encoder(srcBatch)
71 | srcBatch = srcBatch[0] # drop the lengths needed for encoder
72 |
73 | rnnSize = context.size(2)
74 | encStates = (self.model._fix_enc_hidden(encStates[0]),
75 | self.model._fix_enc_hidden(encStates[1]))
76 |
77 | # This mask is applied to the attention model inside the decoder
78 | # so that the attention ignores source padding
79 | padMask = srcBatch.data.eq(onmt.Constants.PAD).t()
80 | def applyContextMask(m):
81 | if isinstance(m, onmt.modules.GlobalAttention):
82 | m.applyMask(padMask)
83 |
84 | # (2) if a target is specified, compute the 'goldScore'
85 | # (i.e. log likelihood) of the target under the model
86 | goldScores = context.data.new(batchSize).zero_()
87 | if tgtBatch is not None:
88 | decStates = encStates
89 | decOut = self.model.make_init_decoder_output(context)
90 | self.model.decoder.apply(applyContextMask)
91 | initOutput = self.model.make_init_decoder_output(context)
92 |
93 | decOut, decStates, attn = self.model.decoder(
94 | tgtBatch[:-1], decStates, context, initOutput)
95 | for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
96 | gen_t = self.model.generator.forward(dec_t)
97 | tgt_t = tgt_t.unsqueeze(1)
98 | scores = gen_t.data.gather(1, tgt_t)
99 | scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
100 | goldScores += scores
101 |
102 | # (3) run the decoder to generate sentences, using beam search
103 |
104 | # Expand tensors for each beam.
105 | context = Variable(context.data.repeat(1, beamSize, 1))
106 | decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
107 | Variable(encStates[1].data.repeat(1, beamSize, 1)))
108 |
109 | beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]
110 |
111 | decOut = self.model.make_init_decoder_output(context)
112 |
113 | padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1)
114 |
115 | batchIdx = list(range(batchSize))
116 | remainingSents = batchSize
117 | for i in range(self.opt.max_sent_length):
118 |
119 | self.model.decoder.apply(applyContextMask)
120 |
121 | # Prepare decoder input.
122 | input = torch.stack([b.getCurrentState() for b in beam
123 | if not b.done]).t().contiguous().view(1, -1)
124 |
125 | decOut, decStates, attn = self.model.decoder(
126 | Variable(input, volatile=True), decStates, context, decOut)
127 | # decOut: 1 x (beam*batch) x numWords
128 | decOut = decOut.squeeze(0)
129 | out = self.model.generator.forward(decOut)
130 |
131 | # batch x beam x numWords
132 | wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()
133 | attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()
134 |
135 | active = []
136 | for b in range(batchSize):
137 | if beam[b].done:
138 | continue
139 |
140 | idx = batchIdx[b]
141 | if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
142 | active += [b]
143 |
144 | for decState in decStates: # iterate over h, c
145 | # layers x beam*sent x dim
146 | sentStates = decState.view(
147 | -1, beamSize, remainingSents, decState.size(2))[:, :, idx]
148 | sentStates.data.copy_(
149 | sentStates.data.index_select(1, beam[b].getCurrentOrigin()))
150 |
151 | if not active:
152 | break
153 |
154 | # in this section, the sentences that are still active are
155 | # compacted so that the decoder is not run on completed sentences
156 | activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
157 | batchIdx = {beam: idx for idx, beam in enumerate(active)}
158 |
159 | def updateActive(t):
160 | # select only the remaining active sentences
161 | view = t.data.view(-1, remainingSents, rnnSize)
162 | newSize = list(t.size())
163 | newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
164 | return Variable(view.index_select(1, activeIdx) \
165 | .view(*newSize), volatile=True)
166 |
167 | decStates = (updateActive(decStates[0]), updateActive(decStates[1]))
168 | decOut = updateActive(decOut)
169 | context = updateActive(context)
170 | padMask = padMask.index_select(1, activeIdx)
171 |
172 | remainingSents = len(active)
173 |
174 | # (4) package everything up
175 |
176 | allHyp, allScores, allAttn = [], [], []
177 | n_best = self.opt.n_best
178 |
179 | for b in range(batchSize):
180 | scores, ks = beam[b].sortBest()
181 |
182 | allScores += [scores[:n_best]]
183 | valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1)
184 | hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
185 | attn = [a.index_select(1, valid_attn) for a in attn]
186 | allHyp += [hyps]
187 | allAttn += [attn]
188 |
189 | return allHyp, allScores, allAttn, goldScores
190 |
191 | def translate(self, srcBatch, goldBatch):
192 | # (1) convert words to indexes
193 | dataset = self.buildData(srcBatch, goldBatch)
194 | src, tgt, indices = dataset[0]
195 |
196 | # (2) translate
197 | pred, predScore, attn, goldScore = self.translateBatch(src, tgt)
198 | pred, predScore, attn, goldScore = list(zip(*sorted(zip(pred, predScore, attn, goldScore, indices), key=lambda x: x[-1])))[:-1]
199 |
200 | # (3) convert indexes to words
201 | predBatch = []
202 | for b in range(src[0].size(1)):
203 | predBatch.append(
204 | [self.buildTargetTokens(pred[b][n], srcBatch[b], attn[b][n])
205 | for n in range(self.opt.n_best)]
206 | )
207 |
208 | return predBatch, predScore, goldScore
209 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/__init__.py:
--------------------------------------------------------------------------------
1 | import onmt.Constants
2 | import onmt.Models
3 | from onmt.Translator import Translator
4 | from onmt.Dataset import Dataset
5 | from onmt.Optim import Optim
6 | from onmt.Dict import Dict
7 | from onmt.Beam import Beam
8 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/modules/GlobalAttention.py:
--------------------------------------------------------------------------------
1 | """
2 | Global attention takes a matrix and a query vector. It
3 | then computes a parameterized convex combination of the matrix
4 | based on the input query.
5 |
6 |
7 | H_1 H_2 H_3 ... H_n
8 | q q q q
9 | | | | |
10 | \ | | /
11 | .....
12 | \ | /
13 | a
14 |
15 | Constructs a unit mapping.
16 | $$(H_1 + H_n, q) => (a)$$
17 | Where H is of `batch x n x dim` and q is of `batch x dim`.
18 |
19 | The full def is $$\tanh(W_2 [(softmax((W_1 q + b_1) H) H), q] + b_2)$$.:
20 |
21 | """
22 |
23 | import torch
24 | import torch.nn as nn
25 | import math
26 |
27 | class GlobalAttention(nn.Module):
28 | def __init__(self, dim, dot=False):
29 | super(GlobalAttention, self).__init__()
30 | self.linear_in = nn.Linear(dim, dim, bias=False)
31 | self.sm = nn.Softmax()
32 | self.linear_out = nn.Linear(dim*2, dim, bias=False)
33 | self.tanh = nn.Tanh()
34 | self.mask = None
35 | self.dot = dot
36 |
37 | def applyMask(self, mask):
38 | self.mask = mask
39 |
40 | def forward(self, input, context):
41 | """
42 | input: batch x dim
43 | context: batch x sourceL x dim
44 | """
45 | if not self.dot:
46 | targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1
47 | else:
48 | targetT = input.unsqueeze(2)
49 |
50 | # Get attention
51 | attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL
52 | if self.mask is not None:
53 | attn.data.masked_fill_(self.mask, -float('inf'))
54 | attn = self.sm(attn)
55 | attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL
56 |
57 | weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim
58 | contextCombined = torch.cat((weightedContext, input), 1)
59 |
60 | contextOutput = self.tanh(self.linear_out(contextCombined))
61 |
62 | return contextOutput, attn
63 |
--------------------------------------------------------------------------------
/OpenNMT-py/onmt/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from onmt.modules.GlobalAttention import GlobalAttention
2 |
--------------------------------------------------------------------------------
/OpenNMT-py/preprocess.py:
--------------------------------------------------------------------------------
1 | import onmt
2 |
3 | import argparse
4 | import torch
5 |
6 | parser = argparse.ArgumentParser(description='preprocess.py')
7 |
8 | ##
9 | ## **Preprocess Options**
10 | ##
11 |
12 | parser.add_argument('-config', help="Read options from this file")
13 |
14 | parser.add_argument('-train_src', required=True,
15 | help="Path to the training source data")
16 | parser.add_argument('-train_tgt', required=True,
17 | help="Path to the training target data")
18 | parser.add_argument('-valid_src', required=True,
19 | help="Path to the validation source data")
20 | parser.add_argument('-valid_tgt', required=True,
21 | help="Path to the validation target data")
22 |
23 | parser.add_argument('-save_data', required=True,
24 | help="Output file for the prepared data")
25 |
26 | parser.add_argument('-src_vocab_size', type=int, default=50000,
27 | help="Size of the source vocabulary")
28 | parser.add_argument('-tgt_vocab_size', type=int, default=50000,
29 | help="Size of the target vocabulary")
30 | parser.add_argument('-src_vocab',
31 | help="Path to an existing source vocabulary")
32 | parser.add_argument('-tgt_vocab',
33 | help="Path to an existing target vocabulary")
34 |
35 |
36 | parser.add_argument('-seq_length', type=int, default=50,
37 | help="Maximum sequence length")
38 | parser.add_argument('-shuffle', type=int, default=1,
39 | help="Shuffle data")
40 | parser.add_argument('-seed', type=int, default=3435,
41 | help="Random seed")
42 |
43 | parser.add_argument('-lower', action='store_true', help='lowercase data')
44 |
45 | parser.add_argument('-report_every', type=int, default=100000,
46 | help="Report status every this many sentences")
47 |
48 | opt = parser.parse_args()
49 |
50 | torch.manual_seed(opt.seed)
51 |
52 | def makeVocabulary(filename, size):
53 | vocab = onmt.Dict([onmt.Constants.PAD_WORD, onmt.Constants.UNK_WORD,
54 | onmt.Constants.BOS_WORD, onmt.Constants.EOS_WORD], lower=opt.lower)
55 |
56 | with open(filename) as f:
57 | for sent in f.readlines():
58 | for word in sent.split():
59 | vocab.add(word)
60 |
61 | originalSize = vocab.size()
62 | vocab = vocab.prune(size)
63 | print('Created dictionary of size %d (pruned from %d)' %
64 | (vocab.size(), originalSize))
65 |
66 | return vocab
67 |
68 |
69 | def initVocabulary(name, dataFile, vocabFile, vocabSize):
70 |
71 | vocab = None
72 | if vocabFile is not None:
73 | # If given, load existing word dictionary.
74 | print('Reading ' + name + ' vocabulary from \'' + vocabFile + '\'...')
75 | vocab = onmt.Dict()
76 | vocab.loadFile(vocabFile)
77 | print('Loaded ' + vocab.size() + ' ' + name + ' words')
78 |
79 | if vocab is None:
80 | # If a dictionary is still missing, generate it.
81 | print('Building ' + name + ' vocabulary...')
82 | genWordVocab = makeVocabulary(dataFile, vocabSize)
83 |
84 | vocab = genWordVocab
85 |
86 | print()
87 | return vocab
88 |
89 |
90 | def saveVocabulary(name, vocab, file):
91 | print('Saving ' + name + ' vocabulary to \'' + file + '\'...')
92 | vocab.writeFile(file)
93 |
94 |
95 | def makeData(srcFile, tgtFile, srcDicts, tgtDicts):
96 | src, tgt = [], []
97 | sizes = []
98 | count, ignored = 0, 0
99 |
100 | print('Processing %s & %s ...' % (srcFile, tgtFile))
101 | srcF = open(srcFile)
102 | tgtF = open(tgtFile)
103 |
104 | while True:
105 | srcWords = srcF.readline().split()
106 | tgtWords = tgtF.readline().split()
107 |
108 | if not srcWords or not tgtWords:
109 | if srcWords and not tgtWords or not srcWords and tgtWords:
110 | print('WARNING: source and target do not have the same number of sentences')
111 | break
112 |
113 | if len(srcWords) <= opt.seq_length and len(tgtWords) <= opt.seq_length:
114 |
115 | src += [srcDicts.convertToIdx(srcWords,
116 | onmt.Constants.UNK_WORD)]
117 | tgt += [tgtDicts.convertToIdx(tgtWords,
118 | onmt.Constants.UNK_WORD,
119 | onmt.Constants.BOS_WORD,
120 | onmt.Constants.EOS_WORD)]
121 |
122 | sizes += [len(srcWords)]
123 | else:
124 | ignored += 1
125 |
126 | count += 1
127 |
128 | if count % opt.report_every == 0:
129 | print('... %d sentences prepared' % count)
130 |
131 | srcF.close()
132 | tgtF.close()
133 |
134 | if opt.shuffle == 1:
135 | print('... shuffling sentences')
136 | perm = torch.randperm(len(src))
137 | src = [src[idx] for idx in perm]
138 | tgt = [tgt[idx] for idx in perm]
139 | sizes = [sizes[idx] for idx in perm]
140 |
141 | print('... sorting sentences by size')
142 | _, perm = torch.sort(torch.Tensor(sizes))
143 | src = [src[idx] for idx in perm]
144 | tgt = [tgt[idx] for idx in perm]
145 |
146 | print('Prepared %d sentences (%d ignored due to length == 0 or > %d)' %
147 | (len(src), ignored, opt.seq_length))
148 |
149 | return src, tgt
150 |
151 |
152 | def main():
153 |
154 | dicts = {}
155 | dicts['src'] = initVocabulary('source', opt.train_src, opt.src_vocab,
156 | opt.src_vocab_size)
157 | dicts['tgt'] = initVocabulary('target', opt.train_tgt, opt.tgt_vocab,
158 | opt.tgt_vocab_size)
159 |
160 | print('Preparing training ...')
161 | train = {}
162 | train['src'], train['tgt'] = makeData(opt.train_src, opt.train_tgt,
163 | dicts['src'], dicts['tgt'])
164 |
165 | print('Preparing validation ...')
166 | valid = {}
167 | valid['src'], valid['tgt'] = makeData(opt.valid_src, opt.valid_tgt,
168 | dicts['src'], dicts['tgt'])
169 |
170 | if opt.src_vocab is None:
171 | saveVocabulary('source', dicts['src'], opt.save_data + '.src.dict')
172 | if opt.tgt_vocab is None:
173 | saveVocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict')
174 |
175 |
176 | print('Saving data to \'' + opt.save_data + '.train.pt\'...')
177 | save_data = {'dicts': dicts,
178 | 'train': train,
179 | 'valid': valid}
180 | torch.save(save_data, opt.save_data + '.train.pt')
181 |
182 |
183 | if __name__ == "__main__":
184 | main()
185 |
--------------------------------------------------------------------------------
/OpenNMT-py/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import onmt
4 | import argparse
5 | import torch
6 | import torch.nn as nn
7 | from torch import cuda
8 | from torch.autograd import Variable
9 | import math
10 | import time
11 |
12 | parser = argparse.ArgumentParser(description='train.py')
13 |
14 | ## Data options
15 |
16 | parser.add_argument('-data', required=True,
17 | help='Path to the *-train.pt file from preprocess.py')
18 | parser.add_argument('-save_model', default='model',
19 | help="""Model filename (the model will be saved as
20 | _epochN_PPL.pt where PPL is the
21 | validation perplexity""")
22 | parser.add_argument('-train_from_state_dict', default='', type=str,
23 | help="""If training from a checkpoint then this is the
24 | path to the pretrained model's state_dict.""")
25 | parser.add_argument('-train_from', default='', type=str,
26 | help="""If training from a checkpoint then this is the
27 | path to the pretrained model.""")
28 |
29 | ## Model options
30 |
31 | parser.add_argument('-layers', type=int, default=2,
32 | help='Number of layers in the LSTM encoder/decoder')
33 | parser.add_argument('-rnn_size', type=int, default=600,
34 | help='Size of LSTM hidden states')
35 | parser.add_argument('-word_vec_size', type=int, default=400,
36 | help='Word embedding sizes')
37 | parser.add_argument('-input_feed', type=int, default=1,
38 | help="""Feed the context vector at each time step as
39 | additional input (via concatenation with the word
40 | embeddings) to the decoder.""")
41 | # parser.add_argument('-residual', action="store_true",
42 | # help="Add residual connections between RNN layers.")
43 | parser.add_argument('-brnn', action='store_true',
44 | help='Use a bidirectional encoder')
45 | parser.add_argument('-dot', action='store_true',
46 | help='Use dot attention')
47 | parser.add_argument('-brnn_merge', default='concat',
48 | help="""Merge action for the bidirectional hidden states:
49 | [concat|sum]""")
50 |
51 | ## Optimization options
52 |
53 | parser.add_argument('-batch_size', type=int, default=64,
54 | help='Maximum batch size')
55 | parser.add_argument('-max_generator_batches', type=int, default=100,
56 | help="""Maximum batches of words in a sequence to run
57 | the generator on in parallel. Higher is faster, but uses
58 | more memory.""")
59 | parser.add_argument('-epochs', type=int, default=1000,
60 | help='Number of training epochs')
61 | parser.add_argument('-start_epoch', type=int, default=1,
62 | help='The epoch from which to start')
63 | parser.add_argument('-param_init', type=float, default=0.1,
64 | help="""Parameters are initialized over uniform distribution
65 | with support (-param_init, param_init)""")
66 | parser.add_argument('-optim', default='sgd',
67 | help="Optimization method. [sgd|adagrad|adadelta|adam]")
68 | parser.add_argument('-max_grad_norm', type=float, default=5,
69 | help="""If the norm of the gradient vector exceeds this,
70 | renormalize it to have the norm equal to max_grad_norm""")
71 | parser.add_argument('-dropout', type=float, default=0.2,
72 | help='Dropout probability; applied between LSTM stacks.')
73 | parser.add_argument('-curriculum', action="store_true",
74 | help="""For this many epochs, order the minibatches based
75 | on source sequence length. Sometimes setting this to 1 will
76 | increase convergence speed.""")
77 | parser.add_argument('-extra_shuffle', action="store_true",
78 | help="""By default only shuffle mini-batch order; when true,
79 | shuffle and re-assign mini-batches""")
80 |
81 | #learning rate
82 | parser.add_argument('-learning_rate', type=float, default=1.0,
83 | help="""Starting learning rate. If adagrad/adadelta/adam is
84 | used, then this is the global learning rate. Recommended
85 | settings: sgd = 1, adagrad = 0.1, adadelta = 1, adam = 0.001""")
86 | parser.add_argument('-lr_cutoff', type=float, default=0.03,
87 | help='below this the training will be stopped')
88 | parser.add_argument('-learning_rate_decay', type=float, default=0.5,
89 | help="""If update_learning_rate, decay learning rate by
90 | this much if (i) perplexity does not decrease on the
91 | validation set or (ii) epoch has gone past
92 | start_decay_at""")
93 | parser.add_argument('-start_decay_at', type=int, default=1000,
94 | help="""Start decaying every epoch after and including this
95 | epoch""")
96 |
97 | #pretrained word vectors
98 |
99 | parser.add_argument('-pre_word_vecs_enc',
100 | help="""If a valid path is specified, then this will load
101 | pretrained word embeddings on the encoder side.
102 | See README for specific formatting instructions.""")
103 | parser.add_argument('-pre_word_vecs_dec',
104 | help="""If a valid path is specified, then this will load
105 | pretrained word embeddings on the decoder side.
106 | See README for specific formatting instructions.""")
107 | parser.add_argument('-detach_embed', default=0, type=int)
108 | parser.add_argument('-fix_embed', action='store_true')
109 |
110 | # GPU
111 | parser.add_argument('-gpus', default=[0], nargs='+', type=int,
112 | help="Use CUDA on the listed devices.")
113 |
114 | parser.add_argument('-log_interval', type=int, default=50,
115 | help="Print stats at this interval.")
116 |
117 | opt = parser.parse_args()
118 |
119 | print(opt)
120 |
121 | if torch.cuda.is_available() and not opt.gpus:
122 | print("WARNING: You have a CUDA device, so you should probably run with -gpus 0")
123 |
124 | if opt.gpus:
125 | cuda.set_device(opt.gpus[0])
126 |
127 | def NMTCriterion(vocabSize):
128 | weight = torch.ones(vocabSize)
129 | weight[onmt.Constants.PAD] = 0
130 | crit = nn.NLLLoss(weight, size_average=False)
131 | if opt.gpus:
132 | crit.cuda()
133 | return crit
134 |
135 |
136 | def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
137 | # compute generations one piece at a time
138 | num_correct, loss = 0, 0
139 | outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval)
140 |
141 | batch_size = outputs.size(1)
142 | outputs_split = torch.split(outputs, opt.max_generator_batches)
143 | targets_split = torch.split(targets, opt.max_generator_batches)
144 | for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
145 | out_t = out_t.view(-1, out_t.size(2))
146 | scores_t = generator(out_t)
147 | loss_t = crit(scores_t, targ_t.view(-1))
148 | pred_t = scores_t.max(1)[1]
149 | num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(onmt.Constants.PAD).data).sum()
150 | num_correct += num_correct_t
151 | loss += loss_t.data[0]
152 | if not eval:
153 | loss_t.div(batch_size).backward()
154 |
155 | grad_output = None if outputs.grad is None else outputs.grad.data
156 | return loss, grad_output, num_correct
157 |
158 |
159 | def eval(model, criterion, data):
160 | total_loss = 0
161 | total_words = 0
162 | total_num_correct = 0
163 |
164 | model.eval()
165 | for i in range(len(data)):
166 | batch = data[i][:-1] # exclude original indices
167 | outputs = model(batch)
168 | targets = batch[1][1:] # exclude from targets
169 | loss, _, num_correct = memoryEfficientLoss(
170 | outputs, targets, model.generator, criterion, eval=True)
171 | total_loss += loss
172 | total_num_correct += num_correct
173 | total_words += targets.data.ne(onmt.Constants.PAD).sum()
174 |
175 | model.train()
176 | return total_loss / total_words, total_num_correct / total_words
177 |
178 |
179 | def trainModel(model, trainData, validData, dataset, optim):
180 | print(model)
181 | model.train()
182 |
183 | # define criterion of each GPU
184 | criterion = NMTCriterion(dataset['dicts']['tgt'].size())
185 |
186 | start_time = time.time()
187 | def trainEpoch(epoch):
188 |
189 | if opt.extra_shuffle and epoch > opt.curriculum:
190 | trainData.shuffle()
191 |
192 | # shuffle mini batch order
193 | batchOrder = torch.randperm(len(trainData))
194 |
195 | total_loss, total_words, total_num_correct = 0, 0, 0
196 | report_loss, report_tgt_words, report_src_words, report_num_correct = 0, 0, 0, 0
197 | start = time.time()
198 | for i in range(len(trainData)):
199 |
200 | batchIdx = batchOrder[i] if epoch > opt.curriculum else i
201 | batch = trainData[batchIdx][:-1] # exclude original indices
202 |
203 | model.zero_grad()
204 | outputs = model(batch)
205 | targets = batch[1][1:] # exclude from targets
206 | loss, gradOutput, num_correct = memoryEfficientLoss(
207 | outputs, targets, model.generator, criterion)
208 |
209 | outputs.backward(gradOutput)
210 |
211 | # update the parameters
212 | optim.step()
213 |
214 | num_words = targets.data.ne(onmt.Constants.PAD).sum()
215 | report_loss += loss
216 | report_num_correct += num_correct
217 | report_tgt_words += num_words
218 | report_src_words += sum(batch[0][1])
219 | total_loss += loss
220 | total_num_correct += num_correct
221 | total_words += num_words
222 | if i % opt.log_interval == -1 % opt.log_interval:
223 | print("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" %
224 | (epoch, i+1, len(trainData),
225 | report_num_correct / report_tgt_words * 100,
226 | math.exp(report_loss / report_tgt_words),
227 | report_src_words/(time.time()-start),
228 | report_tgt_words/(time.time()-start),
229 | time.time()-start_time))
230 |
231 | report_loss = report_tgt_words = report_src_words = report_num_correct = 0
232 | start = time.time()
233 |
234 | return total_loss / total_words, total_num_correct / total_words
235 |
236 | for epoch in range(opt.start_epoch, opt.epochs + 1):
237 | print('')
238 |
239 | # (1) train for one epoch on the training set
240 | train_loss, train_acc = trainEpoch(epoch)
241 | train_ppl = math.exp(min(train_loss, 100))
242 | print('Train perplexity: %g' % train_ppl)
243 | print('Train accuracy: %g' % (train_acc*100))
244 |
245 | # (2) evaluate on the validation set
246 | valid_loss, valid_acc = eval(model, criterion, validData)
247 | valid_ppl = math.exp(min(valid_loss, 100))
248 | print('Validation perplexity: %g' % valid_ppl)
249 | print('Validation accuracy: %g' % (valid_acc*100))
250 |
251 | # (3) update the learning rate
252 | optim.updateLearningRate(valid_ppl, epoch)
253 |
254 | model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict()
255 | model_state_dict = {k: v for k, v in model_state_dict.items() if 'generator' not in k}
256 | generator_state_dict = model.generator.module.state_dict() if len(opt.gpus) > 1 else model.generator.state_dict()
257 | # (4) drop a checkpoint
258 | checkpoint = {
259 | 'model': model_state_dict,
260 | 'generator': generator_state_dict,
261 | 'dicts': dataset['dicts'],
262 | 'opt': opt,
263 | 'epoch': epoch,
264 | 'optim': optim,
265 | 'ppl': valid_ppl,
266 | 'loss': valid_loss,
267 | 'acc': valid_acc
268 | }
269 | torch.save(checkpoint,
270 | '%s_ppl_%.2f_acc_%.2f_loss_%.2f_e%d.pt' % (opt.save_model, valid_ppl, 100*valid_acc, valid_loss, epoch))
271 | if optim.lr < opt.lr_cutoff:
272 | print('Learning rate decayed below cutoff: {} < {}'.format(optim.lr, opt.lr_cutoff))
273 | break
274 |
275 | def main():
276 |
277 | print("Loading data from '%s'" % opt.data)
278 |
279 | dataset = torch.load(opt.data)
280 |
281 | dict_checkpoint = opt.train_from if opt.train_from else opt.train_from_state_dict
282 | if dict_checkpoint:
283 | print('Loading dicts from checkpoint at %s' % dict_checkpoint)
284 | checkpoint = torch.load(dict_checkpoint)
285 | dataset['dicts'] = checkpoint['dicts']
286 |
287 | trainData = onmt.Dataset(dataset['train']['src'],
288 | dataset['train']['tgt'], opt.batch_size, opt.gpus)
289 | validData = onmt.Dataset(dataset['valid']['src'],
290 | dataset['valid']['tgt'], opt.batch_size, opt.gpus,
291 | volatile=True)
292 |
293 | dicts = dataset['dicts']
294 | print(' * vocabulary size. source = %d; target = %d' %
295 | (dicts['src'].size(), dicts['tgt'].size()))
296 | print(' * number of training sentences. %d' %
297 | len(dataset['train']['src']))
298 | print(' * maximum batch size. %d' % opt.batch_size)
299 |
300 | print('Building model...')
301 |
302 | encoder = onmt.Models.Encoder(opt, dicts['src'])
303 | decoder = onmt.Models.Decoder(opt, dicts['tgt'])
304 |
305 | generator = nn.Sequential(
306 | nn.Linear(opt.rnn_size, dicts['tgt'].size()),
307 | nn.LogSoftmax())
308 |
309 | model = onmt.Models.NMTModel(encoder, decoder)
310 |
311 | if opt.train_from:
312 | print('Loading model from checkpoint at %s' % opt.train_from)
313 | chk_model = checkpoint['model']
314 | generator_state_dict = chk_model.generator.state_dict()
315 | model_state_dict = {k: v for k, v in chk_model.state_dict().items() if 'generator' not in k}
316 | model.load_state_dict(model_state_dict)
317 | generator.load_state_dict(generator_state_dict)
318 | opt.start_epoch = checkpoint['epoch'] + 1
319 |
320 | if opt.train_from_state_dict:
321 | print('Loading model from checkpoint at %s' % opt.train_from_state_dict)
322 | model.load_state_dict(checkpoint['model'])
323 | generator.load_state_dict(checkpoint['generator'])
324 | opt.start_epoch = checkpoint['epoch'] + 1
325 |
326 | if len(opt.gpus) >= 1:
327 | model.cuda()
328 | generator.cuda()
329 | else:
330 | model.cpu()
331 | generator.cpu()
332 |
333 | if len(opt.gpus) > 1:
334 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
335 | generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)
336 |
337 | model.generator = generator
338 |
339 | if not opt.train_from_state_dict and not opt.train_from:
340 | for p in model.parameters():
341 | p.data.uniform_(-opt.param_init, opt.param_init)
342 |
343 | encoder.load_pretrained_vectors(opt)
344 | decoder.load_pretrained_vectors(opt)
345 |
346 | optim = onmt.Optim(
347 | opt.optim, opt.learning_rate, opt.max_grad_norm,
348 | lr_decay=opt.learning_rate_decay,
349 | start_decay_at=opt.start_decay_at
350 | )
351 | else:
352 | print('Loading optimizer from checkpoint:')
353 | optim = checkpoint['optim']
354 | print(optim)
355 |
356 | optim.set_parameters(model.parameters())
357 |
358 | if opt.train_from or opt.train_from_state_dict:
359 | optim.optimizer.load_state_dict(checkpoint['optim'].optimizer.state_dict())
360 |
361 | nParams = sum([p.nelement() for p in model.parameters()])
362 | print('* number of parameters: %d' % nParams)
363 |
364 | trainModel(model, trainData, validData, dataset, optim)
365 |
366 |
367 | if __name__ == "__main__":
368 | main()
369 |
--------------------------------------------------------------------------------
/OpenNMT-py/translate.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import onmt
4 | import torch
5 | import argparse
6 | import math
7 |
8 | parser = argparse.ArgumentParser(description='translate.py')
9 |
10 | parser.add_argument('-model', required=True,
11 | help='Path to model .pt file')
12 | parser.add_argument('-src', required=True,
13 | help='Source sequence to decode (one line per sequence)')
14 | parser.add_argument('-tgt',
15 | help='True target sequence (optional)')
16 | parser.add_argument('-output', default='pred.txt',
17 | help="""Path to output the predictions (each line will
18 | be the decoded sequence""")
19 | parser.add_argument('-beam_size', type=int, default=5,
20 | help='Beam size')
21 | parser.add_argument('-batch_size', type=int, default=30,
22 | help='Batch size')
23 | parser.add_argument('-max_sent_length', default=100,
24 | help='Maximum sentence length.')
25 | parser.add_argument('-replace_unk', action="store_true",
26 | help="""Replace the generated UNK tokens with the source
27 | token that had the highest attention weight. If phrase_table
28 | is provided, it will lookup the identified source token and
29 | give the corresponding target token. If it is not provided
30 | (or the identified source token does not exist in the
31 | table) then it will copy the source token""")
32 | # parser.add_argument('-phrase_table',
33 | # help="""Path to source-target dictionary to replace UNK
34 | # tokens. See README.md for the format of this file.""")
35 | parser.add_argument('-verbose', action="store_true",
36 | help='Print scores and predictions for each sentence')
37 | parser.add_argument('-n_best', type=int, default=1,
38 | help="""If verbose is set, will output the n_best
39 | decoded sentences""")
40 |
41 | parser.add_argument('-gpu', type=int, default=-1,
42 | help="Device to run on")
43 |
44 |
45 |
46 | def reportScore(name, scoreTotal, wordsTotal):
47 | print("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
48 | name, scoreTotal / wordsTotal,
49 | name, math.exp(-scoreTotal/wordsTotal)))
50 |
51 | def addone(f):
52 | for line in f:
53 | yield line
54 | yield None
55 |
56 | def main():
57 | opt = parser.parse_args()
58 | opt.cuda = opt.gpu > -1
59 | if opt.cuda:
60 | torch.cuda.set_device(opt.gpu)
61 |
62 | translator = onmt.Translator(opt)
63 |
64 | outF = open(opt.output, 'w')
65 |
66 | predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0
67 |
68 | srcBatch, tgtBatch = [], []
69 |
70 | count = 0
71 |
72 | tgtF = open(opt.tgt) if opt.tgt else None
73 | for line in addone(open(opt.src)):
74 |
75 | if line is not None:
76 | srcTokens = line.split()
77 | srcBatch += [srcTokens]
78 | if tgtF:
79 | tgtTokens = tgtF.readline().split() if tgtF else None
80 | tgtBatch += [tgtTokens]
81 |
82 | if len(srcBatch) < opt.batch_size:
83 | continue
84 | else:
85 | # at the end of file, check last batch
86 | if len(srcBatch) == 0:
87 | break
88 |
89 | predBatch, predScore, goldScore = translator.translate(srcBatch, tgtBatch)
90 |
91 | predScoreTotal += sum(score[0] for score in predScore)
92 | predWordsTotal += sum(len(x[0]) for x in predBatch)
93 | if tgtF is not None:
94 | goldScoreTotal += sum(goldScore)
95 | goldWordsTotal += sum(len(x) for x in tgtBatch)
96 |
97 | for b in range(len(predBatch)):
98 | count += 1
99 | outF.write(" ".join(predBatch[b][0]) + '\n')
100 | outF.flush()
101 |
102 | if opt.verbose:
103 | srcSent = ' '.join(srcBatch[b])
104 | if translator.tgt_dict.lower:
105 | srcSent = srcSent.lower()
106 | print('SENT %d: %s' % (count, srcSent))
107 | print('PRED %d: %s' % (count, " ".join(predBatch[b][0])))
108 | print("PRED SCORE: %.4f" % predScore[b][0])
109 |
110 | if tgtF is not None:
111 | tgtSent = ' '.join(tgtBatch[b])
112 | if translator.tgt_dict.lower:
113 | tgtSent = tgtSent.lower()
114 | print('GOLD %d: %s ' % (count, tgtSent))
115 | print("GOLD SCORE: %.4f" % goldScore[b])
116 |
117 | if opt.n_best > 1:
118 | print('\nBEST HYP:')
119 | for n in range(opt.n_best):
120 | print("[%.4f] %s" % (predScore[b][n], " ".join(predBatch[b][n])))
121 |
122 | print('')
123 |
124 | srcBatch, tgtBatch = [], []
125 |
126 | reportScore('PRED', predScoreTotal, predWordsTotal)
127 | if tgtF:
128 | reportScore('GOLD', goldScoreTotal, goldWordsTotal)
129 |
130 | if tgtF:
131 | tgtF.close()
132 |
133 |
134 | if __name__ == "__main__":
135 | main()
136 |
--------------------------------------------------------------------------------
/OpenNMT-py/wmt_clean.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import pycld2
3 | import unicodeblock.blocks
4 | from argparse import ArgumentParser
5 |
6 | parser = ArgumentParser()
7 | parser.add_argument('prefix', default='data/wmt17/de-en/')
8 | args = parser.parse_args()
9 |
10 | langs = ('de','en')
11 | lang_fix = '.' + '-'.join(langs)
12 | subsets = 'commoncrawl', 'europarl-v7', 'news-commentary-v12', 'rapid2016'
13 | for x in subsets:
14 | path_prefix = args.prefix + x + lang_fix
15 | paths_in = [path_prefix+'.'+lang for lang in langs]
16 | paths_out = [path_prefix+'.clean.'+lang for lang in langs]
17 | latin = lambda s: all("LATIN" in b or "PUNCT" in b or "DIGIT" in b or "SPAC" in b for b in map(unicodeblock.blocks.of,s) if b is not None)
18 | good_src = lambda s: pycld2.detect(s)[2][0][1] in [langs[0],'un'] and latin(s.decode()) and len(s)>1
19 | good_trg = lambda s: pycld2.detect(s)[2][0][1] in [langs[1],'un'] and latin(s.decode()) and len(s)>1
20 |
21 | with open(paths_in[0],'rb') as src, open(paths_in[1],'rb') as trg, open(paths_out[0],'wb') as src_out, open(paths_out[1],'wb') as trg_out:
22 | for srcline,trgline in zip(src,trg):
23 | try:
24 | if good_src(srcline) and good_trg(trgline):
25 | src_out.write(srcline)
26 | trg_out.write(trgline)
27 | except:
28 | try:
29 | srcline = srcline.decode("utf-8").encode("latin-1")
30 | trgline = trgline.decode("utf-8").encode("latin-1")
31 | try:
32 | if good_src(srcline) and good_trg(trgline):
33 | src_out.write(srcline)
34 | trg_out.write(trgline)
35 | except:
36 | pass
37 | except:
38 | pass
39 |
--------------------------------------------------------------------------------
/OpenNMT-py/wmt_sgm2txt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import subprocess
4 | import xml.etree.ElementTree as ET
5 | from argparse import ArgumentParser
6 |
7 | parser = ArgumentParser(description='sgm2txt')
8 | parser.add_argument('path')
9 | parser.add_argument('-t', '--tags', nargs='+', default=['seg'])
10 | parser.add_argument('-th', '--threads', default=8, type=int)
11 | parser.add_argument('-a', '--aggressive', action='store_true')
12 | args = parser.parse_args()
13 |
14 | def tokenize(f_txt):
15 | lang = os.path.splitext(f_txt)[1][1:]
16 | f_tok = f_txt
17 | if args.aggressive:
18 | f_tok += '.atok'
19 | else:
20 | f_tok += '.tok'
21 | with open(f_tok, 'w') as fout, open(f_txt) as fin:
22 | if args.aggressive:
23 | pipe = subprocess.call(['perl', 'tokenizer.perl', '-a', '-q', '-threads', str(args.threads), '-no-escape', '-l', lang], stdin=fin, stdout=fout)
24 | else:
25 | pipe = subprocess.call(['perl', 'tokenizer.perl', '-q', '-threads', str(args.threads), '-no-escape', '-l', lang], stdin=fin, stdout=fout)
26 |
27 | for f_xml in glob.iglob(os.path.join(args.path, '*.sgm')):
28 | print(f_xml)
29 | f_txt = os.path.splitext(f_xml)[0]
30 | with open(f_txt, 'w') as fd_txt:
31 | root = ET.parse(f_xml).getroot()[0]
32 | for doc in root.findall('doc'):
33 | for tag in args.tags:
34 | for e in doc.findall(tag):
35 | fd_txt.write(e.text.strip() + '\n')
36 | tokenize(f_txt)
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Contextualized Word Vectors (CoVe)
2 |
3 | This repo provides the best, pretrained MT-LSTM from the paper [Learned in Translation: Contextualized Word Vectors (McCann et. al. 2017)](http://papers.nips.cc/paper/7209-learned-in-translation-contextualized-word-vectors.pdf).
4 | For a high-level overview of why CoVe are great, check out the [post](https://einstein.ai/research/learned-in-translation-contextualized-word-vectors).
5 |
6 | This repository uses a [PyTorch](http://pytorch.org/) implementation of the MTLSTM class in mtlstm.py to load a pretrained encoder,
7 | which takes in sequences of vectors pretrained with GloVe and outputs CoVe.
8 |
9 | ## Need CoVe in Tensorflow?
10 |
11 | A Keras/TensorFlow implementation of the MT-LSTM/CoVe can be found at https://github.com/rgsachin/CoVe.
12 |
13 | ## Unknown Words
14 |
15 | Out of vocabulary words for CoVe are also out of vocabulary for GloVe, which should be rare for most use cases. During training the CoVe encoder would have received a zero vector for any words that were not in GloVe, and it used zero vectors for unknown words in our classification and question answering experiments, so that is recommended.
16 |
17 | You could also try initializing unknown inputs to something close to GloVe vectors instead, but we have no experiments suggesting that this would work better than zero vectors. If you wanted to try this, GloVe vectors follow (very roughly) a Gaussian with mean 0 and standard deviation 0.4. You could initialize by randomly drawing from that distrubtion, but you would probably want to train those embeddings while keeping the CoVe encoder (MTLSTM) and GloVe fixed.
18 |
19 | ## Example Usage
20 |
21 | The following example can be found in `test/example.py`. It demonstrates a few different variations of how to use the pretrained MTLSTM class that generates contextualized word vectors (CoVe) programmatically.
22 |
23 | ### Running with Docker
24 |
25 | Install [Docker](https://www.docker.com/get-docker).
26 | Install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) if you would like to use with with a GPU.
27 |
28 | ```bash
29 | docker pull bmccann/cove # pull the docker image
30 | # On CPU
31 | docker run -it --rm -v `pwd`/.embeddings:/.embeddings/ -v `pwd`/.data/:/.data/ bmccann/cove bash -c "python /test/example.py --device -1"
32 | # On GPU
33 | nvidia-docker run -it --rm -v `pwd`/.embeddings:/.embeddings/ -v `pwd`/.data/:/.data/ bmccann/cove bash -c "python /test/example.py"
34 | ```
35 |
36 | ### Running without Docker
37 |
38 | Install [PyTorch](http://pytorch.org/).
39 |
40 | ```bash
41 | git clone https://github.com/salesforce/cove.git # use ssh: git@github.com:salesforce/cove.git
42 | cd cove
43 | pip install -r requirements.txt
44 | python setup.py develop
45 | # On CPU
46 | python test/example.py --device -1
47 | # On GPU
48 | python test/example.py
49 | ```
50 | ## Re-training CoVe
51 |
52 | There is also the third option if you are operating in an entirely different context -- retrain the bidirectional LSTM using trained embeddings. If you are mostly encoding a non-English language, that might be the best option. Check out the paper for details; code for this is included in the directory OpenNMT-py, which was forked from [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) a long while back and includes changes we made to the repo internally.
53 |
54 | ## References
55 |
56 | If using this code, please cite:
57 |
58 | B. McCann, J. Bradbury, C. Xiong, R. Socher, [*Learned in Translation: Contextualized Word Vectors*](http://papers.nips.cc/paper/7209-learned-in-translation-contextualized-word-vectors.pdf)
59 |
60 | ```
61 | @inproceedings{mccann2017learned,
62 | title={Learned in translation: Contextualized word vectors},
63 | author={McCann, Bryan and Bradbury, James and Xiong, Caiming and Socher, Richard},
64 | booktitle={Advances in Neural Information Processing Systems},
65 | pages={6297--6308},
66 | year={2017}
67 | }
68 | ```
69 |
70 | Contact: [bmccann@salesforce.com](mailto:bmccann@salesforce.com)
71 |
--------------------------------------------------------------------------------
/cove/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoder import *
2 |
3 |
4 | _all__ = ['MTLSTM']
5 |
--------------------------------------------------------------------------------
/cove/encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack
6 | from torch.nn.utils.rnn import pack_padded_sequence as pack
7 | import torch.utils.model_zoo as model_zoo
8 |
9 |
10 | model_urls = {
11 | 'wmt-lstm' : 'https://s3.amazonaws.com/research.metamind.io/cove/wmtlstm-8f474287.pth'
12 | }
13 |
14 | MODEL_CACHE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.torch')
15 |
16 |
17 | class MTLSTM(nn.Module):
18 |
19 | def __init__(self, n_vocab=None, vectors=None, residual_embeddings=False, layer0=False, layer1=True, trainable=False, model_cache=MODEL_CACHE):
20 | """Initialize an MTLSTM. If layer0 and layer1 are True, they are concatenated along the last dimension so that layer0 outputs
21 | contribute the first 600 entries and layer1 contributes the second 600 entries. If residual embeddings is also true, inputs
22 | are also concatenated along the last dimension with any outputs such that they form the first 300 entries.
23 |
24 | Arguments:
25 | n_vocab (int): If not None, initialize MTLSTM with an embedding matrix with n_vocab vectors
26 | vectors (Float Tensor): If not None, initialize embedding matrix with specified vectors (These should be 300d CommonCrawl GloVe vectors)
27 | residual_embedding (bool): If True, concatenate the input GloVe embeddings with contextualized word vectors as final output
28 | layer0 (bool): If True, return the outputs of the first layer of the MTLSTM
29 | layer1 (bool): If True, return the outputs of the second layer of the MTLSTM
30 | trainable (bool): If True, do not detach outputs; i.e. train the MTLSTM (recommended to leave False)
31 | model_cache (str): path to the model file for the MTLSTM to load pretrained weights (defaults to the best MTLSTM from (McCann et al. 2017) --
32 | that MTLSTM was trained with 300d 840B GloVe on the WMT 2017 machine translation dataset.
33 | """
34 | super(MTLSTM, self).__init__()
35 | self.layer0 = layer0
36 | self.layer1 = layer1
37 | self.residual_embeddings = residual_embeddings
38 | self.trainable = trainable
39 | self.embed = False
40 | if n_vocab is not None:
41 | self.embed = True
42 | self.vectors = nn.Embedding(n_vocab, 300)
43 | if vectors is not None:
44 | self.vectors.weight.data = vectors
45 | state_dict = model_zoo.load_url(model_urls['wmt-lstm'], model_dir=model_cache)
46 | if layer0:
47 | layer0_dict = {k: v for k, v in state_dict.items() if 'l0' in k}
48 | self.rnn0 = nn.LSTM(300, 300, num_layers=1, bidirectional=True, batch_first=True)
49 | self.rnn0.load_state_dict(layer0_dict)
50 | if layer1:
51 | layer1_dict = {k.replace('l1', 'l0'): v for k, v in state_dict.items() if 'l1' in k}
52 | self.rnn1 = nn.LSTM(600, 300, num_layers=1, bidirectional=True, batch_first=True)
53 | self.rnn1.load_state_dict(layer1_dict)
54 | elif layer1:
55 | self.rnn1 = nn.LSTM(300, 300, num_layers=2, bidirectional=True, batch_first=True)
56 | self.rnn1.load_state_dict(model_zoo.load_url(model_urls['wmt-lstm'], model_dir=model_cache))
57 | else:
58 | raise ValueError('At least one of layer0 and layer1 must be True.')
59 |
60 |
61 | def forward(self, inputs, lengths, hidden=None):
62 | """
63 | Arguments:
64 | inputs (Tensor): If MTLSTM handles embedding, a Long Tensor of size (batch_size, timesteps).
65 | Otherwise, a Float Tensor of size (batch_size, timesteps, features).
66 | lengths (Long Tensor): lenghts of each sequence for handling padding
67 | hidden (Float Tensor): initial hidden state of the LSTM
68 | """
69 | if self.embed:
70 | inputs = self.vectors(inputs)
71 | if not isinstance(lengths, torch.Tensor):
72 | lengths = torch.Tensor(lengths).long()
73 | if inputs.is_cuda:
74 | with torch.cuda.device_of(inputs):
75 | lengths = lengths.cuda(torch.cuda.current_device())
76 | lens, indices = torch.sort(lengths, 0, True)
77 | outputs = [inputs] if self.residual_embeddings else []
78 | len_list = lens.tolist()
79 | packed_inputs = pack(inputs[indices], len_list, batch_first=True)
80 |
81 | if self.layer0:
82 | outputs0, hidden_t0 = self.rnn0(packed_inputs, hidden)
83 | unpacked_outputs0 = unpack(outputs0, batch_first=True)[0]
84 | _, _indices = torch.sort(indices, 0)
85 | unpacked_outputs0 = unpacked_outputs0[_indices]
86 | outputs.append(unpacked_outputs0)
87 | packed_inputs = outputs0
88 | if self.layer1:
89 | outputs1, hidden_t1 = self.rnn1(packed_inputs, hidden)
90 | unpacked_outputs1 = unpack(outputs1, batch_first=True)[0]
91 | _, _indices = torch.sort(indices, 0)
92 | unpacked_outputs1 = unpacked_outputs1[_indices]
93 | outputs.append(unpacked_outputs1)
94 |
95 | outputs = torch.cat(outputs, 2)
96 | return outputs if self.trainable else outputs.detach()
97 |
--------------------------------------------------------------------------------
/get_data.sh:
--------------------------------------------------------------------------------
1 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl
2 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/lowercase.perl
3 | sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl
4 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de
5 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en
6 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl
7 |
8 | mkdir -p data/wmt17
9 | cd data/wmt17
10 | wget http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz
11 | wget http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz
12 | wget http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz
13 | wget http://data.statmt.org/wmt17/translation-task/rapid2016.tgz
14 | wget http://data.statmt.org/wmt17/translation-task/dev.tgz
15 | tar -xzf training-parallel-europarl-v7.tgz
16 | tar -xzf training-parallel-commoncrawl.tgz
17 | tar -xzf training-parallel-nc-v12.tgz
18 | tar -xzf rapid2016.tgz
19 | tar -xzf dev.tgz
20 | mkdir de-en
21 | mv *de-en* de-en
22 | mv training/*de-en* de-en
23 | mv dev/*deen* de-en
24 | mv dev/*ende* de-en
25 | mv dev/*.de de-en
26 | mv dev/*.en de-en
27 | mv dev/newstest2009*.en*
28 | mv dev/news-test2008*.en*
29 |
30 | python ../../wmt_clean.py de-en
31 | for l in de; do for f in de-en/*.clean.$l; do perl ../../tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done
32 | for l in en; do for f in de-en/*.clean.$l; do perl ../../tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done
33 | for l in en de; do for f in de-en/*.clean.$l.tok; do perl ../../lowercase.perl < $f > $f.low; done; done
34 | for l in en de; do perl ../../tokenizer.perl -no-escape -l $l -q < de-en/newstest2013.$l > de-en/newstest2013.$l.tok; done
35 | for l in en de; do perl ../../lowercase.perl < de-en/newstest2013.$l.tok > de-en/newstest2013.$l.tok.low; done
36 | for l in en de; do cat de-en/commoncraw*clean.$l.tok.low de-en/europarl*.clean.$l.tok.low de-en/news-commentary*.clean.$l.tok.low de-en/rapid*.clean.$l.tok.low > de-en/train.clean.$l.tok.low; done
37 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r https://raw.githubusercontent.com/pytorch/text/master/requirements.txt
2 | git+https://github.com/pytorch/text.git
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from setuptools import setup, find_packages
3 | from codecs import open
4 | from os import path
5 |
6 |
7 | with open(path.join(path.abspath(path.dirname(__file__)), 'README.md'), encoding='utf-8') as f:
8 | long_description = f.read()
9 |
10 | setup_info = dict(
11 | name='cove',
12 | version='1.0.0',
13 | author='Bryan McCann',
14 | author_email='Bryan.McCann.is@gmail.com',
15 | url='https://github.com/salesforce/cove',
16 | description='Context Vectors for Deep Learning and NLP',
17 | long_description=long_description,
18 | license='BSD 3-Clause',
19 | keywords='cove, context vectors, deep learning, natural language processing',
20 | packages=find_packages()
21 | )
22 |
23 | setup(**setup_info)
24 |
--------------------------------------------------------------------------------
/test/example.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import numpy as np
3 |
4 | import torch
5 | from torchtext import data
6 | from torchtext import datasets
7 | from torchtext.vocab import GloVe
8 |
9 | from cove import MTLSTM
10 |
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument('--device', default=0, help='Which device to run one; -1 for CPU', type=int)
14 | parser.add_argument('--data', default='.data', help='where to store data')
15 | parser.add_argument('--embeddings', default='.embeddings', help='where to store embeddings')
16 | args = parser.parse_args()
17 |
18 | inputs = data.Field(lower=True, include_lengths=True, batch_first=True)
19 |
20 | print('Generating train, dev, test splits')
21 | train, dev, test = datasets.IWSLT.splits(root=args.data, exts=['.en', '.de'], fields=[inputs, inputs])
22 | train_iter, dev_iter, test_iter = data.Iterator.splits(
23 | (train, dev, test), batch_size=100, device=torch.device(args.device) if args.device >= 0 else None)
24 |
25 | print('Building vocabulary')
26 | inputs.build_vocab(train, dev, test)
27 | inputs.vocab.load_vectors(vectors=GloVe(name='840B', dim=300, cache=args.embeddings))
28 |
29 | outputs_last_layer_cove = MTLSTM(n_vocab=len(inputs.vocab), vectors=inputs.vocab.vectors, model_cache=args.embeddings)
30 | outputs_both_layer_cove = MTLSTM(n_vocab=len(inputs.vocab), vectors=inputs.vocab.vectors, layer0=True, model_cache=args.embeddings)
31 | outputs_both_layer_cove_with_glove = MTLSTM(n_vocab=len(inputs.vocab), vectors=inputs.vocab.vectors, layer0=True, residual_embeddings=True, model_cache=args.embeddings)
32 |
33 | if args.device >= 0:
34 | outputs_last_layer_cove.cuda()
35 | outputs_both_layer_cove.cuda()
36 | outputs_both_layer_cove_with_glove.cuda()
37 |
38 | train_iter.init_epoch()
39 | print('Generating CoVe')
40 | for batch_idx, batch in enumerate(train_iter):
41 | if batch_idx > 0:
42 | break
43 | last_layer_cove = outputs_last_layer_cove(*batch.src)
44 | print(last_layer_cove.size())
45 | first_then_last_layer_cove = outputs_both_layer_cove(*batch.src)
46 | print(first_then_last_layer_cove.size())
47 | glove_then_first_then_last_layer_cove = outputs_both_layer_cove_with_glove(*batch.src)
48 | print(glove_then_first_then_last_layer_cove.size())
49 | assert np.allclose(last_layer_cove, first_then_last_layer_cove[:, :, -600:])
50 | assert np.allclose(last_layer_cove, glove_then_first_then_last_layer_cove[:, :, -600:])
51 | assert np.allclose(first_then_last_layer_cove[:, :, :600], glove_then_first_then_last_layer_cove[:, :, 300:900])
52 | print(last_layer_cove[:, :, -10:])
53 | print(first_then_last_layer_cove[:, :, -10:])
54 | print(glove_then_first_then_last_layer_cove[:, :, -10:])
55 |
--------------------------------------------------------------------------------
/wmt_clean.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import pycld2
3 | import unicodeblock.blocks
4 | from argparse import ArgumentParser
5 |
6 | parser = ArgumentParser()
7 | parser.add_argument('prefix', default='data/wmt17/de-en/')
8 | args = parser.parse_args()
9 |
10 | langs = ('de','en')
11 | lang_fix = '.' + '-'.join(langs)
12 | subsets = 'commoncrawl', 'europarl-v7', 'news-commentary-v12', 'rapid2016'
13 | for x in subsets:
14 | path_prefix = args.prefix + x + lang_fix
15 | paths_in = [path_prefix+'.'+lang for lang in langs]
16 | paths_out = [path_prefix+'.clean.'+lang for lang in langs]
17 | latin = lambda s: all("LATIN" in b or "PUNCT" in b or "DIGIT" in b or "SPAC" in b for b in map(unicodeblock.blocks.of,s) if b is not None)
18 | good_src = lambda s: pycld2.detect(s)[2][0][1] in [langs[0],'un'] and latin(s.decode()) and len(s)>1
19 | good_trg = lambda s: pycld2.detect(s)[2][0][1] in [langs[1],'un'] and latin(s.decode()) and len(s)>1
20 |
21 | with open(paths_in[0],'rb') as src, open(paths_in[1],'rb') as trg, open(paths_out[0],'wb') as src_out, open(paths_out[1],'wb') as trg_out:
22 | for srcline,trgline in zip(src,trg):
23 | try:
24 | if good_src(srcline) and good_trg(trgline):
25 | src_out.write(srcline)
26 | trg_out.write(trgline)
27 | except:
28 | try:
29 | srcline = srcline.decode("utf-8").encode("latin-1")
30 | trgline = trgline.decode("utf-8").encode("latin-1")
31 | try:
32 | if good_src(srcline) and good_trg(trgline):
33 | src_out.write(srcline)
34 | trg_out.write(trgline)
35 | except:
36 | pass
37 | except:
38 | pass
39 |
--------------------------------------------------------------------------------