├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── dataset.py ├── download_data.sh ├── get_model.sh ├── img └── model.png ├── models ├── __init__.py ├── blm.py ├── get_canvas.cpp ├── inst.py ├── lblm.py ├── lm.py └── torch_utils.py ├── optimizer.py ├── test.py ├── train.py ├── transformer ├── Beam.py ├── Constants.py ├── Layers.py ├── Models.py ├── Modules.py ├── Optim.py ├── SubLayers.py ├── Translator.py └── __init__.py ├── utils.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | rsync_exclude.txt 3 | figs/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # vim 102 | .vim 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | 113 | # ddp 114 | __temp_weight_ddp_end.ckpt 115 | 116 | checkpoints/ 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Blank Language Models 2 | 3 | This repository contains the code for our EMNLP 2020 paper: 4 | [**Blank Language Models**](https://arxiv.org/abs/2002.03079) 5 | *Tianxiao Shen*, Victor Quach*, Regina Barzilay, and Tommi Jaakkola (*: Equal contribution)* 6 | 7 |
8 | 9 | Given partially specified text with one or more blanks, BLM will fill in the blanks with a variable number of tokens consistent with the context, making it ideal for text editing and rewriting. 10 | 11 | > Input: They also have \___ which \___ . 12 | > Output: They also have ice cream which is really good . 13 | 14 |
15 | 16 |

17 | 18 | 19 | ## Demo 20 | 21 | We have an online demo built using [streamlit](https://www.streamlit.io/), available [here](http://128.52.131.173:8501) 22 | 23 | Or try locally by running: 24 | 25 | ``` 26 | streamlit run app.py 27 | ``` 28 | 29 | 30 | ## Dependencies 31 | 32 | Our code is based on the [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) framework. 33 | 34 | It has been tested in PyTorch 1.6.0, PyTorch Lightning 1.0.7 35 | 36 | 37 | ## Download Data 38 | 39 | Download the processed Yelp and Yahoo datasets by running: 40 | ``` 41 | bash download_data.sh 42 | ``` 43 | 44 | 45 | ## Training 46 | 47 | To train a BLM on Yelp negative sentences: 48 | ``` 49 | python train.py --train data/yelp/train.0 --valid data/yelp/valid.0 --root_dir checkpoints/yelp/neg/blm/ \ 50 | --vocab_size 10000 --max_len 20 --model_type blm --share_emb_prj_weight 51 | ``` 52 | 53 | Yelp positive sentences: 54 | ``` 55 | python train.py --train data/yelp/train.1 --valid data/yelp/valid.1 --root_dir checkpoints/yelp/pos/blm/ \ 56 | --vocab_size 10000 --max_len 20 --model_type blm --share_emb_prj_weight 57 | ``` 58 | 59 | Yahoo documents: 60 | ``` 61 | python train.py --train data/yahoo/train.txt --valid data/yahoo/valid.txt --root_dir checkpoints/yahoo/blm/ \ 62 | --vocab_size 20000 --max_len 205 --model_type blm --share_emb_prj_weight 63 | ``` 64 | 65 | Run `python train.py -h` to see all training options. 66 | 67 | You can use Tensorboard to monitor the training progress. 68 | 69 | 70 | ## Testing 71 | 72 | After training, we can evaluate the model's perplexity by Monte Carlo estimate, and use the model to generate text from scratch or fill in the blanks in the input. 73 | 74 | For all of the following, replace `epoch\=???.ckpt` with the checkpoint saved in training. 75 | 76 | - The following command evaluates for Yelp negative sentences: 77 | 78 | ``` 79 | python test.py --checkpoint checkpoints/yelp/neg/blm/lightning_logs/version_0/checkpoints/epoch\=???.ckpt \ 80 | --eval data/yelp/test.0 --n_mc 10 81 | ``` 82 | 83 | - The following command samples from the model trained on Yelp negative sentences: 84 | 85 | ``` 86 | python test.py --checkpoint checkpoints/yelp/neg/blm/lightning_logs/version_0/checkpoints/epoch\=???.ckpt \ 87 | --sample 1000 --decode sample --output sample.txt 88 | ``` 89 | 90 | - The following command uses the model trained on Yelp negative sentences to fill in blanked positive sentences to achieve sentiment transfer: 91 | 92 | ``` 93 | python test.py --checkpoint checkpoints/yelp/neg/blm/lightning_logs/version_0/checkpoints/epoch\=???.ckpt \ 94 | --fill data/yelp/blank/test.1.blank --output test.1.tsf 95 | ``` 96 | 97 | To output the whole generation trajectory, turn on the `--write_mid` option. 98 | 99 | The output file will be stored in `outputs/` within the checkpoint directory. 100 | 101 | 102 | ## Acknowledgements 103 | 104 | We use the Transformer implementation from https://github.com/jadore801120/attention-is-all-you-need-pytorch 105 | 106 | 107 | ## Citation 108 | 109 | If you use our work, please cite: 110 | 111 | ```bibtex 112 | @inproceedings{shen2020blank, 113 | title = "{Blank Language Models}", 114 | author = "Shen, Tianxiao and 115 | Quach, Victor and 116 | Barzilay, Regina and 117 | Jaakkola, Tommi", 118 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing", 119 | month = nov, 120 | year = "2020", 121 | address = "Online", 122 | publisher = "Association for Computational Linguistics" 123 | } 124 | ``` 125 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | 5 | import torch 6 | import pytorch_lightning as pl 7 | import streamlit as st 8 | 9 | from vocab import Vocab 10 | from utils import load_data, load_sent, load_model, makedir, write 11 | from dataset import get_eval_dataloader 12 | 13 | 14 | st.sidebar.write("## Parameters") 15 | 16 | device = st.sidebar.selectbox("Device", 17 | ["cpu"] + ["cuda:{}".format(i) for i in range(torch.cuda.device_count())], 18 | 0, 19 | lambda key: key if key == "cpu" else "GPU {}".format(key) 20 | ) 21 | 22 | st.write('# Blank Language Models: Demo') 23 | 24 | st.write("[View source](http://github.com/varal7/blank_language_model)") 25 | 26 | st.write('## Load model') 27 | 28 | yelp_neg = "checkpoints/yelp/neg/lightning_logs/version_0/checkpoints/model.ckpt" 29 | yelp_pos = "checkpoints/yelp/pos/lightning_logs/version_0/checkpoints/model.ckpt" 30 | 31 | if not os.path.exists(yelp_neg) or not os.path.exists(yelp_pos): 32 | st.write(":warning: Default models not found. Run `get_model.sh` to download models trained on Yelp.") 33 | checkpoint = st.radio("Load checkpoint", ("Custom model", )) 34 | 35 | else: 36 | checkpoint = st.radio("Load checkpoint", ("Yelp positive reviews", "Yelp negative reviews", "Custom model")) 37 | 38 | if checkpoint == "Custom model": 39 | checkpoint_file = st.text_input("Path to `model.ckpt` file", value=yelp_pos) 40 | else: 41 | checkpoint_file = yelp_pos if "Yelp positive" in checkpoint else yelp_neg 42 | 43 | @st.cache 44 | def get_model(checkpoint_file, device): 45 | model = load_model(checkpoint_file).to(device) 46 | model.eval() 47 | vocab = Vocab(os.path.join(model.hparams.root_dir, 'vocab.txt')) 48 | return model, vocab 49 | 50 | 51 | model, vocab = get_model(checkpoint_file, device) 52 | 53 | decode = st.sidebar.radio("Decoding", ("Greedy", "Sample")).lower() 54 | 55 | mode = st.sidebar.radio("Task", ('Infilling', 'Sample')) 56 | 57 | 58 | if mode == "Sample": 59 | _, full = model.generate([model.init_canvas()], decode, device) 60 | full = [[vocab.idx2word[id] for id in ids] for ids in full] 61 | for step in full: 62 | st.write(" ".join(step).replace("", "\_\_\_")) 63 | 64 | if mode == "Infilling": 65 | st.write('## Load infilling data') 66 | text_input = st.text_input("Blanked input", value="___ place ___ and ___ food ___ .").lower() 67 | s = text_input.replace("___", "").split() 68 | s += [''] if model.hparams.add_eos else [] 69 | s = [vocab.word_to_idx(w) for w in s] 70 | _, full = model.generate(s, decode, device) 71 | full = [[vocab.idx2word[id] for id in ids] for ids in full] 72 | for step in full: 73 | st.write(" ".join(step).replace("", "\_\_\_")) 74 | 75 | if st.button("Rerun"): 76 | pass 77 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vocab import Vocab 4 | 5 | 6 | def get_batch(x, vocab, append_at_ends=False): 7 | seq = [] 8 | n = [len(s) for s in x] 9 | n_real = [] 10 | max_len = max(n) 11 | for s, l in zip(x, n): 12 | # Combine BPE tokens to count the number of words 13 | n_real.append(l - sum(1 for t in s if t.endswith("@@"))) 14 | 15 | s_idx = [vocab.word_to_idx(w) for w in s] 16 | if append_at_ends: 17 | s_idx = [Vocab.first] + s_idx + [Vocab.last] 18 | seq.append(s_idx + [Vocab.pad] * (max_len - l)) 19 | return torch.LongTensor(seq), torch.LongTensor(n), torch.LongTensor(n_real) 20 | 21 | 22 | def get_batches(data, vocab, max_tok, append_at_ends=False, same_len=False): 23 | offset = 2 if append_at_ends else 0 24 | 25 | order = range(len(data)) 26 | z = sorted(zip(order, data), key=lambda i: len(i[1]), reverse=True) 27 | order, data = zip(*z) 28 | 29 | batches = [] 30 | i = 0 31 | while i < len(data): 32 | j = i 33 | while j < len(data) and (len(data[i]) + offset) * (j-i+1) <= max_tok \ 34 | and (not same_len or len(data[j]) == len(data[i])): 35 | j += 1 36 | batches.append(get_batch(data[i: j], vocab, append_at_ends)) 37 | i = j 38 | return batches, order 39 | 40 | 41 | class LMDataset(torch.utils.data.Dataset): 42 | def __init__(self, batches): 43 | self.batches = batches 44 | 45 | def __getitem__(self, idx): 46 | return self.batches[idx] 47 | 48 | def __len__(self): 49 | return len(self.batches) 50 | 51 | 52 | def get_train_dataloader(train_data, vocab, max_tok, data_workers=8, 53 | model_type=None): 54 | train_batches, _ = get_batches(train_data, vocab, max_tok, 55 | append_at_ends=(model_type == 'inst')) 56 | print("Number of train batches: {}".format(len(train_batches))) 57 | train_ds = LMDataset(train_batches) 58 | return torch.utils.data.DataLoader(train_ds, num_workers=data_workers, 59 | shuffle=True, pin_memory=True) 60 | 61 | 62 | def get_eval_dataloader(val_data, vocab, max_tok, data_workers=8, 63 | model_type=None): 64 | val_batches, _ = get_batches(val_data, vocab, max_tok, same_len=True, 65 | append_at_ends=(model_type == 'inst')) 66 | print("Number of eval batches: {}".format(len(val_batches))) 67 | val_ds = LMDataset(val_batches) 68 | return torch.utils.data.DataLoader(val_ds, num_workers=data_workers, 69 | pin_memory=True) 70 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | cd data 3 | 4 | dir="http://people.csail.mit.edu/tianxiao/data" 5 | 6 | wget $dir/yelp_blm.zip 7 | unzip yelp_blm.zip 8 | rm yelp_blm.zip 9 | 10 | wget $dir/yahoo_blm.zip 11 | unzip yahoo_blm.zip 12 | rm yahoo_blm.zip 13 | -------------------------------------------------------------------------------- /get_model.sh: -------------------------------------------------------------------------------- 1 | mkdir -p checkpoints/yelp 2 | cd checkpoints/yelp 3 | wget http://128.52.131.173:8000/yelp/neg.tgz -O neg.tgz 4 | wget http://128.52.131.173:8000/yelp/pos.tgz -O pos.tgz 5 | tar -xf neg.tgz 6 | tar -xf pos.tgz 7 | -------------------------------------------------------------------------------- /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Varal7/blank_language_model/0ab36a6a6e5272683ba7b059c16035b0c9d00ef0/img/model.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . inst import InsTLM 2 | from . blm import BLM 3 | from . lblm import LBLM 4 | 5 | 6 | def get_model_class(model_type): 7 | if model_type == 'blm': 8 | return BLM 9 | elif model_type == 'inst': 10 | return InsTLM 11 | elif model_type == 'lblm': 12 | return LBLM 13 | else: 14 | raise ValueError('Unknown model ' + model_type) 15 | -------------------------------------------------------------------------------- /models/blm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from . lm import LM 7 | from . torch_utils import get_canvas, sample_permutation, seq_cross_entropy, collect, batch_randint, select 8 | from vocab import Vocab 9 | 10 | 11 | class BLM(LM): 12 | """Blank Language Model""" 13 | 14 | def __init__(self, hparams): 15 | super().__init__(hparams) 16 | hparams = self.hparams # a['key'] (if so) -> a.key 17 | 18 | self.lrb = nn.Sequential( 19 | nn.Linear(hparams.d_model * 2, hparams.d_model * 2), 20 | nn.ReLU(), 21 | nn.Linear(hparams.d_model * 2, 4) 22 | ) 23 | 24 | def init_canvas(self): 25 | return Vocab.blank 26 | 27 | def get_loss(self, seq, canvas, blanks, rest, loc, lb, rb): 28 | count = (rest != -1).sum(1) 29 | output = self.forward_encoder(canvas) 30 | output_blank = collect(output, blanks) 31 | 32 | logits_loc = self.loc(output_blank).squeeze(-1) 33 | logits_loc[blanks == -1] = float('-inf') 34 | nll_loc = -F.log_softmax(logits_loc, 1) 35 | loss_loc = collect(nll_loc, loc) 36 | loss_loc = loss_loc.sum(1) / count.float() 37 | output_loc = collect(output_blank, loc) 38 | 39 | logits_word = self.word(output_loc) * self.x_logit_scale 40 | target = collect(seq, rest, Vocab.pad) 41 | loss_word = seq_cross_entropy(logits_word, target, Vocab.pad) 42 | loss_word = loss_word.sum(1) / count.float() 43 | output_word = torch.cat((output_loc, self.enc.src_word_emb(target)), -1) 44 | 45 | logits_lrb = self.lrb(output_word) 46 | loss_lrb = seq_cross_entropy(logits_lrb, lb * 2 + rb, -3) 47 | loss_lrb = loss_lrb.sum(1) / count.float() 48 | 49 | return loss_loc, loss_word, loss_lrb 50 | 51 | def losses(self, seq, n, n_real): 52 | """ 53 | Args: 54 | n: number of BPE tokens 55 | n_real: number of real words (for reporting PPL) 56 | """ 57 | k = batch_randint(0, n - 1) 58 | rank = sample_permutation(seq) 59 | keep = (rank < k.unsqueeze(1)) 60 | canvas, blanks, rest, loc, lb, rb = get_canvas(seq, keep, n) 61 | loss_loc, loss_word, loss_lrb = self.get_loss(seq, canvas, blanks, rest, loc, lb, rb) 62 | nll_lb = (loss_loc + loss_word + loss_lrb) * n.float() - (n + 1).float().lgamma() 63 | return {'loss': nll_lb.sum() / n_real.sum(), 64 | 'loc': loss_loc.mean(), 65 | 'word': loss_word.mean(), 66 | 'lrb': loss_lrb.mean() 67 | } 68 | 69 | def nll_mc(self, seq, n, m): 70 | """ 71 | Compute negative log-likelihood by monte carlo estimate 72 | Args: 73 | m: number of samples to take 74 | 75 | Note: sentences in the batch must have the same length 76 | """ 77 | a = [] 78 | for _ in range(m): 79 | rank = sample_permutation(seq) 80 | logp = 0. 81 | for k in range(seq.size(1)): 82 | keep = (rank < k) 83 | canvas, blanks, rest, loc, lb, rb = get_canvas(seq, keep, n) 84 | k_th = (rank == k).nonzero(as_tuple=True)[1] 85 | x, y = (rest == k_th.unsqueeze(1)).nonzero(as_tuple=True) 86 | assert torch.all(x == torch.arange(len(seq), device=seq.device)) 87 | rest, loc, lb, rb = [t[x, y].unsqueeze(1) for t in [rest, loc, lb, rb]] 88 | loss_loc, loss_word, loss_lrb = self.get_loss(seq, canvas, blanks, rest, loc, lb, rb) 89 | logp -= loss_loc + loss_word + loss_lrb 90 | a.append(logp.unsqueeze(1)) 91 | return np.log(m) - (n + 1).float().lgamma() - torch.logsumexp(torch.cat(a, 1), 1) 92 | 93 | def generate(self, seq, decode, device): 94 | seq = torch.LongTensor(seq).to(device) 95 | blanks = [i for i, w in enumerate(seq) if w == Vocab.blank] 96 | is_fill = [0] * len(seq) 97 | fill = [[]] 98 | full = [seq] 99 | while len(blanks) > 0 and len(seq) <= self.hparams.max_len: 100 | output = self.forward_encoder(seq.unsqueeze(0))[0] 101 | output_blank = output[blanks] 102 | loc = select(self.loc(output_blank).squeeze(-1), decode) 103 | output_loc = output_blank[loc] 104 | 105 | logits_word = self.word(output_loc) * self.x_logit_scale 106 | logits_word[Vocab.blank] = float('-inf') # never predict 107 | 108 | # joint word, lrb prediction 109 | lprob_word = F.log_softmax(logits_word, -1) 110 | output_word = torch.cat((output_loc.unsqueeze(0).expand(self.hparams.vocab_size, -1), 111 | self.enc.src_word_emb.weight), -1) 112 | logits_lrb = self.lrb(output_word) 113 | lprob_lrb = F.log_softmax(logits_lrb, -1) 114 | lprob_word_lrb = lprob_word.unsqueeze(1) + lprob_lrb 115 | word_lrb = select(lprob_word_lrb.view(-1), decode) 116 | word, lrb = word_lrb // 4, word_lrb % 4 117 | 118 | # predict word first and then lrb 119 | # word = select(logits_word, decode) 120 | # output_word = torch.cat((output_loc, self.enc.src_word_emb(word)), dim=-1) 121 | # lrb = select(self.lrb(output_word), decode) 122 | 123 | lb, rb = lrb // 2, lrb % 2 124 | ins = ([Vocab.blank] if lb else []) + [word] + ([Vocab.blank] if rb else []) 125 | ins = torch.LongTensor(ins).to(device) 126 | pos = blanks[loc] 127 | seq = torch.cat((seq[:pos], ins, seq[pos + 1:])) 128 | blanks = [i for i, w in enumerate(seq) if w == Vocab.blank] 129 | is_fill = is_fill[:pos] + [1] * len(ins) + is_fill[pos + 1:] 130 | fill.append([id for id, isf in zip(seq, is_fill) if isf]) 131 | full.append(seq) 132 | return fill, full 133 | -------------------------------------------------------------------------------- /models/get_canvas.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | using namespace std; 6 | 7 | vector>> get_insertion_canvas( 8 | vector>& seq, 9 | vector>& keep, 10 | vector n 11 | ) { 12 | vector> batch_canvas, batch_rest, batch_loc; 13 | for (uint32_t b = 0; b < seq.size(); b++) { 14 | vector indices, canvas, rest, loc; 15 | for (uint32_t i = 0; i < n[b] + 2; i++) { 16 | if (keep[b][i]) { 17 | canvas.push_back(seq[b][i]); 18 | indices.push_back(i); 19 | } else { 20 | rest.push_back(i); 21 | } 22 | } 23 | if (rest.size() == 0) { 24 | rest.push_back(n[b] + 1); 25 | loc.push_back(n[b]); 26 | } 27 | else { 28 | uint32_t j =0; 29 | for (uint32_t i: rest) { 30 | while (indices[j] < i) { 31 | j++; 32 | } 33 | loc.push_back(j-1); 34 | } 35 | } 36 | 37 | batch_canvas.push_back(canvas); 38 | batch_rest.push_back(rest); 39 | batch_loc.push_back(loc); 40 | } 41 | return {batch_canvas, batch_rest, batch_loc}; 42 | } 43 | 44 | vector>> get_canvas( 45 | vector>& seq, 46 | vector>& keep, 47 | vector n, 48 | uint32_t blank_id) { 49 | vector> batch_canvas, batch_blanks, batch_rest, batch_loc, batch_lb, batch_rb; 50 | for (uint32_t b = 0; b < seq.size(); b++) { 51 | vector canvas, blanks, rest, loc, lb, rb; 52 | for (uint32_t i = 0; i < n[b]; ) { 53 | if (keep[b][i]) { 54 | canvas.push_back(seq[b][i]); 55 | i++; 56 | } else { 57 | lb.push_back(0); 58 | while (i < n[b] && !keep[b][i]) { 59 | rest.push_back(i); 60 | loc.push_back(blanks.size()); 61 | lb.push_back(1); 62 | rb.push_back(1); 63 | i++; 64 | } 65 | lb.pop_back(); 66 | rb.pop_back(); 67 | rb.push_back(0); 68 | blanks.push_back(canvas.size()); 69 | canvas.push_back(blank_id); 70 | } 71 | } 72 | batch_canvas.push_back(canvas); 73 | batch_blanks.push_back(blanks); 74 | batch_rest.push_back(rest); 75 | batch_loc.push_back(loc); 76 | batch_lb.push_back(lb); 77 | batch_rb.push_back(rb); 78 | } 79 | return {batch_canvas, batch_blanks, batch_rest, batch_loc, batch_lb, batch_rb}; 80 | } 81 | 82 | vector>> get_known_length_canvas( 83 | vector>& seq, 84 | vector>& keep, 85 | vector n, 86 | uint32_t blank_id) { 87 | vector> batch_canvas, batch_blanks, batch_rest, batch_loc, batch_lb; 88 | for (uint32_t b = 0; b < seq.size(); b++) { 89 | vector canvas, blanks, rest, loc, lb; 90 | for (uint32_t i = 0; i < n[b]; ) { 91 | if (keep[b][i]) { 92 | canvas.push_back(seq[b][i]); 93 | i++; 94 | } else { 95 | uint32_t cur_len = 0; 96 | while (i < n[b] && !keep[b][i]) { 97 | rest.push_back(i); 98 | loc.push_back(blanks.size()); 99 | lb.push_back(cur_len); 100 | i++; 101 | cur_len++; 102 | } 103 | blanks.push_back(canvas.size()); 104 | canvas.push_back(blank_id + cur_len); 105 | } 106 | } 107 | batch_canvas.push_back(canvas); 108 | batch_blanks.push_back(blanks); 109 | batch_rest.push_back(rest); 110 | batch_loc.push_back(loc); 111 | batch_lb.push_back(lb); 112 | } 113 | return {batch_canvas, batch_blanks, batch_rest, batch_loc, batch_lb}; 114 | } 115 | 116 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 117 | m.def("get_canvas", &get_canvas, "get_canvas"); 118 | m.def("get_insertion_canvas", &get_insertion_canvas, "get_insertion_canvas"); 119 | m.def("get_known_length_canvas", &get_known_length_canvas, "get_known_length_canvas"); 120 | 121 | } 122 | -------------------------------------------------------------------------------- /models/inst.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from . lm import LM 7 | from . torch_utils import get_ins_canvas, sample_permutation, seq_cross_entropy, collect, batch_randint, new_arange, select 8 | from vocab import Vocab 9 | 10 | 11 | class InsTLM(LM): 12 | """Insertion Transformer Language Model""" 13 | 14 | def __init__(self, hparams): 15 | super().__init__(hparams) 16 | hparams = self.hparams # a['key'] (if so) -> a.key 17 | 18 | self.pool_out = nn.Linear(2 * hparams.d_model, hparams.d_model) 19 | 20 | def get_loss(self, seq, canvas, rest, loc, mask): 21 | count = (rest != -1).sum(1) 22 | output = self.forward_encoder(canvas) 23 | features = self.pool_out(torch.cat((output[:, :-1, :], output[:, 1:, :]), dim=-1)) 24 | logits_loc = self.loc(features).squeeze(-1) 25 | logits_loc[~mask] = float('-inf') 26 | nll_loc = -F.log_softmax(logits_loc, 1) 27 | loss_loc = collect(nll_loc, loc) 28 | loss_loc = loss_loc.sum(1) / count.float() 29 | output_loc = collect(features, loc) 30 | 31 | logits_word = self.word(output_loc) * self.x_logit_scale 32 | target = collect(seq, rest, Vocab.pad) 33 | loss_word = seq_cross_entropy(logits_word, target, Vocab.pad) 34 | loss_word = loss_word.sum(1) / count.float() 35 | # output_word = torch.cat((output_loc, self.enc.src_word_emb(target)), -1) 36 | return loss_loc, loss_word 37 | 38 | def losses(self, seq, n, n_real): 39 | """ 40 | Args: 41 | n: number of BPE tokens 42 | n_real: number of real words (for reporting PPL) 43 | """ 44 | k = batch_randint(0, n) 45 | rank = sample_permutation(seq) 46 | keep = (rank < (k + 2).unsqueeze(1)) # keep and in addition 47 | canvas, rest, loc = get_ins_canvas(seq, keep, n) 48 | 49 | # canvas has + k tokens + , so k + 1 slots 50 | mask = (new_arange(canvas) < (k + 1).unsqueeze(1))[:, :-1] # mask for logits_loc 51 | loss_loc, loss_word = self.get_loss(seq, canvas, rest, loc, mask) 52 | nll_lb = (loss_loc + loss_word) * (n + 1).float() - (n + 1).float().lgamma() 53 | return {'loss': nll_lb.sum() / n_real.sum(), 54 | 'loc': loss_loc.mean(), 55 | 'word': loss_word.mean(), 56 | } 57 | 58 | def nll_mc(self, seq, n, m): 59 | """ 60 | Compute negative log-likelihood by monte carlo estimate 61 | Args: 62 | m: number of samples to take 63 | 64 | Note: sentences in the batch must have the same length 65 | """ 66 | a = [] 67 | for _ in range(m): 68 | rank = sample_permutation(seq) 69 | logp = 0. 70 | for k in range(2, seq.size(1) + 1): # k from 2 to n + 2 71 | keep = (rank < k) 72 | canvas, rest, loc = get_ins_canvas(seq, keep, n) 73 | if k == seq.size(1): 74 | pass # rest and loc are already correct 75 | else: 76 | k_th = (rank == k).nonzero(as_tuple=True)[1] # First token not kept 77 | x, y = (rest == k_th.unsqueeze(1)).nonzero(as_tuple=True) 78 | assert len(seq) == len(x) 79 | assert torch.all(x == torch.arange(len(seq), device=seq.device)) 80 | rest, loc = [t[x, y].unsqueeze(1) for t in [rest, loc]] 81 | mask = (new_arange(canvas) < (k - 1))[:, :-1] # mask for logits_loc 82 | loss_loc, loss_word = self.get_loss(seq, canvas, rest, loc, mask) 83 | logp -= loss_loc + loss_word 84 | a.append(logp.unsqueeze(1)) 85 | return np.log(m) - (n + 1).float().lgamma() - torch.logsumexp(torch.cat(a, 1), 1) 86 | 87 | def generate(self, seq, blanks, decode, device, force_insert=False, prioritize_unfilled=False): 88 | seq = torch.LongTensor([Vocab.first] + seq + [Vocab.last]).to(device) 89 | is_fill = [0] * len(seq) 90 | fill = [[]] 91 | full = [seq[1:-1]] 92 | mandatory_blanks = np.array(blanks) 93 | if len(blanks) > 0: 94 | while len(seq) < self.hparams.max_len: 95 | output = self.forward_encoder(seq.unsqueeze(0)) 96 | features = self.pool_out( 97 | torch.cat((output[:, :-1, :], output[:, 1:, :]), dim=-1) 98 | )[0] 99 | 100 | logits_loc = self.loc(features).squeeze(-1) 101 | 102 | all_filled = (mandatory_blanks == None).all() 103 | can_stop = not force_insert or all_filled 104 | end_slot = len(seq) - 2 105 | 106 | if prioritize_unfilled and not all_filled: 107 | logits_loc[np.array(blanks)[mandatory_blanks == None]] = float('-inf') 108 | 109 | if end_slot not in blanks and can_stop: # enable end slot for termination 110 | blanks_end = blanks + [end_slot] 111 | loc = select(logits_loc[blanks_end], decode) 112 | pos = blanks_end[loc] 113 | else: 114 | loc = select(logits_loc[blanks], decode) 115 | pos = blanks[loc] 116 | 117 | output_loc = features[pos] 118 | logits_word = self.word(output_loc) * self.x_logit_scale 119 | 120 | if pos == end_slot: 121 | if end_slot not in blanks: # end slot is added artificially, so no words allowed there 122 | break 123 | elif not can_stop: 124 | logits_word[Vocab.last] = float('-inf') 125 | 126 | word = select(logits_word, decode) 127 | 128 | if pos == end_slot and word.item() == Vocab.last: 129 | break 130 | 131 | blanks = blanks[:loc + 1] + [x + 1 for x in blanks[loc:]] 132 | mandatory_blanks = np.concatenate(( 133 | mandatory_blanks[:loc], 134 | np.array([None]), 135 | np.array([None]), 136 | [x + 1 if x is not None else None for x in mandatory_blanks[loc + 1:]] 137 | )) 138 | seq = torch.cat((seq[:pos + 1], word.unsqueeze(0), seq[pos + 1:])) 139 | is_fill = is_fill[:pos + 1] + [1] + is_fill[pos + 1:] 140 | fill.append([id for id, isf in zip(seq, is_fill) if isf]) 141 | full.append(seq[1:-1]) 142 | return fill, full 143 | -------------------------------------------------------------------------------- /models/lblm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from . lm import LM 7 | from . torch_utils import get_known_length_canvas, sample_permutation, seq_cross_entropy, collect, batch_randint, select 8 | from vocab import Vocab 9 | 10 | 11 | class LBLM(LM): 12 | """Length-aware Blank Language Model""" 13 | 14 | def __init__(self, hparams): 15 | super().__init__(hparams) 16 | hparams = self.hparams # a['key'] (if so) -> a.key 17 | 18 | self.lrb = nn.Sequential( 19 | nn.Linear(hparams.d_model * 2, hparams.d_model * 2), 20 | nn.ReLU(), 21 | nn.Linear(hparams.d_model * 2, hparams.max_len) 22 | ) 23 | 24 | def blank_indices(self): 25 | return Vocab.blank_0 + np.arange(self.hparams.max_len) 26 | 27 | def init_canvas(self): 28 | return np.random.choice(self.blank_indices()[1:]) # no blank_0 29 | 30 | def get_loss(self, seq, canvas, blanks, rest, loc, lb): 31 | count = (rest != -1).sum(1) 32 | output = self.forward_encoder(canvas) 33 | output_blank = collect(output, blanks) 34 | 35 | logits_loc = self.loc(output_blank).squeeze(-1) 36 | logits_loc[blanks == -1] = float('-inf') 37 | nll_loc = -F.log_softmax(logits_loc, 1) 38 | loss_loc = collect(nll_loc, loc) 39 | loss_loc = loss_loc.sum(1) / count.float() 40 | output_loc = collect(output_blank, loc) 41 | 42 | logits_word = self.word(output_loc) * self.x_logit_scale 43 | target = collect(seq, rest, Vocab.pad) 44 | loss_word = seq_cross_entropy(logits_word, target, Vocab.pad) 45 | loss_word = loss_word.sum(1) / count.float() 46 | output_word = torch.cat((output_loc, self.enc.src_word_emb(target)), -1) 47 | 48 | logits_lrb = self.lrb(output_word) 49 | 50 | # mask out illegal blank options 51 | length = collect(canvas, blanks) - Vocab.blank_0 52 | length_loc = collect(length, loc, -1) 53 | bs, seq_len = length_loc.shape 54 | ta = length_loc.unsqueeze(-1).repeat(1, 1, self.hparams.max_len) 55 | ra = torch.arange(self.hparams.max_len).unsqueeze(0).unsqueeze(0).repeat(bs, seq_len, 1).to(ta.device) 56 | mask = (ra >= ta) 57 | logits_lrb.masked_fill_(mask, float('-inf')) 58 | 59 | loss_lrb = seq_cross_entropy(logits_lrb, lb, -1) 60 | loss_lrb = loss_lrb.sum(1) / count.float() 61 | 62 | return loss_loc, loss_word, loss_lrb 63 | 64 | def losses(self, seq, n, n_real): 65 | """ 66 | Args: 67 | n: number of BPE tokens 68 | n_real: number of real words (for reporting PPL) 69 | """ 70 | k = batch_randint(0, n - 1) 71 | rank = sample_permutation(seq) 72 | keep = (rank < k.unsqueeze(1)) 73 | canvas, blanks, rest, loc, lb = get_known_length_canvas(seq, keep, n) 74 | loss_loc, loss_word, loss_lrb = self.get_loss(seq, canvas, blanks, rest, loc, lb) 75 | nll_lb = (loss_loc + loss_word + loss_lrb) * n.float() - (n + 1).float().lgamma() 76 | return {'loss': nll_lb.sum() / n_real.sum(), 77 | 'loc': loss_loc.mean(), 78 | 'word': loss_word.mean(), 79 | 'lrb': loss_lrb.mean() 80 | } 81 | 82 | # lower than real perplexity since conditioned on length 83 | def nll_mc(self, seq, n, m): 84 | """ 85 | Compute negative log-likelihood by monte carlo estimate 86 | Args: 87 | m: number of samples to take 88 | 89 | Note: sentences in the batch must have the same length 90 | """ 91 | a = [] 92 | for _ in range(m): 93 | rank = sample_permutation(seq) 94 | logp = 0. 95 | for k in range(seq.size(1)): 96 | keep = (rank < k) 97 | canvas, blanks, rest, loc, lb = get_known_length_canvas(seq, keep, n) 98 | k_th = (rank == k).nonzero(as_tuple=True)[1] 99 | x, y = (rest == k_th.unsqueeze(1)).nonzero(as_tuple=True) 100 | assert torch.all(x == torch.arange(len(seq), device=seq.device)) 101 | rest, loc, lb = [t[x, y].unsqueeze(1) for t in [rest, loc, lb]] 102 | loss_loc, loss_word, loss_lrb = self.get_loss(seq, canvas, blanks, rest, loc, lb) 103 | logp -= loss_loc + loss_word + loss_lrb 104 | a.append(logp.unsqueeze(1)) 105 | return np.log(m) - (n + 1).float().lgamma() - torch.logsumexp(torch.cat(a, 1), 1) 106 | 107 | def generate(self, seq, decode, device): 108 | seq = torch.LongTensor(seq).to(device) 109 | blanks = [i for i, w in enumerate(seq) if w.item() in self.blank_indices()] 110 | is_fill = [0] * len(seq) 111 | fill = [[]] 112 | full = [seq] 113 | while len(blanks) > 0 and len(seq) <= self.hparams.max_len: 114 | output = self.forward_encoder(seq.unsqueeze(0))[0] 115 | output_blank = output[blanks] 116 | loc = select(self.loc(output_blank).squeeze(-1), decode) 117 | output_loc = output_blank[loc] 118 | 119 | length_previous = seq[blanks[loc]] - Vocab.blank_0 120 | 121 | logits_word = self.word(output_loc) * self.x_logit_scale 122 | logits_word[self.blank_indices()] = float('-inf') # never predict 123 | 124 | # joint word, lrb prediction 125 | lprob_word = F.log_softmax(logits_word, -1) 126 | output_word = torch.cat((output_loc.unsqueeze(0).expand(self.hparams.vocab_size, -1), 127 | self.enc.src_word_emb.weight), -1) 128 | logits_lrb = self.lrb(output_word) 129 | logits_lrb[:, length_previous:] = float('-inf') # mask out illegal blank options 130 | max_blank_len = logits_lrb.shape[1] 131 | lprob_lrb = F.log_softmax(logits_lrb, -1) 132 | lprob_word_lrb = lprob_word.unsqueeze(1) + lprob_lrb 133 | word_lrb = select(lprob_word_lrb.view(-1), decode) 134 | word, lrb = word_lrb // max_blank_len, word_lrb % max_blank_len 135 | 136 | lb = lrb 137 | rb = length_previous - lb - 1 138 | 139 | ins = ([Vocab.blank_0 + lb] if lb else []) + [word] + ([Vocab.blank_0 + rb] if rb else []) 140 | ins = torch.LongTensor(ins).to(device) 141 | pos = blanks[loc] 142 | seq = torch.cat((seq[:pos], ins, seq[pos + 1:])) 143 | blanks = [i for i, w in enumerate(seq) if w.item() in self.blank_indices()] 144 | is_fill = is_fill[:pos] + [1] * len(ins) + is_fill[pos + 1:] 145 | fill.append([id for id, isf in zip(seq, is_fill) if isf]) 146 | full.append(seq) 147 | return fill, full 148 | -------------------------------------------------------------------------------- /models/lm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | 5 | from transformer.Models import Encoder 6 | from optimizer import config_opt_schedule 7 | from vocab import Vocab 8 | 9 | 10 | class LM(pl.LightningModule): 11 | """Language Model Container Class""" 12 | 13 | def __init__(self, hparams): 14 | super().__init__() 15 | self.hparams = hparams 16 | hparams = self.hparams # a['key'] (if so) -> a.key 17 | 18 | self.enc = Encoder( 19 | n_src_vocab=hparams.vocab_size, len_max_seq=hparams.max_len, 20 | d_word_vec=hparams.d_model, d_model=hparams.d_model, 21 | d_inner=hparams.d_inner_hid, d_k=hparams.d_k, d_v=hparams.d_v, 22 | n_layers=hparams.n_layers, n_head=hparams.n_head, 23 | dropout=hparams.dropout) 24 | 25 | self.word = nn.Linear(hparams.d_model, hparams.vocab_size, bias=False) 26 | nn.init.xavier_normal_(self.word.weight) 27 | self.x_logit_scale = 1. 28 | if hparams.share_emb_prj_weight: 29 | self.word.weight = self.enc.src_word_emb.weight 30 | self.x_logit_scale = (hparams.d_model ** -0.5) 31 | 32 | self.loc = nn.Linear(hparams.d_model, 1) 33 | 34 | def configure_optimizers(self): 35 | return config_opt_schedule(self.parameters(), self.hparams) 36 | 37 | def training_step(self, batch, batch_idx): 38 | seq, n, n_real = map(lambda x: x.squeeze(0), batch) 39 | losses = self('losses', seq, n, n_real) 40 | return {**losses, 'log': {**losses}} 41 | 42 | def eval_step(self, batch, batch_idx): 43 | seq, n, n_real = map(lambda x: x.squeeze(0), batch) 44 | losses = self('losses', seq, n, n_real) 45 | if self.hparams.n_mc > 0: 46 | nll = self('nll_mc', seq, n, self.hparams.n_mc).sum() 47 | else: 48 | nll = losses['loss'] * n_real.sum() 49 | n_words = n_real.sum() 50 | return {**losses, 'n_words': n_words, 'nll': nll} 51 | 52 | def eval_epoch_end(self, outputs): 53 | # n_words and nll are batch/dataset sum, other losses are mean 54 | losses = {} 55 | for key in outputs[0].keys(): 56 | if key not in ['n_words', 'nll']: 57 | losses[key] = torch.stack([x[key] for x in outputs]).mean() 58 | nll = torch.stack([x['nll'] for x in outputs]).sum() 59 | n_words = torch.stack([x['n_words'] for x in outputs]).sum() 60 | ppl = torch.exp(nll / n_words) 61 | return {**losses, 'nll': nll, 'n_words': n_words, 'ppl': ppl} 62 | 63 | def validation_step(self, batch, batch_idx): 64 | return self.eval_step(batch, batch_idx) 65 | 66 | def validation_epoch_end(self, outputs): 67 | logs = self.eval_epoch_end(outputs) 68 | val_logs = {'val_' + k: v for k, v in logs.items()} 69 | return {'val_loss': logs['loss'], 'log': val_logs} 70 | 71 | def test_step(self, batch, batch_idx): 72 | return self.eval_step(batch, batch_idx) 73 | 74 | def test_epoch_end(self, outputs): 75 | logs = self.eval_epoch_end(outputs) 76 | test_logs = {'test_' + k: v for k, v in logs.items()} 77 | return {'test_loss': logs['loss'], 'log': test_logs} 78 | 79 | def forward_encoder(self, canvas): 80 | pos = (1 + torch.arange(canvas.size(1))).repeat(len(canvas), 1) 81 | pos[canvas == Vocab.pad] = 0 82 | output, *_ = self.enc(canvas, pos.to(canvas.device)) 83 | return output 84 | 85 | def forward(self, action, *args): 86 | if action == 'nll_mc': 87 | return self.nll_mc(*args) 88 | elif action == 'losses': 89 | return self.losses(*args) 90 | raise NotImplementedError 91 | -------------------------------------------------------------------------------- /models/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.cpp_extension import load 4 | 5 | from vocab import Vocab 6 | 7 | get_canvas_cpp = load(name='canvas', sources=['models/get_canvas.cpp']) 8 | 9 | 10 | def select(logits, decode): 11 | if decode == 'sample': 12 | return torch.multinomial(logits.exp(), num_samples=1)[0] 13 | else: 14 | return logits.argmax() 15 | 16 | 17 | def seq_cross_entropy(pred, gold, pad): 18 | gold_shape = gold.shape 19 | pred = pred.view(-1, pred.size(-1)) 20 | gold = gold.view(-1) 21 | loss = F.cross_entropy(pred, gold, ignore_index=pad, reduction='none') 22 | return loss.view(gold_shape) 23 | 24 | 25 | def new_arange(x, *size): 26 | """ 27 | Return a Tensor of `size` filled with a range function on the device of x 28 | If `size` is empty, using the size of the variable x 29 | """ 30 | if len(size) == 0: 31 | size = x.size() 32 | return torch.arange(size[-1], device=x.device).expand(*size).contiguous() 33 | 34 | 35 | def batch_randint(start, batch_end): 36 | """ 37 | Sample k from start to end (both inclusive) 38 | Return the same shape as batch_end 39 | """ 40 | return start + (torch.rand_like(batch_end.float()) * (batch_end - start + 1).float()).long() 41 | 42 | 43 | def sample_permutation(seq): 44 | score = torch.rand_like(seq.float()) 45 | score.masked_fill_(seq == Vocab.pad, 1) # always put pads last 46 | score.masked_fill_(seq == Vocab.first, -1) # always keep 47 | score.masked_fill_(seq == Vocab.last, -1) # always keep 48 | indices = score.argsort() 49 | rank = torch.zeros_like(seq) 50 | rank[torch.arange(len(seq)).unsqueeze(1), indices] = \ 51 | torch.arange(seq.size(1), device=seq.device) 52 | return rank 53 | 54 | 55 | def collect(input, index, padding_idx=0): 56 | """ 57 | Perform a batched index select where index is given for each example 58 | Args: 59 | input: tensor of shape (B, T_1, dim_2, dim_3, ...) 60 | index: tensor of shape (B, T_2) 61 | Return: 62 | tensor of shape (B, T_2, dim_2, dim_3, ...) 63 | """ 64 | # Add a column of padding_idx at index 0 (of dim 1) 65 | view = list(input.shape) 66 | view[1] = 1 67 | padding_column = input.new_ones(view) * padding_idx 68 | input = torch.cat([padding_column, input], 1) 69 | 70 | # Expand index to compatible size for gather 71 | for i in range(2, len(input.shape)): 72 | index = index.unsqueeze(i) 73 | 74 | view[0] = -1 75 | view[1] = -1 76 | index = index.expand(view) 77 | return torch.gather(input, 1, index + 1) 78 | 79 | 80 | def to_tensor(x, pad_id, device): 81 | max_len = max([len(xi) for xi in x]) 82 | x_ = [xi + [pad_id] * (max_len - len(xi)) for xi in x] 83 | return torch.tensor(x_).to(device) 84 | 85 | 86 | def get_canvas(seq, keep, n): 87 | """ 88 | Args: 89 | seq: original (batched) sequence of tokens 90 | keep: mask over seq indicating whether to keep each token 91 | n: number of tokens 92 | Return: 93 | canvas: replace consecutive masked tokens in seq by the token 94 | blanks: indices of tokens in canvas 95 | rest: indices of masked tokens in seq, these are the tokens to predict 96 | loc: indices of how rest relates to blanks 97 | lb: whether to create a left blank for predicting each token in rest 98 | rb: whether to create a right blank for predicting each token in rest 99 | (rest, loc, lb, rb have the same shape) 100 | """ 101 | res = get_canvas_cpp.get_canvas(seq.tolist(), keep.tolist(), n.tolist(), Vocab.blank) 102 | pad = [Vocab.pad, -1, -1, -1, -1, -1] 103 | return [to_tensor(r, p, seq.device) for r, p in zip(res, pad)] 104 | 105 | 106 | def get_known_length_canvas(seq, keep, n): 107 | """ 108 | Return: 109 | canvas: replace consecutive masked tokens in seq by the token 110 | blanks: indices of tokens in canvas 111 | rest: indices of masked tokens in seq, these are the tokens to predict 112 | loc: indices of how rest relates to blanks 113 | lb: length of the new left blank for predicting each token in rest 114 | (rest, loc, lb have the same shape) 115 | """ 116 | res = get_canvas_cpp.get_known_length_canvas(seq.tolist(), keep.tolist(), n.tolist(), Vocab.blank_0) 117 | pad = [Vocab.pad, -1, -1, -1, -1] 118 | return [to_tensor(r, p, seq.device) for r, p in zip(res, pad)] 119 | 120 | 121 | def get_ins_canvas(seq, keep, n): 122 | """ 123 | Return: 124 | canvas: remove masked tokens in seq 125 | rest: indices of masked tokens in seq, these are the tokens to predict 126 | loc: indices of how rest relates to canvas 127 | (rest, loc have the same shape) 128 | """ 129 | res = get_canvas_cpp.get_insertion_canvas(seq.tolist(), keep.tolist(), n.tolist()) 130 | pad = [Vocab.pad, -1, -1] 131 | return [to_tensor(r, p, seq.device) for r, p in zip(res, pad)] 132 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def config_opt_schedule(params, args): 5 | optimizer = torch.optim.Adam( 6 | params, 7 | betas=eval(args.adam_betas), 8 | eps=args.adam_eps, 9 | weight_decay=args.weight_decay, 10 | lr=args.lr 11 | ) 12 | 13 | if args.lr_schedule == 'fixed': 14 | return optimizer 15 | 16 | elif args.lr_schedule == 'triangular': 17 | scheduler = torch.optim.lr_scheduler.CyclicLR( 18 | optimizer, 19 | base_lr=0, 20 | max_lr=args.lr, 21 | step_size_up=args.warmup_steps, 22 | step_size_down=args.descend_steps, 23 | cycle_momentum=False, 24 | ) 25 | return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] 26 | 27 | else: 28 | raise ValueError('Unknown lr schedule ' + args.lr_schedule) 29 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | import torch 5 | import pytorch_lightning as pl 6 | 7 | from vocab import Vocab 8 | from utils import load_data, load_sent, load_model, makedir, write 9 | from dataset import get_eval_dataloader 10 | 11 | 12 | def main(args): 13 | pl.seed_everything(args.seed) 14 | 15 | model = load_model(args.checkpoint).to(device) 16 | model.eval() 17 | vocab = Vocab(os.path.join(model.hparams.root_dir, 'vocab.txt')) 18 | 19 | if args.eval: 20 | data = load_data(args.eval, model.hparams.add_eos, model.hparams.cat_sent, model.hparams.max_len) 21 | dl = get_eval_dataloader( 22 | data, vocab, args.max_tok, 23 | data_workers=args.data_workers, 24 | model_type=model.hparams.model_type) 25 | trainer = pl.Trainer( 26 | gpus=args.gpus, 27 | amp_level=args.fp16_opt_level, 28 | precision=16 if args.fp16 else 32, 29 | default_root_dir='testing_logs') 30 | model.hparams.n_mc = args.n_mc 31 | trainer.test(model, test_dataloaders=dl) 32 | 33 | if args.output: 34 | output = os.path.join(os.path.dirname(os.path.dirname(args.checkpoint)), 'outputs/', args.output) 35 | makedir(output) 36 | 37 | if args.sample: 38 | with open(output, 'w') as f: 39 | for i in tqdm(range(args.sample)): 40 | if model.hparams.model_type == 'inst': 41 | _, full = model.generate([], [0], args.decode, device) 42 | else: 43 | _, full = model.generate([model.init_canvas()], args.decode, device) 44 | 45 | full = [[vocab.idx2word[id] for id in ids] for ids in full] 46 | write(f, full, args.write_mid) 47 | 48 | if args.fill: 49 | sents = load_sent(args.fill, model.hparams.add_eos) 50 | sents = [[vocab.word_to_idx(w) for w in s] for s in sents] 51 | with open(output + '.fill', 'w') as f_fill: 52 | with open(output + '.full', 'w') as f_full: 53 | for s in tqdm(sents): 54 | if model.hparams.model_type == 'inst': 55 | seq, blanks = [], [] 56 | for w in s: 57 | if w == vocab.blank: 58 | blanks.append(len(seq)) 59 | else: 60 | seq.append(w) 61 | if args.anywhere: 62 | blanks = list(range(len(seq) + 1)) 63 | fill, full = model.generate(seq, blanks, args.decode, device, 64 | args.force_insert, args.prioritize_unfilled) 65 | else: 66 | fill, full = model.generate(s, args.decode, device) 67 | 68 | fill = [[vocab.idx2word[id] for id in ids] for ids in fill] 69 | full = [[vocab.idx2word[id] for id in ids] for ids in full] 70 | write(f_fill, fill, args.write_mid) 71 | write(f_full, full, args.write_mid) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | 77 | parser.add_argument('--checkpoint', required=True, 78 | help='path to checkpoint') 79 | 80 | parser.add_argument('--eval', default='', 81 | help='data file to evaluate') 82 | parser.add_argument('--n_mc', type=int, default=10, 83 | help='num of samples for monte carlo estimate of ppl') 84 | parser.add_argument('--max_tok', type=int, default=40000, 85 | help='max number of tokens per batch') 86 | 87 | parser.add_argument('--output', default='', 88 | help='output file') 89 | parser.add_argument('--sample', type=int, default=0, 90 | help='num of sentences to generate') 91 | parser.add_argument('--fill', default='', 92 | help='input file to fill') 93 | parser.add_argument('--decode', default='greedy', 94 | choices=['greedy', 'sample'], 95 | help='greedy decoding or sampling') 96 | parser.add_argument('--write_mid', action='store_true', 97 | help='write intermediate partial sentences') 98 | 99 | # Specific to InsT 100 | parser.add_argument('--anywhere', action='store_true', 101 | help='fill in anywhere, not only blanks') 102 | parser.add_argument('--force_insert', action='store_true', 103 | help='disable termination unless all slots are filled') 104 | parser.add_argument('--prioritize_unfilled', action='store_true', 105 | help='prioritize unfilled slots if any') 106 | 107 | parser.add_argument('--seed', type=int, default=1111, 108 | help='random seed') 109 | parser.add_argument('--data_workers', type=int, default=8, 110 | help='data workers') 111 | parser.add_argument('--no_cuda', action='store_true', 112 | help='disable CUDA') 113 | parser.add_argument('--fp16', action='store_true', 114 | help='whether to use 16-bit (mixed) precision ' 115 | '(through NVIDIA apex) instead of 32-bit') 116 | parser.add_argument('--fp16_opt_level', default='O1', 117 | help="for fp16: Apex AMP optimization level selected " 118 | "in ['O0', 'O1', 'O2', and 'O3']. see details at " 119 | "https://nvidia.github.io/apex/amp.html") 120 | 121 | args = parser.parse_args() 122 | 123 | cuda = not args.no_cuda and torch.cuda.is_available() 124 | device = torch.device('cuda' if cuda else 'cpu') 125 | args.gpus = 1 if cuda else 0 126 | 127 | main(args) 128 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.callbacks import LearningRateMonitor 7 | 8 | from models import get_model_class 9 | from vocab import Vocab 10 | from utils import load_data 11 | from dataset import get_train_dataloader, get_eval_dataloader 12 | 13 | 14 | def main(args): 15 | pl.seed_everything(args.seed) 16 | 17 | torch.multiprocessing.set_sharing_strategy('file_system') 18 | 19 | args.multigpu = torch.cuda.device_count() > 1 20 | 21 | train_data = load_data(args.train, args.add_eos, args.cat_sent, args.max_len) 22 | valid_data = load_data(args.valid, args.add_eos, args.cat_sent, args.max_len) 23 | 24 | os.makedirs(args.root_dir, exist_ok=True) 25 | 26 | vocab_file = os.path.join(args.root_dir, 'vocab.txt') 27 | if not os.path.isfile(vocab_file): 28 | max_blank_len = args.max_len if args.model_type == 'lblm' else None 29 | Vocab.build(train_data, vocab_file, args.vocab_size, max_blank_len) 30 | vocab = Vocab(vocab_file) 31 | args.vocab_size = vocab.size 32 | 33 | train_dl = get_train_dataloader( 34 | train_data, vocab, args.max_tok, 35 | data_workers=args.data_workers if not args.multigpu else 0, 36 | model_type=args.model_type) 37 | val_dl = get_eval_dataloader( 38 | valid_data, vocab, args.eval_max_tok, 39 | data_workers=args.data_workers if not args.multigpu else 0, 40 | model_type=args.model_type) 41 | 42 | model = get_model_class(args.model_type)(args) 43 | 44 | trainer = pl.Trainer( 45 | accumulate_grad_batches=args.accum_grad, 46 | max_steps=args.max_steps, 47 | callbacks=[LearningRateMonitor()] if args.lr_schedule != 'fixed' else None, 48 | val_check_interval=args.val_check_interval if args.val_check_interval > 0 else 1.0, 49 | gpus=args.gpus, 50 | distributed_backend='ddp' if args.multigpu else None, 51 | amp_level=args.fp16_opt_level, 52 | precision=16 if args.fp16 else 32, 53 | default_root_dir=args.root_dir, 54 | resume_from_checkpoint=args.load_checkpoint 55 | ) 56 | 57 | trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl) 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | 63 | # Path 64 | parser.add_argument('--train', 65 | help='path to training file') 66 | parser.add_argument('--valid', 67 | help='path to validation file') 68 | parser.add_argument('--root_dir', default='checkpoints', 69 | help='directory to save checkpoints and outputs') 70 | parser.add_argument('--load_checkpoint', default=None, 71 | help='path to load checkpoint if specified') 72 | 73 | # Data 74 | parser.add_argument('--vocab_size', type=int, default=10000, 75 | help='keep N most frequent words in vocabulary') 76 | parser.add_argument('--max_len', type=int, default=512, 77 | help='max sequence length') 78 | parser.add_argument('--cat_sent', action='store_true', 79 | help='concat sentences and chunk into size of max_len') 80 | parser.add_argument('--add_eos', action='store_true', 81 | help='add at the end of each sentence') 82 | 83 | # Model 84 | parser.add_argument('--model_type', default='blm', 85 | choices=['blm', 'inst', 'lblm'], 86 | help='model type: blm, inst or lblm') 87 | 88 | parser.add_argument('--d_model', type=int, default=512, 89 | help='transformer dimension d_model') 90 | parser.add_argument('--d_inner_hid', type=int, default=2048, 91 | help='transformer dimension d_inner_hid') 92 | parser.add_argument('--d_k', type=int, default=64, 93 | help='transformer dimension d_k') 94 | parser.add_argument('--d_v', type=int, default=64, 95 | help='transformer dimension d_v') 96 | parser.add_argument('--n_head', type=int, default=8, 97 | help='number of attention heads') 98 | parser.add_argument('--n_layers', type=int, default=6, 99 | help='number of layers') 100 | parser.add_argument('--share_emb_prj_weight', action='store_true', 101 | help='share word embedding and projection weights') 102 | 103 | # Optimization 104 | parser.add_argument('--max_tok', type=int, default=10000, 105 | help='max number of tokens per batch') 106 | parser.add_argument('--accum_grad', type=int, default=1, 107 | help='accumulate gradients across N batches.') 108 | 109 | parser.add_argument('--adam_betas', default='(0.9, 0.999)', 110 | help='adam betas') 111 | parser.add_argument('--adam_eps', type=float, default=1e-8, 112 | help='adam eps') 113 | parser.add_argument('--weight_decay', type=float, default=1e-5, 114 | help='weight decay') 115 | parser.add_argument('--dropout', type=float, default=0.3, 116 | help='dropout probability (0 = no dropout)') 117 | 118 | parser.add_argument('--lr_schedule', default='fixed', 119 | choices=['fixed', 'triangular'], 120 | help='learning rate schedule') 121 | parser.add_argument('--lr', type=float, default=0.0001, 122 | help='learning rate') 123 | parser.add_argument('--warmup_steps', type=int, default=4000, 124 | help='number of warmup steps (triangular)') 125 | parser.add_argument('--descend_steps', type=int, default=300000, 126 | help='number of descending steps (triangular)') 127 | parser.add_argument('--max_steps', type=int, default=500000, 128 | help='number of training steps') 129 | 130 | # Validation 131 | parser.add_argument('--eval_max_tok', type=int, default=40000, 132 | help='max number of tokens per batch for evaluation') 133 | parser.add_argument('--val_check_interval', type=int, default=0, 134 | help='check validation set every N training batches' 135 | '(0 means checking once an epoch)') 136 | parser.add_argument('--n_mc', type=int, default=1, 137 | help='num of samples for Monte Carlo estimate of ppl') 138 | 139 | # Others 140 | parser.add_argument('--seed', type=int, default=1111, 141 | help='random seed') 142 | parser.add_argument('--data_workers', type=int, default=8, 143 | help='data workers') 144 | parser.add_argument('--gpus', type=int, default=-1, 145 | help='number of gpus to train on (-1 means all gpus)') 146 | parser.add_argument('--fp16', action='store_true', 147 | help='whether to use 16-bit (mixed) precision ' 148 | '(through NVIDIA apex) instead of 32-bit') 149 | parser.add_argument('--fp16_opt_level', default='O1', 150 | help="for fp16: Apex AMP optimization level selected " 151 | "in ['O0', 'O1', 'O2', and 'O3']. see details at " 152 | "https://nvidia.github.io/apex/amp.html") 153 | 154 | args = parser.parse_args() 155 | 156 | main(args) 157 | -------------------------------------------------------------------------------- /transformer/Beam.py: -------------------------------------------------------------------------------- 1 | """ Manage beam search info structure. 2 | 3 | Heavily borrowed from OpenNMT-py. 4 | For code in OpenNMT-py, please check the following link: 5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import transformer.Constants as Constants 11 | 12 | class Beam(): 13 | ''' Beam search ''' 14 | 15 | def __init__(self, size, device=False): 16 | 17 | self.size = size 18 | self._done = False 19 | 20 | # The score for each translation on the beam. 21 | self.scores = torch.zeros((size,), dtype=torch.float, device=device) 22 | self.all_scores = [] 23 | 24 | # The backpointers at each time-step. 25 | self.prev_ks = [] 26 | 27 | # The outputs at each time-step. 28 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] 29 | self.next_ys[0][0] = Constants.BOS 30 | 31 | def get_current_state(self): 32 | "Get the outputs for the current timestep." 33 | return self.get_tentative_hypothesis() 34 | 35 | def get_current_origin(self): 36 | "Get the backpointers for the current timestep." 37 | return self.prev_ks[-1] 38 | 39 | @property 40 | def done(self): 41 | return self._done 42 | 43 | def advance(self, word_prob): 44 | "Update beam status and check if finished or not." 45 | num_words = word_prob.size(1) 46 | 47 | # Sum the previous scores. 48 | if len(self.prev_ks) > 0: 49 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) 50 | else: 51 | beam_lk = word_prob[0] 52 | 53 | flat_beam_lk = beam_lk.view(-1) 54 | 55 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 56 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort 57 | 58 | self.all_scores.append(self.scores) 59 | self.scores = best_scores 60 | 61 | # bestScoresId is flattened as a (beam x word) array, 62 | # so we need to calculate which word and beam each score came from 63 | prev_k = best_scores_id / num_words 64 | self.prev_ks.append(prev_k) 65 | self.next_ys.append(best_scores_id - prev_k * num_words) 66 | 67 | # End condition is when top-of-beam is EOS. 68 | if self.next_ys[-1][0].item() == Constants.EOS: 69 | self._done = True 70 | self.all_scores.append(self.scores) 71 | 72 | return self._done 73 | 74 | def sort_scores(self): 75 | "Sort the scores." 76 | return torch.sort(self.scores, 0, True) 77 | 78 | def get_the_best_score_and_idx(self): 79 | "Get the score of the best in the beam." 80 | scores, ids = self.sort_scores() 81 | return scores[1], ids[1] 82 | 83 | def get_tentative_hypothesis(self): 84 | "Get the decoded sequence for the current timestep." 85 | 86 | if len(self.next_ys) == 1: 87 | dec_seq = self.next_ys[0].unsqueeze(1) 88 | else: 89 | _, keys = self.sort_scores() 90 | hyps = [self.get_hypothesis(k) for k in keys] 91 | hyps = [[Constants.BOS] + h for h in hyps] 92 | dec_seq = torch.LongTensor(hyps) 93 | 94 | return dec_seq 95 | 96 | def get_hypothesis(self, k): 97 | """ Walk back to construct the full hypothesis. """ 98 | hyp = [] 99 | for j in range(len(self.prev_ks) - 1, -1, -1): 100 | hyp.append(self.next_ys[j+1][k]) 101 | k = self.prev_ks[j][k] 102 | 103 | return list(map(lambda x: x.item(), hyp[::-1])) 104 | -------------------------------------------------------------------------------- /transformer/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | #PAD_WORD = '' 8 | #UNK_WORD = '' 9 | #BOS_WORD = '' 10 | #EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /transformer/Layers.py: -------------------------------------------------------------------------------- 1 | ''' Define the Layers ''' 2 | import torch.nn as nn 3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | ''' Compose with two layers ''' 10 | 11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 12 | super(EncoderLayer, self).__init__() 13 | self.slf_attn = MultiHeadAttention( 14 | n_head, d_model, d_k, d_v, dropout=dropout) 15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 16 | 17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 18 | enc_output, enc_slf_attn = self.slf_attn( 19 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 20 | enc_output *= non_pad_mask 21 | 22 | enc_output = self.pos_ffn(enc_output) 23 | enc_output *= non_pad_mask 24 | 25 | return enc_output, enc_slf_attn 26 | 27 | 28 | class DecoderLayer(nn.Module): 29 | ''' Compose with three layers ''' 30 | 31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 32 | super(DecoderLayer, self).__init__() 33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 36 | 37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 38 | dec_output, dec_slf_attn = self.slf_attn( 39 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 40 | dec_output *= non_pad_mask 41 | 42 | dec_output, dec_enc_attn = self.enc_attn( 43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 44 | dec_output *= non_pad_mask 45 | 46 | dec_output = self.pos_ffn(dec_output) 47 | dec_output *= non_pad_mask 48 | 49 | return dec_output, dec_slf_attn, dec_enc_attn 50 | -------------------------------------------------------------------------------- /transformer/Models.py: -------------------------------------------------------------------------------- 1 | ''' Define the Transformer model ''' 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import transformer.Constants as Constants 6 | from transformer.Layers import EncoderLayer, DecoderLayer 7 | 8 | __author__ = "Yu-Hsiang Huang" 9 | 10 | def get_non_pad_mask(seq): 11 | assert seq.dim() == 2 12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) 13 | 14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 15 | ''' Sinusoid position encoding table ''' 16 | 17 | def cal_angle(position, hid_idx): 18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 19 | 20 | def get_posi_angle_vec(position): 21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | 25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 27 | 28 | if padding_idx is not None: 29 | # zero vector for padding dimension 30 | sinusoid_table[padding_idx] = 0. 31 | 32 | return torch.FloatTensor(sinusoid_table) 33 | 34 | def get_attn_key_pad_mask(seq_k, seq_q): 35 | ''' For masking out the padding part of key sequence. ''' 36 | 37 | # Expand to fit the shape of key query attention matrix. 38 | len_q = seq_q.size(1) 39 | padding_mask = seq_k.eq(Constants.PAD) 40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 41 | 42 | return padding_mask 43 | 44 | def get_subsequent_mask(seq): 45 | ''' For masking out the subsequent info. ''' 46 | 47 | sz_b, len_s = seq.size() 48 | subsequent_mask = torch.triu( 49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 51 | 52 | return subsequent_mask 53 | 54 | class Encoder(nn.Module): 55 | ''' A encoder model with self attention mechanism. ''' 56 | 57 | def __init__( 58 | self, 59 | n_src_vocab, len_max_seq, d_word_vec, 60 | n_layers, n_head, d_k, d_v, 61 | d_model, d_inner, dropout=0.1): 62 | 63 | super().__init__() 64 | 65 | n_position = len_max_seq + 1 66 | 67 | self.src_word_emb = nn.Embedding( 68 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD) 69 | 70 | self.position_enc = nn.Embedding.from_pretrained( 71 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 72 | freeze=True) 73 | 74 | self.layer_stack = nn.ModuleList([ 75 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 76 | for _ in range(n_layers)]) 77 | 78 | def forward(self, src_seq, src_pos, return_attns=False): 79 | 80 | enc_slf_attn_list = [] 81 | 82 | # -- Prepare masks 83 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) 84 | non_pad_mask = get_non_pad_mask(src_seq) 85 | 86 | # -- Forward 87 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) 88 | 89 | for enc_layer in self.layer_stack: 90 | enc_output, enc_slf_attn = enc_layer( 91 | enc_output, 92 | non_pad_mask=non_pad_mask, 93 | slf_attn_mask=slf_attn_mask) 94 | if return_attns: 95 | enc_slf_attn_list += [enc_slf_attn] 96 | 97 | if return_attns: 98 | return enc_output, enc_slf_attn_list 99 | return enc_output, 100 | 101 | class Decoder(nn.Module): 102 | ''' A decoder model with self attention mechanism. ''' 103 | 104 | def __init__( 105 | self, 106 | n_tgt_vocab, len_max_seq, d_word_vec, 107 | n_layers, n_head, d_k, d_v, 108 | d_model, d_inner, dropout=0.1): 109 | 110 | super().__init__() 111 | n_position = len_max_seq + 1 112 | 113 | self.tgt_word_emb = nn.Embedding( 114 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) 115 | 116 | self.position_enc = nn.Embedding.from_pretrained( 117 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 118 | freeze=True) 119 | 120 | self.layer_stack = nn.ModuleList([ 121 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 122 | for _ in range(n_layers)]) 123 | 124 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): 125 | 126 | dec_slf_attn_list, dec_enc_attn_list = [], [] 127 | 128 | # -- Prepare masks 129 | non_pad_mask = get_non_pad_mask(tgt_seq) 130 | 131 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) 132 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) 133 | slf_attn_mask = (slf_attn_mask_keypad.type_as(slf_attn_mask_subseq) + slf_attn_mask_subseq).gt(0) 134 | 135 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) 136 | 137 | # -- Forward 138 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) 139 | 140 | for dec_layer in self.layer_stack: 141 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 142 | dec_output, enc_output, 143 | non_pad_mask=non_pad_mask, 144 | slf_attn_mask=slf_attn_mask, 145 | dec_enc_attn_mask=dec_enc_attn_mask) 146 | 147 | if return_attns: 148 | dec_slf_attn_list += [dec_slf_attn] 149 | dec_enc_attn_list += [dec_enc_attn] 150 | 151 | if return_attns: 152 | return dec_output, dec_slf_attn_list, dec_enc_attn_list 153 | return dec_output, 154 | 155 | class Transformer(nn.Module): 156 | ''' A sequence to sequence model with attention mechanism. ''' 157 | 158 | def __init__( 159 | self, 160 | n_src_vocab, n_tgt_vocab, len_max_seq, 161 | d_word_vec=512, d_model=512, d_inner=2048, 162 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, 163 | tgt_emb_prj_weight_sharing=True, 164 | emb_src_tgt_weight_sharing=True): 165 | 166 | super().__init__() 167 | 168 | self.encoder = Encoder( 169 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, 170 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 171 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 172 | dropout=dropout) 173 | 174 | self.decoder = Decoder( 175 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, 176 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 177 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 178 | dropout=dropout) 179 | 180 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 181 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 182 | 183 | assert d_model == d_word_vec, \ 184 | 'To facilitate the residual connections, \ 185 | the dimensions of all module outputs shall be the same.' 186 | 187 | if tgt_emb_prj_weight_sharing: 188 | # Share the weight matrix between target word embedding & the final logit dense layer 189 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight 190 | self.x_logit_scale = (d_model ** -0.5) 191 | else: 192 | self.x_logit_scale = 1. 193 | 194 | if emb_src_tgt_weight_sharing: 195 | # Share the weight matrix between source & target word embeddings 196 | assert n_src_vocab == n_tgt_vocab, \ 197 | "To share word embedding table, the vocabulary size of src/tgt shall be the same." 198 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight 199 | 200 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): 201 | 202 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] 203 | 204 | enc_output, *_ = self.encoder(src_seq, src_pos) 205 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) 206 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 207 | 208 | return seq_logit.view(-1, seq_logit.size(2)) 209 | -------------------------------------------------------------------------------- /transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' Scaled Dot-Product Attention ''' 9 | 10 | def __init__(self, temperature, attn_dropout=0.1): 11 | super().__init__() 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | self.softmax = nn.Softmax(dim=2) 15 | 16 | def forward(self, q, k, v, mask=None): 17 | 18 | attn = torch.bmm(q, k.transpose(1, 2)) 19 | attn = attn / self.temperature 20 | 21 | if mask is not None: 22 | attn = attn.masked_fill(mask, -np.inf) 23 | 24 | attn = self.softmax(attn) 25 | attn = self.dropout(attn) 26 | output = torch.bmm(attn, v) 27 | 28 | return output, attn 29 | -------------------------------------------------------------------------------- /transformer/Optim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | 3 | class LRScheduler(object): 4 | 5 | def __init__(self, optimizer, lr): 6 | self._optimizer = optimizer 7 | self.lr = lr 8 | self.set_lr() 9 | 10 | def step(self): 11 | self._optimizer.step() 12 | 13 | def zero_grad(self): 14 | self._optimizer.zero_grad() 15 | 16 | def set_lr(self): 17 | for param_group in self._optimizer.param_groups: 18 | param_group['lr'] = self.lr 19 | 20 | 21 | class InverseSqrtScheduler(LRScheduler): 22 | 23 | def __init__(self, optimizer, peak_lr, warmup_steps): 24 | super().__init__(optimizer, 0) 25 | 26 | self.warmup_steps = warmup_steps 27 | self.current_step = 0 28 | # linearly warmup for the first warmup_steps 29 | self.warmup_factor = peak_lr / warmup_steps 30 | # then, decay prop. to the inverse square root of the step number 31 | self.decay_factor = peak_lr * warmup_steps**0.5 32 | 33 | def step(self): 34 | self._update_learning_rate() 35 | super().step() 36 | 37 | def _update_learning_rate(self): 38 | self.current_step += 1 39 | if self.current_step < self.warmup_steps: 40 | self.lr = self.warmup_factor * self.current_step 41 | else: 42 | self.lr = self.decay_factor * self.current_step**-0.5 43 | self.set_lr() 44 | 45 | 46 | class LinearDecayScheduler(LRScheduler): 47 | 48 | def __init__(self, optimizer, peak_lr, warmup_steps, total_steps): 49 | super().__init__(optimizer, 0) 50 | 51 | self.warmup_steps = warmup_steps 52 | self.total_steps = total_steps 53 | self.current_step = 0 54 | # linearly warmup for the first warmup_steps 55 | self.warmup_factor = peak_lr / warmup_steps 56 | # then, linearly decay to 0 57 | self.decay_factor = peak_lr / (total_steps - warmup_steps) 58 | 59 | def step(self): 60 | self._update_learning_rate() 61 | super().step() 62 | 63 | def _update_learning_rate(self): 64 | self.current_step += 1 65 | if self.current_step < self.warmup_steps: 66 | self.lr = self.warmup_factor * self.current_step 67 | elif self.current_step < self.total_steps: 68 | self.lr = self.decay_factor * (self.total_steps - self.current_step) 69 | else: 70 | self.lr = 0 71 | self.set_lr() 72 | -------------------------------------------------------------------------------- /transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | ''' Define the sublayers in encoder/decoder layer ''' 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformer.Modules import ScaledDotProductAttention 6 | 7 | __author__ = "Yu-Hsiang Huang" 8 | 9 | class MultiHeadAttention(nn.Module): 10 | ''' Multi-Head Attention module ''' 11 | 12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model, n_head * d_v) 22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.fc = nn.Linear(n_head * d_v, d_model) 30 | nn.init.xavier_normal_(self.fc.weight) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | 35 | def forward(self, q, k, v, mask=None): 36 | 37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 38 | 39 | sz_b, len_q, _ = q.size() 40 | sz_b, len_k, _ = k.size() 41 | sz_b, len_v, _ = v.size() 42 | 43 | residual = q 44 | 45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 48 | 49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 52 | 53 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 54 | output, attn = self.attention(q, k, v, mask=mask) 55 | 56 | output = output.view(n_head, sz_b, len_q, d_v) 57 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 58 | 59 | output = self.dropout(self.fc(output)) 60 | output = self.layer_norm(output + residual) 61 | 62 | return output, attn 63 | 64 | class PositionwiseFeedForward(nn.Module): 65 | ''' A two-feed-forward-layer module ''' 66 | 67 | def __init__(self, d_in, d_hid, dropout=0.1): 68 | super().__init__() 69 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 70 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 71 | self.layer_norm = nn.LayerNorm(d_in) 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | def forward(self, x): 75 | residual = x 76 | output = x.transpose(1, 2) 77 | output = self.w_2(F.relu(self.w_1(output))) 78 | output = output.transpose(1, 2) 79 | output = self.dropout(output) 80 | output = self.layer_norm(output + residual) 81 | return output 82 | -------------------------------------------------------------------------------- /transformer/Translator.py: -------------------------------------------------------------------------------- 1 | ''' This module will handle the text generation with beam search. ''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformer.Models import Transformer 8 | from transformer.Beam import Beam 9 | 10 | class Translator(object): 11 | ''' Load with trained model and handle the beam search ''' 12 | 13 | def __init__(self, opt): 14 | self.opt = opt 15 | self.device = torch.device('cuda' if opt.cuda else 'cpu') 16 | 17 | checkpoint = torch.load(opt.model) 18 | model_opt = checkpoint['settings'] 19 | self.model_opt = model_opt 20 | 21 | model = Transformer( 22 | model_opt.src_vocab_size, 23 | model_opt.tgt_vocab_size, 24 | model_opt.max_token_seq_len, 25 | tgt_emb_prj_weight_sharing=model_opt.proj_share_weight, 26 | emb_src_tgt_weight_sharing=model_opt.embs_share_weight, 27 | d_k=model_opt.d_k, 28 | d_v=model_opt.d_v, 29 | d_model=model_opt.d_model, 30 | d_word_vec=model_opt.d_word_vec, 31 | d_inner=model_opt.d_inner_hid, 32 | n_layers=model_opt.n_layers, 33 | n_head=model_opt.n_head, 34 | dropout=model_opt.dropout) 35 | 36 | model.load_state_dict(checkpoint['model']) 37 | print('[Info] Trained model state loaded.') 38 | 39 | model.word_prob_prj = nn.LogSoftmax(dim=1) 40 | 41 | model = model.to(self.device) 42 | 43 | self.model = model 44 | self.model.eval() 45 | 46 | def translate_batch(self, src_seq, src_pos): 47 | ''' Translation work in one batch ''' 48 | 49 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 50 | ''' Indicate the position of an instance in a tensor. ''' 51 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 52 | 53 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): 54 | ''' Collect tensor parts associated to active instances. ''' 55 | 56 | _, *d_hs = beamed_tensor.size() 57 | n_curr_active_inst = len(curr_active_inst_idx) 58 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 59 | 60 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) 61 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) 62 | beamed_tensor = beamed_tensor.view(*new_shape) 63 | 64 | return beamed_tensor 65 | 66 | def collate_active_info( 67 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): 68 | # Sentences which are still active are collected, 69 | # so the decoder will not run on completed sentences. 70 | n_prev_active_inst = len(inst_idx_to_position_map) 71 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 72 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 73 | 74 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) 75 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) 76 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 77 | 78 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map 79 | 80 | def beam_decode_step( 81 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): 82 | ''' Decode and update beam status, and then return active beam idx ''' 83 | 84 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 85 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] 86 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) 87 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) 88 | return dec_partial_seq 89 | 90 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 91 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 92 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 93 | return dec_partial_pos 94 | 95 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): 96 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) 97 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h 98 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) 99 | word_prob = word_prob.view(n_active_inst, n_bm, -1) 100 | 101 | return word_prob 102 | 103 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): 104 | active_inst_idx_list = [] 105 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 106 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) 107 | if not is_inst_complete: 108 | active_inst_idx_list += [inst_idx] 109 | 110 | return active_inst_idx_list 111 | 112 | n_active_inst = len(inst_idx_to_position_map) 113 | 114 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) 115 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) 116 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) 117 | 118 | # Update the beam with predicted word prob information and collect incomplete instances 119 | active_inst_idx_list = collect_active_inst_idx_list( 120 | inst_dec_beams, word_prob, inst_idx_to_position_map) 121 | 122 | return active_inst_idx_list 123 | 124 | def collect_hypothesis_and_scores(inst_dec_beams, n_best): 125 | all_hyp, all_scores = [], [] 126 | for inst_idx in range(len(inst_dec_beams)): 127 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 128 | all_scores += [scores[:n_best]] 129 | 130 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] 131 | all_hyp += [hyps] 132 | return all_hyp, all_scores 133 | 134 | with torch.no_grad(): 135 | #-- Encode 136 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) 137 | src_enc, *_ = self.model.encoder(src_seq, src_pos) 138 | 139 | #-- Repeat data for beam search 140 | n_bm = self.opt.beam_size 141 | n_inst, len_s, d_h = src_enc.size() 142 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 143 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 144 | 145 | #-- Prepare beams 146 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] 147 | 148 | #-- Bookkeeping for active or not 149 | active_inst_idx_list = list(range(n_inst)) 150 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 151 | 152 | #-- Decode 153 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): 154 | 155 | active_inst_idx_list = beam_decode_step( 156 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) 157 | 158 | if not active_inst_idx_list: 159 | break # all instances have finished their path to 160 | 161 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info( 162 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) 163 | 164 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) 165 | 166 | return batch_hyp, batch_scores 167 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import transformer.Constants 2 | import transformer.Modules 3 | import transformer.Layers 4 | import transformer.SubLayers 5 | import transformer.Models 6 | import transformer.Translator 7 | import transformer.Beam 8 | import transformer.Optim 9 | 10 | __all__ = [ 11 | transformer.Constants, transformer.Modules, transformer.Layers, 12 | transformer.SubLayers, transformer.Models, transformer.Optim, 13 | transformer.Translator, transformer.Beam] 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | from models import get_model_class 5 | 6 | 7 | def strip_eos(sents): 8 | return [sent[:sent.index('')] if '' in sent else sent 9 | for sent in sents] 10 | 11 | 12 | def makedir(path): 13 | dir = os.path.dirname(path) 14 | if dir: 15 | os.makedirs(dir, exist_ok=True) 16 | 17 | 18 | def repeat(f, x, n): 19 | for i in range(n): 20 | x = f(x) 21 | return x 22 | 23 | 24 | def get_hparams(checkpoint): 25 | hparams_file = os.path.join(os.path.dirname(os.path.dirname(checkpoint)), 'hparams.yaml') 26 | with open(hparams_file) as stream: 27 | return yaml.safe_load(stream) 28 | 29 | 30 | def load_model(checkpoint): 31 | hparams = get_hparams(checkpoint) 32 | model = get_model_class(hparams['model_type']).load_from_checkpoint(checkpoint, hparams=hparams) 33 | model.hparams.root_dir = repeat(lambda x: os.path.dirname(x), checkpoint, 4) 34 | return model 35 | 36 | 37 | def load_sent(path, add_eos=False): 38 | sents = [] 39 | with open(path) as f: 40 | for line in f: 41 | s = line.split() 42 | if add_eos: 43 | s += [''] 44 | sents.append(s) 45 | return sents 46 | 47 | 48 | def load_data(path, add_eos=False, cat_sent=False, max_len=512): 49 | if not add_eos: 50 | print('WARNING: You should always use add_eos to get comparable PPL on' 51 | 'language modeling tasks') 52 | 53 | sents = load_sent(path, add_eos) 54 | if cat_sent: 55 | if not add_eos: 56 | raise ValueError('Using cat_sent without add_eos') 57 | d = [w for s in sents for w in s] 58 | data = [d[i: i + max_len] for i in range(0, len(d), max_len)] 59 | else: 60 | print('# truncated sentences:', 61 | sum(1 for s in sents if len(s) > max_len)) 62 | data = [s[:max_len] for s in sents] 63 | return data 64 | 65 | 66 | def write(file, sents, write_mid): 67 | sents = strip_eos(sents) 68 | if write_mid: 69 | for s in sents: 70 | file.write(' '.join(s) + '\n') 71 | file.write('\n') 72 | else: 73 | file.write(' '.join(sents[-1]) + '\n') 74 | file.flush() 75 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | 4 | class Vocab(object): 5 | def __init__(self, path): 6 | self.word2idx = {} 7 | self.idx2word = [] 8 | 9 | with open(path) as f: 10 | for line in f: 11 | w = line.split()[0] 12 | self.word2idx[w] = len(self.word2idx) 13 | self.idx2word.append(w) 14 | self.size = len(self.word2idx) 15 | 16 | pad, unk, first, last, eos, blank, blank_0 = range(7) 17 | 18 | @staticmethod 19 | def build(sents, path, size, max_blank_len=None): 20 | voc = ['', '', '', '', '', '', ''] 21 | if max_blank_len: 22 | voc += [''.format(i) for i in range(1, max_blank_len)] 23 | occ = [0 for _ in voc] 24 | 25 | cnt = Counter([w for s in sents for w in s]) 26 | for i, v in enumerate(voc): 27 | if v in cnt: 28 | occ[i] = cnt[v] 29 | del cnt[v] 30 | for v, o in cnt.most_common(size): 31 | voc.append(v) 32 | occ.append(o) 33 | for v, o in cnt.most_common()[size:]: 34 | occ[Vocab.unk] += o 35 | 36 | with open(path, 'w') as f: 37 | for v, o in zip(voc, occ): 38 | f.write('{}\t{}\n'.format(v, o)) 39 | 40 | def word_to_idx(self, word): 41 | return self.word2idx[word] if word in self.word2idx else Vocab.unk 42 | --------------------------------------------------------------------------------