├── .gitignore ├── LICENSE ├── Layers.py ├── Model.py ├── README.md ├── train.py ├── transformer_xl ├── Layers.py ├── Transformer_xl.py └── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | .idea/ 4 | Transformer-master/ 5 | 6 | 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | # *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2019] [cyk1337] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # -*- encoding: utf-8 4 | 5 | ''' 6 | _____.___._______________ __.____ __________ _________ ___ ___ _____ .___ 7 | \__ | |\_ _____/ |/ _| | \ \ \_ ___ \ / | \ / _ \ | | 8 | / | | | __)_| < | | / | \ / \ \// ~ \/ /_\ \| | 9 | \____ | | \ | \| | / | \ \ \___\ Y / | \ | 10 | / ______|/_______ /____|__ \______/\____|__ / \______ /\___|_ /\____|__ /___| 11 | \/ \/ \/ \/ \/ \/ \/ 12 | 13 | ========================================================================================== 14 | 15 | @author: Yekun Chai 16 | 17 | @license: School of Informatics, Edinburgh 18 | 19 | @contact: chaiyekun@gmail.com 20 | 21 | @file: Layers.py 22 | 23 | @time: 29/09/2019 20:51 24 | 25 | @desc: 26 | 27 | ''' 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | from torch.autograd import Variable 32 | 33 | import matplotlib.pyplot as plt 34 | import math 35 | import numpy as np 36 | 37 | import utils 38 | 39 | 40 | class LayerNorm(nn.Module): 41 | """ layer norm""" 42 | 43 | def __init__(self, features, eps=1e-6): 44 | super(LayerNorm, self).__init__() 45 | self.weight = nn.Parameter(torch.ones(features)) 46 | self.bias = nn.Parameter(torch.zeros(features)) 47 | self.eps = eps 48 | 49 | def forward(self, x): 50 | mean = x.mean(-1, keepdim=True) 51 | std = x.std(-1, keepdim=True) 52 | return self.weight * (x - mean) / (std + self.eps) + self.bias 53 | 54 | 55 | class SublayerConnection(nn.Module): 56 | """ 57 | a residual connection followed by a layer norm 58 | """ 59 | 60 | def __init__(self, size, dropout): 61 | super(SublayerConnection, self).__init__() 62 | self.norm = LayerNorm(size) 63 | self.dropout = nn.Dropout(dropout) 64 | 65 | def forward(self, x, sublayer): 66 | """ Apply residual connection to any sublayer with the same size""" 67 | return x + self.dropout(sublayer(self.norm(x))) 68 | 69 | 70 | class EncoderLayer(nn.Module): 71 | """ encoder consists of a self-attn and ffc""" 72 | 73 | def __init__(self, size, self_attn, feed_forward, dropout): 74 | super(EncoderLayer, self).__init__() 75 | self.self_attn = self_attn 76 | self.feed_forward = feed_forward 77 | self.sublayer = utils.clones(SublayerConnection(size, dropout), 2) 78 | self.size = size 79 | self.local_rnn = LocalRNNLayer(size, dropout) 80 | 81 | def forward(self, x, mask): 82 | x = self.local_rnn(x) 83 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 84 | return self.sublayer[1](x, self.feed_forward) 85 | 86 | 87 | class DecoderLayer(nn.Module): 88 | """ decoder""" 89 | 90 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 91 | super(DecoderLayer, self).__init__() 92 | self.size = size 93 | self.self_attn = self_attn 94 | self.src_attn = src_attn 95 | self.feed_forward = feed_forward 96 | self.sublyer = utils.clones(SublayerConnection(size, dropout), 3) 97 | 98 | def forward(self, x, memory, src_mask, tgt_mask): 99 | m = memory 100 | x = self.sublyer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 101 | x = self.sublyer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 102 | return self.sublyer[2](x, self.feed_forward) 103 | 104 | 105 | def attention(query, key, value, mask=None, dropout=None): 106 | """ 107 | scaled dot product 108 | --------------------------- 109 | L : target sequence length 110 | S : source sequence length: 111 | N : batch size 112 | E : embedding dim 113 | 114 | h : # of attn head 115 | d_k: E // h 116 | --------------------------- 117 | :param query: (N, h, L, d_k) 118 | :param key: (N, h, S, d_k) 119 | :param value: (N, h, S, d_k) 120 | :param mask: 121 | :param dropout: float 122 | :return: 123 | """ 124 | d_k = query.size(-1) 125 | # (nbatch, h, seq_len, d_k) @ (nbatch, h, d_k, seq_len) => (nbatch, h, seq_len, seq_len) 126 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 127 | if mask is not None: 128 | scores = scores.masked_fill(mask == 0, -1e9) 129 | p_attn = F.softmax(scores, dim=-1) 130 | if dropout: 131 | p_attn = dropout(p_attn) 132 | # (nbatch, h, seq_len, seq_len) * (nbatch, h, seq_len, d_k) = > (nbatch, h, seq_len, d_k) 133 | return torch.matmul(p_attn, value), p_attn 134 | 135 | 136 | class MultiHeadedAttention(nn.Module): 137 | def __init__(self, d_model, h, dropout=0.1): 138 | """ 139 | multi-head attention 140 | :param h: nhead 141 | :param d_model: d_model 142 | :param dropout: float 143 | """ 144 | super(MultiHeadedAttention, self).__init__() 145 | assert d_model % h == 0 146 | # assume d_v always equals d_k 147 | self.d_k = d_model // h 148 | self.h = h 149 | self.linears = utils.clones(nn.Linear(d_model, d_model), 4) 150 | self.attn = None 151 | self.dropout = nn.Dropout(p=dropout) 152 | 153 | def forward(self, query, key, value, mask=None): 154 | """ 155 | --------------------------- 156 | L : target sequence length 157 | S : source sequence length: 158 | N : batch size 159 | E : embedding dim 160 | --------------------------- 161 | :param query: (N,L,E) 162 | :param key: (N,S,E) 163 | :param value: (N,S,E) 164 | :param mask: 165 | """ 166 | if mask is not None: 167 | # Same mask applied to all h heads. 168 | mask = mask.unsqueeze(1) 169 | nbatches = query.size(0) # batch size 170 | 171 | # 1) split embedding dim to h heads : from d_model => h * d_k 172 | # dim: (nbatch, h, seq_length, d_model//h) 173 | query, key, value = \ 174 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 175 | for l, x in zip(self.linears, (query, key, value))] 176 | 177 | # 2) compute attention 178 | x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) 179 | 180 | # 3) "Concat" using a view and apply a final linear. 181 | # dim: (nbatch, h, d_model) 182 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) 183 | return self.linears[-1](x) 184 | 185 | 186 | class PositionwiseFeedForward(nn.Module): 187 | """ FFN """ 188 | 189 | def __init__(self, d_model, d_ff, dropout=.1): 190 | super(PositionwiseFeedForward, self).__init__() 191 | self.w_1 = nn.Linear(d_model, d_ff) 192 | self.w_2 = nn.Linear(d_ff, d_model) 193 | self.dropout = nn.Dropout(dropout) 194 | 195 | def forward(self, x): 196 | # return self.w_2(self.dropout(F.relu(self.w_1(x)))) 197 | 198 | ## Swish 199 | x = self.w_1(x) 200 | x *= F.sigmoid(x) 201 | return self.w_2(self.dropout(x)) 202 | 203 | 204 | class LocalRNNLayer(nn.Module): 205 | def __init__(self, size, dropout=.0): 206 | super(LocalRNNLayer, self).__init__() 207 | self.local_rnn = LocalRNN(size, size, window_size=5) 208 | self.sublayer = SublayerConnection(size, dropout) 209 | 210 | def forward(self, x): 211 | return self.sublayer(x, self.local_rnn) 212 | 213 | 214 | class LocalRNN(nn.Module): 215 | """ R transformer""" 216 | 217 | def __init__(self, input_size, output_size, window_size, rnn_type='GRU', MAX_LENGTH=10000): 218 | super(LocalRNN, self).__init__() 219 | self.window_size = window_size 220 | if rnn_type == 'GRU': 221 | # set `batch_first`=True so that the input and output dim are both (nBatch, seq_len, d_model) 222 | self.rnn = nn.GRU(output_size, output_size, batch_first=True) 223 | elif rnn_type == 'LSTM': 224 | self.rnn = nn.LSTM(output_size, output_size, batch_first=True) 225 | else: 226 | self.rnn = nn.RNN(output_size, output_size, batch_first=True) 227 | # self.output = nn.Sequential(nn.Linear(output_size, output_size), nn.ReLU()) 228 | 229 | # generate segments according to window_size. 230 | # -> e.g. window size = 4, generate [1,2,3,4, 231 | # 2,3,4,5, 232 | # 3,4,5,6, 233 | # 4,5,6,7, 234 | # ... 235 | # MAX_LEN - 1 -k ,... , MAX_LEN-2, MAX_LEN-1] 236 | idx = [i for j in range(window_size - 1, MAX_LENGTH) for i in range(j - (window_size - 1), j + 1)] 237 | self.idx = torch.LongTensor(idx) 238 | # padding (k-1) before the beginning of the sequence 239 | self.zeros_pad = torch.zeros((window_size - 1, input_size)) 240 | 241 | def forward(self, x): 242 | """ regard window size dim as batch dim""" 243 | assert x.dim() == 3, '3 dimensions of input expected!' 244 | nbatches, seq_len, d_model = x.size() 245 | 246 | x = self._gather_seg_sequence(x) 247 | output, _ = self.rnn(x) 248 | h_last_per_batch = output[:, -1, :] 249 | return h_last_per_batch.view(nbatches, seq_len, d_model) 250 | 251 | def _gather_seg_sequence(self, x): 252 | nbatch, seq_len, d_model = x.size() 253 | # use `repeat` to pad one batch -> (nbatch, k01, input_size) 254 | zeros = self.zeros_pad.repeat(nbatch, 1, 1) 255 | # concat padded zeros and the sequence along the sequence dim 256 | x = torch.cat((zeros, x), dim=1) 257 | # gather the corresponding embeddings along the sequence dim (1) 258 | idx = self.idx[:self.window_size * seq_len] # 259 | x_ = torch.index_select(input=x, dim=1, index=idx) 260 | # reshape -> (bsz * seq_len, window_size, d_model) 261 | x_ = x_.reshape(nbatch * seq_len, self.window_size, -1) 262 | return x_ 263 | 264 | 265 | class Embeddings(nn.Module): 266 | def __init__(self, d_model, vocab): 267 | super(Embeddings, self).__init__() 268 | self.lut = nn.Embedding(vocab, d_model) 269 | self.d_model = d_model 270 | 271 | def forward(self, x): 272 | return self.lut(x) * math.sqrt(self.d_model) # to make positional encoding smaller 273 | 274 | 275 | class PositionalEncoding(nn.Module): 276 | def __init__(self, d_model, dropout, max_len=5000): 277 | super(PositionalEncoding, self).__init__() 278 | self.dropout = nn.Dropout(p=dropout) 279 | 280 | pe = torch.zeros(max_len, d_model) 281 | position = torch.arange(0., max_len).unsqueeze(1) 282 | div_term = torch.exp(torch.arange(0., d_model, 2) * - (math.log(1e4) / d_model)) 283 | pe[:, ::2] = torch.sin(position * div_term) 284 | pe[:, 1::2] = torch.cos(position * div_term) 285 | pe = pe.unsqueeze(0) 286 | self.register_buffer('pe', pe) 287 | 288 | def forward(self, x): 289 | x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) 290 | return self.dropout(x) 291 | 292 | 293 | class MultiHeadedAttention_RPR(nn.Module): 294 | """ @ author: Yekun CHAI """ 295 | 296 | def __init__(self, d_model, h, max_relative_position, dropout=.0): 297 | """ 298 | multi-head attention 299 | :param h: nhead 300 | :param d_model: d_model 301 | :param dropout: float 302 | """ 303 | super(MultiHeadedAttention_RPR, self).__init__() 304 | assert d_model % h == 0 305 | # assume d_v always equals d_k 306 | self.d_k = d_model // h 307 | self.h = h 308 | self.linears = utils.clones(nn.Linear(d_model, d_model), 4) 309 | self.dropout = nn.Dropout(p=dropout) 310 | 311 | self.max_relative_position = max_relative_position 312 | self.vocab_size = max_relative_position * 2 + 1 313 | self.embed_K = nn.Embedding(self.vocab_size, self.d_k) 314 | self.embed_V = nn.Embedding(self.vocab_size, self.d_k) 315 | 316 | def forward(self, query, key, value, mask=None): 317 | """ 318 | --------------------------- 319 | L : target sequence length 320 | S : source sequence length: 321 | N : batch size 322 | E : embedding dim 323 | --------------------------- 324 | :param query: (N,L,E) 325 | :param key: (N,S,E) 326 | :param value: (N,S,E) 327 | :param mask: 328 | """ 329 | nbatches = query.size(0) # batch size 330 | seq_len = query.size(1) 331 | # 1) split embedding dim to h heads : from d_model => h * d_k 332 | # dim: (nbatch, h, seq_length, d_model//h) 333 | query, key, value = \ 334 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 335 | for l, x in zip(self.linears, (query, key, value))] 336 | 337 | # 2) rpr 338 | relation_keys = self.generate_relative_positions_embeddings(seq_len, seq_len, self.embed_K) 339 | relation_values = self.generate_relative_positions_embeddings(seq_len, seq_len, self.embed_V) 340 | logits = self._relative_attn_inner(query, key, relation_keys, True) 341 | weights = self.dropout(F.softmax(logits, -1)) 342 | x = self._relative_attn_inner(weights, value, relation_values, False) 343 | # 3) "Concat" using a view and apply a final linear. 344 | # dim: (nbatch, h, d_model) 345 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) 346 | return self.linears[-1](x) 347 | 348 | def _generate_relative_positions_matrix(self, len_q, len_k): 349 | """ 350 | genetate rpr matrix 351 | --------------------------- 352 | :param len_q: seq_len 353 | :param len_k: seq_len 354 | :return: rpr matrix, dim: (len_q, len_q) 355 | """ 356 | assert len_q == len_k 357 | range_vec_q = range_vec_k = torch.arange(len_q) 358 | distance_mat = range_vec_k.unsqueeze(0) - range_vec_q.unsqueeze(-1) 359 | disntance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) 360 | return disntance_mat_clipped + self.max_relative_position 361 | 362 | def generate_relative_positions_embeddings(self, len_q, len_k, embedding_table): 363 | """ 364 | generate relative position embedding 365 | ---------------------- 366 | :param len_q: 367 | :param len_k: 368 | :return: rpr embedding, dim: (len_q, len_q, d_k) 369 | """ 370 | relative_position_matrix = self._generate_relative_positions_matrix(len_q, len_k) 371 | return embedding_table(relative_position_matrix) 372 | 373 | def _relative_attn_inner(self, x, y, z, transpose): 374 | """ 375 | efficient implementation 376 | ------------------------ 377 | :param x: 378 | :param y: 379 | :param z: 380 | :param transpose: 381 | :return: 382 | """ 383 | nbatches = x.size(0) 384 | heads = x.size(1) 385 | seq_len = x.size(2) 386 | 387 | # (N, h, s, s) 388 | xy_matmul = torch.matmul(x, y.transpose(-1, -2) if transpose else y) 389 | # (s, N, h, d) => (s, N*h, d) 390 | x_t_v = x.permute(2, 0, 1, 3).contiguous().view(seq_len, nbatches * heads, -1) 391 | # (s, N*h, d) @ (s, d, s) => (s, N*h, s) 392 | x_tz_matmul = torch.matmul(x_t_v, z.transpose(-1, -2) if transpose else z) 393 | # (N, h, s, s) 394 | x_tz_matmul_v_t = x_tz_matmul.view(seq_len, nbatches, heads, -1).permute(1, 2, 0, 3) 395 | return xy_matmul + x_tz_matmul_v_t 396 | 397 | 398 | if __name__ == '__main__': 399 | # import os 400 | # 401 | # os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 402 | # 403 | # plt.figure(figsize=(15, 5)) 404 | # pe = PositionalEncoding(20, 0) 405 | # y = pe.forward(Variable(torch.zeros(1, 100, 20))) 406 | # plt.plot(np.arange(100), y[0, :, 4:8].data.numpy()) 407 | # plt.legend(["dim %d" % p for p in list(range(4, 8))]) 408 | # plt.show() 409 | pe = MultiHeadedAttention_RPR(8, 256) 410 | x = torch.randn((64, 10, 256)) 411 | y = pe(x, x, x) 412 | print(y.size()) 413 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # -*- encoding: utf-8 4 | 5 | ''' 6 | _____.___._______________ __.____ __________ _________ ___ ___ _____ .___ 7 | \__ | |\_ _____/ |/ _| | \ \ \_ ___ \ / | \ / _ \ | | 8 | / | | | __)_| < | | / | \ / \ \// ~ \/ /_\ \| | 9 | \____ | | \ | \| | / | \ \ \___\ Y / | \ | 10 | / ______|/_______ /____|__ \______/\____|__ / \______ /\___|_ /\____|__ /___| 11 | \/ \/ \/ \/ \/ \/ \/ 12 | 13 | ========================================================================================== 14 | 15 | @author: Yekun Chai 16 | 17 | @license: School of Informatics, Edinburgh 18 | 19 | @contact: chaiyekun@gmail.com 20 | 21 | @file: Model.py 22 | 23 | @time: 29/09/2019 20:25 24 | 25 | @desc: 26 | 27 | ''' 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | 33 | import numpy as np 34 | import math, copy, time 35 | import seaborn 36 | 37 | seaborn.set_context(context="talk") 38 | 39 | import utils 40 | from Layers import MultiHeadedAttention, PositionwiseFeedForward, PositionalEncoding, EncoderLayer, DecoderLayer, \ 41 | Embeddings, MultiHeadedAttention_RPR 42 | 43 | 44 | class EncoderDecoder(nn.Module): 45 | """ 46 | standard encoder decoder architecture 47 | """ 48 | 49 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 50 | super(EncoderDecoder, self).__init__() 51 | self.encoder = encoder 52 | self.decoder = decoder 53 | self.src_embed = src_embed 54 | self.tgt_embed = tgt_embed 55 | self.generator = generator 56 | 57 | def forward(self, src, tgt, src_mask, tgt_mask): 58 | return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask) 59 | 60 | def encode(self, src, src_mask): 61 | return self.encoder(self.src_embed(src), src_mask) 62 | 63 | def decode(self, memory, src_mask, tgt, tgt_mask): 64 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 65 | 66 | 67 | class Generator(nn.Module): 68 | def __init__(self, d_model, vocab): 69 | super(Generator, self).__init__() 70 | self.proj = nn.Linear(d_model, vocab) 71 | 72 | def forward(self, x): 73 | return F.softmax(self.proj(x), dim=-1) 74 | 75 | 76 | class Encoder(nn.Module): 77 | """ Core encoder -> a stack of N layers """ 78 | 79 | def __init__(self, layer, N): 80 | super(Encoder, self).__init__() 81 | self.layers = utils.clones(layer, N) 82 | size = layer.size 83 | self.norm = nn.LayerNorm(size) 84 | 85 | def forward(self, x, mask): 86 | """ pass input and mask through each layer in turn""" 87 | 88 | for layer in self.layers: 89 | x = layer(x, mask) 90 | return self.norm(x) 91 | 92 | 93 | class Decoder(nn.Module): 94 | """ N layer decoder with masking""" 95 | 96 | def __init__(self, layer, N): 97 | super(Decoder, self).__init__() 98 | self.layers = utils.clones(layer, N) 99 | self.norm = nn.LayerNorm(layer.size) 100 | 101 | def forward(self, x, memory, src_mask, tgt_mask): 102 | for layer in self.layers: 103 | x = layer(x, memory, src_mask, tgt_mask) 104 | return self.norm(x) 105 | 106 | 107 | def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=.1): 108 | """ construct model from hyper-parameters""" 109 | c = copy.deepcopy 110 | attn_rpr = MultiHeadedAttention_RPR(d_model, h, max_relative_position=5) 111 | attn = MultiHeadedAttention(d_model, h) 112 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 113 | position = PositionalEncoding(d_model, dropout) 114 | model = EncoderDecoder( 115 | 116 | Encoder(EncoderLayer(d_model, c(attn_rpr), c(ff), dropout), N), 117 | Decoder(DecoderLayer(d_model, c(attn_rpr), c(attn), c(ff), dropout), N), 118 | nn.Sequential(Embeddings(d_model, src_vocab), c(position)), 119 | nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), 120 | Generator(d_model, tgt_vocab) 121 | ) 122 | 123 | for p in model.parameters(): 124 | if p.dim() > 1: 125 | nn.init.xavier_uniform_(p) 126 | return model 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer-in-PyTorch 2 | ![GitHub](https://img.shields.io/github/license/cyk1337/Transformer-in-Pytorch) 3 | - This repo contains the Transformer variants implementation in PyTorch (`Transformer` / `Transformer-XL` / `R-Transformer`). PR is welcome. 4 | - If you are unfamilar with Transformer and its variants, refer to my blog: [transformer explanation](http://ychai.uk/notes/2019/01/22/NLP/Attention-in-a-nutshell/#Transformer). 5 | 6 | ## Citation 7 | For attribution in academic contexts, please cite this work as: 8 | ``` 9 | @misc{chai2019-transformer-in-pytorch, 10 | author = {Chai, Yekun}, 11 | title = {Transformer-in-PyTorch}, 12 | year = {2019}, 13 | publisher = {GitHub}, 14 | journal = {GitHub repository}, 15 | howpublished = {\url{https://github.com/cyk1337/Transformer-in-PyTorch}} 16 | } 17 | 18 | @misc{chai2019attn-summary, 19 | author = {Chai, Yekun}, 20 | title = {{Attention in a Nutshell}}, 21 | year = {2019}, 22 | howpublished = {\url{http://cyk1337.github.io/notes/2019/01/22/NLP/Attention-in-a-nutshell/}}, 23 | } 24 | ``` 25 | 26 | References: 27 | - [The annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html) 28 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # -*- encoding: utf-8 4 | 5 | ''' 6 | _____.___._______________ __.____ __________ _________ ___ ___ _____ .___ 7 | \__ | |\_ _____/ |/ _| | \ \ \_ ___ \ / | \ / _ \ | | 8 | / | | | __)_| < | | / | \ / \ \// ~ \/ /_\ \| | 9 | \____ | | \ | \| | / | \ \ \___\ Y / | \ | 10 | / ______|/_______ /____|__ \______/\____|__ / \______ /\___|_ /\____|__ /___| 11 | \/ \/ \/ \/ \/ \/ \/ 12 | 13 | ========================================================================================== 14 | 15 | @author: Yekun Chai 16 | 17 | @license: School of Informatics, Edinburgh 18 | 19 | @contact: chaiyekun@gmail.com 20 | 21 | @file: train.py 22 | 23 | @time: 30/09/2019 15:02 24 | 25 | @desc: 26 | 27 | ''' 28 | import time 29 | import torch 30 | from torch.autograd import Variable 31 | import numpy as np 32 | 33 | import utils 34 | from Model import make_model 35 | 36 | 37 | class Batch: 38 | def __init__(self, src, trg=None, pad=0): 39 | self.src = src 40 | self.src_mask = (src != pad).unsqueeze(-2) 41 | if trg is not None: 42 | self.trg = trg[:, :-1] 43 | self.trg_y = trg[:, 1:] 44 | self.trg_mask = self.make_std_mask(self.trg, pad) 45 | self.ntokens = (self.trg_y != pad).data.sum() 46 | 47 | @staticmethod 48 | def make_std_mask(tgt, pad): 49 | """ create a mask to hide padding and future words""" 50 | tgt_mask = (tgt != pad).unsqueeze(-2) # pad mask 51 | tgt_mask = tgt_mask & Variable(utils.subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) 52 | return tgt_mask 53 | 54 | 55 | def run_epoch(data_iter, model, loss_compute): 56 | "Standard Training and Logging Function" 57 | start = time.time() 58 | total_tokens = 0 59 | total_loss = 0 60 | tokens = 0 61 | for i, batch in enumerate(data_iter): 62 | out = model.forward(batch.src, batch.trg, 63 | batch.src_mask, batch.trg_mask) 64 | loss = loss_compute(out, batch.trg_y, batch.ntokens) 65 | total_loss += loss 66 | total_tokens += batch.ntokens 67 | tokens += batch.ntokens 68 | if i % 50 == 1: 69 | elapsed = time.time() - start 70 | print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % 71 | (i, loss / batch.ntokens, tokens / elapsed)) 72 | start = time.time() 73 | tokens = 0 74 | return total_loss / total_tokens 75 | 76 | 77 | global max_src_in_batch, max_tgt_in_batch 78 | 79 | 80 | def batch_size_fn(new, count, sofar): 81 | global max_src_in_batch, max_tgt_in_batch 82 | if count == 1: 83 | max_src_in_batch = 0 84 | max_tgt_in_batch = 0 85 | max_src_in_batch = max(max_src_in_batch, len(new.src)) 86 | max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2) 87 | src_elements = count * max_src_in_batch 88 | tgt_elements = count * max_tgt_in_batch 89 | return max(src_elements, tgt_elements) 90 | 91 | 92 | class NoamOpt: 93 | """ Optim wrapper """ 94 | 95 | def __init__(self, model_size, factor, warmup, optimizer): 96 | self.optimizer = optimizer 97 | self._step = 0 98 | self.warmup = warmup 99 | self.factor = factor 100 | self.model_size = model_size 101 | self._rate = 0 102 | 103 | def step(self): 104 | """update parameters and rate""" 105 | self._step += 1 106 | rate = self.rate() 107 | for p in self.optimizer.param_groups: 108 | p['lr'] = rate 109 | self._rate = rate 110 | self.optimizer.step() 111 | 112 | def rate(self, step=None): 113 | if step is None: 114 | step = self._step 115 | return self.factor * (self.model_size ** (-.5) * min(step ** (-.5), step * self.warmup ** (-1.5))) 116 | 117 | 118 | def get_std_opt(model): 119 | return NoamOpt(model.src_embed[0].d_model, 2, 4000, 120 | torch.optim.Adam(model.parameters(), lr=0, betas=(.9, .98), eps=1e-9)) 121 | 122 | 123 | def data_gen(V, batch, nbatches): 124 | """ Generate random data for a src-tgt copy task""" 125 | for i in range(nbatches): 126 | data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10))) 127 | data[:, 0] = 1 128 | src = Variable(data, requires_grad=False) 129 | tgt = Variable(data, requires_grad=False) 130 | yield Batch(src, tgt, 0) 131 | 132 | 133 | class SimpleLossCompute: 134 | "A simple loss compute and train function." 135 | 136 | def __init__(self, generator, criterion, opt=None): 137 | self.generator = generator 138 | self.criterion = criterion 139 | self.opt = opt 140 | 141 | def __call__(self, x, y, norm): 142 | x = self.generator(x) 143 | loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 144 | y.contiguous().view(-1)) / norm 145 | loss.backward() 146 | if self.opt is not None: 147 | self.opt.step() 148 | self.opt.optimizer.zero_grad() 149 | return loss.data.item() * norm 150 | 151 | 152 | def run_test(): 153 | V = 11 154 | criterion = utils.LabelSmoothing(size=V, padding_idx=0, smoothing=0.0) 155 | model = make_model(V, V, N=2, d_model=256, d_ff=256, dropout=.0) 156 | model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400, 157 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 158 | 159 | for epoch in range(10): 160 | model.train() 161 | run_epoch(data_gen(V, 30, 20), model, 162 | SimpleLossCompute(model.generator, criterion, model_opt)) 163 | model.eval() 164 | print(run_epoch(data_gen(V, 30, 5), model, 165 | SimpleLossCompute(model.generator, criterion, None))) 166 | 167 | # model.eval() 168 | # src = Variable(torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])) 169 | # src_mask = Variable(torch.ones(1, 1, 10)) 170 | # print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1)) 171 | 172 | 173 | def greedy_decode(model, src, src_mask, max_len, start_symbol): 174 | memory = model.encode(src, src_mask) 175 | ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data) 176 | for i in range(max_len - 1): 177 | out = model.decode(memory, src_mask, 178 | Variable(ys), 179 | Variable(utils.subsequent_mask(ys.size(1)) 180 | .type_as(src.data))) 181 | prob = model.generator(out[:, -1]) 182 | _, next_word = torch.max(prob, dim=1) 183 | next_word = next_word.data[0] 184 | ys = torch.cat([ys, 185 | torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) 186 | return ys 187 | 188 | 189 | if __name__ == '__main__': 190 | # import matplotlib.pyplot as plt 191 | # import numpy as np 192 | # 193 | # opts = [NoamOpt(512, 1, 4000, None), 194 | # NoamOpt(512, 1, 8000, None), 195 | # NoamOpt(256, 1, 4000, None)] 196 | # plt.plot(np.arange(1, 20000), [[opt.rate(i) for opt in opts] for i in range(1, 20000)]) 197 | # plt.legend(["512:4000", "512:8000", "256:4000"]) 198 | # plt.show() 199 | 200 | run_test() 201 | -------------------------------------------------------------------------------- /transformer_xl/Layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # -*- encoding: utf-8 4 | 5 | ''' 6 | _____.___._______________ __.____ __________ _________ ___ ___ _____ .___ 7 | \__ | |\_ _____/ |/ _| | \ \ \_ ___ \ / | \ / _ \ | | 8 | / | | | __)_| < | | / | \ / \ \// ~ \/ /_\ \| | 9 | \____ | | \ | \| | / | \ \ \___\ Y / | \ | 10 | / ______|/_______ /____|__ \______/\____|__ / \______ /\___|_ /\____|__ /___| 11 | \/ \/ \/ \/ \/ \/ \/ 12 | 13 | ========================================================================================== 14 | 15 | @author: Yekun Chai 16 | 17 | @license: School of Informatics, Edinburgh 18 | 19 | @contact: chaiyekun@gmail.com 20 | 21 | @file: Layers.py 22 | 23 | @time: 17/10/2019 15:11 24 | 25 | @desc: 26 | 27 | ''' 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | 32 | 33 | class PositionalEmbedding(nn.Module): 34 | def __init__(self, d_embed): 35 | super(PositionalEmbedding, self).__init__() 36 | self.d_embed = d_embed 37 | inv_freq = 1 / (10000 ** (torch.arange(.0, d_embed, 2.0)) / d_embed) 38 | self.register_buffer('inv_freq', inv_freq) 39 | 40 | def forward(self, pos_seq, nbatch=None): 41 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 42 | pos_embed = torch.concat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 43 | if nbatch is None: 44 | return pos_embed[:, None, :] 45 | else: 46 | return pos_embed[:, None, :].expand(-1, nbatch, -1) 47 | 48 | 49 | class PositionwiseFF(nn.Module): 50 | def __init__(self, d_model, d_hid, dropout, pre_lnorm=False): 51 | super(PositionwiseFF, self).__init__() 52 | self.d_model = d_model 53 | self.d_inner = d_hid 54 | self.dropout = dropout 55 | 56 | self.net = nn.Sequential( 57 | nn.Linear(d_model, d_hid), nn.ReLU(inplace=True), 58 | nn.Dropout(dropout), 59 | nn.Linear(d_hid, d_model), 60 | nn.Dropout(dropout), 61 | ) 62 | 63 | self.layer_norm = nn.LayerNorm(d_model) 64 | self.pre_lnorm = pre_lnorm 65 | 66 | def forward(self, x): 67 | if self.pre_lnorm: 68 | out = self.net(self.layer_norm(x)) 69 | out += x 70 | else: 71 | out = self.net(x) 72 | out = self.layer_norm(x + out) 73 | return out 74 | 75 | 76 | class MultiHeadAttn(nn.Module): 77 | def __init__(self, nheads, d_model, d_head, dropout, dropatt=0, pre_lnorm=False): 78 | super(MultiHeadAttn, self).__init__() 79 | self.nheads = nheads 80 | self.d_model = d_model 81 | self.d_head = d_head # d_model//nheads 82 | self.dropout = dropout 83 | 84 | self.q_net = nn.Linear(d_model, nheads * d_head, d_model, bias=False) 85 | self.kv_net = nn.Linear(d_model, 2 * nheads * d_head, bias=False) 86 | 87 | self.dropout = nn.Dropout(dropout) 88 | self.dropout_attn = nn.Dropout(dropatt) 89 | self.o_net = nn.Linear(nheads * d_head, d_model, bias=False) 90 | 91 | self.layer_norm = nn.LayerNorm(d_model) 92 | self.scale = 1 / (d_head ** .5) 93 | self.pre_lnorm = pre_lnorm 94 | 95 | def forward(self, h, attn_mask=None, mems=None): 96 | # [seq_len x nbatch x nheads x d_head] 97 | if mems is None: 98 | c = torch.cat([mems, h], 0) 99 | else: 100 | c = h 101 | 102 | if self.pre_lnorm: 103 | c = self.layer_norm(c) 104 | 105 | head_q = self.q_net(h) 106 | head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) 107 | 108 | head_q = head_q.view(h.size(0), h.size(1), self.nheads, self.d_head) 109 | head_k = head_k.view(c.size(0), c.size(1), self.nheads, self.d_head) 110 | head_v = head_v.view(c.size(0), c.size(1), self.nheads, self.d_head) 111 | 112 | # [q_len, k_len, nbatch, n_head] 113 | attn_score = torch.einsum('qbnd,kbnd->qkbn', (head_q, head_k)) 114 | attn_score.mul_(self.scale) 115 | if attn_mask is not None and attn_mask.any().item(): 116 | if attn_mask.dim() == 2: 117 | attn_score.mask_fill_(attn_mask[None, :, :, None], -float('inf')) 118 | elif attn_mask.dim() == 3: 119 | attn_score.mask_fill_(attn_mask[:, :, :, None], -float('inf')) 120 | 121 | attn_prob = F.softmax(attn_score, dim=1) 122 | attn_prob = self.dropout_attn(attn_prob) 123 | 124 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) 125 | attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.nheads * self.d_head) 126 | 127 | attn_out = self.o_net(attn_vec) 128 | attn_out = self.dropout(attn_out) 129 | 130 | if self.pre_lnorm: 131 | output = h + attn_out 132 | else: 133 | output = self.layer_norm(h + attn_out) 134 | return output 135 | 136 | 137 | class RelMultiHeadAttn(nn.Module): 138 | def __init__(self, nheads, d_model, d_head, dropout_p, dropout_attn_p=0, tgt_len=None, ext_len=None, mem_len=None, 139 | pre_lnorm=False): 140 | super(RelMultiHeadAttn, self).__init__() 141 | 142 | self.nheads = nheads 143 | self.d_model = d_model 144 | self.d_head = d_head 145 | self.dropout_p = dropout_p 146 | self.qkv_net = nn.Linear(d_model, 3 * nheads * d_head, bias=False) 147 | 148 | self.dropout = nn.Dropout(dropout_p) 149 | self.dropout_attn = nn.Dropout(dropout_attn_p) 150 | self.o_net = nn.Linear(nheads * d_head, d_model, bias=False) 151 | 152 | self.layer_norm = nn.LayerNorm(d_model) 153 | self.scale = 1 / (d_head ** .5) 154 | 155 | self.pre_lnorm = pre_lnorm 156 | 157 | def _parallelogram_mask(self, h, w, left=False): 158 | mask = torch.ones((h, w)).byte() 159 | m = min(h, w) 160 | mask[:m, :m] = torch.triu(mask[:m, :m]) 161 | mask[-m:, -m:] = torch.tril(mask[-m:, -m:]) 162 | 163 | if left: 164 | return mask 165 | else: 166 | return mask.flip(0) 167 | 168 | def _shift(self, x, qlen, klen, mask, left=False): 169 | if qlen > 1: 170 | zero_pad = torch.zeros((x.size(0), qlen - 1, x.size(2), x.size(3)), device=x.device, dtype=x.type) 171 | else: 172 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) 173 | 174 | if left: 175 | mask = mask.flip(1) 176 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) 177 | else: 178 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) 179 | 180 | x = x_padded.masked_selected(mask[:, :, None, None]).view(qlen, klen, x.zie(2), x.size(3)) 181 | return x 182 | 183 | def _rel_shift(self, x, zero_triu=False): 184 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), device=x.device, dtype=x.dtype) 185 | x_padded = torch.concat([zero_pad, x], dim=1) 186 | 187 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) 188 | 189 | x = x_padded[1:].view_as(x) 190 | 191 | if zero_triu: 192 | ones = torch.ones((x.size(0), x.size(1))) 193 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:, :, None, None] 194 | return x 195 | 196 | def forward(self, w, r, attn_mask=None, mems=None): 197 | raise NotImplementedError 198 | 199 | 200 | class RelPartialLearnableMulltiHeadAttn(RelMultiHeadAttn): 201 | def __init__(self, *args, **kwargs): 202 | super(RelPartialLearnableMulltiHeadAttn, self).__init__(*args, **kwargs) 203 | self.r_net = nn.Linear(self.d_model, self.nheads * self.d_head, bias=False) 204 | 205 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): 206 | qlen, rlen, nbatch = w.size(0), r.size(0), w.size(1) 207 | 208 | if mems is not None: 209 | cat = torch.cat([mems, w], 0) 210 | if self.pre_lnorm: 211 | w_heads = self.qkv_net(self.layer_norm(cat)) 212 | else: 213 | w_heads = self.qkv_net(cat) 214 | 215 | r_head_k = self.r_net(r) 216 | 217 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 218 | w_head_q = w_head_q[-qlen:] 219 | else: 220 | if self.pre_lnorm: 221 | w_heads = self.qkv_net(self.layer_norm(w)) 222 | else: 223 | w_heads = self.qkv_net(w) 224 | r_head_k = self.r_net 225 | 226 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 227 | klen = w_head_k.size(0) 228 | w_head_q = w_head_q.view(qlen, nbatch, self.nheads, self.d_head) 229 | w_head_k = w_head_k.view(klen, nbatch, self.nheads, self.d_head) 230 | w_head_v = w_head_v.view(klen, nbatch, self.nheads, self.d_head) 231 | 232 | r_head_k = r_head_k.view(rlen, self.nheads, self.d_head) 233 | 234 | # compute attn score 235 | rw_head_q = w_head_q + r_w_bias 236 | AC = torch.einsum('ibnd, jbnd->ijbn', (rw_head_q, w_head_k)) 237 | 238 | rr_head_q = w_head_q + r_w_bias 239 | BD = torch.einsum('ibnd, jbnd->ijbn', (rr_head_q, r_head_k)) 240 | BD = self._rel_shift(BD) 241 | 242 | attn_score = AC + BD 243 | attn_score.mul_(self.scale) 244 | 245 | # compute attn probability 246 | if attn_mask is not None and attn_mask.any().item(): 247 | if attn_mask.dim() == 2: 248 | attn_score = attn_score.float().mask_fill( 249 | attn_mask[None, :, :, None], -float('inf')).type_as(attn_score) 250 | elif attn_mask.dim() == 3: 251 | attn_score = attn_score.float().mask_fill(attn_mask[:, :, :, None], -float('inf')).type_as(attn_score) 252 | 253 | attn_prob = F.softmax(attn_score, dim=1) 254 | attn_prob = self.dropout_attn(attn_prob) 255 | 256 | # compute attn vec 257 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 258 | 259 | attn_vec = attn_vec.view(attn_vec.size(0), attn_vec.size(1), self.nheads * self.d_head) 260 | 261 | # linear projection 262 | attn_out = self.o_net(attn_vec) 263 | attn_out = self.dropout(attn_out) 264 | 265 | if self.pre_lnorm: 266 | output = w + attn_out 267 | else: 268 | output = self.pre_lnorm(w + attn_out) 269 | return output 270 | 271 | 272 | class RelLearnableMultiHeadAttn(RelMultiHeadAttn): 273 | def __init__(self, *args, **Kwargs): 274 | super(RelLearnableMultiHeadAttn, self).__init__(*args, **Kwargs) 275 | 276 | def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): 277 | qlen, nbatch = w.size(0), w.size(1) 278 | if mems is not None: 279 | cat = torch.cat([mems, w], 0) 280 | if self.pre_lnorm: 281 | w_heads = self.qkv_net(self.layer_norm(cat)) 282 | else: 283 | w_heads = self.qkv_net(cat) 284 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 285 | 286 | w_head_q = w_head_q[-qlen:] 287 | else: 288 | if self.pre_lnorm: 289 | w_heads = self.qkv_net(self.layer_norm(w)) 290 | else: 291 | w_heads = self.qkv_net(w) 292 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 293 | 294 | klen = w_head_k.size(0) 295 | 296 | w_head_q = w_head_q.view(qlen, nbatch, self.nheads, self.d_head) 297 | w_head_k = w_head_k.view(klen, nbatch, self.nheads, self.d_head) 298 | w_head_v = w_head_v.view(klen, nbatch, self.nheads, self.d_head) 299 | 300 | if klen < r_emb.size(0): 301 | r_emb_pad = r_emb[0:1].expand(klen - r_emb.size(0), -1, -1) 302 | r_emb = torch.cat([r_emb_pad, r_emb], 0) 303 | r_bias_pad = r_bias[0:1].expand(klen - r_emb.size(0), -1) 304 | r_bias = torch.cat([r_bias_pad, r_bias], 0) 305 | else: 306 | r_emb = r_emb[-klen:] 307 | r_bias = r_bias[-klen:] 308 | 309 | # compute the attn score 310 | rw_head_q = w_head_q + r_w_bias[None] 311 | 312 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) 313 | B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) 314 | D_ = r_bias[None, :, :, None] 315 | BD = self._rel_shift(B_ + D_) 316 | 317 | attn_score = AC + BD 318 | attn_score.mul_(self.scale) 319 | 320 | # compute attn prob 321 | if attn_mask is not None and attn_mask.any().item(): 322 | if attn_mask.dim() == 2: 323 | attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf')) 324 | elif attn_mask.dim() == 3: 325 | attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf')) 326 | 327 | attn_prob = F.softmax(attn_score, dim=1) 328 | attn_prob = self.dropout_attn(attn_prob) 329 | 330 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 331 | 332 | attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.nheads * self.d_head) 333 | 334 | attn_out = self.o_net(attn_vec) 335 | attn_out = self.dropout(attn_out) 336 | 337 | if self.pre_lnorm: 338 | output = w + attn_out 339 | else: 340 | output = self.layer_norm(w + attn_out) 341 | 342 | return output 343 | 344 | 345 | class DecoderLayer(nn.Module): 346 | def __init__(self, nheads, d_model, d_head, d_inner, dropout, **kwargs): 347 | super(DecoderLayer, self).__init__() 348 | self.dec_attn = MultiHeadAttn(nheads, d_model, d_head, dropout, **kwargs) 349 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) 350 | 351 | def forward(self, dec_inp, dec_attn_mask=None, mems=None): 352 | output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, mems=mems) 353 | output = self.pos_ff(output) 354 | return output 355 | 356 | 357 | class RelLearnableDecoderLayer(nn.Module): 358 | def __init__(self, nheads, d_model, d_head, d_inner, dropout, **kwargs): 359 | super(RelLearnableDecoderLayer, self).__init__() 360 | self.dec_attn = RelLearnableMultiHeadAttn(nheads, d_model, d_head, dropout, **kwargs) 361 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) 362 | 363 | def forward(self, dec_inp, r_emb, r_w_bias, dec_attn_mask=None, mems=None): 364 | output = self.dec_attn(dec_inp, r_emb, r_w_bias, attn_mask=dec_attn_mask, mems=mems) 365 | output = self.pos_ff(output) 366 | return output 367 | 368 | 369 | class RelPartialLearnabledecoderLayer(nn.Module): 370 | def __init__(self, nheads, d_model, d_head, d_inner, dropout, **kwargs): 371 | super(RelPartialLearnabledecoderLayer, self).__init__() 372 | self.dec_attn = RelPartialLearnableMulltiHeadAttn(nheads, d_model, d_head, dropout, **kwargs) 373 | self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) 374 | 375 | def forward(self, dec_input, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): 376 | output = self.dec_attn(dec_input, r, r_w_bias, r_r_bias, attn_mask=dec_attn_mask, mems=mems) 377 | output = self.pos_ff(output) 378 | return output 379 | 380 | 381 | class AdaptiveEmbedding(nn.Module): 382 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False): 383 | super(AdaptiveEmbedding, self).__init__() 384 | 385 | self.n_token = n_token 386 | self.d_embed = d_embed 387 | self.cutoffs = cutoffs + [n_token] 388 | self.div_val = div_val 389 | self.d_proj = d_proj 390 | 391 | self.emb_scale = d_proj ** .5 392 | self.cutoff_ends = [0] + self.cutoffs 393 | 394 | self.emb_layers = nn.ModuleList() 395 | self.emb_projs = nn.ParameterList() 396 | 397 | if div_val == 1: 398 | self.embed_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0)) 399 | if d_proj != d_embed: 400 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) 401 | else: 402 | for i in range(len(self.cutoffs)): 403 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 404 | d_emb_i = d_embed // (div_val ** i) 405 | self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) 406 | self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) 407 | 408 | def forward(self, input): 409 | if self.div_val == 1: 410 | embed = self.emb_layers[0][input] 411 | if self.d_proj != self.d_embed: 412 | embed = F.linear(embed, self.emb_projs[0]) 413 | else: 414 | param = next(self.parameters()) 415 | input_flat = input.view(-1) 416 | emb_flat = torch.zeros([input_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) 417 | for i in range(len(self.cutoffs)): 418 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 419 | 420 | mask_i = (input_flat >= l_idx) & (input_flat < r_idx) 421 | indices_i = mask_i.nonzero().squeeze() 422 | 423 | if indices_i.numel() == 0: 424 | continue 425 | inp_i = input_flat.index_select(0, indices_i) - l_idx 426 | emb_i = self.emb_layers[i](inp_i) 427 | emb_i = F.linear(emb_i, self.emb_projs[i]) 428 | 429 | emb_flat.index_copy_(0, indices_i, emb_i) 430 | 431 | embed = emb_flat.view(*input.size(), self.d_proj) 432 | embed.mul_(self.emb_scale) 433 | return embed 434 | 435 | 436 | class MemTransformerLM(nn.Module): 437 | def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, dropout, dropatt, tie_weight=True, 438 | d_embed=None, div_val=1, tie_projs=[False], pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, 439 | cutoffs=[], adapt_inp=False, same_length=False, attn_type=0, clamp_len=-1, sample_softmax=-1): 440 | super(MemTransformerLM, self).__init__() 441 | self.n_token = n_token 442 | 443 | d_embed = d_model if d_embed is None else d_embed 444 | self.d_embed = d_embed 445 | self.d_model = d_model 446 | self.n_head = n_head 447 | self.d_head = d_head 448 | 449 | self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val) 450 | 451 | self.drop = nn.Dropout(dropout) 452 | 453 | self.n_layer = n_layer 454 | self.tgt_len = tgt_len 455 | self.mem_len = mem_len 456 | self.ext_len = ext_len 457 | 458 | self.max_klen = tgt_len + ext_len + mem_len 459 | 460 | self.attn_type = attn_type 461 | 462 | self.layers = nn.ModuleList() 463 | 464 | if attn_type == 0: # default attn 465 | for i in range(n_layer): 466 | self.layers.append(RelPartialLearnabledecoderLayer( 467 | n_head, d_model, d_head, d_inner, dropout, 468 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 469 | dropatt=dropatt, pre_lnorm=pre_lnorm)) 470 | 471 | elif attn_type == 1: # learnable embeddings 472 | for i in range(n_layer): 473 | self.layers.append(RelLearnableDecoderLayer( 474 | n_head, d_model, d_head, d_inner, dropout, 475 | tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 476 | dropatt=dropatt, pre_lnorm=pre_lnorm 477 | )) 478 | elif attn_type in [2, 3]: # abs embedding 479 | for i in range(n_layer): 480 | self.layers.append( 481 | DecoderLayer(n_head, d_model, d_head, d_inner, dropout, 482 | dropatt=dropatt, pre_lnorm=pre_lnorm 483 | )) 484 | 485 | self.sample_softmax = sample_softmax 486 | # use sampled softmax 487 | if sample_softmax > 0: 488 | self.out_layer = nn.Linear(d_model, n_token) 489 | if tie_weight: 490 | self.out_layer.weight = self.word_emb.weight 491 | self.tie_weight = tie_weight 492 | raise NotImplementedError 493 | 494 | 495 | if __name__ == '__main__': 496 | pe = PositionalEmbedding(256) 497 | res = pe(3) 498 | -------------------------------------------------------------------------------- /transformer_xl/Transformer_xl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # -*- encoding: utf-8 4 | 5 | ''' 6 | _____.___._______________ __.____ __________ _________ ___ ___ _____ .___ 7 | \__ | |\_ _____/ |/ _| | \ \ \_ ___ \ / | \ / _ \ | | 8 | / | | | __)_| < | | / | \ / \ \// ~ \/ /_\ \| | 9 | \____ | | \ | \| | / | \ \ \___\ Y / | \ | 10 | / ______|/_______ /____|__ \______/\____|__ / \______ /\___|_ /\____|__ /___| 11 | \/ \/ \/ \/ \/ \/ \/ 12 | 13 | ========================================================================================== 14 | 15 | @author: Yekun Chai 16 | 17 | @license: School of Informatics, Edinburgh 18 | 19 | @contact: chaiyekun@gmail.com 20 | 21 | @file: Transformer_xl.py 22 | 23 | @time: 17/10/2019 15:11 24 | 25 | @desc: 26 | 27 | ''' 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | 33 | 34 | class TransformerXL(nn.Module): 35 | def __init__(self, d_model, n_head, d_head, mem_len, n_layer, clamp_len, tgt_len, 36 | ext_len, dropatt, d_inner, pre_lnorm, dropout): 37 | self.n_layer = n_layer 38 | self.mem_len = mem_len 39 | self.word_emb = _ 40 | self.clamp_len = clamp_len 41 | self.d_model = d_model 42 | self.n_head = n_head 43 | self.d_head = d_head 44 | self.drop = nn.Dropout(p=dropout) 45 | self.layers = nn.ModuleList() 46 | 47 | for i in range(n_layer): 48 | self.layers = self.layers.append(RelPartialLearnableDecLayer( 49 | n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, 50 | ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm)) 51 | 52 | def _create_params(self): 53 | self.pos_emb = PositionEmbedding(self.d_model) 54 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 55 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 56 | 57 | def init_mems(self): 58 | if self.mem_len > 0: 59 | mems = [] 60 | param = next(self.parameters()) 61 | for _ in range(self.n_layer + 1): 62 | empty = torch.empty(0, dtype=param.dtype, device=param.device) 63 | mems.append(empty) 64 | return mems 65 | else: # do not use mems 66 | return None 67 | 68 | def _update_mems(self, hids, mems, qlen, mlen): 69 | if mems is None: return 70 | 71 | assert len(hids) == len(mems), 'len(hids) != len(mems)!' 72 | 73 | with torch.no_grad(): 74 | new_mems = [] 75 | end_idx = mlen + qlen 76 | beggin_idx = max(0, end_idx - self.mem_len) 77 | for i in range(len(hids)): 78 | cat = torch.cat((mems[i], hids[i]), dim=0) 79 | new_mems.append(cat[beggin_idx:end_idx].detach()) 80 | return new_mems 81 | 82 | def _forward(self, inp, mems=None): 83 | qlen, bsz = inp.szie() 84 | word_emb = self.word_emb(inp) 85 | 86 | mlen = mems[0].size(0) if mems is not None else 0 87 | klen = mlen + qlen 88 | 89 | dec_attn_mask = torch.triu(word_emb.new_ones(qlen, klen), diagnal=1 + mlen).byte()[:, :, None] 90 | 91 | hiddens = [] 92 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) 93 | if self.clamp_len > 0: 94 | pos_seq.clamp_(max=self.clamp_len) 95 | pos_emb = self.pos_emb(pos_seq) 96 | 97 | core_out = self.drop(word_emb) 98 | pos_emb = self.drop(pos_emb) 99 | 100 | hiddens.append(core_out) 101 | 102 | for i, layer in enumerate(self.layers): 103 | mems_i = None if mems is None else mems[i] 104 | core_out = layer(core_out, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) 105 | hiddens.append(core_out) 106 | 107 | core_out = self.drop(core_out) 108 | 109 | new_mems = self._update_mems(hiddens, mems, mlen, qlen) 110 | 111 | return core_out, new_mems 112 | 113 | def forward(self, x, y, *mems): 114 | if not mems: mems = self.init_mems() 115 | 116 | tgt_len = y.size(0) 117 | hidden, new_mems = self._forward(x, mems=mems) 118 | pred_hid = hidden[-tgt_len:] 119 | 120 | 121 | class RelPartialLearnableDecLayer(nn.Module): 122 | def __init__(self, d_model, n_head, d_head, d_inner, dropout, **kwargs): 123 | super(RelPartialLearnableDecLayer, self).__init__() 124 | self.dec_attn = RelPartialLearnableMHDPA(n_head, d_model, d_head, dropout, **kwargs) 125 | self.ffn = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm')) 126 | 127 | def forward(self, inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): 128 | out = self.dec_attn(inp, r, r_w_bias, r_r_bias, dec_attn_mask, mems) 129 | out = self.ffn(out) 130 | return out 131 | 132 | 133 | class RelPartialLearnableMHDPA(nn.Module): 134 | def __init__(self, n_head, d_model, d_head, dropout, tgt_len=None, mem_len=None, pre_lnorm=False): 135 | super(RelPartialLearnableMHDPA, self).__init__() 136 | 137 | self.n_head = n_head 138 | self.d_model = d_model 139 | self.d_head = d_head 140 | self.dropout = dropout 141 | 142 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) 143 | self.drop = nn.Dropout(p=dropout) 144 | self.dropatt = nn.Dropout(p=dropout) 145 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 146 | self.layer_norm = nn.LayerNorm(d_model) 147 | 148 | self.scale = 1 / (d_head ** .5) 149 | self.pre_ln = pre_lnorm 150 | # xl 151 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) 152 | 153 | def _rel_shift(self, x, zero_triu=False): 154 | bsz, klen, n_head, d_head = x.size() 155 | zero_pad = torch.zeros((bsz, 1, n_head, d_head), device=x.device, dtype=x.dtype) 156 | x_padded = torch.cat((zero_pad, x), 1) # bsz, klen+1, n_head, d_head 157 | x_padded = x_padded.view(klen + 1, bsz, n_head, d_head) 158 | x = x_padded[1:].view_as(x) 159 | 160 | if zero_triu: 161 | ones = torch.ones((x.size(0), x.size(1))) 162 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:, :, None, None] 163 | return x 164 | 165 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): 166 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) 167 | 168 | if mems is not None: 169 | cat = torch.cat((mems, w), 0) 170 | if self.pre_ln: 171 | w_heads = self.qkv_net(self.layer_norm(cat)) 172 | else: 173 | w_heads = self.qkv_net(cat) 174 | 175 | r_head_k = self.r_net(r) 176 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 177 | w_head_q = w_head_q[-qlen:] 178 | else: 179 | if self.pre_ln: 180 | w_heads = self.qkv_net(self.layer_norm(w)) 181 | else: 182 | w_heads = self.qkv_net(w) 183 | r_head_k = self.r_net(r) 184 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 185 | klen = w_head_k.size(0) 186 | 187 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen, bsz, n_head, d_head 188 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # memlen + qlen, bsz, n_head, d_head 189 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # memlen + qlen, bsz, n_head, d_head 190 | 191 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head 192 | 193 | rw_head_q = w_head_q + r_w_bias 194 | AC = torch.einsum('ibnd, jbnd->ijbn', (rw_head_q, w_head_q)) 195 | 196 | rr_head_q = w_head_q + r_r_bias 197 | BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) 198 | BD = self._rel_shift(BD) 199 | 200 | # qlen, klen, bsz, n_head 201 | attn_score = AC + BD 202 | attn_score.mul_(self.scale) 203 | 204 | if attn_mask is not None and attn_mask.any().item(): 205 | if attn_mask.dim() == 2: 206 | attn_score = attn_mask.float().masked_fill(attn_mask[None, :, :, None], -float('inf')).type_as( 207 | attn_score) 208 | elif attn_mask.dim() == 3: 209 | attn_score = attn_mask.float().masked_fill(attn_mask[:, :, :, None], -float('inf')).type_as(attn_score) 210 | 211 | attn_p = F.softmax(attn_score, -1) 212 | attn_p = self.dropatt(attn_p) 213 | 214 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_p, w_head_v)) 215 | 216 | attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 217 | 218 | attn_out = self.o_net(attn_vec) 219 | attn_out = self.drop(attn_out) 220 | 221 | if self.pre_ln: 222 | out = w + attn_out 223 | else: 224 | out = self.layer_norm(w + attn_out) 225 | return out 226 | 227 | 228 | class PositionwiseFF(nn.Module): 229 | def __init__(self, d_model, d_inner, dropout, pre_ln=False): 230 | self.d_model = d_model 231 | self.d_inner = d_inner 232 | self.dropout = dropout 233 | 234 | self.coreNet = nn.Sequential( 235 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), 236 | nn.Dropout(dropout), 237 | nn.Linear(d_inner, d_model), 238 | nn.Dropout(dropout), 239 | ) 240 | self.layer_norm = nn.LayerNorm(d_model) 241 | self.pre_ln = pre_ln 242 | 243 | def forward(self, inp): 244 | core_out = self.coreNet(inp) 245 | if self.pre_ln: 246 | out = core_out + inp 247 | else: 248 | out = self.layer_norm(inp + core_out) 249 | return out 250 | 251 | 252 | class PositionEmbedding(nn.Module): 253 | """ R_{i-j} in Att_rel in xl """ 254 | 255 | def __init__(self, d_emb): 256 | super(PositionEmbedding, self).__init__() 257 | self.d_emb = d_emb 258 | inv_freq = 1 / (10000 ** (torch.arange(.0, d_emb, 2.0) / d_emb)) 259 | self.register_buffer('inv_freq', inv_freq) 260 | 261 | def forward(self, pos_seq, bsz=None): 262 | sinuisoid_inp = torch.ger(pos_seq, self.inv_freq) # outer product 263 | 264 | if bsz is not None: 265 | return pos_seq[:, None, :].expand(-1, bsz, -1) 266 | else: 267 | return pos_seq[:, None, :] 268 | -------------------------------------------------------------------------------- /transformer_xl/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | #-*- encoding: utf-8 4 | 5 | ''' 6 | _____.___._______________ __.____ __________ _________ ___ ___ _____ .___ 7 | \__ | |\_ _____/ |/ _| | \ \ \_ ___ \ / | \ / _ \ | | 8 | / | | | __)_| < | | / | \ / \ \// ~ \/ /_\ \| | 9 | \____ | | \ | \| | / | \ \ \___\ Y / | \ | 10 | / ______|/_______ /____|__ \______/\____|__ / \______ /\___|_ /\____|__ /___| 11 | \/ \/ \/ \/ \/ \/ \/ 12 | 13 | ========================================================================================== 14 | 15 | @author: Yekun Chai 16 | 17 | @license: School of Informatics, Edinburgh 18 | 19 | @contact: chaiyekun@gmail.com 20 | 21 | @file: __init__.py.py 22 | 23 | @time: 17/10/2019 15:10 24 | 25 | @desc: 26 | 27 | ''' 28 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # -*- encoding: utf-8 4 | 5 | ''' 6 | _____.___._______________ __.____ __________ _________ ___ ___ _____ .___ 7 | \__ | |\_ _____/ |/ _| | \ \ \_ ___ \ / | \ / _ \ | | 8 | / | | | __)_| < | | / | \ / \ \// ~ \/ /_\ \| | 9 | \____ | | \ | \| | / | \ \ \___\ Y / | \ | 10 | / ______|/_______ /____|__ \______/\____|__ / \______ /\___|_ /\____|__ /___| 11 | \/ \/ \/ \/ \/ \/ \/ 12 | 13 | ========================================================================================== 14 | 15 | @author: Yekun Chai 16 | 17 | @license: School of Informatics, Edinburgh 18 | 19 | @contact: chaiyekun@gmail.com 20 | 21 | @file: utils.py 22 | 23 | @time: 29/09/2019 20:41 24 | 25 | @desc: 26 | 27 | ''' 28 | import torch 29 | import torch.nn as nn 30 | from torch.autograd import Variable 31 | import matplotlib.pyplot as plt 32 | 33 | import numpy as np 34 | import copy 35 | 36 | 37 | def clones(module, N): 38 | """ 39 | produce N identical layers 40 | :param module: 41 | :param N: 42 | :return: 43 | """ 44 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 45 | 46 | 47 | def subsequent_mask(size): 48 | """ Mask out subsequent positions """ 49 | attn_shape = (1, size, size) 50 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 51 | return torch.from_numpy(subsequent_mask) == 0 52 | 53 | 54 | class LabelSmoothing(nn.Module): 55 | def __init__(self, size, padding_idx, smoothing=.0): 56 | super(LabelSmoothing, self).__init__() 57 | self.criterion = nn.KLDivLoss(reduction='sum') 58 | self.padding_idx = padding_idx 59 | self.confidence = 1.0 - smoothing 60 | self.smoothing = smoothing 61 | self.size = size 62 | self.true_dist = None 63 | 64 | def forward(self, x, target): 65 | assert x.size(1) == self.size 66 | true_dist = x.data.clone() 67 | true_dist.fill_(self.smoothing / (self.size - 2)) 68 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 69 | true_dist[:, self.padding_idx] = 0 70 | mask = torch.nonzero(target.data == self.padding_idx) 71 | if mask.dim() > 0: 72 | true_dist.index_fill_(0, mask.squeeze(), .0) 73 | self.true_dist = true_dist 74 | return self.criterion(x, Variable(true_dist, requires_grad=False)) 75 | 76 | 77 | if __name__ == '__main__': 78 | # test subsequent_mask 79 | # --------------------------------- 80 | # plt.figure(figsize=(5, 5)) 81 | # plt.imshow(subsequent_mask(20)[0]) 82 | # plt.show() 83 | # --------------------------------- 84 | import os 85 | 86 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 87 | # test label smoothing 88 | # crit = LabelSmoothing(5, 0, .4) 89 | # predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0], [0, 0.2, 0.7, 0.1, 0], [0, 0.2, 0.7, 0.1, 0]]) 90 | # v = crit(Variable(predict.log()), 91 | # Variable(torch.LongTensor([2, 1, 0]))) 92 | # plt.imshow(crit.true_dist) 93 | # plt.show() 94 | # 95 | crit = LabelSmoothing(5, 0, 0.1) 96 | 97 | def loss(x): 98 | d = x + 3 * 1 99 | predict = torch.FloatTensor([[0, x / d, 1 / d, 1 / d, 1 / d], 100 | ]) 101 | # print(predict) 102 | return crit(Variable(predict.log()), 103 | Variable(torch.LongTensor([1]))).data.item() 104 | 105 | 106 | plt.plot(np.arange(1, 100), [loss(x) for x in range(1, 100)]) 107 | plt.show() 108 | --------------------------------------------------------------------------------