├── .gitignore ├── LICENSE.md ├── README.md ├── __init__.py ├── data ├── src-test.txt ├── src-train.txt ├── src-val.txt ├── tgt-train.txt └── tgt-val.txt ├── docs ├── README.md ├── _config.yml ├── css │ └── extra.css ├── extended.md ├── generate.sh ├── img │ ├── architecture.png │ ├── brnn.png │ ├── dbrnn.png │ ├── favicon.ico │ ├── global-attention-model.png │ ├── input_feed.png │ ├── logo-alpha.png │ ├── pdbrnn.png │ └── residual.png ├── index.md ├── installation.md ├── options │ ├── preprocess.md │ ├── train.md │ └── translate.md ├── quickstart.md └── references.md ├── eval.sh ├── mkdocs.yml ├── onmt ├── Beam.py ├── Constants.py ├── Dataset.py ├── Decoders.py ├── Dict.py ├── Encoders.py ├── Markdown.py ├── Models.py ├── Optim.py ├── Translator.py ├── __init__.py └── modules │ ├── Attention.py │ ├── Gate.py │ ├── ImageEncoder.py │ ├── Normalization.py │ ├── SRU_units.py │ ├── Units.py │ └── __init__.py ├── preprocess.py ├── setup.py ├── test └── test_simple.py ├── tools └── extract_embeddings.py ├── train.py ├── train.sh ├── translate.py └── translate.sh /.gitignore: -------------------------------------------------------------------------------- 1 | pred.txt 2 | multi-bleu.perl 3 | *.pt 4 | *.pyc 5 | #.* 6 | .idea 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | This software is derived from the OpenNMT project at 2 | https://github.com/OpenNMT/OpenNMT. 3 | 4 | The MIT License (MIT) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenNMT: Open-Source Neural Machine Translation 2 | 3 | This is an extension of [OpenNMT](https://github.com/OpenNMT/OpenNMT), 4 | which includes the code for the SR-NMT that has been introduced in 5 | [Deep Neural Machine Translation with Weakly-Recurrent Units](https://arxiv.org/abs/1805.04185). 6 | 7 |
8 | 9 | ## Quickstart 10 | 11 | ## Some useful tools: 12 | 13 | The example below uses the Moses tokenizer (http://www.statmt.org/moses/) to prepare the data and the moses BLEU script for evaluation. 14 | 15 | ```bash 16 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl 17 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de 18 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en 19 | sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl 20 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl 21 | ``` 22 | 23 | ## A simple pipeline: 24 | 25 | Download and preprocess the data as you would do for [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py). 26 | Then use preprocess.py, train.sh and translate.sh for the actual training and translation. 27 | 28 | ### 1) Preprocess the data. 29 | 30 | ```bash 31 | python preprocess.py -train_src /path/to/data/train.src -train_tgt /path/to/data/train.tgt -valid_src /path/to/data/valid.src -valid_tgt /path/to/data/valid.tgt -save_data /path/to/data/data 32 | ``` 33 | 34 | ### 2) Train the model. 35 | 36 | ```bash 37 | sh train.sh num_layers num_gpu 38 | ``` 39 | 40 | ### 3) Translate sentences. 41 | 42 | ```bash 43 | sh translate.sh model_name test_file num_gpu 44 | ``` 45 | 46 | ### 4) Evaluate. 47 | ```bash 48 | sh eval.sh hypothesys target_language /path/to/test/tokenized.tgt 49 | ``` 50 | This evaluation is consistent with the one used in the paper and was taken from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/get_ende_bleu.sh). 51 | 52 | ## New versions 53 | We are working to integrate SR-NMT inside: 54 | - [OpenNMT-py](https://github.com/mattiadg/OpenNMT-py) 55 | ([OpenNMT/OpenNMT-py#748](https://github.com/OpenNMT/OpenNMT-py/pull/748)) 56 | Status: Testing 57 | 58 | - [OpenNMT-tf](https://github.com/mattiadg/OpenNMT-tf/tree/srnmt) 59 | Status: Development 60 | 61 | ## Citation 62 | 63 | If you use this software, please cite: 64 | 65 | ``` 66 | @inproceedings{digangi2018deep, 67 | author = {Di Gangi, Mattia A and Federico, Marcello}, 68 | title = {Deep Neural Machine Translation with Weakly-Recurrent Units}, 69 | booktitle = {Proceedings of the 21st Annual Conference of the European Association for Machine Translation}, 70 | pages = {119--128}, 71 | year = {2018} 72 | } 73 | ``` 74 | 75 | 76 | [OpenNMT technical report](https://doi.org/10.18653/v1/P17-4012) 77 | 78 | ``` 79 | @inproceedings{opennmt, 80 | author = {Guillaume Klein and 81 | Yoon Kim and 82 | Yuntian Deng and 83 | Jean Senellart and 84 | Alexander M. Rush}, 85 | title = {OpenNMT: Open-Source Toolkit for Neural Machine Translation}, 86 | booktitle = {Proc. ACL}, 87 | year = {2017}, 88 | url = {https://doi.org/10.18653/v1/P17-4012}, 89 | doi = {10.18653/v1/P17-4012} 90 | } 91 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/__init__.py -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | [MkDocs](http://www.mkdocs.org/) is used to generate the documentation at http://opennmt.net/OpenNMT/. 2 | 3 | Documentation under construction for [SR-NMT](https://github.com/mattiadg/SR-NMT) 4 | 5 |
-------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /docs/css/extra.css: -------------------------------------------------------------------------------- 1 | .md-nav__item--active > .md-nav__link, .md-nav__link:active, .md-typeset a { 2 | color: #ac4142; 3 | } 4 | 5 | .md-nav__link:focus, .md-nav__link:hover, .md-typeset a:active, .md-typeset a:hover { 6 | color: #d67272; 7 | } 8 | 9 | .md-header { 10 | background-color: #ac4142; 11 | } 12 | 13 | label.md-nav__title.md-nav__title--site { 14 | background-color: #ac4142; 15 | color: white; 16 | padding: 1rem 1.2rem; 17 | font-size: 1.6rem; 18 | } 19 | 20 | .md-nav.md-nav--secondary { 21 | border-left-color: #ac4142; 22 | } 23 | 24 | .md-sidebar.md-sidebar--primary { 25 | height: 420px; 26 | } 27 | 28 | .md-flex__cell.md-flex__cell--shrink img { 29 | width:36px;height:36px;margin-top:-6px 30 | } -------------------------------------------------------------------------------- /docs/extended.md: -------------------------------------------------------------------------------- 1 | ## Some useful tools: 2 | 3 | The example below uses the Moses tokenizer (http://www.statmt.org/moses/) to prepare the data and the moses BLEU script for evaluation. 4 | 5 | ```bash 6 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl 7 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de 8 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en 9 | sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl 10 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl 11 | ``` 12 | 13 | ## WMT'16 Multimodal Translation: Multi30k (de-en) 14 | 15 | An example of training for the WMT'16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html). 16 | 17 | ### 0) Download the data. 18 | 19 | ```bash 20 | mkdir -p data/multi30k 21 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz && tar -xf training.tar.gz -C data/multi30k && rm training.tar.gz 22 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz && tar -xf validation.tar.gz -C data/multi30k && rm validation.tar.gz 23 | wget https://staff.fnwi.uva.nl/d.elliott/wmt16/mmt16_task1_test.tgz && tar -xf mmt16_task1_test.tgz -C data/multi30k && rm mmt16_task1_test.tgz 24 | ``` 25 | 26 | ### 1) Preprocess the data. 27 | 28 | ```bash 29 | for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done 30 | for l in en de; do for f in data/multi30k/*.$l; do perl tokenizer.perl -a -no-escape -l $l -q < $f > $f.atok; done; done 31 | python preprocess.py -train_src data/multi30k/train.en.atok -train_tgt data/multi30k/train.de.atok -valid_src data/multi30k/val.en.atok -valid_tgt data/multi30k/val.de.atok -save_data data/multi30k.atok.low -lower 32 | ``` 33 | 34 | ### 2) Train the model. 35 | 36 | ```bash 37 | python train.py -data data/multi30k.atok.low.train.pt -save_model multi30k_model -gpus 0 38 | ``` 39 | 40 | ### 3) Translate sentences. 41 | 42 | ```bash 43 | python translate.py -gpu 0 -model multi30k_model_e13_*.pt -src data/multi30k/test.en.atok -tgt data/multi30k/test.de.atok -replace_unk -verbose -output multi30k.test.pred.atok 44 | ``` 45 | 46 | ### 4) Evaluate. 47 | 48 | ```bash 49 | perl multi-bleu.perl data/multi30k/test.de.atok < multi30k.test.pred.atok 50 | ``` 51 | 52 | ## Pretrained Models 53 | 54 | The following pretrained models can be downloaded and used with translate.py (These were trained with an older version of the code; they will be updated soon). 55 | 56 | - [onmt_model_en_de_200k](https://s3.amazonaws.com/pytorch/examples/opennmt/models/onmt_model_en_de_200k-4783d9c3.pt): An English-German translation model based on the 200k sentence dataset at [OpenNMT/IntegrationTesting](https://github.com/OpenNMT/IntegrationTesting/tree/master/data). Perplexity: 21. 57 | - [onmt_model_en_fr_b1M](https://s3.amazonaws.com/pytorch/examples/opennmt/models/onmt_model_en_fr_b1M-261c69a7.pt): An English-French model trained on benchmark-1M. Perplexity: 4.85. 58 | 59 | -------------------------------------------------------------------------------- /docs/generate.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | gen_script_options () 4 | { 5 | echo "" > $2 6 | echo "" >> $2 7 | python $1 -md >> $2 8 | } 9 | 10 | gen_script_options preprocess.py docs/options/preprocess.md 11 | gen_script_options train.py docs/options/train.md 12 | gen_script_options translate.py docs/options/translate.md 13 | -------------------------------------------------------------------------------- /docs/img/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/architecture.png -------------------------------------------------------------------------------- /docs/img/brnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/brnn.png -------------------------------------------------------------------------------- /docs/img/dbrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/dbrnn.png -------------------------------------------------------------------------------- /docs/img/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/favicon.ico -------------------------------------------------------------------------------- /docs/img/global-attention-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/global-attention-model.png -------------------------------------------------------------------------------- /docs/img/input_feed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/input_feed.png -------------------------------------------------------------------------------- /docs/img/logo-alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/logo-alpha.png -------------------------------------------------------------------------------- /docs/img/pdbrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/pdbrnn.png -------------------------------------------------------------------------------- /docs/img/residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/residual.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | This portal provides a detailled documentation of the OpenNMT toolkit. It describes how to use the PyTorch project and how it works. 2 | 3 | *For the Lua Torch version, visit the documentation at [GitHub](http://opennmt.net/OpenNMT).* 4 | 5 | ## Additional resources 6 | 7 | You can find additional help or tutorials in the following resources: 8 | 9 | * [Forum](http://forum.opennmt.net/) 10 | * [Gitter channel](https://gitter.im/OpenNMT/openmt) 11 | 12 | !!! note "Note" 13 | If you find an error in this documentation, please consider [opening an issue](https://github.com/OpenNMT/OpenNMT-py/issues/new) or directly submitting a modification by clicking on the edit button at the top of a page. 14 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | ## Standard 2 | 3 | 1\. [Install PyTorch](http://pytorch.org/) 4 | 5 | 2\. Clone the OpenNMT-py repository: 6 | 7 | ```bash 8 | git clone https://github.com/OpenNMT/OpenNMT-py 9 | cd OpenNMT-py 10 | ``` 11 | 12 | And you are ready to go! Take a look at the [quickstart](quickstart.md) to familiarize yourself with the main training workflow. 13 | 14 | -------------------------------------------------------------------------------- /docs/options/preprocess.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # preprocess.py: 4 | 5 | ``` 6 | usage: preprocess.py [-h] [-md] [-config CONFIG] -train_src TRAIN_SRC 7 | -train_tgt TRAIN_TGT -valid_src VALID_SRC -valid_tgt 8 | VALID_TGT -save_data SAVE_DATA 9 | [-src_vocab_size SRC_VOCAB_SIZE] 10 | [-tgt_vocab_size TGT_VOCAB_SIZE] [-src_vocab SRC_VOCAB] 11 | [-tgt_vocab TGT_VOCAB] [-seq_length SEQ_LENGTH] 12 | [-shuffle SHUFFLE] [-seed SEED] [-lower] 13 | [-report_every REPORT_EVERY] 14 | 15 | ``` 16 | 17 | preprocess.py 18 | 19 | ## **optional arguments**: 20 | ### **-h, --help** 21 | 22 | ``` 23 | show this help message and exit 24 | ``` 25 | 26 | ### **-md** 27 | 28 | ``` 29 | print Markdown-formatted help text and exit. 30 | ``` 31 | 32 | ### **-config CONFIG** 33 | 34 | ``` 35 | Read options from this file 36 | ``` 37 | 38 | ### **-train_src TRAIN_SRC** 39 | 40 | ``` 41 | Path to the training source data 42 | ``` 43 | 44 | ### **-train_tgt TRAIN_TGT** 45 | 46 | ``` 47 | Path to the training target data 48 | ``` 49 | 50 | ### **-valid_src VALID_SRC** 51 | 52 | ``` 53 | Path to the validation source data 54 | ``` 55 | 56 | ### **-valid_tgt VALID_TGT** 57 | 58 | ``` 59 | Path to the validation target data 60 | ``` 61 | 62 | ### **-save_data SAVE_DATA** 63 | 64 | ``` 65 | Output file for the prepared data 66 | ``` 67 | 68 | ### **-src_vocab_size SRC_VOCAB_SIZE** 69 | 70 | ``` 71 | Size of the source vocabulary 72 | ``` 73 | 74 | ### **-tgt_vocab_size TGT_VOCAB_SIZE** 75 | 76 | ``` 77 | Size of the target vocabulary 78 | ``` 79 | 80 | ### **-src_vocab SRC_VOCAB** 81 | 82 | ``` 83 | Path to an existing source vocabulary 84 | ``` 85 | 86 | ### **-tgt_vocab TGT_VOCAB** 87 | 88 | ``` 89 | Path to an existing target vocabulary 90 | ``` 91 | 92 | ### **-seq_length SEQ_LENGTH** 93 | 94 | ``` 95 | Maximum sequence length 96 | ``` 97 | 98 | ### **-shuffle SHUFFLE** 99 | 100 | ``` 101 | Shuffle data 102 | ``` 103 | 104 | ### **-seed SEED** 105 | 106 | ``` 107 | Random seed 108 | ``` 109 | 110 | ### **-lower** 111 | 112 | ``` 113 | lowercase data 114 | ``` 115 | 116 | ### **-report_every REPORT_EVERY** 117 | 118 | ``` 119 | Report status every this many sentences 120 | ``` 121 | -------------------------------------------------------------------------------- /docs/options/train.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # train.py: 4 | 5 | ``` 6 | usage: train.py [-h] [-md] -data DATA [-save_model SAVE_MODEL] 7 | [-train_from_state_dict TRAIN_FROM_STATE_DICT] 8 | [-train_from TRAIN_FROM] [-layers LAYERS] [-rnn_size RNN_SIZE] 9 | [-word_vec_size WORD_VEC_SIZE] [-input_feed INPUT_FEED] 10 | [-brnn] [-brnn_merge BRNN_MERGE] [-batch_size BATCH_SIZE] 11 | [-max_generator_batches MAX_GENERATOR_BATCHES] 12 | [-epochs EPOCHS] [-start_epoch START_EPOCH] 13 | [-param_init PARAM_INIT] [-optim OPTIM] 14 | [-max_grad_norm MAX_GRAD_NORM] [-dropout DROPOUT] 15 | [-curriculum] [-extra_shuffle] [-learning_rate LEARNING_RATE] 16 | [-learning_rate_decay LEARNING_RATE_DECAY] 17 | [-start_decay_at START_DECAY_AT] 18 | [-pre_word_vecs_enc PRE_WORD_VECS_ENC] 19 | [-pre_word_vecs_dec PRE_WORD_VECS_DEC] [-gpus GPUS [GPUS ...]] 20 | [-log_interval LOG_INTERVAL] 21 | 22 | ``` 23 | 24 | train.py 25 | 26 | ## **optional arguments**: 27 | ### **-h, --help** 28 | 29 | ``` 30 | show this help message and exit 31 | ``` 32 | 33 | ### **-md** 34 | 35 | ``` 36 | print Markdown-formatted help text and exit. 37 | ``` 38 | 39 | ### **-data DATA** 40 | 41 | ``` 42 | Path to the *-train.pt file from preprocess.py 43 | ``` 44 | 45 | ### **-save_model SAVE_MODEL** 46 | 47 | ``` 48 | Model filename (the model will be saved as _epochN_PPL.pt where PPL 49 | is the validation perplexity 50 | ``` 51 | 52 | ### **-train_from_state_dict TRAIN_FROM_STATE_DICT** 53 | 54 | ``` 55 | If training from a checkpoint then this is the path to the pretrained model's 56 | state_dict. 57 | ``` 58 | 59 | ### **-train_from TRAIN_FROM** 60 | 61 | ``` 62 | If training from a checkpoint then this is the path to the pretrained model. 63 | ``` 64 | 65 | ### **-layers LAYERS** 66 | 67 | ``` 68 | Number of layers in the LSTM encoder/decoder 69 | ``` 70 | 71 | ### **-rnn_size RNN_SIZE** 72 | 73 | ``` 74 | Size of LSTM hidden states 75 | ``` 76 | 77 | ### **-word_vec_size WORD_VEC_SIZE** 78 | 79 | ``` 80 | Word embedding sizes 81 | ``` 82 | 83 | ### **-input_feed INPUT_FEED** 84 | 85 | ``` 86 | Feed the context vector at each time step as additional input (via concatenation 87 | with the word embeddings) to the decoder. 88 | ``` 89 | 90 | ### **-brnn** 91 | 92 | ``` 93 | Use a bidirectional encoder 94 | ``` 95 | 96 | ### **-brnn_merge BRNN_MERGE** 97 | 98 | ``` 99 | Merge action for the bidirectional hidden states: [concat|sum] 100 | ``` 101 | 102 | ### **-batch_size BATCH_SIZE** 103 | 104 | ``` 105 | Maximum batch size 106 | ``` 107 | 108 | ### **-max_generator_batches MAX_GENERATOR_BATCHES** 109 | 110 | ``` 111 | Maximum batches of words in a sequence to run the generator on in parallel. 112 | Higher is faster, but uses more memory. 113 | ``` 114 | 115 | ### **-epochs EPOCHS** 116 | 117 | ``` 118 | Number of training epochs 119 | ``` 120 | 121 | ### **-start_epoch START_EPOCH** 122 | 123 | ``` 124 | The epoch from which to start 125 | ``` 126 | 127 | ### **-param_init PARAM_INIT** 128 | 129 | ``` 130 | Parameters are initialized over uniform distribution with support (-param_init, 131 | param_init) 132 | ``` 133 | 134 | ### **-optim OPTIM** 135 | 136 | ``` 137 | Optimization method. [sgd|adagrad|adadelta|adam] 138 | ``` 139 | 140 | ### **-max_grad_norm MAX_GRAD_NORM** 141 | 142 | ``` 143 | If the norm of the gradient vector exceeds this, renormalize it to have the norm 144 | equal to max_grad_norm 145 | ``` 146 | 147 | ### **-dropout DROPOUT** 148 | 149 | ``` 150 | Dropout probability; applied between LSTM stacks. 151 | ``` 152 | 153 | ### **-curriculum** 154 | 155 | ``` 156 | For this many epochs, order the minibatches based on source sequence length. 157 | Sometimes setting this to 1 will increase convergence speed. 158 | ``` 159 | 160 | ### **-extra_shuffle** 161 | 162 | ``` 163 | By default only shuffle mini-batch order; when true, shuffle and re-assign mini- 164 | batches 165 | ``` 166 | 167 | ### **-learning_rate LEARNING_RATE** 168 | 169 | ``` 170 | Starting learning rate. If adagrad/adadelta/adam is used, then this is the 171 | global learning rate. Recommended settings: sgd = 1, adagrad = 0.1, adadelta = 172 | 1, adam = 0.001 173 | ``` 174 | 175 | ### **-learning_rate_decay LEARNING_RATE_DECAY** 176 | 177 | ``` 178 | If update_learning_rate, decay learning rate by this much if (i) perplexity does 179 | not decrease on the validation set or (ii) epoch has gone past start_decay_at 180 | ``` 181 | 182 | ### **-start_decay_at START_DECAY_AT** 183 | 184 | ``` 185 | Start decaying every epoch after and including this epoch 186 | ``` 187 | 188 | ### **-pre_word_vecs_enc PRE_WORD_VECS_ENC** 189 | 190 | ``` 191 | If a valid path is specified, then this will load pretrained word embeddings on 192 | the encoder side. See README for specific formatting instructions. 193 | ``` 194 | 195 | ### **-pre_word_vecs_dec PRE_WORD_VECS_DEC** 196 | 197 | ``` 198 | If a valid path is specified, then this will load pretrained word embeddings on 199 | the decoder side. See README for specific formatting instructions. 200 | ``` 201 | 202 | ### **-gpus GPUS [GPUS ...]** 203 | 204 | ``` 205 | Use CUDA on the listed devices. 206 | ``` 207 | 208 | ### **-log_interval LOG_INTERVAL** 209 | 210 | ``` 211 | Print stats at this interval. 212 | ``` 213 | -------------------------------------------------------------------------------- /docs/options/translate.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # translate.py: 4 | 5 | ``` 6 | usage: translate.py [-h] [-md] -model MODEL -src SRC [-tgt TGT] 7 | [-output OUTPUT] [-beam_size BEAM_SIZE] 8 | [-batch_size BATCH_SIZE] 9 | [-max_sent_length MAX_SENT_LENGTH] [-replace_unk] 10 | [-verbose] [-n_best N_BEST] [-gpu GPU] 11 | 12 | ``` 13 | 14 | translate.py 15 | 16 | ## **optional arguments**: 17 | ### **-h, --help** 18 | 19 | ``` 20 | show this help message and exit 21 | ``` 22 | 23 | ### **-md** 24 | 25 | ``` 26 | print Markdown-formatted help text and exit. 27 | ``` 28 | 29 | ### **-model MODEL** 30 | 31 | ``` 32 | Path to model .pt file 33 | ``` 34 | 35 | ### **-src SRC** 36 | 37 | ``` 38 | Source sequence to decode (one line per sequence) 39 | ``` 40 | 41 | ### **-tgt TGT** 42 | 43 | ``` 44 | True target sequence (optional) 45 | ``` 46 | 47 | ### **-output OUTPUT** 48 | 49 | ``` 50 | Path to output the predictions (each line will be the decoded sequence 51 | ``` 52 | 53 | ### **-beam_size BEAM_SIZE** 54 | 55 | ``` 56 | Beam size 57 | ``` 58 | 59 | ### **-batch_size BATCH_SIZE** 60 | 61 | ``` 62 | Batch size 63 | ``` 64 | 65 | ### **-max_sent_length MAX_SENT_LENGTH** 66 | 67 | ``` 68 | Maximum sentence length. 69 | ``` 70 | 71 | ### **-replace_unk** 72 | 73 | ``` 74 | Replace the generated UNK tokens with the source token that had the highest 75 | attention weight. If phrase_table is provided, it will lookup the identified 76 | source token and give the corresponding target token. If it is not provided (or 77 | the identified source token does not exist in the table) then it will copy the 78 | source token 79 | ``` 80 | 81 | ### **-verbose** 82 | 83 | ``` 84 | Print scores and predictions for each sentence 85 | ``` 86 | 87 | ### **-n_best N_BEST** 88 | 89 | ``` 90 | If verbose is set, will output the n_best decoded sentences 91 | ``` 92 | 93 | ### **-gpu GPU** 94 | 95 | ``` 96 | Device to run on 97 | ``` 98 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | ## Step 1: Preprocess the data 2 | 3 | ```bash 4 | python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo 5 | ``` 6 | 7 | We will be working with some example data in `data/` folder. 8 | 9 | The data consists of parallel source (`src`) and target (`tgt`) data containing one sentence per line with tokens separated by a space: 10 | 11 | * `src-train.txt` 12 | * `tgt-train.txt` 13 | * `src-val.txt` 14 | * `tgt-val.txt` 15 | 16 | Validation files are required and used to evaluate the convergence of the training. It usually contains no more than 5000 sentences. 17 | 18 | ```text 19 | $ head -n 3 data/src-train.txt 20 | It is not acceptable that , with the help of the national bureaucracies , Parliament 's legislative prerogative should be made null and void by means of implementing provisions whose content , purpose and extent are not laid down in advance . 21 | Federal Master Trainer and Senior Instructor of the Italian Federation of Aerobic Fitness , Group Fitness , Postural Gym , Stretching and Pilates; from 2004 , he has been collaborating with Antiche Terme as personal Trainer and Instructor of Stretching , Pilates and Postural Gym . 22 | " Two soldiers came up to me and told me that if I refuse to sleep with them , they will kill me . They beat me and ripped my clothes . 23 | ``` 24 | 25 | After running the preprocessing, the following files are generated: 26 | 27 | * `demo.src.dict`: Dictionary of source vocab to index mappings. 28 | * `demo.tgt.dict`: Dictionary of target vocab to index mappings. 29 | * `demo.train.pt`: serialized PyTorch file containing vocabulary, training and validation data 30 | 31 | The `*.dict` files are needed to check or reuse the vocabularies. These files are simple human-readable dictionaries. 32 | 33 | ```text 34 | $ head -n 10 data/demo.src.dict 35 | 1 36 | 2 37 | 3 38 | 4 39 | It 5 40 | is 6 41 | not 7 42 | acceptable 8 43 | that 9 44 | , 10 45 | with 11 46 | ``` 47 | 48 | Internally the system never touches the words themselves, but uses these indices. 49 | 50 | ## Step 2: Train the model 51 | 52 | ```bash 53 | python train.py -data data/demo.train.pt -save_model demo-model 54 | ``` 55 | 56 | The main train command is quite simple. Minimally it takes a data file 57 | and a save file. This will run the default model, which consists of a 58 | 2-layer LSTM with 500 hidden units on both the encoder/decoder. You 59 | can also add `-gpus 1` to use (say) GPU 1. 60 | 61 | ## Step 3: Translate 62 | 63 | ```bash 64 | python translate.py -model demo-model_epochX_PPL.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose 65 | ``` 66 | 67 | Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`. 68 | 69 | !!! note "Note" 70 | The predictions are going to be quite terrible, as the demo dataset is small. Try running on some larger datasets! For example you can download millions of parallel sentences for [translation](http://www.statmt.org/wmt16/translation-task.html) or [summarization](https://github.com/harvardnlp/sent-summary). 71 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | This is the list of papers, OpenNMT has been inspired on: 2 | 3 | * Luong, M. T., Pham, H., & Manning, C. D. (2015). [Effective approaches to attention-based neural machine translation](https://arxiv.org/abs/1508.04025). arXiv preprint arXiv:1508.04025. 4 | * Sennrich, R., & Haddow, B. (2016). [Linguistic input features improve neural machine translation](https://arxiv.org/abs/1606.02892). arXiv preprint arXiv:1606.02892. 5 | * Sennrich, R., Haddow, B., & Birch, A. (2015). [Neural machine translation of rare words with subword units](https://arxiv.org/abs/1508.07909). arXiv preprint arXiv:1508.07909. 6 | * Wu, Y., Schuster, M., Chen, Z., Le, Q. V., Norouzi, M., Macherey, W., ... & Klingner, J. (2016). [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144). arXiv preprint arXiv:1609.08144. 7 | * Jean, S., Cho, K., Memisevic, R., Bengio, Y. (2015). [On Using Very Large Target Vocabulary for Neural Machine Translation](http://www.aclweb.org/anthology/P15-1001). ACL 2015 8 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | hyp=$1 4 | tgt=$2 5 | tok_gold_targets=$3 6 | 7 | mosesdecoder=/hltsrv1/software/moses/moses-20150228_kenlm_cmph_xmlrpc_irstlm_master/ 8 | 9 | sed -e "s/@@ //g" < $hyp | $mosesdecoder/scripts/tokenizer/detokenizer.perl $tgt | $mosesdecoder/scripts/recaser/detruecase.perl > $hyp.tmp 10 | # Tokenize. 11 | perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l $tgt < $hyp.tmp > $hyp.tok 12 | 13 | # Put compounds in ATAT format (comparable to papers like GNMT, ConvS2S). 14 | # See https://nlp.stanford.edu/projects/nmt/ : 15 | # 'Also, for historical reasons, we split compound words, e.g., 16 | # "rich-text format" --> rich ##AT##-##AT## text format."' 17 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $tok_gold_targets > $tok_gold_targets.atat 18 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $hyp.tok > $hyp.atat 19 | 20 | # Get BLEU. 21 | perl $mosesdecoder/scripts/generic/multi-bleu.perl $tok_gold_targets.atat < $hyp.atat 22 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: OpenNMT-py 2 | repo_name: 'OpenNMT/OpenNMT-py' 3 | repo_url: https://github.com/OpenNMT/OpenNMT-py 4 | edit_uri: edit/master/docs/ 5 | 6 | docs_dir: docs 7 | theme: 'material' 8 | extra: 9 | logo: 'img/logo-alpha.png' 10 | social: 11 | - type: 'globe' 12 | link: 'http://opennmt.net' 13 | - type: 'github' 14 | link: 'https://github.com/OpenNMT/OpenNMT-py' 15 | 16 | google_analytics: ['UA-89222039-1', 'opennmt.net'] 17 | extra_javascript: 18 | - https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_HTMLorMML 19 | extra_css: 20 | - css/extra.css 21 | 22 | markdown_extensions: 23 | - math 24 | - sane_lists 25 | - def_list 26 | - fenced_code 27 | - admonition 28 | - codehilite(guess_lang=false) 29 | - toc(permalink=true) 30 | 31 | pages: 32 | - Overview: index.md 33 | - Installation: installation.md 34 | - Quickstart: quickstart.md 35 | - "Extended Example": extended.md 36 | 37 | # - Data: 38 | # - Preparation: data/preparation.md 39 | # - "Word features": data/word_features.md 40 | # - Training: 41 | # - Models: training/models.md 42 | # - Embeddings: training/embeddings.md 43 | # - Logs: training/logs.md 44 | # - "Multi GPU": training/multi_gpu.md 45 | # - Retraining: training/retraining.md 46 | # - "Decay strategies": training/decay.md 47 | # - "Data sampling": training/sampling.md 48 | # - Translation: 49 | # - Inference: translation/inference.md 50 | # - "Beam search": translation/beam_search.md 51 | # - "Unknown words": translation/unknowns.md 52 | # - Tools: 53 | # - Tokenization: tools/tokenization.md 54 | # - Servers: tools/servers.md 55 | - "Reference: Options": 56 | # - "Scripts usage": options/usage.md 57 | - "preprocess.py": options/preprocess.md 58 | - "train.py": options/train.md 59 | - "translate.py": options/translate.md 60 | # - "tag.lua": options/tag.md 61 | # - "tools/build_vocab.lua": options/build_vocab.md 62 | # - "tools/release_model.lua": options/release_model.md 63 | # - "tools/tokenize.lua": options/tokenize.md 64 | # - "tools/learn_bpe.lua": options/learn_bpe.md 65 | # - "tools/translation_server.lua": options/server.md 66 | # - "tools/rest_translation_server.lua": options/rest_server.md 67 | # - "tools/embeddings.lua": options/embeddings.md 68 | # - Extensions: extensions.md 69 | - References: references.md 70 | # - "Common issues": issues.md 71 | -------------------------------------------------------------------------------- /onmt/Beam.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import onmt 4 | 5 | """ 6 | Class for managing the internals of the beam search process. 7 | 8 | 9 | hyp1-hyp1---hyp1 -hyp1 10 | \ / 11 | hyp2 \-hyp2 /-hyp2hyp2 12 | / \ 13 | hyp3-hyp3---hyp3 -hyp3 14 | ======================== 15 | 16 | Takes care of beams, back pointers, and scores. 17 | """ 18 | 19 | 20 | class Beam(object): 21 | def __init__(self, size, cuda=False): 22 | 23 | self.size = size 24 | self.done = False 25 | 26 | self.tt = torch.cuda if cuda else torch 27 | 28 | # The score for each translation on the beam. 29 | self.scores = self.tt.FloatTensor(size).zero_() 30 | self.allScores = [] 31 | 32 | # The backpointers at each time-step. 33 | self.prevKs = [] 34 | 35 | # The outputs at each time-step. 36 | self.nextYs = [self.tt.LongTensor(size).fill_(onmt.Constants.PAD)] 37 | self.nextYs[0][0] = onmt.Constants.BOS 38 | 39 | # The attentions (matrix) for each time. 40 | self.attn = [] 41 | 42 | def getCurrentState(self): 43 | "Get the outputs for the current timestep." 44 | return self.nextYs[-1] 45 | 46 | def getCurrentOrigin(self): 47 | "Get the backpointers for the current timestep." 48 | return self.prevKs[-1] 49 | 50 | def advance(self, wordLk, attnOut): 51 | """ 52 | Given prob over words for every last beam `wordLk` and attention 53 | `attnOut`: Compute and update the beam search. 54 | 55 | Parameters: 56 | 57 | * `wordLk`- probs of advancing from the last step (K x words) 58 | * `attnOut`- attention at the last step 59 | 60 | Returns: True if beam search is complete. 61 | """ 62 | numWords = wordLk.size(1) 63 | 64 | # Sum the previous scores. 65 | if len(self.prevKs) > 0: 66 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 67 | else: 68 | beamLk = wordLk[0] 69 | 70 | flatBeamLk = beamLk.view(-1) 71 | 72 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 73 | self.allScores.append(self.scores) 74 | self.scores = bestScores 75 | 76 | # bestScoresId is flattened beam x word array, so calculate which 77 | # word and beam each score came from 78 | prevK = bestScoresId / numWords 79 | self.prevKs.append(prevK) 80 | self.nextYs.append(bestScoresId - prevK * numWords) 81 | self.attn.append(attnOut.index_select(0, prevK)) 82 | 83 | # End condition is when top-of-beam is EOS. 84 | if self.nextYs[-1][0] == onmt.Constants.EOS: 85 | self.done = True 86 | self.allScores.append(self.scores) 87 | 88 | return self.done 89 | 90 | def sortBest(self): 91 | return torch.sort(self.scores, 0, True) 92 | 93 | def getBest(self): 94 | "Get the score of the best in the beam." 95 | scores, ids = self.sortBest() 96 | return scores[1], ids[1] 97 | 98 | def getHyp(self, k): 99 | """ 100 | Walk back to construct the full hypothesis. 101 | 102 | Parameters. 103 | 104 | * `k` - the position in the beam to construct. 105 | 106 | Returns. 107 | 108 | 1. The hypothesis 109 | 2. The attention at each time step. 110 | """ 111 | hyp, attn = [], [] 112 | # print(len(self.prevKs), len(self.nextYs), len(self.attn)) 113 | for j in range(len(self.prevKs) - 1, -1, -1): 114 | hyp.append(self.nextYs[j+1][k]) 115 | attn.append(self.attn[j][k]) 116 | k = self.prevKs[j][k] 117 | 118 | return hyp[::-1], torch.stack(attn[::-1]) 119 | -------------------------------------------------------------------------------- /onmt/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /onmt/Dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | import onmt 8 | 9 | 10 | class Dataset(object): 11 | def __init__(self, srcData, tgtData, batchSize, cuda, 12 | volatile=False, data_type="text"): 13 | self.src = srcData 14 | self._type = data_type 15 | if tgtData: 16 | self.tgt = tgtData 17 | assert(len(self.src) == len(self.tgt)) 18 | else: 19 | self.tgt = None 20 | self.cuda = cuda 21 | 22 | self.batchSize = batchSize 23 | self.numBatches = math.ceil(len(self.src)/batchSize) 24 | self.volatile = volatile 25 | 26 | def _batchify(self, data, align_right=False, 27 | include_lengths=False, dtype="text"): 28 | if dtype in ["text", "bitext", "monotext"]: 29 | lengths = [x.size(0) for x in data] 30 | max_length = max(lengths) 31 | out = data[0].new(len(data), max_length).fill_(onmt.Constants.PAD) 32 | for i in range(len(data)): 33 | data_length = data[i].size(0) 34 | offset = max_length - data_length if align_right else 0 35 | out[i].narrow(0, offset, data_length).copy_(data[i]) 36 | if include_lengths: 37 | return out, lengths 38 | else: 39 | return out 40 | elif dtype == "img": 41 | heights = [x.size(1) for x in data] 42 | max_height = max(heights) 43 | widths = [x.size(2) for x in data] 44 | max_width = max(widths) 45 | 46 | out = data[0].new(len(data), 3, max_height, max_width).fill_(0) 47 | for i in range(len(data)): 48 | data_height = data[i].size(1) 49 | data_width = data[i].size(2) 50 | height_offset = max_height - data_height if align_right else 0 51 | width_offset = max_width - data_width if align_right else 0 52 | out[i].narrow(1, height_offset, data_height) \ 53 | .narrow(2, width_offset, data_width).copy_(data[i]) 54 | return out, widths 55 | 56 | def __getitem__(self, index): 57 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches) 58 | srcBatch, lengths = self._batchify( 59 | self.src[index*self.batchSize:(index+1)*self.batchSize], 60 | align_right=False, include_lengths=True, dtype=self._type) 61 | 62 | if self.tgt: 63 | tgtBatch = self._batchify( 64 | self.tgt[index*self.batchSize:(index+1)*self.batchSize], 65 | dtype="text") 66 | else: 67 | tgtBatch = None 68 | 69 | # within batch sorting by decreasing length for variable length rnns 70 | indices = range(len(srcBatch)) 71 | batch = (zip(indices, srcBatch) if tgtBatch is None 72 | else zip(indices, srcBatch, tgtBatch)) 73 | batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1])) 74 | if tgtBatch is None: 75 | indices, srcBatch = zip(*batch) 76 | else: 77 | indices, srcBatch, tgtBatch = zip(*batch) 78 | 79 | def wrap(b, dtype="text"): 80 | if b is None: 81 | return b 82 | b = torch.stack(b, 0) 83 | if dtype in ["text", "bitext", "monotext"]: 84 | b = b.t().contiguous() 85 | if self.cuda: 86 | b = b.cuda() 87 | b = Variable(b, volatile=self.volatile) 88 | return b 89 | 90 | # wrap lengths in a Variable to properly split it in DataParallel 91 | lengths = torch.LongTensor(lengths).view(1, -1) 92 | lengths = Variable(lengths, volatile=self.volatile) 93 | return (wrap(srcBatch, self._type), lengths), \ 94 | wrap(tgtBatch, "text"), indices 95 | 96 | def __len__(self): 97 | return self.numBatches 98 | 99 | def shuffle(self): 100 | data = list(zip(self.src, self.tgt)) 101 | self.src, self.tgt = zip(*[data[i] for i in torch.randperm(len(data))]) 102 | -------------------------------------------------------------------------------- /onmt/Decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import onmt 6 | 7 | from onmt.modules import SRU 8 | 9 | from .modules.SRU_units import AttSRU 10 | from .modules.Attention import getAttention 11 | from .modules.Normalization import LayerNorm 12 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 13 | 14 | 15 | def getDecoder(decoderType): 16 | decoders = {'StackedRNN': StackedRNNDecoder, 17 | 'SR': SGUDecoder, 18 | 'ParallelRNN': ParallelRNNDecoder} 19 | 20 | if decoderType not in decoders: 21 | raise NotImplementedError(decoderType) 22 | 23 | return decoders[decoderType] 24 | 25 | 26 | def getStackedLayer(rnn_type): 27 | if rnn_type == "LSTM": 28 | return StackedLSTM 29 | elif rnn_type == "GRU": 30 | return StackedGRU 31 | else: 32 | return None 33 | 34 | def getRNN(rnn_type): 35 | rnns = {'LSTM': nn.LSTM, 36 | 'GRU': nn.GRU 37 | } 38 | 39 | return rnns[rnn_type] 40 | 41 | class StackedLSTM(nn.Module): 42 | 43 | def __init__(self, num_layers, input_size, rnn_size, dropout): 44 | super(StackedLSTM, self).__init__() 45 | self.dropout = nn.Dropout(dropout) 46 | self.num_layers = num_layers 47 | self.layers = nn.ModuleList() 48 | 49 | for i in range(num_layers): 50 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 51 | input_size = rnn_size 52 | 53 | def forward(self, input, hidden): 54 | h_0, c_0 = hidden 55 | h_1, c_1 = [], [] 56 | for i, layer in enumerate(self.layers): 57 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 58 | input = h_1_i 59 | if i + 1 != self.num_layers: 60 | input = self.dropout(input) 61 | h_1 += [h_1_i] 62 | c_1 += [c_1_i] 63 | 64 | h_1 = torch.stack(h_1) 65 | c_1 = torch.stack(c_1) 66 | 67 | return input, (h_1, c_1) 68 | 69 | 70 | class StackedGRU(nn.Module): 71 | 72 | def __init__(self, num_layers, input_size, rnn_size, dropout): 73 | super(StackedGRU, self).__init__() 74 | self.dropout = nn.Dropout(dropout) 75 | self.num_layers = num_layers 76 | self.layers = nn.ModuleList() 77 | 78 | for i in range(num_layers): 79 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 80 | input_size = rnn_size 81 | 82 | def forward(self, input, hidden): 83 | h_1 = [] 84 | for i, layer in enumerate(self.layers): 85 | h_1_i = layer(input, hidden[i]) 86 | input = h_1_i 87 | if i + 1 != self.num_layers: 88 | input = self.dropout(input) 89 | h_1 += [h_1_i] 90 | 91 | h_1 = torch.stack(h_1) 92 | 93 | return input, h_1 94 | 95 | 96 | class StackedSGU(nn.Module): 97 | 98 | def __init__(self, num_layers, input_size, rnn_size, layer_norm, dropout): 99 | super(StackedSGU, self).__init__() 100 | self.num_layers = num_layers 101 | self.layers = nn.ModuleList() 102 | self.dropout = nn.Dropout(dropout) 103 | for i in range(num_layers): 104 | self.layers.append(AttSRU(input_size, 105 | rnn_size, rnn_size, layer_norm, dropout)) 106 | input_size = rnn_size 107 | 108 | def initialize_parameters(self, param_init): 109 | for layer in self.layers: 110 | layer.initialize_parameters(param_init) 111 | 112 | def forward(self, dec_state, hidden, enc_out): 113 | input = dec_state 114 | first_input = dec_state 115 | hiddens = [] 116 | for i, layer in enumerate(self.layers): 117 | input, new_hidden, attn_state = layer(input, hidden[i], enc_out) 118 | hiddens += [new_hidden] 119 | 120 | return self.dropout(input), torch.stack(hiddens), attn_state 121 | 122 | 123 | class StackedRNNDecoder(nn.Module): 124 | 125 | def __init__(self, opt, dicts): 126 | 127 | self.layers = opt.layers_dec 128 | self.input_feed = opt.input_feed 129 | self.hidden_size = opt.rnn_size 130 | 131 | input_size = opt.word_vec_size 132 | if self.input_feed: 133 | input_size += opt.rnn_size 134 | 135 | super(StackedRNNDecoder, self).__init__() 136 | self.word_lut = nn.Embedding(dicts.size(), 137 | opt.word_vec_size, 138 | padding_idx=onmt.Constants.PAD) 139 | 140 | rnn_type = opt.rnn_decoder_type if opt.rnn_decoder_type else opt.rnn_type 141 | if rnn_type in ['LSTM', 'GRU']: 142 | self.rnn = getStackedLayer(rnn_type)\ 143 | (opt.layers_dec, input_size, opt.rnn_size, opt.dropout) 144 | else: 145 | self.rnn = getStackedLayer(rnn_type) \ 146 | (opt.layers_dec, input_size, opt.rnn_size, opt.activ, 147 | opt.layer_norm, opt.dropout) 148 | 149 | self.attn = getAttention(opt.attn_type)(opt.rnn_size, opt.activ) 150 | 151 | self.linear_ctx = nn.Linear(opt.rnn_size, opt.rnn_size) 152 | self.linear_out = nn.Linear(2 * opt.rnn_size, opt.rnn_size) 153 | 154 | self.dropout = nn.Dropout(opt.dropout) 155 | self.log = self.rnn.log if hasattr(self.rnn, 'log') else False 156 | 157 | self.layer_norm = opt.layer_norm 158 | if self.layer_norm: 159 | self.ctx_ln = LayerNorm(opt.rnn_size) 160 | 161 | self.activ = getattr(F, opt.activ) 162 | 163 | def load_pretrained_vectors(self, opt): 164 | if opt.pre_word_vecs_dec is not None: 165 | pretrained = torch.load(opt.pre_word_vecs_dec) 166 | self.word_lut.weight.data.copy_(pretrained) 167 | 168 | def initialize_parameters(self, param_init): 169 | pass 170 | 171 | def forward(self, input, hidden, context, init_output): 172 | """ 173 | input: targetL x batch 174 | hidden: batch x hidden_dim 175 | context: sourceL x batch x hidden_dim 176 | init_output: batch x hidden_dim 177 | """ 178 | # targetL x batch x hidden_dim 179 | emb = self.word_lut(input) 180 | 181 | # batch x sourceL x hidden_dim 182 | context = context.transpose(0, 1) 183 | 184 | # n.b. you can increase performance if you compute W_ih * x for all 185 | # iterations in parallel, but that's only possible if 186 | # self.input_feed=False 187 | outputs = [] 188 | output = init_output 189 | 190 | for emb_t in emb.split(1): 191 | # batch x word_dim 192 | emb_inp = emb_t.squeeze(0) 193 | 194 | if self.input_feed == 1: 195 | # batch x (word_dim+hidden_dim) 196 | emb_inp_feed = torch.cat([emb_inp, output], 1) 197 | else: 198 | emb_inp_feed = emb_inp 199 | 200 | # batch x hidden_dim, layers x batch x hidden_dim 201 | if self.log: 202 | rnn_output, hidden, activ = self.rnn(emb_inp_feed, hidden) 203 | else: 204 | rnn_output, hidden = self.rnn(emb_inp_feed, hidden) 205 | 206 | values = context 207 | pctx = self.linear_ctx(self.dropout(context)) 208 | if self.layer_norm: 209 | pctx = self.ctx_ln(pctx) 210 | weightedContext, attn = self.attn(rnn_output, pctx, values) 211 | 212 | contextCombined = self.linear_out(torch.cat([rnn_output, weightedContext], dim=-1)) 213 | 214 | output = self.activ(contextCombined) 215 | output = self.dropout(output) 216 | outputs += [output] 217 | 218 | outputs = torch.stack(outputs) 219 | 220 | if self.log: 221 | return outputs, hidden, attn, activ 222 | 223 | return outputs, hidden, attn 224 | 225 | 226 | class SGUDecoder(nn.Module): 227 | 228 | def __init__(self, opt, dicts): 229 | self.layers = opt.layers_dec 230 | self.hidden_size = opt.rnn_size 231 | 232 | input_size = opt.word_vec_size 233 | 234 | super(SGUDecoder, self).__init__() 235 | self.word_lut = nn.Embedding(dicts.size(), 236 | opt.word_vec_size, 237 | padding_idx=onmt.Constants.PAD) 238 | 239 | self.stacked = StackedSGU(opt.layers_dec, opt.rnn_size, 240 | opt.rnn_size, opt.layer_norm, 241 | opt.dropout) 242 | 243 | self.log = False 244 | 245 | def load_pretrained_vectors(self, opt): 246 | if opt.pre_word_vecs_dec is not None: 247 | pretrained = torch.load(opt.pre_word_vecs_dec) 248 | self.word_lut.weight.data.copy_(pretrained) 249 | 250 | def initialize_parameters(self, param_init): 251 | self.stacked.initialize_parameters(param_init) 252 | #self.attn.initialize_parameters(param_init) 253 | 254 | def forward(self, input, hidden, context, init_output): 255 | """ 256 | input: targetL x batch 257 | hidden: num_layers x batch x hidden_dim 258 | context: sourceL x batch x hidden_dim 259 | init_output: batch x hidden_dim 260 | """ 261 | batch_size = input.size(1) 262 | hidden_dim = context.size(2) 263 | 264 | #targetL x batch x hidden_dim 265 | emb = self.word_lut(input) 266 | 267 | # batch x sourceL x hidden_dim 268 | context = context.transpose(0, 1) 269 | if len(hidden.size()) < 3: 270 | hidden = hidden.unsqueeze(0) 271 | 272 | # (targetL x batch) x sourceL x hidden_dim 273 | #values = context.repeat(emb.size(0), 1, 1) 274 | rnn_outputs = emb #.view(-1, hidden_dim) 275 | 276 | outputs, hidden, attn = self.stacked(rnn_outputs, hidden, context) 277 | 278 | return outputs, hidden, attn 279 | 280 | 281 | class StackedSRU(nn.Module): 282 | def __init__(self, num_layers, input_size, rnn_size, dropout): 283 | super(StackedSRU, self).__init__() 284 | self.dropout = nn.Dropout(dropout) 285 | self.num_layers = num_layers 286 | self.layers = nn.ModuleList() 287 | 288 | for i in range(num_layers): 289 | self.layers.append(SRU(input_size, rnn_size, dropout)) 290 | input_size = rnn_size 291 | 292 | def initialize_parameters(self, param_init): 293 | for layer in self.layers: 294 | layer.initialize_parameters(param_init) 295 | 296 | def forward(self, input, hidden): 297 | """ 298 | 299 | :param input: batch x hi 300 | :param hidden: 301 | :return: 302 | """ 303 | h_1 = [] 304 | for i, layer in enumerate(self.layers): 305 | h_1_i, h = layer(input, hidden[i]) 306 | input = h_1_i 307 | h_1 += [h] 308 | 309 | h_1 = torch.stack(h_1) 310 | 311 | return input, h_1 312 | 313 | 314 | class ParallelRNNDecoder(nn.Module): 315 | def __init__(self, opt, dicts): 316 | from .modules.Attention import MLPAttention 317 | 318 | self.layers = opt.layers_dec 319 | self.hidden_size = opt.rnn_size 320 | 321 | input_size = opt.word_vec_size 322 | 323 | super(ParallelRNNDecoder, self).__init__() 324 | self.word_lut = nn.Embedding(dicts.size(), 325 | opt.word_vec_size, 326 | padding_idx=onmt.Constants.PAD) 327 | 328 | self.rnn = StackedSRU(self.layers, input_size, self.hidden_size, opt.dropout) 329 | 330 | self.attn = MLPAttention(opt.rnn_size, opt.activ) # getAttention(opt.attn_type)(opt.rnn_size, opt.activ) 331 | 332 | self.linear_ctx = nn.Linear(opt.rnn_size, opt.rnn_size) 333 | self.linear_out = nn.Linear(2 * opt.rnn_size, opt.rnn_size) 334 | 335 | self.dropout = nn.Dropout(opt.dropout) 336 | 337 | def load_pretrained_vectors(self, opt): 338 | if opt.pre_word_vecs_dec is not None: 339 | pretrained = torch.load(opt.pre_word_vecs_dec) 340 | self.word_lut.weight.data.copy_(pretrained) 341 | 342 | def initialize_parameters(self, param_init): 343 | pass 344 | 345 | def forward(self, input, hidden, context, init_output): 346 | """ 347 | input: targetL x batch 348 | hidden: batch x hidden_dim 349 | context: sourceL x batch x hidden_dim 350 | init_output: batch x hidden_dim 351 | """ 352 | # targetL x batch x hidden_dim 353 | emb = self.word_lut(input) 354 | 355 | # batch x sourceL x hidden_dim 356 | context = context.transpose(0, 1) 357 | 358 | # batch x hidden_dim, layers x batch x hidden_dim 359 | rnn_output, hidden = self.rnn(emb, hidden) 360 | 361 | values = context 362 | pctx = self.linear_ctx(self.dropout(context)) 363 | 364 | weightedContext, attn = self.attn(self.dropout(rnn_output), pctx, values) 365 | 366 | contextCombined = self.linear_out(self.dropout(torch.cat([rnn_output, weightedContext], dim=-1))) 367 | 368 | output = F.tanh(contextCombined) 369 | output = self.dropout(output) 370 | 371 | return output, hidden, attn 372 | -------------------------------------------------------------------------------- /onmt/Dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Dict(object): 5 | def __init__(self, data=None, lower=False): 6 | self.idxToLabel = {} 7 | self.labelToIdx = {} 8 | self.frequencies = {} 9 | self.lower = lower 10 | 11 | # Special entries will not be pruned. 12 | self.special = [] 13 | 14 | if data is not None: 15 | if type(data) == str: 16 | self.loadFile(data) 17 | else: 18 | self.addSpecials(data) 19 | 20 | def size(self): 21 | return len(self.idxToLabel) 22 | 23 | def loadFile(self, filename): 24 | "Load entries from a file." 25 | for line in open(filename): 26 | fields = line.split() 27 | label = fields[0] 28 | idx = int(fields[1]) 29 | self.add(label, idx) 30 | 31 | def writeFile(self, filename): 32 | "Write entries to a file." 33 | with open(filename, 'w') as file: 34 | for i in range(self.size()): 35 | label = self.idxToLabel[i] 36 | file.write('%s %d\n' % (label, i)) 37 | 38 | file.close() 39 | 40 | def lookup(self, key, default=None): 41 | key = key.lower() if self.lower else key 42 | try: 43 | return self.labelToIdx[key] 44 | except KeyError: 45 | return default 46 | 47 | def getLabel(self, idx, default=None): 48 | try: 49 | return self.idxToLabel[idx] 50 | except KeyError: 51 | return default 52 | 53 | def addSpecial(self, label, idx=None): 54 | "Mark this `label` and `idx` as special (i.e. will not be pruned)." 55 | idx = self.add(label, idx) 56 | self.special += [idx] 57 | 58 | def addSpecials(self, labels): 59 | "Mark all labels in `labels` as specials (i.e. will not be pruned)." 60 | for label in labels: 61 | self.addSpecial(label) 62 | 63 | def add(self, label, idx=None): 64 | "Add `label` in the dictionary. Use `idx` as its index if given." 65 | label = label.lower() if self.lower else label 66 | if idx is not None: 67 | self.idxToLabel[idx] = label 68 | self.labelToIdx[label] = idx 69 | else: 70 | if label in self.labelToIdx: 71 | idx = self.labelToIdx[label] 72 | else: 73 | idx = len(self.idxToLabel) 74 | self.idxToLabel[idx] = label 75 | self.labelToIdx[label] = idx 76 | 77 | if idx not in self.frequencies: 78 | self.frequencies[idx] = 1 79 | else: 80 | self.frequencies[idx] += 1 81 | 82 | return idx 83 | 84 | def prune(self, size): 85 | "Return a new dictionary with the `size` most frequent entries." 86 | if size >= self.size(): 87 | return self 88 | 89 | # Only keep the `size` most frequent entries. 90 | freq = torch.Tensor( 91 | [self.frequencies[i] for i in range(len(self.frequencies))]) 92 | _, idx = torch.sort(freq, 0, True) 93 | 94 | newDict = Dict() 95 | newDict.lower = self.lower 96 | 97 | # Add special entries in all cases. 98 | for i in self.special: 99 | newDict.addSpecial(self.idxToLabel[i]) 100 | 101 | for i in idx[:size]: 102 | newDict.add(self.idxToLabel[i]) 103 | 104 | return newDict 105 | 106 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 107 | """ 108 | Convert `labels` to indices. Use `unkWord` if not found. 109 | Optionally insert `bosWord` at the beginning and `eosWord` at the . 110 | """ 111 | vec = [] 112 | 113 | if bosWord is not None: 114 | vec += [self.lookup(bosWord)] 115 | 116 | unk = self.lookup(unkWord) 117 | vec += [self.lookup(label, default=unk) for label in labels] 118 | 119 | if eosWord is not None: 120 | vec += [self.lookup(eosWord)] 121 | 122 | return torch.LongTensor(vec) 123 | 124 | def convertToLabels(self, idx, stop): 125 | """ 126 | Convert `idx` to labels. 127 | If index `stop` is reached, convert it and return. 128 | """ 129 | 130 | labels = [] 131 | 132 | for i in idx: 133 | labels += [self.getLabel(i)] 134 | if i == stop: 135 | break 136 | 137 | return labels 138 | -------------------------------------------------------------------------------- /onmt/Encoders.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | import onmt 6 | from .modules.SRU_units import BiSRU 7 | 8 | from torch.nn.utils.rnn import PackedSequence 9 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 10 | from torch.nn.utils.rnn import pack_padded_sequence as pack 11 | 12 | from onmt.modules.Units import ParallelMyRNN 13 | 14 | def getEncoder(encoder_type): 15 | encoders = {'RNN': Encoder, 16 | 'SR': SGUEncoder} 17 | if encoder_type not in encoders: 18 | raise NotImplementedError(encoder_type) 19 | return encoders[encoder_type] 20 | 21 | 22 | class Encoder(nn.Module): 23 | 24 | def __init__(self, opt, dicts): 25 | 26 | def getunittype(rnn_type): 27 | if rnn_type in ['LSTM', 'GRU']: 28 | return getattr(nn, rnn_type) 29 | elif rnn_type == 'SRU': 30 | return ParallelMyRNN 31 | 32 | self.layers = opt.layers_enc 33 | self.num_directions = 2 if opt.brnn else 1 34 | assert opt.rnn_size % self.num_directions == 0 35 | self.hidden_size = opt.rnn_size // self.num_directions 36 | 37 | super(Encoder, self).__init__() 38 | self.word_lut = nn.Embedding(dicts.size(), 39 | opt.word_vec_size, 40 | padding_idx=onmt.Constants.PAD) 41 | 42 | rnn_type = opt.rnn_encoder_type if opt.rnn_encoder_type else opt.rnn_type 43 | self.rnn = getunittype(rnn_type)( 44 | opt.word_vec_size, self.hidden_size, 45 | num_layers=opt.layers_enc, dropout=opt.dropout, 46 | bidirectional=opt.brnn) 47 | 48 | def load_pretrained_vectors(self, opt): 49 | if opt.pre_word_vecs_enc is not None: 50 | pretrained = torch.load(opt.pre_word_vecs_enc) 51 | self.word_lut.weight.data.copy_(pretrained) 52 | 53 | def initialize_parameters(self, param_init): 54 | if hasattr(self.rnn, 'initialize_parameters'): 55 | self.rnn.initialize_parameters(param_init) 56 | 57 | def forward(self, input, hidden=None): 58 | 59 | if isinstance(input, tuple): 60 | # Lengths data is wrapped inside a Variable. 61 | lengths = input[1].data.view(-1).tolist() 62 | emb = pack(self.word_lut(input[0]), lengths) 63 | else: 64 | emb = self.word_lut(input) 65 | outputs, hidden_t = self.rnn(emb, hidden) 66 | if isinstance(outputs, PackedSequence): 67 | outputs = unpack(outputs)[0] 68 | 69 | return hidden_t, outputs, emb 70 | 71 | class StackedSGU(nn.Module): 72 | 73 | def __init__(self, layers, input_size, hidden_size, layer_norm, dropout): 74 | self.layers = layers 75 | super(StackedSGU, self).__init__() 76 | self.sgus = nn.ModuleList() 77 | self.dropout = nn.Dropout(dropout) 78 | for _ in range(layers): 79 | self.sgus.append(BiSRU(input_size, hidden_size, layer_norm, dropout)) 80 | input_size = hidden_size 81 | 82 | def initialize_parameters(self, param_init): 83 | for sgu in self.sgus: 84 | sgu.initialize_parameters(param_init) 85 | 86 | def forward(self, input): 87 | 88 | hiddens = [] 89 | for i in range(self.layers): 90 | input = self.sgus[i](input) 91 | hiddens += [input[-1]] 92 | return input, torch.stack(hiddens) 93 | 94 | class SGUEncoder(nn.Module): 95 | 96 | def __init__(self, opt, dicts): 97 | 98 | self.layers = opt.layers_enc 99 | self.num_directions = 2 if opt.brnn else 1 100 | assert opt.rnn_size % self.num_directions == 0 101 | self.hidden_size = opt.rnn_size // self.num_directions 102 | 103 | super(SGUEncoder, self).__init__() 104 | self.word_lut = nn.Embedding(dicts.size(), 105 | opt.word_vec_size, 106 | padding_idx=onmt.Constants.PAD) 107 | self.sgu = StackedSGU(self.layers, opt.word_vec_size, 108 | self.hidden_size * self.num_directions, opt.layer_norm, 109 | opt.dropout) 110 | 111 | 112 | def load_pretrained_vectors(self, opt): 113 | if opt.pre_word_vecs_enc is not None: 114 | pretrained = torch.load(opt.pre_word_vecs_enc) 115 | self.word_lut.weight.data.copy_(pretrained) 116 | 117 | def initialize_parameters(self, param_init): 118 | self.sgu.initialize_parameters(param_init) 119 | 120 | def forward(self, input, hidden=None): 121 | 122 | if isinstance(input, tuple): 123 | # Lengths data is wrapped inside a Variable. 124 | lengths = input[1].data.view(-1).tolist() 125 | emb = self.word_lut(input[0]) 126 | else: 127 | emb = self.word_lut(input) 128 | outputs, hidden_t = self.sgu(emb) 129 | 130 | return hidden_t, outputs, emb 131 | -------------------------------------------------------------------------------- /onmt/Markdown.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The Chromium Authors. All rights reserved. 2 | # Use of this source code is governed by a BSD-style license that can be 3 | # found in the LICENSE file. 4 | import argparse 5 | 6 | 7 | class MarkdownHelpFormatter(argparse.HelpFormatter): 8 | """A really bare-bones argparse help formatter that generates valid markdown. 9 | This will generate something like: 10 | usage 11 | # **section heading**: 12 | ## **--argument-one** 13 | ``` 14 | argument-one help text 15 | ``` 16 | """ 17 | 18 | def _format_usage(self, usage, actions, groups, prefix): 19 | usage_text = super(MarkdownHelpFormatter, self)._format_usage( 20 | usage, actions, groups, prefix) 21 | return '\n```\n%s\n```\n\n' % usage_text 22 | 23 | def format_help(self): 24 | self._root_section.heading = '# %s' % self._prog 25 | return super(MarkdownHelpFormatter, self).format_help() 26 | 27 | def start_section(self, heading): 28 | super(MarkdownHelpFormatter, self).start_section('## **%s**' % heading) 29 | 30 | def _format_action(self, action): 31 | lines = [] 32 | action_header = self._format_action_invocation(action) 33 | lines.append('### **%s** ' % action_header) 34 | if action.help: 35 | lines.append('') 36 | lines.append('```') 37 | help_text = self._expand_help(action) 38 | lines.extend(self._split_lines(help_text, 80)) 39 | lines.append('```') 40 | lines.extend(['', '']) 41 | return '\n'.join(lines) 42 | 43 | 44 | class MarkdownHelpAction(argparse.Action): 45 | def __init__(self, option_strings, 46 | dest=argparse.SUPPRESS, default=argparse.SUPPRESS, 47 | **kwargs): 48 | super(MarkdownHelpAction, self).__init__( 49 | option_strings=option_strings, 50 | dest=dest, 51 | default=default, 52 | nargs=0, 53 | **kwargs) 54 | 55 | def __call__(self, parser, namespace, values, option_string=None): 56 | parser.formatter_class = MarkdownHelpFormatter 57 | parser.print_help() 58 | parser.exit() 59 | 60 | 61 | def add_md_help_argument(parser): 62 | parser.add_argument('-md', action=MarkdownHelpAction, 63 | help='print Markdown-formatted help text and exit.') 64 | -------------------------------------------------------------------------------- /onmt/Models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | from .Encoders import Encoder 4 | 5 | 6 | class NMTModel(nn.Module): 7 | 8 | def __init__(self, encoder, decoder): 9 | super(NMTModel, self).__init__() 10 | self.encoder = encoder 11 | self.decoder = decoder 12 | 13 | def make_init_decoder_output(self, context): 14 | batch_size = context.size(1) 15 | h_size = (batch_size, self.decoder.hidden_size) 16 | return Variable(context.data.new(*h_size).zero_(), requires_grad=False) 17 | 18 | def load_pretrained_vectors(self, opt): 19 | self.encoder.load_pretrained_vectors(opt) 20 | self.decoder.load_pretrained_vectors(opt) 21 | 22 | def initialize_parameters(self, param_init): 23 | self.encoder.initialize_parameters(param_init) 24 | self.decoder.initialize_parameters(param_init) 25 | 26 | def brnn_merge_concat(self, h): 27 | # the encoder hidden is (layers*directions) x batch x dim 28 | # we need to convert it to layers x batch x (directions*dim) 29 | if self.encoder.num_directions == 2: 30 | return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ 31 | .transpose(1, 2).contiguous() \ 32 | .view(h.size(0) // 2, h.size(1), h.size(2) * 2) 33 | else: 34 | return h 35 | 36 | def forward(self, input): 37 | src = input[0] 38 | tgt = input[1][:-1] # exclude last target from inputs 39 | enc_hidden, context, emb = self.encoder(src) 40 | init_output = self.make_init_decoder_output(context) 41 | 42 | if isinstance(self.encoder, Encoder): 43 | if isinstance(enc_hidden, tuple): 44 | enc_hidden = tuple(self.brnn_merge_concat(enc_hidden[i]) 45 | for i in range(len(enc_hidden))) 46 | else: 47 | enc_hidden = self.brnn_merge_concat(enc_hidden) 48 | if enc_hidden.size(0) < self.decoder.layers: 49 | enc_hidden = enc_hidden.repeat(self.decoder.layers, 1, 1) 50 | else: 51 | enc_hidden = Variable(enc_hidden.data.new(*enc_hidden.size()).zero_(), requires_grad=False) 52 | 53 | #self.decoder.mask_attention(src[0]) 54 | out, dec_hidden, _attn = self.decoder(tgt, enc_hidden, 55 | context, init_output) 56 | return out 57 | -------------------------------------------------------------------------------- /onmt/Optim.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.nn.utils import clip_grad_norm 3 | 4 | 5 | class Optim(object): 6 | 7 | def set_parameters(self, params): 8 | self.params = list(params) # careful: params may be a generator 9 | 10 | if self.method == 'sgd': 11 | self.optimizer = optim.SGD(self.params, lr=self.lr, momentum=0.9, weight_decay=1e-5, nesterov=True) 12 | elif self.method == 'adagrad': 13 | self.optimizer = optim.Adagrad(self.params, lr=self.lr) 14 | elif self.method == 'adadelta': 15 | self.optimizer = optim.Adadelta(self.params, lr=self.lr) 16 | elif self.method == 'adam': 17 | self.optimizer = optim.Adam(self.params, lr=self.lr) 18 | else: 19 | raise RuntimeError("Invalid optim method: " + self.method) 20 | 21 | def __init__(self, method, lr, max_grad_norm, 22 | lr_decay=1, start_decay_at=None): 23 | self.last_ppl = None 24 | self.lr = lr 25 | self.max_grad_norm = max_grad_norm 26 | self.method = method 27 | self.lr_decay = lr_decay 28 | self.start_decay_at = start_decay_at 29 | self.start_decay = False 30 | 31 | def step(self): 32 | "Compute gradients norm." 33 | if self.max_grad_norm: 34 | clip_grad_norm(self.params, self.max_grad_norm) 35 | self.optimizer.step() 36 | 37 | 38 | def updateLearningRate(self, ppl, iter): 39 | """ 40 | Decay learning rate if val perf does not improve 41 | or we hit the start_decay_at limit. 42 | """ 43 | 44 | if self.start_decay_at is not None and iter >= self.start_decay_at: 45 | self.start_decay = True 46 | if self.last_ppl is not None and ppl > self.last_ppl: 47 | self.start_decay = True 48 | 49 | if self.start_decay: 50 | self.lr = self.lr * self.lr_decay 51 | print("Decaying learning rate to %g" % self.lr) 52 | 53 | self.last_ppl = ppl 54 | self.optimizer.param_groups[0]['lr'] = self.lr 55 | -------------------------------------------------------------------------------- /onmt/Translator.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | import onmt.Models 3 | import onmt.modules 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | from .Decoders import getDecoder, SGUDecoder 8 | from .Encoders import getEncoder, Encoder, SGUEncoder 9 | from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence as unpack 10 | import sys 11 | 12 | 13 | 14 | 15 | def loadImageLibs(): 16 | "Conditional import of torch image libs." 17 | global Image, transforms 18 | from PIL import Image 19 | from torchvision import transforms 20 | 21 | 22 | class Translator(object): 23 | def __init__(self, opt): 24 | self.opt = opt 25 | self.tt = torch.cuda if opt.cuda else torch 26 | self.beam_accum = None 27 | 28 | checkpoint = torch.load(opt.model, 29 | map_location=lambda storage, loc: storage) 30 | 31 | model_opt = checkpoint['opt'] 32 | self.model_opt = model_opt 33 | self.src_dict = checkpoint['dicts']['src'] 34 | self.tgt_dict = checkpoint['dicts']['tgt'] 35 | self._type = "text" #model_opt.encoder_type \ 36 | #if "encoder_type" in model_opt else "text" 37 | 38 | #if self._type == "text": 39 | # encoder = Encoder(model_opt, self.src_dict) 40 | #elif self._type == "img": 41 | # loadImageLibs() 42 | # encoder = onmt.modules.ImageEncoder(model_opt) 43 | print("Translator layer_norm:", model_opt.layer_norm) 44 | 45 | encoder = getEncoder(model_opt.encoder_type)(model_opt, self.src_dict) 46 | decoder = getDecoder(model_opt.decoder_type)(model_opt, self.tgt_dict) 47 | model = onmt.Models.NMTModel(encoder, decoder) 48 | 49 | generator = nn.Sequential( 50 | nn.Linear(model_opt.rnn_size, self.tgt_dict.size()), 51 | nn.LogSoftmax()) 52 | 53 | model.load_state_dict(checkpoint['model']) 54 | generator.load_state_dict(checkpoint['generator']) 55 | 56 | if opt.cuda: 57 | model.cuda() 58 | generator.cuda() 59 | else: 60 | model.cpu() 61 | generator.cpu() 62 | 63 | model.generator = generator 64 | 65 | self.model = model 66 | self.model.eval() 67 | 68 | def initBeamAccum(self): 69 | self.beam_accum = { 70 | "predicted_ids": [], 71 | "beam_parent_ids": [], 72 | "scores": [], 73 | "log_probs": []} 74 | 75 | def _getBatchSize(self, batch): 76 | if self._type == "text": 77 | return batch.size(1) 78 | else: 79 | return batch.size(0) 80 | 81 | def buildData(self, srcBatch, goldBatch): 82 | # This needs to be the same as preprocess.py. 83 | if self._type == "text": 84 | srcData = [self.src_dict.convertToIdx(b, 85 | onmt.Constants.UNK_WORD) 86 | for b in srcBatch] 87 | elif self._type == "img": 88 | srcData = [transforms.ToTensor()( 89 | Image.open(self.opt.src_img_dir + "/" + b[0])) 90 | for b in srcBatch] 91 | 92 | tgtData = None 93 | if goldBatch: 94 | tgtData = [self.tgt_dict.convertToIdx(b, 95 | onmt.Constants.UNK_WORD, 96 | onmt.Constants.BOS_WORD, 97 | onmt.Constants.EOS_WORD) for b in goldBatch] 98 | 99 | return onmt.Dataset(srcData, tgtData, self.opt.batch_size, 100 | self.opt.cuda, volatile=True, 101 | data_type=self._type) 102 | 103 | def buildTargetTokens(self, pred, src, attn): 104 | tokens = self.tgt_dict.convertToLabels(pred, onmt.Constants.EOS) 105 | tokens = tokens[:-1] # EOS 106 | if self.opt.replace_unk: 107 | for i in range(len(tokens)): 108 | if tokens[i] == onmt.Constants.UNK_WORD: 109 | _, maxIndex = attn[i].max(0) 110 | tokens[i] = src[maxIndex[0]] 111 | return tokens 112 | 113 | def translateBatch(self, srcBatch, tgtBatch): 114 | # Batch size is in different location depending on data. 115 | 116 | beamSize = self.opt.beam_size 117 | 118 | # (1) run the encoder on the src 119 | encStates, context, emb = self.model.encoder(srcBatch) 120 | 121 | # Drop the lengths needed for encoder. 122 | srcBatch = srcBatch[0] 123 | batchSize = self._getBatchSize(srcBatch) 124 | 125 | rnnSize = context.size(2) 126 | decoder = self.model.decoder 127 | attentionLayer = decoder.attn if hasattr(decoder, 'attn') else None 128 | 129 | if isinstance(self.model.encoder, Encoder): 130 | if isinstance(encStates, tuple): 131 | encStates = tuple(self.model.brnn_merge_concat(encStates[i]) 132 | for i in range(len(encStates))) 133 | else: 134 | encStates = self.model.brnn_merge_concat(encStates) 135 | if encStates.size(0) < decoder.layers: 136 | encStates = encStates.repeat(decoder.layers, 1, 1) 137 | else: 138 | encStates = Variable(encStates.data.new(*encStates.size()).zero_(), requires_grad=False) 139 | # encStates = encStates.unsqueeze(0).repeat(decoder.layers, 1, 1) 140 | 141 | 142 | useMasking = not isinstance(decoder, SGUDecoder) #self._type.endswith("text") 143 | 144 | # This mask is applied to the attention model inside the decoder 145 | # so that the attention ignores source padding 146 | padMask = None 147 | if useMasking: 148 | padMask = srcBatch.data.eq(onmt.Constants.PAD).t() 149 | 150 | def mask(padMask): 151 | if useMasking: 152 | attentionLayer.applyMask(padMask) 153 | 154 | # (2) if a target is specified, compute the 'goldScore' 155 | # (i.e. log likelihood) of the target under the model 156 | goldScores = context.data.new(batchSize).zero_() 157 | 158 | if tgtBatch is not None: 159 | 160 | decStates = encStates 161 | 162 | mask(padMask) 163 | initOutput = self.model.make_init_decoder_output(context) 164 | 165 | decOut, decStates, attn = self.model.decoder( 166 | tgtBatch[:-1], decStates, context, initOutput) 167 | for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): 168 | gen_t = self.model.generator.forward(dec_t) 169 | tgt_t = tgt_t.unsqueeze(1) 170 | scores = gen_t.data.gather(1, tgt_t) 171 | scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) 172 | goldScores += scores 173 | 174 | # (3) run the decoder to generate sentences, using beam search 175 | 176 | # Expand tensors for each beam. 177 | context = Variable(context.data.repeat(1, beamSize, 1)) 178 | if isinstance(emb, PackedSequence): 179 | emb = Variable(unpack(emb)[0].data.repeat(1, beamSize, 1)) 180 | else: 181 | emb = Variable(emb.data.repeat(1, beamSize, 1)) 182 | 183 | if isinstance(encStates, tuple): 184 | decStates = tuple(Variable(encStates[i].data.repeat(1, beamSize, 1)) 185 | for i in range(len(encStates))) 186 | else: 187 | decStates = Variable(encStates.data.repeat(1, beamSize, 1)) 188 | 189 | beam = [onmt.Beam(beamSize, self.opt.cuda) for _ in range(batchSize)] 190 | 191 | decOut = self.model.make_init_decoder_output(context) 192 | 193 | if useMasking: 194 | padMask = srcBatch.data.eq( 195 | onmt.Constants.PAD).t() \ 196 | .unsqueeze(0) \ 197 | .repeat(beamSize, 1, 1) 198 | 199 | batchIdx = list(range(batchSize)) 200 | remainingSents = batchSize 201 | 202 | activs = [] 203 | for i in range(self.opt.max_sent_length): 204 | mask(padMask) 205 | # Prepare decoder input. 206 | input = torch.stack([b.getCurrentState() for b in beam 207 | if not b.done]).t().contiguous().view(1, -1) 208 | 209 | #if self.model.decoder.log: 210 | # decOut, decStates, attn, activ = self.model.decoder( 211 | # Variable(input, volatile=True), decStates, context, decOut, emb) 212 | # activs.append(activ) 213 | #else: 214 | decOut, decStates, attn = self.model.decoder( 215 | Variable(input, volatile=True), decStates, context, decOut) 216 | 217 | # decOut: 1 x (beam*batch) x numWords 218 | decOut = decOut.squeeze(0) 219 | out = self.model.generator.forward(decOut) 220 | 221 | # batch x beam x numWords 222 | wordLk = out.view(beamSize, remainingSents, -1) \ 223 | .transpose(0, 1).contiguous() 224 | attn = attn.view(beamSize, remainingSents, -1) \ 225 | .transpose(0, 1).contiguous() 226 | 227 | active = [] 228 | for b in range(batchSize): 229 | if beam[b].done: 230 | continue 231 | 232 | idx = batchIdx[b] 233 | if not beam[b].advance(wordLk.data[idx], attn.data[idx]): 234 | active += [b] 235 | #print(decStates) 236 | if not isinstance(decStates, tuple): 237 | decStates = tuple(decStates.unsqueeze(0)) 238 | #print(decStates) 239 | for decState in decStates: # iterate over h, c 240 | # layers x beam*sent x dim 241 | sentStates = decState.view(-1, beamSize, 242 | remainingSents, 243 | decState.size(2))[:, :, idx] 244 | sentStates.data.copy_( 245 | sentStates.data.index_select( 246 | 1, beam[b].getCurrentOrigin())) 247 | 248 | if not active: 249 | break 250 | 251 | # in this section, the sentences that are still active are 252 | # compacted so that the decoder is not run on completed sentences 253 | activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) 254 | batchIdx = {beam: idx for idx, beam in enumerate(active)} 255 | 256 | def updateActive(t, lastSize=rnnSize): 257 | # select only the remaining active sentences 258 | view = t.data.view(-1, remainingSents, lastSize) 259 | newSize = list(t.size()) 260 | newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents 261 | return Variable(view.index_select(1, activeIdx) 262 | .view(*newSize), volatile=True) 263 | 264 | decStates = tuple(updateActive(decStates[i]) 265 | for i in range(len(decStates))) 266 | 267 | if len(decStates) == 1: 268 | # The GRU needs only one matrix as hidden state 269 | decStates = decStates[0] 270 | 271 | decOut = updateActive(decOut) 272 | context = updateActive(context) 273 | emb = updateActive(emb, emb.size(2)) 274 | 275 | if useMasking: 276 | padMask = padMask.index_select(1, activeIdx) 277 | 278 | remainingSents = len(active) 279 | 280 | # (4) package everything up 281 | allHyp, allScores, allAttn = [], [], [] 282 | n_best = self.opt.n_best 283 | 284 | if activs: 285 | new_activs = torch.zeros((2, activs[0].size(1), len(activs))) 286 | for i, activ in enumerate(activs): 287 | new_activs[:, :activ.size(1), i] = activ.data 288 | activs = new_activs 289 | sys.stderr.write("r=\n") 290 | for i in range(activs.size(1)): 291 | for j in range(activs.size(2)): 292 | sys.stderr.write(str(activs[0][i][j]) + " ") 293 | sys.stderr.write("\n") 294 | sys.stderr.write("z=\n") 295 | for i in range(activs.size(1)): 296 | for j in range(activs.size(2)): 297 | sys.stderr.write(str(activs[1][i][j]) + " ") 298 | sys.stderr.write("\n") 299 | 300 | for b in range(batchSize): 301 | scores, ks = beam[b].sortBest() 302 | 303 | allScores += [scores[:n_best]] 304 | hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) 305 | allHyp += [hyps] 306 | if useMasking: 307 | valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \ 308 | .nonzero().squeeze(1) 309 | attn = [a.index_select(1, valid_attn) for a in attn] 310 | allAttn += [attn] 311 | 312 | if self.beam_accum: 313 | self.beam_accum["beam_parent_ids"].append( 314 | [t.tolist() 315 | for t in beam[b].prevKs]) 316 | self.beam_accum["scores"].append([ 317 | ["%4f" % s for s in t.tolist()] 318 | for t in beam[b].allScores][1:]) 319 | self.beam_accum["predicted_ids"].append( 320 | [[self.tgt_dict.getLabel(id) 321 | for id in t.tolist()] 322 | for t in beam[b].nextYs][1:]) 323 | 324 | return allHyp, allScores, allAttn, goldScores 325 | 326 | def translate(self, srcBatch, goldBatch): 327 | # (1) convert words to indexes 328 | dataset = self.buildData(srcBatch, goldBatch) 329 | src, tgt, indices = dataset[0] 330 | batchSize = self._getBatchSize(src[0]) 331 | 332 | # (2) translate 333 | pred, predScore, attn, goldScore = self.translateBatch(src, tgt) 334 | pred, predScore, attn, goldScore = list(zip( 335 | *sorted(zip(pred, predScore, attn, goldScore, indices), 336 | key=lambda x: x[-1])))[:-1] 337 | 338 | # (3) convert indexes to words 339 | predBatch = [] 340 | for b in range(batchSize): 341 | predBatch.append( 342 | [self.buildTargetTokens(pred[b][n], srcBatch[b], attn[b][n]) 343 | for n in range(self.opt.n_best)] 344 | ) 345 | 346 | return predBatch, predScore, goldScore 347 | -------------------------------------------------------------------------------- /onmt/__init__.py: -------------------------------------------------------------------------------- 1 | import onmt.Constants 2 | import onmt.Models 3 | from onmt.Translator import Translator 4 | from onmt.Dataset import Dataset 5 | from onmt.Optim import Optim 6 | from onmt.Dict import Dict 7 | from onmt.Beam import Beam 8 | 9 | # For flake8 compatibility. 10 | __all__ = [onmt.Constants, onmt.Models, Translator, Dataset, Optim, Dict, Beam, Encoders] 11 | -------------------------------------------------------------------------------- /onmt/modules/Attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global attention takes a matrix and a query vector. It 3 | then computes a parameterized convex combination of the matrix 4 | based on the input query. 5 | 6 | 7 | H_1 H_2 H_3 ... H_n 8 | q q q q 9 | | | | | 10 | \ | | / 11 | ..... 12 | \ | / 13 | a 14 | 15 | Constructs a unit mapping. 16 | $$(H_1 + H_n, q) => (a)$$ 17 | Where H is of `batch x n x dim` and q is of `batch x dim`. 18 | 19 | The full def is $$\tanh(W_2 [(softmax((W_1 q + b_1) H) H), q] + b_2)$$.: 20 | 21 | """ 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | import numpy as np 27 | from .Normalization import LayerNorm 28 | 29 | def getAttention(attention_type): 30 | attns = {'dot': DotAttention, 31 | 'mlp': MLPAttentionGRU, 32 | } 33 | 34 | if attention_type not in attns: 35 | raise NotImplementedError(attention_type) 36 | 37 | return attns[attention_type] 38 | 39 | 40 | class DotAttention(nn.Module): 41 | def __init__(self, dim, enc_dim=None, layer_norm=False, activ='tanh'): 42 | super(DotAttention, self).__init__() 43 | self.mask = None 44 | if not enc_dim: 45 | enc_dim = dim 46 | out_dim = dim 47 | self.linear_in = nn.Linear(dim, out_dim, bias=False) 48 | self.layer_norm = layer_norm 49 | if self.layer_norm: 50 | self.ln_in = LayerNorm(dim) 51 | 52 | def applyMask(self, mask): 53 | self.mask = mask 54 | 55 | def initialize_parameters(self, param_init): 56 | pass 57 | 58 | def forward(self, input, context, values): 59 | """ 60 | input: targetL x batch x dim 61 | context: batch x sourceL x dim 62 | """ 63 | batch, sourceL, dim = context.size() 64 | targetT = self.ln_in(self.linear_in(input.transpose(0, 1))) # batch x targetL x dim 65 | context = context.transpose(1, 2) # batch x dim x sourceL 66 | # Get attention 67 | attn = torch.bmm(targetT, context) # batch x targetL x sourceL 68 | if self.mask is not None: 69 | attn.data.masked_fill_(self.mask, -float('inf')) 70 | attn = F.softmax(attn.view(-1, sourceL)) # (batch x targetL) x sourceL 71 | attn3 = attn.view(batch, -1, sourceL) # batch x targetL x sourceL 72 | weightedContext = torch.bmm(attn3, values).transpose(0, 1) # targetL x batch x dim 73 | 74 | return weightedContext, attn 75 | 76 | 77 | class MLPAttention(nn.Module): 78 | def __init__(self, dim, layer_norm=False, activ='tanh'): 79 | super(MLPAttention, self).__init__() 80 | self.dim = dim 81 | self.v = nn.Linear(self.dim, 1) 82 | self.combine_hid = nn.Linear(self.dim, self.dim) 83 | #self.combine_ctx = nn.Linear(self.dim, self.dim) 84 | self.mask = None 85 | self.activ = getattr(F, activ) 86 | self.layer_norm = layer_norm 87 | if layer_norm: 88 | #self.ctx_ln = LayerNorm(dim) 89 | self.hidden_ln = LayerNorm(dim) 90 | 91 | def applyMask(self, mask): 92 | self.mask = mask 93 | 94 | def initialize_parameters(self, param_init): 95 | pass 96 | 97 | 98 | def forward(self, input, context, values): 99 | """ 100 | input: targetL x batch x dim 101 | context: batch x sourceL x dim 102 | values: batch x sourceL x dim 103 | 104 | Output: 105 | 106 | output: batch x hidden_size 107 | w: batch x sourceL 108 | """ 109 | targetL = input.size(0) 110 | output_size = input.size(2) 111 | sourceL = context.size(1) 112 | batch_size = input.size(1) 113 | 114 | # targetL x batch x dim 115 | input = self.combine_hid(input) 116 | # (targetL x batch) x dim 117 | #context = self.combine_ctx(context) 118 | if self.layer_norm: 119 | input = self.hidden_ln(input) 120 | #context = self.ctx_ln(context) 121 | 122 | # batch x (sourceL x targetL) x dim 123 | context = context.repeat(1, targetL, 1) 124 | 125 | # batch x targetL x dim -> batch x (targetL x sourceL) x dim 126 | input = input.transpose(0, 1).repeat(1, 1, sourceL).contiguous().view(batch_size, -1, output_size) 127 | #context = context.view(batch_size, -1, output_size) 128 | # batch x (targetL x sourceL) x dim 129 | combined = self.activ(input + context) 130 | 131 | # batch x (targetL x sourceL) x 1 132 | attn = self.v(combined) 133 | 134 | # (batch_size x targetL) x sourceL 135 | attn = attn.contiguous().view(batch_size * targetL, sourceL) 136 | 137 | if self.mask is not None: 138 | attn.data.masked_fill_(self.mask, -float('inf')) 139 | 140 | # (batch_size x targetL) x sourceL 141 | attn = F.softmax(attn) 142 | 143 | # batch_size x targetL x sourceL 144 | attn3 = attn.contiguous().view(batch_size, targetL, sourceL) 145 | 146 | # batch x targetL x dim -> targetL x batch x dim 147 | weightedContext = torch.bmm(attn3, values).transpose(0, 1) 148 | 149 | return weightedContext, attn 150 | 151 | class MLPAttentionGRU(nn.Module): 152 | def __init__(self, dim, layer_norm=False, activ='tanh'): 153 | super(MLPAttentionGRU, self).__init__() 154 | self.dim = dim 155 | self.v = nn.Linear(self.dim, 1) 156 | self.combine_hid = nn.Linear(self.dim, self.dim) 157 | # self.combine_ctx = nn.Linear(self.dim, self.dim) 158 | self.mask = None 159 | self.activ = getattr(F, activ) 160 | self.layer_norm = layer_norm 161 | if layer_norm: 162 | # self.ctx_ln = LayerNorm(dim) 163 | self.hidden_ln = LayerNorm(dim) 164 | 165 | def applyMask(self, mask): 166 | self.mask = mask 167 | 168 | def initialize_parameters(self, param_init): 169 | pass 170 | 171 | def forward(self, input, context, values): 172 | """ 173 | input: batch x dim 174 | context: batch x sourceL x dim 175 | values: batch x sourceL x dim 176 | 177 | Output: 178 | 179 | output: batch x hidden_size 180 | w: batch x sourceL 181 | """ 182 | sourceL = context.size(1) 183 | batch_size = input.size(0) 184 | 185 | # batch x dim 186 | input = self.combine_hid(input) 187 | 188 | if self.layer_norm: 189 | input = self.hidden_ln(input) 190 | 191 | # batch x sourceL x dim 192 | input = input.unsqueeze(1).expand_as(context) 193 | # batch x sourceL x dim 194 | combined = self.activ(input + context) 195 | 196 | # batch x sourceL x 1 197 | attn = self.v(combined) 198 | 199 | # batch_size x sourceL 200 | attn = attn.view(batch_size, sourceL) 201 | 202 | if self.mask is not None: 203 | attn.data.masked_fill_(self.mask, -float('inf')) 204 | 205 | # batch_size x sourceL 206 | attn = F.softmax(attn) 207 | 208 | # batch_size x 1 x sourceL 209 | attn3 = attn.unsqueeze(1) 210 | 211 | # batch x dim 212 | weightedContext = torch.bmm(attn3, values).squeeze(1) 213 | 214 | return weightedContext, attn 215 | 216 | 217 | 218 | class SelfAttention(nn.Module): 219 | def __init__(self, k_size, q_size, v_size, out_size): 220 | super(SelfAttention, self).__init__() 221 | self.linearK = nn.Linear(v_size, out_size) 222 | self.linearQ = nn.Linear(q_size, out_size) 223 | self.linearV = nn.Linear(v_size, out_size) 224 | self.dim = out_size 225 | self.mask = None 226 | 227 | def applyMask(self, mask): 228 | self.mask = mask 229 | 230 | def forward(self, input, context, values): 231 | """ 232 | input: batch x targetL x dim 233 | context: batch x sourceL x dim 234 | values: batch x sourceL x dim 235 | """ 236 | K = self.linearK(input) # batch x targetL x out_size 237 | Q = self.linearQ(context) # batch x sourceL x out_size 238 | V = self.linearV(values) # batch x sourceL x out_size 239 | 240 | dot_prod = K.bmm(Q.transpose(1, 2)) * (1 / np.sqrt(self.dim)) # batch x targetL x sourceL 241 | 242 | attn = dot_prod.sum(dim=1, keepdim=False) # batch x sourceL 243 | if self.mask is not None: 244 | attn.data.masked_fill_(self.mask, -float('inf')) 245 | attn = F.softmax(attn) # batch x sourceL 246 | attn3 = attn.unsqueeze(2) # batch x sourceL x 1 247 | weightedContext = V * attn3 # batch x sourceL x out_size 248 | 249 | return weightedContext, attn 250 | -------------------------------------------------------------------------------- /onmt/modules/Gate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Context gate is a decoder module that takes as input the previous word 3 | embedding, the current decoder state and the attention state, and produces a 4 | gate. 5 | The gate can be used to select the input from the target side context 6 | (decoder state), from the source context (attention state) or both. 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def ContextGateFactory(type, embeddings_size, decoder_size, 13 | attention_size, output_size): 14 | """Returns the correct ContextGate class""" 15 | 16 | gate_types = {'source': SourceContextGate, 17 | 'target': TargetContextGate, 18 | 'both': BothContextGate} 19 | 20 | assert type in gate_types, "Not valid ContextGate type: {0}".format(type) 21 | return gate_types[type](embeddings_size, decoder_size, attention_size, 22 | output_size) 23 | 24 | 25 | class ContextGate(nn.Module): 26 | """Implement up to the computation of the gate""" 27 | 28 | def __init__(self, embeddings_size, decoder_size, 29 | attention_size, output_size): 30 | super(ContextGate, self).__init__() 31 | input_size = embeddings_size + decoder_size + attention_size 32 | self.gate = nn.Linear(input_size, output_size, bias=True) 33 | self.sig = nn.Sigmoid() 34 | self.source_proj = nn.Linear(attention_size, output_size) 35 | self.target_proj = nn.Linear(embeddings_size + decoder_size, 36 | output_size) 37 | 38 | def forward(self, prev_emb, dec_state, attn_state): 39 | input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) 40 | z = self.sig(self.gate(input_tensor)) 41 | proj_source = self.source_proj(attn_state) 42 | 43 | proj_target = self.target_proj( 44 | torch.cat((prev_emb, dec_state), dim=1)) 45 | 46 | return z, proj_source, proj_target 47 | 48 | 49 | class SourceContextGate(nn.Module): 50 | """Apply the context gate only to the source context""" 51 | 52 | def __init__(self, embeddings_size, decoder_size, 53 | attention_size, output_size): 54 | super(SourceContextGate, self).__init__() 55 | self.context_gate = ContextGate(embeddings_size, decoder_size, 56 | attention_size, output_size) 57 | self.tanh = nn.Tanh() 58 | 59 | def forward(self, prev_emb, dec_state, attn_state): 60 | z, source, target = self.context_gate( 61 | prev_emb, dec_state, attn_state) 62 | return target + z * source 63 | 64 | 65 | class TargetContextGate(nn.Module): 66 | """Apply the context gate only to the target context""" 67 | 68 | def __init__(self, embeddings_size, decoder_size, 69 | attention_size, output_size): 70 | super(TargetContextGate, self).__init__() 71 | self.context_gate = ContextGate(embeddings_size, decoder_size, 72 | attention_size, output_size) 73 | self.tanh = nn.Tanh() 74 | 75 | def forward(self, prev_emb, dec_state, attn_state): 76 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 77 | return z * target + source 78 | 79 | 80 | class BothContextGate(nn.Module): 81 | """Apply the context gate to both contexts""" 82 | 83 | def __init__(self, embeddings_size, decoder_size, 84 | attention_size, output_size): 85 | super(BothContextGate, self).__init__() 86 | self.context_gate = ContextGate(embeddings_size, decoder_size, 87 | attention_size, output_size) 88 | self.tanh = nn.Tanh() 89 | 90 | def forward(self, prev_emb, dec_state, attn_state): 91 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 92 | return (1. - z) * target + z * source 93 | -------------------------------------------------------------------------------- /onmt/modules/ImageEncoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.cuda 5 | from torch.autograd import Variable 6 | 7 | 8 | class ImageEncoder(nn.Module): 9 | def __init__(self, opt): 10 | super(ImageEncoder, self).__init__() 11 | self.layers = opt.layers 12 | self.num_directions = 2 if opt.brnn else 1 13 | self.hidden_size = opt.rnn_size 14 | 15 | self.layer1 = nn.Conv2d(3, 64, kernel_size=(3, 3), 16 | padding=(1, 1), stride=(1, 1)) 17 | self.layer2 = nn.Conv2d(64, 128, kernel_size=(3, 3), 18 | padding=(1, 1), stride=(1, 1)) 19 | self.layer3 = nn.Conv2d(128, 256, kernel_size=(3, 3), 20 | padding=(1, 1), stride=(1, 1)) 21 | self.layer4 = nn.Conv2d(256, 256, kernel_size=(3, 3), 22 | padding=(1, 1), stride=(1, 1)) 23 | self.layer5 = nn.Conv2d(256, 512, kernel_size=(3, 3), 24 | padding=(1, 1), stride=(1, 1)) 25 | self.layer6 = nn.Conv2d(512, 512, kernel_size=(3, 3), 26 | padding=(1, 1), stride=(1, 1)) 27 | 28 | self.batch_norm1 = nn.BatchNorm2d(256) 29 | self.batch_norm2 = nn.BatchNorm2d(512) 30 | self.batch_norm3 = nn.BatchNorm2d(512) 31 | 32 | input_size = 512 33 | self.rnn = nn.LSTM(input_size, opt.rnn_size, 34 | num_layers=opt.layers, 35 | dropout=opt.dropout, 36 | bidirectional=opt.brnn) 37 | self.pos_lut = nn.Embedding(1000, input_size) 38 | 39 | def load_pretrained_vectors(self, opt): 40 | pass 41 | 42 | def forward(self, input): 43 | input = input[0] 44 | batchSize = input.size(0) 45 | # (batch_size, 64, imgH, imgW) 46 | # layer 1 47 | input = F.relu(self.layer1(input[:, :, :, :]-0.5), True) 48 | 49 | # (batch_size, 64, imgH/2, imgW/2) 50 | input = F.max_pool2d(input, kernel_size=(2, 2), stride=(2, 2)) 51 | 52 | # (batch_size, 128, imgH/2, imgW/2) 53 | # layer 2 54 | input = F.relu(self.layer2(input), True) 55 | 56 | # (batch_size, 128, imgH/2/2, imgW/2/2) 57 | input = F.max_pool2d(input, kernel_size=(2, 2), stride=(2, 2)) 58 | 59 | # (batch_size, 256, imgH/2/2, imgW/2/2) 60 | # layer 3 61 | # batch norm 1 62 | input = F.relu(self.batch_norm1(self.layer3(input)), True) 63 | 64 | # (batch_size, 256, imgH/2/2, imgW/2/2) 65 | # layer4 66 | input = F.relu(self.layer4(input), True) 67 | 68 | # (batch_size, 256, imgH/2/2/2, imgW/2/2) 69 | input = F.max_pool2d(input, kernel_size=(1, 2), stride=(1, 2)) 70 | 71 | # (batch_size, 512, imgH/2/2/2, imgW/2/2) 72 | # layer 5 73 | # batch norm 2 74 | input = F.relu(self.batch_norm2(self.layer5(input)), True) 75 | 76 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 77 | input = F.max_pool2d(input, kernel_size=(2, 1), stride=(2, 1)) 78 | 79 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 80 | input = F.relu(self.batch_norm3(self.layer6(input)), True) 81 | 82 | # # (batch_size, 512, H, W) 83 | # # (batch_size, H, W, 512) 84 | all_outputs = [] 85 | for row in range(input.size(2)): 86 | inp = input[:, :, row, :].transpose(0, 2)\ 87 | .transpose(1, 2) 88 | pos_emb = self.pos_lut( 89 | Variable(torch.cuda.LongTensor(batchSize).fill_(row))) 90 | with_pos = torch.cat( 91 | (pos_emb.view(1, pos_emb.size(0), pos_emb.size(1)), inp), 0) 92 | outputs, hidden_t = self.rnn(with_pos) 93 | all_outputs.append(outputs) 94 | out = torch.cat(all_outputs, 0) 95 | 96 | return hidden_t, out 97 | -------------------------------------------------------------------------------- /onmt/modules/Normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LayerNorm(nn.Module): 5 | 6 | def __init__(self, features, eps=1e-6): 7 | super().__init__() 8 | self.gamma = nn.Parameter(torch.ones(features)) 9 | self.beta = nn.Parameter(torch.zeros(features)) 10 | self.eps = eps 11 | 12 | def forward(self, x): 13 | mean = x.mean(-1, keepdim=True) 14 | std = x.std(-1, keepdim=True) 15 | return (self.gamma / (std + self.eps)) * (x - mean) + self.beta 16 | 17 | def initialize_parameters(self, param_init): 18 | self.gamma.data.fill_(1.) 19 | self.beta.data.fill_(0.) -------------------------------------------------------------------------------- /onmt/modules/SRU_units.py: -------------------------------------------------------------------------------- 1 | """ 2 | Context gate is a decoder module that takes as input the previous word 3 | embedding, the current decoder state and the attention state, and produces a 4 | gate. 5 | The gate can be used to select the input from the target side context 6 | (decoder state), from the source context (attention state) or both. 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from .Normalization import LayerNorm 12 | from torch.autograd import Variable 13 | import numpy as np 14 | 15 | 16 | class AttSRU(nn.Module): 17 | 18 | def __init__(self, input_size, attention_size, output_size, layer_norm, dropout): 19 | from .Attention import DotAttention 20 | super(AttSRU, self).__init__() 21 | self.linear_in = nn.Linear(input_size, 3*output_size, bias=(not layer_norm)) 22 | self.linear_hidden = nn.Linear(output_size, output_size, bias=(not layer_norm)) 23 | self.linear_ctx = nn.Linear(output_size, output_size, bias=(not layer_norm)) 24 | self.linear_enc = nn.Linear(output_size, output_size, bias=(not layer_norm)) 25 | self.output_size = output_size 26 | self.attn = DotAttention(attention_size, layer_norm=True) 27 | self.layer_norm = layer_norm 28 | self.dropout = nn.Dropout(dropout) 29 | if self.layer_norm: 30 | self.preact_ln = LayerNorm(3 * output_size) 31 | self.enc_ln = LayerNorm(output_size) 32 | 33 | self.trans_h_ln = LayerNorm(output_size) 34 | self.trans_c_ln = LayerNorm(output_size) 35 | 36 | def initialize_parameters(self, param_init): 37 | self.preact_ln.initialize_parameters(param_init) 38 | self.trans_h_ln.initialize_parameters(param_init) 39 | self.trans_c_ln.initialize_parameters(param_init) 40 | self.enc_ln.initialize_parameters(param_init) 41 | 42 | def forward(self, prev_layer, hidden, enc_output): 43 | """ 44 | :param prev_layer: targetL x batch x output_size 45 | :param hidden: batch x output_size 46 | :param enc_output: (targetL x batch) x sourceL x output_size 47 | :return: 48 | """ 49 | 50 | # targetL x batch x output_size 51 | preact = self.linear_in(self.dropout(prev_layer)) 52 | pctx = self.linear_enc(self.dropout(enc_output)) 53 | if self.layer_norm: 54 | preact = self.preact_ln(preact) 55 | pctx = self.enc_ln(pctx) 56 | #z = self.z_ln(z) 57 | #prev_layer_t = self.prev_layer_ln(prev_layer_t) 58 | #h_gate = self.h_gate_ln(h_gate) 59 | z, h_gate, prev_layer_t = preact.split(self.output_size, dim=-1) 60 | z, h_gate = F.sigmoid(z), F.sigmoid(h_gate) 61 | 62 | ss = [] 63 | for i in range(prev_layer.size(0)): 64 | s = (1. - z[i]) * hidden + z[i] * prev_layer_t[i] 65 | # targetL x batch x output_size 66 | ss += [s] 67 | # batch x output_size 68 | hidden = s 69 | 70 | # (targetL x batch) x output_size 71 | ss = torch.stack(ss) 72 | attn_out, attn = self.attn(self.dropout(ss), pctx, pctx) 73 | attn_out = attn_out / np.sqrt(self.output_size) 74 | 75 | trans_h = self.linear_hidden(self.dropout(ss)) 76 | trans_c = self.linear_ctx(self.dropout(attn_out)) 77 | if self.layer_norm: 78 | #out = self.post_ln(out) 79 | trans_h = self.trans_h_ln(trans_h) 80 | trans_c = self.trans_c_ln(trans_c) 81 | #trans_h, trans_c = F.tanh(trans_h), F.tanh(trans_c) 82 | out = trans_h + trans_c 83 | out = F.tanh(out) 84 | out = out.view(prev_layer.size()) 85 | out = (1. - h_gate) * out + h_gate * prev_layer 86 | 87 | return out, hidden, attn 88 | 89 | class BiSRU(nn.Module): 90 | 91 | def __init__(self, input_size, output_size, layer_norm, dropout): 92 | super(BiSRU, self).__init__() 93 | self.input_linear = nn.Linear(input_size, 3*output_size, bias=(not layer_norm)) 94 | self.layer_norm = layer_norm 95 | self.output_size = output_size 96 | self.dropout = nn.Dropout(dropout) 97 | if self.layer_norm: 98 | self.preact_ln = LayerNorm(3 * output_size) 99 | #self.x_f_ln = LayerNorm(output_size // 2) 100 | #self.x_b_ln = LayerNorm(output_size // 2) 101 | #self.f_g_ln = LayerNorm(output_size // 2) 102 | #self.b_g_ln = LayerNorm(output_size // 2) 103 | #self.highway_ln = LayerNorm(output_size) 104 | 105 | def initialize_parameters(self, param_init): 106 | self.preact_ln.initialize_parameters(param_init) 107 | 108 | def forward(self, input): 109 | pre_act = self.input_linear(self.dropout(input)) 110 | #h_gate = pre_act[:, :, 2*self.output_size:] 111 | #gf, gb, x_f, x_b = pre_act[:, :, :2*self.output_size].split(self.output_size // 2, dim=-1) 112 | if self.layer_norm: 113 | pre_act = self.preact_ln(pre_act) 114 | #x_f = self.x_f_ln(x_f) 115 | #x_b = self.x_b_ln(x_b) 116 | #gf = self.f_g_ln(gf) 117 | #gb = self.b_g_ln(gb) 118 | #h_gate = self.highway_ln(h_gate) 119 | h_gate = pre_act[:, :, 2*self.output_size:] 120 | g, x = pre_act[:, :, :2*self.output_size].split(self.output_size, dim=-1) 121 | gf, gb = F.sigmoid(g).split(self.output_size // 2, dim=-1) 122 | x_f, x_b = x.split(self.output_size // 2, dim=-1) 123 | h_gate = F.sigmoid(h_gate) 124 | h_f_pre = gf * x_f 125 | h_b_pre = gb * x_b 126 | 127 | h_i_f = Variable(h_f_pre.data.new(gf[0].size()).zero_(), requires_grad=False) 128 | h_i_b = Variable(h_f_pre.data.new(gf[0].size()).zero_(), requires_grad=False) 129 | 130 | h_f, h_b = [], [] 131 | for i in range(input.size(0)): 132 | h_i_f = (1. - gf[i]) * h_i_f + h_f_pre[i] 133 | h_i_b = (1. - gb[-(i+1)]) * h_i_b + h_b_pre[-(i+1)] 134 | h_f += [h_i_f] 135 | h_b += [h_i_b] 136 | 137 | h = torch.cat([torch.stack(h_f), torch.stack(h_b[::-1])], dim=-1) 138 | 139 | output = (1. - h_gate) * h + input * h_gate 140 | 141 | return output 142 | 143 | 144 | class SRU(nn.Module): 145 | def __init__(self, input_size, output_size, dropout): 146 | super(SRU, self).__init__() 147 | self.linear_in = nn.Linear(input_size, 3 * output_size) 148 | if input_size != output_size: 149 | self.reduce = nn.Linear(input_size, output_size) 150 | self.input_size = input_size 151 | self.output_size = output_size 152 | self.dropout = nn.Dropout(dropout) 153 | 154 | def initialize_parameters(self, param_init): 155 | pass 156 | 157 | def forward(self, prev_layer, hidden): 158 | """ 159 | :param prev_layer: targetL x batch x output_size 160 | :param hidden: batch x output_size 161 | :return: 162 | """ 163 | 164 | # targetL x batch x output_size 165 | preact = self.linear_in(self.dropout(prev_layer)) 166 | 167 | prev_layer_t = preact[:, :, :self.output_size] 168 | z, h_gate = F.sigmoid(preact[:, :, self.output_size:]).split(self.output_size, dim=-1) 169 | 170 | ss = [] 171 | for i in range(prev_layer.size(0)): 172 | s = (1 - z[i]) * hidden + z[i] * prev_layer_t[i] 173 | # targetL x batch x output_size 174 | ss += [s] 175 | # batch x output_size 176 | hidden = s 177 | 178 | # (targetL x batch) x output_size 179 | out = torch.stack(ss) 180 | if self.input_size != self.output_size: 181 | prev_layer = self.reduce(self.dropout(prev_layer)) 182 | 183 | out = (1. - h_gate) * out + h_gate * prev_layer 184 | 185 | return out, hidden 186 | -------------------------------------------------------------------------------- /onmt/modules/Units.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | from torch.nn.utils.rnn import PackedSequence 6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 7 | from torch.nn.utils.rnn import pack_padded_sequence as pack 8 | from onmt.modules import SRU 9 | 10 | import math 11 | 12 | class ParallelMyRNN(nn.Module): 13 | def __init__(self, input_size, hidden_size, 14 | num_layers=1, dropout=0, bidirectional=False): 15 | super(ParallelMyRNN, self).__init__() 16 | self.unit = SRU 17 | self.input_size = input_size 18 | self.rnn_size = hidden_size 19 | self.num_layers = num_layers 20 | self.dropout = dropout 21 | self.Dropout = nn.Dropout(dropout) 22 | self.bidirectional = bidirectional 23 | self.num_directions = 2 if bidirectional else 1 24 | self.hidden_size = self.rnn_size * self.num_directions 25 | self.rnns = nn.ModuleList([nn.ModuleList() for _ in range(self.num_directions)]) 26 | 27 | # for layer in range(num_layers): 28 | for layer in range(num_layers): 29 | layer_input_size = input_size if layer == 0 else self.hidden_size 30 | for direction in range(self.num_directions): 31 | self.rnns[direction].append(self.unit(layer_input_size, self.rnn_size, self.dropout)) 32 | 33 | def reset_parameters(self): 34 | stdv = 1.0 / math.sqrt(self.rnn_size) 35 | for weight in self.parameters(): 36 | weight.data.uniform_(-stdv, stdv) 37 | 38 | def initialize_parameters(self, param_init): 39 | for direction in range(self.num_directions): 40 | for layer in self.rnns[direction]: 41 | layer.initialize_parameters(param_init) 42 | 43 | def reverse_tensor(self, x, dim): 44 | idx = [i for i in range(x.size(dim) - 1, -1, -1)] 45 | idx = Variable(torch.LongTensor(idx)) 46 | if x.is_cuda: 47 | idx = idx.cuda() 48 | return x.index_select(dim, idx) 49 | 50 | def forward(self, input, hidden=None): 51 | 52 | is_packed = isinstance(input, PackedSequence) 53 | if is_packed: 54 | input, batch_sizes = unpack(input) 55 | max_batch_size = batch_sizes[0] 56 | 57 | if hidden is None: 58 | # (num_layers x num_directions) x batch_size x rnn_size 59 | hidden = Variable(input.data.new(self.num_layers * 60 | self.num_directions, 61 | input.size(1), 62 | self.rnn_size).zero_(), requires_grad=False) 63 | if input.is_cuda: 64 | hidden = hidden.cuda() 65 | 66 | gru_out = [] 67 | _input = input 68 | for i in range(self.num_layers): 69 | if not self.bidirectional: 70 | prev_layer = self.Dropout(_input) 71 | h = hidden[i] # batch_size x rnn_size 72 | unit = self.rnns[0][i] # Computation unit 73 | 74 | layer_out, hid_uni = unit(prev_layer, h) # src_len x batch x hidden_size 75 | 76 | else: 77 | input_forward = self.Dropout(_input) 78 | input_backward = self.Dropout(_input) 79 | h_forward = hidden[i * self.num_directions] # batch_size x rnn_size 80 | h_backward = hidden[i * self.num_directions + 1] # batch_size x rnn_size 81 | unit_forward = self.rnns[0][i] # Computation unit 82 | unit_backward = self.rnns[1][i] # Computation unit 83 | 84 | output_forward, h_forward = unit_forward(input_forward, h_forward) 85 | output_backward, h_backward = self.compute_backwards(unit_backward, input_backward, h_backward) 86 | 87 | layer_out = torch.cat([output_forward, output_backward], dim=2) # src_len x batch x hidden_size 88 | 89 | _input = layer_out 90 | 91 | if self.bidirectional: 92 | gru_out.append(output_forward[-1].unsqueeze(0)) 93 | gru_out.append(output_backward[-1].unsqueeze(0)) # num_directions x [batch x rnn_size] 94 | else: 95 | gru_out.append(layer_out) 96 | 97 | hidden = torch.cat(gru_out, dim=0) # (num_layers x num_directions) x batch x rnn_size 98 | 99 | output = _input 100 | 101 | return output, hidden 102 | 103 | def __repr__(self): 104 | s = '{name}({input_size}, {rnn_size}' 105 | if self.num_layers != 1: 106 | s += ', num_layers={num_layers}' 107 | if self.dropout != 0: 108 | s += ', dropout={dropout}' 109 | if self.bidirectional is not False: 110 | s += ', bidirectional={bidirectional}' 111 | s += ')' 112 | return s.format(name=self.__class__.__name__, **self.__dict__) 113 | 114 | def compute_backwards(self, unit, input, hidden): 115 | h = hidden 116 | steps = torch.cat(input.split(1, dim=0)[::-1], dim=0) 117 | out, hidden = unit(steps, h) 118 | out = torch.cat(out.split(1, dim=0)[::-1], dim=0) 119 | return out, hidden 120 | -------------------------------------------------------------------------------- /onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .SRU_units import SRU 2 | from .Units import ParallelMyRNN -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | import onmt.Markdown 3 | import argparse 4 | import torch 5 | 6 | 7 | def loadImageLibs(): 8 | "Conditional import of torch image libs." 9 | global Image, transforms 10 | from PIL import Image 11 | from torchvision import transforms 12 | 13 | 14 | parser = argparse.ArgumentParser(description='preprocess.py') 15 | onmt.Markdown.add_md_help_argument(parser) 16 | 17 | # **Preprocess Options** 18 | 19 | parser.add_argument('-config', help="Read options from this file") 20 | 21 | parser.add_argument('-src_type', default="bitext", 22 | choices=["bitext", "monotext", "img"], 23 | help="""Type of the source input. 24 | This affects all the subsequent operations 25 | Options are [bitext|monotext|img].""") 26 | parser.add_argument('-src_img_dir', default=".", 27 | help="Location of source images") 28 | 29 | 30 | parser.add_argument('-train', 31 | help="""Path to the monolingual training data""") 32 | parser.add_argument('-train_src', required=False, 33 | help="Path to the training source data") 34 | parser.add_argument('-train_tgt', required=False, 35 | help="Path to the training target data") 36 | parser.add_argument('-valid', 37 | help="""Path to the monolingual validation data""") 38 | parser.add_argument('-valid_src', required=False, 39 | help="Path to the validation source data") 40 | parser.add_argument('-valid_tgt', required=False, 41 | help="Path to the validation target data") 42 | 43 | parser.add_argument('-save_data', required=True, 44 | help="Output file for the prepared data") 45 | 46 | parser.add_argument('-src_vocab_size', type=int, default=50000, 47 | help="Size of the source vocabulary") 48 | parser.add_argument('-tgt_vocab_size', type=int, default=50000, 49 | help="Size of the target vocabulary") 50 | parser.add_argument('-src_vocab', 51 | help="Path to an existing source vocabulary") 52 | parser.add_argument('-tgt_vocab', 53 | help="Path to an existing target vocabulary") 54 | 55 | parser.add_argument('-src_seq_length', type=int, default=50, 56 | help="Maximum source sequence length") 57 | parser.add_argument('-src_seq_length_trunc', type=int, default=0, 58 | help="Truncate source sequence length.") 59 | parser.add_argument('-tgt_seq_length', type=int, default=50, 60 | help="Maximum target sequence length to keep.") 61 | parser.add_argument('-tgt_seq_length_trunc', type=int, default=0, 62 | help="Truncate target sequence length.") 63 | 64 | parser.add_argument('-shuffle', type=int, default=1, 65 | help="Shuffle data") 66 | parser.add_argument('-seed', type=int, default=3435, 67 | help="Random seed") 68 | 69 | parser.add_argument('-lower', action='store_true', help='lowercase data') 70 | 71 | parser.add_argument('-report_every', type=int, default=100000, 72 | help="Report status every this many sentences") 73 | 74 | opt = parser.parse_args() 75 | 76 | torch.manual_seed(opt.seed) 77 | 78 | 79 | def makeVocabulary(filename, size): 80 | vocab = onmt.Dict([onmt.Constants.PAD_WORD, onmt.Constants.UNK_WORD, 81 | onmt.Constants.BOS_WORD, onmt.Constants.EOS_WORD], 82 | lower=opt.lower) 83 | 84 | with open(filename) as f: 85 | for sent in f.readlines(): 86 | for word in sent.split(): 87 | vocab.add(word) 88 | 89 | originalSize = vocab.size() 90 | vocab = vocab.prune(size) 91 | print('Created dictionary of size %d (pruned from %d)' % 92 | (vocab.size(), originalSize)) 93 | 94 | return vocab 95 | 96 | 97 | def initVocabulary(name, dataFile, vocabFile, vocabSize): 98 | 99 | vocab = None 100 | if vocabFile is not None: 101 | # If given, load existing word dictionary. 102 | print('Reading ' + name + ' vocabulary from \'' + vocabFile + '\'...') 103 | vocab = onmt.Dict() 104 | vocab.loadFile(vocabFile) 105 | print('Loaded ' + str(vocab.size()) + ' ' + name + ' words') 106 | 107 | if vocab is None: 108 | # If a dictionary is still missing, generate it. 109 | print('Building ' + name + ' vocabulary...') 110 | genWordVocab = makeVocabulary(dataFile, vocabSize) 111 | 112 | vocab = genWordVocab 113 | 114 | print() 115 | return vocab 116 | 117 | 118 | def saveVocabulary(name, vocab, file): 119 | print('Saving ' + name + ' vocabulary to \'' + file + '\'...') 120 | vocab.writeFile(file) 121 | 122 | 123 | def makeBilingualData(srcFile, tgtFile, srcDicts, tgtDicts): 124 | src, tgt = [], [] 125 | sizes = [] 126 | count, ignored = 0, 0 127 | 128 | print('Processing %s & %s ...' % (srcFile, tgtFile)) 129 | srcF = open(srcFile) 130 | tgtF = open(tgtFile) 131 | 132 | while True: 133 | sline = srcF.readline() 134 | tline = tgtF.readline() 135 | 136 | # normal end of file 137 | if sline == "" and tline == "": 138 | break 139 | 140 | # source or target does not have same number of lines 141 | if sline == "" or tline == "": 142 | print('WARNING: src and tgt do not have the same # of sentences') 143 | break 144 | 145 | sline = sline.strip() 146 | tline = tline.strip() 147 | 148 | # source and/or target are empty 149 | if sline == "" or tline == "": 150 | print('WARNING: ignoring an empty line ('+str(count+1)+')') 151 | continue 152 | 153 | srcWords = sline.split() 154 | tgtWords = tline.split() 155 | 156 | if len(srcWords) <= opt.src_seq_length \ 157 | and len(tgtWords) <= opt.tgt_seq_length: 158 | 159 | # Check truncation condition. 160 | if opt.src_seq_length_trunc != 0: 161 | srcWords = srcWords[:opt.src_seq_length_trunc] 162 | if opt.tgt_seq_length_trunc != 0: 163 | tgtWords = tgtWords[:opt.tgt_seq_length_trunc] 164 | 165 | if opt.src_type == "bitext": 166 | src += [srcDicts.convertToIdx(srcWords, 167 | onmt.Constants.UNK_WORD)] 168 | elif opt.src_type == "img": 169 | loadImageLibs() 170 | src += [transforms.ToTensor()( 171 | Image.open(opt.src_img_dir + "/" + srcWords[0]))] 172 | 173 | tgt += [tgtDicts.convertToIdx(tgtWords, 174 | onmt.Constants.UNK_WORD, 175 | onmt.Constants.BOS_WORD, 176 | onmt.Constants.EOS_WORD)] 177 | sizes += [len(srcWords)] 178 | else: 179 | ignored += 1 180 | 181 | count += 1 182 | 183 | if count % opt.report_every == 0: 184 | print('... %d sentences prepared' % count) 185 | 186 | srcF.close() 187 | tgtF.close() 188 | 189 | if opt.shuffle == 1: 190 | print('... shuffling sentences') 191 | perm = torch.randperm(len(src)) 192 | src = [src[idx] for idx in perm] 193 | tgt = [tgt[idx] for idx in perm] 194 | sizes = [sizes[idx] for idx in perm] 195 | 196 | print('... sorting sentences by size') 197 | _, perm = torch.sort(torch.Tensor(sizes)) 198 | src = [src[idx] for idx in perm] 199 | tgt = [tgt[idx] for idx in perm] 200 | 201 | print(('Prepared %d sentences ' + 202 | '(%d ignored due to length == 0 or src len > %d or tgt len > %d)') % 203 | (len(src), ignored, opt.src_seq_length, opt.tgt_seq_length)) 204 | 205 | return src, tgt 206 | 207 | 208 | def makeMonolingualData(srcFile, srcDicts): 209 | src = [] 210 | sizes = [] 211 | count, ignored = 0, 0 212 | 213 | print('Processing %s ...' % (srcFile)) 214 | 215 | with open(srcFile) as srcF: 216 | for sline in srcF: 217 | sline = sline.strip() 218 | 219 | # source and/or target are empty 220 | if sline == "": 221 | print('WARNING: ignoring an empty line ('+str(count+1)+')') 222 | continue 223 | 224 | srcWords = sline.split() 225 | 226 | if len(srcWords) <= opt.src_seq_length: 227 | 228 | # Check truncation condition.LGRU_model_1layers_acc_54.83_ppl_12.43_e1.pt 229 | if opt.src_seq_length_trunc != 0: 230 | srcWords = srcWords[:opt.src_seq_length_trunc] 231 | 232 | src += [srcDicts.convertToIdx(srcWords, 233 | onmt.Constants.UNK_WORD, 234 | onmt.Constants.BOS_WORD, 235 | onmt.Constants.EOS_WORD)] 236 | sizes += [len(srcWords)] 237 | else: 238 | ignored += 1 239 | 240 | count += 1 241 | 242 | if count % opt.report_every == 0: 243 | print('... %d sentences prepared' % count) 244 | 245 | if opt.shuffle == 1: 246 | print('... shuffling sentences') 247 | perm = torch.randperm(len(src)) 248 | src = [src[idx] for idx in perm] 249 | sizes = [sizes[idx] for idx in perm] 250 | 251 | print('... sorting sentences by size') 252 | _, perm = torch.sort(torch.Tensor(sizes)) 253 | src = [src[idx] for idx in perm] 254 | 255 | print(('Prepared %d sentences ' + 256 | '(%d ignored due to length == 0 or src len > %d)') % 257 | (len(src), ignored, opt.src_seq_length)) 258 | 259 | return src 260 | 261 | 262 | def main(): 263 | 264 | if opt.src_type in ['bitext', 'img']: 265 | assert None not in [opt.train_src, opt.train_tgt, 266 | opt.valid_src, opt.valid_tgt], \ 267 | "With source type %s the following parameters are" \ 268 | "required: -train_src, -train_tgt, " \ 269 | "-valid_src, -valid_tgt" % (opt.src_type) 270 | 271 | elif opt.src_type == 'monotext': 272 | assert None not in [opt.train, opt.valid], \ 273 | "With source type monotext the following " \ 274 | "parameters are required: -train, -valid" 275 | 276 | dicts = {} 277 | dicts['src'] = onmt.Dict() 278 | if opt.src_type == 'bitext': 279 | dicts['src'] = initVocabulary('source', opt.train_src, opt.src_vocab, 280 | opt.src_vocab_size) 281 | dicts['tgt'] = initVocabulary('target', opt.train_tgt, opt.tgt_vocab, 282 | opt.tgt_vocab_size) 283 | 284 | elif opt.src_type == 'monotext': 285 | dicts['src'] = initVocabulary('source', opt.train, opt.src_vocab, 286 | opt.src_vocab_size) 287 | 288 | elif opt.src_type == 'img': 289 | dicts['tgt'] = initVocabulary('target', opt.train_tgt, opt.tgt_vocab, 290 | opt.tgt_vocab_size) 291 | 292 | print('Preparing training ...') 293 | train = {} 294 | valid = {} 295 | 296 | if opt.src_type in ['bitext', 'img']: 297 | train['src'], train['tgt'] = makeBilingualData(opt.train_src, 298 | opt.train_tgt, 299 | dicts['src'], 300 | dicts['tgt']) 301 | 302 | print('Preparing validation ...') 303 | valid['src'], valid['tgt'] = makeBilingualData(opt.valid_src, 304 | opt.valid_tgt, 305 | dicts['src'], 306 | dicts['tgt']) 307 | 308 | elif opt.src_type == 'monotext': 309 | train['src'] = makeMonolingualData(opt.train, dicts['src']) 310 | train['tgt'] = train['src'] # Keeps compatibility with bilingual code 311 | print('Preparing validation ...') 312 | valid['src'] = makeMonolingualData(opt.valid, dicts['src']) 313 | valid['tgt'] = valid['src'] 314 | 315 | if opt.src_vocab is None: 316 | saveVocabulary('source', dicts['src'], opt.save_data + '.src.dict') 317 | if opt.src_type in ['bitext', 'img'] and opt.tgt_vocab is None: 318 | saveVocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict') 319 | 320 | print('Saving data to \'' + opt.save_data + '.train.pt\'...') 321 | save_data = {'dicts': dicts, 322 | 'type': opt.src_type, 323 | 'train': train, 324 | 'valid': valid} 325 | torch.save(save_data, opt.save_data + '.train.pt') 326 | 327 | 328 | if __name__ == "__main__": 329 | main() 330 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from distutils.core import setup 4 | 5 | setup(name='OpenNMT', 6 | version='0.1', 7 | description='OpenNMT', 8 | packages=['onmt', 'onmt.modules']) 9 | -------------------------------------------------------------------------------- /test/test_simple.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | 3 | 4 | def test_load(): 5 | onmt 6 | pass 7 | -------------------------------------------------------------------------------- /tools/extract_embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import onmt 4 | import torch 5 | import argparse 6 | 7 | import onmt.Models 8 | 9 | parser = argparse.ArgumentParser(description='translate.py') 10 | 11 | parser.add_argument('-model', required=True, 12 | help='Path to model .pt file') 13 | parser.add_argument('-output_dir', default='.', 14 | help="""Path to output the embeddings""") 15 | parser.add_argument('-gpu', type=int, default=-1, 16 | help="Device to run on") 17 | 18 | 19 | def write_embeddings(filename, dict, embeddings): 20 | with open(filename, 'w') as file: 21 | for i in range(len(embeddings)): 22 | str = dict.idxToLabel[i].encode("utf-8") 23 | for j in range(len(embeddings[0])): 24 | str = str + " %5f" % (embeddings[i][j]) 25 | file.write(str + "\n") 26 | 27 | 28 | def main(): 29 | opt = parser.parse_args() 30 | checkpoint = torch.load(opt.model) 31 | opt.cuda = opt.gpu > -1 32 | if opt.cuda: 33 | torch.cuda.set_device(opt.gpu) 34 | 35 | model_opt = checkpoint['opt'] 36 | src_dict = checkpoint['dicts']['src'] 37 | tgt_dict = checkpoint['dicts']['tgt'] 38 | 39 | encoder = onmt.Models.Encoder(model_opt, src_dict) 40 | decoder = onmt.Models.Decoder(model_opt, tgt_dict) 41 | encoder_embeddings = encoder.word_lut.weight.data.tolist() 42 | decoder_embeddings = decoder.word_lut.weight.data.tolist() 43 | 44 | print("Writing source embeddings") 45 | write_embeddings(opt.output_dir + "/src_embeddings.txt", src_dict, 46 | encoder_embeddings) 47 | 48 | print("Writing target embeddings") 49 | write_embeddings(opt.output_dir + "/tgt_embeddings.txt", tgt_dict, 50 | decoder_embeddings) 51 | 52 | print('... done.') 53 | print('Converting model...') 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import onmt 4 | import onmt.Markdown 5 | import onmt.Models 6 | import onmt.Decoders 7 | import onmt.Encoders 8 | import onmt.modules 9 | import argparse 10 | import torch 11 | import torch.nn as nn 12 | from torch import cuda 13 | from torch.autograd import Variable 14 | from torch.nn import init 15 | import math 16 | import time 17 | 18 | parser = argparse.ArgumentParser(description='train.py') 19 | onmt.Markdown.add_md_help_argument(parser) 20 | 21 | # Data options 22 | 23 | parser.add_argument('-data', required=True, 24 | help='Path to the *-train.pt file from preprocess.py') 25 | parser.add_argument('-save_model', default='model', 26 | help="""Model filename (the model will be saved as 27 | _epochN_PPL.pt where PPL is the 28 | validation perplexity""") 29 | parser.add_argument('-train_from_state_dict', default='', type=str, 30 | help="""If training from a checkpoint then this is the 31 | path to the pretrained model's state_dict.""") 32 | parser.add_argument('-train_from', default='', type=str, 33 | help="""If training from a checkpoint then this is the 34 | path to the pretrained model.""") 35 | 36 | # Model options 37 | 38 | parser.add_argument('-model_type', type=str, default='nmt', 39 | choices=['nmt', 'lm'], 40 | help="""Kind of model to train, it can be 41 | neural machine translation or language model 42 | [nmt|lm]""") 43 | parser.add_argument('-layers_enc', type=int, default=2, 44 | help='Number of layers in the LSTM encoder') 45 | parser.add_argument('-layers_dec', type=int, default=2, 46 | help='Number of layers in the LSTM decoder') 47 | parser.add_argument('-rnn_size', type=int, default=500, 48 | help='Size of LSTM hidden states') 49 | parser.add_argument('-word_vec_size', type=int, default=500, 50 | help='Word embedding sizes') 51 | parser.add_argument('-input_feed', type=int, default=1, 52 | help="""Feed the context vector at each time step as 53 | additional input (via concatenation with the word 54 | embeddings) to the decoder.""") 55 | parser.add_argument('-rnn_type', type=str, default='LSTM', 56 | choices=['LSTM', 'GRU', 'SRU'], 57 | help="""The gate type to use in the RNNs""") 58 | parser.add_argument('-rnn_encoder_type', type=str, 59 | choices=['LSTM', 'GRU', 'SRU'], 60 | help="""The gate type to use in the encoder RNNs. It overwrites -rnn_type""") 61 | parser.add_argument('-rnn_decoder_type', type=str, 62 | choices=['LSTM', 'GRU', 'SRU'], 63 | help="""The gate type to use in the decoder RNNs. It overwrites -rnn_type""") 64 | parser.add_argument('-attn_type', type=str, default='mlp', 65 | choices=['mlp', 'dot'], 66 | help="""The attention type to use in the decoder""") 67 | parser.add_argument('-activ', type=str, default='tanh', 68 | help="""Activation function inside the RNNs.""") 69 | parser.add_argument('-brnn', action='store_true', 70 | help='Use a bidirectional encoder') 71 | parser.add_argument('-context_gate', type=str, default=None, 72 | choices=['source', 'target', 'both'], 73 | help="""Type of context gate to use [source|target|both]. 74 | Do not select for no context gate.""") 75 | parser.add_argument('-decoder_type', type=str, default='StackedRNN', 76 | help="""Decoder neural architecture to use""") 77 | parser.add_argument('-encoder_type', type=str, default='RNN', 78 | help="""Encoder architecture""") 79 | parser.add_argument('-layer_norm', default=False, action="store_true", 80 | help="""Add layer normalization in recurrent units""") 81 | 82 | # Optimization options 83 | 84 | parser.add_argument('-batch_size', type=int, default=64, 85 | help='Maximum batch size') 86 | parser.add_argument('-max_generator_batches', type=int, default=32, 87 | help="""Maximum batches of words in a sequence to run 88 | the generator on in parallel. Higher is faster, but uses 89 | more memory.""") 90 | parser.add_argument('-epochs', type=int, default=13, 91 | help='Number of training epochs') 92 | parser.add_argument('-start_epoch', type=int, default=1, 93 | help='The epoch from which to start') 94 | parser.add_argument('-param_init', type=float, default=0.1, 95 | help="""Parameters are initialized over uniform distribution 96 | with support (-param_init, param_init)""") 97 | parser.add_argument('-optim', default='sgd', 98 | help="Optimization method. [sgd|adagrad|adadelta|adam]") 99 | parser.add_argument('-max_grad_norm', type=float, default=5, 100 | help="""If the norm of the gradient vector exceeds this, 101 | renormalize it to have the norm equal to max_grad_norm""") 102 | parser.add_argument('-dropout', type=float, default=0.3, 103 | help='Dropout probability; applied between LSTM stacks.') 104 | parser.add_argument('-curriculum', action="store_true", 105 | help="""For this many epochs, order the minibatches based 106 | on source sequence length. Sometimes setting this to 1 will 107 | increase convergence speed.""") 108 | parser.add_argument('-extra_shuffle', action="store_true", 109 | help="""By default only shuffle mini-batch order; when true, 110 | shuffle and re-assign mini-batches""") 111 | parser.add_argument('-change_optimizer', default=False, action='store_true', 112 | help="""In case a model is reloaded, it sets the optimizer 113 | values to the one set in the arguments""") 114 | parser.add_argument('-enc_short_path', type=bool, default=False, 115 | help="""If True, creates a short path from the source embeddings to the output 116 | by adding them to the attention""") 117 | parser.add_argument('-use_learning_rate_decay', action="store_true", 118 | help='if set, activate learning rate decay after every checkpoint') 119 | parser.add_argument('-save_each', type=int, default=10000, 120 | help="""The number of minibatches to compute before saving a checkpoint""") 121 | 122 | # learning rate 123 | parser.add_argument('-learning_rate', type=float, default=1.0, 124 | help="""Starting learning rate. If adagrad/adadelta/adam is 125 | used, then this is the global learning rate. Recommended 126 | settings: sgd = 1, adagrad = 0.1, 127 | adadelta = 1, adam = 0.001""") 128 | parser.add_argument('-learning_rate_decay', type=float, default=0.5, 129 | help="""If update_learning_rate, decay learning rate by 130 | this much if (i) perplexity does not decrease on the 131 | validation set or (ii) epoch has gone past 132 | start_decay_at""") 133 | parser.add_argument('-start_decay_at', type=int, default=8, 134 | help="""Start decaying every epoch after and including this 135 | epoch""") 136 | 137 | # pretrained word vectors 138 | 139 | parser.add_argument('-pre_word_vecs_enc', 140 | help="""If a valid path is specified, then this will load 141 | pretrained word embeddings on the encoder side. 142 | See README for specific formatting instructions.""") 143 | parser.add_argument('-pre_word_vecs_dec', 144 | help="""If a valid path is specified, then this will load 145 | pretrained word embeddings on the decoder side. 146 | See README for specific formatting instructions.""") 147 | parser.add_argument('-pre_word_vecs', 148 | help="""If a valid path is specified, then this will load 149 | pretrained word embeddings on the language model. 150 | See README for specific formatting instructions.""") 151 | 152 | # GPU 153 | parser.add_argument('-gpus', default=[], nargs='+', type=int, 154 | help="Use CUDA on the listed devices.") 155 | 156 | parser.add_argument('-log_interval', type=int, default=50, 157 | help="Print stats at this interval.") 158 | 159 | parser.add_argument('-seed', type=int, default=-1, 160 | help="""Random seed used for the experiments 161 | reproducibility.""") 162 | 163 | opt = parser.parse_args() 164 | 165 | print(opt) 166 | 167 | if opt.seed > 0: 168 | torch.manual_seed(opt.seed) 169 | 170 | if torch.cuda.is_available() and not opt.gpus: 171 | print("WARNING: You have a CUDA device, should run with -gpus 0") 172 | 173 | if opt.gpus: 174 | cuda.set_device(opt.gpus[0]) 175 | if opt.seed > 0: 176 | torch.cuda.manual_seed(opt.seed) 177 | 178 | 179 | def NMTCriterion(vocabSize): 180 | weight = torch.ones(vocabSize) 181 | weight[onmt.Constants.PAD] = 0 182 | crit = nn.NLLLoss(weight, size_average=False) 183 | if opt.gpus: 184 | crit.cuda() 185 | return crit 186 | 187 | 188 | def memoryEfficientLoss(outputs, targets, generator, crit, eval=False): 189 | # compute generations one piece at a time 190 | num_correct, loss = 0, 0 191 | outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval) 192 | 193 | batch_size = outputs.size(1) 194 | outputs_split = torch.split(outputs, opt.max_generator_batches) 195 | targets_split = torch.split(targets, opt.max_generator_batches) 196 | for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)): 197 | out_t = out_t.view(-1, out_t.size(2)) 198 | scores_t = generator(out_t) 199 | loss_t = crit(scores_t, targ_t.view(-1)) 200 | pred_t = scores_t.max(1)[1] 201 | num_correct_t = pred_t.data.eq(targ_t.data) \ 202 | .masked_select( 203 | targ_t.ne(onmt.Constants.PAD).data) \ 204 | .sum() 205 | num_correct += num_correct_t 206 | loss += loss_t.data[0] 207 | if not eval: 208 | loss_t.div(batch_size).backward() 209 | 210 | grad_output = None if outputs.grad is None else outputs.grad.data 211 | return loss, grad_output, num_correct 212 | 213 | 214 | def eval(model, criterion, data): 215 | total_loss = 0 216 | total_words = 0 217 | total_num_correct = 0 218 | 219 | model.eval() 220 | for i in range(len(data)): 221 | # exclude original indices 222 | batch = data[i][:-1] 223 | outputs = model(batch) 224 | # exclude from targets 225 | targets = batch[1][1:] 226 | loss, _, num_correct = memoryEfficientLoss( 227 | outputs, targets, model.generator, criterion, eval=True) 228 | total_loss += loss 229 | total_num_correct += num_correct 230 | total_words += targets.data.ne(onmt.Constants.PAD).sum() 231 | 232 | model.train() 233 | return total_loss / total_words, total_num_correct / total_words 234 | 235 | 236 | def trainModel(model, trainData, validData, dataset, optim, opt): 237 | print(model) 238 | model.train() 239 | 240 | # Define criterion of each GPU. 241 | criterion = NMTCriterion(dataset['dicts']['tgt'].size()) 242 | 243 | start_time = time.time() 244 | 245 | def trainEpoch(epoch, iter): 246 | 247 | if opt.extra_shuffle and epoch > opt.curriculum: 248 | trainData.shuffle() 249 | 250 | # Shuffle mini batch order. 251 | batchOrder = torch.randperm(len(trainData)) 252 | 253 | total_loss, total_words, total_num_correct = 0, 0, 0 254 | report_loss, report_tgt_words = 0, 0 255 | report_src_words, report_num_correct = 0, 0 256 | start = time.time() 257 | for i in range(len(trainData)): 258 | 259 | if iter >= opt.epochs: 260 | break 261 | iter += 1 262 | 263 | batchIdx = batchOrder[i] if epoch > opt.curriculum else i 264 | # Exclude original indices. 265 | batch = trainData[batchIdx][:-1] 266 | 267 | model.zero_grad() 268 | outputs = model(batch) 269 | # Exclude from targets. 270 | targets = batch[1][1:] 271 | loss, gradOutput, num_correct = memoryEfficientLoss( 272 | outputs, targets, model.generator, criterion) 273 | 274 | outputs.backward(gradOutput) 275 | 276 | # Update the parameters. 277 | optim.step() 278 | 279 | num_words = targets.data.ne(onmt.Constants.PAD).sum() 280 | report_loss += loss 281 | report_num_correct += num_correct 282 | report_tgt_words += num_words 283 | report_src_words += batch[0][1].data.sum() 284 | total_loss += loss 285 | total_num_correct += num_correct 286 | total_words += num_words 287 | if i % opt.log_interval == -1 % opt.log_interval: 288 | print(("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; " + 289 | "%3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed") % 290 | (epoch, i+1, len(trainData), 291 | report_num_correct / report_tgt_words * 100, 292 | math.exp(report_loss / report_tgt_words), 293 | report_src_words/(time.time()-start), 294 | report_tgt_words/(time.time()-start), 295 | time.time()-start_time)) 296 | 297 | report_loss, report_tgt_words = 0, 0 298 | report_src_words, report_num_correct = 0, 0 299 | start = time.time() 300 | 301 | if iter % opt.save_each == 0: 302 | # (2) evaluate on the validation set 303 | valid_loss, valid_acc = eval(model, criterion, validData) 304 | valid_ppl = math.exp(min(valid_loss, 100)) 305 | print('Validation perplexity: %g' % valid_ppl) 306 | print('Validation accuracy: %g' % (valid_acc * 100)) 307 | 308 | # (3) update the learning rate 309 | if opt.use_learning_rate_decay: 310 | optim.updateLearningRate(valid_ppl, iter) 311 | 312 | model_state_dict = (model.module.state_dict() if len(opt.gpus) > 1 313 | else model.state_dict()) 314 | model_state_dict = {k: v for k, v in model_state_dict.items() 315 | if 'generator' not in k} 316 | generator_state_dict = (model.generator.module.state_dict() 317 | if len(opt.gpus) > 1 318 | else model.generator.state_dict()) 319 | # (4) drop a checkpoint 320 | checkpoint = { 321 | 'model': model_state_dict, 322 | 'generator': generator_state_dict, 323 | 'dicts': dataset['dicts'], 324 | 'opt': opt, 325 | 'epoch': epoch, 326 | 'optim': optim, 327 | 'type': opt.model_type 328 | } 329 | torch.save(checkpoint, 330 | '%s_acc_%.2f_ppl_%.2f_iter%d_e%d.pt' 331 | % (opt.save_model, 100 * valid_acc, valid_ppl, iter, epoch)) 332 | 333 | return total_loss / total_words, total_num_correct / total_words, iter 334 | 335 | epoch, iter = 1, 0 336 | while iter < opt.epochs: 337 | print('') 338 | # (1) train for one epoch on the training set 339 | train_loss, train_acc, iter = trainEpoch(epoch, iter) 340 | epoch += 1 341 | train_ppl = math.exp(min(train_loss, 100)) 342 | print('Train perplexity: %g' % train_ppl) 343 | print('Train accuracy: %g' % (train_acc*100)) 344 | 345 | 346 | def main(): 347 | print("Loading data from '%s'" % opt.data) 348 | 349 | dataset = torch.load(opt.data) 350 | if opt.model_type == 'nmt': 351 | if dataset.get("type", "text") not in ["bitext", "text"]: 352 | print("WARNING: The provided dataset is not bilingual!") 353 | elif opt.model_type == 'lm': 354 | if dataset.get("type", "text") != 'monotext': 355 | print("WARNING: The provided dataset is not monolingual!") 356 | else: 357 | raise NotImplementedError('Not valid model type %s' % opt.model_type) 358 | 359 | dict_checkpoint = (opt.train_from if opt.train_from 360 | else opt.train_from_state_dict) 361 | if dict_checkpoint: 362 | print('Loading dicts from checkpoint at %s' % dict_checkpoint) 363 | checkpoint = torch.load(dict_checkpoint) 364 | if opt.model_type == 'nmt': 365 | assert checkpoint.get('type', None) is None or \ 366 | checkpoint['type'] == "nmt", \ 367 | "The loaded model is not neural machine translation!" 368 | elif opt.model_type == 'lm': 369 | assert checkpoint['type'] == "lm", \ 370 | "The loaded model is not a language model!" 371 | dataset['dicts'] = checkpoint['dicts'] 372 | 373 | trainData = onmt.Dataset(dataset['train']['src'], 374 | dataset['train']['tgt'], opt.batch_size, opt.gpus, 375 | data_type=dataset.get("type", "text")) 376 | validData = onmt.Dataset(dataset['valid']['src'], 377 | dataset['valid']['tgt'], opt.batch_size, opt.gpus, 378 | volatile=True, 379 | data_type=dataset.get("type", "text")) 380 | 381 | dicts = dataset['dicts'] 382 | model_opt = checkpoint['opt'] if dict_checkpoint else opt 383 | if dicts.get('tgt', None) is None: 384 | # Makes the code compatible with the language model 385 | dicts['tgt'] = dicts['src'] 386 | if opt.model_type == 'nmt': 387 | print(' * vocabulary size. source = %d; target = %d' % 388 | (dicts['src'].size(), dicts['tgt'].size())) 389 | elif opt.model_type == 'lm': 390 | print(' * vocabulary size = %d' % 391 | (dicts['src'].size())) 392 | print(' * number of training sentences. %d' % 393 | len(dataset['train']['src'])) 394 | print(' * maximum batch size. %d' % opt.batch_size) 395 | 396 | print('Building model...') 397 | 398 | if opt.model_type == 'nmt': 399 | 400 | decoder = onmt.Decoders.getDecoder(model_opt.decoder_type)(model_opt, dicts['tgt']) 401 | encoder = onmt.Encoders.getEncoder(model_opt.encoder_type)(model_opt, dicts['src']) 402 | 403 | model = onmt.Models.NMTModel(encoder, decoder) 404 | 405 | elif opt.model_type == 'lm': 406 | model = onmt.LanguageModel.LM(model_opt, dicts['src']) 407 | 408 | generator = nn.Sequential( 409 | nn.Linear(model_opt.rnn_size, dicts['tgt'].size()), 410 | nn.LogSoftmax()) 411 | 412 | if opt.train_from: 413 | print('Loading model from checkpoint at %s' % opt.train_from) 414 | chk_model = checkpoint['model'] 415 | generator_state_dict = chk_model.generator.state_dict() 416 | model_state_dict = {k: v for k, v in chk_model.state_dict().items() 417 | if 'generator' not in k} 418 | model.load_state_dict(model_state_dict) 419 | generator.load_state_dict(generator_state_dict) 420 | opt.start_epoch = checkpoint['epoch'] + 1 421 | 422 | if opt.train_from_state_dict: 423 | print('Loading model from state_dict at %s' 424 | % opt.train_from_state_dict) 425 | model.load_state_dict(checkpoint['model']) 426 | generator.load_state_dict(checkpoint['generator']) 427 | model_opt.start_epoch = opt.start_epoch 428 | model_opt.epochs = opt.epochs 429 | 430 | if len(opt.gpus) >= 1: 431 | model.cuda() 432 | generator.cuda() 433 | else: 434 | model.cpu() 435 | generator.cpu() 436 | 437 | if len(opt.gpus) > 1: 438 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1) 439 | generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0) 440 | model_opt["gpus"] = opt.gpus 441 | 442 | model.generator = generator 443 | 444 | if not opt.train_from_state_dict and not opt.train_from: 445 | for p in model.parameters(): 446 | #p.data.uniform_(-opt.param_init, opt.param_init) 447 | if len(p.data.size()) > 1: 448 | init.xavier_normal(p.data) 449 | else: 450 | p.data.uniform_(-opt.param_init, opt.param_init) 451 | model.initialize_parameters(opt.param_init) 452 | model.load_pretrained_vectors(opt) 453 | 454 | if (not opt.train_from_state_dict and not opt.train_from) or opt.change_optimizer: 455 | optim = onmt.Optim( 456 | opt.optim, opt.learning_rate, opt.max_grad_norm, 457 | lr_decay=opt.learning_rate_decay, 458 | start_decay_at=opt.start_decay_at 459 | ) 460 | optim.set_parameters(model.parameters()) 461 | model_opt.learning_rate = opt.learning_rate 462 | model_opt.learning_rate_decay = opt.learning_rate_decay 463 | model_opt.save_each = opt.save_each 464 | 465 | else: 466 | print('Loading optimizer from checkpoint:') 467 | optim = checkpoint['optim'] 468 | optim.optimizer.load_state_dict( 469 | checkpoint['optim'].optimizer.state_dict()) 470 | optim.set_parameters(model.parameters()) 471 | 472 | nParams = sum([p.nelement() for p in model.parameters()]) 473 | print('* number of parameters: %d' % nParams) 474 | 475 | if opt.train_from or opt.train_from_state_dict: 476 | print(model_opt) 477 | 478 | model_opt.use_learning_rate_decay = opt.use_learning_rate_decay 479 | trainModel(model, trainData, validData, dataset, optim, model_opt) 480 | 481 | 482 | if __name__ == "__main__": 483 | main() 484 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | layer=$1 2 | gpu=$2 3 | python -u train.py -data data.train.pt -save_model path/to/model/SGU_${layer}layers -layer_norm 4 | -max_grad_norm 1 -layers_enc $layer -layers_dec $layer -dropout 0.1 -gpus $gpu -optim adam -learning_rate 0.0003 -decoder_type SR -encoder_type SR 5 | -attn_type dot -save_each 30000 -brnn -rnn_size 500 -epochs 1000000 -word_vec_size 500 > path/to/log.out 6 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from builtins import bytes 3 | 4 | import onmt 5 | import onmt.Markdown 6 | import torch 7 | import argparse 8 | import math 9 | import codecs 10 | import os 11 | 12 | parser = argparse.ArgumentParser(description='translate.py') 13 | onmt.Markdown.add_md_help_argument(parser) 14 | 15 | parser.add_argument('-model', required=True, 16 | help='Path to model .pt file') 17 | parser.add_argument('-src', required=True, 18 | help='Source sequence to decode (one line per sequence)') 19 | parser.add_argument('-src_img_dir', default="", 20 | help='Source image directory') 21 | parser.add_argument('-tgt', 22 | help='True target sequence (optional)') 23 | parser.add_argument('-output', default='pred.txt', 24 | help="""Path to output the predictions (each line will 25 | be the decoded sequence""") 26 | parser.add_argument('-beam_size', type=int, default=5, 27 | help='Beam size') 28 | parser.add_argument('-batch_size', type=int, default=30, 29 | help='Batch size') 30 | parser.add_argument('-max_sent_length', type=int, default=100, 31 | help='Maximum sentence length.') 32 | parser.add_argument('-replace_unk', action="store_true", 33 | help="""Replace the generated UNK tokens with the source 34 | token that had highest attention weight. If phrase_table 35 | is provided, it will lookup the identified source token and 36 | give the corresponding target token. If it is not provided 37 | (or the identified source token does not exist in the 38 | table) then it will copy the source token""") 39 | # parser.add_argument('-phrase_table', 40 | # help="""Path to source-target dictionary to replace UNK 41 | # tokens. See README.md for the format of this file.""") 42 | parser.add_argument('-verbose', action="store_true", 43 | help='Print scores and predictions for each sentence') 44 | parser.add_argument('-dump_beam', type=str, default="", 45 | help='File to dump beam information to.') 46 | 47 | parser.add_argument('-n_best', type=int, default=1, 48 | help="""If verbose is set, will output the n_best 49 | decoded sentences""") 50 | 51 | parser.add_argument('-gpu', type=int, default=-1, 52 | help="Device to run on") 53 | 54 | 55 | def reportScore(name, scoreTotal, wordsTotal): 56 | print("%s AVG SCORE: %.4f, %s PPL: %.4f" % ( 57 | name, scoreTotal / wordsTotal, 58 | name, math.exp(-scoreTotal/wordsTotal))) 59 | 60 | 61 | def addone(f): 62 | for line in f: 63 | yield line 64 | yield None 65 | 66 | 67 | def main(): 68 | opt = parser.parse_args() 69 | opt.cuda = opt.gpu > -1 70 | if opt.cuda: 71 | torch.cuda.set_device(opt.gpu) 72 | 73 | translator = onmt.Translator(opt) 74 | 75 | 76 | outF = codecs.open(opt.output, 'w', 'utf-8') 77 | 78 | predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0 79 | 80 | srcBatch, tgtBatch = [], [] 81 | 82 | count = 0 83 | 84 | tgtF = codecs.open(opt.tgt, 'r', 'utf-8') if opt.tgt else None 85 | 86 | if opt.dump_beam != "": 87 | import json 88 | translator.initBeamAccum() 89 | 90 | for line in addone(codecs.open(opt.src, 'r', 'utf-8')): 91 | if line is not None: 92 | srcTokens = line.split() 93 | srcBatch += [srcTokens] 94 | if tgtF: 95 | tgtTokens = tgtF.readline().split() if tgtF else None 96 | tgtBatch += [tgtTokens] 97 | 98 | if len(srcBatch) < opt.batch_size: 99 | continue 100 | else: 101 | # at the end of file, check last batch 102 | if len(srcBatch) == 0: 103 | break 104 | 105 | predBatch, predScore, goldScore = translator.translate(srcBatch, 106 | tgtBatch) 107 | predScoreTotal += sum(score[0] for score in predScore) 108 | predWordsTotal += sum(len(x[0]) for x in predBatch) 109 | if tgtF is not None: 110 | goldScoreTotal += sum(goldScore) 111 | goldWordsTotal += sum(len(x) for x in tgtBatch) 112 | 113 | for b in range(len(predBatch)): 114 | count += 1 115 | outF.write(" ".join(predBatch[b][0]) + '\n') 116 | outF.flush() 117 | 118 | if opt.verbose: 119 | srcSent = ' '.join(srcBatch[b]) 120 | if translator.tgt_dict.lower: 121 | srcSent = srcSent.lower() 122 | os.write(1, bytes('SENT %d: %s\n' % (count, srcSent), 'UTF-8')) 123 | os.write(1, bytes('PRED %d: %s\n' % 124 | (count, " ".join(predBatch[b][0])), 'UTF-8')) 125 | print("PRED SCORE: %.4f" % predScore[b][0]) 126 | 127 | if tgtF is not None: 128 | tgtSent = ' '.join(tgtBatch[b]) 129 | if translator.tgt_dict.lower: 130 | tgtSent = tgtSent.lower() 131 | os.write(1, bytes('GOLD %d: %s\n' % 132 | (count, tgtSent), 'UTF-8')) 133 | print("GOLD SCORE: %.4f" % goldScore[b]) 134 | 135 | if opt.n_best > 1: 136 | print('\nBEST HYP:') 137 | for n in range(opt.n_best): 138 | os.write(1, bytes("[%.4f] %s\n" % (predScore[b][n], 139 | " ".join(predBatch[b][n])), 140 | 'UTF-8')) 141 | 142 | print('') 143 | 144 | srcBatch, tgtBatch = [], [] 145 | 146 | reportScore('PRED', predScoreTotal, predWordsTotal) 147 | if tgtF: 148 | reportScore('GOLD', goldScoreTotal, goldWordsTotal) 149 | 150 | if tgtF: 151 | tgtF.close() 152 | 153 | if opt.dump_beam: 154 | json.dump(translator.beam_accum, 155 | codecs.open(opt.dump_beam, 'w', 'utf-8')) 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /translate.sh: -------------------------------------------------------------------------------- 1 | model=$1 2 | test=$2 3 | gpu=$3 4 | python translate.py -src $test -model $model -output $model.test.out -gpu $gpu -batch_size 1 5 | --------------------------------------------------------------------------------