├── .gitignore ├── tuning.sh ├── optim.py ├── LICENSE ├── label_smoothing.py ├── prepare_rouge.py ├── word_prob_layer.py ├── configs.py ├── README.md ├── data.py ├── model.py ├── bleu.py ├── utils_pg.py ├── prepare_data.py ├── transformer.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | /data*/* 4 | /model*/* 5 | /cnndm/* 6 | -------------------------------------------------------------------------------- /tuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | FILES=./cnndm/model/* 3 | for f in $FILES; do 4 | echo "==========================" ${f##*/} 5 | python -u main.py ${f##*/} 6 | python prepare_rouge.py 7 | cd ./deepmind/result/ 8 | perl /home/pijili/tools/ROUGE-1.5.5/ROUGE-1.5.5.pl -n 4 -w 1.2 -m -2 4 -u -c 95 -r 1000 -f A -p 0.5 -t 0 myROUGE_Config.xml C 9 | cd ../../ 10 | done 11 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | class Optim: 4 | "Optim wrapper that implements rate." 5 | def __init__(self, model_size, factor, warmup, optimizer): 6 | self.optimizer = optimizer 7 | self._step = 0 8 | self.warmup = warmup 9 | self.factor = factor 10 | self.model_size = model_size 11 | self._rate = 0 12 | 13 | def step(self): 14 | "Update parameters and rate" 15 | self._step += 1 16 | rate = self.rate() 17 | for p in self.optimizer.param_groups: 18 | p['lr'] = rate 19 | self._rate = rate 20 | self.optimizer.step() 21 | 22 | def rate(self, step = None): 23 | "Implement `lrate` above" 24 | if step is None: 25 | step = self._step 26 | return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) 27 | 28 | def state_dict(self): 29 | return self.optimizer.state_dict() 30 | 31 | def load_state_dict(self, m): 32 | self.optimizer.load_state_dict(m) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Piji Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /label_smoothing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LabelSmoothing(nn.Module): 6 | "Implement label smoothing." 7 | def __init__(self, device, size, padding_idx, label_smoothing=0.0): 8 | super(LabelSmoothing, self).__init__() 9 | assert 0.0 < label_smoothing <= 1.0 10 | self.padding_idx = padding_idx 11 | self.size = size 12 | self.device = device 13 | 14 | self.smoothing_value = label_smoothing / (size - 2) 15 | self.one_hot = torch.full((1, size), self.smoothing_value).to(device) 16 | self.one_hot[0, self.padding_idx] = 0 17 | 18 | self.confidence = 1.0 - label_smoothing 19 | 20 | def forward(self, output, target): 21 | real_size = output.size(1) 22 | if real_size > self.size: 23 | real_size -= self.size 24 | else: 25 | real_size = 0 26 | 27 | model_prob = self.one_hot.repeat(target.size(0), 1) 28 | if real_size > 0: 29 | ext_zeros = torch.full((model_prob.size(0), real_size), self.smoothing_value).to(self.device) 30 | model_prob = torch.cat((model_prob, ext_zeros), -1) 31 | model_prob.scatter_(1, target, self.confidence) 32 | model_prob.masked_fill_((target == self.padding_idx), 0.) 33 | 34 | return F.kl_div(output, model_prob, reduction='sum') 35 | -------------------------------------------------------------------------------- /prepare_rouge.py: -------------------------------------------------------------------------------- 1 | #pylint: skip-file 2 | import sys 3 | import os 4 | from configs import * 5 | 6 | cfg = DeepmindConfigs() 7 | 8 | # config file for ROUGE 9 | ROUGE_PATH = cfg.cc.RESULT_PATH 10 | SUMM_PATH = cfg.cc.SUMM_PATH 11 | MODEL_PATH = cfg.cc.GROUND_TRUTH_PATH 12 | i2summ = {} 13 | summ2i = {} 14 | i2model = {} 15 | 16 | # for result 17 | flist = os.listdir(SUMM_PATH) 18 | i = 0 19 | for fname in flist: 20 | i2summ[str(i)] = fname 21 | summ2i[fname] = str(i) 22 | i += 1 23 | 24 | # for models 25 | flist = os.listdir(MODEL_PATH) 26 | i2model = {} 27 | for fname in flist: 28 | if fname not in summ2i: 29 | raise IOError 30 | 31 | i = summ2i[fname] 32 | i2model[i] = fname 33 | 34 | assert len(i2model) == len(i2summ) 35 | 36 | # write to config file 37 | rouge_s = "" 38 | file_id = 0 39 | for file_id, fsumm in i2summ.items(): 40 | rouge_s += "\n" \ 41 | + "\n" \ 42 | + SUMM_PATH \ 43 | + "\n" \ 44 | + "\n" \ 45 | + "\n" + MODEL_PATH \ 46 | + "\n" \ 47 | + "\n" \ 48 | + "\n" \ 49 | + "\n" \ 50 | + "\n

" + fsumm + "

" \ 51 | + "\n
" \ 52 | + "\n" 53 | rouge_s += "\n" + i2model[file_id] + "" 54 | rouge_s += "\n\n
" 55 | 56 | rouge_s += "\n
" 57 | 58 | with open(ROUGE_PATH + "myROUGE_Config.xml", "w") as f_rouge: 59 | f_rouge.write(rouge_s) 60 | -------------------------------------------------------------------------------- /word_prob_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #pylint: skip-file 3 | import torch 4 | import torch as T 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | from utils_pg import * 10 | from transformer import MultiheadAttention 11 | 12 | class WordProbLayer(nn.Module): 13 | def __init__(self, hidden_size, dict_size, device, copy, coverage, dropout): 14 | super(WordProbLayer, self).__init__() 15 | self.hidden_size = hidden_size 16 | self.dict_size = dict_size 17 | self.device = device 18 | self.copy = copy 19 | self.coverage = coverage 20 | self.dropout = dropout 21 | 22 | if self.copy: 23 | self.external_attn = MultiheadAttention(self.hidden_size, 1, self.dropout, weights_dropout=False) 24 | self.proj = nn.Linear(self.hidden_size * 3, self.dict_size) 25 | self.v = nn.Parameter(torch.Tensor(1, self.hidden_size * 3)) 26 | self.bv = nn.Parameter(torch.Tensor(1)) 27 | else: 28 | self.proj = nn.Linear(self.hidden_size, self.dict_size) 29 | 30 | self.init_weights() 31 | 32 | def init_weights(self): 33 | init_linear_weight(self.proj) 34 | if self.copy: 35 | init_xavier_weight(self.v) 36 | init_bias(self.bv) 37 | 38 | def forward(self, h, y_emb=None, memory=None, mask_x=None, xids=None, max_ext_len=None): 39 | 40 | if self.copy: 41 | atts, dists = self.external_attn(query=h, key=memory, value=memory, key_padding_mask=mask_x, need_weights = True) 42 | pred = T.softmax(self.proj(T.cat([h, y_emb, atts], -1)), dim=-1) 43 | if max_ext_len > 0: 44 | ext_zeros = Variable(torch.zeros(pred.size(0), pred.size(1), max_ext_len)).to(self.device) 45 | pred = T.cat((pred, ext_zeros), -1) 46 | g = T.sigmoid(F.linear(T.cat([h, y_emb, atts], -1), self.v, self.bv)) 47 | xids = xids.transpose(0, 1).unsqueeze(0).repeat(pred.size(0), 1, 1) 48 | pred = (g * pred).scatter_add(2, xids, (1 - g) * dists) 49 | else: 50 | pred = T.softmax(self.proj(h), dim=-1) 51 | dists = None 52 | return pred, dists 53 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #pylint: skip-file 3 | import os 4 | 5 | class CommonConfigs(object): 6 | def __init__(self, d_type): 7 | self.ROOT_PATH = os.getcwd() + "/" 8 | self.TRAINING_DATA_PATH = self.ROOT_PATH + d_type + "/train_set/" 9 | self.VALIDATE_DATA_PATH = self.ROOT_PATH + d_type + "/validate_set/" 10 | self.TESTING_DATA_PATH = self.ROOT_PATH + d_type + "/test_set/" 11 | self.RESULT_PATH = self.ROOT_PATH + d_type + "/result/" 12 | self.MODEL_PATH = self.ROOT_PATH + d_type + "/model/" 13 | self.BEAM_SUMM_PATH = self.RESULT_PATH + "/beam_summary/" 14 | self.BEAM_GT_PATH = self.RESULT_PATH + "/beam_ground_truth/" 15 | self.GROUND_TRUTH_PATH = self.RESULT_PATH + "/ground_truth/" 16 | self.SUMM_PATH = self.RESULT_PATH + "/summary/" 17 | self.TMP_PATH = self.ROOT_PATH + d_type + "/tmp/" 18 | 19 | 20 | class DeepmindTraining(object): 21 | IS_UNICODE = False 22 | REMOVES_PUNCTION = False 23 | HAS_Y = True 24 | BATCH_SIZE = 20 25 | 26 | class DeepmindTesting(object): 27 | IS_UNICODE = False 28 | HAS_Y = True 29 | BATCH_SIZE = 80 30 | MIN_LEN_PREDICT = 35 31 | MAX_LEN_PREDICT = 120 32 | MAX_BYTE_PREDICT = None 33 | PRINT_SIZE = 500 34 | REMOVES_PUNCTION = False 35 | 36 | class DeepmindConfigs(): 37 | 38 | cc = CommonConfigs("cnndm") 39 | FIRE = False 40 | 41 | CELL = "transformer" 42 | CUDA = True 43 | COPY = True 44 | COVERAGE = True 45 | 46 | BI_RNN = False 47 | AVG_NLL = True 48 | NORM_CLIP = 2 49 | if not AVG_NLL: 50 | NORM_CLIP = 5 51 | LR = 0.15 52 | SMOOTHING = 0.1 53 | 54 | BEAM_SEARCH = True 55 | BEAM_SIZE = 4 56 | ALPHA = 0.9 # length penalty 57 | BETA = 5 # coverage during beamsearch 58 | 59 | DIM_X = 512 60 | DIM_Y = DIM_X 61 | HIDDEN_SIZE = 512 62 | FF_SIZE = 1024 63 | NUM_H = 8 # multi-head attention 64 | DROPOUT = 0.2 65 | NUM_L = 4 # num of layers 66 | MIN_LEN_X = 10 67 | MIN_LEN_Y = 10 68 | MAX_LEN_X = 400 69 | MAX_LEN_Y = 100 70 | MIN_NUM_X = 1 71 | MAX_NUM_X = 1 72 | MAX_NUM_Y = None 73 | 74 | NUM_Y = 1 75 | 76 | UNI_LOW_FREQ_THRESHOLD = 10 77 | 78 | PG_DICT_SIZE = 50000 # dict for acl17 paper: pointer-generator 79 | 80 | W_UNK = "" 81 | W_BOS = "" 82 | W_EOS = "" 83 | W_PAD = "" 84 | W_LS = "" 85 | W_RS = "" 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TranSummar 2 | Transformer for abstractive summarization 3 | 4 | #### cnndm (with copy and coverage), epoch57: 5 | ``` 6 | --------------------------------------------- 7 | C ROUGE-1 Average_R: 0.41097 (95%-conf.int. 0.40861 - 0.41346) 8 | C ROUGE-1 Average_P: 0.40874 (95%-conf.int. 0.40619 - 0.41141) 9 | C ROUGE-1 Average_F: 0.39656 (95%-conf.int. 0.39451 - 0.39871) 10 | --------------------------------------------- 11 | C ROUGE-2 Average_R: 0.17821 (95%-conf.int. 0.17590 - 0.18049) 12 | C ROUGE-2 Average_P: 0.17781 (95%-conf.int. 0.17540 - 0.18037) 13 | C ROUGE-2 Average_F: 0.17208 (95%-conf.int. 0.16990 - 0.17433) 14 | --------------------------------------------- 15 | C ROUGE-3 Average_R: 0.09845 (95%-conf.int. 0.09640 - 0.10064) 16 | C ROUGE-3 Average_P: 0.09844 (95%-conf.int. 0.09627 - 0.10069) 17 | C ROUGE-3 Average_F: 0.09505 (95%-conf.int. 0.09307 - 0.09713) 18 | --------------------------------------------- 19 | C ROUGE-4 Average_R: 0.06297 (95%-conf.int. 0.06109 - 0.06499) 20 | C ROUGE-4 Average_P: 0.06329 (95%-conf.int. 0.06137 - 0.06537) 21 | C ROUGE-4 Average_F: 0.06086 (95%-conf.int. 0.05908 - 0.06275) 22 | --------------------------------------------- 23 | C ROUGE-L Average_R: 0.37912 (95%-conf.int. 0.37682 - 0.38160) 24 | C ROUGE-L Average_P: 0.37726 (95%-conf.int. 0.37484 - 0.37987) 25 | C ROUGE-L Average_F: 0.36593 (95%-conf.int. 0.36388 - 0.36810) 26 | --------------------------------------------- 27 | C ROUGE-W-1.2 Average_R: 0.16602 (95%-conf.int. 0.16487 - 0.16715) 28 | C ROUGE-W-1.2 Average_P: 0.27554 (95%-conf.int. 0.27362 - 0.27758) 29 | C ROUGE-W-1.2 Average_F: 0.20031 (95%-conf.int. 0.19902 - 0.20156) 30 | --------------------------------------------- 31 | C ROUGE-SU4 Average_R: 0.18191 (95%-conf.int. 0.17981 - 0.18403) 32 | C ROUGE-SU4 Average_P: 0.18101 (95%-conf.int. 0.17890 - 0.18320) 33 | C ROUGE-SU4 Average_F: 0.17496 (95%-conf.int. 0.17308 - 0.17693) 34 | ``` 35 | #### gigawords (no copy and no coverage), epoch18: 36 | ``` 37 | --------------------------------------------- 38 | C ROUGE-1 Average_R: 0.36144 (95%-conf.int. 0.34958 - 0.37330) 39 | C ROUGE-1 Average_P: 0.37213 (95%-conf.int. 0.36018 - 0.38460) 40 | C ROUGE-1 Average_F: 0.35586 (95%-conf.int. 0.34433 - 0.36747) 41 | --------------------------------------------- 42 | C ROUGE-2 Average_R: 0.17568 (95%-conf.int. 0.16614 - 0.18606) 43 | C ROUGE-2 Average_P: 0.18536 (95%-conf.int. 0.17463 - 0.19625) 44 | C ROUGE-2 Average_F: 0.17467 (95%-conf.int. 0.16489 - 0.18467) 45 | --------------------------------------------- 46 | C ROUGE-3 Average_R: 0.09628 (95%-conf.int. 0.08782 - 0.10555) 47 | C ROUGE-3 Average_P: 0.10448 (95%-conf.int. 0.09482 - 0.11429) 48 | C ROUGE-3 Average_F: 0.09643 (95%-conf.int. 0.08763 - 0.10558) 49 | --------------------------------------------- 50 | C ROUGE-4 Average_R: 0.05583 (95%-conf.int. 0.04812 - 0.06380) 51 | C ROUGE-4 Average_P: 0.06323 (95%-conf.int. 0.05447 - 0.07197) 52 | C ROUGE-4 Average_F: 0.05653 (95%-conf.int. 0.04871 - 0.06425) 53 | --------------------------------------------- 54 | C ROUGE-L Average_R: 0.33497 (95%-conf.int. 0.32382 - 0.34678) 55 | C ROUGE-L Average_P: 0.34521 (95%-conf.int. 0.33414 - 0.35770) 56 | C ROUGE-L Average_F: 0.33005 (95%-conf.int. 0.31905 - 0.34173) 57 | --------------------------------------------- 58 | C ROUGE-W-1.2 Average_R: 0.20852 (95%-conf.int. 0.20134 - 0.21619) 59 | C ROUGE-W-1.2 Average_P: 0.32259 (95%-conf.int. 0.31225 - 0.33452) 60 | C ROUGE-W-1.2 Average_F: 0.24404 (95%-conf.int. 0.23594 - 0.25286) 61 | --------------------------------------------- 62 | C ROUGE-SU4 Average_R: 0.20404 (95%-conf.int. 0.19456 - 0.21456) 63 | C ROUGE-SU4 Average_P: 0.21410 (95%-conf.int. 0.20406 - 0.22493) 64 | C ROUGE-SU4 Average_F: 0.19664 (95%-conf.int. 0.18736 - 0.20654) 65 | ``` 66 | 67 | ### How to run: 68 | - Python 3.7, Pytorch 0.4+ 69 | - Download the processed dataset from: https://drive.google.com/file/d/1EUuEMBSlrlnf_J2jcAVl1v4owSvw_8ZF/view?usp=sharing , or you can download the original FINISHED_FILES from: https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail , and process by yourself. 70 | - Modify the path in prepare_data.py then run it: python prepare_data.py 71 | - Training: python -u main.py | tee train.log 72 | - Tuning: modify main.py: is_predicting=true and model_selection=true, then run "bash tuning_deepmind.sh | tee tune.log" 73 | - Testing: modify main.py: is_predicting=true and model_selection=false, then run "python main.py you-best-model (say cnndm.s2s.gpu4.epoch7.1)", go to "./deepmind/result/" and run $ROUGE$ myROUGE_Config.xml C, you will get the results. 74 | - The Perl Rouge package is enough, I did not use pyrouge. 75 | 76 | ### Reference: 77 | - fairseq: https://github.com/pytorch/fairseq 78 | - The Annotated Transformer: http://nlp.seas.harvard.edu/2018/04/03/attention.html 79 | - bert:https://github.com/jcyk/BERT 80 | - Rush-Gigaword: https://drive.google.com/open?id=0B6N7tANPyVeBNmlSX19Ld2xDU1E 81 | - Rush-CNN/Dailymail: https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz 82 | 83 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #pylint: skip-file 3 | import sys 4 | import os 5 | import os.path 6 | import time 7 | from operator import itemgetter 8 | import numpy as np 9 | import pickle 10 | from random import shuffle 11 | 12 | class BatchData: 13 | def __init__(self, flist, modules, consts, options): 14 | self.batch_size = len(flist) 15 | self.x = np.zeros((consts["len_x"], self.batch_size), dtype = np.int64) 16 | self.x_ext = np.zeros((consts["len_x"], self.batch_size), dtype = np.int64) 17 | self.y = np.zeros((consts["len_y"], self.batch_size), dtype = np.int64) 18 | self.y_inp = np.zeros((consts["len_y"], self.batch_size), dtype = np.int64) 19 | self.y_ext = np.zeros((consts["len_y"], self.batch_size), dtype = np.int64) 20 | self.x_mask = np.zeros((consts["len_x"], self.batch_size, 1), dtype = np.int64) 21 | self.y_mask = np.zeros((consts["len_y"], self.batch_size, 1), dtype = np.int64) 22 | self.len_x = [] 23 | self.len_y = [] 24 | self.original_contents = [] 25 | self.original_summarys = [] 26 | self.x_ext_words = [] 27 | self.max_ext_len = 0 28 | 29 | w2i = modules["w2i"] 30 | i2w = modules["i2w"] 31 | dict_size = len(w2i) 32 | 33 | for idx_doc in range(len(flist)): 34 | if len(flist[idx_doc]) == 2: 35 | contents, summarys = flist[idx_doc] 36 | else: 37 | print ("ERROR!") 38 | return 39 | 40 | content, original_content = contents 41 | summary, original_summary = summarys 42 | self.original_contents.append(original_content) 43 | self.original_summarys.append(original_summary) 44 | 45 | xi_oovs = [] 46 | for idx_word in range(len(content)): 47 | # some sentences in duc is longer than len_x 48 | if idx_word == consts["len_x"]: 49 | break 50 | w = content[idx_word] 51 | 52 | if w not in w2i: # OOV 53 | if w not in xi_oovs: 54 | xi_oovs.append(w) 55 | self.x_ext[idx_word, idx_doc] = dict_size + xi_oovs.index(w) # 500005, 51000 56 | w = i2w[modules["lfw_emb"]] 57 | else: 58 | self.x_ext[idx_word, idx_doc] = w2i[w] 59 | 60 | self.x[idx_word, idx_doc] = w2i[w] 61 | self.x_mask[idx_word, idx_doc, 0] = 1 62 | self.len_x.append(np.sum(self.x_mask[:, idx_doc, :])) 63 | self.x_ext_words.append(xi_oovs) 64 | if self.max_ext_len < len(xi_oovs): 65 | self.max_ext_len = len(xi_oovs) 66 | 67 | if options["has_y"]: 68 | for idx_word in range(len(summary)): 69 | w = summary[idx_word] 70 | 71 | if w not in w2i: 72 | if w in xi_oovs: 73 | self.y_ext[idx_word, idx_doc] = dict_size + xi_oovs.index(w) 74 | else: 75 | self.y_ext[idx_word, idx_doc] = w2i[i2w[modules["lfw_emb"]]] # unk 76 | w = i2w[modules["lfw_emb"]] 77 | else: 78 | self.y_ext[idx_word, idx_doc] = w2i[w] 79 | self.y[idx_word, idx_doc] = w2i[w] 80 | 81 | if idx_word == 0: 82 | self.y_inp[idx_word, idx_doc] = modules["bos_idx"] 83 | if idx_word < (len(summary) - 1): 84 | self.y_inp[idx_word + 1, idx_doc] = w2i[w] 85 | 86 | if not options["is_predicting"]: 87 | self.y_mask[idx_word, idx_doc, 0] = 1 88 | self.len_y.append(len(summary)) 89 | else: 90 | self.y = self.y_mask = None 91 | 92 | max_len_x = int(np.max(self.len_x)) 93 | max_len_y = int(np.max(self.len_y)) 94 | 95 | self.x = self.x[0:max_len_x, :] 96 | self.x_ext = self.x_ext[0:max_len_x, :] 97 | self.x_mask = self.x_mask[0:max_len_x, :, :] 98 | self.y = self.y[0:max_len_y, :] 99 | self.y_inp = self.y_inp[0:max_len_y, :] 100 | self.y_ext = self.y_ext[0:max_len_y, :] 101 | self.y_mask = self.y_mask[0:max_len_y, :, :] 102 | 103 | def get_data(xy_list, modules, consts, options): 104 | return BatchData(xy_list, modules, consts, options) 105 | 106 | def batched(x_size, options, consts): 107 | batch_size = consts["testing_batch_size"] if options["is_predicting"] else consts["batch_size"] 108 | if options["is_debugging"]: 109 | x_size = 13 110 | ids = [i for i in range(x_size)] 111 | if not options["is_predicting"]: 112 | shuffle(ids) 113 | batch_list = [] 114 | batch_ids = [] 115 | for i in range(x_size): 116 | idx = ids[i] 117 | batch_ids.append(idx) 118 | if len(batch_ids) == batch_size or i == (x_size - 1): 119 | batch_list.append(batch_ids) 120 | batch_ids = [] 121 | return batch_list, len(ids), len(batch_list) 122 | 123 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #pylint: skip-file 3 | import sys 4 | import numpy as np 5 | import torch 6 | import torch as T 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | 11 | from utils_pg import * 12 | 13 | from transformer import TransformerLayer, Embedding, LearnedPositionalEmbedding, gelu, LayerNorm, SelfAttentionMask 14 | from word_prob_layer import * 15 | from label_smoothing import LabelSmoothing 16 | 17 | class Model(nn.Module): 18 | def __init__(self, modules, consts, options): 19 | super(Model, self).__init__() 20 | 21 | self.has_learnable_w2v = options["has_learnable_w2v"] 22 | self.is_predicting = options["is_predicting"] 23 | self.is_bidirectional = options["is_bidirectional"] 24 | self.beam_decoding = options["beam_decoding"] 25 | self.cell = options["cell"] 26 | self.device = options["device"] 27 | self.copy = options["copy"] 28 | self.coverage = options["coverage"] 29 | self.avg_nll = options["avg_nll"] 30 | 31 | self.dim_x = consts["dim_x"] 32 | self.dim_y = consts["dim_y"] 33 | self.len_x = consts["len_x"] 34 | self.len_y = consts["len_y"] 35 | self.hidden_size = consts["hidden_size"] 36 | self.dict_size = consts["dict_size"] 37 | self.pad_token_idx = consts["pad_token_idx"] 38 | self.ctx_size = self.hidden_size * 2 if self.is_bidirectional else self.hidden_size 39 | self.num_layers = consts["num_layers"] 40 | self.d_ff = consts["d_ff"] 41 | self.num_heads = consts["num_heads"] 42 | self.dropout = consts["dropout"] 43 | self.smoothing_factor = consts["label_smoothing"] 44 | 45 | self.tok_embed = nn.Embedding(self.dict_size, self.dim_x, self.pad_token_idx) 46 | self.pos_embed = LearnedPositionalEmbedding(self.dim_x, device=self.device) 47 | 48 | self.enc_layers = nn.ModuleList() 49 | for i in range(self.num_layers): 50 | self.enc_layers.append(TransformerLayer(self.dim_x, self.d_ff, self.num_heads, self.dropout)) 51 | 52 | self.dec_layers = nn.ModuleList() 53 | for i in range(self.num_layers): 54 | self.dec_layers.append(TransformerLayer(self.dim_x, self.d_ff, self.num_heads, self.dropout, with_external=True)) 55 | 56 | self.attn_mask = SelfAttentionMask(device=self.device) 57 | 58 | self.emb_layer_norm = LayerNorm(self.dim_x) 59 | 60 | self.word_prob = WordProbLayer(self.hidden_size, self.dict_size, self.device, self.copy, self.coverage, self.dropout) 61 | 62 | self.smoothing = LabelSmoothing(self.device, self.dict_size, self.pad_token_idx, self.smoothing_factor) 63 | 64 | self.init_weights() 65 | 66 | def init_weights(self): 67 | init_uniform_weight(self.tok_embed.weight) 68 | 69 | 70 | def label_smotthing_loss(self, y_pred, y, y_mask, avg=True): 71 | seq_len, bsz = y.size() 72 | 73 | y_pred = T.log(y_pred.clamp(min=1e-8)) 74 | loss = self.smoothing(y_pred.view(seq_len * bsz, -1), y.view(seq_len * bsz, -1)) 75 | if avg: 76 | return loss / T.sum(y_mask) 77 | else: 78 | return loss / bsz 79 | 80 | def nll_loss(self, y_pred, y, y_mask, avg=True): 81 | cost = -T.log(T.gather(y_pred, 2, y.view(y.size(0), y.size(1), 1))) 82 | cost = cost.view(y.shape) 83 | y_mask = y_mask.view(y.shape) 84 | if avg: 85 | cost = T.sum(cost * y_mask, 0) / T.sum(y_mask, 0) 86 | else: 87 | cost = T.sum(cost * y_mask, 0) 88 | cost = cost.view((y.size(1), -1)) 89 | return T.mean(cost) 90 | 91 | def encode(self, inp): 92 | seq_len, bsz = inp.size() 93 | x = self.tok_embed(inp) + self.pos_embed(inp) 94 | x = self.emb_layer_norm(x) 95 | x = F.dropout(x, p=self.dropout, training=self.training) 96 | padding_mask = torch.eq(inp, self.pad_token_idx) 97 | if not padding_mask.any(): 98 | padding_mask = None 99 | 100 | xs = [] 101 | for layer_id, layer in enumerate(self.enc_layers): 102 | x, _ ,_ = layer(x, self_padding_mask=padding_mask) 103 | xs.append(x) 104 | 105 | return x, padding_mask 106 | 107 | 108 | def decode(self, inp, mask_x, mask_y, src, src_padding_mask, xids=None, max_ext_len=None): 109 | seq_len, bsz = inp.size() 110 | x = self.tok_embed(inp) + self.pos_embed(inp) 111 | x = self.emb_layer_norm(x) 112 | x = F.dropout(x, p=self.dropout, training=self.training) 113 | h = x 114 | if not self.is_predicting: 115 | mask_y = mask_y.view((seq_len, bsz)) 116 | padding_mask = torch.eq(mask_y, self.pad_token_idx) 117 | if not padding_mask.any(): 118 | padding_mask = None 119 | else: 120 | padding_mask = None 121 | 122 | self_attn_mask = self.attn_mask(seq_len) 123 | 124 | for layer_id, layer in enumerate(self.dec_layers): 125 | x, _, _ = layer(x, self_padding_mask=padding_mask,\ 126 | self_attn_mask = self_attn_mask,\ 127 | external_memories = src,\ 128 | external_padding_mask = src_padding_mask,\ 129 | need_weights = False) 130 | if self.copy: 131 | y_dec, attn_dist = self.word_prob(x, h, src, src_padding_mask, xids, max_ext_len) 132 | else: 133 | y_dec, attn_dist = self.word_prob(x) 134 | 135 | return y_dec, attn_dist 136 | 137 | 138 | def forward(self, x, y_inp, y_tgt, mask_x, mask_y, x_ext, y_ext, max_ext_len): 139 | hs, src_padding_mask = self.encode(x) 140 | if self.copy: 141 | y_pred, _ = self.decode(y_inp, mask_x, mask_y, hs, src_padding_mask, x_ext, max_ext_len) 142 | cost = self.label_smotthing_loss(y_pred, y_ext, mask_y, self.avg_nll) 143 | else: 144 | y_pred, _ = self.decode(y_inp, mask_x, mask_y, hs, src_padding_mask) 145 | cost = self.nll_loss(y_pred, y_tgt, mask_y, self.avg_nll) 146 | 147 | return y_pred, cost 148 | 149 | -------------------------------------------------------------------------------- /bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | 18 | """Python implementation of BLEU and smooth-BLEU. 19 | 20 | copy from: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py 21 | 22 | This module provides a Python implementation of BLEU and smooth-BLEU. 23 | Smooth BLEU is computed following the method outlined in the paper: 24 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 25 | evaluation metrics for machine translation. COLING 2004. 26 | """ 27 | 28 | import collections 29 | import math 30 | import os 31 | import argparse 32 | 33 | def load_lines(f_path): 34 | lines = [] 35 | with open(f_path, "r") as f: 36 | for line in f: 37 | line = line.strip('\n').strip('\r') 38 | fs = line.split() 39 | lines.append(fs) 40 | return lines 41 | 42 | def _get_ngrams(segment, max_order): 43 | """Extracts all n-grams upto a given maximum order from an input segment. 44 | 45 | Args: 46 | segment: text segment from which n-grams will be extracted. 47 | max_order: maximum length in tokens of the n-grams returned by this 48 | methods. 49 | 50 | Returns: 51 | The Counter containing all n-grams upto max_order in segment 52 | with a count of how many times each n-gram occurred. 53 | """ 54 | ngram_counts = collections.Counter() 55 | for order in range(1, max_order + 1): 56 | for i in range(0, len(segment) - order + 1): 57 | ngram = tuple(segment[i:i+order]) 58 | ngram_counts[ngram] += 1 59 | return ngram_counts 60 | 61 | 62 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 63 | smooth=False): 64 | """Computes BLEU score of translated segments against one or more references. 65 | 66 | Args: 67 | reference_corpus: list of lists of references for each translation. Each 68 | reference should be tokenized into a list of tokens. 69 | translation_corpus: list of translations to score. Each translation 70 | should be tokenized into a list of tokens. 71 | max_order: Maximum n-gram order to use when computing BLEU score. 72 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 73 | 74 | Returns: 75 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 76 | precisions and brevity penalty. 77 | """ 78 | matches_by_order = [0] * max_order 79 | possible_matches_by_order = [0] * max_order 80 | reference_length = 0 81 | translation_length = 0 82 | for (references, translation) in zip(reference_corpus, 83 | translation_corpus): 84 | reference_length += min(len(r) for r in references) 85 | translation_length += len(translation) 86 | 87 | merged_ref_ngram_counts = collections.Counter() 88 | for reference in references: 89 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 90 | translation_ngram_counts = _get_ngrams(translation, max_order) 91 | overlap = translation_ngram_counts & merged_ref_ngram_counts 92 | for ngram in overlap: 93 | matches_by_order[len(ngram)-1] += overlap[ngram] 94 | for order in range(1, max_order+1): 95 | possible_matches = len(translation) - order + 1 96 | if possible_matches > 0: 97 | possible_matches_by_order[order-1] += possible_matches 98 | 99 | precisions = [0] * max_order 100 | for i in range(0, max_order): 101 | if smooth: 102 | precisions[i] = ((matches_by_order[i] + 1.) / 103 | (possible_matches_by_order[i] + 1.)) 104 | else: 105 | if possible_matches_by_order[i] > 0: 106 | precisions[i] = (float(matches_by_order[i]) / 107 | possible_matches_by_order[i]) 108 | else: 109 | precisions[i] = 0.0 110 | 111 | if min(precisions) > 0: 112 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 113 | geo_mean = math.exp(p_log_sum) 114 | else: 115 | geo_mean = 0 116 | 117 | ratio = float(translation_length) / reference_length 118 | 119 | if ratio > 1.0: 120 | bp = 1. 121 | else: 122 | bp = math.exp(1 - 1. / ratio) 123 | 124 | bleu = geo_mean * bp 125 | 126 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 127 | 128 | 129 | 130 | def bleu(ref_path, pred_path, smooth=True, n = 1): 131 | id2f_ref = {} 132 | id2f_pred = {} 133 | 134 | flist = os.listdir(ref_path) 135 | for fname in flist: 136 | id_ = fname 137 | id2f_ref[id_] = ref_path + fname 138 | 139 | flist = os.listdir(pred_path) 140 | for fname in flist: 141 | id_ = fname 142 | id2f_pred[id_] = pred_path + fname 143 | 144 | assert len(id2f_ref) == len(id2f_pred) 145 | 146 | ref_lists = [] 147 | pred_lists = [] 148 | for fid, fpath in id2f_ref.items(): 149 | ref_list = load_lines(fpath) 150 | assert len(ref_list) == n 151 | ref_lists.append(ref_list) 152 | 153 | pred_list = load_lines(id2f_pred[fid]) 154 | assert len(pred_list) == n 155 | pred_lists.append(pred_list[0]) 156 | 157 | 158 | return compute_bleu(ref_lists, pred_lists, smooth=smooth) 159 | 160 | bleu("./weibo/result/ground_truth/", "./weibo/result/summary/", smooth=True) 161 | 162 | if __name__ == "__main__": 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument("-r", "--ref", help="reference path") 165 | parser.add_argument("-p", "--pred", help="prediction path") 166 | args = parser.parse_args() 167 | 168 | bleu, precisions, bp, ratio, translation_length, reference_length = bleu(args.ref, args.pred) 169 | print "BLEU = ",bleu 170 | print "BLEU1 = ",precisions[0] 171 | print "BLEU2 = ",precisions[1] 172 | print "BLEU3 = ",precisions[2] 173 | print "BLEU4 = ",precisions[3] 174 | print "ratio = ",ratio 175 | -------------------------------------------------------------------------------- /utils_pg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #pylint: skip-file 3 | import numpy as np 4 | from numpy.random import random as rand 5 | import pickle 6 | import sys 7 | import os 8 | import shutil 9 | from copy import deepcopy 10 | import random 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | def init_seeds(): 17 | random.seed(123) 18 | torch.manual_seed(123) 19 | if torch.cuda.is_available(): 20 | torch.cuda.manual_seed_all(123) 21 | 22 | def init_lstm_weight(lstm): 23 | for param in lstm.parameters(): 24 | if len(param.shape) >= 2: # weights 25 | init_ortho_weight(param.data) 26 | else: # bias 27 | init_bias(param.data) 28 | 29 | def init_gru_weight(gru): 30 | for param in gru.parameters(): 31 | if len(param.shape) >= 2: # weights 32 | init_ortho_weight(param.data) 33 | else: # bias 34 | init_bias(param.data) 35 | 36 | def init_linear_weight(linear): 37 | init_xavier_weight(linear.weight) 38 | if linear.bias is not None: 39 | init_bias(linear.bias) 40 | 41 | def init_normal_weight(w): 42 | nn.init.normal_(w, mean=0, std=0.01) 43 | 44 | def init_uniform_weight(w): 45 | nn.init.uniform_(w, -0.1, 0.1) 46 | 47 | def init_ortho_weight(w): 48 | nn.init.orthogonal_(w) 49 | 50 | def init_xavier_weight(w): 51 | nn.init.xavier_normal_(w) 52 | 53 | def init_bias(b): 54 | nn.init.constant_(b, 0.) 55 | 56 | def rebuild_dir(path): 57 | if os.path.exists(path): 58 | try: 59 | shutil.rmtree(path) 60 | except OSError: 61 | pass 62 | os.mkdir(path) 63 | 64 | def save_model(f, model, optimizer): 65 | torch.save({"model_state_dict" : model.state_dict(), 66 | "optimizer_state_dict" : optimizer.state_dict()}, 67 | f) 68 | 69 | def load_model(f, model, optimizer): 70 | checkpoint = torch.load(f) 71 | model.load_state_dict(checkpoint["model_state_dict"]) 72 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 73 | return model, optimizer 74 | 75 | def sort_samples(x, len_x, mask_x, y, len_y, \ 76 | mask_y, oys, x_ext, y_ext, oovs): 77 | sorted_x_idx = np.argsort(len_x)[::-1] 78 | 79 | sorted_x_len = np.array(len_x)[sorted_x_idx] 80 | sorted_x = x[:, sorted_x_idx] 81 | sorted_x_mask = mask_x[:, sorted_x_idx, :] 82 | sorted_oovs = [oovs[i] for i in sorted_x_idx] 83 | 84 | sorted_y_len = np.array(len_y)[sorted_x_idx] 85 | sorted_y = y[:, sorted_x_idx] 86 | sorted_y_mask = mask_y[:, sorted_x_idx, :] 87 | sorted_oys = [oys[i] for i in sorted_x_idx] 88 | sorted_x_ext = x_ext[:, sorted_x_idx] 89 | sorted_y_ext = y_ext[:, sorted_x_idx] 90 | 91 | return sorted_x, sorted_x_len, sorted_x_mask, sorted_y, \ 92 | sorted_y_len, sorted_y_mask, sorted_oys, \ 93 | sorted_x_ext, sorted_y_ext, sorted_oovs 94 | 95 | def print_sent_dec(y_pred, y, y_mask, oovs, modules, consts, options, batch_size): 96 | print("golden truth and prediction samples:") 97 | max_y_words = np.sum(y_mask, axis = 0) 98 | max_y_words = max_y_words.reshape((batch_size)) 99 | max_num_docs = 16 if batch_size > 16 else batch_size 100 | is_unicode = options["is_unicode"] 101 | dict_size = len(modules["i2w"]) 102 | for idx_doc in range(max_num_docs): 103 | print(idx_doc + 1, "----------------------------------------------------------------------------------------------------") 104 | sent_true= "" 105 | for idx_word in range(max_y_words[idx_doc]): 106 | i = y[idx_word, idx_doc] if options["has_learnable_w2v"] else np.argmax(y[idx_word, idx_doc]) 107 | if i in modules["i2w"]: 108 | sent_true += modules["i2w"][i] 109 | else: 110 | sent_true += oovs[idx_doc][i - dict_size] 111 | if not is_unicode: 112 | sent_true += " " 113 | 114 | if is_unicode: 115 | print(sent_true.encode("utf-8")) 116 | else: 117 | print(sent_true) 118 | 119 | print() 120 | 121 | sent_pred = "" 122 | for idx_word in range(max_y_words[idx_doc]): 123 | i = torch.argmax(y_pred[idx_word, idx_doc, :]).item() 124 | if i in modules["i2w"]: 125 | sent_pred += modules["i2w"][i] 126 | else: 127 | sent_pred += oovs[idx_doc][i - dict_size] 128 | if not is_unicode: 129 | sent_pred += " " 130 | if is_unicode: 131 | print(sent_pred.encode("utf-8")) 132 | else: 133 | print(sent_pred) 134 | print("----------------------------------------------------------------------------------------------------") 135 | print() 136 | 137 | 138 | def write_for_rouge(fname, ref_sents, dec_words, cfg): 139 | dec_sents = [] 140 | while len(dec_words) > 0: 141 | try: 142 | fst_period_idx = dec_words.index(".") 143 | except ValueError: 144 | fst_period_idx = len(dec_words) 145 | sent = dec_words[:fst_period_idx + 1] 146 | dec_words = dec_words[fst_period_idx + 1:] 147 | dec_sents.append(' '.join(sent)) 148 | 149 | ref_file = "".join((cfg.cc.GROUND_TRUTH_PATH, fname)) 150 | decoded_file = "".join((cfg.cc.SUMM_PATH, fname)) 151 | 152 | with open(ref_file, "w") as f: 153 | for idx, sent in enumerate(ref_sents): 154 | sent = sent.strip() 155 | f.write(sent) if idx == len(ref_sents) - 1 else f.write(sent + "\n") 156 | with open(decoded_file, "w") as f: 157 | for idx, sent in enumerate(dec_sents): 158 | sent = sent.strip() 159 | f.write(sent) if idx == len(dec_sents) - 1 else f.write(sent + "\n") 160 | 161 | def write_summ(dst_path, summ_list, num_summ, options, i2w = None, oovs=None, score_list = None): 162 | assert num_summ > 0 163 | with open(dst_path, "w") as f_summ: 164 | if num_summ == 1: 165 | if score_list != None: 166 | f_summ.write(str(score_list[0])) 167 | f_summ.write("\t") 168 | if i2w != None: 169 | ''' 170 | for e in summ_list: 171 | e = int(e) 172 | if e in i2w: 173 | print i2w[e], 174 | else: 175 | print oovs[e - len(i2w)], 176 | print "\n" 177 | ''' 178 | s = [] 179 | for e in summ_list: 180 | e = int(e) 181 | if e in i2w: 182 | s.append(i2w[e]) 183 | else: 184 | s.append(oovs[e - len(i2w)]) 185 | s = " ".join(s) 186 | else: 187 | s = " ".join(summ_list) 188 | f_summ.write(s) 189 | f_summ.write("\n") 190 | else: 191 | assert num_summ == len(summ_list) 192 | if score_list != None: 193 | assert num_summ == len(score_list) 194 | 195 | for i in range(num_summ): 196 | if score_list != None: 197 | f_summ.write(str(score_list[i])) 198 | f_summ.write("\t") 199 | if i2w != None: 200 | ''' 201 | for e in summ_list[i]: 202 | e = int(e) 203 | if e in i2w: 204 | print i2w[e], 205 | else: 206 | print oovs[e - len(i2w)], 207 | print "\n" 208 | ''' 209 | s = [] 210 | for e in summ_list[i]: 211 | e = int(e) 212 | if e in i2w: 213 | s.append(i2w[e]) 214 | else: 215 | s.append(oovs[e - len(i2w)]) 216 | s = " ".join(s) 217 | else: 218 | s = " ".join(summ_list[i]) 219 | 220 | f_summ.write(s) 221 | f_summ.write("\n") 222 | 223 | 224 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import operator 3 | from os import makedirs 4 | from os.path import exists 5 | import argparse 6 | from configs import * 7 | import pickle 8 | import numpy as np 9 | import re 10 | from random import shuffle 11 | import string 12 | import struct 13 | 14 | def run(d_type, d_path): 15 | prepare_deepmind(d_path) 16 | 17 | stop_words = {"-lrb-", "-rrb-", "-"} 18 | unk_words = {"unk", ""} 19 | 20 | def get_xy_tuple(cont, head, cfg): 21 | x = read_cont(cont, cfg) 22 | y = read_head(head, cfg) 23 | 24 | if x != None and y != None: 25 | return (x, y) 26 | else: 27 | return None 28 | 29 | def load_lines(d_path, f_name, configs): 30 | lines = [] 31 | f_path = d_path + f_name 32 | with open(f_path, 'r') as f: 33 | for line in f: 34 | line = line.strip("\n").lower() 35 | fs = line.split("") 36 | if len(fs) == 2: 37 | xy_tuple = get_xy_tuple(fs[1], fs[0], configs) 38 | else: 39 | print("ERROR:" + line) 40 | continue 41 | if xy_tuple != None: 42 | lines.append(xy_tuple) 43 | return lines 44 | 45 | def load_dict(d_path, f_name, dic, dic_list): 46 | f_path = d_path + f_name 47 | f = open(f_path, "r") 48 | for line in f: 49 | line = line.strip('\n').strip('\r').lower() 50 | if line: 51 | tf = line.split() 52 | if len(tf) == 2: 53 | dic[tf[0]] = int(tf[1]) 54 | dic_list.append(tf[0]) 55 | else: 56 | print("warning in vocab:", line) 57 | return dic, dic_list 58 | 59 | def to_dict(xys, dic): 60 | # dict should not consider test set!!!!! 61 | for xy in xys: 62 | sents, summs = xy 63 | y = summs[0] 64 | for w in y: 65 | if w in dic: 66 | dic[w] += 1 67 | else: 68 | dic[w] = 1 69 | 70 | x = sents[0] 71 | for w in x: 72 | if w in dic: 73 | dic[w] += 1 74 | else: 75 | dic[w] = 1 76 | return dic 77 | 78 | 79 | def del_num(s): 80 | return re.sub(r"(\b|\s+\-?|^\-?)(\d+|\d*\.\d+)\b","#", s) 81 | 82 | def read_cont(f_cont, cfg): 83 | lines = [] 84 | line = f_cont #del_num(f_cont) 85 | words = line.split() 86 | num_words = len(words) 87 | if num_words >= cfg.MIN_LEN_X and num_words < cfg.MAX_LEN_X: 88 | lines += words 89 | elif num_words >= cfg.MAX_LEN_X: 90 | lines += words[0:cfg.MAX_LEN_X] 91 | lines += [cfg.W_EOS] 92 | return (lines, f_cont) if len(lines) >= cfg.MIN_LEN_X and len(lines) <= cfg.MAX_LEN_X+1 else None 93 | 94 | def abstract2sents(abstract, cfg): 95 | cur = 0 96 | sents = [] 97 | while True: 98 | try: 99 | start_p = abstract.index(cfg.W_LS, cur) 100 | end_p = abstract.index(cfg.W_RS, start_p + 1) 101 | cur = end_p + len(cfg.W_RS) 102 | sents.append(abstract[start_p+len(cfg.W_LS):end_p]) 103 | except ValueError as e: # no more sentences 104 | return sents 105 | 106 | def read_head(f_head, cfg): 107 | lines = [] 108 | 109 | sents = abstract2sents(f_head, cfg) 110 | line = ' '.join(sents) 111 | words = line.split() 112 | num_words = len(words) 113 | if num_words >= cfg.MIN_LEN_Y and num_words <= cfg.MAX_LEN_Y: 114 | lines += words 115 | lines += [cfg.W_EOS] 116 | elif num_words > cfg.MAX_LEN_Y: # do not know if should be stoped 117 | lines = words[0 : cfg.MAX_LEN_Y + 1] # one more word. 118 | 119 | return (lines, sents) if len(lines) >= cfg.MIN_LEN_Y and len(lines) <= cfg.MAX_LEN_Y+1 else None 120 | 121 | def prepare_deepmind(d_path): 122 | configs = DeepmindConfigs() 123 | TRAINING_PATH = configs.cc.TRAINING_DATA_PATH 124 | VALIDATE_PATH = configs.cc.VALIDATE_DATA_PATH 125 | TESTING_PATH = configs.cc.TESTING_DATA_PATH 126 | RESULT_PATH = configs.cc.RESULT_PATH 127 | MODEL_PATH = configs.cc.MODEL_PATH 128 | BEAM_SUMM_PATH = configs.cc.BEAM_SUMM_PATH 129 | BEAM_GT_PATH = configs.cc.BEAM_GT_PATH 130 | GROUND_TRUTH_PATH = configs.cc.GROUND_TRUTH_PATH 131 | SUMM_PATH = configs.cc.SUMM_PATH 132 | TMP_PATH = configs.cc.TMP_PATH 133 | 134 | print ("train: " + TRAINING_PATH) 135 | print ("test: " + TESTING_PATH) 136 | print ("validate: " + VALIDATE_PATH) 137 | print ("result: " + RESULT_PATH) 138 | print ("model: " + MODEL_PATH) 139 | print ("tmp: " + TMP_PATH) 140 | 141 | if not exists(TRAINING_PATH): 142 | makedirs(TRAINING_PATH) 143 | if not exists(VALIDATE_PATH): 144 | makedirs(VALIDATE_PATH) 145 | if not exists(TESTING_PATH): 146 | makedirs(TESTING_PATH) 147 | if not exists(RESULT_PATH): 148 | makedirs(RESULT_PATH) 149 | if not exists(MODEL_PATH): 150 | makedirs(MODEL_PATH) 151 | if not exists(BEAM_SUMM_PATH): 152 | makedirs(BEAM_SUMM_PATH) 153 | if not exists(BEAM_GT_PATH): 154 | makedirs(BEAM_GT_PATH) 155 | if not exists(GROUND_TRUTH_PATH): 156 | makedirs(GROUND_TRUTH_PATH) 157 | if not exists(SUMM_PATH): 158 | makedirs(SUMM_PATH) 159 | if not exists(TMP_PATH): 160 | makedirs(TMP_PATH) 161 | 162 | 163 | print ("trainset...") 164 | train_xy_list = load_lines(d_path, "train.txt", configs) 165 | 166 | print ("dump train...") 167 | pickle.dump(train_xy_list, open(TRAINING_PATH + "train.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL) 168 | 169 | 170 | print ("fitering and building dict...") 171 | use_abisee = True 172 | all_dic1 = {} 173 | all_dic2 = {} 174 | dic_list = [] 175 | all_dic1, dic_list = load_dict(d_path, "vocab", all_dic1, dic_list) 176 | all_dic2 = to_dict(train_xy_list, all_dic2) 177 | for w, tf in all_dic2.items(): 178 | if w not in all_dic1: 179 | all_dic1[w] = tf 180 | 181 | candiate_list = dic_list[0:configs.PG_DICT_SIZE] # 50000 182 | candiate_set = set(candiate_list) 183 | 184 | dic = {} 185 | w2i = {} 186 | i2w = {} 187 | w2w = {} 188 | 189 | for w in [configs.W_PAD, configs.W_UNK, configs.W_BOS, configs.W_EOS]: 190 | #for w in [configs.W_PAD, configs.W_UNK, configs.W_BOS, configs.W_EOS, configs.W_LS, configs.W_RS]: 191 | w2i[w] = len(dic) 192 | i2w[w2i[w]] = w 193 | dic[w] = 10000 194 | w2w[w] = w 195 | 196 | for w, tf in all_dic1.items(): 197 | if w in candiate_set: 198 | w2i[w] = len(dic) 199 | i2w[w2i[w]] = w 200 | dic[w] = tf 201 | w2w[w] = w 202 | else: 203 | w2w[w] = configs.W_UNK 204 | hfw = [] 205 | sorted_x = sorted(dic.items(), key=operator.itemgetter(1), reverse=True) 206 | for w in sorted_x: 207 | hfw.append(w[0]) 208 | 209 | assert len(hfw) == len(dic) 210 | assert len(w2i) == len(dic) 211 | print ("dump dict...") 212 | pickle.dump([all_dic1, dic, hfw, w2i, i2w, w2w], open(TRAINING_PATH + "dic.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL) 213 | 214 | print ("testset...") 215 | test_xy_list = load_lines(d_path, "test.txt", configs) 216 | 217 | print ("validset...") 218 | valid_xy_list = load_lines(d_path, "val.txt", configs) 219 | 220 | 221 | print ("#train = ", len(train_xy_list)) 222 | print ("#test = ", len(test_xy_list)) 223 | print ("#validate = ", len(valid_xy_list)) 224 | 225 | print ("#all_dic = ", len(all_dic1), ", #dic = ", len(dic), ", #hfw = ", len(hfw)) 226 | 227 | print ("dump test...") 228 | pickle.dump(test_xy_list, open(TESTING_PATH + "test.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL) 229 | shuffle(test_xy_list) 230 | pickle.dump(test_xy_list[0:2000], open(TESTING_PATH + "pj2000.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL) 231 | 232 | print ("dump validate...") 233 | pickle.dump(valid_xy_list, open(VALIDATE_PATH + "valid.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL) 234 | pickle.dump(valid_xy_list[0:1000], open(VALIDATE_PATH + "pj1000.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL) 235 | 236 | print ("done.") 237 | 238 | if __name__ == "__main__": 239 | parser = argparse.ArgumentParser() 240 | parser.add_argument("-d", "--data", default="cnndm", help="dataset path", ) 241 | args = parser.parse_args() 242 | 243 | data_type = "cnndm" 244 | raw_path = "./data/" 245 | 246 | print (data_type, raw_path) 247 | run(data_type, raw_path) 248 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | import math 6 | 7 | class TransformerLayer(nn.Module): 8 | 9 | def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_external=False, weights_dropout = True): 10 | super(TransformerLayer, self).__init__() 11 | self.self_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout) 12 | self.fc1 = nn.Linear(embed_dim, ff_embed_dim) 13 | self.fc2 = nn.Linear(ff_embed_dim, embed_dim) 14 | self.attn_layer_norm = LayerNorm(embed_dim) 15 | self.ff_layer_norm = LayerNorm(embed_dim) 16 | self.with_external = with_external 17 | self.dropout = dropout 18 | if self.with_external: 19 | self.external_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout) 20 | self.external_layer_norm = LayerNorm(embed_dim) 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | nn.init.normal_(self.fc1.weight, std=0.02) 25 | nn.init.normal_(self.fc2.weight, std=0.02) 26 | nn.init.constant_(self.fc1.bias, 0.) 27 | nn.init.constant_(self.fc2.bias, 0.) 28 | 29 | def forward(self, x, kv = None, 30 | self_padding_mask = None, self_attn_mask = None, 31 | external_memories = None, external_padding_mask=None, 32 | need_weights = False): 33 | # x: seq_len x bsz x embed_dim 34 | residual = x 35 | if kv is None: 36 | x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights) 37 | else: 38 | x, self_attn = self.self_attn(query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights) 39 | 40 | x = F.dropout(x, p=self.dropout, training=self.training) 41 | x = self.attn_layer_norm(residual + x) 42 | 43 | if self.with_external: 44 | residual = x 45 | x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask, need_weights = need_weights) 46 | x = F.dropout(x, p=self.dropout, training=self.training) 47 | x = self.external_layer_norm(residual + x) 48 | else: 49 | external_attn = None 50 | 51 | residual = x 52 | x = gelu(self.fc1(x)) 53 | x = F.dropout(x, p=self.dropout, training=self.training) 54 | x = self.fc2(x) 55 | x = F.dropout(x, p=self.dropout, training=self.training) 56 | x = self.ff_layer_norm(residual + x) 57 | 58 | return x, self_attn, external_attn 59 | 60 | class MultiheadAttention(nn.Module): 61 | 62 | def __init__(self, embed_dim, num_heads, dropout=0., weights_dropout=True): 63 | super(MultiheadAttention, self).__init__() 64 | self.embed_dim = embed_dim 65 | self.num_heads = num_heads 66 | self.dropout = dropout 67 | self.head_dim = embed_dim // num_heads 68 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 69 | self.scaling = self.head_dim ** -0.5 70 | 71 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 72 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) 73 | 74 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) 75 | self.weights_dropout = weights_dropout 76 | self.reset_parameters() 77 | 78 | def reset_parameters(self): 79 | nn.init.normal_(self.in_proj_weight, std=0.02) 80 | nn.init.normal_(self.out_proj.weight, std=0.02) 81 | nn.init.constant_(self.in_proj_bias, 0.) 82 | nn.init.constant_(self.out_proj.bias, 0.) 83 | 84 | def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=False): 85 | """ Input shape: Time x Batch x Channel 86 | key_padding_mask: Time x batch 87 | attn_mask: tgt_len x src_len 88 | """ 89 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 90 | kv_same = key.data_ptr() == value.data_ptr() 91 | 92 | tgt_len, bsz, embed_dim = query.size() 93 | assert key.size() == value.size() 94 | 95 | if qkv_same: 96 | # self-attention 97 | q, k, v = self.in_proj_qkv(query) 98 | elif kv_same: 99 | # encoder-decoder attention 100 | q = self.in_proj_q(query) 101 | k, v = self.in_proj_kv(key) 102 | else: 103 | q = self.in_proj_q(query) 104 | k = self.in_proj_k(key) 105 | v = self.in_proj_v(value) 106 | q = q*self.scaling 107 | 108 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 109 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 110 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 111 | 112 | src_len = k.size(1) 113 | # k,v: bsz*heads x src_len x dim 114 | # q: bsz*heads x tgt_len x dim 115 | 116 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 117 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 118 | 119 | if attn_mask is not None: 120 | attn_weights.masked_fill_( 121 | attn_mask.unsqueeze(0), 122 | float('-inf') 123 | ) 124 | 125 | if key_padding_mask is not None: 126 | # don't attend to padding symbols 127 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 128 | attn_weights.masked_fill_( 129 | key_padding_mask.transpose(0, 1).unsqueeze(1).unsqueeze(2), 130 | float('-inf') 131 | ) 132 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 133 | 134 | 135 | attn_weights = F.softmax(attn_weights, dim=-1) 136 | 137 | if self.weights_dropout: 138 | attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) 139 | 140 | attn = torch.bmm(attn_weights, v) 141 | if not self.weights_dropout: 142 | attn = F.dropout(attn, p=self.dropout, training=self.training) 143 | 144 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 145 | 146 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 147 | attn = self.out_proj(attn) 148 | 149 | if need_weights: 150 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 151 | 152 | #attn_weights, _ = attn_weights.max(dim=1) 153 | attn_weights = attn_weights[:, 0, :, :] 154 | #attn_weights = attn_weights.mean(dim=1) 155 | attn_weights = attn_weights.transpose(0, 1) 156 | else: 157 | attn_weights = None 158 | 159 | return attn, attn_weights 160 | 161 | def in_proj_qkv(self, query): 162 | return self._in_proj(query).chunk(3, dim=-1) 163 | 164 | def in_proj_kv(self, key): 165 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 166 | 167 | def in_proj_q(self, query): 168 | return self._in_proj(query, end=self.embed_dim) 169 | 170 | def in_proj_k(self, key): 171 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 172 | 173 | def in_proj_v(self, value): 174 | return self._in_proj(value, start=2 * self.embed_dim) 175 | 176 | def _in_proj(self, input, start=0, end=None): 177 | weight = self.in_proj_weight 178 | bias = self.in_proj_bias 179 | weight = weight[start:end, :] 180 | if bias is not None: 181 | bias = bias[start:end] 182 | return F.linear(input, weight, bias) 183 | 184 | def gelu(x): 185 | cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 186 | return cdf*x 187 | 188 | class LayerNorm(nn.Module): 189 | def __init__(self, hidden_size, eps=1e-12): 190 | super(LayerNorm, self).__init__() 191 | self.weight = nn.Parameter(torch.Tensor(hidden_size)) 192 | self.bias = nn.Parameter(torch.Tensor(hidden_size)) 193 | self.eps = eps 194 | self.reset_parameters() 195 | def reset_parameters(self): 196 | nn.init.constant_(self.weight, 1.) 197 | nn.init.constant_(self.bias, 0.) 198 | 199 | def forward(self, x): 200 | u = x.mean(-1, keepdim=True) 201 | s = (x - u).pow(2).mean(-1, keepdim=True) 202 | x = (x - u) / torch.sqrt(s + self.eps) 203 | return self.weight * x + self.bias 204 | 205 | def Embedding(num_embeddings, embedding_dim, padding_idx): 206 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 207 | nn.init.normal_(m.weight, std=0.02) 208 | nn.init.constant_(m.weight[padding_idx], 0) 209 | return m 210 | 211 | class SelfAttentionMask(nn.Module): 212 | def __init__(self, init_size = 100, device = 0): 213 | super(SelfAttentionMask, self).__init__() 214 | self.weights = SelfAttentionMask.get_mask(init_size) 215 | self.device = device 216 | 217 | @staticmethod 218 | def get_mask(size): 219 | weights = torch.triu(torch.ones((size, size), dtype = torch.bool), 1) 220 | return weights 221 | 222 | def forward(self, size): 223 | if self.weights is None or size > self.weights.size(0): 224 | self.weights = SelfAttentionMask.get_mask(size) 225 | res = self.weights[:size,:size].cuda(self.device).detach() 226 | return res 227 | 228 | class LearnedPositionalEmbedding(nn.Module): 229 | """This module produces LearnedPositionalEmbedding. 230 | """ 231 | def __init__(self, embedding_dim, init_size=1024, device=0): 232 | super(LearnedPositionalEmbedding, self).__init__() 233 | self.weights = nn.Embedding(init_size, embedding_dim) 234 | self.device= device 235 | self.reset_parameters() 236 | 237 | def reset_parameters(self): 238 | nn.init.normal_(self.weights.weight, std=0.02) 239 | 240 | def forward(self, input, offset=0): 241 | """Input is expected to be of size [seq_len x bsz].""" 242 | seq_len, bsz = input.size() 243 | positions = (offset + torch.arange(seq_len)).cuda(self.device) 244 | res = self.weights(positions).unsqueeze(1).expand(-1, bsz, -1) 245 | return res 246 | 247 | class SinusoidalPositionalEmbedding(nn.Module): 248 | """This module produces sinusoidal positional embeddings of any length. 249 | """ 250 | def __init__(self, embedding_dim, init_size=1024, device=0): 251 | super(SinusoidalPositionalEmbedding, self).__init__() 252 | self.embedding_dim = embedding_dim 253 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 254 | init_size, 255 | embedding_dim 256 | ) 257 | self.device= device 258 | 259 | @staticmethod 260 | def get_embedding(num_embeddings, embedding_dim): 261 | """Build sinusoidal embeddings. 262 | This matches the implementation in tensor2tensor, but differs slightly 263 | from the description in Section 3.5 of "Attention Is All You Need". 264 | """ 265 | half_dim = embedding_dim // 2 266 | emb = math.log(10000) / (half_dim - 1) 267 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 268 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 269 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 270 | if embedding_dim % 2 == 1: 271 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 272 | return emb 273 | 274 | def forward(self, input, offset=0): 275 | """Input is expected to be of size [seq_len x bsz].""" 276 | seq_len, bsz = input.size() 277 | mx_position = seq_len + offset 278 | if self.weights is None or mx_position > self.weights.size(0): 279 | # recompute/expand embeddings if needed 280 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 281 | mx_position, 282 | self.embedding_dim, 283 | ) 284 | 285 | positions = offset + torch.arange(seq_len) 286 | res = self.weights.index_select(0, positions).unsqueeze(1).expand(-1, bsz, -1).cuda(self.device).detach() 287 | return res 288 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | cudaid = 6 4 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cudaid) 5 | 6 | import sys 7 | import time 8 | import numpy as np 9 | import pickle 10 | import copy 11 | import random 12 | from random import shuffle 13 | import math 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | 19 | import data as datar 20 | from model import * 21 | from utils_pg import * 22 | from configs import * 23 | from optim import Optim 24 | 25 | cfg = DeepmindConfigs() 26 | TRAINING_DATASET_CLS = DeepmindTraining 27 | TESTING_DATASET_CLS = DeepmindTesting 28 | 29 | def print_basic_info(modules, consts, options): 30 | if options["is_debugging"]: 31 | print("\nWARNING: IN DEBUGGING MODE\n") 32 | if options["copy"]: 33 | print("USE COPY MECHANISM") 34 | if options["coverage"]: 35 | print("USE COVERAGE MECHANISM") 36 | if options["avg_nll"]: 37 | print("USE AVG NLL as LOSS") 38 | else: 39 | print("USE NLL as LOSS") 40 | if options["has_learnable_w2v"]: 41 | print("USE LEARNABLE W2V EMBEDDING") 42 | if options["is_bidirectional"]: 43 | print("USE BI-DIRECTIONAL RNN") 44 | if options["omit_eos"]: 45 | print(" IS OMITTED IN TESTING DATA") 46 | if options["prediction_bytes_limitation"]: 47 | print("MAXIMUM BYTES IN PREDICTION IS LIMITED") 48 | print("RNN TYPE: " + options["cell"]) 49 | for k in consts: 50 | print(k + ":", consts[k]) 51 | 52 | def init_modules(): 53 | 54 | init_seeds() 55 | 56 | options = {} 57 | 58 | options["is_debugging"] = False 59 | options["is_predicting"] = False 60 | options["model_selection"] = False # When options["is_predicting"] = True, true means use validation set for tuning, false is real testing. 61 | 62 | options["cuda"] = cfg.CUDA and torch.cuda.is_available() 63 | options["device"] = torch.device("cuda" if options["cuda"] else "cpu") 64 | 65 | #in config.py 66 | options["cell"] = cfg.CELL 67 | options["copy"] = cfg.COPY 68 | options["coverage"] = cfg.COVERAGE 69 | options["is_bidirectional"] = cfg.BI_RNN 70 | options["avg_nll"] = cfg.AVG_NLL 71 | 72 | options["beam_decoding"] = cfg.BEAM_SEARCH # False for greedy decoding 73 | 74 | assert TRAINING_DATASET_CLS.IS_UNICODE == TESTING_DATASET_CLS.IS_UNICODE 75 | options["is_unicode"] = TRAINING_DATASET_CLS.IS_UNICODE # True Chinese dataet 76 | options["has_y"] = TRAINING_DATASET_CLS.HAS_Y 77 | 78 | options["has_learnable_w2v"] = True 79 | options["omit_eos"] = False # omit and continuously decode until length of sentence reaches MAX_LEN_PREDICT (for DUC testing data) 80 | options["prediction_bytes_limitation"] = False if TESTING_DATASET_CLS.MAX_BYTE_PREDICT == None else True 81 | options["fire"] = cfg.FIRE 82 | 83 | assert options["is_unicode"] == False 84 | 85 | consts = {} 86 | 87 | consts["idx_gpu"] = cudaid 88 | 89 | consts["norm_clip"] = cfg.NORM_CLIP 90 | consts["dim_x"] = cfg.DIM_X 91 | consts["dim_y"] = cfg.DIM_Y 92 | consts["len_x"] = cfg.MAX_LEN_X + 1 # plus 1 for eos 93 | consts["len_y"] = cfg.MAX_LEN_Y + 1 94 | consts["num_x"] = cfg.MAX_NUM_X 95 | consts["num_y"] = cfg.NUM_Y 96 | consts["hidden_size"] = cfg.HIDDEN_SIZE 97 | consts["d_ff"] = cfg.FF_SIZE 98 | consts["num_heads"] = cfg.NUM_H 99 | consts["dropout"] = cfg.DROPOUT 100 | consts["num_layers"] = cfg.NUM_L 101 | consts["label_smoothing"] = cfg.SMOOTHING 102 | consts["alpha"] = cfg.ALPHA 103 | consts["beta"] = cfg.BETA 104 | 105 | consts["batch_size"] = 5 if options["is_debugging"] else TRAINING_DATASET_CLS.BATCH_SIZE 106 | if options["is_debugging"]: 107 | consts["testing_batch_size"] = 1 if options["beam_decoding"] else 2 108 | else: 109 | #consts["testing_batch_size"] = 1 if options["beam_decoding"] else TESTING_DATASET_CLS.BATCH_SIZE 110 | consts["testing_batch_size"] = TESTING_DATASET_CLS.BATCH_SIZE 111 | 112 | consts["min_len_predict"] = TESTING_DATASET_CLS.MIN_LEN_PREDICT 113 | consts["max_len_predict"] = TESTING_DATASET_CLS.MAX_LEN_PREDICT 114 | consts["max_byte_predict"] = TESTING_DATASET_CLS.MAX_BYTE_PREDICT 115 | consts["testing_print_size"] = TESTING_DATASET_CLS.PRINT_SIZE 116 | 117 | consts["lr"] = cfg.LR 118 | consts["beam_size"] = cfg.BEAM_SIZE 119 | 120 | consts["max_epoch"] = 50 if options["is_debugging"] else 64 121 | consts["print_time"] = 2 122 | consts["save_epoch"] = 1 123 | 124 | assert consts["dim_x"] == consts["dim_y"] 125 | assert consts["beam_size"] >= 1 126 | 127 | modules = {} 128 | 129 | [_, dic, hfw, w2i, i2w, w2w] = pickle.load(open(cfg.cc.TRAINING_DATA_PATH + "dic.pkl", "rb")) 130 | consts["dict_size"] = len(dic) 131 | modules["dic"] = dic 132 | modules["w2i"] = w2i 133 | modules["i2w"] = i2w 134 | modules["lfw_emb"] = modules["w2i"][cfg.W_UNK] 135 | modules["eos_emb"] = modules["w2i"][cfg.W_EOS] 136 | modules["bos_idx"] = modules["w2i"][cfg.W_BOS] 137 | consts["pad_token_idx"] = modules["w2i"][cfg.W_PAD] 138 | 139 | return modules, consts, options 140 | 141 | def beam_decode(fname, batch, model, modules, consts, options): 142 | fname = str(fname) 143 | 144 | beam_size = consts["beam_size"] 145 | num_live = 1 146 | num_dead = 0 147 | samples = [] 148 | sample_scores = np.zeros(beam_size) 149 | 150 | last_traces = [[]] 151 | last_scores = torch.FloatTensor(np.zeros(1)).to(options["device"]) 152 | last_c_scores = torch.FloatTensor(np.zeros(1)).to(options["device"]) 153 | last_states = [[]] 154 | 155 | if options["copy"]: 156 | x, x_mask, word_emb, padding_mask, y, len_y, ref_sents, max_ext_len, oovs = batch 157 | else: 158 | x, word_emb, padding_mask, y, len_y, ref_sents = batch 159 | 160 | ys = torch.LongTensor(np.ones((1, num_live), dtype="int64") * modules["bos_idx"]).to(options["device"]) 161 | x = x.unsqueeze(1) 162 | word_emb = word_emb.unsqueeze(1) 163 | padding_mask = padding_mask.unsqueeze(1) 164 | if options["copy"]: 165 | x_mask = x_mask.unsqueeze(1) 166 | 167 | for step in range(consts["max_len_predict"]): 168 | tile_word_emb = word_emb.repeat(1, num_live, 1) 169 | tile_padding_mask = padding_mask.repeat(1, num_live) 170 | if options["copy"]: 171 | tile_x = x.repeat(1, num_live) 172 | tile_x_mask = x_mask.repeat(1, num_live, 1) 173 | 174 | if options["copy"]: 175 | y_pred, attn_dist = model.decode(ys, tile_x_mask, None, tile_word_emb, tile_padding_mask, tile_x, max_ext_len) 176 | else: 177 | y_pred, attn_dist = model.decode(ys, None, None, tile_word_emb, tile_padding_mask) 178 | 179 | dict_size = y_pred.shape[-1] 180 | y_pred = y_pred[-1, :, :] 181 | if options["coverage"]: 182 | attn_dist = attn_dist[-1, :, :] 183 | 184 | cand_y_scores = last_scores + torch.log(y_pred) # larger is better 185 | if options["coverage"]: 186 | cand_scores = (cand_y_scores + last_c_scores).flatten() 187 | else: 188 | cand_scores = cand_y_scores.flatten() 189 | idx_top_joint_scores = torch.topk(cand_scores, beam_size - num_dead)[1] 190 | 191 | idx_last_traces = idx_top_joint_scores // dict_size 192 | idx_word_now = idx_top_joint_scores % dict_size 193 | top_joint_scores = cand_y_scores.flatten()[idx_top_joint_scores] 194 | 195 | traces_now = [] 196 | scores_now = np.zeros((beam_size - num_dead)) 197 | states_now = [] 198 | 199 | for i, [j, k] in enumerate(zip(idx_last_traces, idx_word_now)): 200 | traces_now.append(last_traces[j] + [k]) 201 | scores_now[i] = copy.copy(top_joint_scores[i]) 202 | if options["coverage"]: 203 | states_now.append(last_states[j] + [copy.copy(attn_dist[j, :])]) 204 | 205 | num_live = 0 206 | last_traces = [] 207 | last_scores = [] 208 | last_states = [] 209 | last_c_scores = [] 210 | dead_ids = [] 211 | for i in range(len(traces_now)): 212 | if traces_now[i][-1] == modules["eos_emb"] and len(traces_now[i]) >= consts["min_len_predict"]: 213 | samples.append([str(e.item()) for e in traces_now[i][:-1]]) 214 | sample_scores[num_dead] = scores_now[i] 215 | num_dead += 1 216 | dead_ids += [i] 217 | else: 218 | last_traces.append(traces_now[i]) 219 | last_scores.append(scores_now[i]) 220 | 221 | if options["coverage"]: 222 | last_states.append(states_now[i]) 223 | attns = torch.stack(states_now[i]) 224 | m, n = attns.shape 225 | cp = torch.sum(attns, dim=0) 226 | cp = torch.max(cp, torch.ones_like(cp)) 227 | cp = - consts["beta"] * (torch.sum(cp).item() - n) 228 | last_c_scores.append(cp) 229 | 230 | num_live += 1 231 | if num_live == 0 or num_dead >= beam_size: 232 | break 233 | 234 | if options["coverage"]: 235 | last_c_scores = torch.FloatTensor(np.array(last_c_scores).reshape((num_live, 1))).to(options["device"]) 236 | 237 | last_scores = torch.FloatTensor(np.array(last_scores).reshape((num_live, 1))).to(options["device"]) 238 | next_y = [] 239 | for e in last_traces: 240 | eid = e[-1].item() 241 | if eid in modules["i2w"]: 242 | next_y.append(eid) 243 | else: 244 | next_y.append(modules["lfw_emb"]) # unk for copy mechanism 245 | 246 | next_y = np.array(next_y).reshape((1, num_live)) 247 | next_y = torch.LongTensor(next_y).to(options["device"]) 248 | 249 | if step == 0: 250 | ys = ys.repeat(1, num_live) 251 | ys_ = [] 252 | py_ = [] 253 | for i in range(ys.size(1)): 254 | if i not in dead_ids: 255 | ys_.append(ys[:, i]) 256 | ys = torch.cat([torch.stack(ys_, dim=1), next_y], dim=0) 257 | 258 | assert num_live + num_dead == beam_size 259 | 260 | if num_live > 0: 261 | for i in range(num_live): 262 | samples.append([str(e.item()) for e in last_traces[i]]) 263 | sample_scores[num_dead] = last_scores[i] 264 | num_dead += 1 265 | 266 | #weight by length 267 | for i in range(len(sample_scores)): 268 | sent_len = float(len(samples[i])) 269 | lp = np.power(5 + sent_len, consts["alpha"]) / np.power(5 + 1, consts["alpha"]) 270 | sample_scores[i] /= lp 271 | 272 | idx_sorted_scores = np.argsort(sample_scores) # ascending order 273 | if options["has_y"]: 274 | ly = len_y[0] 275 | y_true = y[0 : ly].tolist() 276 | y_true = [str(i) for i in y_true[:-1]] # delete 277 | 278 | sorted_samples = [] 279 | sorted_scores = [] 280 | filter_idx = [] 281 | for e in idx_sorted_scores: 282 | if len(samples[e]) >= consts["min_len_predict"]: 283 | filter_idx.append(e) 284 | if len(filter_idx) == 0: 285 | filter_idx = idx_sorted_scores 286 | for e in filter_idx: 287 | sorted_samples.append(samples[e]) 288 | sorted_scores.append(sample_scores[e]) 289 | 290 | num_samples = len(sorted_samples) 291 | if len(sorted_samples) == 1: 292 | sorted_samples = sorted_samples[0] 293 | num_samples = 1 294 | 295 | # for task with bytes-length limitation 296 | if options["prediction_bytes_limitation"]: 297 | for i in range(len(sorted_samples)): 298 | sample = sorted_samples[i] 299 | b = 0 300 | for j in range(len(sample)): 301 | e = int(sample[j]) 302 | if e in modules["i2w"]: 303 | word = modules["i2w"][e] 304 | else: 305 | word = oovs[e - len(modules["i2w"])] 306 | if j == 0: 307 | b += len(word) 308 | else: 309 | b += len(word) + 1 310 | if b > consts["max_byte_predict"]: 311 | sorted_samples[i] = sorted_samples[i][0 : j] 312 | break 313 | 314 | dec_words = [] 315 | 316 | for e in sorted_samples[-1]: 317 | e = int(e) 318 | if e in modules["i2w"]: # if not copy, the word are all in dict 319 | dec_words.append(modules["i2w"][e]) 320 | else: 321 | dec_words.append(oovs[e - len(modules["i2w"])]) 322 | 323 | write_for_rouge(fname, ref_sents, dec_words, cfg) 324 | 325 | # beam search history for checking 326 | if not options["copy"]: 327 | oovs = None 328 | write_summ("".join((cfg.cc.BEAM_SUMM_PATH, fname)), sorted_samples, num_samples, options, modules["i2w"], oovs, sorted_scores) 329 | write_summ("".join((cfg.cc.BEAM_GT_PATH, fname)), y_true, 1, options, modules["i2w"], oovs) 330 | 331 | 332 | 333 | def predict(model, modules, consts, options): 334 | print("start predicting,") 335 | model.eval() 336 | options["has_y"] = TESTING_DATASET_CLS.HAS_Y 337 | if options["beam_decoding"]: 338 | print("using beam search") 339 | else: 340 | print("using greedy search") 341 | rebuild_dir(cfg.cc.BEAM_SUMM_PATH) 342 | rebuild_dir(cfg.cc.BEAM_GT_PATH) 343 | rebuild_dir(cfg.cc.GROUND_TRUTH_PATH) 344 | rebuild_dir(cfg.cc.SUMM_PATH) 345 | 346 | print("loading test set...") 347 | if options["model_selection"]: 348 | xy_list = pickle.load(open(cfg.cc.VALIDATE_DATA_PATH + "pj1000.pkl", "rb")) 349 | else: 350 | xy_list = pickle.load(open(cfg.cc.TESTING_DATA_PATH + "test.pkl", "rb")) 351 | batch_list, num_files, num_batches = datar.batched(len(xy_list), options, consts) 352 | 353 | print("num_files = ", num_files, ", num_batches = ", num_batches) 354 | 355 | running_start = time.time() 356 | partial_num = 0 357 | total_num = 0 358 | si = 0 359 | for idx_batch in range(num_batches): 360 | test_idx = batch_list[idx_batch] 361 | batch_raw = [xy_list[xy_idx] for xy_idx in test_idx] 362 | batch = datar.get_data(batch_raw, modules, consts, options) 363 | 364 | assert len(test_idx) == batch.x.shape[1] # local_batch_size 365 | 366 | 367 | word_emb, padding_mask = model.encode(torch.LongTensor(batch.x).to(options["device"])) 368 | 369 | if options["beam_decoding"]: 370 | for idx_s in range(len(test_idx)): 371 | if options["copy"]: 372 | inputx = (torch.LongTensor(batch.x_ext[:, idx_s]).to(options["device"]), \ 373 | torch.FloatTensor(batch.x_mask[:, idx_s, :]).to(options["device"]), \ 374 | word_emb[:, idx_s, :], padding_mask[:, idx_s],\ 375 | batch.y[:, idx_s], [batch.len_y[idx_s]], batch.original_summarys[idx_s],\ 376 | batch.max_ext_len, batch.x_ext_words[idx_s]) 377 | else: 378 | inputx = (torch.LongTensor(batch.x[:, idx_s]).to(options["device"]), word_emb[:, idx_s, :], padding_mask[:, idx_s],\ 379 | batch.y[:, idx_s], [batch.len_y[idx_s]], batch.original_summarys[idx_s]) 380 | 381 | beam_decode(si, inputx, model, modules, consts, options) 382 | si += 1 383 | else: 384 | pass 385 | #greedy_decode() 386 | 387 | testing_batch_size = len(test_idx) 388 | partial_num += testing_batch_size 389 | total_num += testing_batch_size 390 | if partial_num >= consts["testing_print_size"]: 391 | print(total_num, "summs are generated") 392 | partial_num = 0 393 | print (si, total_num) 394 | 395 | def run(existing_model_name = None): 396 | modules, consts, options = init_modules() 397 | 398 | if options["is_predicting"]: 399 | need_load_model = True 400 | training_model = False 401 | predict_model = True 402 | else: 403 | need_load_model = False 404 | training_model = True 405 | predict_model = False 406 | 407 | print_basic_info(modules, consts, options) 408 | 409 | if training_model: 410 | print ("loading train set...") 411 | if options["is_debugging"]: 412 | xy_list = pickle.load(open(cfg.cc.TESTING_DATA_PATH + "test.pkl", "rb")) 413 | else: 414 | xy_list = pickle.load(open(cfg.cc.TRAINING_DATA_PATH + "train.pkl", "rb")) 415 | batch_list, num_files, num_batches = datar.batched(len(xy_list), options, consts) 416 | print ("num_files = ", num_files, ", num_batches = ", num_batches) 417 | 418 | running_start = time.time() 419 | if True: #TODO: refactor 420 | print ("compiling model ..." ) 421 | model = Model(modules, consts, options) 422 | if options["cuda"]: 423 | model.cuda() 424 | optimizer = torch.optim.Adagrad(model.parameters(), lr=consts["lr"], initial_accumulator_value=0.1) 425 | 426 | model_name = "".join(["cnndm.s2s.", options["cell"]]) 427 | existing_epoch = 0 428 | if need_load_model: 429 | if existing_model_name == None: 430 | existing_model_name = "cnndm.s2s.transformer.gpu0.epoch27.2" 431 | print ("loading existed model:", existing_model_name) 432 | model, optimizer = load_model(cfg.cc.MODEL_PATH + existing_model_name, model, optimizer) 433 | 434 | if training_model: 435 | print ("start training model ") 436 | model.train() 437 | print_size = num_files // consts["print_time"] if num_files >= consts["print_time"] else num_files 438 | 439 | last_total_error = float("inf") 440 | print ("max epoch:", consts["max_epoch"]) 441 | for epoch in range(0, consts["max_epoch"]): 442 | print ("epoch: ", epoch + existing_epoch) 443 | num_partial = 1 444 | total_error = 0.0 445 | error_c = 0.0 446 | partial_num_files = 0 447 | epoch_start = time.time() 448 | partial_start = time.time() 449 | # shuffle the trainset 450 | batch_list, num_files, num_batches = datar.batched(len(xy_list), options, consts) 451 | used_batch = 0. 452 | for idx_batch in range(num_batches): 453 | train_idx = batch_list[idx_batch] 454 | batch_raw = [xy_list[xy_idx] for xy_idx in train_idx] 455 | if len(batch_raw) != consts["batch_size"]: 456 | continue 457 | local_batch_size = len(batch_raw) 458 | batch = datar.get_data(batch_raw, modules, consts, options) 459 | 460 | 461 | model.zero_grad() 462 | 463 | y_pred, cost = model(torch.LongTensor(batch.x).to(options["device"]),\ 464 | torch.LongTensor(batch.y_inp).to(options["device"]),\ 465 | torch.LongTensor(batch.y).to(options["device"]),\ 466 | torch.FloatTensor(batch.x_mask).to(options["device"]),\ 467 | torch.FloatTensor(batch.y_mask).to(options["device"]),\ 468 | torch.LongTensor(batch.x_ext).to(options["device"]),\ 469 | torch.LongTensor(batch.y_ext).to(options["device"]),\ 470 | batch.max_ext_len) 471 | 472 | 473 | cost.backward() 474 | torch.nn.utils.clip_grad_norm_(model.parameters(), consts["norm_clip"]) 475 | optimizer.step() 476 | 477 | 478 | cost = cost.item() 479 | total_error += cost 480 | used_batch += 1 481 | partial_num_files += consts["batch_size"] 482 | if partial_num_files // print_size == 1 and idx_batch < num_batches: 483 | print (idx_batch + 1, "/" , num_batches, "batches have been processed,", \ 484 | "average cost until now:", "cost =", total_error / used_batch, ",", \ 485 | "cost_c =", error_c / used_batch, ",", \ 486 | "time:", time.time() - partial_start) 487 | partial_num_files = 0 488 | if not options["is_debugging"]: 489 | print("save model... ",) 490 | file_name = model_name + ".gpu" + str(consts["idx_gpu"]) + ".epoch" + str(epoch // consts["save_epoch"] + existing_epoch) + "." + str(num_partial) 491 | save_model(cfg.cc.MODEL_PATH + file_name, model, optimizer) 492 | if options["fire"]: 493 | shutil.move(cfg.cc.MODEL_PATH + file_name, "/out/") 494 | 495 | print("finished") 496 | num_partial += 1 497 | print ("in this epoch, total average cost =", total_error / used_batch, ",", \ 498 | "cost_c =", error_c / used_batch, ",",\ 499 | "time:", time.time() - epoch_start) 500 | 501 | print_sent_dec(y_pred, batch.y, batch.y_mask, batch.x_ext_words, modules, consts, options, local_batch_size) 502 | 503 | if last_total_error > total_error or options["is_debugging"]: 504 | last_total_error = total_error 505 | if not options["is_debugging"]: 506 | print ("save model... ",) 507 | file_name = model_name + ".gpu" + str(consts["idx_gpu"]) + ".epoch" + str(epoch // consts["save_epoch"] + existing_epoch) + "." + str(num_partial) 508 | save_model(cfg.cc.MODEL_PATH + file_name, model, optimizer) 509 | if options["fire"]: 510 | shutil.move(cfg.cc.MODEL_PATH + file_name, "/out/") 511 | 512 | print ("finished") 513 | else: 514 | print ("optimization finished") 515 | break 516 | 517 | print ("save final model... "), 518 | file_name = model_name + ".final.gpu" + str(consts["idx_gpu"]) + ".epoch" + str(epoch // consts["save_epoch"] + existing_epoch) + "." + str(num_partial) 519 | save_model(cfg.cc.MODEL_PATH + file_name, model, optimizer) 520 | if options["fire"]: 521 | shutil.move(cfg.cc.MODEL_PATH + file_name, "/out/") 522 | 523 | print ("finished") 524 | else: 525 | print ("skip training model") 526 | 527 | if predict_model: 528 | predict(model, modules, consts, options) 529 | print ("Finished, time:", time.time() - running_start) 530 | 531 | if __name__ == "__main__": 532 | np.set_printoptions(threshold = np.inf) 533 | existing_model_name = sys.argv[1] if len(sys.argv) > 1 else None 534 | run(existing_model_name) 535 | --------------------------------------------------------------------------------