├── README.md ├── config.py ├── dataset.py ├── decoder.py ├── decoder_block.py ├── emb.py ├── encoder.py ├── encoder_block.py ├── evaluation.py ├── multihead_attn.py ├── ref ├── AI是如何学习的.pptx ├── aixiaoyutang.jpg ├── transformer_1.pdf └── transformer_2.pdf ├── train.py └── transformer.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-transformer 2 | 3 | ``` 4 | pytorch复现transformer 5 | 6 | 数据集: 德语翻译英语 7 | ``` 8 | 9 | # 运行效果 10 | 11 | ``` 12 | $ python evaluation.py 13 | Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche. -> Two young , White males are outside near many bushes . 14 | Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem. -> Several men in hard hats are operating a giant pulley system . 15 | Ein kleines Mädchen klettert in ein Spielhaus aus Holz. -> A little girl climbing into a wooden playhouse . 16 | Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster. -> A man in a blue shirt is standing on a ladder cleaning a window . 17 | Zwei Männer stehen am Herd und bereiten Essen zu. -> Two men are at the stove preparing food . 18 | Ein Mann in grün hält eine Gitarre, während der andere Mann sein Hemd ansieht. -> A man in green holds a guitar while the other man observes his shirt . 19 | Ein Mann lächelt einen ausgestopften Löwen an. -> A man is smiling at a stuffed lion 20 | Ein schickes Mädchen spricht mit dem Handy während sie langsam die Straße entlangschwebt. -> A trendy girl talking on her cellphone while gliding slowly down the street . 21 | ``` 22 | 23 | # 说明 24 | 25 | ``` 26 | ref: 理论 27 | checkpoints: 训练好的model 28 | train.py:训练 29 | evaluation.py:预测 30 | ``` 31 | 32 | # 了解更多 33 | 34 | ``` 35 | 代码和资源均免费公开,但如果你觉得很难看懂的话,那你大概率应该加入我的私教小课堂: 36 | 37 | 我会在鱼塘里做2件事情: 38 | 1,回答大家关于本项目的问题 39 | 2,定期组织直播课讲解本项目 40 | ``` 41 | 42 | ![ai小鱼塘](ref/aixiaoyutang.jpg) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # 设备 4 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | # 最长序列(受限于postition emb) 7 | SEQ_MAX_LEN=5000 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 德语->英语翻译数据集 3 | 参考: https://pytorch.org/tutorials/beginner/translation_transformer.html 4 | ''' 5 | 6 | from torchtext.data.utils import get_tokenizer 7 | from torchtext.vocab import build_vocab_from_iterator 8 | from torchtext.datasets import multi30k, Multi30k 9 | 10 | # 下载翻译数据集 11 | multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz" 12 | multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz" 13 | train_dataset = list(Multi30k(split='train', language_pair=('de', 'en'))) 14 | 15 | # 创建分词器 16 | de_tokenizer=get_tokenizer('spacy', language='de_core_news_sm') 17 | en_tokenizer=get_tokenizer('spacy', language='en_core_web_sm') 18 | 19 | # 生成词表 20 | UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # 特殊token 21 | UNK_SYM, PAD_SYM, BOS_SYM, EOS_SYM = '', '', '', '' 22 | 23 | de_tokens=[] # 德语token列表 24 | en_tokens=[] # 英语token列表 25 | for de,en in train_dataset: 26 | de_tokens.append(de_tokenizer(de)) 27 | en_tokens.append(en_tokenizer(en)) 28 | 29 | de_vocab=build_vocab_from_iterator(de_tokens,specials=[UNK_SYM, PAD_SYM, BOS_SYM, EOS_SYM],special_first=True) # 德语token词表 30 | de_vocab.set_default_index(UNK_IDX) 31 | en_vocab=build_vocab_from_iterator(en_tokens,specials=[UNK_SYM, PAD_SYM, BOS_SYM, EOS_SYM],special_first=True) # 英语token词表 32 | en_vocab.set_default_index(UNK_IDX) 33 | 34 | # 句子特征预处理 35 | def de_preprocess(de_sentence): 36 | tokens=de_tokenizer(de_sentence) 37 | tokens=[BOS_SYM]+tokens+[EOS_SYM] 38 | ids=de_vocab(tokens) 39 | return tokens,ids 40 | 41 | def en_preprocess(en_sentence): 42 | tokens=en_tokenizer(en_sentence) 43 | tokens=[BOS_SYM]+tokens+[EOS_SYM] 44 | ids=en_vocab(tokens) 45 | return tokens,ids 46 | 47 | if __name__ == '__main__': 48 | # 词表大小 49 | print('de vocab:', len(de_vocab)) 50 | print('en vocab:', len(en_vocab)) 51 | 52 | # 特征预处理 53 | de_sentence,en_sentence=train_dataset[0] 54 | print('de preprocess:',*de_preprocess(de_sentence)) 55 | print('en preprocess:',*en_preprocess(en_sentence)) -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | decoder解码器, 输出当前词序列的下一个词概率 3 | ''' 4 | from torch import nn 5 | import torch 6 | from emb import EmbeddingWithPosition 7 | from dataset import de_preprocess,en_preprocess,train_dataset,de_vocab,PAD_IDX,en_vocab 8 | from decoder_block import DecoderBlock 9 | from encoder import Encoder 10 | from config import DEVICE 11 | 12 | class Decoder(nn.Module): 13 | def __init__(self,vocab_size,emb_size,q_k_size,v_size,f_size,head,nblocks,dropout=0.1,seq_max_len=5000): 14 | super().__init__() 15 | self.emb=EmbeddingWithPosition(vocab_size,emb_size,dropout,seq_max_len) 16 | 17 | self.decoder_blocks=nn.ModuleList() 18 | for _ in range(nblocks): 19 | self.decoder_blocks.append(DecoderBlock(emb_size,q_k_size,v_size,f_size,head)) 20 | 21 | # 输出向量词概率Logits 22 | self.linear=nn.Linear(emb_size,vocab_size) 23 | 24 | def forward(self,x,encoder_z,encoder_x): # x: (batch_size,seq_len) 25 | first_attn_mask=(x==PAD_IDX).unsqueeze(1).expand(x.size()[0],x.size()[1],x.size()[1]).to(DEVICE) # 目标序列的pad掩码 26 | first_attn_mask=first_attn_mask|torch.triu(torch.ones(x.size()[1],x.size()[1]),diagonal=1).bool().unsqueeze(0).expand(x.size()[0],-1,-1).to(DEVICE) # &目标序列的向后看掩码 27 | # 根据来源序列的pad掩码,遮盖decoder对其pad部分的注意力 28 | second_attn_mask=(encoder_x==PAD_IDX).unsqueeze(1).expand(encoder_x.size()[0],x.size()[1],encoder_x.size()[1]).to(DEVICE) # (batch_size,target_len,src_len) 29 | 30 | x=self.emb(x) 31 | for block in self.decoder_blocks: 32 | x=block(x,encoder_z,first_attn_mask,second_attn_mask) 33 | 34 | return self.linear(x) # (batch_size,target_len,vocab_size) 35 | 36 | if __name__=='__main__': 37 | # 取2个de句子转词ID序列,输入给encoder 38 | de_tokens1,de_ids1=de_preprocess(train_dataset[0][0]) 39 | de_tokens2,de_ids2=de_preprocess(train_dataset[1][0]) 40 | # 对应2个en句子转词ID序列,再做embedding,输入给decoder 41 | en_tokens1,en_ids1=en_preprocess(train_dataset[0][1]) 42 | en_tokens2,en_ids2=en_preprocess(train_dataset[1][1]) 43 | 44 | # de句子组成batch并padding对齐 45 | if len(de_ids1)len(de_ids2): 48 | de_ids2.extend([PAD_IDX]*(len(de_ids1)-len(de_ids2))) 49 | 50 | enc_x_batch=torch.tensor([de_ids1,de_ids2],dtype=torch.long).to(DEVICE) 51 | print('enc_x_batch batch:', enc_x_batch.size()) 52 | 53 | # en句子组成batch并padding对齐 54 | if len(en_ids1)len(en_ids2): 57 | en_ids2.extend([PAD_IDX]*(len(en_ids1)-len(en_ids2))) 58 | 59 | dec_x_batch=torch.tensor([en_ids1,en_ids2],dtype=torch.long).to(DEVICE) 60 | print('dec_x_batch batch:', dec_x_batch.size()) 61 | 62 | # Encoder编码,输出每个词的编码向量 63 | enc=Encoder(vocab_size=len(de_vocab),emb_size=128,q_k_size=256,v_size=512,f_size=512,head=8,nblocks=3).to(DEVICE) 64 | enc_outputs=enc(enc_x_batch) 65 | print('encoder outputs:', enc_outputs.size()) 66 | 67 | # Decoder编码,输出每个词对应下一个词的概率 68 | dec=Decoder(vocab_size=len(en_vocab),emb_size=128,q_k_size=256,v_size=512,f_size=512,head=8,nblocks=3).to(DEVICE) 69 | enc_outputs=dec(dec_x_batch,enc_outputs,enc_x_batch) 70 | print('decoder outputs:', enc_outputs.size()) -------------------------------------------------------------------------------- /decoder_block.py: -------------------------------------------------------------------------------- 1 | ''' 2 | decoder block支持堆叠, 每个block都输入emb序列并输出emb序列(1:1对应) 3 | ''' 4 | from torch import nn 5 | import torch 6 | from multihead_attn import MultiHeadAttention 7 | from emb import EmbeddingWithPosition 8 | from dataset import de_preprocess,en_preprocess,train_dataset,de_vocab,PAD_IDX,en_vocab 9 | from encoder import Encoder 10 | from config import DEVICE 11 | 12 | class DecoderBlock(nn.Module): 13 | def __init__(self,emb_size,q_k_size,v_size,f_size,head): 14 | super().__init__() 15 | 16 | # 第1个多头注意力 17 | self.first_multihead_attn=MultiHeadAttention(emb_size,q_k_size,v_size,head) 18 | self.z_linear1=nn.Linear(head*v_size,emb_size) 19 | self.addnorm1=nn.LayerNorm(emb_size) 20 | 21 | # 第2个多头注意力 22 | self.second_multihead_attn=MultiHeadAttention(emb_size,q_k_size,v_size,head) 23 | self.z_linear2=nn.Linear(head*v_size,emb_size) 24 | self.addnorm2=nn.LayerNorm(emb_size) 25 | 26 | # feed-forward结构 27 | self.feedforward=nn.Sequential( 28 | nn.Linear(emb_size,f_size), 29 | nn.ReLU(), 30 | nn.Linear(f_size,emb_size) 31 | ) 32 | self.addnorm3=nn.LayerNorm(emb_size) 33 | 34 | def forward(self,x,encoder_z,first_attn_mask,second_attn_mask): # x: (batch_size,seq_len,emb_size) 35 | # 第1个多头 36 | z=self.first_multihead_attn(x,x,first_attn_mask) # z: (batch_size,seq_len,head*v_size) , first_attn_mask用于遮盖decoder序列的pad部分,以及避免decoder Q到每个词后面的词 37 | z=self.z_linear1(z) # z: (batch_size,seq_len,emb_size) 38 | output1=self.addnorm1(z+x) # x: (batch_size,seq_len,emb_size) 39 | 40 | # 第2个多头 41 | z=self.second_multihead_attn(output1,encoder_z,second_attn_mask) # z: (batch_size,seq_len,head*v_size) , second_attn_mask用于遮盖encoder序列的pad部分,避免decoder Q到它们 42 | z=self.z_linear2(z) # z: (batch_size,seq_len,emb_size) 43 | output2=self.addnorm2(z+output1) # x: (batch_size,seq_len,emb_size) 44 | 45 | # 最后feedforward 46 | z=self.feedforward(output2) # z: (batch_size,seq_len,emb_size) 47 | return self.addnorm3(z+output2) # (batch_size,seq_len,emb_size) 48 | 49 | if __name__=='__main__': 50 | # 取2个de句子转词ID序列,输入给encoder 51 | de_tokens1,de_ids1=de_preprocess(train_dataset[0][0]) 52 | de_tokens2,de_ids2=de_preprocess(train_dataset[1][0]) 53 | # 对应2个en句子转词ID序列,再做embedding,输入给decoder 54 | en_tokens1,en_ids1=en_preprocess(train_dataset[0][1]) 55 | en_tokens2,en_ids2=en_preprocess(train_dataset[1][1]) 56 | 57 | # de句子组成batch并padding对齐 58 | if len(de_ids1)len(de_ids2): 61 | de_ids2.extend([PAD_IDX]*(len(de_ids1)-len(de_ids2))) 62 | 63 | enc_x_batch=torch.tensor([de_ids1,de_ids2],dtype=torch.long).to(DEVICE) 64 | print('enc_x_batch batch:', enc_x_batch.size()) 65 | 66 | # en句子组成batch并padding对齐 67 | if len(en_ids1)len(en_ids2): 70 | en_ids2.extend([PAD_IDX]*(len(en_ids1)-len(en_ids2))) 71 | 72 | dec_x_batch=torch.tensor([en_ids1,en_ids2],dtype=torch.long).to(DEVICE) 73 | print('dec_x_batch batch:', dec_x_batch.size()) 74 | 75 | # Encoder编码,输出每个词的编码向量 76 | enc=Encoder(vocab_size=len(de_vocab),emb_size=128,q_k_size=256,v_size=512,f_size=512,head=8,nblocks=3).to(DEVICE) 77 | enc_outputs=enc(enc_x_batch) 78 | print('encoder outputs:', enc_outputs.size()) 79 | 80 | # 生成decoder所需的掩码 81 | first_attn_mask=(dec_x_batch==PAD_IDX).unsqueeze(1).expand(dec_x_batch.size()[0],dec_x_batch.size()[1],dec_x_batch.size()[1]) # 目标序列的pad掩码 82 | first_attn_mask=first_attn_mask|torch.triu(torch.ones(dec_x_batch.size()[1],dec_x_batch.size()[1]),diagonal=1).bool().unsqueeze(0).expand(dec_x_batch.size()[0],-1,-1).to(DEVICE) # &目标序列的向后看掩码 83 | print('first_attn_mask:',first_attn_mask.size()) 84 | # 根据来源序列的pad掩码,遮盖decoder每个Q对encoder输出K的注意力 85 | second_attn_mask=(enc_x_batch==PAD_IDX).unsqueeze(1).expand(enc_x_batch.size()[0],dec_x_batch.size()[1],enc_x_batch.size()[1]) # (batch_size,target_len,src_len) 86 | print('second_attn_mask:',second_attn_mask.size()) 87 | 88 | first_attn_mask=first_attn_mask.to(DEVICE) 89 | second_attn_mask=second_attn_mask.to(DEVICE) 90 | 91 | # Decoder输入做emb先 92 | emb=EmbeddingWithPosition(len(en_vocab),128).to(DEVICE) 93 | dec_x_emb_batch=emb(dec_x_batch) 94 | print('dec_x_emb_batch:',dec_x_emb_batch.size()) 95 | 96 | # 5个Decoder block堆叠 97 | decoder_blocks=[] 98 | for i in range(5): 99 | decoder_blocks.append(DecoderBlock(emb_size=128,q_k_size=256,v_size=512,f_size=512,head=8).to(DEVICE)) 100 | 101 | for i in range(5): 102 | dec_x_emb_batch=decoder_blocks[i](dec_x_emb_batch,enc_outputs,first_attn_mask,second_attn_mask) 103 | print('decoder_outputs:',dec_x_emb_batch.size()) -------------------------------------------------------------------------------- /emb.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 输入词序列,先做id向量化,再给id附加位置信息 3 | ''' 4 | from torch import nn 5 | import torch 6 | from dataset import de_vocab,de_preprocess,train_dataset 7 | import math 8 | 9 | class EmbeddingWithPosition(nn.Module): 10 | def __init__(self,vocab_size,emb_size,dropout=0.1,seq_max_len=5000): 11 | super().__init__() 12 | 13 | # 序列中的每个词转emb向量, 其他形状不变 14 | self.seq_emb=nn.Embedding(vocab_size,emb_size) 15 | 16 | # 为序列中每个位置准备一个位置向量,也是emb_size宽 17 | position_idx=torch.arange(0,seq_max_len,dtype=torch.float).unsqueeze(-1) 18 | position_emb_fill=position_idx*torch.exp(-torch.arange(0,emb_size,2)*math.log(10000.0)/emb_size) 19 | pos_encoding=torch.zeros(seq_max_len,emb_size) 20 | pos_encoding[:,0::2]=torch.sin(position_emb_fill) 21 | pos_encoding[:,1::2]=torch.cos(position_emb_fill) 22 | self.register_buffer('pos_encoding',pos_encoding) # 固定参数,不需要train 23 | 24 | # 防过拟合 25 | self.dropout=nn.Dropout(dropout) 26 | 27 | def forward(self,x): # x: (batch_size,seq_len) 28 | x=self.seq_emb(x) # x: (batch_size,seq_len,emb_size) 29 | x=x+self.pos_encoding.unsqueeze(0)[:,:x.size()[1],:] # x: (batch_size,seq_len,emb_size) 30 | return self.dropout(x) 31 | 32 | if __name__=='__main__': 33 | emb=EmbeddingWithPosition(len(de_vocab),128) 34 | 35 | de_tokens,de_ids=de_preprocess(train_dataset[0][0]) # 取de句子转词ID序列 36 | de_ids_tensor=torch.tensor(de_ids,dtype=torch.long) 37 | 38 | emb_result=emb(de_ids_tensor.unsqueeze(0)) # 转batch再输入模型 39 | print('de_ids_tensor:', de_ids_tensor.size(), 'emb_result:', emb_result.size()) -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | encoder编码器,输入词id序列,输出每个词的编码向量(输入输出1:1) 3 | ''' 4 | from torch import nn 5 | import torch 6 | from encoder_block import EncoderBlock 7 | from emb import EmbeddingWithPosition 8 | from dataset import de_preprocess,train_dataset,de_vocab,PAD_IDX 9 | from config import DEVICE 10 | 11 | class Encoder(nn.Module): 12 | def __init__(self,vocab_size,emb_size,q_k_size,v_size,f_size,head,nblocks,dropout=0.1,seq_max_len=5000): 13 | super().__init__() 14 | self.emb=EmbeddingWithPosition(vocab_size,emb_size,dropout,seq_max_len) 15 | 16 | self.encoder_blocks=nn.ModuleList() 17 | for _ in range(nblocks): 18 | self.encoder_blocks.append(EncoderBlock(emb_size,q_k_size,v_size,f_size,head)) 19 | 20 | def forward(self,x): # x:(batch_size,seq_len) 21 | pad_mask=(x==PAD_IDX).unsqueeze(1) # pad_mask:(batch_size,1,seq_len) 22 | pad_mask=pad_mask.expand(x.size()[0],x.size()[1],x.size()[1]) # pad_mask:(batch_size,seq_len,seq_len) 23 | 24 | pad_mask=pad_mask.to(DEVICE) 25 | 26 | x=self.emb(x) 27 | for block in self.encoder_blocks: 28 | x=block(x,pad_mask) # x:(batch_size,seq_len,emb_size) 29 | return x 30 | 31 | if __name__=='__main__': 32 | # 取2个de句子转词ID序列 33 | de_tokens1,de_ids1=de_preprocess(train_dataset[0][0]) 34 | de_tokens2,de_ids2=de_preprocess(train_dataset[1][0]) 35 | 36 | # 组成batch并padding对齐 37 | if len(de_ids1)len(de_ids2): 40 | de_ids2.extend([PAD_IDX]*(len(de_ids1)-len(de_ids2))) 41 | 42 | batch=torch.tensor([de_ids1,de_ids2],dtype=torch.long).to(DEVICE) 43 | print('batch:', batch.size()) 44 | 45 | # Encoder编码 46 | encoder=Encoder(vocab_size=len(de_vocab),emb_size=128,q_k_size=256,v_size=512,f_size=512,head=8,nblocks=3).to(DEVICE) 47 | z=encoder.forward(batch) 48 | print('encoder outputs:', z.size()) -------------------------------------------------------------------------------- /encoder_block.py: -------------------------------------------------------------------------------- 1 | ''' 2 | encoder block支持堆叠, 每个block都输入emb序列并输出emb序列(1:1对应) 3 | ''' 4 | from torch import nn 5 | import torch 6 | from multihead_attn import MultiHeadAttention 7 | from emb import EmbeddingWithPosition 8 | from dataset import de_preprocess,train_dataset,de_vocab 9 | 10 | class EncoderBlock(nn.Module): 11 | def __init__(self,emb_size,q_k_size,v_size,f_size,head): 12 | super().__init__() 13 | 14 | self.multihead_attn=MultiHeadAttention(emb_size,q_k_size,v_size,head) # 多头注意力 15 | self.z_linear=nn.Linear(head*v_size,emb_size) # 调整多头输出尺寸为emb_size 16 | self.addnorm1=nn.LayerNorm(emb_size) # 按last dim做norm 17 | 18 | # feed-forward结构 19 | self.feedforward=nn.Sequential( 20 | nn.Linear(emb_size,f_size), 21 | nn.ReLU(), 22 | nn.Linear(f_size,emb_size) 23 | ) 24 | self.addnorm2=nn.LayerNorm(emb_size) # 按last dim做norm 25 | 26 | def forward(self,x,attn_mask): # x: (batch_size,seq_len,emb_size) 27 | z=self.multihead_attn(x,x,attn_mask) # z: (batch_size,seq_len,head*v_size) 28 | z=self.z_linear(z) # z: (batch_size,seq_len,emb_size) 29 | output1=self.addnorm1(z+x) # z: (batch_size,seq_len,emb_size) 30 | 31 | z=self.feedforward(output1) # z: (batch_size,seq_len,emb_size) 32 | return self.addnorm2(z+output1) # (batch_size,seq_len,emb_size) 33 | 34 | if __name__=='__main__': 35 | # 准备1个batch 36 | emb=EmbeddingWithPosition(len(de_vocab),128) 37 | de_tokens,de_ids=de_preprocess(train_dataset[0][0]) # 取de句子转词ID序列 38 | de_ids_tensor=torch.tensor(de_ids,dtype=torch.long) 39 | emb_result=emb(de_ids_tensor.unsqueeze(0)) # 转batch再输入模型 40 | print('emb_result:', emb_result.size()) 41 | 42 | attn_mask=torch.zeros((1,de_ids_tensor.size()[0],de_ids_tensor.size()[0])) # batch中每个样本对应1个注意力矩阵 43 | 44 | # 5个Encoder block堆叠 45 | encoder_blocks=[] 46 | for i in range(5): 47 | encoder_blocks.append(EncoderBlock(emb_size=128,q_k_size=256,v_size=512,f_size=512,head=8)) 48 | 49 | # 前向forward 50 | encoder_outputs=emb_result 51 | for i in range(5): 52 | encoder_outputs=encoder_blocks[i](encoder_outputs,attn_mask) 53 | print('encoder_outputs:',encoder_outputs.size()) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataset import de_preprocess,train_dataset,BOS_IDX,EOS_IDX,UNK_IDX,PAD_IDX,en_vocab 3 | from config import DEVICE,SEQ_MAX_LEN 4 | 5 | # de翻译到en 6 | def translate(transformer,de_sentence): 7 | # De分词 8 | de_tokens,de_ids=de_preprocess(de_sentence) 9 | if len(de_tokens)>SEQ_MAX_LEN: 10 | raise Exception('不支持超过{}的句子'.format(SEQ_MAX_LEN)) 11 | 12 | # Encoder阶段 13 | enc_x_batch=torch.tensor([de_ids],dtype=torch.long).to(DEVICE) # 准备encoder输入 14 | encoder_z=transformer.encode(enc_x_batch) # encoder编码 15 | 16 | # Decoder阶段 17 | en_token_ids=[BOS_IDX] # 翻译结果 18 | while len(en_token_ids) {} -> {}'.format(de,en,en1)) 48 | ''' -------------------------------------------------------------------------------- /multihead_attn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 输入emb后的词序列,根据Q,K,V方法计算词与词之间的相关性,为每个词生成信息提取后的emb(与输入词1:1映射) 3 | ''' 4 | from torch import nn 5 | import torch 6 | from dataset import de_vocab,de_preprocess,train_dataset 7 | from emb import EmbeddingWithPosition 8 | import math 9 | 10 | class MultiHeadAttention(nn.Module): 11 | def __init__(self,emb_size,q_k_size,v_size,head): 12 | super().__init__() 13 | self.emb_size=emb_size 14 | self.q_k_size=q_k_size 15 | self.v_size=v_size 16 | self.head=head 17 | 18 | self.w_q=nn.Linear(emb_size,head*q_k_size) # 多头 19 | self.w_k=nn.Linear(emb_size,head*q_k_size) 20 | self.w_v=nn.Linear(emb_size,head*v_size) 21 | 22 | def forward(self,x_q,x_k_v,attn_mask): 23 | # x_q: (batch_size,seq_len,emb_size) 24 | q=self.w_q(x_q) # q: (batch_size,seq_len,head*q_k_size) 25 | k=self.w_k(x_k_v) # k: (batch_size,seq_len,head*q_k_size) 26 | 27 | # 多头兼容 28 | q=q.view(q.size()[0],q.size()[1],self.head,self.q_k_size).transpose(1,2) # q: (batch_size,head,seq_len,q_k_size) 29 | k=k.view(k.size()[0],k.size()[1],self.head,self.q_k_size).transpose(1,2).transpose(2,3) # k:(batch_size,head,q_k_size,seq_len) 30 | 31 | # 注意力矩阵 32 | attn=torch.matmul(q,k)/math.sqrt(self.q_k_size) # (batch_size,head,seq_len,seq_len) row是q,col是k 33 | 34 | # 注意力分值处理 35 | # attn_mask: (batch_size,seq_len,seq_len) 36 | attn_mask=attn_mask.unsqueeze(1).expand(-1,self.head,-1,-1) # attn_mask: (batch_size,head,seq_len,seq_len) 37 | attn=attn.masked_fill(attn_mask,-1e9) 38 | attn=torch.softmax(attn,dim=-1) # scores: (batch_size,head,seq_len,seq_len) 39 | 40 | # 注意力与V相乘 41 | v=self.w_v(x_k_v) # v: (batch_size,seq_len,head*v_size) 42 | v=v.view(v.size()[0],v.size()[1],self.head,self.v_size).transpose(1,2) # v: (batch_size,head,seq_len,v_size) 43 | z=torch.matmul(attn,v) # z: (batch_size,head,seq_len,v_size) 44 | z=z.transpose(1,2) # z: (batch_size,seq_len,head,v_size) 45 | return z.reshape(z.size()[0],z.size()[1],-1) # z: (batch_size,seq_len,head*v_size) 46 | 47 | if __name__=='__main__': 48 | # 准备1个batch 49 | emb=EmbeddingWithPosition(len(de_vocab),128) 50 | de_tokens,de_ids=de_preprocess(train_dataset[0][0]) # 取de句子转词ID序列 51 | de_ids_tensor=torch.tensor(de_ids,dtype=torch.long) 52 | emb_result=emb(de_ids_tensor.unsqueeze(0)) # 转batch再输入模型 53 | print('emb_result:', emb_result.size()) 54 | 55 | # 多头注意力 56 | multihead=MultiHeadAttention(emb_size=128,q_k_size=256,v_size=512,head=8) 57 | attn_mask=torch.zeros((1,de_ids_tensor.size()[0],de_ids_tensor.size()[0])) # batch中每个样本对应1个注意力矩阵 58 | multihead_result=multihead(x_q=emb_result,x_k_v=emb_result,attn_mask=attn_mask) 59 | print('multihead_result:', multihead_result.size()) -------------------------------------------------------------------------------- /ref/AI是如何学习的.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/pytorch-transformer/22eb49272edb9da52887a59e360a4f2ad22330c9/ref/AI是如何学习的.pptx -------------------------------------------------------------------------------- /ref/aixiaoyutang.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/pytorch-transformer/22eb49272edb9da52887a59e360a4f2ad22330c9/ref/aixiaoyutang.jpg -------------------------------------------------------------------------------- /ref/transformer_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/pytorch-transformer/22eb49272edb9da52887a59e360a4f2ad22330c9/ref/transformer_1.pdf -------------------------------------------------------------------------------- /ref/transformer_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owenliang/pytorch-transformer/22eb49272edb9da52887a59e360a4f2ad22330c9/ref/transformer_2.pdf -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 训练de翻译en模型 3 | ''' 4 | from torch import nn 5 | import torch 6 | from dataset import en_preprocess,de_preprocess,train_dataset,en_vocab,de_vocab,PAD_IDX 7 | from transformer import Transformer 8 | from torch.utils.data import DataLoader,Dataset 9 | from config import DEVICE,SEQ_MAX_LEN 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | # 数据集 13 | class De2EnDataset(Dataset): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | self.enc_x=[] 18 | self.dec_x=[] 19 | for de,en in train_dataset: 20 | # 分词 21 | de_tokens,de_ids=de_preprocess(de) 22 | en_tokens,en_ids=en_preprocess(en) 23 | # 序列超出的跳过 24 | if len(de_ids)>SEQ_MAX_LEN or len(en_ids)>SEQ_MAX_LEN: 25 | continue 26 | self.enc_x.append(de_ids) 27 | self.dec_x.append(en_ids) 28 | 29 | def __len__(self): 30 | return len(self.enc_x) 31 | 32 | def __getitem__(self, index): 33 | return self.enc_x[index],self.dec_x[index] 34 | 35 | def collate_fn(batch): 36 | enc_x_batch=[] 37 | dec_x_batch=[] 38 | for enc_x,dec_x in batch: 39 | enc_x_batch.append(torch.tensor(enc_x,dtype=torch.long)) 40 | dec_x_batch.append(torch.tensor(dec_x,dtype=torch.long)) 41 | 42 | # batch内序列长度补齐 43 | pad_enc_x=pad_sequence(enc_x_batch,True,PAD_IDX) 44 | pad_dec_x=pad_sequence(dec_x_batch,True,PAD_IDX) 45 | return pad_enc_x,pad_dec_x 46 | 47 | if __name__=='__main__': 48 | # de翻译en的数据集 49 | dataset=De2EnDataset() 50 | dataloader=DataLoader(dataset,batch_size=250,shuffle=True,num_workers=4,persistent_workers=True,collate_fn=collate_fn) 51 | 52 | # 模型 53 | try: 54 | transformer=torch.load('checkpoints/model.pth') 55 | except: 56 | transformer=Transformer(enc_vocab_size=len(de_vocab),dec_vocab_size=len(en_vocab),emb_size=512,q_k_size=64,v_size=64,f_size=2048,head=8,nblocks=6,dropout=0.1,seq_max_len=SEQ_MAX_LEN).to(DEVICE) 57 | 58 | # 损失函数和优化器 59 | loss_fn=nn.CrossEntropyLoss(ignore_index=PAD_IDX) # 样本正确输出序列的pad词不参与损失计算 60 | optimizer=torch.optim.SGD(transformer.parameters(), lr=1e-3, momentum=0.99) 61 | 62 | # 开始练 63 | transformer.train() 64 | EPOCHS=300 65 | for epoch in range(EPOCHS): 66 | batch_i=0 67 | loss_sum=0 68 | for pad_enc_x,pad_dec_x in dataloader: 69 | real_dec_z=pad_dec_x[:,1:].to(DEVICE) # decoder正确输出 70 | pad_enc_x=pad_enc_x.to(DEVICE) 71 | pad_dec_x=pad_dec_x[:,:-1].to(DEVICE) # decoder实际输入 72 | dec_z=transformer(pad_enc_x,pad_dec_x) # decoder实际输出 73 | 74 | batch_i+=1 75 | loss=loss_fn(dec_z.view(-1,dec_z.size()[-1]),real_dec_z.view(-1)) # 把整个batch中的所有词拉平 76 | loss_sum+=loss.item() 77 | print('epoch:{} batch:{} loss:{}'.format(epoch,batch_i,loss.item())) 78 | 79 | optimizer.zero_grad() 80 | loss.backward() 81 | optimizer.step() 82 | torch.save(transformer,'checkpoints/model.pth'.format(epoch)) -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | transformer模型 3 | 由encoder和decoder构成 4 | ''' 5 | from torch import nn 6 | import torch 7 | from decoder import Decoder 8 | from encoder import Encoder 9 | from dataset import en_preprocess,de_preprocess,train_dataset,en_vocab,de_vocab,PAD_IDX 10 | from config import DEVICE 11 | 12 | class Transformer(nn.Module): 13 | def __init__(self,enc_vocab_size,dec_vocab_size,emb_size,q_k_size,v_size,f_size,head,nblocks,dropout=0.1,seq_max_len=5000): 14 | super().__init__() 15 | self.encoder=Encoder(enc_vocab_size,emb_size,q_k_size,v_size,f_size,head,nblocks,dropout,seq_max_len) 16 | self.decoder=Decoder(dec_vocab_size,emb_size,q_k_size,v_size,f_size,head,nblocks,dropout,seq_max_len) 17 | 18 | def forward(self,encoder_x,decoder_x): 19 | encoder_z=self.encode(encoder_x) 20 | return self.decode(decoder_x,encoder_z,encoder_x) 21 | 22 | def encode(self,encoder_x): 23 | encoder_z=self.encoder(encoder_x) 24 | return encoder_z 25 | 26 | def decode(self,decoder_x,encoder_z,encoder_x): 27 | decoder_z=self.decoder(decoder_x,encoder_z,encoder_x) 28 | return decoder_z 29 | 30 | if __name__=='__main__': 31 | transformer=Transformer(enc_vocab_size=len(en_vocab),dec_vocab_size=len(de_vocab),emb_size=128,q_k_size=256,v_size=512,f_size=512,head=8,nblocks=3,dropout=0.1,seq_max_len=5000).to(DEVICE) 32 | 33 | # 取2个de句子转词ID序列,输入给encoder 34 | de_tokens1,de_ids1=de_preprocess(train_dataset[0][0]) 35 | de_tokens2,de_ids2=de_preprocess(train_dataset[1][0]) 36 | # 对应2个en句子转词ID序列,再做embedding,输入给decoder 37 | en_tokens1,en_ids1=en_preprocess(train_dataset[0][1]) 38 | en_tokens2,en_ids2=en_preprocess(train_dataset[1][1]) 39 | 40 | # de句子组成batch并padding对齐 41 | if len(de_ids1)len(de_ids2): 44 | de_ids2.extend([PAD_IDX]*(len(de_ids1)-len(de_ids2))) 45 | 46 | enc_x_batch=torch.tensor([de_ids1,de_ids2],dtype=torch.long).to(DEVICE) 47 | print('enc_x_batch batch:', enc_x_batch.size()) 48 | 49 | # en句子组成batch并padding对齐 50 | if len(en_ids1)len(en_ids2): 53 | en_ids2.extend([PAD_IDX]*(len(en_ids1)-len(en_ids2))) 54 | 55 | dec_x_batch=torch.tensor([en_ids1,en_ids2],dtype=torch.long).to(DEVICE) 56 | print('dec_x_batch batch:', dec_x_batch.size()) 57 | 58 | # 输出每个en词的下一个词概率 59 | decoder_z=transformer(enc_x_batch,dec_x_batch) 60 | print('decoder outputs:',decoder_z.size()) --------------------------------------------------------------------------------