├── LICENSE ├── README.md ├── Transformer.py ├── fra.txt ├── main.ipynb ├── source ├── 1.jpg ├── 2.png └── 3.png ├── train.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Transformer的pytorch实现 2 | 文件对应的程序: 3 | - Transformer.py ---> Transformer结构的整体实现 4 | - util.py ---> 读取数据 5 | - train.py ---> 训练和测试程序 6 | - main.ipynb ---> 执行函数 7 | - fra.txt ---> 数据 8 | 9 | 执行程序时,直接打开main.ipynb就行 10 | 11 | 针对程序的详细介绍,请转至[简书](https://www.jianshu.com/p/b0cf5520c4fa) 12 | 13 | torch版本: 1.1.0 14 | 15 | python版本: 3.7.0 16 | -------------------------------------------------------------------------------- /Transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | # ========================= Masked辅助函数 ============================== 9 | def masked_softmax(X, valid_length, value=-1e6): 10 | # 如果valid_length是一维的:valid_length的维度等于batch_size的大小 11 | # 对每一个batch去确定一个valid_length,因此valid_length的维度与batch_size大小相同 12 | # 再将valid_length内的元素通过repeat操作将valid_length内的元素repeat seq_len(X.size()[1])次 13 | # 结果就是对每一个batch上的X根据valid_length输出相应的attention weights,因此一个batch上的attention weights是一样的 14 | 15 | # 如果valid_length是二维的:valid_length的维度等于[batch_size, seq_length] 16 | # 此时是针对每一个batch的每一句话都设置了seq_length 17 | if valid_length is None: 18 | return F.softmax(X, dim=-1) 19 | else: 20 | X_size = X.size() 21 | device = valid_length.device 22 | if valid_length.dim() == 1: 23 | valid_length = torch.tensor(valid_length.cpu().numpy().repeat(X_size[1], axis=0), 24 | dtype=torch.float, device=device) if valid_length.is_cuda \ 25 | else torch.tensor(valid_length.numpy().repeat(X_size[1], axis=0), 26 | dtype=torch.float, device=device) 27 | else: 28 | valid_length = valid_length.view([-1]) 29 | X = X.view([-1, X_size[-1]]) 30 | max_seq_length = X_size[-1] 31 | valid_length = valid_length.to(torch.device('cpu')) 32 | mask = torch.arange(max_seq_length, dtype=torch.float)[None, :] >= valid_length[:, None] 33 | X[mask] = value 34 | X = X.view(X_size) 35 | return F.softmax(X, dim=-1) 36 | # ============================ 编码器实现 ================================= 37 | class DotProductAttention(nn.Module): 38 | # 经过DotProductAttention之后,输入输出的维度是不变的,都是[batch_size*h, seq_len, d_model//h] 39 | def __init__(self, dropout,): 40 | super(DotProductAttention, self).__init__() 41 | self.drop = nn.Dropout(dropout) 42 | 43 | def forward(self, Q, K, V, valid_length): 44 | # Q, K, V shape:[batch_size*h, seq_len, d_model//h] 45 | d_model = Q.size()[-1] # int 46 | # torch.bmm表示批次之间(>2维)的矩阵相乘 47 | attention_scores = torch.bmm(Q, K.transpose(1, 2))/math.sqrt(d_model) 48 | # attention_scores shape: [batch_size*h, seq_len, seq_len] 49 | attention_weights = self.drop(masked_softmax(attention_scores, valid_length)) 50 | return torch.bmm(attention_weights, V) # [batch_size*h, seq_len, d_model//h] 51 | 52 | class MultiHeadAttention(nn.Module): 53 | def __init__(self, input_size, hidden_size, num_heads, dropout,): 54 | super(MultiHeadAttention, self).__init__() 55 | # 保证MultiHeadAttention的输入输出tensor的维度一样 56 | assert hidden_size % num_heads == 0 57 | # hidden_size => d_model 58 | self.num_heads = num_heads 59 | # num_heads => h 60 | self.hidden_size = hidden_size 61 | # 这里的d_model为中间隐层单元的神经元数目,d_model=h*d_v=h*d_k=h*d_q 62 | self.Wq = nn.Linear(input_size, hidden_size, bias=False) 63 | self.Wk = nn.Linear(input_size, hidden_size, bias=False) 64 | self.Wv = nn.Linear(input_size, hidden_size, bias=False) 65 | self.Wo = nn.Linear(hidden_size, hidden_size, bias=False) 66 | self.attention = DotProductAttention(dropout) 67 | 68 | def _transpose_qkv(self, X): 69 | # X的输入维度为[batch_size, seq_len, d_model] 70 | # 通过该函数将X的维度改变成[batch_size*num_heads, seq_len, d_model//num_heads] 71 | self._batch, self._seq_len = X.size()[0], X.size()[1] 72 | X = X.view([self._batch, self._seq_len, self.num_heads, self.hidden_size//self.num_heads]) # [batch_size, seq_len, num_heads, d_model//num_heads] 73 | X = X.permute([0, 2, 1, 3]) # [batch_size, num_heads, seq_len, d_model//num_heads] 74 | return X.contiguous().view([self._batch*self.num_heads, self._seq_len, self.hidden_size//self.num_heads]) 75 | 76 | def _transpose_output(self, X): 77 | X = X.view([self._batch, self.num_heads, -1, self.hidden_size//self.num_heads]) 78 | X = X.permute([0, 2, 1, 3]) 79 | return X.contiguous().view([self._batch, -1, self.hidden_size]) 80 | 81 | def forward(self, query, key, value, valid_length): 82 | Q = self._transpose_qkv(self.Wq(query)) 83 | K = self._transpose_qkv(self.Wk(key)) 84 | V = self._transpose_qkv(self.Wv(value)) 85 | # 由于输入的valid_length是相对batch输入的,而经过_transpose_qkv之后, 86 | # batch的大小发生了改变,Q的第一维度由原来的batch改为batch*num_heads 87 | # 因此,需要对valid_length进行复制,也就是进行np.title的操作 88 | if valid_length is not None: 89 | device = valid_length.device 90 | valid_length = valid_length.cpu().numpy() if valid_length.is_cuda else valid_length.numpy() 91 | if valid_length.ndim == 1: 92 | valid_length = np.tile(valid_length, self.num_heads) 93 | else: 94 | valid_length = np.tile(valid_length, [self.num_heads, 1]) 95 | valid_length = torch.tensor(valid_length, dtype=torch.float, device=device) 96 | output = self.attention(Q, K, V, valid_length) 97 | output_concat = self._transpose_output(output) 98 | return self.Wo(output_concat) 99 | 100 | class PositionWiseFFN(nn.Module): 101 | # y = w*[max(0, wx+b)]x+b 102 | def __init__(self, input_size, fft_hidden_size, output_size,): 103 | super(PositionWiseFFN, self).__init__() 104 | self.FFN1 = nn.Linear(input_size, fft_hidden_size) 105 | self.FFN2 = nn.Linear(fft_hidden_size, output_size) 106 | 107 | def forward(self, X): 108 | return self.FFN2(F.relu(self.FFN1(X))) 109 | 110 | class AddNorm(nn.Module): 111 | def __init__(self, hidden_size, dropout,): 112 | super(AddNorm, self).__init__() 113 | self.drop = nn.Dropout(dropout) 114 | self.LN = nn.LayerNorm(hidden_size) 115 | 116 | def forward(self, X, Y): 117 | assert X.size() == Y.size() 118 | return self.LN(self.drop(Y) + X) 119 | 120 | class PositionalEncoding(nn.Module): 121 | def __init__(self, dropout,): 122 | super(PositionalEncoding, self).__init__() 123 | 124 | def forward(self, X, max_seq_len=None): 125 | if max_seq_len is None: 126 | max_seq_len = X.size()[1] 127 | # X为wordEmbedding的输入,PositionalEncoding与batch没有关系 128 | # max_seq_len越大,sin()或者cos()的周期越小,同样维度 129 | # 的X,针对不同的max_seq_len就可以得到不同的positionalEncoding 130 | assert X.size()[1] <= max_seq_len 131 | # X的维度为: [batch_size, seq_len, embed_size] 132 | # 其中: seq_len = l, embed_size = d 133 | l, d = X.size()[1], X.size()[-1] 134 | # P_{i,2j} = sin(i/10000^{2j/d}) 135 | # P_{i,2j+1} = cos(i/10000^{2j/d}) 136 | # for i=0,1,...,l-1 and j=0,1,2,...,[(d-2)/2] 137 | max_seq_len = int((max_seq_len//l)*l) 138 | P = np.zeros([1, l, d]) 139 | # T = i/10000^{2j/d} 140 | T = [i*1.0/10000**(2*j*1.0/d) for i in range(0, max_seq_len, max_seq_len//l) for j in range((d+1)//2)] 141 | T = np.array(T).reshape([l, (d+1)//2]) 142 | if d % 2 != 0: 143 | P[0, :, 1::2] = np.cos(T[:, :-1]) 144 | else: 145 | P[0, :, 1::2] = np.cos(T) 146 | P[0, :, 0::2] = np.sin(T) 147 | return torch.tensor(P, dtype=torch.float, device=X.device) 148 | 149 | class EncoderBlock(nn.Module): 150 | # 编码块由四部分构成,即多头注意力,addnorm,前馈神经网络,addnorm 151 | def __init__(self, embedding_size, ffn_hidden_size, num_heads, dropout,): 152 | super(EncoderBlock, self).__init__() 153 | self.attention = MultiHeadAttention(input_size=embedding_size, 154 | hidden_size=embedding_size, 155 | num_heads=num_heads, 156 | dropout=dropout, ) 157 | self.addnorm1 = AddNorm(hidden_size=embedding_size, dropout=dropout,) 158 | self.ffn = PositionWiseFFN(input_size=embedding_size, 159 | fft_hidden_size=ffn_hidden_size, 160 | output_size=embedding_size, ) 161 | self.addnorm2 = AddNorm(hidden_size=embedding_size, dropout=dropout,) 162 | 163 | def forward(self, X, valid_length=None): 164 | atten_out = self.attention(query=X, key=X, value=X, valid_length=valid_length) 165 | addnorm_out = self.addnorm1(X, atten_out) 166 | ffn_out = self.ffn(addnorm_out) 167 | return self.addnorm2(addnorm_out, ffn_out) 168 | 169 | class TransformerEncoder(nn.Module): 170 | def __init__(self, vocab_size, embedding_size, n_layers, hidden_size, num_heads, dropout, ): 171 | super(TransformerEncoder, self).__init__() 172 | self.vocab_size = vocab_size 173 | self.embedding_size = embedding_size 174 | self.n_layers = n_layers 175 | self.hidden_size = hidden_size 176 | self.num_heads = num_heads 177 | self.dropout = dropout 178 | 179 | self.word_embed = nn.Embedding(self.vocab_size, self.embedding_size) 180 | self.position_embed = PositionalEncoding(self.dropout,) 181 | self.drop = nn.Dropout(self.dropout) 182 | self.encoders = nn.ModuleList() 183 | for _ in range(self.n_layers): 184 | self.encoders.append(EncoderBlock(embedding_size=self.embedding_size, 185 | ffn_hidden_size=self.hidden_size, 186 | num_heads=self.num_heads, 187 | dropout=self.dropout, )) 188 | 189 | def forward(self, X, valid_length=None, max_seq_len=None): 190 | word_embedding = self.word_embed(X) 191 | word_embedding = word_embedding*math.sqrt(self.embedding_size) + \ 192 | self.position_embed(word_embedding, max_seq_len=max_seq_len) 193 | Y = self.drop(word_embedding) 194 | for i in range(self.n_layers): 195 | Y = self.encoders[i](Y, valid_length=valid_length) 196 | return Y 197 | # ============================ 解码器实现 ================================= 198 | class DecoderBlock(nn.Module): 199 | def __init__(self, embedding_size, ffn_hidden_size, num_heads, dropout,): 200 | super(DecoderBlock, self).__init__() 201 | self.attention1 = MultiHeadAttention(input_size=embedding_size, 202 | hidden_size=embedding_size, 203 | num_heads=num_heads, 204 | dropout=dropout, ) 205 | self.addnorm1 = AddNorm(hidden_size=embedding_size, dropout=dropout,) 206 | self.attention2 = MultiHeadAttention(input_size=embedding_size, 207 | hidden_size=embedding_size, 208 | num_heads=num_heads, 209 | dropout=dropout, ) 210 | self.addnorm2 = AddNorm(hidden_size=embedding_size, dropout=dropout,) 211 | self.ffn = PositionWiseFFN(input_size=embedding_size, 212 | fft_hidden_size=ffn_hidden_size, 213 | output_size=embedding_size, ) 214 | self.addnorm3 = AddNorm(hidden_size=embedding_size, dropout=dropout,) 215 | 216 | def forward(self, X, state): 217 | enc_output, enc_valid_length = state[0], state[1] 218 | 219 | if self.training: # 参数self自带 220 | batch_size, seq_len = X.size()[0], X.size()[1] 221 | dec_valid_length = torch.tensor(np.tile(np.arange(1, seq_len+1), [batch_size, 1]), 222 | dtype=torch.float, device=X.device) 223 | else: 224 | dec_valid_length = None 225 | 226 | attention_1_out = self.attention1(X, X, X, dec_valid_length) 227 | addnorm_1_out = self.addnorm1(X, attention_1_out) 228 | attention_2_out = self.attention2(addnorm_1_out, enc_output, enc_output, enc_valid_length) 229 | addnorm_2_out = self.addnorm2(addnorm_1_out, attention_2_out) 230 | ffn_out = self.ffn(addnorm_2_out) 231 | addnorm_3_out = self.addnorm3(addnorm_2_out, ffn_out) 232 | return addnorm_3_out, state 233 | 234 | class TransformerDecoder(nn.Module): 235 | def __init__(self, vocab_size, embedding_size, n_layers, hidden_size, 236 | num_heads, dropout, ): 237 | super(TransformerDecoder, self).__init__() 238 | self.vocab_size = vocab_size 239 | self.embedding_size = embedding_size 240 | self.n_layers = n_layers 241 | self.hidden_size = hidden_size 242 | self.num_heads = num_heads 243 | self.dropout = dropout 244 | 245 | self.word_embed = nn.Embedding(vocab_size, embedding_size) 246 | self.position_embed = PositionalEncoding(self.dropout) 247 | self.dense = nn.Linear(embedding_size, vocab_size) 248 | self.drop = nn.Dropout(self.dropout) 249 | self.decoders = nn.ModuleList() 250 | for _ in range(self.n_layers): 251 | self.decoders.append(DecoderBlock(embedding_size=self.embedding_size, 252 | ffn_hidden_size=self.hidden_size, 253 | num_heads=self.num_heads, 254 | dropout=self.dropout, )) 255 | 256 | def init_state(self, enc_output, enc_valid_length): 257 | return [enc_output, enc_valid_length] 258 | 259 | def forward(self, X, state, max_seq_len=None): 260 | word_embedding = self.word_embed(X) 261 | word_embedding = word_embedding*math.sqrt(self.embedding_size) + \ 262 | self.position_embed(word_embedding, max_seq_len=max_seq_len) 263 | Y = self.drop(word_embedding) 264 | for i in range(self.n_layers): 265 | Y, state = self.decoders[i](Y, state) 266 | return self.dense(Y), state 267 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-05-08T04:50:32.013886Z", 9 | "start_time": "2020-05-08T04:50:32.010359Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "# -*- coding: utf-8 -*-" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "### 1、导入相关的库\n", 22 | "- Transformer里面是关于Transformer模型的函数\n", 23 | "- util里面是相关的数据读取文件\n", 24 | "- train内是相关的训练和测试函数" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": { 31 | "ExecuteTime": { 32 | "end_time": "2020-05-08T04:50:32.449156Z", 33 | "start_time": "2020-05-08T04:50:32.026571Z" 34 | } 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "import os\n", 39 | "from Transformer import *\n", 40 | "from util import *\n", 41 | "from train import *\n", 42 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2,3,4\"\n", 43 | "device = torch.device('cuda')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "### 2、设置相关的参数" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": { 57 | "ExecuteTime": { 58 | "end_time": "2020-05-08T04:50:32.458145Z", 59 | "start_time": "2020-05-08T04:50:32.451725Z" 60 | } 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "embedding_size = 32 # token的维度\n", 65 | "num_layers = 2 # 编码器和解码器的层数,这里两者层数相同,也可以不同\n", 66 | "dropout = 0.05 # 所有层的droprate都相同,也可以不同\n", 67 | "batch_size = 64 # 批次\n", 68 | "num_steps = 10 # 预测步长\n", 69 | "factor = 1 # 学习率因子\n", 70 | "warmup = 2000 # 学习率上升步长\n", 71 | "lr, num_epochs, ctx = 0.005, 500, device # 学习率;周期;设备\n", 72 | "num_hiddens, num_heads = 64, 4 # 隐层单元的数目——表示FFN中间层的输出维度;attention的数目" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "### 3、导入文件\n", 80 | "文件为fra.txt文件" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "metadata": { 87 | "ExecuteTime": { 88 | "end_time": "2020-05-08T04:50:38.060788Z", 89 | "start_time": "2020-05-08T04:50:32.461294Z" 90 | } 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "src_vocab, tgt_vocab, train_iter = load_data_nmt(batch_size, num_steps)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "### 4、加载模型\n", 102 | "- TransformerEncoder为编码器模型\n", 103 | "- TransformerDecoder为解码器模型\n", 104 | "- transformer为编码器和解码器构成的最终模型" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "metadata": { 111 | "ExecuteTime": { 112 | "end_time": "2020-05-08T04:50:38.083069Z", 113 | "start_time": "2020-05-08T04:50:38.064776Z" 114 | } 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "encoder = TransformerEncoder(vocab_size=len(src_vocab), \n", 119 | " embedding_size=embedding_size, \n", 120 | " n_layers=num_layers, \n", 121 | " hidden_size=num_hiddens, \n", 122 | " num_heads=num_heads, \n", 123 | " dropout=dropout, )\n", 124 | "decoder = TransformerDecoder(vocab_size=len(src_vocab), \n", 125 | " embedding_size=embedding_size, \n", 126 | " n_layers=num_layers, \n", 127 | " hidden_size=num_hiddens, \n", 128 | " num_heads=num_heads, \n", 129 | " dropout=dropout, )" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "metadata": { 136 | "ExecuteTime": { 137 | "end_time": "2020-05-08T04:50:38.095197Z", 138 | "start_time": "2020-05-08T04:50:38.085535Z" 139 | } 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "class transformer(nn.Module):\n", 144 | " def __init__(self, enc_net, dec_net):\n", 145 | " super(transformer, self).__init__()\n", 146 | " self.enc_net = enc_net # TransformerEncoder的对象 \n", 147 | " self.dec_net = dec_net # TransformerDecoder的对象\n", 148 | " \n", 149 | " def forward(self, enc_X, dec_X, valid_length=None, max_seq_len=None):\n", 150 | " \"\"\"\n", 151 | " enc_X: 编码器的输入\n", 152 | " dec_X: 解码器的输入\n", 153 | " valid_length: 编码器的输入对应的valid_length,主要用于编码器attention的masksoftmax中,\n", 154 | " 并且还用于解码器的第二个attention的masksoftmax中\n", 155 | " max_seq_len: 位置编码时调整sin和cos周期大小的,默认大小为enc_X的第一个维度seq_len\n", 156 | " \"\"\"\n", 157 | " \n", 158 | " # 1、通过编码器得到编码器最后一层的输出enc_output\n", 159 | " enc_output = self.enc_net(enc_X, valid_length, max_seq_len)\n", 160 | " # 2、state为解码器的初始状态,state包含两个元素,分别为[enc_output, valid_length]\n", 161 | " state = self.dec_net.init_state(enc_output, valid_length)\n", 162 | " # 3、通过解码器得到编码器最后一层到线性层的输出output,这里的output不是解码器最后一层的输出,而是\n", 163 | " # 最后一层再连接线性层的输出\n", 164 | " output = self.dec_net(dec_X, state)\n", 165 | " return output" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 7, 171 | "metadata": { 172 | "ExecuteTime": { 173 | "end_time": "2020-05-08T04:50:38.101624Z", 174 | "start_time": "2020-05-08T04:50:38.097736Z" 175 | } 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "model = transformer(encoder, decoder)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "### 5、训练模型" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 8, 192 | "metadata": { 193 | "ExecuteTime": { 194 | "end_time": "2020-05-08T04:55:48.207886Z", 195 | "start_time": "2020-05-08T04:50:38.104140Z" 196 | } 197 | }, 198 | "outputs": [ 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "epoch 50,loss 0.096, time 29.3 sec\n", 204 | "epoch 100,loss 0.049, time 30.8 sec\n", 205 | "epoch 150,loss 0.042, time 30.0 sec\n", 206 | "epoch 200,loss 0.036, time 30.9 sec\n", 207 | "epoch 250,loss 0.035, time 31.7 sec\n", 208 | "epoch 300,loss 0.033, time 30.1 sec\n", 209 | "epoch 350,loss 0.032, time 31.6 sec\n", 210 | "epoch 400,loss 0.031, time 31.9 sec\n", 211 | "epoch 450,loss 0.031, time 30.1 sec\n", 212 | "epoch 500,loss 0.031, time 30.7 sec\n" 213 | ] 214 | } 215 | ], 216 | "source": [ 217 | "model.train()\n", 218 | "train(model, train_iter, lr, factor, warmup, num_epochs, ctx)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "### 6、测试模型" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 9, 231 | "metadata": { 232 | "ExecuteTime": { 233 | "end_time": "2020-05-08T04:55:48.353943Z", 234 | "start_time": "2020-05-08T04:55:48.212492Z" 235 | } 236 | }, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "Go . => va !\n", 243 | "Wow ! => !\n", 244 | "I'm OK . => ça .\n", 245 | "I won ! => je l'ai !\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "model.eval()\n", 251 | "for sentence in ['Go .', 'Wow !', \"I'm OK .\", 'I won !']:\n", 252 | " print(sentence + ' => ' + translate(model, sentence, src_vocab, tgt_vocab, num_steps, ctx))" 253 | ] 254 | } 255 | ], 256 | "metadata": { 257 | "kernelspec": { 258 | "display_name": "Python [conda env:gpu-py37]", 259 | "language": "python", 260 | "name": "conda-env-gpu-py37-py" 261 | }, 262 | "language_info": { 263 | "codemirror_mode": { 264 | "name": "ipython", 265 | "version": 3 266 | }, 267 | "file_extension": ".py", 268 | "mimetype": "text/x-python", 269 | "name": "python", 270 | "nbconvert_exporter": "python", 271 | "pygments_lexer": "ipython3", 272 | "version": "3.7.6" 273 | } 274 | }, 275 | "nbformat": 4, 276 | "nbformat_minor": 2 277 | } 278 | -------------------------------------------------------------------------------- /source/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cingtiye/Transformer-pytorch/416af13f53ef197aaf11fdf19bc33a27d878a9f9/source/1.jpg -------------------------------------------------------------------------------- /source/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cingtiye/Transformer-pytorch/416af13f53ef197aaf11fdf19bc33a27d878a9f9/source/2.png -------------------------------------------------------------------------------- /source/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cingtiye/Transformer-pytorch/416af13f53ef197aaf11fdf19bc33a27d878a9f9/source/3.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.optim as optim 4 | import time 5 | import torch.nn as nn 6 | 7 | def SequenceMask(X, X_len,value=0): 8 | maxlen = X.size(1) 9 | mask = torch.arange(maxlen)[None, :].to(X_len.device) < X_len[:, None] 10 | X[~mask] = value 11 | return X 12 | 13 | class MaskedSoftmaxCELoss(nn.CrossEntropyLoss): 14 | def forward(self, pred, label, valid_length): 15 | # the sample weights shape should be (batch_size, seq_len) 16 | weights = torch.ones_like(label) 17 | weights = SequenceMask(weights, valid_length).float() 18 | self.reduction='none' 19 | output=super(MaskedSoftmaxCELoss, self).forward(pred.transpose(1,2), label) 20 | return (output*weights).mean(dim=1) 21 | 22 | class NoamOpt: 23 | def __init__(self, model_size, factor, warmup, optimizer): 24 | self.optimizer = optimizer # 优化器 25 | self._step = 0 # 步长 26 | self.warmup = warmup # warmup_steps 27 | self.factor = factor # 学习率因子(就是学习率前面的系数) 28 | self.model_size = model_size # d_model 29 | self._rate = 0 # 学习率 30 | 31 | def step(self): 32 | "Update parameters and rate" 33 | self._step += 1 34 | rate = self.rate() 35 | for p in self.optimizer.param_groups: 36 | p['lr'] = rate 37 | self._rate = rate 38 | self.optimizer.step() 39 | 40 | def rate(self, step=None): 41 | "Implement `lrate` above" 42 | if step is None: 43 | step = self._step 44 | return self.factor * \ 45 | (self.model_size ** (-0.5) * 46 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 47 | 48 | def grad_clipping(params, theta, device): 49 | """Clip the gradient.""" 50 | norm = torch.tensor([0], dtype=torch.float32, device=device) 51 | for param in params: 52 | norm += (param.grad ** 2).sum() 53 | norm = norm.sqrt().item() 54 | if norm > theta: 55 | for param in params: 56 | param.grad.data.mul_(theta / norm) 57 | 58 | def grad_clipping_nn(model, theta, device): 59 | """Clip the gradient for a nn model.""" 60 | grad_clipping(model.parameters(), theta, device) 61 | 62 | # def get_std_opt(model): 63 | # return NoamOpt(model.src_embed[0].d_model, 2, 4000, 64 | # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 65 | 66 | def train(model, data_iter, lr, factor, warmup, num_epochs, device): 67 | model.to(device) 68 | # optimizer = optim.Adam(model.parameters(), lr=lr) 69 | optimizer = NoamOpt(model.enc_net.embedding_size, factor, warmup, 70 | torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)) 71 | loss = MaskedSoftmaxCELoss() 72 | tic = time.time() 73 | for epoch in range(1, num_epochs + 1): 74 | l_sum, num_tokens_sum = 0.0, 0.0 75 | for batch in data_iter: 76 | optimizer.optimizer.zero_grad() 77 | X, X_vlen, Y, Y_vlen = [x.to(device) for x in batch] 78 | Y_input, Y_label, Y_vlen = Y[:, :-1], Y[:, 1:], Y_vlen - 1 79 | 80 | Y_hat, _ = model(X, Y_input, X_vlen) 81 | l = loss(Y_hat, Y_label, Y_vlen).sum() 82 | l.backward() 83 | 84 | with torch.no_grad(): 85 | grad_clipping_nn(model, 5, device) 86 | num_tokens = Y_vlen.sum().item() 87 | optimizer.step() 88 | l_sum += l.sum().item() 89 | num_tokens_sum += num_tokens 90 | if epoch % 50 == 0: 91 | print("epoch {0:4d},loss {1:.3f}, time {2:.1f} sec".format( 92 | epoch, (l_sum / num_tokens_sum), time.time() - tic)) 93 | tic = time.time() 94 | 95 | def translate(model, src_sentence, src_vocab, tgt_vocab, max_len, device): 96 | """Translate based on an encoder-decoder model with greedy search.""" 97 | src_tokens = src_vocab[src_sentence.lower().split(' ')] 98 | src_len = len(src_tokens) 99 | if src_len < max_len: 100 | src_tokens += [src_vocab.pad] * (max_len - src_len) 101 | enc_X = torch.tensor(src_tokens, device=device) 102 | enc_valid_length = torch.tensor([src_len], device=device) 103 | # use expand_dim to add the batch_size dimension. 104 | enc_outputs = model.enc_net(enc_X.unsqueeze(dim=0), enc_valid_length) 105 | dec_state = model.dec_net.init_state(enc_outputs, enc_valid_length) 106 | dec_X = torch.tensor([tgt_vocab.bos], device=device).unsqueeze(dim=0) 107 | predict_tokens = [] 108 | for _ in range(max_len): 109 | Y, dec_state = model.dec_net(dec_X, dec_state) 110 | # The token with highest score is used as the next time step input. 111 | dec_X = Y.argmax(dim=2) 112 | py = dec_X.squeeze(dim=0).int().item() 113 | if py == tgt_vocab.eos: 114 | break 115 | predict_tokens.append(py) 116 | return ' '.join(tgt_vocab.to_tokens(predict_tokens)) -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import zipfile 4 | import torch 5 | import requests 6 | from io import BytesIO 7 | from torch.utils import data 8 | import sys 9 | import collections 10 | 11 | class Vocab(object): # This class is saved in d2l. 12 | def __init__(self, tokens, min_freq=0, use_special_tokens=False): 13 | # sort by frequency and token 14 | counter = collections.Counter(tokens) 15 | token_freqs = sorted(counter.items(), key=lambda x: x[0]) 16 | token_freqs.sort(key=lambda x: x[1], reverse=True) 17 | if use_special_tokens: 18 | # padding, begin of sentence, end of sentence, unknown 19 | self.pad, self.bos, self.eos, self.unk = (0, 1, 2, 3) 20 | tokens = ['', '', '', ''] 21 | else: 22 | self.unk = 0 23 | tokens = [''] 24 | tokens += [token for token, freq in token_freqs if freq >= min_freq] 25 | self.idx_to_token = [] 26 | self.token_to_idx = dict() 27 | for token in tokens: 28 | self.idx_to_token.append(token) 29 | self.token_to_idx[token] = len(self.idx_to_token) - 1 30 | 31 | def __len__(self): 32 | return len(self.idx_to_token) 33 | 34 | def __getitem__(self, tokens): 35 | if not isinstance(tokens, (list, tuple)): 36 | return self.token_to_idx.get(tokens, self.unk) 37 | else: 38 | return [self.__getitem__(token) for token in tokens] 39 | 40 | def to_tokens(self, indices): 41 | if not isinstance(indices, (list, tuple)): 42 | return self.idx_to_token[indices] 43 | else: 44 | return [self.idx_to_token[index] for index in indices] 45 | 46 | def load_data_nmt(batch_size, max_len, num_examples=1000): 47 | """Download an NMT dataset, return its vocabulary and data iterator.""" 48 | 49 | # Download and preprocess 50 | def preprocess_raw(text): 51 | text = text.replace('\u202f', ' ').replace('\xa0', ' ') 52 | out = '' 53 | for i, char in enumerate(text.lower()): 54 | if char in (',', '!', '.') and text[i - 1] != ' ': 55 | out += ' ' 56 | out += char 57 | return out 58 | 59 | with open('./fra.txt', 'r', encoding='utf-8') as f: 60 | raw_text = f.read() 61 | 62 | text = preprocess_raw(raw_text) 63 | 64 | # Tokenize 65 | source, target = [], [] 66 | for i, line in enumerate(text.split('\n')): 67 | if i >= num_examples: 68 | break 69 | parts = line.split('\t') 70 | if len(parts) >= 2: 71 | source.append(parts[0].split(' ')) 72 | target.append(parts[1].split(' ')) 73 | 74 | # Build vocab 75 | def build_vocab(tokens): 76 | tokens = [token for line in tokens for token in line] 77 | return Vocab(tokens, min_freq=3, use_special_tokens=True) 78 | 79 | src_vocab, tgt_vocab = build_vocab(source), build_vocab(target) 80 | 81 | # Convert to index arrays 82 | def pad(line, max_len, padding_token): 83 | if len(line) > max_len: 84 | return line[:max_len] 85 | return line + [padding_token] * (max_len - len(line)) 86 | 87 | def build_array(lines, vocab, max_len, is_source): 88 | lines = [vocab[line] for line in lines] 89 | if not is_source: 90 | lines = [[vocab.bos] + line + [vocab.eos] for line in lines] 91 | array = torch.tensor([pad(line, max_len, vocab.pad) for line in lines]) 92 | valid_len = (array != vocab.pad).sum(1) 93 | return array, valid_len 94 | 95 | src_vocab, tgt_vocab = build_vocab(source), build_vocab(target) 96 | src_array, src_valid_len = build_array(source, src_vocab, max_len, True) 97 | tgt_array, tgt_valid_len = build_array(target, tgt_vocab, max_len, False) 98 | train_data = data.TensorDataset(src_array, src_valid_len, tgt_array, tgt_valid_len) 99 | train_iter = data.DataLoader(train_data, batch_size, shuffle=True) 100 | return src_vocab, tgt_vocab, train_iter --------------------------------------------------------------------------------