├── .gitignore
├── LICENSE.md
├── README.md
├── __init__.py
├── data
├── src-test.txt
├── src-train.txt
├── src-val.txt
├── tgt-train.txt
└── tgt-val.txt
├── docs
├── README.md
├── _config.yml
├── css
│ └── extra.css
├── extended.md
├── generate.sh
├── img
│ ├── architecture.png
│ ├── brnn.png
│ ├── dbrnn.png
│ ├── favicon.ico
│ ├── global-attention-model.png
│ ├── input_feed.png
│ ├── logo-alpha.png
│ ├── pdbrnn.png
│ └── residual.png
├── index.md
├── installation.md
├── options
│ ├── preprocess.md
│ ├── train.md
│ └── translate.md
├── quickstart.md
└── references.md
├── eval.sh
├── mkdocs.yml
├── onmt
├── Beam.py
├── Constants.py
├── Dataset.py
├── Decoders.py
├── Dict.py
├── Encoders.py
├── Markdown.py
├── Models.py
├── Optim.py
├── Translator.py
├── __init__.py
└── modules
│ ├── Attention.py
│ ├── Gate.py
│ ├── ImageEncoder.py
│ ├── Normalization.py
│ ├── SRU_units.py
│ ├── Units.py
│ └── __init__.py
├── preprocess.py
├── setup.py
├── test
└── test_simple.py
├── tools
└── extract_embeddings.py
├── train.py
├── train.sh
├── translate.py
└── translate.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | pred.txt
2 | multi-bleu.perl
3 | *.pt
4 | *.pyc
5 | #.*
6 | .idea
7 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | This software is derived from the OpenNMT project at
2 | https://github.com/OpenNMT/OpenNMT.
3 |
4 | The MIT License (MIT)
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenNMT: Open-Source Neural Machine Translation
2 |
3 | This is an extension of [OpenNMT](https://github.com/OpenNMT/OpenNMT),
4 | which includes the code for the SR-NMT that has been introduced in
5 | [Deep Neural Machine Translation with Weakly-Recurrent Units](https://arxiv.org/abs/1805.04185).
6 |
7 |
8 |
9 | ## Quickstart
10 |
11 | ## Some useful tools:
12 |
13 | The example below uses the Moses tokenizer (http://www.statmt.org/moses/) to prepare the data and the moses BLEU script for evaluation.
14 |
15 | ```bash
16 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl
17 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de
18 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en
19 | sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl
20 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl
21 | ```
22 |
23 | ## A simple pipeline:
24 |
25 | Download and preprocess the data as you would do for [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py).
26 | Then use preprocess.py, train.sh and translate.sh for the actual training and translation.
27 |
28 | ### 1) Preprocess the data.
29 |
30 | ```bash
31 | python preprocess.py -train_src /path/to/data/train.src -train_tgt /path/to/data/train.tgt -valid_src /path/to/data/valid.src -valid_tgt /path/to/data/valid.tgt -save_data /path/to/data/data
32 | ```
33 |
34 | ### 2) Train the model.
35 |
36 | ```bash
37 | sh train.sh num_layers num_gpu
38 | ```
39 |
40 | ### 3) Translate sentences.
41 |
42 | ```bash
43 | sh translate.sh model_name test_file num_gpu
44 | ```
45 |
46 | ### 4) Evaluate.
47 | ```bash
48 | sh eval.sh hypothesys target_language /path/to/test/tokenized.tgt
49 | ```
50 | This evaluation is consistent with the one used in the paper and was taken from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/get_ende_bleu.sh).
51 |
52 | ## New versions
53 | We are working to integrate SR-NMT inside:
54 | - [OpenNMT-py](https://github.com/mattiadg/OpenNMT-py)
55 | ([OpenNMT/OpenNMT-py#748](https://github.com/OpenNMT/OpenNMT-py/pull/748))
56 | Status: Testing
57 |
58 | - [OpenNMT-tf](https://github.com/mattiadg/OpenNMT-tf/tree/srnmt)
59 | Status: Development
60 |
61 | ## Citation
62 |
63 | If you use this software, please cite:
64 |
65 | ```
66 | @inproceedings{digangi2018deep,
67 | author = {Di Gangi, Mattia A and Federico, Marcello},
68 | title = {Deep Neural Machine Translation with Weakly-Recurrent Units},
69 | booktitle = {Proceedings of the 21st Annual Conference of the European Association for Machine Translation},
70 | pages = {119--128},
71 | year = {2018}
72 | }
73 | ```
74 |
75 |
76 | [OpenNMT technical report](https://doi.org/10.18653/v1/P17-4012)
77 |
78 | ```
79 | @inproceedings{opennmt,
80 | author = {Guillaume Klein and
81 | Yoon Kim and
82 | Yuntian Deng and
83 | Jean Senellart and
84 | Alexander M. Rush},
85 | title = {OpenNMT: Open-Source Toolkit for Neural Machine Translation},
86 | booktitle = {Proc. ACL},
87 | year = {2017},
88 | url = {https://doi.org/10.18653/v1/P17-4012},
89 | doi = {10.18653/v1/P17-4012}
90 | }
91 | ```
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/__init__.py
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | [MkDocs](http://www.mkdocs.org/) is used to generate the documentation at http://opennmt.net/OpenNMT/.
2 |
3 | Documentation under construction for [SR-NMT](https://github.com/mattiadg/SR-NMT)
4 |
5 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-minimal
--------------------------------------------------------------------------------
/docs/css/extra.css:
--------------------------------------------------------------------------------
1 | .md-nav__item--active > .md-nav__link, .md-nav__link:active, .md-typeset a {
2 | color: #ac4142;
3 | }
4 |
5 | .md-nav__link:focus, .md-nav__link:hover, .md-typeset a:active, .md-typeset a:hover {
6 | color: #d67272;
7 | }
8 |
9 | .md-header {
10 | background-color: #ac4142;
11 | }
12 |
13 | label.md-nav__title.md-nav__title--site {
14 | background-color: #ac4142;
15 | color: white;
16 | padding: 1rem 1.2rem;
17 | font-size: 1.6rem;
18 | }
19 |
20 | .md-nav.md-nav--secondary {
21 | border-left-color: #ac4142;
22 | }
23 |
24 | .md-sidebar.md-sidebar--primary {
25 | height: 420px;
26 | }
27 |
28 | .md-flex__cell.md-flex__cell--shrink img {
29 | width:36px;height:36px;margin-top:-6px
30 | }
--------------------------------------------------------------------------------
/docs/extended.md:
--------------------------------------------------------------------------------
1 | ## Some useful tools:
2 |
3 | The example below uses the Moses tokenizer (http://www.statmt.org/moses/) to prepare the data and the moses BLEU script for evaluation.
4 |
5 | ```bash
6 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl
7 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de
8 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en
9 | sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl
10 | wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl
11 | ```
12 |
13 | ## WMT'16 Multimodal Translation: Multi30k (de-en)
14 |
15 | An example of training for the WMT'16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html).
16 |
17 | ### 0) Download the data.
18 |
19 | ```bash
20 | mkdir -p data/multi30k
21 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz && tar -xf training.tar.gz -C data/multi30k && rm training.tar.gz
22 | wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz && tar -xf validation.tar.gz -C data/multi30k && rm validation.tar.gz
23 | wget https://staff.fnwi.uva.nl/d.elliott/wmt16/mmt16_task1_test.tgz && tar -xf mmt16_task1_test.tgz -C data/multi30k && rm mmt16_task1_test.tgz
24 | ```
25 |
26 | ### 1) Preprocess the data.
27 |
28 | ```bash
29 | for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done
30 | for l in en de; do for f in data/multi30k/*.$l; do perl tokenizer.perl -a -no-escape -l $l -q < $f > $f.atok; done; done
31 | python preprocess.py -train_src data/multi30k/train.en.atok -train_tgt data/multi30k/train.de.atok -valid_src data/multi30k/val.en.atok -valid_tgt data/multi30k/val.de.atok -save_data data/multi30k.atok.low -lower
32 | ```
33 |
34 | ### 2) Train the model.
35 |
36 | ```bash
37 | python train.py -data data/multi30k.atok.low.train.pt -save_model multi30k_model -gpus 0
38 | ```
39 |
40 | ### 3) Translate sentences.
41 |
42 | ```bash
43 | python translate.py -gpu 0 -model multi30k_model_e13_*.pt -src data/multi30k/test.en.atok -tgt data/multi30k/test.de.atok -replace_unk -verbose -output multi30k.test.pred.atok
44 | ```
45 |
46 | ### 4) Evaluate.
47 |
48 | ```bash
49 | perl multi-bleu.perl data/multi30k/test.de.atok < multi30k.test.pred.atok
50 | ```
51 |
52 | ## Pretrained Models
53 |
54 | The following pretrained models can be downloaded and used with translate.py (These were trained with an older version of the code; they will be updated soon).
55 |
56 | - [onmt_model_en_de_200k](https://s3.amazonaws.com/pytorch/examples/opennmt/models/onmt_model_en_de_200k-4783d9c3.pt): An English-German translation model based on the 200k sentence dataset at [OpenNMT/IntegrationTesting](https://github.com/OpenNMT/IntegrationTesting/tree/master/data). Perplexity: 21.
57 | - [onmt_model_en_fr_b1M](https://s3.amazonaws.com/pytorch/examples/opennmt/models/onmt_model_en_fr_b1M-261c69a7.pt): An English-French model trained on benchmark-1M. Perplexity: 4.85.
58 |
59 |
--------------------------------------------------------------------------------
/docs/generate.sh:
--------------------------------------------------------------------------------
1 | #! /bin/sh
2 |
3 | gen_script_options ()
4 | {
5 | echo "" > $2
6 | echo "" >> $2
7 | python $1 -md >> $2
8 | }
9 |
10 | gen_script_options preprocess.py docs/options/preprocess.md
11 | gen_script_options train.py docs/options/train.md
12 | gen_script_options translate.py docs/options/translate.md
13 |
--------------------------------------------------------------------------------
/docs/img/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/architecture.png
--------------------------------------------------------------------------------
/docs/img/brnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/brnn.png
--------------------------------------------------------------------------------
/docs/img/dbrnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/dbrnn.png
--------------------------------------------------------------------------------
/docs/img/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/favicon.ico
--------------------------------------------------------------------------------
/docs/img/global-attention-model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/global-attention-model.png
--------------------------------------------------------------------------------
/docs/img/input_feed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/input_feed.png
--------------------------------------------------------------------------------
/docs/img/logo-alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/logo-alpha.png
--------------------------------------------------------------------------------
/docs/img/pdbrnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/pdbrnn.png
--------------------------------------------------------------------------------
/docs/img/residual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mattiadg/SR-NMT/650c45b1981c4a9a72a8a8205d0185a9c2381f42/docs/img/residual.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | This portal provides a detailled documentation of the OpenNMT toolkit. It describes how to use the PyTorch project and how it works.
2 |
3 | *For the Lua Torch version, visit the documentation at [GitHub](http://opennmt.net/OpenNMT).*
4 |
5 | ## Additional resources
6 |
7 | You can find additional help or tutorials in the following resources:
8 |
9 | * [Forum](http://forum.opennmt.net/)
10 | * [Gitter channel](https://gitter.im/OpenNMT/openmt)
11 |
12 | !!! note "Note"
13 | If you find an error in this documentation, please consider [opening an issue](https://github.com/OpenNMT/OpenNMT-py/issues/new) or directly submitting a modification by clicking on the edit button at the top of a page.
14 |
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | ## Standard
2 |
3 | 1\. [Install PyTorch](http://pytorch.org/)
4 |
5 | 2\. Clone the OpenNMT-py repository:
6 |
7 | ```bash
8 | git clone https://github.com/OpenNMT/OpenNMT-py
9 | cd OpenNMT-py
10 | ```
11 |
12 | And you are ready to go! Take a look at the [quickstart](quickstart.md) to familiarize yourself with the main training workflow.
13 |
14 |
--------------------------------------------------------------------------------
/docs/options/preprocess.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # preprocess.py:
4 |
5 | ```
6 | usage: preprocess.py [-h] [-md] [-config CONFIG] -train_src TRAIN_SRC
7 | -train_tgt TRAIN_TGT -valid_src VALID_SRC -valid_tgt
8 | VALID_TGT -save_data SAVE_DATA
9 | [-src_vocab_size SRC_VOCAB_SIZE]
10 | [-tgt_vocab_size TGT_VOCAB_SIZE] [-src_vocab SRC_VOCAB]
11 | [-tgt_vocab TGT_VOCAB] [-seq_length SEQ_LENGTH]
12 | [-shuffle SHUFFLE] [-seed SEED] [-lower]
13 | [-report_every REPORT_EVERY]
14 |
15 | ```
16 |
17 | preprocess.py
18 |
19 | ## **optional arguments**:
20 | ### **-h, --help**
21 |
22 | ```
23 | show this help message and exit
24 | ```
25 |
26 | ### **-md**
27 |
28 | ```
29 | print Markdown-formatted help text and exit.
30 | ```
31 |
32 | ### **-config CONFIG**
33 |
34 | ```
35 | Read options from this file
36 | ```
37 |
38 | ### **-train_src TRAIN_SRC**
39 |
40 | ```
41 | Path to the training source data
42 | ```
43 |
44 | ### **-train_tgt TRAIN_TGT**
45 |
46 | ```
47 | Path to the training target data
48 | ```
49 |
50 | ### **-valid_src VALID_SRC**
51 |
52 | ```
53 | Path to the validation source data
54 | ```
55 |
56 | ### **-valid_tgt VALID_TGT**
57 |
58 | ```
59 | Path to the validation target data
60 | ```
61 |
62 | ### **-save_data SAVE_DATA**
63 |
64 | ```
65 | Output file for the prepared data
66 | ```
67 |
68 | ### **-src_vocab_size SRC_VOCAB_SIZE**
69 |
70 | ```
71 | Size of the source vocabulary
72 | ```
73 |
74 | ### **-tgt_vocab_size TGT_VOCAB_SIZE**
75 |
76 | ```
77 | Size of the target vocabulary
78 | ```
79 |
80 | ### **-src_vocab SRC_VOCAB**
81 |
82 | ```
83 | Path to an existing source vocabulary
84 | ```
85 |
86 | ### **-tgt_vocab TGT_VOCAB**
87 |
88 | ```
89 | Path to an existing target vocabulary
90 | ```
91 |
92 | ### **-seq_length SEQ_LENGTH**
93 |
94 | ```
95 | Maximum sequence length
96 | ```
97 |
98 | ### **-shuffle SHUFFLE**
99 |
100 | ```
101 | Shuffle data
102 | ```
103 |
104 | ### **-seed SEED**
105 |
106 | ```
107 | Random seed
108 | ```
109 |
110 | ### **-lower**
111 |
112 | ```
113 | lowercase data
114 | ```
115 |
116 | ### **-report_every REPORT_EVERY**
117 |
118 | ```
119 | Report status every this many sentences
120 | ```
121 |
--------------------------------------------------------------------------------
/docs/options/train.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # train.py:
4 |
5 | ```
6 | usage: train.py [-h] [-md] -data DATA [-save_model SAVE_MODEL]
7 | [-train_from_state_dict TRAIN_FROM_STATE_DICT]
8 | [-train_from TRAIN_FROM] [-layers LAYERS] [-rnn_size RNN_SIZE]
9 | [-word_vec_size WORD_VEC_SIZE] [-input_feed INPUT_FEED]
10 | [-brnn] [-brnn_merge BRNN_MERGE] [-batch_size BATCH_SIZE]
11 | [-max_generator_batches MAX_GENERATOR_BATCHES]
12 | [-epochs EPOCHS] [-start_epoch START_EPOCH]
13 | [-param_init PARAM_INIT] [-optim OPTIM]
14 | [-max_grad_norm MAX_GRAD_NORM] [-dropout DROPOUT]
15 | [-curriculum] [-extra_shuffle] [-learning_rate LEARNING_RATE]
16 | [-learning_rate_decay LEARNING_RATE_DECAY]
17 | [-start_decay_at START_DECAY_AT]
18 | [-pre_word_vecs_enc PRE_WORD_VECS_ENC]
19 | [-pre_word_vecs_dec PRE_WORD_VECS_DEC] [-gpus GPUS [GPUS ...]]
20 | [-log_interval LOG_INTERVAL]
21 |
22 | ```
23 |
24 | train.py
25 |
26 | ## **optional arguments**:
27 | ### **-h, --help**
28 |
29 | ```
30 | show this help message and exit
31 | ```
32 |
33 | ### **-md**
34 |
35 | ```
36 | print Markdown-formatted help text and exit.
37 | ```
38 |
39 | ### **-data DATA**
40 |
41 | ```
42 | Path to the *-train.pt file from preprocess.py
43 | ```
44 |
45 | ### **-save_model SAVE_MODEL**
46 |
47 | ```
48 | Model filename (the model will be saved as _epochN_PPL.pt where PPL
49 | is the validation perplexity
50 | ```
51 |
52 | ### **-train_from_state_dict TRAIN_FROM_STATE_DICT**
53 |
54 | ```
55 | If training from a checkpoint then this is the path to the pretrained model's
56 | state_dict.
57 | ```
58 |
59 | ### **-train_from TRAIN_FROM**
60 |
61 | ```
62 | If training from a checkpoint then this is the path to the pretrained model.
63 | ```
64 |
65 | ### **-layers LAYERS**
66 |
67 | ```
68 | Number of layers in the LSTM encoder/decoder
69 | ```
70 |
71 | ### **-rnn_size RNN_SIZE**
72 |
73 | ```
74 | Size of LSTM hidden states
75 | ```
76 |
77 | ### **-word_vec_size WORD_VEC_SIZE**
78 |
79 | ```
80 | Word embedding sizes
81 | ```
82 |
83 | ### **-input_feed INPUT_FEED**
84 |
85 | ```
86 | Feed the context vector at each time step as additional input (via concatenation
87 | with the word embeddings) to the decoder.
88 | ```
89 |
90 | ### **-brnn**
91 |
92 | ```
93 | Use a bidirectional encoder
94 | ```
95 |
96 | ### **-brnn_merge BRNN_MERGE**
97 |
98 | ```
99 | Merge action for the bidirectional hidden states: [concat|sum]
100 | ```
101 |
102 | ### **-batch_size BATCH_SIZE**
103 |
104 | ```
105 | Maximum batch size
106 | ```
107 |
108 | ### **-max_generator_batches MAX_GENERATOR_BATCHES**
109 |
110 | ```
111 | Maximum batches of words in a sequence to run the generator on in parallel.
112 | Higher is faster, but uses more memory.
113 | ```
114 |
115 | ### **-epochs EPOCHS**
116 |
117 | ```
118 | Number of training epochs
119 | ```
120 |
121 | ### **-start_epoch START_EPOCH**
122 |
123 | ```
124 | The epoch from which to start
125 | ```
126 |
127 | ### **-param_init PARAM_INIT**
128 |
129 | ```
130 | Parameters are initialized over uniform distribution with support (-param_init,
131 | param_init)
132 | ```
133 |
134 | ### **-optim OPTIM**
135 |
136 | ```
137 | Optimization method. [sgd|adagrad|adadelta|adam]
138 | ```
139 |
140 | ### **-max_grad_norm MAX_GRAD_NORM**
141 |
142 | ```
143 | If the norm of the gradient vector exceeds this, renormalize it to have the norm
144 | equal to max_grad_norm
145 | ```
146 |
147 | ### **-dropout DROPOUT**
148 |
149 | ```
150 | Dropout probability; applied between LSTM stacks.
151 | ```
152 |
153 | ### **-curriculum**
154 |
155 | ```
156 | For this many epochs, order the minibatches based on source sequence length.
157 | Sometimes setting this to 1 will increase convergence speed.
158 | ```
159 |
160 | ### **-extra_shuffle**
161 |
162 | ```
163 | By default only shuffle mini-batch order; when true, shuffle and re-assign mini-
164 | batches
165 | ```
166 |
167 | ### **-learning_rate LEARNING_RATE**
168 |
169 | ```
170 | Starting learning rate. If adagrad/adadelta/adam is used, then this is the
171 | global learning rate. Recommended settings: sgd = 1, adagrad = 0.1, adadelta =
172 | 1, adam = 0.001
173 | ```
174 |
175 | ### **-learning_rate_decay LEARNING_RATE_DECAY**
176 |
177 | ```
178 | If update_learning_rate, decay learning rate by this much if (i) perplexity does
179 | not decrease on the validation set or (ii) epoch has gone past start_decay_at
180 | ```
181 |
182 | ### **-start_decay_at START_DECAY_AT**
183 |
184 | ```
185 | Start decaying every epoch after and including this epoch
186 | ```
187 |
188 | ### **-pre_word_vecs_enc PRE_WORD_VECS_ENC**
189 |
190 | ```
191 | If a valid path is specified, then this will load pretrained word embeddings on
192 | the encoder side. See README for specific formatting instructions.
193 | ```
194 |
195 | ### **-pre_word_vecs_dec PRE_WORD_VECS_DEC**
196 |
197 | ```
198 | If a valid path is specified, then this will load pretrained word embeddings on
199 | the decoder side. See README for specific formatting instructions.
200 | ```
201 |
202 | ### **-gpus GPUS [GPUS ...]**
203 |
204 | ```
205 | Use CUDA on the listed devices.
206 | ```
207 |
208 | ### **-log_interval LOG_INTERVAL**
209 |
210 | ```
211 | Print stats at this interval.
212 | ```
213 |
--------------------------------------------------------------------------------
/docs/options/translate.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # translate.py:
4 |
5 | ```
6 | usage: translate.py [-h] [-md] -model MODEL -src SRC [-tgt TGT]
7 | [-output OUTPUT] [-beam_size BEAM_SIZE]
8 | [-batch_size BATCH_SIZE]
9 | [-max_sent_length MAX_SENT_LENGTH] [-replace_unk]
10 | [-verbose] [-n_best N_BEST] [-gpu GPU]
11 |
12 | ```
13 |
14 | translate.py
15 |
16 | ## **optional arguments**:
17 | ### **-h, --help**
18 |
19 | ```
20 | show this help message and exit
21 | ```
22 |
23 | ### **-md**
24 |
25 | ```
26 | print Markdown-formatted help text and exit.
27 | ```
28 |
29 | ### **-model MODEL**
30 |
31 | ```
32 | Path to model .pt file
33 | ```
34 |
35 | ### **-src SRC**
36 |
37 | ```
38 | Source sequence to decode (one line per sequence)
39 | ```
40 |
41 | ### **-tgt TGT**
42 |
43 | ```
44 | True target sequence (optional)
45 | ```
46 |
47 | ### **-output OUTPUT**
48 |
49 | ```
50 | Path to output the predictions (each line will be the decoded sequence
51 | ```
52 |
53 | ### **-beam_size BEAM_SIZE**
54 |
55 | ```
56 | Beam size
57 | ```
58 |
59 | ### **-batch_size BATCH_SIZE**
60 |
61 | ```
62 | Batch size
63 | ```
64 |
65 | ### **-max_sent_length MAX_SENT_LENGTH**
66 |
67 | ```
68 | Maximum sentence length.
69 | ```
70 |
71 | ### **-replace_unk**
72 |
73 | ```
74 | Replace the generated UNK tokens with the source token that had the highest
75 | attention weight. If phrase_table is provided, it will lookup the identified
76 | source token and give the corresponding target token. If it is not provided (or
77 | the identified source token does not exist in the table) then it will copy the
78 | source token
79 | ```
80 |
81 | ### **-verbose**
82 |
83 | ```
84 | Print scores and predictions for each sentence
85 | ```
86 |
87 | ### **-n_best N_BEST**
88 |
89 | ```
90 | If verbose is set, will output the n_best decoded sentences
91 | ```
92 |
93 | ### **-gpu GPU**
94 |
95 | ```
96 | Device to run on
97 | ```
98 |
--------------------------------------------------------------------------------
/docs/quickstart.md:
--------------------------------------------------------------------------------
1 | ## Step 1: Preprocess the data
2 |
3 | ```bash
4 | python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo
5 | ```
6 |
7 | We will be working with some example data in `data/` folder.
8 |
9 | The data consists of parallel source (`src`) and target (`tgt`) data containing one sentence per line with tokens separated by a space:
10 |
11 | * `src-train.txt`
12 | * `tgt-train.txt`
13 | * `src-val.txt`
14 | * `tgt-val.txt`
15 |
16 | Validation files are required and used to evaluate the convergence of the training. It usually contains no more than 5000 sentences.
17 |
18 | ```text
19 | $ head -n 3 data/src-train.txt
20 | It is not acceptable that , with the help of the national bureaucracies , Parliament 's legislative prerogative should be made null and void by means of implementing provisions whose content , purpose and extent are not laid down in advance .
21 | Federal Master Trainer and Senior Instructor of the Italian Federation of Aerobic Fitness , Group Fitness , Postural Gym , Stretching and Pilates; from 2004 , he has been collaborating with Antiche Terme as personal Trainer and Instructor of Stretching , Pilates and Postural Gym .
22 | " Two soldiers came up to me and told me that if I refuse to sleep with them , they will kill me . They beat me and ripped my clothes .
23 | ```
24 |
25 | After running the preprocessing, the following files are generated:
26 |
27 | * `demo.src.dict`: Dictionary of source vocab to index mappings.
28 | * `demo.tgt.dict`: Dictionary of target vocab to index mappings.
29 | * `demo.train.pt`: serialized PyTorch file containing vocabulary, training and validation data
30 |
31 | The `*.dict` files are needed to check or reuse the vocabularies. These files are simple human-readable dictionaries.
32 |
33 | ```text
34 | $ head -n 10 data/demo.src.dict
35 | 1
36 | 2
37 | 3
38 | 4
39 | It 5
40 | is 6
41 | not 7
42 | acceptable 8
43 | that 9
44 | , 10
45 | with 11
46 | ```
47 |
48 | Internally the system never touches the words themselves, but uses these indices.
49 |
50 | ## Step 2: Train the model
51 |
52 | ```bash
53 | python train.py -data data/demo.train.pt -save_model demo-model
54 | ```
55 |
56 | The main train command is quite simple. Minimally it takes a data file
57 | and a save file. This will run the default model, which consists of a
58 | 2-layer LSTM with 500 hidden units on both the encoder/decoder. You
59 | can also add `-gpus 1` to use (say) GPU 1.
60 |
61 | ## Step 3: Translate
62 |
63 | ```bash
64 | python translate.py -model demo-model_epochX_PPL.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose
65 | ```
66 |
67 | Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`.
68 |
69 | !!! note "Note"
70 | The predictions are going to be quite terrible, as the demo dataset is small. Try running on some larger datasets! For example you can download millions of parallel sentences for [translation](http://www.statmt.org/wmt16/translation-task.html) or [summarization](https://github.com/harvardnlp/sent-summary).
71 |
--------------------------------------------------------------------------------
/docs/references.md:
--------------------------------------------------------------------------------
1 | This is the list of papers, OpenNMT has been inspired on:
2 |
3 | * Luong, M. T., Pham, H., & Manning, C. D. (2015). [Effective approaches to attention-based neural machine translation](https://arxiv.org/abs/1508.04025). arXiv preprint arXiv:1508.04025.
4 | * Sennrich, R., & Haddow, B. (2016). [Linguistic input features improve neural machine translation](https://arxiv.org/abs/1606.02892). arXiv preprint arXiv:1606.02892.
5 | * Sennrich, R., Haddow, B., & Birch, A. (2015). [Neural machine translation of rare words with subword units](https://arxiv.org/abs/1508.07909). arXiv preprint arXiv:1508.07909.
6 | * Wu, Y., Schuster, M., Chen, Z., Le, Q. V., Norouzi, M., Macherey, W., ... & Klingner, J. (2016). [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144). arXiv preprint arXiv:1609.08144.
7 | * Jean, S., Cho, K., Memisevic, R., Bengio, Y. (2015). [On Using Very Large Target Vocabulary for Neural Machine Translation](http://www.aclweb.org/anthology/P15-1001). ACL 2015
8 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | hyp=$1
4 | tgt=$2
5 | tok_gold_targets=$3
6 |
7 | mosesdecoder=/hltsrv1/software/moses/moses-20150228_kenlm_cmph_xmlrpc_irstlm_master/
8 |
9 | sed -e "s/@@ //g" < $hyp | $mosesdecoder/scripts/tokenizer/detokenizer.perl $tgt | $mosesdecoder/scripts/recaser/detruecase.perl > $hyp.tmp
10 | # Tokenize.
11 | perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l $tgt < $hyp.tmp > $hyp.tok
12 |
13 | # Put compounds in ATAT format (comparable to papers like GNMT, ConvS2S).
14 | # See https://nlp.stanford.edu/projects/nmt/ :
15 | # 'Also, for historical reasons, we split compound words, e.g.,
16 | # "rich-text format" --> rich ##AT##-##AT## text format."'
17 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $tok_gold_targets > $tok_gold_targets.atat
18 | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < $hyp.tok > $hyp.atat
19 |
20 | # Get BLEU.
21 | perl $mosesdecoder/scripts/generic/multi-bleu.perl $tok_gold_targets.atat < $hyp.atat
22 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: OpenNMT-py
2 | repo_name: 'OpenNMT/OpenNMT-py'
3 | repo_url: https://github.com/OpenNMT/OpenNMT-py
4 | edit_uri: edit/master/docs/
5 |
6 | docs_dir: docs
7 | theme: 'material'
8 | extra:
9 | logo: 'img/logo-alpha.png'
10 | social:
11 | - type: 'globe'
12 | link: 'http://opennmt.net'
13 | - type: 'github'
14 | link: 'https://github.com/OpenNMT/OpenNMT-py'
15 |
16 | google_analytics: ['UA-89222039-1', 'opennmt.net']
17 | extra_javascript:
18 | - https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_HTMLorMML
19 | extra_css:
20 | - css/extra.css
21 |
22 | markdown_extensions:
23 | - math
24 | - sane_lists
25 | - def_list
26 | - fenced_code
27 | - admonition
28 | - codehilite(guess_lang=false)
29 | - toc(permalink=true)
30 |
31 | pages:
32 | - Overview: index.md
33 | - Installation: installation.md
34 | - Quickstart: quickstart.md
35 | - "Extended Example": extended.md
36 |
37 | # - Data:
38 | # - Preparation: data/preparation.md
39 | # - "Word features": data/word_features.md
40 | # - Training:
41 | # - Models: training/models.md
42 | # - Embeddings: training/embeddings.md
43 | # - Logs: training/logs.md
44 | # - "Multi GPU": training/multi_gpu.md
45 | # - Retraining: training/retraining.md
46 | # - "Decay strategies": training/decay.md
47 | # - "Data sampling": training/sampling.md
48 | # - Translation:
49 | # - Inference: translation/inference.md
50 | # - "Beam search": translation/beam_search.md
51 | # - "Unknown words": translation/unknowns.md
52 | # - Tools:
53 | # - Tokenization: tools/tokenization.md
54 | # - Servers: tools/servers.md
55 | - "Reference: Options":
56 | # - "Scripts usage": options/usage.md
57 | - "preprocess.py": options/preprocess.md
58 | - "train.py": options/train.md
59 | - "translate.py": options/translate.md
60 | # - "tag.lua": options/tag.md
61 | # - "tools/build_vocab.lua": options/build_vocab.md
62 | # - "tools/release_model.lua": options/release_model.md
63 | # - "tools/tokenize.lua": options/tokenize.md
64 | # - "tools/learn_bpe.lua": options/learn_bpe.md
65 | # - "tools/translation_server.lua": options/server.md
66 | # - "tools/rest_translation_server.lua": options/rest_server.md
67 | # - "tools/embeddings.lua": options/embeddings.md
68 | # - Extensions: extensions.md
69 | - References: references.md
70 | # - "Common issues": issues.md
71 |
--------------------------------------------------------------------------------
/onmt/Beam.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import onmt
4 |
5 | """
6 | Class for managing the internals of the beam search process.
7 |
8 |
9 | hyp1-hyp1---hyp1 -hyp1
10 | \ /
11 | hyp2 \-hyp2 /-hyp2hyp2
12 | / \
13 | hyp3-hyp3---hyp3 -hyp3
14 | ========================
15 |
16 | Takes care of beams, back pointers, and scores.
17 | """
18 |
19 |
20 | class Beam(object):
21 | def __init__(self, size, cuda=False):
22 |
23 | self.size = size
24 | self.done = False
25 |
26 | self.tt = torch.cuda if cuda else torch
27 |
28 | # The score for each translation on the beam.
29 | self.scores = self.tt.FloatTensor(size).zero_()
30 | self.allScores = []
31 |
32 | # The backpointers at each time-step.
33 | self.prevKs = []
34 |
35 | # The outputs at each time-step.
36 | self.nextYs = [self.tt.LongTensor(size).fill_(onmt.Constants.PAD)]
37 | self.nextYs[0][0] = onmt.Constants.BOS
38 |
39 | # The attentions (matrix) for each time.
40 | self.attn = []
41 |
42 | def getCurrentState(self):
43 | "Get the outputs for the current timestep."
44 | return self.nextYs[-1]
45 |
46 | def getCurrentOrigin(self):
47 | "Get the backpointers for the current timestep."
48 | return self.prevKs[-1]
49 |
50 | def advance(self, wordLk, attnOut):
51 | """
52 | Given prob over words for every last beam `wordLk` and attention
53 | `attnOut`: Compute and update the beam search.
54 |
55 | Parameters:
56 |
57 | * `wordLk`- probs of advancing from the last step (K x words)
58 | * `attnOut`- attention at the last step
59 |
60 | Returns: True if beam search is complete.
61 | """
62 | numWords = wordLk.size(1)
63 |
64 | # Sum the previous scores.
65 | if len(self.prevKs) > 0:
66 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
67 | else:
68 | beamLk = wordLk[0]
69 |
70 | flatBeamLk = beamLk.view(-1)
71 |
72 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
73 | self.allScores.append(self.scores)
74 | self.scores = bestScores
75 |
76 | # bestScoresId is flattened beam x word array, so calculate which
77 | # word and beam each score came from
78 | prevK = bestScoresId / numWords
79 | self.prevKs.append(prevK)
80 | self.nextYs.append(bestScoresId - prevK * numWords)
81 | self.attn.append(attnOut.index_select(0, prevK))
82 |
83 | # End condition is when top-of-beam is EOS.
84 | if self.nextYs[-1][0] == onmt.Constants.EOS:
85 | self.done = True
86 | self.allScores.append(self.scores)
87 |
88 | return self.done
89 |
90 | def sortBest(self):
91 | return torch.sort(self.scores, 0, True)
92 |
93 | def getBest(self):
94 | "Get the score of the best in the beam."
95 | scores, ids = self.sortBest()
96 | return scores[1], ids[1]
97 |
98 | def getHyp(self, k):
99 | """
100 | Walk back to construct the full hypothesis.
101 |
102 | Parameters.
103 |
104 | * `k` - the position in the beam to construct.
105 |
106 | Returns.
107 |
108 | 1. The hypothesis
109 | 2. The attention at each time step.
110 | """
111 | hyp, attn = [], []
112 | # print(len(self.prevKs), len(self.nextYs), len(self.attn))
113 | for j in range(len(self.prevKs) - 1, -1, -1):
114 | hyp.append(self.nextYs[j+1][k])
115 | attn.append(self.attn[j][k])
116 | k = self.prevKs[j][k]
117 |
118 | return hyp[::-1], torch.stack(attn[::-1])
119 |
--------------------------------------------------------------------------------
/onmt/Constants.py:
--------------------------------------------------------------------------------
1 |
2 | PAD = 0
3 | UNK = 1
4 | BOS = 2
5 | EOS = 3
6 |
7 | PAD_WORD = ''
8 | UNK_WORD = ''
9 | BOS_WORD = ''
10 | EOS_WORD = ''
11 |
--------------------------------------------------------------------------------
/onmt/Dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import math
4 | import torch
5 | from torch.autograd import Variable
6 |
7 | import onmt
8 |
9 |
10 | class Dataset(object):
11 | def __init__(self, srcData, tgtData, batchSize, cuda,
12 | volatile=False, data_type="text"):
13 | self.src = srcData
14 | self._type = data_type
15 | if tgtData:
16 | self.tgt = tgtData
17 | assert(len(self.src) == len(self.tgt))
18 | else:
19 | self.tgt = None
20 | self.cuda = cuda
21 |
22 | self.batchSize = batchSize
23 | self.numBatches = math.ceil(len(self.src)/batchSize)
24 | self.volatile = volatile
25 |
26 | def _batchify(self, data, align_right=False,
27 | include_lengths=False, dtype="text"):
28 | if dtype in ["text", "bitext", "monotext"]:
29 | lengths = [x.size(0) for x in data]
30 | max_length = max(lengths)
31 | out = data[0].new(len(data), max_length).fill_(onmt.Constants.PAD)
32 | for i in range(len(data)):
33 | data_length = data[i].size(0)
34 | offset = max_length - data_length if align_right else 0
35 | out[i].narrow(0, offset, data_length).copy_(data[i])
36 | if include_lengths:
37 | return out, lengths
38 | else:
39 | return out
40 | elif dtype == "img":
41 | heights = [x.size(1) for x in data]
42 | max_height = max(heights)
43 | widths = [x.size(2) for x in data]
44 | max_width = max(widths)
45 |
46 | out = data[0].new(len(data), 3, max_height, max_width).fill_(0)
47 | for i in range(len(data)):
48 | data_height = data[i].size(1)
49 | data_width = data[i].size(2)
50 | height_offset = max_height - data_height if align_right else 0
51 | width_offset = max_width - data_width if align_right else 0
52 | out[i].narrow(1, height_offset, data_height) \
53 | .narrow(2, width_offset, data_width).copy_(data[i])
54 | return out, widths
55 |
56 | def __getitem__(self, index):
57 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches)
58 | srcBatch, lengths = self._batchify(
59 | self.src[index*self.batchSize:(index+1)*self.batchSize],
60 | align_right=False, include_lengths=True, dtype=self._type)
61 |
62 | if self.tgt:
63 | tgtBatch = self._batchify(
64 | self.tgt[index*self.batchSize:(index+1)*self.batchSize],
65 | dtype="text")
66 | else:
67 | tgtBatch = None
68 |
69 | # within batch sorting by decreasing length for variable length rnns
70 | indices = range(len(srcBatch))
71 | batch = (zip(indices, srcBatch) if tgtBatch is None
72 | else zip(indices, srcBatch, tgtBatch))
73 | batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1]))
74 | if tgtBatch is None:
75 | indices, srcBatch = zip(*batch)
76 | else:
77 | indices, srcBatch, tgtBatch = zip(*batch)
78 |
79 | def wrap(b, dtype="text"):
80 | if b is None:
81 | return b
82 | b = torch.stack(b, 0)
83 | if dtype in ["text", "bitext", "monotext"]:
84 | b = b.t().contiguous()
85 | if self.cuda:
86 | b = b.cuda()
87 | b = Variable(b, volatile=self.volatile)
88 | return b
89 |
90 | # wrap lengths in a Variable to properly split it in DataParallel
91 | lengths = torch.LongTensor(lengths).view(1, -1)
92 | lengths = Variable(lengths, volatile=self.volatile)
93 | return (wrap(srcBatch, self._type), lengths), \
94 | wrap(tgtBatch, "text"), indices
95 |
96 | def __len__(self):
97 | return self.numBatches
98 |
99 | def shuffle(self):
100 | data = list(zip(self.src, self.tgt))
101 | self.src, self.tgt = zip(*[data[i] for i in torch.randperm(len(data))])
102 |
--------------------------------------------------------------------------------
/onmt/Decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.parameter import Parameter
5 | import onmt
6 |
7 | from onmt.modules import SRU
8 |
9 | from .modules.SRU_units import AttSRU
10 | from .modules.Attention import getAttention
11 | from .modules.Normalization import LayerNorm
12 | from torch.nn.utils.rnn import pad_packed_sequence as unpack
13 |
14 |
15 | def getDecoder(decoderType):
16 | decoders = {'StackedRNN': StackedRNNDecoder,
17 | 'SR': SGUDecoder,
18 | 'ParallelRNN': ParallelRNNDecoder}
19 |
20 | if decoderType not in decoders:
21 | raise NotImplementedError(decoderType)
22 |
23 | return decoders[decoderType]
24 |
25 |
26 | def getStackedLayer(rnn_type):
27 | if rnn_type == "LSTM":
28 | return StackedLSTM
29 | elif rnn_type == "GRU":
30 | return StackedGRU
31 | else:
32 | return None
33 |
34 | def getRNN(rnn_type):
35 | rnns = {'LSTM': nn.LSTM,
36 | 'GRU': nn.GRU
37 | }
38 |
39 | return rnns[rnn_type]
40 |
41 | class StackedLSTM(nn.Module):
42 |
43 | def __init__(self, num_layers, input_size, rnn_size, dropout):
44 | super(StackedLSTM, self).__init__()
45 | self.dropout = nn.Dropout(dropout)
46 | self.num_layers = num_layers
47 | self.layers = nn.ModuleList()
48 |
49 | for i in range(num_layers):
50 | self.layers.append(nn.LSTMCell(input_size, rnn_size))
51 | input_size = rnn_size
52 |
53 | def forward(self, input, hidden):
54 | h_0, c_0 = hidden
55 | h_1, c_1 = [], []
56 | for i, layer in enumerate(self.layers):
57 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
58 | input = h_1_i
59 | if i + 1 != self.num_layers:
60 | input = self.dropout(input)
61 | h_1 += [h_1_i]
62 | c_1 += [c_1_i]
63 |
64 | h_1 = torch.stack(h_1)
65 | c_1 = torch.stack(c_1)
66 |
67 | return input, (h_1, c_1)
68 |
69 |
70 | class StackedGRU(nn.Module):
71 |
72 | def __init__(self, num_layers, input_size, rnn_size, dropout):
73 | super(StackedGRU, self).__init__()
74 | self.dropout = nn.Dropout(dropout)
75 | self.num_layers = num_layers
76 | self.layers = nn.ModuleList()
77 |
78 | for i in range(num_layers):
79 | self.layers.append(nn.GRUCell(input_size, rnn_size))
80 | input_size = rnn_size
81 |
82 | def forward(self, input, hidden):
83 | h_1 = []
84 | for i, layer in enumerate(self.layers):
85 | h_1_i = layer(input, hidden[i])
86 | input = h_1_i
87 | if i + 1 != self.num_layers:
88 | input = self.dropout(input)
89 | h_1 += [h_1_i]
90 |
91 | h_1 = torch.stack(h_1)
92 |
93 | return input, h_1
94 |
95 |
96 | class StackedSGU(nn.Module):
97 |
98 | def __init__(self, num_layers, input_size, rnn_size, layer_norm, dropout):
99 | super(StackedSGU, self).__init__()
100 | self.num_layers = num_layers
101 | self.layers = nn.ModuleList()
102 | self.dropout = nn.Dropout(dropout)
103 | for i in range(num_layers):
104 | self.layers.append(AttSRU(input_size,
105 | rnn_size, rnn_size, layer_norm, dropout))
106 | input_size = rnn_size
107 |
108 | def initialize_parameters(self, param_init):
109 | for layer in self.layers:
110 | layer.initialize_parameters(param_init)
111 |
112 | def forward(self, dec_state, hidden, enc_out):
113 | input = dec_state
114 | first_input = dec_state
115 | hiddens = []
116 | for i, layer in enumerate(self.layers):
117 | input, new_hidden, attn_state = layer(input, hidden[i], enc_out)
118 | hiddens += [new_hidden]
119 |
120 | return self.dropout(input), torch.stack(hiddens), attn_state
121 |
122 |
123 | class StackedRNNDecoder(nn.Module):
124 |
125 | def __init__(self, opt, dicts):
126 |
127 | self.layers = opt.layers_dec
128 | self.input_feed = opt.input_feed
129 | self.hidden_size = opt.rnn_size
130 |
131 | input_size = opt.word_vec_size
132 | if self.input_feed:
133 | input_size += opt.rnn_size
134 |
135 | super(StackedRNNDecoder, self).__init__()
136 | self.word_lut = nn.Embedding(dicts.size(),
137 | opt.word_vec_size,
138 | padding_idx=onmt.Constants.PAD)
139 |
140 | rnn_type = opt.rnn_decoder_type if opt.rnn_decoder_type else opt.rnn_type
141 | if rnn_type in ['LSTM', 'GRU']:
142 | self.rnn = getStackedLayer(rnn_type)\
143 | (opt.layers_dec, input_size, opt.rnn_size, opt.dropout)
144 | else:
145 | self.rnn = getStackedLayer(rnn_type) \
146 | (opt.layers_dec, input_size, opt.rnn_size, opt.activ,
147 | opt.layer_norm, opt.dropout)
148 |
149 | self.attn = getAttention(opt.attn_type)(opt.rnn_size, opt.activ)
150 |
151 | self.linear_ctx = nn.Linear(opt.rnn_size, opt.rnn_size)
152 | self.linear_out = nn.Linear(2 * opt.rnn_size, opt.rnn_size)
153 |
154 | self.dropout = nn.Dropout(opt.dropout)
155 | self.log = self.rnn.log if hasattr(self.rnn, 'log') else False
156 |
157 | self.layer_norm = opt.layer_norm
158 | if self.layer_norm:
159 | self.ctx_ln = LayerNorm(opt.rnn_size)
160 |
161 | self.activ = getattr(F, opt.activ)
162 |
163 | def load_pretrained_vectors(self, opt):
164 | if opt.pre_word_vecs_dec is not None:
165 | pretrained = torch.load(opt.pre_word_vecs_dec)
166 | self.word_lut.weight.data.copy_(pretrained)
167 |
168 | def initialize_parameters(self, param_init):
169 | pass
170 |
171 | def forward(self, input, hidden, context, init_output):
172 | """
173 | input: targetL x batch
174 | hidden: batch x hidden_dim
175 | context: sourceL x batch x hidden_dim
176 | init_output: batch x hidden_dim
177 | """
178 | # targetL x batch x hidden_dim
179 | emb = self.word_lut(input)
180 |
181 | # batch x sourceL x hidden_dim
182 | context = context.transpose(0, 1)
183 |
184 | # n.b. you can increase performance if you compute W_ih * x for all
185 | # iterations in parallel, but that's only possible if
186 | # self.input_feed=False
187 | outputs = []
188 | output = init_output
189 |
190 | for emb_t in emb.split(1):
191 | # batch x word_dim
192 | emb_inp = emb_t.squeeze(0)
193 |
194 | if self.input_feed == 1:
195 | # batch x (word_dim+hidden_dim)
196 | emb_inp_feed = torch.cat([emb_inp, output], 1)
197 | else:
198 | emb_inp_feed = emb_inp
199 |
200 | # batch x hidden_dim, layers x batch x hidden_dim
201 | if self.log:
202 | rnn_output, hidden, activ = self.rnn(emb_inp_feed, hidden)
203 | else:
204 | rnn_output, hidden = self.rnn(emb_inp_feed, hidden)
205 |
206 | values = context
207 | pctx = self.linear_ctx(self.dropout(context))
208 | if self.layer_norm:
209 | pctx = self.ctx_ln(pctx)
210 | weightedContext, attn = self.attn(rnn_output, pctx, values)
211 |
212 | contextCombined = self.linear_out(torch.cat([rnn_output, weightedContext], dim=-1))
213 |
214 | output = self.activ(contextCombined)
215 | output = self.dropout(output)
216 | outputs += [output]
217 |
218 | outputs = torch.stack(outputs)
219 |
220 | if self.log:
221 | return outputs, hidden, attn, activ
222 |
223 | return outputs, hidden, attn
224 |
225 |
226 | class SGUDecoder(nn.Module):
227 |
228 | def __init__(self, opt, dicts):
229 | self.layers = opt.layers_dec
230 | self.hidden_size = opt.rnn_size
231 |
232 | input_size = opt.word_vec_size
233 |
234 | super(SGUDecoder, self).__init__()
235 | self.word_lut = nn.Embedding(dicts.size(),
236 | opt.word_vec_size,
237 | padding_idx=onmt.Constants.PAD)
238 |
239 | self.stacked = StackedSGU(opt.layers_dec, opt.rnn_size,
240 | opt.rnn_size, opt.layer_norm,
241 | opt.dropout)
242 |
243 | self.log = False
244 |
245 | def load_pretrained_vectors(self, opt):
246 | if opt.pre_word_vecs_dec is not None:
247 | pretrained = torch.load(opt.pre_word_vecs_dec)
248 | self.word_lut.weight.data.copy_(pretrained)
249 |
250 | def initialize_parameters(self, param_init):
251 | self.stacked.initialize_parameters(param_init)
252 | #self.attn.initialize_parameters(param_init)
253 |
254 | def forward(self, input, hidden, context, init_output):
255 | """
256 | input: targetL x batch
257 | hidden: num_layers x batch x hidden_dim
258 | context: sourceL x batch x hidden_dim
259 | init_output: batch x hidden_dim
260 | """
261 | batch_size = input.size(1)
262 | hidden_dim = context.size(2)
263 |
264 | #targetL x batch x hidden_dim
265 | emb = self.word_lut(input)
266 |
267 | # batch x sourceL x hidden_dim
268 | context = context.transpose(0, 1)
269 | if len(hidden.size()) < 3:
270 | hidden = hidden.unsqueeze(0)
271 |
272 | # (targetL x batch) x sourceL x hidden_dim
273 | #values = context.repeat(emb.size(0), 1, 1)
274 | rnn_outputs = emb #.view(-1, hidden_dim)
275 |
276 | outputs, hidden, attn = self.stacked(rnn_outputs, hidden, context)
277 |
278 | return outputs, hidden, attn
279 |
280 |
281 | class StackedSRU(nn.Module):
282 | def __init__(self, num_layers, input_size, rnn_size, dropout):
283 | super(StackedSRU, self).__init__()
284 | self.dropout = nn.Dropout(dropout)
285 | self.num_layers = num_layers
286 | self.layers = nn.ModuleList()
287 |
288 | for i in range(num_layers):
289 | self.layers.append(SRU(input_size, rnn_size, dropout))
290 | input_size = rnn_size
291 |
292 | def initialize_parameters(self, param_init):
293 | for layer in self.layers:
294 | layer.initialize_parameters(param_init)
295 |
296 | def forward(self, input, hidden):
297 | """
298 |
299 | :param input: batch x hi
300 | :param hidden:
301 | :return:
302 | """
303 | h_1 = []
304 | for i, layer in enumerate(self.layers):
305 | h_1_i, h = layer(input, hidden[i])
306 | input = h_1_i
307 | h_1 += [h]
308 |
309 | h_1 = torch.stack(h_1)
310 |
311 | return input, h_1
312 |
313 |
314 | class ParallelRNNDecoder(nn.Module):
315 | def __init__(self, opt, dicts):
316 | from .modules.Attention import MLPAttention
317 |
318 | self.layers = opt.layers_dec
319 | self.hidden_size = opt.rnn_size
320 |
321 | input_size = opt.word_vec_size
322 |
323 | super(ParallelRNNDecoder, self).__init__()
324 | self.word_lut = nn.Embedding(dicts.size(),
325 | opt.word_vec_size,
326 | padding_idx=onmt.Constants.PAD)
327 |
328 | self.rnn = StackedSRU(self.layers, input_size, self.hidden_size, opt.dropout)
329 |
330 | self.attn = MLPAttention(opt.rnn_size, opt.activ) # getAttention(opt.attn_type)(opt.rnn_size, opt.activ)
331 |
332 | self.linear_ctx = nn.Linear(opt.rnn_size, opt.rnn_size)
333 | self.linear_out = nn.Linear(2 * opt.rnn_size, opt.rnn_size)
334 |
335 | self.dropout = nn.Dropout(opt.dropout)
336 |
337 | def load_pretrained_vectors(self, opt):
338 | if opt.pre_word_vecs_dec is not None:
339 | pretrained = torch.load(opt.pre_word_vecs_dec)
340 | self.word_lut.weight.data.copy_(pretrained)
341 |
342 | def initialize_parameters(self, param_init):
343 | pass
344 |
345 | def forward(self, input, hidden, context, init_output):
346 | """
347 | input: targetL x batch
348 | hidden: batch x hidden_dim
349 | context: sourceL x batch x hidden_dim
350 | init_output: batch x hidden_dim
351 | """
352 | # targetL x batch x hidden_dim
353 | emb = self.word_lut(input)
354 |
355 | # batch x sourceL x hidden_dim
356 | context = context.transpose(0, 1)
357 |
358 | # batch x hidden_dim, layers x batch x hidden_dim
359 | rnn_output, hidden = self.rnn(emb, hidden)
360 |
361 | values = context
362 | pctx = self.linear_ctx(self.dropout(context))
363 |
364 | weightedContext, attn = self.attn(self.dropout(rnn_output), pctx, values)
365 |
366 | contextCombined = self.linear_out(self.dropout(torch.cat([rnn_output, weightedContext], dim=-1)))
367 |
368 | output = F.tanh(contextCombined)
369 | output = self.dropout(output)
370 |
371 | return output, hidden, attn
372 |
--------------------------------------------------------------------------------
/onmt/Dict.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Dict(object):
5 | def __init__(self, data=None, lower=False):
6 | self.idxToLabel = {}
7 | self.labelToIdx = {}
8 | self.frequencies = {}
9 | self.lower = lower
10 |
11 | # Special entries will not be pruned.
12 | self.special = []
13 |
14 | if data is not None:
15 | if type(data) == str:
16 | self.loadFile(data)
17 | else:
18 | self.addSpecials(data)
19 |
20 | def size(self):
21 | return len(self.idxToLabel)
22 |
23 | def loadFile(self, filename):
24 | "Load entries from a file."
25 | for line in open(filename):
26 | fields = line.split()
27 | label = fields[0]
28 | idx = int(fields[1])
29 | self.add(label, idx)
30 |
31 | def writeFile(self, filename):
32 | "Write entries to a file."
33 | with open(filename, 'w') as file:
34 | for i in range(self.size()):
35 | label = self.idxToLabel[i]
36 | file.write('%s %d\n' % (label, i))
37 |
38 | file.close()
39 |
40 | def lookup(self, key, default=None):
41 | key = key.lower() if self.lower else key
42 | try:
43 | return self.labelToIdx[key]
44 | except KeyError:
45 | return default
46 |
47 | def getLabel(self, idx, default=None):
48 | try:
49 | return self.idxToLabel[idx]
50 | except KeyError:
51 | return default
52 |
53 | def addSpecial(self, label, idx=None):
54 | "Mark this `label` and `idx` as special (i.e. will not be pruned)."
55 | idx = self.add(label, idx)
56 | self.special += [idx]
57 |
58 | def addSpecials(self, labels):
59 | "Mark all labels in `labels` as specials (i.e. will not be pruned)."
60 | for label in labels:
61 | self.addSpecial(label)
62 |
63 | def add(self, label, idx=None):
64 | "Add `label` in the dictionary. Use `idx` as its index if given."
65 | label = label.lower() if self.lower else label
66 | if idx is not None:
67 | self.idxToLabel[idx] = label
68 | self.labelToIdx[label] = idx
69 | else:
70 | if label in self.labelToIdx:
71 | idx = self.labelToIdx[label]
72 | else:
73 | idx = len(self.idxToLabel)
74 | self.idxToLabel[idx] = label
75 | self.labelToIdx[label] = idx
76 |
77 | if idx not in self.frequencies:
78 | self.frequencies[idx] = 1
79 | else:
80 | self.frequencies[idx] += 1
81 |
82 | return idx
83 |
84 | def prune(self, size):
85 | "Return a new dictionary with the `size` most frequent entries."
86 | if size >= self.size():
87 | return self
88 |
89 | # Only keep the `size` most frequent entries.
90 | freq = torch.Tensor(
91 | [self.frequencies[i] for i in range(len(self.frequencies))])
92 | _, idx = torch.sort(freq, 0, True)
93 |
94 | newDict = Dict()
95 | newDict.lower = self.lower
96 |
97 | # Add special entries in all cases.
98 | for i in self.special:
99 | newDict.addSpecial(self.idxToLabel[i])
100 |
101 | for i in idx[:size]:
102 | newDict.add(self.idxToLabel[i])
103 |
104 | return newDict
105 |
106 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None):
107 | """
108 | Convert `labels` to indices. Use `unkWord` if not found.
109 | Optionally insert `bosWord` at the beginning and `eosWord` at the .
110 | """
111 | vec = []
112 |
113 | if bosWord is not None:
114 | vec += [self.lookup(bosWord)]
115 |
116 | unk = self.lookup(unkWord)
117 | vec += [self.lookup(label, default=unk) for label in labels]
118 |
119 | if eosWord is not None:
120 | vec += [self.lookup(eosWord)]
121 |
122 | return torch.LongTensor(vec)
123 |
124 | def convertToLabels(self, idx, stop):
125 | """
126 | Convert `idx` to labels.
127 | If index `stop` is reached, convert it and return.
128 | """
129 |
130 | labels = []
131 |
132 | for i in idx:
133 | labels += [self.getLabel(i)]
134 | if i == stop:
135 | break
136 |
137 | return labels
138 |
--------------------------------------------------------------------------------
/onmt/Encoders.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 |
5 | import onmt
6 | from .modules.SRU_units import BiSRU
7 |
8 | from torch.nn.utils.rnn import PackedSequence
9 | from torch.nn.utils.rnn import pad_packed_sequence as unpack
10 | from torch.nn.utils.rnn import pack_padded_sequence as pack
11 |
12 | from onmt.modules.Units import ParallelMyRNN
13 |
14 | def getEncoder(encoder_type):
15 | encoders = {'RNN': Encoder,
16 | 'SR': SGUEncoder}
17 | if encoder_type not in encoders:
18 | raise NotImplementedError(encoder_type)
19 | return encoders[encoder_type]
20 |
21 |
22 | class Encoder(nn.Module):
23 |
24 | def __init__(self, opt, dicts):
25 |
26 | def getunittype(rnn_type):
27 | if rnn_type in ['LSTM', 'GRU']:
28 | return getattr(nn, rnn_type)
29 | elif rnn_type == 'SRU':
30 | return ParallelMyRNN
31 |
32 | self.layers = opt.layers_enc
33 | self.num_directions = 2 if opt.brnn else 1
34 | assert opt.rnn_size % self.num_directions == 0
35 | self.hidden_size = opt.rnn_size // self.num_directions
36 |
37 | super(Encoder, self).__init__()
38 | self.word_lut = nn.Embedding(dicts.size(),
39 | opt.word_vec_size,
40 | padding_idx=onmt.Constants.PAD)
41 |
42 | rnn_type = opt.rnn_encoder_type if opt.rnn_encoder_type else opt.rnn_type
43 | self.rnn = getunittype(rnn_type)(
44 | opt.word_vec_size, self.hidden_size,
45 | num_layers=opt.layers_enc, dropout=opt.dropout,
46 | bidirectional=opt.brnn)
47 |
48 | def load_pretrained_vectors(self, opt):
49 | if opt.pre_word_vecs_enc is not None:
50 | pretrained = torch.load(opt.pre_word_vecs_enc)
51 | self.word_lut.weight.data.copy_(pretrained)
52 |
53 | def initialize_parameters(self, param_init):
54 | if hasattr(self.rnn, 'initialize_parameters'):
55 | self.rnn.initialize_parameters(param_init)
56 |
57 | def forward(self, input, hidden=None):
58 |
59 | if isinstance(input, tuple):
60 | # Lengths data is wrapped inside a Variable.
61 | lengths = input[1].data.view(-1).tolist()
62 | emb = pack(self.word_lut(input[0]), lengths)
63 | else:
64 | emb = self.word_lut(input)
65 | outputs, hidden_t = self.rnn(emb, hidden)
66 | if isinstance(outputs, PackedSequence):
67 | outputs = unpack(outputs)[0]
68 |
69 | return hidden_t, outputs, emb
70 |
71 | class StackedSGU(nn.Module):
72 |
73 | def __init__(self, layers, input_size, hidden_size, layer_norm, dropout):
74 | self.layers = layers
75 | super(StackedSGU, self).__init__()
76 | self.sgus = nn.ModuleList()
77 | self.dropout = nn.Dropout(dropout)
78 | for _ in range(layers):
79 | self.sgus.append(BiSRU(input_size, hidden_size, layer_norm, dropout))
80 | input_size = hidden_size
81 |
82 | def initialize_parameters(self, param_init):
83 | for sgu in self.sgus:
84 | sgu.initialize_parameters(param_init)
85 |
86 | def forward(self, input):
87 |
88 | hiddens = []
89 | for i in range(self.layers):
90 | input = self.sgus[i](input)
91 | hiddens += [input[-1]]
92 | return input, torch.stack(hiddens)
93 |
94 | class SGUEncoder(nn.Module):
95 |
96 | def __init__(self, opt, dicts):
97 |
98 | self.layers = opt.layers_enc
99 | self.num_directions = 2 if opt.brnn else 1
100 | assert opt.rnn_size % self.num_directions == 0
101 | self.hidden_size = opt.rnn_size // self.num_directions
102 |
103 | super(SGUEncoder, self).__init__()
104 | self.word_lut = nn.Embedding(dicts.size(),
105 | opt.word_vec_size,
106 | padding_idx=onmt.Constants.PAD)
107 | self.sgu = StackedSGU(self.layers, opt.word_vec_size,
108 | self.hidden_size * self.num_directions, opt.layer_norm,
109 | opt.dropout)
110 |
111 |
112 | def load_pretrained_vectors(self, opt):
113 | if opt.pre_word_vecs_enc is not None:
114 | pretrained = torch.load(opt.pre_word_vecs_enc)
115 | self.word_lut.weight.data.copy_(pretrained)
116 |
117 | def initialize_parameters(self, param_init):
118 | self.sgu.initialize_parameters(param_init)
119 |
120 | def forward(self, input, hidden=None):
121 |
122 | if isinstance(input, tuple):
123 | # Lengths data is wrapped inside a Variable.
124 | lengths = input[1].data.view(-1).tolist()
125 | emb = self.word_lut(input[0])
126 | else:
127 | emb = self.word_lut(input)
128 | outputs, hidden_t = self.sgu(emb)
129 |
130 | return hidden_t, outputs, emb
131 |
--------------------------------------------------------------------------------
/onmt/Markdown.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The Chromium Authors. All rights reserved.
2 | # Use of this source code is governed by a BSD-style license that can be
3 | # found in the LICENSE file.
4 | import argparse
5 |
6 |
7 | class MarkdownHelpFormatter(argparse.HelpFormatter):
8 | """A really bare-bones argparse help formatter that generates valid markdown.
9 | This will generate something like:
10 | usage
11 | # **section heading**:
12 | ## **--argument-one**
13 | ```
14 | argument-one help text
15 | ```
16 | """
17 |
18 | def _format_usage(self, usage, actions, groups, prefix):
19 | usage_text = super(MarkdownHelpFormatter, self)._format_usage(
20 | usage, actions, groups, prefix)
21 | return '\n```\n%s\n```\n\n' % usage_text
22 |
23 | def format_help(self):
24 | self._root_section.heading = '# %s' % self._prog
25 | return super(MarkdownHelpFormatter, self).format_help()
26 |
27 | def start_section(self, heading):
28 | super(MarkdownHelpFormatter, self).start_section('## **%s**' % heading)
29 |
30 | def _format_action(self, action):
31 | lines = []
32 | action_header = self._format_action_invocation(action)
33 | lines.append('### **%s** ' % action_header)
34 | if action.help:
35 | lines.append('')
36 | lines.append('```')
37 | help_text = self._expand_help(action)
38 | lines.extend(self._split_lines(help_text, 80))
39 | lines.append('```')
40 | lines.extend(['', ''])
41 | return '\n'.join(lines)
42 |
43 |
44 | class MarkdownHelpAction(argparse.Action):
45 | def __init__(self, option_strings,
46 | dest=argparse.SUPPRESS, default=argparse.SUPPRESS,
47 | **kwargs):
48 | super(MarkdownHelpAction, self).__init__(
49 | option_strings=option_strings,
50 | dest=dest,
51 | default=default,
52 | nargs=0,
53 | **kwargs)
54 |
55 | def __call__(self, parser, namespace, values, option_string=None):
56 | parser.formatter_class = MarkdownHelpFormatter
57 | parser.print_help()
58 | parser.exit()
59 |
60 |
61 | def add_md_help_argument(parser):
62 | parser.add_argument('-md', action=MarkdownHelpAction,
63 | help='print Markdown-formatted help text and exit.')
64 |
--------------------------------------------------------------------------------
/onmt/Models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.autograd import Variable
3 | from .Encoders import Encoder
4 |
5 |
6 | class NMTModel(nn.Module):
7 |
8 | def __init__(self, encoder, decoder):
9 | super(NMTModel, self).__init__()
10 | self.encoder = encoder
11 | self.decoder = decoder
12 |
13 | def make_init_decoder_output(self, context):
14 | batch_size = context.size(1)
15 | h_size = (batch_size, self.decoder.hidden_size)
16 | return Variable(context.data.new(*h_size).zero_(), requires_grad=False)
17 |
18 | def load_pretrained_vectors(self, opt):
19 | self.encoder.load_pretrained_vectors(opt)
20 | self.decoder.load_pretrained_vectors(opt)
21 |
22 | def initialize_parameters(self, param_init):
23 | self.encoder.initialize_parameters(param_init)
24 | self.decoder.initialize_parameters(param_init)
25 |
26 | def brnn_merge_concat(self, h):
27 | # the encoder hidden is (layers*directions) x batch x dim
28 | # we need to convert it to layers x batch x (directions*dim)
29 | if self.encoder.num_directions == 2:
30 | return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
31 | .transpose(1, 2).contiguous() \
32 | .view(h.size(0) // 2, h.size(1), h.size(2) * 2)
33 | else:
34 | return h
35 |
36 | def forward(self, input):
37 | src = input[0]
38 | tgt = input[1][:-1] # exclude last target from inputs
39 | enc_hidden, context, emb = self.encoder(src)
40 | init_output = self.make_init_decoder_output(context)
41 |
42 | if isinstance(self.encoder, Encoder):
43 | if isinstance(enc_hidden, tuple):
44 | enc_hidden = tuple(self.brnn_merge_concat(enc_hidden[i])
45 | for i in range(len(enc_hidden)))
46 | else:
47 | enc_hidden = self.brnn_merge_concat(enc_hidden)
48 | if enc_hidden.size(0) < self.decoder.layers:
49 | enc_hidden = enc_hidden.repeat(self.decoder.layers, 1, 1)
50 | else:
51 | enc_hidden = Variable(enc_hidden.data.new(*enc_hidden.size()).zero_(), requires_grad=False)
52 |
53 | #self.decoder.mask_attention(src[0])
54 | out, dec_hidden, _attn = self.decoder(tgt, enc_hidden,
55 | context, init_output)
56 | return out
57 |
--------------------------------------------------------------------------------
/onmt/Optim.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | from torch.nn.utils import clip_grad_norm
3 |
4 |
5 | class Optim(object):
6 |
7 | def set_parameters(self, params):
8 | self.params = list(params) # careful: params may be a generator
9 |
10 | if self.method == 'sgd':
11 | self.optimizer = optim.SGD(self.params, lr=self.lr, momentum=0.9, weight_decay=1e-5, nesterov=True)
12 | elif self.method == 'adagrad':
13 | self.optimizer = optim.Adagrad(self.params, lr=self.lr)
14 | elif self.method == 'adadelta':
15 | self.optimizer = optim.Adadelta(self.params, lr=self.lr)
16 | elif self.method == 'adam':
17 | self.optimizer = optim.Adam(self.params, lr=self.lr)
18 | else:
19 | raise RuntimeError("Invalid optim method: " + self.method)
20 |
21 | def __init__(self, method, lr, max_grad_norm,
22 | lr_decay=1, start_decay_at=None):
23 | self.last_ppl = None
24 | self.lr = lr
25 | self.max_grad_norm = max_grad_norm
26 | self.method = method
27 | self.lr_decay = lr_decay
28 | self.start_decay_at = start_decay_at
29 | self.start_decay = False
30 |
31 | def step(self):
32 | "Compute gradients norm."
33 | if self.max_grad_norm:
34 | clip_grad_norm(self.params, self.max_grad_norm)
35 | self.optimizer.step()
36 |
37 |
38 | def updateLearningRate(self, ppl, iter):
39 | """
40 | Decay learning rate if val perf does not improve
41 | or we hit the start_decay_at limit.
42 | """
43 |
44 | if self.start_decay_at is not None and iter >= self.start_decay_at:
45 | self.start_decay = True
46 | if self.last_ppl is not None and ppl > self.last_ppl:
47 | self.start_decay = True
48 |
49 | if self.start_decay:
50 | self.lr = self.lr * self.lr_decay
51 | print("Decaying learning rate to %g" % self.lr)
52 |
53 | self.last_ppl = ppl
54 | self.optimizer.param_groups[0]['lr'] = self.lr
55 |
--------------------------------------------------------------------------------
/onmt/Translator.py:
--------------------------------------------------------------------------------
1 | import onmt
2 | import onmt.Models
3 | import onmt.modules
4 | import torch.nn as nn
5 | import torch
6 | from torch.autograd import Variable
7 | from .Decoders import getDecoder, SGUDecoder
8 | from .Encoders import getEncoder, Encoder, SGUEncoder
9 | from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence as unpack
10 | import sys
11 |
12 |
13 |
14 |
15 | def loadImageLibs():
16 | "Conditional import of torch image libs."
17 | global Image, transforms
18 | from PIL import Image
19 | from torchvision import transforms
20 |
21 |
22 | class Translator(object):
23 | def __init__(self, opt):
24 | self.opt = opt
25 | self.tt = torch.cuda if opt.cuda else torch
26 | self.beam_accum = None
27 |
28 | checkpoint = torch.load(opt.model,
29 | map_location=lambda storage, loc: storage)
30 |
31 | model_opt = checkpoint['opt']
32 | self.model_opt = model_opt
33 | self.src_dict = checkpoint['dicts']['src']
34 | self.tgt_dict = checkpoint['dicts']['tgt']
35 | self._type = "text" #model_opt.encoder_type \
36 | #if "encoder_type" in model_opt else "text"
37 |
38 | #if self._type == "text":
39 | # encoder = Encoder(model_opt, self.src_dict)
40 | #elif self._type == "img":
41 | # loadImageLibs()
42 | # encoder = onmt.modules.ImageEncoder(model_opt)
43 | print("Translator layer_norm:", model_opt.layer_norm)
44 |
45 | encoder = getEncoder(model_opt.encoder_type)(model_opt, self.src_dict)
46 | decoder = getDecoder(model_opt.decoder_type)(model_opt, self.tgt_dict)
47 | model = onmt.Models.NMTModel(encoder, decoder)
48 |
49 | generator = nn.Sequential(
50 | nn.Linear(model_opt.rnn_size, self.tgt_dict.size()),
51 | nn.LogSoftmax())
52 |
53 | model.load_state_dict(checkpoint['model'])
54 | generator.load_state_dict(checkpoint['generator'])
55 |
56 | if opt.cuda:
57 | model.cuda()
58 | generator.cuda()
59 | else:
60 | model.cpu()
61 | generator.cpu()
62 |
63 | model.generator = generator
64 |
65 | self.model = model
66 | self.model.eval()
67 |
68 | def initBeamAccum(self):
69 | self.beam_accum = {
70 | "predicted_ids": [],
71 | "beam_parent_ids": [],
72 | "scores": [],
73 | "log_probs": []}
74 |
75 | def _getBatchSize(self, batch):
76 | if self._type == "text":
77 | return batch.size(1)
78 | else:
79 | return batch.size(0)
80 |
81 | def buildData(self, srcBatch, goldBatch):
82 | # This needs to be the same as preprocess.py.
83 | if self._type == "text":
84 | srcData = [self.src_dict.convertToIdx(b,
85 | onmt.Constants.UNK_WORD)
86 | for b in srcBatch]
87 | elif self._type == "img":
88 | srcData = [transforms.ToTensor()(
89 | Image.open(self.opt.src_img_dir + "/" + b[0]))
90 | for b in srcBatch]
91 |
92 | tgtData = None
93 | if goldBatch:
94 | tgtData = [self.tgt_dict.convertToIdx(b,
95 | onmt.Constants.UNK_WORD,
96 | onmt.Constants.BOS_WORD,
97 | onmt.Constants.EOS_WORD) for b in goldBatch]
98 |
99 | return onmt.Dataset(srcData, tgtData, self.opt.batch_size,
100 | self.opt.cuda, volatile=True,
101 | data_type=self._type)
102 |
103 | def buildTargetTokens(self, pred, src, attn):
104 | tokens = self.tgt_dict.convertToLabels(pred, onmt.Constants.EOS)
105 | tokens = tokens[:-1] # EOS
106 | if self.opt.replace_unk:
107 | for i in range(len(tokens)):
108 | if tokens[i] == onmt.Constants.UNK_WORD:
109 | _, maxIndex = attn[i].max(0)
110 | tokens[i] = src[maxIndex[0]]
111 | return tokens
112 |
113 | def translateBatch(self, srcBatch, tgtBatch):
114 | # Batch size is in different location depending on data.
115 |
116 | beamSize = self.opt.beam_size
117 |
118 | # (1) run the encoder on the src
119 | encStates, context, emb = self.model.encoder(srcBatch)
120 |
121 | # Drop the lengths needed for encoder.
122 | srcBatch = srcBatch[0]
123 | batchSize = self._getBatchSize(srcBatch)
124 |
125 | rnnSize = context.size(2)
126 | decoder = self.model.decoder
127 | attentionLayer = decoder.attn if hasattr(decoder, 'attn') else None
128 |
129 | if isinstance(self.model.encoder, Encoder):
130 | if isinstance(encStates, tuple):
131 | encStates = tuple(self.model.brnn_merge_concat(encStates[i])
132 | for i in range(len(encStates)))
133 | else:
134 | encStates = self.model.brnn_merge_concat(encStates)
135 | if encStates.size(0) < decoder.layers:
136 | encStates = encStates.repeat(decoder.layers, 1, 1)
137 | else:
138 | encStates = Variable(encStates.data.new(*encStates.size()).zero_(), requires_grad=False)
139 | # encStates = encStates.unsqueeze(0).repeat(decoder.layers, 1, 1)
140 |
141 |
142 | useMasking = not isinstance(decoder, SGUDecoder) #self._type.endswith("text")
143 |
144 | # This mask is applied to the attention model inside the decoder
145 | # so that the attention ignores source padding
146 | padMask = None
147 | if useMasking:
148 | padMask = srcBatch.data.eq(onmt.Constants.PAD).t()
149 |
150 | def mask(padMask):
151 | if useMasking:
152 | attentionLayer.applyMask(padMask)
153 |
154 | # (2) if a target is specified, compute the 'goldScore'
155 | # (i.e. log likelihood) of the target under the model
156 | goldScores = context.data.new(batchSize).zero_()
157 |
158 | if tgtBatch is not None:
159 |
160 | decStates = encStates
161 |
162 | mask(padMask)
163 | initOutput = self.model.make_init_decoder_output(context)
164 |
165 | decOut, decStates, attn = self.model.decoder(
166 | tgtBatch[:-1], decStates, context, initOutput)
167 | for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
168 | gen_t = self.model.generator.forward(dec_t)
169 | tgt_t = tgt_t.unsqueeze(1)
170 | scores = gen_t.data.gather(1, tgt_t)
171 | scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
172 | goldScores += scores
173 |
174 | # (3) run the decoder to generate sentences, using beam search
175 |
176 | # Expand tensors for each beam.
177 | context = Variable(context.data.repeat(1, beamSize, 1))
178 | if isinstance(emb, PackedSequence):
179 | emb = Variable(unpack(emb)[0].data.repeat(1, beamSize, 1))
180 | else:
181 | emb = Variable(emb.data.repeat(1, beamSize, 1))
182 |
183 | if isinstance(encStates, tuple):
184 | decStates = tuple(Variable(encStates[i].data.repeat(1, beamSize, 1))
185 | for i in range(len(encStates)))
186 | else:
187 | decStates = Variable(encStates.data.repeat(1, beamSize, 1))
188 |
189 | beam = [onmt.Beam(beamSize, self.opt.cuda) for _ in range(batchSize)]
190 |
191 | decOut = self.model.make_init_decoder_output(context)
192 |
193 | if useMasking:
194 | padMask = srcBatch.data.eq(
195 | onmt.Constants.PAD).t() \
196 | .unsqueeze(0) \
197 | .repeat(beamSize, 1, 1)
198 |
199 | batchIdx = list(range(batchSize))
200 | remainingSents = batchSize
201 |
202 | activs = []
203 | for i in range(self.opt.max_sent_length):
204 | mask(padMask)
205 | # Prepare decoder input.
206 | input = torch.stack([b.getCurrentState() for b in beam
207 | if not b.done]).t().contiguous().view(1, -1)
208 |
209 | #if self.model.decoder.log:
210 | # decOut, decStates, attn, activ = self.model.decoder(
211 | # Variable(input, volatile=True), decStates, context, decOut, emb)
212 | # activs.append(activ)
213 | #else:
214 | decOut, decStates, attn = self.model.decoder(
215 | Variable(input, volatile=True), decStates, context, decOut)
216 |
217 | # decOut: 1 x (beam*batch) x numWords
218 | decOut = decOut.squeeze(0)
219 | out = self.model.generator.forward(decOut)
220 |
221 | # batch x beam x numWords
222 | wordLk = out.view(beamSize, remainingSents, -1) \
223 | .transpose(0, 1).contiguous()
224 | attn = attn.view(beamSize, remainingSents, -1) \
225 | .transpose(0, 1).contiguous()
226 |
227 | active = []
228 | for b in range(batchSize):
229 | if beam[b].done:
230 | continue
231 |
232 | idx = batchIdx[b]
233 | if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
234 | active += [b]
235 | #print(decStates)
236 | if not isinstance(decStates, tuple):
237 | decStates = tuple(decStates.unsqueeze(0))
238 | #print(decStates)
239 | for decState in decStates: # iterate over h, c
240 | # layers x beam*sent x dim
241 | sentStates = decState.view(-1, beamSize,
242 | remainingSents,
243 | decState.size(2))[:, :, idx]
244 | sentStates.data.copy_(
245 | sentStates.data.index_select(
246 | 1, beam[b].getCurrentOrigin()))
247 |
248 | if not active:
249 | break
250 |
251 | # in this section, the sentences that are still active are
252 | # compacted so that the decoder is not run on completed sentences
253 | activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
254 | batchIdx = {beam: idx for idx, beam in enumerate(active)}
255 |
256 | def updateActive(t, lastSize=rnnSize):
257 | # select only the remaining active sentences
258 | view = t.data.view(-1, remainingSents, lastSize)
259 | newSize = list(t.size())
260 | newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
261 | return Variable(view.index_select(1, activeIdx)
262 | .view(*newSize), volatile=True)
263 |
264 | decStates = tuple(updateActive(decStates[i])
265 | for i in range(len(decStates)))
266 |
267 | if len(decStates) == 1:
268 | # The GRU needs only one matrix as hidden state
269 | decStates = decStates[0]
270 |
271 | decOut = updateActive(decOut)
272 | context = updateActive(context)
273 | emb = updateActive(emb, emb.size(2))
274 |
275 | if useMasking:
276 | padMask = padMask.index_select(1, activeIdx)
277 |
278 | remainingSents = len(active)
279 |
280 | # (4) package everything up
281 | allHyp, allScores, allAttn = [], [], []
282 | n_best = self.opt.n_best
283 |
284 | if activs:
285 | new_activs = torch.zeros((2, activs[0].size(1), len(activs)))
286 | for i, activ in enumerate(activs):
287 | new_activs[:, :activ.size(1), i] = activ.data
288 | activs = new_activs
289 | sys.stderr.write("r=\n")
290 | for i in range(activs.size(1)):
291 | for j in range(activs.size(2)):
292 | sys.stderr.write(str(activs[0][i][j]) + " ")
293 | sys.stderr.write("\n")
294 | sys.stderr.write("z=\n")
295 | for i in range(activs.size(1)):
296 | for j in range(activs.size(2)):
297 | sys.stderr.write(str(activs[1][i][j]) + " ")
298 | sys.stderr.write("\n")
299 |
300 | for b in range(batchSize):
301 | scores, ks = beam[b].sortBest()
302 |
303 | allScores += [scores[:n_best]]
304 | hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
305 | allHyp += [hyps]
306 | if useMasking:
307 | valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \
308 | .nonzero().squeeze(1)
309 | attn = [a.index_select(1, valid_attn) for a in attn]
310 | allAttn += [attn]
311 |
312 | if self.beam_accum:
313 | self.beam_accum["beam_parent_ids"].append(
314 | [t.tolist()
315 | for t in beam[b].prevKs])
316 | self.beam_accum["scores"].append([
317 | ["%4f" % s for s in t.tolist()]
318 | for t in beam[b].allScores][1:])
319 | self.beam_accum["predicted_ids"].append(
320 | [[self.tgt_dict.getLabel(id)
321 | for id in t.tolist()]
322 | for t in beam[b].nextYs][1:])
323 |
324 | return allHyp, allScores, allAttn, goldScores
325 |
326 | def translate(self, srcBatch, goldBatch):
327 | # (1) convert words to indexes
328 | dataset = self.buildData(srcBatch, goldBatch)
329 | src, tgt, indices = dataset[0]
330 | batchSize = self._getBatchSize(src[0])
331 |
332 | # (2) translate
333 | pred, predScore, attn, goldScore = self.translateBatch(src, tgt)
334 | pred, predScore, attn, goldScore = list(zip(
335 | *sorted(zip(pred, predScore, attn, goldScore, indices),
336 | key=lambda x: x[-1])))[:-1]
337 |
338 | # (3) convert indexes to words
339 | predBatch = []
340 | for b in range(batchSize):
341 | predBatch.append(
342 | [self.buildTargetTokens(pred[b][n], srcBatch[b], attn[b][n])
343 | for n in range(self.opt.n_best)]
344 | )
345 |
346 | return predBatch, predScore, goldScore
347 |
--------------------------------------------------------------------------------
/onmt/__init__.py:
--------------------------------------------------------------------------------
1 | import onmt.Constants
2 | import onmt.Models
3 | from onmt.Translator import Translator
4 | from onmt.Dataset import Dataset
5 | from onmt.Optim import Optim
6 | from onmt.Dict import Dict
7 | from onmt.Beam import Beam
8 |
9 | # For flake8 compatibility.
10 | __all__ = [onmt.Constants, onmt.Models, Translator, Dataset, Optim, Dict, Beam, Encoders]
11 |
--------------------------------------------------------------------------------
/onmt/modules/Attention.py:
--------------------------------------------------------------------------------
1 | """
2 | Global attention takes a matrix and a query vector. It
3 | then computes a parameterized convex combination of the matrix
4 | based on the input query.
5 |
6 |
7 | H_1 H_2 H_3 ... H_n
8 | q q q q
9 | | | | |
10 | \ | | /
11 | .....
12 | \ | /
13 | a
14 |
15 | Constructs a unit mapping.
16 | $$(H_1 + H_n, q) => (a)$$
17 | Where H is of `batch x n x dim` and q is of `batch x dim`.
18 |
19 | The full def is $$\tanh(W_2 [(softmax((W_1 q + b_1) H) H), q] + b_2)$$.:
20 |
21 | """
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | import numpy as np
27 | from .Normalization import LayerNorm
28 |
29 | def getAttention(attention_type):
30 | attns = {'dot': DotAttention,
31 | 'mlp': MLPAttentionGRU,
32 | }
33 |
34 | if attention_type not in attns:
35 | raise NotImplementedError(attention_type)
36 |
37 | return attns[attention_type]
38 |
39 |
40 | class DotAttention(nn.Module):
41 | def __init__(self, dim, enc_dim=None, layer_norm=False, activ='tanh'):
42 | super(DotAttention, self).__init__()
43 | self.mask = None
44 | if not enc_dim:
45 | enc_dim = dim
46 | out_dim = dim
47 | self.linear_in = nn.Linear(dim, out_dim, bias=False)
48 | self.layer_norm = layer_norm
49 | if self.layer_norm:
50 | self.ln_in = LayerNorm(dim)
51 |
52 | def applyMask(self, mask):
53 | self.mask = mask
54 |
55 | def initialize_parameters(self, param_init):
56 | pass
57 |
58 | def forward(self, input, context, values):
59 | """
60 | input: targetL x batch x dim
61 | context: batch x sourceL x dim
62 | """
63 | batch, sourceL, dim = context.size()
64 | targetT = self.ln_in(self.linear_in(input.transpose(0, 1))) # batch x targetL x dim
65 | context = context.transpose(1, 2) # batch x dim x sourceL
66 | # Get attention
67 | attn = torch.bmm(targetT, context) # batch x targetL x sourceL
68 | if self.mask is not None:
69 | attn.data.masked_fill_(self.mask, -float('inf'))
70 | attn = F.softmax(attn.view(-1, sourceL)) # (batch x targetL) x sourceL
71 | attn3 = attn.view(batch, -1, sourceL) # batch x targetL x sourceL
72 | weightedContext = torch.bmm(attn3, values).transpose(0, 1) # targetL x batch x dim
73 |
74 | return weightedContext, attn
75 |
76 |
77 | class MLPAttention(nn.Module):
78 | def __init__(self, dim, layer_norm=False, activ='tanh'):
79 | super(MLPAttention, self).__init__()
80 | self.dim = dim
81 | self.v = nn.Linear(self.dim, 1)
82 | self.combine_hid = nn.Linear(self.dim, self.dim)
83 | #self.combine_ctx = nn.Linear(self.dim, self.dim)
84 | self.mask = None
85 | self.activ = getattr(F, activ)
86 | self.layer_norm = layer_norm
87 | if layer_norm:
88 | #self.ctx_ln = LayerNorm(dim)
89 | self.hidden_ln = LayerNorm(dim)
90 |
91 | def applyMask(self, mask):
92 | self.mask = mask
93 |
94 | def initialize_parameters(self, param_init):
95 | pass
96 |
97 |
98 | def forward(self, input, context, values):
99 | """
100 | input: targetL x batch x dim
101 | context: batch x sourceL x dim
102 | values: batch x sourceL x dim
103 |
104 | Output:
105 |
106 | output: batch x hidden_size
107 | w: batch x sourceL
108 | """
109 | targetL = input.size(0)
110 | output_size = input.size(2)
111 | sourceL = context.size(1)
112 | batch_size = input.size(1)
113 |
114 | # targetL x batch x dim
115 | input = self.combine_hid(input)
116 | # (targetL x batch) x dim
117 | #context = self.combine_ctx(context)
118 | if self.layer_norm:
119 | input = self.hidden_ln(input)
120 | #context = self.ctx_ln(context)
121 |
122 | # batch x (sourceL x targetL) x dim
123 | context = context.repeat(1, targetL, 1)
124 |
125 | # batch x targetL x dim -> batch x (targetL x sourceL) x dim
126 | input = input.transpose(0, 1).repeat(1, 1, sourceL).contiguous().view(batch_size, -1, output_size)
127 | #context = context.view(batch_size, -1, output_size)
128 | # batch x (targetL x sourceL) x dim
129 | combined = self.activ(input + context)
130 |
131 | # batch x (targetL x sourceL) x 1
132 | attn = self.v(combined)
133 |
134 | # (batch_size x targetL) x sourceL
135 | attn = attn.contiguous().view(batch_size * targetL, sourceL)
136 |
137 | if self.mask is not None:
138 | attn.data.masked_fill_(self.mask, -float('inf'))
139 |
140 | # (batch_size x targetL) x sourceL
141 | attn = F.softmax(attn)
142 |
143 | # batch_size x targetL x sourceL
144 | attn3 = attn.contiguous().view(batch_size, targetL, sourceL)
145 |
146 | # batch x targetL x dim -> targetL x batch x dim
147 | weightedContext = torch.bmm(attn3, values).transpose(0, 1)
148 |
149 | return weightedContext, attn
150 |
151 | class MLPAttentionGRU(nn.Module):
152 | def __init__(self, dim, layer_norm=False, activ='tanh'):
153 | super(MLPAttentionGRU, self).__init__()
154 | self.dim = dim
155 | self.v = nn.Linear(self.dim, 1)
156 | self.combine_hid = nn.Linear(self.dim, self.dim)
157 | # self.combine_ctx = nn.Linear(self.dim, self.dim)
158 | self.mask = None
159 | self.activ = getattr(F, activ)
160 | self.layer_norm = layer_norm
161 | if layer_norm:
162 | # self.ctx_ln = LayerNorm(dim)
163 | self.hidden_ln = LayerNorm(dim)
164 |
165 | def applyMask(self, mask):
166 | self.mask = mask
167 |
168 | def initialize_parameters(self, param_init):
169 | pass
170 |
171 | def forward(self, input, context, values):
172 | """
173 | input: batch x dim
174 | context: batch x sourceL x dim
175 | values: batch x sourceL x dim
176 |
177 | Output:
178 |
179 | output: batch x hidden_size
180 | w: batch x sourceL
181 | """
182 | sourceL = context.size(1)
183 | batch_size = input.size(0)
184 |
185 | # batch x dim
186 | input = self.combine_hid(input)
187 |
188 | if self.layer_norm:
189 | input = self.hidden_ln(input)
190 |
191 | # batch x sourceL x dim
192 | input = input.unsqueeze(1).expand_as(context)
193 | # batch x sourceL x dim
194 | combined = self.activ(input + context)
195 |
196 | # batch x sourceL x 1
197 | attn = self.v(combined)
198 |
199 | # batch_size x sourceL
200 | attn = attn.view(batch_size, sourceL)
201 |
202 | if self.mask is not None:
203 | attn.data.masked_fill_(self.mask, -float('inf'))
204 |
205 | # batch_size x sourceL
206 | attn = F.softmax(attn)
207 |
208 | # batch_size x 1 x sourceL
209 | attn3 = attn.unsqueeze(1)
210 |
211 | # batch x dim
212 | weightedContext = torch.bmm(attn3, values).squeeze(1)
213 |
214 | return weightedContext, attn
215 |
216 |
217 |
218 | class SelfAttention(nn.Module):
219 | def __init__(self, k_size, q_size, v_size, out_size):
220 | super(SelfAttention, self).__init__()
221 | self.linearK = nn.Linear(v_size, out_size)
222 | self.linearQ = nn.Linear(q_size, out_size)
223 | self.linearV = nn.Linear(v_size, out_size)
224 | self.dim = out_size
225 | self.mask = None
226 |
227 | def applyMask(self, mask):
228 | self.mask = mask
229 |
230 | def forward(self, input, context, values):
231 | """
232 | input: batch x targetL x dim
233 | context: batch x sourceL x dim
234 | values: batch x sourceL x dim
235 | """
236 | K = self.linearK(input) # batch x targetL x out_size
237 | Q = self.linearQ(context) # batch x sourceL x out_size
238 | V = self.linearV(values) # batch x sourceL x out_size
239 |
240 | dot_prod = K.bmm(Q.transpose(1, 2)) * (1 / np.sqrt(self.dim)) # batch x targetL x sourceL
241 |
242 | attn = dot_prod.sum(dim=1, keepdim=False) # batch x sourceL
243 | if self.mask is not None:
244 | attn.data.masked_fill_(self.mask, -float('inf'))
245 | attn = F.softmax(attn) # batch x sourceL
246 | attn3 = attn.unsqueeze(2) # batch x sourceL x 1
247 | weightedContext = V * attn3 # batch x sourceL x out_size
248 |
249 | return weightedContext, attn
250 |
--------------------------------------------------------------------------------
/onmt/modules/Gate.py:
--------------------------------------------------------------------------------
1 | """
2 | Context gate is a decoder module that takes as input the previous word
3 | embedding, the current decoder state and the attention state, and produces a
4 | gate.
5 | The gate can be used to select the input from the target side context
6 | (decoder state), from the source context (attention state) or both.
7 | """
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | def ContextGateFactory(type, embeddings_size, decoder_size,
13 | attention_size, output_size):
14 | """Returns the correct ContextGate class"""
15 |
16 | gate_types = {'source': SourceContextGate,
17 | 'target': TargetContextGate,
18 | 'both': BothContextGate}
19 |
20 | assert type in gate_types, "Not valid ContextGate type: {0}".format(type)
21 | return gate_types[type](embeddings_size, decoder_size, attention_size,
22 | output_size)
23 |
24 |
25 | class ContextGate(nn.Module):
26 | """Implement up to the computation of the gate"""
27 |
28 | def __init__(self, embeddings_size, decoder_size,
29 | attention_size, output_size):
30 | super(ContextGate, self).__init__()
31 | input_size = embeddings_size + decoder_size + attention_size
32 | self.gate = nn.Linear(input_size, output_size, bias=True)
33 | self.sig = nn.Sigmoid()
34 | self.source_proj = nn.Linear(attention_size, output_size)
35 | self.target_proj = nn.Linear(embeddings_size + decoder_size,
36 | output_size)
37 |
38 | def forward(self, prev_emb, dec_state, attn_state):
39 | input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1)
40 | z = self.sig(self.gate(input_tensor))
41 | proj_source = self.source_proj(attn_state)
42 |
43 | proj_target = self.target_proj(
44 | torch.cat((prev_emb, dec_state), dim=1))
45 |
46 | return z, proj_source, proj_target
47 |
48 |
49 | class SourceContextGate(nn.Module):
50 | """Apply the context gate only to the source context"""
51 |
52 | def __init__(self, embeddings_size, decoder_size,
53 | attention_size, output_size):
54 | super(SourceContextGate, self).__init__()
55 | self.context_gate = ContextGate(embeddings_size, decoder_size,
56 | attention_size, output_size)
57 | self.tanh = nn.Tanh()
58 |
59 | def forward(self, prev_emb, dec_state, attn_state):
60 | z, source, target = self.context_gate(
61 | prev_emb, dec_state, attn_state)
62 | return target + z * source
63 |
64 |
65 | class TargetContextGate(nn.Module):
66 | """Apply the context gate only to the target context"""
67 |
68 | def __init__(self, embeddings_size, decoder_size,
69 | attention_size, output_size):
70 | super(TargetContextGate, self).__init__()
71 | self.context_gate = ContextGate(embeddings_size, decoder_size,
72 | attention_size, output_size)
73 | self.tanh = nn.Tanh()
74 |
75 | def forward(self, prev_emb, dec_state, attn_state):
76 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
77 | return z * target + source
78 |
79 |
80 | class BothContextGate(nn.Module):
81 | """Apply the context gate to both contexts"""
82 |
83 | def __init__(self, embeddings_size, decoder_size,
84 | attention_size, output_size):
85 | super(BothContextGate, self).__init__()
86 | self.context_gate = ContextGate(embeddings_size, decoder_size,
87 | attention_size, output_size)
88 | self.tanh = nn.Tanh()
89 |
90 | def forward(self, prev_emb, dec_state, attn_state):
91 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
92 | return (1. - z) * target + z * source
93 |
--------------------------------------------------------------------------------
/onmt/modules/ImageEncoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | import torch.cuda
5 | from torch.autograd import Variable
6 |
7 |
8 | class ImageEncoder(nn.Module):
9 | def __init__(self, opt):
10 | super(ImageEncoder, self).__init__()
11 | self.layers = opt.layers
12 | self.num_directions = 2 if opt.brnn else 1
13 | self.hidden_size = opt.rnn_size
14 |
15 | self.layer1 = nn.Conv2d(3, 64, kernel_size=(3, 3),
16 | padding=(1, 1), stride=(1, 1))
17 | self.layer2 = nn.Conv2d(64, 128, kernel_size=(3, 3),
18 | padding=(1, 1), stride=(1, 1))
19 | self.layer3 = nn.Conv2d(128, 256, kernel_size=(3, 3),
20 | padding=(1, 1), stride=(1, 1))
21 | self.layer4 = nn.Conv2d(256, 256, kernel_size=(3, 3),
22 | padding=(1, 1), stride=(1, 1))
23 | self.layer5 = nn.Conv2d(256, 512, kernel_size=(3, 3),
24 | padding=(1, 1), stride=(1, 1))
25 | self.layer6 = nn.Conv2d(512, 512, kernel_size=(3, 3),
26 | padding=(1, 1), stride=(1, 1))
27 |
28 | self.batch_norm1 = nn.BatchNorm2d(256)
29 | self.batch_norm2 = nn.BatchNorm2d(512)
30 | self.batch_norm3 = nn.BatchNorm2d(512)
31 |
32 | input_size = 512
33 | self.rnn = nn.LSTM(input_size, opt.rnn_size,
34 | num_layers=opt.layers,
35 | dropout=opt.dropout,
36 | bidirectional=opt.brnn)
37 | self.pos_lut = nn.Embedding(1000, input_size)
38 |
39 | def load_pretrained_vectors(self, opt):
40 | pass
41 |
42 | def forward(self, input):
43 | input = input[0]
44 | batchSize = input.size(0)
45 | # (batch_size, 64, imgH, imgW)
46 | # layer 1
47 | input = F.relu(self.layer1(input[:, :, :, :]-0.5), True)
48 |
49 | # (batch_size, 64, imgH/2, imgW/2)
50 | input = F.max_pool2d(input, kernel_size=(2, 2), stride=(2, 2))
51 |
52 | # (batch_size, 128, imgH/2, imgW/2)
53 | # layer 2
54 | input = F.relu(self.layer2(input), True)
55 |
56 | # (batch_size, 128, imgH/2/2, imgW/2/2)
57 | input = F.max_pool2d(input, kernel_size=(2, 2), stride=(2, 2))
58 |
59 | # (batch_size, 256, imgH/2/2, imgW/2/2)
60 | # layer 3
61 | # batch norm 1
62 | input = F.relu(self.batch_norm1(self.layer3(input)), True)
63 |
64 | # (batch_size, 256, imgH/2/2, imgW/2/2)
65 | # layer4
66 | input = F.relu(self.layer4(input), True)
67 |
68 | # (batch_size, 256, imgH/2/2/2, imgW/2/2)
69 | input = F.max_pool2d(input, kernel_size=(1, 2), stride=(1, 2))
70 |
71 | # (batch_size, 512, imgH/2/2/2, imgW/2/2)
72 | # layer 5
73 | # batch norm 2
74 | input = F.relu(self.batch_norm2(self.layer5(input)), True)
75 |
76 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2)
77 | input = F.max_pool2d(input, kernel_size=(2, 1), stride=(2, 1))
78 |
79 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2)
80 | input = F.relu(self.batch_norm3(self.layer6(input)), True)
81 |
82 | # # (batch_size, 512, H, W)
83 | # # (batch_size, H, W, 512)
84 | all_outputs = []
85 | for row in range(input.size(2)):
86 | inp = input[:, :, row, :].transpose(0, 2)\
87 | .transpose(1, 2)
88 | pos_emb = self.pos_lut(
89 | Variable(torch.cuda.LongTensor(batchSize).fill_(row)))
90 | with_pos = torch.cat(
91 | (pos_emb.view(1, pos_emb.size(0), pos_emb.size(1)), inp), 0)
92 | outputs, hidden_t = self.rnn(with_pos)
93 | all_outputs.append(outputs)
94 | out = torch.cat(all_outputs, 0)
95 |
96 | return hidden_t, out
97 |
--------------------------------------------------------------------------------
/onmt/modules/Normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class LayerNorm(nn.Module):
5 |
6 | def __init__(self, features, eps=1e-6):
7 | super().__init__()
8 | self.gamma = nn.Parameter(torch.ones(features))
9 | self.beta = nn.Parameter(torch.zeros(features))
10 | self.eps = eps
11 |
12 | def forward(self, x):
13 | mean = x.mean(-1, keepdim=True)
14 | std = x.std(-1, keepdim=True)
15 | return (self.gamma / (std + self.eps)) * (x - mean) + self.beta
16 |
17 | def initialize_parameters(self, param_init):
18 | self.gamma.data.fill_(1.)
19 | self.beta.data.fill_(0.)
--------------------------------------------------------------------------------
/onmt/modules/SRU_units.py:
--------------------------------------------------------------------------------
1 | """
2 | Context gate is a decoder module that takes as input the previous word
3 | embedding, the current decoder state and the attention state, and produces a
4 | gate.
5 | The gate can be used to select the input from the target side context
6 | (decoder state), from the source context (attention state) or both.
7 | """
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from .Normalization import LayerNorm
12 | from torch.autograd import Variable
13 | import numpy as np
14 |
15 |
16 | class AttSRU(nn.Module):
17 |
18 | def __init__(self, input_size, attention_size, output_size, layer_norm, dropout):
19 | from .Attention import DotAttention
20 | super(AttSRU, self).__init__()
21 | self.linear_in = nn.Linear(input_size, 3*output_size, bias=(not layer_norm))
22 | self.linear_hidden = nn.Linear(output_size, output_size, bias=(not layer_norm))
23 | self.linear_ctx = nn.Linear(output_size, output_size, bias=(not layer_norm))
24 | self.linear_enc = nn.Linear(output_size, output_size, bias=(not layer_norm))
25 | self.output_size = output_size
26 | self.attn = DotAttention(attention_size, layer_norm=True)
27 | self.layer_norm = layer_norm
28 | self.dropout = nn.Dropout(dropout)
29 | if self.layer_norm:
30 | self.preact_ln = LayerNorm(3 * output_size)
31 | self.enc_ln = LayerNorm(output_size)
32 |
33 | self.trans_h_ln = LayerNorm(output_size)
34 | self.trans_c_ln = LayerNorm(output_size)
35 |
36 | def initialize_parameters(self, param_init):
37 | self.preact_ln.initialize_parameters(param_init)
38 | self.trans_h_ln.initialize_parameters(param_init)
39 | self.trans_c_ln.initialize_parameters(param_init)
40 | self.enc_ln.initialize_parameters(param_init)
41 |
42 | def forward(self, prev_layer, hidden, enc_output):
43 | """
44 | :param prev_layer: targetL x batch x output_size
45 | :param hidden: batch x output_size
46 | :param enc_output: (targetL x batch) x sourceL x output_size
47 | :return:
48 | """
49 |
50 | # targetL x batch x output_size
51 | preact = self.linear_in(self.dropout(prev_layer))
52 | pctx = self.linear_enc(self.dropout(enc_output))
53 | if self.layer_norm:
54 | preact = self.preact_ln(preact)
55 | pctx = self.enc_ln(pctx)
56 | #z = self.z_ln(z)
57 | #prev_layer_t = self.prev_layer_ln(prev_layer_t)
58 | #h_gate = self.h_gate_ln(h_gate)
59 | z, h_gate, prev_layer_t = preact.split(self.output_size, dim=-1)
60 | z, h_gate = F.sigmoid(z), F.sigmoid(h_gate)
61 |
62 | ss = []
63 | for i in range(prev_layer.size(0)):
64 | s = (1. - z[i]) * hidden + z[i] * prev_layer_t[i]
65 | # targetL x batch x output_size
66 | ss += [s]
67 | # batch x output_size
68 | hidden = s
69 |
70 | # (targetL x batch) x output_size
71 | ss = torch.stack(ss)
72 | attn_out, attn = self.attn(self.dropout(ss), pctx, pctx)
73 | attn_out = attn_out / np.sqrt(self.output_size)
74 |
75 | trans_h = self.linear_hidden(self.dropout(ss))
76 | trans_c = self.linear_ctx(self.dropout(attn_out))
77 | if self.layer_norm:
78 | #out = self.post_ln(out)
79 | trans_h = self.trans_h_ln(trans_h)
80 | trans_c = self.trans_c_ln(trans_c)
81 | #trans_h, trans_c = F.tanh(trans_h), F.tanh(trans_c)
82 | out = trans_h + trans_c
83 | out = F.tanh(out)
84 | out = out.view(prev_layer.size())
85 | out = (1. - h_gate) * out + h_gate * prev_layer
86 |
87 | return out, hidden, attn
88 |
89 | class BiSRU(nn.Module):
90 |
91 | def __init__(self, input_size, output_size, layer_norm, dropout):
92 | super(BiSRU, self).__init__()
93 | self.input_linear = nn.Linear(input_size, 3*output_size, bias=(not layer_norm))
94 | self.layer_norm = layer_norm
95 | self.output_size = output_size
96 | self.dropout = nn.Dropout(dropout)
97 | if self.layer_norm:
98 | self.preact_ln = LayerNorm(3 * output_size)
99 | #self.x_f_ln = LayerNorm(output_size // 2)
100 | #self.x_b_ln = LayerNorm(output_size // 2)
101 | #self.f_g_ln = LayerNorm(output_size // 2)
102 | #self.b_g_ln = LayerNorm(output_size // 2)
103 | #self.highway_ln = LayerNorm(output_size)
104 |
105 | def initialize_parameters(self, param_init):
106 | self.preact_ln.initialize_parameters(param_init)
107 |
108 | def forward(self, input):
109 | pre_act = self.input_linear(self.dropout(input))
110 | #h_gate = pre_act[:, :, 2*self.output_size:]
111 | #gf, gb, x_f, x_b = pre_act[:, :, :2*self.output_size].split(self.output_size // 2, dim=-1)
112 | if self.layer_norm:
113 | pre_act = self.preact_ln(pre_act)
114 | #x_f = self.x_f_ln(x_f)
115 | #x_b = self.x_b_ln(x_b)
116 | #gf = self.f_g_ln(gf)
117 | #gb = self.b_g_ln(gb)
118 | #h_gate = self.highway_ln(h_gate)
119 | h_gate = pre_act[:, :, 2*self.output_size:]
120 | g, x = pre_act[:, :, :2*self.output_size].split(self.output_size, dim=-1)
121 | gf, gb = F.sigmoid(g).split(self.output_size // 2, dim=-1)
122 | x_f, x_b = x.split(self.output_size // 2, dim=-1)
123 | h_gate = F.sigmoid(h_gate)
124 | h_f_pre = gf * x_f
125 | h_b_pre = gb * x_b
126 |
127 | h_i_f = Variable(h_f_pre.data.new(gf[0].size()).zero_(), requires_grad=False)
128 | h_i_b = Variable(h_f_pre.data.new(gf[0].size()).zero_(), requires_grad=False)
129 |
130 | h_f, h_b = [], []
131 | for i in range(input.size(0)):
132 | h_i_f = (1. - gf[i]) * h_i_f + h_f_pre[i]
133 | h_i_b = (1. - gb[-(i+1)]) * h_i_b + h_b_pre[-(i+1)]
134 | h_f += [h_i_f]
135 | h_b += [h_i_b]
136 |
137 | h = torch.cat([torch.stack(h_f), torch.stack(h_b[::-1])], dim=-1)
138 |
139 | output = (1. - h_gate) * h + input * h_gate
140 |
141 | return output
142 |
143 |
144 | class SRU(nn.Module):
145 | def __init__(self, input_size, output_size, dropout):
146 | super(SRU, self).__init__()
147 | self.linear_in = nn.Linear(input_size, 3 * output_size)
148 | if input_size != output_size:
149 | self.reduce = nn.Linear(input_size, output_size)
150 | self.input_size = input_size
151 | self.output_size = output_size
152 | self.dropout = nn.Dropout(dropout)
153 |
154 | def initialize_parameters(self, param_init):
155 | pass
156 |
157 | def forward(self, prev_layer, hidden):
158 | """
159 | :param prev_layer: targetL x batch x output_size
160 | :param hidden: batch x output_size
161 | :return:
162 | """
163 |
164 | # targetL x batch x output_size
165 | preact = self.linear_in(self.dropout(prev_layer))
166 |
167 | prev_layer_t = preact[:, :, :self.output_size]
168 | z, h_gate = F.sigmoid(preact[:, :, self.output_size:]).split(self.output_size, dim=-1)
169 |
170 | ss = []
171 | for i in range(prev_layer.size(0)):
172 | s = (1 - z[i]) * hidden + z[i] * prev_layer_t[i]
173 | # targetL x batch x output_size
174 | ss += [s]
175 | # batch x output_size
176 | hidden = s
177 |
178 | # (targetL x batch) x output_size
179 | out = torch.stack(ss)
180 | if self.input_size != self.output_size:
181 | prev_layer = self.reduce(self.dropout(prev_layer))
182 |
183 | out = (1. - h_gate) * out + h_gate * prev_layer
184 |
185 | return out, hidden
186 |
--------------------------------------------------------------------------------
/onmt/modules/Units.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | from torch.nn.utils.rnn import PackedSequence
6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack
7 | from torch.nn.utils.rnn import pack_padded_sequence as pack
8 | from onmt.modules import SRU
9 |
10 | import math
11 |
12 | class ParallelMyRNN(nn.Module):
13 | def __init__(self, input_size, hidden_size,
14 | num_layers=1, dropout=0, bidirectional=False):
15 | super(ParallelMyRNN, self).__init__()
16 | self.unit = SRU
17 | self.input_size = input_size
18 | self.rnn_size = hidden_size
19 | self.num_layers = num_layers
20 | self.dropout = dropout
21 | self.Dropout = nn.Dropout(dropout)
22 | self.bidirectional = bidirectional
23 | self.num_directions = 2 if bidirectional else 1
24 | self.hidden_size = self.rnn_size * self.num_directions
25 | self.rnns = nn.ModuleList([nn.ModuleList() for _ in range(self.num_directions)])
26 |
27 | # for layer in range(num_layers):
28 | for layer in range(num_layers):
29 | layer_input_size = input_size if layer == 0 else self.hidden_size
30 | for direction in range(self.num_directions):
31 | self.rnns[direction].append(self.unit(layer_input_size, self.rnn_size, self.dropout))
32 |
33 | def reset_parameters(self):
34 | stdv = 1.0 / math.sqrt(self.rnn_size)
35 | for weight in self.parameters():
36 | weight.data.uniform_(-stdv, stdv)
37 |
38 | def initialize_parameters(self, param_init):
39 | for direction in range(self.num_directions):
40 | for layer in self.rnns[direction]:
41 | layer.initialize_parameters(param_init)
42 |
43 | def reverse_tensor(self, x, dim):
44 | idx = [i for i in range(x.size(dim) - 1, -1, -1)]
45 | idx = Variable(torch.LongTensor(idx))
46 | if x.is_cuda:
47 | idx = idx.cuda()
48 | return x.index_select(dim, idx)
49 |
50 | def forward(self, input, hidden=None):
51 |
52 | is_packed = isinstance(input, PackedSequence)
53 | if is_packed:
54 | input, batch_sizes = unpack(input)
55 | max_batch_size = batch_sizes[0]
56 |
57 | if hidden is None:
58 | # (num_layers x num_directions) x batch_size x rnn_size
59 | hidden = Variable(input.data.new(self.num_layers *
60 | self.num_directions,
61 | input.size(1),
62 | self.rnn_size).zero_(), requires_grad=False)
63 | if input.is_cuda:
64 | hidden = hidden.cuda()
65 |
66 | gru_out = []
67 | _input = input
68 | for i in range(self.num_layers):
69 | if not self.bidirectional:
70 | prev_layer = self.Dropout(_input)
71 | h = hidden[i] # batch_size x rnn_size
72 | unit = self.rnns[0][i] # Computation unit
73 |
74 | layer_out, hid_uni = unit(prev_layer, h) # src_len x batch x hidden_size
75 |
76 | else:
77 | input_forward = self.Dropout(_input)
78 | input_backward = self.Dropout(_input)
79 | h_forward = hidden[i * self.num_directions] # batch_size x rnn_size
80 | h_backward = hidden[i * self.num_directions + 1] # batch_size x rnn_size
81 | unit_forward = self.rnns[0][i] # Computation unit
82 | unit_backward = self.rnns[1][i] # Computation unit
83 |
84 | output_forward, h_forward = unit_forward(input_forward, h_forward)
85 | output_backward, h_backward = self.compute_backwards(unit_backward, input_backward, h_backward)
86 |
87 | layer_out = torch.cat([output_forward, output_backward], dim=2) # src_len x batch x hidden_size
88 |
89 | _input = layer_out
90 |
91 | if self.bidirectional:
92 | gru_out.append(output_forward[-1].unsqueeze(0))
93 | gru_out.append(output_backward[-1].unsqueeze(0)) # num_directions x [batch x rnn_size]
94 | else:
95 | gru_out.append(layer_out)
96 |
97 | hidden = torch.cat(gru_out, dim=0) # (num_layers x num_directions) x batch x rnn_size
98 |
99 | output = _input
100 |
101 | return output, hidden
102 |
103 | def __repr__(self):
104 | s = '{name}({input_size}, {rnn_size}'
105 | if self.num_layers != 1:
106 | s += ', num_layers={num_layers}'
107 | if self.dropout != 0:
108 | s += ', dropout={dropout}'
109 | if self.bidirectional is not False:
110 | s += ', bidirectional={bidirectional}'
111 | s += ')'
112 | return s.format(name=self.__class__.__name__, **self.__dict__)
113 |
114 | def compute_backwards(self, unit, input, hidden):
115 | h = hidden
116 | steps = torch.cat(input.split(1, dim=0)[::-1], dim=0)
117 | out, hidden = unit(steps, h)
118 | out = torch.cat(out.split(1, dim=0)[::-1], dim=0)
119 | return out, hidden
120 |
--------------------------------------------------------------------------------
/onmt/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .SRU_units import SRU
2 | from .Units import ParallelMyRNN
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import onmt
2 | import onmt.Markdown
3 | import argparse
4 | import torch
5 |
6 |
7 | def loadImageLibs():
8 | "Conditional import of torch image libs."
9 | global Image, transforms
10 | from PIL import Image
11 | from torchvision import transforms
12 |
13 |
14 | parser = argparse.ArgumentParser(description='preprocess.py')
15 | onmt.Markdown.add_md_help_argument(parser)
16 |
17 | # **Preprocess Options**
18 |
19 | parser.add_argument('-config', help="Read options from this file")
20 |
21 | parser.add_argument('-src_type', default="bitext",
22 | choices=["bitext", "monotext", "img"],
23 | help="""Type of the source input.
24 | This affects all the subsequent operations
25 | Options are [bitext|monotext|img].""")
26 | parser.add_argument('-src_img_dir', default=".",
27 | help="Location of source images")
28 |
29 |
30 | parser.add_argument('-train',
31 | help="""Path to the monolingual training data""")
32 | parser.add_argument('-train_src', required=False,
33 | help="Path to the training source data")
34 | parser.add_argument('-train_tgt', required=False,
35 | help="Path to the training target data")
36 | parser.add_argument('-valid',
37 | help="""Path to the monolingual validation data""")
38 | parser.add_argument('-valid_src', required=False,
39 | help="Path to the validation source data")
40 | parser.add_argument('-valid_tgt', required=False,
41 | help="Path to the validation target data")
42 |
43 | parser.add_argument('-save_data', required=True,
44 | help="Output file for the prepared data")
45 |
46 | parser.add_argument('-src_vocab_size', type=int, default=50000,
47 | help="Size of the source vocabulary")
48 | parser.add_argument('-tgt_vocab_size', type=int, default=50000,
49 | help="Size of the target vocabulary")
50 | parser.add_argument('-src_vocab',
51 | help="Path to an existing source vocabulary")
52 | parser.add_argument('-tgt_vocab',
53 | help="Path to an existing target vocabulary")
54 |
55 | parser.add_argument('-src_seq_length', type=int, default=50,
56 | help="Maximum source sequence length")
57 | parser.add_argument('-src_seq_length_trunc', type=int, default=0,
58 | help="Truncate source sequence length.")
59 | parser.add_argument('-tgt_seq_length', type=int, default=50,
60 | help="Maximum target sequence length to keep.")
61 | parser.add_argument('-tgt_seq_length_trunc', type=int, default=0,
62 | help="Truncate target sequence length.")
63 |
64 | parser.add_argument('-shuffle', type=int, default=1,
65 | help="Shuffle data")
66 | parser.add_argument('-seed', type=int, default=3435,
67 | help="Random seed")
68 |
69 | parser.add_argument('-lower', action='store_true', help='lowercase data')
70 |
71 | parser.add_argument('-report_every', type=int, default=100000,
72 | help="Report status every this many sentences")
73 |
74 | opt = parser.parse_args()
75 |
76 | torch.manual_seed(opt.seed)
77 |
78 |
79 | def makeVocabulary(filename, size):
80 | vocab = onmt.Dict([onmt.Constants.PAD_WORD, onmt.Constants.UNK_WORD,
81 | onmt.Constants.BOS_WORD, onmt.Constants.EOS_WORD],
82 | lower=opt.lower)
83 |
84 | with open(filename) as f:
85 | for sent in f.readlines():
86 | for word in sent.split():
87 | vocab.add(word)
88 |
89 | originalSize = vocab.size()
90 | vocab = vocab.prune(size)
91 | print('Created dictionary of size %d (pruned from %d)' %
92 | (vocab.size(), originalSize))
93 |
94 | return vocab
95 |
96 |
97 | def initVocabulary(name, dataFile, vocabFile, vocabSize):
98 |
99 | vocab = None
100 | if vocabFile is not None:
101 | # If given, load existing word dictionary.
102 | print('Reading ' + name + ' vocabulary from \'' + vocabFile + '\'...')
103 | vocab = onmt.Dict()
104 | vocab.loadFile(vocabFile)
105 | print('Loaded ' + str(vocab.size()) + ' ' + name + ' words')
106 |
107 | if vocab is None:
108 | # If a dictionary is still missing, generate it.
109 | print('Building ' + name + ' vocabulary...')
110 | genWordVocab = makeVocabulary(dataFile, vocabSize)
111 |
112 | vocab = genWordVocab
113 |
114 | print()
115 | return vocab
116 |
117 |
118 | def saveVocabulary(name, vocab, file):
119 | print('Saving ' + name + ' vocabulary to \'' + file + '\'...')
120 | vocab.writeFile(file)
121 |
122 |
123 | def makeBilingualData(srcFile, tgtFile, srcDicts, tgtDicts):
124 | src, tgt = [], []
125 | sizes = []
126 | count, ignored = 0, 0
127 |
128 | print('Processing %s & %s ...' % (srcFile, tgtFile))
129 | srcF = open(srcFile)
130 | tgtF = open(tgtFile)
131 |
132 | while True:
133 | sline = srcF.readline()
134 | tline = tgtF.readline()
135 |
136 | # normal end of file
137 | if sline == "" and tline == "":
138 | break
139 |
140 | # source or target does not have same number of lines
141 | if sline == "" or tline == "":
142 | print('WARNING: src and tgt do not have the same # of sentences')
143 | break
144 |
145 | sline = sline.strip()
146 | tline = tline.strip()
147 |
148 | # source and/or target are empty
149 | if sline == "" or tline == "":
150 | print('WARNING: ignoring an empty line ('+str(count+1)+')')
151 | continue
152 |
153 | srcWords = sline.split()
154 | tgtWords = tline.split()
155 |
156 | if len(srcWords) <= opt.src_seq_length \
157 | and len(tgtWords) <= opt.tgt_seq_length:
158 |
159 | # Check truncation condition.
160 | if opt.src_seq_length_trunc != 0:
161 | srcWords = srcWords[:opt.src_seq_length_trunc]
162 | if opt.tgt_seq_length_trunc != 0:
163 | tgtWords = tgtWords[:opt.tgt_seq_length_trunc]
164 |
165 | if opt.src_type == "bitext":
166 | src += [srcDicts.convertToIdx(srcWords,
167 | onmt.Constants.UNK_WORD)]
168 | elif opt.src_type == "img":
169 | loadImageLibs()
170 | src += [transforms.ToTensor()(
171 | Image.open(opt.src_img_dir + "/" + srcWords[0]))]
172 |
173 | tgt += [tgtDicts.convertToIdx(tgtWords,
174 | onmt.Constants.UNK_WORD,
175 | onmt.Constants.BOS_WORD,
176 | onmt.Constants.EOS_WORD)]
177 | sizes += [len(srcWords)]
178 | else:
179 | ignored += 1
180 |
181 | count += 1
182 |
183 | if count % opt.report_every == 0:
184 | print('... %d sentences prepared' % count)
185 |
186 | srcF.close()
187 | tgtF.close()
188 |
189 | if opt.shuffle == 1:
190 | print('... shuffling sentences')
191 | perm = torch.randperm(len(src))
192 | src = [src[idx] for idx in perm]
193 | tgt = [tgt[idx] for idx in perm]
194 | sizes = [sizes[idx] for idx in perm]
195 |
196 | print('... sorting sentences by size')
197 | _, perm = torch.sort(torch.Tensor(sizes))
198 | src = [src[idx] for idx in perm]
199 | tgt = [tgt[idx] for idx in perm]
200 |
201 | print(('Prepared %d sentences ' +
202 | '(%d ignored due to length == 0 or src len > %d or tgt len > %d)') %
203 | (len(src), ignored, opt.src_seq_length, opt.tgt_seq_length))
204 |
205 | return src, tgt
206 |
207 |
208 | def makeMonolingualData(srcFile, srcDicts):
209 | src = []
210 | sizes = []
211 | count, ignored = 0, 0
212 |
213 | print('Processing %s ...' % (srcFile))
214 |
215 | with open(srcFile) as srcF:
216 | for sline in srcF:
217 | sline = sline.strip()
218 |
219 | # source and/or target are empty
220 | if sline == "":
221 | print('WARNING: ignoring an empty line ('+str(count+1)+')')
222 | continue
223 |
224 | srcWords = sline.split()
225 |
226 | if len(srcWords) <= opt.src_seq_length:
227 |
228 | # Check truncation condition.LGRU_model_1layers_acc_54.83_ppl_12.43_e1.pt
229 | if opt.src_seq_length_trunc != 0:
230 | srcWords = srcWords[:opt.src_seq_length_trunc]
231 |
232 | src += [srcDicts.convertToIdx(srcWords,
233 | onmt.Constants.UNK_WORD,
234 | onmt.Constants.BOS_WORD,
235 | onmt.Constants.EOS_WORD)]
236 | sizes += [len(srcWords)]
237 | else:
238 | ignored += 1
239 |
240 | count += 1
241 |
242 | if count % opt.report_every == 0:
243 | print('... %d sentences prepared' % count)
244 |
245 | if opt.shuffle == 1:
246 | print('... shuffling sentences')
247 | perm = torch.randperm(len(src))
248 | src = [src[idx] for idx in perm]
249 | sizes = [sizes[idx] for idx in perm]
250 |
251 | print('... sorting sentences by size')
252 | _, perm = torch.sort(torch.Tensor(sizes))
253 | src = [src[idx] for idx in perm]
254 |
255 | print(('Prepared %d sentences ' +
256 | '(%d ignored due to length == 0 or src len > %d)') %
257 | (len(src), ignored, opt.src_seq_length))
258 |
259 | return src
260 |
261 |
262 | def main():
263 |
264 | if opt.src_type in ['bitext', 'img']:
265 | assert None not in [opt.train_src, opt.train_tgt,
266 | opt.valid_src, opt.valid_tgt], \
267 | "With source type %s the following parameters are" \
268 | "required: -train_src, -train_tgt, " \
269 | "-valid_src, -valid_tgt" % (opt.src_type)
270 |
271 | elif opt.src_type == 'monotext':
272 | assert None not in [opt.train, opt.valid], \
273 | "With source type monotext the following " \
274 | "parameters are required: -train, -valid"
275 |
276 | dicts = {}
277 | dicts['src'] = onmt.Dict()
278 | if opt.src_type == 'bitext':
279 | dicts['src'] = initVocabulary('source', opt.train_src, opt.src_vocab,
280 | opt.src_vocab_size)
281 | dicts['tgt'] = initVocabulary('target', opt.train_tgt, opt.tgt_vocab,
282 | opt.tgt_vocab_size)
283 |
284 | elif opt.src_type == 'monotext':
285 | dicts['src'] = initVocabulary('source', opt.train, opt.src_vocab,
286 | opt.src_vocab_size)
287 |
288 | elif opt.src_type == 'img':
289 | dicts['tgt'] = initVocabulary('target', opt.train_tgt, opt.tgt_vocab,
290 | opt.tgt_vocab_size)
291 |
292 | print('Preparing training ...')
293 | train = {}
294 | valid = {}
295 |
296 | if opt.src_type in ['bitext', 'img']:
297 | train['src'], train['tgt'] = makeBilingualData(opt.train_src,
298 | opt.train_tgt,
299 | dicts['src'],
300 | dicts['tgt'])
301 |
302 | print('Preparing validation ...')
303 | valid['src'], valid['tgt'] = makeBilingualData(opt.valid_src,
304 | opt.valid_tgt,
305 | dicts['src'],
306 | dicts['tgt'])
307 |
308 | elif opt.src_type == 'monotext':
309 | train['src'] = makeMonolingualData(opt.train, dicts['src'])
310 | train['tgt'] = train['src'] # Keeps compatibility with bilingual code
311 | print('Preparing validation ...')
312 | valid['src'] = makeMonolingualData(opt.valid, dicts['src'])
313 | valid['tgt'] = valid['src']
314 |
315 | if opt.src_vocab is None:
316 | saveVocabulary('source', dicts['src'], opt.save_data + '.src.dict')
317 | if opt.src_type in ['bitext', 'img'] and opt.tgt_vocab is None:
318 | saveVocabulary('target', dicts['tgt'], opt.save_data + '.tgt.dict')
319 |
320 | print('Saving data to \'' + opt.save_data + '.train.pt\'...')
321 | save_data = {'dicts': dicts,
322 | 'type': opt.src_type,
323 | 'train': train,
324 | 'valid': valid}
325 | torch.save(save_data, opt.save_data + '.train.pt')
326 |
327 |
328 | if __name__ == "__main__":
329 | main()
330 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from distutils.core import setup
4 |
5 | setup(name='OpenNMT',
6 | version='0.1',
7 | description='OpenNMT',
8 | packages=['onmt', 'onmt.modules'])
9 |
--------------------------------------------------------------------------------
/test/test_simple.py:
--------------------------------------------------------------------------------
1 | import onmt
2 |
3 |
4 | def test_load():
5 | onmt
6 | pass
7 |
--------------------------------------------------------------------------------
/tools/extract_embeddings.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import onmt
4 | import torch
5 | import argparse
6 |
7 | import onmt.Models
8 |
9 | parser = argparse.ArgumentParser(description='translate.py')
10 |
11 | parser.add_argument('-model', required=True,
12 | help='Path to model .pt file')
13 | parser.add_argument('-output_dir', default='.',
14 | help="""Path to output the embeddings""")
15 | parser.add_argument('-gpu', type=int, default=-1,
16 | help="Device to run on")
17 |
18 |
19 | def write_embeddings(filename, dict, embeddings):
20 | with open(filename, 'w') as file:
21 | for i in range(len(embeddings)):
22 | str = dict.idxToLabel[i].encode("utf-8")
23 | for j in range(len(embeddings[0])):
24 | str = str + " %5f" % (embeddings[i][j])
25 | file.write(str + "\n")
26 |
27 |
28 | def main():
29 | opt = parser.parse_args()
30 | checkpoint = torch.load(opt.model)
31 | opt.cuda = opt.gpu > -1
32 | if opt.cuda:
33 | torch.cuda.set_device(opt.gpu)
34 |
35 | model_opt = checkpoint['opt']
36 | src_dict = checkpoint['dicts']['src']
37 | tgt_dict = checkpoint['dicts']['tgt']
38 |
39 | encoder = onmt.Models.Encoder(model_opt, src_dict)
40 | decoder = onmt.Models.Decoder(model_opt, tgt_dict)
41 | encoder_embeddings = encoder.word_lut.weight.data.tolist()
42 | decoder_embeddings = decoder.word_lut.weight.data.tolist()
43 |
44 | print("Writing source embeddings")
45 | write_embeddings(opt.output_dir + "/src_embeddings.txt", src_dict,
46 | encoder_embeddings)
47 |
48 | print("Writing target embeddings")
49 | write_embeddings(opt.output_dir + "/tgt_embeddings.txt", tgt_dict,
50 | decoder_embeddings)
51 |
52 | print('... done.')
53 | print('Converting model...')
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import onmt
4 | import onmt.Markdown
5 | import onmt.Models
6 | import onmt.Decoders
7 | import onmt.Encoders
8 | import onmt.modules
9 | import argparse
10 | import torch
11 | import torch.nn as nn
12 | from torch import cuda
13 | from torch.autograd import Variable
14 | from torch.nn import init
15 | import math
16 | import time
17 |
18 | parser = argparse.ArgumentParser(description='train.py')
19 | onmt.Markdown.add_md_help_argument(parser)
20 |
21 | # Data options
22 |
23 | parser.add_argument('-data', required=True,
24 | help='Path to the *-train.pt file from preprocess.py')
25 | parser.add_argument('-save_model', default='model',
26 | help="""Model filename (the model will be saved as
27 | _epochN_PPL.pt where PPL is the
28 | validation perplexity""")
29 | parser.add_argument('-train_from_state_dict', default='', type=str,
30 | help="""If training from a checkpoint then this is the
31 | path to the pretrained model's state_dict.""")
32 | parser.add_argument('-train_from', default='', type=str,
33 | help="""If training from a checkpoint then this is the
34 | path to the pretrained model.""")
35 |
36 | # Model options
37 |
38 | parser.add_argument('-model_type', type=str, default='nmt',
39 | choices=['nmt', 'lm'],
40 | help="""Kind of model to train, it can be
41 | neural machine translation or language model
42 | [nmt|lm]""")
43 | parser.add_argument('-layers_enc', type=int, default=2,
44 | help='Number of layers in the LSTM encoder')
45 | parser.add_argument('-layers_dec', type=int, default=2,
46 | help='Number of layers in the LSTM decoder')
47 | parser.add_argument('-rnn_size', type=int, default=500,
48 | help='Size of LSTM hidden states')
49 | parser.add_argument('-word_vec_size', type=int, default=500,
50 | help='Word embedding sizes')
51 | parser.add_argument('-input_feed', type=int, default=1,
52 | help="""Feed the context vector at each time step as
53 | additional input (via concatenation with the word
54 | embeddings) to the decoder.""")
55 | parser.add_argument('-rnn_type', type=str, default='LSTM',
56 | choices=['LSTM', 'GRU', 'SRU'],
57 | help="""The gate type to use in the RNNs""")
58 | parser.add_argument('-rnn_encoder_type', type=str,
59 | choices=['LSTM', 'GRU', 'SRU'],
60 | help="""The gate type to use in the encoder RNNs. It overwrites -rnn_type""")
61 | parser.add_argument('-rnn_decoder_type', type=str,
62 | choices=['LSTM', 'GRU', 'SRU'],
63 | help="""The gate type to use in the decoder RNNs. It overwrites -rnn_type""")
64 | parser.add_argument('-attn_type', type=str, default='mlp',
65 | choices=['mlp', 'dot'],
66 | help="""The attention type to use in the decoder""")
67 | parser.add_argument('-activ', type=str, default='tanh',
68 | help="""Activation function inside the RNNs.""")
69 | parser.add_argument('-brnn', action='store_true',
70 | help='Use a bidirectional encoder')
71 | parser.add_argument('-context_gate', type=str, default=None,
72 | choices=['source', 'target', 'both'],
73 | help="""Type of context gate to use [source|target|both].
74 | Do not select for no context gate.""")
75 | parser.add_argument('-decoder_type', type=str, default='StackedRNN',
76 | help="""Decoder neural architecture to use""")
77 | parser.add_argument('-encoder_type', type=str, default='RNN',
78 | help="""Encoder architecture""")
79 | parser.add_argument('-layer_norm', default=False, action="store_true",
80 | help="""Add layer normalization in recurrent units""")
81 |
82 | # Optimization options
83 |
84 | parser.add_argument('-batch_size', type=int, default=64,
85 | help='Maximum batch size')
86 | parser.add_argument('-max_generator_batches', type=int, default=32,
87 | help="""Maximum batches of words in a sequence to run
88 | the generator on in parallel. Higher is faster, but uses
89 | more memory.""")
90 | parser.add_argument('-epochs', type=int, default=13,
91 | help='Number of training epochs')
92 | parser.add_argument('-start_epoch', type=int, default=1,
93 | help='The epoch from which to start')
94 | parser.add_argument('-param_init', type=float, default=0.1,
95 | help="""Parameters are initialized over uniform distribution
96 | with support (-param_init, param_init)""")
97 | parser.add_argument('-optim', default='sgd',
98 | help="Optimization method. [sgd|adagrad|adadelta|adam]")
99 | parser.add_argument('-max_grad_norm', type=float, default=5,
100 | help="""If the norm of the gradient vector exceeds this,
101 | renormalize it to have the norm equal to max_grad_norm""")
102 | parser.add_argument('-dropout', type=float, default=0.3,
103 | help='Dropout probability; applied between LSTM stacks.')
104 | parser.add_argument('-curriculum', action="store_true",
105 | help="""For this many epochs, order the minibatches based
106 | on source sequence length. Sometimes setting this to 1 will
107 | increase convergence speed.""")
108 | parser.add_argument('-extra_shuffle', action="store_true",
109 | help="""By default only shuffle mini-batch order; when true,
110 | shuffle and re-assign mini-batches""")
111 | parser.add_argument('-change_optimizer', default=False, action='store_true',
112 | help="""In case a model is reloaded, it sets the optimizer
113 | values to the one set in the arguments""")
114 | parser.add_argument('-enc_short_path', type=bool, default=False,
115 | help="""If True, creates a short path from the source embeddings to the output
116 | by adding them to the attention""")
117 | parser.add_argument('-use_learning_rate_decay', action="store_true",
118 | help='if set, activate learning rate decay after every checkpoint')
119 | parser.add_argument('-save_each', type=int, default=10000,
120 | help="""The number of minibatches to compute before saving a checkpoint""")
121 |
122 | # learning rate
123 | parser.add_argument('-learning_rate', type=float, default=1.0,
124 | help="""Starting learning rate. If adagrad/adadelta/adam is
125 | used, then this is the global learning rate. Recommended
126 | settings: sgd = 1, adagrad = 0.1,
127 | adadelta = 1, adam = 0.001""")
128 | parser.add_argument('-learning_rate_decay', type=float, default=0.5,
129 | help="""If update_learning_rate, decay learning rate by
130 | this much if (i) perplexity does not decrease on the
131 | validation set or (ii) epoch has gone past
132 | start_decay_at""")
133 | parser.add_argument('-start_decay_at', type=int, default=8,
134 | help="""Start decaying every epoch after and including this
135 | epoch""")
136 |
137 | # pretrained word vectors
138 |
139 | parser.add_argument('-pre_word_vecs_enc',
140 | help="""If a valid path is specified, then this will load
141 | pretrained word embeddings on the encoder side.
142 | See README for specific formatting instructions.""")
143 | parser.add_argument('-pre_word_vecs_dec',
144 | help="""If a valid path is specified, then this will load
145 | pretrained word embeddings on the decoder side.
146 | See README for specific formatting instructions.""")
147 | parser.add_argument('-pre_word_vecs',
148 | help="""If a valid path is specified, then this will load
149 | pretrained word embeddings on the language model.
150 | See README for specific formatting instructions.""")
151 |
152 | # GPU
153 | parser.add_argument('-gpus', default=[], nargs='+', type=int,
154 | help="Use CUDA on the listed devices.")
155 |
156 | parser.add_argument('-log_interval', type=int, default=50,
157 | help="Print stats at this interval.")
158 |
159 | parser.add_argument('-seed', type=int, default=-1,
160 | help="""Random seed used for the experiments
161 | reproducibility.""")
162 |
163 | opt = parser.parse_args()
164 |
165 | print(opt)
166 |
167 | if opt.seed > 0:
168 | torch.manual_seed(opt.seed)
169 |
170 | if torch.cuda.is_available() and not opt.gpus:
171 | print("WARNING: You have a CUDA device, should run with -gpus 0")
172 |
173 | if opt.gpus:
174 | cuda.set_device(opt.gpus[0])
175 | if opt.seed > 0:
176 | torch.cuda.manual_seed(opt.seed)
177 |
178 |
179 | def NMTCriterion(vocabSize):
180 | weight = torch.ones(vocabSize)
181 | weight[onmt.Constants.PAD] = 0
182 | crit = nn.NLLLoss(weight, size_average=False)
183 | if opt.gpus:
184 | crit.cuda()
185 | return crit
186 |
187 |
188 | def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
189 | # compute generations one piece at a time
190 | num_correct, loss = 0, 0
191 | outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval)
192 |
193 | batch_size = outputs.size(1)
194 | outputs_split = torch.split(outputs, opt.max_generator_batches)
195 | targets_split = torch.split(targets, opt.max_generator_batches)
196 | for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
197 | out_t = out_t.view(-1, out_t.size(2))
198 | scores_t = generator(out_t)
199 | loss_t = crit(scores_t, targ_t.view(-1))
200 | pred_t = scores_t.max(1)[1]
201 | num_correct_t = pred_t.data.eq(targ_t.data) \
202 | .masked_select(
203 | targ_t.ne(onmt.Constants.PAD).data) \
204 | .sum()
205 | num_correct += num_correct_t
206 | loss += loss_t.data[0]
207 | if not eval:
208 | loss_t.div(batch_size).backward()
209 |
210 | grad_output = None if outputs.grad is None else outputs.grad.data
211 | return loss, grad_output, num_correct
212 |
213 |
214 | def eval(model, criterion, data):
215 | total_loss = 0
216 | total_words = 0
217 | total_num_correct = 0
218 |
219 | model.eval()
220 | for i in range(len(data)):
221 | # exclude original indices
222 | batch = data[i][:-1]
223 | outputs = model(batch)
224 | # exclude from targets
225 | targets = batch[1][1:]
226 | loss, _, num_correct = memoryEfficientLoss(
227 | outputs, targets, model.generator, criterion, eval=True)
228 | total_loss += loss
229 | total_num_correct += num_correct
230 | total_words += targets.data.ne(onmt.Constants.PAD).sum()
231 |
232 | model.train()
233 | return total_loss / total_words, total_num_correct / total_words
234 |
235 |
236 | def trainModel(model, trainData, validData, dataset, optim, opt):
237 | print(model)
238 | model.train()
239 |
240 | # Define criterion of each GPU.
241 | criterion = NMTCriterion(dataset['dicts']['tgt'].size())
242 |
243 | start_time = time.time()
244 |
245 | def trainEpoch(epoch, iter):
246 |
247 | if opt.extra_shuffle and epoch > opt.curriculum:
248 | trainData.shuffle()
249 |
250 | # Shuffle mini batch order.
251 | batchOrder = torch.randperm(len(trainData))
252 |
253 | total_loss, total_words, total_num_correct = 0, 0, 0
254 | report_loss, report_tgt_words = 0, 0
255 | report_src_words, report_num_correct = 0, 0
256 | start = time.time()
257 | for i in range(len(trainData)):
258 |
259 | if iter >= opt.epochs:
260 | break
261 | iter += 1
262 |
263 | batchIdx = batchOrder[i] if epoch > opt.curriculum else i
264 | # Exclude original indices.
265 | batch = trainData[batchIdx][:-1]
266 |
267 | model.zero_grad()
268 | outputs = model(batch)
269 | # Exclude from targets.
270 | targets = batch[1][1:]
271 | loss, gradOutput, num_correct = memoryEfficientLoss(
272 | outputs, targets, model.generator, criterion)
273 |
274 | outputs.backward(gradOutput)
275 |
276 | # Update the parameters.
277 | optim.step()
278 |
279 | num_words = targets.data.ne(onmt.Constants.PAD).sum()
280 | report_loss += loss
281 | report_num_correct += num_correct
282 | report_tgt_words += num_words
283 | report_src_words += batch[0][1].data.sum()
284 | total_loss += loss
285 | total_num_correct += num_correct
286 | total_words += num_words
287 | if i % opt.log_interval == -1 % opt.log_interval:
288 | print(("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; " +
289 | "%3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed") %
290 | (epoch, i+1, len(trainData),
291 | report_num_correct / report_tgt_words * 100,
292 | math.exp(report_loss / report_tgt_words),
293 | report_src_words/(time.time()-start),
294 | report_tgt_words/(time.time()-start),
295 | time.time()-start_time))
296 |
297 | report_loss, report_tgt_words = 0, 0
298 | report_src_words, report_num_correct = 0, 0
299 | start = time.time()
300 |
301 | if iter % opt.save_each == 0:
302 | # (2) evaluate on the validation set
303 | valid_loss, valid_acc = eval(model, criterion, validData)
304 | valid_ppl = math.exp(min(valid_loss, 100))
305 | print('Validation perplexity: %g' % valid_ppl)
306 | print('Validation accuracy: %g' % (valid_acc * 100))
307 |
308 | # (3) update the learning rate
309 | if opt.use_learning_rate_decay:
310 | optim.updateLearningRate(valid_ppl, iter)
311 |
312 | model_state_dict = (model.module.state_dict() if len(opt.gpus) > 1
313 | else model.state_dict())
314 | model_state_dict = {k: v for k, v in model_state_dict.items()
315 | if 'generator' not in k}
316 | generator_state_dict = (model.generator.module.state_dict()
317 | if len(opt.gpus) > 1
318 | else model.generator.state_dict())
319 | # (4) drop a checkpoint
320 | checkpoint = {
321 | 'model': model_state_dict,
322 | 'generator': generator_state_dict,
323 | 'dicts': dataset['dicts'],
324 | 'opt': opt,
325 | 'epoch': epoch,
326 | 'optim': optim,
327 | 'type': opt.model_type
328 | }
329 | torch.save(checkpoint,
330 | '%s_acc_%.2f_ppl_%.2f_iter%d_e%d.pt'
331 | % (opt.save_model, 100 * valid_acc, valid_ppl, iter, epoch))
332 |
333 | return total_loss / total_words, total_num_correct / total_words, iter
334 |
335 | epoch, iter = 1, 0
336 | while iter < opt.epochs:
337 | print('')
338 | # (1) train for one epoch on the training set
339 | train_loss, train_acc, iter = trainEpoch(epoch, iter)
340 | epoch += 1
341 | train_ppl = math.exp(min(train_loss, 100))
342 | print('Train perplexity: %g' % train_ppl)
343 | print('Train accuracy: %g' % (train_acc*100))
344 |
345 |
346 | def main():
347 | print("Loading data from '%s'" % opt.data)
348 |
349 | dataset = torch.load(opt.data)
350 | if opt.model_type == 'nmt':
351 | if dataset.get("type", "text") not in ["bitext", "text"]:
352 | print("WARNING: The provided dataset is not bilingual!")
353 | elif opt.model_type == 'lm':
354 | if dataset.get("type", "text") != 'monotext':
355 | print("WARNING: The provided dataset is not monolingual!")
356 | else:
357 | raise NotImplementedError('Not valid model type %s' % opt.model_type)
358 |
359 | dict_checkpoint = (opt.train_from if opt.train_from
360 | else opt.train_from_state_dict)
361 | if dict_checkpoint:
362 | print('Loading dicts from checkpoint at %s' % dict_checkpoint)
363 | checkpoint = torch.load(dict_checkpoint)
364 | if opt.model_type == 'nmt':
365 | assert checkpoint.get('type', None) is None or \
366 | checkpoint['type'] == "nmt", \
367 | "The loaded model is not neural machine translation!"
368 | elif opt.model_type == 'lm':
369 | assert checkpoint['type'] == "lm", \
370 | "The loaded model is not a language model!"
371 | dataset['dicts'] = checkpoint['dicts']
372 |
373 | trainData = onmt.Dataset(dataset['train']['src'],
374 | dataset['train']['tgt'], opt.batch_size, opt.gpus,
375 | data_type=dataset.get("type", "text"))
376 | validData = onmt.Dataset(dataset['valid']['src'],
377 | dataset['valid']['tgt'], opt.batch_size, opt.gpus,
378 | volatile=True,
379 | data_type=dataset.get("type", "text"))
380 |
381 | dicts = dataset['dicts']
382 | model_opt = checkpoint['opt'] if dict_checkpoint else opt
383 | if dicts.get('tgt', None) is None:
384 | # Makes the code compatible with the language model
385 | dicts['tgt'] = dicts['src']
386 | if opt.model_type == 'nmt':
387 | print(' * vocabulary size. source = %d; target = %d' %
388 | (dicts['src'].size(), dicts['tgt'].size()))
389 | elif opt.model_type == 'lm':
390 | print(' * vocabulary size = %d' %
391 | (dicts['src'].size()))
392 | print(' * number of training sentences. %d' %
393 | len(dataset['train']['src']))
394 | print(' * maximum batch size. %d' % opt.batch_size)
395 |
396 | print('Building model...')
397 |
398 | if opt.model_type == 'nmt':
399 |
400 | decoder = onmt.Decoders.getDecoder(model_opt.decoder_type)(model_opt, dicts['tgt'])
401 | encoder = onmt.Encoders.getEncoder(model_opt.encoder_type)(model_opt, dicts['src'])
402 |
403 | model = onmt.Models.NMTModel(encoder, decoder)
404 |
405 | elif opt.model_type == 'lm':
406 | model = onmt.LanguageModel.LM(model_opt, dicts['src'])
407 |
408 | generator = nn.Sequential(
409 | nn.Linear(model_opt.rnn_size, dicts['tgt'].size()),
410 | nn.LogSoftmax())
411 |
412 | if opt.train_from:
413 | print('Loading model from checkpoint at %s' % opt.train_from)
414 | chk_model = checkpoint['model']
415 | generator_state_dict = chk_model.generator.state_dict()
416 | model_state_dict = {k: v for k, v in chk_model.state_dict().items()
417 | if 'generator' not in k}
418 | model.load_state_dict(model_state_dict)
419 | generator.load_state_dict(generator_state_dict)
420 | opt.start_epoch = checkpoint['epoch'] + 1
421 |
422 | if opt.train_from_state_dict:
423 | print('Loading model from state_dict at %s'
424 | % opt.train_from_state_dict)
425 | model.load_state_dict(checkpoint['model'])
426 | generator.load_state_dict(checkpoint['generator'])
427 | model_opt.start_epoch = opt.start_epoch
428 | model_opt.epochs = opt.epochs
429 |
430 | if len(opt.gpus) >= 1:
431 | model.cuda()
432 | generator.cuda()
433 | else:
434 | model.cpu()
435 | generator.cpu()
436 |
437 | if len(opt.gpus) > 1:
438 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
439 | generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)
440 | model_opt["gpus"] = opt.gpus
441 |
442 | model.generator = generator
443 |
444 | if not opt.train_from_state_dict and not opt.train_from:
445 | for p in model.parameters():
446 | #p.data.uniform_(-opt.param_init, opt.param_init)
447 | if len(p.data.size()) > 1:
448 | init.xavier_normal(p.data)
449 | else:
450 | p.data.uniform_(-opt.param_init, opt.param_init)
451 | model.initialize_parameters(opt.param_init)
452 | model.load_pretrained_vectors(opt)
453 |
454 | if (not opt.train_from_state_dict and not opt.train_from) or opt.change_optimizer:
455 | optim = onmt.Optim(
456 | opt.optim, opt.learning_rate, opt.max_grad_norm,
457 | lr_decay=opt.learning_rate_decay,
458 | start_decay_at=opt.start_decay_at
459 | )
460 | optim.set_parameters(model.parameters())
461 | model_opt.learning_rate = opt.learning_rate
462 | model_opt.learning_rate_decay = opt.learning_rate_decay
463 | model_opt.save_each = opt.save_each
464 |
465 | else:
466 | print('Loading optimizer from checkpoint:')
467 | optim = checkpoint['optim']
468 | optim.optimizer.load_state_dict(
469 | checkpoint['optim'].optimizer.state_dict())
470 | optim.set_parameters(model.parameters())
471 |
472 | nParams = sum([p.nelement() for p in model.parameters()])
473 | print('* number of parameters: %d' % nParams)
474 |
475 | if opt.train_from or opt.train_from_state_dict:
476 | print(model_opt)
477 |
478 | model_opt.use_learning_rate_decay = opt.use_learning_rate_decay
479 | trainModel(model, trainData, validData, dataset, optim, model_opt)
480 |
481 |
482 | if __name__ == "__main__":
483 | main()
484 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | layer=$1
2 | gpu=$2
3 | python -u train.py -data data.train.pt -save_model path/to/model/SGU_${layer}layers -layer_norm
4 | -max_grad_norm 1 -layers_enc $layer -layers_dec $layer -dropout 0.1 -gpus $gpu -optim adam -learning_rate 0.0003 -decoder_type SR -encoder_type SR
5 | -attn_type dot -save_each 30000 -brnn -rnn_size 500 -epochs 1000000 -word_vec_size 500 > path/to/log.out
6 |
--------------------------------------------------------------------------------
/translate.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from builtins import bytes
3 |
4 | import onmt
5 | import onmt.Markdown
6 | import torch
7 | import argparse
8 | import math
9 | import codecs
10 | import os
11 |
12 | parser = argparse.ArgumentParser(description='translate.py')
13 | onmt.Markdown.add_md_help_argument(parser)
14 |
15 | parser.add_argument('-model', required=True,
16 | help='Path to model .pt file')
17 | parser.add_argument('-src', required=True,
18 | help='Source sequence to decode (one line per sequence)')
19 | parser.add_argument('-src_img_dir', default="",
20 | help='Source image directory')
21 | parser.add_argument('-tgt',
22 | help='True target sequence (optional)')
23 | parser.add_argument('-output', default='pred.txt',
24 | help="""Path to output the predictions (each line will
25 | be the decoded sequence""")
26 | parser.add_argument('-beam_size', type=int, default=5,
27 | help='Beam size')
28 | parser.add_argument('-batch_size', type=int, default=30,
29 | help='Batch size')
30 | parser.add_argument('-max_sent_length', type=int, default=100,
31 | help='Maximum sentence length.')
32 | parser.add_argument('-replace_unk', action="store_true",
33 | help="""Replace the generated UNK tokens with the source
34 | token that had highest attention weight. If phrase_table
35 | is provided, it will lookup the identified source token and
36 | give the corresponding target token. If it is not provided
37 | (or the identified source token does not exist in the
38 | table) then it will copy the source token""")
39 | # parser.add_argument('-phrase_table',
40 | # help="""Path to source-target dictionary to replace UNK
41 | # tokens. See README.md for the format of this file.""")
42 | parser.add_argument('-verbose', action="store_true",
43 | help='Print scores and predictions for each sentence')
44 | parser.add_argument('-dump_beam', type=str, default="",
45 | help='File to dump beam information to.')
46 |
47 | parser.add_argument('-n_best', type=int, default=1,
48 | help="""If verbose is set, will output the n_best
49 | decoded sentences""")
50 |
51 | parser.add_argument('-gpu', type=int, default=-1,
52 | help="Device to run on")
53 |
54 |
55 | def reportScore(name, scoreTotal, wordsTotal):
56 | print("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
57 | name, scoreTotal / wordsTotal,
58 | name, math.exp(-scoreTotal/wordsTotal)))
59 |
60 |
61 | def addone(f):
62 | for line in f:
63 | yield line
64 | yield None
65 |
66 |
67 | def main():
68 | opt = parser.parse_args()
69 | opt.cuda = opt.gpu > -1
70 | if opt.cuda:
71 | torch.cuda.set_device(opt.gpu)
72 |
73 | translator = onmt.Translator(opt)
74 |
75 |
76 | outF = codecs.open(opt.output, 'w', 'utf-8')
77 |
78 | predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0
79 |
80 | srcBatch, tgtBatch = [], []
81 |
82 | count = 0
83 |
84 | tgtF = codecs.open(opt.tgt, 'r', 'utf-8') if opt.tgt else None
85 |
86 | if opt.dump_beam != "":
87 | import json
88 | translator.initBeamAccum()
89 |
90 | for line in addone(codecs.open(opt.src, 'r', 'utf-8')):
91 | if line is not None:
92 | srcTokens = line.split()
93 | srcBatch += [srcTokens]
94 | if tgtF:
95 | tgtTokens = tgtF.readline().split() if tgtF else None
96 | tgtBatch += [tgtTokens]
97 |
98 | if len(srcBatch) < opt.batch_size:
99 | continue
100 | else:
101 | # at the end of file, check last batch
102 | if len(srcBatch) == 0:
103 | break
104 |
105 | predBatch, predScore, goldScore = translator.translate(srcBatch,
106 | tgtBatch)
107 | predScoreTotal += sum(score[0] for score in predScore)
108 | predWordsTotal += sum(len(x[0]) for x in predBatch)
109 | if tgtF is not None:
110 | goldScoreTotal += sum(goldScore)
111 | goldWordsTotal += sum(len(x) for x in tgtBatch)
112 |
113 | for b in range(len(predBatch)):
114 | count += 1
115 | outF.write(" ".join(predBatch[b][0]) + '\n')
116 | outF.flush()
117 |
118 | if opt.verbose:
119 | srcSent = ' '.join(srcBatch[b])
120 | if translator.tgt_dict.lower:
121 | srcSent = srcSent.lower()
122 | os.write(1, bytes('SENT %d: %s\n' % (count, srcSent), 'UTF-8'))
123 | os.write(1, bytes('PRED %d: %s\n' %
124 | (count, " ".join(predBatch[b][0])), 'UTF-8'))
125 | print("PRED SCORE: %.4f" % predScore[b][0])
126 |
127 | if tgtF is not None:
128 | tgtSent = ' '.join(tgtBatch[b])
129 | if translator.tgt_dict.lower:
130 | tgtSent = tgtSent.lower()
131 | os.write(1, bytes('GOLD %d: %s\n' %
132 | (count, tgtSent), 'UTF-8'))
133 | print("GOLD SCORE: %.4f" % goldScore[b])
134 |
135 | if opt.n_best > 1:
136 | print('\nBEST HYP:')
137 | for n in range(opt.n_best):
138 | os.write(1, bytes("[%.4f] %s\n" % (predScore[b][n],
139 | " ".join(predBatch[b][n])),
140 | 'UTF-8'))
141 |
142 | print('')
143 |
144 | srcBatch, tgtBatch = [], []
145 |
146 | reportScore('PRED', predScoreTotal, predWordsTotal)
147 | if tgtF:
148 | reportScore('GOLD', goldScoreTotal, goldWordsTotal)
149 |
150 | if tgtF:
151 | tgtF.close()
152 |
153 | if opt.dump_beam:
154 | json.dump(translator.beam_accum,
155 | codecs.open(opt.dump_beam, 'w', 'utf-8'))
156 |
157 |
158 | if __name__ == "__main__":
159 | main()
160 |
--------------------------------------------------------------------------------
/translate.sh:
--------------------------------------------------------------------------------
1 | model=$1
2 | test=$2
3 | gpu=$3
4 | python translate.py -src $test -model $model -output $model.test.out -gpu $gpu -batch_size 1
5 |
--------------------------------------------------------------------------------