├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── dataset.py
├── download_data.sh
├── get_model.sh
├── img
└── model.png
├── models
├── __init__.py
├── blm.py
├── get_canvas.cpp
├── inst.py
├── lblm.py
├── lm.py
└── torch_utils.py
├── optimizer.py
├── test.py
├── train.py
├── transformer
├── Beam.py
├── Constants.py
├── Layers.py
├── Models.py
├── Modules.py
├── Optim.py
├── SubLayers.py
├── Translator.py
└── __init__.py
├── utils.py
└── vocab.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | rsync_exclude.txt
3 | figs/
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # Environments
89 | .env
90 | .venv
91 | env/
92 | venv/
93 | ENV/
94 | env.bak/
95 | venv.bak/
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # vim
102 | .vim
103 |
104 | # Rope project settings
105 | .ropeproject
106 |
107 | # mkdocs documentation
108 | /site
109 |
110 | # mypy
111 | .mypy_cache/
112 |
113 | # ddp
114 | __temp_weight_ddp_end.ckpt
115 |
116 | checkpoints/
117 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Blank Language Models
2 |
3 | This repository contains the code for our EMNLP 2020 paper:
4 | [**Blank Language Models**](https://arxiv.org/abs/2002.03079)
5 | *Tianxiao Shen*, Victor Quach*, Regina Barzilay, and Tommi Jaakkola (*: Equal contribution)*
6 |
7 |
8 |
9 | Given partially specified text with one or more blanks, BLM will fill in the blanks with a variable number of tokens consistent with the context, making it ideal for text editing and rewriting.
10 |
11 | > Input: They also have \___ which \___ .
12 | > Output: They also have ice cream which is really good .
13 |
14 |
15 |
16 |

17 |
18 |
19 | ## Demo
20 |
21 | We have an online demo built using [streamlit](https://www.streamlit.io/), available [here](http://128.52.131.173:8501)
22 |
23 | Or try locally by running:
24 |
25 | ```
26 | streamlit run app.py
27 | ```
28 |
29 |
30 | ## Dependencies
31 |
32 | Our code is based on the [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) framework.
33 |
34 | It has been tested in PyTorch 1.6.0, PyTorch Lightning 1.0.7
35 |
36 |
37 | ## Download Data
38 |
39 | Download the processed Yelp and Yahoo datasets by running:
40 | ```
41 | bash download_data.sh
42 | ```
43 |
44 |
45 | ## Training
46 |
47 | To train a BLM on Yelp negative sentences:
48 | ```
49 | python train.py --train data/yelp/train.0 --valid data/yelp/valid.0 --root_dir checkpoints/yelp/neg/blm/ \
50 | --vocab_size 10000 --max_len 20 --model_type blm --share_emb_prj_weight
51 | ```
52 |
53 | Yelp positive sentences:
54 | ```
55 | python train.py --train data/yelp/train.1 --valid data/yelp/valid.1 --root_dir checkpoints/yelp/pos/blm/ \
56 | --vocab_size 10000 --max_len 20 --model_type blm --share_emb_prj_weight
57 | ```
58 |
59 | Yahoo documents:
60 | ```
61 | python train.py --train data/yahoo/train.txt --valid data/yahoo/valid.txt --root_dir checkpoints/yahoo/blm/ \
62 | --vocab_size 20000 --max_len 205 --model_type blm --share_emb_prj_weight
63 | ```
64 |
65 | Run `python train.py -h` to see all training options.
66 |
67 | You can use Tensorboard to monitor the training progress.
68 |
69 |
70 | ## Testing
71 |
72 | After training, we can evaluate the model's perplexity by Monte Carlo estimate, and use the model to generate text from scratch or fill in the blanks in the input.
73 |
74 | For all of the following, replace `epoch\=???.ckpt` with the checkpoint saved in training.
75 |
76 | - The following command evaluates for Yelp negative sentences:
77 |
78 | ```
79 | python test.py --checkpoint checkpoints/yelp/neg/blm/lightning_logs/version_0/checkpoints/epoch\=???.ckpt \
80 | --eval data/yelp/test.0 --n_mc 10
81 | ```
82 |
83 | - The following command samples from the model trained on Yelp negative sentences:
84 |
85 | ```
86 | python test.py --checkpoint checkpoints/yelp/neg/blm/lightning_logs/version_0/checkpoints/epoch\=???.ckpt \
87 | --sample 1000 --decode sample --output sample.txt
88 | ```
89 |
90 | - The following command uses the model trained on Yelp negative sentences to fill in blanked positive sentences to achieve sentiment transfer:
91 |
92 | ```
93 | python test.py --checkpoint checkpoints/yelp/neg/blm/lightning_logs/version_0/checkpoints/epoch\=???.ckpt \
94 | --fill data/yelp/blank/test.1.blank --output test.1.tsf
95 | ```
96 |
97 | To output the whole generation trajectory, turn on the `--write_mid` option.
98 |
99 | The output file will be stored in `outputs/` within the checkpoint directory.
100 |
101 |
102 | ## Acknowledgements
103 |
104 | We use the Transformer implementation from https://github.com/jadore801120/attention-is-all-you-need-pytorch
105 |
106 |
107 | ## Citation
108 |
109 | If you use our work, please cite:
110 |
111 | ```bibtex
112 | @inproceedings{shen2020blank,
113 | title = "{Blank Language Models}",
114 | author = "Shen, Tianxiao and
115 | Quach, Victor and
116 | Barzilay, Regina and
117 | Jaakkola, Tommi",
118 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
119 | month = nov,
120 | year = "2020",
121 | address = "Online",
122 | publisher = "Association for Computational Linguistics"
123 | }
124 | ```
125 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from tqdm import tqdm
4 |
5 | import torch
6 | import pytorch_lightning as pl
7 | import streamlit as st
8 |
9 | from vocab import Vocab
10 | from utils import load_data, load_sent, load_model, makedir, write
11 | from dataset import get_eval_dataloader
12 |
13 |
14 | st.sidebar.write("## Parameters")
15 |
16 | device = st.sidebar.selectbox("Device",
17 | ["cpu"] + ["cuda:{}".format(i) for i in range(torch.cuda.device_count())],
18 | 0,
19 | lambda key: key if key == "cpu" else "GPU {}".format(key)
20 | )
21 |
22 | st.write('# Blank Language Models: Demo')
23 |
24 | st.write("[View source](http://github.com/varal7/blank_language_model)")
25 |
26 | st.write('## Load model')
27 |
28 | yelp_neg = "checkpoints/yelp/neg/lightning_logs/version_0/checkpoints/model.ckpt"
29 | yelp_pos = "checkpoints/yelp/pos/lightning_logs/version_0/checkpoints/model.ckpt"
30 |
31 | if not os.path.exists(yelp_neg) or not os.path.exists(yelp_pos):
32 | st.write(":warning: Default models not found. Run `get_model.sh` to download models trained on Yelp.")
33 | checkpoint = st.radio("Load checkpoint", ("Custom model", ))
34 |
35 | else:
36 | checkpoint = st.radio("Load checkpoint", ("Yelp positive reviews", "Yelp negative reviews", "Custom model"))
37 |
38 | if checkpoint == "Custom model":
39 | checkpoint_file = st.text_input("Path to `model.ckpt` file", value=yelp_pos)
40 | else:
41 | checkpoint_file = yelp_pos if "Yelp positive" in checkpoint else yelp_neg
42 |
43 | @st.cache
44 | def get_model(checkpoint_file, device):
45 | model = load_model(checkpoint_file).to(device)
46 | model.eval()
47 | vocab = Vocab(os.path.join(model.hparams.root_dir, 'vocab.txt'))
48 | return model, vocab
49 |
50 |
51 | model, vocab = get_model(checkpoint_file, device)
52 |
53 | decode = st.sidebar.radio("Decoding", ("Greedy", "Sample")).lower()
54 |
55 | mode = st.sidebar.radio("Task", ('Infilling', 'Sample'))
56 |
57 |
58 | if mode == "Sample":
59 | _, full = model.generate([model.init_canvas()], decode, device)
60 | full = [[vocab.idx2word[id] for id in ids] for ids in full]
61 | for step in full:
62 | st.write(" ".join(step).replace("", "\_\_\_"))
63 |
64 | if mode == "Infilling":
65 | st.write('## Load infilling data')
66 | text_input = st.text_input("Blanked input", value="___ place ___ and ___ food ___ .").lower()
67 | s = text_input.replace("___", "").split()
68 | s += [''] if model.hparams.add_eos else []
69 | s = [vocab.word_to_idx(w) for w in s]
70 | _, full = model.generate(s, decode, device)
71 | full = [[vocab.idx2word[id] for id in ids] for ids in full]
72 | for step in full:
73 | st.write(" ".join(step).replace("", "\_\_\_"))
74 |
75 | if st.button("Rerun"):
76 | pass
77 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from vocab import Vocab
4 |
5 |
6 | def get_batch(x, vocab, append_at_ends=False):
7 | seq = []
8 | n = [len(s) for s in x]
9 | n_real = []
10 | max_len = max(n)
11 | for s, l in zip(x, n):
12 | # Combine BPE tokens to count the number of words
13 | n_real.append(l - sum(1 for t in s if t.endswith("@@")))
14 |
15 | s_idx = [vocab.word_to_idx(w) for w in s]
16 | if append_at_ends:
17 | s_idx = [Vocab.first] + s_idx + [Vocab.last]
18 | seq.append(s_idx + [Vocab.pad] * (max_len - l))
19 | return torch.LongTensor(seq), torch.LongTensor(n), torch.LongTensor(n_real)
20 |
21 |
22 | def get_batches(data, vocab, max_tok, append_at_ends=False, same_len=False):
23 | offset = 2 if append_at_ends else 0
24 |
25 | order = range(len(data))
26 | z = sorted(zip(order, data), key=lambda i: len(i[1]), reverse=True)
27 | order, data = zip(*z)
28 |
29 | batches = []
30 | i = 0
31 | while i < len(data):
32 | j = i
33 | while j < len(data) and (len(data[i]) + offset) * (j-i+1) <= max_tok \
34 | and (not same_len or len(data[j]) == len(data[i])):
35 | j += 1
36 | batches.append(get_batch(data[i: j], vocab, append_at_ends))
37 | i = j
38 | return batches, order
39 |
40 |
41 | class LMDataset(torch.utils.data.Dataset):
42 | def __init__(self, batches):
43 | self.batches = batches
44 |
45 | def __getitem__(self, idx):
46 | return self.batches[idx]
47 |
48 | def __len__(self):
49 | return len(self.batches)
50 |
51 |
52 | def get_train_dataloader(train_data, vocab, max_tok, data_workers=8,
53 | model_type=None):
54 | train_batches, _ = get_batches(train_data, vocab, max_tok,
55 | append_at_ends=(model_type == 'inst'))
56 | print("Number of train batches: {}".format(len(train_batches)))
57 | train_ds = LMDataset(train_batches)
58 | return torch.utils.data.DataLoader(train_ds, num_workers=data_workers,
59 | shuffle=True, pin_memory=True)
60 |
61 |
62 | def get_eval_dataloader(val_data, vocab, max_tok, data_workers=8,
63 | model_type=None):
64 | val_batches, _ = get_batches(val_data, vocab, max_tok, same_len=True,
65 | append_at_ends=(model_type == 'inst'))
66 | print("Number of eval batches: {}".format(len(val_batches)))
67 | val_ds = LMDataset(val_batches)
68 | return torch.utils.data.DataLoader(val_ds, num_workers=data_workers,
69 | pin_memory=True)
70 |
--------------------------------------------------------------------------------
/download_data.sh:
--------------------------------------------------------------------------------
1 | mkdir data
2 | cd data
3 |
4 | dir="http://people.csail.mit.edu/tianxiao/data"
5 |
6 | wget $dir/yelp_blm.zip
7 | unzip yelp_blm.zip
8 | rm yelp_blm.zip
9 |
10 | wget $dir/yahoo_blm.zip
11 | unzip yahoo_blm.zip
12 | rm yahoo_blm.zip
13 |
--------------------------------------------------------------------------------
/get_model.sh:
--------------------------------------------------------------------------------
1 | mkdir -p checkpoints/yelp
2 | cd checkpoints/yelp
3 | wget http://128.52.131.173:8000/yelp/neg.tgz -O neg.tgz
4 | wget http://128.52.131.173:8000/yelp/pos.tgz -O pos.tgz
5 | tar -xf neg.tgz
6 | tar -xf pos.tgz
7 |
--------------------------------------------------------------------------------
/img/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Varal7/blank_language_model/0ab36a6a6e5272683ba7b059c16035b0c9d00ef0/img/model.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . inst import InsTLM
2 | from . blm import BLM
3 | from . lblm import LBLM
4 |
5 |
6 | def get_model_class(model_type):
7 | if model_type == 'blm':
8 | return BLM
9 | elif model_type == 'inst':
10 | return InsTLM
11 | elif model_type == 'lblm':
12 | return LBLM
13 | else:
14 | raise ValueError('Unknown model ' + model_type)
15 |
--------------------------------------------------------------------------------
/models/blm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from . lm import LM
7 | from . torch_utils import get_canvas, sample_permutation, seq_cross_entropy, collect, batch_randint, select
8 | from vocab import Vocab
9 |
10 |
11 | class BLM(LM):
12 | """Blank Language Model"""
13 |
14 | def __init__(self, hparams):
15 | super().__init__(hparams)
16 | hparams = self.hparams # a['key'] (if so) -> a.key
17 |
18 | self.lrb = nn.Sequential(
19 | nn.Linear(hparams.d_model * 2, hparams.d_model * 2),
20 | nn.ReLU(),
21 | nn.Linear(hparams.d_model * 2, 4)
22 | )
23 |
24 | def init_canvas(self):
25 | return Vocab.blank
26 |
27 | def get_loss(self, seq, canvas, blanks, rest, loc, lb, rb):
28 | count = (rest != -1).sum(1)
29 | output = self.forward_encoder(canvas)
30 | output_blank = collect(output, blanks)
31 |
32 | logits_loc = self.loc(output_blank).squeeze(-1)
33 | logits_loc[blanks == -1] = float('-inf')
34 | nll_loc = -F.log_softmax(logits_loc, 1)
35 | loss_loc = collect(nll_loc, loc)
36 | loss_loc = loss_loc.sum(1) / count.float()
37 | output_loc = collect(output_blank, loc)
38 |
39 | logits_word = self.word(output_loc) * self.x_logit_scale
40 | target = collect(seq, rest, Vocab.pad)
41 | loss_word = seq_cross_entropy(logits_word, target, Vocab.pad)
42 | loss_word = loss_word.sum(1) / count.float()
43 | output_word = torch.cat((output_loc, self.enc.src_word_emb(target)), -1)
44 |
45 | logits_lrb = self.lrb(output_word)
46 | loss_lrb = seq_cross_entropy(logits_lrb, lb * 2 + rb, -3)
47 | loss_lrb = loss_lrb.sum(1) / count.float()
48 |
49 | return loss_loc, loss_word, loss_lrb
50 |
51 | def losses(self, seq, n, n_real):
52 | """
53 | Args:
54 | n: number of BPE tokens
55 | n_real: number of real words (for reporting PPL)
56 | """
57 | k = batch_randint(0, n - 1)
58 | rank = sample_permutation(seq)
59 | keep = (rank < k.unsqueeze(1))
60 | canvas, blanks, rest, loc, lb, rb = get_canvas(seq, keep, n)
61 | loss_loc, loss_word, loss_lrb = self.get_loss(seq, canvas, blanks, rest, loc, lb, rb)
62 | nll_lb = (loss_loc + loss_word + loss_lrb) * n.float() - (n + 1).float().lgamma()
63 | return {'loss': nll_lb.sum() / n_real.sum(),
64 | 'loc': loss_loc.mean(),
65 | 'word': loss_word.mean(),
66 | 'lrb': loss_lrb.mean()
67 | }
68 |
69 | def nll_mc(self, seq, n, m):
70 | """
71 | Compute negative log-likelihood by monte carlo estimate
72 | Args:
73 | m: number of samples to take
74 |
75 | Note: sentences in the batch must have the same length
76 | """
77 | a = []
78 | for _ in range(m):
79 | rank = sample_permutation(seq)
80 | logp = 0.
81 | for k in range(seq.size(1)):
82 | keep = (rank < k)
83 | canvas, blanks, rest, loc, lb, rb = get_canvas(seq, keep, n)
84 | k_th = (rank == k).nonzero(as_tuple=True)[1]
85 | x, y = (rest == k_th.unsqueeze(1)).nonzero(as_tuple=True)
86 | assert torch.all(x == torch.arange(len(seq), device=seq.device))
87 | rest, loc, lb, rb = [t[x, y].unsqueeze(1) for t in [rest, loc, lb, rb]]
88 | loss_loc, loss_word, loss_lrb = self.get_loss(seq, canvas, blanks, rest, loc, lb, rb)
89 | logp -= loss_loc + loss_word + loss_lrb
90 | a.append(logp.unsqueeze(1))
91 | return np.log(m) - (n + 1).float().lgamma() - torch.logsumexp(torch.cat(a, 1), 1)
92 |
93 | def generate(self, seq, decode, device):
94 | seq = torch.LongTensor(seq).to(device)
95 | blanks = [i for i, w in enumerate(seq) if w == Vocab.blank]
96 | is_fill = [0] * len(seq)
97 | fill = [[]]
98 | full = [seq]
99 | while len(blanks) > 0 and len(seq) <= self.hparams.max_len:
100 | output = self.forward_encoder(seq.unsqueeze(0))[0]
101 | output_blank = output[blanks]
102 | loc = select(self.loc(output_blank).squeeze(-1), decode)
103 | output_loc = output_blank[loc]
104 |
105 | logits_word = self.word(output_loc) * self.x_logit_scale
106 | logits_word[Vocab.blank] = float('-inf') # never predict
107 |
108 | # joint word, lrb prediction
109 | lprob_word = F.log_softmax(logits_word, -1)
110 | output_word = torch.cat((output_loc.unsqueeze(0).expand(self.hparams.vocab_size, -1),
111 | self.enc.src_word_emb.weight), -1)
112 | logits_lrb = self.lrb(output_word)
113 | lprob_lrb = F.log_softmax(logits_lrb, -1)
114 | lprob_word_lrb = lprob_word.unsqueeze(1) + lprob_lrb
115 | word_lrb = select(lprob_word_lrb.view(-1), decode)
116 | word, lrb = word_lrb // 4, word_lrb % 4
117 |
118 | # predict word first and then lrb
119 | # word = select(logits_word, decode)
120 | # output_word = torch.cat((output_loc, self.enc.src_word_emb(word)), dim=-1)
121 | # lrb = select(self.lrb(output_word), decode)
122 |
123 | lb, rb = lrb // 2, lrb % 2
124 | ins = ([Vocab.blank] if lb else []) + [word] + ([Vocab.blank] if rb else [])
125 | ins = torch.LongTensor(ins).to(device)
126 | pos = blanks[loc]
127 | seq = torch.cat((seq[:pos], ins, seq[pos + 1:]))
128 | blanks = [i for i, w in enumerate(seq) if w == Vocab.blank]
129 | is_fill = is_fill[:pos] + [1] * len(ins) + is_fill[pos + 1:]
130 | fill.append([id for id, isf in zip(seq, is_fill) if isf])
131 | full.append(seq)
132 | return fill, full
133 |
--------------------------------------------------------------------------------
/models/get_canvas.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | using namespace std;
6 |
7 | vector>> get_insertion_canvas(
8 | vector>& seq,
9 | vector>& keep,
10 | vector n
11 | ) {
12 | vector> batch_canvas, batch_rest, batch_loc;
13 | for (uint32_t b = 0; b < seq.size(); b++) {
14 | vector indices, canvas, rest, loc;
15 | for (uint32_t i = 0; i < n[b] + 2; i++) {
16 | if (keep[b][i]) {
17 | canvas.push_back(seq[b][i]);
18 | indices.push_back(i);
19 | } else {
20 | rest.push_back(i);
21 | }
22 | }
23 | if (rest.size() == 0) {
24 | rest.push_back(n[b] + 1);
25 | loc.push_back(n[b]);
26 | }
27 | else {
28 | uint32_t j =0;
29 | for (uint32_t i: rest) {
30 | while (indices[j] < i) {
31 | j++;
32 | }
33 | loc.push_back(j-1);
34 | }
35 | }
36 |
37 | batch_canvas.push_back(canvas);
38 | batch_rest.push_back(rest);
39 | batch_loc.push_back(loc);
40 | }
41 | return {batch_canvas, batch_rest, batch_loc};
42 | }
43 |
44 | vector>> get_canvas(
45 | vector>& seq,
46 | vector>& keep,
47 | vector n,
48 | uint32_t blank_id) {
49 | vector> batch_canvas, batch_blanks, batch_rest, batch_loc, batch_lb, batch_rb;
50 | for (uint32_t b = 0; b < seq.size(); b++) {
51 | vector canvas, blanks, rest, loc, lb, rb;
52 | for (uint32_t i = 0; i < n[b]; ) {
53 | if (keep[b][i]) {
54 | canvas.push_back(seq[b][i]);
55 | i++;
56 | } else {
57 | lb.push_back(0);
58 | while (i < n[b] && !keep[b][i]) {
59 | rest.push_back(i);
60 | loc.push_back(blanks.size());
61 | lb.push_back(1);
62 | rb.push_back(1);
63 | i++;
64 | }
65 | lb.pop_back();
66 | rb.pop_back();
67 | rb.push_back(0);
68 | blanks.push_back(canvas.size());
69 | canvas.push_back(blank_id);
70 | }
71 | }
72 | batch_canvas.push_back(canvas);
73 | batch_blanks.push_back(blanks);
74 | batch_rest.push_back(rest);
75 | batch_loc.push_back(loc);
76 | batch_lb.push_back(lb);
77 | batch_rb.push_back(rb);
78 | }
79 | return {batch_canvas, batch_blanks, batch_rest, batch_loc, batch_lb, batch_rb};
80 | }
81 |
82 | vector>> get_known_length_canvas(
83 | vector>& seq,
84 | vector>& keep,
85 | vector n,
86 | uint32_t blank_id) {
87 | vector> batch_canvas, batch_blanks, batch_rest, batch_loc, batch_lb;
88 | for (uint32_t b = 0; b < seq.size(); b++) {
89 | vector canvas, blanks, rest, loc, lb;
90 | for (uint32_t i = 0; i < n[b]; ) {
91 | if (keep[b][i]) {
92 | canvas.push_back(seq[b][i]);
93 | i++;
94 | } else {
95 | uint32_t cur_len = 0;
96 | while (i < n[b] && !keep[b][i]) {
97 | rest.push_back(i);
98 | loc.push_back(blanks.size());
99 | lb.push_back(cur_len);
100 | i++;
101 | cur_len++;
102 | }
103 | blanks.push_back(canvas.size());
104 | canvas.push_back(blank_id + cur_len);
105 | }
106 | }
107 | batch_canvas.push_back(canvas);
108 | batch_blanks.push_back(blanks);
109 | batch_rest.push_back(rest);
110 | batch_loc.push_back(loc);
111 | batch_lb.push_back(lb);
112 | }
113 | return {batch_canvas, batch_blanks, batch_rest, batch_loc, batch_lb};
114 | }
115 |
116 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
117 | m.def("get_canvas", &get_canvas, "get_canvas");
118 | m.def("get_insertion_canvas", &get_insertion_canvas, "get_insertion_canvas");
119 | m.def("get_known_length_canvas", &get_known_length_canvas, "get_known_length_canvas");
120 |
121 | }
122 |
--------------------------------------------------------------------------------
/models/inst.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from . lm import LM
7 | from . torch_utils import get_ins_canvas, sample_permutation, seq_cross_entropy, collect, batch_randint, new_arange, select
8 | from vocab import Vocab
9 |
10 |
11 | class InsTLM(LM):
12 | """Insertion Transformer Language Model"""
13 |
14 | def __init__(self, hparams):
15 | super().__init__(hparams)
16 | hparams = self.hparams # a['key'] (if so) -> a.key
17 |
18 | self.pool_out = nn.Linear(2 * hparams.d_model, hparams.d_model)
19 |
20 | def get_loss(self, seq, canvas, rest, loc, mask):
21 | count = (rest != -1).sum(1)
22 | output = self.forward_encoder(canvas)
23 | features = self.pool_out(torch.cat((output[:, :-1, :], output[:, 1:, :]), dim=-1))
24 | logits_loc = self.loc(features).squeeze(-1)
25 | logits_loc[~mask] = float('-inf')
26 | nll_loc = -F.log_softmax(logits_loc, 1)
27 | loss_loc = collect(nll_loc, loc)
28 | loss_loc = loss_loc.sum(1) / count.float()
29 | output_loc = collect(features, loc)
30 |
31 | logits_word = self.word(output_loc) * self.x_logit_scale
32 | target = collect(seq, rest, Vocab.pad)
33 | loss_word = seq_cross_entropy(logits_word, target, Vocab.pad)
34 | loss_word = loss_word.sum(1) / count.float()
35 | # output_word = torch.cat((output_loc, self.enc.src_word_emb(target)), -1)
36 | return loss_loc, loss_word
37 |
38 | def losses(self, seq, n, n_real):
39 | """
40 | Args:
41 | n: number of BPE tokens
42 | n_real: number of real words (for reporting PPL)
43 | """
44 | k = batch_randint(0, n)
45 | rank = sample_permutation(seq)
46 | keep = (rank < (k + 2).unsqueeze(1)) # keep and in addition
47 | canvas, rest, loc = get_ins_canvas(seq, keep, n)
48 |
49 | # canvas has + k tokens + , so k + 1 slots
50 | mask = (new_arange(canvas) < (k + 1).unsqueeze(1))[:, :-1] # mask for logits_loc
51 | loss_loc, loss_word = self.get_loss(seq, canvas, rest, loc, mask)
52 | nll_lb = (loss_loc + loss_word) * (n + 1).float() - (n + 1).float().lgamma()
53 | return {'loss': nll_lb.sum() / n_real.sum(),
54 | 'loc': loss_loc.mean(),
55 | 'word': loss_word.mean(),
56 | }
57 |
58 | def nll_mc(self, seq, n, m):
59 | """
60 | Compute negative log-likelihood by monte carlo estimate
61 | Args:
62 | m: number of samples to take
63 |
64 | Note: sentences in the batch must have the same length
65 | """
66 | a = []
67 | for _ in range(m):
68 | rank = sample_permutation(seq)
69 | logp = 0.
70 | for k in range(2, seq.size(1) + 1): # k from 2 to n + 2
71 | keep = (rank < k)
72 | canvas, rest, loc = get_ins_canvas(seq, keep, n)
73 | if k == seq.size(1):
74 | pass # rest and loc are already correct
75 | else:
76 | k_th = (rank == k).nonzero(as_tuple=True)[1] # First token not kept
77 | x, y = (rest == k_th.unsqueeze(1)).nonzero(as_tuple=True)
78 | assert len(seq) == len(x)
79 | assert torch.all(x == torch.arange(len(seq), device=seq.device))
80 | rest, loc = [t[x, y].unsqueeze(1) for t in [rest, loc]]
81 | mask = (new_arange(canvas) < (k - 1))[:, :-1] # mask for logits_loc
82 | loss_loc, loss_word = self.get_loss(seq, canvas, rest, loc, mask)
83 | logp -= loss_loc + loss_word
84 | a.append(logp.unsqueeze(1))
85 | return np.log(m) - (n + 1).float().lgamma() - torch.logsumexp(torch.cat(a, 1), 1)
86 |
87 | def generate(self, seq, blanks, decode, device, force_insert=False, prioritize_unfilled=False):
88 | seq = torch.LongTensor([Vocab.first] + seq + [Vocab.last]).to(device)
89 | is_fill = [0] * len(seq)
90 | fill = [[]]
91 | full = [seq[1:-1]]
92 | mandatory_blanks = np.array(blanks)
93 | if len(blanks) > 0:
94 | while len(seq) < self.hparams.max_len:
95 | output = self.forward_encoder(seq.unsqueeze(0))
96 | features = self.pool_out(
97 | torch.cat((output[:, :-1, :], output[:, 1:, :]), dim=-1)
98 | )[0]
99 |
100 | logits_loc = self.loc(features).squeeze(-1)
101 |
102 | all_filled = (mandatory_blanks == None).all()
103 | can_stop = not force_insert or all_filled
104 | end_slot = len(seq) - 2
105 |
106 | if prioritize_unfilled and not all_filled:
107 | logits_loc[np.array(blanks)[mandatory_blanks == None]] = float('-inf')
108 |
109 | if end_slot not in blanks and can_stop: # enable end slot for termination
110 | blanks_end = blanks + [end_slot]
111 | loc = select(logits_loc[blanks_end], decode)
112 | pos = blanks_end[loc]
113 | else:
114 | loc = select(logits_loc[blanks], decode)
115 | pos = blanks[loc]
116 |
117 | output_loc = features[pos]
118 | logits_word = self.word(output_loc) * self.x_logit_scale
119 |
120 | if pos == end_slot:
121 | if end_slot not in blanks: # end slot is added artificially, so no words allowed there
122 | break
123 | elif not can_stop:
124 | logits_word[Vocab.last] = float('-inf')
125 |
126 | word = select(logits_word, decode)
127 |
128 | if pos == end_slot and word.item() == Vocab.last:
129 | break
130 |
131 | blanks = blanks[:loc + 1] + [x + 1 for x in blanks[loc:]]
132 | mandatory_blanks = np.concatenate((
133 | mandatory_blanks[:loc],
134 | np.array([None]),
135 | np.array([None]),
136 | [x + 1 if x is not None else None for x in mandatory_blanks[loc + 1:]]
137 | ))
138 | seq = torch.cat((seq[:pos + 1], word.unsqueeze(0), seq[pos + 1:]))
139 | is_fill = is_fill[:pos + 1] + [1] + is_fill[pos + 1:]
140 | fill.append([id for id, isf in zip(seq, is_fill) if isf])
141 | full.append(seq[1:-1])
142 | return fill, full
143 |
--------------------------------------------------------------------------------
/models/lblm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from . lm import LM
7 | from . torch_utils import get_known_length_canvas, sample_permutation, seq_cross_entropy, collect, batch_randint, select
8 | from vocab import Vocab
9 |
10 |
11 | class LBLM(LM):
12 | """Length-aware Blank Language Model"""
13 |
14 | def __init__(self, hparams):
15 | super().__init__(hparams)
16 | hparams = self.hparams # a['key'] (if so) -> a.key
17 |
18 | self.lrb = nn.Sequential(
19 | nn.Linear(hparams.d_model * 2, hparams.d_model * 2),
20 | nn.ReLU(),
21 | nn.Linear(hparams.d_model * 2, hparams.max_len)
22 | )
23 |
24 | def blank_indices(self):
25 | return Vocab.blank_0 + np.arange(self.hparams.max_len)
26 |
27 | def init_canvas(self):
28 | return np.random.choice(self.blank_indices()[1:]) # no blank_0
29 |
30 | def get_loss(self, seq, canvas, blanks, rest, loc, lb):
31 | count = (rest != -1).sum(1)
32 | output = self.forward_encoder(canvas)
33 | output_blank = collect(output, blanks)
34 |
35 | logits_loc = self.loc(output_blank).squeeze(-1)
36 | logits_loc[blanks == -1] = float('-inf')
37 | nll_loc = -F.log_softmax(logits_loc, 1)
38 | loss_loc = collect(nll_loc, loc)
39 | loss_loc = loss_loc.sum(1) / count.float()
40 | output_loc = collect(output_blank, loc)
41 |
42 | logits_word = self.word(output_loc) * self.x_logit_scale
43 | target = collect(seq, rest, Vocab.pad)
44 | loss_word = seq_cross_entropy(logits_word, target, Vocab.pad)
45 | loss_word = loss_word.sum(1) / count.float()
46 | output_word = torch.cat((output_loc, self.enc.src_word_emb(target)), -1)
47 |
48 | logits_lrb = self.lrb(output_word)
49 |
50 | # mask out illegal blank options
51 | length = collect(canvas, blanks) - Vocab.blank_0
52 | length_loc = collect(length, loc, -1)
53 | bs, seq_len = length_loc.shape
54 | ta = length_loc.unsqueeze(-1).repeat(1, 1, self.hparams.max_len)
55 | ra = torch.arange(self.hparams.max_len).unsqueeze(0).unsqueeze(0).repeat(bs, seq_len, 1).to(ta.device)
56 | mask = (ra >= ta)
57 | logits_lrb.masked_fill_(mask, float('-inf'))
58 |
59 | loss_lrb = seq_cross_entropy(logits_lrb, lb, -1)
60 | loss_lrb = loss_lrb.sum(1) / count.float()
61 |
62 | return loss_loc, loss_word, loss_lrb
63 |
64 | def losses(self, seq, n, n_real):
65 | """
66 | Args:
67 | n: number of BPE tokens
68 | n_real: number of real words (for reporting PPL)
69 | """
70 | k = batch_randint(0, n - 1)
71 | rank = sample_permutation(seq)
72 | keep = (rank < k.unsqueeze(1))
73 | canvas, blanks, rest, loc, lb = get_known_length_canvas(seq, keep, n)
74 | loss_loc, loss_word, loss_lrb = self.get_loss(seq, canvas, blanks, rest, loc, lb)
75 | nll_lb = (loss_loc + loss_word + loss_lrb) * n.float() - (n + 1).float().lgamma()
76 | return {'loss': nll_lb.sum() / n_real.sum(),
77 | 'loc': loss_loc.mean(),
78 | 'word': loss_word.mean(),
79 | 'lrb': loss_lrb.mean()
80 | }
81 |
82 | # lower than real perplexity since conditioned on length
83 | def nll_mc(self, seq, n, m):
84 | """
85 | Compute negative log-likelihood by monte carlo estimate
86 | Args:
87 | m: number of samples to take
88 |
89 | Note: sentences in the batch must have the same length
90 | """
91 | a = []
92 | for _ in range(m):
93 | rank = sample_permutation(seq)
94 | logp = 0.
95 | for k in range(seq.size(1)):
96 | keep = (rank < k)
97 | canvas, blanks, rest, loc, lb = get_known_length_canvas(seq, keep, n)
98 | k_th = (rank == k).nonzero(as_tuple=True)[1]
99 | x, y = (rest == k_th.unsqueeze(1)).nonzero(as_tuple=True)
100 | assert torch.all(x == torch.arange(len(seq), device=seq.device))
101 | rest, loc, lb = [t[x, y].unsqueeze(1) for t in [rest, loc, lb]]
102 | loss_loc, loss_word, loss_lrb = self.get_loss(seq, canvas, blanks, rest, loc, lb)
103 | logp -= loss_loc + loss_word + loss_lrb
104 | a.append(logp.unsqueeze(1))
105 | return np.log(m) - (n + 1).float().lgamma() - torch.logsumexp(torch.cat(a, 1), 1)
106 |
107 | def generate(self, seq, decode, device):
108 | seq = torch.LongTensor(seq).to(device)
109 | blanks = [i for i, w in enumerate(seq) if w.item() in self.blank_indices()]
110 | is_fill = [0] * len(seq)
111 | fill = [[]]
112 | full = [seq]
113 | while len(blanks) > 0 and len(seq) <= self.hparams.max_len:
114 | output = self.forward_encoder(seq.unsqueeze(0))[0]
115 | output_blank = output[blanks]
116 | loc = select(self.loc(output_blank).squeeze(-1), decode)
117 | output_loc = output_blank[loc]
118 |
119 | length_previous = seq[blanks[loc]] - Vocab.blank_0
120 |
121 | logits_word = self.word(output_loc) * self.x_logit_scale
122 | logits_word[self.blank_indices()] = float('-inf') # never predict
123 |
124 | # joint word, lrb prediction
125 | lprob_word = F.log_softmax(logits_word, -1)
126 | output_word = torch.cat((output_loc.unsqueeze(0).expand(self.hparams.vocab_size, -1),
127 | self.enc.src_word_emb.weight), -1)
128 | logits_lrb = self.lrb(output_word)
129 | logits_lrb[:, length_previous:] = float('-inf') # mask out illegal blank options
130 | max_blank_len = logits_lrb.shape[1]
131 | lprob_lrb = F.log_softmax(logits_lrb, -1)
132 | lprob_word_lrb = lprob_word.unsqueeze(1) + lprob_lrb
133 | word_lrb = select(lprob_word_lrb.view(-1), decode)
134 | word, lrb = word_lrb // max_blank_len, word_lrb % max_blank_len
135 |
136 | lb = lrb
137 | rb = length_previous - lb - 1
138 |
139 | ins = ([Vocab.blank_0 + lb] if lb else []) + [word] + ([Vocab.blank_0 + rb] if rb else [])
140 | ins = torch.LongTensor(ins).to(device)
141 | pos = blanks[loc]
142 | seq = torch.cat((seq[:pos], ins, seq[pos + 1:]))
143 | blanks = [i for i, w in enumerate(seq) if w.item() in self.blank_indices()]
144 | is_fill = is_fill[:pos] + [1] * len(ins) + is_fill[pos + 1:]
145 | fill.append([id for id, isf in zip(seq, is_fill) if isf])
146 | full.append(seq)
147 | return fill, full
148 |
--------------------------------------------------------------------------------
/models/lm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_lightning as pl
4 |
5 | from transformer.Models import Encoder
6 | from optimizer import config_opt_schedule
7 | from vocab import Vocab
8 |
9 |
10 | class LM(pl.LightningModule):
11 | """Language Model Container Class"""
12 |
13 | def __init__(self, hparams):
14 | super().__init__()
15 | self.hparams = hparams
16 | hparams = self.hparams # a['key'] (if so) -> a.key
17 |
18 | self.enc = Encoder(
19 | n_src_vocab=hparams.vocab_size, len_max_seq=hparams.max_len,
20 | d_word_vec=hparams.d_model, d_model=hparams.d_model,
21 | d_inner=hparams.d_inner_hid, d_k=hparams.d_k, d_v=hparams.d_v,
22 | n_layers=hparams.n_layers, n_head=hparams.n_head,
23 | dropout=hparams.dropout)
24 |
25 | self.word = nn.Linear(hparams.d_model, hparams.vocab_size, bias=False)
26 | nn.init.xavier_normal_(self.word.weight)
27 | self.x_logit_scale = 1.
28 | if hparams.share_emb_prj_weight:
29 | self.word.weight = self.enc.src_word_emb.weight
30 | self.x_logit_scale = (hparams.d_model ** -0.5)
31 |
32 | self.loc = nn.Linear(hparams.d_model, 1)
33 |
34 | def configure_optimizers(self):
35 | return config_opt_schedule(self.parameters(), self.hparams)
36 |
37 | def training_step(self, batch, batch_idx):
38 | seq, n, n_real = map(lambda x: x.squeeze(0), batch)
39 | losses = self('losses', seq, n, n_real)
40 | return {**losses, 'log': {**losses}}
41 |
42 | def eval_step(self, batch, batch_idx):
43 | seq, n, n_real = map(lambda x: x.squeeze(0), batch)
44 | losses = self('losses', seq, n, n_real)
45 | if self.hparams.n_mc > 0:
46 | nll = self('nll_mc', seq, n, self.hparams.n_mc).sum()
47 | else:
48 | nll = losses['loss'] * n_real.sum()
49 | n_words = n_real.sum()
50 | return {**losses, 'n_words': n_words, 'nll': nll}
51 |
52 | def eval_epoch_end(self, outputs):
53 | # n_words and nll are batch/dataset sum, other losses are mean
54 | losses = {}
55 | for key in outputs[0].keys():
56 | if key not in ['n_words', 'nll']:
57 | losses[key] = torch.stack([x[key] for x in outputs]).mean()
58 | nll = torch.stack([x['nll'] for x in outputs]).sum()
59 | n_words = torch.stack([x['n_words'] for x in outputs]).sum()
60 | ppl = torch.exp(nll / n_words)
61 | return {**losses, 'nll': nll, 'n_words': n_words, 'ppl': ppl}
62 |
63 | def validation_step(self, batch, batch_idx):
64 | return self.eval_step(batch, batch_idx)
65 |
66 | def validation_epoch_end(self, outputs):
67 | logs = self.eval_epoch_end(outputs)
68 | val_logs = {'val_' + k: v for k, v in logs.items()}
69 | return {'val_loss': logs['loss'], 'log': val_logs}
70 |
71 | def test_step(self, batch, batch_idx):
72 | return self.eval_step(batch, batch_idx)
73 |
74 | def test_epoch_end(self, outputs):
75 | logs = self.eval_epoch_end(outputs)
76 | test_logs = {'test_' + k: v for k, v in logs.items()}
77 | return {'test_loss': logs['loss'], 'log': test_logs}
78 |
79 | def forward_encoder(self, canvas):
80 | pos = (1 + torch.arange(canvas.size(1))).repeat(len(canvas), 1)
81 | pos[canvas == Vocab.pad] = 0
82 | output, *_ = self.enc(canvas, pos.to(canvas.device))
83 | return output
84 |
85 | def forward(self, action, *args):
86 | if action == 'nll_mc':
87 | return self.nll_mc(*args)
88 | elif action == 'losses':
89 | return self.losses(*args)
90 | raise NotImplementedError
91 |
--------------------------------------------------------------------------------
/models/torch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.utils.cpp_extension import load
4 |
5 | from vocab import Vocab
6 |
7 | get_canvas_cpp = load(name='canvas', sources=['models/get_canvas.cpp'])
8 |
9 |
10 | def select(logits, decode):
11 | if decode == 'sample':
12 | return torch.multinomial(logits.exp(), num_samples=1)[0]
13 | else:
14 | return logits.argmax()
15 |
16 |
17 | def seq_cross_entropy(pred, gold, pad):
18 | gold_shape = gold.shape
19 | pred = pred.view(-1, pred.size(-1))
20 | gold = gold.view(-1)
21 | loss = F.cross_entropy(pred, gold, ignore_index=pad, reduction='none')
22 | return loss.view(gold_shape)
23 |
24 |
25 | def new_arange(x, *size):
26 | """
27 | Return a Tensor of `size` filled with a range function on the device of x
28 | If `size` is empty, using the size of the variable x
29 | """
30 | if len(size) == 0:
31 | size = x.size()
32 | return torch.arange(size[-1], device=x.device).expand(*size).contiguous()
33 |
34 |
35 | def batch_randint(start, batch_end):
36 | """
37 | Sample k from start to end (both inclusive)
38 | Return the same shape as batch_end
39 | """
40 | return start + (torch.rand_like(batch_end.float()) * (batch_end - start + 1).float()).long()
41 |
42 |
43 | def sample_permutation(seq):
44 | score = torch.rand_like(seq.float())
45 | score.masked_fill_(seq == Vocab.pad, 1) # always put pads last
46 | score.masked_fill_(seq == Vocab.first, -1) # always keep
47 | score.masked_fill_(seq == Vocab.last, -1) # always keep
48 | indices = score.argsort()
49 | rank = torch.zeros_like(seq)
50 | rank[torch.arange(len(seq)).unsqueeze(1), indices] = \
51 | torch.arange(seq.size(1), device=seq.device)
52 | return rank
53 |
54 |
55 | def collect(input, index, padding_idx=0):
56 | """
57 | Perform a batched index select where index is given for each example
58 | Args:
59 | input: tensor of shape (B, T_1, dim_2, dim_3, ...)
60 | index: tensor of shape (B, T_2)
61 | Return:
62 | tensor of shape (B, T_2, dim_2, dim_3, ...)
63 | """
64 | # Add a column of padding_idx at index 0 (of dim 1)
65 | view = list(input.shape)
66 | view[1] = 1
67 | padding_column = input.new_ones(view) * padding_idx
68 | input = torch.cat([padding_column, input], 1)
69 |
70 | # Expand index to compatible size for gather
71 | for i in range(2, len(input.shape)):
72 | index = index.unsqueeze(i)
73 |
74 | view[0] = -1
75 | view[1] = -1
76 | index = index.expand(view)
77 | return torch.gather(input, 1, index + 1)
78 |
79 |
80 | def to_tensor(x, pad_id, device):
81 | max_len = max([len(xi) for xi in x])
82 | x_ = [xi + [pad_id] * (max_len - len(xi)) for xi in x]
83 | return torch.tensor(x_).to(device)
84 |
85 |
86 | def get_canvas(seq, keep, n):
87 | """
88 | Args:
89 | seq: original (batched) sequence of tokens
90 | keep: mask over seq indicating whether to keep each token
91 | n: number of tokens
92 | Return:
93 | canvas: replace consecutive masked tokens in seq by the token
94 | blanks: indices of tokens in canvas
95 | rest: indices of masked tokens in seq, these are the tokens to predict
96 | loc: indices of how rest relates to blanks
97 | lb: whether to create a left blank for predicting each token in rest
98 | rb: whether to create a right blank for predicting each token in rest
99 | (rest, loc, lb, rb have the same shape)
100 | """
101 | res = get_canvas_cpp.get_canvas(seq.tolist(), keep.tolist(), n.tolist(), Vocab.blank)
102 | pad = [Vocab.pad, -1, -1, -1, -1, -1]
103 | return [to_tensor(r, p, seq.device) for r, p in zip(res, pad)]
104 |
105 |
106 | def get_known_length_canvas(seq, keep, n):
107 | """
108 | Return:
109 | canvas: replace consecutive masked tokens in seq by the token
110 | blanks: indices of tokens in canvas
111 | rest: indices of masked tokens in seq, these are the tokens to predict
112 | loc: indices of how rest relates to blanks
113 | lb: length of the new left blank for predicting each token in rest
114 | (rest, loc, lb have the same shape)
115 | """
116 | res = get_canvas_cpp.get_known_length_canvas(seq.tolist(), keep.tolist(), n.tolist(), Vocab.blank_0)
117 | pad = [Vocab.pad, -1, -1, -1, -1]
118 | return [to_tensor(r, p, seq.device) for r, p in zip(res, pad)]
119 |
120 |
121 | def get_ins_canvas(seq, keep, n):
122 | """
123 | Return:
124 | canvas: remove masked tokens in seq
125 | rest: indices of masked tokens in seq, these are the tokens to predict
126 | loc: indices of how rest relates to canvas
127 | (rest, loc have the same shape)
128 | """
129 | res = get_canvas_cpp.get_insertion_canvas(seq.tolist(), keep.tolist(), n.tolist())
130 | pad = [Vocab.pad, -1, -1]
131 | return [to_tensor(r, p, seq.device) for r, p in zip(res, pad)]
132 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def config_opt_schedule(params, args):
5 | optimizer = torch.optim.Adam(
6 | params,
7 | betas=eval(args.adam_betas),
8 | eps=args.adam_eps,
9 | weight_decay=args.weight_decay,
10 | lr=args.lr
11 | )
12 |
13 | if args.lr_schedule == 'fixed':
14 | return optimizer
15 |
16 | elif args.lr_schedule == 'triangular':
17 | scheduler = torch.optim.lr_scheduler.CyclicLR(
18 | optimizer,
19 | base_lr=0,
20 | max_lr=args.lr,
21 | step_size_up=args.warmup_steps,
22 | step_size_down=args.descend_steps,
23 | cycle_momentum=False,
24 | )
25 | return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
26 |
27 | else:
28 | raise ValueError('Unknown lr schedule ' + args.lr_schedule)
29 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from tqdm import tqdm
4 | import torch
5 | import pytorch_lightning as pl
6 |
7 | from vocab import Vocab
8 | from utils import load_data, load_sent, load_model, makedir, write
9 | from dataset import get_eval_dataloader
10 |
11 |
12 | def main(args):
13 | pl.seed_everything(args.seed)
14 |
15 | model = load_model(args.checkpoint).to(device)
16 | model.eval()
17 | vocab = Vocab(os.path.join(model.hparams.root_dir, 'vocab.txt'))
18 |
19 | if args.eval:
20 | data = load_data(args.eval, model.hparams.add_eos, model.hparams.cat_sent, model.hparams.max_len)
21 | dl = get_eval_dataloader(
22 | data, vocab, args.max_tok,
23 | data_workers=args.data_workers,
24 | model_type=model.hparams.model_type)
25 | trainer = pl.Trainer(
26 | gpus=args.gpus,
27 | amp_level=args.fp16_opt_level,
28 | precision=16 if args.fp16 else 32,
29 | default_root_dir='testing_logs')
30 | model.hparams.n_mc = args.n_mc
31 | trainer.test(model, test_dataloaders=dl)
32 |
33 | if args.output:
34 | output = os.path.join(os.path.dirname(os.path.dirname(args.checkpoint)), 'outputs/', args.output)
35 | makedir(output)
36 |
37 | if args.sample:
38 | with open(output, 'w') as f:
39 | for i in tqdm(range(args.sample)):
40 | if model.hparams.model_type == 'inst':
41 | _, full = model.generate([], [0], args.decode, device)
42 | else:
43 | _, full = model.generate([model.init_canvas()], args.decode, device)
44 |
45 | full = [[vocab.idx2word[id] for id in ids] for ids in full]
46 | write(f, full, args.write_mid)
47 |
48 | if args.fill:
49 | sents = load_sent(args.fill, model.hparams.add_eos)
50 | sents = [[vocab.word_to_idx(w) for w in s] for s in sents]
51 | with open(output + '.fill', 'w') as f_fill:
52 | with open(output + '.full', 'w') as f_full:
53 | for s in tqdm(sents):
54 | if model.hparams.model_type == 'inst':
55 | seq, blanks = [], []
56 | for w in s:
57 | if w == vocab.blank:
58 | blanks.append(len(seq))
59 | else:
60 | seq.append(w)
61 | if args.anywhere:
62 | blanks = list(range(len(seq) + 1))
63 | fill, full = model.generate(seq, blanks, args.decode, device,
64 | args.force_insert, args.prioritize_unfilled)
65 | else:
66 | fill, full = model.generate(s, args.decode, device)
67 |
68 | fill = [[vocab.idx2word[id] for id in ids] for ids in fill]
69 | full = [[vocab.idx2word[id] for id in ids] for ids in full]
70 | write(f_fill, fill, args.write_mid)
71 | write(f_full, full, args.write_mid)
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser()
76 |
77 | parser.add_argument('--checkpoint', required=True,
78 | help='path to checkpoint')
79 |
80 | parser.add_argument('--eval', default='',
81 | help='data file to evaluate')
82 | parser.add_argument('--n_mc', type=int, default=10,
83 | help='num of samples for monte carlo estimate of ppl')
84 | parser.add_argument('--max_tok', type=int, default=40000,
85 | help='max number of tokens per batch')
86 |
87 | parser.add_argument('--output', default='',
88 | help='output file')
89 | parser.add_argument('--sample', type=int, default=0,
90 | help='num of sentences to generate')
91 | parser.add_argument('--fill', default='',
92 | help='input file to fill')
93 | parser.add_argument('--decode', default='greedy',
94 | choices=['greedy', 'sample'],
95 | help='greedy decoding or sampling')
96 | parser.add_argument('--write_mid', action='store_true',
97 | help='write intermediate partial sentences')
98 |
99 | # Specific to InsT
100 | parser.add_argument('--anywhere', action='store_true',
101 | help='fill in anywhere, not only blanks')
102 | parser.add_argument('--force_insert', action='store_true',
103 | help='disable termination unless all slots are filled')
104 | parser.add_argument('--prioritize_unfilled', action='store_true',
105 | help='prioritize unfilled slots if any')
106 |
107 | parser.add_argument('--seed', type=int, default=1111,
108 | help='random seed')
109 | parser.add_argument('--data_workers', type=int, default=8,
110 | help='data workers')
111 | parser.add_argument('--no_cuda', action='store_true',
112 | help='disable CUDA')
113 | parser.add_argument('--fp16', action='store_true',
114 | help='whether to use 16-bit (mixed) precision '
115 | '(through NVIDIA apex) instead of 32-bit')
116 | parser.add_argument('--fp16_opt_level', default='O1',
117 | help="for fp16: Apex AMP optimization level selected "
118 | "in ['O0', 'O1', 'O2', and 'O3']. see details at "
119 | "https://nvidia.github.io/apex/amp.html")
120 |
121 | args = parser.parse_args()
122 |
123 | cuda = not args.no_cuda and torch.cuda.is_available()
124 | device = torch.device('cuda' if cuda else 'cpu')
125 | args.gpus = 1 if cuda else 0
126 |
127 | main(args)
128 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 |
5 | import pytorch_lightning as pl
6 | from pytorch_lightning.callbacks import LearningRateMonitor
7 |
8 | from models import get_model_class
9 | from vocab import Vocab
10 | from utils import load_data
11 | from dataset import get_train_dataloader, get_eval_dataloader
12 |
13 |
14 | def main(args):
15 | pl.seed_everything(args.seed)
16 |
17 | torch.multiprocessing.set_sharing_strategy('file_system')
18 |
19 | args.multigpu = torch.cuda.device_count() > 1
20 |
21 | train_data = load_data(args.train, args.add_eos, args.cat_sent, args.max_len)
22 | valid_data = load_data(args.valid, args.add_eos, args.cat_sent, args.max_len)
23 |
24 | os.makedirs(args.root_dir, exist_ok=True)
25 |
26 | vocab_file = os.path.join(args.root_dir, 'vocab.txt')
27 | if not os.path.isfile(vocab_file):
28 | max_blank_len = args.max_len if args.model_type == 'lblm' else None
29 | Vocab.build(train_data, vocab_file, args.vocab_size, max_blank_len)
30 | vocab = Vocab(vocab_file)
31 | args.vocab_size = vocab.size
32 |
33 | train_dl = get_train_dataloader(
34 | train_data, vocab, args.max_tok,
35 | data_workers=args.data_workers if not args.multigpu else 0,
36 | model_type=args.model_type)
37 | val_dl = get_eval_dataloader(
38 | valid_data, vocab, args.eval_max_tok,
39 | data_workers=args.data_workers if not args.multigpu else 0,
40 | model_type=args.model_type)
41 |
42 | model = get_model_class(args.model_type)(args)
43 |
44 | trainer = pl.Trainer(
45 | accumulate_grad_batches=args.accum_grad,
46 | max_steps=args.max_steps,
47 | callbacks=[LearningRateMonitor()] if args.lr_schedule != 'fixed' else None,
48 | val_check_interval=args.val_check_interval if args.val_check_interval > 0 else 1.0,
49 | gpus=args.gpus,
50 | distributed_backend='ddp' if args.multigpu else None,
51 | amp_level=args.fp16_opt_level,
52 | precision=16 if args.fp16 else 32,
53 | default_root_dir=args.root_dir,
54 | resume_from_checkpoint=args.load_checkpoint
55 | )
56 |
57 | trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
58 |
59 |
60 | if __name__ == '__main__':
61 | parser = argparse.ArgumentParser()
62 |
63 | # Path
64 | parser.add_argument('--train',
65 | help='path to training file')
66 | parser.add_argument('--valid',
67 | help='path to validation file')
68 | parser.add_argument('--root_dir', default='checkpoints',
69 | help='directory to save checkpoints and outputs')
70 | parser.add_argument('--load_checkpoint', default=None,
71 | help='path to load checkpoint if specified')
72 |
73 | # Data
74 | parser.add_argument('--vocab_size', type=int, default=10000,
75 | help='keep N most frequent words in vocabulary')
76 | parser.add_argument('--max_len', type=int, default=512,
77 | help='max sequence length')
78 | parser.add_argument('--cat_sent', action='store_true',
79 | help='concat sentences and chunk into size of max_len')
80 | parser.add_argument('--add_eos', action='store_true',
81 | help='add at the end of each sentence')
82 |
83 | # Model
84 | parser.add_argument('--model_type', default='blm',
85 | choices=['blm', 'inst', 'lblm'],
86 | help='model type: blm, inst or lblm')
87 |
88 | parser.add_argument('--d_model', type=int, default=512,
89 | help='transformer dimension d_model')
90 | parser.add_argument('--d_inner_hid', type=int, default=2048,
91 | help='transformer dimension d_inner_hid')
92 | parser.add_argument('--d_k', type=int, default=64,
93 | help='transformer dimension d_k')
94 | parser.add_argument('--d_v', type=int, default=64,
95 | help='transformer dimension d_v')
96 | parser.add_argument('--n_head', type=int, default=8,
97 | help='number of attention heads')
98 | parser.add_argument('--n_layers', type=int, default=6,
99 | help='number of layers')
100 | parser.add_argument('--share_emb_prj_weight', action='store_true',
101 | help='share word embedding and projection weights')
102 |
103 | # Optimization
104 | parser.add_argument('--max_tok', type=int, default=10000,
105 | help='max number of tokens per batch')
106 | parser.add_argument('--accum_grad', type=int, default=1,
107 | help='accumulate gradients across N batches.')
108 |
109 | parser.add_argument('--adam_betas', default='(0.9, 0.999)',
110 | help='adam betas')
111 | parser.add_argument('--adam_eps', type=float, default=1e-8,
112 | help='adam eps')
113 | parser.add_argument('--weight_decay', type=float, default=1e-5,
114 | help='weight decay')
115 | parser.add_argument('--dropout', type=float, default=0.3,
116 | help='dropout probability (0 = no dropout)')
117 |
118 | parser.add_argument('--lr_schedule', default='fixed',
119 | choices=['fixed', 'triangular'],
120 | help='learning rate schedule')
121 | parser.add_argument('--lr', type=float, default=0.0001,
122 | help='learning rate')
123 | parser.add_argument('--warmup_steps', type=int, default=4000,
124 | help='number of warmup steps (triangular)')
125 | parser.add_argument('--descend_steps', type=int, default=300000,
126 | help='number of descending steps (triangular)')
127 | parser.add_argument('--max_steps', type=int, default=500000,
128 | help='number of training steps')
129 |
130 | # Validation
131 | parser.add_argument('--eval_max_tok', type=int, default=40000,
132 | help='max number of tokens per batch for evaluation')
133 | parser.add_argument('--val_check_interval', type=int, default=0,
134 | help='check validation set every N training batches'
135 | '(0 means checking once an epoch)')
136 | parser.add_argument('--n_mc', type=int, default=1,
137 | help='num of samples for Monte Carlo estimate of ppl')
138 |
139 | # Others
140 | parser.add_argument('--seed', type=int, default=1111,
141 | help='random seed')
142 | parser.add_argument('--data_workers', type=int, default=8,
143 | help='data workers')
144 | parser.add_argument('--gpus', type=int, default=-1,
145 | help='number of gpus to train on (-1 means all gpus)')
146 | parser.add_argument('--fp16', action='store_true',
147 | help='whether to use 16-bit (mixed) precision '
148 | '(through NVIDIA apex) instead of 32-bit')
149 | parser.add_argument('--fp16_opt_level', default='O1',
150 | help="for fp16: Apex AMP optimization level selected "
151 | "in ['O0', 'O1', 'O2', and 'O3']. see details at "
152 | "https://nvidia.github.io/apex/amp.html")
153 |
154 | args = parser.parse_args()
155 |
156 | main(args)
157 |
--------------------------------------------------------------------------------
/transformer/Beam.py:
--------------------------------------------------------------------------------
1 | """ Manage beam search info structure.
2 |
3 | Heavily borrowed from OpenNMT-py.
4 | For code in OpenNMT-py, please check the following link:
5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
6 | """
7 |
8 | import torch
9 | import numpy as np
10 | import transformer.Constants as Constants
11 |
12 | class Beam():
13 | ''' Beam search '''
14 |
15 | def __init__(self, size, device=False):
16 |
17 | self.size = size
18 | self._done = False
19 |
20 | # The score for each translation on the beam.
21 | self.scores = torch.zeros((size,), dtype=torch.float, device=device)
22 | self.all_scores = []
23 |
24 | # The backpointers at each time-step.
25 | self.prev_ks = []
26 |
27 | # The outputs at each time-step.
28 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)]
29 | self.next_ys[0][0] = Constants.BOS
30 |
31 | def get_current_state(self):
32 | "Get the outputs for the current timestep."
33 | return self.get_tentative_hypothesis()
34 |
35 | def get_current_origin(self):
36 | "Get the backpointers for the current timestep."
37 | return self.prev_ks[-1]
38 |
39 | @property
40 | def done(self):
41 | return self._done
42 |
43 | def advance(self, word_prob):
44 | "Update beam status and check if finished or not."
45 | num_words = word_prob.size(1)
46 |
47 | # Sum the previous scores.
48 | if len(self.prev_ks) > 0:
49 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
50 | else:
51 | beam_lk = word_prob[0]
52 |
53 | flat_beam_lk = beam_lk.view(-1)
54 |
55 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
56 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort
57 |
58 | self.all_scores.append(self.scores)
59 | self.scores = best_scores
60 |
61 | # bestScoresId is flattened as a (beam x word) array,
62 | # so we need to calculate which word and beam each score came from
63 | prev_k = best_scores_id / num_words
64 | self.prev_ks.append(prev_k)
65 | self.next_ys.append(best_scores_id - prev_k * num_words)
66 |
67 | # End condition is when top-of-beam is EOS.
68 | if self.next_ys[-1][0].item() == Constants.EOS:
69 | self._done = True
70 | self.all_scores.append(self.scores)
71 |
72 | return self._done
73 |
74 | def sort_scores(self):
75 | "Sort the scores."
76 | return torch.sort(self.scores, 0, True)
77 |
78 | def get_the_best_score_and_idx(self):
79 | "Get the score of the best in the beam."
80 | scores, ids = self.sort_scores()
81 | return scores[1], ids[1]
82 |
83 | def get_tentative_hypothesis(self):
84 | "Get the decoded sequence for the current timestep."
85 |
86 | if len(self.next_ys) == 1:
87 | dec_seq = self.next_ys[0].unsqueeze(1)
88 | else:
89 | _, keys = self.sort_scores()
90 | hyps = [self.get_hypothesis(k) for k in keys]
91 | hyps = [[Constants.BOS] + h for h in hyps]
92 | dec_seq = torch.LongTensor(hyps)
93 |
94 | return dec_seq
95 |
96 | def get_hypothesis(self, k):
97 | """ Walk back to construct the full hypothesis. """
98 | hyp = []
99 | for j in range(len(self.prev_ks) - 1, -1, -1):
100 | hyp.append(self.next_ys[j+1][k])
101 | k = self.prev_ks[j][k]
102 |
103 | return list(map(lambda x: x.item(), hyp[::-1]))
104 |
--------------------------------------------------------------------------------
/transformer/Constants.py:
--------------------------------------------------------------------------------
1 |
2 | PAD = 0
3 | UNK = 1
4 | BOS = 2
5 | EOS = 3
6 |
7 | #PAD_WORD = ''
8 | #UNK_WORD = ''
9 | #BOS_WORD = ''
10 | #EOS_WORD = ''
11 |
--------------------------------------------------------------------------------
/transformer/Layers.py:
--------------------------------------------------------------------------------
1 | ''' Define the Layers '''
2 | import torch.nn as nn
3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward
4 |
5 | __author__ = "Yu-Hsiang Huang"
6 |
7 |
8 | class EncoderLayer(nn.Module):
9 | ''' Compose with two layers '''
10 |
11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
12 | super(EncoderLayer, self).__init__()
13 | self.slf_attn = MultiHeadAttention(
14 | n_head, d_model, d_k, d_v, dropout=dropout)
15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
16 |
17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
18 | enc_output, enc_slf_attn = self.slf_attn(
19 | enc_input, enc_input, enc_input, mask=slf_attn_mask)
20 | enc_output *= non_pad_mask
21 |
22 | enc_output = self.pos_ffn(enc_output)
23 | enc_output *= non_pad_mask
24 |
25 | return enc_output, enc_slf_attn
26 |
27 |
28 | class DecoderLayer(nn.Module):
29 | ''' Compose with three layers '''
30 |
31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
32 | super(DecoderLayer, self).__init__()
33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
36 |
37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
38 | dec_output, dec_slf_attn = self.slf_attn(
39 | dec_input, dec_input, dec_input, mask=slf_attn_mask)
40 | dec_output *= non_pad_mask
41 |
42 | dec_output, dec_enc_attn = self.enc_attn(
43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
44 | dec_output *= non_pad_mask
45 |
46 | dec_output = self.pos_ffn(dec_output)
47 | dec_output *= non_pad_mask
48 |
49 | return dec_output, dec_slf_attn, dec_enc_attn
50 |
--------------------------------------------------------------------------------
/transformer/Models.py:
--------------------------------------------------------------------------------
1 | ''' Define the Transformer model '''
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import transformer.Constants as Constants
6 | from transformer.Layers import EncoderLayer, DecoderLayer
7 |
8 | __author__ = "Yu-Hsiang Huang"
9 |
10 | def get_non_pad_mask(seq):
11 | assert seq.dim() == 2
12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
13 |
14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
15 | ''' Sinusoid position encoding table '''
16 |
17 | def cal_angle(position, hid_idx):
18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
19 |
20 | def get_posi_angle_vec(position):
21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
22 |
23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
24 |
25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
27 |
28 | if padding_idx is not None:
29 | # zero vector for padding dimension
30 | sinusoid_table[padding_idx] = 0.
31 |
32 | return torch.FloatTensor(sinusoid_table)
33 |
34 | def get_attn_key_pad_mask(seq_k, seq_q):
35 | ''' For masking out the padding part of key sequence. '''
36 |
37 | # Expand to fit the shape of key query attention matrix.
38 | len_q = seq_q.size(1)
39 | padding_mask = seq_k.eq(Constants.PAD)
40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
41 |
42 | return padding_mask
43 |
44 | def get_subsequent_mask(seq):
45 | ''' For masking out the subsequent info. '''
46 |
47 | sz_b, len_s = seq.size()
48 | subsequent_mask = torch.triu(
49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
51 |
52 | return subsequent_mask
53 |
54 | class Encoder(nn.Module):
55 | ''' A encoder model with self attention mechanism. '''
56 |
57 | def __init__(
58 | self,
59 | n_src_vocab, len_max_seq, d_word_vec,
60 | n_layers, n_head, d_k, d_v,
61 | d_model, d_inner, dropout=0.1):
62 |
63 | super().__init__()
64 |
65 | n_position = len_max_seq + 1
66 |
67 | self.src_word_emb = nn.Embedding(
68 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD)
69 |
70 | self.position_enc = nn.Embedding.from_pretrained(
71 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
72 | freeze=True)
73 |
74 | self.layer_stack = nn.ModuleList([
75 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
76 | for _ in range(n_layers)])
77 |
78 | def forward(self, src_seq, src_pos, return_attns=False):
79 |
80 | enc_slf_attn_list = []
81 |
82 | # -- Prepare masks
83 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
84 | non_pad_mask = get_non_pad_mask(src_seq)
85 |
86 | # -- Forward
87 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
88 |
89 | for enc_layer in self.layer_stack:
90 | enc_output, enc_slf_attn = enc_layer(
91 | enc_output,
92 | non_pad_mask=non_pad_mask,
93 | slf_attn_mask=slf_attn_mask)
94 | if return_attns:
95 | enc_slf_attn_list += [enc_slf_attn]
96 |
97 | if return_attns:
98 | return enc_output, enc_slf_attn_list
99 | return enc_output,
100 |
101 | class Decoder(nn.Module):
102 | ''' A decoder model with self attention mechanism. '''
103 |
104 | def __init__(
105 | self,
106 | n_tgt_vocab, len_max_seq, d_word_vec,
107 | n_layers, n_head, d_k, d_v,
108 | d_model, d_inner, dropout=0.1):
109 |
110 | super().__init__()
111 | n_position = len_max_seq + 1
112 |
113 | self.tgt_word_emb = nn.Embedding(
114 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD)
115 |
116 | self.position_enc = nn.Embedding.from_pretrained(
117 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
118 | freeze=True)
119 |
120 | self.layer_stack = nn.ModuleList([
121 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
122 | for _ in range(n_layers)])
123 |
124 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):
125 |
126 | dec_slf_attn_list, dec_enc_attn_list = [], []
127 |
128 | # -- Prepare masks
129 | non_pad_mask = get_non_pad_mask(tgt_seq)
130 |
131 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
132 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
133 | slf_attn_mask = (slf_attn_mask_keypad.type_as(slf_attn_mask_subseq) + slf_attn_mask_subseq).gt(0)
134 |
135 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)
136 |
137 | # -- Forward
138 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos)
139 |
140 | for dec_layer in self.layer_stack:
141 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
142 | dec_output, enc_output,
143 | non_pad_mask=non_pad_mask,
144 | slf_attn_mask=slf_attn_mask,
145 | dec_enc_attn_mask=dec_enc_attn_mask)
146 |
147 | if return_attns:
148 | dec_slf_attn_list += [dec_slf_attn]
149 | dec_enc_attn_list += [dec_enc_attn]
150 |
151 | if return_attns:
152 | return dec_output, dec_slf_attn_list, dec_enc_attn_list
153 | return dec_output,
154 |
155 | class Transformer(nn.Module):
156 | ''' A sequence to sequence model with attention mechanism. '''
157 |
158 | def __init__(
159 | self,
160 | n_src_vocab, n_tgt_vocab, len_max_seq,
161 | d_word_vec=512, d_model=512, d_inner=2048,
162 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1,
163 | tgt_emb_prj_weight_sharing=True,
164 | emb_src_tgt_weight_sharing=True):
165 |
166 | super().__init__()
167 |
168 | self.encoder = Encoder(
169 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq,
170 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
171 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
172 | dropout=dropout)
173 |
174 | self.decoder = Decoder(
175 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq,
176 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
177 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
178 | dropout=dropout)
179 |
180 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
181 | nn.init.xavier_normal_(self.tgt_word_prj.weight)
182 |
183 | assert d_model == d_word_vec, \
184 | 'To facilitate the residual connections, \
185 | the dimensions of all module outputs shall be the same.'
186 |
187 | if tgt_emb_prj_weight_sharing:
188 | # Share the weight matrix between target word embedding & the final logit dense layer
189 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
190 | self.x_logit_scale = (d_model ** -0.5)
191 | else:
192 | self.x_logit_scale = 1.
193 |
194 | if emb_src_tgt_weight_sharing:
195 | # Share the weight matrix between source & target word embeddings
196 | assert n_src_vocab == n_tgt_vocab, \
197 | "To share word embedding table, the vocabulary size of src/tgt shall be the same."
198 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
199 |
200 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):
201 |
202 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1]
203 |
204 | enc_output, *_ = self.encoder(src_seq, src_pos)
205 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
206 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale
207 |
208 | return seq_logit.view(-1, seq_logit.size(2))
209 |
--------------------------------------------------------------------------------
/transformer/Modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | __author__ = "Yu-Hsiang Huang"
6 |
7 | class ScaledDotProductAttention(nn.Module):
8 | ''' Scaled Dot-Product Attention '''
9 |
10 | def __init__(self, temperature, attn_dropout=0.1):
11 | super().__init__()
12 | self.temperature = temperature
13 | self.dropout = nn.Dropout(attn_dropout)
14 | self.softmax = nn.Softmax(dim=2)
15 |
16 | def forward(self, q, k, v, mask=None):
17 |
18 | attn = torch.bmm(q, k.transpose(1, 2))
19 | attn = attn / self.temperature
20 |
21 | if mask is not None:
22 | attn = attn.masked_fill(mask, -np.inf)
23 |
24 | attn = self.softmax(attn)
25 | attn = self.dropout(attn)
26 | output = torch.bmm(attn, v)
27 |
28 | return output, attn
29 |
--------------------------------------------------------------------------------
/transformer/Optim.py:
--------------------------------------------------------------------------------
1 | '''A wrapper class for optimizer '''
2 |
3 | class LRScheduler(object):
4 |
5 | def __init__(self, optimizer, lr):
6 | self._optimizer = optimizer
7 | self.lr = lr
8 | self.set_lr()
9 |
10 | def step(self):
11 | self._optimizer.step()
12 |
13 | def zero_grad(self):
14 | self._optimizer.zero_grad()
15 |
16 | def set_lr(self):
17 | for param_group in self._optimizer.param_groups:
18 | param_group['lr'] = self.lr
19 |
20 |
21 | class InverseSqrtScheduler(LRScheduler):
22 |
23 | def __init__(self, optimizer, peak_lr, warmup_steps):
24 | super().__init__(optimizer, 0)
25 |
26 | self.warmup_steps = warmup_steps
27 | self.current_step = 0
28 | # linearly warmup for the first warmup_steps
29 | self.warmup_factor = peak_lr / warmup_steps
30 | # then, decay prop. to the inverse square root of the step number
31 | self.decay_factor = peak_lr * warmup_steps**0.5
32 |
33 | def step(self):
34 | self._update_learning_rate()
35 | super().step()
36 |
37 | def _update_learning_rate(self):
38 | self.current_step += 1
39 | if self.current_step < self.warmup_steps:
40 | self.lr = self.warmup_factor * self.current_step
41 | else:
42 | self.lr = self.decay_factor * self.current_step**-0.5
43 | self.set_lr()
44 |
45 |
46 | class LinearDecayScheduler(LRScheduler):
47 |
48 | def __init__(self, optimizer, peak_lr, warmup_steps, total_steps):
49 | super().__init__(optimizer, 0)
50 |
51 | self.warmup_steps = warmup_steps
52 | self.total_steps = total_steps
53 | self.current_step = 0
54 | # linearly warmup for the first warmup_steps
55 | self.warmup_factor = peak_lr / warmup_steps
56 | # then, linearly decay to 0
57 | self.decay_factor = peak_lr / (total_steps - warmup_steps)
58 |
59 | def step(self):
60 | self._update_learning_rate()
61 | super().step()
62 |
63 | def _update_learning_rate(self):
64 | self.current_step += 1
65 | if self.current_step < self.warmup_steps:
66 | self.lr = self.warmup_factor * self.current_step
67 | elif self.current_step < self.total_steps:
68 | self.lr = self.decay_factor * (self.total_steps - self.current_step)
69 | else:
70 | self.lr = 0
71 | self.set_lr()
72 |
--------------------------------------------------------------------------------
/transformer/SubLayers.py:
--------------------------------------------------------------------------------
1 | ''' Define the sublayers in encoder/decoder layer '''
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from transformer.Modules import ScaledDotProductAttention
6 |
7 | __author__ = "Yu-Hsiang Huang"
8 |
9 | class MultiHeadAttention(nn.Module):
10 | ''' Multi-Head Attention module '''
11 |
12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
13 | super().__init__()
14 |
15 | self.n_head = n_head
16 | self.d_k = d_k
17 | self.d_v = d_v
18 |
19 | self.w_qs = nn.Linear(d_model, n_head * d_k)
20 | self.w_ks = nn.Linear(d_model, n_head * d_k)
21 | self.w_vs = nn.Linear(d_model, n_head * d_v)
22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
25 |
26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
27 | self.layer_norm = nn.LayerNorm(d_model)
28 |
29 | self.fc = nn.Linear(n_head * d_v, d_model)
30 | nn.init.xavier_normal_(self.fc.weight)
31 |
32 | self.dropout = nn.Dropout(dropout)
33 |
34 |
35 | def forward(self, q, k, v, mask=None):
36 |
37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
38 |
39 | sz_b, len_q, _ = q.size()
40 | sz_b, len_k, _ = k.size()
41 | sz_b, len_v, _ = v.size()
42 |
43 | residual = q
44 |
45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
48 |
49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
52 |
53 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
54 | output, attn = self.attention(q, k, v, mask=mask)
55 |
56 | output = output.view(n_head, sz_b, len_q, d_v)
57 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
58 |
59 | output = self.dropout(self.fc(output))
60 | output = self.layer_norm(output + residual)
61 |
62 | return output, attn
63 |
64 | class PositionwiseFeedForward(nn.Module):
65 | ''' A two-feed-forward-layer module '''
66 |
67 | def __init__(self, d_in, d_hid, dropout=0.1):
68 | super().__init__()
69 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
70 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
71 | self.layer_norm = nn.LayerNorm(d_in)
72 | self.dropout = nn.Dropout(dropout)
73 |
74 | def forward(self, x):
75 | residual = x
76 | output = x.transpose(1, 2)
77 | output = self.w_2(F.relu(self.w_1(output)))
78 | output = output.transpose(1, 2)
79 | output = self.dropout(output)
80 | output = self.layer_norm(output + residual)
81 | return output
82 |
--------------------------------------------------------------------------------
/transformer/Translator.py:
--------------------------------------------------------------------------------
1 | ''' This module will handle the text generation with beam search. '''
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from transformer.Models import Transformer
8 | from transformer.Beam import Beam
9 |
10 | class Translator(object):
11 | ''' Load with trained model and handle the beam search '''
12 |
13 | def __init__(self, opt):
14 | self.opt = opt
15 | self.device = torch.device('cuda' if opt.cuda else 'cpu')
16 |
17 | checkpoint = torch.load(opt.model)
18 | model_opt = checkpoint['settings']
19 | self.model_opt = model_opt
20 |
21 | model = Transformer(
22 | model_opt.src_vocab_size,
23 | model_opt.tgt_vocab_size,
24 | model_opt.max_token_seq_len,
25 | tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
26 | emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
27 | d_k=model_opt.d_k,
28 | d_v=model_opt.d_v,
29 | d_model=model_opt.d_model,
30 | d_word_vec=model_opt.d_word_vec,
31 | d_inner=model_opt.d_inner_hid,
32 | n_layers=model_opt.n_layers,
33 | n_head=model_opt.n_head,
34 | dropout=model_opt.dropout)
35 |
36 | model.load_state_dict(checkpoint['model'])
37 | print('[Info] Trained model state loaded.')
38 |
39 | model.word_prob_prj = nn.LogSoftmax(dim=1)
40 |
41 | model = model.to(self.device)
42 |
43 | self.model = model
44 | self.model.eval()
45 |
46 | def translate_batch(self, src_seq, src_pos):
47 | ''' Translation work in one batch '''
48 |
49 | def get_inst_idx_to_tensor_position_map(inst_idx_list):
50 | ''' Indicate the position of an instance in a tensor. '''
51 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
52 |
53 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
54 | ''' Collect tensor parts associated to active instances. '''
55 |
56 | _, *d_hs = beamed_tensor.size()
57 | n_curr_active_inst = len(curr_active_inst_idx)
58 | new_shape = (n_curr_active_inst * n_bm, *d_hs)
59 |
60 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
61 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
62 | beamed_tensor = beamed_tensor.view(*new_shape)
63 |
64 | return beamed_tensor
65 |
66 | def collate_active_info(
67 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
68 | # Sentences which are still active are collected,
69 | # so the decoder will not run on completed sentences.
70 | n_prev_active_inst = len(inst_idx_to_position_map)
71 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
72 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)
73 |
74 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
75 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
76 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
77 |
78 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map
79 |
80 | def beam_decode_step(
81 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):
82 | ''' Decode and update beam status, and then return active beam idx '''
83 |
84 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
85 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
86 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
87 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
88 | return dec_partial_seq
89 |
90 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
91 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device)
92 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
93 | return dec_partial_pos
94 |
95 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm):
96 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output)
97 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h
98 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1)
99 | word_prob = word_prob.view(n_active_inst, n_bm, -1)
100 |
101 | return word_prob
102 |
103 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
104 | active_inst_idx_list = []
105 | for inst_idx, inst_position in inst_idx_to_position_map.items():
106 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
107 | if not is_inst_complete:
108 | active_inst_idx_list += [inst_idx]
109 |
110 | return active_inst_idx_list
111 |
112 | n_active_inst = len(inst_idx_to_position_map)
113 |
114 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
115 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
116 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm)
117 |
118 | # Update the beam with predicted word prob information and collect incomplete instances
119 | active_inst_idx_list = collect_active_inst_idx_list(
120 | inst_dec_beams, word_prob, inst_idx_to_position_map)
121 |
122 | return active_inst_idx_list
123 |
124 | def collect_hypothesis_and_scores(inst_dec_beams, n_best):
125 | all_hyp, all_scores = [], []
126 | for inst_idx in range(len(inst_dec_beams)):
127 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
128 | all_scores += [scores[:n_best]]
129 |
130 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
131 | all_hyp += [hyps]
132 | return all_hyp, all_scores
133 |
134 | with torch.no_grad():
135 | #-- Encode
136 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
137 | src_enc, *_ = self.model.encoder(src_seq, src_pos)
138 |
139 | #-- Repeat data for beam search
140 | n_bm = self.opt.beam_size
141 | n_inst, len_s, d_h = src_enc.size()
142 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
143 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
144 |
145 | #-- Prepare beams
146 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)]
147 |
148 | #-- Bookkeeping for active or not
149 | active_inst_idx_list = list(range(n_inst))
150 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
151 |
152 | #-- Decode
153 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):
154 |
155 | active_inst_idx_list = beam_decode_step(
156 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)
157 |
158 | if not active_inst_idx_list:
159 | break # all instances have finished their path to
160 |
161 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
162 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list)
163 |
164 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best)
165 |
166 | return batch_hyp, batch_scores
167 |
--------------------------------------------------------------------------------
/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | import transformer.Constants
2 | import transformer.Modules
3 | import transformer.Layers
4 | import transformer.SubLayers
5 | import transformer.Models
6 | import transformer.Translator
7 | import transformer.Beam
8 | import transformer.Optim
9 |
10 | __all__ = [
11 | transformer.Constants, transformer.Modules, transformer.Layers,
12 | transformer.SubLayers, transformer.Models, transformer.Optim,
13 | transformer.Translator, transformer.Beam]
14 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 | from models import get_model_class
5 |
6 |
7 | def strip_eos(sents):
8 | return [sent[:sent.index('')] if '' in sent else sent
9 | for sent in sents]
10 |
11 |
12 | def makedir(path):
13 | dir = os.path.dirname(path)
14 | if dir:
15 | os.makedirs(dir, exist_ok=True)
16 |
17 |
18 | def repeat(f, x, n):
19 | for i in range(n):
20 | x = f(x)
21 | return x
22 |
23 |
24 | def get_hparams(checkpoint):
25 | hparams_file = os.path.join(os.path.dirname(os.path.dirname(checkpoint)), 'hparams.yaml')
26 | with open(hparams_file) as stream:
27 | return yaml.safe_load(stream)
28 |
29 |
30 | def load_model(checkpoint):
31 | hparams = get_hparams(checkpoint)
32 | model = get_model_class(hparams['model_type']).load_from_checkpoint(checkpoint, hparams=hparams)
33 | model.hparams.root_dir = repeat(lambda x: os.path.dirname(x), checkpoint, 4)
34 | return model
35 |
36 |
37 | def load_sent(path, add_eos=False):
38 | sents = []
39 | with open(path) as f:
40 | for line in f:
41 | s = line.split()
42 | if add_eos:
43 | s += ['']
44 | sents.append(s)
45 | return sents
46 |
47 |
48 | def load_data(path, add_eos=False, cat_sent=False, max_len=512):
49 | if not add_eos:
50 | print('WARNING: You should always use add_eos to get comparable PPL on'
51 | 'language modeling tasks')
52 |
53 | sents = load_sent(path, add_eos)
54 | if cat_sent:
55 | if not add_eos:
56 | raise ValueError('Using cat_sent without add_eos')
57 | d = [w for s in sents for w in s]
58 | data = [d[i: i + max_len] for i in range(0, len(d), max_len)]
59 | else:
60 | print('# truncated sentences:',
61 | sum(1 for s in sents if len(s) > max_len))
62 | data = [s[:max_len] for s in sents]
63 | return data
64 |
65 |
66 | def write(file, sents, write_mid):
67 | sents = strip_eos(sents)
68 | if write_mid:
69 | for s in sents:
70 | file.write(' '.join(s) + '\n')
71 | file.write('\n')
72 | else:
73 | file.write(' '.join(sents[-1]) + '\n')
74 | file.flush()
75 |
--------------------------------------------------------------------------------
/vocab.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 |
3 |
4 | class Vocab(object):
5 | def __init__(self, path):
6 | self.word2idx = {}
7 | self.idx2word = []
8 |
9 | with open(path) as f:
10 | for line in f:
11 | w = line.split()[0]
12 | self.word2idx[w] = len(self.word2idx)
13 | self.idx2word.append(w)
14 | self.size = len(self.word2idx)
15 |
16 | pad, unk, first, last, eos, blank, blank_0 = range(7)
17 |
18 | @staticmethod
19 | def build(sents, path, size, max_blank_len=None):
20 | voc = ['', '', '', '', '', '', '']
21 | if max_blank_len:
22 | voc += [''.format(i) for i in range(1, max_blank_len)]
23 | occ = [0 for _ in voc]
24 |
25 | cnt = Counter([w for s in sents for w in s])
26 | for i, v in enumerate(voc):
27 | if v in cnt:
28 | occ[i] = cnt[v]
29 | del cnt[v]
30 | for v, o in cnt.most_common(size):
31 | voc.append(v)
32 | occ.append(o)
33 | for v, o in cnt.most_common()[size:]:
34 | occ[Vocab.unk] += o
35 |
36 | with open(path, 'w') as f:
37 | for v, o in zip(voc, occ):
38 | f.write('{}\t{}\n'.format(v, o))
39 |
40 | def word_to_idx(self, word):
41 | return self.word2idx[word] if word in self.word2idx else Vocab.unk
42 |
--------------------------------------------------------------------------------