├── .gitignore
├── LICENSE
├── README.md
├── experiments
├── configs
│ ├── iwslt14
│ │ ├── config-rnn.json
│ │ └── config-transformer.json
│ ├── wmt14
│ │ ├── config-rnn.json
│ │ ├── config-transformer-base.json
│ │ └── config-transformer-large.json
│ └── wmt16
│ │ ├── config-rnn.json
│ │ ├── config-transformer-base.json
│ │ └── config-transformer-large.json
├── distributed.py
├── nmt.py
├── options.py
├── scripts
│ └── multi-bleu.perl
├── slurm.py
└── translate.py
├── flownmt
├── __init__.py
├── data
│ ├── __init__.py
│ └── dataloader.py
├── flownmt.py
├── flows
│ ├── __init__.py
│ ├── actnorm.py
│ ├── couplings
│ │ ├── __init__.py
│ │ ├── blocks.py
│ │ ├── coupling.py
│ │ └── transform.py
│ ├── flow.py
│ ├── linear.py
│ ├── nmt.py
│ └── parallel
│ │ ├── __init__.py
│ │ ├── data_parallel.py
│ │ └── parallel_apply.py
├── modules
│ ├── __init__.py
│ ├── decoders
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ ├── rnn.py
│ │ ├── simple.py
│ │ └── transformer.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── encoder.py
│ │ ├── rnn.py
│ │ └── transformer.py
│ ├── posteriors
│ │ ├── __init__.py
│ │ ├── posterior.py
│ │ ├── rnn.py
│ │ ├── shift_rnn.py
│ │ └── transformer.py
│ └── priors
│ │ ├── __init__.py
│ │ ├── length_predictors
│ │ ├── __init__.py
│ │ ├── diff_discretized_mix_logistic.py
│ │ ├── diff_softmax.py
│ │ ├── predictor.py
│ │ └── utils.py
│ │ └── prior.py
├── nnet
│ ├── __init__.py
│ ├── attention.py
│ ├── criterion.py
│ ├── layer_norm.py
│ ├── positional_encoding.py
│ ├── transformer.py
│ └── weightnorm.py
├── optim
│ ├── __init__.py
│ ├── adamw.py
│ └── lr_scheduler.py
└── utils.py
├── images
└── flowseq_diagram.png
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # data dir
107 | /experiments/data/
108 |
109 | # model dir
110 | /experiments/models/
111 |
112 | # log dir
113 | /experiments/log/
114 |
115 | # test dir
116 | /tests/
117 |
118 | # IDE
119 | /.idea/
120 | *.iml
121 | *.sublime-project
122 | *.sublime-workspace
123 |
124 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FlowSeq: a Generative Flow based Sequence-to-Sequence Tookit.
2 | This is the Pytorch implementation for [FlowSeq: Non-Autoregressive Conditional Sequence Generation with Generative Flow](http://arxiv.org/abs/1909.02480), accepted by EMNLP 2019.
3 |
4 |
5 |
6 |
7 |
8 | We propose an efficient and effective model for non-autoregressive sequence generation using latent variable models.
9 | We model the complex distributions with generative flows, and design
10 | several layers of flow tailored for modeling the conditional density of sequential latent variables.
11 | On several machine translation benchmark datasets (wmt14-ende, wmt16-enro), we achieved comparable performance
12 | with state-of-the-art non-autoregressive NMT models and almost constant-decoding time w.r.t the sequence length.
13 |
14 | ## Requirements
15 | * Python version >= 3.6
16 | * Pytorch version >= 1.1
17 | * apex
18 | * Perl
19 |
20 | ## Installation
21 | 1. Install [NVIDIA-apex](https://github.com/NVIDIA/apex).
22 | 2. Install [Pytorch and torchvision](https://pytorch.org/get-started/locally/?source=Google&medium=PaidSearch&utm_campaign=1712416206&utm_adgroup=67591282235&utm_keyword=pytorch%20installation&utm_offering=AI&utm_Product=PYTorch&gclid=CjwKCAjw-7LrBRB6EiwAhh1yX0hnpuTNccHYdOCd3WeW1plR0GhjSkzqLuAL5eRNcobASoxbsOwX4RoCQKkQAvD_BwE).
23 |
24 | ## Data
25 | 1. WMT'14 English to German (EN-DE) can be obtained with scripts provided in [fairseq](https://github.com/pytorch/fairseq/blob/master/examples/translation/README.md#wmt14-english-to-german-convolutional).
26 | 2. WMT'16 English to Romania (EN-RO) can be obtained from [here](https://github.com/nyu-dl/dl4mt-nonauto#downloading-datasets--pre-trained-models).
27 |
28 | ## Training a new model
29 | The MT datasets should be named in the format of ``train.{language code}, dev.{language code}, test.{language code}``, e.g "train.de".
30 | Suppose we put the WMT14-ENDE data sets under ``data/wmt14-ende/real-bpe/``, we can train FlowSeq over this data on one node with the
31 | following script:
32 | ```bash
33 | cd experiments
34 |
35 | python -u distributed.py \
36 | --nnodes 1 --node_rank 0 --nproc_per_node --master_addr \
37 | --master_port \
38 | --config configs/wmt14/config-transformer-base.json --model_path \
39 | --data_path data/wmt14-ende/real-bpe/ \
40 | --batch_size 2048 --batch_steps 1 --init_batch_size 512 --eval_batch_size 32 \
41 | --src en --tgt de \
42 | --lr 0.0005 --beta1 0.9 --beta2 0.999 --eps 1e-8 --grad_clip 1.0 --amsgrad \
43 | --lr_decay 'expo' --weight_decay 0.001 \
44 | --init_steps 30000 --kl_warmup_steps 10000 \
45 | --subword 'joint-bpe' --bucket_batch 1 --create_vocab
46 | ```
47 | After training, under the , there will be saved checkpoints, `model.pt`, `config.json`, `log.txt`,
48 | `vocab` directory and intermediate translation results under the `translations` directory.
49 |
50 | #### Note:
51 | - The argument --batch_steps is used for accumulated gradients to trade speed for memory. The size of each segment of data batch is batch-size / (num_gpus * batch_steps).
52 | - To train FlowSeq on multiple nodes, we provide a script for the slurm cluster environment `/experiments/slurm.py` or please
53 | refer to the pytorch distributed parallel training [tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html).
54 | - To create distillation dataset, please use [fairseq](https://github.com/pytorch/fairseq/blob/master/examples/translation/README.md#neural-machine-translation) to train a Transformer model
55 | and translate the source data set.
56 |
57 | ## Translation and evalutaion
58 | ```bash
59 | cd experiments
60 |
61 | python -u translate.py \
62 | --model_path \
63 | --data_path data/wmt14-ende/real-bpe/ \
64 | --batch_size 32 --bucket_batch 1 \
65 | --decode {'argmax', 'iw', 'sample'} \
66 | --tau 0.0 --nlen 3 --ntr 1
67 | ```
68 | Please check details of arguments [here](https://github.com/XuezheMax/flowseq/blob/master/experiments/options.py#L48).
69 |
70 | To keep the output translations original order of the input test data, use `--bucket_batch 0`.
71 |
72 | ## References
73 | ```
74 | @inproceedings{flowseq2019,
75 | title = {FlowSeq: Non-Autoregressive Conditional Sequence Generation with Generative Flow},
76 | author = {Ma, Xuezhe and Zhou, Chunting and Li, Xian and Neubig, Graham and Hovy, Eduard},
77 | booktitle = {Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing},
78 | address = {Hong Kong},
79 | month = {November},
80 | year = {2019}
81 | }
82 | ```
83 |
--------------------------------------------------------------------------------
/experiments/configs/iwslt14/config-rnn.json:
--------------------------------------------------------------------------------
1 | {
2 | "de_vocab_size": 32012,
3 | "en_vocab_size": 22825,
4 | "max_de_length": 64,
5 | "max_en_length": 64,
6 | "share_embed": false,
7 | "tie_weights": false,
8 | "embed_dim": 256,
9 | "latent_dim": 256,
10 | "hidden_size": 256,
11 | "prior": {
12 | "type": "normal",
13 | "length_predictor": {
14 | "type": "diff_softmax",
15 | "diff_range": 16,
16 | "dropout": 0.33,
17 | "label_smoothing": 0.1
18 | },
19 | "flow": {
20 | "levels": 3,
21 | "num_steps": [4, 2, 2],
22 | "factors": [2, 2],
23 | "hidden_features": 256,
24 | "transform": "affine",
25 | "coupling_type": "rnn",
26 | "rnn_mode": "LSTM",
27 | "dropout": 0.0,
28 | "inverse": true
29 | }
30 | },
31 | "encoder": {
32 | "type": "rnn",
33 | "rnn_mode": "LSTM",
34 | "num_layers": 2
35 | },
36 | "posterior": {
37 | "type": "rnn",
38 | "rnn_mode": "LSTM",
39 | "num_layers": 1,
40 | "use_attn": true,
41 | "dropout": 0.33,
42 | "dropword": 0.2
43 | },
44 | "decoder": {
45 | "type": "rnn",
46 | "rnn_mode": "LSTM",
47 | "num_layers": 1,
48 | "dropout": 0.33,
49 | "dropword": 0.2,
50 | "label_smoothing": 0.0
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/experiments/configs/iwslt14/config-transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "de_vocab_size": 32012,
3 | "en_vocab_size": 22825,
4 | "max_de_length": 64,
5 | "max_en_length": 64,
6 | "share_embed": false,
7 | "tie_weights": false,
8 | "embed_dim": 256,
9 | "latent_dim": 256,
10 | "hidden_size": 1024,
11 | "prior": {
12 | "type": "normal",
13 | "length_predictor": {
14 | "type": "diff_softmax",
15 | "diff_range": 16,
16 | "dropout": 0.33,
17 | "label_smoothing": 0.1
18 | },
19 | "flow": {
20 | "levels": 3,
21 | "num_steps": [4, 4, 2],
22 | "factors": [2, 2],
23 | "hidden_features": 512,
24 | "transform": "affine",
25 | "coupling_type": "self_attn",
26 | "heads": 4,
27 | "pos_enc": "attn",
28 | "max_length": 200,
29 | "dropout": 0.0,
30 | "inverse": true
31 | }
32 | },
33 | "encoder": {
34 | "type": "transformer",
35 | "num_layers": 5,
36 | "heads": 4,
37 | "max_length": 200,
38 | "dropout": 0.2
39 | },
40 | "posterior": {
41 | "type": "transformer",
42 | "num_layers": 3,
43 | "heads": 4,
44 | "max_length": 200,
45 | "dropout": 0.2,
46 | "dropword": 0.2
47 | },
48 | "decoder": {
49 | "type": "transformer",
50 | "num_layers": 3,
51 | "heads": 4,
52 | "max_length": 200,
53 | "dropout": 0.2,
54 | "dropword": 0.0,
55 | "label_smoothing": 0.1
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/experiments/configs/wmt14/config-rnn.json:
--------------------------------------------------------------------------------
1 | {
2 | "de_vocab_size": 37000,
3 | "en_vocab_size": 37000,
4 | "max_de_length": 80,
5 | "max_en_length": 80,
6 | "share_embed": true,
7 | "tie_weights": true,
8 | "embed_dim": 256,
9 | "latent_dim": 256,
10 | "hidden_size": 512,
11 | "prior": {
12 | "type": "normal",
13 | "length_predictor": {
14 | "type": "diff_softmax",
15 | "diff_range": 20,
16 | "dropout": 0.33,
17 | "label_smoothing": 0.1
18 | },
19 | "flow": {
20 | "levels": 3,
21 | "num_steps": [4, 4, 2],
22 | "factors": [2, 2],
23 | "hidden_features": 512,
24 | "transform": "affine",
25 | "coupling_type": "rnn",
26 | "rnn_mode": "LSTM",
27 | "dropout": 0.0,
28 | "inverse": true
29 | }
30 | },
31 | "encoder": {
32 | "type": "rnn",
33 | "rnn_mode": "LSTM",
34 | "num_layers": 3
35 | },
36 | "posterior": {
37 | "type": "rnn",
38 | "rnn_mode": "LSTM",
39 | "num_layers": 2,
40 | "use_attn": true,
41 | "dropout": 0.33,
42 | "dropword": 0.2
43 | },
44 | "decoder": {
45 | "type": "rnn",
46 | "rnn_mode": "LSTM",
47 | "num_layers": 2,
48 | "dropout": 0.33,
49 | "dropword": 0.2,
50 | "label_smoothing": 0.0
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/experiments/configs/wmt14/config-transformer-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "de_vocab_size": 37000,
3 | "en_vocab_size": 37000,
4 | "max_de_length": 80,
5 | "max_en_length": 80,
6 | "share_embed": true,
7 | "tie_weights": true,
8 | "embed_dim": 256,
9 | "latent_dim": 256,
10 | "hidden_size": 1024,
11 | "prior": {
12 | "type": "normal",
13 | "length_predictor": {
14 | "type": "diff_softmax",
15 | "diff_range": 20,
16 | "dropout": 0.33,
17 | "label_smoothing": 0.1
18 | },
19 | "flow": {
20 | "levels": 3,
21 | "num_steps": [4, 4, 2],
22 | "factors": [2, 2],
23 | "hidden_features": 512,
24 | "transform": "affine",
25 | "coupling_type": "self_attn",
26 | "heads": 8,
27 | "pos_enc": "attn",
28 | "max_length": 250,
29 | "dropout": 0.0,
30 | "inverse": true
31 | }
32 | },
33 | "encoder": {
34 | "type": "transformer",
35 | "num_layers": 6,
36 | "heads": 8,
37 | "max_length": 250,
38 | "dropout": 0.1
39 | },
40 | "posterior": {
41 | "type": "transformer",
42 | "num_layers": 4,
43 | "heads": 8,
44 | "max_length": 250,
45 | "dropout": 0.1,
46 | "dropword": 0.2
47 | },
48 | "decoder": {
49 | "type": "transformer",
50 | "num_layers": 4,
51 | "heads": 8,
52 | "max_length": 250,
53 | "dropout": 0.1,
54 | "dropword": 0.0,
55 | "label_smoothing": 0.1
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/experiments/configs/wmt14/config-transformer-large.json:
--------------------------------------------------------------------------------
1 | {
2 | "de_vocab_size": 37000,
3 | "en_vocab_size": 37000,
4 | "max_de_length": 80,
5 | "max_en_length": 80,
6 | "share_embed": true,
7 | "tie_weights": true,
8 | "embed_dim": 512,
9 | "latent_dim": 512,
10 | "hidden_size": 1024,
11 | "prior": {
12 | "type": "normal",
13 | "length_predictor": {
14 | "type": "diff_softmax",
15 | "diff_range": 20,
16 | "dropout": 0.33,
17 | "label_smoothing": 0.1
18 | },
19 | "flow": {
20 | "levels": 3,
21 | "num_steps": [4, 4, 2],
22 | "factors": [2, 2],
23 | "hidden_features": 1024,
24 | "transform": "affine",
25 | "coupling_type": "self_attn",
26 | "heads": 8,
27 | "pos_enc": "attn",
28 | "max_length": 250,
29 | "dropout": 0.0,
30 | "inverse": true
31 | }
32 | },
33 | "encoder": {
34 | "type": "transformer",
35 | "num_layers": 6,
36 | "heads": 8,
37 | "max_length": 250,
38 | "dropout": 0.1
39 | },
40 | "posterior": {
41 | "type": "transformer",
42 | "num_layers": 4,
43 | "heads": 8,
44 | "max_length": 250,
45 | "dropout": 0.1,
46 | "dropword": 0.2
47 | },
48 | "decoder": {
49 | "type": "transformer",
50 | "num_layers": 4,
51 | "heads": 8,
52 | "max_length": 250,
53 | "dropout": 0.1,
54 | "dropword": 0.0,
55 | "label_smoothing": 0.1
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/experiments/configs/wmt16/config-rnn.json:
--------------------------------------------------------------------------------
1 | {
2 | "ro_vocab_size": 31500,
3 | "en_vocab_size": 31500,
4 | "max_ro_length": 80,
5 | "max_en_length": 80,
6 | "share_embed": true,
7 | "tie_weights": true,
8 | "embed_dim": 256,
9 | "latent_dim": 256,
10 | "hidden_size": 512,
11 | "prior": {
12 | "type": "normal",
13 | "length_predictor": {
14 | "type": "diff_softmax",
15 | "diff_range": 20,
16 | "dropout": 0.33,
17 | "label_smoothing": 0.1
18 | },
19 | "flow": {
20 | "levels": 3,
21 | "num_steps": [4, 4, 2],
22 | "factors": [2, 2],
23 | "hidden_features": 512,
24 | "transform": "affine",
25 | "coupling_type": "rnn",
26 | "rnn_mode": "LSTM",
27 | "dropout": 0.0,
28 | "inverse": true
29 | }
30 | },
31 | "encoder": {
32 | "type": "rnn",
33 | "rnn_mode": "LSTM",
34 | "num_layers": 3
35 | },
36 | "posterior": {
37 | "type": "rnn",
38 | "rnn_mode": "LSTM",
39 | "num_layers": 2,
40 | "use_attn": true,
41 | "dropout": 0.33,
42 | "dropword": 0.2
43 | },
44 | "decoder": {
45 | "type": "rnn",
46 | "rnn_mode": "LSTM",
47 | "num_layers": 2,
48 | "dropout": 0.33,
49 | "dropword": 0.2,
50 | "label_smoothing": 0.0
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/experiments/configs/wmt16/config-transformer-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "ro_vocab_size": 31500,
3 | "en_vocab_size": 31500,
4 | "max_ro_length": 80,
5 | "max_en_length": 80,
6 | "share_embed": true,
7 | "tie_weights": true,
8 | "embed_dim": 256,
9 | "latent_dim": 256,
10 | "hidden_size": 1024,
11 | "prior": {
12 | "type": "normal",
13 | "length_predictor": {
14 | "type": "diff_softmax",
15 | "diff_range": 20,
16 | "dropout": 0.33,
17 | "label_smoothing": 0.1
18 | },
19 | "flow": {
20 | "levels": 3,
21 | "num_steps": [4, 4, 2],
22 | "factors": [2, 2],
23 | "hidden_features": 512,
24 | "transform": "affine",
25 | "coupling_type": "self_attn",
26 | "heads": 8,
27 | "pos_enc": "attn",
28 | "max_length": 250,
29 | "dropout": 0.0,
30 | "inverse": true
31 | }
32 | },
33 | "encoder": {
34 | "type": "transformer",
35 | "num_layers": 6,
36 | "heads": 8,
37 | "max_length": 250,
38 | "dropout": 0.1
39 | },
40 | "posterior": {
41 | "type": "transformer",
42 | "num_layers": 4,
43 | "heads": 8,
44 | "max_length": 250,
45 | "dropout": 0.1,
46 | "dropword": 0.2
47 | },
48 | "decoder": {
49 | "type": "transformer",
50 | "num_layers": 4,
51 | "heads": 8,
52 | "max_length": 250,
53 | "dropout": 0.1,
54 | "dropword": 0.0,
55 | "label_smoothing": 0.1
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/experiments/configs/wmt16/config-transformer-large.json:
--------------------------------------------------------------------------------
1 | {
2 | "ro_vocab_size": 31500,
3 | "en_vocab_size": 31500,
4 | "max_ro_length": 80,
5 | "max_en_length": 80,
6 | "share_embed": true,
7 | "tie_weights": true,
8 | "embed_dim": 512,
9 | "latent_dim": 512,
10 | "hidden_size": 1024,
11 | "prior": {
12 | "type": "normal",
13 | "length_predictor": {
14 | "type": "diff_softmax",
15 | "diff_range": 20,
16 | "dropout": 0.33,
17 | "label_smoothing": 0.1
18 | },
19 | "flow": {
20 | "levels": 3,
21 | "num_steps": [4, 4, 2],
22 | "factors": [2, 2],
23 | "hidden_features": 1024,
24 | "transform": "affine",
25 | "coupling_type": "self_attn",
26 | "heads": 8,
27 | "pos_enc": "attn",
28 | "max_length": 250,
29 | "dropout": 0.0,
30 | "inverse": true
31 | }
32 | },
33 | "encoder": {
34 | "type": "transformer",
35 | "num_layers": 6,
36 | "heads": 8,
37 | "max_length": 250,
38 | "dropout": 0.1
39 | },
40 | "posterior": {
41 | "type": "transformer",
42 | "num_layers": 4,
43 | "heads": 8,
44 | "max_length": 250,
45 | "dropout": 0.1,
46 | "dropword": 0.2
47 | },
48 | "decoder": {
49 | "type": "transformer",
50 | "num_layers": 4,
51 | "heads": 8,
52 | "max_length": 250,
53 | "dropout": 0.1,
54 | "dropword": 0.0,
55 | "label_smoothing": 0.1
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/experiments/distributed.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | current_path = os.path.dirname(os.path.realpath(__file__))
5 | root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
6 | sys.path.append(root_path)
7 |
8 | import json
9 | import signal
10 | import threading
11 | import torch
12 |
13 | from flownmt.data import NMTDataSet
14 | import experiments.options as options
15 | from experiments.nmt import main as single_process_main
16 |
17 |
18 | def create_dataset(args):
19 | model_path = args.model_path
20 | if not os.path.exists(model_path):
21 | os.makedirs(model_path)
22 |
23 | result_path = os.path.join(model_path, 'translations')
24 | if not os.path.exists(result_path):
25 | os.makedirs(result_path)
26 |
27 | vocab_path = os.path.join(model_path, 'vocab')
28 | if not os.path.exists(vocab_path):
29 | os.makedirs(vocab_path)
30 |
31 | data_path = args.data_path
32 | src_lang = args.src
33 | tgt_lang = args.tgt
34 | src_vocab_path = os.path.join(vocab_path, '{}.vocab'.format(src_lang))
35 | tgt_vocab_path = os.path.join(vocab_path, '{}.vocab'.format(tgt_lang))
36 |
37 | params = json.load(open(args.config, 'r'))
38 |
39 | src_max_vocab = params['{}_vocab_size'.format(src_lang)]
40 | tgt_max_vocab = params['{}_vocab_size'.format(tgt_lang)]
41 |
42 | NMTDataSet(data_path, src_lang, tgt_lang, src_vocab_path, tgt_vocab_path, src_max_vocab, tgt_max_vocab,
43 | subword=args.subword, create_vocab=True)
44 |
45 |
46 | def main():
47 | args = options.parse_distributed_args()
48 | args_dict = vars(args)
49 |
50 | nproc_per_node = args_dict.pop('nproc_per_node')
51 | nnodes = args_dict.pop('nnodes')
52 | node_rank = args_dict.pop('node_rank')
53 |
54 | # world size in terms of number of processes
55 | dist_world_size = nproc_per_node * nnodes
56 |
57 | # set PyTorch distributed related environmental variables
58 | current_env = os.environ
59 | current_env["MASTER_ADDR"] = args_dict.pop('master_addr')
60 | current_env["MASTER_PORT"] = str(args_dict.pop('master_port'))
61 | current_env["WORLD_SIZE"] = str(dist_world_size)
62 |
63 | create_vocab = args_dict.pop('create_vocab')
64 | if create_vocab:
65 | create_dataset(args)
66 | args.create_vocab = False
67 |
68 | batch_size = args.batch_size // dist_world_size
69 | args.batch_size = batch_size
70 |
71 | mp = torch.multiprocessing.get_context('spawn')
72 | # Create a thread to listen for errors in the child processes.
73 | error_queue = mp.SimpleQueue()
74 | error_handler = ErrorHandler(error_queue)
75 |
76 | processes = []
77 |
78 | for local_rank in range(0, nproc_per_node):
79 | # each process's rank
80 | dist_rank = nproc_per_node * node_rank + local_rank
81 | args.rank = dist_rank
82 | args.local_rank = local_rank
83 | process = mp.Process(target=run, args=(args, error_queue, ), daemon=True)
84 | process.start()
85 | error_handler.add_child(process.pid)
86 | processes.append(process)
87 |
88 | for process in processes:
89 | process.join()
90 |
91 |
92 | def run(args, error_queue):
93 | try:
94 | single_process_main(args)
95 | except KeyboardInterrupt:
96 | pass # killed by parent, do nothing
97 | except Exception:
98 | # propagate exception to parent process, keeping original traceback
99 | import traceback
100 | error_queue.put((args.rank, traceback.format_exc()))
101 |
102 |
103 | class ErrorHandler(object):
104 | """A class that listens for exceptions in children processes and propagates
105 | the tracebacks to the parent process."""
106 |
107 | def __init__(self, error_queue):
108 | self.error_queue = error_queue
109 | self.children_pids = []
110 | self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
111 | self.error_thread.start()
112 | signal.signal(signal.SIGUSR1, self.signal_handler)
113 |
114 | def add_child(self, pid):
115 | self.children_pids.append(pid)
116 |
117 | def error_listener(self):
118 | (rank, original_trace) = self.error_queue.get()
119 | self.error_queue.put((rank, original_trace))
120 | os.kill(os.getpid(), signal.SIGUSR1)
121 |
122 | def signal_handler(self, signalnum, stackframe):
123 | for pid in self.children_pids:
124 | os.kill(pid, signal.SIGINT) # kill children processes
125 | (rank, original_trace) = self.error_queue.get()
126 | msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
127 | msg += original_trace
128 | raise Exception(msg)
129 |
130 |
131 | if __name__ == "__main__":
132 | main()
133 |
--------------------------------------------------------------------------------
/experiments/options.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from argparse import ArgumentParser
3 |
4 |
5 | def parse_args():
6 | parser = ArgumentParser(description='FlowNMT')
7 | parser.add_argument('--rank', type=int, default=-1, metavar='N', help='rank of the process in all distributed processes')
8 | parser.add_argument("--local_rank", type=int, default=0, metavar='N', help='rank of the process in the machine')
9 | parser.add_argument('--config', type=str, help='config file', required=True)
10 | parser.add_argument('--batch_size', type=int, default=512, metavar='N',
11 | help='input batch size for training (default: 512)')
12 | parser.add_argument('--eval_batch_size', type=int, default=4, metavar='N',
13 | help='input batch size for eval (default: 4)')
14 | parser.add_argument('--batch_steps', type=int, default=1, metavar='N',
15 | help='number of steps for each batch (the batch size of each step is batch-size / steps (default: 1)')
16 | parser.add_argument('--init_batch_size', type=int, default=1024, metavar='N',
17 | help='number of instances for model initialization (default: 1024)')
18 | parser.add_argument('--epochs', type=int, default=500, metavar='N', help='number of epochs to train')
19 | parser.add_argument('--kl_warmup_steps', type=int, default=10000, metavar='N', help='number of steps to warm up KL weight(default: 10000)')
20 | parser.add_argument('--init_steps', type=int, default=5000, metavar='N', help='number of steps to train decoder (default: 5000)')
21 | parser.add_argument('--seed', type=int, default=65537, metavar='S', help='random seed (default: 65537)')
22 | parser.add_argument('--loss_type', choices=['sentence', 'token'], default='sentence',
23 | help='loss type (default: sentence)')
24 | parser.add_argument('--train_k', type=int, default=1, metavar='N', help='training K (default: 1)')
25 | parser.add_argument('--log_interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')
26 | parser.add_argument('--lr_decay', choices=['inv_sqrt', 'expo'], help='lr decay method', default='inv_sqrt')
27 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
28 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam')
29 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam')
30 | parser.add_argument('--eps', type=float, default=1e-6, help='eps of Adam')
31 | parser.add_argument('--weight_decay', type=float, default=0.0, help='weight for l2 norm decay')
32 | parser.add_argument('--amsgrad', action='store_true', help='AMS Grad')
33 | parser.add_argument('--grad_clip', type=float, default=0, help='max norm for gradient clip (default 0: no clip')
34 | parser.add_argument('--model_path', help='path for saving model file.', required=True)
35 | parser.add_argument('--data_path', help='path for data file.', default=None)
36 |
37 | parser.add_argument('--src', type=str, help='source language code', required=True)
38 | parser.add_argument('--tgt', type=str, help='target language code', required=True)
39 | parser.add_argument('--create_vocab', action='store_true', help='create vocabulary.')
40 | parser.add_argument('--share_all_embeddings', action='store_true', help='share source, target and output embeddings')
41 | parser.add_argument("--subword", type=str, default="joint-bpe", choices=['joint-bpe', 'sep-bpe', 'word', 'bert-bpe', 'joint-spm'])
42 | parser.add_argument('--recover', type=int, default=-1, help='recover the model from disk.')
43 | parser.add_argument("--bucket_batch", type=int, default=0, help="whether bucket data based on tgt length in batching")
44 |
45 | return parser.parse_args()
46 |
47 |
48 | def parse_translate_args():
49 | parser = ArgumentParser(description='FlowNMT')
50 | parser.add_argument('--batch_size', type=int, default=512, metavar='N', help='input batch size for training (default: 512)')
51 | parser.add_argument('--seed', type=int, default=524287, metavar='S', help='random seed (default: 65537)')
52 | parser.add_argument('--model_path', help='path for saving model file.', required=True)
53 | parser.add_argument('--data_path', help='path for data file.', default=None)
54 | parser.add_argument("--subword", type=str, default="joint-bpe", choices=['joint-bpe', 'sep-bpe', 'word', 'bert-bpe', 'joint-spm'])
55 | parser.add_argument("--bucket_batch", type=int, default=0, help="whether bucket data based on tgt length in batching")
56 | parser.add_argument('--decode', choices=['argmax', 'iw', 'sample'], help='decoding algorithm', default='argmax')
57 | parser.add_argument('--tau', type=float, default=0.0, metavar='S', help='temperature for iw decoding (default: 0.)')
58 | parser.add_argument('--nlen', type=int, default=3, help='number of length candidates.')
59 | parser.add_argument('--ntr', type=int, default=1, help='number of samples per length candidate.')
60 | return parser.parse_args()
61 |
62 |
63 | def parse_distributed_args():
64 | """
65 | Helper function parsing the command line options
66 | @retval ArgumentParser
67 | """
68 | parser = ArgumentParser(description="Dist FlowNMT")
69 |
70 | # Optional arguments for the launch helper
71 | parser.add_argument("--nnodes", type=int, default=1,
72 | help="The number of nodes to use for distributed "
73 | "training")
74 | parser.add_argument("--node_rank", type=int, default=0,
75 | help="The rank of the node for multi-node distributed "
76 | "training")
77 | parser.add_argument("--nproc_per_node", type=int, default=1,
78 | help="The number of processes to launch on each node, "
79 | "for GPU training, this is recommended to be set "
80 | "to the number of GPUs in your system so that "
81 | "each process can be bound to a single GPU.")
82 | parser.add_argument("--master_addr", default="127.0.0.1", type=str,
83 | help="Master node (rank 0)'s address, should be either "
84 | "the IP address or the hostname of node 0, for "
85 | "single node multi-proc training, the "
86 | "--master_addr can simply be 127.0.0.1")
87 | parser.add_argument("--master_port", default=29500, type=int,
88 | help="Master node (rank 0)'s free port that needs to "
89 | "be used for communciation during distributed "
90 | "training")
91 |
92 | # arguments for flownmt model
93 | parser.add_argument('--config', type=str, help='config file', required=True)
94 | parser.add_argument('--batch_size', type=int, default=512, metavar='N',
95 | help='input batch size for training (default: 512)')
96 | parser.add_argument('--eval_batch_size', type=int, default=4, metavar='N',
97 | help='input batch size for eval (default: 4)')
98 | parser.add_argument('--init_batch_size', type=int, default=1024, metavar='N',
99 | help='number of instances for model initialization (default: 1024)')
100 | parser.add_argument('--batch_steps', type=int, default=1, metavar='N',
101 | help='number of steps for each batch (the batch size of each step is batch-size / steps (default: 1)')
102 | parser.add_argument('--epochs', type=int, default=500, metavar='N', help='number of epochs to train')
103 | parser.add_argument('--kl_warmup_steps', type=int, default=10000, metavar='N',
104 | help='number of steps to warm up KL weight(default: 10000)')
105 | parser.add_argument('--init_steps', type=int, default=5000, metavar='N',
106 | help='number of steps to train decoder (default: 5000)')
107 | parser.add_argument('--seed', type=int, default=65537, metavar='S', help='random seed (default: 524287)')
108 | parser.add_argument('--loss_type', choices=['sentence', 'token'], default='sentence',
109 | help='loss type (default: sentence)')
110 | parser.add_argument('--train_k', type=int, default=1, metavar='N', help='training K (default: 1)')
111 | parser.add_argument('--log_interval', type=int, default=10, metavar='N',
112 | help='how many batches to wait before logging training status')
113 | parser.add_argument('--lr_decay', choices=['inv_sqrt', 'expo'], help='lr decay method', default='inv_sqrt')
114 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
115 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam')
116 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam')
117 | parser.add_argument('--eps', type=float, default=1e-6, help='eps of Adam')
118 | parser.add_argument('--weight_decay', type=float, default=0.0, help='weight for l2 norm decay')
119 | parser.add_argument('--amsgrad', action='store_true', help='AMS Grad')
120 | parser.add_argument('--grad_clip', type=float, default=0, help='max norm for gradient clip (default 0: no clip')
121 | parser.add_argument('--model_path', help='path for saving model file.', required=True)
122 | parser.add_argument('--data_path', help='path for data file.', default=None)
123 |
124 | parser.add_argument('--src', type=str, help='source language code', required=True)
125 | parser.add_argument('--tgt', type=str, help='target language code', required=True)
126 | parser.add_argument('--create_vocab', action='store_true', help='create vocabulary.')
127 | parser.add_argument('--share_all_embeddings', action='store_true', help='share source, target and output embeddings')
128 | parser.add_argument("--subword", type=str, default="joint-bpe",
129 | choices=['joint-bpe', 'sep-bpe', 'word', 'bert-bpe'])
130 | parser.add_argument("--bucket_batch", type=int, default=0,
131 | help="whether bucket data based on tgt length in batching")
132 | parser.add_argument('--recover', type=int, default=-1, help='recover the model from disk.')
133 |
134 | return parser.parse_args()
135 |
--------------------------------------------------------------------------------
/experiments/scripts/multi-bleu.perl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | #
3 | # This file is part of moses. Its use is licensed under the GNU Lesser General
4 | # Public License version 2.1 or, at your option, any later version.
5 |
6 | # $Id$
7 | use warnings;
8 | use strict;
9 |
10 | my $lowercase = 0;
11 | if ($ARGV[0] eq "-lc") {
12 | $lowercase = 1;
13 | shift;
14 | }
15 |
16 | my $stem = $ARGV[0];
17 | if (!defined $stem) {
18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n";
20 | exit(1);
21 | }
22 |
23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
24 |
25 | my @REF;
26 | my $ref=0;
27 | while(-e "$stem$ref") {
28 | &add_to_ref("$stem$ref",\@REF);
29 | $ref++;
30 | }
31 | &add_to_ref($stem,\@REF) if -e $stem;
32 | die("ERROR: could not find reference file $stem") unless scalar @REF;
33 |
34 | # add additional references explicitly specified on the command line
35 | shift;
36 | foreach my $stem (@ARGV) {
37 | &add_to_ref($stem,\@REF) if -e $stem;
38 | }
39 |
40 |
41 |
42 | sub add_to_ref {
43 | my ($file,$REF) = @_;
44 | my $s=0;
45 | if ($file =~ /.gz$/) {
46 | open(REF,"gzip -dc $file|") or die "Can't read $file";
47 | } else {
48 | open(REF,$file) or die "Can't read $file";
49 | }
50 | while([) {
51 | chop;
52 | push @{$$REF[$s++]}, $_;
53 | }
54 | close(REF);
55 | }
56 |
57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference);
58 | my $s=0;
59 | while() {
60 | chop;
61 | $_ = lc if $lowercase;
62 | my @WORD = split;
63 | my %REF_NGRAM = ();
64 | my $length_translation_this_sentence = scalar(@WORD);
65 | my ($closest_diff,$closest_length) = (9999,9999);
66 | foreach my $reference (@{$REF[$s]}) {
67 | # print "$s $_ <=> $reference\n";
68 | $reference = lc($reference) if $lowercase;
69 | my @WORD = split(' ',$reference);
70 | my $length = scalar(@WORD);
71 | my $diff = abs($length_translation_this_sentence-$length);
72 | if ($diff < $closest_diff) {
73 | $closest_diff = $diff;
74 | $closest_length = $length;
75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
76 | } elsif ($diff == $closest_diff) {
77 | $closest_length = $length if $length < $closest_length;
78 | # from two references with the same closeness to me
79 | # take the *shorter* into account, not the "first" one.
80 | }
81 | for(my $n=1;$n<=4;$n++) {
82 | my %REF_NGRAM_N = ();
83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
84 | my $ngram = "$n";
85 | for(my $w=0;$w<$n;$w++) {
86 | $ngram .= " ".$WORD[$start+$w];
87 | }
88 | $REF_NGRAM_N{$ngram}++;
89 | }
90 | foreach my $ngram (keys %REF_NGRAM_N) {
91 | if (!defined($REF_NGRAM{$ngram}) ||
92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}]
\n";
95 | }
96 | }
97 | }
98 | }
99 | $length_translation += $length_translation_this_sentence;
100 | $length_reference += $closest_length;
101 | for(my $n=1;$n<=4;$n++) {
102 | my %T_NGRAM = ();
103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) {
104 | my $ngram = "$n";
105 | for(my $w=0;$w<$n;$w++) {
106 | $ngram .= " ".$WORD[$start+$w];
107 | }
108 | $T_NGRAM{$ngram}++;
109 | }
110 | foreach my $ngram (keys %T_NGRAM) {
111 | $ngram =~ /^(\d+) /;
112 | my $n = $1;
113 | # my $corr = 0;
114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n";
115 | $TOTAL[$n] += $T_NGRAM{$ngram};
116 | if (defined($REF_NGRAM{$ngram})) {
117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
118 | $CORRECT[$n] += $T_NGRAM{$ngram};
119 | # $corr = $T_NGRAM{$ngram};
120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n";
121 | }
122 | else {
123 | $CORRECT[$n] += $REF_NGRAM{$ngram};
124 | # $corr = $REF_NGRAM{$ngram};
125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n";
126 | }
127 | }
128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
130 | }
131 | }
132 | $s++;
133 | }
134 | my $brevity_penalty = 1;
135 | my $bleu = 0;
136 |
137 | my @bleu=();
138 |
139 | for(my $n=1;$n<=4;$n++) {
140 | if (defined ($TOTAL[$n]) and defined ($CORRECT[$n])){
141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
143 | }else{
144 | $bleu[$n]=0;
145 | }
146 | }
147 |
148 | if ($length_reference==0){
149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
150 | exit(1);
151 | }
152 |
153 | if ($length_translation<$length_reference) {
154 | $brevity_penalty = exp(1-$length_reference/$length_translation);
155 | }
156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
157 | my_log( $bleu[2] ) +
158 | my_log( $bleu[3] ) +
159 | my_log( $bleu[4] ) ) / 4) ;
160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
161 | 100*$bleu,
162 | 100*$bleu[1],
163 | 100*$bleu[2],
164 | 100*$bleu[3],
165 | 100*$bleu[4],
166 | $brevity_penalty,
167 | $length_translation / $length_reference,
168 | $length_translation,
169 | $length_reference;
170 |
171 | sub my_log {
172 | return -9999999999 unless $_[0];
173 | return log($_[0]);
174 | }
--------------------------------------------------------------------------------
/experiments/slurm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | current_path = os.path.dirname(os.path.realpath(__file__))
5 | root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
6 | sys.path.append(root_path)
7 |
8 | import torch.multiprocessing as mp
9 |
10 | import experiments.options as options
11 | from experiments.nmt import main as single_process_main
12 |
13 |
14 | def main():
15 | args = options.parse_distributed_args()
16 | args_dict = vars(args)
17 |
18 | args_dict.pop('master_addr')
19 | str(args_dict.pop('master_port'))
20 | args_dict.pop('nnodes')
21 | args_dict.pop('nproc_per_node')
22 | args_dict.pop('node_rank')
23 |
24 | current_env = os.environ
25 | nnodes = int(current_env['SLURM_NNODES'])
26 | dist_world_size = int(current_env['SLURM_NTASKS'])
27 | args.rank = int(current_env['SLURM_PROCID'])
28 | args.local_rank = int(current_env['SLURM_LOCALID'])
29 |
30 |
31 | print('start process: rank={}({}), master addr={}, port={}, nnodes={}, world size={}'.format(
32 | args.rank, args.local_rank, current_env["MASTER_ADDR"], current_env["MASTER_PORT"], nnodes, dist_world_size))
33 | current_env["WORLD_SIZE"] = str(dist_world_size)
34 |
35 | create_vocab = args_dict.pop('create_vocab')
36 | assert not create_vocab
37 | args.create_vocab = False
38 |
39 | batch_size = args.batch_size // dist_world_size
40 | args.batch_size = batch_size
41 |
42 | single_process_main(args)
43 |
44 |
45 | if __name__ == "__main__":
46 | mp.set_start_method('forkserver')
47 | main()
48 |
--------------------------------------------------------------------------------
/experiments/translate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | current_path = os.path.dirname(os.path.realpath(__file__))
5 | root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
6 | sys.path.append(root_path)
7 |
8 | import time
9 | import json
10 | import random
11 | import numpy as np
12 |
13 | import torch
14 |
15 | from flownmt.data import NMTDataSet, DataIterator
16 | from flownmt import FlowNMT
17 | from experiments.options import parse_translate_args
18 |
19 |
20 | def calc_bleu(fref, fmt, result_path):
21 | script = os.path.join(current_path, 'scripts/multi-bleu.perl')
22 | temp = os.path.join(result_path, 'tmp')
23 | os.system("perl %s %s < %s > %s" % (script, fref, fmt, temp))
24 | bleu = open(temp, 'r').read().strip()
25 | bleu = bleu.split(",")[0].split("=")
26 | if len(bleu) < 2:
27 | return 0.0
28 | bleu = float(bleu[1].strip())
29 | return bleu
30 |
31 |
32 | def translate_argmax(dataset, dataloader, flownmt, result_path, outfile, tau, n_tr):
33 | flownmt.eval()
34 | translations = []
35 | lengths = []
36 | length_err = 0
37 | num_insts = 0
38 | start_time = time.time()
39 | num_back = 0
40 | for step, (src, tgt, src_masks, tgt_masks) in enumerate(dataloader):
41 | trans, lens = flownmt.translate_argmax(src, src_masks, n_tr=n_tr, tau=tau)
42 | translations.append(trans)
43 | lengths.append(lens)
44 | length_err += (lens.float() - tgt_masks.sum(dim=1)).abs().sum().item()
45 | num_insts += src.size(0)
46 | if step % 10 == 0:
47 | sys.stdout.write("\b" * num_back)
48 | sys.stdout.write(" " * num_back)
49 | sys.stdout.write("\b" * num_back)
50 | log_info = 'argmax translating (tau={:.1f}, n_tr={})...{}'.format(tau, n_tr, num_insts)
51 | sys.stdout.write(log_info)
52 | sys.stdout.flush()
53 | num_back = len(log_info)
54 | print('time: {:.1f}s'.format(time.time() - start_time))
55 | outfile = os.path.join(result_path, outfile)
56 | dataset.dump_to_file(translations, lengths, outfile)
57 | bleu = calc_bleu(dataloader.tgt_sort_origin_path, outfile, result_path)
58 | print('#SENT: {}, Length Err: {:.1f}, BLEU: {:.2f}'.format(num_insts, length_err / num_insts, bleu))
59 |
60 |
61 | def translate_iw(dataset, dataloader, flownmt, result_path, outfile, tau, n_len, n_tr):
62 | flownmt.eval()
63 | iwk = 4
64 | translations = []
65 | lengths = []
66 | length_err = 0
67 | num_insts = 0
68 | start_time = time.time()
69 | num_back = 0
70 | for step, (src, tgt, src_masks, tgt_masks) in enumerate(dataloader):
71 | trans, lens = flownmt.translate_iw(src, src_masks, n_len=n_len, n_tr=n_tr, tau=tau, k=iwk)
72 | translations.append(trans)
73 | lengths.append(lens)
74 | length_err += (lens.float() - tgt_masks.sum(dim=1)).abs().sum().item()
75 | num_insts += src.size(0)
76 | if step % 10 == 0:
77 | sys.stdout.write("\b" * num_back)
78 | sys.stdout.write(" " * num_back)
79 | sys.stdout.write("\b" * num_back)
80 | log_info = 'importance weighted translating (tau={:.1f}, n_len={}, n_tr={})...{}'.format(tau, n_len, n_tr, num_insts)
81 | sys.stdout.write(log_info)
82 | sys.stdout.flush()
83 | num_back = len(log_info)
84 | print('time: {:.1f}s'.format(time.time() - start_time))
85 | outfile = os.path.join(result_path, outfile)
86 | dataset.dump_to_file(translations, lengths, outfile)
87 | bleu = calc_bleu(dataloader.tgt_sort_origin_path, outfile, result_path)
88 | print('#SENT: {}, Length Err: {:.1f}, BLEU: {:.2f}'.format(num_insts, length_err / num_insts, bleu))
89 |
90 |
91 | def sample(dataset, dataloader, flownmt, result_path, outfile, tau, n_len, n_tr):
92 | flownmt.eval()
93 | lengths = []
94 | translations = []
95 | num_insts = 0
96 | start_time = time.time()
97 | num_back = 0
98 | for step, (src, tgt, src_masks, tgt_masks) in enumerate(dataloader):
99 | trans, lens = flownmt.translate_sample(src, src_masks, n_len=n_len, n_tr=n_tr, tau=tau)
100 | translations.append(trans)
101 | lengths.append(lens)
102 | num_insts += src.size(0)
103 | if step % 10 == 0:
104 | sys.stdout.write("\b" * num_back)
105 | sys.stdout.write(" " * num_back)
106 | sys.stdout.write("\b" * num_back)
107 | log_info = 'sampling (tau={:.1f}, n_len={}, n_tr={})...{}'.format(tau, n_len, n_tr, num_insts)
108 | sys.stdout.write(log_info)
109 | sys.stdout.flush()
110 | num_back = len(log_info)
111 | print('time: {:.1f}s'.format(time.time() - start_time))
112 | outfile = os.path.join(result_path, outfile)
113 | dataset.dump_to_file(translations, lengths, outfile, post_edit=False)
114 |
115 |
116 | def setup(args):
117 | args.cuda = torch.cuda.is_available()
118 | random_seed = args.seed
119 | random.seed(random_seed)
120 | np.random.seed(random_seed)
121 | torch.manual_seed(random_seed)
122 | device = torch.device('cuda', 0) if args.cuda else torch.device('cpu')
123 | if args.cuda:
124 | torch.cuda.set_device(device)
125 | torch.cuda.manual_seed(random_seed)
126 |
127 | torch.backends.cudnn.benchmark = False
128 |
129 | model_path = args.model_path
130 | result_path = os.path.join(model_path, 'translations')
131 | args.result_path = result_path
132 | params = json.load(open(os.path.join(model_path, 'config.json'), 'r'))
133 |
134 | src_lang = params['src']
135 | tgt_lang = params['tgt']
136 | data_path = args.data_path
137 | vocab_path = os.path.join(model_path, 'vocab')
138 | src_vocab_path = os.path.join(vocab_path, '{}.vocab'.format(src_lang))
139 | tgt_vocab_path = os.path.join(vocab_path, '{}.vocab'.format(tgt_lang))
140 | src_vocab_size = params['src_vocab_size']
141 | tgt_vocab_size = params['tgt_vocab_size']
142 | args.max_src_length = params.pop('max_src_length')
143 | args.max_tgt_length = params.pop('max_tgt_length')
144 | dataset = NMTDataSet(data_path, src_lang, tgt_lang,
145 | src_vocab_path, tgt_vocab_path,
146 | src_vocab_size, tgt_vocab_size,
147 | subword=args.subword, create_vocab=False)
148 | assert src_vocab_size == dataset.src_vocab_size
149 | assert tgt_vocab_size == dataset.tgt_vocab_size
150 |
151 | flownmt = FlowNMT.load(model_path, device=device)
152 | args.length_unit = flownmt.length_unit
153 | args.device = device
154 | return args, dataset, flownmt
155 |
156 |
157 | def init_dataloader(args, dataset):
158 | eval_batch = args.batch_size
159 | val_iter = DataIterator(dataset, eval_batch, 0, args.max_src_length, args.max_tgt_length, 1000, args.device, args.result_path,
160 | bucket_data=args.bucket_batch, multi_scale=args.length_unit, corpus="dev")
161 | test_iter = DataIterator(dataset, eval_batch, 0, args.max_src_length, args.max_tgt_length, 1000, args.device, args.result_path,
162 | bucket_data=args.bucket_batch, multi_scale=args.length_unit, corpus="test")
163 | return val_iter, test_iter
164 |
165 |
166 | def main(args):
167 | args, dataset, flownmt = setup(args)
168 | print(args)
169 |
170 | val_iter, test_iter = init_dataloader(args, dataset)
171 |
172 | result_path = args.result_path
173 | if args.decode == 'argmax':
174 | tau = args.tau
175 | n_tr = args.ntr
176 | outfile = 'argmax.t{:.1f}.ntr{}.dev.mt'.format(tau, n_tr)
177 | translate_argmax(dataset, val_iter, flownmt, result_path, outfile, tau, n_tr)
178 | outfile = 'argmax.t{:.1f}.ntr{}.test.mt'.format(tau, n_tr)
179 | translate_argmax(dataset, test_iter, flownmt, result_path, outfile, tau, n_tr)
180 | elif args.decode == 'iw':
181 | tau = args.tau
182 | n_len = args.nlen
183 | n_tr = args.ntr
184 | outfile = 'iw.t{:.1f}.nlen{}.ntr{}.dev.mt'.format(tau, n_len, n_tr)
185 | translate_iw(dataset, val_iter, flownmt, result_path, outfile, tau, n_len, n_tr)
186 | outfile = 'iw.t{:.1f}.nlen{}.ntr{}.test.mt'.format(tau, n_len, n_tr)
187 | translate_iw(dataset, test_iter, flownmt, result_path, outfile, tau, n_len, n_tr)
188 | else:
189 | assert not args.bucket_batch
190 | tau = args.tau
191 | n_len = args.nlen
192 | n_tr = args.ntr
193 | outfile = 'sample.t{:.1f}.nlen{}.ntr{}.dev.mt'.format(tau, n_len, n_tr)
194 | sample(dataset, val_iter, flownmt, result_path, outfile, tau, n_len, n_tr)
195 | outfile = 'sample.t{:.1f}.nlen{}.ntr{}.test.mt'.format(tau, n_len, n_tr)
196 | sample(dataset, test_iter, flownmt, result_path, outfile, tau, n_len, n_tr)
197 |
198 |
199 | if __name__ == "__main__":
200 | args = parse_translate_args()
201 | with torch.no_grad():
202 | main(args)
--------------------------------------------------------------------------------
/flownmt/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.flownmt import FlowNMT
2 |
3 |
--------------------------------------------------------------------------------
/flownmt/data/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'violet-zct'
2 |
3 | from flownmt.data.dataloader import NMTDataSet, DataIterator
4 |
--------------------------------------------------------------------------------
/flownmt/flows/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.flows.flow import Flow
2 | from flownmt.flows.actnorm import ActNormFlow
3 | from flownmt.flows.parallel import *
4 | from flownmt.flows.linear import InvertibleMultiHeadFlow, InvertibleLinearFlow
5 | from flownmt.flows.couplings import *
6 | from flownmt.flows.nmt import NMTFlow
7 |
--------------------------------------------------------------------------------
/flownmt/flows/actnorm.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import Parameter
7 |
8 | from flownmt.flows.flow import Flow
9 |
10 |
11 | class ActNormFlow(Flow):
12 | def __init__(self, in_features, inverse=False):
13 | super(ActNormFlow, self).__init__(inverse)
14 | self.in_features = in_features
15 | self.log_scale = Parameter(torch.Tensor(in_features))
16 | self.bias = Parameter(torch.Tensor(in_features))
17 | self.reset_parameters()
18 |
19 | def reset_parameters(self):
20 | nn.init.normal_(self.log_scale, mean=0, std=0.05)
21 | nn.init.constant_(self.bias, 0.)
22 |
23 | @overrides
24 | def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
25 | """
26 |
27 | Args:
28 | input: Tensor
29 | input tensor [batch, N1, N2, ..., Nl, in_features]
30 | mask: Tensor
31 | mask tensor [batch, N1, N2, ...,Nl]
32 |
33 | Returns: out: Tensor , logdet: Tensor
34 | out: [batch, N1, N2, ..., in_features], the output of the flow
35 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
36 |
37 | """
38 | dim = input.dim()
39 | out = input * self.log_scale.exp() + self.bias
40 | out = out * mask.unsqueeze(dim - 1)
41 | logdet = self.log_scale.sum(dim=0, keepdim=True)
42 | if dim > 2:
43 | num = mask.view(out.size(0), -1).sum(dim=1)
44 | logdet = logdet * num
45 | return out, logdet
46 |
47 | @overrides
48 | def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
49 | """
50 |
51 | Args:
52 | input: Tensor
53 | input tensor [batch, N1, N2, ..., Nl, in_features]
54 | mask: Tensor
55 | mask tensor [batch, N1, N2, ...,Nl]
56 |
57 | Returns: out: Tensor , logdet: Tensor
58 | out: [batch, N1, N2, ..., in_features], the output of the flow
59 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
60 |
61 | """
62 | dim = input.dim()
63 | out = (input - self.bias) * mask.unsqueeze(dim - 1)
64 | out = out.div(self.log_scale.exp() + 1e-8)
65 | logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0
66 | if dim > 2:
67 | num = mask.view(out.size(0), -1).sum(dim=1)
68 | logdet = logdet * num
69 | return out, logdet
70 |
71 | @overrides
72 | def init(self, data: torch.Tensor, mask: torch.Tensor, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
73 | """
74 |
75 | Args:
76 | data: input: Tensor
77 | input tensor [batch, N1, N2, ..., in_features]
78 | mask: Tensor
79 | mask tensor [batch, N1, N2, ...,Nl]
80 | init_scale: float
81 | initial scale
82 |
83 | Returns: out: Tensor , logdet: Tensor
84 | out: [batch, N1, N2, ..., in_features], the output of the flow
85 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
86 |
87 | """
88 | with torch.no_grad():
89 | out, _ = self.forward(data, mask)
90 | mean = out.view(-1, self.in_features).mean(dim=0)
91 | std = out.view(-1, self.in_features).std(dim=0)
92 | inv_stdv = init_scale / (std + 1e-6)
93 |
94 | self.log_scale.add_(inv_stdv.log())
95 | self.bias.add_(-mean).mul_(inv_stdv)
96 | return self.forward(data, mask)
97 |
98 | @overrides
99 | def extra_repr(self):
100 | return 'inverse={}, in_features={}'.format(self.inverse, self.in_features)
101 |
102 | @classmethod
103 | def from_params(cls, params: Dict) -> "ActNormFlow":
104 | return ActNormFlow(**params)
105 |
106 |
107 | ActNormFlow.register('actnorm')
108 |
--------------------------------------------------------------------------------
/flownmt/flows/couplings/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.flows.couplings.coupling import NICE
2 |
--------------------------------------------------------------------------------
/flownmt/flows/couplings/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
4 |
5 | from flownmt.nnet.weightnorm import Conv1dWeightNorm, LinearWeightNorm
6 | from flownmt.nnet.attention import GlobalAttention, MultiHeadAttention
7 | from flownmt.nnet.positional_encoding import PositionalEncoding
8 | from flownmt.nnet.transformer import TransformerDecoderLayer
9 |
10 |
11 | class NICEConvBlock(nn.Module):
12 | def __init__(self, src_features, in_features, out_features, hidden_features, kernel_size, dropout=0.0):
13 | super(NICEConvBlock, self).__init__()
14 | self.conv1 = Conv1dWeightNorm(in_features, hidden_features, kernel_size=kernel_size, padding=kernel_size // 2, bias=True)
15 | self.conv2 = Conv1dWeightNorm(hidden_features, hidden_features, kernel_size=kernel_size, padding=kernel_size // 2, bias=True)
16 | self.activation = nn.ELU(inplace=True)
17 | self.attn = GlobalAttention(src_features, hidden_features, hidden_features, dropout=dropout)
18 | self.linear = LinearWeightNorm(hidden_features * 2, out_features, bias=True)
19 |
20 | def forward(self, x, mask, src, src_mask):
21 | """
22 |
23 | Args:
24 | x: Tensor
25 | input tensor [batch, length, in_features]
26 | mask: Tensor
27 | x mask tensor [batch, length]
28 | src: Tensor
29 | source input tensor [batch, src_length, src_features]
30 | src_mask: Tensor
31 | source mask tensor [batch, src_length]
32 |
33 | Returns: Tensor
34 | out tensor [batch, length, out_features]
35 |
36 | """
37 | out = self.activation(self.conv1(x.transpose(1, 2)))
38 | out = self.activation(self.conv2(out)).transpose(1, 2) * mask.unsqueeze(2)
39 | out = self.attn(out, src, key_mask=src_mask.eq(0))
40 | out = self.linear(torch.cat([x, out], dim=2))
41 | return out
42 |
43 | def init(self, x, mask, src, src_mask, init_scale=1.0):
44 | out = self.activation(self.conv1.init(x.transpose(1, 2), init_scale=init_scale))
45 | out = self.activation(self.conv2.init(out, init_scale=init_scale)).transpose(1, 2) * mask.unsqueeze(2)
46 | out = self.attn.init(out, src, key_mask=src_mask.eq(0), init_scale=init_scale)
47 | out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0)
48 | return out
49 |
50 |
51 | class NICERecurrentBlock(nn.Module):
52 | def __init__(self, rnn_mode, src_features, in_features, out_features, hidden_features, dropout=0.0):
53 | super(NICERecurrentBlock, self).__init__()
54 | if rnn_mode == 'RNN':
55 | RNN = nn.RNN
56 | elif rnn_mode == 'LSTM':
57 | RNN = nn.LSTM
58 | elif rnn_mode == 'GRU':
59 | RNN = nn.GRU
60 | else:
61 | raise ValueError('Unknown RNN mode: %s' % rnn_mode)
62 |
63 | self.rnn = RNN(in_features, hidden_features // 2, batch_first=True, bidirectional=True)
64 | self.attn = GlobalAttention(src_features, hidden_features, hidden_features, dropout=dropout)
65 | self.linear = LinearWeightNorm(in_features + hidden_features, out_features, bias=True)
66 |
67 | def forward(self, x, mask, src, src_mask):
68 | lengths = mask.sum(dim=1).long()
69 | packed_out = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
70 | packed_out, _ = self.rnn(packed_out)
71 | out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=mask.size(1))
72 | # [batch, length, out_features]
73 | out = self.attn(out, src, key_mask=src_mask.eq(0))
74 | out = self.linear(torch.cat([x, out], dim=2))
75 | return out
76 |
77 | def init(self, x, mask, src, src_mask, init_scale=1.0):
78 | lengths = mask.sum(dim=1).long()
79 | packed_out = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
80 | packed_out, _ = self.rnn(packed_out)
81 | out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=mask.size(1))
82 | # [batch, length, out_features]
83 | out = self.attn.init(out, src, key_mask=src_mask.eq(0), init_scale=init_scale)
84 | out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0)
85 | return out
86 |
87 |
88 | class NICESelfAttnBlock(nn.Module):
89 | def __init__(self, src_features, in_features, out_features, hidden_features, heads, dropout=0.0,
90 | pos_enc='add', max_length=100):
91 | super(NICESelfAttnBlock, self).__init__()
92 | assert pos_enc in ['add', 'attn']
93 | self.src_proj = nn.Linear(src_features, in_features, bias=False) if src_features != in_features else None
94 | self.pos_enc = PositionalEncoding(in_features, padding_idx=None, init_size=max_length + 1)
95 | self.pos_attn = MultiHeadAttention(in_features, heads, dropout=dropout) if pos_enc == 'attn' else None
96 | self.transformer = TransformerDecoderLayer(in_features, hidden_features, heads, dropout=dropout)
97 | self.linear = LinearWeightNorm(in_features, out_features, bias=True)
98 |
99 | def forward(self, x, mask, src, src_mask):
100 | if self.src_proj is not None:
101 | src = self.src_proj(src)
102 |
103 | key_mask = mask.eq(0)
104 | pos_enc = self.pos_enc(x) * mask.unsqueeze(2)
105 | if self.pos_attn is None:
106 | x = x + pos_enc
107 | else:
108 | x = self.pos_attn(pos_enc, x, x, key_mask)
109 |
110 | x = self.transformer(x, key_mask, src, src_mask.eq(0))
111 | return self.linear(x)
112 |
113 | def init(self, x, mask, src, src_mask, init_scale=1.0):
114 | if self.src_proj is not None:
115 | src = self.src_proj(src)
116 |
117 | key_mask = mask.eq(0)
118 | pos_enc = self.pos_enc(x) * mask.unsqueeze(2)
119 | if self.pos_attn is None:
120 | x = x + pos_enc
121 | else:
122 | x = self.pos_attn(pos_enc, x, x, key_mask)
123 |
124 | x = self.transformer.init(x, key_mask, src, src_mask.eq(0), init_scale=init_scale)
125 | x = x * mask.unsqueeze(2)
126 | return self.linear.init(x, init_scale=0.0)
127 |
--------------------------------------------------------------------------------
/flownmt/flows/couplings/coupling.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Tuple, Dict
3 | import torch
4 |
5 | from flownmt.flows.couplings.blocks import NICEConvBlock, NICERecurrentBlock, NICESelfAttnBlock
6 | from flownmt.flows.flow import Flow
7 | from flownmt.flows.couplings.transform import Transform, Additive, Affine, NLSQ
8 |
9 |
10 | class NICE(Flow):
11 | """
12 | NICE Flow
13 | """
14 | def __init__(self, src_features, features, hidden_features=None, inverse=False, split_dim=2, split_type='continuous', order='up', factor=2,
15 | transform='affine', type='conv', kernel=3, rnn_mode='LSTM', heads=1, dropout=0.0, pos_enc='add', max_length=100):
16 | super(NICE, self).__init__(inverse)
17 | self.features = features
18 | assert split_dim in [1, 2]
19 | assert split_type in ['continuous', 'skip']
20 | if split_dim == 1:
21 | assert split_type == 'skip'
22 | if factor != 2:
23 | assert split_type == 'continuous'
24 | assert order in ['up', 'down']
25 | self.split_dim = split_dim
26 | self.split_type = split_type
27 | self.up = order == 'up'
28 | if split_dim == 2:
29 | out_features = features // factor
30 | in_features = features - out_features
31 | self.z1_channels = in_features if self.up else out_features
32 | else:
33 | in_features = features
34 | out_features = features
35 | self.z1_channels = None
36 | assert transform in ['additive', 'affine', 'nlsq']
37 | if transform == 'additive':
38 | self.transform = Additive
39 | elif transform == 'affine':
40 | self.transform = Affine
41 | out_features = out_features * 2
42 | elif transform == 'nlsq':
43 | self.transform = NLSQ
44 | out_features = out_features * 5
45 | else:
46 | raise ValueError('unknown transform: {}'.format(transform))
47 |
48 | if hidden_features is None:
49 | hidden_features = min(2 * in_features, 1024)
50 | assert type in ['conv', 'self_attn', 'rnn']
51 | if type == 'conv':
52 | self.net = NICEConvBlock(src_features, in_features, out_features, hidden_features, kernel_size=kernel, dropout=dropout)
53 | elif type == 'rnn':
54 | self.net = NICERecurrentBlock(rnn_mode, src_features, in_features, out_features, hidden_features, dropout=dropout)
55 | else:
56 | self.net = NICESelfAttnBlock(src_features, in_features, out_features, hidden_features,
57 | heads=heads, dropout=dropout, pos_enc=pos_enc, max_length=max_length)
58 |
59 | def split(self, z, mask):
60 | split_dim = self.split_dim
61 | split_type = self.split_type
62 | dim = z.size(split_dim)
63 | if split_type == 'continuous':
64 | return z.split([self.z1_channels, dim - self.z1_channels], dim=split_dim), mask
65 | elif split_type == 'skip':
66 | idx1 = torch.tensor(list(range(0, dim, 2))).to(z.device)
67 | idx2 = torch.tensor(list(range(1, dim, 2))).to(z.device)
68 | z1 = z.index_select(split_dim, idx1)
69 | z2 = z.index_select(split_dim, idx2)
70 | if split_dim == 1:
71 | mask = mask.index_select(split_dim, idx1)
72 | return (z1, z2), mask
73 | else:
74 | raise ValueError('unknown split type: {}'.format(split_type))
75 |
76 | def unsplit(self, z1, z2):
77 | split_dim = self.split_dim
78 | split_type = self.split_type
79 | if split_type == 'continuous':
80 | return torch.cat([z1, z2], dim=split_dim)
81 | elif split_type == 'skip':
82 | z = torch.cat([z1, z2], dim=split_dim)
83 | dim = z1.size(split_dim)
84 | idx = torch.tensor([i // 2 if i % 2 == 0 else i // 2 + dim for i in range(dim * 2)]).to(z.device)
85 | return z.index_select(split_dim, idx)
86 | else:
87 | raise ValueError('unknown split type: {}'.format(split_type))
88 |
89 | def calc_params(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor):
90 | params = self.net(z, mask, src, src_mask)
91 | return params
92 |
93 | def init_net(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor, init_scale=1.0):
94 | params = self.net.init(z, mask, src, src_mask, init_scale=init_scale)
95 | return params
96 |
97 | @overrides
98 | def forward(self, input: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
99 | """
100 | Args:
101 | input: Tensor
102 | input tensor [batch, length, in_features]
103 | mask: Tensor
104 | mask tensor [batch, length]
105 | src: Tensor
106 | source input tensor [batch, src_length, src_features]
107 | src_mask: Tensor
108 | source mask tensor [batch, src_length]
109 |
110 | Returns: out: Tensor , logdet: Tensor
111 | out: [batch, length, in_features], the output of the flow
112 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
113 | """
114 | # [batch, length, in_channels]
115 | (z1, z2), mask = self.split(input, mask)
116 | # [batch, length, features]
117 | z, zp = (z1, z2) if self.up else (z2, z1)
118 |
119 | params = self.calc_params(z, mask, src, src_mask)
120 | zp, logdet = self.transform.fwd(zp, mask, params)
121 |
122 | z1, z2 = (z, zp) if self.up else (zp, z)
123 | return self.unsplit(z1, z2), logdet
124 |
125 | @overrides
126 | def backward(self, input: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
127 | """
128 | Args:
129 | input: Tensor
130 | input tensor [batch, length, in_features]
131 | mask: Tensor
132 | mask tensor [batch, length]
133 | src: Tensor
134 | source input tensor [batch, src_length, src_features]
135 | src_mask: Tensor
136 | source mask tensor [batch, src_length]
137 |
138 | Returns: out: Tensor , logdet: Tensor
139 | out: [batch, length, in_features], the output of the flow
140 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
141 | """
142 | # [batch, length, in_channels]
143 | (z1, z2), mask = self.split(input, mask)
144 | # [batch, length, features]
145 | z, zp = (z1, z2) if self.up else (z2, z1)
146 |
147 | params = self.calc_params(z, mask, src, src_mask)
148 | zp, logdet = self.transform.bwd(zp, mask, params)
149 |
150 | z1, z2 = (z, zp) if self.up else (zp, z)
151 | return self.unsplit(z1, z2), logdet
152 |
153 | @overrides
154 | def init(self, data: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
155 | # [batch, length, in_channels]
156 | (z1, z2), mask = self.split(data, mask)
157 | # [batch, length, features]
158 | z, zp = (z1, z2) if self.up else (z2, z1)
159 |
160 | params = self.init_net(z, mask, src, src_mask, init_scale=init_scale)
161 | zp, logdet = self.transform.fwd(zp, mask, params)
162 |
163 | z1, z2 = (z, zp) if self.up else (zp, z)
164 | return self.unsplit(z1, z2), logdet
165 |
166 | @overrides
167 | def extra_repr(self):
168 | return 'inverse={}, in_channels={}, scale={}'.format(self.inverse, self.in_channels, self.scale)
169 |
170 | @classmethod
171 | def from_params(cls, params: Dict) -> "NICE":
172 | return NICE(**params)
173 |
174 |
175 | NICE.register('nice')
176 |
--------------------------------------------------------------------------------
/flownmt/flows/couplings/transform.py:
--------------------------------------------------------------------------------
1 | import math
2 | from overrides import overrides
3 | from typing import Tuple
4 | import torch
5 |
6 |
7 | class Transform():
8 | @staticmethod
9 | def fwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]:
10 | raise NotImplementedError
11 |
12 | @staticmethod
13 | def bwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]:
14 | raise NotImplementedError
15 |
16 |
17 | class Additive(Transform):
18 | @staticmethod
19 | @overrides
20 | def fwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]:
21 | mu = params
22 | z = (z + mu).mul(mask.unsqueeze(2))
23 | logdet = z.new_zeros(z.size(0))
24 | return z, logdet
25 |
26 | @staticmethod
27 | @overrides
28 | def bwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]:
29 | mu = params
30 | z = (z - mu).mul(mask.unsqueeze(2))
31 | logdet = z.new_zeros(z.size(0))
32 | return z, logdet
33 |
34 |
35 | class Affine(Transform):
36 | @staticmethod
37 | @overrides
38 | def fwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]:
39 | mu, log_scale = params.chunk(2, dim=2)
40 | scale = log_scale.add_(2.0).sigmoid_()
41 | z = (scale * z + mu).mul(mask.unsqueeze(2))
42 | logdet = scale.log().mul(mask.unsqueeze(2)).view(z.size(0), -1).sum(dim=1)
43 | return z, logdet
44 |
45 | @staticmethod
46 | @overrides
47 | def bwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]:
48 | mu, log_scale = params.chunk(2, dim=2)
49 | scale = log_scale.add_(2.0).sigmoid_()
50 | z = (z - mu).div(scale + 1e-12).mul(mask.unsqueeze(2))
51 | logdet = scale.log().mul(mask.unsqueeze(2)).view(z.size(0), -1).sum(dim=1) * -1.0
52 | return z, logdet
53 |
54 |
55 | def arccosh(x):
56 | return torch.log(x + torch.sqrt(x.pow(2) - 1))
57 |
58 |
59 | def arcsinh(x):
60 | return torch.log(x + torch.sqrt(x.pow(2) + 1))
61 |
62 |
63 | class NLSQ(Transform):
64 | # A = 8 * math.sqrt(3) / 9 - 0.05 # 0.05 is a small number to prevent exactly 0 slope
65 | logA = math.log(8 * math.sqrt(3) / 9 - 0.05) # 0.05 is a small number to prevent exactly 0 slope
66 |
67 | @staticmethod
68 | def get_pseudo_params(params):
69 | a, logb, cprime, logd, g = params.chunk(5, dim=2)
70 |
71 | # for stability
72 | logb = logb.mul_(0.4)
73 | cprime = cprime.mul_(0.3)
74 | logd = logd.mul_(0.4)
75 |
76 | # b = logb.add_(2.0).sigmoid_()
77 | # d = logd.add_(2.0).sigmoid_()
78 | # c = (NLSQ.A * b / d).mul(cprime.tanh_())
79 |
80 | c = (NLSQ.logA + logb - logd).exp_().mul(cprime.tanh_())
81 | b = logb.exp_()
82 | d = logd.exp_()
83 | return a, b, c, d, g
84 |
85 | @staticmethod
86 | @overrides
87 | def fwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]:
88 | a, b, c, d, g = NLSQ.get_pseudo_params(params)
89 |
90 | arg = (d * z).add_(g)
91 | denom = arg.pow(2).add_(1)
92 | c = c / denom
93 | z = (b * z + a + c).mul(mask.unsqueeze(2))
94 | logdet = torch.log(b - 2 * c * d * arg / denom)
95 | logdet = logdet.mul(mask.unsqueeze(2)).view(z.size(0), -1).sum(dim=1)
96 | return z, logdet
97 |
98 | @staticmethod
99 | @overrides
100 | def bwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]:
101 | a, b, c, d, g = NLSQ.get_pseudo_params(params)
102 |
103 | # double needed for stability. No effect on overall speed
104 | a = a.double()
105 | b = b.double()
106 | c = c.double()
107 | d = d.double()
108 | g = g.double()
109 | z = z.double()
110 |
111 | aa = -b * d.pow(2)
112 | bb = (z - a) * d.pow(2) - 2 * b * d * g
113 | cc = (z - a) * 2 * d * g - b * (1 + g.pow(2))
114 | dd = (z - a) * (1 + g.pow(2)) - c
115 |
116 | p = (3 * aa * cc - bb.pow(2)) / (3 * aa.pow(2))
117 | q = (2 * bb.pow(3) - 9 * aa * bb * cc + 27 * aa.pow(2) * dd) / (27 * aa.pow(3))
118 |
119 | t = -2 * torch.abs(q) / q * torch.sqrt(torch.abs(p) / 3)
120 | inter_term1 = -3 * torch.abs(q) / (2 * p) * torch.sqrt(3 / torch.abs(p))
121 | inter_term2 = 1 / 3 * arccosh(torch.abs(inter_term1 - 1) + 1)
122 | t = t * torch.cosh(inter_term2)
123 |
124 | tpos = -2 * torch.sqrt(torch.abs(p) / 3)
125 | inter_term1 = 3 * q / (2 * p) * torch.sqrt(3 / torch.abs(p))
126 | inter_term2 = 1 / 3 * arcsinh(inter_term1)
127 | tpos = tpos * torch.sinh(inter_term2)
128 |
129 | t[p > 0] = tpos[p > 0]
130 | z = t - bb / (3 * aa)
131 | arg = d * z + g
132 | denom = arg.pow(2) + 1
133 | logdet = torch.log(b - 2 * c * d * arg / denom.pow(2))
134 |
135 | z = z.float().mul(mask.unsqueeze(2))
136 | logdet = logdet.float().mul(mask.unsqueeze(2)).view(z.size(0), -1).sum(dim=1) * -1.0
137 | return z, logdet
138 |
--------------------------------------------------------------------------------
/flownmt/flows/flow.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class Flow(nn.Module):
7 | """
8 | Normalizing Flow base class
9 | """
10 | _registry = dict()
11 |
12 | def __init__(self, inverse):
13 | super(Flow, self).__init__()
14 | self.inverse = inverse
15 |
16 | def forward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
17 | """
18 |
19 | Args:
20 | *input: input [batch, *input_size]
21 |
22 | Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch]
23 | out, the output of the flow
24 | logdet, the log determinant of :math:`\partial output / \partial input`
25 | """
26 | raise NotImplementedError
27 |
28 | def backward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
29 | """
30 |
31 | Args:
32 | *input: input [batch, *input_size]
33 |
34 | Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch]
35 | out, the output of the flow
36 | logdet, the log determinant of :math:`\partial output / \partial input`
37 | """
38 | raise NotImplementedError
39 |
40 | def init(self, *input, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
41 | raise NotImplementedError
42 |
43 | def fwdpass(self, x: torch.Tensor, *h, init=False, init_scale=1.0, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
44 | """
45 |
46 | Args:
47 | x: Tensor
48 | The random variable before flow
49 | h: list of object
50 | other conditional inputs
51 | init: bool
52 | perform initialization or not (default: False)
53 | init_scale: float
54 | initial scale (default: 1.0)
55 |
56 | Returns: y: Tensor, logdet: Tensor
57 | y, the random variable after flow
58 | logdet, the log determinant of :math:`\partial y / \partial x`
59 | Then the density :math:`\log(p(y)) = \log(p(x)) - logdet`
60 |
61 | """
62 | if self.inverse:
63 | if init:
64 | raise RuntimeError('inverse flow shold be initialized with backward pass')
65 | else:
66 | return self.backward(x, *h, **kwargs)
67 | else:
68 | if init:
69 | return self.init(x, *h, init_scale=init_scale, **kwargs)
70 | else:
71 | return self.forward(x, *h, **kwargs)
72 |
73 | def bwdpass(self, y: torch.Tensor, *h, init=False, init_scale=1.0, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
74 | """
75 |
76 | Args:
77 | y: Tensor
78 | The random variable after flow
79 | h: list of object
80 | other conditional inputs
81 | init: bool
82 | perform initialization or not (default: False)
83 | init_scale: float
84 | initial scale (default: 1.0)
85 |
86 | Returns: x: Tensor, logdet: Tensor
87 | x, the random variable before flow
88 | logdet, the log determinant of :math:`\partial x / \partial y`
89 | Then the density :math:`\log(p(y)) = \log(p(x)) + logdet`
90 |
91 | """
92 | if self.inverse:
93 | if init:
94 | return self.init(y, *h, init_scale=init_scale, **kwargs)
95 | else:
96 | return self.forward(y, *h, **kwargs)
97 | else:
98 | if init:
99 | raise RuntimeError('forward flow should be initialzed with forward pass')
100 | else:
101 | return self.backward(y, *h, **kwargs)
102 |
103 | @classmethod
104 | def register(cls, name: str):
105 | Flow._registry[name] = cls
106 |
107 | @classmethod
108 | def by_name(cls, name: str):
109 | return Flow._registry[name]
110 |
111 | @classmethod
112 | def from_params(cls, params: Dict):
113 | raise NotImplementedError
114 |
--------------------------------------------------------------------------------
/flownmt/flows/linear.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import Parameter
7 |
8 | from flownmt.flows.flow import Flow
9 |
10 |
11 | class InvertibleLinearFlow(Flow):
12 | def __init__(self, in_features, inverse=False):
13 | super(InvertibleLinearFlow, self).__init__(inverse)
14 | self.in_features = in_features
15 | self.weight = Parameter(torch.Tensor(in_features, in_features))
16 | self.register_buffer('weight_inv', self.weight.data.clone())
17 | self.reset_parameters()
18 |
19 | def reset_parameters(self):
20 | nn.init.orthogonal_(self.weight)
21 | self.sync()
22 |
23 | def sync(self):
24 | self.weight_inv.copy_(self.weight.data.inverse())
25 |
26 | @overrides
27 | def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
28 | """
29 |
30 | Args:
31 | input: Tensor
32 | input tensor [batch, N1, N2, ..., Nl, in_features]
33 | mask: Tensor
34 | mask tensor [batch, N1, N2, ...,Nl]
35 |
36 | Returns: out: Tensor , logdet: Tensor
37 | out: [batch, N1, N2, ..., in_features], the output of the flow
38 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
39 |
40 | """
41 | dim = input.dim()
42 | # [batch, N1, N2, ..., in_features]
43 | out = F.linear(input, self.weight)
44 | _, logdet = torch.slogdet(self.weight)
45 | if dim > 2:
46 | num = mask.view(out.size(0), -1).sum(dim=1)
47 | logdet = logdet * num
48 | return out, logdet
49 |
50 | @overrides
51 | def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
52 | """
53 |
54 | Args:
55 | input: Tensor
56 | input tensor [batch, N1, N2, ..., Nl, in_features]
57 | mask: Tensor
58 | mask tensor [batch, N1, N2, ...,Nl]
59 |
60 | Returns: out: Tensor , logdet: Tensor
61 | out: [batch, N1, N2, ..., in_features], the output of the flow
62 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
63 |
64 | """
65 | dim = input.dim()
66 | # [batch, N1, N2, ..., in_features]
67 | out = F.linear(input, self.weight_inv)
68 | _, logdet = torch.slogdet(self.weight_inv)
69 | if dim > 2:
70 | num = mask.view(out.size(0), -1).sum(dim=1)
71 | logdet = logdet * num
72 | return out, logdet
73 |
74 | @overrides
75 | def init(self, data: torch.Tensor, mask: torch.Tensor, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
76 | with torch.no_grad():
77 | return self.forward(data)
78 |
79 | @overrides
80 | def extra_repr(self):
81 | return 'inverse={}, in_features={}'.format(self.inverse, self.in_features)
82 |
83 | @classmethod
84 | def from_params(cls, params: Dict) -> "InvertibleLinearFlow":
85 | return InvertibleLinearFlow(**params)
86 |
87 |
88 | class InvertibleMultiHeadFlow(Flow):
89 | @staticmethod
90 | def _get_heads(in_features):
91 | units = [32, 16, 8]
92 | for unit in units:
93 | if in_features % unit == 0:
94 | return in_features // unit
95 | assert in_features < 8, 'features={}'.format(in_features)
96 | return 1
97 |
98 | def __init__(self, in_features, heads=None, type='A', inverse=False):
99 | super(InvertibleMultiHeadFlow, self).__init__(inverse)
100 | self.in_features = in_features
101 | if heads is None:
102 | heads = InvertibleMultiHeadFlow._get_heads(in_features)
103 | self.heads = heads
104 | self.type = type
105 | assert in_features % heads == 0, 'features ({}) should be divided by heads ({})'.format(in_features, heads)
106 | assert type in ['A', 'B'], 'type should belong to [A, B]'
107 | self.weight = Parameter(torch.Tensor(in_features // heads, in_features // heads))
108 | self.register_buffer('weight_inv', self.weight.data.clone())
109 | self.reset_parameters()
110 |
111 | def reset_parameters(self):
112 | nn.init.orthogonal_(self.weight)
113 | self.sync()
114 |
115 | def sync(self):
116 | self.weight_inv.copy_(self.weight.data.inverse())
117 |
118 | @overrides
119 | def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
120 | """
121 |
122 | Args:
123 | input: Tensor
124 | input tensor [batch, N1, N2, ..., Nl, in_features]
125 | mask: Tensor
126 | mask tensor [batch, N1, N2, ...,Nl]
127 |
128 | Returns: out: Tensor , logdet: Tensor
129 | out: [batch, N1, N2, ..., in_features], the output of the flow
130 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
131 |
132 | """
133 | size = input.size()
134 | dim = input.dim()
135 | # [batch, N1, N2, ..., heads, in_features/ heads]
136 | if self.type == 'A':
137 | out = input.view(*size[:-1], self.heads, self.in_features // self.heads)
138 | else:
139 | out = input.view(*size[:-1], self.in_features // self.heads, self.heads).transpose(-2, -1)
140 |
141 | out = F.linear(out, self.weight)
142 | if self.type == 'B':
143 | out = out.transpose(-2, -1).contiguous()
144 | out = out.view(*size)
145 |
146 | _, logdet = torch.slogdet(self.weight)
147 | if dim > 2:
148 | num = mask.view(size[0], -1).sum(dim=1) * self.heads
149 | logdet = logdet * num
150 | return out, logdet
151 |
152 | @overrides
153 | def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
154 | """
155 |
156 | Args:
157 | input: Tensor
158 | input tensor [batch, N1, N2, ..., Nl, in_features]
159 | mask: Tensor
160 | mask tensor [batch, N1, N2, ...,Nl]
161 |
162 | Returns: out: Tensor , logdet: Tensor
163 | out: [batch, N1, N2, ..., in_features], the output of the flow
164 | logdet: [batch], the log determinant of :math:`\partial output / \partial input`
165 |
166 | """
167 | size = input.size()
168 | dim = input.dim()
169 | # [batch, N1, N2, ..., heads, in_features/ heads]
170 | if self.type == 'A':
171 | out = input.view(*size[:-1], self.heads, self.in_features // self.heads)
172 | else:
173 | out = input.view(*size[:-1], self.in_features // self.heads, self.heads).transpose(-2, -1)
174 |
175 | out = F.linear(out, self.weight_inv)
176 | if self.type == 'B':
177 | out = out.transpose(-2, -1).contiguous()
178 | out = out.view(*size)
179 |
180 | _, logdet = torch.slogdet(self.weight_inv)
181 | if dim > 2:
182 | num = mask.view(size[0], -1).sum(dim=1) * self.heads
183 | logdet = logdet * num
184 | return out, logdet
185 |
186 | @overrides
187 | def init(self, data: torch.Tensor, mask: torch.Tensor, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
188 | with torch.no_grad():
189 | return self.forward(data, mask)
190 |
191 | @overrides
192 | def extra_repr(self):
193 | return 'inverse={}, in_features={}, heads={}, type={}'.format(self.inverse, self.in_features, self.heads, self.type)
194 |
195 | @classmethod
196 | def from_params(cls, params: Dict) -> "InvertibleMultiHeadFlow":
197 | return InvertibleMultiHeadFlow(**params)
198 |
199 |
200 | InvertibleLinearFlow.register('invertible_linear')
201 | InvertibleMultiHeadFlow.register('invertible_multihead')
202 |
--------------------------------------------------------------------------------
/flownmt/flows/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.flows.parallel.data_parallel import DataParallelFlow
2 |
--------------------------------------------------------------------------------
/flownmt/flows/parallel/data_parallel.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Tuple
3 | import torch
4 | from torch.nn.parallel.replicate import replicate
5 | from flownmt.flows.parallel.parallel_apply import parallel_apply
6 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
7 | from torch.nn.parallel.data_parallel import _check_balance
8 |
9 | from flownmt.flows.flow import Flow
10 |
11 |
12 | class DataParallelFlow(Flow):
13 | """
14 | Implements data parallelism at the flow level.
15 | """
16 | def __init__(self, flow: Flow, device_ids=None, output_device=None, dim=0):
17 | super(DataParallelFlow, self).__init__(flow.inverse)
18 |
19 | if not torch.cuda.is_available():
20 | self.flow = flow
21 | self.device_ids = []
22 | return
23 |
24 | if device_ids is None:
25 | device_ids = list(range(torch.cuda.device_count()))
26 | if output_device is None:
27 | output_device = device_ids[0]
28 | self.dim = dim
29 | self.flow = flow
30 | self.device_ids = device_ids
31 | self.output_device = output_device
32 |
33 | _check_balance(self.device_ids)
34 |
35 | if len(self.device_ids) == 1:
36 | self.flow.cuda(device_ids[0])
37 |
38 | @overrides
39 | def forward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
40 | if not self.device_ids:
41 | return self.flow.forward(*inputs, **kwargs)
42 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
43 | if len(self.device_ids) == 1:
44 | return self.flow.forward(*inputs[0], **kwargs[0])
45 | replicas = self.replicate(self.flow, self.device_ids[:len(inputs)])
46 | outputs = self.parallel_apply(replicas, inputs, kwargs)
47 | return self.gather(outputs, self.output_device)
48 |
49 | @overrides
50 | def backward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
51 | if not self.device_ids:
52 | return self.flow.backward(*inputs, **kwargs)
53 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
54 | if len(self.device_ids) == 1:
55 | return self.flow.backward(*inputs[0], **kwargs[0])
56 | replicas = self.replicate(self.flow, self.device_ids[:len(inputs)])
57 | outputs = self.parallel_apply(replicas, inputs, kwargs, backward=True)
58 | return self.gather(outputs, self.output_device)
59 |
60 | @overrides
61 | def init(self, *input, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
62 | return self.flow.init(*input, **kwargs)
63 |
64 | def replicate(self, flow, device_ids):
65 | return replicate(flow, device_ids)
66 |
67 | def scatter(self, inputs, kwargs, device_ids):
68 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
69 |
70 | def parallel_apply(self, replicas, inputs, kwargs, backward=False):
71 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)], backward=backward)
72 |
73 | def gather(self, outputs, output_device):
74 | return gather(outputs, output_device, dim=self.dim)
75 |
--------------------------------------------------------------------------------
/flownmt/flows/parallel/parallel_apply.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import torch
3 |
4 |
5 | def get_a_var(obj):
6 | if isinstance(obj, torch.Tensor):
7 | return obj
8 |
9 | if isinstance(obj, list) or isinstance(obj, tuple):
10 | for result in map(get_a_var, obj):
11 | if isinstance(result, torch.Tensor):
12 | return result
13 | if isinstance(obj, dict):
14 | for result in map(get_a_var, obj.items()):
15 | if isinstance(result, torch.Tensor):
16 | return result
17 | return None
18 |
19 |
20 | def parallel_apply(flows, inputs, kwargs_tup=None, devices=None, backward=False):
21 | r"""Applies each `module` in :attr:`modules` in parallel on arguments
22 | contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
23 | on each of :attr:`devices`.
24 |
25 | :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
26 | :attr:`devices` (if given) should all have same length. Moreover, each
27 | element of :attr:`inputs` can either be a single object as the only argument
28 | to a module, or a collection of positional arguments.
29 | """
30 | assert len(flows) == len(inputs)
31 | if kwargs_tup is not None:
32 | assert len(flows) == len(kwargs_tup)
33 | else:
34 | kwargs_tup = ({},) * len(flows)
35 | if devices is not None:
36 | assert len(flows) == len(devices)
37 | else:
38 | devices = [None] * len(flows)
39 |
40 | lock = threading.Lock()
41 | results = {}
42 | grad_enabled = torch.is_grad_enabled()
43 |
44 | def _worker(i, flow, input, kwargs, device=None, back=False):
45 | torch.set_grad_enabled(grad_enabled)
46 | if device is None:
47 | device = get_a_var(input).get_device()
48 | try:
49 | with torch.cuda.device(device):
50 | # this also avoids accidental slicing of `input` if it is a Tensor
51 | if not isinstance(input, (list, tuple)):
52 | input = (input,)
53 | output = flow.backward(*input, **kwargs) if back else flow.forward(*input, **kwargs)
54 | with lock:
55 | results[i] = output
56 | except Exception as e:
57 | with lock:
58 | results[i] = e
59 |
60 | if len(flows) > 1:
61 | threads = [threading.Thread(target=_worker,
62 | args=(i, flow, input, kwargs, device, backward))
63 | for i, (flow, input, kwargs, device) in
64 | enumerate(zip(flows, inputs, kwargs_tup, devices))]
65 |
66 | for thread in threads:
67 | thread.start()
68 | for thread in threads:
69 | thread.join()
70 | else:
71 | _worker(0, flows[0], inputs[0], kwargs_tup[0], devices[0], backward)
72 |
73 | outputs = []
74 | for i in range(len(inputs)):
75 | output = results[i]
76 | if isinstance(output, Exception):
77 | raise output
78 | outputs.append(output)
79 | return outputs
80 |
--------------------------------------------------------------------------------
/flownmt/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.modules.encoders import *
2 | from flownmt.modules.posteriors import *
3 | from flownmt.modules.decoders import *
4 | from flownmt.modules.priors import *
5 |
--------------------------------------------------------------------------------
/flownmt/modules/decoders/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.modules.decoders.decoder import Decoder
2 | from flownmt.modules.decoders.simple import SimpleDecoder
3 | from flownmt.modules.decoders.rnn import RecurrentDecoder
4 | from flownmt.modules.decoders.transformer import TransformerDecoder
5 |
--------------------------------------------------------------------------------
/flownmt/modules/decoders/decoder.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 | import torch
3 | import torch.nn as nn
4 |
5 | from flownmt.nnet.criterion import LabelSmoothedCrossEntropyLoss
6 |
7 |
8 | class Decoder(nn.Module):
9 | """
10 | Decoder to predict translations from latent z
11 | """
12 | _registry = dict()
13 |
14 | def __init__(self, vocab_size, latent_dim, label_smoothing=0., _shared_weight=None):
15 | super(Decoder, self).__init__()
16 | self.readout = nn.Linear(latent_dim, vocab_size, bias=True)
17 | if _shared_weight is not None:
18 | self.readout.weight = _shared_weight
19 | nn.init.constant_(self.readout.bias, 0.)
20 | else:
21 | self.reset_parameters(latent_dim)
22 |
23 | if label_smoothing < 1e-5:
24 | self.criterion = nn.CrossEntropyLoss(reduction='none')
25 | elif 1e-5 < label_smoothing < 1.0:
26 | self.criterion = LabelSmoothedCrossEntropyLoss(label_smoothing)
27 | else:
28 | raise ValueError('label smoothing should be less than 1.0.')
29 |
30 | def reset_parameters(self, dim):
31 | # nn.init.normal_(self.readout.weight, mean=0, std=dim ** -0.5)
32 | nn.init.uniform_(self.readout.weight, -0.1, 0.1)
33 | nn.init.constant_(self.readout.bias, 0.)
34 |
35 | def init(self, z, mask, src, src_mask, init_scale=1.0):
36 | raise NotImplementedError
37 |
38 | def decode(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
39 | """
40 |
41 | Args:
42 | z: Tensor
43 | latent code [batch, length, hidden_size]
44 | mask: Tensor
45 | mask [batch, length]
46 | src: Tensor
47 | src encoding [batch, src_length, hidden_size]
48 | src_mask: Tensor
49 | source mask [batch, src_length]
50 |
51 | Returns: Tensor1, Tensor2
52 | Tenser1: decoded word index [batch, length]
53 | Tensor2: log probabilities of decoding [batch]
54 |
55 | """
56 | raise NotImplementedError
57 |
58 | def loss(self, z: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
59 | src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
60 | """
61 |
62 | Args:
63 | z: Tensor
64 | latent codes [batch, length, hidden_size]
65 | target: LongTensor
66 | target translations [batch, length]
67 | mask: Tensor
68 | masks for target sentence [batch, length]
69 | src: Tensor
70 | src encoding [batch, src_length, hidden_size]
71 | src_mask: Tensor
72 | source mask [batch, src_length]
73 |
74 | Returns: Tensor
75 | tensor for loss [batch]
76 |
77 | """
78 | raise NotImplementedError
79 |
80 | @classmethod
81 | def register(cls, name: str):
82 | Decoder._registry[name] = cls
83 |
84 | @classmethod
85 | def by_name(cls, name: str):
86 | return Decoder._registry[name]
87 |
88 | @classmethod
89 | def from_params(cls, params: Dict) -> "Decoder":
90 | raise NotImplementedError
91 |
92 |
93 | Decoder.register('simple')
94 |
--------------------------------------------------------------------------------
/flownmt/modules/decoders/rnn.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
7 |
8 | from flownmt.modules.decoders.decoder import Decoder
9 | from flownmt.nnet.attention import GlobalAttention
10 |
11 |
12 | class RecurrentDecoder(Decoder):
13 | """
14 | Decoder with Recurrent Neural Networks
15 | """
16 | def __init__(self, vocab_size, latent_dim, rnn_mode, num_layers, hidden_size, bidirectional=True,
17 | dropout=0.0, dropword=0.0, label_smoothing=0., _shared_weight=None):
18 | super(RecurrentDecoder, self).__init__(vocab_size, latent_dim,
19 | label_smoothing=label_smoothing,
20 | _shared_weight=_shared_weight)
21 |
22 | if rnn_mode == 'RNN':
23 | RNN = nn.RNN
24 | elif rnn_mode == 'LSTM':
25 | RNN = nn.LSTM
26 | elif rnn_mode == 'GRU':
27 | RNN = nn.GRU
28 | else:
29 | raise ValueError('Unknown RNN mode: %s' % rnn_mode)
30 | assert hidden_size % 2 == 0
31 | # RNN for processing latent variables zs
32 | if bidirectional:
33 | self.rnn = RNN(latent_dim, hidden_size // 2, num_layers=num_layers, batch_first=True, bidirectional=True)
34 | else:
35 | self.rnn = RNN(latent_dim, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False)
36 |
37 | self.attn = GlobalAttention(latent_dim, hidden_size, latent_dim, hidden_features=hidden_size)
38 | self.ctx_proj = nn.Sequential(nn.Linear(latent_dim + hidden_size, latent_dim), nn.ELU())
39 | self.dropout = dropout
40 | self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens
41 |
42 | def forward(self, z, mask, src, src_mask):
43 | lengths = mask.sum(dim=1).long()
44 | if self.dropout2d is not None:
45 | z = self.dropout2d(z)
46 |
47 | packed_z = pack_padded_sequence(z, lengths, batch_first=True, enforce_sorted=False)
48 | packed_enc, _ = self.rnn(packed_z)
49 | enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=mask.size(1))
50 |
51 | ctx = self.attn(enc, src, key_mask=src_mask.eq(0))
52 | ctx = torch.cat([ctx, enc], dim=2)
53 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
54 | return self.readout(ctx)
55 |
56 | @overrides
57 | def init(self, z, mask, src, src_mask, init_scale=1.0):
58 | with torch.no_grad():
59 | return self(z, mask, src, src_mask)
60 |
61 | @overrides
62 | def decode(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
63 | """
64 |
65 | Args:
66 | z: Tensor
67 | latent code [batch, length, hidden_size]
68 | mask: Tensor
69 | mask [batch, length]
70 | src: Tensor
71 | src encoding [batch, src_length, hidden_size]
72 | src_mask: Tensor
73 | source mask [batch, src_length]
74 |
75 | Returns: Tensor1, Tensor2
76 | Tenser1: decoded word index [batch, length]
77 | Tensor2: log probabilities of decoding [batch]
78 |
79 | """
80 | # [batch, length, vocab_size]
81 | log_probs = F.log_softmax(self(z, mask, src, src_mask), dim=2)
82 | # [batch, length]
83 | log_probs, dec = log_probs.max(dim=2)
84 | dec = dec * mask.long()
85 | # [batch]
86 | log_probs = log_probs.mul(mask).sum(dim=1)
87 | return dec, log_probs
88 |
89 | @overrides
90 | def loss(self, z: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
91 | src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
92 | """
93 |
94 | Args:
95 | z: Tensor
96 | latent codes [batch, length, hidden_size]
97 | target: LongTensor
98 | target translations [batch, length]
99 | mask: Tensor
100 | masks for target sentence [batch, length]
101 | src: Tensor
102 | src encoding [batch, src_length, hidden_size]
103 | src_mask: Tensor
104 | source mask [batch, src_length]
105 |
106 | Returns: Tensor
107 | tensor for loss [batch]
108 |
109 | """
110 | # [batch, length, vocab_size] -> [batch, vocab_size, length]
111 | logits = self(z, mask, src, src_mask).transpose(1, 2)
112 | # [batch, length]
113 | loss = self.criterion(logits, target).mul(mask)
114 | return loss.sum(dim=1)
115 |
116 | @classmethod
117 | def from_params(cls, params: Dict) -> "RecurrentDecoder":
118 | return RecurrentDecoder(**params)
119 |
120 |
121 | RecurrentDecoder.register('rnn')
122 |
--------------------------------------------------------------------------------
/flownmt/modules/decoders/simple.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from flownmt.modules.decoders.decoder import Decoder
8 | from flownmt.nnet.attention import GlobalAttention
9 |
10 | class SimpleDecoder(Decoder):
11 | """
12 | Simple Decoder to predict translations from latent z
13 | """
14 |
15 | def __init__(self, vocab_size, latent_dim, hidden_size, dropout=0.0, label_smoothing=0., _shared_weight=None):
16 | super(SimpleDecoder, self).__init__(vocab_size, latent_dim,
17 | label_smoothing=label_smoothing,
18 | _shared_weight=_shared_weight)
19 | self.attn = GlobalAttention(latent_dim, latent_dim, latent_dim, hidden_features=hidden_size)
20 | ctx_features = latent_dim * 2
21 | self.ctx_proj = nn.Sequential(nn.Linear(ctx_features, latent_dim), nn.ELU())
22 | self.dropout = dropout
23 |
24 | @overrides
25 | def forward(self, z, src, src_mask):
26 | ctx = self.attn(z, src, key_mask=src_mask.eq(0))
27 | ctx = F.dropout(self.ctx_proj(torch.cat([ctx, z], dim=2)), p=self.dropout, training=self.training)
28 | return self.readout(ctx)
29 |
30 | @overrides
31 | def init(self, z, mask, src, src_mask, init_scale=1.0):
32 | with torch.no_grad():
33 | self(z, src, src_mask)
34 |
35 | @overrides
36 | def decode(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
37 | """
38 |
39 | Args:
40 | z: Tensor
41 | latent code [batch, length, hidden_size]
42 | mask: Tensor
43 | mask [batch, length]
44 | src: Tensor
45 | src encoding [batch, src_length, hidden_size]
46 | src_mask: Tensor
47 | source mask [batch, src_length]
48 |
49 | Returns: Tensor1, Tensor2
50 | Tenser1: decoded word index [batch, length]
51 | Tensor2: log probabilities of decoding [batch]
52 |
53 | """
54 | # [batch, length, vocab_size]
55 | log_probs = F.log_softmax(self(z, src, src_mask), dim=2)
56 | # [batch, length]
57 | log_probs, dec = log_probs.max(dim=2)
58 | dec = dec * mask.long()
59 | # [batch]
60 | log_probs = log_probs.mul(mask).sum(dim=1)
61 | return dec, log_probs
62 |
63 | @overrides
64 | def loss(self, z: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
65 | src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
66 | """
67 |
68 | Args:
69 | z: Tensor
70 | latent codes [batch, length, hidden_size]
71 | target: LongTensor
72 | target translations [batch, length]
73 | mask: Tensor
74 | masks for target sentence [batch, length]
75 | src: Tensor
76 | src encoding [batch, src_length, hidden_size]
77 | src_mask: Tensor
78 | source mask [batch, src_length]
79 |
80 | Returns: Tensor
81 | tensor for loss [batch]
82 |
83 | """
84 | # [batch, length, vocab_size] -> [batch, vocab_size, length]
85 | logits = self(z, src, src_mask).transpose(1, 2)
86 | # [batch, length]
87 | loss = self.criterion(logits, target).mul(mask)
88 | return loss.sum(dim=1)
89 |
90 | @classmethod
91 | def from_params(cls, params: Dict) -> "SimpleDecoder":
92 | return SimpleDecoder(**params)
93 |
94 |
95 | SimpleDecoder.register('simple')
96 |
--------------------------------------------------------------------------------
/flownmt/modules/decoders/transformer.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from flownmt.modules.decoders.decoder import Decoder
8 | from flownmt.nnet.attention import MultiHeadAttention
9 | from flownmt.nnet.transformer import TransformerDecoderLayer
10 | from flownmt.nnet.positional_encoding import PositionalEncoding
11 |
12 |
13 | class TransformerDecoder(Decoder):
14 | """
15 | Decoder with Transformer
16 | """
17 | def __init__(self, vocab_size, latent_dim, num_layers, hidden_size, heads, label_smoothing=0.,
18 | dropout=0.0, dropword=0.0, max_length=100, _shared_weight=None):
19 | super(TransformerDecoder, self).__init__(vocab_size, latent_dim,
20 | label_smoothing=label_smoothing,
21 | _shared_weight=_shared_weight)
22 | self.pos_enc = PositionalEncoding(latent_dim, None, max_length + 1)
23 | self.pos_attn = MultiHeadAttention(latent_dim, heads, dropout=dropout)
24 | layers = [TransformerDecoderLayer(latent_dim, hidden_size, heads, dropout=dropout) for _ in range(num_layers)]
25 | self.layers = nn.ModuleList(layers)
26 | self.dropword = dropword # drop entire tokens
27 |
28 | def forward(self, z, mask, src, src_mask):
29 | z = F.dropout2d(z, p=self.dropword, training=self.training)
30 | # [batch, length, latent_dim]
31 | pos_enc = self.pos_enc(z) * mask.unsqueeze(2)
32 |
33 | key_mask = mask.eq(0)
34 | ctx = self.pos_attn(pos_enc, z, z, key_mask)
35 |
36 | src_mask = src_mask.eq(0)
37 | for layer in self.layers:
38 | ctx = layer(ctx, key_mask, src, src_mask)
39 |
40 | return self.readout(ctx)
41 |
42 | @overrides
43 | def init(self, z, mask, src, src_mask, init_scale=1.0):
44 | with torch.no_grad():
45 | return self(z, mask, src, src_mask)
46 |
47 | @overrides
48 | def decode(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
49 | """
50 |
51 | Args:
52 | z: Tensor
53 | latent code [batch, length, hidden_size]
54 | mask: Tensor
55 | mask [batch, length]
56 | src: Tensor
57 | src encoding [batch, src_length, hidden_size]
58 | src_mask: Tensor
59 | source mask [batch, src_length]
60 |
61 | Returns: Tensor1, Tensor2
62 | Tenser1: decoded word index [batch, length]
63 | Tensor2: log probabilities of decoding [batch]
64 |
65 | """
66 | # [batch, length, vocab_size]
67 | log_probs = F.log_softmax(self(z, mask, src, src_mask), dim=2)
68 | # [batch, length]
69 | log_probs, dec = log_probs.max(dim=2)
70 | dec = dec * mask.long()
71 | # [batch]
72 | log_probs = log_probs.mul(mask).sum(dim=1)
73 | return dec, log_probs
74 |
75 | @overrides
76 | def loss(self, z: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
77 | src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
78 | """
79 |
80 | Args:
81 | z: Tensor
82 | latent codes [batch, length, hidden_size]
83 | target: LongTensor
84 | target translations [batch, length]
85 | mask: Tensor
86 | masks for target sentence [batch, length]
87 | src: Tensor
88 | src encoding [batch, src_length, hidden_size]
89 | src_mask: Tensor
90 | source mask [batch, src_length]
91 |
92 | Returns: Tensor
93 | tensor for loss [batch]
94 |
95 | """
96 | # [batch, length, vocab_size] -> [batch, vocab_size, length]
97 | logits = self(z, mask, src, src_mask).transpose(1, 2)
98 | # [batch, length]
99 | loss = self.criterion(logits, target).mul(mask)
100 | return loss.sum(dim=1)
101 |
102 | @classmethod
103 | def from_params(cls, params: Dict) -> "TransformerDecoder":
104 | return TransformerDecoder(**params)
105 |
106 |
107 | TransformerDecoder.register('transformer')
108 |
--------------------------------------------------------------------------------
/flownmt/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.modules.encoders.encoder import Encoder
2 | from flownmt.modules.encoders.rnn import RecurrentEncoder
3 | from flownmt.modules.encoders.transformer import TransformerEncoder
4 |
--------------------------------------------------------------------------------
/flownmt/modules/encoders/encoder.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Encoder(nn.Module):
8 | """
9 | Src Encoder to encode source sentence
10 | """
11 | _registry = dict()
12 |
13 | def __init__(self, vocab_size, embed_dim, padding_idx):
14 | super(Encoder, self).__init__()
15 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
16 | self.reset_parameters()
17 |
18 | def reset_parameters(self):
19 | nn.init.uniform_(self.embed.weight, -0.1, 0.1)
20 | if self.embed.padding_idx is not None:
21 | with torch.no_grad():
22 | self.embed.weight[self.embed.padding_idx].fill_(0)
23 |
24 | @overrides
25 | def forward(self, src_sents, masks=None) -> Tuple[torch.Tensor, torch.Tensor]:
26 | """
27 | Encoding src sentences into src encoding representations.
28 | Args:
29 | src_sents: Tensor [batch, length]
30 | masks: Tensor or None [batch, length]
31 |
32 | Returns: Tensor1, Tensor2
33 | Tensor1: tensor for src encoding [batch, length, hidden_size]
34 | Tensor2: tensor for global state [batch, hidden_size]
35 |
36 | """
37 | raise NotImplementedError
38 |
39 | def init(self, src_sents, masks=None, init_scale=1.0) -> torch.Tensor:
40 | raise NotImplementedError
41 |
42 | @classmethod
43 | def register(cls, name: str):
44 | Encoder._registry[name] = cls
45 |
46 | @classmethod
47 | def by_name(cls, name: str):
48 | return Encoder._registry[name]
49 |
50 | @classmethod
51 | def from_params(cls, params: Dict):
52 | raise NotImplementedError
53 |
--------------------------------------------------------------------------------
/flownmt/modules/encoders/rnn.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from flownmt.modules.encoders.encoder import Encoder
8 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
9 |
10 |
11 | class RecurrentCore(nn.Module):
12 | def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, dropout=0.0):
13 | super(RecurrentCore, self).__init__()
14 | self.embed = embed
15 |
16 | if rnn_mode == 'RNN':
17 | RNN = nn.RNN
18 | elif rnn_mode == 'LSTM':
19 | RNN = nn.LSTM
20 | elif rnn_mode == 'GRU':
21 | RNN = nn.GRU
22 | else:
23 | raise ValueError('Unknown RNN mode: %s' % rnn_mode)
24 | assert hidden_size % 2 == 0
25 | self.rnn = RNN(embed.embedding_dim, hidden_size // 2,
26 | num_layers=num_layers, batch_first=True, bidirectional=True)
27 | self.enc_proj = nn.Sequential(nn.Linear(hidden_size, latent_dim), nn.ELU())
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | nn.init.constant_(self.enc_proj[0].bias, 0.)
32 |
33 | @overrides
34 | def forward(self, src_sents, masks) -> Tuple[torch.Tensor, torch.Tensor]:
35 | word_embed = F.dropout(self.embed(src_sents), p=0.2, training=self.training)
36 |
37 | lengths = masks.sum(dim=1).long()
38 | packed_embed = pack_padded_sequence(word_embed, lengths, batch_first=True, enforce_sorted=False)
39 | packed_enc, _ = self.rnn(packed_embed)
40 | # [batch, length, hidden_size]
41 | src_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=masks.size(1))
42 | # [batch, length, latent_dim]
43 | src_enc = self.enc_proj(src_enc).mul(masks.unsqueeze(2))
44 |
45 | # [batch, latent_dim]
46 | batch = src_sents.size(0)
47 | idx = lengths - 1
48 | batch_idx = torch.arange(0, batch).long().to(idx.device)
49 | ctx = src_enc[batch_idx, idx]
50 | return src_enc, ctx
51 |
52 |
53 | class RecurrentEncoder(Encoder):
54 | """
55 | Src Encoder to encode source sentence with Recurrent Neural Networks
56 | """
57 |
58 | def __init__(self, vocab_size, embed_dim, padding_idx, rnn_mode, num_layers, latent_dim, hidden_size, dropout=0.0):
59 | super(RecurrentEncoder, self).__init__(vocab_size, embed_dim, padding_idx)
60 | self.core = RecurrentCore(self.embed, rnn_mode, num_layers, latent_dim, hidden_size, dropout=dropout)
61 |
62 | @overrides
63 | def forward(self, src_sents, masks=None) -> Tuple[torch.Tensor, torch.Tensor]:
64 | src_enc, ctx = self.core(src_sents, masks=masks)
65 | return src_enc, ctx
66 |
67 | def init(self, src_sents, masks=None, init_scale=1.0) -> torch.Tensor:
68 | with torch.no_grad():
69 | src_enc, _ = self.core(src_sents, masks=masks)
70 | return src_enc
71 |
72 | @classmethod
73 | def from_params(cls, params: Dict) -> "RecurrentEncoder":
74 | return RecurrentEncoder(**params)
75 |
76 |
77 | RecurrentEncoder.register('rnn')
78 |
--------------------------------------------------------------------------------
/flownmt/modules/encoders/transformer.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from flownmt.modules.encoders.encoder import Encoder
9 | from flownmt.nnet.transformer import TransformerEncoderLayer
10 | from flownmt.nnet.positional_encoding import PositionalEncoding
11 |
12 |
13 | class TransformerCore(nn.Module):
14 | def __init__(self, embed, num_layers, latent_dim, hidden_size, heads, dropout=0.0, max_length=100):
15 | super(TransformerCore, self).__init__()
16 | self.embed = embed
17 | self.padding_idx = embed.padding_idx
18 |
19 | embed_dim = embed.embedding_dim
20 | self.embed_scale = math.sqrt(embed_dim)
21 | assert embed_dim == latent_dim
22 | layers = [TransformerEncoderLayer(latent_dim, hidden_size, heads, dropout=dropout) for _ in range(num_layers)]
23 | self.layers = nn.ModuleList(layers)
24 | self.pos_enc = PositionalEncoding(latent_dim, self.padding_idx, max_length + 1)
25 | self.reset_parameters()
26 |
27 | def reset_parameters(self):
28 | pass
29 |
30 | @overrides
31 | def forward(self, src_sents, masks) -> Tuple[torch.Tensor, torch.Tensor]:
32 | # [batch, leagth, embed_dim]
33 | x = self.embed_scale * self.embed(src_sents)
34 | x += self.pos_enc(src_sents)
35 | x = F.dropout(x, p=0.2, training=self.training)
36 |
37 | # [batch, leagth, latent_dim]
38 | key_mask = masks.eq(0)
39 | if not key_mask.any():
40 | key_mask = None
41 |
42 | for layer in self.layers:
43 | x = layer(x, key_mask)
44 |
45 | x *= masks.unsqueeze(2)
46 | # [batch, latent_dim]
47 | batch = src_sents.size(0)
48 | idx = masks.sum(dim=1).long() - 1
49 | batch_idx = torch.arange(0, batch).long().to(idx.device)
50 | ctx = x[batch_idx, idx]
51 | return x, ctx
52 |
53 |
54 | class TransformerEncoder(Encoder):
55 | """
56 | Src Encoder to encode source sentence with Transformer
57 | """
58 |
59 | def __init__(self, vocab_size, embed_dim, padding_idx, num_layers, latent_dim, hidden_size, heads, dropout=0.0, max_length=100):
60 | super(TransformerEncoder, self).__init__(vocab_size, embed_dim, padding_idx)
61 | self.core = TransformerCore(self.embed, num_layers, latent_dim, hidden_size, heads, dropout=dropout, max_length=max_length)
62 |
63 | @overrides
64 | def forward(self, src_sents, masks=None) -> Tuple[torch.Tensor, torch.Tensor]:
65 | src_enc, ctx = self.core(src_sents, masks=masks)
66 | return src_enc, ctx
67 |
68 | def init(self, src_sents, masks=None, init_scale=1.0) -> torch.Tensor:
69 | with torch.no_grad():
70 | src_enc, _ = self.core(src_sents, masks=masks)
71 | return src_enc
72 |
73 | @classmethod
74 | def from_params(cls, params: Dict) -> "TransformerEncoder":
75 | return TransformerEncoder(**params)
76 |
77 |
78 | TransformerEncoder.register('transformer')
79 |
--------------------------------------------------------------------------------
/flownmt/modules/posteriors/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.modules.posteriors.posterior import Posterior
2 | from flownmt.modules.posteriors.rnn import RecurrentPosterior
3 | from flownmt.modules.posteriors.shift_rnn import ShiftRecurrentPosterior
4 | from flownmt.modules.posteriors.transformer import TransformerPosterior
5 |
--------------------------------------------------------------------------------
/flownmt/modules/posteriors/posterior.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Posterior(nn.Module):
8 | """
9 | posterior class
10 | """
11 | _registry = dict()
12 |
13 | def __init__(self, vocab_size, embed_dim, padding_idx, _shared_embed=None):
14 | super(Posterior, self).__init__()
15 | if _shared_embed is None:
16 | self.tgt_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
17 | self.reset_parameters()
18 | else:
19 | self.tgt_embed = _shared_embed
20 |
21 | def reset_parameters(self):
22 | nn.init.uniform_(self.tgt_embed.weight, -0.1, 0.1)
23 | if self.tgt_embed.padding_idx is not None:
24 | with torch.no_grad():
25 | self.tgt_embed.weight[self.tgt_embed.padding_idx].fill_(0)
26 |
27 | def target_embed_weight(self):
28 | raise NotImplementedError
29 |
30 | @staticmethod
31 | def reparameterize(mu, logvar, mask, nsamples=1, random=True):
32 | # [batch, length, dim]
33 | size = mu.size()
34 | std = logvar.mul(0.5).exp()
35 | # [batch, nsamples, length, dim]
36 | if random:
37 | eps = torch.randn(size[0], nsamples, *size[1:], device=mu.device)
38 | eps *= mask.view(size[0], 1, size[1], 1)
39 | else:
40 | eps = mu.new_zeros(size[0], nsamples, *size[1:])
41 | return eps.mul(std.unsqueeze(1)).add(mu.unsqueeze(1)), eps
42 |
43 |
44 | @staticmethod
45 | def log_probability(z, eps, mu, logvar, mask):
46 | size = eps.size()
47 | nz = size[3]
48 | # [batch, nsamples, length, nz]
49 | log_probs = logvar.unsqueeze(1) + eps.pow(2)
50 | # [batch, 1]
51 | cc = mask.sum(dim=1, keepdim=True) * (math.log(math.pi * 2.) * nz)
52 | # [batch, nsamples, length * nz] --> [batch, nsamples]
53 | log_probs = log_probs.view(size[0], size[1], -1).sum(dim=2) + cc
54 | return log_probs * -0.5
55 |
56 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
57 | raise NotImplementedError
58 |
59 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]:
60 | raise NotImplementedError
61 |
62 | def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor,
63 | src_enc: torch.Tensor, src_masks: torch.Tensor,
64 | nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]:
65 | """
66 |
67 | Args:
68 | tgt_sents: Tensor [batch, tgt_length]
69 | tensor for target sentences
70 | tgt_masks: Tensor [batch, tgt_length]
71 | tensor for target masks
72 | src_enc: Tensor [batch, src_length, hidden_size]
73 | tensor for source encoding
74 | src_masks: Tensor [batch, src_length]
75 | tensor for source masks
76 | nsamples: int
77 | number of samples
78 | random: bool
79 | if True, perform random sampling. Otherwise, return mean.
80 |
81 | Returns: Tensor1, Tensor2
82 | Tensor1: samples from the posterior [batch, nsamples, tgt_length, nz]
83 | Tensor2: log probabilities [batch, nsamples]
84 |
85 | """
86 | raise NotImplementedError
87 |
88 | @classmethod
89 | def register(cls, name: str):
90 | Posterior._registry[name] = cls
91 |
92 | @classmethod
93 | def by_name(cls, name: str):
94 | return Posterior._registry[name]
95 |
96 | @classmethod
97 | def from_params(cls, params: Dict):
98 | raise NotImplementedError
99 |
--------------------------------------------------------------------------------
/flownmt/modules/posteriors/rnn.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Tuple, Dict
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
7 |
8 | from flownmt.nnet.weightnorm import LinearWeightNorm
9 | from flownmt.modules.posteriors.posterior import Posterior
10 | from flownmt.nnet.attention import GlobalAttention
11 |
12 |
13 | class RecurrentCore(nn.Module):
14 | def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, use_attn=False, dropout=0.0, dropword=0.0):
15 | super(RecurrentCore, self).__init__()
16 | if rnn_mode == 'RNN':
17 | RNN = nn.RNN
18 | elif rnn_mode == 'LSTM':
19 | RNN = nn.LSTM
20 | elif rnn_mode == 'GRU':
21 | RNN = nn.GRU
22 | else:
23 | raise ValueError('Unknown RNN mode: %s' % rnn_mode)
24 | assert hidden_size % 2 == 0
25 | self.tgt_embed = embed
26 | self.rnn = RNN(embed.embedding_dim, hidden_size // 2,
27 | num_layers=num_layers, batch_first=True, bidirectional=True)
28 | self.use_attn = use_attn
29 | if use_attn:
30 | self.attn = GlobalAttention(latent_dim, hidden_size, hidden_size, hidden_features=hidden_size)
31 | self.ctx_proj = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size), nn.ELU())
32 | else:
33 | self.ctx_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ELU())
34 | self.dropout = dropout
35 | self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens
36 | self.mu = LinearWeightNorm(hidden_size, latent_dim, bias=True)
37 | self.logvar = LinearWeightNorm(hidden_size, latent_dim, bias=True)
38 |
39 | @overrides
40 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
41 | tgt_embed = self.tgt_embed(tgt_sents)
42 | if self.dropout2d is not None:
43 | tgt_embed = self.dropout2d(tgt_embed)
44 | lengths = tgt_masks.sum(dim=1).long()
45 | packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False)
46 | packed_enc, _ = self.rnn(packed_embed)
47 | tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1))
48 | if self.use_attn:
49 | ctx = self.attn(tgt_enc, src_enc, key_mask=src_masks.eq(0))
50 | ctx = torch.cat([tgt_enc, ctx], dim=2)
51 | else:
52 | ctx = tgt_enc
53 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
54 | mu = self.mu(ctx) * tgt_masks.unsqueeze(2)
55 | logvar = self.logvar(ctx) * tgt_masks.unsqueeze(2)
56 | return mu, logvar
57 |
58 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True):
59 | with torch.no_grad():
60 | tgt_embed = self.tgt_embed(tgt_sents)
61 | if self.dropout2d is not None:
62 | tgt_embed = self.dropout2d(tgt_embed)
63 | lengths = tgt_masks.sum(dim=1).long()
64 | packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False)
65 | packed_enc, _ = self.rnn(packed_embed)
66 | tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1))
67 | if self.use_attn:
68 | ctx = self.attn.init(tgt_enc, src_enc, key_mask=src_masks.eq(0), init_scale=init_scale)
69 | ctx = torch.cat([tgt_enc, ctx], dim=2)
70 | else:
71 | ctx = tgt_enc
72 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
73 | mu = self.mu.init(ctx, init_scale=0.05 * init_scale) if init_mu else self.mu(ctx)
74 | logvar = self.logvar.init(ctx, init_scale=0.05 * init_scale) if init_var else self.logvar(ctx)
75 | mu = mu * tgt_masks.unsqueeze(2)
76 | logvar = logvar * tgt_masks.unsqueeze(2)
77 | return mu, logvar
78 |
79 |
80 | class RecurrentPosterior(Posterior):
81 | """
82 | Posterior with Recurrent Neural Networks
83 | """
84 | def __init__(self, vocab_size, embed_dim, padding_idx, rnn_mode, num_layers, latent_dim, hidden_size,
85 | use_attn=False, dropout=0.0, dropword=0.0, _shared_embed=None):
86 | super(RecurrentPosterior, self).__init__(vocab_size, embed_dim, padding_idx, _shared_embed=_shared_embed)
87 | self.core = RecurrentCore(self.tgt_embed, rnn_mode, num_layers, latent_dim, hidden_size,
88 | use_attn=use_attn, dropout=dropout, dropword=dropword)
89 |
90 | def target_embed_weight(self):
91 | if isinstance(self.core, nn.DataParallel):
92 | return self.core.module.tgt_embedd.weight
93 | else:
94 | return self.core.tgt_embed.weight
95 |
96 | @overrides
97 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
98 | return self.core(tgt_sents, tgt_masks, src_enc, src_masks)
99 |
100 | @overrides
101 | def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor,
102 | src_enc: torch.Tensor, src_masks: torch.Tensor,
103 | nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]:
104 | mu, logvar = self.core(tgt_sents, tgt_masks, src_enc, src_masks)
105 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, nsamples=nsamples, random=random)
106 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks)
107 | return z, log_probs
108 |
109 | @overrides
110 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]:
111 | mu, logvar = self.core.init(tgt_sents, tgt_masks, src_enc, src_masks,
112 | init_scale=init_scale, init_mu=init_mu, init_var=init_var)
113 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, random=True)
114 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks)
115 | z = z.squeeze(1)
116 | log_probs = log_probs.squeeze(1)
117 | return z, log_probs
118 |
119 | @classmethod
120 | def from_params(cls, params: Dict) -> "RecurrentPosterior":
121 | return RecurrentPosterior(**params)
122 |
123 |
124 | RecurrentPosterior.register('rnn')
125 |
--------------------------------------------------------------------------------
/flownmt/modules/posteriors/shift_rnn.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Tuple, Dict
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
7 |
8 | from flownmt.nnet.weightnorm import LinearWeightNorm
9 | from flownmt.modules.posteriors.posterior import Posterior
10 | from flownmt.nnet.attention import GlobalAttention
11 |
12 |
13 | class ShiftRecurrentCore(nn.Module):
14 | def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, bidirectional=True, use_attn=False, dropout=0.0, dropword=0.0):
15 | super(ShiftRecurrentCore, self).__init__()
16 | if rnn_mode == 'RNN':
17 | RNN = nn.RNN
18 | elif rnn_mode == 'LSTM':
19 | RNN = nn.LSTM
20 | elif rnn_mode == 'GRU':
21 | RNN = nn.GRU
22 | else:
23 | raise ValueError('Unknown RNN mode: %s' % rnn_mode)
24 | assert hidden_size % 2 == 0
25 | self.tgt_embed = embed
26 | assert num_layers == 1
27 | self.bidirectional = bidirectional
28 | if bidirectional:
29 | self.rnn = RNN(embed.embedding_dim, hidden_size // 2, num_layers=1, batch_first=True, bidirectional=True)
30 | else:
31 | self.rnn = RNN(embed.embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=False)
32 | self.use_attn = use_attn
33 | if use_attn:
34 | self.attn = GlobalAttention(latent_dim, hidden_size, hidden_size)
35 | self.ctx_proj = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size), nn.ELU())
36 | else:
37 | self.ctx_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ELU())
38 | self.dropout = dropout
39 | self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens
40 | self.mu = LinearWeightNorm(hidden_size, latent_dim, bias=True)
41 | self.logvar = LinearWeightNorm(hidden_size, latent_dim, bias=True)
42 |
43 | @overrides
44 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
45 | tgt_embed = self.tgt_embed(tgt_sents)
46 | if self.dropout2d is not None:
47 | tgt_embed = self.dropout2d(tgt_embed)
48 | lengths = tgt_masks.sum(dim=1).long()
49 | packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False)
50 | packed_enc, _ = self.rnn(packed_embed)
51 | tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1))
52 |
53 | if self.bidirectional:
54 | # split into fwd and bwd
55 | fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2) # (batch_size, seq_len, hidden_size // 2)
56 | pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2)))
57 | pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1)
58 | pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1)
59 | tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2)
60 | else:
61 | pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2)))
62 | tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1]
63 |
64 | if self.use_attn:
65 | ctx = self.attn(tgt_enc, src_enc, key_mask=src_masks.eq(0))
66 | ctx = torch.cat([tgt_enc, ctx], dim=2)
67 | else:
68 | ctx = tgt_enc
69 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
70 | mu = self.mu(ctx) * tgt_masks.unsqueeze(2)
71 | logvar = self.logvar(ctx) * tgt_masks.unsqueeze(2)
72 | return mu, logvar
73 |
74 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True):
75 | with torch.no_grad():
76 | tgt_embed = self.tgt_embed(tgt_sents)
77 | if self.dropout2d is not None:
78 | tgt_embed = self.dropout2d(tgt_embed)
79 | lengths = tgt_masks.sum(dim=1).long()
80 | packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False)
81 | packed_enc, _ = self.rnn(packed_embed)
82 | tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1))
83 |
84 | if self.bidirectional:
85 | fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2) # (batch_size, seq_len, hidden_size // 2)
86 | pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2)))
87 | pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1)
88 | pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1)
89 | tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2)
90 | else:
91 | pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2)))
92 | tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1]
93 |
94 | if self.use_attn:
95 | ctx = self.attn.init(tgt_enc, src_enc, key_mask=src_masks.eq(0), init_scale=init_scale)
96 | ctx = torch.cat([tgt_enc, ctx], dim=2)
97 | else:
98 | ctx = tgt_enc
99 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
100 | mu = self.mu.init(ctx, init_scale=0.05 * init_scale) if init_mu else self.mu(ctx)
101 | logvar = self.logvar.init(ctx, init_scale=0.05 * init_scale) if init_var else self.logvar(ctx)
102 | mu = mu * tgt_masks.unsqueeze(2)
103 | logvar = logvar * tgt_masks.unsqueeze(2)
104 | return mu, logvar
105 |
106 |
107 | class ShiftRecurrentPosterior(Posterior):
108 | """
109 | Posterior with Recurrent Neural Networks
110 | """
111 | def __init__(self, vocab_size, embed_dim, padding_idx, rnn_mode, num_layers, latent_dim, hidden_size,
112 | bidirectional=True, use_attn=False, dropout=0.0, dropword=0.0, _shared_embed=None):
113 | super(ShiftRecurrentPosterior, self).__init__(vocab_size, embed_dim, padding_idx, _shared_embed=_shared_embed)
114 | self.core = ShiftRecurrentCore(self.tgt_embed, rnn_mode, num_layers, latent_dim, hidden_size,
115 | bidirectional=bidirectional, use_attn=use_attn, dropout=dropout, dropword=dropword)
116 |
117 | def target_embed_weight(self):
118 | if isinstance(self.core, nn.DataParallel):
119 | return self.core.module.tgt_embedd.weight
120 | else:
121 | return self.core.tgt_embed.weight
122 |
123 | @overrides
124 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
125 | return self.core(tgt_sents, tgt_masks, src_enc, src_masks)
126 |
127 | @overrides
128 | def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor,
129 | src_enc: torch.Tensor, src_masks: torch.Tensor,
130 | nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]:
131 | mu, logvar = self.core(tgt_sents, tgt_masks, src_enc, src_masks)
132 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, nsamples=nsamples, random=random)
133 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks)
134 | return z, log_probs
135 |
136 | @overrides
137 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]:
138 | mu, logvar = self.core.init(tgt_sents, tgt_masks, src_enc, src_masks,
139 | init_scale=init_scale, init_mu=init_mu, init_var=init_var)
140 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, random=True)
141 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks)
142 | z = z.squeeze(1)
143 | log_probs = log_probs.squeeze(1)
144 | return z, log_probs
145 |
146 | @classmethod
147 | def from_params(cls, params: Dict) -> "ShiftRecurrentPosterior":
148 | return ShiftRecurrentPosterior(**params)
149 |
150 |
151 | ShiftRecurrentPosterior.register('shift_rnn')
152 |
--------------------------------------------------------------------------------
/flownmt/modules/posteriors/transformer.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Tuple, Dict
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from flownmt.nnet.weightnorm import LinearWeightNorm
9 | from flownmt.nnet.transformer import TransformerDecoderLayer
10 | from flownmt.nnet.positional_encoding import PositionalEncoding
11 | from flownmt.modules.posteriors.posterior import Posterior
12 |
13 |
14 | class TransformerCore(nn.Module):
15 | def __init__(self, embed, num_layers, latent_dim, hidden_size, heads, dropout=0.0, dropword=0.0, max_length=100):
16 | super(TransformerCore, self).__init__()
17 | self.tgt_embed = embed
18 | self.padding_idx = embed.padding_idx
19 |
20 | embed_dim = embed.embedding_dim
21 | self.embed_scale = math.sqrt(embed_dim)
22 | assert embed_dim == latent_dim
23 | layers = [TransformerDecoderLayer(latent_dim, hidden_size, heads, dropout=dropout) for _ in range(num_layers)]
24 | self.layers = nn.ModuleList(layers)
25 | self.pos_enc = PositionalEncoding(latent_dim, self.padding_idx, max_length + 1)
26 | self.dropword = dropword # drop entire tokens
27 | self.mu = LinearWeightNorm(latent_dim, latent_dim, bias=True)
28 | self.logvar = LinearWeightNorm(latent_dim, latent_dim, bias=True)
29 | self.reset_parameters()
30 |
31 | def reset_parameters(self):
32 | pass
33 |
34 | @overrides
35 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
36 | x = self.embed_scale * self.tgt_embed(tgt_sents)
37 | x = F.dropout2d(x, p=self.dropword, training=self.training)
38 | x += self.pos_enc(tgt_sents)
39 | x = F.dropout(x, p=0.2, training=self.training)
40 |
41 | mask = tgt_masks.eq(0)
42 | key_mask = src_masks.eq(0)
43 | for layer in self.layers:
44 | x = layer(x, mask, src_enc, key_mask)
45 |
46 | mu = self.mu(x) * tgt_masks.unsqueeze(2)
47 | logvar = self.logvar(x) * tgt_masks.unsqueeze(2)
48 | return mu, logvar
49 |
50 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True):
51 | with torch.no_grad():
52 | x = self.embed_scale * self.tgt_embed(tgt_sents)
53 | x = F.dropout2d(x, p=self.dropword, training=self.training)
54 | x += self.pos_enc(tgt_sents)
55 | x = F.dropout(x, p=0.2, training=self.training)
56 |
57 | mask = tgt_masks.eq(0)
58 | key_mask = src_masks.eq(0)
59 | for layer in self.layers:
60 | x = layer.init(x, mask, src_enc, key_mask, init_scale=init_scale)
61 |
62 | x = x * tgt_masks.unsqueeze(2)
63 | mu = self.mu.init(x, init_scale=0.05 * init_scale) if init_mu else self.mu(x)
64 | logvar = self.logvar.init(x, init_scale=0.05 * init_scale) if init_var else self.logvar(x)
65 | mu = mu * tgt_masks.unsqueeze(2)
66 | logvar = logvar * tgt_masks.unsqueeze(2)
67 | return mu, logvar
68 |
69 |
70 | class TransformerPosterior(Posterior):
71 | """
72 | Posterior with Transformer
73 | """
74 | def __init__(self, vocab_size, embed_dim, padding_idx, num_layers, latent_dim, hidden_size, heads,
75 | dropout=0.0, dropword=0.0, max_length=100, _shared_embed=None):
76 | super(TransformerPosterior, self).__init__(vocab_size, embed_dim, padding_idx, _shared_embed=_shared_embed)
77 | self.core = TransformerCore(self.tgt_embed, num_layers, latent_dim, hidden_size, heads,
78 | dropout=dropout, dropword=dropword, max_length=max_length)
79 |
80 | def target_embed_weight(self):
81 | if isinstance(self.core, nn.DataParallel):
82 | return self.core.module.tgt_embedd.weight
83 | else:
84 | return self.core.tgt_embed.weight
85 |
86 | @overrides
87 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks):
88 | return self.core(tgt_sents, tgt_masks, src_enc, src_masks)
89 |
90 | @overrides
91 | def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor,
92 | src_enc: torch.Tensor, src_masks: torch.Tensor,
93 | nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]:
94 | mu, logvar = self.core(tgt_sents, tgt_masks, src_enc, src_masks)
95 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, nsamples=nsamples, random=random)
96 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks)
97 | return z, log_probs
98 |
99 | @overrides
100 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]:
101 | mu, logvar = self.core.init(tgt_sents, tgt_masks, src_enc, src_masks,
102 | init_scale=init_scale, init_mu=init_mu, init_var=init_var)
103 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, random=True)
104 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks)
105 | z = z.squeeze(1)
106 | log_probs = log_probs.squeeze(1)
107 | return z, log_probs
108 |
109 | @classmethod
110 | def from_params(cls, params: Dict) -> "TransformerPosterior":
111 | return TransformerPosterior(**params)
112 |
113 |
114 | TransformerPosterior.register('transformer')
115 |
--------------------------------------------------------------------------------
/flownmt/modules/priors/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.modules.priors.prior import Prior
2 | from flownmt.modules.priors.length_predictors import *
3 |
--------------------------------------------------------------------------------
/flownmt/modules/priors/length_predictors/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.modules.priors.length_predictors.predictor import LengthPredictor
2 | from flownmt.modules.priors.length_predictors.diff_discretized_mix_logistic import DiffDiscreteMixLogisticLengthPredictor
3 | from flownmt.modules.priors.length_predictors.diff_softmax import DiffSoftMaxLengthPredictor
4 |
--------------------------------------------------------------------------------
/flownmt/modules/priors/length_predictors/diff_discretized_mix_logistic.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from flownmt.modules.priors.length_predictors.predictor import LengthPredictor
8 | from flownmt.modules.priors.length_predictors.utils import discretized_mix_logistic_loss, discretized_mix_logistic_topk
9 |
10 |
11 | class DiffDiscreteMixLogisticLengthPredictor(LengthPredictor):
12 | def __init__(self, features, max_src_length, diff_range, nmix=1, dropout=0.0):
13 | super(DiffDiscreteMixLogisticLengthPredictor, self).__init__()
14 | self.max_src_length = max_src_length
15 | self.range = diff_range
16 | self.nmix = nmix
17 | self.features = features
18 | self.dropout = dropout
19 | self.ctx_proj = None
20 | self.diff = None
21 |
22 | def set_length_unit(self, length_unit):
23 | self.length_unit = length_unit
24 | self.ctx_proj = nn.Sequential(nn.Linear(self.features, self.features), nn.ELU())
25 | self.diff = nn.Linear(self.features, 3 * self.nmix)
26 | self.reset_parameters()
27 |
28 | def reset_parameters(self):
29 | nn.init.constant_(self.ctx_proj[0].bias, 0.)
30 | nn.init.uniform_(self.diff.weight, -0.1, 0.1)
31 | nn.init.constant_(self.diff.bias, 0.)
32 |
33 | def forward(self, ctx):
34 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
35 | # [batch, 3 * nmix]
36 | coeffs = self.diff(ctx)
37 | # [batch, nmix]
38 | logit_probs = F.log_softmax(coeffs[:, :self.nmix], dim=1)
39 | mu = coeffs[:, self.nmix:self.nmix * 2]
40 | log_scale = coeffs[:, self.nmix * 2:]
41 | return mu, log_scale, logit_probs
42 |
43 | @overrides
44 | def loss(self, ctx: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
45 | """
46 | Args:
47 | ctx: Tensor
48 | tensor [batch, features]
49 | src_mask: Tensor
50 | tensor for source mask [batch, src_length]
51 | tgt_mask: Tensor
52 | tensor for target mask [batch, tgt_length]
53 | Returns: Tensor
54 | tensor for loss [batch]
55 | """
56 | src_lengths = src_mask.sum(dim=1).float()
57 | tgt_lengths = tgt_mask.sum(dim=1).float()
58 | mu, log_scale, logit_probs = self(ctx, src_lengths.long())
59 | x = (tgt_lengths - src_lengths).div(self.range).clamp(min=-1, max=1)
60 | bin_size = 0.5 / self.range
61 | lower = bin_size - 1.0
62 | upper = 1.0 - bin_size
63 | loss = discretized_mix_logistic_loss(x, mu, log_scale, logit_probs, bin_size, lower, upper)
64 | return loss
65 |
66 | @overrides
67 | def predict(self, ctx: torch.Tensor, src_mask:torch.Tensor, topk: int = 1) -> Tuple[torch.Tensor, torch.LongTensor]:
68 | """
69 | Args:
70 | ctx: Tensor
71 | tensor [batch, features]
72 | src_mask: Tensor
73 | tensor for source mask [batch, src_length]
74 | topk: int (default 1)
75 | return top k length candidates for each src sentence
76 | Returns: Tensor1, LongTensor2
77 | Tensor1: log probs for each length
78 | LongTensor2: tensor for lengths [batch, topk]
79 | """
80 | bin_size = 0.5 / self.range
81 | lower = bin_size - 1.0
82 | upper = 1.0 - bin_size
83 | # [batch]
84 | src_lengths = src_mask.sum(dim=1).long()
85 | mu, log_scale, logit_probs = self(ctx, src_lengths)
86 | # [batch, topk]
87 | log_probs, diffs = discretized_mix_logistic_topk(mu, log_scale, logit_probs,
88 | self.range, bin_size, lower, upper, topk=topk)
89 | lengths = (diffs + src_lengths.unsqueeze(1)).clamp(min=self.length_unit)
90 | res = lengths.fmod(self.length_unit)
91 | padding = (self.length_unit - res).fmod(self.length_unit)
92 | lengths = lengths + padding
93 | return log_probs, lengths
94 |
95 | @classmethod
96 | def from_params(cls, params: Dict) -> 'DiffDiscreteMixLogisticLengthPredictor':
97 | return DiffDiscreteMixLogisticLengthPredictor(**params)
98 |
99 |
100 | DiffDiscreteMixLogisticLengthPredictor.register('diff_logistic')
101 |
--------------------------------------------------------------------------------
/flownmt/modules/priors/length_predictors/diff_softmax.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | from typing import Dict, Tuple
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from flownmt.modules.priors.length_predictors.predictor import LengthPredictor
8 | from flownmt.nnet.criterion import LabelSmoothedCrossEntropyLoss
9 |
10 |
11 | class DiffSoftMaxLengthPredictor(LengthPredictor):
12 | def __init__(self, features, max_src_length, diff_range, dropout=0.0, label_smoothing=0.):
13 | super(DiffSoftMaxLengthPredictor, self).__init__()
14 | self.max_src_length = max_src_length
15 | self.range = diff_range
16 | self.features = features
17 | self.dropout = dropout
18 | self.ctx_proj = None
19 | self.diff = None
20 | if label_smoothing < 1e-5:
21 | self.criterion = nn.CrossEntropyLoss(reduction='none')
22 | elif 1e-5 < label_smoothing < 1.0:
23 | self.criterion = LabelSmoothedCrossEntropyLoss(label_smoothing)
24 | else:
25 | raise ValueError('label smoothing should be less than 1.0.')
26 |
27 | def set_length_unit(self, length_unit):
28 | self.length_unit = length_unit
29 | self.ctx_proj = nn.Sequential(nn.Linear(self.features, self.features), nn.ELU(),
30 | nn.Linear(self.features, self.features), nn.ELU())
31 | self.diff = nn.Linear(self.features, 2 * self.range + 1)
32 | self.reset_parameters()
33 |
34 | def reset_parameters(self):
35 | nn.init.constant_(self.ctx_proj[0].bias, 0.)
36 | nn.init.constant_(self.ctx_proj[2].bias, 0.)
37 | nn.init.uniform_(self.diff.weight, -0.1, 0.1)
38 | nn.init.constant_(self.diff.bias, 0.)
39 |
40 | def forward(self, ctx):
41 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training)
42 | return self.diff(ctx)
43 |
44 | @overrides
45 | def loss(self, ctx: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
46 | """
47 | Args:
48 | ctx: Tensor
49 | tensor [batch, features]
50 | src_mask: Tensor
51 | tensor for source mask [batch, src_length]
52 | tgt_mask: Tensor
53 | tensor for target mask [batch, tgt_length]
54 | Returns: Tensor
55 | tensor for loss [batch]
56 | """
57 | # [batch]
58 | src_lengths = src_mask.sum(dim=1).long()
59 | tgt_lengths = tgt_mask.sum(dim=1).long()
60 | # [batch, 2 * range + 1]
61 | logits = self(ctx)
62 | # [1, 2 * range + 1]
63 | mask = torch.arange(0, logits.size(1), device=logits.device).unsqueeze(0)
64 | # [batch, 2 * range + 1]
65 | mask = (mask + src_lengths.unsqueeze(1) - self.range).fmod(self.length_unit).ne(0)
66 | logits = logits.masked_fill(mask, float('-inf'))
67 |
68 | # handle tgt < src - range
69 | x = (tgt_lengths - src_lengths + self.range).clamp(min=0)
70 | tgt = x + src_lengths - self.range
71 | res = tgt.fmod(self.length_unit)
72 | padding = (self.length_unit - res).fmod(self.length_unit)
73 | tgt = tgt + padding
74 | # handle tgt > src + range
75 | x = (tgt - src_lengths + self.range).clamp(max=2 * self.range)
76 | tgt = x + src_lengths - self.range
77 | tgt = tgt - tgt.fmod(self.length_unit)
78 |
79 | x = tgt - src_lengths + self.range
80 | loss_length = self.criterion(logits, x)
81 | return loss_length
82 |
83 | @overrides
84 | def predict(self, ctx: torch.Tensor, src_mask:torch.Tensor, topk: int = 1) -> Tuple[torch.LongTensor, torch.Tensor]:
85 | """
86 | Args:
87 | ctx: Tensor
88 | tensor [batch, features]
89 | src_mask: Tensor
90 | tensor for source mask [batch, src_length]
91 | topk: int (default 1)
92 | return top k length candidates for each src sentence
93 | Returns: LongTensor1, Tensor2
94 | LongTensor1: tensor for lengths [batch, topk]
95 | Tensor2: log probs for each length
96 | """
97 | # [batch]
98 | src_lengths = src_mask.sum(dim=1).long()
99 | # [batch, 2 * range + 1]
100 | logits = self(ctx)
101 | # [1, 2 * range + 1]
102 | x = torch.arange(0, logits.size(1), device=logits.device).unsqueeze(0)
103 | # [batch, 2 * range + 1]
104 | tgt = x + src_lengths.unsqueeze(1) - self.range
105 | mask = tgt.fmod(self.length_unit).ne(0)
106 | logits = logits.masked_fill(mask, float('-inf'))
107 | # [batch, 2 * range + 1]
108 | log_probs = F.log_softmax(logits, dim=1)
109 | # handle tgt length <= 0
110 | mask = tgt.le(0)
111 | log_probs = log_probs.masked_fill(mask, float('-inf'))
112 | # [batch, topk]
113 | log_probs, x = log_probs.topk(topk, dim=1)
114 | lengths = x + src_lengths.unsqueeze(1) - self.range
115 | return lengths, log_probs
116 |
117 | @classmethod
118 | def from_params(cls, params: Dict) -> 'DiffSoftMaxLengthPredictor':
119 | return DiffSoftMaxLengthPredictor(**params)
120 |
121 |
122 | DiffSoftMaxLengthPredictor.register('diff_softmax')
123 |
--------------------------------------------------------------------------------
/flownmt/modules/priors/length_predictors/predictor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class LengthPredictor(nn.Module):
7 | """
8 | Length Predictor
9 | """
10 | _registry = dict()
11 |
12 | def __init__(self):
13 | super(LengthPredictor, self).__init__()
14 | self.length_unit = None
15 |
16 | def set_length_unit(self, length_unit):
17 | self.length_unit = length_unit
18 |
19 | def loss(self, ctx: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
20 | """
21 |
22 | Args:
23 | ctx: Tensor
24 | tensor [batch, features]
25 | src_mask: Tensor
26 | tensor for source mask [batch, src_length]
27 | tgt_mask: Tensor
28 | tensor for target mask [batch, tgt_length]
29 |
30 | Returns: Tensor
31 | tensor for loss [batch]
32 |
33 | """
34 | raise NotImplementedError
35 |
36 | def predict(self, ctx: torch.Tensor, src_mask:torch.Tensor, topk: int = 1) -> Tuple[torch.LongTensor, torch.Tensor]:
37 | """
38 | Args:
39 | ctx: Tensor
40 | tensor [batch, features]
41 | src_mask: Tensor
42 | tensor for source mask [batch, src_length]
43 | topk: int (default 1)
44 | return top k length candidates for each src sentence
45 | Returns: LongTensor1, Tensor2
46 | LongTensor1: tensor for lengths [batch, topk]
47 | Tensor2: log probs for each length
48 | """
49 | raise NotImplementedError
50 |
51 | @classmethod
52 | def register(cls, name: str):
53 | LengthPredictor._registry[name] = cls
54 |
55 | @classmethod
56 | def by_name(cls, name: str):
57 | return LengthPredictor._registry[name]
58 |
59 | @classmethod
60 | def from_params(cls, params: Dict):
61 | raise NotImplementedError
62 |
--------------------------------------------------------------------------------
/flownmt/modules/priors/length_predictors/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def discretized_mix_logistic_loss(x, means, logscales, logit_probs,
8 | bin_size, lower, upper) -> torch.Tensor:
9 | """
10 | loss for discretized mixture logistic distribution
11 | Args:
12 | x: [batch, ]
13 | means: [batch, nmix]
14 | logscales: [batch, nmix]
15 | logit_probs:, [batch, nmix]
16 | bin_size: float
17 | The segment for cdf is [x-binsize, x+binsize]
18 | lower: float
19 | upper: float
20 | Returns:
21 | loss [batch]
22 | """
23 | eps = 1e-12
24 | # [batch, 1]
25 | x = x.unsqueeze(1)
26 | # [batch, nmix]
27 | centered_x = x - means
28 | if isinstance(logscales, float):
29 | inv_stdv = np.exp(-logscales)
30 | else:
31 | inv_stdv = torch.exp(-logscales)
32 |
33 | # [batch, nmix]
34 | min_in = inv_stdv * (centered_x - bin_size)
35 | plus_in = inv_stdv * (centered_x + bin_size)
36 | x_in = inv_stdv * centered_x
37 |
38 | # [batch, nmix]
39 | cdf_min = torch.sigmoid(min_in)
40 | cdf_plus = torch.sigmoid(plus_in)
41 | # lower < x < upper
42 | cdf_delta = cdf_plus - cdf_min
43 | log_cdf_mid = torch.log(cdf_delta + eps)
44 | log_cdf_approx = x_in - logscales - 2. * F.softplus(x_in) + np.log(2 * bin_size)
45 |
46 | # x < lower
47 | log_cdf_low = plus_in - F.softplus(plus_in)
48 |
49 | # x > upper
50 | log_cdf_up = -F.softplus(min_in)
51 |
52 | # [batch, nmix]
53 | log_cdf = torch.where(cdf_delta.gt(1e-5), log_cdf_mid, log_cdf_approx)
54 | log_cdf = torch.where(x.ge(lower), log_cdf, log_cdf_low)
55 | log_cdf = torch.where(x.le(upper), log_cdf, log_cdf_up)
56 |
57 | # [batch]
58 | loss = torch.logsumexp(log_cdf + logit_probs, dim=1) * -1.
59 | return loss
60 |
61 |
62 | def discretized_mix_logistic_topk(means, logscales, logit_probs,
63 | range, bin_size, lower, upper, topk=1) -> Tuple[torch.Tensor, torch.LongTensor]:
64 | """
65 | topk for discretized mixture logistic distribution
66 | Args:
67 | means: [batch, nmix]
68 | logscales: [batch, nmix]
69 | logit_probs:, [batch, nmix]
70 | range: int
71 | range of x
72 | bin_size: float
73 | The segment for cdf is [x-binsize, x+binsize]
74 | lower: float
75 | upper: float
76 | topk: int
77 | Returns: Tensor1, Tensor2
78 | Tensor1: log probs [batch, topk]
79 | Tensor2: indexes for top k [batch, topk]
80 |
81 | """
82 | eps = 1e-12
83 | # [batch, 1, nmix]
84 | means = means.unsqueeze(1)
85 | logscales = logscales.unsqueeze(1)
86 | logit_probs = logit_probs.unsqueeze(1)
87 | # [1, 2 * range + 1, 1]
88 | x = torch.arange(-range, range + 1, 1., device=means.device).unsqueeze(0).unsqueeze(2)
89 | x = x.div(range)
90 | # [batch, 2 * range + 1, nmix]
91 | centered_x = x - means
92 | if isinstance(logscales, float):
93 | inv_stdv = np.exp(-logscales)
94 | else:
95 | inv_stdv = torch.exp(-logscales)
96 |
97 | # [batch, 2 * range + 1, nmix]
98 | min_in = inv_stdv * (centered_x - bin_size)
99 | plus_in = inv_stdv * (centered_x + bin_size)
100 | x_in = inv_stdv * centered_x
101 |
102 | # [batch, 2 * range + 1, nmix]
103 | cdf_min = torch.sigmoid(min_in)
104 | cdf_plus = torch.sigmoid(plus_in)
105 | # lower < x < upper
106 | cdf_delta = cdf_plus - cdf_min
107 | log_cdf_mid = torch.log(cdf_delta + eps)
108 | log_cdf_approx = x_in - logscales - 2. * F.softplus(x_in) + np.log(2 * bin_size)
109 |
110 | # x < lower
111 | log_cdf_low = plus_in - F.softplus(plus_in)
112 |
113 | # x > upper
114 | log_cdf_up = -F.softplus(min_in)
115 |
116 | # [batch, 2 * range + 1, nmix]
117 | log_cdf = torch.where(cdf_delta.gt(1e-5), log_cdf_mid, log_cdf_approx)
118 | log_cdf = torch.where(x.ge(lower), log_cdf, log_cdf_low)
119 | log_cdf = torch.where(x.le(upper), log_cdf, log_cdf_up)
120 | # [batch, 2 * range + 1]
121 | log_probs = torch.logsumexp(log_cdf + logit_probs, dim=2)
122 | log_probs, idx = log_probs.topk(topk, dim=1)
123 |
124 | return log_probs, idx - range
125 |
--------------------------------------------------------------------------------
/flownmt/modules/priors/prior.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Tuple, Union
3 | import torch
4 | import torch.nn as nn
5 |
6 | from flownmt.flows.nmt import NMTFlow
7 | from flownmt.modules.priors.length_predictors import LengthPredictor
8 |
9 |
10 | class Prior(nn.Module):
11 | """
12 | class for Prior with a NMTFlow inside
13 | """
14 | _registry = dict()
15 |
16 | def __init__(self, flow: NMTFlow, length_predictor: LengthPredictor):
17 | super(Prior, self).__init__()
18 | assert flow.inverse, 'prior flow should have inverse mode'
19 | self.flow = flow
20 | self.length_unit = max(2, 2 ** (self.flow.levels - 1))
21 | self.features = self.flow.features
22 | self._length_predictor = length_predictor
23 | self._length_predictor.set_length_unit(self.length_unit)
24 |
25 | def sync(self):
26 | self.flow.sync()
27 |
28 | def predict_length(self, ctx: torch.Tensor, src_mask: torch.Tensor, topk: int = 1) -> Tuple[torch.LongTensor, torch.Tensor]:
29 | """
30 | Args:
31 | ctx: Tensor
32 | tensor [batch, features]
33 | src_mask: Tensor
34 | tensor for source mask [batch, src_length]
35 | topk: int (default 1)
36 | return top k length candidates for each src sentence
37 | Returns: LongTensor1, Tensor2
38 | LongTensor1: tensor for lengths [batch, topk]
39 | Tensor2: log probs for each length [batch, topk]
40 | """
41 | return self._length_predictor.predict(ctx, src_mask, topk=topk)
42 |
43 | def length_loss(self, ctx: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
44 | """
45 |
46 | Args:
47 | ctx: Tensor
48 | tensor [batch, features]
49 | src_mask: Tensor
50 | tensor for source mask [batch, src_length]
51 | tgt_mask: Tensor
52 | tensor for target mask [batch, tgt_length]
53 |
54 | Returns: Tensor
55 | tensor for loss [batch]
56 |
57 | """
58 | return self._length_predictor.loss(ctx, src_mask, tgt_mask)
59 |
60 | def decode(self, epsilon: torch.Tensor, tgt_mask: torch.Tensor,
61 | src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
62 | """
63 |
64 | Args:
65 | epsilon: Tensor
66 | epslion [batch, tgt_length, nz]
67 | tgt_mask: Tensor
68 | tensor of target masks [batch, tgt_length]
69 | src: Tensor
70 | source encoding [batch, src_length, hidden_size]
71 | src_mask: Tensor
72 | tensor of source masks [batch, src_length]
73 |
74 | Returns: Tensor1, Tensor2
75 | Tensor1: decoded latent code z [batch, tgt_length, nz]
76 | Tensor2: log probabilities [batch]
77 |
78 | """
79 | # [batch, tgt_length, nz]
80 | z, logdet = self.flow.fwdpass(epsilon, tgt_mask, src, src_mask)
81 | # [batch, tgt_length, nz]
82 | log_probs = epsilon.mul(epsilon) + math.log(math.pi * 2.0)
83 | # apply mask
84 | log_probs = log_probs.mul(tgt_mask.unsqueeze(2))
85 | # [batch]
86 | log_probs = log_probs.view(z.size(0), -1).sum(dim=1).mul(-0.5) + logdet
87 | return z, log_probs
88 |
89 | def sample(self, nlengths: int, nsamples: int, src: torch.Tensor,
90 | ctx: torch.Tensor, src_mask: torch.Tensor,
91 | tau=0.0, include_zero=False) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
92 | """
93 |
94 | Args:
95 | nlengths: int
96 | number of lengths per sentence
97 | nsamples: int
98 | number of samples per sentence per length
99 | src: Tensor
100 | source encoding [batch, src_length, hidden_size]
101 | ctx: Tensor
102 | tensor for global state [batch, hidden_size]
103 | src_mask: Tensor
104 | tensor of masks [batch, src_length]
105 | tau: float (default 0.0)
106 | temperature of density
107 | include_zero: bool (default False)
108 | include zero sample
109 |
110 | Returns: (Tensor1, Tensor2, Tensor3), (Tensor4, Tensor5), (Tensor6, Tensor7, Tensor8)
111 | Tensor1: samples from the prior [batch * nlengths * nsamples, tgt_length, nz]
112 | Tensor2: log probabilities [batch * nlengths * nsamples]
113 | Tensor3: target masks [batch * nlengths * nsamples, tgt_length]
114 | Tensor4: lengths [batch * nlengths]
115 | Tensor5: log probabilities of lengths [batch * nlengths]
116 | Tensor6: source encoding with shape [batch * nlengths * nsamples, src_length, hidden_size]
117 | Tensor7: tensor for global state [batch * nlengths * nsamples, hidden_size]
118 | Tensor8: source masks with shape [batch * nlengths * nsamples, src_length]
119 |
120 | """
121 | batch = src.size(0)
122 | batch_nlen = batch * nlengths
123 | # [batch, nlenths]
124 | lengths, log_probs_length = self.predict_length(ctx, src_mask, topk=nlengths)
125 | # [batch * nlengths]
126 | log_probs_length = log_probs_length.view(-1)
127 | lengths = lengths.view(-1)
128 | max_length = lengths.max().item()
129 | # [batch * nlengths, max_length]
130 | tgt_mask = torch.arange(max_length).to(src.device).unsqueeze(0).expand(batch_nlen, max_length).lt(lengths.unsqueeze(1)).float()
131 |
132 | # [batch * nlengths, nsamples, tgt_length, nz]
133 | epsilon = src.new_empty(batch_nlen, nsamples, max_length, self.features).normal_()
134 | epsilon = epsilon.mul(tgt_mask.view(batch_nlen, 1, max_length, 1)) * tau
135 | if include_zero:
136 | epsilon[:, 0].zero_()
137 | # [batch * nlengths * nsamples, tgt_length, nz]
138 | epsilon = epsilon.view(-1, max_length, self.features)
139 | if nsamples * nlengths > 1:
140 | # [batch, nlengths * nsamples, src_length, hidden_size]
141 | src = src.unsqueeze(1) + src.new_zeros(batch, nlengths * nsamples, *src.size()[1:])
142 | # [batch * nlengths * nsamples, src_length, hidden_size]
143 | src = src.view(batch_nlen * nsamples, *src.size()[2:])
144 | # [batch, nlengths * nsamples, hidden_size]
145 | ctx = ctx.unsqueeze(1) + ctx.new_zeros(batch, nlengths * nsamples, ctx.size(1))
146 | # [batch * nlengths * nsamples, hidden_size]
147 | ctx = ctx.view(batch_nlen * nsamples, ctx.size(2))
148 | # [batch, nlengths * nsamples, src_length]
149 | src_mask = src_mask.unsqueeze(1) + src_mask.new_zeros(batch, nlengths * nsamples, src_mask.size(1))
150 | # [batch * nlengths * nsamples, src_length]
151 | src_mask = src_mask.view(batch_nlen * nsamples, src_mask.size(2))
152 | # [batch * nlengths, nsamples, tgt_length]
153 | tgt_mask = tgt_mask.unsqueeze(1) + tgt_mask.new_zeros(batch_nlen, nsamples, tgt_mask.size(1))
154 | # [batch * nlengths * nsamples, tgt_length]
155 | tgt_mask = tgt_mask.view(batch_nlen * nsamples, tgt_mask.size(2))
156 |
157 | # [batch * nlength * nsamples, tgt_length, nz]
158 | z, log_probs = self.decode(epsilon, tgt_mask, src, src_mask)
159 | return (z, log_probs, tgt_mask), (lengths, log_probs_length), (src, ctx, src_mask)
160 |
161 | def log_probability(self, z: torch.Tensor, tgt_mask: torch.Tensor,
162 | src: torch.Tensor, ctx: torch.Tensor, src_mask: torch.Tensor,
163 | length_loss: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
164 | """
165 |
166 | Args:
167 | z: Tensor
168 | tensor of latent code [batch, length, nz]
169 | tgt_mask: Tensor
170 | tensor of target masks [batch, length]
171 | src: Tensor
172 | source encoding [batch, src_length, hidden_size]
173 | ctx: Tensor
174 | tensor for global state [batch, hidden_size]
175 | src_mask: Tensor
176 | tensor of source masks [batch, src_length]
177 | length_loss: bool (default True)
178 | compute loss of length
179 |
180 | Returns: Tensor1, Tensor2
181 | Tensor1: log probabilities of z [batch]
182 | Tensor2: length loss [batch]
183 |
184 | """
185 | # [batch]
186 | loss_length = self.length_loss(ctx, src_mask, tgt_mask) if length_loss else None
187 |
188 | # [batch, length, nz]
189 | epsilon, logdet = self.flow.bwdpass(z, tgt_mask, src, src_mask)
190 | # [batch, tgt_length, nz]
191 | log_probs = epsilon.mul(epsilon) + math.log(math.pi * 2.0)
192 | # apply mask
193 | log_probs = log_probs.mul(tgt_mask.unsqueeze(2))
194 | log_probs = log_probs.view(z.size(0), -1).sum(dim=1).mul(-0.5) + logdet
195 | return log_probs, loss_length
196 |
197 | def init(self, z, tgt_mask, src, src_mask, init_scale=1.0):
198 | return self.flow.bwdpass(z, tgt_mask, src, src_mask, init=True, init_scale=init_scale)
199 |
200 | @classmethod
201 | def register(cls, name: str):
202 | Prior._registry[name] = cls
203 |
204 | @classmethod
205 | def by_name(cls, name: str):
206 | return Prior._registry[name]
207 |
208 | @classmethod
209 | def from_params(cls, params: Dict) -> "Prior":
210 | flow_params = params.pop('flow')
211 | flow = NMTFlow.from_params(flow_params)
212 | predictor_params = params.pop('length_predictor')
213 | length_predictor = LengthPredictor.by_name(predictor_params.pop('type')).from_params(predictor_params)
214 | return Prior(flow, length_predictor)
215 |
216 |
217 | Prior.register('normal')
218 |
--------------------------------------------------------------------------------
/flownmt/nnet/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.nnet.weightnorm import LinearWeightNorm, Conv1dWeightNorm
2 | from flownmt.nnet.attention import GlobalAttention, MultiHeadAttention, PositionwiseFeedForward
3 | from flownmt.nnet.transformer import TransformerEncoderLayer, TransformerDecoderLayer
4 | from flownmt.nnet.layer_norm import LayerNorm
5 | from flownmt.nnet.positional_encoding import PositionalEncoding
6 | from flownmt.nnet.criterion import LabelSmoothedCrossEntropyLoss
7 |
--------------------------------------------------------------------------------
/flownmt/nnet/attention.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | import torch
3 | from torch.nn import Parameter
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from flownmt.nnet.layer_norm import LayerNorm
8 |
9 |
10 | class GlobalAttention(nn.Module):
11 | """
12 | Global Attention between encoder and decoder
13 | """
14 |
15 | def __init__(self, key_features, query_features, value_features, hidden_features=None, dropout=0.0):
16 | """
17 |
18 | Args:
19 | key_features: int
20 | dimension of keys
21 | query_features: int
22 | dimension of queries
23 | value_features: int
24 | dimension of values (outputs)
25 | hidden_features: int
26 | dimension of hidden states (default value_features)
27 | dropout: float
28 | dropout rate
29 | """
30 | super(GlobalAttention, self).__init__()
31 | if hidden_features is None:
32 | hidden_features = value_features
33 | self.key_proj = nn.Linear(key_features, 2 * hidden_features, bias=True)
34 | self.query_proj = nn.Linear(query_features, hidden_features, bias=True)
35 | self.dropout = dropout
36 | self.fc = nn.Linear(hidden_features, value_features)
37 | self.hidden_features = hidden_features
38 | self.reset_parameters()
39 |
40 | def reset_parameters(self):
41 | # key proj
42 | nn.init.xavier_uniform_(self.key_proj.weight)
43 | nn.init.constant_(self.key_proj.bias, 0)
44 | # query proj
45 | nn.init.xavier_uniform_(self.query_proj.weight)
46 | nn.init.constant_(self.query_proj.bias, 0)
47 | # fc
48 | nn.init.xavier_uniform_(self.fc.weight)
49 | nn.init.constant_(self.fc.bias, 0)
50 |
51 | @overrides
52 | def forward(self, query, key, key_mask=None):
53 | """
54 |
55 | Args:
56 | query: Tensor
57 | query tensor [batch, query_length, query_features]
58 | key: Tensor
59 | key tensor [batch, key_length, key_features]
60 | key_mask: ByteTensor or None
61 | binary ByteTensor [batch, src_len] padding elements are indicated by 1s.
62 |
63 | Returns: Tensor
64 | value tensor [batch, query_length, value_features]
65 |
66 | """
67 | bs, timesteps, _ = key.size()
68 | dim = self.hidden_features
69 | # [batch, query_length, dim]
70 | query = self.query_proj(query)
71 |
72 | # [batch, key_length, 2 * dim]
73 | c = self.key_proj(key)
74 | # [batch, key_length, 2, dim]
75 | c = c.view(bs, timesteps, 2, dim)
76 | # [batch, key_length, dim]
77 | key = c[:, :, 0]
78 | value = c[:, :, 1]
79 |
80 | # attention weights [batch, query_length, key_length]
81 | attn_weights = torch.bmm(query, key.transpose(1, 2))
82 | if key_mask is not None:
83 | attn_weights = attn_weights.masked_fill(key_mask.unsqueeze(1), float('-inf'))
84 |
85 | attn_weights = F.softmax(attn_weights.float(), dim=-1,
86 | dtype=torch.float32 if attn_weights.dtype == torch.float16 else attn_weights.dtype)
87 |
88 | # values [batch, query_length, dim]
89 | out = torch.bmm(attn_weights, value)
90 | out = F.dropout(self.fc(out), p=self.dropout, training=self.training)
91 | return out
92 |
93 | def init(self, query, key, key_mask=None, init_scale=1.0):
94 | with torch.no_grad():
95 | return self(query, key, key_mask=key_mask)
96 |
97 |
98 | class MultiHeadAttention(nn.Module):
99 | """
100 | Multi-head Attention
101 | """
102 | def __init__(self, model_dim, heads, dropout=0.0, mask_diag=False):
103 | """
104 |
105 | Args:
106 | model_dim: int
107 | the input dimension for keys, queries and values
108 | heads: int
109 | number of heads
110 | dropout: float
111 | dropout rate
112 | """
113 | super(MultiHeadAttention, self).__init__()
114 | self.model_dim = model_dim
115 | self.head_dim = model_dim // heads
116 | self.heads = heads
117 | self.dropout = dropout
118 | self.mask_diag = mask_diag
119 | assert self.head_dim * heads == self.model_dim, "model_dim must be divisible by number of heads"
120 | self.scaling = self.head_dim ** -0.5
121 | self.in_proj_weight = Parameter(torch.empty(3 * model_dim, model_dim))
122 | self.in_proj_bias = Parameter(torch.empty(3 * model_dim))
123 | self.layer_norm = LayerNorm(model_dim)
124 | self.reset_parameters()
125 |
126 | def reset_parameters(self):
127 | # in proj
128 | nn.init.xavier_uniform_(self.in_proj_weight[:self.model_dim, :])
129 | nn.init.xavier_uniform_(self.in_proj_weight[self.model_dim:(self.model_dim * 2), :])
130 | nn.init.xavier_uniform_(self.in_proj_weight[(self.model_dim * 2):, :])
131 | nn.init.constant_(self.in_proj_bias, 0.)
132 |
133 | def forward(self, query, key, value, key_mask=None):
134 | """
135 |
136 | Args:
137 | query: Tenfor
138 | [batch, tgt_len, model_dim]
139 | key: Tensor
140 | [batch, src_len, model_dim]
141 | value: Tensor
142 | [batch, src_len, model_dim]
143 | key_mask: ByteTensor or None
144 | binary ByteTensor [batch, src_len] padding elements are indicated by 1s.
145 |
146 | Returns:
147 |
148 | """
149 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
150 | kv_same = key.data_ptr() == value.data_ptr()
151 |
152 | bs, src_len, model_dim = key.size()
153 | tgt_len = query.size(1)
154 | heads = self.heads
155 | residual = query
156 |
157 | # k, v: [bs, src_len, model_dim]
158 | # q: [bs, tgt_len, model_dim]
159 | if qkv_same:
160 | # self-attention
161 | q, k, v = self._in_proj_qkv(query)
162 | elif kv_same:
163 | # encoder-decoder attention
164 | q = self._in_proj_q(query)
165 | k, v = self._in_proj_kv(key)
166 | else:
167 | q = self._in_proj_q(query)
168 | k = self._in_proj_k(key)
169 | v = self._in_proj_v(value)
170 | q *= self.scaling
171 |
172 | model_dim = q.size(2)
173 | dim = model_dim // heads
174 |
175 | # [len, batch, model_dim] -> [len, batch * heads, dim] -> [batch * heads, len, dim]
176 | q = q.transpose(0, 1).contiguous().view(tgt_len, bs * heads, dim).transpose(0, 1)
177 | k = k.transpose(0, 1).contiguous().view(src_len, bs * heads, dim).transpose(0, 1)
178 | v = v.transpose(0, 1).contiguous().view(src_len, bs * heads, dim).transpose(0, 1)
179 |
180 | # attention weights [batch * heads, tgt_len, src_len]
181 | attn_weights = torch.bmm(q, k.transpose(1, 2))
182 | if key_mask is not None:
183 | attn_weights = attn_weights.view(bs, heads, tgt_len, src_len)
184 | attn_weights = attn_weights.masked_fill(key_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
185 | attn_weights = attn_weights.view(bs * heads, tgt_len, src_len)
186 |
187 | if self.mask_diag:
188 | assert tgt_len == src_len
189 | # [1, tgt_len, tgt_len]
190 | diag_mask = torch.eye(tgt_len, device=query.device, dtype=torch.uint8).unsqueeze(0)
191 | attn_weights = attn_weights.masked_fill(diag_mask, float('-inf'))
192 |
193 | attn_weights = F.softmax(attn_weights.float(), dim=-1,
194 | dtype=torch.float32 if attn_weights.dtype == torch.float16 else attn_weights.dtype)
195 |
196 |
197 | # outputs [batch * heads, tgt_len, dim]
198 | out = torch.bmm(attn_weights, v)
199 | # merge heads
200 | # [batch, heads, tgt_len, dim] -> [batch, tgt_len, heads, dim]
201 | # -> [batch, tgt_len, model_dim]
202 | out = out.view(bs, heads, tgt_len, dim).transpose(1, 2).contiguous().view(bs, tgt_len, model_dim)
203 | out = F.dropout(out, p=self.dropout, training=self.training)
204 | out = self.layer_norm(out + residual)
205 | return out
206 |
207 | def init(self, query, key, value, key_mask=None, init_scale=1.0):
208 | with torch.no_grad():
209 | return self(query, key, value, key_mask=key_mask)
210 |
211 | def _in_proj_qkv(self, query):
212 | return self._in_proj(query).chunk(3, dim=-1)
213 |
214 | def _in_proj_kv(self, key):
215 | return self._in_proj(key, start=self.model_dim).chunk(2, dim=-1)
216 |
217 | def _in_proj_q(self, query):
218 | return self._in_proj(query, end=self.model_dim)
219 |
220 | def _in_proj_k(self, key):
221 | return self._in_proj(key, start=self.model_dim, end=2 * self.model_dim)
222 |
223 | def _in_proj_v(self, value):
224 | return self._in_proj(value, start=2 * self.model_dim)
225 |
226 | def _in_proj(self, input, start=0, end=None):
227 | weight = self.in_proj_weight
228 | bias = self.in_proj_bias
229 | weight = weight[start:end, :]
230 | if bias is not None:
231 | bias = bias[start:end]
232 | return F.linear(input, weight, bias)
233 |
234 |
235 | class PositionwiseFeedForward(nn.Module):
236 | def __init__(self, features, hidden_features, dropout=0.0):
237 | super(PositionwiseFeedForward, self).__init__()
238 | self.linear1 = nn.Linear(features, hidden_features)
239 | self.dropout = dropout
240 | self.linear2 = nn.Linear(hidden_features, features)
241 | self.layer_norm = LayerNorm(features)
242 |
243 | def forward(self, x):
244 | residual = x
245 | x = F.relu(self.linear1(x), inplace=True)
246 | x = F.dropout(x, p=self.dropout, training=self.training)
247 | x = F.dropout(self.linear2(x), p=self.dropout, training=self.training)
248 | x = self.layer_norm(residual + x)
249 | return x
250 |
251 | def init(self, x, init_scale=1.0):
252 | with torch.no_grad():
253 | return self(x)
254 |
--------------------------------------------------------------------------------
/flownmt/nnet/criterion.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 |
4 |
5 | class LabelSmoothedCrossEntropyLoss(nn.Module):
6 | """
7 | Cross Entropy loss with label smoothing.
8 | For training, the loss is smoothed with parameter eps, while for evaluation, the smoothing is disabled.
9 | """
10 | def __init__(self, label_smoothing):
11 | super(LabelSmoothedCrossEntropyLoss, self).__init__()
12 | self.eps = label_smoothing
13 |
14 | def forward(self, input, target):
15 | # [batch, c, d1, ..., dk]
16 | loss = F.log_softmax(input, dim=1) * -1.
17 | # [batch, d1, ..., dk]
18 | nll_loss = loss.gather(dim=1, index=target.unsqueeze(1)).squeeze(1)
19 | if self.training:
20 | # [batch, c, d1, ..., dk]
21 | inf_mask = loss.eq(float('inf'))
22 | # [batch, d1, ..., dk]
23 | smooth_loss = loss.masked_fill(inf_mask, 0.).sum(dim=1)
24 | eps_i = self.eps / (1.0 - inf_mask.float()).sum(dim=1)
25 | return nll_loss * (1. - self.eps) + smooth_loss * eps_i
26 | else:
27 | return nll_loss
28 |
--------------------------------------------------------------------------------
/flownmt/nnet/layer_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
6 | if not export and torch.cuda.is_available():
7 | try:
8 | from apex.normalization import FusedLayerNorm
9 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
10 | except ImportError:
11 | pass
12 | return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
13 |
--------------------------------------------------------------------------------
/flownmt/nnet/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | from flownmt.utils import make_positions
6 |
7 |
8 | class PositionalEncoding(nn.Module):
9 | """This module produces sinusoidal positional embeddings of any length.
10 | Padding symbols are ignored.
11 | """
12 |
13 | def __init__(self, encoding_dim, padding_idx, init_size=1024):
14 | super().__init__()
15 | self.encoding_dim = encoding_dim
16 | self.padding_idx = padding_idx
17 | self.weights = PositionalEncoding.get_embedding(
18 | init_size,
19 | encoding_dim,
20 | padding_idx,
21 | )
22 | self.register_buffer('_float_tensor', torch.FloatTensor(1))
23 |
24 | @staticmethod
25 | def get_embedding(num_encodings, encoding_dim, padding_idx=None):
26 | """Build sinusoidal embeddings.
27 | This matches the implementation in tensor2tensor, but differs slightly
28 | from the description in Section 3.5 of "Attention Is All You Need".
29 | """
30 | half_dim = encoding_dim // 2
31 | emb = math.log(10000) / (half_dim - 1)
32 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
33 | emb = torch.arange(num_encodings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
34 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_encodings, -1)
35 | if encoding_dim % 2 == 1:
36 | # zero pad
37 | emb = torch.cat([emb, torch.zeros(num_encodings, 1)], dim=1)
38 | emb[0, :] = 0
39 | return emb
40 |
41 | def forward(self, x):
42 | """Input is expected to be of size [bsz x seqlen]."""
43 | bsz, seq_len = x.size()[:2]
44 | max_pos = seq_len + 1
45 | if self.weights is None or max_pos > self.weights.size(0):
46 | # recompute/expand embeddings if needed
47 | self.weights = PositionalEncoding.get_embedding(
48 | max_pos,
49 | self.embedding_dim,
50 | self.padding_idx,
51 | )
52 | self.weights = self.weights.type_as(self._float_tensor)
53 |
54 | if self.padding_idx is None:
55 | return self.weights[1:seq_len + 1].detach()
56 | else:
57 | positions = make_positions(x, self.padding_idx)
58 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
59 |
60 | def max_positions(self):
61 | """Maximum number of supported positions."""
62 | return int(1e5) # an arbitrary large number
63 |
--------------------------------------------------------------------------------
/flownmt/nnet/transformer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from flownmt.nnet.attention import MultiHeadAttention, PositionwiseFeedForward
4 |
5 |
6 | class TransformerEncoderLayer(nn.Module):
7 | def __init__(self, model_dim, hidden_dim, heads, dropout=0.0, mask_diag=False):
8 | super(TransformerEncoderLayer, self).__init__()
9 | self.slf_attn = MultiHeadAttention(model_dim, heads, dropout=dropout, mask_diag=mask_diag)
10 | self.pos_ffn = PositionwiseFeedForward(model_dim, hidden_dim, dropout=dropout)
11 |
12 | def forward(self, x, mask):
13 | out = self.slf_attn(x, x, x, key_mask=mask)
14 | out = self.pos_ffn(out)
15 | return out
16 |
17 | def init(self, x, mask, init_scale=1.0):
18 | out = self.slf_attn.init(x, x, x, key_mask=mask, init_scale=init_scale)
19 | out = self.pos_ffn.init(out, init_scale=init_scale)
20 | return out
21 |
22 |
23 | class TransformerDecoderLayer(nn.Module):
24 | def __init__(self, model_dim, hidden_dim, heads, dropout=0.0, mask_diag=False):
25 | super(TransformerDecoderLayer, self).__init__()
26 | self.slf_attn = MultiHeadAttention(model_dim, heads, dropout=dropout, mask_diag=mask_diag)
27 | self.enc_attn = MultiHeadAttention(model_dim, heads, dropout=dropout)
28 | self.pos_ffn = PositionwiseFeedForward(model_dim, hidden_dim, dropout=dropout)
29 |
30 | def forward(self, x, mask, src, src_mask):
31 | out = self.slf_attn(x, x, x, key_mask=mask)
32 | out = self.enc_attn(out, src, src, key_mask=src_mask)
33 | out = self.pos_ffn(out)
34 | return out
35 |
36 | def init(self, x, mask, src, src_mask, init_scale=1.0):
37 | out = self.slf_attn.init(x, x, x, key_mask=mask, init_scale=init_scale)
38 | out = self.enc_attn.init(out, src, src, key_mask=src_mask, init_scale=init_scale)
39 | out = self.pos_ffn.init(out, init_scale=init_scale)
40 | return out
41 |
--------------------------------------------------------------------------------
/flownmt/nnet/weightnorm.py:
--------------------------------------------------------------------------------
1 | from overrides import overrides
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class LinearWeightNorm(nn.Module):
7 | """
8 | Linear with weight normalization
9 | """
10 | def __init__(self, in_features, out_features, bias=True):
11 | super(LinearWeightNorm, self).__init__()
12 | self.linear = nn.Linear(in_features, out_features, bias=bias)
13 | self.reset_parameters()
14 |
15 | def reset_parameters(self):
16 | nn.init.normal_(self.linear.weight, mean=0.0, std=0.05)
17 | if self.linear.bias is not None:
18 | nn.init.constant_(self.linear.bias, 0)
19 | self.linear = nn.utils.weight_norm(self.linear)
20 |
21 | def extra_repr(self):
22 | return 'in_features={}, out_features={}, bias={}'.format(
23 | self.in_features, self.out_features, self.bias is not None
24 | )
25 |
26 | def init(self, x, init_scale=1.0):
27 | with torch.no_grad():
28 | # [batch, out_features]
29 | out = self(x).view(-1, self.linear.out_features)
30 | # [out_features]
31 | mean = out.mean(dim=0)
32 | std = out.std(dim=0)
33 | inv_stdv = init_scale / (std + 1e-6)
34 |
35 | self.linear.weight_g.mul_(inv_stdv.unsqueeze(1))
36 | if self.linear.bias is not None:
37 | self.linear.bias.add_(-mean).mul_(inv_stdv)
38 | return self(x)
39 |
40 | def forward(self, input):
41 | return self.linear(input)
42 |
43 |
44 | class Conv1dWeightNorm(nn.Module):
45 | """
46 | Conv1d with weight normalization
47 | """
48 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
49 | padding=0, dilation=1, groups=1, bias=True):
50 | super(Conv1dWeightNorm, self).__init__()
51 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride,
52 | padding=padding, dilation=dilation, groups=groups, bias=bias)
53 | self.reset_parameters()
54 |
55 | def reset_parameters(self):
56 | nn.init.normal_(self.conv.weight, mean=0.0, std=0.05)
57 | if self.conv.bias is not None:
58 | nn.init.constant_(self.conv.bias, 0)
59 | self.conv = nn.utils.weight_norm(self.conv)
60 |
61 | def init(self, x, init_scale=1.0):
62 | with torch.no_grad():
63 | # [batch, n_channels, L]
64 | out = self(x)
65 | n_channels = out.size(1)
66 | out = out.transpose(0, 1).contiguous().view(n_channels, -1)
67 | # [n_channels]
68 | mean = out.mean(dim=1)
69 | std = out.std(dim=1)
70 | inv_stdv = init_scale / (std + 1e-6)
71 |
72 | self.conv.weight_g.mul_(inv_stdv.view(n_channels, 1, 1))
73 | if self.conv.bias is not None:
74 | self.conv.bias.add_(-mean).mul_(inv_stdv)
75 | return self(x)
76 |
77 | def forward(self, input):
78 | return self.conv(input)
79 |
80 | @overrides
81 | def extra_repr(self):
82 | return self.conv.extra_repr()
83 |
--------------------------------------------------------------------------------
/flownmt/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from flownmt.optim.adamw import AdamW
2 | from flownmt.optim.lr_scheduler import InverseSquareRootScheduler, ExponentialScheduler
3 |
--------------------------------------------------------------------------------
/flownmt/optim/adamw.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class AdamW(Optimizer):
7 | r"""Implements AdamW algorithm.
8 | This implementation is modified from torch.optim.Adam based on:
9 | `Fixed Weight Decay Regularization in Adam`
10 | (see https://arxiv.org/abs/1711.05101)
11 |
12 | Adam has been proposed in `Adam: A Method for Stochastic Optimization`_.
13 |
14 | Arguments:
15 | params (iterable): iterable of parameters to optimize or dicts defining
16 | parameter groups
17 | lr (float, optional): learning rate (default: 1e-3)
18 | betas (Tuple[float, float], optional): coefficients used for computing
19 | running averages of gradient and its square (default: (0.9, 0.999))
20 | eps (float, optional): term added to the denominator to improve
21 | numerical stability (default: 1e-8)
22 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
23 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
24 | algorithm from the paper `On the Convergence of Adam and Beyond`_
25 | (default: False)
26 |
27 | .. _Adam\: A Method for Stochastic Optimization:
28 | https://arxiv.org/abs/1412.6980
29 | .. _On the Convergence of Adam and Beyond:
30 | https://openreview.net/forum?id=ryQu7f-RZ
31 | """
32 |
33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
34 | weight_decay=0, amsgrad=False):
35 | if not 0.0 <= lr:
36 | raise ValueError("Invalid learning rate: {}".format(lr))
37 | if not 0.0 <= eps:
38 | raise ValueError("Invalid epsilon value: {}".format(eps))
39 | if not 0.0 <= betas[0] < 1.0:
40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
41 | if not 0.0 <= betas[1] < 1.0:
42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
43 | defaults = dict(lr=lr, betas=betas, eps=eps,
44 | weight_decay=weight_decay, amsgrad=amsgrad)
45 | super(AdamW, self).__init__(params, defaults)
46 |
47 | def __setstate__(self, state):
48 | super(AdamW, self).__setstate__(state)
49 | for group in self.param_groups:
50 | group.setdefault('amsgrad', False)
51 |
52 | def step(self, closure=None):
53 | """Performs a single optimization step.
54 |
55 | Arguments:
56 | closure (callable, optional): A closure that reevaluates the model
57 | and returns the loss.
58 | """
59 | loss = None
60 | if closure is not None:
61 | loss = closure()
62 |
63 | for group in self.param_groups:
64 | for p in group['params']:
65 | if p.grad is None:
66 | continue
67 | grad = p.grad.data
68 | if grad.is_sparse:
69 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
70 | amsgrad = group['amsgrad']
71 |
72 | state = self.state[p]
73 |
74 | # State initialization
75 | if len(state) == 0:
76 | state['step'] = 0
77 | # Exponential moving average of gradient values
78 | state['exp_avg'] = torch.zeros_like(p.data)
79 | # Exponential moving average of squared gradient values
80 | state['exp_avg_sq'] = torch.zeros_like(p.data)
81 | if amsgrad:
82 | # Maintains max of all exp. moving avg. of sq. grad. values
83 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
84 |
85 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
86 | if amsgrad:
87 | max_exp_avg_sq = state['max_exp_avg_sq']
88 | beta1, beta2 = group['betas']
89 |
90 | state['step'] += 1
91 |
92 | # Decay the first and second moment running average coefficient
93 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
94 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
95 | if amsgrad:
96 | # Maintains the maximum of all 2nd moment running avg. till now
97 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
98 | # Use the max. for normalizing running avg. of gradient
99 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
100 | else:
101 | denom = exp_avg_sq.sqrt().add_(group['eps'])
102 |
103 | bias_correction1 = 1 - beta1 ** state['step']
104 | bias_correction2 = 1 - beta2 ** state['step']
105 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
106 |
107 | if group['weight_decay'] != 0:
108 | p.data.add_(-group['weight_decay'] * group['lr'], p.data)
109 |
110 | p.data.addcdiv_(-step_size, exp_avg, denom)
111 |
112 | return loss
113 |
--------------------------------------------------------------------------------
/flownmt/optim/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.optimizer import Optimizer
2 |
3 |
4 | class _LRScheduler(object):
5 | def __init__(self, optimizer, last_epoch=-1):
6 | if not isinstance(optimizer, Optimizer):
7 | raise TypeError('{} is not an Optimizer'.format(
8 | type(optimizer).__name__))
9 | self.optimizer = optimizer
10 | if last_epoch == -1:
11 | for group in optimizer.param_groups:
12 | group.setdefault('initial_lr', group['lr'])
13 | last_epoch = 0
14 | else:
15 | for i, group in enumerate(optimizer.param_groups):
16 | if 'initial_lr' not in group:
17 | raise KeyError("param 'initial_lr' is not specified "
18 | "in param_groups[{}] when resuming an optimizer".format(i))
19 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
20 |
21 | def state_dict(self):
22 | """Returns the state of the scheduler as a :class:`dict`.
23 |
24 | It contains an entry for every variable in self.__dict__ which
25 | is not the optimizer.
26 | """
27 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
28 |
29 | def load_state_dict(self, state_dict):
30 | """Loads the schedulers state.
31 |
32 | Arguments:
33 | state_dict (dict): scheduler state. Should be an object returned
34 | from a call to :meth:`state_dict`.
35 | """
36 | self.__dict__.update(state_dict)
37 |
38 | def get_lr(self):
39 | raise NotImplementedError
40 |
41 | def step(self, epoch=None):
42 | if epoch is None:
43 | epoch = self.last_epoch + 1
44 | self.last_epoch = epoch
45 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
46 | param_group['lr'] = lr
47 |
48 |
49 | class InverseSquareRootScheduler(_LRScheduler):
50 | """
51 | Decay the LR based on the inverse square root of the update number.
52 | We also support a warmup phase where we linearly increase the learning rate
53 | from zero until the configured learning rate (``--lr``).
54 | Thereafter we decay proportional to the number of
55 | updates, with a decay factor set to align with the configured learning rate.
56 | During warmup::
57 | lrs = torch.linspace(0, args.lr, args.warmup_updates)
58 | lr = lrs[update_num]
59 | After warmup::
60 | decay_factor = args.lr * sqrt(args.warmup_updates)
61 | lr = decay_factor / sqrt(update_num)
62 | """
63 | def __init__(self, optimizer, warmup_steps, init_lr, last_epoch=-1):
64 | assert warmup_steps > 0, 'warmup steps should be larger than 0.'
65 | super(InverseSquareRootScheduler, self).__init__(optimizer, last_epoch)
66 | self.warmup_steps = float(warmup_steps)
67 | self.init_lr = init_lr
68 | self.lr_steps = [(base_lr - init_lr) / warmup_steps for base_lr in self.base_lrs]
69 | self.decay_factor = self.warmup_steps ** 0.5
70 | if last_epoch == -1:
71 | last_epoch = 0
72 | self.step(last_epoch)
73 |
74 | def get_lr(self):
75 | if self.last_epoch < self.warmup_steps:
76 | return [self.init_lr + lr_step * self.last_epoch for lr_step in self.lr_steps]
77 | else:
78 | lr_factor = self.decay_factor * self.last_epoch**-0.5
79 | return [base_lr * lr_factor for base_lr in self.base_lrs]
80 |
81 |
82 | class ExponentialScheduler(_LRScheduler):
83 | """Set the learning rate of each parameter group to the initial lr decayed
84 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr.
85 | We also support a warmup phase where we linearly increase the learning rate
86 | from zero until the configured learning rate (``--lr``).
87 | Args:
88 | optimizer (Optimizer): Wrapped optimizer.
89 | gamma (float): Multiplicative factor of learning rate decay.
90 | warmup_steps (int): Warmup steps..
91 | last_epoch (int): The index of last epoch. Default: -1.
92 | """
93 |
94 | def __init__(self, optimizer, gamma, warmup_steps, init_lr, last_epoch=-1):
95 | super(ExponentialScheduler, self).__init__(optimizer, last_epoch)
96 | self.gamma = gamma
97 | # handle warmup <= 0
98 | self.warmup_steps = max(1, warmup_steps)
99 | self.init_lr = init_lr
100 | self.lr_steps = [(base_lr - init_lr) / self.warmup_steps for base_lr in self.base_lrs]
101 | if last_epoch == -1:
102 | last_epoch = 0
103 | self.step(last_epoch)
104 |
105 | def get_lr(self):
106 | if self.last_epoch < self.warmup_steps:
107 | return [self.init_lr + lr_step * self.last_epoch for lr_step in self.lr_steps]
108 | else:
109 | lr_factor = self.gamma ** (self.last_epoch - self.warmup_steps)
110 | return [base_lr * lr_factor for base_lr in self.base_lrs]
111 |
--------------------------------------------------------------------------------
/flownmt/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | from typing import Tuple, List
4 | import torch
5 | from torch._six import inf
6 |
7 |
8 | def get_logger(name, level=logging.INFO, handler=sys.stdout,
9 | formatter='%(asctime)s - %(name)s - %(levelname)s - %(message)s'):
10 | logger = logging.getLogger(name)
11 | logger.setLevel(logging.INFO)
12 | formatter = logging.Formatter(formatter)
13 | stream_handler = logging.StreamHandler(handler)
14 | stream_handler.setLevel(level)
15 | stream_handler.setFormatter(formatter)
16 | logger.addHandler(stream_handler)
17 | return logger
18 |
19 |
20 | def norm(p: torch.Tensor, dim: int):
21 | """Computes the norm over all dimensions except dim"""
22 | if dim is None:
23 | return p.norm()
24 | elif dim == 0:
25 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
26 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
27 | elif dim == p.dim() - 1:
28 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
29 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
30 | else:
31 | return norm(p.transpose(0, dim), 0).transpose(0, dim)
32 |
33 |
34 | def exponentialMovingAverage(original, shadow, decay_rate, init=False):
35 | params = dict()
36 | for name, param in shadow.named_parameters():
37 | params[name] = param
38 | for name, param in original.named_parameters():
39 | shadow_param = params[name]
40 | if init:
41 | shadow_param.data.copy_(param.data)
42 | else:
43 | shadow_param.data.add_((1 - decay_rate) * (param.data - shadow_param.data))
44 |
45 |
46 | def logPlusOne(x):
47 | """
48 | compute log(x + 1) for small x
49 | Args:
50 | x: Tensor
51 | Returns: Tensor
52 | log(x+1)
53 | """
54 | eps = 1e-4
55 | mask = x.abs().le(eps).type_as(x)
56 | return x.mul(x.mul(-0.5) + 1.0) * mask + (x + 1.0).log() * (1.0 - mask)
57 |
58 |
59 | def gate(x1, x2):
60 | return x1 * x2.sigmoid_()
61 |
62 |
63 | def total_grad_norm(parameters, norm_type=2):
64 | if isinstance(parameters, torch.Tensor):
65 | parameters = [parameters]
66 | parameters = list(filter(lambda p: p.grad is not None, parameters))
67 | norm_type = float(norm_type)
68 | if norm_type == inf:
69 | total_norm = max(p.grad.data.abs().max() for p in parameters)
70 | else:
71 | total_norm = 0
72 | for p in parameters:
73 | param_norm = p.grad.data.norm(norm_type)
74 | total_norm += param_norm.item() ** norm_type
75 | total_norm = total_norm ** (1. / norm_type)
76 | return total_norm
77 |
78 |
79 | def squeeze(x: torch.Tensor, mask: torch.Tensor, factor: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
80 | """
81 | Args:
82 | x: Tensor
83 | input tensor [batch, length, features]
84 | mask: Tensor
85 | mask tensor [batch, length]
86 | factor: int
87 | squeeze factor (default 2)
88 | Returns: Tensor1, Tensor2
89 | squeezed x [batch, length // factor, factor * features]
90 | squeezed mask [batch, length // factor]
91 | """
92 | assert factor >= 1
93 | if factor == 1:
94 | return x
95 |
96 | batch, length, features = x.size()
97 | assert length % factor == 0
98 | # [batch, length // factor, factor * features]
99 | x = x.contiguous().view(batch, length // factor, factor * features)
100 | mask = mask.view(batch, length // factor, factor).sum(dim=2).clamp(max=1.0)
101 | return x, mask
102 |
103 |
104 | def unsqueeze(x: torch.Tensor, factor: int = 2) -> torch.Tensor:
105 | """
106 | Args:
107 | x: Tensor
108 | input tensor [batch, length, features]
109 | factor: int
110 | unsqueeze factor (default 2)
111 | Returns: Tensor
112 | squeezed tensor [batch, length * 2, features // 2]
113 | """
114 | assert factor >= 1
115 | if factor == 1:
116 | return x
117 |
118 | batch, length, features = x.size()
119 | assert features % factor == 0
120 | # [batch, length * factor, features // factor]
121 | x = x.view(batch, length * factor, features // factor)
122 | return x
123 |
124 |
125 | def split(x: torch.Tensor, z1_features) -> Tuple[torch.Tensor, torch.Tensor]:
126 | """
127 | Args:
128 | x: Tensor
129 | input tensor [batch, length, features]
130 | z1_features: int
131 | the number of features of z1
132 | Returns: Tensor, Tensor
133 | split tensors [batch, length, z1_features], [batch, length, features-z1_features]
134 | """
135 | z1 = x[:, :, :z1_features]
136 | z2 = x[:, :, z1_features:]
137 | return z1, z2
138 |
139 |
140 | def unsplit(xs: List[torch.Tensor]) -> torch.Tensor:
141 | """
142 | Args:
143 | xs: List[Tensor]
144 | tensors to be combined
145 | Returns: Tensor
146 | combined tensor
147 | """
148 | return torch.cat(xs, dim=2)
149 |
150 |
151 | def make_positions(tensor, padding_idx):
152 | """Replace non-padding symbols with their position numbers.
153 | Position numbers begin at padding_idx+1. Padding symbols are ignored.
154 | """
155 | mask = tensor.ne(padding_idx).long()
156 | return torch.cumsum(mask, dim=1) * mask
157 |
158 |
159 | # def prepare_rnn_seq(rnn_input, lengths, batch_first=False):
160 | # '''
161 | # Args:
162 | # rnn_input: [seq_len, batch, input_size]: tensor containing the features of the input sequence.
163 | # lengths: [batch]: tensor containing the lengthes of the input sequence
164 | # batch_first: If True, then the input and output tensors are provided as [batch, seq_len, feature].
165 | # Returns:
166 | # '''
167 | #
168 | # def check_decreasing(lengths):
169 | # lens, order = torch.sort(lengths, dim=0, descending=True)
170 | # if torch.ne(lens, lengths).sum() == 0:
171 | # return None
172 | # else:
173 | # _, rev_order = torch.sort(order)
174 | # return lens, order, rev_order
175 | #
176 | # check_res = check_decreasing(lengths)
177 | #
178 | # if check_res is None:
179 | # lens = lengths
180 | # rev_order = None
181 | # else:
182 | # lens, order, rev_order = check_res
183 | # batch_dim = 0 if batch_first else 1
184 | # rnn_input = rnn_input.index_select(batch_dim, order)
185 | # lens = lens.tolist()
186 | # seq = pack_padded_sequence(rnn_input, lens, batch_first=batch_first)
187 | # return seq, rev_order
188 | #
189 | # def recover_rnn_seq(seq, rev_order, batch_first=False, total_length=None):
190 | # output, _ = pad_packed_sequence(seq, batch_first=batch_first, total_length=total_length)
191 | # if rev_order is not None:
192 | # batch_dim = 0 if batch_first else 1
193 | # output = output.index_select(batch_dim, rev_order)
194 | # return output
195 | #
196 | #
197 | # def recover_order(tensors, rev_order):
198 | # if rev_order is None:
199 | # return tensors
200 | # recovered_tensors = [tensor.index_select(0, rev_order) for tensor in tensors]
201 | # return recovered_tensors
202 | #
203 | #
204 | # def decreasing_order(lengths, tensors):
205 | # def check_decreasing(lengths):
206 | # lens, order = torch.sort(lengths, dim=0, descending=True)
207 | # if torch.ne(lens, lengths).sum() == 0:
208 | # return None
209 | # else:
210 | # _, rev_order = torch.sort(order)
211 | # return lens, order, rev_order
212 | #
213 | # check_res = check_decreasing(lengths)
214 | #
215 | # if check_res is None:
216 | # lens = lengths
217 | # rev_order = None
218 | # ordered_tensors = tensors
219 | # else:
220 | # lens, order, rev_order = check_res
221 | # ordered_tensors = [tensor.index_select(0, order) for tensor in tensors]
222 | #
223 | # return lens, ordered_tensors, rev_order
224 |
--------------------------------------------------------------------------------
/images/flowseq_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XuezheMax/flowseq/8cb4ae00c26fbeb3e1459e3b3b90e7e9a84c3d2b/images/flowseq_diagram.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python = 3.6
2 | numpy
3 | overrides
4 |
--------------------------------------------------------------------------------