├── .gitignore ├── CODEOWNERS ├── Dockerfile ├── LICENSE ├── OpenNMT-py ├── .gitignore ├── LICENSE.md ├── README.md ├── cache_embeddings.py ├── corenlp_tokenize.py ├── get_embed_for_dict.py ├── iwslt_xml2txt.py ├── multi30k_corenlp_tokenizer.sh ├── onmt │ ├── Beam.py │ ├── Constants.py │ ├── Dataset.py │ ├── Dict.py │ ├── Models.py │ ├── Optim.py │ ├── Translator.py │ ├── __init__.py │ └── modules │ │ ├── GlobalAttention.py │ │ └── __init__.py ├── preprocess.py ├── train.py ├── translate.py ├── wmt_clean.py └── wmt_sgm2txt.py ├── README.md ├── cove ├── __init__.py └── encoder.py ├── get_data.sh ├── requirements.txt ├── setup.py ├── test └── example.py └── wmt_clean.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # docker build --no-cache multitasking . 2 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 3 | 4 | RUN apt-get update \ 5 | && apt-get install -y --no-install-recommends \ 6 | git \ 7 | ssh \ 8 | build-essential \ 9 | locales \ 10 | ca-certificates \ 11 | curl \ 12 | unzip 13 | 14 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 15 | chmod +x ~/miniconda.sh && \ 16 | ~/miniconda.sh -b -p /opt/conda && \ 17 | rm ~/miniconda.sh && \ 18 | /opt/conda/bin/conda install conda-build python=3.6.3 numpy pyyaml mkl&& \ 19 | /opt/conda/bin/conda clean -ya 20 | ENV PATH /opt/conda/bin:$PATH 21 | 22 | # Default to utf-8 encodings in python 23 | # Can verify in container with: 24 | # python -c 'import locale; print(locale.getpreferredencoding(False))' 25 | RUN locale-gen en_US.UTF-8 26 | ENV LANG en_US.UTF-8 27 | ENV LANGUAGE en_US:en 28 | ENV LC_ALL en_US.UTF-8 29 | 30 | RUN conda install -c pytorch pytorch cuda90 31 | 32 | RUN pip install tqdm 33 | RUN pip install requests 34 | RUN pip install git+https://github.com/pytorch/text.git 35 | ADD ./README.md /README.md 36 | ADD ./cove/ /cove/ 37 | ADD ./setup.py /setup.py 38 | RUN python setup.py develop 39 | 40 | ADD ./test/ /test/ 41 | 42 | CMD python /test/example.py 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Salesforce.com, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /OpenNMT-py/.gitignore: -------------------------------------------------------------------------------- 1 | pred.txt 2 | multi-bleu.perl 3 | *.pt 4 | *.pyc 5 | -------------------------------------------------------------------------------- /OpenNMT-py/LICENSE.md: -------------------------------------------------------------------------------- 1 | This software is derived from the OpenNMT project at 2 | https://github.com/OpenNMT/OpenNMT. 3 | 4 | The MIT License (MIT) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /OpenNMT-py/README.md: -------------------------------------------------------------------------------- 1 | # OpenNMT: Open-Source Neural Machine Translation 2 | 3 | This is a [Pytorch](https://github.com/pytorch/pytorch) 4 | port of [OpenNMT](https://github.com/OpenNMT/OpenNMT), 5 | an open-source (MIT) neural machine translation system. 6 | 7 |
8 | 9 | # Requirements 10 | 11 | ======= 12 | ## Some useful tools: 13 | 14 | The example below uses the Moses tokenizer (http://www.statmt.org/moses/) to prepare the data and the Moses BLEU script for evaluation. 15 | 16 | ```bash 17 | ```bash 18 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl 19 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/lowercase.perl 20 | sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl 21 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de 22 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en 23 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl 24 | ``` 25 | ## WMT'16 Multimodal Translation: Multi30k (de-en) 26 | 27 | An example of training for the WMT'16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html). 28 | 29 | ### 0) Download the data. 30 | 31 | ```bash 32 | mkdir -p data/multi30k 33 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz && tar -xf training.tar.gz -C data/multi30k && rm training.tar.gz 34 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz && tar -xf validation.tar.gz -C data/multi30k && rm validation.tar.gz 35 | wget https://staff.fnwi.uva.nl/d.elliott/wmt16/mmt16_task1_test.tgz && tar -xf mmt16_task1_test.tgz -C data/multi30k && rm mmt16_task1_test.tgz 36 | for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done 37 | ``` 38 | 39 | The last line of the train and validation files is blank, so the last line of the bash above removes the non-compliant lines. 40 | 41 | ### 1) Preprocess the data. 42 | 43 | Moses tokenization without html escaping (add the -a option after -no-escape for aggressive hypen splitting) 44 | 45 | ```bash 46 | for l in en de; do for f in data/multi30k/*.$l; do perl tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done 47 | ``` 48 | 49 | Typically, we lowercase this dataset, as the important comparisons are in uncased BLEU: 50 | 51 | ```bash 52 | for f in data/multi30k/*.tok; do perl lowercase.perl < $f > $f.low; done # if you ran Moses 53 | ``` 54 | 55 | If you would like to use the Moses tokenization for source and target, prepare the data for the model as so: 56 | 57 | ```bash 58 | python preprocess.py -train_src data/multi30k/train.en.tok.low -train_tgt data/multi30k/train.de.tok.low -valid_src data/multi30k/val.en.tok.low -valid_tgt data/multi30k/val.de.tok.low -save_data data/multi30k.tok.low -lower 59 | ``` 60 | 61 | ```bash 62 | ``` 63 | 64 | The extra lower option in the line above will ensure that the vocabulary object converts all words to lowercase before lookup. 65 | 66 | If you would like to use GloVe vectors and character embeddings, now's the time: 67 | 68 | ```bash 69 | python get_embed_for_dict.py data/multi30k.tok.low.src.dict -glove -chargram -d_hid 400 70 | python get_embed_for_dict.py data/multi30k.tok.low.src.dict -glove -d_hid 300 71 | ``` 72 | 73 | ### 2) Train the model. 74 | 75 | ```bash 76 | python train.py -data data/multi30k.tok.low.train.pt -save_model snapshots/multi30k.tok.low.600h.400d.2dp.brnn.2l.fixed_glove_char.model -brnn -pre_word_vecs_enc data/multi30k.tok.low.src.dict.glove.chargram -fix_embed 77 | 78 | python train.py -data data/multi30k.tok.low.train.pt -save_model snapshots/multi30k.tok.low.600h.300d.2dp.brnn.2l.fixed_glove.model -brnn -rnn_size 600 -word_vec_size 300 -pre_word_vecs_enc data/multi30k.tok.low.src.dict.glove -fix_embed 79 | ``` 80 | 81 | ### 3) Translate sentences. 82 | 83 | ```bash 84 | python translate.py -gpu 0 -model model_name -src data/multi30k/test.en.tok.low -tgt data/multi30k/test.de.tok.low -replace_unk -verbose -output multi30k.tok.low.test.pred 85 | 86 | ``` 87 | 88 | ### 4) Evaluate. 89 | 90 | ```bash 91 | perl multi-bleu.perl data/multi30k/test.de.tok.low < multi30k.tok.low.test.pred 92 | ``` 93 | 94 | ## IWSLT'16 (de-en) 95 | 96 | ### 0) Download the data. 97 | 98 | ```bash 99 | mkdir -p data/iwslt16 100 | wget https://wit3.fbk.eu/archive/2016-01//texts/de/en/de-en.tgz && tar -xf de-en.tgz -C data 101 | ``` 102 | 103 | ### 1) Preprocess the data. 104 | 105 | ```bash 106 | python iwslt_xml2txt.py data/de-en 107 | python iwslt_xml2txt.py data/de-en -a 108 | 109 | python preprocess.py -train_src data/de-en/train.de-en.en.tok -train_tgt data/de-en/train.de-en.de.tok -valid_src data/de-en/IWSLT16.TED.tst2013.de-en.en.tok -valid_tgt data/de-en/IWSLT16.TED.tst2013.de-en.de.tok -save_data data/iwslt16.tok.low -lower -src_vocab_size 22822 -tgt_vocab_size 32009 110 | 111 | #Glove Vectors + CharNgrams 112 | python get_embed_for_dict.py data/iwslt16.tok.low.src.dict -glove -chargrams 113 | python get_embed_for_dict.py data/iwslt16.tok.low.src.dict -glove 114 | ``` 115 | 116 | ### 2) Train the model. 117 | 118 | ```bash 119 | python train.py -data data/iwslt16.tok.low.train.pt -save_model snapshots/iwslt16.tok.low.600h.400d.2dp.brnn.2l.fixed_glove_char.model -gpus 0 -brnn -rnn_size 600 -fix_embed -pre_word_vecs_enc data/iwslt16.tok.low.src.dict.glove.chargram > iwslt16.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.log 120 | 121 | python train.py -data data/iwslt16.tok.low.train.pt -save_model snapshots/iwslt16.tok.low.600h.300d.2dp.brnn.2l.fixed_glove_char.model -gpus 0 -brnn -rnn_size 600 -word_vec_size 300 -fix_embed -pre_word_vecs_enc data/iwslt16.tok.low.src.dict.glove > iwslt16.tok.low.600h.300d.2dp.brnn.2l.fixed_glove.log 122 | ``` 123 | 124 | ### 3) Translate sentences. 125 | 126 | ```bash 127 | python translate.py -gpu 0 -model model_name -src data/de-en/IWSLT16.TED.tst2014.de-en.en.tok -tgt data/de-en/IWSLT16.TED.tst2014.de-en.de.tok -replace_unk -verbose -output iwslt.ted.tst2014.de-en.tok.low.pred 128 | ``` 129 | 130 | ### 4) Evaluate. 131 | 132 | ```bash 133 | perl multi-bleu.perl data/de-en/IWSLT16.TED.tst2014.de-en.de.tok < iwslt.ted.tst2014.de-en.tok.low.pred 134 | ``` 135 | 136 | ## WMt'17 (de-en) 137 | 138 | ### 0) Download the data. 139 | 140 | ```bash 141 | mkdir -p data/wmt17 142 | cd data/wmt17 143 | wget http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz 144 | wget http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz 145 | wget http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz 146 | wget http://data.statmt.org/wmt17/translation-task/rapid2016.tgz 147 | wget http://data.statmt.org/wmt17/translation-task/dev.tgz 148 | tar -xzf training-parallel-europarl-v7.tgz 149 | tar -xzf training-parallel-commoncrawl.tgz 150 | tar -xzf training-parallel-nc-v12.tgz 151 | tar -xzf rapid2016.tgz 152 | tar -xzf dev.tgz 153 | mkdir de-en 154 | mv *de-en* de-en 155 | mv training/*de-en* de-en 156 | mv dev/*deen* de-en 157 | mv dev/*ende* de-en 158 | mv dev/*.de de-en 159 | mv dev/*.en de-en 160 | mv dev/newstest2009*.en* 161 | mv dev/news-test2008*.en* 162 | 163 | python ../../wmt_clean.py de-en 164 | for l in de; do for f in de-en/*.clean.$l; do perl ../../tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done 165 | for l in en; do for f in de-en/*.clean.$l; do perl ../../tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done 166 | for l in en de; do for f in de-en/*.clean.$l.tok; do perl ../../lowercase.perl < $f > $f.low; done; done 167 | for l in en de; do perl ../../tokenizer.perl -no-escape -l $l -q < de-en/newstest2013.$l > de-en/newstest2013.$l.tok; done 168 | for l in en de; do perl ../../lowercase.perl < de-en/newstest2013.$l.tok > de-en/newstest2013.$l.tok.low; done 169 | for l in en de; do cat de-en/commoncraw*clean.$l.tok.low de-en/europarl*.clean.$l.tok.low de-en/news-commentary*.clean.$l.tok.low de-en/rapid*.clean.$l.tok.low > de-en/train.clean.$l.tok.low; done 170 | ``` 171 | 172 | ### 1) Preprocess the data. 173 | 174 | ```bash 175 | # News Commentary 176 | python preprocess.py -train_src data/wmt17/de-en/news-commentary-v12.de-en.clean.en.tok.low -train_tgt data/wmt17/de-en/news-commentary-v12.de-en.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/news-commentary.clean.tok.low -lower -seq_length 75 177 | python get_embed_for_dict.py data/news-commentary.clean.tok.low.src.dict -glove -d_hid 300 178 | python get_embed_for_dict.py data/news-commentary.clean.tok.low.src.dict -glove -chargrams -d_hid 400 179 | 180 | # Rapid Fire 181 | python preprocess.py -train_src data/wmt17/de-en/rapid*.clean.en.tok.low -train_tgt data/wmt17/de-en/rapid*.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/rapid.clean.tok.low -lower -seq_length 75 182 | python get_embed_for_dict.py data/rapid.clean.tok.low.src.dict -glove -d_hid 300 183 | python get_embed_for_dict.py data/rapid.clean.tok.low.src.dict -glove -chargrams -d_hid 400 184 | 185 | # Europarl 186 | python preprocess.py -train_src data/wmt17/de-en/europarl*.clean.en.tok.low -train_tgt data/wmt17/de-en/europarl*.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/europarl.clean.tok.low -lower -seq_length 75 187 | python get_embed_for_dict.py data/europarl.clean.tok.low.src.dict -glove -d_hid 300 188 | python get_embed_for_dict.py data/europarl.clean.tok.low.src.dict -glove -chargrams -d_hid 400 189 | 190 | # Common Crawl 191 | python preprocess.py -train_src data/wmt17/de-en/commoncrawl*.clean.en.tok.low -train_tgt data/wmt17/de-en/commoncrawl*.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/commoncrawl.clean.tok.low -lower -seq_length 75 192 | python get_embed_for_dict.py data/commoncrawl.clean.tok.low.src.dict -glove -d_hid 300 193 | python get_embed_for_dict.py data/commoncrawl.clean.tok.low.src.dict -glove -chargrams -d_hid 400 194 | 195 | # WMT'17 196 | python preprocess.py -train_src data/wmt17/de-en/train.clean.en.tok.low -train_tgt data/wmt17/de-en/train.clean.de.tok.low -valid_src data/wmt17/de-en/newstest2013.en.tok.low -valid_tgt data/wmt17/de-en/newstest2013.de.tok.low -save_data data/wmt17.clean.tok.low -lower -seq_length 75 197 | python get_embed_for_dict.py data/wmt17.clean.tok.low.src.dict -glove -d_hid 300 198 | python get_embed_for_dict.py data/wmt17.clean.tok.low.src.dict -glove -chargrams -d_hid 400 199 | ``` 200 | 201 | ### 2) Train the model 202 | 203 | ```bash 204 | # Train fixed glove+char models 205 | for corpus in wmt17 206 | do 207 | python train.py -data data/${corpus}.clean.tok.low.train.pt -save_model snapshots/${corpus}.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.model -gpus 0 -brnn -word_vec_size 400 -pre_word_vecs_enc data/${corpus}.clean.tok.low.src.dict.glove.chargram -fix_embed > logs/${corpus}.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.log 208 | done 209 | 210 | # Train fixed glove models 211 | for corpus in wmt17 212 | do 213 | python train.py -data data/${corpus}.clean.tok.low.train.pt -save_model snapshots/${corpus}.clean.tok.low.600h.300d.2l.brnn.2dp.fixed_glove.model -gpus 0 -brnn -word_vec_size 300 -pre_word_vecs_enc data/${corpus}.clean.tok.low.src.dict.glove -fix_embed > logs/${corpus}.clean.tok.low.600h.300d.2l.brnn.2dp.fixed_glove.log 214 | done 215 | 216 | # Train fixed glove+char models 217 | for corpus in news-commentary rapid europarl commoncrawl 218 | do 219 | python train.py -data data/${corpus}.clean.tok.low.train.pt -save_model snapshots/${corpus}.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.model -gpus 0 -brnn -word_vec_size 400 -pre_word_vecs_enc data/${corpus}.clean.tok.low.src.dict.glove.chargram -fix_embed > logs/${corpus}.clean.tok.low.600h.400d.2l.brnn.2dp.fixed_glove_char.log 220 | done 221 | 222 | # Train fixed glove models 223 | for corpus in news-commentary rapid europarl commoncrawl 224 | do 225 | python train.py -data data/${corpus}.clean.tok.low.train.pt -save_model snapshots/${corpus}.clean.tok.low.600h.300d.2l.brnn.2dp.fixed_glove.model -gpus 0 -brnn -word_vec_size 300 -pre_word_vecs_enc data/${corpus}.clean.tok.low.src.dict.glove -fix_embed > logs/${corpus}.clean.tok.low.600h.300d.2l.brnn.2dp.fixed_glove.log 226 | done 227 | ``` 228 | 229 | ### 3) Translate sentences. 230 | 231 | ### 4) Evaluate. 232 | -------------------------------------------------------------------------------- /OpenNMT-py/cache_embeddings.py: -------------------------------------------------------------------------------- 1 | from os import path, environ, makedirs 2 | import logging 3 | import requests 4 | from argparse import ArgumentParser 5 | 6 | import torch 7 | 8 | 9 | 10 | def cache_glove(glove_prefix): 11 | stoi = {} 12 | itos = [] 13 | vectors = [] 14 | fname = glove_prefix+'.txt' 15 | 16 | with open(fname, 'rb') as f: 17 | for l in f: 18 | l = l.strip().split(b' ') 19 | word, vector = l[0], l[1:] 20 | try: 21 | word = word.decode() 22 | except: 23 | print('non-UTF8 token', repr(word), 'ignored') 24 | continue 25 | stoi[word] = len(itos) 26 | itos.append(word) 27 | vectors.append([float(x) for x in vector]) 28 | d = {'stoi': stoi, 'itos': itos, 'vectors': torch.FloatTensor(vectors)} 29 | torch.save(d, glove_prefix+'.pt') 30 | 31 | def cache_chargrams(): 32 | stoi = {} 33 | itos = [] 34 | vectors = [] 35 | fname = 'kazuma1.emb' 36 | 37 | with open(fname, 'rb') as f: 38 | for l in f: 39 | l = l.strip().split(b' ') 40 | word = l[0] 41 | vector = [float(n) for n in l[1:]] 42 | 43 | try: 44 | word = word.decode() 45 | except: 46 | print('non-UTF8 token', repr(word), 'ignored') 47 | continue 48 | 49 | stoi[word] = len(itos) 50 | itos.append(word) 51 | vectors.append(vector) 52 | 53 | d = {'stoi': stoi, 'itos': itos, 'vectors': torch.FloatTensor(vectors)} 54 | torch.save(d, 'kazuma.100d.pt') 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = ArgumentParser() 59 | parser.add_argument('embeddings') 60 | parser.add_argument('-glove_prefix', default='glove.840B.300d', type=str) 61 | args = parser.parse_args() 62 | 63 | if args.embeddings == 'glove': 64 | cache_glove(args.glove_prefix) 65 | else: 66 | cache_chargrams() 67 | 68 | -------------------------------------------------------------------------------- /OpenNMT-py/corenlp_tokenize.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import time 3 | sys.path.append(os.getcwd()) 4 | 5 | from stanza.nlp.corenlp import CoreNLPClient 6 | import json 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import re 10 | import gzip 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description='preprocess.py') 14 | 15 | ## 16 | ## **Preprocess Options** 17 | ## 18 | 19 | parser.add_argument('-config', help="Read options from this file") 20 | 21 | parser.add_argument('-input-fn', required=True, 22 | help="Path to the input english data") 23 | parser.add_argument('-output-fn', required=True, 24 | help="Path to the output english data") 25 | 26 | parser.add_argument('-src_vocab_size', type=int, default=50000, 27 | help="Size of the source vocabulary") 28 | parser.add_argument('-tgt_vocab_size', type=int, default=50000, 29 | help="Size of the target vocabulary") 30 | parser.add_argument('-src_vocab', 31 | help="Path to an existing source vocabulary") 32 | parser.add_argument('-tgt_vocab', 33 | help="Path to an existing target vocabulary") 34 | 35 | 36 | parser.add_argument('-seq_length', type=int, default=50, 37 | help="Maximum sequence length") 38 | parser.add_argument('-shuffle', type=int, default=1, 39 | help="Shuffle data") 40 | parser.add_argument('-seed', type=int, default=3435, 41 | help="Random seed") 42 | 43 | parser.add_argument('-lower', action='store_true', help='lowercase data') 44 | 45 | parser.add_argument('-report_every', type=int, default=100000, 46 | help="Report status every this many sentences") 47 | 48 | opt = parser.parse_args() 49 | 50 | corenlp = CoreNLPClient(default_annotators=['tokenize', 'ssplit']) 51 | 52 | def annotate_sentence(corenlp, gloss): 53 | try: 54 | parse = corenlp.annotate(gloss) 55 | except: 56 | time.sleep(10) 57 | parse = corenlp.annotate(gloss) 58 | token_str = ' '.join([token['word'] for sentence in parse.json['sentence'] for token in sentence['token'] ]) 59 | #return parse.json['sentence'][0]['token'] 60 | return token_str 61 | 62 | with open(opt.input_fn, 'r') as f: 63 | with open(opt.output_fn, 'w') as f2: 64 | for sent in f.readlines(): 65 | f2.write(annotate_sentence(corenlp, sent)) 66 | f2.write('\n') 67 | -------------------------------------------------------------------------------- /OpenNMT-py/get_embed_for_dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | 4 | parser = ArgumentParser() 5 | parser.add_argument('path') 6 | parser.add_argument('-glove', action='store_true', dest='glove') 7 | parser.add_argument('-small-glove', action='store_true', dest='small_glove') 8 | parser.add_argument('-chargram', action='store_true', dest='chargram') 9 | parser.add_argument('-d_hid', default=400, type=int) 10 | 11 | args = parser.parse_args() 12 | 13 | 14 | def ngrams(sentence, n): 15 | return [sentence[i:i+n] for i in range(len(sentence)-n+1)] 16 | 17 | 18 | def charemb(w): 19 | chars = ['#BEGIN#'] + list(w) + ['#END#'] 20 | match = {} 21 | for i in [2, 3, 4]: 22 | grams = ngrams(chars, i) 23 | for g in grams: 24 | g = '{}gram-{}'.format(i, ''.join(g)) 25 | e = None 26 | if g in kazuma['stoi']: 27 | e = kazuma['vectors'][kazuma['stoi'][g]] 28 | if e is not None: 29 | match[g] = e 30 | if match: 31 | emb = sum(match.values()) / len(match) 32 | else: 33 | emb = torch.FloatTensor(100).uniform_(-0.1, 0.1) 34 | return emb 35 | 36 | 37 | with open(args.path, 'rb') as f: 38 | vocab = [l.strip().split(b' ')[0] for l in f] 39 | 40 | if args.glove: 41 | glove = torch.load('glove.840B.300d.pt') 42 | if args.chargram: 43 | kazuma = torch.load('kazuma.100d.pt') 44 | 45 | vectors = [] 46 | for word in vocab: 47 | vector = torch.FloatTensor(args.d_hid).uniform_(-0.1, 0.1) 48 | try: 49 | word = word.decode() 50 | glove_dim = args.d_hid - 100 if args.chargram else args.d_hid 51 | if args.glove and word in glove['stoi']: 52 | vector[:glove_dim] = glove['vectors'][glove['stoi'][word]] 53 | if args.chargram: 54 | vector[glove_dim:] = charemb(word) 55 | except: 56 | import pdb; pdb.set_trace() 57 | print('non-UTF-8 token', repr(word), 'ignored') 58 | vectors.append(vector) 59 | 60 | ext = '.glove' if args.glove else '' 61 | ext += '.chargram' if args.chargram else '' 62 | torch.save(torch.stack(vectors), args.path + ext) 63 | -------------------------------------------------------------------------------- /OpenNMT-py/iwslt_xml2txt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import subprocess 4 | import xml.etree.ElementTree as ET 5 | from argparse import ArgumentParser 6 | 7 | parser = ArgumentParser(description='xml2txt') 8 | parser.add_argument('path') 9 | parser.add_argument('-t', '--tags', nargs='+', default=['seg']) 10 | parser.add_argument('-th', '--threads', default=8, type=int) 11 | parser.add_argument('-a', '--aggressive', action='store_true') 12 | parser.add_argument('-corenlp', action='store_true') 13 | args = parser.parse_args() 14 | 15 | def tokenize(f_txt): 16 | lang = os.path.splitext(f_txt)[1][1:] 17 | f_tok = f_txt 18 | if args.aggressive: 19 | f_tok += '.atok' 20 | elif args.corenlp and lang == 'en': 21 | f_tok += '.corenlp' 22 | else: 23 | f_tok += '.tok' 24 | with open(f_tok, 'w') as fout, open(f_txt) as fin: 25 | if args.aggressive: 26 | pipe = subprocess.call(['perl', 'tokenizer.perl', '-a', '-q', '-threads', str(args.threads), '-no-escape', '-l', lang], stdin=fin, stdout=fout) 27 | elif args.corenlp and lang=='en': 28 | pipe = subprocess.call(['python', 'corenlp_tokenize.py', '-input-fn', f_txt, '-output-fn', f_tok]) 29 | else: 30 | pipe = subprocess.call(['perl', 'tokenizer.perl', '-q', '-threads', str(args.threads), '-no-escape', '-l', lang], stdin=fin, stdout=fout) 31 | 32 | for f_xml in glob.iglob(os.path.join(args.path, '*.xml')): 33 | print(f_xml) 34 | f_txt = os.path.splitext(f_xml)[0] 35 | with open(f_txt, 'w') as fd_txt: 36 | root = ET.parse(f_xml).getroot()[0] 37 | for doc in root.findall('doc'): 38 | for tag in args.tags: 39 | for e in doc.findall(tag): 40 | fd_txt.write(e.text.strip() + '\n') 41 | tokenize(f_txt) 42 | 43 | xml_tags = [' 0: 63 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 64 | else: 65 | beamLk = wordLk[0] 66 | 67 | flatBeamLk = beamLk.view(-1) 68 | 69 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 70 | self.scores = bestScores 71 | 72 | # bestScoresId is flattened beam x word array, so calculate which 73 | # word and beam each score came from 74 | prevK = bestScoresId / numWords 75 | self.prevKs.append(prevK) 76 | self.nextYs.append(bestScoresId - prevK * numWords) 77 | self.attn.append(attnOut.index_select(0, prevK)) 78 | 79 | # End condition is when top-of-beam is EOS. 80 | if self.nextYs[-1][0] == onmt.Constants.EOS: 81 | self.done = True 82 | 83 | return self.done 84 | 85 | def sortBest(self): 86 | return torch.sort(self.scores, 0, True) 87 | 88 | # Get the score of the best in the beam. 89 | def getBest(self): 90 | scores, ids = self.sortBest() 91 | return scores[1], ids[1] 92 | 93 | # Walk back to construct the full hypothesis. 94 | # 95 | # Parameters. 96 | # 97 | # * `k` - the position in the beam to construct. 98 | # 99 | # Returns. 100 | # 101 | # 1. The hypothesis 102 | # 2. The attention at each time step. 103 | def getHyp(self, k): 104 | hyp, attn = [], [] 105 | # print(len(self.prevKs), len(self.nextYs), len(self.attn)) 106 | for j in range(len(self.prevKs) - 1, -1, -1): 107 | hyp.append(self.nextYs[j+1][k]) 108 | attn.append(self.attn[j][k]) 109 | k = self.prevKs[j][k] 110 | 111 | return hyp[::-1], torch.stack(attn[::-1]) 112 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/Dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import random 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | import onmt 10 | 11 | 12 | class Dataset(object): 13 | 14 | def __init__(self, srcData, tgtData, batchSize, cuda, volatile=False): 15 | self.src = srcData 16 | if tgtData: 17 | self.tgt = tgtData 18 | assert(len(self.src) == len(self.tgt)) 19 | else: 20 | self.tgt = None 21 | self.cuda = cuda 22 | 23 | self.batchSize = batchSize 24 | self.numBatches = math.ceil(len(self.src)/batchSize) 25 | self.volatile = volatile 26 | 27 | def _batchify(self, data, align_right=False, include_lengths=False): 28 | lengths = [x.size(0) for x in data] 29 | max_length = max(lengths) 30 | out = data[0].new(len(data), max_length).fill_(onmt.Constants.PAD) 31 | for i in range(len(data)): 32 | data_length = data[i].size(0) 33 | offset = max_length - data_length if align_right else 0 34 | out[i].narrow(0, offset, data_length).copy_(data[i]) 35 | 36 | if include_lengths: 37 | return out, lengths 38 | else: 39 | return out 40 | 41 | def __getitem__(self, index): 42 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches) 43 | srcBatch, lengths = self._batchify( 44 | self.src[index*self.batchSize:(index+1)*self.batchSize], 45 | align_right=False, include_lengths=True) 46 | 47 | if self.tgt: 48 | tgtBatch = self._batchify( 49 | self.tgt[index*self.batchSize:(index+1)*self.batchSize]) 50 | else: 51 | tgtBatch = None 52 | 53 | # within batch sorting by decreasing length for variable length rnns 54 | indices = range(len(srcBatch)) 55 | batch = zip(indices, srcBatch) if tgtBatch is None else zip(indices, srcBatch, tgtBatch) 56 | batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1])) 57 | if tgtBatch is None: 58 | indices, srcBatch = zip(*batch) 59 | else: 60 | indices, srcBatch, tgtBatch = zip(*batch) 61 | 62 | def wrap(b): 63 | if b is None: 64 | return b 65 | b = torch.stack(b, 0).t().contiguous() 66 | if self.cuda: 67 | b = b.cuda() 68 | b = Variable(b, volatile=self.volatile) 69 | return b 70 | 71 | return (wrap(srcBatch), lengths), wrap(tgtBatch), indices 72 | 73 | def __len__(self): 74 | return self.numBatches 75 | 76 | 77 | def shuffle(self): 78 | data = list(zip(self.src, self.tgt)) 79 | self.src, self.tgt = zip(*[data[i] for i in torch.randperm(len(data))]) 80 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/Dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Dict(object): 5 | def __init__(self, data=None, lower=False): 6 | self.idxToLabel = {} 7 | self.labelToIdx = {} 8 | self.frequencies = {} 9 | self.lower = lower 10 | 11 | # Special entries will not be pruned. 12 | self.special = [] 13 | 14 | if data is not None: 15 | if type(data) == str: 16 | self.loadFile(data) 17 | else: 18 | self.addSpecials(data) 19 | 20 | def size(self): 21 | return len(self.idxToLabel) 22 | 23 | # Load entries from a file. 24 | def loadFile(self, filename): 25 | for line in open(filename): 26 | fields = line.split() 27 | label = fields[0] 28 | idx = int(fields[1]) 29 | self.add(label, idx) 30 | 31 | # Write entries to a file. 32 | def writeFile(self, filename): 33 | with open(filename, 'w') as file: 34 | for i in range(self.size()): 35 | label = self.idxToLabel[i] 36 | file.write('%s %d\n' % (label, i)) 37 | 38 | file.close() 39 | 40 | def lookup(self, key, default=None): 41 | key = key.lower() if self.lower else key 42 | try: 43 | return self.labelToIdx[key] 44 | except KeyError: 45 | return default 46 | 47 | def getLabel(self, idx, default=None): 48 | try: 49 | return self.idxToLabel[idx] 50 | except KeyError: 51 | return default 52 | 53 | # Mark this `label` and `idx` as special (i.e. will not be pruned). 54 | def addSpecial(self, label, idx=None): 55 | idx = self.add(label, idx) 56 | self.special += [idx] 57 | 58 | # Mark all labels in `labels` as specials (i.e. will not be pruned). 59 | def addSpecials(self, labels): 60 | for label in labels: 61 | self.addSpecial(label) 62 | 63 | # Add `label` in the dictionary. Use `idx` as its index if given. 64 | def add(self, label, idx=None): 65 | label = label.lower() if self.lower else label 66 | if idx is not None: 67 | self.idxToLabel[idx] = label 68 | self.labelToIdx[label] = idx 69 | else: 70 | if label in self.labelToIdx: 71 | idx = self.labelToIdx[label] 72 | else: 73 | idx = len(self.idxToLabel) 74 | self.idxToLabel[idx] = label 75 | self.labelToIdx[label] = idx 76 | 77 | if idx not in self.frequencies: 78 | self.frequencies[idx] = 1 79 | else: 80 | self.frequencies[idx] += 1 81 | 82 | return idx 83 | 84 | # Return a new dictionary with the `size` most frequent entries. 85 | def prune(self, size): 86 | if size >= self.size(): 87 | return self 88 | 89 | # Only keep the `size` most frequent entries. 90 | freq = torch.Tensor( 91 | [self.frequencies[i] for i in range(len(self.frequencies))]) 92 | _, idx = torch.sort(freq, 0, True) 93 | 94 | newDict = Dict() 95 | newDict.lower = self.lower 96 | 97 | # Add special entries in all cases. 98 | for i in self.special: 99 | newDict.addSpecial(self.idxToLabel[i]) 100 | 101 | for i in idx[:size]: 102 | newDict.add(self.idxToLabel[i]) 103 | 104 | return newDict 105 | 106 | # Convert `labels` to indices. Use `unkWord` if not found. 107 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 108 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 109 | vec = [] 110 | 111 | if bosWord is not None: 112 | vec += [self.lookup(bosWord)] 113 | 114 | unk = self.lookup(unkWord) 115 | vec += [self.lookup(label, default=unk) for label in labels] 116 | 117 | if eosWord is not None: 118 | vec += [self.lookup(eosWord)] 119 | 120 | return torch.LongTensor(vec) 121 | 122 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 123 | def convertToLabels(self, idx, stop): 124 | labels = [] 125 | 126 | for i in idx: 127 | labels += [self.getLabel(i)] 128 | if i == stop: 129 | break 130 | 131 | return labels 132 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import onmt.modules 5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | 8 | class Encoder(nn.Module): 9 | 10 | def __init__(self, opt, dicts): 11 | super(Encoder, self).__init__() 12 | self.detach_embed = opt.detach_embed if hasattr(opt, 'detach_embed') else 0 13 | self.fix_embed = opt.fix_embed 14 | self.count = 0 15 | self.layers = opt.layers 16 | self.dropout = nn.Dropout(opt.dropout) 17 | self.num_directions = 2 if opt.brnn else 1 18 | assert opt.rnn_size % self.num_directions == 0 19 | self.hidden_size = opt.rnn_size // self.num_directions 20 | input_size = opt.word_vec_size 21 | 22 | self.word_lut = nn.Embedding(dicts.size(), 23 | opt.word_vec_size, 24 | padding_idx=onmt.Constants.PAD) 25 | self.rnn = nn.LSTM(input_size, self.hidden_size, 26 | num_layers=opt.layers, 27 | dropout=opt.dropout, 28 | bidirectional=opt.brnn) 29 | 30 | def load_pretrained_vectors(self, opt): 31 | if opt.pre_word_vecs_enc is not None: 32 | pretrained = torch.load(opt.pre_word_vecs_enc) 33 | self.word_lut.weight.data.copy_(pretrained) 34 | 35 | def forward(self, input, hidden=None): 36 | if isinstance(input, tuple): 37 | emb = self.word_lut(input[0]) 38 | else: 39 | emb = self.word_lut(input) 40 | if self.fix_embed or self.count < self.detach_embed: 41 | emb = emb.detach() 42 | emb = self.dropout(emb) 43 | if isinstance(input, tuple): 44 | emb = pack(emb, input[1]) 45 | outputs, hidden_t = self.rnn(emb, hidden) 46 | if isinstance(input, tuple): 47 | outputs = self.dropout(unpack(outputs)[0]) 48 | self.count += 1 49 | return hidden_t, outputs 50 | 51 | 52 | class StackedLSTM(nn.Module): 53 | def __init__(self, num_layers, input_size, rnn_size, dropout): 54 | super(StackedLSTM, self).__init__() 55 | self.dropout = nn.Dropout(dropout) 56 | self.num_layers = num_layers 57 | self.layers = nn.ModuleList() 58 | 59 | for i in range(num_layers): 60 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 61 | input_size = rnn_size 62 | 63 | def forward(self, input, hidden): 64 | h_0, c_0 = hidden 65 | h_1, c_1 = [], [] 66 | for i, layer in enumerate(self.layers): 67 | input = self.dropout(input) 68 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 69 | input = h_1_i 70 | h_1 += [h_1_i] 71 | c_1 += [c_1_i] 72 | 73 | h_1 = torch.stack(h_1) 74 | c_1 = torch.stack(c_1) 75 | 76 | return input, (h_1, c_1) 77 | 78 | 79 | class Decoder(nn.Module): 80 | 81 | def __init__(self, opt, dicts): 82 | super(Decoder, self).__init__() 83 | self.layers = opt.layers 84 | self.input_feed = opt.input_feed 85 | input_size = opt.word_vec_size 86 | self.dropout = nn.Dropout(opt.dropout) 87 | if self.input_feed: 88 | input_size += opt.rnn_size 89 | 90 | self.word_lut = nn.Embedding(dicts.size(), 91 | opt.word_vec_size, 92 | padding_idx=onmt.Constants.PAD) 93 | self.rnn = StackedLSTM(opt.layers, input_size, opt.rnn_size, opt.dropout) 94 | self.attn = onmt.modules.GlobalAttention(opt.rnn_size, opt.dot) 95 | self.dropout = nn.Dropout(opt.dropout) 96 | 97 | self.hidden_size = opt.rnn_size 98 | 99 | def load_pretrained_vectors(self, opt): 100 | if opt.pre_word_vecs_dec is not None: 101 | pretrained = torch.load(opt.pre_word_vecs_dec) 102 | self.word_lut.weight.data.copy_(pretrained) 103 | 104 | def forward(self, input, hidden, context, init_output): 105 | emb = self.word_lut(input) 106 | 107 | # n.b. you can increase performance if you compute W_ih * x for all 108 | # iterations in parallel, but that's only possible if 109 | # self.input_feed=False 110 | outputs = [] 111 | output = init_output 112 | for emb_t in emb.split(1): 113 | emb_t = self.dropout(emb_t) 114 | output = self.dropout(output) 115 | emb_t = emb_t.squeeze(0) 116 | if self.input_feed: 117 | emb_t = torch.cat([emb_t, output], 1) 118 | 119 | output, hidden = self.rnn(emb_t, hidden) 120 | output, attn = self.attn(output, context.t()) 121 | output = self.dropout(output) 122 | outputs += [output] 123 | 124 | outputs = torch.stack(outputs) 125 | return outputs, hidden, attn 126 | 127 | 128 | class NMTModel(nn.Module): 129 | 130 | def __init__(self, encoder, decoder): 131 | super(NMTModel, self).__init__() 132 | self.encoder = encoder 133 | self.decoder = decoder 134 | 135 | def make_init_decoder_output(self, context): 136 | batch_size = context.size(1) 137 | h_size = (batch_size, self.decoder.hidden_size) 138 | return Variable(context.data.new(*h_size).zero_(), requires_grad=False) 139 | 140 | def _fix_enc_hidden(self, h): 141 | # the encoder hidden is (layers*directions) x batch x dim 142 | # we need to convert it to layers x batch x (directions*dim) 143 | if self.encoder.num_directions == 2: 144 | return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ 145 | .transpose(1, 2).contiguous() \ 146 | .view(h.size(0) // 2, h.size(1), h.size(2) * 2) 147 | else: 148 | return h 149 | 150 | def forward(self, input): 151 | src = input[0] 152 | tgt = input[1][:-1] # exclude last target from inputs 153 | enc_hidden, context = self.encoder(src) 154 | init_output = self.make_init_decoder_output(context) 155 | 156 | enc_hidden = (self._fix_enc_hidden(enc_hidden[0]), 157 | self._fix_enc_hidden(enc_hidden[1])) 158 | 159 | out, dec_hidden, _attn = self.decoder(tgt, enc_hidden, context, init_output) 160 | 161 | return out 162 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/Optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.nn.utils import clip_grad_norm 5 | 6 | class Optim(object): 7 | 8 | def set_parameters(self, params): 9 | self.params = list(params) # careful: params may be a generator 10 | if self.method == 'sgd': 11 | self.optimizer = optim.SGD(self.params, lr=self.lr) 12 | elif self.method == 'adagrad': 13 | self.optimizer = optim.Adagrad(self.params, lr=self.lr) 14 | elif self.method == 'adadelta': 15 | self.optimizer = optim.Adadelta(self.params, lr=self.lr) 16 | elif self.method == 'adam': 17 | self.optimizer = optim.Adam(self.params, lr=self.lr) 18 | else: 19 | raise RuntimeError("Invalid optim method: " + self.method) 20 | 21 | def __init__(self, method, lr, max_grad_norm, lr_decay=1, start_decay_at=None): 22 | self.last_ppl = None 23 | self.lr = lr 24 | self.max_grad_norm = max_grad_norm 25 | self.method = method 26 | self.lr_decay = lr_decay 27 | self.start_decay_at = start_decay_at 28 | 29 | def step(self): 30 | # Compute gradients norm. 31 | if self.max_grad_norm: 32 | clip_grad_norm(self.params, self.max_grad_norm) 33 | self.optimizer.step() 34 | 35 | # decay learning rate if val perf does not improve or we hit the start_decay_at limit 36 | def updateLearningRate(self, ppl, epoch): 37 | start_decay = False 38 | if self.start_decay_at is not None and epoch >= self.start_decay_at: 39 | start_decay = True 40 | if self.last_ppl is not None and ppl > self.last_ppl: 41 | start_decay = True 42 | 43 | if start_decay: 44 | self.lr = self.lr * self.lr_decay 45 | print("Decaying learning rate to %g" % self.lr) 46 | 47 | self.last_ppl = ppl 48 | self.optimizer.param_groups[0]['lr'] = self.lr 49 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/Translator.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | import torch.nn as nn 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class Translator(object): 8 | def __init__(self, opt): 9 | self.opt = opt 10 | self.tt = torch.cuda if opt.cuda else torch 11 | 12 | checkpoint = torch.load(opt.model) 13 | 14 | model_opt = checkpoint['opt'] 15 | self.src_dict = checkpoint['dicts']['src'] 16 | self.tgt_dict = checkpoint['dicts']['tgt'] 17 | 18 | encoder = onmt.Models.Encoder(model_opt, self.src_dict) 19 | decoder = onmt.Models.Decoder(model_opt, self.tgt_dict) 20 | model = onmt.Models.NMTModel(encoder, decoder) 21 | 22 | generator = nn.Sequential( 23 | nn.Linear(model_opt.rnn_size, self.tgt_dict.size()), 24 | nn.LogSoftmax()) 25 | 26 | model.load_state_dict(checkpoint['model']) 27 | generator.load_state_dict(checkpoint['generator']) 28 | 29 | if opt.cuda: 30 | model.cuda() 31 | generator.cuda() 32 | else: 33 | model.cpu() 34 | generator.cpu() 35 | 36 | model.generator = generator 37 | 38 | self.model = model 39 | self.model.eval() 40 | 41 | 42 | def buildData(self, srcBatch, goldBatch): 43 | srcData = [self.src_dict.convertToIdx(b, 44 | onmt.Constants.UNK_WORD) for b in srcBatch] 45 | tgtData = None 46 | if goldBatch: 47 | tgtData = [self.tgt_dict.convertToIdx(b, 48 | onmt.Constants.UNK_WORD, 49 | onmt.Constants.BOS_WORD, 50 | onmt.Constants.EOS_WORD) for b in goldBatch] 51 | 52 | return onmt.Dataset(srcData, tgtData, 53 | self.opt.batch_size, self.opt.cuda, volatile=True) 54 | 55 | def buildTargetTokens(self, pred, src, attn): 56 | tokens = self.tgt_dict.convertToLabels(pred, onmt.Constants.EOS) 57 | tokens = tokens[:-1] # EOS 58 | if self.opt.replace_unk: 59 | for i in range(len(tokens)): 60 | if tokens[i] == onmt.Constants.UNK_WORD: 61 | _, maxIndex = attn[i].max(0) 62 | tokens[i] = src[maxIndex[0]] 63 | return tokens 64 | 65 | def translateBatch(self, srcBatch, tgtBatch): 66 | batchSize = srcBatch[0].size(1) 67 | beamSize = self.opt.beam_size 68 | 69 | # (1) run the encoder on the src 70 | encStates, context = self.model.encoder(srcBatch) 71 | srcBatch = srcBatch[0] # drop the lengths needed for encoder 72 | 73 | rnnSize = context.size(2) 74 | encStates = (self.model._fix_enc_hidden(encStates[0]), 75 | self.model._fix_enc_hidden(encStates[1])) 76 | 77 | # This mask is applied to the attention model inside the decoder 78 | # so that the attention ignores source padding 79 | padMask = srcBatch.data.eq(onmt.Constants.PAD).t() 80 | def applyContextMask(m): 81 | if isinstance(m, onmt.modules.GlobalAttention): 82 | m.applyMask(padMask) 83 | 84 | # (2) if a target is specified, compute the 'goldScore' 85 | # (i.e. log likelihood) of the target under the model 86 | goldScores = context.data.new(batchSize).zero_() 87 | if tgtBatch is not None: 88 | decStates = encStates 89 | decOut = self.model.make_init_decoder_output(context) 90 | self.model.decoder.apply(applyContextMask) 91 | initOutput = self.model.make_init_decoder_output(context) 92 | 93 | decOut, decStates, attn = self.model.decoder( 94 | tgtBatch[:-1], decStates, context, initOutput) 95 | for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): 96 | gen_t = self.model.generator.forward(dec_t) 97 | tgt_t = tgt_t.unsqueeze(1) 98 | scores = gen_t.data.gather(1, tgt_t) 99 | scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) 100 | goldScores += scores 101 | 102 | # (3) run the decoder to generate sentences, using beam search 103 | 104 | # Expand tensors for each beam. 105 | context = Variable(context.data.repeat(1, beamSize, 1)) 106 | decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)), 107 | Variable(encStates[1].data.repeat(1, beamSize, 1))) 108 | 109 | beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] 110 | 111 | decOut = self.model.make_init_decoder_output(context) 112 | 113 | padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1) 114 | 115 | batchIdx = list(range(batchSize)) 116 | remainingSents = batchSize 117 | for i in range(self.opt.max_sent_length): 118 | 119 | self.model.decoder.apply(applyContextMask) 120 | 121 | # Prepare decoder input. 122 | input = torch.stack([b.getCurrentState() for b in beam 123 | if not b.done]).t().contiguous().view(1, -1) 124 | 125 | decOut, decStates, attn = self.model.decoder( 126 | Variable(input, volatile=True), decStates, context, decOut) 127 | # decOut: 1 x (beam*batch) x numWords 128 | decOut = decOut.squeeze(0) 129 | out = self.model.generator.forward(decOut) 130 | 131 | # batch x beam x numWords 132 | wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() 133 | attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() 134 | 135 | active = [] 136 | for b in range(batchSize): 137 | if beam[b].done: 138 | continue 139 | 140 | idx = batchIdx[b] 141 | if not beam[b].advance(wordLk.data[idx], attn.data[idx]): 142 | active += [b] 143 | 144 | for decState in decStates: # iterate over h, c 145 | # layers x beam*sent x dim 146 | sentStates = decState.view( 147 | -1, beamSize, remainingSents, decState.size(2))[:, :, idx] 148 | sentStates.data.copy_( 149 | sentStates.data.index_select(1, beam[b].getCurrentOrigin())) 150 | 151 | if not active: 152 | break 153 | 154 | # in this section, the sentences that are still active are 155 | # compacted so that the decoder is not run on completed sentences 156 | activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) 157 | batchIdx = {beam: idx for idx, beam in enumerate(active)} 158 | 159 | def updateActive(t): 160 | # select only the remaining active sentences 161 | view = t.data.view(-1, remainingSents, rnnSize) 162 | newSize = list(t.size()) 163 | newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents 164 | return Variable(view.index_select(1, activeIdx) \ 165 | .view(*newSize), volatile=True) 166 | 167 | decStates = (updateActive(decStates[0]), updateActive(decStates[1])) 168 | decOut = updateActive(decOut) 169 | context = updateActive(context) 170 | padMask = padMask.index_select(1, activeIdx) 171 | 172 | remainingSents = len(active) 173 | 174 | # (4) package everything up 175 | 176 | allHyp, allScores, allAttn = [], [], [] 177 | n_best = self.opt.n_best 178 | 179 | for b in range(batchSize): 180 | scores, ks = beam[b].sortBest() 181 | 182 | allScores += [scores[:n_best]] 183 | valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1) 184 | hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) 185 | attn = [a.index_select(1, valid_attn) for a in attn] 186 | allHyp += [hyps] 187 | allAttn += [attn] 188 | 189 | return allHyp, allScores, allAttn, goldScores 190 | 191 | def translate(self, srcBatch, goldBatch): 192 | # (1) convert words to indexes 193 | dataset = self.buildData(srcBatch, goldBatch) 194 | src, tgt, indices = dataset[0] 195 | 196 | # (2) translate 197 | pred, predScore, attn, goldScore = self.translateBatch(src, tgt) 198 | pred, predScore, attn, goldScore = list(zip(*sorted(zip(pred, predScore, attn, goldScore, indices), key=lambda x: x[-1])))[:-1] 199 | 200 | # (3) convert indexes to words 201 | predBatch = [] 202 | for b in range(src[0].size(1)): 203 | predBatch.append( 204 | [self.buildTargetTokens(pred[b][n], srcBatch[b], attn[b][n]) 205 | for n in range(self.opt.n_best)] 206 | ) 207 | 208 | return predBatch, predScore, goldScore 209 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/__init__.py: -------------------------------------------------------------------------------- 1 | import onmt.Constants 2 | import onmt.Models 3 | from onmt.Translator import Translator 4 | from onmt.Dataset import Dataset 5 | from onmt.Optim import Optim 6 | from onmt.Dict import Dict 7 | from onmt.Beam import Beam 8 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/modules/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global attention takes a matrix and a query vector. It 3 | then computes a parameterized convex combination of the matrix 4 | based on the input query. 5 | 6 | 7 | H_1 H_2 H_3 ... H_n 8 | q q q q 9 | | | | | 10 | \ | | / 11 | ..... 12 | \ | / 13 | a 14 | 15 | Constructs a unit mapping. 16 | $$(H_1 + H_n, q) => (a)$$ 17 | Where H is of `batch x n x dim` and q is of `batch x dim`. 18 | 19 | The full def is $$\tanh(W_2 [(softmax((W_1 q + b_1) H) H), q] + b_2)$$.: 20 | 21 | """ 22 | 23 | import torch 24 | import torch.nn as nn 25 | import math 26 | 27 | class GlobalAttention(nn.Module): 28 | def __init__(self, dim, dot=False): 29 | super(GlobalAttention, self).__init__() 30 | self.linear_in = nn.Linear(dim, dim, bias=False) 31 | self.sm = nn.Softmax() 32 | self.linear_out = nn.Linear(dim*2, dim, bias=False) 33 | self.tanh = nn.Tanh() 34 | self.mask = None 35 | self.dot = dot 36 | 37 | def applyMask(self, mask): 38 | self.mask = mask 39 | 40 | def forward(self, input, context): 41 | """ 42 | input: batch x dim 43 | context: batch x sourceL x dim 44 | """ 45 | if not self.dot: 46 | targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1 47 | else: 48 | targetT = input.unsqueeze(2) 49 | 50 | # Get attention 51 | attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL 52 | if self.mask is not None: 53 | attn.data.masked_fill_(self.mask, -float('inf')) 54 | attn = self.sm(attn) 55 | attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL 56 | 57 | weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim 58 | contextCombined = torch.cat((weightedContext, input), 1) 59 | 60 | contextOutput = self.tanh(self.linear_out(contextCombined)) 61 | 62 | return contextOutput, attn 63 | -------------------------------------------------------------------------------- /OpenNMT-py/onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from onmt.modules.GlobalAttention import GlobalAttention 2 | -------------------------------------------------------------------------------- /OpenNMT-py/preprocess.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | 3 | import argparse 4 | import torch 5 | 6 | parser = argparse.ArgumentParser(description='preprocess.py') 7 | 8 | ## 9 | ## **Preprocess Options** 10 | ## 11 | 12 | parser.add_argument('-config', help="Read options from this file") 13 | 14 | parser.add_argument('-train_src', required=True, 15 | help="Path to the training source data") 16 | parser.add_argument('-train_tgt', required=True, 17 | help="Path to the training target data") 18 | parser.add_argument('-valid_src', required=True, 19 | help="Path to the validation source data") 20 | parser.add_argument('-valid_tgt', required=True, 21 | help="Path to the validation target data") 22 | 23 | parser.add_argument('-save_data', required=True, 24 | help="Output file for the prepared data") 25 | 26 | parser.add_argument('-src_vocab_size', type=int, default=50000, 27 | help="Size of the source vocabulary") 28 | parser.add_argument('-tgt_vocab_size', type=int, default=50000, 29 | help="Size of the target vocabulary") 30 | parser.add_argument('-src_vocab', 31 | help="Path to an existing source vocabulary") 32 | parser.add_argument('-tgt_vocab', 33 | help="Path to an existing target vocabulary") 34 | 35 | 36 | parser.add_argument('-seq_length', type=int, default=50, 37 | help="Maximum sequence length") 38 | parser.add_argument('-shuffle', type=int, default=1, 39 | help="Shuffle data") 40 | parser.add_argument('-seed', type=int, default=3435, 41 | help="Random seed") 42 | 43 | parser.add_argument('-lower', action='store_true', help='lowercase data') 44 | 45 | parser.add_argument('-report_every', type=int, default=100000, 46 | help="Report status every this many sentences") 47 | 48 | opt = parser.parse_args() 49 | 50 | torch.manual_seed(opt.seed) 51 | 52 | def makeVocabulary(filename, size): 53 | vocab = onmt.Dict([onmt.Constants.PAD_WORD, onmt.Constants.UNK_WORD, 54 | onmt.Constants.BOS_WORD, onmt.Constants.EOS_WORD], lower=opt.lower) 55 | 56 | with open(filename) as f: 57 | for sent in f.readlines(): 58 | for word in sent.split(): 59 | vocab.add(word) 60 | 61 | originalSize = vocab.size() 62 | vocab = vocab.prune(size) 63 | print('Created dictionary of size %d (pruned from %d)' % 64 | (vocab.size(), originalSize)) 65 | 66 | return vocab 67 | 68 | 69 | def initVocabulary(name, dataFile, vocabFile, vocabSize): 70 | 71 | vocab = None 72 | if vocabFile is not None: 73 | # If given, load existing word dictionary. 74 | print('Reading ' + name + ' vocabulary from \'' + vocabFile + '\'...') 75 | vocab = onmt.Dict() 76 | vocab.loadFile(vocabFile) 77 | print('Loaded ' + vocab.size() + ' ' + name + ' words') 78 | 79 | if vocab is None: 80 | # If a dictionary is still missing, generate it. 81 | print('Building ' + name + ' vocabulary...') 82 | genWordVocab = makeVocabulary(dataFile, vocabSize) 83 | 84 | vocab = genWordVocab 85 | 86 | print() 87 | return vocab 88 | 89 | 90 | def saveVocabulary(name, vocab, file): 91 | print('Saving ' + name + ' vocabulary to \'' + file + '\'...') 92 | vocab.writeFile(file) 93 | 94 | 95 | def makeData(srcFile, tgtFile, srcDicts, tgtDicts): 96 | src, tgt = [], [] 97 | sizes = [] 98 | count, ignored = 0, 0 99 | 100 | print('Processing %s & %s ...' % (srcFile, tgtFile)) 101 | srcF = open(srcFile) 102 | tgtF = open(tgtFile) 103 | 104 | while True: 105 | srcWords = srcF.readline().split() 106 | tgtWords = tgtF.readline().split() 107 | 108 | if not srcWords or not tgtWords: 109 | if srcWords and not tgtWords or not srcWords and tgtWords: 110 | print('WARNING: source and target do not have the same number of sentences') 111 | break 112 | 113 | if len(srcWords) <= opt.seq_length and len(tgtWords) <= opt.seq_length: 114 | 115 | src += [srcDicts.convertToIdx(srcWords, 116 | onmt.Constants.UNK_WORD)] 117 | tgt += [tgtDicts.convertToIdx(tgtWords, 118 | onmt.Constants.UNK_WORD, 119 | onmt.Constants.BOS_WORD, 120 | onmt.Constants.EOS_WORD)] 121 | 122 | sizes += [len(srcWords)] 123 | else: 124 | ignored += 1 125 | 126 | count += 1 127 | 128 | if count % opt.report_every == 0: 129 | print('... %d sentences prepared' % count) 130 | 131 | srcF.close() 132 | tgtF.close() 133 | 134 | if opt.shuffle == 1: 135 | print('... shuffling sentences') 136 | perm = torch.randperm(len(src)) 137 | src = [src[idx] for idx in perm] 138 | tgt = [tgt[idx] for idx in perm] 139 | sizes = [sizes[idx] for idx in perm] 140 | 141 | print('... sorting sentences by size') 142 | _, perm = torch.sort(torch.Tensor(sizes)) 143 | src = [src[idx] for idx in perm] 144 | tgt = [tgt[idx] for idx in perm] 145 | 146 | print('Prepared %d sentences (%d ignored due to length == 0 or > %d)' % 147 | (len(src), ignored, opt.seq_length)) 148 | 149 | return src, tgt 150 | 151 | 152 | def main(): 153 | 154 | dicts = {} 155 | dicts['src'] = initVocabulary('source', opt.train_src, opt.src_vocab, 156 | opt.src_vocab_size) 157 | dicts['tgt'] = initVocabulary('target', opt.train_tgt, opt.tgt_vocab, 158 | opt.tgt_vocab_size) 159 | 160 | print('Preparing training ...') 161 | train = {} 162 | train['src'], train['tgt'] = makeData(opt.train_src, opt.train_tgt, 163 | dicts['src'], dicts['tgt']) 164 | 165 | print('Preparing validation ...') 166 | valid = {} 167 | valid['src'], valid['tgt'] = makeData(opt.valid_src, opt.valid_tgt, 168 | dicts['src'], dicts['tgt']) 169 | 170 | if opt.src_vocab is None: 171 | saveVocabulary('source', dicts['src'], opt.save_data + '.src.dict') 172 | if opt.tgt_vocab is None: 173 | saveVocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict') 174 | 175 | 176 | print('Saving data to \'' + opt.save_data + '.train.pt\'...') 177 | save_data = {'dicts': dicts, 178 | 'train': train, 179 | 'valid': valid} 180 | torch.save(save_data, opt.save_data + '.train.pt') 181 | 182 | 183 | if __name__ == "__main__": 184 | main() 185 | -------------------------------------------------------------------------------- /OpenNMT-py/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import onmt 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | from torch import cuda 8 | from torch.autograd import Variable 9 | import math 10 | import time 11 | 12 | parser = argparse.ArgumentParser(description='train.py') 13 | 14 | ## Data options 15 | 16 | parser.add_argument('-data', required=True, 17 | help='Path to the *-train.pt file from preprocess.py') 18 | parser.add_argument('-save_model', default='model', 19 | help="""Model filename (the model will be saved as 20 | _epochN_PPL.pt where PPL is the 21 | validation perplexity""") 22 | parser.add_argument('-train_from_state_dict', default='', type=str, 23 | help="""If training from a checkpoint then this is the 24 | path to the pretrained model's state_dict.""") 25 | parser.add_argument('-train_from', default='', type=str, 26 | help="""If training from a checkpoint then this is the 27 | path to the pretrained model.""") 28 | 29 | ## Model options 30 | 31 | parser.add_argument('-layers', type=int, default=2, 32 | help='Number of layers in the LSTM encoder/decoder') 33 | parser.add_argument('-rnn_size', type=int, default=600, 34 | help='Size of LSTM hidden states') 35 | parser.add_argument('-word_vec_size', type=int, default=400, 36 | help='Word embedding sizes') 37 | parser.add_argument('-input_feed', type=int, default=1, 38 | help="""Feed the context vector at each time step as 39 | additional input (via concatenation with the word 40 | embeddings) to the decoder.""") 41 | # parser.add_argument('-residual', action="store_true", 42 | # help="Add residual connections between RNN layers.") 43 | parser.add_argument('-brnn', action='store_true', 44 | help='Use a bidirectional encoder') 45 | parser.add_argument('-dot', action='store_true', 46 | help='Use dot attention') 47 | parser.add_argument('-brnn_merge', default='concat', 48 | help="""Merge action for the bidirectional hidden states: 49 | [concat|sum]""") 50 | 51 | ## Optimization options 52 | 53 | parser.add_argument('-batch_size', type=int, default=64, 54 | help='Maximum batch size') 55 | parser.add_argument('-max_generator_batches', type=int, default=100, 56 | help="""Maximum batches of words in a sequence to run 57 | the generator on in parallel. Higher is faster, but uses 58 | more memory.""") 59 | parser.add_argument('-epochs', type=int, default=1000, 60 | help='Number of training epochs') 61 | parser.add_argument('-start_epoch', type=int, default=1, 62 | help='The epoch from which to start') 63 | parser.add_argument('-param_init', type=float, default=0.1, 64 | help="""Parameters are initialized over uniform distribution 65 | with support (-param_init, param_init)""") 66 | parser.add_argument('-optim', default='sgd', 67 | help="Optimization method. [sgd|adagrad|adadelta|adam]") 68 | parser.add_argument('-max_grad_norm', type=float, default=5, 69 | help="""If the norm of the gradient vector exceeds this, 70 | renormalize it to have the norm equal to max_grad_norm""") 71 | parser.add_argument('-dropout', type=float, default=0.2, 72 | help='Dropout probability; applied between LSTM stacks.') 73 | parser.add_argument('-curriculum', action="store_true", 74 | help="""For this many epochs, order the minibatches based 75 | on source sequence length. Sometimes setting this to 1 will 76 | increase convergence speed.""") 77 | parser.add_argument('-extra_shuffle', action="store_true", 78 | help="""By default only shuffle mini-batch order; when true, 79 | shuffle and re-assign mini-batches""") 80 | 81 | #learning rate 82 | parser.add_argument('-learning_rate', type=float, default=1.0, 83 | help="""Starting learning rate. If adagrad/adadelta/adam is 84 | used, then this is the global learning rate. Recommended 85 | settings: sgd = 1, adagrad = 0.1, adadelta = 1, adam = 0.001""") 86 | parser.add_argument('-lr_cutoff', type=float, default=0.03, 87 | help='below this the training will be stopped') 88 | parser.add_argument('-learning_rate_decay', type=float, default=0.5, 89 | help="""If update_learning_rate, decay learning rate by 90 | this much if (i) perplexity does not decrease on the 91 | validation set or (ii) epoch has gone past 92 | start_decay_at""") 93 | parser.add_argument('-start_decay_at', type=int, default=1000, 94 | help="""Start decaying every epoch after and including this 95 | epoch""") 96 | 97 | #pretrained word vectors 98 | 99 | parser.add_argument('-pre_word_vecs_enc', 100 | help="""If a valid path is specified, then this will load 101 | pretrained word embeddings on the encoder side. 102 | See README for specific formatting instructions.""") 103 | parser.add_argument('-pre_word_vecs_dec', 104 | help="""If a valid path is specified, then this will load 105 | pretrained word embeddings on the decoder side. 106 | See README for specific formatting instructions.""") 107 | parser.add_argument('-detach_embed', default=0, type=int) 108 | parser.add_argument('-fix_embed', action='store_true') 109 | 110 | # GPU 111 | parser.add_argument('-gpus', default=[0], nargs='+', type=int, 112 | help="Use CUDA on the listed devices.") 113 | 114 | parser.add_argument('-log_interval', type=int, default=50, 115 | help="Print stats at this interval.") 116 | 117 | opt = parser.parse_args() 118 | 119 | print(opt) 120 | 121 | if torch.cuda.is_available() and not opt.gpus: 122 | print("WARNING: You have a CUDA device, so you should probably run with -gpus 0") 123 | 124 | if opt.gpus: 125 | cuda.set_device(opt.gpus[0]) 126 | 127 | def NMTCriterion(vocabSize): 128 | weight = torch.ones(vocabSize) 129 | weight[onmt.Constants.PAD] = 0 130 | crit = nn.NLLLoss(weight, size_average=False) 131 | if opt.gpus: 132 | crit.cuda() 133 | return crit 134 | 135 | 136 | def memoryEfficientLoss(outputs, targets, generator, crit, eval=False): 137 | # compute generations one piece at a time 138 | num_correct, loss = 0, 0 139 | outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) 140 | 141 | batch_size = outputs.size(1) 142 | outputs_split = torch.split(outputs, opt.max_generator_batches) 143 | targets_split = torch.split(targets, opt.max_generator_batches) 144 | for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)): 145 | out_t = out_t.view(-1, out_t.size(2)) 146 | scores_t = generator(out_t) 147 | loss_t = crit(scores_t, targ_t.view(-1)) 148 | pred_t = scores_t.max(1)[1] 149 | num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(onmt.Constants.PAD).data).sum() 150 | num_correct += num_correct_t 151 | loss += loss_t.data[0] 152 | if not eval: 153 | loss_t.div(batch_size).backward() 154 | 155 | grad_output = None if outputs.grad is None else outputs.grad.data 156 | return loss, grad_output, num_correct 157 | 158 | 159 | def eval(model, criterion, data): 160 | total_loss = 0 161 | total_words = 0 162 | total_num_correct = 0 163 | 164 | model.eval() 165 | for i in range(len(data)): 166 | batch = data[i][:-1] # exclude original indices 167 | outputs = model(batch) 168 | targets = batch[1][1:] # exclude from targets 169 | loss, _, num_correct = memoryEfficientLoss( 170 | outputs, targets, model.generator, criterion, eval=True) 171 | total_loss += loss 172 | total_num_correct += num_correct 173 | total_words += targets.data.ne(onmt.Constants.PAD).sum() 174 | 175 | model.train() 176 | return total_loss / total_words, total_num_correct / total_words 177 | 178 | 179 | def trainModel(model, trainData, validData, dataset, optim): 180 | print(model) 181 | model.train() 182 | 183 | # define criterion of each GPU 184 | criterion = NMTCriterion(dataset['dicts']['tgt'].size()) 185 | 186 | start_time = time.time() 187 | def trainEpoch(epoch): 188 | 189 | if opt.extra_shuffle and epoch > opt.curriculum: 190 | trainData.shuffle() 191 | 192 | # shuffle mini batch order 193 | batchOrder = torch.randperm(len(trainData)) 194 | 195 | total_loss, total_words, total_num_correct = 0, 0, 0 196 | report_loss, report_tgt_words, report_src_words, report_num_correct = 0, 0, 0, 0 197 | start = time.time() 198 | for i in range(len(trainData)): 199 | 200 | batchIdx = batchOrder[i] if epoch > opt.curriculum else i 201 | batch = trainData[batchIdx][:-1] # exclude original indices 202 | 203 | model.zero_grad() 204 | outputs = model(batch) 205 | targets = batch[1][1:] # exclude from targets 206 | loss, gradOutput, num_correct = memoryEfficientLoss( 207 | outputs, targets, model.generator, criterion) 208 | 209 | outputs.backward(gradOutput) 210 | 211 | # update the parameters 212 | optim.step() 213 | 214 | num_words = targets.data.ne(onmt.Constants.PAD).sum() 215 | report_loss += loss 216 | report_num_correct += num_correct 217 | report_tgt_words += num_words 218 | report_src_words += sum(batch[0][1]) 219 | total_loss += loss 220 | total_num_correct += num_correct 221 | total_words += num_words 222 | if i % opt.log_interval == -1 % opt.log_interval: 223 | print("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" % 224 | (epoch, i+1, len(trainData), 225 | report_num_correct / report_tgt_words * 100, 226 | math.exp(report_loss / report_tgt_words), 227 | report_src_words/(time.time()-start), 228 | report_tgt_words/(time.time()-start), 229 | time.time()-start_time)) 230 | 231 | report_loss = report_tgt_words = report_src_words = report_num_correct = 0 232 | start = time.time() 233 | 234 | return total_loss / total_words, total_num_correct / total_words 235 | 236 | for epoch in range(opt.start_epoch, opt.epochs + 1): 237 | print('') 238 | 239 | # (1) train for one epoch on the training set 240 | train_loss, train_acc = trainEpoch(epoch) 241 | train_ppl = math.exp(min(train_loss, 100)) 242 | print('Train perplexity: %g' % train_ppl) 243 | print('Train accuracy: %g' % (train_acc*100)) 244 | 245 | # (2) evaluate on the validation set 246 | valid_loss, valid_acc = eval(model, criterion, validData) 247 | valid_ppl = math.exp(min(valid_loss, 100)) 248 | print('Validation perplexity: %g' % valid_ppl) 249 | print('Validation accuracy: %g' % (valid_acc*100)) 250 | 251 | # (3) update the learning rate 252 | optim.updateLearningRate(valid_ppl, epoch) 253 | 254 | model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict() 255 | model_state_dict = {k: v for k, v in model_state_dict.items() if 'generator' not in k} 256 | generator_state_dict = model.generator.module.state_dict() if len(opt.gpus) > 1 else model.generator.state_dict() 257 | # (4) drop a checkpoint 258 | checkpoint = { 259 | 'model': model_state_dict, 260 | 'generator': generator_state_dict, 261 | 'dicts': dataset['dicts'], 262 | 'opt': opt, 263 | 'epoch': epoch, 264 | 'optim': optim, 265 | 'ppl': valid_ppl, 266 | 'loss': valid_loss, 267 | 'acc': valid_acc 268 | } 269 | torch.save(checkpoint, 270 | '%s_ppl_%.2f_acc_%.2f_loss_%.2f_e%d.pt' % (opt.save_model, valid_ppl, 100*valid_acc, valid_loss, epoch)) 271 | if optim.lr < opt.lr_cutoff: 272 | print('Learning rate decayed below cutoff: {} < {}'.format(optim.lr, opt.lr_cutoff)) 273 | break 274 | 275 | def main(): 276 | 277 | print("Loading data from '%s'" % opt.data) 278 | 279 | dataset = torch.load(opt.data) 280 | 281 | dict_checkpoint = opt.train_from if opt.train_from else opt.train_from_state_dict 282 | if dict_checkpoint: 283 | print('Loading dicts from checkpoint at %s' % dict_checkpoint) 284 | checkpoint = torch.load(dict_checkpoint) 285 | dataset['dicts'] = checkpoint['dicts'] 286 | 287 | trainData = onmt.Dataset(dataset['train']['src'], 288 | dataset['train']['tgt'], opt.batch_size, opt.gpus) 289 | validData = onmt.Dataset(dataset['valid']['src'], 290 | dataset['valid']['tgt'], opt.batch_size, opt.gpus, 291 | volatile=True) 292 | 293 | dicts = dataset['dicts'] 294 | print(' * vocabulary size. source = %d; target = %d' % 295 | (dicts['src'].size(), dicts['tgt'].size())) 296 | print(' * number of training sentences. %d' % 297 | len(dataset['train']['src'])) 298 | print(' * maximum batch size. %d' % opt.batch_size) 299 | 300 | print('Building model...') 301 | 302 | encoder = onmt.Models.Encoder(opt, dicts['src']) 303 | decoder = onmt.Models.Decoder(opt, dicts['tgt']) 304 | 305 | generator = nn.Sequential( 306 | nn.Linear(opt.rnn_size, dicts['tgt'].size()), 307 | nn.LogSoftmax()) 308 | 309 | model = onmt.Models.NMTModel(encoder, decoder) 310 | 311 | if opt.train_from: 312 | print('Loading model from checkpoint at %s' % opt.train_from) 313 | chk_model = checkpoint['model'] 314 | generator_state_dict = chk_model.generator.state_dict() 315 | model_state_dict = {k: v for k, v in chk_model.state_dict().items() if 'generator' not in k} 316 | model.load_state_dict(model_state_dict) 317 | generator.load_state_dict(generator_state_dict) 318 | opt.start_epoch = checkpoint['epoch'] + 1 319 | 320 | if opt.train_from_state_dict: 321 | print('Loading model from checkpoint at %s' % opt.train_from_state_dict) 322 | model.load_state_dict(checkpoint['model']) 323 | generator.load_state_dict(checkpoint['generator']) 324 | opt.start_epoch = checkpoint['epoch'] + 1 325 | 326 | if len(opt.gpus) >= 1: 327 | model.cuda() 328 | generator.cuda() 329 | else: 330 | model.cpu() 331 | generator.cpu() 332 | 333 | if len(opt.gpus) > 1: 334 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1) 335 | generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0) 336 | 337 | model.generator = generator 338 | 339 | if not opt.train_from_state_dict and not opt.train_from: 340 | for p in model.parameters(): 341 | p.data.uniform_(-opt.param_init, opt.param_init) 342 | 343 | encoder.load_pretrained_vectors(opt) 344 | decoder.load_pretrained_vectors(opt) 345 | 346 | optim = onmt.Optim( 347 | opt.optim, opt.learning_rate, opt.max_grad_norm, 348 | lr_decay=opt.learning_rate_decay, 349 | start_decay_at=opt.start_decay_at 350 | ) 351 | else: 352 | print('Loading optimizer from checkpoint:') 353 | optim = checkpoint['optim'] 354 | print(optim) 355 | 356 | optim.set_parameters(model.parameters()) 357 | 358 | if opt.train_from or opt.train_from_state_dict: 359 | optim.optimizer.load_state_dict(checkpoint['optim'].optimizer.state_dict()) 360 | 361 | nParams = sum([p.nelement() for p in model.parameters()]) 362 | print('* number of parameters: %d' % nParams) 363 | 364 | trainModel(model, trainData, validData, dataset, optim) 365 | 366 | 367 | if __name__ == "__main__": 368 | main() 369 | -------------------------------------------------------------------------------- /OpenNMT-py/translate.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import onmt 4 | import torch 5 | import argparse 6 | import math 7 | 8 | parser = argparse.ArgumentParser(description='translate.py') 9 | 10 | parser.add_argument('-model', required=True, 11 | help='Path to model .pt file') 12 | parser.add_argument('-src', required=True, 13 | help='Source sequence to decode (one line per sequence)') 14 | parser.add_argument('-tgt', 15 | help='True target sequence (optional)') 16 | parser.add_argument('-output', default='pred.txt', 17 | help="""Path to output the predictions (each line will 18 | be the decoded sequence""") 19 | parser.add_argument('-beam_size', type=int, default=5, 20 | help='Beam size') 21 | parser.add_argument('-batch_size', type=int, default=30, 22 | help='Batch size') 23 | parser.add_argument('-max_sent_length', default=100, 24 | help='Maximum sentence length.') 25 | parser.add_argument('-replace_unk', action="store_true", 26 | help="""Replace the generated UNK tokens with the source 27 | token that had the highest attention weight. If phrase_table 28 | is provided, it will lookup the identified source token and 29 | give the corresponding target token. If it is not provided 30 | (or the identified source token does not exist in the 31 | table) then it will copy the source token""") 32 | # parser.add_argument('-phrase_table', 33 | # help="""Path to source-target dictionary to replace UNK 34 | # tokens. See README.md for the format of this file.""") 35 | parser.add_argument('-verbose', action="store_true", 36 | help='Print scores and predictions for each sentence') 37 | parser.add_argument('-n_best', type=int, default=1, 38 | help="""If verbose is set, will output the n_best 39 | decoded sentences""") 40 | 41 | parser.add_argument('-gpu', type=int, default=-1, 42 | help="Device to run on") 43 | 44 | 45 | 46 | def reportScore(name, scoreTotal, wordsTotal): 47 | print("%s AVG SCORE: %.4f, %s PPL: %.4f" % ( 48 | name, scoreTotal / wordsTotal, 49 | name, math.exp(-scoreTotal/wordsTotal))) 50 | 51 | def addone(f): 52 | for line in f: 53 | yield line 54 | yield None 55 | 56 | def main(): 57 | opt = parser.parse_args() 58 | opt.cuda = opt.gpu > -1 59 | if opt.cuda: 60 | torch.cuda.set_device(opt.gpu) 61 | 62 | translator = onmt.Translator(opt) 63 | 64 | outF = open(opt.output, 'w') 65 | 66 | predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0 67 | 68 | srcBatch, tgtBatch = [], [] 69 | 70 | count = 0 71 | 72 | tgtF = open(opt.tgt) if opt.tgt else None 73 | for line in addone(open(opt.src)): 74 | 75 | if line is not None: 76 | srcTokens = line.split() 77 | srcBatch += [srcTokens] 78 | if tgtF: 79 | tgtTokens = tgtF.readline().split() if tgtF else None 80 | tgtBatch += [tgtTokens] 81 | 82 | if len(srcBatch) < opt.batch_size: 83 | continue 84 | else: 85 | # at the end of file, check last batch 86 | if len(srcBatch) == 0: 87 | break 88 | 89 | predBatch, predScore, goldScore = translator.translate(srcBatch, tgtBatch) 90 | 91 | predScoreTotal += sum(score[0] for score in predScore) 92 | predWordsTotal += sum(len(x[0]) for x in predBatch) 93 | if tgtF is not None: 94 | goldScoreTotal += sum(goldScore) 95 | goldWordsTotal += sum(len(x) for x in tgtBatch) 96 | 97 | for b in range(len(predBatch)): 98 | count += 1 99 | outF.write(" ".join(predBatch[b][0]) + '\n') 100 | outF.flush() 101 | 102 | if opt.verbose: 103 | srcSent = ' '.join(srcBatch[b]) 104 | if translator.tgt_dict.lower: 105 | srcSent = srcSent.lower() 106 | print('SENT %d: %s' % (count, srcSent)) 107 | print('PRED %d: %s' % (count, " ".join(predBatch[b][0]))) 108 | print("PRED SCORE: %.4f" % predScore[b][0]) 109 | 110 | if tgtF is not None: 111 | tgtSent = ' '.join(tgtBatch[b]) 112 | if translator.tgt_dict.lower: 113 | tgtSent = tgtSent.lower() 114 | print('GOLD %d: %s ' % (count, tgtSent)) 115 | print("GOLD SCORE: %.4f" % goldScore[b]) 116 | 117 | if opt.n_best > 1: 118 | print('\nBEST HYP:') 119 | for n in range(opt.n_best): 120 | print("[%.4f] %s" % (predScore[b][n], " ".join(predBatch[b][n]))) 121 | 122 | print('') 123 | 124 | srcBatch, tgtBatch = [], [] 125 | 126 | reportScore('PRED', predScoreTotal, predWordsTotal) 127 | if tgtF: 128 | reportScore('GOLD', goldScoreTotal, goldWordsTotal) 129 | 130 | if tgtF: 131 | tgtF.close() 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /OpenNMT-py/wmt_clean.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import pycld2 3 | import unicodeblock.blocks 4 | from argparse import ArgumentParser 5 | 6 | parser = ArgumentParser() 7 | parser.add_argument('prefix', default='data/wmt17/de-en/') 8 | args = parser.parse_args() 9 | 10 | langs = ('de','en') 11 | lang_fix = '.' + '-'.join(langs) 12 | subsets = 'commoncrawl', 'europarl-v7', 'news-commentary-v12', 'rapid2016' 13 | for x in subsets: 14 | path_prefix = args.prefix + x + lang_fix 15 | paths_in = [path_prefix+'.'+lang for lang in langs] 16 | paths_out = [path_prefix+'.clean.'+lang for lang in langs] 17 | latin = lambda s: all("LATIN" in b or "PUNCT" in b or "DIGIT" in b or "SPAC" in b for b in map(unicodeblock.blocks.of,s) if b is not None) 18 | good_src = lambda s: pycld2.detect(s)[2][0][1] in [langs[0],'un'] and latin(s.decode()) and len(s)>1 19 | good_trg = lambda s: pycld2.detect(s)[2][0][1] in [langs[1],'un'] and latin(s.decode()) and len(s)>1 20 | 21 | with open(paths_in[0],'rb') as src, open(paths_in[1],'rb') as trg, open(paths_out[0],'wb') as src_out, open(paths_out[1],'wb') as trg_out: 22 | for srcline,trgline in zip(src,trg): 23 | try: 24 | if good_src(srcline) and good_trg(trgline): 25 | src_out.write(srcline) 26 | trg_out.write(trgline) 27 | except: 28 | try: 29 | srcline = srcline.decode("utf-8").encode("latin-1") 30 | trgline = trgline.decode("utf-8").encode("latin-1") 31 | try: 32 | if good_src(srcline) and good_trg(trgline): 33 | src_out.write(srcline) 34 | trg_out.write(trgline) 35 | except: 36 | pass 37 | except: 38 | pass 39 | -------------------------------------------------------------------------------- /OpenNMT-py/wmt_sgm2txt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import subprocess 4 | import xml.etree.ElementTree as ET 5 | from argparse import ArgumentParser 6 | 7 | parser = ArgumentParser(description='sgm2txt') 8 | parser.add_argument('path') 9 | parser.add_argument('-t', '--tags', nargs='+', default=['seg']) 10 | parser.add_argument('-th', '--threads', default=8, type=int) 11 | parser.add_argument('-a', '--aggressive', action='store_true') 12 | args = parser.parse_args() 13 | 14 | def tokenize(f_txt): 15 | lang = os.path.splitext(f_txt)[1][1:] 16 | f_tok = f_txt 17 | if args.aggressive: 18 | f_tok += '.atok' 19 | else: 20 | f_tok += '.tok' 21 | with open(f_tok, 'w') as fout, open(f_txt) as fin: 22 | if args.aggressive: 23 | pipe = subprocess.call(['perl', 'tokenizer.perl', '-a', '-q', '-threads', str(args.threads), '-no-escape', '-l', lang], stdin=fin, stdout=fout) 24 | else: 25 | pipe = subprocess.call(['perl', 'tokenizer.perl', '-q', '-threads', str(args.threads), '-no-escape', '-l', lang], stdin=fin, stdout=fout) 26 | 27 | for f_xml in glob.iglob(os.path.join(args.path, '*.sgm')): 28 | print(f_xml) 29 | f_txt = os.path.splitext(f_xml)[0] 30 | with open(f_txt, 'w') as fd_txt: 31 | root = ET.parse(f_xml).getroot()[0] 32 | for doc in root.findall('doc'): 33 | for tag in args.tags: 34 | for e in doc.findall(tag): 35 | fd_txt.write(e.text.strip() + '\n') 36 | tokenize(f_txt) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contextualized Word Vectors (CoVe) 2 | 3 | This repo provides the best, pretrained MT-LSTM from the paper [Learned in Translation: Contextualized Word Vectors (McCann et. al. 2017)](http://papers.nips.cc/paper/7209-learned-in-translation-contextualized-word-vectors.pdf). 4 | For a high-level overview of why CoVe are great, check out the [post](https://einstein.ai/research/learned-in-translation-contextualized-word-vectors). 5 | 6 | This repository uses a [PyTorch](http://pytorch.org/) implementation of the MTLSTM class in mtlstm.py to load a pretrained encoder, 7 | which takes in sequences of vectors pretrained with GloVe and outputs CoVe. 8 | 9 | ## Need CoVe in Tensorflow? 10 | 11 | A Keras/TensorFlow implementation of the MT-LSTM/CoVe can be found at https://github.com/rgsachin/CoVe. 12 | 13 | ## Unknown Words 14 | 15 | Out of vocabulary words for CoVe are also out of vocabulary for GloVe, which should be rare for most use cases. During training the CoVe encoder would have received a zero vector for any words that were not in GloVe, and it used zero vectors for unknown words in our classification and question answering experiments, so that is recommended. 16 | 17 | You could also try initializing unknown inputs to something close to GloVe vectors instead, but we have no experiments suggesting that this would work better than zero vectors. If you wanted to try this, GloVe vectors follow (very roughly) a Gaussian with mean 0 and standard deviation 0.4. You could initialize by randomly drawing from that distrubtion, but you would probably want to train those embeddings while keeping the CoVe encoder (MTLSTM) and GloVe fixed. 18 | 19 | ## Example Usage 20 | 21 | The following example can be found in `test/example.py`. It demonstrates a few different variations of how to use the pretrained MTLSTM class that generates contextualized word vectors (CoVe) programmatically. 22 | 23 | ### Running with Docker 24 | 25 | Install [Docker](https://www.docker.com/get-docker). 26 | Install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) if you would like to use with with a GPU. 27 | 28 | ```bash 29 | docker pull bmccann/cove # pull the docker image 30 | # On CPU 31 | docker run -it --rm -v `pwd`/.embeddings:/.embeddings/ -v `pwd`/.data/:/.data/ bmccann/cove bash -c "python /test/example.py --device -1" 32 | # On GPU 33 | nvidia-docker run -it --rm -v `pwd`/.embeddings:/.embeddings/ -v `pwd`/.data/:/.data/ bmccann/cove bash -c "python /test/example.py" 34 | ``` 35 | 36 | ### Running without Docker 37 | 38 | Install [PyTorch](http://pytorch.org/). 39 | 40 | ```bash 41 | git clone https://github.com/salesforce/cove.git # use ssh: git@github.com:salesforce/cove.git 42 | cd cove 43 | pip install -r requirements.txt 44 | python setup.py develop 45 | # On CPU 46 | python test/example.py --device -1 47 | # On GPU 48 | python test/example.py 49 | ``` 50 | ## Re-training CoVe 51 | 52 | There is also the third option if you are operating in an entirely different context -- retrain the bidirectional LSTM using trained embeddings. If you are mostly encoding a non-English language, that might be the best option. Check out the paper for details; code for this is included in the directory OpenNMT-py, which was forked from [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) a long while back and includes changes we made to the repo internally. 53 | 54 | ## References 55 | 56 | If using this code, please cite: 57 | 58 | B. McCann, J. Bradbury, C. Xiong, R. Socher, [*Learned in Translation: Contextualized Word Vectors*](http://papers.nips.cc/paper/7209-learned-in-translation-contextualized-word-vectors.pdf) 59 | 60 | ``` 61 | @inproceedings{mccann2017learned, 62 | title={Learned in translation: Contextualized word vectors}, 63 | author={McCann, Bryan and Bradbury, James and Xiong, Caiming and Socher, Richard}, 64 | booktitle={Advances in Neural Information Processing Systems}, 65 | pages={6297--6308}, 66 | year={2017} 67 | } 68 | ``` 69 | 70 | Contact: [bmccann@salesforce.com](mailto:bmccann@salesforce.com) 71 | -------------------------------------------------------------------------------- /cove/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import * 2 | 3 | 4 | _all__ = ['MTLSTM'] 5 | -------------------------------------------------------------------------------- /cove/encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | 10 | model_urls = { 11 | 'wmt-lstm' : 'https://s3.amazonaws.com/research.metamind.io/cove/wmtlstm-8f474287.pth' 12 | } 13 | 14 | MODEL_CACHE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.torch') 15 | 16 | 17 | class MTLSTM(nn.Module): 18 | 19 | def __init__(self, n_vocab=None, vectors=None, residual_embeddings=False, layer0=False, layer1=True, trainable=False, model_cache=MODEL_CACHE): 20 | """Initialize an MTLSTM. If layer0 and layer1 are True, they are concatenated along the last dimension so that layer0 outputs 21 | contribute the first 600 entries and layer1 contributes the second 600 entries. If residual embeddings is also true, inputs 22 | are also concatenated along the last dimension with any outputs such that they form the first 300 entries. 23 | 24 | Arguments: 25 | n_vocab (int): If not None, initialize MTLSTM with an embedding matrix with n_vocab vectors 26 | vectors (Float Tensor): If not None, initialize embedding matrix with specified vectors (These should be 300d CommonCrawl GloVe vectors) 27 | residual_embedding (bool): If True, concatenate the input GloVe embeddings with contextualized word vectors as final output 28 | layer0 (bool): If True, return the outputs of the first layer of the MTLSTM 29 | layer1 (bool): If True, return the outputs of the second layer of the MTLSTM 30 | trainable (bool): If True, do not detach outputs; i.e. train the MTLSTM (recommended to leave False) 31 | model_cache (str): path to the model file for the MTLSTM to load pretrained weights (defaults to the best MTLSTM from (McCann et al. 2017) -- 32 | that MTLSTM was trained with 300d 840B GloVe on the WMT 2017 machine translation dataset. 33 | """ 34 | super(MTLSTM, self).__init__() 35 | self.layer0 = layer0 36 | self.layer1 = layer1 37 | self.residual_embeddings = residual_embeddings 38 | self.trainable = trainable 39 | self.embed = False 40 | if n_vocab is not None: 41 | self.embed = True 42 | self.vectors = nn.Embedding(n_vocab, 300) 43 | if vectors is not None: 44 | self.vectors.weight.data = vectors 45 | state_dict = model_zoo.load_url(model_urls['wmt-lstm'], model_dir=model_cache) 46 | if layer0: 47 | layer0_dict = {k: v for k, v in state_dict.items() if 'l0' in k} 48 | self.rnn0 = nn.LSTM(300, 300, num_layers=1, bidirectional=True, batch_first=True) 49 | self.rnn0.load_state_dict(layer0_dict) 50 | if layer1: 51 | layer1_dict = {k.replace('l1', 'l0'): v for k, v in state_dict.items() if 'l1' in k} 52 | self.rnn1 = nn.LSTM(600, 300, num_layers=1, bidirectional=True, batch_first=True) 53 | self.rnn1.load_state_dict(layer1_dict) 54 | elif layer1: 55 | self.rnn1 = nn.LSTM(300, 300, num_layers=2, bidirectional=True, batch_first=True) 56 | self.rnn1.load_state_dict(model_zoo.load_url(model_urls['wmt-lstm'], model_dir=model_cache)) 57 | else: 58 | raise ValueError('At least one of layer0 and layer1 must be True.') 59 | 60 | 61 | def forward(self, inputs, lengths, hidden=None): 62 | """ 63 | Arguments: 64 | inputs (Tensor): If MTLSTM handles embedding, a Long Tensor of size (batch_size, timesteps). 65 | Otherwise, a Float Tensor of size (batch_size, timesteps, features). 66 | lengths (Long Tensor): lenghts of each sequence for handling padding 67 | hidden (Float Tensor): initial hidden state of the LSTM 68 | """ 69 | if self.embed: 70 | inputs = self.vectors(inputs) 71 | if not isinstance(lengths, torch.Tensor): 72 | lengths = torch.Tensor(lengths).long() 73 | if inputs.is_cuda: 74 | with torch.cuda.device_of(inputs): 75 | lengths = lengths.cuda(torch.cuda.current_device()) 76 | lens, indices = torch.sort(lengths, 0, True) 77 | outputs = [inputs] if self.residual_embeddings else [] 78 | len_list = lens.tolist() 79 | packed_inputs = pack(inputs[indices], len_list, batch_first=True) 80 | 81 | if self.layer0: 82 | outputs0, hidden_t0 = self.rnn0(packed_inputs, hidden) 83 | unpacked_outputs0 = unpack(outputs0, batch_first=True)[0] 84 | _, _indices = torch.sort(indices, 0) 85 | unpacked_outputs0 = unpacked_outputs0[_indices] 86 | outputs.append(unpacked_outputs0) 87 | packed_inputs = outputs0 88 | if self.layer1: 89 | outputs1, hidden_t1 = self.rnn1(packed_inputs, hidden) 90 | unpacked_outputs1 = unpack(outputs1, batch_first=True)[0] 91 | _, _indices = torch.sort(indices, 0) 92 | unpacked_outputs1 = unpacked_outputs1[_indices] 93 | outputs.append(unpacked_outputs1) 94 | 95 | outputs = torch.cat(outputs, 2) 96 | return outputs if self.trainable else outputs.detach() 97 | -------------------------------------------------------------------------------- /get_data.sh: -------------------------------------------------------------------------------- 1 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl 2 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/lowercase.perl 3 | sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl 4 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de 5 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en 6 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl 7 | 8 | mkdir -p data/wmt17 9 | cd data/wmt17 10 | wget http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz 11 | wget http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz 12 | wget http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz 13 | wget http://data.statmt.org/wmt17/translation-task/rapid2016.tgz 14 | wget http://data.statmt.org/wmt17/translation-task/dev.tgz 15 | tar -xzf training-parallel-europarl-v7.tgz 16 | tar -xzf training-parallel-commoncrawl.tgz 17 | tar -xzf training-parallel-nc-v12.tgz 18 | tar -xzf rapid2016.tgz 19 | tar -xzf dev.tgz 20 | mkdir de-en 21 | mv *de-en* de-en 22 | mv training/*de-en* de-en 23 | mv dev/*deen* de-en 24 | mv dev/*ende* de-en 25 | mv dev/*.de de-en 26 | mv dev/*.en de-en 27 | mv dev/newstest2009*.en* 28 | mv dev/news-test2008*.en* 29 | 30 | python ../../wmt_clean.py de-en 31 | for l in de; do for f in de-en/*.clean.$l; do perl ../../tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done 32 | for l in en; do for f in de-en/*.clean.$l; do perl ../../tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done 33 | for l in en de; do for f in de-en/*.clean.$l.tok; do perl ../../lowercase.perl < $f > $f.low; done; done 34 | for l in en de; do perl ../../tokenizer.perl -no-escape -l $l -q < de-en/newstest2013.$l > de-en/newstest2013.$l.tok; done 35 | for l in en de; do perl ../../lowercase.perl < de-en/newstest2013.$l.tok > de-en/newstest2013.$l.tok.low; done 36 | for l in en de; do cat de-en/commoncraw*clean.$l.tok.low de-en/europarl*.clean.$l.tok.low de-en/news-commentary*.clean.$l.tok.low de-en/rapid*.clean.$l.tok.low > de-en/train.clean.$l.tok.low; done 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r https://raw.githubusercontent.com/pytorch/text/master/requirements.txt 2 | git+https://github.com/pytorch/text.git 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | from codecs import open 4 | from os import path 5 | 6 | 7 | with open(path.join(path.abspath(path.dirname(__file__)), 'README.md'), encoding='utf-8') as f: 8 | long_description = f.read() 9 | 10 | setup_info = dict( 11 | name='cove', 12 | version='1.0.0', 13 | author='Bryan McCann', 14 | author_email='Bryan.McCann.is@gmail.com', 15 | url='https://github.com/salesforce/cove', 16 | description='Context Vectors for Deep Learning and NLP', 17 | long_description=long_description, 18 | license='BSD 3-Clause', 19 | keywords='cove, context vectors, deep learning, natural language processing', 20 | packages=find_packages() 21 | ) 22 | 23 | setup(**setup_info) 24 | -------------------------------------------------------------------------------- /test/example.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import numpy as np 3 | 4 | import torch 5 | from torchtext import data 6 | from torchtext import datasets 7 | from torchtext.vocab import GloVe 8 | 9 | from cove import MTLSTM 10 | 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument('--device', default=0, help='Which device to run one; -1 for CPU', type=int) 14 | parser.add_argument('--data', default='.data', help='where to store data') 15 | parser.add_argument('--embeddings', default='.embeddings', help='where to store embeddings') 16 | args = parser.parse_args() 17 | 18 | inputs = data.Field(lower=True, include_lengths=True, batch_first=True) 19 | 20 | print('Generating train, dev, test splits') 21 | train, dev, test = datasets.IWSLT.splits(root=args.data, exts=['.en', '.de'], fields=[inputs, inputs]) 22 | train_iter, dev_iter, test_iter = data.Iterator.splits( 23 | (train, dev, test), batch_size=100, device=torch.device(args.device) if args.device >= 0 else None) 24 | 25 | print('Building vocabulary') 26 | inputs.build_vocab(train, dev, test) 27 | inputs.vocab.load_vectors(vectors=GloVe(name='840B', dim=300, cache=args.embeddings)) 28 | 29 | outputs_last_layer_cove = MTLSTM(n_vocab=len(inputs.vocab), vectors=inputs.vocab.vectors, model_cache=args.embeddings) 30 | outputs_both_layer_cove = MTLSTM(n_vocab=len(inputs.vocab), vectors=inputs.vocab.vectors, layer0=True, model_cache=args.embeddings) 31 | outputs_both_layer_cove_with_glove = MTLSTM(n_vocab=len(inputs.vocab), vectors=inputs.vocab.vectors, layer0=True, residual_embeddings=True, model_cache=args.embeddings) 32 | 33 | if args.device >= 0: 34 | outputs_last_layer_cove.cuda() 35 | outputs_both_layer_cove.cuda() 36 | outputs_both_layer_cove_with_glove.cuda() 37 | 38 | train_iter.init_epoch() 39 | print('Generating CoVe') 40 | for batch_idx, batch in enumerate(train_iter): 41 | if batch_idx > 0: 42 | break 43 | last_layer_cove = outputs_last_layer_cove(*batch.src) 44 | print(last_layer_cove.size()) 45 | first_then_last_layer_cove = outputs_both_layer_cove(*batch.src) 46 | print(first_then_last_layer_cove.size()) 47 | glove_then_first_then_last_layer_cove = outputs_both_layer_cove_with_glove(*batch.src) 48 | print(glove_then_first_then_last_layer_cove.size()) 49 | assert np.allclose(last_layer_cove, first_then_last_layer_cove[:, :, -600:]) 50 | assert np.allclose(last_layer_cove, glove_then_first_then_last_layer_cove[:, :, -600:]) 51 | assert np.allclose(first_then_last_layer_cove[:, :, :600], glove_then_first_then_last_layer_cove[:, :, 300:900]) 52 | print(last_layer_cove[:, :, -10:]) 53 | print(first_then_last_layer_cove[:, :, -10:]) 54 | print(glove_then_first_then_last_layer_cove[:, :, -10:]) 55 | -------------------------------------------------------------------------------- /wmt_clean.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import pycld2 3 | import unicodeblock.blocks 4 | from argparse import ArgumentParser 5 | 6 | parser = ArgumentParser() 7 | parser.add_argument('prefix', default='data/wmt17/de-en/') 8 | args = parser.parse_args() 9 | 10 | langs = ('de','en') 11 | lang_fix = '.' + '-'.join(langs) 12 | subsets = 'commoncrawl', 'europarl-v7', 'news-commentary-v12', 'rapid2016' 13 | for x in subsets: 14 | path_prefix = args.prefix + x + lang_fix 15 | paths_in = [path_prefix+'.'+lang for lang in langs] 16 | paths_out = [path_prefix+'.clean.'+lang for lang in langs] 17 | latin = lambda s: all("LATIN" in b or "PUNCT" in b or "DIGIT" in b or "SPAC" in b for b in map(unicodeblock.blocks.of,s) if b is not None) 18 | good_src = lambda s: pycld2.detect(s)[2][0][1] in [langs[0],'un'] and latin(s.decode()) and len(s)>1 19 | good_trg = lambda s: pycld2.detect(s)[2][0][1] in [langs[1],'un'] and latin(s.decode()) and len(s)>1 20 | 21 | with open(paths_in[0],'rb') as src, open(paths_in[1],'rb') as trg, open(paths_out[0],'wb') as src_out, open(paths_out[1],'wb') as trg_out: 22 | for srcline,trgline in zip(src,trg): 23 | try: 24 | if good_src(srcline) and good_trg(trgline): 25 | src_out.write(srcline) 26 | trg_out.write(trgline) 27 | except: 28 | try: 29 | srcline = srcline.decode("utf-8").encode("latin-1") 30 | trgline = trgline.decode("utf-8").encode("latin-1") 31 | try: 32 | if good_src(srcline) and good_trg(trgline): 33 | src_out.write(srcline) 34 | trg_out.write(trgline) 35 | except: 36 | pass 37 | except: 38 | pass 39 | --------------------------------------------------------------------------------