├── .gitignore ├── LICENSE ├── README.md ├── experiments ├── configs │ ├── iwslt14 │ │ ├── config-rnn.json │ │ └── config-transformer.json │ ├── wmt14 │ │ ├── config-rnn.json │ │ ├── config-transformer-base.json │ │ └── config-transformer-large.json │ └── wmt16 │ │ ├── config-rnn.json │ │ ├── config-transformer-base.json │ │ └── config-transformer-large.json ├── distributed.py ├── nmt.py ├── options.py ├── scripts │ └── multi-bleu.perl ├── slurm.py └── translate.py ├── flownmt ├── __init__.py ├── data │ ├── __init__.py │ └── dataloader.py ├── flownmt.py ├── flows │ ├── __init__.py │ ├── actnorm.py │ ├── couplings │ │ ├── __init__.py │ │ ├── blocks.py │ │ ├── coupling.py │ │ └── transform.py │ ├── flow.py │ ├── linear.py │ ├── nmt.py │ └── parallel │ │ ├── __init__.py │ │ ├── data_parallel.py │ │ └── parallel_apply.py ├── modules │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── rnn.py │ │ ├── simple.py │ │ └── transformer.py │ ├── encoders │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── rnn.py │ │ └── transformer.py │ ├── posteriors │ │ ├── __init__.py │ │ ├── posterior.py │ │ ├── rnn.py │ │ ├── shift_rnn.py │ │ └── transformer.py │ └── priors │ │ ├── __init__.py │ │ ├── length_predictors │ │ ├── __init__.py │ │ ├── diff_discretized_mix_logistic.py │ │ ├── diff_softmax.py │ │ ├── predictor.py │ │ └── utils.py │ │ └── prior.py ├── nnet │ ├── __init__.py │ ├── attention.py │ ├── criterion.py │ ├── layer_norm.py │ ├── positional_encoding.py │ ├── transformer.py │ └── weightnorm.py ├── optim │ ├── __init__.py │ ├── adamw.py │ └── lr_scheduler.py └── utils.py ├── images └── flowseq_diagram.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # data dir 107 | /experiments/data/ 108 | 109 | # model dir 110 | /experiments/models/ 111 | 112 | # log dir 113 | /experiments/log/ 114 | 115 | # test dir 116 | /tests/ 117 | 118 | # IDE 119 | /.idea/ 120 | *.iml 121 | *.sublime-project 122 | *.sublime-workspace 123 | 124 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlowSeq: a Generative Flow based Sequence-to-Sequence Tookit. 2 | This is the Pytorch implementation for [FlowSeq: Non-Autoregressive Conditional Sequence Generation with Generative Flow](http://arxiv.org/abs/1909.02480), accepted by EMNLP 2019. 3 | 4 |

5 | 6 |

7 | 8 | We propose an efficient and effective model for non-autoregressive sequence generation using latent variable models. 9 | We model the complex distributions with generative flows, and design 10 | several layers of flow tailored for modeling the conditional density of sequential latent variables. 11 | On several machine translation benchmark datasets (wmt14-ende, wmt16-enro), we achieved comparable performance 12 | with state-of-the-art non-autoregressive NMT models and almost constant-decoding time w.r.t the sequence length. 13 | 14 | ## Requirements 15 | * Python version >= 3.6 16 | * Pytorch version >= 1.1 17 | * apex 18 | * Perl 19 | 20 | ## Installation 21 | 1. Install [NVIDIA-apex](https://github.com/NVIDIA/apex). 22 | 2. Install [Pytorch and torchvision](https://pytorch.org/get-started/locally/?source=Google&medium=PaidSearch&utm_campaign=1712416206&utm_adgroup=67591282235&utm_keyword=pytorch%20installation&utm_offering=AI&utm_Product=PYTorch&gclid=CjwKCAjw-7LrBRB6EiwAhh1yX0hnpuTNccHYdOCd3WeW1plR0GhjSkzqLuAL5eRNcobASoxbsOwX4RoCQKkQAvD_BwE). 23 | 24 | ## Data 25 | 1. WMT'14 English to German (EN-DE) can be obtained with scripts provided in [fairseq](https://github.com/pytorch/fairseq/blob/master/examples/translation/README.md#wmt14-english-to-german-convolutional). 26 | 2. WMT'16 English to Romania (EN-RO) can be obtained from [here](https://github.com/nyu-dl/dl4mt-nonauto#downloading-datasets--pre-trained-models). 27 | 28 | ## Training a new model 29 | The MT datasets should be named in the format of ``train.{language code}, dev.{language code}, test.{language code}``, e.g "train.de". 30 | Suppose we put the WMT14-ENDE data sets under ``data/wmt14-ende/real-bpe/``, we can train FlowSeq over this data on one node with the 31 | following script: 32 | ```bash 33 | cd experiments 34 | 35 | python -u distributed.py \ 36 | --nnodes 1 --node_rank 0 --nproc_per_node --master_addr
\ 37 | --master_port \ 38 | --config configs/wmt14/config-transformer-base.json --model_path \ 39 | --data_path data/wmt14-ende/real-bpe/ \ 40 | --batch_size 2048 --batch_steps 1 --init_batch_size 512 --eval_batch_size 32 \ 41 | --src en --tgt de \ 42 | --lr 0.0005 --beta1 0.9 --beta2 0.999 --eps 1e-8 --grad_clip 1.0 --amsgrad \ 43 | --lr_decay 'expo' --weight_decay 0.001 \ 44 | --init_steps 30000 --kl_warmup_steps 10000 \ 45 | --subword 'joint-bpe' --bucket_batch 1 --create_vocab 46 | ``` 47 | After training, under the , there will be saved checkpoints, `model.pt`, `config.json`, `log.txt`, 48 | `vocab` directory and intermediate translation results under the `translations` directory. 49 | 50 | #### Note: 51 | - The argument --batch_steps is used for accumulated gradients to trade speed for memory. The size of each segment of data batch is batch-size / (num_gpus * batch_steps). 52 | - To train FlowSeq on multiple nodes, we provide a script for the slurm cluster environment `/experiments/slurm.py` or please 53 | refer to the pytorch distributed parallel training [tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html). 54 | - To create distillation dataset, please use [fairseq](https://github.com/pytorch/fairseq/blob/master/examples/translation/README.md#neural-machine-translation) to train a Transformer model 55 | and translate the source data set. 56 | 57 | ## Translation and evalutaion 58 | ```bash 59 | cd experiments 60 | 61 | python -u translate.py \ 62 | --model_path \ 63 | --data_path data/wmt14-ende/real-bpe/ \ 64 | --batch_size 32 --bucket_batch 1 \ 65 | --decode {'argmax', 'iw', 'sample'} \ 66 | --tau 0.0 --nlen 3 --ntr 1 67 | ``` 68 | Please check details of arguments [here](https://github.com/XuezheMax/flowseq/blob/master/experiments/options.py#L48). 69 | 70 | To keep the output translations original order of the input test data, use `--bucket_batch 0`. 71 | 72 | ## References 73 | ``` 74 | @inproceedings{flowseq2019, 75 | title = {FlowSeq: Non-Autoregressive Conditional Sequence Generation with Generative Flow}, 76 | author = {Ma, Xuezhe and Zhou, Chunting and Li, Xian and Neubig, Graham and Hovy, Eduard}, 77 | booktitle = {Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing}, 78 | address = {Hong Kong}, 79 | month = {November}, 80 | year = {2019} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /experiments/configs/iwslt14/config-rnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "de_vocab_size": 32012, 3 | "en_vocab_size": 22825, 4 | "max_de_length": 64, 5 | "max_en_length": 64, 6 | "share_embed": false, 7 | "tie_weights": false, 8 | "embed_dim": 256, 9 | "latent_dim": 256, 10 | "hidden_size": 256, 11 | "prior": { 12 | "type": "normal", 13 | "length_predictor": { 14 | "type": "diff_softmax", 15 | "diff_range": 16, 16 | "dropout": 0.33, 17 | "label_smoothing": 0.1 18 | }, 19 | "flow": { 20 | "levels": 3, 21 | "num_steps": [4, 2, 2], 22 | "factors": [2, 2], 23 | "hidden_features": 256, 24 | "transform": "affine", 25 | "coupling_type": "rnn", 26 | "rnn_mode": "LSTM", 27 | "dropout": 0.0, 28 | "inverse": true 29 | } 30 | }, 31 | "encoder": { 32 | "type": "rnn", 33 | "rnn_mode": "LSTM", 34 | "num_layers": 2 35 | }, 36 | "posterior": { 37 | "type": "rnn", 38 | "rnn_mode": "LSTM", 39 | "num_layers": 1, 40 | "use_attn": true, 41 | "dropout": 0.33, 42 | "dropword": 0.2 43 | }, 44 | "decoder": { 45 | "type": "rnn", 46 | "rnn_mode": "LSTM", 47 | "num_layers": 1, 48 | "dropout": 0.33, 49 | "dropword": 0.2, 50 | "label_smoothing": 0.0 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /experiments/configs/iwslt14/config-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "de_vocab_size": 32012, 3 | "en_vocab_size": 22825, 4 | "max_de_length": 64, 5 | "max_en_length": 64, 6 | "share_embed": false, 7 | "tie_weights": false, 8 | "embed_dim": 256, 9 | "latent_dim": 256, 10 | "hidden_size": 1024, 11 | "prior": { 12 | "type": "normal", 13 | "length_predictor": { 14 | "type": "diff_softmax", 15 | "diff_range": 16, 16 | "dropout": 0.33, 17 | "label_smoothing": 0.1 18 | }, 19 | "flow": { 20 | "levels": 3, 21 | "num_steps": [4, 4, 2], 22 | "factors": [2, 2], 23 | "hidden_features": 512, 24 | "transform": "affine", 25 | "coupling_type": "self_attn", 26 | "heads": 4, 27 | "pos_enc": "attn", 28 | "max_length": 200, 29 | "dropout": 0.0, 30 | "inverse": true 31 | } 32 | }, 33 | "encoder": { 34 | "type": "transformer", 35 | "num_layers": 5, 36 | "heads": 4, 37 | "max_length": 200, 38 | "dropout": 0.2 39 | }, 40 | "posterior": { 41 | "type": "transformer", 42 | "num_layers": 3, 43 | "heads": 4, 44 | "max_length": 200, 45 | "dropout": 0.2, 46 | "dropword": 0.2 47 | }, 48 | "decoder": { 49 | "type": "transformer", 50 | "num_layers": 3, 51 | "heads": 4, 52 | "max_length": 200, 53 | "dropout": 0.2, 54 | "dropword": 0.0, 55 | "label_smoothing": 0.1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /experiments/configs/wmt14/config-rnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "de_vocab_size": 37000, 3 | "en_vocab_size": 37000, 4 | "max_de_length": 80, 5 | "max_en_length": 80, 6 | "share_embed": true, 7 | "tie_weights": true, 8 | "embed_dim": 256, 9 | "latent_dim": 256, 10 | "hidden_size": 512, 11 | "prior": { 12 | "type": "normal", 13 | "length_predictor": { 14 | "type": "diff_softmax", 15 | "diff_range": 20, 16 | "dropout": 0.33, 17 | "label_smoothing": 0.1 18 | }, 19 | "flow": { 20 | "levels": 3, 21 | "num_steps": [4, 4, 2], 22 | "factors": [2, 2], 23 | "hidden_features": 512, 24 | "transform": "affine", 25 | "coupling_type": "rnn", 26 | "rnn_mode": "LSTM", 27 | "dropout": 0.0, 28 | "inverse": true 29 | } 30 | }, 31 | "encoder": { 32 | "type": "rnn", 33 | "rnn_mode": "LSTM", 34 | "num_layers": 3 35 | }, 36 | "posterior": { 37 | "type": "rnn", 38 | "rnn_mode": "LSTM", 39 | "num_layers": 2, 40 | "use_attn": true, 41 | "dropout": 0.33, 42 | "dropword": 0.2 43 | }, 44 | "decoder": { 45 | "type": "rnn", 46 | "rnn_mode": "LSTM", 47 | "num_layers": 2, 48 | "dropout": 0.33, 49 | "dropword": 0.2, 50 | "label_smoothing": 0.0 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /experiments/configs/wmt14/config-transformer-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "de_vocab_size": 37000, 3 | "en_vocab_size": 37000, 4 | "max_de_length": 80, 5 | "max_en_length": 80, 6 | "share_embed": true, 7 | "tie_weights": true, 8 | "embed_dim": 256, 9 | "latent_dim": 256, 10 | "hidden_size": 1024, 11 | "prior": { 12 | "type": "normal", 13 | "length_predictor": { 14 | "type": "diff_softmax", 15 | "diff_range": 20, 16 | "dropout": 0.33, 17 | "label_smoothing": 0.1 18 | }, 19 | "flow": { 20 | "levels": 3, 21 | "num_steps": [4, 4, 2], 22 | "factors": [2, 2], 23 | "hidden_features": 512, 24 | "transform": "affine", 25 | "coupling_type": "self_attn", 26 | "heads": 8, 27 | "pos_enc": "attn", 28 | "max_length": 250, 29 | "dropout": 0.0, 30 | "inverse": true 31 | } 32 | }, 33 | "encoder": { 34 | "type": "transformer", 35 | "num_layers": 6, 36 | "heads": 8, 37 | "max_length": 250, 38 | "dropout": 0.1 39 | }, 40 | "posterior": { 41 | "type": "transformer", 42 | "num_layers": 4, 43 | "heads": 8, 44 | "max_length": 250, 45 | "dropout": 0.1, 46 | "dropword": 0.2 47 | }, 48 | "decoder": { 49 | "type": "transformer", 50 | "num_layers": 4, 51 | "heads": 8, 52 | "max_length": 250, 53 | "dropout": 0.1, 54 | "dropword": 0.0, 55 | "label_smoothing": 0.1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /experiments/configs/wmt14/config-transformer-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "de_vocab_size": 37000, 3 | "en_vocab_size": 37000, 4 | "max_de_length": 80, 5 | "max_en_length": 80, 6 | "share_embed": true, 7 | "tie_weights": true, 8 | "embed_dim": 512, 9 | "latent_dim": 512, 10 | "hidden_size": 1024, 11 | "prior": { 12 | "type": "normal", 13 | "length_predictor": { 14 | "type": "diff_softmax", 15 | "diff_range": 20, 16 | "dropout": 0.33, 17 | "label_smoothing": 0.1 18 | }, 19 | "flow": { 20 | "levels": 3, 21 | "num_steps": [4, 4, 2], 22 | "factors": [2, 2], 23 | "hidden_features": 1024, 24 | "transform": "affine", 25 | "coupling_type": "self_attn", 26 | "heads": 8, 27 | "pos_enc": "attn", 28 | "max_length": 250, 29 | "dropout": 0.0, 30 | "inverse": true 31 | } 32 | }, 33 | "encoder": { 34 | "type": "transformer", 35 | "num_layers": 6, 36 | "heads": 8, 37 | "max_length": 250, 38 | "dropout": 0.1 39 | }, 40 | "posterior": { 41 | "type": "transformer", 42 | "num_layers": 4, 43 | "heads": 8, 44 | "max_length": 250, 45 | "dropout": 0.1, 46 | "dropword": 0.2 47 | }, 48 | "decoder": { 49 | "type": "transformer", 50 | "num_layers": 4, 51 | "heads": 8, 52 | "max_length": 250, 53 | "dropout": 0.1, 54 | "dropword": 0.0, 55 | "label_smoothing": 0.1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /experiments/configs/wmt16/config-rnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "ro_vocab_size": 31500, 3 | "en_vocab_size": 31500, 4 | "max_ro_length": 80, 5 | "max_en_length": 80, 6 | "share_embed": true, 7 | "tie_weights": true, 8 | "embed_dim": 256, 9 | "latent_dim": 256, 10 | "hidden_size": 512, 11 | "prior": { 12 | "type": "normal", 13 | "length_predictor": { 14 | "type": "diff_softmax", 15 | "diff_range": 20, 16 | "dropout": 0.33, 17 | "label_smoothing": 0.1 18 | }, 19 | "flow": { 20 | "levels": 3, 21 | "num_steps": [4, 4, 2], 22 | "factors": [2, 2], 23 | "hidden_features": 512, 24 | "transform": "affine", 25 | "coupling_type": "rnn", 26 | "rnn_mode": "LSTM", 27 | "dropout": 0.0, 28 | "inverse": true 29 | } 30 | }, 31 | "encoder": { 32 | "type": "rnn", 33 | "rnn_mode": "LSTM", 34 | "num_layers": 3 35 | }, 36 | "posterior": { 37 | "type": "rnn", 38 | "rnn_mode": "LSTM", 39 | "num_layers": 2, 40 | "use_attn": true, 41 | "dropout": 0.33, 42 | "dropword": 0.2 43 | }, 44 | "decoder": { 45 | "type": "rnn", 46 | "rnn_mode": "LSTM", 47 | "num_layers": 2, 48 | "dropout": 0.33, 49 | "dropword": 0.2, 50 | "label_smoothing": 0.0 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /experiments/configs/wmt16/config-transformer-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "ro_vocab_size": 31500, 3 | "en_vocab_size": 31500, 4 | "max_ro_length": 80, 5 | "max_en_length": 80, 6 | "share_embed": true, 7 | "tie_weights": true, 8 | "embed_dim": 256, 9 | "latent_dim": 256, 10 | "hidden_size": 1024, 11 | "prior": { 12 | "type": "normal", 13 | "length_predictor": { 14 | "type": "diff_softmax", 15 | "diff_range": 20, 16 | "dropout": 0.33, 17 | "label_smoothing": 0.1 18 | }, 19 | "flow": { 20 | "levels": 3, 21 | "num_steps": [4, 4, 2], 22 | "factors": [2, 2], 23 | "hidden_features": 512, 24 | "transform": "affine", 25 | "coupling_type": "self_attn", 26 | "heads": 8, 27 | "pos_enc": "attn", 28 | "max_length": 250, 29 | "dropout": 0.0, 30 | "inverse": true 31 | } 32 | }, 33 | "encoder": { 34 | "type": "transformer", 35 | "num_layers": 6, 36 | "heads": 8, 37 | "max_length": 250, 38 | "dropout": 0.1 39 | }, 40 | "posterior": { 41 | "type": "transformer", 42 | "num_layers": 4, 43 | "heads": 8, 44 | "max_length": 250, 45 | "dropout": 0.1, 46 | "dropword": 0.2 47 | }, 48 | "decoder": { 49 | "type": "transformer", 50 | "num_layers": 4, 51 | "heads": 8, 52 | "max_length": 250, 53 | "dropout": 0.1, 54 | "dropword": 0.0, 55 | "label_smoothing": 0.1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /experiments/configs/wmt16/config-transformer-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "ro_vocab_size": 31500, 3 | "en_vocab_size": 31500, 4 | "max_ro_length": 80, 5 | "max_en_length": 80, 6 | "share_embed": true, 7 | "tie_weights": true, 8 | "embed_dim": 512, 9 | "latent_dim": 512, 10 | "hidden_size": 1024, 11 | "prior": { 12 | "type": "normal", 13 | "length_predictor": { 14 | "type": "diff_softmax", 15 | "diff_range": 20, 16 | "dropout": 0.33, 17 | "label_smoothing": 0.1 18 | }, 19 | "flow": { 20 | "levels": 3, 21 | "num_steps": [4, 4, 2], 22 | "factors": [2, 2], 23 | "hidden_features": 1024, 24 | "transform": "affine", 25 | "coupling_type": "self_attn", 26 | "heads": 8, 27 | "pos_enc": "attn", 28 | "max_length": 250, 29 | "dropout": 0.0, 30 | "inverse": true 31 | } 32 | }, 33 | "encoder": { 34 | "type": "transformer", 35 | "num_layers": 6, 36 | "heads": 8, 37 | "max_length": 250, 38 | "dropout": 0.1 39 | }, 40 | "posterior": { 41 | "type": "transformer", 42 | "num_layers": 4, 43 | "heads": 8, 44 | "max_length": 250, 45 | "dropout": 0.1, 46 | "dropword": 0.2 47 | }, 48 | "decoder": { 49 | "type": "transformer", 50 | "num_layers": 4, 51 | "heads": 8, 52 | "max_length": 250, 53 | "dropout": 0.1, 54 | "dropword": 0.0, 55 | "label_smoothing": 0.1 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /experiments/distributed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | current_path = os.path.dirname(os.path.realpath(__file__)) 5 | root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 6 | sys.path.append(root_path) 7 | 8 | import json 9 | import signal 10 | import threading 11 | import torch 12 | 13 | from flownmt.data import NMTDataSet 14 | import experiments.options as options 15 | from experiments.nmt import main as single_process_main 16 | 17 | 18 | def create_dataset(args): 19 | model_path = args.model_path 20 | if not os.path.exists(model_path): 21 | os.makedirs(model_path) 22 | 23 | result_path = os.path.join(model_path, 'translations') 24 | if not os.path.exists(result_path): 25 | os.makedirs(result_path) 26 | 27 | vocab_path = os.path.join(model_path, 'vocab') 28 | if not os.path.exists(vocab_path): 29 | os.makedirs(vocab_path) 30 | 31 | data_path = args.data_path 32 | src_lang = args.src 33 | tgt_lang = args.tgt 34 | src_vocab_path = os.path.join(vocab_path, '{}.vocab'.format(src_lang)) 35 | tgt_vocab_path = os.path.join(vocab_path, '{}.vocab'.format(tgt_lang)) 36 | 37 | params = json.load(open(args.config, 'r')) 38 | 39 | src_max_vocab = params['{}_vocab_size'.format(src_lang)] 40 | tgt_max_vocab = params['{}_vocab_size'.format(tgt_lang)] 41 | 42 | NMTDataSet(data_path, src_lang, tgt_lang, src_vocab_path, tgt_vocab_path, src_max_vocab, tgt_max_vocab, 43 | subword=args.subword, create_vocab=True) 44 | 45 | 46 | def main(): 47 | args = options.parse_distributed_args() 48 | args_dict = vars(args) 49 | 50 | nproc_per_node = args_dict.pop('nproc_per_node') 51 | nnodes = args_dict.pop('nnodes') 52 | node_rank = args_dict.pop('node_rank') 53 | 54 | # world size in terms of number of processes 55 | dist_world_size = nproc_per_node * nnodes 56 | 57 | # set PyTorch distributed related environmental variables 58 | current_env = os.environ 59 | current_env["MASTER_ADDR"] = args_dict.pop('master_addr') 60 | current_env["MASTER_PORT"] = str(args_dict.pop('master_port')) 61 | current_env["WORLD_SIZE"] = str(dist_world_size) 62 | 63 | create_vocab = args_dict.pop('create_vocab') 64 | if create_vocab: 65 | create_dataset(args) 66 | args.create_vocab = False 67 | 68 | batch_size = args.batch_size // dist_world_size 69 | args.batch_size = batch_size 70 | 71 | mp = torch.multiprocessing.get_context('spawn') 72 | # Create a thread to listen for errors in the child processes. 73 | error_queue = mp.SimpleQueue() 74 | error_handler = ErrorHandler(error_queue) 75 | 76 | processes = [] 77 | 78 | for local_rank in range(0, nproc_per_node): 79 | # each process's rank 80 | dist_rank = nproc_per_node * node_rank + local_rank 81 | args.rank = dist_rank 82 | args.local_rank = local_rank 83 | process = mp.Process(target=run, args=(args, error_queue, ), daemon=True) 84 | process.start() 85 | error_handler.add_child(process.pid) 86 | processes.append(process) 87 | 88 | for process in processes: 89 | process.join() 90 | 91 | 92 | def run(args, error_queue): 93 | try: 94 | single_process_main(args) 95 | except KeyboardInterrupt: 96 | pass # killed by parent, do nothing 97 | except Exception: 98 | # propagate exception to parent process, keeping original traceback 99 | import traceback 100 | error_queue.put((args.rank, traceback.format_exc())) 101 | 102 | 103 | class ErrorHandler(object): 104 | """A class that listens for exceptions in children processes and propagates 105 | the tracebacks to the parent process.""" 106 | 107 | def __init__(self, error_queue): 108 | self.error_queue = error_queue 109 | self.children_pids = [] 110 | self.error_thread = threading.Thread(target=self.error_listener, daemon=True) 111 | self.error_thread.start() 112 | signal.signal(signal.SIGUSR1, self.signal_handler) 113 | 114 | def add_child(self, pid): 115 | self.children_pids.append(pid) 116 | 117 | def error_listener(self): 118 | (rank, original_trace) = self.error_queue.get() 119 | self.error_queue.put((rank, original_trace)) 120 | os.kill(os.getpid(), signal.SIGUSR1) 121 | 122 | def signal_handler(self, signalnum, stackframe): 123 | for pid in self.children_pids: 124 | os.kill(pid, signal.SIGINT) # kill children processes 125 | (rank, original_trace) = self.error_queue.get() 126 | msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n" 127 | msg += original_trace 128 | raise Exception(msg) 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /experiments/options.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from argparse import ArgumentParser 3 | 4 | 5 | def parse_args(): 6 | parser = ArgumentParser(description='FlowNMT') 7 | parser.add_argument('--rank', type=int, default=-1, metavar='N', help='rank of the process in all distributed processes') 8 | parser.add_argument("--local_rank", type=int, default=0, metavar='N', help='rank of the process in the machine') 9 | parser.add_argument('--config', type=str, help='config file', required=True) 10 | parser.add_argument('--batch_size', type=int, default=512, metavar='N', 11 | help='input batch size for training (default: 512)') 12 | parser.add_argument('--eval_batch_size', type=int, default=4, metavar='N', 13 | help='input batch size for eval (default: 4)') 14 | parser.add_argument('--batch_steps', type=int, default=1, metavar='N', 15 | help='number of steps for each batch (the batch size of each step is batch-size / steps (default: 1)') 16 | parser.add_argument('--init_batch_size', type=int, default=1024, metavar='N', 17 | help='number of instances for model initialization (default: 1024)') 18 | parser.add_argument('--epochs', type=int, default=500, metavar='N', help='number of epochs to train') 19 | parser.add_argument('--kl_warmup_steps', type=int, default=10000, metavar='N', help='number of steps to warm up KL weight(default: 10000)') 20 | parser.add_argument('--init_steps', type=int, default=5000, metavar='N', help='number of steps to train decoder (default: 5000)') 21 | parser.add_argument('--seed', type=int, default=65537, metavar='S', help='random seed (default: 65537)') 22 | parser.add_argument('--loss_type', choices=['sentence', 'token'], default='sentence', 23 | help='loss type (default: sentence)') 24 | parser.add_argument('--train_k', type=int, default=1, metavar='N', help='training K (default: 1)') 25 | parser.add_argument('--log_interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') 26 | parser.add_argument('--lr_decay', choices=['inv_sqrt', 'expo'], help='lr decay method', default='inv_sqrt') 27 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 28 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam') 29 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam') 30 | parser.add_argument('--eps', type=float, default=1e-6, help='eps of Adam') 31 | parser.add_argument('--weight_decay', type=float, default=0.0, help='weight for l2 norm decay') 32 | parser.add_argument('--amsgrad', action='store_true', help='AMS Grad') 33 | parser.add_argument('--grad_clip', type=float, default=0, help='max norm for gradient clip (default 0: no clip') 34 | parser.add_argument('--model_path', help='path for saving model file.', required=True) 35 | parser.add_argument('--data_path', help='path for data file.', default=None) 36 | 37 | parser.add_argument('--src', type=str, help='source language code', required=True) 38 | parser.add_argument('--tgt', type=str, help='target language code', required=True) 39 | parser.add_argument('--create_vocab', action='store_true', help='create vocabulary.') 40 | parser.add_argument('--share_all_embeddings', action='store_true', help='share source, target and output embeddings') 41 | parser.add_argument("--subword", type=str, default="joint-bpe", choices=['joint-bpe', 'sep-bpe', 'word', 'bert-bpe', 'joint-spm']) 42 | parser.add_argument('--recover', type=int, default=-1, help='recover the model from disk.') 43 | parser.add_argument("--bucket_batch", type=int, default=0, help="whether bucket data based on tgt length in batching") 44 | 45 | return parser.parse_args() 46 | 47 | 48 | def parse_translate_args(): 49 | parser = ArgumentParser(description='FlowNMT') 50 | parser.add_argument('--batch_size', type=int, default=512, metavar='N', help='input batch size for training (default: 512)') 51 | parser.add_argument('--seed', type=int, default=524287, metavar='S', help='random seed (default: 65537)') 52 | parser.add_argument('--model_path', help='path for saving model file.', required=True) 53 | parser.add_argument('--data_path', help='path for data file.', default=None) 54 | parser.add_argument("--subword", type=str, default="joint-bpe", choices=['joint-bpe', 'sep-bpe', 'word', 'bert-bpe', 'joint-spm']) 55 | parser.add_argument("--bucket_batch", type=int, default=0, help="whether bucket data based on tgt length in batching") 56 | parser.add_argument('--decode', choices=['argmax', 'iw', 'sample'], help='decoding algorithm', default='argmax') 57 | parser.add_argument('--tau', type=float, default=0.0, metavar='S', help='temperature for iw decoding (default: 0.)') 58 | parser.add_argument('--nlen', type=int, default=3, help='number of length candidates.') 59 | parser.add_argument('--ntr', type=int, default=1, help='number of samples per length candidate.') 60 | return parser.parse_args() 61 | 62 | 63 | def parse_distributed_args(): 64 | """ 65 | Helper function parsing the command line options 66 | @retval ArgumentParser 67 | """ 68 | parser = ArgumentParser(description="Dist FlowNMT") 69 | 70 | # Optional arguments for the launch helper 71 | parser.add_argument("--nnodes", type=int, default=1, 72 | help="The number of nodes to use for distributed " 73 | "training") 74 | parser.add_argument("--node_rank", type=int, default=0, 75 | help="The rank of the node for multi-node distributed " 76 | "training") 77 | parser.add_argument("--nproc_per_node", type=int, default=1, 78 | help="The number of processes to launch on each node, " 79 | "for GPU training, this is recommended to be set " 80 | "to the number of GPUs in your system so that " 81 | "each process can be bound to a single GPU.") 82 | parser.add_argument("--master_addr", default="127.0.0.1", type=str, 83 | help="Master node (rank 0)'s address, should be either " 84 | "the IP address or the hostname of node 0, for " 85 | "single node multi-proc training, the " 86 | "--master_addr can simply be 127.0.0.1") 87 | parser.add_argument("--master_port", default=29500, type=int, 88 | help="Master node (rank 0)'s free port that needs to " 89 | "be used for communciation during distributed " 90 | "training") 91 | 92 | # arguments for flownmt model 93 | parser.add_argument('--config', type=str, help='config file', required=True) 94 | parser.add_argument('--batch_size', type=int, default=512, metavar='N', 95 | help='input batch size for training (default: 512)') 96 | parser.add_argument('--eval_batch_size', type=int, default=4, metavar='N', 97 | help='input batch size for eval (default: 4)') 98 | parser.add_argument('--init_batch_size', type=int, default=1024, metavar='N', 99 | help='number of instances for model initialization (default: 1024)') 100 | parser.add_argument('--batch_steps', type=int, default=1, metavar='N', 101 | help='number of steps for each batch (the batch size of each step is batch-size / steps (default: 1)') 102 | parser.add_argument('--epochs', type=int, default=500, metavar='N', help='number of epochs to train') 103 | parser.add_argument('--kl_warmup_steps', type=int, default=10000, metavar='N', 104 | help='number of steps to warm up KL weight(default: 10000)') 105 | parser.add_argument('--init_steps', type=int, default=5000, metavar='N', 106 | help='number of steps to train decoder (default: 5000)') 107 | parser.add_argument('--seed', type=int, default=65537, metavar='S', help='random seed (default: 524287)') 108 | parser.add_argument('--loss_type', choices=['sentence', 'token'], default='sentence', 109 | help='loss type (default: sentence)') 110 | parser.add_argument('--train_k', type=int, default=1, metavar='N', help='training K (default: 1)') 111 | parser.add_argument('--log_interval', type=int, default=10, metavar='N', 112 | help='how many batches to wait before logging training status') 113 | parser.add_argument('--lr_decay', choices=['inv_sqrt', 'expo'], help='lr decay method', default='inv_sqrt') 114 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 115 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 of Adam') 116 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 of Adam') 117 | parser.add_argument('--eps', type=float, default=1e-6, help='eps of Adam') 118 | parser.add_argument('--weight_decay', type=float, default=0.0, help='weight for l2 norm decay') 119 | parser.add_argument('--amsgrad', action='store_true', help='AMS Grad') 120 | parser.add_argument('--grad_clip', type=float, default=0, help='max norm for gradient clip (default 0: no clip') 121 | parser.add_argument('--model_path', help='path for saving model file.', required=True) 122 | parser.add_argument('--data_path', help='path for data file.', default=None) 123 | 124 | parser.add_argument('--src', type=str, help='source language code', required=True) 125 | parser.add_argument('--tgt', type=str, help='target language code', required=True) 126 | parser.add_argument('--create_vocab', action='store_true', help='create vocabulary.') 127 | parser.add_argument('--share_all_embeddings', action='store_true', help='share source, target and output embeddings') 128 | parser.add_argument("--subword", type=str, default="joint-bpe", 129 | choices=['joint-bpe', 'sep-bpe', 'word', 'bert-bpe']) 130 | parser.add_argument("--bucket_batch", type=int, default=0, 131 | help="whether bucket data based on tgt length in batching") 132 | parser.add_argument('--recover', type=int, default=-1, help='recover the model from disk.') 133 | 134 | return parser.parse_args() 135 | -------------------------------------------------------------------------------- /experiments/scripts/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n]) and defined ($CORRECT[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | sub my_log { 172 | return -9999999999 unless $_[0]; 173 | return log($_[0]); 174 | } -------------------------------------------------------------------------------- /experiments/slurm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | current_path = os.path.dirname(os.path.realpath(__file__)) 5 | root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 6 | sys.path.append(root_path) 7 | 8 | import torch.multiprocessing as mp 9 | 10 | import experiments.options as options 11 | from experiments.nmt import main as single_process_main 12 | 13 | 14 | def main(): 15 | args = options.parse_distributed_args() 16 | args_dict = vars(args) 17 | 18 | args_dict.pop('master_addr') 19 | str(args_dict.pop('master_port')) 20 | args_dict.pop('nnodes') 21 | args_dict.pop('nproc_per_node') 22 | args_dict.pop('node_rank') 23 | 24 | current_env = os.environ 25 | nnodes = int(current_env['SLURM_NNODES']) 26 | dist_world_size = int(current_env['SLURM_NTASKS']) 27 | args.rank = int(current_env['SLURM_PROCID']) 28 | args.local_rank = int(current_env['SLURM_LOCALID']) 29 | 30 | 31 | print('start process: rank={}({}), master addr={}, port={}, nnodes={}, world size={}'.format( 32 | args.rank, args.local_rank, current_env["MASTER_ADDR"], current_env["MASTER_PORT"], nnodes, dist_world_size)) 33 | current_env["WORLD_SIZE"] = str(dist_world_size) 34 | 35 | create_vocab = args_dict.pop('create_vocab') 36 | assert not create_vocab 37 | args.create_vocab = False 38 | 39 | batch_size = args.batch_size // dist_world_size 40 | args.batch_size = batch_size 41 | 42 | single_process_main(args) 43 | 44 | 45 | if __name__ == "__main__": 46 | mp.set_start_method('forkserver') 47 | main() 48 | -------------------------------------------------------------------------------- /experiments/translate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | current_path = os.path.dirname(os.path.realpath(__file__)) 5 | root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 6 | sys.path.append(root_path) 7 | 8 | import time 9 | import json 10 | import random 11 | import numpy as np 12 | 13 | import torch 14 | 15 | from flownmt.data import NMTDataSet, DataIterator 16 | from flownmt import FlowNMT 17 | from experiments.options import parse_translate_args 18 | 19 | 20 | def calc_bleu(fref, fmt, result_path): 21 | script = os.path.join(current_path, 'scripts/multi-bleu.perl') 22 | temp = os.path.join(result_path, 'tmp') 23 | os.system("perl %s %s < %s > %s" % (script, fref, fmt, temp)) 24 | bleu = open(temp, 'r').read().strip() 25 | bleu = bleu.split(",")[0].split("=") 26 | if len(bleu) < 2: 27 | return 0.0 28 | bleu = float(bleu[1].strip()) 29 | return bleu 30 | 31 | 32 | def translate_argmax(dataset, dataloader, flownmt, result_path, outfile, tau, n_tr): 33 | flownmt.eval() 34 | translations = [] 35 | lengths = [] 36 | length_err = 0 37 | num_insts = 0 38 | start_time = time.time() 39 | num_back = 0 40 | for step, (src, tgt, src_masks, tgt_masks) in enumerate(dataloader): 41 | trans, lens = flownmt.translate_argmax(src, src_masks, n_tr=n_tr, tau=tau) 42 | translations.append(trans) 43 | lengths.append(lens) 44 | length_err += (lens.float() - tgt_masks.sum(dim=1)).abs().sum().item() 45 | num_insts += src.size(0) 46 | if step % 10 == 0: 47 | sys.stdout.write("\b" * num_back) 48 | sys.stdout.write(" " * num_back) 49 | sys.stdout.write("\b" * num_back) 50 | log_info = 'argmax translating (tau={:.1f}, n_tr={})...{}'.format(tau, n_tr, num_insts) 51 | sys.stdout.write(log_info) 52 | sys.stdout.flush() 53 | num_back = len(log_info) 54 | print('time: {:.1f}s'.format(time.time() - start_time)) 55 | outfile = os.path.join(result_path, outfile) 56 | dataset.dump_to_file(translations, lengths, outfile) 57 | bleu = calc_bleu(dataloader.tgt_sort_origin_path, outfile, result_path) 58 | print('#SENT: {}, Length Err: {:.1f}, BLEU: {:.2f}'.format(num_insts, length_err / num_insts, bleu)) 59 | 60 | 61 | def translate_iw(dataset, dataloader, flownmt, result_path, outfile, tau, n_len, n_tr): 62 | flownmt.eval() 63 | iwk = 4 64 | translations = [] 65 | lengths = [] 66 | length_err = 0 67 | num_insts = 0 68 | start_time = time.time() 69 | num_back = 0 70 | for step, (src, tgt, src_masks, tgt_masks) in enumerate(dataloader): 71 | trans, lens = flownmt.translate_iw(src, src_masks, n_len=n_len, n_tr=n_tr, tau=tau, k=iwk) 72 | translations.append(trans) 73 | lengths.append(lens) 74 | length_err += (lens.float() - tgt_masks.sum(dim=1)).abs().sum().item() 75 | num_insts += src.size(0) 76 | if step % 10 == 0: 77 | sys.stdout.write("\b" * num_back) 78 | sys.stdout.write(" " * num_back) 79 | sys.stdout.write("\b" * num_back) 80 | log_info = 'importance weighted translating (tau={:.1f}, n_len={}, n_tr={})...{}'.format(tau, n_len, n_tr, num_insts) 81 | sys.stdout.write(log_info) 82 | sys.stdout.flush() 83 | num_back = len(log_info) 84 | print('time: {:.1f}s'.format(time.time() - start_time)) 85 | outfile = os.path.join(result_path, outfile) 86 | dataset.dump_to_file(translations, lengths, outfile) 87 | bleu = calc_bleu(dataloader.tgt_sort_origin_path, outfile, result_path) 88 | print('#SENT: {}, Length Err: {:.1f}, BLEU: {:.2f}'.format(num_insts, length_err / num_insts, bleu)) 89 | 90 | 91 | def sample(dataset, dataloader, flownmt, result_path, outfile, tau, n_len, n_tr): 92 | flownmt.eval() 93 | lengths = [] 94 | translations = [] 95 | num_insts = 0 96 | start_time = time.time() 97 | num_back = 0 98 | for step, (src, tgt, src_masks, tgt_masks) in enumerate(dataloader): 99 | trans, lens = flownmt.translate_sample(src, src_masks, n_len=n_len, n_tr=n_tr, tau=tau) 100 | translations.append(trans) 101 | lengths.append(lens) 102 | num_insts += src.size(0) 103 | if step % 10 == 0: 104 | sys.stdout.write("\b" * num_back) 105 | sys.stdout.write(" " * num_back) 106 | sys.stdout.write("\b" * num_back) 107 | log_info = 'sampling (tau={:.1f}, n_len={}, n_tr={})...{}'.format(tau, n_len, n_tr, num_insts) 108 | sys.stdout.write(log_info) 109 | sys.stdout.flush() 110 | num_back = len(log_info) 111 | print('time: {:.1f}s'.format(time.time() - start_time)) 112 | outfile = os.path.join(result_path, outfile) 113 | dataset.dump_to_file(translations, lengths, outfile, post_edit=False) 114 | 115 | 116 | def setup(args): 117 | args.cuda = torch.cuda.is_available() 118 | random_seed = args.seed 119 | random.seed(random_seed) 120 | np.random.seed(random_seed) 121 | torch.manual_seed(random_seed) 122 | device = torch.device('cuda', 0) if args.cuda else torch.device('cpu') 123 | if args.cuda: 124 | torch.cuda.set_device(device) 125 | torch.cuda.manual_seed(random_seed) 126 | 127 | torch.backends.cudnn.benchmark = False 128 | 129 | model_path = args.model_path 130 | result_path = os.path.join(model_path, 'translations') 131 | args.result_path = result_path 132 | params = json.load(open(os.path.join(model_path, 'config.json'), 'r')) 133 | 134 | src_lang = params['src'] 135 | tgt_lang = params['tgt'] 136 | data_path = args.data_path 137 | vocab_path = os.path.join(model_path, 'vocab') 138 | src_vocab_path = os.path.join(vocab_path, '{}.vocab'.format(src_lang)) 139 | tgt_vocab_path = os.path.join(vocab_path, '{}.vocab'.format(tgt_lang)) 140 | src_vocab_size = params['src_vocab_size'] 141 | tgt_vocab_size = params['tgt_vocab_size'] 142 | args.max_src_length = params.pop('max_src_length') 143 | args.max_tgt_length = params.pop('max_tgt_length') 144 | dataset = NMTDataSet(data_path, src_lang, tgt_lang, 145 | src_vocab_path, tgt_vocab_path, 146 | src_vocab_size, tgt_vocab_size, 147 | subword=args.subword, create_vocab=False) 148 | assert src_vocab_size == dataset.src_vocab_size 149 | assert tgt_vocab_size == dataset.tgt_vocab_size 150 | 151 | flownmt = FlowNMT.load(model_path, device=device) 152 | args.length_unit = flownmt.length_unit 153 | args.device = device 154 | return args, dataset, flownmt 155 | 156 | 157 | def init_dataloader(args, dataset): 158 | eval_batch = args.batch_size 159 | val_iter = DataIterator(dataset, eval_batch, 0, args.max_src_length, args.max_tgt_length, 1000, args.device, args.result_path, 160 | bucket_data=args.bucket_batch, multi_scale=args.length_unit, corpus="dev") 161 | test_iter = DataIterator(dataset, eval_batch, 0, args.max_src_length, args.max_tgt_length, 1000, args.device, args.result_path, 162 | bucket_data=args.bucket_batch, multi_scale=args.length_unit, corpus="test") 163 | return val_iter, test_iter 164 | 165 | 166 | def main(args): 167 | args, dataset, flownmt = setup(args) 168 | print(args) 169 | 170 | val_iter, test_iter = init_dataloader(args, dataset) 171 | 172 | result_path = args.result_path 173 | if args.decode == 'argmax': 174 | tau = args.tau 175 | n_tr = args.ntr 176 | outfile = 'argmax.t{:.1f}.ntr{}.dev.mt'.format(tau, n_tr) 177 | translate_argmax(dataset, val_iter, flownmt, result_path, outfile, tau, n_tr) 178 | outfile = 'argmax.t{:.1f}.ntr{}.test.mt'.format(tau, n_tr) 179 | translate_argmax(dataset, test_iter, flownmt, result_path, outfile, tau, n_tr) 180 | elif args.decode == 'iw': 181 | tau = args.tau 182 | n_len = args.nlen 183 | n_tr = args.ntr 184 | outfile = 'iw.t{:.1f}.nlen{}.ntr{}.dev.mt'.format(tau, n_len, n_tr) 185 | translate_iw(dataset, val_iter, flownmt, result_path, outfile, tau, n_len, n_tr) 186 | outfile = 'iw.t{:.1f}.nlen{}.ntr{}.test.mt'.format(tau, n_len, n_tr) 187 | translate_iw(dataset, test_iter, flownmt, result_path, outfile, tau, n_len, n_tr) 188 | else: 189 | assert not args.bucket_batch 190 | tau = args.tau 191 | n_len = args.nlen 192 | n_tr = args.ntr 193 | outfile = 'sample.t{:.1f}.nlen{}.ntr{}.dev.mt'.format(tau, n_len, n_tr) 194 | sample(dataset, val_iter, flownmt, result_path, outfile, tau, n_len, n_tr) 195 | outfile = 'sample.t{:.1f}.nlen{}.ntr{}.test.mt'.format(tau, n_len, n_tr) 196 | sample(dataset, test_iter, flownmt, result_path, outfile, tau, n_len, n_tr) 197 | 198 | 199 | if __name__ == "__main__": 200 | args = parse_translate_args() 201 | with torch.no_grad(): 202 | main(args) -------------------------------------------------------------------------------- /flownmt/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.flownmt import FlowNMT 2 | 3 | -------------------------------------------------------------------------------- /flownmt/data/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'violet-zct' 2 | 3 | from flownmt.data.dataloader import NMTDataSet, DataIterator 4 | -------------------------------------------------------------------------------- /flownmt/flows/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.flows.flow import Flow 2 | from flownmt.flows.actnorm import ActNormFlow 3 | from flownmt.flows.parallel import * 4 | from flownmt.flows.linear import InvertibleMultiHeadFlow, InvertibleLinearFlow 5 | from flownmt.flows.couplings import * 6 | from flownmt.flows.nmt import NMTFlow 7 | -------------------------------------------------------------------------------- /flownmt/flows/actnorm.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import Parameter 7 | 8 | from flownmt.flows.flow import Flow 9 | 10 | 11 | class ActNormFlow(Flow): 12 | def __init__(self, in_features, inverse=False): 13 | super(ActNormFlow, self).__init__(inverse) 14 | self.in_features = in_features 15 | self.log_scale = Parameter(torch.Tensor(in_features)) 16 | self.bias = Parameter(torch.Tensor(in_features)) 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | nn.init.normal_(self.log_scale, mean=0, std=0.05) 21 | nn.init.constant_(self.bias, 0.) 22 | 23 | @overrides 24 | def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 25 | """ 26 | 27 | Args: 28 | input: Tensor 29 | input tensor [batch, N1, N2, ..., Nl, in_features] 30 | mask: Tensor 31 | mask tensor [batch, N1, N2, ...,Nl] 32 | 33 | Returns: out: Tensor , logdet: Tensor 34 | out: [batch, N1, N2, ..., in_features], the output of the flow 35 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 36 | 37 | """ 38 | dim = input.dim() 39 | out = input * self.log_scale.exp() + self.bias 40 | out = out * mask.unsqueeze(dim - 1) 41 | logdet = self.log_scale.sum(dim=0, keepdim=True) 42 | if dim > 2: 43 | num = mask.view(out.size(0), -1).sum(dim=1) 44 | logdet = logdet * num 45 | return out, logdet 46 | 47 | @overrides 48 | def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 49 | """ 50 | 51 | Args: 52 | input: Tensor 53 | input tensor [batch, N1, N2, ..., Nl, in_features] 54 | mask: Tensor 55 | mask tensor [batch, N1, N2, ...,Nl] 56 | 57 | Returns: out: Tensor , logdet: Tensor 58 | out: [batch, N1, N2, ..., in_features], the output of the flow 59 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 60 | 61 | """ 62 | dim = input.dim() 63 | out = (input - self.bias) * mask.unsqueeze(dim - 1) 64 | out = out.div(self.log_scale.exp() + 1e-8) 65 | logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0 66 | if dim > 2: 67 | num = mask.view(out.size(0), -1).sum(dim=1) 68 | logdet = logdet * num 69 | return out, logdet 70 | 71 | @overrides 72 | def init(self, data: torch.Tensor, mask: torch.Tensor, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: 73 | """ 74 | 75 | Args: 76 | data: input: Tensor 77 | input tensor [batch, N1, N2, ..., in_features] 78 | mask: Tensor 79 | mask tensor [batch, N1, N2, ...,Nl] 80 | init_scale: float 81 | initial scale 82 | 83 | Returns: out: Tensor , logdet: Tensor 84 | out: [batch, N1, N2, ..., in_features], the output of the flow 85 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 86 | 87 | """ 88 | with torch.no_grad(): 89 | out, _ = self.forward(data, mask) 90 | mean = out.view(-1, self.in_features).mean(dim=0) 91 | std = out.view(-1, self.in_features).std(dim=0) 92 | inv_stdv = init_scale / (std + 1e-6) 93 | 94 | self.log_scale.add_(inv_stdv.log()) 95 | self.bias.add_(-mean).mul_(inv_stdv) 96 | return self.forward(data, mask) 97 | 98 | @overrides 99 | def extra_repr(self): 100 | return 'inverse={}, in_features={}'.format(self.inverse, self.in_features) 101 | 102 | @classmethod 103 | def from_params(cls, params: Dict) -> "ActNormFlow": 104 | return ActNormFlow(**params) 105 | 106 | 107 | ActNormFlow.register('actnorm') 108 | -------------------------------------------------------------------------------- /flownmt/flows/couplings/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.flows.couplings.coupling import NICE 2 | -------------------------------------------------------------------------------- /flownmt/flows/couplings/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 4 | 5 | from flownmt.nnet.weightnorm import Conv1dWeightNorm, LinearWeightNorm 6 | from flownmt.nnet.attention import GlobalAttention, MultiHeadAttention 7 | from flownmt.nnet.positional_encoding import PositionalEncoding 8 | from flownmt.nnet.transformer import TransformerDecoderLayer 9 | 10 | 11 | class NICEConvBlock(nn.Module): 12 | def __init__(self, src_features, in_features, out_features, hidden_features, kernel_size, dropout=0.0): 13 | super(NICEConvBlock, self).__init__() 14 | self.conv1 = Conv1dWeightNorm(in_features, hidden_features, kernel_size=kernel_size, padding=kernel_size // 2, bias=True) 15 | self.conv2 = Conv1dWeightNorm(hidden_features, hidden_features, kernel_size=kernel_size, padding=kernel_size // 2, bias=True) 16 | self.activation = nn.ELU(inplace=True) 17 | self.attn = GlobalAttention(src_features, hidden_features, hidden_features, dropout=dropout) 18 | self.linear = LinearWeightNorm(hidden_features * 2, out_features, bias=True) 19 | 20 | def forward(self, x, mask, src, src_mask): 21 | """ 22 | 23 | Args: 24 | x: Tensor 25 | input tensor [batch, length, in_features] 26 | mask: Tensor 27 | x mask tensor [batch, length] 28 | src: Tensor 29 | source input tensor [batch, src_length, src_features] 30 | src_mask: Tensor 31 | source mask tensor [batch, src_length] 32 | 33 | Returns: Tensor 34 | out tensor [batch, length, out_features] 35 | 36 | """ 37 | out = self.activation(self.conv1(x.transpose(1, 2))) 38 | out = self.activation(self.conv2(out)).transpose(1, 2) * mask.unsqueeze(2) 39 | out = self.attn(out, src, key_mask=src_mask.eq(0)) 40 | out = self.linear(torch.cat([x, out], dim=2)) 41 | return out 42 | 43 | def init(self, x, mask, src, src_mask, init_scale=1.0): 44 | out = self.activation(self.conv1.init(x.transpose(1, 2), init_scale=init_scale)) 45 | out = self.activation(self.conv2.init(out, init_scale=init_scale)).transpose(1, 2) * mask.unsqueeze(2) 46 | out = self.attn.init(out, src, key_mask=src_mask.eq(0), init_scale=init_scale) 47 | out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0) 48 | return out 49 | 50 | 51 | class NICERecurrentBlock(nn.Module): 52 | def __init__(self, rnn_mode, src_features, in_features, out_features, hidden_features, dropout=0.0): 53 | super(NICERecurrentBlock, self).__init__() 54 | if rnn_mode == 'RNN': 55 | RNN = nn.RNN 56 | elif rnn_mode == 'LSTM': 57 | RNN = nn.LSTM 58 | elif rnn_mode == 'GRU': 59 | RNN = nn.GRU 60 | else: 61 | raise ValueError('Unknown RNN mode: %s' % rnn_mode) 62 | 63 | self.rnn = RNN(in_features, hidden_features // 2, batch_first=True, bidirectional=True) 64 | self.attn = GlobalAttention(src_features, hidden_features, hidden_features, dropout=dropout) 65 | self.linear = LinearWeightNorm(in_features + hidden_features, out_features, bias=True) 66 | 67 | def forward(self, x, mask, src, src_mask): 68 | lengths = mask.sum(dim=1).long() 69 | packed_out = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) 70 | packed_out, _ = self.rnn(packed_out) 71 | out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=mask.size(1)) 72 | # [batch, length, out_features] 73 | out = self.attn(out, src, key_mask=src_mask.eq(0)) 74 | out = self.linear(torch.cat([x, out], dim=2)) 75 | return out 76 | 77 | def init(self, x, mask, src, src_mask, init_scale=1.0): 78 | lengths = mask.sum(dim=1).long() 79 | packed_out = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) 80 | packed_out, _ = self.rnn(packed_out) 81 | out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=mask.size(1)) 82 | # [batch, length, out_features] 83 | out = self.attn.init(out, src, key_mask=src_mask.eq(0), init_scale=init_scale) 84 | out = self.linear.init(torch.cat([x, out], dim=2), init_scale=0.0) 85 | return out 86 | 87 | 88 | class NICESelfAttnBlock(nn.Module): 89 | def __init__(self, src_features, in_features, out_features, hidden_features, heads, dropout=0.0, 90 | pos_enc='add', max_length=100): 91 | super(NICESelfAttnBlock, self).__init__() 92 | assert pos_enc in ['add', 'attn'] 93 | self.src_proj = nn.Linear(src_features, in_features, bias=False) if src_features != in_features else None 94 | self.pos_enc = PositionalEncoding(in_features, padding_idx=None, init_size=max_length + 1) 95 | self.pos_attn = MultiHeadAttention(in_features, heads, dropout=dropout) if pos_enc == 'attn' else None 96 | self.transformer = TransformerDecoderLayer(in_features, hidden_features, heads, dropout=dropout) 97 | self.linear = LinearWeightNorm(in_features, out_features, bias=True) 98 | 99 | def forward(self, x, mask, src, src_mask): 100 | if self.src_proj is not None: 101 | src = self.src_proj(src) 102 | 103 | key_mask = mask.eq(0) 104 | pos_enc = self.pos_enc(x) * mask.unsqueeze(2) 105 | if self.pos_attn is None: 106 | x = x + pos_enc 107 | else: 108 | x = self.pos_attn(pos_enc, x, x, key_mask) 109 | 110 | x = self.transformer(x, key_mask, src, src_mask.eq(0)) 111 | return self.linear(x) 112 | 113 | def init(self, x, mask, src, src_mask, init_scale=1.0): 114 | if self.src_proj is not None: 115 | src = self.src_proj(src) 116 | 117 | key_mask = mask.eq(0) 118 | pos_enc = self.pos_enc(x) * mask.unsqueeze(2) 119 | if self.pos_attn is None: 120 | x = x + pos_enc 121 | else: 122 | x = self.pos_attn(pos_enc, x, x, key_mask) 123 | 124 | x = self.transformer.init(x, key_mask, src, src_mask.eq(0), init_scale=init_scale) 125 | x = x * mask.unsqueeze(2) 126 | return self.linear.init(x, init_scale=0.0) 127 | -------------------------------------------------------------------------------- /flownmt/flows/couplings/coupling.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Tuple, Dict 3 | import torch 4 | 5 | from flownmt.flows.couplings.blocks import NICEConvBlock, NICERecurrentBlock, NICESelfAttnBlock 6 | from flownmt.flows.flow import Flow 7 | from flownmt.flows.couplings.transform import Transform, Additive, Affine, NLSQ 8 | 9 | 10 | class NICE(Flow): 11 | """ 12 | NICE Flow 13 | """ 14 | def __init__(self, src_features, features, hidden_features=None, inverse=False, split_dim=2, split_type='continuous', order='up', factor=2, 15 | transform='affine', type='conv', kernel=3, rnn_mode='LSTM', heads=1, dropout=0.0, pos_enc='add', max_length=100): 16 | super(NICE, self).__init__(inverse) 17 | self.features = features 18 | assert split_dim in [1, 2] 19 | assert split_type in ['continuous', 'skip'] 20 | if split_dim == 1: 21 | assert split_type == 'skip' 22 | if factor != 2: 23 | assert split_type == 'continuous' 24 | assert order in ['up', 'down'] 25 | self.split_dim = split_dim 26 | self.split_type = split_type 27 | self.up = order == 'up' 28 | if split_dim == 2: 29 | out_features = features // factor 30 | in_features = features - out_features 31 | self.z1_channels = in_features if self.up else out_features 32 | else: 33 | in_features = features 34 | out_features = features 35 | self.z1_channels = None 36 | assert transform in ['additive', 'affine', 'nlsq'] 37 | if transform == 'additive': 38 | self.transform = Additive 39 | elif transform == 'affine': 40 | self.transform = Affine 41 | out_features = out_features * 2 42 | elif transform == 'nlsq': 43 | self.transform = NLSQ 44 | out_features = out_features * 5 45 | else: 46 | raise ValueError('unknown transform: {}'.format(transform)) 47 | 48 | if hidden_features is None: 49 | hidden_features = min(2 * in_features, 1024) 50 | assert type in ['conv', 'self_attn', 'rnn'] 51 | if type == 'conv': 52 | self.net = NICEConvBlock(src_features, in_features, out_features, hidden_features, kernel_size=kernel, dropout=dropout) 53 | elif type == 'rnn': 54 | self.net = NICERecurrentBlock(rnn_mode, src_features, in_features, out_features, hidden_features, dropout=dropout) 55 | else: 56 | self.net = NICESelfAttnBlock(src_features, in_features, out_features, hidden_features, 57 | heads=heads, dropout=dropout, pos_enc=pos_enc, max_length=max_length) 58 | 59 | def split(self, z, mask): 60 | split_dim = self.split_dim 61 | split_type = self.split_type 62 | dim = z.size(split_dim) 63 | if split_type == 'continuous': 64 | return z.split([self.z1_channels, dim - self.z1_channels], dim=split_dim), mask 65 | elif split_type == 'skip': 66 | idx1 = torch.tensor(list(range(0, dim, 2))).to(z.device) 67 | idx2 = torch.tensor(list(range(1, dim, 2))).to(z.device) 68 | z1 = z.index_select(split_dim, idx1) 69 | z2 = z.index_select(split_dim, idx2) 70 | if split_dim == 1: 71 | mask = mask.index_select(split_dim, idx1) 72 | return (z1, z2), mask 73 | else: 74 | raise ValueError('unknown split type: {}'.format(split_type)) 75 | 76 | def unsplit(self, z1, z2): 77 | split_dim = self.split_dim 78 | split_type = self.split_type 79 | if split_type == 'continuous': 80 | return torch.cat([z1, z2], dim=split_dim) 81 | elif split_type == 'skip': 82 | z = torch.cat([z1, z2], dim=split_dim) 83 | dim = z1.size(split_dim) 84 | idx = torch.tensor([i // 2 if i % 2 == 0 else i // 2 + dim for i in range(dim * 2)]).to(z.device) 85 | return z.index_select(split_dim, idx) 86 | else: 87 | raise ValueError('unknown split type: {}'.format(split_type)) 88 | 89 | def calc_params(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor): 90 | params = self.net(z, mask, src, src_mask) 91 | return params 92 | 93 | def init_net(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor, init_scale=1.0): 94 | params = self.net.init(z, mask, src, src_mask, init_scale=init_scale) 95 | return params 96 | 97 | @overrides 98 | def forward(self, input: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 99 | """ 100 | Args: 101 | input: Tensor 102 | input tensor [batch, length, in_features] 103 | mask: Tensor 104 | mask tensor [batch, length] 105 | src: Tensor 106 | source input tensor [batch, src_length, src_features] 107 | src_mask: Tensor 108 | source mask tensor [batch, src_length] 109 | 110 | Returns: out: Tensor , logdet: Tensor 111 | out: [batch, length, in_features], the output of the flow 112 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 113 | """ 114 | # [batch, length, in_channels] 115 | (z1, z2), mask = self.split(input, mask) 116 | # [batch, length, features] 117 | z, zp = (z1, z2) if self.up else (z2, z1) 118 | 119 | params = self.calc_params(z, mask, src, src_mask) 120 | zp, logdet = self.transform.fwd(zp, mask, params) 121 | 122 | z1, z2 = (z, zp) if self.up else (zp, z) 123 | return self.unsplit(z1, z2), logdet 124 | 125 | @overrides 126 | def backward(self, input: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 127 | """ 128 | Args: 129 | input: Tensor 130 | input tensor [batch, length, in_features] 131 | mask: Tensor 132 | mask tensor [batch, length] 133 | src: Tensor 134 | source input tensor [batch, src_length, src_features] 135 | src_mask: Tensor 136 | source mask tensor [batch, src_length] 137 | 138 | Returns: out: Tensor , logdet: Tensor 139 | out: [batch, length, in_features], the output of the flow 140 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 141 | """ 142 | # [batch, length, in_channels] 143 | (z1, z2), mask = self.split(input, mask) 144 | # [batch, length, features] 145 | z, zp = (z1, z2) if self.up else (z2, z1) 146 | 147 | params = self.calc_params(z, mask, src, src_mask) 148 | zp, logdet = self.transform.bwd(zp, mask, params) 149 | 150 | z1, z2 = (z, zp) if self.up else (zp, z) 151 | return self.unsplit(z1, z2), logdet 152 | 153 | @overrides 154 | def init(self, data: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: 155 | # [batch, length, in_channels] 156 | (z1, z2), mask = self.split(data, mask) 157 | # [batch, length, features] 158 | z, zp = (z1, z2) if self.up else (z2, z1) 159 | 160 | params = self.init_net(z, mask, src, src_mask, init_scale=init_scale) 161 | zp, logdet = self.transform.fwd(zp, mask, params) 162 | 163 | z1, z2 = (z, zp) if self.up else (zp, z) 164 | return self.unsplit(z1, z2), logdet 165 | 166 | @overrides 167 | def extra_repr(self): 168 | return 'inverse={}, in_channels={}, scale={}'.format(self.inverse, self.in_channels, self.scale) 169 | 170 | @classmethod 171 | def from_params(cls, params: Dict) -> "NICE": 172 | return NICE(**params) 173 | 174 | 175 | NICE.register('nice') 176 | -------------------------------------------------------------------------------- /flownmt/flows/couplings/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | from overrides import overrides 3 | from typing import Tuple 4 | import torch 5 | 6 | 7 | class Transform(): 8 | @staticmethod 9 | def fwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]: 10 | raise NotImplementedError 11 | 12 | @staticmethod 13 | def bwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]: 14 | raise NotImplementedError 15 | 16 | 17 | class Additive(Transform): 18 | @staticmethod 19 | @overrides 20 | def fwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]: 21 | mu = params 22 | z = (z + mu).mul(mask.unsqueeze(2)) 23 | logdet = z.new_zeros(z.size(0)) 24 | return z, logdet 25 | 26 | @staticmethod 27 | @overrides 28 | def bwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]: 29 | mu = params 30 | z = (z - mu).mul(mask.unsqueeze(2)) 31 | logdet = z.new_zeros(z.size(0)) 32 | return z, logdet 33 | 34 | 35 | class Affine(Transform): 36 | @staticmethod 37 | @overrides 38 | def fwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]: 39 | mu, log_scale = params.chunk(2, dim=2) 40 | scale = log_scale.add_(2.0).sigmoid_() 41 | z = (scale * z + mu).mul(mask.unsqueeze(2)) 42 | logdet = scale.log().mul(mask.unsqueeze(2)).view(z.size(0), -1).sum(dim=1) 43 | return z, logdet 44 | 45 | @staticmethod 46 | @overrides 47 | def bwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]: 48 | mu, log_scale = params.chunk(2, dim=2) 49 | scale = log_scale.add_(2.0).sigmoid_() 50 | z = (z - mu).div(scale + 1e-12).mul(mask.unsqueeze(2)) 51 | logdet = scale.log().mul(mask.unsqueeze(2)).view(z.size(0), -1).sum(dim=1) * -1.0 52 | return z, logdet 53 | 54 | 55 | def arccosh(x): 56 | return torch.log(x + torch.sqrt(x.pow(2) - 1)) 57 | 58 | 59 | def arcsinh(x): 60 | return torch.log(x + torch.sqrt(x.pow(2) + 1)) 61 | 62 | 63 | class NLSQ(Transform): 64 | # A = 8 * math.sqrt(3) / 9 - 0.05 # 0.05 is a small number to prevent exactly 0 slope 65 | logA = math.log(8 * math.sqrt(3) / 9 - 0.05) # 0.05 is a small number to prevent exactly 0 slope 66 | 67 | @staticmethod 68 | def get_pseudo_params(params): 69 | a, logb, cprime, logd, g = params.chunk(5, dim=2) 70 | 71 | # for stability 72 | logb = logb.mul_(0.4) 73 | cprime = cprime.mul_(0.3) 74 | logd = logd.mul_(0.4) 75 | 76 | # b = logb.add_(2.0).sigmoid_() 77 | # d = logd.add_(2.0).sigmoid_() 78 | # c = (NLSQ.A * b / d).mul(cprime.tanh_()) 79 | 80 | c = (NLSQ.logA + logb - logd).exp_().mul(cprime.tanh_()) 81 | b = logb.exp_() 82 | d = logd.exp_() 83 | return a, b, c, d, g 84 | 85 | @staticmethod 86 | @overrides 87 | def fwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]: 88 | a, b, c, d, g = NLSQ.get_pseudo_params(params) 89 | 90 | arg = (d * z).add_(g) 91 | denom = arg.pow(2).add_(1) 92 | c = c / denom 93 | z = (b * z + a + c).mul(mask.unsqueeze(2)) 94 | logdet = torch.log(b - 2 * c * d * arg / denom) 95 | logdet = logdet.mul(mask.unsqueeze(2)).view(z.size(0), -1).sum(dim=1) 96 | return z, logdet 97 | 98 | @staticmethod 99 | @overrides 100 | def bwd(z: torch.Tensor, mask: torch.Tensor, params) -> Tuple[torch.Tensor, torch.Tensor]: 101 | a, b, c, d, g = NLSQ.get_pseudo_params(params) 102 | 103 | # double needed for stability. No effect on overall speed 104 | a = a.double() 105 | b = b.double() 106 | c = c.double() 107 | d = d.double() 108 | g = g.double() 109 | z = z.double() 110 | 111 | aa = -b * d.pow(2) 112 | bb = (z - a) * d.pow(2) - 2 * b * d * g 113 | cc = (z - a) * 2 * d * g - b * (1 + g.pow(2)) 114 | dd = (z - a) * (1 + g.pow(2)) - c 115 | 116 | p = (3 * aa * cc - bb.pow(2)) / (3 * aa.pow(2)) 117 | q = (2 * bb.pow(3) - 9 * aa * bb * cc + 27 * aa.pow(2) * dd) / (27 * aa.pow(3)) 118 | 119 | t = -2 * torch.abs(q) / q * torch.sqrt(torch.abs(p) / 3) 120 | inter_term1 = -3 * torch.abs(q) / (2 * p) * torch.sqrt(3 / torch.abs(p)) 121 | inter_term2 = 1 / 3 * arccosh(torch.abs(inter_term1 - 1) + 1) 122 | t = t * torch.cosh(inter_term2) 123 | 124 | tpos = -2 * torch.sqrt(torch.abs(p) / 3) 125 | inter_term1 = 3 * q / (2 * p) * torch.sqrt(3 / torch.abs(p)) 126 | inter_term2 = 1 / 3 * arcsinh(inter_term1) 127 | tpos = tpos * torch.sinh(inter_term2) 128 | 129 | t[p > 0] = tpos[p > 0] 130 | z = t - bb / (3 * aa) 131 | arg = d * z + g 132 | denom = arg.pow(2) + 1 133 | logdet = torch.log(b - 2 * c * d * arg / denom.pow(2)) 134 | 135 | z = z.float().mul(mask.unsqueeze(2)) 136 | logdet = logdet.float().mul(mask.unsqueeze(2)).view(z.size(0), -1).sum(dim=1) * -1.0 137 | return z, logdet 138 | -------------------------------------------------------------------------------- /flownmt/flows/flow.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Flow(nn.Module): 7 | """ 8 | Normalizing Flow base class 9 | """ 10 | _registry = dict() 11 | 12 | def __init__(self, inverse): 13 | super(Flow, self).__init__() 14 | self.inverse = inverse 15 | 16 | def forward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 17 | """ 18 | 19 | Args: 20 | *input: input [batch, *input_size] 21 | 22 | Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch] 23 | out, the output of the flow 24 | logdet, the log determinant of :math:`\partial output / \partial input` 25 | """ 26 | raise NotImplementedError 27 | 28 | def backward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 29 | """ 30 | 31 | Args: 32 | *input: input [batch, *input_size] 33 | 34 | Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch] 35 | out, the output of the flow 36 | logdet, the log determinant of :math:`\partial output / \partial input` 37 | """ 38 | raise NotImplementedError 39 | 40 | def init(self, *input, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 41 | raise NotImplementedError 42 | 43 | def fwdpass(self, x: torch.Tensor, *h, init=False, init_scale=1.0, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 44 | """ 45 | 46 | Args: 47 | x: Tensor 48 | The random variable before flow 49 | h: list of object 50 | other conditional inputs 51 | init: bool 52 | perform initialization or not (default: False) 53 | init_scale: float 54 | initial scale (default: 1.0) 55 | 56 | Returns: y: Tensor, logdet: Tensor 57 | y, the random variable after flow 58 | logdet, the log determinant of :math:`\partial y / \partial x` 59 | Then the density :math:`\log(p(y)) = \log(p(x)) - logdet` 60 | 61 | """ 62 | if self.inverse: 63 | if init: 64 | raise RuntimeError('inverse flow shold be initialized with backward pass') 65 | else: 66 | return self.backward(x, *h, **kwargs) 67 | else: 68 | if init: 69 | return self.init(x, *h, init_scale=init_scale, **kwargs) 70 | else: 71 | return self.forward(x, *h, **kwargs) 72 | 73 | def bwdpass(self, y: torch.Tensor, *h, init=False, init_scale=1.0, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 74 | """ 75 | 76 | Args: 77 | y: Tensor 78 | The random variable after flow 79 | h: list of object 80 | other conditional inputs 81 | init: bool 82 | perform initialization or not (default: False) 83 | init_scale: float 84 | initial scale (default: 1.0) 85 | 86 | Returns: x: Tensor, logdet: Tensor 87 | x, the random variable before flow 88 | logdet, the log determinant of :math:`\partial x / \partial y` 89 | Then the density :math:`\log(p(y)) = \log(p(x)) + logdet` 90 | 91 | """ 92 | if self.inverse: 93 | if init: 94 | return self.init(y, *h, init_scale=init_scale, **kwargs) 95 | else: 96 | return self.forward(y, *h, **kwargs) 97 | else: 98 | if init: 99 | raise RuntimeError('forward flow should be initialzed with forward pass') 100 | else: 101 | return self.backward(y, *h, **kwargs) 102 | 103 | @classmethod 104 | def register(cls, name: str): 105 | Flow._registry[name] = cls 106 | 107 | @classmethod 108 | def by_name(cls, name: str): 109 | return Flow._registry[name] 110 | 111 | @classmethod 112 | def from_params(cls, params: Dict): 113 | raise NotImplementedError 114 | -------------------------------------------------------------------------------- /flownmt/flows/linear.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import Parameter 7 | 8 | from flownmt.flows.flow import Flow 9 | 10 | 11 | class InvertibleLinearFlow(Flow): 12 | def __init__(self, in_features, inverse=False): 13 | super(InvertibleLinearFlow, self).__init__(inverse) 14 | self.in_features = in_features 15 | self.weight = Parameter(torch.Tensor(in_features, in_features)) 16 | self.register_buffer('weight_inv', self.weight.data.clone()) 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | nn.init.orthogonal_(self.weight) 21 | self.sync() 22 | 23 | def sync(self): 24 | self.weight_inv.copy_(self.weight.data.inverse()) 25 | 26 | @overrides 27 | def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 28 | """ 29 | 30 | Args: 31 | input: Tensor 32 | input tensor [batch, N1, N2, ..., Nl, in_features] 33 | mask: Tensor 34 | mask tensor [batch, N1, N2, ...,Nl] 35 | 36 | Returns: out: Tensor , logdet: Tensor 37 | out: [batch, N1, N2, ..., in_features], the output of the flow 38 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 39 | 40 | """ 41 | dim = input.dim() 42 | # [batch, N1, N2, ..., in_features] 43 | out = F.linear(input, self.weight) 44 | _, logdet = torch.slogdet(self.weight) 45 | if dim > 2: 46 | num = mask.view(out.size(0), -1).sum(dim=1) 47 | logdet = logdet * num 48 | return out, logdet 49 | 50 | @overrides 51 | def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 52 | """ 53 | 54 | Args: 55 | input: Tensor 56 | input tensor [batch, N1, N2, ..., Nl, in_features] 57 | mask: Tensor 58 | mask tensor [batch, N1, N2, ...,Nl] 59 | 60 | Returns: out: Tensor , logdet: Tensor 61 | out: [batch, N1, N2, ..., in_features], the output of the flow 62 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 63 | 64 | """ 65 | dim = input.dim() 66 | # [batch, N1, N2, ..., in_features] 67 | out = F.linear(input, self.weight_inv) 68 | _, logdet = torch.slogdet(self.weight_inv) 69 | if dim > 2: 70 | num = mask.view(out.size(0), -1).sum(dim=1) 71 | logdet = logdet * num 72 | return out, logdet 73 | 74 | @overrides 75 | def init(self, data: torch.Tensor, mask: torch.Tensor, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: 76 | with torch.no_grad(): 77 | return self.forward(data) 78 | 79 | @overrides 80 | def extra_repr(self): 81 | return 'inverse={}, in_features={}'.format(self.inverse, self.in_features) 82 | 83 | @classmethod 84 | def from_params(cls, params: Dict) -> "InvertibleLinearFlow": 85 | return InvertibleLinearFlow(**params) 86 | 87 | 88 | class InvertibleMultiHeadFlow(Flow): 89 | @staticmethod 90 | def _get_heads(in_features): 91 | units = [32, 16, 8] 92 | for unit in units: 93 | if in_features % unit == 0: 94 | return in_features // unit 95 | assert in_features < 8, 'features={}'.format(in_features) 96 | return 1 97 | 98 | def __init__(self, in_features, heads=None, type='A', inverse=False): 99 | super(InvertibleMultiHeadFlow, self).__init__(inverse) 100 | self.in_features = in_features 101 | if heads is None: 102 | heads = InvertibleMultiHeadFlow._get_heads(in_features) 103 | self.heads = heads 104 | self.type = type 105 | assert in_features % heads == 0, 'features ({}) should be divided by heads ({})'.format(in_features, heads) 106 | assert type in ['A', 'B'], 'type should belong to [A, B]' 107 | self.weight = Parameter(torch.Tensor(in_features // heads, in_features // heads)) 108 | self.register_buffer('weight_inv', self.weight.data.clone()) 109 | self.reset_parameters() 110 | 111 | def reset_parameters(self): 112 | nn.init.orthogonal_(self.weight) 113 | self.sync() 114 | 115 | def sync(self): 116 | self.weight_inv.copy_(self.weight.data.inverse()) 117 | 118 | @overrides 119 | def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """ 121 | 122 | Args: 123 | input: Tensor 124 | input tensor [batch, N1, N2, ..., Nl, in_features] 125 | mask: Tensor 126 | mask tensor [batch, N1, N2, ...,Nl] 127 | 128 | Returns: out: Tensor , logdet: Tensor 129 | out: [batch, N1, N2, ..., in_features], the output of the flow 130 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 131 | 132 | """ 133 | size = input.size() 134 | dim = input.dim() 135 | # [batch, N1, N2, ..., heads, in_features/ heads] 136 | if self.type == 'A': 137 | out = input.view(*size[:-1], self.heads, self.in_features // self.heads) 138 | else: 139 | out = input.view(*size[:-1], self.in_features // self.heads, self.heads).transpose(-2, -1) 140 | 141 | out = F.linear(out, self.weight) 142 | if self.type == 'B': 143 | out = out.transpose(-2, -1).contiguous() 144 | out = out.view(*size) 145 | 146 | _, logdet = torch.slogdet(self.weight) 147 | if dim > 2: 148 | num = mask.view(size[0], -1).sum(dim=1) * self.heads 149 | logdet = logdet * num 150 | return out, logdet 151 | 152 | @overrides 153 | def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 154 | """ 155 | 156 | Args: 157 | input: Tensor 158 | input tensor [batch, N1, N2, ..., Nl, in_features] 159 | mask: Tensor 160 | mask tensor [batch, N1, N2, ...,Nl] 161 | 162 | Returns: out: Tensor , logdet: Tensor 163 | out: [batch, N1, N2, ..., in_features], the output of the flow 164 | logdet: [batch], the log determinant of :math:`\partial output / \partial input` 165 | 166 | """ 167 | size = input.size() 168 | dim = input.dim() 169 | # [batch, N1, N2, ..., heads, in_features/ heads] 170 | if self.type == 'A': 171 | out = input.view(*size[:-1], self.heads, self.in_features // self.heads) 172 | else: 173 | out = input.view(*size[:-1], self.in_features // self.heads, self.heads).transpose(-2, -1) 174 | 175 | out = F.linear(out, self.weight_inv) 176 | if self.type == 'B': 177 | out = out.transpose(-2, -1).contiguous() 178 | out = out.view(*size) 179 | 180 | _, logdet = torch.slogdet(self.weight_inv) 181 | if dim > 2: 182 | num = mask.view(size[0], -1).sum(dim=1) * self.heads 183 | logdet = logdet * num 184 | return out, logdet 185 | 186 | @overrides 187 | def init(self, data: torch.Tensor, mask: torch.Tensor, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: 188 | with torch.no_grad(): 189 | return self.forward(data, mask) 190 | 191 | @overrides 192 | def extra_repr(self): 193 | return 'inverse={}, in_features={}, heads={}, type={}'.format(self.inverse, self.in_features, self.heads, self.type) 194 | 195 | @classmethod 196 | def from_params(cls, params: Dict) -> "InvertibleMultiHeadFlow": 197 | return InvertibleMultiHeadFlow(**params) 198 | 199 | 200 | InvertibleLinearFlow.register('invertible_linear') 201 | InvertibleMultiHeadFlow.register('invertible_multihead') 202 | -------------------------------------------------------------------------------- /flownmt/flows/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.flows.parallel.data_parallel import DataParallelFlow 2 | -------------------------------------------------------------------------------- /flownmt/flows/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Tuple 3 | import torch 4 | from torch.nn.parallel.replicate import replicate 5 | from flownmt.flows.parallel.parallel_apply import parallel_apply 6 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 7 | from torch.nn.parallel.data_parallel import _check_balance 8 | 9 | from flownmt.flows.flow import Flow 10 | 11 | 12 | class DataParallelFlow(Flow): 13 | """ 14 | Implements data parallelism at the flow level. 15 | """ 16 | def __init__(self, flow: Flow, device_ids=None, output_device=None, dim=0): 17 | super(DataParallelFlow, self).__init__(flow.inverse) 18 | 19 | if not torch.cuda.is_available(): 20 | self.flow = flow 21 | self.device_ids = [] 22 | return 23 | 24 | if device_ids is None: 25 | device_ids = list(range(torch.cuda.device_count())) 26 | if output_device is None: 27 | output_device = device_ids[0] 28 | self.dim = dim 29 | self.flow = flow 30 | self.device_ids = device_ids 31 | self.output_device = output_device 32 | 33 | _check_balance(self.device_ids) 34 | 35 | if len(self.device_ids) == 1: 36 | self.flow.cuda(device_ids[0]) 37 | 38 | @overrides 39 | def forward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 40 | if not self.device_ids: 41 | return self.flow.forward(*inputs, **kwargs) 42 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 43 | if len(self.device_ids) == 1: 44 | return self.flow.forward(*inputs[0], **kwargs[0]) 45 | replicas = self.replicate(self.flow, self.device_ids[:len(inputs)]) 46 | outputs = self.parallel_apply(replicas, inputs, kwargs) 47 | return self.gather(outputs, self.output_device) 48 | 49 | @overrides 50 | def backward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 51 | if not self.device_ids: 52 | return self.flow.backward(*inputs, **kwargs) 53 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 54 | if len(self.device_ids) == 1: 55 | return self.flow.backward(*inputs[0], **kwargs[0]) 56 | replicas = self.replicate(self.flow, self.device_ids[:len(inputs)]) 57 | outputs = self.parallel_apply(replicas, inputs, kwargs, backward=True) 58 | return self.gather(outputs, self.output_device) 59 | 60 | @overrides 61 | def init(self, *input, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 62 | return self.flow.init(*input, **kwargs) 63 | 64 | def replicate(self, flow, device_ids): 65 | return replicate(flow, device_ids) 66 | 67 | def scatter(self, inputs, kwargs, device_ids): 68 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 69 | 70 | def parallel_apply(self, replicas, inputs, kwargs, backward=False): 71 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)], backward=backward) 72 | 73 | def gather(self, outputs, output_device): 74 | return gather(outputs, output_device, dim=self.dim) 75 | -------------------------------------------------------------------------------- /flownmt/flows/parallel/parallel_apply.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import torch 3 | 4 | 5 | def get_a_var(obj): 6 | if isinstance(obj, torch.Tensor): 7 | return obj 8 | 9 | if isinstance(obj, list) or isinstance(obj, tuple): 10 | for result in map(get_a_var, obj): 11 | if isinstance(result, torch.Tensor): 12 | return result 13 | if isinstance(obj, dict): 14 | for result in map(get_a_var, obj.items()): 15 | if isinstance(result, torch.Tensor): 16 | return result 17 | return None 18 | 19 | 20 | def parallel_apply(flows, inputs, kwargs_tup=None, devices=None, backward=False): 21 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 22 | contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) 23 | on each of :attr:`devices`. 24 | 25 | :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and 26 | :attr:`devices` (if given) should all have same length. Moreover, each 27 | element of :attr:`inputs` can either be a single object as the only argument 28 | to a module, or a collection of positional arguments. 29 | """ 30 | assert len(flows) == len(inputs) 31 | if kwargs_tup is not None: 32 | assert len(flows) == len(kwargs_tup) 33 | else: 34 | kwargs_tup = ({},) * len(flows) 35 | if devices is not None: 36 | assert len(flows) == len(devices) 37 | else: 38 | devices = [None] * len(flows) 39 | 40 | lock = threading.Lock() 41 | results = {} 42 | grad_enabled = torch.is_grad_enabled() 43 | 44 | def _worker(i, flow, input, kwargs, device=None, back=False): 45 | torch.set_grad_enabled(grad_enabled) 46 | if device is None: 47 | device = get_a_var(input).get_device() 48 | try: 49 | with torch.cuda.device(device): 50 | # this also avoids accidental slicing of `input` if it is a Tensor 51 | if not isinstance(input, (list, tuple)): 52 | input = (input,) 53 | output = flow.backward(*input, **kwargs) if back else flow.forward(*input, **kwargs) 54 | with lock: 55 | results[i] = output 56 | except Exception as e: 57 | with lock: 58 | results[i] = e 59 | 60 | if len(flows) > 1: 61 | threads = [threading.Thread(target=_worker, 62 | args=(i, flow, input, kwargs, device, backward)) 63 | for i, (flow, input, kwargs, device) in 64 | enumerate(zip(flows, inputs, kwargs_tup, devices))] 65 | 66 | for thread in threads: 67 | thread.start() 68 | for thread in threads: 69 | thread.join() 70 | else: 71 | _worker(0, flows[0], inputs[0], kwargs_tup[0], devices[0], backward) 72 | 73 | outputs = [] 74 | for i in range(len(inputs)): 75 | output = results[i] 76 | if isinstance(output, Exception): 77 | raise output 78 | outputs.append(output) 79 | return outputs 80 | -------------------------------------------------------------------------------- /flownmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.modules.encoders import * 2 | from flownmt.modules.posteriors import * 3 | from flownmt.modules.decoders import * 4 | from flownmt.modules.priors import * 5 | -------------------------------------------------------------------------------- /flownmt/modules/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.modules.decoders.decoder import Decoder 2 | from flownmt.modules.decoders.simple import SimpleDecoder 3 | from flownmt.modules.decoders.rnn import RecurrentDecoder 4 | from flownmt.modules.decoders.transformer import TransformerDecoder 5 | -------------------------------------------------------------------------------- /flownmt/modules/decoders/decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | from flownmt.nnet.criterion import LabelSmoothedCrossEntropyLoss 6 | 7 | 8 | class Decoder(nn.Module): 9 | """ 10 | Decoder to predict translations from latent z 11 | """ 12 | _registry = dict() 13 | 14 | def __init__(self, vocab_size, latent_dim, label_smoothing=0., _shared_weight=None): 15 | super(Decoder, self).__init__() 16 | self.readout = nn.Linear(latent_dim, vocab_size, bias=True) 17 | if _shared_weight is not None: 18 | self.readout.weight = _shared_weight 19 | nn.init.constant_(self.readout.bias, 0.) 20 | else: 21 | self.reset_parameters(latent_dim) 22 | 23 | if label_smoothing < 1e-5: 24 | self.criterion = nn.CrossEntropyLoss(reduction='none') 25 | elif 1e-5 < label_smoothing < 1.0: 26 | self.criterion = LabelSmoothedCrossEntropyLoss(label_smoothing) 27 | else: 28 | raise ValueError('label smoothing should be less than 1.0.') 29 | 30 | def reset_parameters(self, dim): 31 | # nn.init.normal_(self.readout.weight, mean=0, std=dim ** -0.5) 32 | nn.init.uniform_(self.readout.weight, -0.1, 0.1) 33 | nn.init.constant_(self.readout.bias, 0.) 34 | 35 | def init(self, z, mask, src, src_mask, init_scale=1.0): 36 | raise NotImplementedError 37 | 38 | def decode(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 39 | """ 40 | 41 | Args: 42 | z: Tensor 43 | latent code [batch, length, hidden_size] 44 | mask: Tensor 45 | mask [batch, length] 46 | src: Tensor 47 | src encoding [batch, src_length, hidden_size] 48 | src_mask: Tensor 49 | source mask [batch, src_length] 50 | 51 | Returns: Tensor1, Tensor2 52 | Tenser1: decoded word index [batch, length] 53 | Tensor2: log probabilities of decoding [batch] 54 | 55 | """ 56 | raise NotImplementedError 57 | 58 | def loss(self, z: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, 59 | src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor: 60 | """ 61 | 62 | Args: 63 | z: Tensor 64 | latent codes [batch, length, hidden_size] 65 | target: LongTensor 66 | target translations [batch, length] 67 | mask: Tensor 68 | masks for target sentence [batch, length] 69 | src: Tensor 70 | src encoding [batch, src_length, hidden_size] 71 | src_mask: Tensor 72 | source mask [batch, src_length] 73 | 74 | Returns: Tensor 75 | tensor for loss [batch] 76 | 77 | """ 78 | raise NotImplementedError 79 | 80 | @classmethod 81 | def register(cls, name: str): 82 | Decoder._registry[name] = cls 83 | 84 | @classmethod 85 | def by_name(cls, name: str): 86 | return Decoder._registry[name] 87 | 88 | @classmethod 89 | def from_params(cls, params: Dict) -> "Decoder": 90 | raise NotImplementedError 91 | 92 | 93 | Decoder.register('simple') 94 | -------------------------------------------------------------------------------- /flownmt/modules/decoders/rnn.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 7 | 8 | from flownmt.modules.decoders.decoder import Decoder 9 | from flownmt.nnet.attention import GlobalAttention 10 | 11 | 12 | class RecurrentDecoder(Decoder): 13 | """ 14 | Decoder with Recurrent Neural Networks 15 | """ 16 | def __init__(self, vocab_size, latent_dim, rnn_mode, num_layers, hidden_size, bidirectional=True, 17 | dropout=0.0, dropword=0.0, label_smoothing=0., _shared_weight=None): 18 | super(RecurrentDecoder, self).__init__(vocab_size, latent_dim, 19 | label_smoothing=label_smoothing, 20 | _shared_weight=_shared_weight) 21 | 22 | if rnn_mode == 'RNN': 23 | RNN = nn.RNN 24 | elif rnn_mode == 'LSTM': 25 | RNN = nn.LSTM 26 | elif rnn_mode == 'GRU': 27 | RNN = nn.GRU 28 | else: 29 | raise ValueError('Unknown RNN mode: %s' % rnn_mode) 30 | assert hidden_size % 2 == 0 31 | # RNN for processing latent variables zs 32 | if bidirectional: 33 | self.rnn = RNN(latent_dim, hidden_size // 2, num_layers=num_layers, batch_first=True, bidirectional=True) 34 | else: 35 | self.rnn = RNN(latent_dim, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False) 36 | 37 | self.attn = GlobalAttention(latent_dim, hidden_size, latent_dim, hidden_features=hidden_size) 38 | self.ctx_proj = nn.Sequential(nn.Linear(latent_dim + hidden_size, latent_dim), nn.ELU()) 39 | self.dropout = dropout 40 | self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens 41 | 42 | def forward(self, z, mask, src, src_mask): 43 | lengths = mask.sum(dim=1).long() 44 | if self.dropout2d is not None: 45 | z = self.dropout2d(z) 46 | 47 | packed_z = pack_padded_sequence(z, lengths, batch_first=True, enforce_sorted=False) 48 | packed_enc, _ = self.rnn(packed_z) 49 | enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=mask.size(1)) 50 | 51 | ctx = self.attn(enc, src, key_mask=src_mask.eq(0)) 52 | ctx = torch.cat([ctx, enc], dim=2) 53 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) 54 | return self.readout(ctx) 55 | 56 | @overrides 57 | def init(self, z, mask, src, src_mask, init_scale=1.0): 58 | with torch.no_grad(): 59 | return self(z, mask, src, src_mask) 60 | 61 | @overrides 62 | def decode(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 63 | """ 64 | 65 | Args: 66 | z: Tensor 67 | latent code [batch, length, hidden_size] 68 | mask: Tensor 69 | mask [batch, length] 70 | src: Tensor 71 | src encoding [batch, src_length, hidden_size] 72 | src_mask: Tensor 73 | source mask [batch, src_length] 74 | 75 | Returns: Tensor1, Tensor2 76 | Tenser1: decoded word index [batch, length] 77 | Tensor2: log probabilities of decoding [batch] 78 | 79 | """ 80 | # [batch, length, vocab_size] 81 | log_probs = F.log_softmax(self(z, mask, src, src_mask), dim=2) 82 | # [batch, length] 83 | log_probs, dec = log_probs.max(dim=2) 84 | dec = dec * mask.long() 85 | # [batch] 86 | log_probs = log_probs.mul(mask).sum(dim=1) 87 | return dec, log_probs 88 | 89 | @overrides 90 | def loss(self, z: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, 91 | src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor: 92 | """ 93 | 94 | Args: 95 | z: Tensor 96 | latent codes [batch, length, hidden_size] 97 | target: LongTensor 98 | target translations [batch, length] 99 | mask: Tensor 100 | masks for target sentence [batch, length] 101 | src: Tensor 102 | src encoding [batch, src_length, hidden_size] 103 | src_mask: Tensor 104 | source mask [batch, src_length] 105 | 106 | Returns: Tensor 107 | tensor for loss [batch] 108 | 109 | """ 110 | # [batch, length, vocab_size] -> [batch, vocab_size, length] 111 | logits = self(z, mask, src, src_mask).transpose(1, 2) 112 | # [batch, length] 113 | loss = self.criterion(logits, target).mul(mask) 114 | return loss.sum(dim=1) 115 | 116 | @classmethod 117 | def from_params(cls, params: Dict) -> "RecurrentDecoder": 118 | return RecurrentDecoder(**params) 119 | 120 | 121 | RecurrentDecoder.register('rnn') 122 | -------------------------------------------------------------------------------- /flownmt/modules/decoders/simple.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from flownmt.modules.decoders.decoder import Decoder 8 | from flownmt.nnet.attention import GlobalAttention 9 | 10 | class SimpleDecoder(Decoder): 11 | """ 12 | Simple Decoder to predict translations from latent z 13 | """ 14 | 15 | def __init__(self, vocab_size, latent_dim, hidden_size, dropout=0.0, label_smoothing=0., _shared_weight=None): 16 | super(SimpleDecoder, self).__init__(vocab_size, latent_dim, 17 | label_smoothing=label_smoothing, 18 | _shared_weight=_shared_weight) 19 | self.attn = GlobalAttention(latent_dim, latent_dim, latent_dim, hidden_features=hidden_size) 20 | ctx_features = latent_dim * 2 21 | self.ctx_proj = nn.Sequential(nn.Linear(ctx_features, latent_dim), nn.ELU()) 22 | self.dropout = dropout 23 | 24 | @overrides 25 | def forward(self, z, src, src_mask): 26 | ctx = self.attn(z, src, key_mask=src_mask.eq(0)) 27 | ctx = F.dropout(self.ctx_proj(torch.cat([ctx, z], dim=2)), p=self.dropout, training=self.training) 28 | return self.readout(ctx) 29 | 30 | @overrides 31 | def init(self, z, mask, src, src_mask, init_scale=1.0): 32 | with torch.no_grad(): 33 | self(z, src, src_mask) 34 | 35 | @overrides 36 | def decode(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 37 | """ 38 | 39 | Args: 40 | z: Tensor 41 | latent code [batch, length, hidden_size] 42 | mask: Tensor 43 | mask [batch, length] 44 | src: Tensor 45 | src encoding [batch, src_length, hidden_size] 46 | src_mask: Tensor 47 | source mask [batch, src_length] 48 | 49 | Returns: Tensor1, Tensor2 50 | Tenser1: decoded word index [batch, length] 51 | Tensor2: log probabilities of decoding [batch] 52 | 53 | """ 54 | # [batch, length, vocab_size] 55 | log_probs = F.log_softmax(self(z, src, src_mask), dim=2) 56 | # [batch, length] 57 | log_probs, dec = log_probs.max(dim=2) 58 | dec = dec * mask.long() 59 | # [batch] 60 | log_probs = log_probs.mul(mask).sum(dim=1) 61 | return dec, log_probs 62 | 63 | @overrides 64 | def loss(self, z: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, 65 | src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor: 66 | """ 67 | 68 | Args: 69 | z: Tensor 70 | latent codes [batch, length, hidden_size] 71 | target: LongTensor 72 | target translations [batch, length] 73 | mask: Tensor 74 | masks for target sentence [batch, length] 75 | src: Tensor 76 | src encoding [batch, src_length, hidden_size] 77 | src_mask: Tensor 78 | source mask [batch, src_length] 79 | 80 | Returns: Tensor 81 | tensor for loss [batch] 82 | 83 | """ 84 | # [batch, length, vocab_size] -> [batch, vocab_size, length] 85 | logits = self(z, src, src_mask).transpose(1, 2) 86 | # [batch, length] 87 | loss = self.criterion(logits, target).mul(mask) 88 | return loss.sum(dim=1) 89 | 90 | @classmethod 91 | def from_params(cls, params: Dict) -> "SimpleDecoder": 92 | return SimpleDecoder(**params) 93 | 94 | 95 | SimpleDecoder.register('simple') 96 | -------------------------------------------------------------------------------- /flownmt/modules/decoders/transformer.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from flownmt.modules.decoders.decoder import Decoder 8 | from flownmt.nnet.attention import MultiHeadAttention 9 | from flownmt.nnet.transformer import TransformerDecoderLayer 10 | from flownmt.nnet.positional_encoding import PositionalEncoding 11 | 12 | 13 | class TransformerDecoder(Decoder): 14 | """ 15 | Decoder with Transformer 16 | """ 17 | def __init__(self, vocab_size, latent_dim, num_layers, hidden_size, heads, label_smoothing=0., 18 | dropout=0.0, dropword=0.0, max_length=100, _shared_weight=None): 19 | super(TransformerDecoder, self).__init__(vocab_size, latent_dim, 20 | label_smoothing=label_smoothing, 21 | _shared_weight=_shared_weight) 22 | self.pos_enc = PositionalEncoding(latent_dim, None, max_length + 1) 23 | self.pos_attn = MultiHeadAttention(latent_dim, heads, dropout=dropout) 24 | layers = [TransformerDecoderLayer(latent_dim, hidden_size, heads, dropout=dropout) for _ in range(num_layers)] 25 | self.layers = nn.ModuleList(layers) 26 | self.dropword = dropword # drop entire tokens 27 | 28 | def forward(self, z, mask, src, src_mask): 29 | z = F.dropout2d(z, p=self.dropword, training=self.training) 30 | # [batch, length, latent_dim] 31 | pos_enc = self.pos_enc(z) * mask.unsqueeze(2) 32 | 33 | key_mask = mask.eq(0) 34 | ctx = self.pos_attn(pos_enc, z, z, key_mask) 35 | 36 | src_mask = src_mask.eq(0) 37 | for layer in self.layers: 38 | ctx = layer(ctx, key_mask, src, src_mask) 39 | 40 | return self.readout(ctx) 41 | 42 | @overrides 43 | def init(self, z, mask, src, src_mask, init_scale=1.0): 44 | with torch.no_grad(): 45 | return self(z, mask, src, src_mask) 46 | 47 | @overrides 48 | def decode(self, z: torch.Tensor, mask: torch.Tensor, src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 49 | """ 50 | 51 | Args: 52 | z: Tensor 53 | latent code [batch, length, hidden_size] 54 | mask: Tensor 55 | mask [batch, length] 56 | src: Tensor 57 | src encoding [batch, src_length, hidden_size] 58 | src_mask: Tensor 59 | source mask [batch, src_length] 60 | 61 | Returns: Tensor1, Tensor2 62 | Tenser1: decoded word index [batch, length] 63 | Tensor2: log probabilities of decoding [batch] 64 | 65 | """ 66 | # [batch, length, vocab_size] 67 | log_probs = F.log_softmax(self(z, mask, src, src_mask), dim=2) 68 | # [batch, length] 69 | log_probs, dec = log_probs.max(dim=2) 70 | dec = dec * mask.long() 71 | # [batch] 72 | log_probs = log_probs.mul(mask).sum(dim=1) 73 | return dec, log_probs 74 | 75 | @overrides 76 | def loss(self, z: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, 77 | src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor: 78 | """ 79 | 80 | Args: 81 | z: Tensor 82 | latent codes [batch, length, hidden_size] 83 | target: LongTensor 84 | target translations [batch, length] 85 | mask: Tensor 86 | masks for target sentence [batch, length] 87 | src: Tensor 88 | src encoding [batch, src_length, hidden_size] 89 | src_mask: Tensor 90 | source mask [batch, src_length] 91 | 92 | Returns: Tensor 93 | tensor for loss [batch] 94 | 95 | """ 96 | # [batch, length, vocab_size] -> [batch, vocab_size, length] 97 | logits = self(z, mask, src, src_mask).transpose(1, 2) 98 | # [batch, length] 99 | loss = self.criterion(logits, target).mul(mask) 100 | return loss.sum(dim=1) 101 | 102 | @classmethod 103 | def from_params(cls, params: Dict) -> "TransformerDecoder": 104 | return TransformerDecoder(**params) 105 | 106 | 107 | TransformerDecoder.register('transformer') 108 | -------------------------------------------------------------------------------- /flownmt/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.modules.encoders.encoder import Encoder 2 | from flownmt.modules.encoders.rnn import RecurrentEncoder 3 | from flownmt.modules.encoders.transformer import TransformerEncoder 4 | -------------------------------------------------------------------------------- /flownmt/modules/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Encoder(nn.Module): 8 | """ 9 | Src Encoder to encode source sentence 10 | """ 11 | _registry = dict() 12 | 13 | def __init__(self, vocab_size, embed_dim, padding_idx): 14 | super(Encoder, self).__init__() 15 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) 16 | self.reset_parameters() 17 | 18 | def reset_parameters(self): 19 | nn.init.uniform_(self.embed.weight, -0.1, 0.1) 20 | if self.embed.padding_idx is not None: 21 | with torch.no_grad(): 22 | self.embed.weight[self.embed.padding_idx].fill_(0) 23 | 24 | @overrides 25 | def forward(self, src_sents, masks=None) -> Tuple[torch.Tensor, torch.Tensor]: 26 | """ 27 | Encoding src sentences into src encoding representations. 28 | Args: 29 | src_sents: Tensor [batch, length] 30 | masks: Tensor or None [batch, length] 31 | 32 | Returns: Tensor1, Tensor2 33 | Tensor1: tensor for src encoding [batch, length, hidden_size] 34 | Tensor2: tensor for global state [batch, hidden_size] 35 | 36 | """ 37 | raise NotImplementedError 38 | 39 | def init(self, src_sents, masks=None, init_scale=1.0) -> torch.Tensor: 40 | raise NotImplementedError 41 | 42 | @classmethod 43 | def register(cls, name: str): 44 | Encoder._registry[name] = cls 45 | 46 | @classmethod 47 | def by_name(cls, name: str): 48 | return Encoder._registry[name] 49 | 50 | @classmethod 51 | def from_params(cls, params: Dict): 52 | raise NotImplementedError 53 | -------------------------------------------------------------------------------- /flownmt/modules/encoders/rnn.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from flownmt.modules.encoders.encoder import Encoder 8 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 9 | 10 | 11 | class RecurrentCore(nn.Module): 12 | def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, dropout=0.0): 13 | super(RecurrentCore, self).__init__() 14 | self.embed = embed 15 | 16 | if rnn_mode == 'RNN': 17 | RNN = nn.RNN 18 | elif rnn_mode == 'LSTM': 19 | RNN = nn.LSTM 20 | elif rnn_mode == 'GRU': 21 | RNN = nn.GRU 22 | else: 23 | raise ValueError('Unknown RNN mode: %s' % rnn_mode) 24 | assert hidden_size % 2 == 0 25 | self.rnn = RNN(embed.embedding_dim, hidden_size // 2, 26 | num_layers=num_layers, batch_first=True, bidirectional=True) 27 | self.enc_proj = nn.Sequential(nn.Linear(hidden_size, latent_dim), nn.ELU()) 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.constant_(self.enc_proj[0].bias, 0.) 32 | 33 | @overrides 34 | def forward(self, src_sents, masks) -> Tuple[torch.Tensor, torch.Tensor]: 35 | word_embed = F.dropout(self.embed(src_sents), p=0.2, training=self.training) 36 | 37 | lengths = masks.sum(dim=1).long() 38 | packed_embed = pack_padded_sequence(word_embed, lengths, batch_first=True, enforce_sorted=False) 39 | packed_enc, _ = self.rnn(packed_embed) 40 | # [batch, length, hidden_size] 41 | src_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=masks.size(1)) 42 | # [batch, length, latent_dim] 43 | src_enc = self.enc_proj(src_enc).mul(masks.unsqueeze(2)) 44 | 45 | # [batch, latent_dim] 46 | batch = src_sents.size(0) 47 | idx = lengths - 1 48 | batch_idx = torch.arange(0, batch).long().to(idx.device) 49 | ctx = src_enc[batch_idx, idx] 50 | return src_enc, ctx 51 | 52 | 53 | class RecurrentEncoder(Encoder): 54 | """ 55 | Src Encoder to encode source sentence with Recurrent Neural Networks 56 | """ 57 | 58 | def __init__(self, vocab_size, embed_dim, padding_idx, rnn_mode, num_layers, latent_dim, hidden_size, dropout=0.0): 59 | super(RecurrentEncoder, self).__init__(vocab_size, embed_dim, padding_idx) 60 | self.core = RecurrentCore(self.embed, rnn_mode, num_layers, latent_dim, hidden_size, dropout=dropout) 61 | 62 | @overrides 63 | def forward(self, src_sents, masks=None) -> Tuple[torch.Tensor, torch.Tensor]: 64 | src_enc, ctx = self.core(src_sents, masks=masks) 65 | return src_enc, ctx 66 | 67 | def init(self, src_sents, masks=None, init_scale=1.0) -> torch.Tensor: 68 | with torch.no_grad(): 69 | src_enc, _ = self.core(src_sents, masks=masks) 70 | return src_enc 71 | 72 | @classmethod 73 | def from_params(cls, params: Dict) -> "RecurrentEncoder": 74 | return RecurrentEncoder(**params) 75 | 76 | 77 | RecurrentEncoder.register('rnn') 78 | -------------------------------------------------------------------------------- /flownmt/modules/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from flownmt.modules.encoders.encoder import Encoder 9 | from flownmt.nnet.transformer import TransformerEncoderLayer 10 | from flownmt.nnet.positional_encoding import PositionalEncoding 11 | 12 | 13 | class TransformerCore(nn.Module): 14 | def __init__(self, embed, num_layers, latent_dim, hidden_size, heads, dropout=0.0, max_length=100): 15 | super(TransformerCore, self).__init__() 16 | self.embed = embed 17 | self.padding_idx = embed.padding_idx 18 | 19 | embed_dim = embed.embedding_dim 20 | self.embed_scale = math.sqrt(embed_dim) 21 | assert embed_dim == latent_dim 22 | layers = [TransformerEncoderLayer(latent_dim, hidden_size, heads, dropout=dropout) for _ in range(num_layers)] 23 | self.layers = nn.ModuleList(layers) 24 | self.pos_enc = PositionalEncoding(latent_dim, self.padding_idx, max_length + 1) 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | pass 29 | 30 | @overrides 31 | def forward(self, src_sents, masks) -> Tuple[torch.Tensor, torch.Tensor]: 32 | # [batch, leagth, embed_dim] 33 | x = self.embed_scale * self.embed(src_sents) 34 | x += self.pos_enc(src_sents) 35 | x = F.dropout(x, p=0.2, training=self.training) 36 | 37 | # [batch, leagth, latent_dim] 38 | key_mask = masks.eq(0) 39 | if not key_mask.any(): 40 | key_mask = None 41 | 42 | for layer in self.layers: 43 | x = layer(x, key_mask) 44 | 45 | x *= masks.unsqueeze(2) 46 | # [batch, latent_dim] 47 | batch = src_sents.size(0) 48 | idx = masks.sum(dim=1).long() - 1 49 | batch_idx = torch.arange(0, batch).long().to(idx.device) 50 | ctx = x[batch_idx, idx] 51 | return x, ctx 52 | 53 | 54 | class TransformerEncoder(Encoder): 55 | """ 56 | Src Encoder to encode source sentence with Transformer 57 | """ 58 | 59 | def __init__(self, vocab_size, embed_dim, padding_idx, num_layers, latent_dim, hidden_size, heads, dropout=0.0, max_length=100): 60 | super(TransformerEncoder, self).__init__(vocab_size, embed_dim, padding_idx) 61 | self.core = TransformerCore(self.embed, num_layers, latent_dim, hidden_size, heads, dropout=dropout, max_length=max_length) 62 | 63 | @overrides 64 | def forward(self, src_sents, masks=None) -> Tuple[torch.Tensor, torch.Tensor]: 65 | src_enc, ctx = self.core(src_sents, masks=masks) 66 | return src_enc, ctx 67 | 68 | def init(self, src_sents, masks=None, init_scale=1.0) -> torch.Tensor: 69 | with torch.no_grad(): 70 | src_enc, _ = self.core(src_sents, masks=masks) 71 | return src_enc 72 | 73 | @classmethod 74 | def from_params(cls, params: Dict) -> "TransformerEncoder": 75 | return TransformerEncoder(**params) 76 | 77 | 78 | TransformerEncoder.register('transformer') 79 | -------------------------------------------------------------------------------- /flownmt/modules/posteriors/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.modules.posteriors.posterior import Posterior 2 | from flownmt.modules.posteriors.rnn import RecurrentPosterior 3 | from flownmt.modules.posteriors.shift_rnn import ShiftRecurrentPosterior 4 | from flownmt.modules.posteriors.transformer import TransformerPosterior 5 | -------------------------------------------------------------------------------- /flownmt/modules/posteriors/posterior.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Posterior(nn.Module): 8 | """ 9 | posterior class 10 | """ 11 | _registry = dict() 12 | 13 | def __init__(self, vocab_size, embed_dim, padding_idx, _shared_embed=None): 14 | super(Posterior, self).__init__() 15 | if _shared_embed is None: 16 | self.tgt_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) 17 | self.reset_parameters() 18 | else: 19 | self.tgt_embed = _shared_embed 20 | 21 | def reset_parameters(self): 22 | nn.init.uniform_(self.tgt_embed.weight, -0.1, 0.1) 23 | if self.tgt_embed.padding_idx is not None: 24 | with torch.no_grad(): 25 | self.tgt_embed.weight[self.tgt_embed.padding_idx].fill_(0) 26 | 27 | def target_embed_weight(self): 28 | raise NotImplementedError 29 | 30 | @staticmethod 31 | def reparameterize(mu, logvar, mask, nsamples=1, random=True): 32 | # [batch, length, dim] 33 | size = mu.size() 34 | std = logvar.mul(0.5).exp() 35 | # [batch, nsamples, length, dim] 36 | if random: 37 | eps = torch.randn(size[0], nsamples, *size[1:], device=mu.device) 38 | eps *= mask.view(size[0], 1, size[1], 1) 39 | else: 40 | eps = mu.new_zeros(size[0], nsamples, *size[1:]) 41 | return eps.mul(std.unsqueeze(1)).add(mu.unsqueeze(1)), eps 42 | 43 | 44 | @staticmethod 45 | def log_probability(z, eps, mu, logvar, mask): 46 | size = eps.size() 47 | nz = size[3] 48 | # [batch, nsamples, length, nz] 49 | log_probs = logvar.unsqueeze(1) + eps.pow(2) 50 | # [batch, 1] 51 | cc = mask.sum(dim=1, keepdim=True) * (math.log(math.pi * 2.) * nz) 52 | # [batch, nsamples, length * nz] --> [batch, nsamples] 53 | log_probs = log_probs.view(size[0], size[1], -1).sum(dim=2) + cc 54 | return log_probs * -0.5 55 | 56 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): 57 | raise NotImplementedError 58 | 59 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]: 60 | raise NotImplementedError 61 | 62 | def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor, 63 | src_enc: torch.Tensor, src_masks: torch.Tensor, 64 | nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]: 65 | """ 66 | 67 | Args: 68 | tgt_sents: Tensor [batch, tgt_length] 69 | tensor for target sentences 70 | tgt_masks: Tensor [batch, tgt_length] 71 | tensor for target masks 72 | src_enc: Tensor [batch, src_length, hidden_size] 73 | tensor for source encoding 74 | src_masks: Tensor [batch, src_length] 75 | tensor for source masks 76 | nsamples: int 77 | number of samples 78 | random: bool 79 | if True, perform random sampling. Otherwise, return mean. 80 | 81 | Returns: Tensor1, Tensor2 82 | Tensor1: samples from the posterior [batch, nsamples, tgt_length, nz] 83 | Tensor2: log probabilities [batch, nsamples] 84 | 85 | """ 86 | raise NotImplementedError 87 | 88 | @classmethod 89 | def register(cls, name: str): 90 | Posterior._registry[name] = cls 91 | 92 | @classmethod 93 | def by_name(cls, name: str): 94 | return Posterior._registry[name] 95 | 96 | @classmethod 97 | def from_params(cls, params: Dict): 98 | raise NotImplementedError 99 | -------------------------------------------------------------------------------- /flownmt/modules/posteriors/rnn.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Tuple, Dict 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 7 | 8 | from flownmt.nnet.weightnorm import LinearWeightNorm 9 | from flownmt.modules.posteriors.posterior import Posterior 10 | from flownmt.nnet.attention import GlobalAttention 11 | 12 | 13 | class RecurrentCore(nn.Module): 14 | def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, use_attn=False, dropout=0.0, dropword=0.0): 15 | super(RecurrentCore, self).__init__() 16 | if rnn_mode == 'RNN': 17 | RNN = nn.RNN 18 | elif rnn_mode == 'LSTM': 19 | RNN = nn.LSTM 20 | elif rnn_mode == 'GRU': 21 | RNN = nn.GRU 22 | else: 23 | raise ValueError('Unknown RNN mode: %s' % rnn_mode) 24 | assert hidden_size % 2 == 0 25 | self.tgt_embed = embed 26 | self.rnn = RNN(embed.embedding_dim, hidden_size // 2, 27 | num_layers=num_layers, batch_first=True, bidirectional=True) 28 | self.use_attn = use_attn 29 | if use_attn: 30 | self.attn = GlobalAttention(latent_dim, hidden_size, hidden_size, hidden_features=hidden_size) 31 | self.ctx_proj = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size), nn.ELU()) 32 | else: 33 | self.ctx_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ELU()) 34 | self.dropout = dropout 35 | self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens 36 | self.mu = LinearWeightNorm(hidden_size, latent_dim, bias=True) 37 | self.logvar = LinearWeightNorm(hidden_size, latent_dim, bias=True) 38 | 39 | @overrides 40 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): 41 | tgt_embed = self.tgt_embed(tgt_sents) 42 | if self.dropout2d is not None: 43 | tgt_embed = self.dropout2d(tgt_embed) 44 | lengths = tgt_masks.sum(dim=1).long() 45 | packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False) 46 | packed_enc, _ = self.rnn(packed_embed) 47 | tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1)) 48 | if self.use_attn: 49 | ctx = self.attn(tgt_enc, src_enc, key_mask=src_masks.eq(0)) 50 | ctx = torch.cat([tgt_enc, ctx], dim=2) 51 | else: 52 | ctx = tgt_enc 53 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) 54 | mu = self.mu(ctx) * tgt_masks.unsqueeze(2) 55 | logvar = self.logvar(ctx) * tgt_masks.unsqueeze(2) 56 | return mu, logvar 57 | 58 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True): 59 | with torch.no_grad(): 60 | tgt_embed = self.tgt_embed(tgt_sents) 61 | if self.dropout2d is not None: 62 | tgt_embed = self.dropout2d(tgt_embed) 63 | lengths = tgt_masks.sum(dim=1).long() 64 | packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False) 65 | packed_enc, _ = self.rnn(packed_embed) 66 | tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1)) 67 | if self.use_attn: 68 | ctx = self.attn.init(tgt_enc, src_enc, key_mask=src_masks.eq(0), init_scale=init_scale) 69 | ctx = torch.cat([tgt_enc, ctx], dim=2) 70 | else: 71 | ctx = tgt_enc 72 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) 73 | mu = self.mu.init(ctx, init_scale=0.05 * init_scale) if init_mu else self.mu(ctx) 74 | logvar = self.logvar.init(ctx, init_scale=0.05 * init_scale) if init_var else self.logvar(ctx) 75 | mu = mu * tgt_masks.unsqueeze(2) 76 | logvar = logvar * tgt_masks.unsqueeze(2) 77 | return mu, logvar 78 | 79 | 80 | class RecurrentPosterior(Posterior): 81 | """ 82 | Posterior with Recurrent Neural Networks 83 | """ 84 | def __init__(self, vocab_size, embed_dim, padding_idx, rnn_mode, num_layers, latent_dim, hidden_size, 85 | use_attn=False, dropout=0.0, dropword=0.0, _shared_embed=None): 86 | super(RecurrentPosterior, self).__init__(vocab_size, embed_dim, padding_idx, _shared_embed=_shared_embed) 87 | self.core = RecurrentCore(self.tgt_embed, rnn_mode, num_layers, latent_dim, hidden_size, 88 | use_attn=use_attn, dropout=dropout, dropword=dropword) 89 | 90 | def target_embed_weight(self): 91 | if isinstance(self.core, nn.DataParallel): 92 | return self.core.module.tgt_embedd.weight 93 | else: 94 | return self.core.tgt_embed.weight 95 | 96 | @overrides 97 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): 98 | return self.core(tgt_sents, tgt_masks, src_enc, src_masks) 99 | 100 | @overrides 101 | def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor, 102 | src_enc: torch.Tensor, src_masks: torch.Tensor, 103 | nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]: 104 | mu, logvar = self.core(tgt_sents, tgt_masks, src_enc, src_masks) 105 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, nsamples=nsamples, random=random) 106 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks) 107 | return z, log_probs 108 | 109 | @overrides 110 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]: 111 | mu, logvar = self.core.init(tgt_sents, tgt_masks, src_enc, src_masks, 112 | init_scale=init_scale, init_mu=init_mu, init_var=init_var) 113 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, random=True) 114 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks) 115 | z = z.squeeze(1) 116 | log_probs = log_probs.squeeze(1) 117 | return z, log_probs 118 | 119 | @classmethod 120 | def from_params(cls, params: Dict) -> "RecurrentPosterior": 121 | return RecurrentPosterior(**params) 122 | 123 | 124 | RecurrentPosterior.register('rnn') 125 | -------------------------------------------------------------------------------- /flownmt/modules/posteriors/shift_rnn.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Tuple, Dict 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 7 | 8 | from flownmt.nnet.weightnorm import LinearWeightNorm 9 | from flownmt.modules.posteriors.posterior import Posterior 10 | from flownmt.nnet.attention import GlobalAttention 11 | 12 | 13 | class ShiftRecurrentCore(nn.Module): 14 | def __init__(self, embed, rnn_mode, num_layers, latent_dim, hidden_size, bidirectional=True, use_attn=False, dropout=0.0, dropword=0.0): 15 | super(ShiftRecurrentCore, self).__init__() 16 | if rnn_mode == 'RNN': 17 | RNN = nn.RNN 18 | elif rnn_mode == 'LSTM': 19 | RNN = nn.LSTM 20 | elif rnn_mode == 'GRU': 21 | RNN = nn.GRU 22 | else: 23 | raise ValueError('Unknown RNN mode: %s' % rnn_mode) 24 | assert hidden_size % 2 == 0 25 | self.tgt_embed = embed 26 | assert num_layers == 1 27 | self.bidirectional = bidirectional 28 | if bidirectional: 29 | self.rnn = RNN(embed.embedding_dim, hidden_size // 2, num_layers=1, batch_first=True, bidirectional=True) 30 | else: 31 | self.rnn = RNN(embed.embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=False) 32 | self.use_attn = use_attn 33 | if use_attn: 34 | self.attn = GlobalAttention(latent_dim, hidden_size, hidden_size) 35 | self.ctx_proj = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size), nn.ELU()) 36 | else: 37 | self.ctx_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ELU()) 38 | self.dropout = dropout 39 | self.dropout2d = nn.Dropout2d(dropword) if dropword > 0. else None # drop entire tokens 40 | self.mu = LinearWeightNorm(hidden_size, latent_dim, bias=True) 41 | self.logvar = LinearWeightNorm(hidden_size, latent_dim, bias=True) 42 | 43 | @overrides 44 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): 45 | tgt_embed = self.tgt_embed(tgt_sents) 46 | if self.dropout2d is not None: 47 | tgt_embed = self.dropout2d(tgt_embed) 48 | lengths = tgt_masks.sum(dim=1).long() 49 | packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False) 50 | packed_enc, _ = self.rnn(packed_embed) 51 | tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1)) 52 | 53 | if self.bidirectional: 54 | # split into fwd and bwd 55 | fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2) # (batch_size, seq_len, hidden_size // 2) 56 | pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2))) 57 | pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1) 58 | pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1) 59 | tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2) 60 | else: 61 | pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2))) 62 | tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1] 63 | 64 | if self.use_attn: 65 | ctx = self.attn(tgt_enc, src_enc, key_mask=src_masks.eq(0)) 66 | ctx = torch.cat([tgt_enc, ctx], dim=2) 67 | else: 68 | ctx = tgt_enc 69 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) 70 | mu = self.mu(ctx) * tgt_masks.unsqueeze(2) 71 | logvar = self.logvar(ctx) * tgt_masks.unsqueeze(2) 72 | return mu, logvar 73 | 74 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True): 75 | with torch.no_grad(): 76 | tgt_embed = self.tgt_embed(tgt_sents) 77 | if self.dropout2d is not None: 78 | tgt_embed = self.dropout2d(tgt_embed) 79 | lengths = tgt_masks.sum(dim=1).long() 80 | packed_embed = pack_padded_sequence(tgt_embed, lengths, batch_first=True, enforce_sorted=False) 81 | packed_enc, _ = self.rnn(packed_embed) 82 | tgt_enc, _ = pad_packed_sequence(packed_enc, batch_first=True, total_length=tgt_masks.size(1)) 83 | 84 | if self.bidirectional: 85 | fwd_tgt_enc, bwd_tgt_enc = tgt_enc.chunk(2, dim=2) # (batch_size, seq_len, hidden_size // 2) 86 | pad_vector = fwd_tgt_enc.new_zeros((fwd_tgt_enc.size(0), 1, fwd_tgt_enc.size(2))) 87 | pad_fwd_tgt_enc = torch.cat([pad_vector, fwd_tgt_enc], dim=1) 88 | pad_bwd_tgt_enc = torch.cat([bwd_tgt_enc, pad_vector], dim=1) 89 | tgt_enc = torch.cat([pad_fwd_tgt_enc[:, :-1], pad_bwd_tgt_enc[:, 1:]], dim=2) 90 | else: 91 | pad_vector = tgt_enc.new_zeros((tgt_enc.size(0), 1, tgt_enc.size(2))) 92 | tgt_enc = torch.cat([pad_vector, tgt_enc], dim=1)[:, :-1] 93 | 94 | if self.use_attn: 95 | ctx = self.attn.init(tgt_enc, src_enc, key_mask=src_masks.eq(0), init_scale=init_scale) 96 | ctx = torch.cat([tgt_enc, ctx], dim=2) 97 | else: 98 | ctx = tgt_enc 99 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) 100 | mu = self.mu.init(ctx, init_scale=0.05 * init_scale) if init_mu else self.mu(ctx) 101 | logvar = self.logvar.init(ctx, init_scale=0.05 * init_scale) if init_var else self.logvar(ctx) 102 | mu = mu * tgt_masks.unsqueeze(2) 103 | logvar = logvar * tgt_masks.unsqueeze(2) 104 | return mu, logvar 105 | 106 | 107 | class ShiftRecurrentPosterior(Posterior): 108 | """ 109 | Posterior with Recurrent Neural Networks 110 | """ 111 | def __init__(self, vocab_size, embed_dim, padding_idx, rnn_mode, num_layers, latent_dim, hidden_size, 112 | bidirectional=True, use_attn=False, dropout=0.0, dropword=0.0, _shared_embed=None): 113 | super(ShiftRecurrentPosterior, self).__init__(vocab_size, embed_dim, padding_idx, _shared_embed=_shared_embed) 114 | self.core = ShiftRecurrentCore(self.tgt_embed, rnn_mode, num_layers, latent_dim, hidden_size, 115 | bidirectional=bidirectional, use_attn=use_attn, dropout=dropout, dropword=dropword) 116 | 117 | def target_embed_weight(self): 118 | if isinstance(self.core, nn.DataParallel): 119 | return self.core.module.tgt_embedd.weight 120 | else: 121 | return self.core.tgt_embed.weight 122 | 123 | @overrides 124 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): 125 | return self.core(tgt_sents, tgt_masks, src_enc, src_masks) 126 | 127 | @overrides 128 | def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor, 129 | src_enc: torch.Tensor, src_masks: torch.Tensor, 130 | nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]: 131 | mu, logvar = self.core(tgt_sents, tgt_masks, src_enc, src_masks) 132 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, nsamples=nsamples, random=random) 133 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks) 134 | return z, log_probs 135 | 136 | @overrides 137 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]: 138 | mu, logvar = self.core.init(tgt_sents, tgt_masks, src_enc, src_masks, 139 | init_scale=init_scale, init_mu=init_mu, init_var=init_var) 140 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, random=True) 141 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks) 142 | z = z.squeeze(1) 143 | log_probs = log_probs.squeeze(1) 144 | return z, log_probs 145 | 146 | @classmethod 147 | def from_params(cls, params: Dict) -> "ShiftRecurrentPosterior": 148 | return ShiftRecurrentPosterior(**params) 149 | 150 | 151 | ShiftRecurrentPosterior.register('shift_rnn') 152 | -------------------------------------------------------------------------------- /flownmt/modules/posteriors/transformer.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Tuple, Dict 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from flownmt.nnet.weightnorm import LinearWeightNorm 9 | from flownmt.nnet.transformer import TransformerDecoderLayer 10 | from flownmt.nnet.positional_encoding import PositionalEncoding 11 | from flownmt.modules.posteriors.posterior import Posterior 12 | 13 | 14 | class TransformerCore(nn.Module): 15 | def __init__(self, embed, num_layers, latent_dim, hidden_size, heads, dropout=0.0, dropword=0.0, max_length=100): 16 | super(TransformerCore, self).__init__() 17 | self.tgt_embed = embed 18 | self.padding_idx = embed.padding_idx 19 | 20 | embed_dim = embed.embedding_dim 21 | self.embed_scale = math.sqrt(embed_dim) 22 | assert embed_dim == latent_dim 23 | layers = [TransformerDecoderLayer(latent_dim, hidden_size, heads, dropout=dropout) for _ in range(num_layers)] 24 | self.layers = nn.ModuleList(layers) 25 | self.pos_enc = PositionalEncoding(latent_dim, self.padding_idx, max_length + 1) 26 | self.dropword = dropword # drop entire tokens 27 | self.mu = LinearWeightNorm(latent_dim, latent_dim, bias=True) 28 | self.logvar = LinearWeightNorm(latent_dim, latent_dim, bias=True) 29 | self.reset_parameters() 30 | 31 | def reset_parameters(self): 32 | pass 33 | 34 | @overrides 35 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): 36 | x = self.embed_scale * self.tgt_embed(tgt_sents) 37 | x = F.dropout2d(x, p=self.dropword, training=self.training) 38 | x += self.pos_enc(tgt_sents) 39 | x = F.dropout(x, p=0.2, training=self.training) 40 | 41 | mask = tgt_masks.eq(0) 42 | key_mask = src_masks.eq(0) 43 | for layer in self.layers: 44 | x = layer(x, mask, src_enc, key_mask) 45 | 46 | mu = self.mu(x) * tgt_masks.unsqueeze(2) 47 | logvar = self.logvar(x) * tgt_masks.unsqueeze(2) 48 | return mu, logvar 49 | 50 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True): 51 | with torch.no_grad(): 52 | x = self.embed_scale * self.tgt_embed(tgt_sents) 53 | x = F.dropout2d(x, p=self.dropword, training=self.training) 54 | x += self.pos_enc(tgt_sents) 55 | x = F.dropout(x, p=0.2, training=self.training) 56 | 57 | mask = tgt_masks.eq(0) 58 | key_mask = src_masks.eq(0) 59 | for layer in self.layers: 60 | x = layer.init(x, mask, src_enc, key_mask, init_scale=init_scale) 61 | 62 | x = x * tgt_masks.unsqueeze(2) 63 | mu = self.mu.init(x, init_scale=0.05 * init_scale) if init_mu else self.mu(x) 64 | logvar = self.logvar.init(x, init_scale=0.05 * init_scale) if init_var else self.logvar(x) 65 | mu = mu * tgt_masks.unsqueeze(2) 66 | logvar = logvar * tgt_masks.unsqueeze(2) 67 | return mu, logvar 68 | 69 | 70 | class TransformerPosterior(Posterior): 71 | """ 72 | Posterior with Transformer 73 | """ 74 | def __init__(self, vocab_size, embed_dim, padding_idx, num_layers, latent_dim, hidden_size, heads, 75 | dropout=0.0, dropword=0.0, max_length=100, _shared_embed=None): 76 | super(TransformerPosterior, self).__init__(vocab_size, embed_dim, padding_idx, _shared_embed=_shared_embed) 77 | self.core = TransformerCore(self.tgt_embed, num_layers, latent_dim, hidden_size, heads, 78 | dropout=dropout, dropword=dropword, max_length=max_length) 79 | 80 | def target_embed_weight(self): 81 | if isinstance(self.core, nn.DataParallel): 82 | return self.core.module.tgt_embedd.weight 83 | else: 84 | return self.core.tgt_embed.weight 85 | 86 | @overrides 87 | def forward(self, tgt_sents, tgt_masks, src_enc, src_masks): 88 | return self.core(tgt_sents, tgt_masks, src_enc, src_masks) 89 | 90 | @overrides 91 | def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor, 92 | src_enc: torch.Tensor, src_masks: torch.Tensor, 93 | nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]: 94 | mu, logvar = self.core(tgt_sents, tgt_masks, src_enc, src_masks) 95 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, nsamples=nsamples, random=random) 96 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks) 97 | return z, log_probs 98 | 99 | @overrides 100 | def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]: 101 | mu, logvar = self.core.init(tgt_sents, tgt_masks, src_enc, src_masks, 102 | init_scale=init_scale, init_mu=init_mu, init_var=init_var) 103 | z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, random=True) 104 | log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks) 105 | z = z.squeeze(1) 106 | log_probs = log_probs.squeeze(1) 107 | return z, log_probs 108 | 109 | @classmethod 110 | def from_params(cls, params: Dict) -> "TransformerPosterior": 111 | return TransformerPosterior(**params) 112 | 113 | 114 | TransformerPosterior.register('transformer') 115 | -------------------------------------------------------------------------------- /flownmt/modules/priors/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.modules.priors.prior import Prior 2 | from flownmt.modules.priors.length_predictors import * 3 | -------------------------------------------------------------------------------- /flownmt/modules/priors/length_predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.modules.priors.length_predictors.predictor import LengthPredictor 2 | from flownmt.modules.priors.length_predictors.diff_discretized_mix_logistic import DiffDiscreteMixLogisticLengthPredictor 3 | from flownmt.modules.priors.length_predictors.diff_softmax import DiffSoftMaxLengthPredictor 4 | -------------------------------------------------------------------------------- /flownmt/modules/priors/length_predictors/diff_discretized_mix_logistic.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from flownmt.modules.priors.length_predictors.predictor import LengthPredictor 8 | from flownmt.modules.priors.length_predictors.utils import discretized_mix_logistic_loss, discretized_mix_logistic_topk 9 | 10 | 11 | class DiffDiscreteMixLogisticLengthPredictor(LengthPredictor): 12 | def __init__(self, features, max_src_length, diff_range, nmix=1, dropout=0.0): 13 | super(DiffDiscreteMixLogisticLengthPredictor, self).__init__() 14 | self.max_src_length = max_src_length 15 | self.range = diff_range 16 | self.nmix = nmix 17 | self.features = features 18 | self.dropout = dropout 19 | self.ctx_proj = None 20 | self.diff = None 21 | 22 | def set_length_unit(self, length_unit): 23 | self.length_unit = length_unit 24 | self.ctx_proj = nn.Sequential(nn.Linear(self.features, self.features), nn.ELU()) 25 | self.diff = nn.Linear(self.features, 3 * self.nmix) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | nn.init.constant_(self.ctx_proj[0].bias, 0.) 30 | nn.init.uniform_(self.diff.weight, -0.1, 0.1) 31 | nn.init.constant_(self.diff.bias, 0.) 32 | 33 | def forward(self, ctx): 34 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) 35 | # [batch, 3 * nmix] 36 | coeffs = self.diff(ctx) 37 | # [batch, nmix] 38 | logit_probs = F.log_softmax(coeffs[:, :self.nmix], dim=1) 39 | mu = coeffs[:, self.nmix:self.nmix * 2] 40 | log_scale = coeffs[:, self.nmix * 2:] 41 | return mu, log_scale, logit_probs 42 | 43 | @overrides 44 | def loss(self, ctx: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Args: 47 | ctx: Tensor 48 | tensor [batch, features] 49 | src_mask: Tensor 50 | tensor for source mask [batch, src_length] 51 | tgt_mask: Tensor 52 | tensor for target mask [batch, tgt_length] 53 | Returns: Tensor 54 | tensor for loss [batch] 55 | """ 56 | src_lengths = src_mask.sum(dim=1).float() 57 | tgt_lengths = tgt_mask.sum(dim=1).float() 58 | mu, log_scale, logit_probs = self(ctx, src_lengths.long()) 59 | x = (tgt_lengths - src_lengths).div(self.range).clamp(min=-1, max=1) 60 | bin_size = 0.5 / self.range 61 | lower = bin_size - 1.0 62 | upper = 1.0 - bin_size 63 | loss = discretized_mix_logistic_loss(x, mu, log_scale, logit_probs, bin_size, lower, upper) 64 | return loss 65 | 66 | @overrides 67 | def predict(self, ctx: torch.Tensor, src_mask:torch.Tensor, topk: int = 1) -> Tuple[torch.Tensor, torch.LongTensor]: 68 | """ 69 | Args: 70 | ctx: Tensor 71 | tensor [batch, features] 72 | src_mask: Tensor 73 | tensor for source mask [batch, src_length] 74 | topk: int (default 1) 75 | return top k length candidates for each src sentence 76 | Returns: Tensor1, LongTensor2 77 | Tensor1: log probs for each length 78 | LongTensor2: tensor for lengths [batch, topk] 79 | """ 80 | bin_size = 0.5 / self.range 81 | lower = bin_size - 1.0 82 | upper = 1.0 - bin_size 83 | # [batch] 84 | src_lengths = src_mask.sum(dim=1).long() 85 | mu, log_scale, logit_probs = self(ctx, src_lengths) 86 | # [batch, topk] 87 | log_probs, diffs = discretized_mix_logistic_topk(mu, log_scale, logit_probs, 88 | self.range, bin_size, lower, upper, topk=topk) 89 | lengths = (diffs + src_lengths.unsqueeze(1)).clamp(min=self.length_unit) 90 | res = lengths.fmod(self.length_unit) 91 | padding = (self.length_unit - res).fmod(self.length_unit) 92 | lengths = lengths + padding 93 | return log_probs, lengths 94 | 95 | @classmethod 96 | def from_params(cls, params: Dict) -> 'DiffDiscreteMixLogisticLengthPredictor': 97 | return DiffDiscreteMixLogisticLengthPredictor(**params) 98 | 99 | 100 | DiffDiscreteMixLogisticLengthPredictor.register('diff_logistic') 101 | -------------------------------------------------------------------------------- /flownmt/modules/priors/length_predictors/diff_softmax.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from typing import Dict, Tuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from flownmt.modules.priors.length_predictors.predictor import LengthPredictor 8 | from flownmt.nnet.criterion import LabelSmoothedCrossEntropyLoss 9 | 10 | 11 | class DiffSoftMaxLengthPredictor(LengthPredictor): 12 | def __init__(self, features, max_src_length, diff_range, dropout=0.0, label_smoothing=0.): 13 | super(DiffSoftMaxLengthPredictor, self).__init__() 14 | self.max_src_length = max_src_length 15 | self.range = diff_range 16 | self.features = features 17 | self.dropout = dropout 18 | self.ctx_proj = None 19 | self.diff = None 20 | if label_smoothing < 1e-5: 21 | self.criterion = nn.CrossEntropyLoss(reduction='none') 22 | elif 1e-5 < label_smoothing < 1.0: 23 | self.criterion = LabelSmoothedCrossEntropyLoss(label_smoothing) 24 | else: 25 | raise ValueError('label smoothing should be less than 1.0.') 26 | 27 | def set_length_unit(self, length_unit): 28 | self.length_unit = length_unit 29 | self.ctx_proj = nn.Sequential(nn.Linear(self.features, self.features), nn.ELU(), 30 | nn.Linear(self.features, self.features), nn.ELU()) 31 | self.diff = nn.Linear(self.features, 2 * self.range + 1) 32 | self.reset_parameters() 33 | 34 | def reset_parameters(self): 35 | nn.init.constant_(self.ctx_proj[0].bias, 0.) 36 | nn.init.constant_(self.ctx_proj[2].bias, 0.) 37 | nn.init.uniform_(self.diff.weight, -0.1, 0.1) 38 | nn.init.constant_(self.diff.bias, 0.) 39 | 40 | def forward(self, ctx): 41 | ctx = F.dropout(self.ctx_proj(ctx), p=self.dropout, training=self.training) 42 | return self.diff(ctx) 43 | 44 | @overrides 45 | def loss(self, ctx: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor: 46 | """ 47 | Args: 48 | ctx: Tensor 49 | tensor [batch, features] 50 | src_mask: Tensor 51 | tensor for source mask [batch, src_length] 52 | tgt_mask: Tensor 53 | tensor for target mask [batch, tgt_length] 54 | Returns: Tensor 55 | tensor for loss [batch] 56 | """ 57 | # [batch] 58 | src_lengths = src_mask.sum(dim=1).long() 59 | tgt_lengths = tgt_mask.sum(dim=1).long() 60 | # [batch, 2 * range + 1] 61 | logits = self(ctx) 62 | # [1, 2 * range + 1] 63 | mask = torch.arange(0, logits.size(1), device=logits.device).unsqueeze(0) 64 | # [batch, 2 * range + 1] 65 | mask = (mask + src_lengths.unsqueeze(1) - self.range).fmod(self.length_unit).ne(0) 66 | logits = logits.masked_fill(mask, float('-inf')) 67 | 68 | # handle tgt < src - range 69 | x = (tgt_lengths - src_lengths + self.range).clamp(min=0) 70 | tgt = x + src_lengths - self.range 71 | res = tgt.fmod(self.length_unit) 72 | padding = (self.length_unit - res).fmod(self.length_unit) 73 | tgt = tgt + padding 74 | # handle tgt > src + range 75 | x = (tgt - src_lengths + self.range).clamp(max=2 * self.range) 76 | tgt = x + src_lengths - self.range 77 | tgt = tgt - tgt.fmod(self.length_unit) 78 | 79 | x = tgt - src_lengths + self.range 80 | loss_length = self.criterion(logits, x) 81 | return loss_length 82 | 83 | @overrides 84 | def predict(self, ctx: torch.Tensor, src_mask:torch.Tensor, topk: int = 1) -> Tuple[torch.LongTensor, torch.Tensor]: 85 | """ 86 | Args: 87 | ctx: Tensor 88 | tensor [batch, features] 89 | src_mask: Tensor 90 | tensor for source mask [batch, src_length] 91 | topk: int (default 1) 92 | return top k length candidates for each src sentence 93 | Returns: LongTensor1, Tensor2 94 | LongTensor1: tensor for lengths [batch, topk] 95 | Tensor2: log probs for each length 96 | """ 97 | # [batch] 98 | src_lengths = src_mask.sum(dim=1).long() 99 | # [batch, 2 * range + 1] 100 | logits = self(ctx) 101 | # [1, 2 * range + 1] 102 | x = torch.arange(0, logits.size(1), device=logits.device).unsqueeze(0) 103 | # [batch, 2 * range + 1] 104 | tgt = x + src_lengths.unsqueeze(1) - self.range 105 | mask = tgt.fmod(self.length_unit).ne(0) 106 | logits = logits.masked_fill(mask, float('-inf')) 107 | # [batch, 2 * range + 1] 108 | log_probs = F.log_softmax(logits, dim=1) 109 | # handle tgt length <= 0 110 | mask = tgt.le(0) 111 | log_probs = log_probs.masked_fill(mask, float('-inf')) 112 | # [batch, topk] 113 | log_probs, x = log_probs.topk(topk, dim=1) 114 | lengths = x + src_lengths.unsqueeze(1) - self.range 115 | return lengths, log_probs 116 | 117 | @classmethod 118 | def from_params(cls, params: Dict) -> 'DiffSoftMaxLengthPredictor': 119 | return DiffSoftMaxLengthPredictor(**params) 120 | 121 | 122 | DiffSoftMaxLengthPredictor.register('diff_softmax') 123 | -------------------------------------------------------------------------------- /flownmt/modules/priors/length_predictors/predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class LengthPredictor(nn.Module): 7 | """ 8 | Length Predictor 9 | """ 10 | _registry = dict() 11 | 12 | def __init__(self): 13 | super(LengthPredictor, self).__init__() 14 | self.length_unit = None 15 | 16 | def set_length_unit(self, length_unit): 17 | self.length_unit = length_unit 18 | 19 | def loss(self, ctx: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor: 20 | """ 21 | 22 | Args: 23 | ctx: Tensor 24 | tensor [batch, features] 25 | src_mask: Tensor 26 | tensor for source mask [batch, src_length] 27 | tgt_mask: Tensor 28 | tensor for target mask [batch, tgt_length] 29 | 30 | Returns: Tensor 31 | tensor for loss [batch] 32 | 33 | """ 34 | raise NotImplementedError 35 | 36 | def predict(self, ctx: torch.Tensor, src_mask:torch.Tensor, topk: int = 1) -> Tuple[torch.LongTensor, torch.Tensor]: 37 | """ 38 | Args: 39 | ctx: Tensor 40 | tensor [batch, features] 41 | src_mask: Tensor 42 | tensor for source mask [batch, src_length] 43 | topk: int (default 1) 44 | return top k length candidates for each src sentence 45 | Returns: LongTensor1, Tensor2 46 | LongTensor1: tensor for lengths [batch, topk] 47 | Tensor2: log probs for each length 48 | """ 49 | raise NotImplementedError 50 | 51 | @classmethod 52 | def register(cls, name: str): 53 | LengthPredictor._registry[name] = cls 54 | 55 | @classmethod 56 | def by_name(cls, name: str): 57 | return LengthPredictor._registry[name] 58 | 59 | @classmethod 60 | def from_params(cls, params: Dict): 61 | raise NotImplementedError 62 | -------------------------------------------------------------------------------- /flownmt/modules/priors/length_predictors/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def discretized_mix_logistic_loss(x, means, logscales, logit_probs, 8 | bin_size, lower, upper) -> torch.Tensor: 9 | """ 10 | loss for discretized mixture logistic distribution 11 | Args: 12 | x: [batch, ] 13 | means: [batch, nmix] 14 | logscales: [batch, nmix] 15 | logit_probs:, [batch, nmix] 16 | bin_size: float 17 | The segment for cdf is [x-binsize, x+binsize] 18 | lower: float 19 | upper: float 20 | Returns: 21 | loss [batch] 22 | """ 23 | eps = 1e-12 24 | # [batch, 1] 25 | x = x.unsqueeze(1) 26 | # [batch, nmix] 27 | centered_x = x - means 28 | if isinstance(logscales, float): 29 | inv_stdv = np.exp(-logscales) 30 | else: 31 | inv_stdv = torch.exp(-logscales) 32 | 33 | # [batch, nmix] 34 | min_in = inv_stdv * (centered_x - bin_size) 35 | plus_in = inv_stdv * (centered_x + bin_size) 36 | x_in = inv_stdv * centered_x 37 | 38 | # [batch, nmix] 39 | cdf_min = torch.sigmoid(min_in) 40 | cdf_plus = torch.sigmoid(plus_in) 41 | # lower < x < upper 42 | cdf_delta = cdf_plus - cdf_min 43 | log_cdf_mid = torch.log(cdf_delta + eps) 44 | log_cdf_approx = x_in - logscales - 2. * F.softplus(x_in) + np.log(2 * bin_size) 45 | 46 | # x < lower 47 | log_cdf_low = plus_in - F.softplus(plus_in) 48 | 49 | # x > upper 50 | log_cdf_up = -F.softplus(min_in) 51 | 52 | # [batch, nmix] 53 | log_cdf = torch.where(cdf_delta.gt(1e-5), log_cdf_mid, log_cdf_approx) 54 | log_cdf = torch.where(x.ge(lower), log_cdf, log_cdf_low) 55 | log_cdf = torch.where(x.le(upper), log_cdf, log_cdf_up) 56 | 57 | # [batch] 58 | loss = torch.logsumexp(log_cdf + logit_probs, dim=1) * -1. 59 | return loss 60 | 61 | 62 | def discretized_mix_logistic_topk(means, logscales, logit_probs, 63 | range, bin_size, lower, upper, topk=1) -> Tuple[torch.Tensor, torch.LongTensor]: 64 | """ 65 | topk for discretized mixture logistic distribution 66 | Args: 67 | means: [batch, nmix] 68 | logscales: [batch, nmix] 69 | logit_probs:, [batch, nmix] 70 | range: int 71 | range of x 72 | bin_size: float 73 | The segment for cdf is [x-binsize, x+binsize] 74 | lower: float 75 | upper: float 76 | topk: int 77 | Returns: Tensor1, Tensor2 78 | Tensor1: log probs [batch, topk] 79 | Tensor2: indexes for top k [batch, topk] 80 | 81 | """ 82 | eps = 1e-12 83 | # [batch, 1, nmix] 84 | means = means.unsqueeze(1) 85 | logscales = logscales.unsqueeze(1) 86 | logit_probs = logit_probs.unsqueeze(1) 87 | # [1, 2 * range + 1, 1] 88 | x = torch.arange(-range, range + 1, 1., device=means.device).unsqueeze(0).unsqueeze(2) 89 | x = x.div(range) 90 | # [batch, 2 * range + 1, nmix] 91 | centered_x = x - means 92 | if isinstance(logscales, float): 93 | inv_stdv = np.exp(-logscales) 94 | else: 95 | inv_stdv = torch.exp(-logscales) 96 | 97 | # [batch, 2 * range + 1, nmix] 98 | min_in = inv_stdv * (centered_x - bin_size) 99 | plus_in = inv_stdv * (centered_x + bin_size) 100 | x_in = inv_stdv * centered_x 101 | 102 | # [batch, 2 * range + 1, nmix] 103 | cdf_min = torch.sigmoid(min_in) 104 | cdf_plus = torch.sigmoid(plus_in) 105 | # lower < x < upper 106 | cdf_delta = cdf_plus - cdf_min 107 | log_cdf_mid = torch.log(cdf_delta + eps) 108 | log_cdf_approx = x_in - logscales - 2. * F.softplus(x_in) + np.log(2 * bin_size) 109 | 110 | # x < lower 111 | log_cdf_low = plus_in - F.softplus(plus_in) 112 | 113 | # x > upper 114 | log_cdf_up = -F.softplus(min_in) 115 | 116 | # [batch, 2 * range + 1, nmix] 117 | log_cdf = torch.where(cdf_delta.gt(1e-5), log_cdf_mid, log_cdf_approx) 118 | log_cdf = torch.where(x.ge(lower), log_cdf, log_cdf_low) 119 | log_cdf = torch.where(x.le(upper), log_cdf, log_cdf_up) 120 | # [batch, 2 * range + 1] 121 | log_probs = torch.logsumexp(log_cdf + logit_probs, dim=2) 122 | log_probs, idx = log_probs.topk(topk, dim=1) 123 | 124 | return log_probs, idx - range 125 | -------------------------------------------------------------------------------- /flownmt/modules/priors/prior.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Tuple, Union 3 | import torch 4 | import torch.nn as nn 5 | 6 | from flownmt.flows.nmt import NMTFlow 7 | from flownmt.modules.priors.length_predictors import LengthPredictor 8 | 9 | 10 | class Prior(nn.Module): 11 | """ 12 | class for Prior with a NMTFlow inside 13 | """ 14 | _registry = dict() 15 | 16 | def __init__(self, flow: NMTFlow, length_predictor: LengthPredictor): 17 | super(Prior, self).__init__() 18 | assert flow.inverse, 'prior flow should have inverse mode' 19 | self.flow = flow 20 | self.length_unit = max(2, 2 ** (self.flow.levels - 1)) 21 | self.features = self.flow.features 22 | self._length_predictor = length_predictor 23 | self._length_predictor.set_length_unit(self.length_unit) 24 | 25 | def sync(self): 26 | self.flow.sync() 27 | 28 | def predict_length(self, ctx: torch.Tensor, src_mask: torch.Tensor, topk: int = 1) -> Tuple[torch.LongTensor, torch.Tensor]: 29 | """ 30 | Args: 31 | ctx: Tensor 32 | tensor [batch, features] 33 | src_mask: Tensor 34 | tensor for source mask [batch, src_length] 35 | topk: int (default 1) 36 | return top k length candidates for each src sentence 37 | Returns: LongTensor1, Tensor2 38 | LongTensor1: tensor for lengths [batch, topk] 39 | Tensor2: log probs for each length [batch, topk] 40 | """ 41 | return self._length_predictor.predict(ctx, src_mask, topk=topk) 42 | 43 | def length_loss(self, ctx: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor: 44 | """ 45 | 46 | Args: 47 | ctx: Tensor 48 | tensor [batch, features] 49 | src_mask: Tensor 50 | tensor for source mask [batch, src_length] 51 | tgt_mask: Tensor 52 | tensor for target mask [batch, tgt_length] 53 | 54 | Returns: Tensor 55 | tensor for loss [batch] 56 | 57 | """ 58 | return self._length_predictor.loss(ctx, src_mask, tgt_mask) 59 | 60 | def decode(self, epsilon: torch.Tensor, tgt_mask: torch.Tensor, 61 | src: torch.Tensor, src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 62 | """ 63 | 64 | Args: 65 | epsilon: Tensor 66 | epslion [batch, tgt_length, nz] 67 | tgt_mask: Tensor 68 | tensor of target masks [batch, tgt_length] 69 | src: Tensor 70 | source encoding [batch, src_length, hidden_size] 71 | src_mask: Tensor 72 | tensor of source masks [batch, src_length] 73 | 74 | Returns: Tensor1, Tensor2 75 | Tensor1: decoded latent code z [batch, tgt_length, nz] 76 | Tensor2: log probabilities [batch] 77 | 78 | """ 79 | # [batch, tgt_length, nz] 80 | z, logdet = self.flow.fwdpass(epsilon, tgt_mask, src, src_mask) 81 | # [batch, tgt_length, nz] 82 | log_probs = epsilon.mul(epsilon) + math.log(math.pi * 2.0) 83 | # apply mask 84 | log_probs = log_probs.mul(tgt_mask.unsqueeze(2)) 85 | # [batch] 86 | log_probs = log_probs.view(z.size(0), -1).sum(dim=1).mul(-0.5) + logdet 87 | return z, log_probs 88 | 89 | def sample(self, nlengths: int, nsamples: int, src: torch.Tensor, 90 | ctx: torch.Tensor, src_mask: torch.Tensor, 91 | tau=0.0, include_zero=False) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: 92 | """ 93 | 94 | Args: 95 | nlengths: int 96 | number of lengths per sentence 97 | nsamples: int 98 | number of samples per sentence per length 99 | src: Tensor 100 | source encoding [batch, src_length, hidden_size] 101 | ctx: Tensor 102 | tensor for global state [batch, hidden_size] 103 | src_mask: Tensor 104 | tensor of masks [batch, src_length] 105 | tau: float (default 0.0) 106 | temperature of density 107 | include_zero: bool (default False) 108 | include zero sample 109 | 110 | Returns: (Tensor1, Tensor2, Tensor3), (Tensor4, Tensor5), (Tensor6, Tensor7, Tensor8) 111 | Tensor1: samples from the prior [batch * nlengths * nsamples, tgt_length, nz] 112 | Tensor2: log probabilities [batch * nlengths * nsamples] 113 | Tensor3: target masks [batch * nlengths * nsamples, tgt_length] 114 | Tensor4: lengths [batch * nlengths] 115 | Tensor5: log probabilities of lengths [batch * nlengths] 116 | Tensor6: source encoding with shape [batch * nlengths * nsamples, src_length, hidden_size] 117 | Tensor7: tensor for global state [batch * nlengths * nsamples, hidden_size] 118 | Tensor8: source masks with shape [batch * nlengths * nsamples, src_length] 119 | 120 | """ 121 | batch = src.size(0) 122 | batch_nlen = batch * nlengths 123 | # [batch, nlenths] 124 | lengths, log_probs_length = self.predict_length(ctx, src_mask, topk=nlengths) 125 | # [batch * nlengths] 126 | log_probs_length = log_probs_length.view(-1) 127 | lengths = lengths.view(-1) 128 | max_length = lengths.max().item() 129 | # [batch * nlengths, max_length] 130 | tgt_mask = torch.arange(max_length).to(src.device).unsqueeze(0).expand(batch_nlen, max_length).lt(lengths.unsqueeze(1)).float() 131 | 132 | # [batch * nlengths, nsamples, tgt_length, nz] 133 | epsilon = src.new_empty(batch_nlen, nsamples, max_length, self.features).normal_() 134 | epsilon = epsilon.mul(tgt_mask.view(batch_nlen, 1, max_length, 1)) * tau 135 | if include_zero: 136 | epsilon[:, 0].zero_() 137 | # [batch * nlengths * nsamples, tgt_length, nz] 138 | epsilon = epsilon.view(-1, max_length, self.features) 139 | if nsamples * nlengths > 1: 140 | # [batch, nlengths * nsamples, src_length, hidden_size] 141 | src = src.unsqueeze(1) + src.new_zeros(batch, nlengths * nsamples, *src.size()[1:]) 142 | # [batch * nlengths * nsamples, src_length, hidden_size] 143 | src = src.view(batch_nlen * nsamples, *src.size()[2:]) 144 | # [batch, nlengths * nsamples, hidden_size] 145 | ctx = ctx.unsqueeze(1) + ctx.new_zeros(batch, nlengths * nsamples, ctx.size(1)) 146 | # [batch * nlengths * nsamples, hidden_size] 147 | ctx = ctx.view(batch_nlen * nsamples, ctx.size(2)) 148 | # [batch, nlengths * nsamples, src_length] 149 | src_mask = src_mask.unsqueeze(1) + src_mask.new_zeros(batch, nlengths * nsamples, src_mask.size(1)) 150 | # [batch * nlengths * nsamples, src_length] 151 | src_mask = src_mask.view(batch_nlen * nsamples, src_mask.size(2)) 152 | # [batch * nlengths, nsamples, tgt_length] 153 | tgt_mask = tgt_mask.unsqueeze(1) + tgt_mask.new_zeros(batch_nlen, nsamples, tgt_mask.size(1)) 154 | # [batch * nlengths * nsamples, tgt_length] 155 | tgt_mask = tgt_mask.view(batch_nlen * nsamples, tgt_mask.size(2)) 156 | 157 | # [batch * nlength * nsamples, tgt_length, nz] 158 | z, log_probs = self.decode(epsilon, tgt_mask, src, src_mask) 159 | return (z, log_probs, tgt_mask), (lengths, log_probs_length), (src, ctx, src_mask) 160 | 161 | def log_probability(self, z: torch.Tensor, tgt_mask: torch.Tensor, 162 | src: torch.Tensor, ctx: torch.Tensor, src_mask: torch.Tensor, 163 | length_loss: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 164 | """ 165 | 166 | Args: 167 | z: Tensor 168 | tensor of latent code [batch, length, nz] 169 | tgt_mask: Tensor 170 | tensor of target masks [batch, length] 171 | src: Tensor 172 | source encoding [batch, src_length, hidden_size] 173 | ctx: Tensor 174 | tensor for global state [batch, hidden_size] 175 | src_mask: Tensor 176 | tensor of source masks [batch, src_length] 177 | length_loss: bool (default True) 178 | compute loss of length 179 | 180 | Returns: Tensor1, Tensor2 181 | Tensor1: log probabilities of z [batch] 182 | Tensor2: length loss [batch] 183 | 184 | """ 185 | # [batch] 186 | loss_length = self.length_loss(ctx, src_mask, tgt_mask) if length_loss else None 187 | 188 | # [batch, length, nz] 189 | epsilon, logdet = self.flow.bwdpass(z, tgt_mask, src, src_mask) 190 | # [batch, tgt_length, nz] 191 | log_probs = epsilon.mul(epsilon) + math.log(math.pi * 2.0) 192 | # apply mask 193 | log_probs = log_probs.mul(tgt_mask.unsqueeze(2)) 194 | log_probs = log_probs.view(z.size(0), -1).sum(dim=1).mul(-0.5) + logdet 195 | return log_probs, loss_length 196 | 197 | def init(self, z, tgt_mask, src, src_mask, init_scale=1.0): 198 | return self.flow.bwdpass(z, tgt_mask, src, src_mask, init=True, init_scale=init_scale) 199 | 200 | @classmethod 201 | def register(cls, name: str): 202 | Prior._registry[name] = cls 203 | 204 | @classmethod 205 | def by_name(cls, name: str): 206 | return Prior._registry[name] 207 | 208 | @classmethod 209 | def from_params(cls, params: Dict) -> "Prior": 210 | flow_params = params.pop('flow') 211 | flow = NMTFlow.from_params(flow_params) 212 | predictor_params = params.pop('length_predictor') 213 | length_predictor = LengthPredictor.by_name(predictor_params.pop('type')).from_params(predictor_params) 214 | return Prior(flow, length_predictor) 215 | 216 | 217 | Prior.register('normal') 218 | -------------------------------------------------------------------------------- /flownmt/nnet/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.nnet.weightnorm import LinearWeightNorm, Conv1dWeightNorm 2 | from flownmt.nnet.attention import GlobalAttention, MultiHeadAttention, PositionwiseFeedForward 3 | from flownmt.nnet.transformer import TransformerEncoderLayer, TransformerDecoderLayer 4 | from flownmt.nnet.layer_norm import LayerNorm 5 | from flownmt.nnet.positional_encoding import PositionalEncoding 6 | from flownmt.nnet.criterion import LabelSmoothedCrossEntropyLoss 7 | -------------------------------------------------------------------------------- /flownmt/nnet/attention.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | import torch 3 | from torch.nn import Parameter 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from flownmt.nnet.layer_norm import LayerNorm 8 | 9 | 10 | class GlobalAttention(nn.Module): 11 | """ 12 | Global Attention between encoder and decoder 13 | """ 14 | 15 | def __init__(self, key_features, query_features, value_features, hidden_features=None, dropout=0.0): 16 | """ 17 | 18 | Args: 19 | key_features: int 20 | dimension of keys 21 | query_features: int 22 | dimension of queries 23 | value_features: int 24 | dimension of values (outputs) 25 | hidden_features: int 26 | dimension of hidden states (default value_features) 27 | dropout: float 28 | dropout rate 29 | """ 30 | super(GlobalAttention, self).__init__() 31 | if hidden_features is None: 32 | hidden_features = value_features 33 | self.key_proj = nn.Linear(key_features, 2 * hidden_features, bias=True) 34 | self.query_proj = nn.Linear(query_features, hidden_features, bias=True) 35 | self.dropout = dropout 36 | self.fc = nn.Linear(hidden_features, value_features) 37 | self.hidden_features = hidden_features 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self): 41 | # key proj 42 | nn.init.xavier_uniform_(self.key_proj.weight) 43 | nn.init.constant_(self.key_proj.bias, 0) 44 | # query proj 45 | nn.init.xavier_uniform_(self.query_proj.weight) 46 | nn.init.constant_(self.query_proj.bias, 0) 47 | # fc 48 | nn.init.xavier_uniform_(self.fc.weight) 49 | nn.init.constant_(self.fc.bias, 0) 50 | 51 | @overrides 52 | def forward(self, query, key, key_mask=None): 53 | """ 54 | 55 | Args: 56 | query: Tensor 57 | query tensor [batch, query_length, query_features] 58 | key: Tensor 59 | key tensor [batch, key_length, key_features] 60 | key_mask: ByteTensor or None 61 | binary ByteTensor [batch, src_len] padding elements are indicated by 1s. 62 | 63 | Returns: Tensor 64 | value tensor [batch, query_length, value_features] 65 | 66 | """ 67 | bs, timesteps, _ = key.size() 68 | dim = self.hidden_features 69 | # [batch, query_length, dim] 70 | query = self.query_proj(query) 71 | 72 | # [batch, key_length, 2 * dim] 73 | c = self.key_proj(key) 74 | # [batch, key_length, 2, dim] 75 | c = c.view(bs, timesteps, 2, dim) 76 | # [batch, key_length, dim] 77 | key = c[:, :, 0] 78 | value = c[:, :, 1] 79 | 80 | # attention weights [batch, query_length, key_length] 81 | attn_weights = torch.bmm(query, key.transpose(1, 2)) 82 | if key_mask is not None: 83 | attn_weights = attn_weights.masked_fill(key_mask.unsqueeze(1), float('-inf')) 84 | 85 | attn_weights = F.softmax(attn_weights.float(), dim=-1, 86 | dtype=torch.float32 if attn_weights.dtype == torch.float16 else attn_weights.dtype) 87 | 88 | # values [batch, query_length, dim] 89 | out = torch.bmm(attn_weights, value) 90 | out = F.dropout(self.fc(out), p=self.dropout, training=self.training) 91 | return out 92 | 93 | def init(self, query, key, key_mask=None, init_scale=1.0): 94 | with torch.no_grad(): 95 | return self(query, key, key_mask=key_mask) 96 | 97 | 98 | class MultiHeadAttention(nn.Module): 99 | """ 100 | Multi-head Attention 101 | """ 102 | def __init__(self, model_dim, heads, dropout=0.0, mask_diag=False): 103 | """ 104 | 105 | Args: 106 | model_dim: int 107 | the input dimension for keys, queries and values 108 | heads: int 109 | number of heads 110 | dropout: float 111 | dropout rate 112 | """ 113 | super(MultiHeadAttention, self).__init__() 114 | self.model_dim = model_dim 115 | self.head_dim = model_dim // heads 116 | self.heads = heads 117 | self.dropout = dropout 118 | self.mask_diag = mask_diag 119 | assert self.head_dim * heads == self.model_dim, "model_dim must be divisible by number of heads" 120 | self.scaling = self.head_dim ** -0.5 121 | self.in_proj_weight = Parameter(torch.empty(3 * model_dim, model_dim)) 122 | self.in_proj_bias = Parameter(torch.empty(3 * model_dim)) 123 | self.layer_norm = LayerNorm(model_dim) 124 | self.reset_parameters() 125 | 126 | def reset_parameters(self): 127 | # in proj 128 | nn.init.xavier_uniform_(self.in_proj_weight[:self.model_dim, :]) 129 | nn.init.xavier_uniform_(self.in_proj_weight[self.model_dim:(self.model_dim * 2), :]) 130 | nn.init.xavier_uniform_(self.in_proj_weight[(self.model_dim * 2):, :]) 131 | nn.init.constant_(self.in_proj_bias, 0.) 132 | 133 | def forward(self, query, key, value, key_mask=None): 134 | """ 135 | 136 | Args: 137 | query: Tenfor 138 | [batch, tgt_len, model_dim] 139 | key: Tensor 140 | [batch, src_len, model_dim] 141 | value: Tensor 142 | [batch, src_len, model_dim] 143 | key_mask: ByteTensor or None 144 | binary ByteTensor [batch, src_len] padding elements are indicated by 1s. 145 | 146 | Returns: 147 | 148 | """ 149 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 150 | kv_same = key.data_ptr() == value.data_ptr() 151 | 152 | bs, src_len, model_dim = key.size() 153 | tgt_len = query.size(1) 154 | heads = self.heads 155 | residual = query 156 | 157 | # k, v: [bs, src_len, model_dim] 158 | # q: [bs, tgt_len, model_dim] 159 | if qkv_same: 160 | # self-attention 161 | q, k, v = self._in_proj_qkv(query) 162 | elif kv_same: 163 | # encoder-decoder attention 164 | q = self._in_proj_q(query) 165 | k, v = self._in_proj_kv(key) 166 | else: 167 | q = self._in_proj_q(query) 168 | k = self._in_proj_k(key) 169 | v = self._in_proj_v(value) 170 | q *= self.scaling 171 | 172 | model_dim = q.size(2) 173 | dim = model_dim // heads 174 | 175 | # [len, batch, model_dim] -> [len, batch * heads, dim] -> [batch * heads, len, dim] 176 | q = q.transpose(0, 1).contiguous().view(tgt_len, bs * heads, dim).transpose(0, 1) 177 | k = k.transpose(0, 1).contiguous().view(src_len, bs * heads, dim).transpose(0, 1) 178 | v = v.transpose(0, 1).contiguous().view(src_len, bs * heads, dim).transpose(0, 1) 179 | 180 | # attention weights [batch * heads, tgt_len, src_len] 181 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 182 | if key_mask is not None: 183 | attn_weights = attn_weights.view(bs, heads, tgt_len, src_len) 184 | attn_weights = attn_weights.masked_fill(key_mask.unsqueeze(1).unsqueeze(2), float('-inf')) 185 | attn_weights = attn_weights.view(bs * heads, tgt_len, src_len) 186 | 187 | if self.mask_diag: 188 | assert tgt_len == src_len 189 | # [1, tgt_len, tgt_len] 190 | diag_mask = torch.eye(tgt_len, device=query.device, dtype=torch.uint8).unsqueeze(0) 191 | attn_weights = attn_weights.masked_fill(diag_mask, float('-inf')) 192 | 193 | attn_weights = F.softmax(attn_weights.float(), dim=-1, 194 | dtype=torch.float32 if attn_weights.dtype == torch.float16 else attn_weights.dtype) 195 | 196 | 197 | # outputs [batch * heads, tgt_len, dim] 198 | out = torch.bmm(attn_weights, v) 199 | # merge heads 200 | # [batch, heads, tgt_len, dim] -> [batch, tgt_len, heads, dim] 201 | # -> [batch, tgt_len, model_dim] 202 | out = out.view(bs, heads, tgt_len, dim).transpose(1, 2).contiguous().view(bs, tgt_len, model_dim) 203 | out = F.dropout(out, p=self.dropout, training=self.training) 204 | out = self.layer_norm(out + residual) 205 | return out 206 | 207 | def init(self, query, key, value, key_mask=None, init_scale=1.0): 208 | with torch.no_grad(): 209 | return self(query, key, value, key_mask=key_mask) 210 | 211 | def _in_proj_qkv(self, query): 212 | return self._in_proj(query).chunk(3, dim=-1) 213 | 214 | def _in_proj_kv(self, key): 215 | return self._in_proj(key, start=self.model_dim).chunk(2, dim=-1) 216 | 217 | def _in_proj_q(self, query): 218 | return self._in_proj(query, end=self.model_dim) 219 | 220 | def _in_proj_k(self, key): 221 | return self._in_proj(key, start=self.model_dim, end=2 * self.model_dim) 222 | 223 | def _in_proj_v(self, value): 224 | return self._in_proj(value, start=2 * self.model_dim) 225 | 226 | def _in_proj(self, input, start=0, end=None): 227 | weight = self.in_proj_weight 228 | bias = self.in_proj_bias 229 | weight = weight[start:end, :] 230 | if bias is not None: 231 | bias = bias[start:end] 232 | return F.linear(input, weight, bias) 233 | 234 | 235 | class PositionwiseFeedForward(nn.Module): 236 | def __init__(self, features, hidden_features, dropout=0.0): 237 | super(PositionwiseFeedForward, self).__init__() 238 | self.linear1 = nn.Linear(features, hidden_features) 239 | self.dropout = dropout 240 | self.linear2 = nn.Linear(hidden_features, features) 241 | self.layer_norm = LayerNorm(features) 242 | 243 | def forward(self, x): 244 | residual = x 245 | x = F.relu(self.linear1(x), inplace=True) 246 | x = F.dropout(x, p=self.dropout, training=self.training) 247 | x = F.dropout(self.linear2(x), p=self.dropout, training=self.training) 248 | x = self.layer_norm(residual + x) 249 | return x 250 | 251 | def init(self, x, init_scale=1.0): 252 | with torch.no_grad(): 253 | return self(x) 254 | -------------------------------------------------------------------------------- /flownmt/nnet/criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | 4 | 5 | class LabelSmoothedCrossEntropyLoss(nn.Module): 6 | """ 7 | Cross Entropy loss with label smoothing. 8 | For training, the loss is smoothed with parameter eps, while for evaluation, the smoothing is disabled. 9 | """ 10 | def __init__(self, label_smoothing): 11 | super(LabelSmoothedCrossEntropyLoss, self).__init__() 12 | self.eps = label_smoothing 13 | 14 | def forward(self, input, target): 15 | # [batch, c, d1, ..., dk] 16 | loss = F.log_softmax(input, dim=1) * -1. 17 | # [batch, d1, ..., dk] 18 | nll_loss = loss.gather(dim=1, index=target.unsqueeze(1)).squeeze(1) 19 | if self.training: 20 | # [batch, c, d1, ..., dk] 21 | inf_mask = loss.eq(float('inf')) 22 | # [batch, d1, ..., dk] 23 | smooth_loss = loss.masked_fill(inf_mask, 0.).sum(dim=1) 24 | eps_i = self.eps / (1.0 - inf_mask.float()).sum(dim=1) 25 | return nll_loss * (1. - self.eps) + smooth_loss * eps_i 26 | else: 27 | return nll_loss 28 | -------------------------------------------------------------------------------- /flownmt/nnet/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 6 | if not export and torch.cuda.is_available(): 7 | try: 8 | from apex.normalization import FusedLayerNorm 9 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 10 | except ImportError: 11 | pass 12 | return nn.LayerNorm(normalized_shape, eps, elementwise_affine) 13 | -------------------------------------------------------------------------------- /flownmt/nnet/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from flownmt.utils import make_positions 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """This module produces sinusoidal positional embeddings of any length. 10 | Padding symbols are ignored. 11 | """ 12 | 13 | def __init__(self, encoding_dim, padding_idx, init_size=1024): 14 | super().__init__() 15 | self.encoding_dim = encoding_dim 16 | self.padding_idx = padding_idx 17 | self.weights = PositionalEncoding.get_embedding( 18 | init_size, 19 | encoding_dim, 20 | padding_idx, 21 | ) 22 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 23 | 24 | @staticmethod 25 | def get_embedding(num_encodings, encoding_dim, padding_idx=None): 26 | """Build sinusoidal embeddings. 27 | This matches the implementation in tensor2tensor, but differs slightly 28 | from the description in Section 3.5 of "Attention Is All You Need". 29 | """ 30 | half_dim = encoding_dim // 2 31 | emb = math.log(10000) / (half_dim - 1) 32 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 33 | emb = torch.arange(num_encodings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 34 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_encodings, -1) 35 | if encoding_dim % 2 == 1: 36 | # zero pad 37 | emb = torch.cat([emb, torch.zeros(num_encodings, 1)], dim=1) 38 | emb[0, :] = 0 39 | return emb 40 | 41 | def forward(self, x): 42 | """Input is expected to be of size [bsz x seqlen].""" 43 | bsz, seq_len = x.size()[:2] 44 | max_pos = seq_len + 1 45 | if self.weights is None or max_pos > self.weights.size(0): 46 | # recompute/expand embeddings if needed 47 | self.weights = PositionalEncoding.get_embedding( 48 | max_pos, 49 | self.embedding_dim, 50 | self.padding_idx, 51 | ) 52 | self.weights = self.weights.type_as(self._float_tensor) 53 | 54 | if self.padding_idx is None: 55 | return self.weights[1:seq_len + 1].detach() 56 | else: 57 | positions = make_positions(x, self.padding_idx) 58 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 59 | 60 | def max_positions(self): 61 | """Maximum number of supported positions.""" 62 | return int(1e5) # an arbitrary large number 63 | -------------------------------------------------------------------------------- /flownmt/nnet/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from flownmt.nnet.attention import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | 6 | class TransformerEncoderLayer(nn.Module): 7 | def __init__(self, model_dim, hidden_dim, heads, dropout=0.0, mask_diag=False): 8 | super(TransformerEncoderLayer, self).__init__() 9 | self.slf_attn = MultiHeadAttention(model_dim, heads, dropout=dropout, mask_diag=mask_diag) 10 | self.pos_ffn = PositionwiseFeedForward(model_dim, hidden_dim, dropout=dropout) 11 | 12 | def forward(self, x, mask): 13 | out = self.slf_attn(x, x, x, key_mask=mask) 14 | out = self.pos_ffn(out) 15 | return out 16 | 17 | def init(self, x, mask, init_scale=1.0): 18 | out = self.slf_attn.init(x, x, x, key_mask=mask, init_scale=init_scale) 19 | out = self.pos_ffn.init(out, init_scale=init_scale) 20 | return out 21 | 22 | 23 | class TransformerDecoderLayer(nn.Module): 24 | def __init__(self, model_dim, hidden_dim, heads, dropout=0.0, mask_diag=False): 25 | super(TransformerDecoderLayer, self).__init__() 26 | self.slf_attn = MultiHeadAttention(model_dim, heads, dropout=dropout, mask_diag=mask_diag) 27 | self.enc_attn = MultiHeadAttention(model_dim, heads, dropout=dropout) 28 | self.pos_ffn = PositionwiseFeedForward(model_dim, hidden_dim, dropout=dropout) 29 | 30 | def forward(self, x, mask, src, src_mask): 31 | out = self.slf_attn(x, x, x, key_mask=mask) 32 | out = self.enc_attn(out, src, src, key_mask=src_mask) 33 | out = self.pos_ffn(out) 34 | return out 35 | 36 | def init(self, x, mask, src, src_mask, init_scale=1.0): 37 | out = self.slf_attn.init(x, x, x, key_mask=mask, init_scale=init_scale) 38 | out = self.enc_attn.init(out, src, src, key_mask=src_mask, init_scale=init_scale) 39 | out = self.pos_ffn.init(out, init_scale=init_scale) 40 | return out 41 | -------------------------------------------------------------------------------- /flownmt/nnet/weightnorm.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class LinearWeightNorm(nn.Module): 7 | """ 8 | Linear with weight normalization 9 | """ 10 | def __init__(self, in_features, out_features, bias=True): 11 | super(LinearWeightNorm, self).__init__() 12 | self.linear = nn.Linear(in_features, out_features, bias=bias) 13 | self.reset_parameters() 14 | 15 | def reset_parameters(self): 16 | nn.init.normal_(self.linear.weight, mean=0.0, std=0.05) 17 | if self.linear.bias is not None: 18 | nn.init.constant_(self.linear.bias, 0) 19 | self.linear = nn.utils.weight_norm(self.linear) 20 | 21 | def extra_repr(self): 22 | return 'in_features={}, out_features={}, bias={}'.format( 23 | self.in_features, self.out_features, self.bias is not None 24 | ) 25 | 26 | def init(self, x, init_scale=1.0): 27 | with torch.no_grad(): 28 | # [batch, out_features] 29 | out = self(x).view(-1, self.linear.out_features) 30 | # [out_features] 31 | mean = out.mean(dim=0) 32 | std = out.std(dim=0) 33 | inv_stdv = init_scale / (std + 1e-6) 34 | 35 | self.linear.weight_g.mul_(inv_stdv.unsqueeze(1)) 36 | if self.linear.bias is not None: 37 | self.linear.bias.add_(-mean).mul_(inv_stdv) 38 | return self(x) 39 | 40 | def forward(self, input): 41 | return self.linear(input) 42 | 43 | 44 | class Conv1dWeightNorm(nn.Module): 45 | """ 46 | Conv1d with weight normalization 47 | """ 48 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 49 | padding=0, dilation=1, groups=1, bias=True): 50 | super(Conv1dWeightNorm, self).__init__() 51 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, 52 | padding=padding, dilation=dilation, groups=groups, bias=bias) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | nn.init.normal_(self.conv.weight, mean=0.0, std=0.05) 57 | if self.conv.bias is not None: 58 | nn.init.constant_(self.conv.bias, 0) 59 | self.conv = nn.utils.weight_norm(self.conv) 60 | 61 | def init(self, x, init_scale=1.0): 62 | with torch.no_grad(): 63 | # [batch, n_channels, L] 64 | out = self(x) 65 | n_channels = out.size(1) 66 | out = out.transpose(0, 1).contiguous().view(n_channels, -1) 67 | # [n_channels] 68 | mean = out.mean(dim=1) 69 | std = out.std(dim=1) 70 | inv_stdv = init_scale / (std + 1e-6) 71 | 72 | self.conv.weight_g.mul_(inv_stdv.view(n_channels, 1, 1)) 73 | if self.conv.bias is not None: 74 | self.conv.bias.add_(-mean).mul_(inv_stdv) 75 | return self(x) 76 | 77 | def forward(self, input): 78 | return self.conv(input) 79 | 80 | @overrides 81 | def extra_repr(self): 82 | return self.conv.extra_repr() 83 | -------------------------------------------------------------------------------- /flownmt/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from flownmt.optim.adamw import AdamW 2 | from flownmt.optim.lr_scheduler import InverseSquareRootScheduler, ExponentialScheduler 3 | -------------------------------------------------------------------------------- /flownmt/optim/adamw.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class AdamW(Optimizer): 7 | r"""Implements AdamW algorithm. 8 | This implementation is modified from torch.optim.Adam based on: 9 | `Fixed Weight Decay Regularization in Adam` 10 | (see https://arxiv.org/abs/1711.05101) 11 | 12 | Adam has been proposed in `Adam: A Method for Stochastic Optimization`_. 13 | 14 | Arguments: 15 | params (iterable): iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr (float, optional): learning rate (default: 1e-3) 18 | betas (Tuple[float, float], optional): coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | eps (float, optional): term added to the denominator to improve 21 | numerical stability (default: 1e-8) 22 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 23 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 24 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 25 | (default: False) 26 | 27 | .. _Adam\: A Method for Stochastic Optimization: 28 | https://arxiv.org/abs/1412.6980 29 | .. _On the Convergence of Adam and Beyond: 30 | https://openreview.net/forum?id=ryQu7f-RZ 31 | """ 32 | 33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 34 | weight_decay=0, amsgrad=False): 35 | if not 0.0 <= lr: 36 | raise ValueError("Invalid learning rate: {}".format(lr)) 37 | if not 0.0 <= eps: 38 | raise ValueError("Invalid epsilon value: {}".format(eps)) 39 | if not 0.0 <= betas[0] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 41 | if not 0.0 <= betas[1] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 43 | defaults = dict(lr=lr, betas=betas, eps=eps, 44 | weight_decay=weight_decay, amsgrad=amsgrad) 45 | super(AdamW, self).__init__(params, defaults) 46 | 47 | def __setstate__(self, state): 48 | super(AdamW, self).__setstate__(state) 49 | for group in self.param_groups: 50 | group.setdefault('amsgrad', False) 51 | 52 | def step(self, closure=None): 53 | """Performs a single optimization step. 54 | 55 | Arguments: 56 | closure (callable, optional): A closure that reevaluates the model 57 | and returns the loss. 58 | """ 59 | loss = None 60 | if closure is not None: 61 | loss = closure() 62 | 63 | for group in self.param_groups: 64 | for p in group['params']: 65 | if p.grad is None: 66 | continue 67 | grad = p.grad.data 68 | if grad.is_sparse: 69 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 70 | amsgrad = group['amsgrad'] 71 | 72 | state = self.state[p] 73 | 74 | # State initialization 75 | if len(state) == 0: 76 | state['step'] = 0 77 | # Exponential moving average of gradient values 78 | state['exp_avg'] = torch.zeros_like(p.data) 79 | # Exponential moving average of squared gradient values 80 | state['exp_avg_sq'] = torch.zeros_like(p.data) 81 | if amsgrad: 82 | # Maintains max of all exp. moving avg. of sq. grad. values 83 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 84 | 85 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 86 | if amsgrad: 87 | max_exp_avg_sq = state['max_exp_avg_sq'] 88 | beta1, beta2 = group['betas'] 89 | 90 | state['step'] += 1 91 | 92 | # Decay the first and second moment running average coefficient 93 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 94 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 95 | if amsgrad: 96 | # Maintains the maximum of all 2nd moment running avg. till now 97 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 98 | # Use the max. for normalizing running avg. of gradient 99 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 100 | else: 101 | denom = exp_avg_sq.sqrt().add_(group['eps']) 102 | 103 | bias_correction1 = 1 - beta1 ** state['step'] 104 | bias_correction2 = 1 - beta2 ** state['step'] 105 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 106 | 107 | if group['weight_decay'] != 0: 108 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 109 | 110 | p.data.addcdiv_(-step_size, exp_avg, denom) 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /flownmt/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.optimizer import Optimizer 2 | 3 | 4 | class _LRScheduler(object): 5 | def __init__(self, optimizer, last_epoch=-1): 6 | if not isinstance(optimizer, Optimizer): 7 | raise TypeError('{} is not an Optimizer'.format( 8 | type(optimizer).__name__)) 9 | self.optimizer = optimizer 10 | if last_epoch == -1: 11 | for group in optimizer.param_groups: 12 | group.setdefault('initial_lr', group['lr']) 13 | last_epoch = 0 14 | else: 15 | for i, group in enumerate(optimizer.param_groups): 16 | if 'initial_lr' not in group: 17 | raise KeyError("param 'initial_lr' is not specified " 18 | "in param_groups[{}] when resuming an optimizer".format(i)) 19 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 20 | 21 | def state_dict(self): 22 | """Returns the state of the scheduler as a :class:`dict`. 23 | 24 | It contains an entry for every variable in self.__dict__ which 25 | is not the optimizer. 26 | """ 27 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 28 | 29 | def load_state_dict(self, state_dict): 30 | """Loads the schedulers state. 31 | 32 | Arguments: 33 | state_dict (dict): scheduler state. Should be an object returned 34 | from a call to :meth:`state_dict`. 35 | """ 36 | self.__dict__.update(state_dict) 37 | 38 | def get_lr(self): 39 | raise NotImplementedError 40 | 41 | def step(self, epoch=None): 42 | if epoch is None: 43 | epoch = self.last_epoch + 1 44 | self.last_epoch = epoch 45 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 46 | param_group['lr'] = lr 47 | 48 | 49 | class InverseSquareRootScheduler(_LRScheduler): 50 | """ 51 | Decay the LR based on the inverse square root of the update number. 52 | We also support a warmup phase where we linearly increase the learning rate 53 | from zero until the configured learning rate (``--lr``). 54 | Thereafter we decay proportional to the number of 55 | updates, with a decay factor set to align with the configured learning rate. 56 | During warmup:: 57 | lrs = torch.linspace(0, args.lr, args.warmup_updates) 58 | lr = lrs[update_num] 59 | After warmup:: 60 | decay_factor = args.lr * sqrt(args.warmup_updates) 61 | lr = decay_factor / sqrt(update_num) 62 | """ 63 | def __init__(self, optimizer, warmup_steps, init_lr, last_epoch=-1): 64 | assert warmup_steps > 0, 'warmup steps should be larger than 0.' 65 | super(InverseSquareRootScheduler, self).__init__(optimizer, last_epoch) 66 | self.warmup_steps = float(warmup_steps) 67 | self.init_lr = init_lr 68 | self.lr_steps = [(base_lr - init_lr) / warmup_steps for base_lr in self.base_lrs] 69 | self.decay_factor = self.warmup_steps ** 0.5 70 | if last_epoch == -1: 71 | last_epoch = 0 72 | self.step(last_epoch) 73 | 74 | def get_lr(self): 75 | if self.last_epoch < self.warmup_steps: 76 | return [self.init_lr + lr_step * self.last_epoch for lr_step in self.lr_steps] 77 | else: 78 | lr_factor = self.decay_factor * self.last_epoch**-0.5 79 | return [base_lr * lr_factor for base_lr in self.base_lrs] 80 | 81 | 82 | class ExponentialScheduler(_LRScheduler): 83 | """Set the learning rate of each parameter group to the initial lr decayed 84 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr. 85 | We also support a warmup phase where we linearly increase the learning rate 86 | from zero until the configured learning rate (``--lr``). 87 | Args: 88 | optimizer (Optimizer): Wrapped optimizer. 89 | gamma (float): Multiplicative factor of learning rate decay. 90 | warmup_steps (int): Warmup steps.. 91 | last_epoch (int): The index of last epoch. Default: -1. 92 | """ 93 | 94 | def __init__(self, optimizer, gamma, warmup_steps, init_lr, last_epoch=-1): 95 | super(ExponentialScheduler, self).__init__(optimizer, last_epoch) 96 | self.gamma = gamma 97 | # handle warmup <= 0 98 | self.warmup_steps = max(1, warmup_steps) 99 | self.init_lr = init_lr 100 | self.lr_steps = [(base_lr - init_lr) / self.warmup_steps for base_lr in self.base_lrs] 101 | if last_epoch == -1: 102 | last_epoch = 0 103 | self.step(last_epoch) 104 | 105 | def get_lr(self): 106 | if self.last_epoch < self.warmup_steps: 107 | return [self.init_lr + lr_step * self.last_epoch for lr_step in self.lr_steps] 108 | else: 109 | lr_factor = self.gamma ** (self.last_epoch - self.warmup_steps) 110 | return [base_lr * lr_factor for base_lr in self.base_lrs] 111 | -------------------------------------------------------------------------------- /flownmt/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from typing import Tuple, List 4 | import torch 5 | from torch._six import inf 6 | 7 | 8 | def get_logger(name, level=logging.INFO, handler=sys.stdout, 9 | formatter='%(asctime)s - %(name)s - %(levelname)s - %(message)s'): 10 | logger = logging.getLogger(name) 11 | logger.setLevel(logging.INFO) 12 | formatter = logging.Formatter(formatter) 13 | stream_handler = logging.StreamHandler(handler) 14 | stream_handler.setLevel(level) 15 | stream_handler.setFormatter(formatter) 16 | logger.addHandler(stream_handler) 17 | return logger 18 | 19 | 20 | def norm(p: torch.Tensor, dim: int): 21 | """Computes the norm over all dimensions except dim""" 22 | if dim is None: 23 | return p.norm() 24 | elif dim == 0: 25 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 26 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size) 27 | elif dim == p.dim() - 1: 28 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 29 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size) 30 | else: 31 | return norm(p.transpose(0, dim), 0).transpose(0, dim) 32 | 33 | 34 | def exponentialMovingAverage(original, shadow, decay_rate, init=False): 35 | params = dict() 36 | for name, param in shadow.named_parameters(): 37 | params[name] = param 38 | for name, param in original.named_parameters(): 39 | shadow_param = params[name] 40 | if init: 41 | shadow_param.data.copy_(param.data) 42 | else: 43 | shadow_param.data.add_((1 - decay_rate) * (param.data - shadow_param.data)) 44 | 45 | 46 | def logPlusOne(x): 47 | """ 48 | compute log(x + 1) for small x 49 | Args: 50 | x: Tensor 51 | Returns: Tensor 52 | log(x+1) 53 | """ 54 | eps = 1e-4 55 | mask = x.abs().le(eps).type_as(x) 56 | return x.mul(x.mul(-0.5) + 1.0) * mask + (x + 1.0).log() * (1.0 - mask) 57 | 58 | 59 | def gate(x1, x2): 60 | return x1 * x2.sigmoid_() 61 | 62 | 63 | def total_grad_norm(parameters, norm_type=2): 64 | if isinstance(parameters, torch.Tensor): 65 | parameters = [parameters] 66 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 67 | norm_type = float(norm_type) 68 | if norm_type == inf: 69 | total_norm = max(p.grad.data.abs().max() for p in parameters) 70 | else: 71 | total_norm = 0 72 | for p in parameters: 73 | param_norm = p.grad.data.norm(norm_type) 74 | total_norm += param_norm.item() ** norm_type 75 | total_norm = total_norm ** (1. / norm_type) 76 | return total_norm 77 | 78 | 79 | def squeeze(x: torch.Tensor, mask: torch.Tensor, factor: int = 2) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Args: 82 | x: Tensor 83 | input tensor [batch, length, features] 84 | mask: Tensor 85 | mask tensor [batch, length] 86 | factor: int 87 | squeeze factor (default 2) 88 | Returns: Tensor1, Tensor2 89 | squeezed x [batch, length // factor, factor * features] 90 | squeezed mask [batch, length // factor] 91 | """ 92 | assert factor >= 1 93 | if factor == 1: 94 | return x 95 | 96 | batch, length, features = x.size() 97 | assert length % factor == 0 98 | # [batch, length // factor, factor * features] 99 | x = x.contiguous().view(batch, length // factor, factor * features) 100 | mask = mask.view(batch, length // factor, factor).sum(dim=2).clamp(max=1.0) 101 | return x, mask 102 | 103 | 104 | def unsqueeze(x: torch.Tensor, factor: int = 2) -> torch.Tensor: 105 | """ 106 | Args: 107 | x: Tensor 108 | input tensor [batch, length, features] 109 | factor: int 110 | unsqueeze factor (default 2) 111 | Returns: Tensor 112 | squeezed tensor [batch, length * 2, features // 2] 113 | """ 114 | assert factor >= 1 115 | if factor == 1: 116 | return x 117 | 118 | batch, length, features = x.size() 119 | assert features % factor == 0 120 | # [batch, length * factor, features // factor] 121 | x = x.view(batch, length * factor, features // factor) 122 | return x 123 | 124 | 125 | def split(x: torch.Tensor, z1_features) -> Tuple[torch.Tensor, torch.Tensor]: 126 | """ 127 | Args: 128 | x: Tensor 129 | input tensor [batch, length, features] 130 | z1_features: int 131 | the number of features of z1 132 | Returns: Tensor, Tensor 133 | split tensors [batch, length, z1_features], [batch, length, features-z1_features] 134 | """ 135 | z1 = x[:, :, :z1_features] 136 | z2 = x[:, :, z1_features:] 137 | return z1, z2 138 | 139 | 140 | def unsplit(xs: List[torch.Tensor]) -> torch.Tensor: 141 | """ 142 | Args: 143 | xs: List[Tensor] 144 | tensors to be combined 145 | Returns: Tensor 146 | combined tensor 147 | """ 148 | return torch.cat(xs, dim=2) 149 | 150 | 151 | def make_positions(tensor, padding_idx): 152 | """Replace non-padding symbols with their position numbers. 153 | Position numbers begin at padding_idx+1. Padding symbols are ignored. 154 | """ 155 | mask = tensor.ne(padding_idx).long() 156 | return torch.cumsum(mask, dim=1) * mask 157 | 158 | 159 | # def prepare_rnn_seq(rnn_input, lengths, batch_first=False): 160 | # ''' 161 | # Args: 162 | # rnn_input: [seq_len, batch, input_size]: tensor containing the features of the input sequence. 163 | # lengths: [batch]: tensor containing the lengthes of the input sequence 164 | # batch_first: If True, then the input and output tensors are provided as [batch, seq_len, feature]. 165 | # Returns: 166 | # ''' 167 | # 168 | # def check_decreasing(lengths): 169 | # lens, order = torch.sort(lengths, dim=0, descending=True) 170 | # if torch.ne(lens, lengths).sum() == 0: 171 | # return None 172 | # else: 173 | # _, rev_order = torch.sort(order) 174 | # return lens, order, rev_order 175 | # 176 | # check_res = check_decreasing(lengths) 177 | # 178 | # if check_res is None: 179 | # lens = lengths 180 | # rev_order = None 181 | # else: 182 | # lens, order, rev_order = check_res 183 | # batch_dim = 0 if batch_first else 1 184 | # rnn_input = rnn_input.index_select(batch_dim, order) 185 | # lens = lens.tolist() 186 | # seq = pack_padded_sequence(rnn_input, lens, batch_first=batch_first) 187 | # return seq, rev_order 188 | # 189 | # def recover_rnn_seq(seq, rev_order, batch_first=False, total_length=None): 190 | # output, _ = pad_packed_sequence(seq, batch_first=batch_first, total_length=total_length) 191 | # if rev_order is not None: 192 | # batch_dim = 0 if batch_first else 1 193 | # output = output.index_select(batch_dim, rev_order) 194 | # return output 195 | # 196 | # 197 | # def recover_order(tensors, rev_order): 198 | # if rev_order is None: 199 | # return tensors 200 | # recovered_tensors = [tensor.index_select(0, rev_order) for tensor in tensors] 201 | # return recovered_tensors 202 | # 203 | # 204 | # def decreasing_order(lengths, tensors): 205 | # def check_decreasing(lengths): 206 | # lens, order = torch.sort(lengths, dim=0, descending=True) 207 | # if torch.ne(lens, lengths).sum() == 0: 208 | # return None 209 | # else: 210 | # _, rev_order = torch.sort(order) 211 | # return lens, order, rev_order 212 | # 213 | # check_res = check_decreasing(lengths) 214 | # 215 | # if check_res is None: 216 | # lens = lengths 217 | # rev_order = None 218 | # ordered_tensors = tensors 219 | # else: 220 | # lens, order, rev_order = check_res 221 | # ordered_tensors = [tensor.index_select(0, order) for tensor in tensors] 222 | # 223 | # return lens, ordered_tensors, rev_order 224 | -------------------------------------------------------------------------------- /images/flowseq_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuezheMax/flowseq/8cb4ae00c26fbeb3e1459e3b3b90e7e9a84c3d2b/images/flowseq_diagram.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python = 3.6 2 | numpy 3 | overrides 4 | --------------------------------------------------------------------------------