├── .gitignore ├── .ipynb_checkpoints └── Untitled-checkpoint.ipynb ├── Beam.py ├── Transformer.py ├── Untitled.ipynb ├── config.py ├── mytest.py ├── mytest1.py ├── mytest_elmo.py ├── readme.md ├── record.md ├── train.py ├── train1.py ├── train_elmo.py ├── transformer ├── Beam.py ├── Constants.py ├── Layers.py ├── Models.py ├── Modules.py ├── Optim.py ├── SubLayers.py ├── Translator.py └── __init__.py ├── translate.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/* 2 | /ckpts/* 3 | /runs/ 4 | /.idea/* 5 | /.vscode/* 6 | /log 7 | /__pycache__/ 8 | /*.tar.gz 9 | /sumdata 10 | /biendata 11 | /build_vocab_for_biendata.py 12 | *.bak 13 | *.pyc 14 | *~ 15 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /Beam.py: -------------------------------------------------------------------------------- 1 | """Beam search implementation in PyTorch.""" 2 | # 3 | # 4 | # hyp1#-hyp1---hyp1 -hyp1 5 | # \ / 6 | # hyp2 \-hyp2 /-hyp2#hyp2 7 | # / \ 8 | # hyp3#-hyp3---hyp3 -hyp3 9 | # ======================== 10 | # 11 | # Takes care of beams, back pointers, and scores. 12 | 13 | # Code borrowed from PyTorch OpenNMT example 14 | # https://github.com/pytorch/examples/blob/master/OpenNMT/onmt/Beam.py 15 | 16 | import torch 17 | 18 | 19 | class Beam(object): 20 | """Ordered beam of candidate outputs.""" 21 | 22 | def __init__(self, size, vocab, max_trg_len, device=torch.device("cuda:0")): 23 | """Initialize params.""" 24 | self.size = size 25 | self.done = False 26 | # self.pad = vocab[''] 27 | self.bos = vocab[''] 28 | self.eos = vocab[''] 29 | self.pad = vocab[''] 30 | self.device = device 31 | self.tt = torch.cuda if device.type == "cuda" else torch 32 | # The score for each translation on the beam. 33 | self.scores = self.tt.FloatTensor(size, device=self.device).zero_() 34 | 35 | self.length=1 36 | self.lengths = self.tt.FloatTensor(size, device=self.device).fill_(12) 37 | # max_trg_len+1, include eos 38 | self.sequence = self.tt.LongTensor(size, max_trg_len+1, device=self.device).fill_(self.pad) 39 | self.sequence[0, 0] = self.bos 40 | 41 | def get_sequence(self): 42 | return self.sequence 43 | 44 | def advance(self, log_probs): 45 | if self.done: 46 | return True 47 | 48 | """Advance the beam.""" 49 | log_probs = log_probs.squeeze() # k*V 50 | num_words = log_probs.shape[-1] 51 | 52 | # Sum the previous scores. 53 | if self.length > 1: 54 | scores = log_probs + self.scores.unsqueeze(1).expand_as(log_probs) 55 | else: 56 | scores = log_probs[0] 57 | 58 | flat_scores = scores.view(-1) 59 | 60 | bestScores, bestScoresId = flat_scores.topk(self.size, 0, True, True) 61 | self.scores = bestScores 62 | 63 | # bestScoresId is flattened beam x word array, so calculate which 64 | # word and beam each score came from 65 | prev_k = bestScoresId / num_words 66 | word_idx = bestScoresId % num_words 67 | 68 | self.sequence = self.sequence[prev_k] 69 | self.sequence[:, self.length] = word_idx # = bestScoresId - prev_k * num_words 70 | 71 | self.length += 1 72 | 73 | # End condition is when top-of-beam is EOS. 74 | if self.sequence[0,-1] == self.eos: 75 | self.done = True 76 | 77 | def advance_v1(self, log_probs): 78 | if self.done: 79 | return True 80 | 81 | """Advance the beam.""" 82 | log_probs = log_probs.squeeze() # k*V 83 | num_words = log_probs.shape[-1] 84 | 85 | non_eos_mask = torch.ones_like(log_probs).cuda() 86 | non_eos_mask[:, self.eos] = 0 87 | # for i in range(self.size): 88 | # if self.sequence[i][int(self.lengths[i])-1] == self.eos: 89 | # non_eos_mask[i,:] = 0 90 | 91 | # Sum the previous scores. 92 | if self.length > 1: 93 | scores = log_probs + self.scores.unsqueeze(1).expand_as(log_probs) 94 | new_len = self.lengths.unsqueeze(1).expand_as(scores) + non_eos_mask 95 | normalized_scores = scores / new_len 96 | else: 97 | scores = log_probs[0] 98 | normalized_scores = scores 99 | 100 | flat_scores = normalized_scores.view(-1) 101 | 102 | bestScores, bestScoresId = flat_scores.topk(self.size, 0, True, True) 103 | self.scores = scores.view(-1)[bestScoresId] 104 | 105 | # bestScoresId is flattened beam x word array, so calculate which 106 | # word and beam each score came from 107 | prev_k = bestScoresId / num_words 108 | word_idx = bestScoresId % num_words 109 | 110 | self.sequence = self.sequence[prev_k] 111 | self.sequence[:, self.length] = word_idx # = bestScoresId - prev_k * num_words 112 | 113 | self.length += 1 114 | self.lengths[word_idx!=self.eos] += 1 115 | 116 | # End condition is when top-of-beam is EOS. 117 | if self.sequence[0,-1] == self.eos: 118 | self.done = True 119 | 120 | def get_hyp(self): 121 | return self.sequence[0] 122 | 123 | 124 | # import torch 125 | 126 | 127 | # class Beam(object): 128 | # """Ordered beam of candidate outputs.""" 129 | 130 | # def __init__(self, size, vocab, max_trg_len, device=torch.device("cuda:0"), use_ptr_gen=False): 131 | # """Initialize params.""" 132 | # self.size = size 133 | # self.done = False 134 | # self.vocab = vocab 135 | # self.rev_vocab = {v:k for k,v in vocab.items()} 136 | # self.use_ptr_gen = True 137 | # # self.pad = vocab[''] 138 | # self.bos = vocab[''] 139 | # self.eos = vocab[''] 140 | # self.pad = vocab[''] 141 | # self.device = device 142 | # self.tt = torch.cuda if device.type == "cuda" else torch 143 | # # The score for each translation on the beam. 144 | # self.scores = self.tt.FloatTensor(size, device=self.device).zero_() 145 | 146 | # self.length=1 147 | # # fill_(12) lead to better performance on giga, 0.1 for DUC2004 148 | # self.lengths = self.tt.FloatTensor(size, device=self.device).fill_(18) 149 | # # max_trg_len+1, include eos 150 | # self.ys = self.tt.LongTensor(size, max_trg_len+1, device=self.device).fill_(self.pad) 151 | # self.ys[0, 0] = self.bos 152 | # self.ext_ys = self.tt.LongTensor(size, max_trg_len+1, device=self.device).fill_(self.pad) 153 | # self.ext_ys[0,0] = self.bos 154 | 155 | # def get_sequence(self): 156 | # return self.ys 157 | 158 | # def advance(self, log_probs): 159 | # if self.done: 160 | # return True 161 | 162 | # """Advance the beam.""" 163 | # log_probs = log_probs.squeeze() # k*V 164 | # num_words = log_probs.shape[-1] 165 | 166 | # # Sum the previous scores. 167 | # if self.length > 1: 168 | # scores = log_probs + self.scores.unsqueeze(1).expand_as(log_probs) 169 | # else: 170 | # scores = log_probs[0] 171 | 172 | # flat_scores = scores.view(-1) 173 | 174 | # bestScores, bestScoresId = flat_scores.topk(self.size, 0, True, True) 175 | # self.scores = bestScores 176 | 177 | # # bestScoresId is flattened beam x word array, so calculate which 178 | # # word and beam each score came from 179 | # prev_k = bestScoresId / num_words 180 | # word_idx = bestScoresId % num_words 181 | 182 | # self.ext_ys = self.ext_ys[prev_k] 183 | # self.ext_ys[:, self.length] = word_idx # = bestScoresId - prev_k * num_words 184 | # self.ys = self.ys[prev_k] 185 | # for j in range(len(self.ys)): 186 | # self.ys[j,self.length] = word_idx[j] if int(word_idx[j].cpu().detach()) in self.rev_vocab else vocab[''] 187 | 188 | # self.length += 1 189 | 190 | # # End condition is when top-of-beam is EOS. 191 | # if self.ys[0,-1] == self.eos: 192 | # self.done = True 193 | 194 | # def advance_v1(self, log_probs): 195 | # if self.done: 196 | # return True 197 | 198 | # """Advance the beam.""" 199 | # log_probs = log_probs.squeeze() # k*V 200 | # num_words = log_probs.shape[-1] 201 | 202 | # non_eos_mask = torch.ones_like(log_probs).cuda() 203 | # non_eos_mask[:, self.eos] = 0 204 | # # for i in range(self.size): 205 | # # if self.ys[i][int(self.lengths[i])-1] == self.eos: 206 | # # non_eos_mask[i,:] = 0 207 | 208 | # # Sum the previous scores. 209 | # if self.length > 1: 210 | # scores = log_probs + self.scores.unsqueeze(1).expand_as(log_probs) 211 | # new_len = self.lengths.unsqueeze(1).expand_as(scores) + non_eos_mask 212 | # normalized_scores = scores / new_len 213 | # else: 214 | # scores = log_probs[0] 215 | # normalized_scores = scores 216 | 217 | # flat_scores = normalized_scores.view(-1) 218 | 219 | # bestScores, bestScoresId = flat_scores.topk(self.size, 0, True, True) 220 | # self.scores = scores.view(-1)[bestScoresId] 221 | 222 | # # bestScoresId is flattened beam x word array, so calculate which 223 | # # word and beam each score came from 224 | # prev_k = bestScoresId / num_words 225 | # word_idx = bestScoresId % num_words 226 | 227 | # self.ext_ys = self.ext_ys[prev_k] 228 | # self.ext_ys[:, self.length] = word_idx # = bestScoresId - prev_k * num_words 229 | 230 | # self.ys = self.ys[prev_k] 231 | # for j in range(len(self.ys)): 232 | # self.ys[j, self.length] = word_idx[j] if int(word_idx[j].cpu().detach()) in self.rev_vocab else vocab[''] 233 | # # self.ys[:, self.length] = word_idx # = bestScoresId - prev_k * num_words 234 | 235 | # self.length += 1 236 | # self.lengths = self.lengths[prev_k] 237 | # self.lengths[word_idx!=self.eos] += 1 238 | 239 | # # End condition is when top-of-beam is EOS. 240 | # if self.ys[0,-1] == self.eos: 241 | # self.done = True 242 | 243 | # def get_hyp(self): 244 | # return self.ext_ys[0] if self.use_ptr_gen else self.ys[0] 245 | -------------------------------------------------------------------------------- /Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import config 5 | # use elmo embeddings, comment the following line if you don't need it 6 | # from allennlp.modules.elmo import Elmo, batch_to_ids 7 | 8 | pad_index = config.pad_index 9 | 10 | 11 | def get_attn_mask(seq_q, seq_k): 12 | """ 13 | :param seq_q: [batch, l_q] 14 | :param seq_k: [batch, l_k] 15 | """ 16 | l_q = seq_q.size(-1) 17 | mask = seq_k.eq(pad_index).unsqueeze(1).expand(-1, l_q, -1) 18 | return mask 19 | 20 | 21 | def get_subsequent_mask(seq_q): 22 | bs, l_q = seq_q.size() 23 | subsequent_mask = torch.triu( 24 | torch.ones((l_q, l_q), device=seq_q.device, dtype=torch.uint8), diagonal=1 25 | ) 26 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(bs, -1, -1) 27 | return subsequent_mask 28 | 29 | 30 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 31 | ''' Sinusoid position encoding table ''' 32 | 33 | exponents = np.array([2 * (i // 2) / d_hid for i in range(d_hid)]) 34 | pow_table = np.power(10000, exponents) 35 | sinusoid_table = np.array([pos / pow_table for pos in range(n_position)]) 36 | 37 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 38 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 39 | 40 | if padding_idx is not None: 41 | # zero vector for padding dimension 42 | sinusoid_table[padding_idx] = 0. 43 | 44 | return torch.FloatTensor(sinusoid_table) 45 | 46 | 47 | class Embedding(nn.Module): 48 | def __init__(self, n_vocab, d_model, max_seq_len, embeddings=None): 49 | super(Embedding, self).__init__() 50 | self.word_embedding = nn.Embedding(n_vocab, d_model) 51 | if embeddings is not None: 52 | self.word_embedding = nn.Embedding.from_pretrained(embeddings, freeze=True) 53 | self.pos_embedding = nn.Embedding.from_pretrained( 54 | get_sinusoid_encoding_table(max_seq_len, d_model)) 55 | 56 | def forward(self, input_ids): 57 | pos_ids = torch.arange(input_ids.shape[-1], dtype=torch.long, device=input_ids.device) 58 | pos_ids.unsqueeze(0).expand_as(input_ids) 59 | embeddings = self.word_embedding(input_ids) + self.pos_embedding(pos_ids) 60 | return embeddings 61 | 62 | 63 | class ScaledDotAttention(nn.Module): 64 | def __init__(self, dropout=0.1): 65 | super(ScaledDotAttention, self).__init__() 66 | self.softmax = nn.Softmax(dim=-1) 67 | self.dropout = nn.Dropout(p=dropout) 68 | 69 | def forward(self, q, k, v, mask): 70 | """ 71 | perform scaled-dot-attention, given query,key,value 72 | :params q: query, size=[batch, q_len, d_q] 73 | :params k: key, size=[batch, k_len, d_k], d_q == d_k 74 | :params v: value, size=[batch, v_len, d_v], v_len==k_len 75 | :params mask: size=[batch, q_len, k_len] 76 | :return attn_vec: 77 | :return attn_weight: 78 | """ 79 | scale = 1 / np.sqrt(q.size(-1)) 80 | scores = torch.bmm(q, k.transpose(-1, -2)) * scale 81 | if mask is not None: 82 | scores = scores.masked_fill(mask, -1e9) 83 | attn_weight = self.softmax(scores) 84 | attn_weight = self.dropout(attn_weight) # dropout 85 | context = torch.bmm(attn_weight, v) 86 | return context, attn_weight 87 | 88 | 89 | class MultiheadAttention(nn.Module): 90 | def __init__(self, n_head, d_model, dropout=0.1): 91 | super(MultiheadAttention, self).__init__() 92 | self.n_head = n_head 93 | d_k = d_model // n_head 94 | d_v = d_model // n_head 95 | self.W_Q = nn.Linear(d_model, d_k * n_head) 96 | self.W_K = nn.Linear(d_model, d_k * n_head) 97 | self.W_V = nn.Linear(d_model, d_v * n_head) 98 | self.W_O = nn.Linear(d_v * n_head, d_model) 99 | 100 | self.attn_layer = ScaledDotAttention() 101 | self.dropout = nn.Dropout(p=dropout) 102 | self.layer_norm = nn.LayerNorm(d_model) 103 | 104 | def forward(self, Q, K, V, mask): 105 | """ 106 | 107 | :param Q: size=[bsz, l_q, d] 108 | :param K: size=[bsz, l_k, d] 109 | :param V: size=[bsz, l_k, d] 110 | :param mask: size=[bsz, l_q, l_k] 111 | :return: 112 | """ 113 | (bsz, l_q, d), l_k = Q.shape, K.shape[1] 114 | n = self.n_head 115 | 116 | residual = Q 117 | Qs = self.W_Q(Q).view(bsz, l_q, n, -1).transpose(1,2).contiguous().view(bsz * n, l_q, -1) 118 | Ks = self.W_K(K).view(bsz, l_k, n, -1).transpose(1,2).contiguous().view(bsz * n, l_k, -1) 119 | Vs = self.W_V(V).view(bsz, l_k, n, -1).transpose(1,2).contiguous().view(bsz * n, l_k, -1) 120 | 121 | mask = torch.cat([mask for _ in range(n)], dim=0) 122 | # [bsz * n, l_q, d // n], [bsz * n, l_q, l_k] 123 | context, attns = self.attn_layer(Qs, Ks, Vs, mask) 124 | 125 | context = context.view(bsz, n, l_q, -1).transpose(1, 2).contiguous().view(bsz, l_q, d) 126 | attns = attns.view(bsz, n, l_q, l_k).transpose(0, 1) # [n, bsz, l_q, l_k] 127 | 128 | # attn_weight = torch.cat(attn_weight, dim=-1) 129 | output = self.W_O(context) 130 | # TODO try output = self.layer_norm(self.dropout(output+residual)) 131 | # instead of self.layer_norm(self.dropout(output) + residual) 132 | output = self.dropout(output) 133 | output = self.layer_norm(output + residual) 134 | return output, context, attns.sum(dim=0) / n 135 | # return output, context, attn_weight 136 | 137 | 138 | class PositionWiseFeedForwardNet(nn.Module): 139 | def __init__(self, d_model, d_inner, dropout=0.1): 140 | super(PositionWiseFeedForwardNet, self).__init__() 141 | self.w1 = nn.Linear(d_model, d_inner, bias=True) 142 | self.w2 = nn.Linear(d_inner, d_model, bias=True) 143 | self.relu = nn.ReLU() 144 | self.layer_norm = nn.LayerNorm(d_model) 145 | self.dropout = nn.Dropout(p=dropout) 146 | 147 | def forward(self, inputs): 148 | """ 149 | :param inputs: [batch, src_len, d_model] 150 | """ 151 | residual = inputs 152 | output = self.relu(self.w1(inputs)) 153 | output = self.w2(output) 154 | output = self.dropout(output) 155 | output = self.layer_norm(output + residual) 156 | return output 157 | 158 | 159 | # class PositionWiseFeedForwardNet(nn.Module): 160 | # def __init__(self, d_model, d_inner, dropout=0.1): 161 | # super(PositionWiseFeedForwardNet, self).__init__() 162 | # self.conv1 = nn.Conv1d(d_model, d_inner, 1, bias=True) 163 | # self.conv2 = nn.Conv1d(d_inner, d_model, 1, bias=True) 164 | # self.relu = nn.ReLU() 165 | # self.layer_norm = nn.LayerNorm(d_model) 166 | # self.dropout = nn.Dropout(p=dropout) 167 | 168 | # def forward(self, inputs): 169 | # """ 170 | # :param inputs: [batch, src_len, d_model] 171 | # """ 172 | # residual = inputs 173 | # # why transpose? 174 | # # w1 = [C_in, C_out, kernel_size] 175 | # # which requires input.shape=[N, C_in, L], C_in == in_channels 176 | # # output.shape=[N, C_out, L], N is batch_size 177 | # output = self.relu(self.conv1(inputs.transpose(-1, -2))) 178 | # output = self.conv2(output).transpose(-1, -2) 179 | # output = self.dropout(output) 180 | # output = self.layer_norm(output + residual) 181 | # return output 182 | 183 | 184 | class EncoderLayer(nn.Module): 185 | def __init__(self, n_head, d_model, d_inner, dropout=0.1): 186 | super(EncoderLayer, self).__init__() 187 | self.multi_attn = MultiheadAttention(n_head, d_model, dropout) 188 | self.poswise_ffn = PositionWiseFeedForwardNet(d_model, d_inner, dropout) 189 | 190 | def forward(self, enc_inputs, mask=None): 191 | output, _, attn_weights = self.multi_attn(enc_inputs, enc_inputs, enc_inputs, mask) 192 | output = self.poswise_ffn(output) 193 | return output, attn_weights 194 | 195 | 196 | class Encoder(nn.Module): 197 | def __init__(self, n_vocab, max_seq_len, n_layer, n_head, 198 | d_model, d_inner, dropout=0.1, embeddings=None): 199 | super(Encoder, self).__init__() 200 | self.embedding = Embedding(n_vocab, d_model, max_seq_len, embeddings) 201 | self.layers = nn.ModuleList( 202 | [EncoderLayer(n_head, d_model, d_inner, dropout) for _ in range(n_layer)] 203 | ) 204 | 205 | def forward(self, src_seq): 206 | 207 | mask = get_attn_mask(src_seq, src_seq) 208 | 209 | enc_output_list = [] 210 | attn_weight_list = [] 211 | 212 | enc_output = self.embedding(src_seq) 213 | for layer in self.layers: 214 | enc_output, attn_weight = layer(enc_output, mask) 215 | enc_output_list.append(enc_output) 216 | attn_weight_list.append(attn_weight) 217 | 218 | return enc_output, attn_weight_list 219 | 220 | 221 | class DecoderLayer(nn.Module): 222 | def __init__(self, n_head, d_model, d_inner, dropout=0.1): 223 | super(DecoderLayer, self).__init__() 224 | self.multi_attn_masked = MultiheadAttention(n_head, d_model, dropout) 225 | self.multi_attn = MultiheadAttention(n_head, d_model, dropout) 226 | self.poswise_ffn = PositionWiseFeedForwardNet(d_model, d_inner, dropout) 227 | 228 | def forward(self, enc_outputs, dec_inputs, self_mask=None, dec_enc_mask=None): 229 | dec_outputs, _, self_attns = self.multi_attn_masked( 230 | dec_inputs, dec_inputs, dec_inputs, self_mask 231 | ) 232 | 233 | dec_outputs, _, dec_enc_attns = self.multi_attn( 234 | dec_outputs, enc_outputs, enc_outputs, dec_enc_mask 235 | ) 236 | 237 | output = self.poswise_ffn(dec_outputs) 238 | return output, self_attns, dec_enc_attns 239 | 240 | 241 | class Decoder(nn.Module): 242 | def __init__(self, n_vocab, max_seq_len, n_layer, n_head, 243 | d_model, d_inner, dropout=0.1, embeddings=None): 244 | super(Decoder, self).__init__() 245 | self.embedding = Embedding(n_vocab, d_model, max_seq_len, embeddings) 246 | self.layers = nn.ModuleList( 247 | [DecoderLayer(n_head, d_model, d_inner, dropout) for _ in range(n_layer)] 248 | ) 249 | 250 | def forward(self, enc_outputs, src_seq, tgt_seq): 251 | """ """ 252 | dec_slf_mask = get_attn_mask(tgt_seq, tgt_seq) 253 | dec_subseq_mask = get_subsequent_mask(tgt_seq) 254 | dec_slf_mask = (dec_slf_mask + dec_subseq_mask).gt(0) 255 | 256 | dec_enc_mask = get_attn_mask(tgt_seq, src_seq) 257 | 258 | dec_slf_attn_list = [] 259 | dec_enc_attn_list = [] 260 | 261 | dec_outputs = self.embedding(tgt_seq) # init 262 | 263 | for layer in self.layers: 264 | dec_outputs, dec_slf_attn, dec_enc_attn = \ 265 | layer(enc_outputs, dec_outputs, dec_slf_mask, dec_enc_mask) 266 | 267 | dec_slf_attn_list.append(dec_slf_attn) 268 | dec_enc_attn_list.append(dec_enc_attn) 269 | 270 | return dec_outputs, dec_slf_attn_list, dec_enc_attn_list 271 | 272 | 273 | class Transformer(nn.Module): 274 | def __init__(self, n_src_vocab, n_tgt_vocab, max_src_len, max_tgt_len, n_layer, 275 | n_head, d_model, d_inner, dropout=0.1, embeddings=None, 276 | src_tgt_emb_share=True, tgt_prj_wt_share=True): 277 | super(Transformer, self).__init__() 278 | 279 | self.encoder = Encoder(n_src_vocab, max_src_len, n_layer, n_head, 280 | d_model, d_inner, dropout, embeddings) 281 | self.decoder = Decoder(n_tgt_vocab, max_tgt_len, n_layer, n_head, 282 | d_model, d_inner, dropout, embeddings) 283 | 284 | self.tgt_word_proj = nn.Linear(d_model, n_tgt_vocab, bias=False) 285 | 286 | # It seems weight sharing leads to GPU out of memory 287 | if src_tgt_emb_share: 288 | assert n_src_vocab == n_tgt_vocab, \ 289 | "To share word embedding table, the vocabulary size of src/tgt shall be the same." 290 | self.encoder.embedding.word_embedding.weight \ 291 | = self.decoder.embedding.word_embedding.weight 292 | 293 | if tgt_prj_wt_share: 294 | self.tgt_word_proj.weight = self.decoder.embedding.word_embedding.weight 295 | self.logit_scale = (d_model ** -0.5) 296 | # self.logit_scale = 1. 297 | else: 298 | self.logit_scale = 1. 299 | 300 | self.loss_layer = nn.CrossEntropyLoss(ignore_index=pad_index) 301 | 302 | def forward(self, src_seq, tgt_seq): 303 | enc_outputs, *_ = self.encoder(src_seq) 304 | 305 | tgt_seq = tgt_seq[:, :-1] 306 | logits, dec_enc_attn = self.decode(enc_outputs, src_seq, tgt_seq) 307 | return logits * self.logit_scale, dec_enc_attn 308 | 309 | def encode(self, src_seq): 310 | enc_outputs, *_ = self.encoder(src_seq) 311 | return enc_outputs 312 | 313 | def decode(self, enc_outputs, src_seq, tgt_seq): 314 | dec_outputs, _, dec_enc_attn_list = self.decoder(enc_outputs, src_seq, tgt_seq) 315 | logits = self.tgt_word_proj(dec_outputs) 316 | return logits * self.logit_scale, dec_enc_attn_list[-1] 317 | 318 | 319 | class EncoderShareEmbedding(nn.Module): 320 | def __init__(self, n_layer, n_head, d_model, d_inner, dropout=0.1): 321 | super(EncoderShareEmbedding, self).__init__() 322 | self.layers = nn.ModuleList( 323 | [EncoderLayer(n_head, d_model, d_inner, dropout) for _ in range(n_layer)] 324 | ) 325 | 326 | def forward(self, enc_inputs, src_seq): 327 | mask = get_attn_mask(src_seq, src_seq) 328 | 329 | enc_output_list = [] 330 | attn_weight_list = [] 331 | 332 | enc_output = enc_inputs 333 | for layer in self.layers: 334 | enc_output, attn_weight = layer(enc_output, mask) 335 | enc_output_list.append(enc_output) 336 | attn_weight_list.append(attn_weight) 337 | 338 | return enc_output_list, attn_weight_list 339 | 340 | 341 | class DecoderShareEmbedding(nn.Module): 342 | def __init__(self, n_layer, n_head, d_model, d_inner, dropout=0.1): 343 | super(DecoderShareEmbedding, self).__init__() 344 | self.layers = nn.ModuleList( 345 | [DecoderLayer(n_head, d_model, d_inner, dropout) for _ in range(n_layer)] 346 | ) 347 | 348 | def forward(self, enc_output_list, dec_inputs, src_seq, tgt_seq): 349 | """ 350 | Noticement: enc_outputs is a list 351 | """ 352 | dec_slf_mask = get_attn_mask(tgt_seq, tgt_seq) 353 | dec_subseq_mask = get_subsequent_mask(tgt_seq) 354 | dec_slf_mask = (dec_slf_mask + dec_subseq_mask).gt(0) 355 | 356 | dec_enc_mask = get_attn_mask(tgt_seq, src_seq) 357 | 358 | dec_slf_attn_list = [] 359 | dec_enc_attn_list = [] 360 | 361 | dec_outputs = dec_inputs 362 | enc_outputs = enc_output_list[-1] # use only the top layer 363 | for layer in self.layers: 364 | dec_outputs, dec_slf_attn, dec_enc_attn = \ 365 | layer(enc_outputs, dec_outputs, dec_slf_mask, dec_enc_mask) 366 | 367 | dec_slf_attn_list.append(dec_slf_attn) 368 | dec_enc_attn_list.append(dec_enc_attn) 369 | 370 | return dec_outputs, dec_slf_attn_list, dec_enc_attn_list 371 | 372 | 373 | class TransformerShareEmbedding(nn.Module): 374 | def __init__(self, n_vocab, max_seq_len, n_layer, n_head, 375 | d_model, d_inner, tgt_prj_wt_share=False, embeddings=None): 376 | super(TransformerShareEmbedding, self).__init__() 377 | 378 | self.embedding = Embedding(n_vocab, d_model, max_seq_len, embeddings=embeddings) 379 | 380 | self.encoder = EncoderShareEmbedding(n_layer, n_head, d_model, d_inner) 381 | self.decoder = DecoderShareEmbedding(n_layer, n_head, d_model, d_inner) 382 | 383 | if tgt_prj_wt_share: 384 | self.tgt_word_proj = nn.Linear(d_model, n_vocab) 385 | self.tgt_word_proj.weight = self.embedding.word_embedding.weight 386 | self.logit_scale = (d_model ** -0.5) 387 | else: 388 | self.tgt_word_proj = nn.Linear(d_model, n_vocab) 389 | self.logit_scale = 1. 390 | 391 | self.loss_layer = nn.CrossEntropyLoss(ignore_index=pad_index) 392 | 393 | def forward(self, src_seq, tgt_seq): 394 | tgt_seq = tgt_seq[:, :-1] 395 | 396 | src_embeds = self.embedding(src_seq) 397 | tgt_embeds = self.embedding(tgt_seq) 398 | enc_output_list, *_ = self.encoder(src_embeds, src_seq) 399 | dec_outputs, *_ = self.decoder(enc_output_list, tgt_embeds, src_seq, tgt_seq) 400 | 401 | logits = self.tgt_word_proj(dec_outputs) * self.logit_scale 402 | return logits 403 | 404 | 405 | class ElmoEmbedder(nn.Module): 406 | def __init__(self, requires_grad=False, dropout=0.5): 407 | super(ElmoEmbedder, self).__init__() 408 | # TODO, 409 | # set num_output_representations=3 may result to better performance 410 | self.elmo = Elmo(config.options_file, 411 | config.weight_file, 412 | num_output_representations=2, 413 | requires_grad=requires_grad, 414 | dropout=dropout) 415 | 416 | @ staticmethod 417 | def batch_to_ids(stncs, tgt_flag=False): 418 | """ 419 | convert list of text into ids that elmo accepts 420 | :param stncs: [['I', 'Like', 'you'],['Yes'] ] 421 | :param tgt_flag: indicates if the inputs is a target sentences, if it is, 422 | use only the previous words as context, and neglect last word 423 | :return ids: indices to feed into elmo 424 | """ 425 | ids = batch_to_ids(stncs) # (batch, seqlen, 50) 426 | if tgt_flag: 427 | ids = ids[:,:-1,:] # neglect the last word 428 | b_size, _len, dim = ids.shape 429 | expand_ids = torch.zeros(b_size * _len, _len, dim, dtype=torch.long) 430 | for i in range(1, _len + 1): 431 | expand_ids[b_size*(i-1):b_size*i, :i, :] = ids[:, :i, :] 432 | return expand_ids 433 | return ids 434 | 435 | def forward(self, stncs, tgt_flag=False, idx=-1): 436 | """ 437 | produce elmo embedding of a batch of variable length text 438 | :param stnc: list of sentences, a sentence is a list of words 439 | :param tgt_flag: indicates if the inputs is a target sentences, if it is, 440 | use only the previous words as context 441 | :param idx: the idx-th layer representation 442 | :return embeddings: list of tensor in shape [batch, max_len, d_model] 443 | :return mask: 0-1 tensor in shape [batch, max_len], 0 for padding 444 | """ 445 | elmo_ids = self.batch_to_ids(stncs, tgt_flag).cuda() 446 | elmo_output = self.elmo(elmo_ids) 447 | embeddings = elmo_output['elmo_representations'][idx] 448 | mask = elmo_output['mask'] 449 | if tgt_flag: 450 | b_size = len(stncs) 451 | embeddings = [embeddings[i][i//b_size].unsqueeze(0) 452 | for i in range(embeddings.shape[0])] 453 | embeddings = torch.cat(embeddings, dim=0).view(elmo_ids.shape[1], b_size, -1) 454 | embeddings = embeddings.transpose(0, 1) 455 | mask = mask[-b_size:] 456 | return embeddings, mask 457 | 458 | 459 | class ElmoTransformer(nn.Module): 460 | def __init__(self, max_len, n_vocab, n_layer, n_head, 461 | d_model, d_inner, dropout=0.1, elmo_requires_grad=False): 462 | super(ElmoTransformer, self).__init__() 463 | 464 | self.n_vocab = n_vocab 465 | self.d_model = d_model # 256 is the size of small elmo 466 | self.elmo_embedder = ElmoEmbedder(requires_grad=elmo_requires_grad, 467 | dropout=dropout) # False means freeze 468 | self.word_embedder = nn.Embedding(n_vocab, d_model) 469 | self.position_embedder = nn.Embedding.from_pretrained( 470 | get_sinusoid_encoding_table(max_len, self.d_emb), freeze=True) 471 | 472 | self.encoder = EncoderShareEmbedding(n_layer, n_head, d_model, d_inner, dropout) 473 | self.decoder = DecoderShareEmbedding(n_layer, n_head, d_model, d_inner, dropout) 474 | 475 | self.dropout = nn.Dropout(dropout) 476 | self.tgt_word_prj = nn.Linear(self.d_model, n_vocab) 477 | 478 | self.loss_layer = nn.CrossEntropyLoss(ignore_index=config.pad_index) 479 | 480 | @ staticmethod 481 | def get_position_ids(stncs, tgt_flag=False): 482 | max_len = max(len(s) for s in stncs) 483 | if tgt_flag: 484 | max_len -= 1 # exclude token 485 | pos_ids = torch.arange(max_len, dtype=torch.long) 486 | pos_ids = pos_ids.unsqueeze(0).expand(len(stncs), -1) 487 | return pos_ids.cuda() 488 | 489 | def embed(self, stncs, ids, tgt_flag=False): 490 | elmo_emb, mask = self.elmo_embedder(stncs, tgt_flag, idx=-1) 491 | pos_emb = self.position_embedder(self.get_position_ids(stncs, tgt_flag)) 492 | if tgt_flag: 493 | word_emb = self.word_embedder(ids[:,:-1]) 494 | else: 495 | word_emb = self.word_embedder(ids) 496 | embedding = torch.cat([elmo_emb, self.dropout(word_emb + pos_emb)], dim=-1) 497 | return embedding, mask 498 | 499 | def forward(self, src_stncs, tgt_stncs, src_ids, tgt_ids): 500 | src_embeds, src_mask = self.embed(src_stncs, src_ids, tgt_flag=False) 501 | tgt_embeds, tgt_mask = self.embed(tgt_stncs, tgt_ids, tgt_flag=True) 502 | 503 | enc_output_list, *_ = self.encoder(src_embeds, src_ids) 504 | dec_output, *_ = self.decoder(enc_output_list, tgt_embeds, src_mask, tgt_mask) 505 | logits = self.tgt_word_prj(dec_output) 506 | return logits 507 | 508 | -------------------------------------------------------------------------------- /Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from Transformer import Transformer\n", 10 | "from utils import BatchManager, load_data\n", 11 | "import translate\n", 12 | "from translate import greedy\n", 13 | "from imp import reload\n", 14 | "import torch\n", 15 | "import os\n", 16 | "import numpy as np\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import json\n", 19 | "from translate import greedy\n", 20 | "\n", 21 | "os.environ['CUDA_VISIBLE_DEVICES'] = '3'" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 12, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "data_dir = '/home/disk3/tiankeke/sumdata/'\n", 31 | "TEST_X = os.path.join(data_dir, 'Giga/input.txt')\n", 32 | "\n", 33 | "vocab_file = 'sumdata/small_vocab.json'\n", 34 | "vocab = json.load(open(vocab_file))\n", 35 | "\n", 36 | "max_src_len = 100\n", 37 | "max_tgt_len = 40\n", 38 | "bsz = 10\n", 39 | "\n", 40 | "test_x = BatchManager(load_data(TEST_X, max_src_len), bsz, vocab)\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "model = Transformer(len(vocab), len(vocab), 200, 200, 2, 4, 256,\n", 50 | " 1024, src_tgt_emb_share=False, tgt_prj_wt_share=True).cuda().eval()\n", 51 | "states = torch.load('models/params_v2_9.pkl')\n", 52 | "model.load_state_dict(states['state_dict'])" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": { 59 | "collapsed": true 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "def calc_entropy(attns):\n", 64 | " \"\"\"\n", 65 | " Calculate the entropy of attention distribution over src text\n", 66 | " :param attns: [bsz, len_tgt, len_src]\n", 67 | " \"\"\"\n", 68 | " ent_word = (- np.log(attns) * attns).sum(axis=-1)\n", 69 | " ent_stnc = ent_word.sum(axis=-1) / (attns.shape[1])\n", 70 | " ent_batch =ent_stnc.sum() / attns.shape[0]\n", 71 | " return ent_word, ent_stnc, ent_batch\n", 72 | "\n", 73 | "def plot_bar(dist):\n", 74 | " plt.bar(range(len(dist)), height=dist)\n", 75 | " plt.show()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 17, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "_, x = test_x.next_batch()\n", 85 | "logits, attns = greedy(model, x, vocab)\n", 86 | "attns = attns.cpu().detach().numpy()\n", 87 | "attns += 1e-9" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 14, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "(10, 14, 38)" 99 | ] 100 | }, 101 | "execution_count": 14, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "attns.shape" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 18, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "ent_w, ent_s, ent_b = calc_entropy(attns)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 19, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "3.02136974335\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "print(ent_b)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 22, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEWFJREFUeJzt3X+MZWV9x/H3p7uwWjSosDaWhc5S6I+1GkvX1URrjVRc\npGVtCuli09KEZttGmjbW6JKmqFQTMG3pH6U/aEEJ1ALFmm5kW2rEtokxuIsisOLWAbcyYmQtiKUG\ncfHbP+6hXm9nmDOzd+fenef9SiZzznOec+d7n5353Oc+996zqSokSW34vkkXIElaOYa+JDXE0Jek\nhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSFrJ13AqBNPPLFmZmYmXYYkHVXuvPPOr1XV+sX6\nTV3oz8zMsHfv3kmXIUlHlST/2aefyzuS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jek\nhhj6ktSQqftErqTxmdl567ztBy4/Z4Ur0bRwpi9JDTH0JakhvUI/ydYk+5PMJtk5z/HXJPl0kkNJ\nzhs5dmGSL3RfF46rcEnS0i0a+knWAFcBZwObgAuSbBrp9iXg14APjpz7AuCdwCuALcA7kzz/8MuW\nJC1Hn5n+FmC2qh6oqieBG4Ftwx2q6kBV3Q18Z+TcNwAfrapHqupR4KPA1jHULUlahj6hfxLw4ND+\nXNfWx+GcK0kasz6hn3naquft9zo3yY4ke5PsPXjwYM+bliQtVZ/QnwNOHtrfADzU8/Z7nVtVV1fV\n5qravH79ov/blyRpmfqE/h7g9CQbkxwLbAd29bz924Czkjy/ewH3rK5NkjQBi4Z+VR0CLmYQ1vcB\nN1fVviSXJTkXIMnLk8wB5wN/lWRfd+4jwB8yeODYA1zWtUmSJqDXZRiqajewe6Tt0qHtPQyWbuY7\n91rg2sOoUZI0Jn4iV5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JD\nDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUkF7/iUorZnbeuuCxA5efs4KVSNKR4Uxfkhpi6EtS\nQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDWk\nV+gn2Zpkf5LZJDvnOb4uyU3d8TuSzHTtxyS5Lsk9Se5Lcsl4y5ckLcWioZ9kDXAVcDawCbggyaaR\nbhcBj1bVacCVwBVd+/nAuqp6CfBTwG88/YAgSVp5fWb6W4DZqnqgqp4EbgS2jfTZBlzXbd8CnJkk\nQAHHJVkLPBt4EvjGWCqXJC1Zn9A/CXhwaH+ua5u3T1UdAh4DTmDwAPA/wFeALwF/VFWPHGbNkqRl\n6hP6maetevbZAjwF/CCwEfi9JKf+vx+Q7EiyN8negwcP9ihJkrQcfUJ/Djh5aH8D8NBCfbqlnOOB\nR4A3A/9cVd+uqoeBTwCbR39AVV1dVZuravP69euXfi8kSb30Cf09wOlJNiY5FtgO7Brpswu4sNs+\nD7i9qorBks7rMnAc8Erg8+MpXZK0VIuGfrdGfzFwG3AfcHNV7UtyWZJzu27XACckmQXeCjz9ts6r\ngOcA9zJ48Hh/Vd095vsgSeppbZ9OVbUb2D3SdunQ9hMM3p45et7j87VLkibDT+RKUkMMfUlqiKEv\nSQ0x9CWpIYa+JDXE0JekhvR6y6Z0JM3svHXBYwcuP2cFK5FWP2f6ktQQQ1+SGmLoS1JDDH1Jaoih\nL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS\n1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0JekhvQK/SRbk+xPMptk5zzH1yW5qTt+R5KZ\noWMvTfLJJPuS3JPkWeMrX5K0FIuGfpI1wFXA2cAm4IIkm0a6XQQ8WlWnAVcCV3TnrgVuAH6zql4M\nvBb49tiqlyQtSZ+Z/hZgtqoeqKongRuBbSN9tgHXddu3AGcmCXAWcHdVfRagqv6rqp4aT+mSpKXq\nE/onAQ8O7c91bfP2qapDwGPACcCPAJXktiSfTvL2wy9ZkrRca3v0yTxt1bPPWuDVwMuBbwIfS3Jn\nVX3se05OdgA7AE455ZQeJUmSlqPPTH8OOHlofwPw0EJ9unX844FHuvZ/q6qvVdU3gd3AGaM/oKqu\nrqrNVbV5/fr1S78XkqRe+oT+HuD0JBuTHAtsB3aN9NkFXNhtnwfcXlUF3Aa8NMn3dw8GPwN8bjyl\nS5KWatHlnao6lORiBgG+Bri2qvYluQzYW1W7gGuA65PMMpjhb+/OfTTJnzB44Chgd1XdeoTuyxE3\ns3P+0g9cfs4KVyJJy9NnTZ+q2s1gaWa47dKh7SeA8xc49wYGb9uUJE2Yn8iVpIb0mulr+i209AQu\nP0n6Lmf6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9\nSWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jek\nhhj6ktQQQ1+SGmLoS1JD1k66gHGb2XnrvO0HLj9nhSuRpOmz6kJ/tVrowQx8QJPUX6/lnSRbk+xP\nMptk5zzH1yW5qTt+R5KZkeOnJHk8ydvGU7YkaTkWneknWQNcBbwemAP2JNlVVZ8b6nYR8GhVnZZk\nO3AF8EtDx68E/ml8ZUs6XD57bFOfmf4WYLaqHqiqJ4EbgW0jfbYB13XbtwBnJglAkjcBDwD7xlOy\nJGm5+oT+ScCDQ/tzXdu8farqEPAYcEKS44B3AO8+/FIlSYerT+hnnrbq2efdwJVV9fgz/oBkR5K9\nSfYePHiwR0mSpOXo8+6dOeDkof0NwEML9JlLshY4HngEeAVwXpL3Ac8DvpPkiar6s+GTq+pq4GqA\nzZs3jz6gSJLGpE/o7wFOT7IR+DKwHXjzSJ9dwIXAJ4HzgNurqoCffrpDkncBj48GviRp5Swa+lV1\nKMnFwG3AGuDaqtqX5DJgb1XtAq4Brk8yy2CGv/1IFi1JWp5eH86qqt3A7pG2S4e2nwDOX+Q23rWM\n+iRJY9TUJ3J9X7Kk1nnBNUlqiKEvSQ0x9CWpIU2t6UvjtppfJ1rN961lzvQlqSGGviQ1xNCXpIYY\n+pLUEF/I1dTz/z2WxseZviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQ\nl6SGGPqS1BCvvaMV4fVzpOlg6Es6Inygn04u70hSQ5zpryBnPpImzZm+JDXEmb7Gwmcx0tHB0J8i\nBqekI83lHUlqiKEvSQ0x9CWpIb1CP8nWJPuTzCbZOc/xdUlu6o7fkWSma399kjuT3NN9f914y5ck\nLcWioZ9kDXAVcDawCbggyaaRbhcBj1bVacCVwBVd+9eAn6+qlwAXAtePq3BJ0tL1melvAWar6oGq\nehK4Edg20mcbcF23fQtwZpJU1Weq6qGufR/wrCTrxlG4JGnp+rxl8yTgwaH9OeAVC/WpqkNJHgNO\nYDDTf9ovAp+pqm8tv1xpdfFtulppfUI/87TVUvokeTGDJZ+z5v0ByQ5gB8App5zSoyRJ0nL0Wd6Z\nA04e2t8APLRQnyRrgeOBR7r9DcCHgV+tqvvn+wFVdXVVba6qzevXr1/aPZAk9dZnpr8HOD3JRuDL\nwHbgzSN9djF4ofaTwHnA7VVVSZ4H3ApcUlWfGF/Z0tFhoeUb6LeE4/KPxm3RmX5VHQIuBm4D7gNu\nrqp9SS5Lcm7X7RrghCSzwFuBp9/WeTFwGvAHSe7qvl449nshSeql17V3qmo3sHuk7dKh7SeA8+c5\n7z3Aew6zRknSmPiJXElqiKEvSQ0x9CWpIV5PX1qE76DRauJMX5Ia4kxfq54zdem7nOlLUkOc6Uua\nCJ+BTYYzfUlqiDP9RjirWphjM538dzkyDH0Bi/+BTfMf4DTXJk0bQ186ivmAp6Uy9CUdlXzAWx5f\nyJWkhjjTHyNnHmqJv+9HJ2f6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLU\nEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SG9Ar9JFuT7E8ym2TnPMfXJbmpO35Hkpmh\nY5d07fuTvGF8pUuSlmrR0E+yBrgKOBvYBFyQZNNIt4uAR6vqNOBK4Iru3E3AduDFwFbgz7vbkyRN\nQJ+Z/hZgtqoeqKongRuBbSN9tgHXddu3AGcmSdd+Y1V9q6q+CMx2tydJmoA+oX8S8ODQ/lzXNm+f\nqjoEPAac0PNcSdIKSVU9c4fkfOANVfXr3f6vAFuq6reH+uzr+sx1+/czmNFfBnyyqm7o2q8BdlfV\nh0Z+xg5gR7f7o8D+Mdw3gBOBr43ptsbN2pZnmmuD6a7P2pbnaKnth6pq/WInrO1xo3PAyUP7G4CH\nFugzl2QtcDzwSM9zqaqrgat71LIkSfZW1eZx3+44WNvyTHNtMN31WdvyrLba+izv7AFOT7IxybEM\nXpjdNdJnF3Bht30ecHsNnkLsArZ37+7ZCJwOfGopBUqSxmfRmX5VHUpyMXAbsAa4tqr2JbkM2FtV\nu4BrgOuTzDKY4W/vzt2X5Gbgc8Ah4C1V9dQRui+SpEX0Wd6hqnYDu0faLh3afgI4f4Fz3wu89zBq\nPBxjXzIaI2tbnmmuDaa7PmtbnlVV26Iv5EqSVg8vwyBJDVmVob/YZSMmLcmBJPckuSvJ3gnXcm2S\nh5PcO9T2giQfTfKF7vvzp6i2dyX5cjd2dyV544RqOznJx5Pcl2Rfkt/p2ic+ds9Q28THLsmzknwq\nyWe72t7dtW/sLuHyhe6SLseudG2L1PeBJF8cGruXTaK+rpY1ST6T5CPd/tLGrqpW1ReDF5vvB04F\njgU+C2yadF0jNR4ATpx0HV0trwHOAO4dansfsLPb3glcMUW1vQt42xSM24uAM7rt5wL/weAyJRMf\nu2eobeJjBwR4Trd9DHAH8ErgZmB71/6XwG9NWX0fAM6b9O9dV9dbgQ8CH+n2lzR2q3Gm3+eyEepU\n1b8zeMfVsOHLalwHvGlFi+osUNtUqKqvVNWnu+3/Bu5j8GnziY/dM9Q2cTXweLd7TPdVwOsYXMIF\nJvs7t1B9UyHJBuAc4G+6/bDEsVuNoX80XPqhgH9Jcmf3aeRp8wNV9RUYBAjwwgnXM+riJHd3yz8T\nWXoa1l1V9icZzAqnauxGaoMpGLtueeIu4GHgowyemX+9BpdwgQn/zY7WV1VPj917u7G7Msm6CZX3\np8Dbge90+yewxLFbjaGfedqm5pG686qqOoPBlUvfkuQ1ky7oKPIXwA8DLwO+AvzxJItJ8hzgQ8Dv\nVtU3JlnLqHlqm4qxq6qnquplDD6hvwX48fm6rWxVQz94pL4kPwFcAvwY8HLgBcA7VrquJD8HPFxV\ndw43z9P1GcduNYZ+r0s/TFJVPdR9fxj4MNN35dGvJnkRQPf94QnX83+q6qvdH+V3gL9mgmOX5BgG\nofq3VfUPXfNUjN18tU3T2HX1fB34VwZr5s/rLuECU/I3O1Tf1m7JrKrqW8D7mczYvQo4N8kBBsvW\nr2Mw81/S2K3G0O9z2YiJSXJckuc+vQ2cBdz7zGetuOHLalwI/OMEa/keTwdq5xeY0Nh1a6nXAPdV\n1Z8MHZr42C1U2zSMXZL1SZ7XbT8b+FkGrzl8nMElXGCCv3ML1Pf5oQfyMFgzX/Gxq6pLqmpDVc0w\nyLXbq+qXWerYTfqV6CP06vYbGbxj4X7g9yddz0htpzJ4R9FngX2Trg/4OwZP9b/N4FnSRQzWCT8G\nfKH7/oIpqu164B7gbgYB+6IJ1fZqBk+j7wbu6r7eOA1j9wy1TXzsgJcCn+lquBe4tGs/lcF1uWaB\nvwfWTejfdaH6bu/G7l7gBrp3+EzqC3gt3333zpLGzk/kSlJDVuPyjiRpAYa+JDXE0Jekhhj6ktQQ\nQ1+SGmLoS1JDDH1JaoihL0kN+V82bZcpaxTlIgAAAABJRU5ErkJggg==\n", 144 | "text/plain": [ 145 | "" 146 | ] 147 | }, 148 | "metadata": {}, 149 | "output_type": "display_data" 150 | } 151 | ], 152 | "source": [ 153 | "idx = 0\n", 154 | "plot_bar(attns[0,idx,:])" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "collapsed": true 162 | }, 163 | "outputs": [], 164 | "source": [] 165 | } 166 | ], 167 | "metadata": { 168 | "kernelspec": { 169 | "display_name": "Python 3", 170 | "language": "python", 171 | "name": "python3" 172 | }, 173 | "language_info": { 174 | "codemirror_mode": { 175 | "name": "ipython", 176 | "version": 3 177 | }, 178 | "file_extension": ".py", 179 | "mimetype": "text/x-python", 180 | "name": "python", 181 | "nbconvert_exporter": "python", 182 | "pygments_lexer": "ipython3", 183 | "version": "3.6.2" 184 | } 185 | }, 186 | "nbformat": 4, 187 | "nbformat_minor": 2 188 | } 189 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # configuration file 2 | pad_tok = '' 3 | start_tok = '' 4 | end_tok = '' 5 | unk_tok = '' 6 | 7 | pad_index = 0 8 | 9 | # elmo small 10 | weight_file = '/home/tiankeke/workspace/Elmo/small/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5' 11 | options_file = '/home/tiankeke/workspace/Elmo/small/options.json' 12 | 13 | # elmo medium 14 | # weight_file = '/home/tiankeke/workspace/Elmo/medium/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5' 15 | # options_file = '/home/tiankeke/workspace/Elmo/medium/options.json' -------------------------------------------------------------------------------- /mytest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from utils import BatchManager, load_data, get_vocab, build_vocab 7 | from Transformer import Transformer 8 | from translate import translate, print_summaries 9 | 10 | parser = argparse.ArgumentParser(description='Selective Encoding for Abstractive Sentence Summarization in pytorch') 11 | 12 | parser.add_argument('--n_test', type=int, default=1951, 13 | help='Number of validation data (up to 189651 in gigaword) [default: 189651])') 14 | parser.add_argument('--input_file', type=str, default="sumdata/Giga/input.txt", help='input file') 15 | parser.add_argument('--output_dir', type=str, default="sumdata/Giga/systems/", help='') 16 | parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 32]') 17 | parser.add_argument('--ckpt_file', type=str, default='./ckpts/params_v2_0.pkl', help='model file path') 18 | parser.add_argument('--search', type=str, default='greedy', help='greedy/beam') 19 | parser.add_argument('--beam_width', type=int, default=12, help='beam search width') 20 | args = parser.parse_args() 21 | print(args) 22 | 23 | 24 | def my_test(test_x, model, tgt_vocab): 25 | summaries = translate(test_x, model, tgt_vocab, search=args.search) 26 | print_summaries(summaries, tgt_vocab, args.output_dir) 27 | print("Done!") 28 | 29 | 30 | def main(): 31 | # if not os.path.exists(args.ckpt_file): 32 | # raise FileNotFoundError("model file not found") 33 | 34 | data_dir = '/home/tiankeke/workspace/datas/sumdata/' 35 | TRAIN_X = os.path.join(data_dir, 'train/train.article.txt') 36 | TRAIN_Y = os.path.join(data_dir, 'train/train.title.txt') 37 | TEST_X = args.input_file 38 | 39 | small_vocab_file = 'sumdata/small_vocab.json' 40 | if os.path.exists(small_vocab_file): 41 | small_vocab = json.load(open(small_vocab_file)) 42 | else: 43 | small_vocab = build_vocab([TRAIN_X, TRAIN_Y], small_vocab_file, vocab_size=80000) 44 | 45 | max_src_len = 100 46 | max_tgt_len = 40 47 | vocab = small_vocab 48 | 49 | test_x = BatchManager(load_data(TEST_X, max_src_len, args.n_test), args.batch_size, small_vocab) 50 | 51 | model = Transformer(len(vocab), len(vocab), 200, 200, 2, 4, 256, 52 | 1024, src_tgt_emb_share=True, tgt_prj_wt_share=True).cuda() 53 | 54 | saved_state = torch.load(args.ckpt_file) 55 | model.load_state_dict(saved_state['state_dict']) 56 | print('Load model parameters from %s' % args.ckpt_file) 57 | 58 | my_test(test_x, model, small_vocab) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | 64 | -------------------------------------------------------------------------------- /mytest1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | from utils import BatchManager, load_data, get_vocab, build_vocab 6 | from transformer.Models import Transformer 7 | from Beam import Beam 8 | import torch.nn.functional as F 9 | 10 | parser = argparse.ArgumentParser(description='Selective Encoding for Abstractive Sentence Summarization in pytorch') 11 | 12 | parser.add_argument('--n_test', type=int, default=1951, 13 | help='Number of validation data (up to 189651 in gigaword) [default: 189651])') 14 | parser.add_argument('--input_file', type=str, default="sumdata/Giga/input.txt", help='input file') 15 | parser.add_argument('--output_dir', type=str, default="sumdata/Giga/systems/", help='') 16 | parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 32]') 17 | parser.add_argument('--emb_dim', type=int, default=300, help='Embedding size [default: 256]') 18 | parser.add_argument('--hid_dim', type=int, default=512, help='Hidden state size [default: 256]') 19 | parser.add_argument('--maxout_dim', type=int, default=2, help='Maxout size [default: 2]') 20 | parser.add_argument('--ckpt_file', type=str, default='./models/params_v1_0.pkl', help='model file path') 21 | parser.add_argument('--search', type=str, default='greedy', help='greedy/beam') 22 | parser.add_argument('--beam_width', type=int, default=12, help='beam search width') 23 | args = parser.parse_args() 24 | print(args) 25 | 26 | 27 | def print_summaries(summaries, vocab): 28 | """ 29 | param summaries: in shape (seq_len, batch) 30 | """ 31 | i2w = {key: value for value, key in vocab.items()} 32 | 33 | for idx in range(len(summaries)): 34 | fout = open(os.path.join(args.output_dir, "%d.txt" % idx), "w") 35 | line = [i2w[tok] for tok in summaries[idx] if tok != vocab[""]] 36 | fout.write(" ".join(line) + "\n") 37 | fout.close() 38 | 39 | 40 | def greedy(model, x, tgt_vocab, max_trg_len=15): 41 | y = torch.ones(x.shape[0], max_trg_len, dtype=torch.long).cuda() * tgt_vocab[""] 42 | y[:,0] = tgt_vocab[""] 43 | 44 | pos_x = torch.arange(x.shape[1]).unsqueeze(0).expand_as(x).cuda() 45 | pos_y = torch.arange(y.shape[1]).unsqueeze(0).expand_as(y).cuda() 46 | 47 | for i in range(max_trg_len-1): 48 | logits = model(x, pos_x, y, pos_y) 49 | y[:, i+1] = torch.argmax(logits[:,i,:], dim=-1) 50 | return y[:,1:].detach().cpu().tolist() 51 | 52 | 53 | def beam_search(model, batch_x, vocab, max_trg_len=10, k=args.beam_width): 54 | 55 | beams = [Beam(k, vocab, max_trg_len) for _ in range(batch_x.shape[0])] 56 | 57 | for i in range(max_trg_len): 58 | for j in range(len(beams)): 59 | x = batch_x[j].unsqueeze(0).expand(k, -1) 60 | y = beams[j].get_sequence() 61 | 62 | pos_x = torch.arange(x.shape[1]).unsqueeze(0).expand_as(x).cuda() 63 | pos_y = torch.arange(y.shape[1]).unsqueeze(0).expand_as(y).cuda() 64 | 65 | logit = model(x, pos_x, y, pos_y) 66 | # logit: [k, seqlen, V] 67 | log_probs = torch.log(F.softmax(logit[:, i, :], -1)) 68 | beams[j].advance_(log_probs) 69 | 70 | allHyp = [b.get_hyp().cpu().numpy() for b in beams] 71 | return allHyp 72 | 73 | 74 | def my_test(valid_x, model, tgt_vocab): 75 | summaries = [] 76 | with torch.no_grad(): 77 | for i in range(valid_x.steps): 78 | _, x = valid_x.next_batch() 79 | if args.search == "greedy": 80 | summary = greedy(model, x, tgt_vocab) 81 | elif args.search == "beam": 82 | summary = beam_search(model, x, tgt_vocab) 83 | else: 84 | raise NameError("Unknown search method") 85 | summaries.extend(summary) 86 | print_summaries(summaries, tgt_vocab) 87 | print("Done!") 88 | 89 | 90 | def main(): 91 | if not os.path.exists(args.ckpt_file): 92 | raise FileNotFoundError("model file not found") 93 | 94 | data_dir = '/home/tiankeke/workspace/datas/sumdata/' 95 | TRAIN_X = os.path.join(data_dir, 'train/train.article.txt') 96 | TRAIN_Y = os.path.join(data_dir, 'train/train.title.txt') 97 | TEST_X = args.input_file 98 | 99 | small_vocab_file = 'sumdata/small_vocab.json' 100 | if os.path.exists(small_vocab_file): 101 | small_vocab = json.load(open(small_vocab_file)) 102 | else: 103 | small_vocab = build_vocab([TRAIN_X, TRAIN_Y], small_vocab_file, vocab_size=80000) 104 | 105 | max_src_len = 101 106 | max_tgt_len = 47 107 | 108 | test_x = BatchManager(load_data(TEST_X, max_src_len, args.n_test), args.batch_size, small_vocab) 109 | 110 | model = Transformer(len(small_vocab), len(small_vocab), max_src_len, d_word_vec=300, 111 | d_model=300, d_inner=1200, n_layers=1, n_head=6, d_k=50, 112 | d_v=50, dropout=0.1, tgt_emb_prj_weight_sharing=True, 113 | emb_src_tgt_weight_sharing=True).cuda() 114 | # print(model) 115 | model.eval() 116 | 117 | saved_state = torch.load(args.ckpt_file) 118 | model.load_state_dict(saved_state['state_dict']) 119 | print('Load model parameters from %s' % args.ckpt_file) 120 | 121 | my_test(test_x, model, small_vocab) 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | 127 | -------------------------------------------------------------------------------- /mytest_elmo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | from utils import BatchManager, load_data, get_vocab, build_vocab 6 | from Transformer import ElmoTransformer 7 | from Beam import Beam 8 | import torch.nn.functional as F 9 | import copy 10 | 11 | parser = argparse.ArgumentParser(description='Selective Encoding for Abstractive Sentence Summarization in pytorch') 12 | 13 | parser.add_argument('--n_test', type=int, default=1951, 14 | help='Number of validation data (up to 189651 in gigaword) [default: 189651])') 15 | parser.add_argument('--input_file', type=str, default="sumdata/Giga/input.txt", help='input file') 16 | parser.add_argument('--output_dir', type=str, default="sumdata/Giga/systems/", help='') 17 | parser.add_argument('--batch_size', type=int, default=64, help='Mini batch size [default: 32]') 18 | parser.add_argument('--ckpt_file', type=str, default='./models/elmo_small_2L_8H_512_epoch0.pkl', help='model file path') 19 | parser.add_argument('--search', type=str, default='greedy', help='greedy/beam') 20 | parser.add_argument('--beam_width', type=int, default=12, help='beam search width') 21 | args = parser.parse_args() 22 | print(args) 23 | 24 | 25 | def print_summaries(summaries, vocab, output_dir, pattern='%d.txt'): 26 | """ 27 | param summaries: in shape (seq_len, batch) 28 | """ 29 | for idx in range(len(summaries)): 30 | fout = open(os.path.join(output_dir, pattern % idx), "w") 31 | line = [tok for tok in summaries[idx] if tok not in ['', '', '']] 32 | fout.write(" ".join(line) + "\n") 33 | fout.close() 34 | 35 | 36 | def run_batch(batch_x, batch_y, model): 37 | x_stncs, x_ids = batch_x.next_batch() 38 | y_stncs, y_ids = batch_y.next_batch() 39 | 40 | logits = model(x_stncs, y_stncs, x_ids, y_ids) 41 | loss = model.loss_layer(logits.view(-1, logits.shape[-1]), 42 | y_ids[:, 1:].contiguous().view(-1)) 43 | return loss 44 | 45 | 46 | def greedy(model, x_stncs, x_ids, tgt_vocab, max_trg_len=15): 47 | b_size = len(x_stncs) 48 | y_stncs = [[''] * max_trg_len for _ in range(b_size)] 49 | y_ids = torch.ones(b_size, max_trg_len, dtype=torch.long).cuda() 50 | y_ids *= tgt_vocab[''] 51 | id2w = {v: k for k,v in tgt_vocab.items()} 52 | for i in range(max_trg_len - 1): 53 | logits = model.forward(x_stncs, y_stncs, x_ids, y_ids) 54 | argmax = torch.argmax(logits[:, i, :], dim=-1).detach().cpu().tolist() 55 | for j in range(b_size): 56 | y_stncs[j][i+1] = id2w[argmax[j]] 57 | y_ids[j][i+1] = argmax[j] 58 | return y_stncs 59 | 60 | 61 | def beam_search(model, batch_x, vocab, max_trg_len=10, k=3): 62 | beams = [Beam(k, vocab, max_trg_len) for _ in range(batch_x.shape[0])] 63 | 64 | for i in range(max_trg_len): 65 | for j in range(len(beams)): 66 | x = batch_x[j].unsqueeze(0).expand(k, -1) 67 | y = beams[j].get_sequence() 68 | logit = model(x, y) 69 | # logit: [k, seqlen, V] 70 | log_probs = torch.log(F.softmax(logit[:, i, :], -1)) 71 | beams[j].advance_(log_probs) 72 | 73 | allHyp = [b.get_hyp().cpu().numpy() for b in beams] 74 | return allHyp 75 | 76 | 77 | def translate(valid_x, model, tgt_vocab, search='greedy', beam_width=5): 78 | summaries = [] 79 | model.eval() 80 | with torch.no_grad(): 81 | for i in range(valid_x.steps): 82 | print(i, flush=True) 83 | x_stncs, x_ids = valid_x.next_batch() 84 | if search == "greedy": 85 | summary = greedy(model, x_stncs, x_ids, tgt_vocab) 86 | elif search == "beam": 87 | summary = beam_search(model, x_stncs, x_ids, tgt_vocab, k=beam_width) 88 | else: 89 | raise NameError("Unknown search method") 90 | summaries.extend(summary) 91 | return summaries 92 | 93 | 94 | def my_test(valid_x, model, tgt_vocab): 95 | summaries = translate(valid_x, model, tgt_vocab, search='greedy') 96 | print_summaries(summaries, tgt_vocab, args.output_dir) 97 | print("Done!") 98 | 99 | 100 | def main(): 101 | if not os.path.exists(args.ckpt_file): 102 | raise FileNotFoundError("model file not found") 103 | 104 | data_dir = '/home/tiankeke/workspace/datas/sumdata/' 105 | TRAIN_X = os.path.join(data_dir, 'train/train.article.txt') 106 | TRAIN_Y = os.path.join(data_dir, 'train/train.title.txt') 107 | TEST_X = args.input_file 108 | 109 | small_vocab_file = 'sumdata/small_vocab.json' 110 | if os.path.exists(small_vocab_file): 111 | small_vocab = json.load(open(small_vocab_file)) 112 | else: 113 | small_vocab = build_vocab([TRAIN_X, TRAIN_Y], small_vocab_file, vocab_size=80000) 114 | 115 | max_src_len = 60 116 | max_tgt_len = 20 117 | 118 | bs = args.batch_size 119 | n_test = args.n_test 120 | 121 | vocab = small_vocab 122 | test_x = BatchManager(load_data(TEST_X, max_src_len, n_test), bs, vocab) 123 | 124 | model = ElmoTransformer(max_src_len, len(vocab), 2, 8, 64, 64, 256, 512, 2048, 125 | dropout=0.5, elmo_requires_grad=False).cuda() 126 | 127 | saved_state = torch.load(args.ckpt_file) 128 | model.load_state_dict(saved_state['state_dict']) 129 | print('Load model parameters from %s' % args.ckpt_file) 130 | 131 | my_test(test_x, model, small_vocab) 132 | 133 | 134 | if __name__ == '__main__': 135 | main() 136 | 137 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ### Transformer for text summarization, implemented in pytorch 2 | - author: Kirk 3 | - mail: cotitan@outlook.com 4 | 5 | ### Requirments, * means not necessary 6 | - pytorch==0.4.0 7 | - numpy==1.12.1+ 8 | - python==3.5+ 9 | - tensorboardx==1.2* 10 | - word2vec==0.10.2* 11 | - allennlp==0.8.2* 12 | - pyrouge==0.1.3* 13 | 14 | ### Data 15 | Training and evaluation data for Gigaword is available https://drive.google.com/open?id=0B6N7tANPyVeBNmlSX19Ld2xDU1E 16 | 17 | Training and evaluation data for CNN/DM is available https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz 18 | 19 | ### Noticement 20 | 1. we use another thread to preprocess a batch of data, which would not terminate after the main process terminate. So you need to press ctrl+c again to terminate the thread. 21 | 22 | ### Directories: 23 | ``` 24 | . 25 | ├── log 26 | ├── models 27 | ├── sumdata 28 | ├── tmp 29 | ├── transformer 30 | ├── Beam.py 31 | ├── config.py 32 | ├── train.py 33 | ├── mytest.py 34 | ├── Transformer.py 35 | ├── translate.py 36 | └── utils.py 37 | ``` 38 | Make sure your project contains the folders above. 39 | 40 | ### How-to 41 | 1. Run _python train.py_ to train 42 | 2. Run _python mytest.py_ to generate summaries 43 | 44 | -------------------------------------------------------------------------------- /record.md: -------------------------------------------------------------------------------- 1 | ### Transformer1 2 | Giga, DUC2004 3 | 4 | ||R|P|F| 5 | |-|-|-|-| 6 | |R-1|20.48|29.14|23.26| 7 | |R-2|7.02|10.00|7.97| 8 | |R-3|18.98|26.89|21.52| 9 | 10 | 11 | ### Transformer1_dropout 12 | epoch 1, greedy search 13 | 14 | ||R|P|F|\||R|P|F| 15 | |-|-|-|-|-|-|-|-| 16 | |R-1|22.36|30.82|25.02|\||16.52|29.94|20.99| 17 | |R-2|8.10|11.19|9.00|\||3.74|7.26|4.85| 18 | |R-3|20.81|28.63|23.26|\||14.90|27.13|18.96| 19 | 20 | epoch 1, beam search 21 | 22 | ||R|P|F|\||R|P|F| 23 | |-|-|-|-|-|-|-|-| 24 | |R-1|19.26|27.89|22.02|\||15.24|28.41|19.63| 25 | |R-2|6.14|8.95|6.99|\||3.56|7.06|4.67| 26 | |R-3|17.29|24.96|19.74|\||13.23|24.60|17.02| 27 | 28 | epoch 2, greedy 29 | 30 | ||R|P|F|\||R|P|F| 31 | |-|-|-|-|-|-|-|-| 32 | |R-1|23.16|32.12|26.04|\||17.07|30.19|21.52| 33 | |R-2|8.63|11.98|9.66|\||4.09|7.60|5.23| 34 | |R-3|21.53|29.77|24.17|\||15.39|27.25|19.41| 35 | 36 | epoch 3, greedy 37 | 38 | ||R|P|F|\||R|P|F| 39 | |-|-|-|-|-|-|-|-| 40 | |R-1|23.96|32.94|26.79|\||17.51|30.37|21.90| 41 | |R-2|9.07|12.44|10.04|\||4.16|7.61|5.28| 42 | |R-3|22.43|30.74|25.03|\||15.72|27.36|19.67| 43 | 44 | ### Transformer1_dropout, learning_rate decay 0.3/epoch 45 | 46 | epoch 1, greedy 47 | 48 | ||R|P|F|\||R|P|F| 49 | |-|-|-|-|-|-|-|-| 50 | |R-1|24.96|34.22|27.97|\||17.75|30.81|22.18| 51 | |R-2|9.53|13.07|10.64|\||4.51|8.41|5.76| 52 | |R-3|23.31|31.79|26.06|\||16.09|28.09|20.14| 53 | 54 | 55 | epoch 2, greedy 56 | 57 | ||R|P|F|\||R|P|F| 58 | |-|-|-|-|-|-|-|-| 59 | |R-1|25.39|34.48|28.31|\||18.80|31.90|23.39| 60 | |R-2|10.00|13.61|11.11|\||4.92|8.85|6.23| 61 | |R-3|23.74|32.10|26.41|\||16.90|28.74|21.04| 62 | 63 | 64 | epoch 3, greedy 65 | 66 | ||R|P|F|\||R|P|F| 67 | |-|-|-|-|-|-|-|-| 68 | |R-1|25.64|34.76|28.59|\||18.44|31.59|23.00| 69 | |R-2|10.13|13.89|11.31|\||4.91|8.71|6.21| 70 | |R-3|23.84|32.20|26.54|\||16.70|28.69|20.84| 71 | 72 | 73 | epoch 4, greedy 74 | 75 | ||R|P|F|\||R|P|F| 76 | |-|-|-|-|-|-|-|-| 77 | |R-1|26.22|34.63|28.89|\||19.37|32.17|23.78| 78 | |R-2|10.53|13.95|11.56|\||5.09|8.91|6.36| 79 | |R-3|24.41|32.09|26.85|\||17.45|29.10|21.44| 80 | 81 | 82 | ### elmo_small 83 | epoch 3, greedy 84 | 85 | ||R|P|F|\||R|P|F| 86 | |-|-|-|-|-|-|-|-| 87 | |R-1|27.89|31.44|28.65|\||21.02|29.51|24.24| 88 | |R-2|11.70|13.04|11.90|\||5.90|8.62|6.89| 89 | |R-3|26.14|29.40|26.82|\||18.90|26.60|21.82| 90 | 91 | 92 | ### elmo_small_translayer2 93 | epoch 3, greedy 94 | 95 | ||R|P|F|\||R|P|F| 96 | |-|-|-|-|-|-|-|-| 97 | |R-1|28.85|31.72|29.21|\||21.91|30.02|24.98| 98 | |R-2|12.08|13.15|12.12|\||6.21|8.94|7.19| 99 | |R-3|26.78|29.44|27.10|\||19.55|26.86|22.31| 100 | 101 | epoch 6, greedy 102 | 103 | ||R|P|F|\||R|P|F| 104 | |-|-|-|-|-|-|-|-| 105 | |R-1|28.93|32.46|29.57|\||21.20|29.82|24.41| 106 | |R-2|12.24|13.51|12.37|\||5.96|8.82|6.98| 107 | |R-3|26.89|30.12|27.46|\||19.05|26.83|21.94| 108 | 109 | epoch 9, greedy 110 | 111 | ||R|P|F|\||R|P|F| 112 | |-|-|-|-|-|-|-|-| 113 | |R-1|28.98|32.47|29.58|\||21.35|29.93|24.58| 114 | |R-2|12.31|13.54|12.41|\||5.97|8.77|6.98| 115 | |R-3|26.99|30.15|27.51|\||19.16|26.89|22.06| -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import logging 4 | import json 5 | import torch 6 | import argparse 7 | from Transformer import Transformer, TransformerShareEmbedding 8 | from tensorboardX import SummaryWriter 9 | import utils 10 | from utils import BatchManager, load_data, load_vocab, build_vocab 11 | from pyrouge import Rouge155 12 | from translate import greedy, print_summaries 13 | 14 | parser = argparse.ArgumentParser(description='Selective Encoding for Abstractive Sentence Summarization in DyNet') 15 | 16 | parser.add_argument('--n_epochs', type=int, default=10, help='Number of epochs [default: 3]') 17 | parser.add_argument('--n_train', type=int, default=3803900, 18 | help='Number of training data (up to 3803957 in gigaword) [default: 3803957]') 19 | parser.add_argument('--n_valid', type=int, default=189651, 20 | help='Number of validation data (up to 189651 in gigaword) [default: 189651])') 21 | parser.add_argument('--batch_size', type=int, default=64, help='Mini batch size [default: 32]') 22 | parser.add_argument('--ckpt_file', type=str, default='./ckpts/params_v2_9.pkl') 23 | args = parser.parse_args() 24 | 25 | logging.basicConfig( 26 | level=logging.INFO, 27 | format='%(asctime)s - %(levelname)s - %(message)s', 28 | filename='log/train2.log', 29 | filemode='w' 30 | ) 31 | 32 | # define a new Handler to log to console as well 33 | console = logging.StreamHandler() 34 | # optional, set the logging level 35 | console.setLevel(logging.INFO) 36 | # set a format which is the same for console use 37 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 38 | # tell the handler to use this format 39 | console.setFormatter(formatter) 40 | # add the handler to the root logger 41 | logging.getLogger('').addHandler(console) 42 | 43 | model_dir = './ckpts' 44 | if not os.path.exists(model_dir): 45 | os.mkdir(model_dir) 46 | 47 | 48 | def run_batch(valid_x, valid_y, model): 49 | _, x = valid_x.next_batch() 50 | _, y = valid_y.next_batch() 51 | logits, _ = model(x, y) 52 | loss = model.loss_layer(logits.view(-1, logits.shape[-1]), 53 | y[:, 1:].contiguous().view(-1)) 54 | return loss 55 | 56 | 57 | def eval_model(valid_x, valid_y, vocab, model): 58 | # if we put the following part outside the code, 59 | # error occurs 60 | r = Rouge155() 61 | r.system_dir = 'tmp/systems' 62 | r.model_dir = 'tmp/models' 63 | r.system_filename_pattern = "(\d+).txt" 64 | r.model_filename_pattern = "[A-Z].#ID#.txt" 65 | 66 | logging.info('Evaluating on a minibatch...') 67 | model.eval() 68 | _, x = valid_x.next_batch() 69 | with torch.no_grad(): 70 | pred = greedy(model, x, vocab) 71 | _, y = valid_y.next_batch() 72 | y = y[:,1:].tolist() 73 | print_summaries(pred, vocab, 'tmp/systems', '%d.txt') 74 | print_summaries(y, vocab, 'tmp/models', 'A.%d.txt') 75 | 76 | try: 77 | output = r.convert_and_evaluate() 78 | output_dict = r.output_to_dict(output) 79 | logging.info('Rouge1-F: %f, Rouge2-F: %f, RougeL-F: %f' 80 | % (output_dict['rouge_1_f_score'], 81 | output_dict['rouge_2_f_score'], 82 | output_dict['rouge_l_f_score'])) 83 | except Exception as e: 84 | logging.info('Failed to evaluate') 85 | 86 | model.train() 87 | 88 | 89 | def adjust_lr(optimizer, epoch): 90 | if (epoch + 1) % 2 == 0: 91 | # optimizer.param_groups[0]['lr'] *= math.sqrt((epoch+1)/10) 92 | optimizer.param_groups[0]['lr'] *= 0.5 93 | 94 | 95 | def train(train_x, train_y, valid_x, valid_y, model, 96 | optimizer, tgt_vocab, scheduler, n_epochs=1, epoch=0): 97 | logging.info("Start to train with lr=%f..." % optimizer.param_groups[0]['lr']) 98 | n_batches = train_x.steps 99 | model.train() 100 | 101 | if os.path.isdir('runs/epoch%d' % epoch): 102 | shutil.rmtree('runs/epoch%d' % epoch) 103 | writer = SummaryWriter('runs/epoch%d' % epoch) 104 | i = epoch * train_x.steps 105 | for epoch in range(epoch, n_epochs): 106 | valid_x.bid = 0 107 | valid_y.bid = 0 108 | 109 | 110 | for idx in range(n_batches): 111 | optimizer.zero_grad() 112 | 113 | loss = run_batch(train_x, train_y, model) 114 | loss.backward() # do not use retain_graph=True 115 | # torch.nn.utils.clip_grad_value_(model.parameters(), 5) 116 | 117 | optimizer.step() 118 | 119 | if (idx + 1) % 50 == 0: 120 | train_loss = loss.cpu().detach().numpy() 121 | model.eval() 122 | with torch.no_grad(): 123 | valid_loss = run_batch(valid_x, valid_y, model) 124 | logging.info('epoch %d: %d, training loss = %f, validation loss = %f' 125 | % (epoch, idx + 1, train_loss, valid_loss)) 126 | writer.add_scalar('scalar/train_loss', train_loss, i) 127 | writer.add_scalar('scalar/valid_loss', valid_loss, i) 128 | i += 1 129 | model.train() 130 | # if (idx + 1) % 2000 == 0: 131 | # eval_model(valid_x, valid_y, tgt_vocab, model) 132 | # dump_tensors() 133 | 134 | adjust_lr(optimizer, epoch) 135 | save_state = {'state_dict': model.state_dict(), 136 | 'epoch': epoch + 1, 137 | 'lr': optimizer.param_groups[0]['lr']} 138 | torch.save(save_state, os.path.join(model_dir, 'params_v2_%d.pkl' % epoch)) 139 | logging.info('Model saved in dir %s' % model_dir) 140 | writer.close() 141 | 142 | 143 | def main(): 144 | print(args) 145 | 146 | data_dir = '/home/disk3/tiankeke/sumdata/' 147 | TRAIN_X = os.path.join(data_dir, 'train/train.article.txt') 148 | TRAIN_Y = os.path.join(data_dir, 'train/train.title.txt') 149 | VALID_X = os.path.join(data_dir, 'train/valid.article.filter.txt') 150 | VALID_Y = os.path.join(data_dir, 'train/valid.title.filter.txt') 151 | 152 | src_vocab_file = 'sumdata/src_vocab.txt' 153 | if not os.path.exists(src_vocab_file): 154 | build_vocab([TRAIN_X], src_vocab_file) 155 | src_vocab = load_vocab(src_vocab_file, vocab_size=90000) 156 | 157 | tgt_vocab_file = 'sumdata/tgt_vocab.txt' 158 | if not os.path.exists(tgt_vocab_file): 159 | build_vocab([TRAIN_Y], tgt_vocab_file) 160 | tgt_vocab = load_vocab(tgt_vocab_file) 161 | 162 | # emb_file = '/home/tiankeke/workspace/embeddings/giga-vec1.bin' 163 | # vocab, embeddings = load_word2vec_embedding(emb_file) 164 | 165 | max_src_len = 100 166 | max_tgt_len = 40 167 | max_pos = 200 168 | 169 | bs = args.batch_size 170 | n_train = args.n_train 171 | n_valid = args.n_valid 172 | 173 | train_x = BatchManager(load_data(TRAIN_X, max_src_len, n_train), bs, src_vocab) 174 | train_y = BatchManager(load_data(TRAIN_Y, max_tgt_len, n_train), bs, tgt_vocab) 175 | train_x, train_y = utils.shuffle(train_x, train_y) 176 | 177 | valid_x = BatchManager(load_data(VALID_X, max_src_len, n_valid), bs, src_vocab) 178 | valid_y = BatchManager(load_data(VALID_Y, max_tgt_len, n_valid), bs, tgt_vocab) 179 | valid_x, valid_y = utils.shuffle(valid_x, valid_y) 180 | # model = Transformer(len(vocab), len(vocab), max_src_len, max_tgt_len, 1, 4, 256, 181 | # 64, 64, 1024, src_tgt_emb_share=True, tgt_prj_wt_share=True).cuda() 182 | model = Transformer(len(src_vocab), len(tgt_vocab), max_pos, max_pos, 2, 4, 256, 183 | 1024, src_tgt_emb_share=False, tgt_prj_wt_share=True).cuda() 184 | # model = TransformerShareEmbedding(len(vocab), max_src_len, 2, 4, 185 | # 256, 1024, False, True).cuda() 186 | 187 | # print(model) 188 | saved_state = {'epoch': 0, 'lr': 0.001} 189 | if os.path.exists(args.ckpt_file): 190 | saved_state = torch.load(args.ckpt_file) 191 | model.load_state_dict(saved_state['state_dict']) 192 | logging.info('Load model parameters from %s' % args.ckpt_file) 193 | 194 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 195 | optimizer = torch.optim.Adam(parameters, lr=saved_state['lr']) 196 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3) 197 | scheduler.step() # last_epoch=-1, which will not update lr at the first time 198 | 199 | # eval_model(valid_x, valid_y, vocab, model) 200 | train(train_x, train_y, valid_x, valid_y, model, 201 | optimizer, tgt_vocab, scheduler, args.n_epochs, saved_state['epoch']) 202 | 203 | 204 | if __name__ == '__main__': 205 | main() 206 | # TODO 207 | # 使用Pycharm,逐过程查看内部状态,看哪一步的结果值很小,可能是该步出问题 208 | 209 | -------------------------------------------------------------------------------- /train1.py: -------------------------------------------------------------------------------- 1 | """ use another implementation from ./transformer """ 2 | import os 3 | import logging 4 | import torch 5 | import json 6 | import argparse 7 | from transformer.Models import Transformer 8 | from utils import BatchManager, load_data, get_vocab, build_vocab 9 | from tensorboardX import SummaryWriter 10 | import config 11 | 12 | parser = argparse.ArgumentParser(description='Selective Encoding for Abstractive Sentence Summarization in DyNet') 13 | 14 | parser.add_argument('--n_epochs', type=int, default=5, help='Number of epochs [default: 3]') 15 | parser.add_argument('--n_train', type=int, default=3803900, 16 | help='Number of training data (up to 3803957 in gigaword) [default: 3803957]') 17 | parser.add_argument('--n_valid', type=int, default=189651, 18 | help='Number of validation data (up to 189651 in gigaword) [default: 189651])') 19 | parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 32]') 20 | parser.add_argument('--emb_dim', type=int, default=300, help='Embedding size [default: 256]') 21 | parser.add_argument('--hid_dim', type=int, default=512, help='Hidden state size [default: 256]') 22 | parser.add_argument('--maxout_dim', type=int, default=2, help='Maxout size [default: 2]') 23 | parser.add_argument('--ckpt_file', type=str, default='./models/params_v1_0.pkl') 24 | args = parser.parse_args() 25 | 26 | 27 | logging.basicConfig( 28 | level=logging.INFO, 29 | format='%(asctime)s - %(levelname)s - %(message)s', 30 | filename='log/train1.log', 31 | filemode='w' 32 | ) 33 | 34 | # define a new Handler to log to console as well 35 | console = logging.StreamHandler() 36 | # optional, set the logging level 37 | console.setLevel(logging.INFO) 38 | # set a format which is the same for console use 39 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 40 | # tell the handler to use this format 41 | console.setFormatter(formatter) 42 | # add the handler to the root logger 43 | logging.getLogger('').addHandler(console) 44 | 45 | 46 | model_dir = './models' 47 | if not os.path.exists(model_dir): 48 | os.mkdir(model_dir) 49 | 50 | loss_layer = torch.nn.CrossEntropyLoss(ignore_index=config.pad_index) 51 | 52 | 53 | def run_batch(valid_x, valid_y, model): 54 | _, x = valid_x.next_batch() 55 | _, y = valid_y.next_batch() 56 | 57 | pos_x = torch.arange(x.shape[1]).unsqueeze(0).expand_as(x).cuda() 58 | pos_y = torch.arange(y.shape[1]).unsqueeze(0).expand_as(y).cuda() 59 | 60 | logits = model(x, pos_x, y, pos_y) 61 | # print(logits.shape) 62 | 63 | loss = 0 64 | for i in range(x.shape[0]): 65 | loss += loss_layer(logits[i], y[i, 1:]) 66 | loss /= x.shape[0] 67 | return loss 68 | 69 | 70 | def train(train_x, train_y, valid_x, valid_y, model, optimizer, scheduler, epochs=1, epoch=0): 71 | logging.info("Start to train...") 72 | n_batches = train_x.steps 73 | writer = SummaryWriter() 74 | model.train() 75 | for epoch in range(epoch, epochs): 76 | for idx in range(n_batches): 77 | optimizer.zero_grad() 78 | 79 | loss = run_batch(train_x, train_y, model) 80 | loss.backward() # do not use retain_graph=True 81 | torch.nn.utils.clip_grad_value_(model.parameters(), 5) 82 | 83 | optimizer.step() 84 | 85 | if (idx + 1) % 50 == 0: 86 | train_loss = loss.cpu().detach().numpy() 87 | model.eval() 88 | valid_loss = run_batch(valid_x, valid_y, model) 89 | logging.info('epoch %d, step %d, training loss = %f, validation loss = %f' 90 | % (epoch, idx + 1, train_loss, valid_loss)) 91 | writer.add_scalar('scalar/epoch_%d/train_loss' % epoch, train_loss, (idx+1)//50) 92 | writer.add_scalar('scalar/epoch_%d/valid_loss' % epoch, valid_loss, (idx+1)//50) 93 | model.train() 94 | del loss 95 | scheduler.step() 96 | 97 | saved_state = {'state_dict': model.state_dict(), 98 | 'epoch': epoch + 1, # +1 for the next epoch 99 | 'lr': optimizer.param_groups[0]['lr']} 100 | 101 | torch.save(saved_state, os.path.join(model_dir, 'params_v1_%d.pkl' % epoch)) 102 | logging.info('Model saved in dir %s' % model_dir) 103 | 104 | 105 | def main(): 106 | print(args) 107 | 108 | data_dir = '/home/tiankeke/workspace/datas/sumdata/' 109 | TRAIN_X = os.path.join(data_dir, 'train/train.article.txt') 110 | TRAIN_Y = os.path.join(data_dir, 'train/train.title.txt') 111 | VALID_X = os.path.join(data_dir, 'train/valid.article.filter.txt') 112 | VALID_Y = os.path.join(data_dir, 'train/valid.title.filter.txt') 113 | 114 | src_vocab, tgt_vocab = get_vocab(TRAIN_X, TRAIN_Y) 115 | 116 | small_vocab_file = 'sumdata/small_vocab.json' 117 | if os.path.exists(small_vocab_file): 118 | small_vocab = json.load(open(small_vocab_file)) 119 | else: 120 | small_vocab = build_vocab([TRAIN_X, TRAIN_Y], small_vocab_file, vocab_size=80000) 121 | 122 | max_src_len = 101 123 | max_tgt_len = 47 124 | bs = args.batch_size 125 | n_train = args.n_train 126 | n_valid = args.n_valid 127 | 128 | vocab = small_vocab 129 | 130 | train_x = BatchManager(load_data(TRAIN_X, max_src_len, n_train), bs, vocab) 131 | train_y = BatchManager(load_data(TRAIN_Y, max_tgt_len, n_train), bs, vocab) 132 | valid_x = BatchManager(load_data(VALID_X, max_src_len, n_valid), bs, vocab) 133 | valid_y = BatchManager(load_data(VALID_Y, max_tgt_len, n_valid), bs, vocab) 134 | 135 | model = Transformer(len(vocab), len(vocab), max_src_len, d_word_vec=300, 136 | d_model=300, d_inner=1200, n_layers=1, n_head=6, d_k=50, 137 | d_v=50, dropout=0.1, tgt_emb_prj_weight_sharing=True, 138 | emb_src_tgt_weight_sharing=True).cuda() 139 | # print(model) 140 | 141 | saved_state = {'epoch': 0, 'lr': 0.001} 142 | if os.path.exists(args.ckpt_file): 143 | saved_state = torch.load(args.ckpt_file) 144 | model.load_state_dict(saved_state['state_dict']) 145 | logging.info('Load model parameters from %s' % args.ckpt_file) 146 | 147 | parameters = filter(lambda p : p.requires_grad, model.parameters()) 148 | optimizer = torch.optim.Adam(parameters, lr=saved_state['lr']) 149 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3) 150 | scheduler.step() # last_epoch=-1, which will not update lr at the first time 151 | 152 | train(train_x, train_y, valid_x, valid_y, model, optimizer, scheduler, 153 | args.n_epochs, saved_state['epoch']) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | #TODO 159 | # 使用Pycharm,逐过程查看内部状态,看哪一步的结果值很小,可能是该步出问题 160 | 161 | -------------------------------------------------------------------------------- /train_elmo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import logging 4 | import json 5 | import torch 6 | import argparse 7 | from Transformer import ElmoTransformer 8 | from tensorboardX import SummaryWriter 9 | from utils import BatchManager, load_data, get_vocab, build_vocab, load_word2vec_embedding, dump_tensors 10 | from pyrouge import Rouge155 11 | from translate import greedy, print_summaries 12 | 13 | parser = argparse.ArgumentParser(description='Selective Encoding for Abstractive Sentence Summarization in DyNet') 14 | 15 | parser.add_argument('--n_epochs', type=int, default=10, help='Number of epochs [default: 3]') 16 | parser.add_argument('--n_train', type=int, default=3803900, 17 | help='Number of training data (up to 3803957 in gigaword) [default: 3803957]') 18 | parser.add_argument('--n_valid', type=int, default=189651, 19 | help='Number of validation data (up to 189651 in gigaword) [default: 189651])') 20 | parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 32]') 21 | parser.add_argument('--ckpt_file', type=str, default='models/elmo_small_2L_8H_512_0.pkl') 22 | args = parser.parse_args() 23 | 24 | save_name = "elmo_small_2L_8H_512" 25 | 26 | logging.basicConfig( 27 | level=logging.INFO, 28 | format='%(asctime)s - %(levelname)s - %(message)s', 29 | filename='log/train_%s.log' % save_name, 30 | filemode='w' 31 | ) 32 | 33 | # define a new Handler to log to console as well 34 | console = logging.StreamHandler() 35 | # optional, set the logging level 36 | console.setLevel(logging.INFO) 37 | # set a format which is the same for console use 38 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 39 | # tell the handler to use this format 40 | console.setFormatter(formatter) 41 | # add the handler to the root logger 42 | logging.getLogger('').addHandler(console) 43 | 44 | model_dir = './models' 45 | if not os.path.exists(model_dir): 46 | os.mkdir(model_dir) 47 | 48 | 49 | def run_batch(batch_x, batch_y, model): 50 | x_stncs, x_ids = batch_x.next_batch() 51 | y_stncs, y_ids = batch_y.next_batch() 52 | 53 | logits = model(x_stncs, y_stncs, x_ids, y_ids) 54 | loss = model.loss_layer(logits.view(-1, logits.shape[-1]), 55 | y_ids[:, 1:].contiguous().view(-1)) 56 | return loss 57 | 58 | 59 | def eval_model(valid_x, valid_y, vocab, model): 60 | # if we put the following part outside the code, 61 | # error occurs 62 | r = Rouge155() 63 | r.system_dir = 'tmp/systems' 64 | r.model_dir = 'tmp/models' 65 | r.system_filename_pattern = "(\d+).txt" 66 | r.model_filename_pattern = "[A-Z].#ID#.txt" 67 | 68 | logging.info('Evaluating on a minibatch...') 69 | model.eval() 70 | x = valid_x.next_batch() 71 | with torch.no_grad(): 72 | pred = greedy(model, x, vocab) 73 | y = valid_y.next_batch()[:,1:].tolist() 74 | print_summaries(pred, vocab, 'tmp/systems', '%d.txt') 75 | print_summaries(y, vocab, 'tmp/models', 'A.%d.txt') 76 | 77 | try: 78 | output = r.convert_and_evaluate() 79 | output_dict = r.output_to_dict(output) 80 | logging.info('Rouge1-F: %f, Rouge2-F: %f, RougeL-F: %f' 81 | % (output_dict['rouge_1_f_score'], 82 | output_dict['rouge_2_f_score'], 83 | output_dict['rouge_l_f_score'])) 84 | except Exception as e: 85 | logging.info('Failed to evaluate') 86 | 87 | model.train() 88 | 89 | 90 | def train(train_x, train_y, valid_x, valid_y, model, 91 | optimizer, vocab, scheduler, n_epochs=1, epoch=0): 92 | logging.info("Start to train with lr=%f..." % optimizer.param_groups[0]['lr']) 93 | n_batches = train_x.steps 94 | model.train() 95 | for epoch in range(epoch, n_epochs): 96 | valid_x.bid = 0 97 | valid_y.bid = 0 98 | 99 | writer_dir = 'runs/%s_epoch%d' % (save_name, epoch) 100 | if os.path.isdir(writer_dir): 101 | shutil.rmtree(writer_dir) 102 | writer = SummaryWriter(writer_dir) 103 | 104 | for idx in range(n_batches): 105 | optimizer.zero_grad() 106 | 107 | loss = run_batch(train_x, train_y, model) 108 | loss.backward() # do not use retain_graph=True 109 | # torch.nn.utils.clip_grad_value_(model.parameters(), 5) 110 | 111 | optimizer.step() 112 | 113 | if (idx + 1) % 50 == 0: 114 | train_loss = loss.cpu().detach().numpy() 115 | model.eval() 116 | with torch.no_grad(): 117 | valid_loss = run_batch(valid_x, valid_y, model) 118 | logging.info('epoch %d, step %d, training loss = %f, validation loss = %f' 119 | % (epoch, idx + 1, train_loss, valid_loss)) 120 | writer.add_scalar('scalar/train_loss', train_loss, (idx + 1) // 50) 121 | writer.add_scalar('scalar/valid_loss', valid_loss, (idx + 1) // 50) 122 | model.train() 123 | # torch.cuda.empty_cache() 124 | # if (idx + 1) % 2000 == 0: 125 | # eval_model(valid_x, valid_y, vocab, model) 126 | # dump_tensors() 127 | 128 | if epoch < 6: 129 | scheduler.step() # make sure lr will not be too small 130 | save_state = {'state_dict': model.state_dict(), 131 | 'epoch': epoch + 1, 132 | 'lr': optimizer.param_groups[0]['lr']} 133 | save_path = os.path.join(model_dir, '%s_epoch%d.pkl' % (save_name, epoch)) 134 | torch.save(save_state, save_path) 135 | logging.info('Model saved in file %s' % save_path) 136 | writer.close() 137 | 138 | 139 | def main(): 140 | print(args) 141 | 142 | data_dir = '/home/tiankeke/workspace/datas/sumdata/' 143 | TRAIN_X = os.path.join(data_dir, 'train/train.article.txt') 144 | TRAIN_Y = os.path.join(data_dir, 'train/train.title.txt') 145 | VALID_X = os.path.join(data_dir, 'train/valid.article.filter.txt') 146 | VALID_Y = os.path.join(data_dir, 'train/valid.title.filter.txt') 147 | 148 | small_vocab_file = 'sumdata/small_vocab.json' 149 | if os.path.exists(small_vocab_file): 150 | small_vocab = json.load(open(small_vocab_file)) 151 | else: 152 | small_vocab = build_vocab([TRAIN_X, TRAIN_Y], small_vocab_file, vocab_size=80000) 153 | 154 | # emb_file = '/home/tiankeke/workspace/embeddings/giga-vec1.bin' 155 | # vocab, embeddings = load_word2vec_embedding(emb_file) 156 | 157 | max_src_len = 60 158 | max_tgt_len = 20 159 | 160 | bs = args.batch_size 161 | n_train = args.n_train 162 | n_valid = args.n_valid 163 | 164 | vocab = small_vocab 165 | 166 | train_x = BatchManager(load_data(TRAIN_X, max_src_len, n_train), bs, vocab) 167 | train_y = BatchManager(load_data(TRAIN_Y, max_tgt_len, n_train), bs, vocab) 168 | valid_x = BatchManager(load_data(VALID_X, max_src_len, n_valid), bs, vocab) 169 | valid_y = BatchManager(load_data(VALID_Y, max_tgt_len, n_valid), bs, vocab) 170 | 171 | # model = Transformer(len(vocab), len(vocab), max_src_len, max_tgt_len, 1, 4, 256, 172 | # 64, 64, 1024, src_tgt_emb_share=True, tgt_prj_emb_share=True).cuda() 173 | # model = Transformer(len(vocab), len(vocab), max_src_len, max_tgt_len, 1, 6, 300, 174 | # 50, 50, 1200, src_tgt_emb_share=True, tgt_prj_emb_share=True).cuda() 175 | # elmo_requries_grad=True after epoch 3 176 | model = ElmoTransformer(max_src_len, len(vocab), 2, 8, 64, 64, 256, 512, 2048, 177 | dropout=0.5, elmo_requires_grad=False).cuda() 178 | 179 | # print(model) 180 | saved_state = {'epoch': 0, 'lr': 0.001} 181 | if os.path.exists(args.ckpt_file): 182 | saved_state = torch.load(args.ckpt_file) 183 | model.load_state_dict(saved_state['state_dict']) 184 | logging.info('Load model parameters from %s' % args.ckpt_file) 185 | 186 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 187 | optimizer = torch.optim.Adam(parameters, lr=saved_state['lr']) 188 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) 189 | scheduler.step() # last_epoch=-1, which will not update lr at the first time 190 | 191 | # eval_model(valid_x, valid_y, vocab, model) 192 | train(train_x, train_y, valid_x, valid_y, model, 193 | optimizer, vocab, scheduler, args.n_epochs, saved_state['epoch']) 194 | 195 | 196 | if __name__ == '__main__': 197 | main() 198 | # TODO 199 | # 使用Pycharm,逐过程查看内部状态,看哪一步的结果值很小,可能是该步出问题 200 | 201 | -------------------------------------------------------------------------------- /transformer/Beam.py: -------------------------------------------------------------------------------- 1 | """ Manage beam search info structure. 2 | 3 | Heavily borrowed from OpenNMT-py. 4 | For code in OpenNMT-py, please check the following link: 5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import transformer.Constants as Constants 11 | 12 | class Beam(): 13 | ''' Beam search ''' 14 | 15 | def __init__(self, size, device=False): 16 | 17 | self.size = size 18 | self._done = False 19 | 20 | # The score for each translation on the beam. 21 | self.scores = torch.zeros((size,), dtype=torch.float, device=device) 22 | self.all_scores = [] 23 | 24 | # The backpointers at each time-step. 25 | self.prev_ks = [] 26 | 27 | # The outputs at each time-step. 28 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] 29 | self.next_ys[0][0] = Constants.BOS 30 | 31 | def get_current_state(self): 32 | "Get the outputs for the current timestep." 33 | return self.get_tentative_hypothesis() 34 | 35 | def get_current_origin(self): 36 | "Get the backpointers for the current timestep." 37 | return self.prev_ks[-1] 38 | 39 | @property 40 | def done(self): 41 | return self._done 42 | 43 | def advance(self, word_prob): 44 | "Update beam status and check if finished or not." 45 | num_words = word_prob.size(1) 46 | 47 | # Sum the previous scores. 48 | if len(self.prev_ks) > 0: 49 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) 50 | else: 51 | beam_lk = word_prob[0] 52 | 53 | flat_beam_lk = beam_lk.view(-1) 54 | 55 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 56 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort 57 | 58 | self.all_scores.append(self.scores) 59 | self.scores = best_scores 60 | 61 | # bestScoresId is flattened as a (beam x word) array, 62 | # so we need to calculate which word and beam each score came from 63 | prev_k = best_scores_id / num_words 64 | self.prev_ks.append(prev_k) 65 | self.next_ys.append(best_scores_id - prev_k * num_words) 66 | 67 | # End condition is when top-of-beam is EOS. 68 | if self.next_ys[-1][0].item() == Constants.EOS: 69 | self._done = True 70 | self.all_scores.append(self.scores) 71 | 72 | return self._done 73 | 74 | def sort_scores(self): 75 | "Sort the scores." 76 | return torch.sort(self.scores, 0, True) 77 | 78 | def get_the_best_score_and_idx(self): 79 | "Get the score of the best in the beam." 80 | scores, ids = self.sort_scores() 81 | return scores[1], ids[1] 82 | 83 | def get_tentative_hypothesis(self): 84 | "Get the decoded sequence for the current timestep." 85 | 86 | if len(self.next_ys) == 1: 87 | dec_seq = self.next_ys[0].unsqueeze(1) 88 | else: 89 | _, keys = self.sort_scores() 90 | hyps = [self.get_hypothesis(k) for k in keys] 91 | hyps = [[Constants.BOS] + h for h in hyps] 92 | dec_seq = torch.LongTensor(hyps) 93 | 94 | return dec_seq 95 | 96 | def get_hypothesis(self, k): 97 | """ Walk back to construct the full hypothesis. """ 98 | hyp = [] 99 | for j in range(len(self.prev_ks) - 1, -1, -1): 100 | hyp.append(self.next_ys[j+1][k]) 101 | k = self.prev_ks[j][k] 102 | 103 | return list(map(lambda x: x.item(), hyp[::-1])) 104 | -------------------------------------------------------------------------------- /transformer/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | BOS = 0 3 | EOS = 1 4 | UNK = 2 5 | PAD = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /transformer/Layers.py: -------------------------------------------------------------------------------- 1 | ''' Define the Layers ''' 2 | import torch.nn as nn 3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | ''' Compose with two layers ''' 10 | 11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 12 | super(EncoderLayer, self).__init__() 13 | self.slf_attn = MultiHeadAttention( 14 | n_head, d_model, d_k, d_v, dropout=dropout) 15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 16 | 17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 18 | enc_output, enc_slf_attn = self.slf_attn( 19 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 20 | enc_output *= non_pad_mask 21 | 22 | enc_output = self.pos_ffn(enc_output) 23 | enc_output *= non_pad_mask 24 | 25 | return enc_output, enc_slf_attn 26 | 27 | 28 | class DecoderLayer(nn.Module): 29 | ''' Compose with three layers ''' 30 | 31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 32 | super(DecoderLayer, self).__init__() 33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 36 | 37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 38 | dec_output, dec_slf_attn = self.slf_attn( 39 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 40 | # dec_output *= non_pad_mask 41 | 42 | dec_output, dec_enc_attn = self.enc_attn( 43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 44 | # dec_output *= non_pad_mask 45 | 46 | dec_output = self.pos_ffn(dec_output) 47 | # dec_output *= non_pad_mask 48 | 49 | return dec_output, dec_slf_attn, dec_enc_attn 50 | -------------------------------------------------------------------------------- /transformer/Models.py: -------------------------------------------------------------------------------- 1 | ''' Define the Transformer model ''' 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import transformer.Constants as Constants 6 | from transformer.Layers import EncoderLayer, DecoderLayer 7 | 8 | __author__ = "Yu-Hsiang Huang" 9 | 10 | def get_non_pad_mask(seq): 11 | assert seq.dim() == 2 12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) 13 | 14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 15 | ''' Sinusoid position encoding table ''' 16 | 17 | def cal_angle(position, hid_idx): 18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 19 | 20 | def get_posi_angle_vec(position): 21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | 25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 27 | 28 | if padding_idx is not None: 29 | # zero vector for padding dimension 30 | sinusoid_table[padding_idx] = 0. 31 | 32 | return torch.FloatTensor(sinusoid_table) 33 | 34 | def get_attn_key_pad_mask(seq_k, seq_q): 35 | ''' For masking out the padding part of key sequence. ''' 36 | 37 | # Expand to fit the shape of key query attention matrix. 38 | len_q = seq_q.size(1) 39 | padding_mask = seq_k.eq(Constants.PAD) 40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 41 | 42 | return padding_mask 43 | 44 | def get_subsequent_mask(seq): 45 | ''' For masking out the subsequent info. ''' 46 | 47 | sz_b, len_s = seq.size() 48 | subsequent_mask = torch.triu( 49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 51 | 52 | return subsequent_mask 53 | 54 | class Encoder(nn.Module): 55 | ''' A encoder model with self attention mechanism. ''' 56 | 57 | def __init__( 58 | self, 59 | n_src_vocab, len_max_seq, d_word_vec, 60 | n_layers, n_head, d_k, d_v, 61 | d_model, d_inner, dropout=0.1): 62 | 63 | super().__init__() 64 | 65 | n_position = len_max_seq + 1 66 | 67 | self.src_word_emb = nn.Embedding( 68 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD) 69 | 70 | self.position_enc = nn.Embedding.from_pretrained( 71 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 72 | freeze=True) 73 | 74 | self.layer_stack = nn.ModuleList([ 75 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 76 | for _ in range(n_layers)]) 77 | 78 | def forward(self, src_seq, src_pos, return_attns=False): 79 | 80 | enc_slf_attn_list = [] 81 | 82 | # -- Prepare masks 83 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) 84 | non_pad_mask = get_non_pad_mask(src_seq) 85 | 86 | # -- Forward 87 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) 88 | 89 | for enc_layer in self.layer_stack: 90 | enc_output, enc_slf_attn = enc_layer( 91 | enc_output, 92 | non_pad_mask=non_pad_mask, 93 | slf_attn_mask=slf_attn_mask) 94 | if return_attns: 95 | enc_slf_attn_list += [enc_slf_attn] 96 | 97 | if return_attns: 98 | return enc_output, enc_slf_attn_list 99 | return enc_output, 100 | 101 | class Decoder(nn.Module): 102 | ''' A decoder model with self attention mechanism. ''' 103 | 104 | def __init__( 105 | self, 106 | n_tgt_vocab, len_max_seq, d_word_vec, 107 | n_layers, n_head, d_k, d_v, 108 | d_model, d_inner, dropout=0.1): 109 | 110 | super().__init__() 111 | n_position = len_max_seq + 1 112 | 113 | self.tgt_word_emb = nn.Embedding( 114 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) 115 | 116 | self.position_enc = nn.Embedding.from_pretrained( 117 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 118 | freeze=True) 119 | 120 | self.layer_stack = nn.ModuleList([ 121 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 122 | for _ in range(n_layers)]) 123 | 124 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): 125 | 126 | dec_slf_attn_list, dec_enc_attn_list = [], [] 127 | 128 | # -- Prepare masks 129 | non_pad_mask = get_non_pad_mask(tgt_seq) 130 | 131 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) 132 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) 133 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) 134 | 135 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) 136 | 137 | # -- Forward 138 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) 139 | 140 | for dec_layer in self.layer_stack: 141 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 142 | dec_output, enc_output, 143 | non_pad_mask=non_pad_mask, 144 | slf_attn_mask=slf_attn_mask, 145 | dec_enc_attn_mask=dec_enc_attn_mask) 146 | 147 | if return_attns: 148 | dec_slf_attn_list += [dec_slf_attn] 149 | dec_enc_attn_list += [dec_enc_attn] 150 | 151 | if return_attns: 152 | return dec_output, dec_slf_attn_list, dec_enc_attn_list 153 | return dec_output, 154 | 155 | class Transformer(nn.Module): 156 | ''' A sequence to sequence model with attention mechanism. ''' 157 | 158 | def __init__( 159 | self, 160 | n_src_vocab, n_tgt_vocab, len_max_seq, 161 | d_word_vec=512, d_model=512, d_inner=2048, 162 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, 163 | tgt_emb_prj_weight_sharing=True, 164 | emb_src_tgt_weight_sharing=True): 165 | 166 | super().__init__() 167 | 168 | self.encoder = Encoder( 169 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, 170 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 171 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 172 | dropout=dropout) 173 | 174 | self.decoder = Decoder( 175 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, 176 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 177 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 178 | dropout=dropout) 179 | 180 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 181 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 182 | 183 | assert d_model == d_word_vec, \ 184 | 'To facilitate the residual connections, \ 185 | the dimensions of all module outputs shall be the same.' 186 | 187 | if tgt_emb_prj_weight_sharing: 188 | # Share the weight matrix between target word embedding & the final logit dense layer 189 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight 190 | self.x_logit_scale = (d_model ** -0.5) 191 | else: 192 | self.x_logit_scale = 1. 193 | 194 | if emb_src_tgt_weight_sharing: 195 | # Share the weight matrix between source & target word embeddings 196 | assert n_src_vocab == n_tgt_vocab, \ 197 | "To share word embedding table, the vocabulary size of src/tgt shall be the same." 198 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight 199 | 200 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): 201 | 202 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] 203 | 204 | enc_output, *_ = self.encoder(src_seq, src_pos) 205 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) 206 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 207 | 208 | return seq_logit 209 | -------------------------------------------------------------------------------- /transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' Scaled Dot-Product Attention ''' 9 | 10 | def __init__(self, temperature, attn_dropout=0.1): 11 | super().__init__() 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | self.softmax = nn.Softmax(dim=2) 15 | 16 | def forward(self, q, k, v, mask=None): 17 | 18 | attn = torch.bmm(q, k.transpose(1, 2)) 19 | attn = attn / self.temperature 20 | 21 | if mask is not None: 22 | attn = attn.masked_fill(mask, -np.inf) 23 | 24 | attn = self.softmax(attn) 25 | attn = self.dropout(attn) 26 | output = torch.bmm(attn, v) 27 | 28 | return output, attn 29 | -------------------------------------------------------------------------------- /transformer/Optim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | class ScheduledOptim(): 5 | '''A simple wrapper class for learning rate scheduling''' 6 | 7 | def __init__(self, optimizer, d_model, n_warmup_steps): 8 | self._optimizer = optimizer 9 | self.n_warmup_steps = n_warmup_steps 10 | self.n_current_steps = 0 11 | self.init_lr = np.power(d_model, -0.5) 12 | 13 | def step_and_update_lr(self): 14 | "Step with the inner optimizer" 15 | self._update_learning_rate() 16 | self._optimizer.step() 17 | 18 | def zero_grad(self): 19 | "Zero out the gradients by the inner optimizer" 20 | self._optimizer.zero_grad() 21 | 22 | def _get_lr_scale(self): 23 | return np.min([ 24 | np.power(self.n_current_steps, -0.5), 25 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 26 | 27 | def _update_learning_rate(self): 28 | ''' Learning rate scheduling per step ''' 29 | 30 | self.n_current_steps += 1 31 | lr = self.init_lr * self._get_lr_scale() 32 | 33 | for param_group in self._optimizer.param_groups: 34 | param_group['lr'] = lr 35 | 36 | -------------------------------------------------------------------------------- /transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | ''' Define the sublayers in encoder/decoder layer ''' 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformer.Modules import ScaledDotProductAttention 6 | 7 | __author__ = "Yu-Hsiang Huang" 8 | 9 | class MultiHeadAttention(nn.Module): 10 | ''' Multi-Head Attention module ''' 11 | 12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model, n_head * d_v) 22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.fc = nn.Linear(n_head * d_v, d_model) 30 | nn.init.xavier_normal_(self.fc.weight) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | 35 | def forward(self, q, k, v, mask=None): 36 | 37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 38 | 39 | sz_b, len_q, _ = q.size() 40 | sz_b, len_k, _ = k.size() 41 | sz_b, len_v, _ = v.size() 42 | 43 | residual = q 44 | 45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 48 | 49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 52 | 53 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 54 | output, attn = self.attention(q, k, v, mask=mask) 55 | 56 | output = output.view(n_head, sz_b, len_q, d_v) 57 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 58 | 59 | output = self.dropout(self.fc(output)) 60 | output = self.layer_norm(output + residual) 61 | 62 | return output, attn 63 | 64 | class PositionwiseFeedForward(nn.Module): 65 | ''' A two-feed-forward-layer module ''' 66 | 67 | def __init__(self, d_in, d_hid, dropout=0.1): 68 | super().__init__() 69 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 70 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 71 | self.layer_norm = nn.LayerNorm(d_in) 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | def forward(self, x): 75 | residual = x 76 | output = x.transpose(1, 2) 77 | output = self.w_2(F.relu(self.w_1(output))) 78 | output = output.transpose(1, 2) 79 | output = self.dropout(output) 80 | output = self.layer_norm(output + residual) 81 | return output 82 | -------------------------------------------------------------------------------- /transformer/Translator.py: -------------------------------------------------------------------------------- 1 | ''' This module will handle the text generation with beam search. ''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformer.Models import Transformer 8 | from transformer.Beam import Beam 9 | 10 | class Translator(object): 11 | ''' Load with trained model and handle the beam search ''' 12 | 13 | def __init__(self, opt): 14 | self.opt = opt 15 | self.device = torch.device('cuda' if opt.cuda else 'cpu') 16 | 17 | checkpoint = torch.load(opt.model) 18 | model_opt = checkpoint['settings'] 19 | self.model_opt = model_opt 20 | 21 | model = Transformer( 22 | model_opt.src_vocab_size, 23 | model_opt.tgt_vocab_size, 24 | model_opt.max_token_seq_len, 25 | tgt_emb_prj_weight_sharing=model_opt.proj_share_weight, 26 | emb_src_tgt_weight_sharing=model_opt.embs_share_weight, 27 | d_k=model_opt.d_k, 28 | d_v=model_opt.d_v, 29 | d_model=model_opt.d_model, 30 | d_word_vec=model_opt.d_word_vec, 31 | d_inner=model_opt.d_inner_hid, 32 | n_layers=model_opt.n_layers, 33 | n_head=model_opt.n_head, 34 | dropout=model_opt.dropout) 35 | 36 | model.load_state_dict(checkpoint['model']) 37 | print('[Info] Trained model state loaded.') 38 | 39 | model.word_prob_prj = nn.LogSoftmax(dim=1) 40 | 41 | model = model.to(self.device) 42 | 43 | self.model = model 44 | self.model.eval() 45 | 46 | def translate_batch(self, src_seq, src_pos): 47 | ''' Translation work in one batch ''' 48 | 49 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 50 | ''' Indicate the position of an instance in a tensor. ''' 51 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 52 | 53 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): 54 | ''' Collect tensor parts associated to active instances. ''' 55 | 56 | _, *d_hs = beamed_tensor.size() 57 | n_curr_active_inst = len(curr_active_inst_idx) 58 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 59 | 60 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) 61 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) 62 | beamed_tensor = beamed_tensor.view(*new_shape) 63 | 64 | return beamed_tensor 65 | 66 | def collate_active_info( 67 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): 68 | # Sentences which are still active are collected, 69 | # so the decoder will not run on completed sentences. 70 | n_prev_active_inst = len(inst_idx_to_position_map) 71 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 72 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 73 | 74 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) 75 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) 76 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 77 | 78 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map 79 | 80 | def beam_decode_step( 81 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): 82 | ''' Decode and update beam status, and then return active beam idx ''' 83 | 84 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 85 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] 86 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) 87 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) 88 | return dec_partial_seq 89 | 90 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 91 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 92 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 93 | return dec_partial_pos 94 | 95 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): 96 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) 97 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h 98 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) 99 | word_prob = word_prob.view(n_active_inst, n_bm, -1) 100 | 101 | return word_prob 102 | 103 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): 104 | active_inst_idx_list = [] 105 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 106 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) 107 | if not is_inst_complete: 108 | active_inst_idx_list += [inst_idx] 109 | 110 | return active_inst_idx_list 111 | 112 | n_active_inst = len(inst_idx_to_position_map) 113 | 114 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) 115 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) 116 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) 117 | 118 | # Update the beam with predicted word prob information and collect incomplete instances 119 | active_inst_idx_list = collect_active_inst_idx_list( 120 | inst_dec_beams, word_prob, inst_idx_to_position_map) 121 | 122 | return active_inst_idx_list 123 | 124 | def collect_hypothesis_and_scores(inst_dec_beams, n_best): 125 | all_hyp, all_scores = [], [] 126 | for inst_idx in range(len(inst_dec_beams)): 127 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 128 | all_scores += [scores[:n_best]] 129 | 130 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] 131 | all_hyp += [hyps] 132 | return all_hyp, all_scores 133 | 134 | with torch.no_grad(): 135 | #-- Encode 136 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) 137 | src_enc, *_ = self.model.encoder(src_seq, src_pos) 138 | 139 | #-- Repeat data for beam search 140 | n_bm = self.opt.beam_size 141 | n_inst, len_s, d_h = src_enc.size() 142 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 143 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 144 | 145 | #-- Prepare beams 146 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] 147 | 148 | #-- Bookkeeping for active or not 149 | active_inst_idx_list = list(range(n_inst)) 150 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 151 | 152 | #-- Decode 153 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): 154 | 155 | active_inst_idx_list = beam_decode_step( 156 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) 157 | 158 | if not active_inst_idx_list: 159 | break # all instances have finished their path to 160 | 161 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info( 162 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) 163 | 164 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) 165 | 166 | return batch_hyp, batch_scores 167 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import transformer.Constants 2 | import transformer.Modules 3 | import transformer.Layers 4 | import transformer.SubLayers 5 | import transformer.Models 6 | import transformer.Translator 7 | import transformer.Beam 8 | import transformer.Optim 9 | 10 | __all__ = [ 11 | transformer.Constants, transformer.Modules, transformer.Layers, 12 | transformer.SubLayers, transformer.Models, transformer.Optim, 13 | transformer.Translator, transformer.Beam] 14 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from Beam import Beam 5 | import torch.nn.functional as F 6 | import time 7 | 8 | 9 | def print_summaries(summaries, vocab, output_dir, pattern='%d.txt'): 10 | """ 11 | param summaries: in shape (seq_len, batch) 12 | """ 13 | i2w = {key: value for value, key in vocab.items()} 14 | i2w[vocab['']] = 'UNK' 15 | 16 | for idx in range(len(summaries)): 17 | fout = open(os.path.join(output_dir, pattern % idx), "w") 18 | line = [summaries[idx][0]] 19 | for tok in summaries[idx][1:]: 20 | if tok in [vocab[''], vocab['']]: 21 | break 22 | if tok != line[-1]: 23 | line.append(tok) 24 | if len(line)==0: 25 | line.append(3) # 3 for unk 26 | line = [i2w[tok] for tok in line] 27 | fout.write(" ".join(line[1:]) + "\n") 28 | fout.close() 29 | 30 | 31 | def greedy(model, x, tgt_vocab, max_trg_len=20, repl_unk=False): 32 | y = torch.ones(len(x), max_trg_len, dtype=torch.long).cuda() * tgt_vocab[""] 33 | y[:, 0] = tgt_vocab[""] 34 | enc_outputs = model.encode(x) 35 | # print(enc_outputs.shape) 36 | for i in range(max_trg_len - 1): 37 | logits, dec_enc_attns = model.decode(enc_outputs, x, y[:, :i+1]) 38 | y[:, i + 1] = torch.argmax(logits[:, i, :], dim=-1) 39 | if repl_unk: 40 | argmax = dec_enc_attns[:,i,:].argmax(dim=-1) 41 | for j in range(y.shape[0]): 42 | if int(y[j,i+1].cpu().detach()) == tgt_vocab['']: 43 | y[j,i+1] == x[j, int(argmax[j].cpu().detach())] 44 | return y.detach().cpu().tolist(), dec_enc_attns 45 | 46 | 47 | def beam_search(model, batch_x, vocab, max_trg_len=18, k=3): 48 | beams = [Beam(k, vocab, max_trg_len) for _ in range(batch_x.shape[0])] 49 | enc_outputs = model.encode(batch_x) 50 | 51 | for i in range(max_trg_len): 52 | todo = [j for j in range(len(beams)) if not beams[j].done] 53 | xs = torch.cat([batch_x[j].unsqueeze(0).expand(k, -1) for j in todo], dim=0) 54 | ys = torch.cat([beams[j].get_sequence() for j in todo], dim=0) 55 | enc_outs = torch.cat([enc_outputs[j].unsqueeze(0).expand(k, -1, -1) for j in todo], dim=0) 56 | logits, *_ = model.decode(enc_outs, xs, ys[:, :i+1]) 57 | log_probs = torch.log(F.softmax(logits[:, i, :], -1)) 58 | idx = 0 59 | for j in todo: 60 | beams[j].advance_v1(log_probs[idx: idx+k]) 61 | idx += k 62 | 63 | allHyp = [b.get_hyp().cpu().numpy() for b in beams] 64 | return allHyp 65 | 66 | 67 | def translate(valid_x, model, tgt_vocab, search='greedy', beam_width=5): 68 | summaries = [] 69 | model.eval() 70 | start = time.time() 71 | with torch.no_grad(): 72 | for i in range(valid_x.steps): 73 | _, batch_x = valid_x.next_batch() 74 | if search == "greedy": 75 | summary, dec_enc_attns = greedy(model, batch_x, tgt_vocab) 76 | elif search == "beam": 77 | summary = beam_search(model, batch_x, tgt_vocab, k=beam_width) 78 | else: 79 | raise NameError("Unknown search method") 80 | summaries.extend(summary) 81 | end = time.time() 82 | print('%.1f seconds spent, speed=%f/seconds' % (end-start, len(valid_x.data)/(end-start))) 83 | return summaries 84 | 85 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from collections import defaultdict 4 | import torch 5 | import os 6 | import config 7 | import random 8 | 9 | pad_tok = config.pad_tok 10 | start_tok = config.start_tok 11 | end_tok = config.end_tok 12 | unk_tok = config.unk_tok 13 | 14 | pad_index = config.pad_index 15 | 16 | """ Caution: 17 | In training data, unk_tok='', but in test data, unk_tok='UNK'. 18 | This is reasonable, because if the unk_tok you predict is the same as the 19 | unk_tok in the test data, then your prediction would be regard as correct, 20 | but since unk_tok is unknown, it's impossible to give a correct prediction 21 | """ 22 | 23 | 24 | def my_pad_sequence(batch, pad_tok): 25 | max_len = max([len(b) for b in batch]) 26 | batch = [b + [pad_tok] * (max_len - len(b)) for b in batch] 27 | return batch 28 | 29 | 30 | def shuffle(bm1, bm2): 31 | c = list(zip(bm1.data, bm2.data)) 32 | random.shuffle(c) 33 | bm1.data, bm2.data = zip(*c) 34 | return bm1, bm2 35 | 36 | 37 | class BatchManager: 38 | def __init__(self, data, batch_size, vocab): 39 | self.steps = int(len(data) / batch_size) 40 | # comment following two lines to neglect the last batch 41 | if self.steps * batch_size < len(data): 42 | self.steps += 1 43 | self.vocab = vocab 44 | self.data = data 45 | self.batch_size = batch_size 46 | self.bid = 0 47 | 48 | def next_batch(self, pad_flag=True, cuda=True): 49 | stncs = list(self.data[self.bid * self.batch_size: (self.bid + 1) * self.batch_size]) 50 | if pad_flag: 51 | stncs = my_pad_sequence(stncs, pad_tok) 52 | ids = [[self.vocab.get(tok, self.vocab[unk_tok]) for tok in stnc] for stnc in stncs] 53 | ids = torch.tensor(ids) 54 | self.bid += 1 55 | if self.bid == self.steps: 56 | self.bid = 0 57 | return stncs, ids.cuda() if cuda else ids 58 | 59 | 60 | def build_vocab(filelist=['sumdata/train/train.article.txt', 'sumdata/train/train.title.txt'], 61 | vocab_file='sumdata/vocab.json', min_count=0, vocab_size=1e9): 62 | print("Building vocab with min_count=%d..." % min_count) 63 | word_freq = defaultdict(int) 64 | for file in filelist: 65 | fin = open(file, "r", encoding="utf8") 66 | for _, line in enumerate(fin): 67 | for word in line.strip().split(): 68 | word_freq[word] += 1 69 | fin.close() 70 | print('Number of all words: %d' % len(word_freq)) 71 | 72 | if unk_tok in word_freq: 73 | word_freq.pop(unk_tok) 74 | sorted_freq = sorted(word_freq.items(), key=lambda x: x[1], reverse=True) 75 | 76 | vocab = {pad_tok: 0, start_tok: 1, end_tok: 2, unk_tok: 3} 77 | for word, freq in sorted_freq: 78 | if freq > min_count: 79 | vocab[word] = len(vocab) 80 | if len(vocab) == vocab_size: 81 | break 82 | print('Number of filtered words: %d, %f%% ' % (len(vocab), len(vocab)/len(word_freq)*100)) 83 | 84 | json.dump(vocab, open(vocab_file,'w')) 85 | return vocab 86 | 87 | 88 | def load_vocab(vocab_file, vocab_size=None): 89 | fin = open(vocab_file) 90 | vocab = {} 91 | if vocab_size is None: 92 | vocab_size = int(1e9) # means the whole vocab 93 | for word in list(fin.readlines())[:vocab_size]: 94 | vocab[word.strip()] = len(vocab) # [:-1] to remove \n 95 | fin.close() 96 | return vocab 97 | 98 | 99 | def get_vocab(TRAIN_X, TRAIN_Y): 100 | src_vocab_file = "sumdata/src_vocab.json" 101 | if not os.path.exists(src_vocab_file): 102 | src_vocab = build_vocab([TRAIN_X], src_vocab_file) 103 | else: 104 | src_vocab = json.load(open(src_vocab_file)) 105 | 106 | tgt_vocab_file = "sumdata/tgt_vocab.json" 107 | if not os.path.exists(tgt_vocab_file): 108 | tgt_vocab = build_vocab([TRAIN_Y], tgt_vocab_file) 109 | else: 110 | tgt_vocab = json.load(open(tgt_vocab_file)) 111 | return src_vocab, tgt_vocab 112 | 113 | 114 | def load_data(filename, max_len, n_data=None): 115 | """ 116 | :param filename: the file to read 117 | :param max_len: maximum length of a line 118 | :param vocab: dict {word: id}, if no vocab provided, return raw text 119 | :param n_data: number of lines to read 120 | :return: datas 121 | """ 122 | fin = open(filename, "r", encoding="utf8") 123 | datas = [] 124 | for idx, line in enumerate(fin): 125 | if idx == n_data or line == '': 126 | break 127 | words = line.strip().split() 128 | if len(words) > max_len - 2: 129 | words = words[:max_len-2] 130 | words = [''] + words + [''] 131 | datas.append(words) 132 | return datas 133 | 134 | --------------------------------------------------------------------------------