├── README.md ├── hierarchical_generate.py ├── vocabulary-embedding.py └── acl17_model.py /README.md: -------------------------------------------------------------------------------- 1 | Code for the ACL'17 paper: 2 | Jiwei Tan, Xiaojun Wan and Jianguo Xiao. Abstractive Document Summarization with a Graph-Based Attentional Neural Model. 3 | 4 | RUN: first run vocabulary-embedding.py and then hierarchical_generate.py to produce hie-embedding.pkl 5 | This model is built with theano=0.8.2 and keras=1.0.6 -------------------------------------------------------------------------------- /hierarchical_generate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Dec 22 09:50:30 2016 4 | 5 | @author: tanjiwei 6 | """ 7 | import cPickle 8 | import re 9 | import numpy as np 10 | 11 | def myrouge_2(sent,ref): 12 | n = 2 13 | sent_tokens=sent.split() 14 | ref_tokens=ref.split() 15 | sent_ngrams=set([' '.join(sent_tokens[i:i+n]) for i in range(len(sent_tokens)-n)]) 16 | ref_ngrams=set([' '.join(ref_tokens[i:i+n]) for i in range(len(ref_tokens)-n)]) 17 | if '@entity 1' in sent_ngrams: 18 | sent_ngrams.remove('@entity 1') 19 | if '@entity 1' in ref_ngrams: 20 | ref_ngrams.remove('@entity 1') 21 | if len(sent_ngrams)*len(ref_ngrams)==0: 22 | return 0.0 23 | recall = len(sent_ngrams.intersection(ref_ngrams))/float(len(ref_ngrams)) 24 | precision = len(sent_ngrams.intersection(ref_ngrams))/float(len(sent_ngrams)) 25 | if recall==0.0 and precision==0.0: 26 | return 0.0 27 | fscore = 2*recall*precision/(recall+precision) 28 | return fscore 29 | 30 | nb_summ = 34 31 | nb_ref = 5 32 | 33 | FN0 = 'dailymail-embedding-only' 34 | FN = 'hie-embedding' 35 | 36 | with open('data/%s.pkl'%FN0, 'rb') as fp: 37 | embedding, idx2word, word2idx, glove_idx2idx = cPickle.load(fp) 38 | 39 | def filter_entity(sent): 40 | return re.sub('@entity \d',' ',sent) 41 | 42 | def map_index(sent): 43 | return [word2idx[_token] for _token in sent.split()] 44 | 45 | 46 | def compress(_refs,_docs): 47 | rouges = [myrouge_2(_sent,' '.join(_refs)) for _sent in _docs] 48 | ranks = np.argsort(rouges)[::-1] 49 | #ranks = range(nb_summ) 50 | results = [] 51 | for i in range(len(_docs)): 52 | if i in ranks[:nb_summ]: 53 | results.append(_docs[i]) 54 | return results 55 | 56 | 57 | folder = 'dailymail' 58 | df=cPickle.load(open('neuralsum/%s/all_replaced.pkl'%folder)) 59 | 60 | highlights_train = df['highlight_training'] 61 | docs_train = df['training'] 62 | 63 | highlights_valid = df['highlight_valid'] 64 | docs_valid = df['valid'] 65 | 66 | highlights_test = df['highlight_test'] 67 | docs_test = df['test'] 68 | 69 | X=[] 70 | Y=[] 71 | for (_refs,_docs) in zip(highlights_train+highlights_valid+highlights_test,docs_train+docs_valid+docs_test): 72 | if len(_refs)<2: 73 | continue 74 | if len(_docs)>nb_summ: 75 | _docs = compress(_refs,_docs) 76 | X.append([map_index(_sent) for _sent in _docs[:nb_summ-1]]+[[word2idx['']]]+[[0]]*(nb_summ-len(_docs)-1)) 77 | appy = [map_index(_sent) for _sent in _refs[:nb_ref]] 78 | if len(appy)']]] 80 | Y.append(appy+[[0]]*(nb_ref-len(_refs)-1)) 81 | 82 | 83 | with open('data/%s.pkl'%FN,'wb') as fp: 84 | cPickle.dump((X, Y, embedding, idx2word, word2idx, glove_idx2idx),fp,-1) -------------------------------------------------------------------------------- /vocabulary-embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 27 16:47:31 2016 4 | 5 | @author: tanjiwei 6 | """ 7 | 8 | FN = 'dailymail-embedding-only' 9 | seed = 42 10 | vocab_size = 40000 11 | embedding_dim = 100 12 | lower = False # dont lower case the text 13 | 14 | #read tokenized headlines and descriptions 15 | import cPickle as pickle 16 | import re 17 | 18 | ''' 19 | cnn=pickle.load(open('neuralsum/cnn/all.pkl')) 20 | dailymail=pickle.load(open('/neuralsum/dailymail/all.pkl')) 21 | 22 | def replace_digit(token): 23 | return re.sub('\d','#',token) 24 | 25 | def replace_entity(sent,entitydic): 26 | tokens=sent.split() 27 | replaced=[replace_digit(t) if t not in entitydic.keys() else '@entity '+str(len(entitydic[t].split()))+' '+entitydic[t].decode('utf8','ignore') for t in tokens ] 28 | return ' '.join(replaced).lower() 29 | 30 | allsents=[] 31 | for folder in [cnn,dailymail]: 32 | for dataset in ['training','validation','test']: 33 | df=folder[dataset] 34 | highlights=df['highlight'].tolist() 35 | sentences=df['sentence'].tolist() 36 | entities=df['entity'].tolist() 37 | for _highlight,_entitydic in zip(highlights,entities): 38 | allsents+=[replace_entity(_s.decode('utf8','ignore').replace('*',''),_entitydic) for _s in _highlight] 39 | for _sentence,_entitydic in zip(sentences,entities): 40 | allsents+=[replace_entity(_s[0].decode('utf8','ignore').replace('*',''),_entitydic) for _s in _sentence] 41 | writer=open('neuralsum/all_processed_sents.pkl','wb') 42 | pickle.dump(allsents,writer,-1) 43 | writer.close() 44 | ''' 45 | allsents=pickle.load(open('neuralsum/all_processed_sents.pkl')) 46 | 47 | #build vocabulary 48 | from collections import Counter 49 | from itertools import chain 50 | def get_vocab(lst): 51 | vocabcount = Counter(w for txt in lst for w in txt.split()) 52 | vocab = map(lambda x: x[0], sorted(vocabcount.items(), key=lambda x: -x[1])) 53 | return vocab, vocabcount 54 | vocab, vocabcount = get_vocab(allsents) 55 | 56 | #Index words 57 | empty = 0 # RNN mask of no data 58 | eos = 1 # end of sentence 59 | eod = 2 60 | entity_unk_0 = 3 61 | entity_unk_1 = 4 62 | entity_unk_2 = 5 63 | entity_unk_3 = 6 64 | entity_unk_4 = 7 65 | start_idx = entity_unk_4+1 # first real word 66 | def get_idx(vocab, vocabcount): 67 | word2idx = dict((word, idx+start_idx) for idx,word in enumerate(vocab)) 68 | word2idx[''] = empty 69 | word2idx[''] = eos 70 | word2idx[''] = eod 71 | word2idx[''] = entity_unk_0 72 | word2idx[''] = entity_unk_1 73 | word2idx[''] = entity_unk_2 74 | word2idx[''] = entity_unk_3 75 | word2idx[''] = entity_unk_4 76 | 77 | idx2word = dict((idx,word) for word,idx in word2idx.iteritems()) 78 | 79 | return word2idx, idx2word 80 | word2idx, idx2word = get_idx(vocab, vocabcount) 81 | 82 | #Word Embedding 83 | #read GloVe 84 | import numpy as np 85 | fname = 'glove.6B.%dd.txt'%embedding_dim 86 | glove_name = 'glove.6B/'+fname 87 | glove_n_symbols=400000 88 | glove_index_dict = {} 89 | glove_embedding_weights = np.empty((glove_n_symbols, embedding_dim)) 90 | globale_scale=.1 91 | with open(glove_name, 'r') as fp: 92 | i = 0 93 | for l in fp: 94 | l = l.strip().split() 95 | w = l[0] 96 | glove_index_dict[w] = i 97 | glove_embedding_weights[i,:] = map(float,l[1:]) 98 | i += 1 99 | glove_embedding_weights *= globale_scale 100 | glove_embedding_weights.std() 101 | for w,i in glove_index_dict.iteritems(): 102 | w = w.lower() 103 | if w not in glove_index_dict: 104 | glove_index_dict[w] = i 105 | 106 | #embedding matrix 107 | #use GloVe to initialize seperate embedding matrix for headlines and description 108 | # generate random embedding with same scale as glove 109 | np.random.seed(seed) 110 | shape = (vocab_size, embedding_dim) 111 | scale = glove_embedding_weights.std()*np.sqrt(12)/2 # uniform and not normal 112 | embedding = np.random.uniform(low=-scale, high=scale, size=shape) 113 | print 'random-embedding/glove scale', scale, 'std', embedding.std() 114 | 115 | # copy from glove weights of words that appear in our short vocabulary (idx2word) 116 | c = 0 117 | for i in range(vocab_size): 118 | w = idx2word[i] 119 | g = glove_index_dict.get(w, glove_index_dict.get(w.lower())) 120 | if g is None and w.startswith('#'): # glove has no hastags (I think...) 121 | w = w[1:] 122 | g = glove_index_dict.get(w, glove_index_dict.get(w.lower())) 123 | if g is not None: 124 | embedding[i,:] = glove_embedding_weights[g,:] 125 | c+=1 126 | print 'number of tokens, in small vocab, found in glove and copied to embedding', c,c/float(vocab_size) 127 | 128 | glove_thr = 0.5 129 | word2glove = {} 130 | for w in word2idx: 131 | if w in glove_index_dict: 132 | g = w 133 | elif w.lower() in glove_index_dict: 134 | g = w.lower() 135 | elif w.startswith('#') and w[1:] in glove_index_dict: 136 | g = w[1:] 137 | elif w.startswith('#') and w[1:].lower() in glove_index_dict: 138 | g = w[1:].lower() 139 | else: 140 | continue 141 | word2glove[w] = g 142 | 143 | #for every word outside the embedding matrix find the closest word inside the mebedding matrix 144 | normed_embedding = embedding/np.array([np.sqrt(np.dot(gweight,gweight)) for gweight in embedding])[:,None] 145 | 146 | nb_unknown_words = 100 147 | 148 | glove_match = [] 149 | for w,idx in word2idx.iteritems(): 150 | if idx >= vocab_size-nb_unknown_words and w.isalpha() and w in word2glove: 151 | gidx = glove_index_dict[word2glove[w]] 152 | gweight = glove_embedding_weights[gidx,:].copy() 153 | # find row in embedding that has the highest cos score with gweight 154 | gweight /= np.sqrt(np.dot(gweight,gweight)) 155 | score = np.dot(normed_embedding[:vocab_size-nb_unknown_words], gweight) 156 | while True: 157 | embedding_idx = score.argmax() 158 | s = score[embedding_idx] 159 | if s < glove_thr: 160 | break 161 | if idx2word[embedding_idx] in word2glove : 162 | glove_match.append((w, embedding_idx, s)) 163 | break 164 | score[embedding_idx] = -1 165 | glove_match.sort(key = lambda x: -x[2]) 166 | print '# of glove substitutes found', len(glove_match) 167 | for orig, sub, score in glove_match[-10:]: 168 | print score, orig,'=>', idx2word[sub] 169 | 170 | #build a lookup table of index of outside words to index of inside words 171 | glove_idx2idx = dict((word2idx[w],embedding_idx) for w, embedding_idx, _ in glove_match) 172 | 173 | #DATA 174 | Y = [[word2idx[token] for token in headline.split()] for headline in allsents] 175 | 176 | with open('neuralsum/%s.pkl'%FN,'wb') as fp: 177 | pickle.dump((embedding, idx2word, word2idx, glove_idx2idx),fp,-1) -------------------------------------------------------------------------------- /acl17_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 27 19:44:16 2016 4 | 5 | @author: tanjiwei 6 | 7 | 8 | """ 9 | 10 | 11 | import keras 12 | FN0 = 'hie-embedding' 13 | FN1 = 'acl17_release_dailymail' 14 | FN1 = None 15 | FN = 'train_acl17_dailymail' 16 | 17 | alpha = 0.9 18 | factor = 10000 19 | #input data (X) is made from maxlend description words followed by eos 20 | maxlend=50 # 0 - if we dont want to use description at all 21 | maxlenh=50 22 | maxlen = maxlend + maxlenh 23 | rnn_size = 512 # must be same as 160330-word-gen 24 | 25 | maxsents = 34 26 | maxhighs = 5 27 | nb_summ = maxsents+1+maxhighs 28 | 29 | seed =42 30 | p_W, p_U, p_dense, p_emb, weight_decay = 0, 0, 0, 0, 0 31 | optimizer = 'adamax' 32 | batch_size=8 33 | 34 | nb_train_samples = 10000 35 | nb_val_samples = 1008 36 | 37 | NB_TEST = 10317+1008 38 | 39 | # read word embedding 40 | import cPickle as pickle 41 | 42 | with open('data/%s.pkl'%FN0, 'rb') as fp: 43 | X, Y, embedding, idx2word, word2idx, glove_idx2idx = pickle.load(fp) 44 | vocab_size, embedding_size = embedding.shape 45 | nb_unknown_words = 40 46 | 47 | print 'number of examples',len(X),len(Y) 48 | print 'dimension of embedding space for words',embedding_size 49 | print 'vocabulary size', vocab_size, 'the last %d words can be used as place holders for unknown/oov words'%nb_unknown_words 50 | print 'total number of different words',len(idx2word), len(word2idx) 51 | print 'number of words outside vocabulary which we can substitue using glove similarity', len(glove_idx2idx) 52 | print 'number of words that will be regarded as unknonw(unk)/out-of-vocabulary(oov)',len(idx2word)-vocab_size-len(glove_idx2idx) 53 | 54 | for i in range(nb_unknown_words): 55 | idx2word[vocab_size-1-i] = '<%d>'%i 56 | 57 | # when printing mark words outside vocabulary with ^ at their end 58 | for i in range(vocab_size-nb_unknown_words, len(idx2word)): 59 | idx2word[i] = idx2word[i]+'^' 60 | 61 | X_train = X[:-NB_TEST] 62 | Y_train = Y[:-NB_TEST] 63 | X_valid = X[-NB_TEST:-nb_val_samples] 64 | Y_valid = Y[-NB_TEST:-nb_val_samples] 65 | X_test = X[-10317:] 66 | Y_test = Y[-10317:] 67 | 68 | len(X_train), len(Y_train), len(X_test), len(Y_test) 69 | del X 70 | del Y 71 | 72 | empty = 0 73 | eos = 1 74 | eod = 2 75 | idx2word[empty] = '_' 76 | idx2word[eos] = '~' 77 | 78 | import numpy as np 79 | from keras.preprocessing import sequence 80 | from keras.utils import np_utils 81 | import random, sys 82 | 83 | def prt(label, x): 84 | print label+':', 85 | for w in x: 86 | print idx2word[w], 87 | print 88 | 89 | import numpy as np 90 | from keras.preprocessing import sequence 91 | from keras.utils import np_utils 92 | import random, sys, re 93 | from pattern.en import tokenize 94 | 95 | from keras.models import Sequential 96 | from keras.layers.core import Dense, Activation, Dropout, RepeatVector, Merge, TimeDistributedDense 97 | from keras.layers.recurrent import LSTM 98 | from keras.layers.embeddings import Embedding 99 | from keras.regularizers import l2 100 | from keras.models import Model 101 | from keras.layers import Input,TimeDistributed 102 | from keras.layers.core import Lambda,Reshape,Flatten,Masking,Permute 103 | from keras.layers import merge 104 | from keras.engine.topology import Layer 105 | from keras.optimizers import Adam, RMSprop # usually I prefer Adam but article used rmsprop 106 | import theano 107 | import theano.tensor as T 108 | 109 | # seed weight initialization 110 | random.seed(seed) 111 | np.random.seed(seed) 112 | 113 | # start with a standaed stacked LSTM 114 | regularizer = l2(weight_decay) if weight_decay else None 115 | 116 | # A special layer that reduces the input just to its headline part 117 | from keras.layers.core import Lambda 118 | import keras.backend as K 119 | 120 | class MaskLayer(Layer): 121 | def __init__(self,**kwargs): 122 | super(MaskLayer,self).__init__(**kwargs) 123 | def call(self,x,mask): 124 | return K.not_equal(x,0) 125 | def get_output_shape_for(self, input_shape): 126 | return input_shape 127 | 128 | class DemaskLayer(Layer): 129 | def __init__(self,**kwargs): 130 | super(DemaskLayer,self).__init__(**kwargs) 131 | def call(self,x,mask): 132 | return x 133 | def compute_mask(self, input, input_mask): 134 | return None 135 | def get_output_shape_for(self, input_shape): 136 | return input_shape 137 | 138 | class SliceLayer(Layer): 139 | def __init__(self,dim,**kwargs): 140 | super(SliceLayer,self).__init__(**kwargs) 141 | self.supports_masking=True 142 | self.dim=dim 143 | def call(self,x,mask): 144 | return x[:,:,self.dim,:] 145 | def get_output_shape_for(self, input_shape): 146 | return (input_shape[0], input_shape[1], input_shape[3]) 147 | 148 | class LeftsubLayer(Layer): 149 | def __init__(self,dim,**kwargs): 150 | super(LeftsubLayer,self).__init__(**kwargs) 151 | self.supports_masking=True 152 | self.dim=dim 153 | def call(self,x,mask): 154 | return x[:,:,:self.dim,:] 155 | def get_output_shape_for(self, input_shape): 156 | return (input_shape[0], input_shape[1], self.dim, input_shape[3]) 157 | 158 | class RightsubLayer(Layer): 159 | def __init__(self,dim,**kwargs): 160 | super(RightsubLayer,self).__init__(**kwargs) 161 | self.supports_masking=True 162 | self.dim=dim 163 | def call(self,x,mask): 164 | return x[:,:,self.dim:self.dim+maxlenh-1,:] 165 | def get_output_shape_for(self, input_shape): 166 | return (input_shape[0], input_shape[1], input_shape[2]-self.dim-1, input_shape[3]) 167 | 168 | 169 | class UpsubLayer(Layer): 170 | def __init__(self,dim,**kwargs): 171 | super(UpsubLayer,self).__init__(**kwargs) 172 | self.supports_masking=True 173 | self.dim=dim 174 | def call(self,x,mask=None): 175 | return x[:,:self.dim,:] 176 | def get_output_shape_for(self, input_shape): 177 | return (input_shape[0], self.dim, input_shape[2]) 178 | 179 | class DownsubLayer(Layer): 180 | def __init__(self,dim,**kwargs): 181 | super(DownsubLayer,self).__init__(**kwargs) 182 | self.supports_masking=True 183 | self.dim=dim 184 | def call(self,x,mask): 185 | return x[:,self.dim:,:,:] 186 | def get_output_shape_for(self, input_shape): 187 | return (input_shape[0], input_shape[1]-self.dim, input_shape[2], input_shape[3]) 188 | 189 | def page_ranking(query,candidates): 190 | reprs = K.concatenate((query[None,:],candidates),axis=0) 191 | sims = K.dot(reprs,K.transpose(reprs)) 192 | W_mask = 1-K.eye(maxsents+1) 193 | W = W_mask*sims 194 | d = (K.epsilon()+K.sum(W,axis=0))**-1 195 | D = K.eye(maxsents+1)*d 196 | P = K.dot(W,D) 197 | y = K.concatenate((K.ones(1),K.zeros(maxsents))) 198 | x_r = (1-alpha)*K.dot(T.nlinalg.matrix_inverse(K.eye(maxsents+1)-alpha*P),y) 199 | return x_r[1:] 200 | 201 | def rank_function(x): 202 | input_reprs = x[:maxsents,:] 203 | output_reprs = x[maxsents:,:] 204 | activation_energies = theano.map(lambda _x:page_ranking(_x,input_reprs),output_reprs)[0] 205 | return activation_energies 206 | 207 | class PageattLayer(Layer): 208 | def _init__(self,**kwargs): 209 | super(PageattLayer,self).__init__(**kwargs) 210 | self.supports_masking=True 211 | def call(self,x,mask): 212 | x_switched = K.switch(mask[:,:,None],x,0.0) 213 | activation_ranks = theano.map(rank_function,x_switched)[0] 214 | activation_energies = K.switch(mask[:,None,:maxsents],activation_ranks,-1e20) 215 | activation_weights = theano.map(K.softmax,activation_energies)[0] 216 | base_values = (mask*((K.sum(mask[:,:maxsents]+0.0,axis=-1))**-1)[:,None])[:,None,:maxsents] 217 | pad_weights = K.concatenate((base_values,activation_weights[:,:-1,:]),axis=1) 218 | diff_weights = activation_weights - pad_weights 219 | posi_diffs = K.switch(diff_weights>0,diff_weights,0.0) 220 | norm_pds = (K.sum(posi_diffs,axis=-1)+K.epsilon())**-1 221 | attentions = posi_diffs*norm_pds[:,:,None] 222 | return attentions 223 | def compute_mask(self, input, input_mask): 224 | return None 225 | 226 | def get_output_shape_for(self, input_shape): 227 | return (input_shape[0],maxhighs+1,maxsents) 228 | 229 | 230 | #Embedding Model 231 | embedding_inputs = Input(shape=(None,),dtype='int32',name='embedding_inputs') 232 | embedding_x = Embedding(vocab_size, embedding_size, 233 | W_regularizer=regularizer, dropout=p_emb, weights=[embedding], mask_zero=True, trainable=True, name='embedding_x')(embedding_inputs) 234 | 235 | embedding_model=Model(input=embedding_inputs,output=embedding_x,name='embedding_model') 236 | embedding_model.compile(loss='mse', optimizer=optimizer) 237 | 238 | #Mask Model 239 | mask_inputs = MaskLayer(name='mask_x')(embedding_inputs) 240 | #mask_inputs_model = Model(input=[embedding_inputs],output=mask_inputs) 241 | mask_repeat = RepeatVector(embedding_size,name='mask_repeat')(mask_inputs) 242 | mask_permute = Permute((2,1),name='mask_permute')(mask_repeat) 243 | mask_model = Model(input=[embedding_inputs],output=mask_permute) 244 | mask_model.compile(loss='mse',optimizer=optimizer) 245 | 246 | #Encoder Model 247 | encoder_input=Input(shape=(maxlend,embedding_size),name='encoder_input') 248 | encoder_mask=Masking(name='encoder_mask')(encoder_input) 249 | encoder_layer1=LSTM(rnn_size, return_sequences=True, # batch_norm=batch_norm, 250 | W_regularizer=regularizer, U_regularizer=regularizer, consume_less='mem', 251 | b_regularizer=regularizer, dropout_W=p_W, dropout_U=p_U, name='encoder_layer1', trainable=True 252 | )(encoder_mask) 253 | encoder_layer2=LSTM(rnn_size, return_sequences=True, # batch_norm=batch_norm, 254 | W_regularizer=regularizer, U_regularizer=regularizer, consume_less='mem', 255 | b_regularizer=regularizer, dropout_W=p_W, dropout_U=p_U, name='encoder_layer2', trainable=True 256 | )(encoder_layer1) 257 | encoder_layer3=LSTM(rnn_size, return_sequences=True, # batch_norm=batch_norm, 258 | W_regularizer=regularizer, U_regularizer=regularizer, consume_less='mem', 259 | b_regularizer=regularizer, dropout_W=p_W, dropout_U=p_U, name='encoder_layer3', trainable=True 260 | )(encoder_layer2) 261 | encoder_model=Model(input=encoder_input,output=encoder_layer3,name='encoder_model') 262 | encoder_model.compile(loss='categorical_crossentropy', optimizer=optimizer) 263 | 264 | #Summ Model 265 | summ_input=Input(shape=(nb_summ,maxlen),dtype='int32', name='summ_input') 266 | #headline_mask=HeadlineMaskLayer(dim=maxlend,name='headline_mask')(summ_input) 267 | summ_x=TimeDistributed(embedding_model,name='summ_x',trainable=True)(summ_input) 268 | summ_input_masks = TimeDistributed(mask_model,name='summ_input_masks')(summ_input) 269 | summ_x_masked = merge([summ_x,summ_input_masks],mode='mul',name='summ_x_masked') 270 | summ_x_masked_model = Model(input=[summ_input],output=summ_x_masked) 271 | 272 | #left sub embeddings to get the input words 273 | summ_leftx = LeftsubLayer(dim=maxlend,name='summ_leftx')(summ_x_masked) 274 | summ_leftx_model = Model(input=[summ_input],output=summ_leftx) 275 | 276 | #encode inputs to sentence embeddings 277 | summ_encodings=TimeDistributed(encoder_model,name='summ_encodings',trainable=True)(summ_leftx) 278 | summ_encodings_model=Model(input=summ_input,output=summ_encodings) 279 | summ_encodings_model.compile(loss='categorical_crossentropy', optimizer=optimizer) 280 | 281 | #slice to get the last state as the sentence embeddings 282 | summ_last=SliceLayer(dim=maxlend-1,name='summ_last')(summ_encodings) 283 | summ_last_model=Model(input=[summ_input],output=summ_last) 284 | summ_last_masked = Masking(name='summ_last_masked')(summ_last) 285 | summ_last_masked_model = Model(input=[summ_input],output=summ_last_masked) 286 | 287 | #512 dim input sentence embeddings 288 | sents_repr = UpsubLayer(maxsents,name='sents_repr')(summ_last_masked) 289 | sents_repr_model = Model(input=[summ_input],output=sents_repr) 290 | 291 | #sentence encoder to turn 512-100 and get output sentence embeddings 292 | summ_merged = LSTM(embedding_size,name='summ_merged',return_sequences=True)(summ_last_masked) 293 | summ_merged_model = Model(input=[summ_input],output=summ_merged) 294 | 295 | #get sentence-level attention weights according to the 100 dim sentence hidden vectors 296 | summ_densed = TimeDistributed(Dense(embedding_size,bias=False),name='summ_densed')(summ_merged) 297 | summ_densed_model = Model(input=[summ_input],output=summ_densed) 298 | summ_sentatt = PageattLayer(name='summ_sentatt')(summ_densed) 299 | summ_sentatt_model = Model(input=[summ_input],output=summ_sentatt) 300 | #summ_sentatt = NewattLayer(name='summ_sentatt')(summ_merged) 301 | #summ_sentatt_model = Model(input=[summ_input],output=summ_sentatt) 302 | 303 | #context vectors merged according to sentence-level attention 304 | context_vecs = merge([summ_sentatt,sents_repr],mode='dot',dot_axes=(2,1),name='context_vecs') 305 | context_vecs_model = Model(input=[summ_input],output=context_vecs) 306 | context_flatten = Flatten(name='context_flatten')(context_vecs) 307 | context_repeat = RepeatVector(maxlenh,name='context_repeat')(context_flatten) 308 | context_repeat_model = Model(input=[summ_input],output=context_repeat) 309 | context_reshape = Reshape((maxlenh,maxhighs+1,rnn_size),name='context_reshape')(context_repeat) 310 | context_reshape_model = Model(input=[summ_input],output=context_reshape) 311 | context_permute = Permute((2,1,3),name='context_permute')(context_reshape) 312 | context_permute_model = Model(input=[summ_input],output=context_permute) 313 | 314 | #expand output sentence embedding 1 new dim 315 | summ_merged_demasked = DemaskLayer(name='summ_merged_demasked')(summ_merged) 316 | summ_expanded = Reshape((nb_summ,1,embedding_size),name='summ_expanded')(summ_merged_demasked) 317 | summ_expanded_model = Model(input=[summ_input],output=summ_expanded) 318 | 319 | #select the right parts (target output) of input word-level representations 320 | refs_x = RightsubLayer(dim=maxlend,name='refs_x')(summ_x_masked) 321 | refs_x_model = Model(input=summ_input,output=refs_x) 322 | 323 | #merge the output sentence embedding with the target output words 324 | merge_x = merge([summ_expanded,refs_x],mode='concat',concat_axis=2,name='merge_x') 325 | merge_x_model = Model(input=summ_input,output=merge_x) 326 | 327 | #keep only the target output sentences 328 | down_x = DownsubLayer(dim=maxsents,name='down_x')(merge_x) 329 | down_x_model = Model(input=summ_input,output=down_x) 330 | 331 | #Choice One: An independent decoder 332 | decoder_input = Input(shape=(maxlenh,embedding_size),name='decoder_input') 333 | decoder_mask = Masking(name='decoder_mask')(decoder_input) 334 | decoder_layer1=LSTM(rnn_size, return_sequences=True, # batch_norm=batch_norm, 335 | W_regularizer=regularizer, U_regularizer=regularizer, consume_less='mem', 336 | b_regularizer=regularizer, dropout_W=p_W, dropout_U=p_U, name='decoder_layer1', trainable=True 337 | )(decoder_mask) 338 | decoder_layer2=LSTM(rnn_size, return_sequences=True, # batch_norm=batch_norm, 339 | W_regularizer=regularizer, U_regularizer=regularizer, consume_less='mem', 340 | b_regularizer=regularizer, dropout_W=p_W, dropout_U=p_U, name='decoder_layer2', trainable=True 341 | )(decoder_layer1) 342 | decoder_layer3=LSTM(rnn_size, return_sequences=True, # batch_norm=batch_norm, 343 | W_regularizer=regularizer, U_regularizer=regularizer, consume_less='mem', 344 | b_regularizer=regularizer, dropout_W=p_W, dropout_U=p_U, name='decoder_layer3', trainable=True 345 | )(decoder_layer2) 346 | decoder_layer_model = Model(input=decoder_input,output=decoder_layer3) 347 | 348 | #decode the summs with decoders 349 | decoded_x = TimeDistributed(decoder_layer_model,name='decoded_x')(down_x) 350 | decoded_x_model = Model(input=summ_input,output=decoded_x) 351 | 352 | #merge the decoded representations with attentioned contexts 353 | decoded_merged = merge([decoded_x,context_permute],mode='concat',concat_axis=-1,name='decoded_merged') 354 | decoded_merged_model = Model(input=[summ_input],output=decoded_merged) 355 | 356 | #high-dimensional dense model 357 | dense_input = Input(shape=(maxlend,rnn_size*2),name='dense_input') 358 | dense_output = TimeDistributed(Dense(vocab_size,activation='softmax'),name='dense_output')(dense_input) 359 | dense_model = Model(input=dense_input,output=dense_output) 360 | 361 | #use the dense model to map embeddings into hot vectors 362 | decoder_words = TimeDistributed(dense_model,name='decoder_words')(decoded_merged) 363 | decoder_words_model = Model(input=summ_input,output=decoder_words) 364 | 365 | all_flatten = Reshape(((maxhighs+1)*maxlenh,vocab_size),name='all_flatten')(decoder_words) 366 | all_flatten_model = Model(input=summ_input,output=all_flatten,name='all_flatten_model') 367 | 368 | all_flatten_model.compile(loss='categorical_crossentropy', optimizer=optimizer) 369 | 370 | def myrouge_2(sent,ref): 371 | n = 2 372 | sent_tokens=sent.split() 373 | ref_tokens=ref.split() 374 | sent_ngrams=set([' '.join(sent_tokens[i:i+n]) for i in range(len(sent_tokens)-n)]) 375 | ref_ngrams=set([' '.join(ref_tokens[i:i+n]) for i in range(len(ref_tokens)-n)]) 376 | if '@entity 1' in sent_ngrams: 377 | sent_ngrams.remove('@entity 1') 378 | if '@entity 1' in ref_ngrams: 379 | ref_ngrams.remove('@entity 1') 380 | if len(sent_ngrams)*len(ref_ngrams)==0: 381 | return 0.0 382 | recall = len(sent_ngrams.intersection(ref_ngrams))/float(len(ref_ngrams)) 383 | precision = len(sent_ngrams.intersection(ref_ngrams))/float(len(sent_ngrams)) 384 | if recall==0.0 and precision==0.0: 385 | return 0.0 386 | fscore = 2*recall*precision/(recall+precision) 387 | return fscore 388 | 389 | def lpadd(xs, tolen, eos=eos): 390 | """left (pre) pad a description to maxlend and then add eos. 391 | The eos is the input to predicting the first word in the headline 392 | """ 393 | pads = [] 394 | for x in xs: 395 | n = len(x) 396 | if n > tolen: 397 | x = x[-tolen+1:] 398 | n = tolen 399 | if sum(x)>0: 400 | pads.append([empty]*(tolen-n-1) + x + [eos]) 401 | else: 402 | pads.append([empty]*(tolen-n-1) + x + [0]) 403 | return pads 404 | 405 | def concat_output(xd_pad): 406 | results = [] 407 | for i in range(len(xd_pad)-1): 408 | results.append(xd_pad[i]+[_x for _x in xd_pad[i+1] if _x!=0]) 409 | results.append(xd_pad[-1]+[3,1]) 410 | return results 411 | 412 | def vocab_fold(xs): 413 | """convert list of word indexes that may contain words outside vocab_size to words inside. 414 | If a word is outside, try first to use glove_idx2idx to find a similar word inside. 415 | If none exist then replace all accurancies of the same unknown word with <0>, <1>, ... 416 | """ 417 | xs = [x if x < vocab_size-nb_unknown_words else glove_idx2idx.get(x,x) for x in xs] 418 | # the more popular word is <0> and so on 419 | outside = sorted([x for x in xs if x >= vocab_size-nb_unknown_words]) 420 | # if there are more than nb_unknown_words oov words then put them all in nb_unknown_words-1 421 | outside = dict((x,vocab_size-1-min(i, nb_unknown_words-1)) for i, x in enumerate(outside)) 422 | xs = [outside.get(x,x) for x in xs] 423 | return xs 424 | 425 | def vocab_fold_list(xs): 426 | return [vocab_fold(_xs) for _xs in xs] 427 | 428 | def vocab_unfold(desc,xs): 429 | # assume desc is the unfolded version of the start of xs 430 | unfold = {} 431 | for i, unfold_idx in enumerate(desc): 432 | fold_idx = xs[i] 433 | if fold_idx >= vocab_size-nb_unknown_words: 434 | unfold[fold_idx] = unfold_idx 435 | return [unfold.get(x,x) for x in xs] 436 | 437 | def conv_seq_labels(xds, xhs): 438 | """description and hedlines are converted to padded input vectors. headlines are one-hot to label""" 439 | batch_size = len(xhs) 440 | assert len(xds) == batch_size 441 | def process_xdxh(xd,xh): 442 | concated_xd = xd+[[3]]+xh 443 | padded_xd = lpadd(concated_xd,maxlend) 444 | concated_xdxh = concat_output(padded_xd) 445 | return vocab_fold_list(concated_xdxh) 446 | x_raw = [process_xdxh(xd,xh) for xd,xh in zip(xds,xhs)] # the input does not have 2nd eos 447 | x = np.asarray([sequence.pad_sequences(_x, maxlen=maxlen, value=empty, padding='post', truncating='post') for _x in x_raw]) 448 | #x = flip_headline(x, nflips=nflips, model=model, debug=debug) 449 | 450 | def padeod_xh(xh): 451 | if [2] in xh: 452 | return xh+[[0]] 453 | else: 454 | return xh+[[2]] 455 | y = np.zeros((batch_size, maxhighs+1, maxlenh, vocab_size)) 456 | xhs_fold = [vocab_fold_list(padeod_xh(xh)) for xh in xhs] 457 | 458 | def process_xh(xh): 459 | if sum(xh)>0: 460 | xh_pad = xh + [eos] + [empty]*maxlenh # output does have a eos at end 461 | else: 462 | xh_pad = xh + [empty]*maxlenh 463 | xh_truncated = xh_pad[:maxlenh] 464 | return np_utils.to_categorical(xh_truncated, vocab_size) 465 | for i, xh in enumerate(xhs_fold): 466 | y[i,:,:,:] = np.asarray([process_xh(xh) for xh in xhs_fold[i]]) 467 | 468 | return x, y.reshape((batch_size,(maxhighs+1)*maxlenh,vocab_size)) 469 | 470 | def gen(Xd, Xh, batch_size=batch_size): 471 | while True: 472 | xds = [] 473 | xhs = [] 474 | for b in range(batch_size): 475 | t = random.randint(0,len(Xd)-1) 476 | xds.append(Xd[t]) 477 | xhs.append(Xh[t]) 478 | yield conv_seq_labels(xds, xhs) 479 | 480 | def greedysearch(Yp): 481 | samples = np.argmax(Yp,axis=-1).tolist() 482 | Ys = [[_word for _word in _sample if _word!=0] for _sample in samples] 483 | return [' '.join([idx2word[_w] for _w in _ys]) for _ys in Ys] 484 | 485 | def gensamples(gens): 486 | i = random.randint(0,len(gens)-1) 487 | print 'HEAD:\n ','\n '.join([' '.join([idx2word[w] for w in sent]) for sent in Y_test[i]]) 488 | #print '\nDESC:\n ','\n '.join([' '.join([idx2word[w] for w in sent]) for sent in X_test[i]]) 489 | print '\nGEND:',gens[i] 490 | sys.stdout.flush() 491 | 492 | 493 | def predict(samples,decode_model,dense_model,context_vec,start_vec): 494 | sample_lengths = map(len, samples) 495 | assert max(sample_lengths)=0 524 | if recall==0.0 and precision==0.0: 525 | fscore = 0.0 526 | else: 527 | fscore = 2*recall*precision/(recall+precision) 528 | return fscore 529 | 530 | def beamsearch(predict,decode_model,dense_model,context_vec,start_vec,mask,reference,rouge_factor,history_gen): 531 | def sample(energy, n): 532 | indexs=np.argsort(energy)[:n] 533 | scores = [energy[_ind] for _ind in indexs] 534 | return indexs,scores 535 | def rerank(iniranks,scores): 536 | pairs = [(_rank,_score) for _rank,_score in zip(iniranks,scores)] 537 | sorted_pairs = sorted(pairs,key=lambda x:x[1])[:beam_size] 538 | #ranks = [s[0] for s in sorted_pairs] 539 | #scores = [s[1] for s in sorted_pairs] 540 | return sorted_pairs 541 | def rank_pair(live_pairs,dead_pairs): 542 | merge_pairs = live_pairs+dead_pairs 543 | sorted_merge = sorted(merge_pairs,key=lambda x:x[1])[:beam_size] 544 | ranks_dead = [-1-s[0] for s in sorted_merge if s[0]<0] 545 | ranks_live = [s[0] for s in sorted_merge if s[0]>=0] 546 | dead_scores = [s[1] for s in sorted_merge if s[0]<0] 547 | live_scores = [s[1] for s in sorted_merge if s[0]>=0] 548 | return ranks_dead, ranks_live, live_scores 549 | 550 | dead_k = 0 # samples that reached eos 551 | dead_samples = [] 552 | dead_scores = [] 553 | live_samples=[[]]*beam_size 554 | live_k = 1 555 | live_scores = [0] 556 | probs = predict(live_samples,decode_model,dense_model,context_vec,start_vec)[0] 557 | live_samples = sample(-probs, beam_size*100)[0][:,None].tolist() 558 | ref_tokens = [] 559 | for _ref in reference: 560 | ref_tokens += _ref 561 | gen_tokens = [] 562 | for _gen in history_gen: 563 | gen_tokens += _gen 564 | if word2idx['@entity'] in gen_tokens: 565 | gen_tokens.remove(word2idx['@entity']) 566 | #left_tokens = set(ref_tokens).difference(gen_tokens) 567 | left_tokens = set(ref_tokens) 568 | live_samples = [_sample for _sample in live_samples if _sample[0] in left_tokens] 569 | if len(live_samples)']]!=1 and [2] in live_samples: 573 | live_samples.remove([2]) 574 | live_samples.append([word2idx['@entity']]) 575 | 576 | while live_k: 577 | # for every possible live sample calc prob for every possible label 578 | probs = predict(live_samples,decode_model,dense_model,context_vec,start_vec) 579 | voc_size = probs.shape[1] 580 | # total score for every sample is sum of -log of word prb 581 | cand_scores = np.array(live_scores)[:,None] - np.log(probs+1e-20) 582 | cand_scores[:,empty] = 1e20 583 | cand_scores = cand_scores * mask[None,:] + ((1-mask)*1e20)[None,:] 584 | ''' 585 | #length control 586 | gen_len=max(map(len,live_samples)) 587 | if gen_len < 15: 588 | cand_scores[:,eos] = 1e20 589 | 590 | #prevent repeat 591 | for i in range(len(cand_scores)): 592 | for j in range(len(live_samples[i])): 593 | cand_scores[i][live_samples[i][j]] = 1e20 594 | ''' 595 | live_scores = list(cand_scores.flatten()) 596 | 597 | # find the best (lowest) scores we have from all possible dead samples and 598 | # all live samples and all possible new words added 599 | ini_ranks,ini_scores = sample(live_scores, beam_size*10) 600 | cand_samples = [live_samples[r//voc_size]+[r%voc_size] for r in ini_ranks] 601 | r_scores = [rouge_factor*(rouge_recall(history_gen+[_sample],reference)-rouge_recall(history_gen+[_sample[:-1]],reference)) for _sample in cand_samples] 602 | merge_scores = np.subtract(ini_scores,r_scores) 603 | 604 | live_pairs = rerank(ini_ranks,merge_scores) 605 | dead_pairs = [(-dind-1,dead_scores[dind]) for dind in range(len(dead_scores))] 606 | 607 | ranks_dead, ranks_live, live_scores = rank_pair(live_pairs,dead_pairs) 608 | 609 | dead_scores = [dead_scores[r] for r in ranks_dead] 610 | dead_samples = [dead_samples[r] for r in ranks_dead] 611 | 612 | #live_scores = [live_scores[r] for r in ranks_live] 613 | 614 | # append the new words to their appropriate live sample 615 | live_samples = [live_samples[r//voc_size]+[r%voc_size] for r in ranks_live] 616 | 617 | # live samples that should be dead are... 618 | # even if len(live_samples) == maxsample we dont want it dead because we want one 619 | # last prediction out of it to reach a headline of maxlenh 620 | zombie = [s[-1] == eos or len(s) > maxlenh-1 for s in live_samples] 621 | 622 | # add zombies to the dead 623 | dead_samples += [s for s,z in zip(live_samples,zombie) if z] 624 | dead_scores += [s for s,z in zip(live_scores,zombie) if z] 625 | dead_k = len(dead_samples) 626 | # remove zombies from the living 627 | live_samples = [s for s,z in zip(live_samples,zombie) if not z] 628 | live_scores = [s for s,z in zip(live_scores,zombie) if not z] 629 | live_k = len(live_samples) 630 | all_samples = dead_samples + live_samples 631 | all_scores = dead_scores + live_scores 632 | indexs = np.argsort(all_scores) 633 | return [all_samples[i] for i in indexs], [all_scores[i] for i in indexs] 634 | 635 | def word_mask(_X): 636 | words = set(_X.flatten()) 637 | mask = np.zeros((vocab_size,)) 638 | for _word in words: 639 | mask[_word] = 1 640 | return mask 641 | 642 | 643 | #dx dy must have 1 first dim 644 | def decoder(dx,dy,min_sents,rouge_factor,decay): 645 | dX,dY=conv_seq_labels(dx,dy) 646 | dX[:,maxsents:]=0 647 | mask = word_mask(dX) 648 | mask[word2idx['']] = 0 649 | sent_generate = [3,1] 650 | score = 0.0 651 | #reference = [[_t for _t in dX[0][:3,:maxlend][_di] if _t!=0] for _di in range(3)] 652 | #reference += [[2,1]] 653 | history_gen = [] 654 | history_att = [] 655 | for epoch in range(maxhighs+1): 656 | #reference = [[_t for _t in dX[0][epoch:epoch+1,:maxlend][0] if _t!=0]] 657 | dX[:,maxsents+epoch,maxlend-len(sent_generate):maxlend] = sent_generate 658 | if word2idx[''] in sent_generate: 659 | break 660 | if epoch > min_sents: 661 | mask[word2idx['']] = 1 662 | #mask = decay_mask(sent_generate,mask,decay) 663 | attention = summ_sentatt_model.predict(dX)[0,epoch] 664 | ori_inds = np.argsort(attention)[::-1] 665 | sort_inds = [_ind for _ind in ori_inds if attention[_ind]>0 and _ind not in history_att] 666 | if len(sort_inds) == 0: 667 | for j in range(maxsents): 668 | if j not in history_att: 669 | sort_inds += [j] 670 | #print sort_inds 671 | reference = [[_t for _t in dX[0,sort_inds[0],:maxlend] if _t!=0]] 672 | history_att.append(sort_inds[0]) 673 | if epoch > min_sents: 674 | reference += [[2,1]] 675 | context_vec = context_vecs_model.predict(dX)[0,epoch:epoch+1,:] 676 | start_vec = summ_merged_model.predict(dX)[0,maxsents+epoch:maxsents+epoch+1,:] 677 | try: 678 | sent_samples,sent_scores = beamsearch(predict,decoder_layer_model,dense_model,context_vec,start_vec,mask,reference,rouge_factor,history_gen) 679 | except: 680 | break 681 | assigned = False 682 | for i in range(len(sent_samples)): 683 | _generate = sent_samples[i] 684 | if _generate[-1] == eos: 685 | sent_generate = _generate 686 | assigned = True 687 | score += sent_scores[i] 688 | break 689 | if not assigned: 690 | sent_generate = sent_samples[0][:-1]+[1] 691 | score += sent_scores[0] 692 | history_gen.append(sent_generate) 693 | generated_tokens = [t for t in dX[:,maxsents+1:].flatten().tolist() if t!=0] 694 | return generated_tokens,score 695 | 696 | def visualize(code): 697 | return ' '.join([idx2word[w] for w in code]) 698 | 699 | def remove_indicate(gen): 700 | return gen.replace('^','') 701 | 702 | def remove_entity(gen): 703 | import re 704 | return re.sub('@entity \d',' ',gen) 705 | 706 | def greedy_decode(Yp): 707 | samples = np.argmax(Yp,axis=-1).tolist() 708 | Ys = [[_word for _word in _sample if _word!=0] for _sample in samples] 709 | return Ys 710 | 711 | def collect_entitys(_X,_Y): 712 | entitys = [] 713 | former_dic = {} 714 | latter_dic = {} 715 | context_dic = {} 716 | for _x in _X+_Y: 717 | for i in range(len(_x)): 718 | if _x[i]==8: 719 | number_index = _x[i+1] 720 | if number_index < vocab_size: 721 | number = int(idx2word[number_index]) 722 | current_entity = ' '.join([str(_t) for _t in _x[i+2:i+2+number]]) 723 | entitys.append(current_entity) 724 | if i>1: 725 | former_token = _x[i-1] 726 | if current_entity in former_dic: 727 | former_dic[current_entity].append(former_token) 728 | else: 729 | former_dic[current_entity] = [former_token] 730 | if i+2+number < len(_x): 731 | latter_token = _x[i+2+number] 732 | if current_entity in latter_dic: 733 | latter_dic[current_entity].append(latter_token) 734 | else: 735 | latter_dic[current_entity] = [latter_token] 736 | if i>1 and i+2+number < len(_x): 737 | context_token = [_x[i-1],_x[i+2+number]] 738 | if current_entity in context_dic: 739 | context_dic[current_entity].append(context_token) 740 | else: 741 | context_dic[current_entity] = [context_token] 742 | from collections import Counter 743 | entity_counter = Counter(entitys) 744 | indexer = 0 745 | entity_dic = {} 746 | list_entity = [] 747 | for _entity,_count in entity_counter.most_common(): 748 | entity_dic[_entity] = indexer 749 | list_entity.append([int(_w) for _w in _entity.split()]) 750 | indexer+=1 751 | return entity_dic,list_entity,former_dic,latter_dic,context_dic 752 | 753 | def entity_replace(_x,entity_dic,list_entity,former_dic,latter_dic,context_dic): 754 | replaced_list = [] 755 | jump = 0 756 | for i in range(len(_x)): 757 | if jump>0: 758 | jump -= 1 759 | continue 760 | if _x[i]!=8: #not entity, add to final list 761 | replaced_list.append(_x[i]) 762 | continue 763 | #get the entity and its context tokens 764 | if i0: 776 | former_token = _x[i-1] 777 | else: 778 | former_token = None 779 | if i+2+number < len(_x): 780 | latter_token = _x[i+2+number] 781 | else: 782 | latter_token = None 783 | if i>0 and i+2+number < len(_x): 784 | context_token = [_x[i-1],_x[i+2+number]] 785 | else: 786 | context_token = None 787 | except: #not a number token 788 | current_entity = None 789 | if i>0: 790 | former_token = _x[i-1] 791 | else: 792 | former_token = None 793 | if i0 and i1] 812 | matched_pairs = [_p for _p in target_pairs if len(set(current_tokens).intersection(set(_p)))>0] 813 | if len(matched_pairs)>0: 814 | replaced_list += matched_pairs[0] 815 | jump = 1+len(current_tokens) 816 | #print 'Case 2: replace %s into %s'%(' '.join([idx2word[_w] for _w in current_tokens]),' '.join([idx2word[_w] for _w in matched_pairs[0]])) 817 | continue 818 | #case 3: no entity match; match context 819 | continue_flag = False 820 | if context_token: 821 | for _listentity in list_entity: 822 | _key = ' '.join([str(_t) for _t in _listentity]) 823 | if context_dic.has_key(_key): 824 | if context_token in context_dic[_key]: 825 | replaced_list += _listentity 826 | #print 'Case 3: replace %s into %s'%(str(current_entity),' '.join([idx2word[_w] for _w in _listentity])) 827 | if current_entity: 828 | jump = 1+len(current_entity.split()) 829 | continue_flag = True 830 | break 831 | if continue_flag: 832 | continue 833 | #case 4: match former or latter toekn 834 | for _listentity in list_entity: 835 | _key = ' '.join([str(_t) for _t in _listentity]) 836 | if former_dic.has_key(_key): 837 | if former_token in former_dic[_key]: 838 | replaced_list += _listentity 839 | #print 'Case 4: replace %s into %s'%(str(current_entity),' '.join([idx2word[_w] for _w in _listentity])) 840 | if current_entity: 841 | jump = 1+len(current_entity.split()) 842 | continue_flag = True 843 | break 844 | if latter_dic.has_key(_key): 845 | if latter_token in latter_dic[_key]: 846 | replaced_list += _listentity 847 | #print 'Case 4: replace %s into %s'%(str(current_entity),' '.join([idx2word[_w] for _w in _listentity])) 848 | if current_entity: 849 | jump = 1+len(current_entity.split()) 850 | continue_flag = True 851 | break 852 | if continue_flag: 853 | continue 854 | #case 5: no match at all. use the most frequent entity 855 | replaced_list += list_entity[0] 856 | #print 'Case 5: replace %s into %s'%(str(current_entity),' '.join([idx2word[_w] for _w in list_entity[0]])) 857 | if current_entity: 858 | jump = 1+len(current_entity.split()) 859 | return replaced_list 860 | 861 | def entity_process(code,_X,_Y): 862 | entity_dic,list_entity,former_dic,latter_dic,context_dic = collect_entitys(_X,_Y) 863 | replaced_list = entity_replace(code,entity_dic,list_entity,former_dic,latter_dic,context_dic) 864 | return replaced_list 865 | 866 | def evaluate(X_test,Y_test,min_sents,rouge_factor,decay): 867 | beam_gens = [] 868 | Y_descs = [' '.join([' '.join([idx2word[_w] for _w in _sent]) for _sent in _Y]) for _Y in Y_test] 869 | for _dx,_dy in zip(X_test,Y_test): 870 | try: 871 | _gen = decoder([_dx],[_dy],min_sents,rouge_factor,decay) 872 | except: 873 | _gen = decoder([_dx],[_dy],0,rouge_factor,decay) 874 | beam_gens.append(_gen) 875 | print 'Sample %d: %.4f\n%s' %(len(beam_gens),myrouge_2(visualize(_gen[0]),Y_descs[len(beam_gens)-1]),visualize(_gen[0])) 876 | beam_codes = [_gen[0] for _gen in beam_gens] 877 | beam_replaceds = [entity_process(code,_X,_Y) for code,_X,_Y in zip(beam_codes,X_test,Y_test)] 878 | visualized_raws = [visualize(_gen) for _gen in beam_codes] 879 | visualized_replaceds = [visualize(_gen) for _gen in beam_replaceds] 880 | visualized_ys = ['\n'.join([visualize(_y) for _y in _Y]) for _Y in Y_test] 881 | raw_scores = [myrouge_2(_gen,_desc) for (_gen,_desc) in zip(visualized_raws,map(remove_entity,Y_descs))] 882 | replaced_scores = [myrouge_2(_gen,_desc) for (_gen,_desc) in zip(visualized_replaceds,map(remove_entity,Y_descs))] 883 | return {'beam_gens':beam_gens,'beam_replaceds':beam_replaceds,'visualized_raws':visualized_raws,'visualized_replaceds':visualized_replaceds,'raw_scores':raw_scores,'replaced_scores':replaced_scores} 884 | 885 | 886 | r = next(gen(X_test, Y_test, batch_size=batch_size)) 887 | r[0].shape, r[1].shape, len(r) 888 | 889 | traingen = gen(X_train, Y_train, batch_size=batch_size) 890 | valgen = gen(X_valid, Y_valid, batch_size=batch_size) 891 | #assert 0==1 892 | history = {} 893 | rouges = [] 894 | Y_descs = [' '.join([' '.join([idx2word[_w] for _w in _sent]) for _sent in _Y]) for _Y in Y_valid] 895 | 896 | 897 | beam_size = 15 898 | min_sents = 0 899 | rouge_factor = 300 900 | decay = 1.0 901 | batch_index = 0 902 | large_batch = 100 903 | iteration_threshold = 50 904 | print 'Rouge factor: ',rouge_factor 905 | print '\tMin sents: ',min_sents 906 | 907 | if FN1: 908 | all_flatten_model.load_weights('data/%s.weights.pkl'%FN1) 909 | 910 | #training function 911 | rouges = [] 912 | for iteration in range(1000): 913 | print '%s\tIteration'%FN, iteration 914 | 915 | #validation on test set 916 | if iteration > iteration_threshold: 917 | trained_embedding = embedding_model.get_weights()[0] 918 | results = evaluate(X_valid[batch_index*large_batch:(batch_index+1)*large_batch],Y_valid[batch_index*large_batch:(batch_index+1)*large_batch],min_sents,rouge_factor,decay) 919 | rouge_score = np.average(results['replaced_scores']) 920 | print '\t\t raw scores: %.4f, replaced scores: %.4f'%(np.average(results['raw_scores']),np.average(results['replaced_scores'])) 921 | 922 | else: 923 | gens = [] 924 | #for _t in range(nb_val_samples/batch_size): 925 | for _t in range(100): 926 | Y_predicts = all_flatten_model.predict(conv_seq_labels(X_valid[_t*batch_size:(_t+1)*batch_size],Y_valid[_t*batch_size:(_t+1)*batch_size])[0],batch_size=batch_size) 927 | gens += greedysearch(Y_predicts) 928 | rouge_score = np.average([myrouge_2(_gen,_desc) for (_gen,_desc) in zip(gens,Y_descs)]) 929 | results = [] 930 | 931 | rouges.append(rouge_score) 932 | print 'Current Rouge score: %.4f'%rouge_score 933 | history['rouge'] = rouges 934 | 935 | 936 | with open('data/%s.history.pkl'%(str(FN)),'wb') as fp: 937 | pickle.dump(history,fp,-1) 938 | if iteration>iteration_threshold and rouge_score == max(history['rouge'][iteration_threshold:]): 939 | all_flatten_model.save_weights('data/%s.weights.pkl'%(str(FN),), overwrite=True) 940 | results_writer = open('data/%s.results.pkl'%(str(FN)),'wb') 941 | pickle.dump(results,results_writer,-1) 942 | results_writer.close() 943 | 944 | gensamples(gens) 945 | 946 | #train 947 | h = all_flatten_model.fit_generator(traingen,samples_per_epoch=nb_train_samples,nb_epoch=1,validation_data=valgen,nb_val_samples=nb_val_samples) 948 | for k,v in h.history.iteritems(): 949 | history[k] = history.get(k,[]) + v 950 | 951 | #predict 952 | trained_embedding = embedding_model.get_weights()[0] 953 | results = evaluate(X_test,Y_test,min_sents,rouge_factor,decay) 954 | outputs = results['visualized_replaceds'] 955 | 956 | 957 | 958 | --------------------------------------------------------------------------------