├── Constants.py ├── Decoders.py ├── EncDecModel.py ├── MAL.py ├── MALCopyDataset.py ├── MALDataset.py ├── README.md ├── Rouge.py ├── Run_CaKe.py ├── Utils.py ├── bleu.py ├── data ├── CopyDataset.py ├── Dataset.py ├── Utils.py └── __init__.py ├── dataset └── data.zip ├── modules ├── Attentions.py ├── Criterions.py ├── Decoders.py ├── Encoders.py ├── Evaluations.py ├── Generations.py ├── Generators.py ├── Utils.py └── __init__.py └── trainers ├── DefaultTrainer.py └── __init__.py /Constants.py: -------------------------------------------------------------------------------- 1 | PAD_WORD = '' 2 | PAD=0 3 | BOS_WORD = '' 4 | BOS=1 5 | UNK_WORD = '' 6 | UNK = 2 7 | EOS_WORD = '' 8 | EOS=3 9 | LINESEP_WORD = '' 10 | LINESEP=4 11 | 12 | id2vocab=dict({PAD:PAD_WORD, BOS:BOS_WORD, UNK:UNK_WORD, EOS:EOS_WORD, LINESEP:LINESEP_WORD}) 13 | vocab2id=dict({PAD_WORD:PAD, BOS_WORD:BOS, UNK_WORD:UNK, EOS_WORD:EOS, LINESEP_WORD:LINESEP}) 14 | 15 | -------------------------------------------------------------------------------- /Decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules.Attentions import * 4 | from modules.Utils import * 5 | from Constants import * 6 | 7 | class BBCDecoder(nn.Module): 8 | def __init__(self, embedding_size, hidden_size, tgt_vocab_size, embedding=None, num_layers=4, dropout=0.5): 9 | super(BBCDecoder, self).__init__() 10 | 11 | # Keep for reference 12 | self.hidden_size = hidden_size 13 | self.embedding_size=embedding_size 14 | self.tgt_vocab_size = tgt_vocab_size 15 | 16 | if embedding is not None: 17 | self.embedding =embedding 18 | else: 19 | self.embedding = nn.Embedding(tgt_vocab_size, embedding_size, padding_idx=PAD) 20 | self.embedding_dropout = nn.Dropout(dropout) 21 | 22 | self.src_attn = BilinearAttention( 23 | query_size=hidden_size, key_size=2*hidden_size, hidden_size=hidden_size, dropout=dropout, coverage=False 24 | ) 25 | self.bg_attn = BilinearAttention( 26 | query_size=hidden_size, key_size=2 * hidden_size, hidden_size=hidden_size, dropout=dropout, coverage=False 27 | ) #background attention score 28 | 29 | self.gru = nn.GRU(2*hidden_size+2*hidden_size+embedding_size, hidden_size, bidirectional=False, num_layers=num_layers, dropout=dropout) 30 | 31 | self.readout = nn.Linear(embedding_size + hidden_size + 2*hidden_size+ 2*hidden_size, hidden_size) 32 | 33 | def forward(self, tgt, state, src_output, bg_output, src_mask=None, bg_mask=None): 34 | gru_state = state[0] 35 | 36 | embedded = self.embedding(tgt) 37 | embedded = self.embedding_dropout(embedded) 38 | 39 | src_context, src_attn=self.src_attn(gru_state[:,-1].unsqueeze(1), src_output, src_output, query_mask=None, key_mask=src_mask) 40 | src_context=src_context.squeeze(1) 41 | src_attn = src_attn.squeeze(1) 42 | #background attention, bg_context(context),bc_attn(attention score) 43 | bg_context, bg_attn = self.bg_attn(gru_state[:, -1].unsqueeze(1), bg_output, bg_output, query_mask=None, key_mask=bg_mask) 44 | bg_context = bg_context.squeeze(1) 45 | bg_attn = bg_attn.squeeze(1) 46 | 47 | #output bg_attn, score r_t, regularize r_t-r_{t-1} 48 | print('bg_attn',bg_attn.squeeze(0).tolist()) 49 | 50 | gru_input = torch.cat((embedded, src_context, bg_context), dim=1) 51 | gru_output, gru_state=self.gru(gru_input.unsqueeze(0), gru_state.transpose(0,1)) 52 | gru_state=gru_state.transpose(0,1) 53 | 54 | concat_output = torch.cat((embedded, gru_state[:,-1], src_context, bg_context), dim=1) 55 | 56 | feature_output=self.readout(concat_output) 57 | return feature_output, [gru_state], [src_attn, bg_attn], bg_context #bg_attn: BBCDecoder[2][1] 58 | 59 | class MALDecoder(nn.Module): 60 | def __init__(self, embedding_size, hidden_size, tgt_vocab_size, embedding=None, num_layers=4, dropout=0.5): 61 | super(MALDecoder, self).__init__() 62 | 63 | # Keep for reference 64 | self.hidden_size = hidden_size 65 | self.embedding_size=embedding_size 66 | self.tgt_vocab_size = tgt_vocab_size 67 | 68 | if embedding is not None: 69 | self.embedding =embedding 70 | else: 71 | self.embedding = nn.Embedding(tgt_vocab_size, embedding_size, padding_idx=0) 72 | self.embedding_dropout = nn.Dropout(dropout) 73 | 74 | self.src_attn = BilinearAttention( 75 | query_size=hidden_size, key_size=2*hidden_size, hidden_size=hidden_size, dropout=dropout, coverage=False 76 | ) 77 | self.bg_attn = BilinearAttention( 78 | query_size=hidden_size, key_size=2 * hidden_size, hidden_size=hidden_size, dropout=dropout, coverage=False 79 | ) 80 | 81 | self.gru = nn.GRU(2*hidden_size+2*hidden_size+2*hidden_size+2*hidden_size+embedding_size, hidden_size, bidirectional=False, num_layers=num_layers, dropout=dropout) 82 | 83 | self.readout = nn.Linear(embedding_size + hidden_size + 2*hidden_size+ 2*hidden_size, hidden_size) 84 | 85 | def forward(self, tgt, state, src_output, bg_output, sel_bg, src_mask=None, bg_mask=None): 86 | gru_state = state[0] 87 | 88 | embedded = self.embedding(tgt) 89 | embedded = self.embedding_dropout(embedded) 90 | 91 | src_context, src_attn=self.src_attn(gru_state[:,-1].unsqueeze(1), src_output, src_output, query_mask=None, key_mask=src_mask) 92 | src_context=src_context.squeeze(1) 93 | src_attn = src_attn.squeeze(1) 94 | bg_context, bg_attn = self.bg_attn(gru_state[:, -1].unsqueeze(1), bg_output, bg_output, query_mask=None, key_mask=bg_mask) 95 | bg_context = bg_context.squeeze(1) 96 | bg_attn = bg_attn.squeeze(1) 97 | 98 | gru_input = torch.cat((embedded, src_context, bg_context, sel_bg), dim=1) 99 | gru_output, gru_state=self.gru(gru_input.unsqueeze(0), gru_state.transpose(0,1)) 100 | gru_state=gru_state.transpose(0,1) 101 | 102 | concat_output = torch.cat((embedded, gru_state[:,-1], src_context, bg_context), dim=1) 103 | 104 | feature_output=self.readout(concat_output) 105 | return feature_output, [gru_state], [src_attn, bg_attn], bg_context -------------------------------------------------------------------------------- /EncDecModel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from Utils import * 3 | from modules.Generations import * 4 | 5 | class EncDecModel(nn.Module): 6 | def __init__(self, src_vocab_size, embedding_size, hidden_size, tgt_vocab_size, src_id2vocab=None, src_vocab2id=None, tgt_id2vocab=None, tgt_vocab2id=None, max_dec_len=120, beam_width=1, eps=1e-10): 7 | super(EncDecModel, self).__init__() 8 | self.src_vocab_size=src_vocab_size 9 | self.embedding_size=embedding_size 10 | self.hidden_size=hidden_size 11 | self.tgt_vocab_size=tgt_vocab_size 12 | self.tgt_id2vocab=tgt_id2vocab 13 | self.tgt_vocab2id=tgt_vocab2id 14 | self.src_id2vocab=src_id2vocab 15 | self.src_vocab2id=src_vocab2id 16 | self.eps=eps 17 | self.beam_width=beam_width 18 | self.max_dec_len=max_dec_len 19 | 20 | def encode(self, data): 21 | raise NotImplementedError 22 | 23 | def init_decoder_states(self, data, encode_output): 24 | return None 25 | 26 | def decode(self, data, tgt, state, encode_output): 27 | raise NotImplementedError 28 | 29 | def generate(self, data, decode_output, softmax=False): 30 | raise NotImplementedError 31 | 32 | def to_word(self, data, gen_output, k=5, sampling=False): 33 | if not sampling: 34 | return topk(gen_output, k=k) 35 | else: 36 | return randomk(gen_output, k=k) 37 | 38 | def generation_to_decoder_input(self, data, indices): 39 | return indices 40 | 41 | def loss(self,data, all_gen_output, all_decode_output, encode_output, reduction='mean'): 42 | raise NotImplementedError 43 | 44 | def to_sentence(self, data, batch_indice): 45 | return to_sentence(batch_indice, self.tgt_id2vocab) 46 | 47 | def sample(self, data): 48 | return sample(self, data, self.max_dec_len) 49 | 50 | def greedy(self, data): 51 | return greedy(self,data, self.max_dec_len) 52 | 53 | def beam(self, data): 54 | return beam(self, data, self.max_dec_len, self.beam_width) 55 | 56 | def mle_train(self, data): 57 | encode_output, init_decoder_state, all_decode_output, all_gen_output=decode_to_end(self,data,schedule_rate=1) 58 | 59 | loss=self.loss(data,all_gen_output,all_decode_output,encode_output) 60 | 61 | return loss.unsqueeze(0) 62 | 63 | def forward(self, data, method): 64 | if method=='mle_train': 65 | return self.mle_train(data) 66 | elif method=='reinforce_train': 67 | return self.reinforce_train(data) 68 | elif method=='test': 69 | if self.beam_width==1: 70 | return self.greedy(data) 71 | else: 72 | return self.beam(data) 73 | elif method=='sample': 74 | return self.sample(data) -------------------------------------------------------------------------------- /MAL.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from data.Utils import * 4 | from EncDecModel import * 5 | from modules.Criterions import * 6 | from modules.Generators import * 7 | import numpy 8 | 9 | class Environment(nn.Module): 10 | def __init__(self,context_size, action_size, hidden_size): 11 | super(Environment, self).__init__() 12 | 13 | self.c_enc = nn.GRU(context_size, hidden_size, num_layers=1, bidirectional=True, dropout=0.5, batch_first=True) 14 | self.a_enc = nn.GRU(action_size, hidden_size, num_layers=1, bidirectional=True, dropout=0.5, batch_first=True) 15 | 16 | self.attn = BilinearAttention( 17 | query_size=2*hidden_size, key_size=2*hidden_size, hidden_size=hidden_size, dropout=0.5, coverage=False 18 | ) 19 | self.match_gru = nn.GRU(4*hidden_size, hidden_size, num_layers=1, bidirectional=True, dropout=0.5, batch_first=True) 20 | self.mlp=nn.Sequential(nn.Linear(2*hidden_size,1,bias=False), nn.Sigmoid()) 21 | 22 | def reward(self, context, action, context_mask, action_mask): 23 | context,_=gru_forward(self.c_enc, context, context_mask.sum(dim=1).long()) 24 | action, _ = gru_forward(self.a_enc, action, action_mask.sum(dim=1).long()) 25 | 26 | action_att, _ = self.attn(action, context, context, query_mask=action_mask, key_mask=context_mask) 27 | action_mask = action_mask.float().detach() 28 | feature = torch.cat([action_att, action], dim=2) * action_mask.unsqueeze(2) 29 | feature, _ = gru_forward(self.match_gru, feature, action_mask.sum(dim=1).long()) 30 | rewards = self.mlp(feature).squeeze(2) * action_mask 31 | 32 | return rewards 33 | 34 | def forward(self, context, action, context_mask, action_mask, y): 35 | rewards=self.reward(context, action, context_mask, action_mask) 36 | rewards = rewards.sum(dim=1) / action_mask.float().sum(dim=1) 37 | if y==1: 38 | gt = torch.ones(rewards.size(0)).float() 39 | if torch.cuda.is_available(): 40 | gt = gt.cuda() 41 | elif y==0: 42 | gt = torch.zeros(rewards.size(0)).float() 43 | if torch.cuda.is_available(): 44 | gt = gt.cuda() 45 | loss=F.binary_cross_entropy(rewards, gt)+1e-2*torch.distributions.bernoulli.Bernoulli(probs=rewards).entropy().mean() 46 | return loss 47 | 48 | class Encoder(nn.Module): 49 | def __init__(self, src_vocab_size, embedding_size, hidden_size): 50 | super(Encoder, self).__init__() 51 | self.c_embedding = nn.Embedding(src_vocab_size, embedding_size, padding_idx=PAD) 52 | # self.b_embedding = nn.Embedding(src_vocab_size, embedding_size,padding_idx=PAD) 53 | self.b_embedding = self.c_embedding 54 | self.c_embedding_dropout = nn.Dropout(0.5) 55 | self.b_embedding_dropout = nn.Dropout(0.5) 56 | 57 | self.c_enc = nn.GRU(embedding_size, hidden_size, num_layers=1, bidirectional=True, dropout=0.5, batch_first=True) 58 | self.b_enc = nn.GRU(embedding_size, hidden_size, num_layers=1, bidirectional=True, dropout=0.5, batch_first=True) 59 | 60 | self.attn = BilinearAttention( 61 | query_size=2 * hidden_size, key_size=2 * hidden_size, hidden_size=hidden_size, dropout=0.5, coverage=False 62 | ) 63 | 64 | self.matching_gru = nn.GRU(8 * hidden_size, hidden_size, num_layers=1, bidirectional=True, dropout=0.5) 65 | 66 | def forward(self, data): 67 | c_mask = data['context'].ne(PAD).detach() 68 | b_mask = data['background'].ne(PAD).detach() 69 | 70 | c_words = self.c_embedding_dropout(self.c_embedding(data['context'])) 71 | b_words = self.b_embedding_dropout(self.b_embedding(data['background'])) 72 | 73 | c_lengths = c_mask.sum(dim=1).detach() 74 | b_lengths = b_mask.sum(dim=1).detach() 75 | c_enc_output, c_state = gru_forward(self.c_enc, c_words, c_lengths) 76 | b_enc_output, b_state = gru_forward(self.b_enc, b_words, b_lengths) 77 | 78 | batch_size, c_len, hidden_size = c_enc_output.size() 79 | batch_size, b_len, hidden_size = b_enc_output.size() 80 | 81 | score = self.attn.unnormalized_score(b_enc_output, c_enc_output, 82 | key_mask=c_mask) # batch_size, bg_len, src_len 83 | 84 | 85 | 86 | b2c = F.softmax(score, dim=2) 87 | 88 | 89 | 90 | b2c = b2c.masked_fill((1 - b_mask).unsqueeze(2).expand(batch_size, b_len, c_len), 0) 91 | 92 | 93 | b2c = torch.bmm(b2c, c_enc_output) # batch_size, bg_len, hidden_size 94 | 95 | c2b = F.softmax(torch.max(score, dim=2)[0], dim=1).unsqueeze(1) # batch_size, 1, bg_len 96 | 97 | c2b = torch.bmm(c2b, b_enc_output).expand(-1, b_len, -1) # batch_size, bg_len, hidden_size 98 | 99 | g = torch.cat([b_enc_output, b2c, b_enc_output * b2c, b_enc_output * c2b], dim=-1) # batch_size, bg_len, 8*hidden_size 100 | 101 | m, _ = gru_forward(self.matching_gru, g, b_mask.sum(dim=1)) # batch_size, bg_len, 2*hidden_size 102 | 103 | return c_enc_output, c_state, b_enc_output, b_state, g, m 104 | 105 | class Selector(nn.Module): 106 | def __init__(self, embedding_size, hidden_size, tgt_vocab_size, embedding=None): 107 | super(Selector, self).__init__() 108 | 109 | if embedding is None: 110 | self.o_embedding = nn.Embedding(tgt_vocab_size, embedding_size, padding_idx=PAD) 111 | else: 112 | self.o_embedding =embedding 113 | self.o_embedding_dropout = nn.Dropout(0.5) 114 | 115 | self.matching_gru = nn.GRU(hidden_size+embedding_size+2*hidden_size, hidden_size, num_layers=1, bidirectional=True, dropout=0.5) 116 | 117 | self.p1_g = nn.Linear(8 * hidden_size, 1, bias=False) 118 | self.p1_m = nn.Linear(2*hidden_size, 1, bias=False) 119 | self.p1_s = nn.Linear(hidden_size+embedding_size+2*hidden_size, 1, bias=False) 120 | self.p1_t = nn.Linear(2*hidden_size, 1, bias=False) 121 | 122 | def forward(self, data, tgt, state, encode_output): 123 | b = data['background'] 124 | b_mask = b.ne(PAD) 125 | c_enc_output, c_state, b_enc_output, b_state, g, m = encode_output 126 | batch_size, b_len, _=b_enc_output.size() 127 | 128 | embedded = self.o_embedding(tgt) 129 | embedded = self.o_embedding_dropout(embedded) 130 | 131 | s = torch.cat([state[0].expand(-1,b_len,-1), embedded.unsqueeze(1).expand(-1,b_len,-1), m], dim=-1) 132 | t, _ = gru_forward(self.matching_gru, s, b_mask.sum(dim=1)) 133 | 134 | p1 = (self.p1_g(g) + self.p1_m(m)+ self.p1_s(s)+ self.p1_t(t)).squeeze(2) # batch_size, bg_len 135 | p1 = p1.masked_fill(1 - b_mask, -float('inf')) 136 | 137 | return p1, state 138 | 139 | 140 | class Generator(nn.Module): 141 | def __init__(self,embedding_size, hidden_size, tgt_vocab_size, embedding=None): 142 | super(Generator, self).__init__() 143 | 144 | self.embedding_size=embedding_size 145 | self.hidden_size=hidden_size 146 | 147 | if embedding is None: 148 | self.o_embedding = nn.Embedding(tgt_vocab_size, embedding_size, padding_idx=PAD) 149 | else: 150 | self.o_embedding =embedding 151 | self.o_embedding_dropout = nn.Dropout(0.5) 152 | 153 | self.attn = BilinearAttention( 154 | query_size=hidden_size, key_size=2*hidden_size, hidden_size=hidden_size, dropout=0.5, coverage=False 155 | ) 156 | 157 | self.dec = nn.GRU(2*hidden_size + embedding_size, hidden_size, bidirectional=False, num_layers=1, dropout=0.5, batch_first=True) 158 | self.readout = nn.Linear(embedding_size + hidden_size + 2*hidden_size, hidden_size) 159 | self.gen=LinearGenerator(feature_size=hidden_size, tgt_vocab_size=tgt_vocab_size) 160 | 161 | def forward(self, data, tgt, state, encode_output): 162 | c_enc_output, c_state, b_enc_output, b_state, g, m=encode_output 163 | c_mask=data['context'].ne(PAD) 164 | 165 | state = state[0] 166 | 167 | embedded = self.o_embedding(tgt) 168 | embedded = self.o_embedding_dropout(embedded) 169 | 170 | # attended vector, attention 171 | attn_context_1, attn = self.attn(state, c_enc_output, c_enc_output, query_mask=None, key_mask=c_mask) 172 | 173 | # input:x_{t-1}, attended vector 174 | gru_input = torch.cat((embedded.unsqueeze(1), attn_context_1), dim=2) 175 | 176 | # update state by attended vector and x_{t-1} 177 | gru_output, state = self.dec(gru_input, state.transpose(0,1)) #gru_input: 2hidden+embedding 178 | state=state.transpose(0,1) 179 | 180 | # use new state to get new attention vector 181 | attn_context, attn = self.attn(state, c_enc_output, c_enc_output, query_mask=None, key_mask=c_mask) 182 | attn_context = attn_context.squeeze(1) 183 | 184 | concat_output = torch.cat((embedded, state.squeeze(1), attn_context), dim=1) 185 | 186 | feature_output = self.readout(concat_output) 187 | 188 | return self.gen(feature_output, softmax=False), [state] 189 | 190 | 191 | class Mixture(nn.Module): 192 | def __init__(self, state_size): 193 | super(Mixture, self).__init__() 194 | self.linear_mixture = nn.Linear(state_size, 1) 195 | 196 | 197 | def forward(self, state, selector_action, generator_action, b_map): 198 | p_s_g = torch.sigmoid(self.linear_mixture(state[0].squeeze(1))) 199 | 200 | selector_action=F.softmax(selector_action, dim=1) # p_background=softmax(p1) 201 | 202 | 203 | generator_action = F.softmax(generator_action, dim=1) # 204 | 205 | generator_action = torch.mul(generator_action, p_s_g.expand_as(generator_action)) #p*p_vocab 206 | 207 | selector_action = torch.bmm(selector_action.unsqueeze(1), b_map.float()).squeeze(1) #selector action is p1 208 | 209 | selector_action = torch.mul(selector_action, (1-p_s_g).expand_as(selector_action)) #(1-p)p_background 210 | 211 | return torch.cat([generator_action, selector_action], 1) #return final 212 | 213 | 214 | class MAL(EncDecModel): 215 | def __init__(self, encoder, selector, generator, env, src_id2vocab, src_vocab2id, tgt_id2vocab, tgt_vocab2id, max_dec_len, beam_width, eps=1e-10): 216 | super(MAL, self).__init__(src_vocab_size=len(src_id2vocab), embedding_size=generator.embedding_size, 217 | hidden_size=generator.hidden_size, tgt_vocab_size=len(tgt_id2vocab), src_id2vocab=src_id2vocab, 218 | src_vocab2id=src_vocab2id, tgt_id2vocab=tgt_id2vocab, tgt_vocab2id=tgt_vocab2id, max_dec_len=max_dec_len, beam_width=beam_width, 219 | eps=eps) 220 | self.encoder=encoder 221 | self.selector=selector 222 | self.generator=generator 223 | self.env=env 224 | 225 | self.state_initializer = nn.Linear(2*self.hidden_size, self.hidden_size) 226 | self.mixture=Mixture(self.hidden_size) 227 | self.criterion = CopyCriterion(len(tgt_id2vocab), force_copy=False, eps=eps) 228 | 229 | def encode(self,data): 230 | return self.encoder(data) 231 | 232 | def init_decoder_states(self,data, encode_output): 233 | c_enc_output, c_state, b_enc_output, b_state, g, m = encode_output 234 | batch_size=c_state.size(0) 235 | 236 | return [self.state_initializer(c_state.contiguous().view(batch_size,-1)).view(batch_size, 1, -1)] 237 | 238 | def decode(self, data, tgt, state, encode_output): 239 | sel_decode_output=self.selector(data, tgt, state, encode_output) 240 | gen_decode_output=self.generator(data, tgt, state, encode_output) 241 | return [sel_decode_output[0],gen_decode_output[0]], gen_decode_output[1] 242 | 243 | def generate(self, data, decode_output, softmax=True): 244 | actions, state = decode_output 245 | self.mixture(state, actions[0], actions[1], data['background_map']) 246 | return self.mixture(state, actions[0], actions[1], data['background_map']) 247 | 248 | def loss(self,data, all_gen_output, all_decode_output, encode_output, reduction='mean'): 249 | loss=self.criterion(all_gen_output, data['output'], data['background_copy'], reduction=reduction) 250 | return loss 251 | # return loss+1e-2*torch.distributions.categorical.Categorical(probs=all_gen_output.view(-1, all_gen_output.size(2))).entropy().mean() 252 | 253 | def generation_to_decoder_input(self, data, indices): 254 | return indices.masked_fill(indices>=self.tgt_vocab_size, UNK) 255 | 256 | def to_word(self, data, gen_output, k=5, sampling=False): 257 | if not sampling: 258 | return copy_topk(gen_output, data['background_vocab_map'], data['background_vocab_overlap'], k=k) 259 | else: 260 | return randomk(gen_output, k=k) 261 | 262 | def to_sentence(self, data, batch_indices): 263 | return to_copy_sentence(data, batch_indices, self.tgt_id2vocab, data['background_dyn_vocab']) 264 | 265 | def forward(self, data, method='mle_train'): 266 | if method=='mle_train': 267 | return self.mle_train(data) 268 | elif method=='mal_train': 269 | return self.mal_train(data) 270 | elif method=='env_train': 271 | return self.env_train(data) 272 | elif method=='test': 273 | if self.beam_width==1: 274 | return self.greedy(data) 275 | else: 276 | return self.beam(data) 277 | elif method=='sample': 278 | return self.sample(data) 279 | 280 | def env_train(self, data): 281 | c_mask = data['context'].ne(PAD).detach() 282 | o_mask = data['output'].ne(PAD).detach() 283 | 284 | with torch.no_grad(): 285 | c = self.encoder.c_embedding(data['context']).detach() 286 | o = self.generator.o_embedding(data['output']).detach() 287 | 288 | a, encode_outputs, init_decoder_states, all_decode_outputs, all_gen_outputs=sample(self, data, max_len=self.max_dec_len) 289 | a.masked_fill_(a >= self.tgt_vocab_size, UNK) 290 | a_mask = a.ne(PAD).detach() 291 | a = self.generator.o_embedding(a).detach() 292 | 293 | return self.env(c, o, c_mask, o_mask, 1).unsqueeze(0), self.env(c, a, c_mask, a_mask, 0).unsqueeze(0) 294 | 295 | def mle_train(self, data): 296 | encode_output, init_decoder_state, all_decode_output, all_gen_output=decode_to_end(self,data,schedule_rate=1) 297 | 298 | gen_loss=self.loss(data,all_gen_output,all_decode_output,encode_output).unsqueeze(0) 299 | 300 | return gen_loss 301 | -------------------------------------------------------------------------------- /MALCopyDataset.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from torch.utils.data import Dataset 3 | from Constants import * 4 | from data.Utils import * 5 | import json 6 | 7 | # from multiprocessing.dummy import Pool as ThreadPool 8 | # pool = ThreadPool(32) 9 | 10 | class MALDataset(Dataset): 11 | def __init__(self, files, src_vocab2id, tgt_vocab2id,n=1E10): 12 | super(MALDataset, self).__init__() 13 | self.ids = list() 14 | self.contexts = list() 15 | self.queries = list() 16 | self.outputs = list() 17 | self.backgrounds = list() 18 | 19 | self.id_arrays = list() 20 | self.context_arrays = list() 21 | self.query_arrays = list() 22 | self.output_arrays = list() 23 | self.background_arrays = list() 24 | self.background_selection_arrays = list() 25 | self.background_ref_start_arrays = list() 26 | self.background_ref_end_arrays = list() 27 | 28 | self.bg_dyn_vocab2ids=list() 29 | self.bg_dyn_id2vocabs=list() 30 | self.background_copy_arrays= list() 31 | 32 | self.src_vocab2id=src_vocab2id 33 | self.tgt_vocab2id=tgt_vocab2id 34 | self.files=files 35 | self.n=n 36 | 37 | self.load() 38 | 39 | def load(self): 40 | with codecs.open(self.files[0], encoding='utf-8') as f: 41 | data = json.load(f) 42 | for id in range(len(data)): 43 | sample=data[id] 44 | 45 | context = sample['context'].split(' ') 46 | self.contexts.append(context) 47 | self.context_arrays.append(torch.tensor([self.src_vocab2id.get(w.lower(), UNK) for w in context], requires_grad=False).long()) 48 | 49 | query = sample['query'].split(' ') 50 | # query = context 51 | # query = query[max(0, len(query) - 65):] 52 | self.queries.append(query) 53 | self.query_arrays.append(torch.tensor([self.src_vocab2id.get(w.lower(), UNK) for w in query], requires_grad=False).long()) 54 | 55 | background = sample['background'].split(' ') 56 | # background = background[:min(self.max_bg, len(background))] 57 | self.backgrounds.append(background) 58 | self.background_arrays.append(torch.tensor([self.src_vocab2id.get(w.lower(), UNK) for w in background], requires_grad=False).long()) 59 | 60 | bg_dyn_vocab2id, bg_dyn_id2vocab = build_vocab(sample['background'].lower().split(' ')) 61 | self.bg_dyn_vocab2ids.append((id, bg_dyn_vocab2id)) 62 | self.bg_dyn_id2vocabs.append((id, bg_dyn_id2vocab)) 63 | 64 | output = sample['response'].lower().split(' ') 65 | self.outputs.append(output) 66 | self.output_arrays.append(torch.tensor([self.tgt_vocab2id.get(w, UNK) for w in output] + [EOS], requires_grad=False).long()) 67 | self.background_copy_arrays.append(torch.tensor([bg_dyn_vocab2id.get(w, UNK) for w in output] + [EOS], requires_grad=False).long()) 68 | 69 | output = set(output) 70 | self.background_selection_arrays.append(torch.tensor([1 if w.lower() in output else 0 for w in background], requires_grad=False).long()) 71 | 72 | if 'bg_ref_start' in sample: 73 | self.background_ref_start_arrays.append(torch.tensor([sample['bg_ref_start']], requires_grad=False)) 74 | self.background_ref_end_arrays.append(torch.tensor([sample['bg_ref_end'] - 1], requires_grad=False)) 75 | else: 76 | self.background_ref_start_arrays.append(torch.tensor([-1], requires_grad=False)) 77 | self.background_ref_end_arrays.append(torch.tensor([-1], requires_grad=False)) 78 | 79 | self.ids.append(id) 80 | self.id_arrays.append(torch.tensor([id]).long()) 81 | 82 | if len(self.contexts)>=self.n: 83 | break 84 | self.len = len(self.contexts) 85 | print('data size: ', self.len) 86 | 87 | def __getitem__(self, index): 88 | return [self.ids[index], self.id_arrays[index], self.contexts[index],self.context_arrays[index],self.queries[index], self.query_arrays[index], 89 | self.backgrounds[index], self.background_arrays[index], self.outputs[index], self.output_arrays[index],self.background_copy_arrays[index], 90 | self.bg_dyn_id2vocabs[index], self.bg_dyn_vocab2ids[index], self.tgt_vocab2id, self.background_ref_start_arrays[index], self.background_ref_end_arrays[index]] 91 | 92 | def __len__(self): 93 | return self.len 94 | 95 | def input(self,id): 96 | return self.contexts[id] 97 | 98 | def output(self,id): 99 | return self.outputs[id] 100 | 101 | def background(self,id): 102 | return self.backgrounds[id] 103 | 104 | def train_collate_fn(data): 105 | id,id_a,context,context_a,query,query_a,background,background_a,output,output_a,background_copy_a,bg_dyn_id2vocab,bg_dyn_vocab2id,tgt_vocab2id, background_ref_start_a, background_ref_end_a = zip(*data) 106 | 107 | batch_size=len(id_a) 108 | 109 | id_t = torch.cat(id_a) 110 | context_t = torch.zeros(batch_size, max([len(s) for s in context_a]), requires_grad=False).long() 111 | query_t = torch.zeros(batch_size, max([len(s) for s in query_a]), requires_grad=False).long() 112 | output_t = torch.zeros(batch_size, max([len(s) for s in output_a]), requires_grad=False).long() 113 | background_copy_t = torch.zeros(batch_size, max([len(s) for s in background_copy_a]), requires_grad=False).long() 114 | background_t = torch.zeros(batch_size, max([len(s) for s in background_a]), requires_grad=False).long() 115 | # bg_sel_t = torch.zeros(batch_size, max([len(s) for s in bg_sel_a]), requires_grad=False).long() 116 | background_ref_start_t = torch.cat(background_ref_start_a) 117 | background_ref_end_t = torch.cat(background_ref_end_a) 118 | 119 | background_map_t=torch.zeros(batch_size, max([len(s) for s in background_a]), max([len(s) for k,s in bg_dyn_vocab2id]), requires_grad=False).long() 120 | 121 | output_text=dict() 122 | background_text=dict() 123 | 124 | def one_instance(b): 125 | context_t[b, :len(context_a[b])] = context_a[b] 126 | query_t[b, :len(query_a[b])] = query_a[b] 127 | output_t[b, :len(output_a[b])] = output_a[b] 128 | background_copy_t[b, :len(background_copy_a[b])] = background_copy_a[b] 129 | background_t[b, :len(background_a[b])] = background_a[b] 130 | # bg_sel_t[b, :len(bg_sel_a[b])] = bg_sel_a[b] 131 | 132 | output_text[id_a[b].item()]=output[b] 133 | background_text[id_a[b].item()]=background[b] 134 | 135 | _, vocab2id = bg_dyn_vocab2id[b] 136 | for j in range(len(background[b])): 137 | # if j >= background_ref_start_a[b] and j <= background_ref_end_a[b]: 138 | background_map_t[b, j, vocab2id[background[b][j].lower()]] = 1 139 | 140 | for b in range(batch_size): 141 | one_instance(b) 142 | 143 | # pool.map(one_instance, range(batch_size)) 144 | 145 | return {'id':id_t,'context':context_t, 'query':query_t, 'output':output_t,'output_text':output_text, 'background':background_t, 'background_text': background_text, 'background_ref_start':background_ref_start_t,'background_ref_end':background_ref_end_t, 'background_dyn_vocab':dict(bg_dyn_id2vocab), 'background_copy':background_copy_t, 'background_map':background_map_t} 146 | 147 | 148 | def test_collate_fn(data): 149 | id, id_a, context, context_a,query,query_a, background, background_a, output, output_a, background_copy_a, bg_dyn_id2vocab, bg_dyn_vocab2id, tgt_vocab2id, background_ref_start_a, background_ref_end_a = zip( 150 | *data) 151 | tgt_vocab2id = tgt_vocab2id[0] 152 | 153 | batch_size = len(id_a) 154 | 155 | id_t = torch.cat(id_a) 156 | context_t = torch.zeros(batch_size, max([len(s) for s in context_a]), requires_grad=False).long() 157 | query_t = torch.zeros(batch_size, max([len(s) for s in query_a]), requires_grad=False).long() 158 | output_t = torch.zeros(batch_size, max([len(s) for s in output_a]), requires_grad=False).long() 159 | background_copy_t = torch.zeros(batch_size, max([len(s) for s in background_copy_a]), requires_grad=False).long() 160 | background_t = torch.zeros(batch_size, max([len(s) for s in background_a]), requires_grad=False).long() 161 | # bg_sel_t = torch.zeros(batch_size, max([len(s) for s in bg_sel_a]), requires_grad=False).long() 162 | 163 | background_map_t = torch.zeros(batch_size, max([len(s) for s in background_a]), max([len(s) for k, s in bg_dyn_vocab2id]), 164 | requires_grad=False).long() 165 | background_vocab_map_t = torch.zeros(batch_size, max([len(s) for k, s in bg_dyn_id2vocab]), len(tgt_vocab2id), 166 | requires_grad=False).float() 167 | background_vocab_overlap_t = torch.ones(batch_size, max([len(s) for k, s in bg_dyn_id2vocab]), 168 | requires_grad=False).float() 169 | background_text=dict() 170 | 171 | def one_instance(b): 172 | context_t[b, :len(context_a[b])] = context_a[b] 173 | query_t[b, :len(query_a[b])] = query_a[b] 174 | output_t[b, :len(output_a[b])] = output_a[b] 175 | background_copy_t[b, :len(background_copy_a[b])] = background_copy_a[b] 176 | background_t[b, :len(background_a[b])] = background_a[b] 177 | # bg_sel_t[b, :len(bg_sel_a[b])] = bg_sel_a[b] 178 | 179 | background_text[id_a[b].item()]=background[b] 180 | 181 | _, vocab2id = bg_dyn_vocab2id[b] 182 | for j in range(len(background[b])): 183 | # if j >= background_ref_start_a[b] and j <= background_ref_end_a[b]: 184 | background_map_t[b, j, vocab2id[background[b][j].lower()]] = 1 185 | 186 | _, id2vocab = bg_dyn_id2vocab[b] 187 | for id in id2vocab: 188 | if id2vocab[id] in tgt_vocab2id: 189 | background_vocab_map_t[b, id, tgt_vocab2id[id2vocab[id]]] = 1 190 | background_vocab_overlap_t[b, id] = 0 191 | 192 | for b in range(batch_size): 193 | one_instance(b) 194 | 195 | # pool.map(one_instance, range(batch_size)) 196 | 197 | return {'id': id_t, 'context':context_t, 'query':query_t, 'output':output_t, 'background':background_t, 198 | 'background_dyn_vocab': dict(bg_dyn_id2vocab), 'background_copy': background_copy_t, 199 | 'background_map': background_map_t, 'background_vocab_map': background_vocab_map_t, 200 | 'background_vocab_overlap': background_vocab_overlap_t, 'background_text': background_text} 201 | 202 | -------------------------------------------------------------------------------- /MALDataset.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from torch.utils.data import Dataset 3 | from Constants import * 4 | from data.Utils import * 5 | import json 6 | from multiprocessing.dummy import Pool as ThreadPool 7 | pool = ThreadPool(32) 8 | 9 | class MALDataset(Dataset): 10 | def __init__(self, files, src_vocab2id, tgt_vocab2id, n=1E10): 11 | super(MALDataset, self).__init__() 12 | self.ids=list() 13 | self.contexts=list() 14 | self.queries = list() 15 | self.outputs=list() 16 | self.backgrounds=list() 17 | 18 | self.id_arrays = list() 19 | self.context_arrays = list() 20 | self.query_arrays = list() 21 | self.output_arrays = list() 22 | self.background_arrays = list() 23 | self.background_selection_arrays=list() 24 | self.background_ref_start_arrays = list() 25 | self.background_ref_end_arrays = list() 26 | 27 | self.src_vocab2id=src_vocab2id 28 | self.tgt_vocab2id=tgt_vocab2id 29 | self.files=files 30 | self.n=n 31 | 32 | self.load() 33 | 34 | def load(self): 35 | with codecs.open(self.files[0], encoding='utf-8') as f: 36 | data = json.load(f) 37 | for id in range(len(data)): 38 | sample=data[id] 39 | 40 | context=sample['context'].split(' ') 41 | self.contexts.append(context) 42 | self.context_arrays.append(torch.tensor([self.src_vocab2id.get(w.lower(), UNK) for w in context], requires_grad=False).long()) 43 | 44 | query = sample['query'].split(' ') 45 | # query=context 46 | self.queries.append(query) 47 | self.query_arrays.append(torch.tensor([self.src_vocab2id.get(w.lower(), UNK) for w in query], requires_grad=False).long()) 48 | 49 | background = sample['background'].split(' ') 50 | self.backgrounds.append(background) 51 | self.background_arrays.append(torch.tensor([self.src_vocab2id.get(w.lower(), UNK) for w in background], requires_grad=False).long()) 52 | 53 | output = sample['response'].lower().split(' ') 54 | self.outputs.append(output) 55 | self.output_arrays.append(torch.tensor([self.tgt_vocab2id.get(w, UNK) for w in output] + [EOS], requires_grad=False).long()) 56 | 57 | output =set(output) 58 | self.background_selection_arrays.append(torch.tensor([1 if w.lower() in output else 0 for w in background], requires_grad=False).long()) 59 | 60 | if 'bg_ref_start' in sample: 61 | self.background_ref_start_arrays.append(torch.tensor([sample['bg_ref_start']], requires_grad=False)) 62 | self.background_ref_end_arrays.append(torch.tensor([sample['bg_ref_end'] - 1], requires_grad=False)) 63 | else: 64 | self.background_ref_start_arrays.append(torch.tensor([-1], requires_grad=False)) 65 | self.background_ref_end_arrays.append(torch.tensor([-1], requires_grad=False)) 66 | 67 | self.ids.append(id) 68 | self.id_arrays.append(torch.tensor([id]).long()) 69 | 70 | if len(self.contexts)>=self.n: 71 | break 72 | self.len = len(self.contexts) 73 | print('data size: ', self.len) 74 | 75 | def __getitem__(self, index): 76 | return [self.ids[index], self.id_arrays[index], self.contexts[index], self.context_arrays[index], self.queries[index], self.query_arrays[index], 77 | self.backgrounds[index], self.background_arrays[index], self.outputs[index], self.output_arrays[index], self.background_ref_start_arrays[index], self.background_ref_end_arrays[index]] 78 | 79 | def __len__(self): 80 | return self.len 81 | 82 | def input(self,id): 83 | return self.contexts[id] 84 | 85 | def output(self,id): 86 | return self.outputs[id] 87 | 88 | def background(self,id): 89 | return self.backgrounds[id] 90 | 91 | def train_collate_fn(data): 92 | id,id_a,context,context_a,query,query_a,background,background_a,output,output_a, background_ref_start_a, background_ref_end_a = zip(*data) 93 | 94 | batch_size=len(id_a) 95 | 96 | id_t = torch.cat(id_a) 97 | context_t = torch.zeros(batch_size, max([len(s) for s in context_a]), requires_grad=False).long() 98 | query_t = torch.zeros(batch_size, max([len(s) for s in query_a]), requires_grad=False).long() 99 | output_t = torch.zeros(batch_size, max([len(s) for s in output_a]), requires_grad=False).long() 100 | background_t = torch.zeros(batch_size, max([len(s) for s in background_a]), requires_grad=False).long() 101 | # bg_sel_t = torch.zeros(batch_size, max([len(s) for s in bg_sel_a]), requires_grad=False).long() 102 | background_ref_start_t = torch.cat(background_ref_start_a) 103 | background_ref_end_t = torch.cat(background_ref_end_a) 104 | 105 | def one_instance(b): 106 | context_t[b, :len(context_a[b])] = context_a[b] 107 | query_t[b, :len(query_a[b])] = query_a[b] 108 | output_t[b, :len(output_a[b])] = output_a[b] 109 | background_t[b, :len(background_a[b])] = background_a[b] 110 | # bg_sel_t[b, :len(bg_sel_a[b])] = bg_sel_a[b] 111 | 112 | # for b in range(batch_size): 113 | # one_instance(b) 114 | 115 | pool.map(one_instance, range(batch_size)) 116 | 117 | return {'id':id_t, 'context':context_t, 'query':query_t, 'output':output_t, 'background':background_t, 'background_ref_start':background_ref_start_t,'background_ref_end':background_ref_end_t} 118 | 119 | 120 | def test_collate_fn(data): 121 | id, id_a, context, context_a, query, query_a, background, background_a, output, output_a, background_ref_start_a, background_ref_end_a = zip(*data) 122 | 123 | batch_size = len(id_a) 124 | 125 | id_t = torch.cat(id_a) 126 | context_t = torch.zeros(batch_size, max([len(s) for s in context_a]), requires_grad=False).long() 127 | query_t = torch.zeros(batch_size, max([len(s) for s in query_a]), requires_grad=False).long() 128 | output_t = torch.zeros(batch_size, max([len(s) for s in output_a]), requires_grad=False).long() 129 | background_t = torch.zeros(batch_size, max([len(s) for s in background_a]), requires_grad=False).long() 130 | background_text=dict() 131 | 132 | def one_instance(b): 133 | context_t[b, :len(context_a[b])] = context_a[b] 134 | query_t[b, :len(query_a[b])] = query_a[b] 135 | output_t[b, :len(output_a[b])] = output_a[b] 136 | background_t[b, :len(background_a[b])] = background_a[b] 137 | 138 | background_text[id[b]]=background[b] 139 | 140 | # for b in range(batch_size): 141 | # one_instance(b) 142 | 143 | pool.map(one_instance, range(batch_size)) 144 | 145 | return {'id': id_t, 'context': context_t, 'query':query_t, 'output': output_t, 'background': background_t, 'background_text': background_text} 146 | 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | use run_Cake.py file to run the code. 2 | Change batch_size and input/output file. 3 | Also you need to change the t parameter. t=10 is for test. 4 | -------------------------------------------------------------------------------- /Rouge.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2017 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ROUGe metric implementation. 16 | 17 | This is a modified and slightly extended verison of 18 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | from __future__ import unicode_literals 25 | 26 | import itertools 27 | import numpy as np 28 | 29 | #pylint: disable=C0103 30 | 31 | 32 | def _get_ngrams(n, text): 33 | """Calcualtes n-grams. 34 | 35 | Args: 36 | n: which n-grams to calculate 37 | text: An array of tokens 38 | 39 | Returns: 40 | A set of n-grams 41 | """ 42 | ngram_set = set() 43 | text_length = len(text) 44 | max_index_ngram_start = text_length - n 45 | for i in range(max_index_ngram_start + 1): 46 | ngram_set.add(tuple(text[i:i + n])) 47 | return ngram_set 48 | 49 | 50 | def _split_into_words(sentences): 51 | """Splits multiple sentences into words and flattens the result""" 52 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 53 | 54 | 55 | def _get_word_ngrams(n, sentences): 56 | """Calculates word n-grams for multiple sentences. 57 | """ 58 | assert len(sentences) > 0 59 | assert n > 0 60 | 61 | words = _split_into_words(sentences) 62 | return _get_ngrams(n, words) 63 | 64 | 65 | def _len_lcs(x, y): 66 | """ 67 | Returns the length of the Longest Common Subsequence between sequences x 68 | and y. 69 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 70 | 71 | Args: 72 | x: sequence of words 73 | y: sequence of words 74 | 75 | Returns 76 | integer: Length of LCS between x and y 77 | """ 78 | table = _lcs(x, y) 79 | n, m = len(x), len(y) 80 | return table[n, m] 81 | 82 | 83 | def _lcs(x, y): 84 | """ 85 | Computes the length of the longest common subsequence (lcs) between two 86 | strings. The implementation below uses a DP programming algorithm and runs 87 | in O(nm) time where n = len(x) and m = len(y). 88 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 89 | 90 | Args: 91 | x: collection of words 92 | y: collection of words 93 | 94 | Returns: 95 | Table of dictionary of coord and len lcs 96 | """ 97 | n, m = len(x), len(y) 98 | table = dict() 99 | for i in range(n + 1): 100 | for j in range(m + 1): 101 | if i == 0 or j == 0: 102 | table[i, j] = 0 103 | elif x[i - 1] == y[j - 1]: 104 | table[i, j] = table[i - 1, j - 1] + 1 105 | else: 106 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 107 | return table 108 | 109 | 110 | def _recon_lcs(x, y): 111 | """ 112 | Returns the Longest Subsequence between x and y. 113 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 114 | 115 | Args: 116 | x: sequence of words 117 | y: sequence of words 118 | 119 | Returns: 120 | sequence: LCS of x and y 121 | """ 122 | i, j = len(x), len(y) 123 | table = _lcs(x, y) 124 | 125 | def _recon(i, j): 126 | """private recon calculation""" 127 | if i == 0 or j == 0: 128 | return [] 129 | elif x[i - 1] == y[j - 1]: 130 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 131 | elif table[i - 1, j] > table[i, j - 1]: 132 | return _recon(i - 1, j) 133 | else: 134 | return _recon(i, j - 1) 135 | 136 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 137 | return recon_tuple 138 | 139 | 140 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 141 | """ 142 | Computes ROUGE-N of two text collections of sentences. 143 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 144 | papers/rouge-working-note-v1.3.1.pdf 145 | 146 | Args: 147 | evaluated_sentences: The sentences that have been picked by the summarizer 148 | reference_sentences: The sentences from the referene set 149 | n: Size of ngram. Defaults to 2. 150 | 151 | Returns: 152 | A tuple (f1, precision, recall) for ROUGE-N 153 | 154 | Raises: 155 | ValueError: raises exception if a param has len <= 0 156 | """ 157 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 158 | raise ValueError("Collections must contain at least 1 sentence.") 159 | 160 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 161 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 162 | reference_count = len(reference_ngrams) 163 | evaluated_count = len(evaluated_ngrams) 164 | 165 | # Gets the overlapping ngrams between evaluated and reference 166 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 167 | overlapping_count = len(overlapping_ngrams) 168 | 169 | # Handle edge case. This isn't mathematically correct, but it's good enough 170 | if evaluated_count == 0: 171 | precision = 0.0 172 | else: 173 | precision = overlapping_count / evaluated_count 174 | 175 | if reference_count == 0: 176 | recall = 0.0 177 | else: 178 | recall = overlapping_count / reference_count 179 | 180 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 181 | 182 | # return overlapping_count / reference_count 183 | return f1_score, precision, recall 184 | 185 | 186 | def _f_p_r_lcs(llcs, m, n): 187 | """ 188 | Computes the LCS-based F-measure score 189 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 190 | rouge-working-note-v1.3.1.pdf 191 | 192 | Args: 193 | llcs: Length of LCS 194 | m: number of words in reference summary 195 | n: number of words in candidate summary 196 | 197 | Returns: 198 | Float. LCS-based F-measure score 199 | """ 200 | r_lcs = llcs / m 201 | p_lcs = llcs / n 202 | beta = p_lcs / (r_lcs + 1e-12) 203 | num = (1 + (beta**2)) * r_lcs * p_lcs 204 | denom = r_lcs + ((beta**2) * p_lcs) 205 | f_lcs = num / (denom + 1e-12) 206 | return f_lcs, p_lcs, r_lcs 207 | 208 | 209 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 210 | """ 211 | Computes ROUGE-L (sentence level) of two text collections of sentences. 212 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 213 | rouge-working-note-v1.3.1.pdf 214 | 215 | Calculated according to: 216 | R_lcs = LCS(X,Y)/m 217 | P_lcs = LCS(X,Y)/n 218 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 219 | 220 | where: 221 | X = reference summary 222 | Y = Candidate summary 223 | m = length of reference summary 224 | n = length of candidate summary 225 | 226 | Args: 227 | evaluated_sentences: The sentences that have been picked by the summarizer 228 | reference_sentences: The sentences from the referene set 229 | 230 | Returns: 231 | A float: F_lcs 232 | 233 | Raises: 234 | ValueError: raises exception if a param has len <= 0 235 | """ 236 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 237 | raise ValueError("Collections must contain at least 1 sentence.") 238 | reference_words = _split_into_words(reference_sentences) 239 | evaluated_words = _split_into_words(evaluated_sentences) 240 | m = len(reference_words) 241 | n = len(evaluated_words) 242 | lcs = _len_lcs(evaluated_words, reference_words) 243 | return _f_p_r_lcs(lcs, m, n) 244 | 245 | 246 | def _union_lcs(evaluated_sentences, reference_sentence): 247 | """ 248 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 249 | subsequence between reference sentence ri and candidate summary C. For example 250 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 251 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 252 | “w1 w2” and the longest common subsequence of r_i and c2 is “w1 w3 w5”. The 253 | union longest common subsequence of r_i, c1, and c2 is “w1 w2 w3 w5” and 254 | LCS_u(r_i, C) = 4/5. 255 | 256 | Args: 257 | evaluated_sentences: The sentences that have been picked by the summarizer 258 | reference_sentence: One of the sentences in the reference summaries 259 | 260 | Returns: 261 | float: LCS_u(r_i, C) 262 | 263 | ValueError: 264 | Raises exception if a param has len <= 0 265 | """ 266 | if len(evaluated_sentences) <= 0: 267 | raise ValueError("Collections must contain at least 1 sentence.") 268 | 269 | lcs_union = set() 270 | reference_words = _split_into_words([reference_sentence]) 271 | combined_lcs_length = 0 272 | for eval_s in evaluated_sentences: 273 | evaluated_words = _split_into_words([eval_s]) 274 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 275 | combined_lcs_length += len(lcs) 276 | lcs_union = lcs_union.union(lcs) 277 | 278 | union_lcs_count = len(lcs_union) 279 | union_lcs_value = union_lcs_count / combined_lcs_length 280 | return union_lcs_value 281 | 282 | 283 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 284 | """ 285 | Computes ROUGE-L (summary level) of two text collections of sentences. 286 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 287 | rouge-working-note-v1.3.1.pdf 288 | 289 | Calculated according to: 290 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 291 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 292 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 293 | 294 | where: 295 | SUM(i,u) = SUM from i through u 296 | u = number of sentences in reference summary 297 | C = Candidate summary made up of v sentences 298 | m = number of words in reference summary 299 | n = number of words in candidate summary 300 | 301 | Args: 302 | evaluated_sentences: The sentences that have been picked by the summarizer 303 | reference_sentence: One of the sentences in the reference summaries 304 | 305 | Returns: 306 | A float: F_lcs 307 | 308 | Raises: 309 | ValueError: raises exception if a param has len <= 0 310 | """ 311 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 312 | raise ValueError("Collections must contain at least 1 sentence.") 313 | 314 | # total number of words in reference sentences 315 | m = len(_split_into_words(reference_sentences)) 316 | 317 | # total number of words in evaluated sentences 318 | n = len(_split_into_words(evaluated_sentences)) 319 | 320 | union_lcs_sum_across_all_references = 0 321 | for ref_s in reference_sentences: 322 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 323 | ref_s) 324 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 325 | 326 | 327 | def rouge(hypotheses, references): 328 | """Calculates average rouge scores for a list of hypotheses and 329 | references""" 330 | 331 | # Filter out hyps that are of 0 length 332 | # hyps_and_refs = zip(hypotheses, references) 333 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 334 | # hypotheses, references = zip(*hyps_and_refs) 335 | 336 | # Calculate ROUGE-1 F1, precision, recall scores 337 | rouge_1 = [ 338 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 339 | ] 340 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 341 | 342 | # Calculate ROUGE-2 F1, precision, recall scores 343 | rouge_2 = [ 344 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 345 | ] 346 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 347 | 348 | # Calculate ROUGE-L F1, precision, recall scores 349 | rouge_l = [ 350 | rouge_l_sentence_level([hyp], [ref]) 351 | for hyp, ref in zip(hypotheses, references) 352 | ] 353 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 354 | 355 | return { 356 | "rouge_1/f_score": rouge_1_f, 357 | "rouge_1/r_score": rouge_1_r, 358 | "rouge_1/p_score": rouge_1_p, 359 | "rouge_2/f_score": rouge_2_f, 360 | "rouge_2/r_score": rouge_2_r, 361 | "rouge_2/p_score": rouge_2_p, 362 | "rouge_l/f_score": rouge_l_f, 363 | "rouge_l/r_score": rouge_l_r, 364 | "rouge_l/p_score": rouge_l_p, 365 | } -------------------------------------------------------------------------------- /Run_CaKe.py: -------------------------------------------------------------------------------- 1 | from MALCopyDataset import * 2 | from MAL import * 3 | from torch import optim 4 | from trainers.DefaultTrainer import * 5 | import torch.backends.cudnn as cudnn 6 | 7 | if __name__ == '__main__': 8 | 9 | cudnn.enabled = True 10 | cudnn.benchmark = True 11 | cudnn.deterministic = False 12 | print(torch.__version__) 13 | print(torch.version.cuda) 14 | print(cudnn.version()) 15 | 16 | init_seed(123456) 17 | 18 | data_path = 'dataset/holl.256/' 19 | 20 | output_path = 'log_mal/holl.256/' 21 | 22 | 23 | src_vocab2id, src_id2vocab, src_id2freq = load_vocab(data_path + 'holl_input_output.256.vocab', t=10) 24 | tgt_vocab2id, tgt_id2vocab, tgt_id2freq = src_vocab2id, src_id2vocab, src_id2freq 25 | 26 | train_dataset = MALDataset([data_path + 'holl-train.256.json'], src_vocab2id, tgt_vocab2id) 27 | dev_dataset = MALDataset([data_path + 'holl-dev.256.json'], src_vocab2id, tgt_vocab2id) 28 | test_dataset = MALDataset([data_path + 'holl-test.256.json'], src_vocab2id, tgt_vocab2id) 29 | 30 | 31 | # env = Environment(128, 128, 256) 32 | encoder=Encoder(len(src_vocab2id), 128, 256) 33 | selector =Selector(128, 256, len(tgt_vocab2id)) 34 | generator = Generator(128, 256, len(tgt_vocab2id)) 35 | model=MAL(encoder, selector, generator, None, src_id2vocab, src_vocab2id, tgt_id2vocab, tgt_vocab2id, max_dec_len=50, beam_width=1) 36 | init_params(model) 37 | 38 | # env_optimizer = optim.Adam(filter(lambda x: x.requires_grad, env.parameters())) 39 | # selector_optimizer = optim.Adam(filter(lambda x: x.requires_grad, selector.parameters())) 40 | # generator_optimizer = optim.Adam(filter(lambda x: x.requires_grad, generator.parameters())) 41 | agent_optimizer = optim.Adam(filter(lambda x: x.requires_grad, list(selector.parameters()) + list(generator.parameters()))) 42 | model_optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters())) 43 | 44 | batch_size = 8 45 | # batch_size = 32 46 | 47 | trainer = DefaultTrainer(model) 48 | 49 | for i in range(100): 50 | trainer.train_epoch('mle_train', train_dataset, train_collate_fn, batch_size, i, model_optimizer) 51 | rouges = trainer.test(dev_dataset, test_collate_fn, batch_size, i, output_path=output_path) 52 | rouges = trainer.test(test_dataset, test_collate_fn, batch_size, 100+i, output_path=output_path) 53 | trainer.serialize(i, output_path=output_path) -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | import codecs 6 | from Constants import * 7 | from torch.distributions.categorical import * 8 | import torch.nn.functional as F 9 | from modules.Utils import * 10 | 11 | def get_ms(): 12 | return time.time() * 1000 13 | 14 | 15 | def init_seed(seed=None): 16 | if seed is None: 17 | seed = int(get_ms() // 1000) 18 | 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | random.seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | def importance_sampling(prob,topk): 26 | m = Categorical(logits=prob) 27 | indices = m.sample((topk,)).transpose(0,1) # batch, topk 28 | 29 | values = prob.gather(1, indices) 30 | return values, indices 31 | 32 | def sequence_mask(lengths, max_len=None): 33 | """ 34 | Creates a boolean mask from sequence lengths. 35 | """ 36 | batch_size = lengths.numel() 37 | max_len = max_len or lengths.max() 38 | mask= (torch.arange(0, max_len) 39 | .type_as(lengths) 40 | .repeat(batch_size, 1) 41 | .lt(lengths.unsqueeze(1))) 42 | if torch.cuda.is_available(): 43 | mask=mask.cuda() 44 | return mask 45 | 46 | def start_end_mask(starts, ends, max_len): 47 | batch_size=len(starts) 48 | mask = torch.arange(1, max_len + 1) 49 | if torch.cuda.is_available(): 50 | mask = mask.cuda() 51 | mask = mask.unsqueeze(0).expand(batch_size, -1) 52 | mask1 = mask >= starts.unsqueeze(1).expand_as(mask) 53 | mask2 = mask <= ends.unsqueeze(1).expand_as(mask) 54 | mask = (mask1 * mask2) 55 | return mask 56 | 57 | 58 | def decode_to_end(model, data, max_target_length=None, schedule_rate=1, softmax=False): 59 | tgt = data['output'] 60 | batch_size = tgt.size(0) 61 | if max_target_length is None: 62 | max_target_length = tgt.size(1) 63 | 64 | encode_outputs = model.encode(data) 65 | init_decoder_states = model.init_decoder_states(data, encode_outputs) 66 | 67 | decoder_input = new_tensor([BOS] * batch_size, requires_grad=False) 68 | 69 | prob = torch.ones((batch_size,)) * schedule_rate 70 | if torch.cuda.is_available(): 71 | prob=prob.cuda() 72 | 73 | all_gen_outputs = list() 74 | all_decode_outputs = list() 75 | decoder_states = init_decoder_states 76 | 77 | for t in range(max_target_length): 78 | # decoder_outputs, decoder_states,... 79 | decode_outputs = model.decode( 80 | data, decoder_input, decoder_states, encode_outputs 81 | ) 82 | # decoder_outputs=decode_outputs[0] 83 | decoder_states = decode_outputs[1] #[sel_d ecode_output[0],gen_decode_output[0]], gen_decode_output[1] 84 | 85 | output = model.generate(data, decode_outputs, softmax=softmax) 86 | 87 | all_gen_outputs.append(output.unsqueeze(0)) 88 | all_decode_outputs.append(decode_outputs) 89 | 90 | if schedule_rate >=1: 91 | decoder_input = tgt[:, t] 92 | elif schedule_rate<=0: 93 | probs, ids = model.to_word(data, output, 1) 94 | decoder_input = model.generation_to_decoder_input(data, ids[:, 0]) 95 | else: 96 | probs, ids = model.to_word(data, output, 1) 97 | indices = model.generation_to_decoder_input(data, ids[:, 0]) 98 | 99 | draws = torch.bernoulli(prob).long() 100 | decoder_input = tgt[:, t] * draws + indices * (1 - draws) 101 | 102 | all_gen_outputs = torch.cat(all_gen_outputs, dim=0).transpose(0, 1).contiguous() 103 | 104 | return encode_outputs, init_decoder_states, all_decode_outputs, all_gen_outputs 105 | 106 | def randomk(gen_output, k=5): 107 | gen_output[:, PAD] = -float('inf') 108 | gen_output[:, BOS] = -float('inf') 109 | gen_output[:, UNK] = -float('inf') 110 | values, indices = importance_sampling(gen_output, k) 111 | # words=[[tgt_id2vocab[id.item()] for id in one] for one in indices] 112 | return values, indices 113 | 114 | def topk(gen_output, k=5): 115 | gen_output[:, PAD] = 0 116 | gen_output[:, BOS] = 0 117 | gen_output[:, UNK] = 0 118 | if k>1: 119 | values, indices = torch.topk(gen_output, k, dim=1, largest=True, 120 | sorted=True, out=None) 121 | else: 122 | values, indices = torch.max(gen_output, dim=1, keepdim=True) 123 | return values, indices 124 | 125 | 126 | def copy_topk(gen_output, vocab_map, vocab_overlap, k=5): 127 | vocab=gen_output[:, :vocab_map.size(-1)] 128 | dy_vocab=gen_output[:, vocab_map.size(-1):] 129 | 130 | vocab=vocab+torch.bmm(dy_vocab.unsqueeze(1), vocab_map).squeeze(1) 131 | dy_vocab=dy_vocab*vocab_overlap 132 | 133 | gen_output=torch.cat([vocab, dy_vocab], dim=-1) 134 | return topk(gen_output, k) 135 | 136 | # def copy_topk(gen_output,tgt_id2vocab,dyn_id2vocabs, topk=5): 137 | # values, indices = torch.topk(gen_output[:, :len(tgt_id2vocab)], topk + 3, dim=1, largest=True, 138 | # sorted=True, out=None) 139 | # 140 | # copy_gen_output = gen_output[:, len(tgt_id2vocab):] 141 | # 142 | # k_values = [] 143 | # k_indices = [] 144 | # words = [] 145 | # for b in range(indices.size(0)): 146 | # temp = dict() 147 | # for i in range(indices.size(1)): 148 | # id = indices[b, i].item() 149 | # if id == PAD or id == UNK or id == BOS: 150 | # continue 151 | # w = tgt_id2vocab[id] 152 | # temp[w] = (id, values[b, i].item()) 153 | # if len(temp) == topk: 154 | # break 155 | # 156 | # for i in range(copy_gen_output.size(1)): 157 | # if i >= len(dyn_id2vocabs[b]): 158 | # continue 159 | # w = dyn_id2vocabs[b][i] 160 | # if w == PAD_WORD or w == BOS_WORD: 161 | # continue 162 | # 163 | # if w not in temp: 164 | # # temp[w] = (tgt_vocab2id.get(w, UNK), copy_gen_output[b, i].item()) 165 | # temp[w] = (i + len(tgt_id2vocab), copy_gen_output[b, i].item()) 166 | # else: 167 | # temp[w] = (temp[w][0], temp[w][1] + copy_gen_output[b, i].item()) 168 | # 169 | # k_items = sorted(temp.items(), key=lambda d: d[1][1], reverse=True)[:topk] 170 | # words.append([i[0] for i in k_items]) 171 | # k_indices.append(new_tensor([[i[1][0] for i in k_items]])) 172 | # k_values.append(new_tensor([[i[1][1] for i in k_items]])) 173 | # indices = torch.cat(k_indices, dim=0) 174 | # values = torch.cat(k_values, dim=0) 175 | # 176 | # return values, indices 177 | 178 | # def copy_topk(gen_output,tgt_id2vocab,tgt_vocab2id, dyn_id2vocabs, topk=5): 179 | # with torch.no_grad(): 180 | # values, indices = torch.topk(gen_output[:, :len(tgt_id2vocab)], topk+3, dim=1, largest=True, 181 | # sorted=True, out=None) 182 | # 183 | # copy_gen_output = gen_output[:, len(tgt_id2vocab):] 184 | # 185 | # k_values = [] 186 | # k_indices = [] 187 | # words = [] 188 | # for b in range(indices.size(0)): 189 | # temp = dict() 190 | # for i in range(indices.size(1)): 191 | # id = indices[b, i].item() 192 | # if id==PAD or id==UNK or id==BOS: 193 | # continue 194 | # w = tgt_id2vocab[id] 195 | # temp[w] = (id, values[b, i].item()) 196 | # if len(temp)==topk: 197 | # break 198 | # 199 | # for i in range(copy_gen_output.size(1)): 200 | # if i>=len(dyn_id2vocabs[b]): 201 | # continue 202 | # w = dyn_id2vocabs[b][i] 203 | # if w == PAD_WORD or w==BOS_WORD: 204 | # continue 205 | # 206 | # if w not in temp: 207 | # # temp[w] = (tgt_vocab2id.get(w, UNK), copy_gen_output[b, i].item()) 208 | # temp[w] = (i+len(tgt_id2vocab), copy_gen_output[b, i].item()) 209 | # else: 210 | # temp[w] = (temp[w][0], temp[w][1] + copy_gen_output[b, i].item()) 211 | # 212 | # k_items = sorted(temp.items(), key=lambda d: d[1][1],reverse=True)[:topk] 213 | # words.append([i[0] for i in k_items]) 214 | # k_indices.append(torch.tensor([[i[1][0] for i in k_items]])) 215 | # k_values.append(torch.tensor([[i[1][1] for i in k_items]])) 216 | # indices = torch.cat(k_indices, dim=0) 217 | # values = torch.cat(k_values, dim=0) 218 | # 219 | # return words, values, indices 220 | 221 | def to_sentence(batch_indices, id2vocab): 222 | batch_size=len(batch_indices) 223 | summ=list() 224 | for i in range(batch_size): 225 | indexes=batch_indices[i] 226 | text_summ2 = [] 227 | for index in indexes: 228 | index = index.item() 229 | w = id2vocab[index] 230 | if w == BOS_WORD or w == PAD_WORD: 231 | continue 232 | if w == EOS_WORD: 233 | break 234 | text_summ2.append(w) 235 | if len(text_summ2)==0: 236 | text_summ2.append(UNK_WORD) 237 | summ.append(text_summ2) 238 | return summ 239 | 240 | def to_copy_sentence(data, batch_indices,tgt_id2vocab, dyn_id2vocab_map): 241 | ids=data['id'] 242 | batch_size=len(batch_indices) 243 | summ=list() 244 | for i in range(batch_size): 245 | indexes=batch_indices[i] 246 | text_summ2 = [] 247 | dyn_id2vocab=dyn_id2vocab_map[ids[i].item()] 248 | for index in indexes: 249 | index = index.item() 250 | if index < len(tgt_id2vocab): 251 | w = tgt_id2vocab[index] 252 | elif index - len(tgt_id2vocab) in dyn_id2vocab: 253 | w = dyn_id2vocab[index - len(tgt_id2vocab)] 254 | else: 255 | w = PAD_WORD 256 | 257 | if w == BOS_WORD or w == PAD_WORD: 258 | continue 259 | 260 | if w == EOS_WORD: 261 | break 262 | 263 | text_summ2.append(w) 264 | 265 | if len(text_summ2)==0: 266 | text_summ2.append(UNK_WORD) 267 | 268 | summ.append(text_summ2) 269 | return summ 270 | 271 | def position(seq): 272 | pos=new_tensor([i+1 for i in range(seq.size(1))],requires_grad=False).long() 273 | pos=pos.repeat(seq.size(0),1) 274 | pos=pos.mul(seq.ne(PAD).long()).long() 275 | return pos -------------------------------------------------------------------------------- /bleu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2017 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BLEU metric implementation. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import os 24 | import re 25 | import subprocess 26 | import tempfile 27 | import numpy as np 28 | 29 | from six.moves import urllib 30 | import tensorflow as tf 31 | 32 | 33 | def moses_multi_bleu(hypotheses, references, lowercase=False): 34 | """Calculate the bleu score for hypotheses and references 35 | using the MOSES ulti-bleu.perl script. 36 | 37 | Args: 38 | hypotheses: A numpy array of strings where each string is a single example. 39 | references: A numpy array of strings where each string is a single example. 40 | lowercase: If true, pass the "-lc" flag to the multi-bleu script 41 | 42 | Returns: 43 | The BLEU score as a float32 value. 44 | """ 45 | 46 | if np.size(hypotheses) == 0: 47 | return np.float32(0.0) 48 | 49 | # Get MOSES multi-bleu script 50 | try: 51 | multi_bleu_path, _ = urllib.request.urlretrieve( 52 | "https://raw.githubusercontent.com/moses-smt/mosesdecoder/" 53 | "master/scripts/generic/multi-bleu.perl") 54 | os.chmod(multi_bleu_path, 0o755) 55 | except: #pylint: disable=W0702 56 | tf.logging.info("Unable to fetch multi-bleu.perl script, using local.") 57 | metrics_dir = os.path.dirname(os.path.realpath(__file__)) 58 | bin_dir = os.path.abspath(os.path.join(metrics_dir, "..", "..", "bin")) 59 | multi_bleu_path = os.path.join(bin_dir, "tools/multi-bleu.perl") 60 | 61 | # Dump hypotheses and references to tempfiles 62 | hypothesis_file = tempfile.NamedTemporaryFile() 63 | hypothesis_file.write("\n".join(hypotheses).encode("utf-8")) 64 | hypothesis_file.write(b"\n") 65 | hypothesis_file.flush() 66 | reference_file = tempfile.NamedTemporaryFile() 67 | reference_file.write("\n".join(references).encode("utf-8")) 68 | reference_file.write(b"\n") 69 | reference_file.flush() 70 | 71 | # Calculate BLEU using multi-bleu script 72 | with open(hypothesis_file.name, "r") as read_pred: 73 | bleu_cmd = [multi_bleu_path] 74 | if lowercase: 75 | bleu_cmd += ["-lc"] 76 | bleu_cmd += [reference_file.name] 77 | try: 78 | bleu_out = subprocess.check_output( 79 | bleu_cmd, stdin=read_pred, stderr=subprocess.STDOUT) 80 | bleu_out = bleu_out.decode("utf-8") 81 | bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1) 82 | bleu_score = float(bleu_score) 83 | except subprocess.CalledProcessError as error: 84 | if error.output is not None: 85 | tf.logging.warning("multi-bleu.perl script returned non-zero exit code") 86 | tf.logging.warning(error.output) 87 | bleu_score = np.float32(0.0) 88 | 89 | # Close temp files 90 | hypothesis_file.close() 91 | reference_file.close() 92 | 93 | return np.float32(bleu_score) -------------------------------------------------------------------------------- /data/CopyDataset.py: -------------------------------------------------------------------------------- 1 | from Constants import * 2 | from data.Dataset import * 3 | 4 | class CopyTrainDataset(TrainDataset): 5 | 6 | def __init__(self, files,src_vocab2id,tgt_vocab2id,srcs=None,tgts=None,max_src=80, max_tgt=40,n=1E10): 7 | super(CopyTrainDataset, self).__init__(files,src_vocab2id,tgt_vocab2id,srcs=srcs,tgts=tgts,max_src=max_src,max_tgt=max_tgt,n=n) 8 | 9 | def __getitem__(self, index): 10 | src = torch.tensor([self.src_vocab2id.get(w, self.src_vocab2id.get(UNK_WORD)) for w in self.srcs[index]], 11 | requires_grad=False).long() 12 | tgt = torch.tensor([self.tgt_vocab2id.get(w, self.tgt_vocab2id.get(UNK_WORD)) for w in self.tgts[index]] + [ 13 | self.tgt_vocab2id.get(EOS_WORD)], requires_grad=False).long() 14 | dyn_vocab2id, dyn_id2vocab = build_vocab(self.srcs[index]) 15 | 16 | src_copy = torch.tensor([dyn_vocab2id.get(w, dyn_vocab2id.get(UNK_WORD)) for w in self.tgts[index]] + [ 17 | dyn_vocab2id.get(EOS_WORD)], requires_grad=False).long() 18 | src_map = build_words_vocab_map(self.srcs[index], dyn_vocab2id) 19 | 20 | vocab_map, vocab_overlap = build_vocab_vocab_map(dyn_id2vocab, self.tgt_vocab2id) 21 | 22 | return [self.ids[index],src, tgt, src_copy, src_map,dyn_id2vocab, vocab_map, vocab_overlap] 23 | 24 | 25 | def train_collate_fn(data): 26 | id,src,tgt, tgt_copy, src_map,dyn_id2vocab, vocab_map, vocab_overlap = zip(*data) 27 | 28 | dyn_id2vocab_map={} 29 | for i in range(len(id)): 30 | dyn_id2vocab_map[id[i]]=dyn_id2vocab[i] 31 | 32 | src = merge1D(src) 33 | tgt = merge1D(tgt) 34 | tgt_copy = merge1D(tgt_copy) 35 | src_map = merge2D(src_map) 36 | id = torch.tensor(id, requires_grad=False) 37 | vocab_map = merge2D(vocab_map) 38 | vocab_overlap = merge1D(vocab_overlap) 39 | 40 | return {'id':id, 'input':src, 'output':tgt, 'input_copy':tgt_copy, 'input_map':src_map, 'input_dyn_vocab':dyn_id2vocab_map, 'vocab_map':vocab_map, 'vocab_overlap':vocab_overlap} 41 | 42 | 43 | class CopyTestDataset(TestDataset): 44 | 45 | def __init__(self, files,src_vocab2id,tgt_vocab2id,srcs=None,tgts=None,max_src=80, n=1E10): 46 | super(CopyTestDataset, self).__init__(files,src_vocab2id,tgt_vocab2id,srcs=srcs,tgts=tgts,max_src=max_src,n=n) 47 | 48 | def __getitem__(self, index): 49 | src = torch.tensor([self.src_vocab2id.get(w, self.src_vocab2id.get(UNK_WORD)) for w in self.srcs[index]], 50 | requires_grad=False).long() 51 | dyn_vocab2id, dyn_id2vocab = build_vocab(self.srcs[index]) 52 | 53 | src_map = build_words_vocab_map(self.srcs[index], dyn_vocab2id) 54 | 55 | vocab_map, vocab_overlap = build_vocab_vocab_map(dyn_id2vocab, self.tgt_vocab2id) 56 | 57 | return [self.ids[index], src, src_map, dyn_id2vocab, vocab_map, vocab_overlap] 58 | 59 | def test_collate_fn(data): 60 | id, src, src_map,dyn_id2vocab, vocab_map, vocab_overlap = zip(*data) 61 | 62 | dyn_id2vocab_map={} 63 | for i in range(len(id)): 64 | dyn_id2vocab_map[id[i]]=dyn_id2vocab[i] 65 | 66 | src = merge1D(src) 67 | src_map = merge2D(src_map) 68 | id = torch.tensor(id, requires_grad=False) 69 | vocab_map=merge2D(vocab_map) 70 | vocab_overlap = merge1D(vocab_overlap) 71 | 72 | return {'id':id, 'input':src,'input_map':src_map, 'input_dyn_vocab':dyn_id2vocab_map, 'vocab_map':vocab_map, 'vocab_overlap':vocab_overlap} 73 | 74 | -------------------------------------------------------------------------------- /data/Dataset.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from torch.utils.data import Dataset 3 | from Constants import * 4 | from data.Utils import * 5 | import os 6 | import re 7 | 8 | class TrainDataset(Dataset): 9 | 10 | def __init__(self, files,src_vocab2id,tgt_vocab2id,srcs=None,tgts=None,max_src=80, max_tgt=40,n=1E10): 11 | super(TrainDataset, self).__init__() 12 | self.ids=list() 13 | self.srcs=list() 14 | self.tgts=list() 15 | self.src_vocab2id=src_vocab2id 16 | self.tgt_vocab2id=tgt_vocab2id 17 | self.files=files 18 | self.n=n 19 | self.max_src=max_src 20 | self.max_tgt=max_tgt 21 | if srcs is None: 22 | self.load() 23 | else: 24 | self.srcs=srcs 25 | self.tgts=tgts 26 | self.len=len(self.srcs) 27 | 28 | def load(self): 29 | id=0 30 | with codecs.open(self.files[0], encoding='utf-8') as f: 31 | for line in f: 32 | if len(self.srcs) >= self.n: 33 | break 34 | temp = line.strip('\n').strip('\r').lower().replace(' ','') 35 | words = re.split('\s', temp) 36 | self.srcs.append(words[:min(len(words), self.max_src)]) 37 | self.ids.append(id) 38 | id = id + 1 39 | with codecs.open(self.files[1], encoding='utf-8') as f: 40 | for line in f: 41 | if len(self.tgts) >= self.n: 42 | break 43 | temp = line.strip('\n').strip('\r').lower().replace(' ','') 44 | words = re.split('\s', temp) 45 | self.tgts.append(words[:min(len(words), self.max_tgt)]) 46 | self.len = len(self.srcs) 47 | print('data size: ', self.len) 48 | 49 | 50 | def __getitem__(self, index): 51 | src = torch.tensor([self.src_vocab2id.get(w, self.src_vocab2id.get(UNK_WORD)) for w in self.srcs[index]] + [ 52 | self.src_vocab2id.get(EOS_WORD)], requires_grad=False).long() 53 | tgt = torch.tensor([self.tgt_vocab2id.get(w, self.tgt_vocab2id.get(UNK_WORD)) for w in self.tgts[index]] + [ 54 | self.tgt_vocab2id.get(EOS_WORD)], requires_grad=False).long() 55 | return [self.ids[index], src, tgt] 56 | 57 | def __len__(self): 58 | return self.len 59 | 60 | def input(self,id): 61 | return self.srcs[id] 62 | 63 | def output(self,id): 64 | return self.tgts[id] 65 | 66 | def train_collate_fn(data): 67 | id,src,tgt = zip(*data) 68 | 69 | src = merge1D(src) 70 | tgt = merge1D(tgt) 71 | id = torch.tensor(id, requires_grad=False) 72 | 73 | return {'id':id, 'input':src, 'output':tgt} 74 | 75 | 76 | class TestDataset(Dataset): 77 | 78 | def __init__(self, files,src_vocab2id,tgt_vocab2id,srcs=None,tgts=None,max_src=80, n=1E10): 79 | super(TestDataset, self).__init__() 80 | self.ids=list() 81 | self.srcs = list() 82 | self.tgts = list() 83 | self.src_vocab2id=src_vocab2id 84 | self.tgt_vocab2id=tgt_vocab2id 85 | self.files=files 86 | self.n=n 87 | self.max_src=max_src 88 | if srcs is None: 89 | self.load() 90 | else: 91 | self.srcs=srcs 92 | self.tgts=tgts 93 | self.len=len(self.srcs) 94 | 95 | def load(self): 96 | id=0 97 | with codecs.open(self.files[0], encoding='utf-8') as f: 98 | for line in f: 99 | if len(self.srcs) >= self.n: 100 | break 101 | temp = line.strip('\n').strip('\r').lower().replace(' ','') 102 | words = re.split('\s', temp) 103 | self.srcs.append(words[:min(len(words), self.max_src)]) 104 | self.ids.append(id) 105 | id=id+1 106 | 107 | for file in self.files[1:]: 108 | temp_tgt=list() 109 | with codecs.open(file, encoding='utf-8') as f: 110 | for line in f: 111 | if len(self.tgts)>=self.n: 112 | break 113 | temp = line.strip('\n').strip('\r').lower().replace(' ','') 114 | temp_tgt.append(temp) 115 | if len(self.tgts)==0: 116 | for i in range(len(temp_tgt)): 117 | self.tgts.append([os.linesep.join([sent.strip() for sent in re.split(r'|', temp_tgt[i])]).strip('\n').strip('\r')]) 118 | else: 119 | for i in range(len(temp_tgt)): 120 | self.tgts[i].append(os.linesep.join([sent.strip() for sent in re.split(r'|', temp_tgt[i])]).strip('\n').strip('\r')) 121 | 122 | self.len=len(self.srcs) 123 | print('data size: ',self.len) 124 | 125 | def __getitem__(self, index): 126 | src = torch.tensor([self.src_vocab2id.get(w, self.src_vocab2id.get(UNK_WORD)) for w in self.srcs[index]]+ [ 127 | self.src_vocab2id.get(EOS_WORD)], requires_grad=False).long() 128 | 129 | return [self.ids[index], src] 130 | 131 | def __len__(self): 132 | return self.len 133 | 134 | def input(self,id): 135 | return self.srcs[id] 136 | 137 | def output(self,id): 138 | return self.tgts[id] 139 | 140 | def test_collate_fn(data): 141 | id, src = zip(*data) 142 | 143 | src = merge1D(src) 144 | id = torch.tensor(id, requires_grad=False) 145 | 146 | return {'id':id, 'input':src} 147 | 148 | -------------------------------------------------------------------------------- /data/Utils.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from Constants import * 3 | import torch 4 | import random 5 | import string 6 | from torch.nn.init import * 7 | 8 | class SampleCache(object): 9 | 10 | def __init__(self, size=100): 11 | super(SampleCache, self).__init__() 12 | self.cache={} 13 | self.size=size 14 | 15 | def put(self,ids,samples): 16 | ids=ids.cpu() 17 | samples=samples.cpu() 18 | for i in range(len(ids)): 19 | id=ids[i].item() 20 | if id not in self.cache: 21 | self.cache[id]=[] 22 | for sent in samples: 23 | if len(self.cache[id])>=self.size: 24 | j=random.randint(0,self.size+1) 25 | if j=word_len: 155 | break 156 | tmp[i][j] = char2id[ch] if ch in char2id else UNK 157 | return tmp 158 | 159 | char2id=dict(vocab2id) 160 | id2char=dict(id2vocab) 161 | def load_char(): 162 | chars=list(string.printable) 163 | for c in chars: 164 | id = len(char2id) 165 | char2id[c] = id 166 | id2char[id] = c 167 | 168 | 169 | def load_vocab(vocab_file,t=0): 170 | thisvocab2id = dict(vocab2id) 171 | thisid2vocab = dict(id2vocab) 172 | id2freq = {} 173 | 174 | sum_freq = 0 175 | with codecs.open(vocab_file, encoding='utf-8') as f: 176 | for line in f: 177 | try: 178 | name,freq = line.strip('\n').strip('\r').split('\t') 179 | except Exception: 180 | continue 181 | if int(freq)>=t: 182 | id=len(thisvocab2id) 183 | thisvocab2id[name] = id 184 | thisid2vocab[id] = name 185 | id2freq[id]=int(freq) 186 | sum_freq+=int(freq) 187 | id2freq[0] = sum_freq/len(id2freq) 188 | id2freq[1] = id2freq[0] 189 | id2freq[2] = id2freq[0] 190 | id2freq[3] = id2freq[0] 191 | 192 | print('item size: ', len(thisvocab2id)) 193 | 194 | return thisvocab2id, thisid2vocab, id2freq 195 | 196 | 197 | def load_embedding(src_vocab2id, file): 198 | model=dict() 199 | with codecs.open(file, encoding='utf-8') as f: 200 | for line in f: 201 | splitLine = line.split() 202 | word = splitLine[0] 203 | model[word] = torch.tensor([float(val) for val in splitLine[1:]]) 204 | matrix = torch.zeros((len(src_vocab2id), 100)) 205 | xavier_uniform_(matrix) 206 | for word in model: 207 | if word in src_vocab2id: 208 | matrix[src_vocab2id[word]]=model[word] 209 | return matrix -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/repozhang/bbc-pre-selection/9fb625623b84f948d48d507a11f305f336e8289f/data/__init__.py -------------------------------------------------------------------------------- /dataset/data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/repozhang/bbc-pre-selection/9fb625623b84f948d48d507a11f305f336e8289f/dataset/data.zip -------------------------------------------------------------------------------- /modules/Attentions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class BilinearAttention(nn.Module): 7 | def __init__(self, query_size, key_size, hidden_size, dropout=0.5, coverage=False, gumbel=False, temperature=100): 8 | super().__init__() 9 | self.linear_key = nn.Linear(key_size, hidden_size, bias=False) 10 | self.linear_query = nn.Linear(query_size, hidden_size, bias=True) 11 | self.v = nn.Linear(hidden_size, 1, bias=False) 12 | self.hidden_size=hidden_size 13 | # self.dropout = nn.Dropout(dropout) 14 | # self.softmax = nn.Softmax(dim=1) 15 | self.tanh = nn.Tanh() 16 | if coverage: 17 | self.linear_coverage = nn.Linear(1, hidden_size, bias=False) 18 | self.gumbel=gumbel 19 | self.temperature=temperature 20 | 21 | def score(self, query, key, query_mask=None, key_mask=None, mask=None, sum_attention=None): 22 | batch_size, key_len, key_size = key.size() 23 | batch_size, query_len, query_size = query.size() 24 | attn=self.unnormalized_score(query, key, key_mask, mask, sum_attention) 25 | 26 | # attn=self.softmax(attn.view(-1, key_len)).view(batch_size, query_len, key_len) 27 | if self.gumbel: 28 | if self.training: 29 | attn =F.gumbel_softmax(attn.view(attn.size(0),-1), self.temperature, hard=False).view(attn.size()) 30 | else: 31 | attn_ = torch.zeros(attn.size()) 32 | if torch.cuda.is_available(): 33 | attn_ = attn_.cuda() 34 | attn = attn_.scatter_(2, attn.argmax(dim=2).unsqueeze(2), 1) 35 | else: 36 | attn=F.softmax(attn, dim=2) 37 | if query_mask is not None: 38 | attn = attn.masked_fill(1-query_mask.unsqueeze(2).expand(batch_size, query_len, key_len), 0) 39 | # attn = self.dropout(attn) 40 | 41 | return attn 42 | 43 | 44 | def unnormalized_score(self, query, key, key_mask=None, mask=None, sum_attention=None): 45 | batch_size, key_len, key_size = key.size() 46 | batch_size, query_len, query_size = query.size() 47 | 48 | wq = self.linear_query(query.view(-1, query_size)) 49 | wq = wq.view(batch_size, query_len, 1, self.hidden_size) 50 | wq = wq.expand(batch_size, query_len, key_len, self.hidden_size) 51 | 52 | uh = self.linear_key(key.view(-1, key_size)) 53 | uh = uh.view(batch_size, 1, key_len, self.hidden_size) 54 | uh = uh.expand(batch_size, query_len, key_len, self.hidden_size) 55 | 56 | wuc = wq + uh 57 | if sum_attention is not None: 58 | batch_size, key_len=sum_attention.size() 59 | wc = self.linear_coverage(sum_attention.view(-1,1)).view(batch_size, 1, key_len, self.hidden_size) 60 | wc = wc.expand(batch_size, query_len, key_len, self.hidden_size) 61 | wuc = wuc + wc 62 | 63 | wquh = self.tanh(wuc) 64 | 65 | attn = self.v(wquh.view(-1, self.hidden_size)).view(batch_size, query_len, key_len) 66 | 67 | if key_mask is not None: 68 | attn = attn.masked_fill(1-key_mask.unsqueeze(1).expand(batch_size, query_len, key_len), -float('inf')) 69 | 70 | if mask is not None: 71 | attn = attn.masked_fill(1 - mask, -float('inf')) 72 | return attn 73 | 74 | def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None, sum_attention=None): 75 | 76 | attn = self.score(query, key, query_mask=query_mask, key_mask=key_mask, mask=mask, sum_attention=sum_attention) 77 | attn_value = torch.bmm(attn,value) 78 | 79 | return attn_value, attn 80 | 81 | class ScaledDotProductAttention(nn.Module): 82 | ''' Scaled Dot-Product Attention ''' 83 | 84 | def __init__(self, temperature, dropout=0.5): 85 | super().__init__() 86 | self.temperature = temperature 87 | # self.dropout = nn.Dropout(dropout) 88 | # self.softmax = nn.Softmax(dim=1) 89 | 90 | def score(self, query, key, query_mask=None, key_mask=None, mask=None): 91 | batch_size, key_len, key_size = key.size() 92 | batch_size, query_len, query_size = query.size() 93 | attn = self.unnormalized_score(query, key, key_mask, mask) 94 | 95 | # attn = self.softmax(attn.view(-1, key_len)).view(batch_size, query_len, key_len) 96 | attn=F.softmax(attn, dim=2) 97 | if query_mask is not None: 98 | attn = attn.masked_fill(1 - query_mask.unsqueeze(2).expand(batch_size, query_len, key_len), 0) 99 | # attn = self.dropout(attn) 100 | return attn 101 | 102 | def unnormalized_score(self, query, key, key_mask=None, mask=None): 103 | batch_size, key_len, key_size = key.size() 104 | batch_size, query_len, query_size = query.size() 105 | 106 | attn = torch.bmm(query, key.transpose(1, 2)) 107 | attn = attn / self.temperature 108 | 109 | if key_mask is not None: 110 | attn = attn.masked_fill(1-key_mask.unsqueeze(1).expand(batch_size, query_len, key_len), -float('inf')) 111 | 112 | if mask is not None: 113 | attn = attn.masked_fill(1 - mask, -float('inf')) 114 | 115 | return attn 116 | 117 | def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None): 118 | attn = self.score(query, key, query_mask=query_mask, key_mask=key_mask, mask=mask) 119 | attn_value = torch.bmm(attn, value) # attn: attention weight; value: vectors 120 | 121 | return attn_value, attn 122 | 123 | class MultiHeadAttention(nn.Module): 124 | ''' Multi-Head Attention module ''' 125 | 126 | def __init__(self, num_head, query_size, key_size, value_size, dropout=0.5, attention=None): 127 | super().__init__() 128 | 129 | self.num_head = num_head 130 | self.key_size = key_size 131 | self.value_size = value_size 132 | 133 | self.w_qs = nn.Linear(query_size, num_head * query_size) 134 | self.w_ks = nn.Linear(key_size, num_head * key_size) 135 | self.w_vs = nn.Linear(value_size, num_head * value_size) 136 | 137 | if attention is None: 138 | self.attention = ScaledDotProductAttention(temperature=np.power(key_size, 0.5)) 139 | else: 140 | self.attention =attention 141 | 142 | self.layer_norm = nn.LayerNorm(query_size) 143 | 144 | self.fc = nn.Linear(num_head * value_size, query_size) 145 | 146 | self.dropout = nn.Dropout(dropout) 147 | 148 | 149 | def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None): 150 | 151 | 152 | batch_size, query_len, query_size = query.size() 153 | batch_size, key_len, key_size = key.size() 154 | batch_size, value_len, value_size = value.size() 155 | 156 | num_head= self.num_head 157 | 158 | residual = query 159 | 160 | query = self.w_qs(query).view(batch_size, query_len, num_head, query_size) 161 | key = self.w_ks(key).view(batch_size, key_len, num_head, key_size) 162 | value = self.w_vs(value).view(batch_size, value_len, num_head, value_size) 163 | 164 | query = query.permute(2, 0, 1, 3).contiguous().view(-1, query_len, query_size) 165 | key = key.permute(2, 0, 1, 3).contiguous().view(-1, key_len, key_size) 166 | value = value.permute(2, 0, 1, 3).contiguous().view(-1, value_len, value_size) 167 | 168 | if query_mask is not None: 169 | query_mask = query_mask.unsqueeze(0).expand(num_head, batch_size, query_len).contiguous().view(-1,query_len) 170 | if key_mask is not None: 171 | # print(key_mask.size(), key_len) 172 | key_mask = key_mask.unsqueeze(0).expand(num_head, batch_size, key_len).contiguous().view(-1,key_len) 173 | if mask is not None: 174 | mask = mask.unsqueeze(0).expand(num_head, batch_size, query_len, key_len).contiguous().view(-1, query_len, key_len) 175 | 176 | output, attn = self.attention(query, key, value, query_mask=query_mask, key_mask=key_mask, mask=mask) 177 | 178 | output = output.view(num_head, batch_size, query_len, value_size) 179 | output = output.permute(1, 2, 0, 3).contiguous().view(batch_size, query_len, -1) 180 | 181 | output = self.dropout(self.fc(output)) 182 | output = self.layer_norm(output + residual) 183 | 184 | return output, attn.view(num_head, batch_size, query_len, key_len) 185 | -------------------------------------------------------------------------------- /modules/Criterions.py: -------------------------------------------------------------------------------- 1 | from Constants import * 2 | import torch 3 | from modules.Attentions import * 4 | 5 | class RecallCriterion(nn.Module): 6 | def __init__(self, hidden_size): 7 | super(RecallCriterion, self).__init__() 8 | self.hidden_size=hidden_size 9 | 10 | def forward(self, y1, y2, mask1, mask2, reduction='mean'): 11 | y = torch.ones(mask2.size()) 12 | if torch.cuda.is_available(): 13 | y = y.cuda() 14 | recall=(y1*mask1.float().unsqueeze(2)).sum(dim=1, keepdim=True).expand(-1, y2.size(1), -1).contiguous() 15 | r_loss = F.cosine_embedding_loss(y2.view(-1, self.hidden_size), recall.view(-1, self.hidden_size), y.view(-1), reduction='none') 16 | r_loss = r_loss * mask2.float().view(-1) 17 | r_loss = r_loss.sum() / mask2.sum().float().detach() 18 | 19 | return r_loss 20 | 21 | class F1Criterion(nn.Module): 22 | def __init__(self, hidden_size): 23 | super(F1Criterion, self).__init__() 24 | self.hidden_size=hidden_size 25 | self.attn = BilinearAttention( 26 | query_size=hidden_size, key_size=hidden_size, hidden_size=hidden_size, dropout=0.5, coverage=False 27 | ) 28 | 29 | def forward(self, y1, y2, mask1, mask2, reduction='mean'): 30 | y = torch.ones(mask1.size()) 31 | if torch.cuda.is_available(): 32 | y = y.cuda() 33 | precision,_=self.attn(y1, y2, y2) 34 | p_loss=F.cosine_embedding_loss(y1.view(-1, self.hidden_size), precision.view(-1, self.hidden_size), y.view(-1), reduction='none') 35 | p_loss= p_loss * mask1.float().view(-1) 36 | p_loss= p_loss.sum() / mask1.sum().float().detach() 37 | 38 | y = torch.ones(mask2.size()) 39 | if torch.cuda.is_available(): 40 | y = y.cuda() 41 | recall, _ = self.attn(y2, y1, y1) 42 | r_loss = F.cosine_embedding_loss(y2.view(-1, self.hidden_size), recall.view(-1, self.hidden_size), y.view(-1), reduction='none') 43 | r_loss = r_loss * mask2.float().view(-1) 44 | r_loss = r_loss.sum() / mask2.sum().float().detach() 45 | 46 | return (p_loss+r_loss)/2 47 | 48 | 49 | class CopyCriterion(object): 50 | def __init__(self, tgt_vocab_size, force_copy=False, eps=1e-20): 51 | self.force_copy = force_copy 52 | self.eps = eps 53 | self.offset = tgt_vocab_size 54 | 55 | def __call__(self, gen_output, tgt, tgt_copy, reduction='mean'): 56 | copy_unk = tgt_copy.eq(UNK).float() 57 | copy_not_unk = tgt_copy.ne(UNK).float() 58 | target_unk = tgt.eq(UNK).float() 59 | target_not_unk = tgt.ne(UNK).float() 60 | target_not_pad=tgt.ne(PAD).float() 61 | 62 | # Copy probability of tokens in source 63 | if len(gen_output.size())>2: 64 | gen_output = gen_output.view(-1, gen_output.size(-1)) 65 | copy_p = gen_output.gather(1, tgt_copy.view(-1, 1) + self.offset).view(-1) 66 | # Set scores for unk to 0 and add eps 67 | copy_p = copy_p.mul(copy_not_unk.view(-1)) + self.eps 68 | # Get scores for tokens in target 69 | tgt_p = gen_output.gather(1, tgt.view(-1, 1)).view(-1) 70 | 71 | # Regular prob (no unks and unks that can't be copied) 72 | if not self.force_copy: 73 | # Add score for non-unks in target 74 | p = copy_p + tgt_p.mul(target_not_unk.view(-1)) 75 | # Add score for when word is unk in both align and tgt 76 | p = p + tgt_p.mul(copy_unk.view(-1)).mul(target_unk.view(-1)) 77 | else: 78 | # Forced copy. Add only probability for not-copied tokens 79 | p = copy_p + tgt_p.mul(copy_unk.view(-1)) 80 | 81 | p = p.log() 82 | 83 | # Drop padding. 84 | loss = -p.mul(target_not_pad.view(-1)) 85 | if reduction=='mean': 86 | return loss.sum()/target_not_pad.sum() 87 | elif reduction=='none': 88 | return loss.view(tgt.size()) 89 | 90 | class CoverageCriterion(object): 91 | def __init__(self, alpha=1): 92 | self.alpha = alpha 93 | 94 | def __call__(self, attentions, mask): 95 | 96 | sum_attentions=[attentions[0]] 97 | for i in range(len(attentions)-1): 98 | if i==0: 99 | sum_attentions.append(attentions[0]) 100 | else: 101 | sum_attentions.append(sum_attentions[-1] + attentions[i]) 102 | 103 | loss = torch.min(torch.cat(sum_attentions, dim=0), attentions[1:]).mul(mask.view(-1,1)) 104 | 105 | loss = self.alpha*loss.sum()/loss.size(0) 106 | 107 | return loss -------------------------------------------------------------------------------- /modules/Decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules.Attentions import * 4 | from modules.Utils import * 5 | from Constants import * 6 | from modules.PositionwiseFeedForward import * 7 | from Utils import * 8 | 9 | class GruDecoder(nn.Module): 10 | def __init__(self, embedding_size, hidden_size, tgt_vocab_size, embedding=None, num_layers=4, dropout=0.5): 11 | super(GruDecoder, self).__init__() 12 | 13 | # Keep for reference 14 | self.hidden_size = hidden_size 15 | self.embedding_size=embedding_size 16 | self.tgt_vocab_size = tgt_vocab_size 17 | 18 | if embedding is not None: 19 | self.embedding =embedding 20 | else: 21 | self.embedding = nn.Embedding(tgt_vocab_size, embedding_size, padding_idx=PAD) 22 | self.embedding_dropout = nn.Dropout(dropout) 23 | 24 | self.attn = BilinearAttention( 25 | query_size=hidden_size, key_size=2*hidden_size, hidden_size=hidden_size, dropout=dropout, coverage=False 26 | ) 27 | 28 | self.gru = nn.GRU(2*hidden_size+embedding_size, hidden_size, bidirectional=False, num_layers=num_layers, dropout=dropout) 29 | 30 | self.readout = nn.Linear(embedding_size + hidden_size + 2*hidden_size, hidden_size) 31 | 32 | def forward(self, tgt, state, enc_output, enc_mask=None): 33 | gru_state = state[0] 34 | # sum_attention= state[1] 35 | 36 | embedded = self.embedding(tgt) 37 | embedded = self.embedding_dropout(embedded) 38 | 39 | attn_context_1, attn=self.attn(gru_state[:,-1].unsqueeze(1), enc_output, enc_output, query_mask=None, key_mask=enc_mask) 40 | attn_context_1=attn_context_1.squeeze(1) 41 | 42 | gru_input = torch.cat((embedded, attn_context_1), dim=1) 43 | gru_output, gru_state=self.gru(gru_input.unsqueeze(0), gru_state.transpose(0,1)) 44 | # gru_output=gru_output.squeeze(0) 45 | gru_state=gru_state.transpose(0,1) 46 | 47 | attn_context, attn=self.attn(gru_state[:,-1].unsqueeze(1), enc_output, enc_output, query_mask=None, key_mask=enc_mask, sum_attention=None) 48 | attn=attn.squeeze(1) 49 | # sum_attention=sum_attention+attn 50 | attn_context = attn_context.squeeze(1) 51 | 52 | concat_output = torch.cat((embedded, gru_state[:,-1], attn_context), dim=1) 53 | 54 | feature_output=self.readout(concat_output) 55 | return feature_output, [gru_state], attn -------------------------------------------------------------------------------- /modules/Encoders.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from modules.Utils import * 3 | from Constants import * 4 | from modules.Attentions import * 5 | from Utils import * 6 | 7 | class GRUEncoder(nn.Module): 8 | def __init__(self, src_vocab_size, embedding_size, hidden_size, embedding_weight=None, num_layers=4, dropout=0.5): 9 | super(GRUEncoder, self).__init__() 10 | 11 | self.src_vocab_size = src_vocab_size 12 | self.embedding_size=embedding_size 13 | self.hidden_size = hidden_size 14 | self.num_layers=num_layers 15 | 16 | self.embedding = nn.Embedding(src_vocab_size, embedding_size, padding_idx=0, _weight=embedding_weight) 17 | self.embedding_dropout = nn.Dropout(dropout) 18 | self.gru = nn.GRU(embedding_size, hidden_size, num_layers=self.num_layers, bidirectional=True, dropout=dropout, batch_first=True) 19 | 20 | 21 | def forward(self, src, state=None): 22 | embedded = self.embedding(src) 23 | embedded =self.embedding_dropout(embedded) 24 | 25 | outputs, state=gru_forward(self.gru, embedded, src.ne(PAD).sum(dim=1),state) 26 | 27 | return outputs, state -------------------------------------------------------------------------------- /modules/Evaluations.py: -------------------------------------------------------------------------------- 1 | # from pythonrouge.pythonrouge import Pythonrouge 2 | import nltk.translate.bleu_score as bleu 3 | # from pyrouge import Rouge155 4 | from tempfile import * 5 | import os 6 | import codecs 7 | import logging 8 | # from pyrouge.utils import log 9 | 10 | # log.get_global_console_logger().setLevel(logging.WARNING) 11 | 12 | # def rouge(hypothesises, references): 13 | # dir=mktemp() 14 | # os.mkdir(dir) 15 | # sys_dir=os.path.join(dir, 'sys') 16 | # os.mkdir(sys_dir) 17 | # ref_dir = os.path.join(dir, 'ref') 18 | # os.mkdir(ref_dir) 19 | # 20 | # uppercase=['A','B','C','D','E','F','G','H','I','J'] 21 | # 22 | # for i in range(len(hypothesises)): 23 | # sys_file = codecs.open(os.path.join(sys_dir, 'sys.'+str(i)+'.txt'), "w", "utf-8") 24 | # sys_file.write(hypothesises[i].replace('<','___').replace('>','___')) 25 | # sys_file.close() 26 | # 27 | # for j in range(len(references[i])): 28 | # ref_file = codecs.open(os.path.join(ref_dir, 'ref.'+uppercase[j]+'.' + str(i) + '.txt'), "w", "utf-8") 29 | # ref_file.write(references[i][j].replace('<','___').replace('>','___')) 30 | # ref_file.close() 31 | # 32 | # r = Rouge155() 33 | # r.system_dir = sys_dir 34 | # r.model_dir = ref_dir 35 | # r.system_filename_pattern = 'sys.(\d+).txt' 36 | # r.model_filename_pattern = 'ref.[A-Z].#ID#.txt' 37 | # 38 | # scores = r.convert_and_evaluate() 39 | # # print(scores) 40 | # scores = r.output_to_dict(scores) 41 | # return scores 42 | 43 | # def rouge(hypothesises, references): 44 | # b_systems = [] 45 | # b_references = [] 46 | # for i in range(len(hypothesises)): 47 | # hypothesis = hypothesises[i] 48 | # reference = [[r] for r in references[i]] 49 | # b_systems.append([hypothesis]) 50 | # b_references.append(reference) 51 | # 52 | # rouge = Pythonrouge(summary_file_exist=False, 53 | # summary=b_systems, reference=b_references, 54 | # n_gram=2, ROUGE_SU4=True, ROUGE_L=True, ROUGE_W_Weight=1.2, 55 | # recall_only=False, stemming=True, stopwords=False, 56 | # word_level=True, length_limit=False, length=75, 57 | # use_cf=True, cf=95, scoring_formula='average', 58 | # resampling=True, samples=1000, favor=True, p=0.5) 59 | # scores = rouge.calc_score() 60 | # 61 | # print(scores) 62 | # return scores 63 | 64 | def distinct(self, hypothesises): 65 | scores = dict() 66 | unigram = set() 67 | unigram_count = 0 68 | bigram = set() 69 | bigram_count = 0 70 | for hypothesis in hypothesises: 71 | words = hypothesis.split(' ') 72 | unigram_count += len(words) 73 | for i in range(len(words)): 74 | unigram.add(words[i]) 75 | if i < len(words) - 1: 76 | bigram.add(words[i] + ' ' + words[i + 1]) 77 | bigram_count += 1 78 | scores['unigram'] = len(unigram) 79 | scores['unigram_count'] = unigram_count 80 | scores['distinct-1'] = len(unigram) / unigram_count 81 | scores['bigram'] = len(bigram) 82 | scores['bigram_count'] = bigram_count 83 | scores['distinct-2'] = len(bigram) / bigram_count 84 | # print(scores) 85 | return scores 86 | 87 | def perplexity(hypothesises, probabilities): 88 | scores=dict() 89 | avg_perplexity=0 90 | for i in range(len(hypothesises)): 91 | perplexity = 1 92 | N = 0 93 | for j in len(hypothesises[i]): 94 | N += 1 95 | perplexity = perplexity * (1/probabilities[i,j].item()) 96 | perplexity = pow(perplexity, 1/float(N)) 97 | avg_perplexity+=perplexity 98 | scores['perplexity'] =avg_perplexity/len(hypothesises) 99 | # print(scores) 100 | return scores 101 | 102 | def sentence_bleu(hypothesises, references): 103 | scores = dict() 104 | avg_bleu=0 105 | for i in range(len(hypothesises)): 106 | avg_bleu+=bleu.sentence_bleu([r.split(' ') for r in references[i]], hypothesises[i].split(' ')) 107 | scores['sentence_bleu'] = avg_bleu/len(hypothesises) 108 | # print(scores) 109 | return scores 110 | 111 | def corpus_bleu(hypothesises, references): 112 | scores = dict() 113 | b_systems = [] 114 | b_references = [] 115 | for i in range(len(hypothesises)): 116 | b_systems.append(hypothesises[i].split(' ')) 117 | b_references.append([r.split(' ') for r in references[i]]) 118 | scores['corpus_bleu'] = bleu.corpus_bleu(b_references, b_systems) 119 | # print(scores) 120 | return scores 121 | 122 | -------------------------------------------------------------------------------- /modules/Generations.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | import random 6 | from data.Utils import * 7 | from torch.distributions.categorical import * 8 | from Constants import * 9 | from modules.Utils import * 10 | 11 | def sample(model, data, max_len=20): 12 | batch_size = data['id'].size(0) 13 | 14 | encode_outputs = model.encode(data) 15 | 16 | init_decoder_states = model.init_decoder_states(data, encode_outputs) 17 | 18 | init_decoder_input = new_tensor([BOS] * batch_size, requires_grad=False) 19 | 20 | indices = list() 21 | end = new_tensor([0] * batch_size).long() == 1 22 | 23 | decoder_input = init_decoder_input 24 | decoder_states = init_decoder_states 25 | all_gen_outputs = list() 26 | all_decode_outputs = list() 27 | 28 | # ranp=random.randint(0, max_len-1) 29 | for t in range(max_len): 30 | decode_outputs = model.decode( 31 | data, decoder_input, decoder_states, encode_outputs 32 | ) 33 | print(decode_outputs[2]) 34 | 35 | gen_output = model.generate(data, decode_outputs, softmax=True) 36 | 37 | all_gen_outputs.append(gen_output.unsqueeze(0)) 38 | all_decode_outputs.append(decode_outputs) 39 | 40 | probs, ids = model.to_word(data, F.softmax(gen_output, dim=1), 1, sampling=True) 41 | print(probs, ids) 42 | # if random.uniform(0,1)>0.9: 43 | # probs, ids = model.to_word(data, F.softmax(gen_output, dim=1), 1, sampling=True) 44 | # else: 45 | # probs, ids = model.to_word(data, F.softmax(gen_output, dim=1), 1, sampling=False) 46 | 47 | decoder_states = decode_outputs[1] 48 | 49 | indice = ids[:, 0] 50 | this_end = indice == EOS 51 | if t == 0: 52 | indice.masked_fill_(this_end, UNK) 53 | elif t==max_len-1: 54 | indice[:]=EOS 55 | indice.masked_fill_(end, PAD) 56 | else: 57 | indice.masked_fill_(end, PAD) 58 | indices.append(indice.unsqueeze(1)) 59 | end = end | this_end 60 | 61 | decoder_input = model.generation_to_decoder_input(data, indice) 62 | 63 | all_gen_outputs = torch.cat(all_gen_outputs, dim=0).transpose(0, 1).contiguous() 64 | 65 | return torch.cat(indices, dim=1), encode_outputs, init_decoder_states, all_decode_outputs, all_gen_outputs 66 | 67 | 68 | def greedy(model,data,max_len=20): 69 | batch_size=data['id'].size(0) 70 | 71 | encode_outputs= model.encode(data) 72 | 73 | decoder_states = model.init_decoder_states(data, encode_outputs) 74 | 75 | decoder_input = new_tensor([BOS] * batch_size, requires_grad=False) 76 | 77 | greedy_indices=list() 78 | greedy_end = new_tensor([0] * batch_size).long() == 1 79 | for t in range(max_len): 80 | decode_outputs = model.decode( 81 | data, decoder_input, decoder_states, encode_outputs 82 | ) 83 | 84 | gen_output=model.generate(data, decode_outputs, softmax=True) 85 | 86 | probs, ids=model.to_word(data, gen_output, 1) 87 | 88 | decoder_states = decode_outputs[1] 89 | 90 | greedy_indice = ids[:,0] 91 | greedy_this_end = greedy_indice == EOS 92 | if t == 0: 93 | greedy_indice.masked_fill_(greedy_this_end, UNK) 94 | else: 95 | greedy_indice.masked_fill_(greedy_end, PAD) 96 | greedy_indices.append(greedy_indice.unsqueeze(1)) 97 | greedy_end = greedy_end | greedy_this_end 98 | 99 | decoder_input = model.generation_to_decoder_input(data, greedy_indice) 100 | 101 | greedy_indice=torch.cat(greedy_indices,dim=1) 102 | return greedy_indice 103 | 104 | 105 | def beam(model,data,max_len=20,width=5): 106 | batch_size = data['id'].size(0) 107 | 108 | encode_outputs = model.encode(data) 109 | decoder_states = model.init_decoder_states(data, encode_outputs) 110 | 111 | num_states=len(decoder_states) 112 | num_encodes=len(encode_outputs) 113 | 114 | next_fringe = [] 115 | results = dict() 116 | for i in range(batch_size): 117 | next_fringe += [Node(parent=None, state=[s[i].unsqueeze(0) for s in decoder_states], word=BOS_WORD, value=BOS, cost=0.0, encode_outputs=[o[i].unsqueeze(0) for o in encode_outputs], data=get_data(i,data), batch_id=i)] 118 | results[i] = [] 119 | 120 | for l in range(max_len+1): 121 | fringe = [] 122 | for n in next_fringe: 123 | if n.value == EOS or l == max_len: 124 | results[n.batch_id].append(n) 125 | else: 126 | fringe.append(n) 127 | 128 | if len(fringe) == 0: 129 | break 130 | 131 | data=concat_data([n.data for n in fringe]) 132 | 133 | decoder_input= new_tensor([n.value for n in fringe], requires_grad=False) 134 | decoder_input = model.generation_to_decoder_input(data, decoder_input) 135 | 136 | decoder_states=[] 137 | for i in range(num_states): 138 | decoder_states+=[torch.cat([n.state[i] for n in fringe],dim=0)] 139 | 140 | encode_outputs=[] 141 | for i in range(num_encodes): 142 | encode_outputs+=[torch.cat([n.encode_outputs[i] for n in fringe],dim=0)] 143 | 144 | decode_outputs = model.decode( 145 | data, decoder_input, decoder_states, encode_outputs 146 | ) 147 | decoder_states = decode_outputs[1] 148 | 149 | gen_output = model.generate(data, decode_outputs, softmax=True) 150 | 151 | probs, ids = model.to_word(data, gen_output, width) 152 | 153 | next_fringe_dict = dict() 154 | for i in range(batch_size): 155 | next_fringe_dict[i] = [] 156 | 157 | for i in range(len(fringe)): 158 | n = fringe[i] 159 | state_n = [s[i].unsqueeze(0) for s in decoder_states] 160 | 161 | for j in range(width): 162 | loss = -math.log(probs[i,j].item() + 1e-10) 163 | 164 | n_new = Node(parent=n, state=state_n, word=None, value=ids[i,j].item(), cost=loss, 165 | encode_outputs=n.encode_outputs, 166 | data=n.data, batch_id=n.batch_id) 167 | 168 | next_fringe_dict[n_new.batch_id].append(n_new) 169 | 170 | next_fringe = [] 171 | for i in range(batch_size): 172 | next_fringe += sorted(next_fringe_dict[i], key=lambda n: n.cum_cost / n.length)[:width] 173 | 174 | outputs = [] 175 | for i in range(batch_size): 176 | results[i].sort(key=lambda n: n.cum_cost / n.length) 177 | outputs.append(results[i][0])# currently only select the first one 178 | 179 | # sents=[node.to_sequence_of_words()[1:-1] for node in outputs] 180 | indices=merge1D([new_tensor(node.to_sequence_of_values()[1:]) for node in outputs]) 181 | 182 | return indices 183 | 184 | 185 | 186 | def transformer_greedy(model,data,max_len=20): 187 | batch_size = data['id'].size(0) 188 | 189 | encode_outputs= model.encode(data) 190 | 191 | decoder_input = new_tensor([BOS] * batch_size).view(batch_size, 1) 192 | 193 | greedy_indices=list() 194 | greedy_indices.append(decoder_input) 195 | greedy_end = new_tensor([0] * batch_size).long() == 1 196 | for t in range(max_len): 197 | decoder_input = torch.cat(greedy_indices, dim=1) 198 | decode_outputs = model.decode( 199 | data, decoder_input, None, encode_outputs 200 | ) 201 | 202 | batch_size, tgt_len, hidden_size=decode_outputs.size() 203 | decode_outputs = decode_outputs.view(batch_size, tgt_len, -1)[:, -1] 204 | gen_output=model.generate(data, decode_outputs, softmax=True) 205 | 206 | probs, ids=model.to_word(data, gen_output, 5) 207 | 208 | greedy_indice = ids[:,2] 209 | greedy_this_end = greedy_indice == EOS 210 | if t == 0: 211 | greedy_indice.masked_fill_(greedy_this_end, UNK) 212 | else: 213 | greedy_indice.masked_fill_(greedy_end, PAD) 214 | greedy_indices.append(greedy_indice.unsqueeze(1)) 215 | greedy_end = greedy_end | greedy_this_end 216 | 217 | greedy_indice=torch.cat(greedy_indices,dim=1) 218 | return greedy_indice 219 | 220 | 221 | def transformer_beam(model,data,max_len=20,width=5): 222 | batch_size = data['id'].size(0) 223 | 224 | encode_outputs = model.encode(data) 225 | 226 | num_encodes=len(encode_outputs) 227 | 228 | next_fringe = [] 229 | results = dict() 230 | for i in range(batch_size): 231 | next_fringe += [Node(parent=None, state=None, word=BOS_WORD, value=BOS, cost=0.0, encode_outputs=[o[i].unsqueeze(0) for o in encode_outputs], data=get_data(i,data), batch_id=i)] 232 | results[i] = [] 233 | 234 | for l in range(max_len+1): 235 | fringe = [] 236 | for n in next_fringe: 237 | if n.value == EOS or l == max_len: 238 | results[n.batch_id].append(n) 239 | else: 240 | fringe.append(n) 241 | 242 | if len(fringe) == 0: 243 | break 244 | 245 | decoder_input = merge1D([new_tensor(node.to_sequence_of_values()) for node in fringe]) 246 | decoder_input = model.generation_to_decoder_input(data, decoder_input) 247 | 248 | encode_outputs=[] 249 | for i in range(num_encodes): 250 | encode_outputs+=[torch.cat([n.encode_outputs[i] for n in fringe],dim=0)] 251 | 252 | data=concat_data([n.data for n in fringe]) 253 | 254 | decode_outputs = model.decode( 255 | data, decoder_input, None, encode_outputs 256 | ) 257 | 258 | this_batch_size, tgt_len, hidden_size = decode_outputs.size() 259 | lengths = decoder_input.ne(PAD).sum(dim=1).long() - 1 260 | last_decode_output = list() 261 | for i in range(this_batch_size): 262 | last_decode_output.append(decode_outputs[i, lengths[i].item()].unsqueeze(0)) 263 | decode_outputs = torch.cat(last_decode_output, dim=0) 264 | 265 | gen_output = model.generate(data, decode_outputs, softmax=True) 266 | 267 | probs, ids = model.to_word(data, gen_output, width) 268 | 269 | next_fringe_dict = dict() 270 | for i in range(batch_size): 271 | next_fringe_dict[i] = [] 272 | 273 | for i in range(len(fringe)): 274 | n = fringe[i] 275 | 276 | for j in range(width): 277 | loss = -math.log(probs[i,j].item() + 1e-10) 278 | 279 | n_new = Node(parent=n, state=None, word=None, value=ids[i,j].item(), cost=loss, 280 | encode_outputs=n.encode_outputs, 281 | data=n.data, batch_id=n.batch_id) 282 | 283 | next_fringe_dict[n_new.batch_id].append(n_new) 284 | 285 | next_fringe = [] 286 | for i in range(batch_size): 287 | next_fringe += sorted(next_fringe_dict[i], key=lambda n: n.cum_cost / n.length)[:width] 288 | 289 | outputs = [] 290 | for i in range(batch_size): 291 | results[i].sort(key=lambda n: n.cum_cost / n.length) 292 | outputs.append(results[i][0])# currently only select the first one 293 | 294 | # sents=[node.to_sequence_of_words()[1:-1] for node in outputs] 295 | indices=merge1D([new_tensor(node.to_sequence_of_values()[1:]) for node in outputs]) 296 | 297 | return indices 298 | 299 | class Node(object): 300 | def __init__(self, parent, state, word, value, cost, encode_outputs, data, batch_id=None): 301 | super(Node, self).__init__() 302 | self.word=word 303 | self.value = value 304 | self.parent = parent # parent Node, None for root 305 | self.state = state 306 | self.cum_cost = parent.cum_cost + cost if parent else cost # e.g. -log(p) of sequence up to current node (including) 307 | self.length = 1 if parent is None else parent.length + 1 308 | self.encode_outputs = encode_outputs # can hold, for example, attention weights 309 | self._sequence = None 310 | self.batch_id=batch_id 311 | self.data=data 312 | 313 | def to_sequence(self): 314 | # Return sequence of nodes from root to current node. 315 | if not self._sequence: 316 | self._sequence = [] 317 | current_node = self 318 | while current_node: 319 | self._sequence.insert(0, current_node) 320 | current_node = current_node.parent 321 | return self._sequence 322 | 323 | def to_sequence_of_values(self): 324 | return [s.value for s in self.to_sequence()] 325 | 326 | def to_sequence_of_words(self): 327 | return [s.word for s in self.to_sequence()] 328 | -------------------------------------------------------------------------------- /modules/Generators.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.cuda 5 | from Constants import * 6 | 7 | class LinearGenerator(nn.Module): 8 | def __init__(self, feature_size, tgt_vocab_size, logit_scale=1, weight=None): 9 | super(LinearGenerator, self).__init__() 10 | self.linear = nn.Linear(feature_size, tgt_vocab_size) 11 | self.logit_scale=logit_scale 12 | if weight is not None: 13 | self.linear.weight =weight 14 | 15 | def forward(self, feature, softmax=True): 16 | logits = self.linear(feature) * self.logit_scale 17 | if softmax: 18 | logits=F.softmax(logits,dim=1) 19 | return logits 20 | 21 | class CopyGenerator(nn.Module): 22 | def __init__(self, feature_size, tgt_vocab_size, logit_scale=1, weight=None): 23 | super(CopyGenerator, self).__init__() 24 | self.linear = nn.Linear(feature_size, tgt_vocab_size) 25 | self.linear_copy = nn.Linear(feature_size, 1) 26 | self.logit_scale = logit_scale 27 | if weight is not None: 28 | self.linear.weight = weight 29 | 30 | def forward(self, feature, attention, src_map): 31 | """ 32 | attention: [batch, source_len] batch can be batch*target_len 33 | src_map: [batch, source_len, copy_vocab_size] value: 0 or 1 34 | """ 35 | # CHECKS 36 | batch, _ = feature.size() 37 | batch, slen = attention.size() 38 | batch,slen,cvocab = src_map.size() 39 | 40 | # Original probabilities. 41 | logits = self.linear(feature) * self.logit_scale 42 | logits[:, PAD] = -float('inf') 43 | logits = F.softmax(logits, dim=1) 44 | 45 | # Probability of copying p(z=1) batch. 46 | p_copy = torch.sigmoid(self.linear_copy(feature)) 47 | # Probibility of not copying: p_{word}(w) * (1 - p(z)) 48 | out_prob = torch.mul(logits, 1 - p_copy.expand_as(logits)) 49 | # mul_attn = torch.mul(attention, p_copy.expand_as(attention)) 50 | # copy_prob = torch.bmm(mul_attn.view(-1, batch, slen) 51 | # .transpose(0, 1), 52 | # src_map.float()).squeeze(1) 53 | 54 | copy_prob = torch.bmm(attention.view(-1, batch, slen) 55 | .transpose(0, 1), 56 | src_map.float()).squeeze(1) 57 | copy_prob = torch.mul(copy_prob, p_copy.expand_as(copy_prob)) 58 | 59 | return torch.cat([out_prob, copy_prob], 1) 60 | # return out_prob, copy_prob 61 | 62 | -------------------------------------------------------------------------------- /modules/Utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from Constants import * 4 | from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence 5 | 6 | 7 | def new_tensor(array, requires_grad=False): 8 | tensor=torch.tensor(array, requires_grad=requires_grad) 9 | if torch.cuda.is_available(): 10 | tensor=tensor.cuda() 11 | return tensor 12 | 13 | # def hotfix_pack_padded_sequence(input, lengths, batch_first=True): 14 | # lengths = torch.as_tensor(lengths, dtype=torch.int64) 15 | # lengths = lengths.cpu() 16 | # return PackedSequence(torch._C._VariableFunctions._pack_padded_sequence(input, lengths, batch_first)) 17 | 18 | def gru_forward(gru, input, lengths, state=None, batch_first=True): 19 | input_lengths, perm = torch.sort(lengths, descending=True) 20 | 21 | input = input[perm] 22 | if state is not None: 23 | state = state[perm].transpose(0, 1).contiguous() 24 | 25 | total_length=input.size(1) 26 | if not batch_first: 27 | input = input.transpose(0, 1) # B x L x N -> L x B x N 28 | packed = torch.nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first) 29 | # packed = hotfix_pack_padded_sequence(embedded, input_lengths, batch_first) 30 | # self.gru.flatten_parameters() 31 | outputs, state = gru(packed, state) # -> L x B x N * n_directions, 1, B, N 32 | outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=batch_first, total_length=total_length) # unpack (back to padded) 33 | 34 | _, perm = torch.sort(perm, descending=False) 35 | if not batch_first: 36 | outputs = outputs.transpose(0, 1) 37 | outputs=outputs[perm] 38 | state = state.transpose(0, 1)[perm] 39 | 40 | return outputs, state 41 | 42 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 43 | ''' Sinusoid position encoding table ''' 44 | 45 | def cal_angle(position, hid_idx): 46 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 47 | 48 | def get_posi_angle_vec(position): 49 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 50 | 51 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 52 | 53 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 54 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 55 | 56 | if padding_idx is not None: 57 | # zero vector for padding dimension 58 | sinusoid_table[padding_idx] = 0. 59 | 60 | return torch.FloatTensor(sinusoid_table) 61 | 62 | def get_attn_key_pad_mask(seq_k, seq_q): 63 | ''' For masking out the padding part of key sequence. ''' 64 | 65 | # Expand to fit the shape of key query attention matrix. 66 | len_q = seq_q.size(1) 67 | padding_mask = seq_k.ne(PAD) 68 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 69 | 70 | return padding_mask 71 | 72 | def get_subsequent_mask(seq): 73 | ''' For masking out the subsequent info. ''' 74 | 75 | sz_b, len_s = seq.size() 76 | subsequent_mask = torch.triu( 77 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 78 | if torch.cuda.is_available(): 79 | subsequent_mask=subsequent_mask.cuda() 80 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 81 | subsequent_mask=1 - subsequent_mask 82 | 83 | return subsequent_mask -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/repozhang/bbc-pre-selection/9fb625623b84f948d48d507a11f305f336e8289f/modules/__init__.py -------------------------------------------------------------------------------- /trainers/DefaultTrainer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.utils.data import DataLoader 3 | import time 4 | from torch.nn.init import * 5 | from torch.optim.lr_scheduler import * 6 | from modules.Generations import * 7 | import codecs 8 | import os 9 | import sys 10 | from Constants import * 11 | from Rouge import * 12 | import json 13 | 14 | 15 | def init_params(model): 16 | for name, param in model.named_parameters(): 17 | print(name, param.size()) 18 | if param.data.dim() > 1: 19 | xavier_uniform_(param.data) 20 | 21 | class DefaultTrainer(object): 22 | def __init__(self, model): 23 | super(DefaultTrainer, self).__init__() 24 | 25 | if torch.cuda.is_available(): 26 | self.model =model.cuda() 27 | else: 28 | self.model = model 29 | self.eval_model = self.model 30 | 31 | self.distributed = False 32 | if torch.cuda.device_count()>1: 33 | self.distributed=True 34 | print("Let's use", torch.cuda.device_count(), "GPUs!") 35 | # print('GPU', torch.cuda.current_device(), 'ready') 36 | torch.distributed.init_process_group(backend='nccl') 37 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, find_unused_parameters=True) 38 | 39 | def train_batch(self, epoch, data, method, optimizer): 40 | optimizer.zero_grad() 41 | loss = self.model(data, method=method) 42 | 43 | if isinstance(loss, tuple): 44 | closs = [l.mean().cpu().item() for l in loss] 45 | # loss = torch.cat([l.mean().view(1) for l in loss]).sum() 46 | loss = torch.cat(loss, dim=-1).mean() 47 | else: 48 | loss = loss.mean() 49 | closs = [loss.cpu().item()] 50 | 51 | loss.backward() 52 | 53 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) 54 | optimizer.step() 55 | return closs 56 | 57 | def serialize(self,epoch, output_path): 58 | output_path = os.path.join(output_path, 'model/') 59 | if not os.path.exists(output_path): 60 | os.makedirs(output_path) 61 | torch.save(self.eval_model.state_dict(), os.path.join(output_path, '.'.join([str(epoch), 'pkl']))) 62 | 63 | def train_epoch(self, method, train_dataset, train_collate_fn, batch_size, epoch, optimizer): 64 | self.model.train() 65 | 66 | train_loader = torch.utils.data.DataLoader(train_dataset, collate_fn=train_collate_fn, batch_size=batch_size, 67 | shuffle=True, pin_memory=True) 68 | 69 | start_time = time.time() 70 | count_batch=0 71 | for j, data in enumerate(train_loader, 0): 72 | if torch.cuda.is_available(): 73 | data_cuda = dict() 74 | for key, value in data.items(): 75 | if isinstance(value, torch.Tensor): 76 | data_cuda[key] = value.cuda() 77 | else: 78 | data_cuda[key] = value 79 | data = data_cuda 80 | count_batch += 1 81 | 82 | bloss = self.train_batch(epoch, data, method=method, optimizer=optimizer) 83 | 84 | if j >= 0 and j%100==0: 85 | elapsed_time = time.time() - start_time 86 | print('Method', method, 'Epoch', epoch, 'Batch ', count_batch, 'Loss ', bloss, 'Time ', elapsed_time) 87 | sys.stdout.flush() 88 | del bloss 89 | 90 | # elapsed_time = time.time() - start_time 91 | # print(method + ' ', epoch, 'time ', elapsed_time) 92 | sys.stdout.flush() 93 | 94 | def predict(self,dataset, collate_fn, batch_size, epoch, output_path): 95 | self.eval_model.eval() 96 | 97 | with torch.no_grad(): 98 | test_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, 99 | pin_memory=True, collate_fn=collate_fn, 100 | num_workers=0) 101 | 102 | srcs = [] 103 | systems = [] 104 | references = [] 105 | for k, data in enumerate(test_loader, 0): 106 | if torch.cuda.is_available(): 107 | data_cuda=dict() 108 | for key, value in data.items(): 109 | if isinstance(value, torch.Tensor): 110 | data_cuda[key]=value.cuda() 111 | else: 112 | data_cuda[key] = value 113 | data = data_cuda 114 | 115 | indices = self.eval_model(data, method='test') 116 | sents=self.eval_model.to_sentence(data,indices) 117 | 118 | srcs += [' '.join(dataset.input(id.item())) for id in data['id']] 119 | systems += [' '.join(s).replace(LINESEP_WORD, os.linesep).lower() for s in sents] 120 | for id in data['id']: 121 | refs=' '.join(dataset.output(id.item())) 122 | references.append(refs.lower()) 123 | 124 | output_path = os.path.join(output_path, 'result/') 125 | if not os.path.exists(output_path): 126 | os.makedirs(output_path) 127 | 128 | file = codecs.open(os.path.join(output_path, str(epoch)+'.txt'), "w", "utf-8") 129 | for i in range(len(systems)): 130 | file.write(systems[i]+ os.linesep) 131 | file.close() 132 | return systems,references,data 133 | 134 | def test(self,dataset, collate_fn, batch_size, epoch, output_path): 135 | with torch.no_grad(): 136 | systems,references,data=self.predict(dataset, collate_fn, batch_size, epoch, output_path) 137 | 138 | rouges= rouge(systems, references) 139 | scores=rouges 140 | # bleu=sentence_bleu(systems, references) 141 | # scores['sentence_bleu']=bleu['sentence_bleu'] 142 | print(scores) 143 | scores['score']=scores['rouge_l/f_score'] 144 | 145 | return scores,data -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/repozhang/bbc-pre-selection/9fb625623b84f948d48d507a11f305f336e8289f/trainers/__init__.py --------------------------------------------------------------------------------