├── .gitignore
├── tuning.sh
├── optim.py
├── LICENSE
├── label_smoothing.py
├── prepare_rouge.py
├── word_prob_layer.py
├── configs.py
├── README.md
├── data.py
├── model.py
├── bleu.py
├── utils_pg.py
├── prepare_data.py
├── transformer.py
└── main.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.log
3 | /data*/*
4 | /model*/*
5 | /cnndm/*
6 |
--------------------------------------------------------------------------------
/tuning.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | FILES=./cnndm/model/*
3 | for f in $FILES; do
4 | echo "==========================" ${f##*/}
5 | python -u main.py ${f##*/}
6 | python prepare_rouge.py
7 | cd ./deepmind/result/
8 | perl /home/pijili/tools/ROUGE-1.5.5/ROUGE-1.5.5.pl -n 4 -w 1.2 -m -2 4 -u -c 95 -r 1000 -f A -p 0.5 -t 0 myROUGE_Config.xml C
9 | cd ../../
10 | done
11 |
--------------------------------------------------------------------------------
/optim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | class Optim:
4 | "Optim wrapper that implements rate."
5 | def __init__(self, model_size, factor, warmup, optimizer):
6 | self.optimizer = optimizer
7 | self._step = 0
8 | self.warmup = warmup
9 | self.factor = factor
10 | self.model_size = model_size
11 | self._rate = 0
12 |
13 | def step(self):
14 | "Update parameters and rate"
15 | self._step += 1
16 | rate = self.rate()
17 | for p in self.optimizer.param_groups:
18 | p['lr'] = rate
19 | self._rate = rate
20 | self.optimizer.step()
21 |
22 | def rate(self, step = None):
23 | "Implement `lrate` above"
24 | if step is None:
25 | step = self._step
26 | return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
27 |
28 | def state_dict(self):
29 | return self.optimizer.state_dict()
30 |
31 | def load_state_dict(self, m):
32 | self.optimizer.load_state_dict(m)
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Piji Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/label_smoothing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class LabelSmoothing(nn.Module):
6 | "Implement label smoothing."
7 | def __init__(self, device, size, padding_idx, label_smoothing=0.0):
8 | super(LabelSmoothing, self).__init__()
9 | assert 0.0 < label_smoothing <= 1.0
10 | self.padding_idx = padding_idx
11 | self.size = size
12 | self.device = device
13 |
14 | self.smoothing_value = label_smoothing / (size - 2)
15 | self.one_hot = torch.full((1, size), self.smoothing_value).to(device)
16 | self.one_hot[0, self.padding_idx] = 0
17 |
18 | self.confidence = 1.0 - label_smoothing
19 |
20 | def forward(self, output, target):
21 | real_size = output.size(1)
22 | if real_size > self.size:
23 | real_size -= self.size
24 | else:
25 | real_size = 0
26 |
27 | model_prob = self.one_hot.repeat(target.size(0), 1)
28 | if real_size > 0:
29 | ext_zeros = torch.full((model_prob.size(0), real_size), self.smoothing_value).to(self.device)
30 | model_prob = torch.cat((model_prob, ext_zeros), -1)
31 | model_prob.scatter_(1, target, self.confidence)
32 | model_prob.masked_fill_((target == self.padding_idx), 0.)
33 |
34 | return F.kl_div(output, model_prob, reduction='sum')
35 |
--------------------------------------------------------------------------------
/prepare_rouge.py:
--------------------------------------------------------------------------------
1 | #pylint: skip-file
2 | import sys
3 | import os
4 | from configs import *
5 |
6 | cfg = DeepmindConfigs()
7 |
8 | # config file for ROUGE
9 | ROUGE_PATH = cfg.cc.RESULT_PATH
10 | SUMM_PATH = cfg.cc.SUMM_PATH
11 | MODEL_PATH = cfg.cc.GROUND_TRUTH_PATH
12 | i2summ = {}
13 | summ2i = {}
14 | i2model = {}
15 |
16 | # for result
17 | flist = os.listdir(SUMM_PATH)
18 | i = 0
19 | for fname in flist:
20 | i2summ[str(i)] = fname
21 | summ2i[fname] = str(i)
22 | i += 1
23 |
24 | # for models
25 | flist = os.listdir(MODEL_PATH)
26 | i2model = {}
27 | for fname in flist:
28 | if fname not in summ2i:
29 | raise IOError
30 |
31 | i = summ2i[fname]
32 | i2model[i] = fname
33 |
34 | assert len(i2model) == len(i2summ)
35 |
36 | # write to config file
37 | rouge_s = ""
38 | file_id = 0
39 | for file_id, fsumm in i2summ.items():
40 | rouge_s += "\n" \
41 | + "\n" \
42 | + SUMM_PATH \
43 | + "\n" \
44 | + "\n" \
45 | + "\n" + MODEL_PATH \
46 | + "\n" \
47 | + "\n" \
48 | + "\n" \
49 | + "\n" \
50 | + "\n" + fsumm + "
" \
51 | + "\n" \
52 | + "\n"
53 | rouge_s += "\n" + i2model[file_id] + ""
54 | rouge_s += "\n\n"
55 |
56 | rouge_s += "\n"
57 |
58 | with open(ROUGE_PATH + "myROUGE_Config.xml", "w") as f_rouge:
59 | f_rouge.write(rouge_s)
60 |
--------------------------------------------------------------------------------
/word_prob_layer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #pylint: skip-file
3 | import torch
4 | import torch as T
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 | from utils_pg import *
10 | from transformer import MultiheadAttention
11 |
12 | class WordProbLayer(nn.Module):
13 | def __init__(self, hidden_size, dict_size, device, copy, coverage, dropout):
14 | super(WordProbLayer, self).__init__()
15 | self.hidden_size = hidden_size
16 | self.dict_size = dict_size
17 | self.device = device
18 | self.copy = copy
19 | self.coverage = coverage
20 | self.dropout = dropout
21 |
22 | if self.copy:
23 | self.external_attn = MultiheadAttention(self.hidden_size, 1, self.dropout, weights_dropout=False)
24 | self.proj = nn.Linear(self.hidden_size * 3, self.dict_size)
25 | self.v = nn.Parameter(torch.Tensor(1, self.hidden_size * 3))
26 | self.bv = nn.Parameter(torch.Tensor(1))
27 | else:
28 | self.proj = nn.Linear(self.hidden_size, self.dict_size)
29 |
30 | self.init_weights()
31 |
32 | def init_weights(self):
33 | init_linear_weight(self.proj)
34 | if self.copy:
35 | init_xavier_weight(self.v)
36 | init_bias(self.bv)
37 |
38 | def forward(self, h, y_emb=None, memory=None, mask_x=None, xids=None, max_ext_len=None):
39 |
40 | if self.copy:
41 | atts, dists = self.external_attn(query=h, key=memory, value=memory, key_padding_mask=mask_x, need_weights = True)
42 | pred = T.softmax(self.proj(T.cat([h, y_emb, atts], -1)), dim=-1)
43 | if max_ext_len > 0:
44 | ext_zeros = Variable(torch.zeros(pred.size(0), pred.size(1), max_ext_len)).to(self.device)
45 | pred = T.cat((pred, ext_zeros), -1)
46 | g = T.sigmoid(F.linear(T.cat([h, y_emb, atts], -1), self.v, self.bv))
47 | xids = xids.transpose(0, 1).unsqueeze(0).repeat(pred.size(0), 1, 1)
48 | pred = (g * pred).scatter_add(2, xids, (1 - g) * dists)
49 | else:
50 | pred = T.softmax(self.proj(h), dim=-1)
51 | dists = None
52 | return pred, dists
53 |
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #pylint: skip-file
3 | import os
4 |
5 | class CommonConfigs(object):
6 | def __init__(self, d_type):
7 | self.ROOT_PATH = os.getcwd() + "/"
8 | self.TRAINING_DATA_PATH = self.ROOT_PATH + d_type + "/train_set/"
9 | self.VALIDATE_DATA_PATH = self.ROOT_PATH + d_type + "/validate_set/"
10 | self.TESTING_DATA_PATH = self.ROOT_PATH + d_type + "/test_set/"
11 | self.RESULT_PATH = self.ROOT_PATH + d_type + "/result/"
12 | self.MODEL_PATH = self.ROOT_PATH + d_type + "/model/"
13 | self.BEAM_SUMM_PATH = self.RESULT_PATH + "/beam_summary/"
14 | self.BEAM_GT_PATH = self.RESULT_PATH + "/beam_ground_truth/"
15 | self.GROUND_TRUTH_PATH = self.RESULT_PATH + "/ground_truth/"
16 | self.SUMM_PATH = self.RESULT_PATH + "/summary/"
17 | self.TMP_PATH = self.ROOT_PATH + d_type + "/tmp/"
18 |
19 |
20 | class DeepmindTraining(object):
21 | IS_UNICODE = False
22 | REMOVES_PUNCTION = False
23 | HAS_Y = True
24 | BATCH_SIZE = 20
25 |
26 | class DeepmindTesting(object):
27 | IS_UNICODE = False
28 | HAS_Y = True
29 | BATCH_SIZE = 80
30 | MIN_LEN_PREDICT = 35
31 | MAX_LEN_PREDICT = 120
32 | MAX_BYTE_PREDICT = None
33 | PRINT_SIZE = 500
34 | REMOVES_PUNCTION = False
35 |
36 | class DeepmindConfigs():
37 |
38 | cc = CommonConfigs("cnndm")
39 | FIRE = False
40 |
41 | CELL = "transformer"
42 | CUDA = True
43 | COPY = True
44 | COVERAGE = True
45 |
46 | BI_RNN = False
47 | AVG_NLL = True
48 | NORM_CLIP = 2
49 | if not AVG_NLL:
50 | NORM_CLIP = 5
51 | LR = 0.15
52 | SMOOTHING = 0.1
53 |
54 | BEAM_SEARCH = True
55 | BEAM_SIZE = 4
56 | ALPHA = 0.9 # length penalty
57 | BETA = 5 # coverage during beamsearch
58 |
59 | DIM_X = 512
60 | DIM_Y = DIM_X
61 | HIDDEN_SIZE = 512
62 | FF_SIZE = 1024
63 | NUM_H = 8 # multi-head attention
64 | DROPOUT = 0.2
65 | NUM_L = 4 # num of layers
66 | MIN_LEN_X = 10
67 | MIN_LEN_Y = 10
68 | MAX_LEN_X = 400
69 | MAX_LEN_Y = 100
70 | MIN_NUM_X = 1
71 | MAX_NUM_X = 1
72 | MAX_NUM_Y = None
73 |
74 | NUM_Y = 1
75 |
76 | UNI_LOW_FREQ_THRESHOLD = 10
77 |
78 | PG_DICT_SIZE = 50000 # dict for acl17 paper: pointer-generator
79 |
80 | W_UNK = ""
81 | W_BOS = ""
82 | W_EOS = ""
83 | W_PAD = ""
84 | W_LS = ""
85 | W_RS = ""
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TranSummar
2 | Transformer for abstractive summarization
3 |
4 | #### cnndm (with copy and coverage), epoch57:
5 | ```
6 | ---------------------------------------------
7 | C ROUGE-1 Average_R: 0.41097 (95%-conf.int. 0.40861 - 0.41346)
8 | C ROUGE-1 Average_P: 0.40874 (95%-conf.int. 0.40619 - 0.41141)
9 | C ROUGE-1 Average_F: 0.39656 (95%-conf.int. 0.39451 - 0.39871)
10 | ---------------------------------------------
11 | C ROUGE-2 Average_R: 0.17821 (95%-conf.int. 0.17590 - 0.18049)
12 | C ROUGE-2 Average_P: 0.17781 (95%-conf.int. 0.17540 - 0.18037)
13 | C ROUGE-2 Average_F: 0.17208 (95%-conf.int. 0.16990 - 0.17433)
14 | ---------------------------------------------
15 | C ROUGE-3 Average_R: 0.09845 (95%-conf.int. 0.09640 - 0.10064)
16 | C ROUGE-3 Average_P: 0.09844 (95%-conf.int. 0.09627 - 0.10069)
17 | C ROUGE-3 Average_F: 0.09505 (95%-conf.int. 0.09307 - 0.09713)
18 | ---------------------------------------------
19 | C ROUGE-4 Average_R: 0.06297 (95%-conf.int. 0.06109 - 0.06499)
20 | C ROUGE-4 Average_P: 0.06329 (95%-conf.int. 0.06137 - 0.06537)
21 | C ROUGE-4 Average_F: 0.06086 (95%-conf.int. 0.05908 - 0.06275)
22 | ---------------------------------------------
23 | C ROUGE-L Average_R: 0.37912 (95%-conf.int. 0.37682 - 0.38160)
24 | C ROUGE-L Average_P: 0.37726 (95%-conf.int. 0.37484 - 0.37987)
25 | C ROUGE-L Average_F: 0.36593 (95%-conf.int. 0.36388 - 0.36810)
26 | ---------------------------------------------
27 | C ROUGE-W-1.2 Average_R: 0.16602 (95%-conf.int. 0.16487 - 0.16715)
28 | C ROUGE-W-1.2 Average_P: 0.27554 (95%-conf.int. 0.27362 - 0.27758)
29 | C ROUGE-W-1.2 Average_F: 0.20031 (95%-conf.int. 0.19902 - 0.20156)
30 | ---------------------------------------------
31 | C ROUGE-SU4 Average_R: 0.18191 (95%-conf.int. 0.17981 - 0.18403)
32 | C ROUGE-SU4 Average_P: 0.18101 (95%-conf.int. 0.17890 - 0.18320)
33 | C ROUGE-SU4 Average_F: 0.17496 (95%-conf.int. 0.17308 - 0.17693)
34 | ```
35 | #### gigawords (no copy and no coverage), epoch18:
36 | ```
37 | ---------------------------------------------
38 | C ROUGE-1 Average_R: 0.36144 (95%-conf.int. 0.34958 - 0.37330)
39 | C ROUGE-1 Average_P: 0.37213 (95%-conf.int. 0.36018 - 0.38460)
40 | C ROUGE-1 Average_F: 0.35586 (95%-conf.int. 0.34433 - 0.36747)
41 | ---------------------------------------------
42 | C ROUGE-2 Average_R: 0.17568 (95%-conf.int. 0.16614 - 0.18606)
43 | C ROUGE-2 Average_P: 0.18536 (95%-conf.int. 0.17463 - 0.19625)
44 | C ROUGE-2 Average_F: 0.17467 (95%-conf.int. 0.16489 - 0.18467)
45 | ---------------------------------------------
46 | C ROUGE-3 Average_R: 0.09628 (95%-conf.int. 0.08782 - 0.10555)
47 | C ROUGE-3 Average_P: 0.10448 (95%-conf.int. 0.09482 - 0.11429)
48 | C ROUGE-3 Average_F: 0.09643 (95%-conf.int. 0.08763 - 0.10558)
49 | ---------------------------------------------
50 | C ROUGE-4 Average_R: 0.05583 (95%-conf.int. 0.04812 - 0.06380)
51 | C ROUGE-4 Average_P: 0.06323 (95%-conf.int. 0.05447 - 0.07197)
52 | C ROUGE-4 Average_F: 0.05653 (95%-conf.int. 0.04871 - 0.06425)
53 | ---------------------------------------------
54 | C ROUGE-L Average_R: 0.33497 (95%-conf.int. 0.32382 - 0.34678)
55 | C ROUGE-L Average_P: 0.34521 (95%-conf.int. 0.33414 - 0.35770)
56 | C ROUGE-L Average_F: 0.33005 (95%-conf.int. 0.31905 - 0.34173)
57 | ---------------------------------------------
58 | C ROUGE-W-1.2 Average_R: 0.20852 (95%-conf.int. 0.20134 - 0.21619)
59 | C ROUGE-W-1.2 Average_P: 0.32259 (95%-conf.int. 0.31225 - 0.33452)
60 | C ROUGE-W-1.2 Average_F: 0.24404 (95%-conf.int. 0.23594 - 0.25286)
61 | ---------------------------------------------
62 | C ROUGE-SU4 Average_R: 0.20404 (95%-conf.int. 0.19456 - 0.21456)
63 | C ROUGE-SU4 Average_P: 0.21410 (95%-conf.int. 0.20406 - 0.22493)
64 | C ROUGE-SU4 Average_F: 0.19664 (95%-conf.int. 0.18736 - 0.20654)
65 | ```
66 |
67 | ### How to run:
68 | - Python 3.7, Pytorch 0.4+
69 | - Download the processed dataset from: https://drive.google.com/file/d/1EUuEMBSlrlnf_J2jcAVl1v4owSvw_8ZF/view?usp=sharing , or you can download the original FINISHED_FILES from: https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail , and process by yourself.
70 | - Modify the path in prepare_data.py then run it: python prepare_data.py
71 | - Training: python -u main.py | tee train.log
72 | - Tuning: modify main.py: is_predicting=true and model_selection=true, then run "bash tuning_deepmind.sh | tee tune.log"
73 | - Testing: modify main.py: is_predicting=true and model_selection=false, then run "python main.py you-best-model (say cnndm.s2s.gpu4.epoch7.1)", go to "./deepmind/result/" and run $ROUGE$ myROUGE_Config.xml C, you will get the results.
74 | - The Perl Rouge package is enough, I did not use pyrouge.
75 |
76 | ### Reference:
77 | - fairseq: https://github.com/pytorch/fairseq
78 | - The Annotated Transformer: http://nlp.seas.harvard.edu/2018/04/03/attention.html
79 | - bert:https://github.com/jcyk/BERT
80 | - Rush-Gigaword: https://drive.google.com/open?id=0B6N7tANPyVeBNmlSX19Ld2xDU1E
81 | - Rush-CNN/Dailymail: https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz
82 |
83 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #pylint: skip-file
3 | import sys
4 | import os
5 | import os.path
6 | import time
7 | from operator import itemgetter
8 | import numpy as np
9 | import pickle
10 | from random import shuffle
11 |
12 | class BatchData:
13 | def __init__(self, flist, modules, consts, options):
14 | self.batch_size = len(flist)
15 | self.x = np.zeros((consts["len_x"], self.batch_size), dtype = np.int64)
16 | self.x_ext = np.zeros((consts["len_x"], self.batch_size), dtype = np.int64)
17 | self.y = np.zeros((consts["len_y"], self.batch_size), dtype = np.int64)
18 | self.y_inp = np.zeros((consts["len_y"], self.batch_size), dtype = np.int64)
19 | self.y_ext = np.zeros((consts["len_y"], self.batch_size), dtype = np.int64)
20 | self.x_mask = np.zeros((consts["len_x"], self.batch_size, 1), dtype = np.int64)
21 | self.y_mask = np.zeros((consts["len_y"], self.batch_size, 1), dtype = np.int64)
22 | self.len_x = []
23 | self.len_y = []
24 | self.original_contents = []
25 | self.original_summarys = []
26 | self.x_ext_words = []
27 | self.max_ext_len = 0
28 |
29 | w2i = modules["w2i"]
30 | i2w = modules["i2w"]
31 | dict_size = len(w2i)
32 |
33 | for idx_doc in range(len(flist)):
34 | if len(flist[idx_doc]) == 2:
35 | contents, summarys = flist[idx_doc]
36 | else:
37 | print ("ERROR!")
38 | return
39 |
40 | content, original_content = contents
41 | summary, original_summary = summarys
42 | self.original_contents.append(original_content)
43 | self.original_summarys.append(original_summary)
44 |
45 | xi_oovs = []
46 | for idx_word in range(len(content)):
47 | # some sentences in duc is longer than len_x
48 | if idx_word == consts["len_x"]:
49 | break
50 | w = content[idx_word]
51 |
52 | if w not in w2i: # OOV
53 | if w not in xi_oovs:
54 | xi_oovs.append(w)
55 | self.x_ext[idx_word, idx_doc] = dict_size + xi_oovs.index(w) # 500005, 51000
56 | w = i2w[modules["lfw_emb"]]
57 | else:
58 | self.x_ext[idx_word, idx_doc] = w2i[w]
59 |
60 | self.x[idx_word, idx_doc] = w2i[w]
61 | self.x_mask[idx_word, idx_doc, 0] = 1
62 | self.len_x.append(np.sum(self.x_mask[:, idx_doc, :]))
63 | self.x_ext_words.append(xi_oovs)
64 | if self.max_ext_len < len(xi_oovs):
65 | self.max_ext_len = len(xi_oovs)
66 |
67 | if options["has_y"]:
68 | for idx_word in range(len(summary)):
69 | w = summary[idx_word]
70 |
71 | if w not in w2i:
72 | if w in xi_oovs:
73 | self.y_ext[idx_word, idx_doc] = dict_size + xi_oovs.index(w)
74 | else:
75 | self.y_ext[idx_word, idx_doc] = w2i[i2w[modules["lfw_emb"]]] # unk
76 | w = i2w[modules["lfw_emb"]]
77 | else:
78 | self.y_ext[idx_word, idx_doc] = w2i[w]
79 | self.y[idx_word, idx_doc] = w2i[w]
80 |
81 | if idx_word == 0:
82 | self.y_inp[idx_word, idx_doc] = modules["bos_idx"]
83 | if idx_word < (len(summary) - 1):
84 | self.y_inp[idx_word + 1, idx_doc] = w2i[w]
85 |
86 | if not options["is_predicting"]:
87 | self.y_mask[idx_word, idx_doc, 0] = 1
88 | self.len_y.append(len(summary))
89 | else:
90 | self.y = self.y_mask = None
91 |
92 | max_len_x = int(np.max(self.len_x))
93 | max_len_y = int(np.max(self.len_y))
94 |
95 | self.x = self.x[0:max_len_x, :]
96 | self.x_ext = self.x_ext[0:max_len_x, :]
97 | self.x_mask = self.x_mask[0:max_len_x, :, :]
98 | self.y = self.y[0:max_len_y, :]
99 | self.y_inp = self.y_inp[0:max_len_y, :]
100 | self.y_ext = self.y_ext[0:max_len_y, :]
101 | self.y_mask = self.y_mask[0:max_len_y, :, :]
102 |
103 | def get_data(xy_list, modules, consts, options):
104 | return BatchData(xy_list, modules, consts, options)
105 |
106 | def batched(x_size, options, consts):
107 | batch_size = consts["testing_batch_size"] if options["is_predicting"] else consts["batch_size"]
108 | if options["is_debugging"]:
109 | x_size = 13
110 | ids = [i for i in range(x_size)]
111 | if not options["is_predicting"]:
112 | shuffle(ids)
113 | batch_list = []
114 | batch_ids = []
115 | for i in range(x_size):
116 | idx = ids[i]
117 | batch_ids.append(idx)
118 | if len(batch_ids) == batch_size or i == (x_size - 1):
119 | batch_list.append(batch_ids)
120 | batch_ids = []
121 | return batch_list, len(ids), len(batch_list)
122 |
123 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #pylint: skip-file
3 | import sys
4 | import numpy as np
5 | import torch
6 | import torch as T
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | import torch.nn.functional as F
10 |
11 | from utils_pg import *
12 |
13 | from transformer import TransformerLayer, Embedding, LearnedPositionalEmbedding, gelu, LayerNorm, SelfAttentionMask
14 | from word_prob_layer import *
15 | from label_smoothing import LabelSmoothing
16 |
17 | class Model(nn.Module):
18 | def __init__(self, modules, consts, options):
19 | super(Model, self).__init__()
20 |
21 | self.has_learnable_w2v = options["has_learnable_w2v"]
22 | self.is_predicting = options["is_predicting"]
23 | self.is_bidirectional = options["is_bidirectional"]
24 | self.beam_decoding = options["beam_decoding"]
25 | self.cell = options["cell"]
26 | self.device = options["device"]
27 | self.copy = options["copy"]
28 | self.coverage = options["coverage"]
29 | self.avg_nll = options["avg_nll"]
30 |
31 | self.dim_x = consts["dim_x"]
32 | self.dim_y = consts["dim_y"]
33 | self.len_x = consts["len_x"]
34 | self.len_y = consts["len_y"]
35 | self.hidden_size = consts["hidden_size"]
36 | self.dict_size = consts["dict_size"]
37 | self.pad_token_idx = consts["pad_token_idx"]
38 | self.ctx_size = self.hidden_size * 2 if self.is_bidirectional else self.hidden_size
39 | self.num_layers = consts["num_layers"]
40 | self.d_ff = consts["d_ff"]
41 | self.num_heads = consts["num_heads"]
42 | self.dropout = consts["dropout"]
43 | self.smoothing_factor = consts["label_smoothing"]
44 |
45 | self.tok_embed = nn.Embedding(self.dict_size, self.dim_x, self.pad_token_idx)
46 | self.pos_embed = LearnedPositionalEmbedding(self.dim_x, device=self.device)
47 |
48 | self.enc_layers = nn.ModuleList()
49 | for i in range(self.num_layers):
50 | self.enc_layers.append(TransformerLayer(self.dim_x, self.d_ff, self.num_heads, self.dropout))
51 |
52 | self.dec_layers = nn.ModuleList()
53 | for i in range(self.num_layers):
54 | self.dec_layers.append(TransformerLayer(self.dim_x, self.d_ff, self.num_heads, self.dropout, with_external=True))
55 |
56 | self.attn_mask = SelfAttentionMask(device=self.device)
57 |
58 | self.emb_layer_norm = LayerNorm(self.dim_x)
59 |
60 | self.word_prob = WordProbLayer(self.hidden_size, self.dict_size, self.device, self.copy, self.coverage, self.dropout)
61 |
62 | self.smoothing = LabelSmoothing(self.device, self.dict_size, self.pad_token_idx, self.smoothing_factor)
63 |
64 | self.init_weights()
65 |
66 | def init_weights(self):
67 | init_uniform_weight(self.tok_embed.weight)
68 |
69 |
70 | def label_smotthing_loss(self, y_pred, y, y_mask, avg=True):
71 | seq_len, bsz = y.size()
72 |
73 | y_pred = T.log(y_pred.clamp(min=1e-8))
74 | loss = self.smoothing(y_pred.view(seq_len * bsz, -1), y.view(seq_len * bsz, -1))
75 | if avg:
76 | return loss / T.sum(y_mask)
77 | else:
78 | return loss / bsz
79 |
80 | def nll_loss(self, y_pred, y, y_mask, avg=True):
81 | cost = -T.log(T.gather(y_pred, 2, y.view(y.size(0), y.size(1), 1)))
82 | cost = cost.view(y.shape)
83 | y_mask = y_mask.view(y.shape)
84 | if avg:
85 | cost = T.sum(cost * y_mask, 0) / T.sum(y_mask, 0)
86 | else:
87 | cost = T.sum(cost * y_mask, 0)
88 | cost = cost.view((y.size(1), -1))
89 | return T.mean(cost)
90 |
91 | def encode(self, inp):
92 | seq_len, bsz = inp.size()
93 | x = self.tok_embed(inp) + self.pos_embed(inp)
94 | x = self.emb_layer_norm(x)
95 | x = F.dropout(x, p=self.dropout, training=self.training)
96 | padding_mask = torch.eq(inp, self.pad_token_idx)
97 | if not padding_mask.any():
98 | padding_mask = None
99 |
100 | xs = []
101 | for layer_id, layer in enumerate(self.enc_layers):
102 | x, _ ,_ = layer(x, self_padding_mask=padding_mask)
103 | xs.append(x)
104 |
105 | return x, padding_mask
106 |
107 |
108 | def decode(self, inp, mask_x, mask_y, src, src_padding_mask, xids=None, max_ext_len=None):
109 | seq_len, bsz = inp.size()
110 | x = self.tok_embed(inp) + self.pos_embed(inp)
111 | x = self.emb_layer_norm(x)
112 | x = F.dropout(x, p=self.dropout, training=self.training)
113 | h = x
114 | if not self.is_predicting:
115 | mask_y = mask_y.view((seq_len, bsz))
116 | padding_mask = torch.eq(mask_y, self.pad_token_idx)
117 | if not padding_mask.any():
118 | padding_mask = None
119 | else:
120 | padding_mask = None
121 |
122 | self_attn_mask = self.attn_mask(seq_len)
123 |
124 | for layer_id, layer in enumerate(self.dec_layers):
125 | x, _, _ = layer(x, self_padding_mask=padding_mask,\
126 | self_attn_mask = self_attn_mask,\
127 | external_memories = src,\
128 | external_padding_mask = src_padding_mask,\
129 | need_weights = False)
130 | if self.copy:
131 | y_dec, attn_dist = self.word_prob(x, h, src, src_padding_mask, xids, max_ext_len)
132 | else:
133 | y_dec, attn_dist = self.word_prob(x)
134 |
135 | return y_dec, attn_dist
136 |
137 |
138 | def forward(self, x, y_inp, y_tgt, mask_x, mask_y, x_ext, y_ext, max_ext_len):
139 | hs, src_padding_mask = self.encode(x)
140 | if self.copy:
141 | y_pred, _ = self.decode(y_inp, mask_x, mask_y, hs, src_padding_mask, x_ext, max_ext_len)
142 | cost = self.label_smotthing_loss(y_pred, y_ext, mask_y, self.avg_nll)
143 | else:
144 | y_pred, _ = self.decode(y_inp, mask_x, mask_y, hs, src_padding_mask)
145 | cost = self.nll_loss(y_pred, y_tgt, mask_y, self.avg_nll)
146 |
147 | return y_pred, cost
148 |
149 |
--------------------------------------------------------------------------------
/bleu.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
18 | """Python implementation of BLEU and smooth-BLEU.
19 |
20 | copy from: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
21 |
22 | This module provides a Python implementation of BLEU and smooth-BLEU.
23 | Smooth BLEU is computed following the method outlined in the paper:
24 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
25 | evaluation metrics for machine translation. COLING 2004.
26 | """
27 |
28 | import collections
29 | import math
30 | import os
31 | import argparse
32 |
33 | def load_lines(f_path):
34 | lines = []
35 | with open(f_path, "r") as f:
36 | for line in f:
37 | line = line.strip('\n').strip('\r')
38 | fs = line.split()
39 | lines.append(fs)
40 | return lines
41 |
42 | def _get_ngrams(segment, max_order):
43 | """Extracts all n-grams upto a given maximum order from an input segment.
44 |
45 | Args:
46 | segment: text segment from which n-grams will be extracted.
47 | max_order: maximum length in tokens of the n-grams returned by this
48 | methods.
49 |
50 | Returns:
51 | The Counter containing all n-grams upto max_order in segment
52 | with a count of how many times each n-gram occurred.
53 | """
54 | ngram_counts = collections.Counter()
55 | for order in range(1, max_order + 1):
56 | for i in range(0, len(segment) - order + 1):
57 | ngram = tuple(segment[i:i+order])
58 | ngram_counts[ngram] += 1
59 | return ngram_counts
60 |
61 |
62 | def compute_bleu(reference_corpus, translation_corpus, max_order=4,
63 | smooth=False):
64 | """Computes BLEU score of translated segments against one or more references.
65 |
66 | Args:
67 | reference_corpus: list of lists of references for each translation. Each
68 | reference should be tokenized into a list of tokens.
69 | translation_corpus: list of translations to score. Each translation
70 | should be tokenized into a list of tokens.
71 | max_order: Maximum n-gram order to use when computing BLEU score.
72 | smooth: Whether or not to apply Lin et al. 2004 smoothing.
73 |
74 | Returns:
75 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
76 | precisions and brevity penalty.
77 | """
78 | matches_by_order = [0] * max_order
79 | possible_matches_by_order = [0] * max_order
80 | reference_length = 0
81 | translation_length = 0
82 | for (references, translation) in zip(reference_corpus,
83 | translation_corpus):
84 | reference_length += min(len(r) for r in references)
85 | translation_length += len(translation)
86 |
87 | merged_ref_ngram_counts = collections.Counter()
88 | for reference in references:
89 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
90 | translation_ngram_counts = _get_ngrams(translation, max_order)
91 | overlap = translation_ngram_counts & merged_ref_ngram_counts
92 | for ngram in overlap:
93 | matches_by_order[len(ngram)-1] += overlap[ngram]
94 | for order in range(1, max_order+1):
95 | possible_matches = len(translation) - order + 1
96 | if possible_matches > 0:
97 | possible_matches_by_order[order-1] += possible_matches
98 |
99 | precisions = [0] * max_order
100 | for i in range(0, max_order):
101 | if smooth:
102 | precisions[i] = ((matches_by_order[i] + 1.) /
103 | (possible_matches_by_order[i] + 1.))
104 | else:
105 | if possible_matches_by_order[i] > 0:
106 | precisions[i] = (float(matches_by_order[i]) /
107 | possible_matches_by_order[i])
108 | else:
109 | precisions[i] = 0.0
110 |
111 | if min(precisions) > 0:
112 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
113 | geo_mean = math.exp(p_log_sum)
114 | else:
115 | geo_mean = 0
116 |
117 | ratio = float(translation_length) / reference_length
118 |
119 | if ratio > 1.0:
120 | bp = 1.
121 | else:
122 | bp = math.exp(1 - 1. / ratio)
123 |
124 | bleu = geo_mean * bp
125 |
126 | return (bleu, precisions, bp, ratio, translation_length, reference_length)
127 |
128 |
129 |
130 | def bleu(ref_path, pred_path, smooth=True, n = 1):
131 | id2f_ref = {}
132 | id2f_pred = {}
133 |
134 | flist = os.listdir(ref_path)
135 | for fname in flist:
136 | id_ = fname
137 | id2f_ref[id_] = ref_path + fname
138 |
139 | flist = os.listdir(pred_path)
140 | for fname in flist:
141 | id_ = fname
142 | id2f_pred[id_] = pred_path + fname
143 |
144 | assert len(id2f_ref) == len(id2f_pred)
145 |
146 | ref_lists = []
147 | pred_lists = []
148 | for fid, fpath in id2f_ref.items():
149 | ref_list = load_lines(fpath)
150 | assert len(ref_list) == n
151 | ref_lists.append(ref_list)
152 |
153 | pred_list = load_lines(id2f_pred[fid])
154 | assert len(pred_list) == n
155 | pred_lists.append(pred_list[0])
156 |
157 |
158 | return compute_bleu(ref_lists, pred_lists, smooth=smooth)
159 |
160 | bleu("./weibo/result/ground_truth/", "./weibo/result/summary/", smooth=True)
161 |
162 | if __name__ == "__main__":
163 | parser = argparse.ArgumentParser()
164 | parser.add_argument("-r", "--ref", help="reference path")
165 | parser.add_argument("-p", "--pred", help="prediction path")
166 | args = parser.parse_args()
167 |
168 | bleu, precisions, bp, ratio, translation_length, reference_length = bleu(args.ref, args.pred)
169 | print "BLEU = ",bleu
170 | print "BLEU1 = ",precisions[0]
171 | print "BLEU2 = ",precisions[1]
172 | print "BLEU3 = ",precisions[2]
173 | print "BLEU4 = ",precisions[3]
174 | print "ratio = ",ratio
175 |
--------------------------------------------------------------------------------
/utils_pg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #pylint: skip-file
3 | import numpy as np
4 | from numpy.random import random as rand
5 | import pickle
6 | import sys
7 | import os
8 | import shutil
9 | from copy import deepcopy
10 | import random
11 |
12 | import torch
13 | from torch import nn
14 |
15 |
16 | def init_seeds():
17 | random.seed(123)
18 | torch.manual_seed(123)
19 | if torch.cuda.is_available():
20 | torch.cuda.manual_seed_all(123)
21 |
22 | def init_lstm_weight(lstm):
23 | for param in lstm.parameters():
24 | if len(param.shape) >= 2: # weights
25 | init_ortho_weight(param.data)
26 | else: # bias
27 | init_bias(param.data)
28 |
29 | def init_gru_weight(gru):
30 | for param in gru.parameters():
31 | if len(param.shape) >= 2: # weights
32 | init_ortho_weight(param.data)
33 | else: # bias
34 | init_bias(param.data)
35 |
36 | def init_linear_weight(linear):
37 | init_xavier_weight(linear.weight)
38 | if linear.bias is not None:
39 | init_bias(linear.bias)
40 |
41 | def init_normal_weight(w):
42 | nn.init.normal_(w, mean=0, std=0.01)
43 |
44 | def init_uniform_weight(w):
45 | nn.init.uniform_(w, -0.1, 0.1)
46 |
47 | def init_ortho_weight(w):
48 | nn.init.orthogonal_(w)
49 |
50 | def init_xavier_weight(w):
51 | nn.init.xavier_normal_(w)
52 |
53 | def init_bias(b):
54 | nn.init.constant_(b, 0.)
55 |
56 | def rebuild_dir(path):
57 | if os.path.exists(path):
58 | try:
59 | shutil.rmtree(path)
60 | except OSError:
61 | pass
62 | os.mkdir(path)
63 |
64 | def save_model(f, model, optimizer):
65 | torch.save({"model_state_dict" : model.state_dict(),
66 | "optimizer_state_dict" : optimizer.state_dict()},
67 | f)
68 |
69 | def load_model(f, model, optimizer):
70 | checkpoint = torch.load(f)
71 | model.load_state_dict(checkpoint["model_state_dict"])
72 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
73 | return model, optimizer
74 |
75 | def sort_samples(x, len_x, mask_x, y, len_y, \
76 | mask_y, oys, x_ext, y_ext, oovs):
77 | sorted_x_idx = np.argsort(len_x)[::-1]
78 |
79 | sorted_x_len = np.array(len_x)[sorted_x_idx]
80 | sorted_x = x[:, sorted_x_idx]
81 | sorted_x_mask = mask_x[:, sorted_x_idx, :]
82 | sorted_oovs = [oovs[i] for i in sorted_x_idx]
83 |
84 | sorted_y_len = np.array(len_y)[sorted_x_idx]
85 | sorted_y = y[:, sorted_x_idx]
86 | sorted_y_mask = mask_y[:, sorted_x_idx, :]
87 | sorted_oys = [oys[i] for i in sorted_x_idx]
88 | sorted_x_ext = x_ext[:, sorted_x_idx]
89 | sorted_y_ext = y_ext[:, sorted_x_idx]
90 |
91 | return sorted_x, sorted_x_len, sorted_x_mask, sorted_y, \
92 | sorted_y_len, sorted_y_mask, sorted_oys, \
93 | sorted_x_ext, sorted_y_ext, sorted_oovs
94 |
95 | def print_sent_dec(y_pred, y, y_mask, oovs, modules, consts, options, batch_size):
96 | print("golden truth and prediction samples:")
97 | max_y_words = np.sum(y_mask, axis = 0)
98 | max_y_words = max_y_words.reshape((batch_size))
99 | max_num_docs = 16 if batch_size > 16 else batch_size
100 | is_unicode = options["is_unicode"]
101 | dict_size = len(modules["i2w"])
102 | for idx_doc in range(max_num_docs):
103 | print(idx_doc + 1, "----------------------------------------------------------------------------------------------------")
104 | sent_true= ""
105 | for idx_word in range(max_y_words[idx_doc]):
106 | i = y[idx_word, idx_doc] if options["has_learnable_w2v"] else np.argmax(y[idx_word, idx_doc])
107 | if i in modules["i2w"]:
108 | sent_true += modules["i2w"][i]
109 | else:
110 | sent_true += oovs[idx_doc][i - dict_size]
111 | if not is_unicode:
112 | sent_true += " "
113 |
114 | if is_unicode:
115 | print(sent_true.encode("utf-8"))
116 | else:
117 | print(sent_true)
118 |
119 | print()
120 |
121 | sent_pred = ""
122 | for idx_word in range(max_y_words[idx_doc]):
123 | i = torch.argmax(y_pred[idx_word, idx_doc, :]).item()
124 | if i in modules["i2w"]:
125 | sent_pred += modules["i2w"][i]
126 | else:
127 | sent_pred += oovs[idx_doc][i - dict_size]
128 | if not is_unicode:
129 | sent_pred += " "
130 | if is_unicode:
131 | print(sent_pred.encode("utf-8"))
132 | else:
133 | print(sent_pred)
134 | print("----------------------------------------------------------------------------------------------------")
135 | print()
136 |
137 |
138 | def write_for_rouge(fname, ref_sents, dec_words, cfg):
139 | dec_sents = []
140 | while len(dec_words) > 0:
141 | try:
142 | fst_period_idx = dec_words.index(".")
143 | except ValueError:
144 | fst_period_idx = len(dec_words)
145 | sent = dec_words[:fst_period_idx + 1]
146 | dec_words = dec_words[fst_period_idx + 1:]
147 | dec_sents.append(' '.join(sent))
148 |
149 | ref_file = "".join((cfg.cc.GROUND_TRUTH_PATH, fname))
150 | decoded_file = "".join((cfg.cc.SUMM_PATH, fname))
151 |
152 | with open(ref_file, "w") as f:
153 | for idx, sent in enumerate(ref_sents):
154 | sent = sent.strip()
155 | f.write(sent) if idx == len(ref_sents) - 1 else f.write(sent + "\n")
156 | with open(decoded_file, "w") as f:
157 | for idx, sent in enumerate(dec_sents):
158 | sent = sent.strip()
159 | f.write(sent) if idx == len(dec_sents) - 1 else f.write(sent + "\n")
160 |
161 | def write_summ(dst_path, summ_list, num_summ, options, i2w = None, oovs=None, score_list = None):
162 | assert num_summ > 0
163 | with open(dst_path, "w") as f_summ:
164 | if num_summ == 1:
165 | if score_list != None:
166 | f_summ.write(str(score_list[0]))
167 | f_summ.write("\t")
168 | if i2w != None:
169 | '''
170 | for e in summ_list:
171 | e = int(e)
172 | if e in i2w:
173 | print i2w[e],
174 | else:
175 | print oovs[e - len(i2w)],
176 | print "\n"
177 | '''
178 | s = []
179 | for e in summ_list:
180 | e = int(e)
181 | if e in i2w:
182 | s.append(i2w[e])
183 | else:
184 | s.append(oovs[e - len(i2w)])
185 | s = " ".join(s)
186 | else:
187 | s = " ".join(summ_list)
188 | f_summ.write(s)
189 | f_summ.write("\n")
190 | else:
191 | assert num_summ == len(summ_list)
192 | if score_list != None:
193 | assert num_summ == len(score_list)
194 |
195 | for i in range(num_summ):
196 | if score_list != None:
197 | f_summ.write(str(score_list[i]))
198 | f_summ.write("\t")
199 | if i2w != None:
200 | '''
201 | for e in summ_list[i]:
202 | e = int(e)
203 | if e in i2w:
204 | print i2w[e],
205 | else:
206 | print oovs[e - len(i2w)],
207 | print "\n"
208 | '''
209 | s = []
210 | for e in summ_list[i]:
211 | e = int(e)
212 | if e in i2w:
213 | s.append(i2w[e])
214 | else:
215 | s.append(oovs[e - len(i2w)])
216 | s = " ".join(s)
217 | else:
218 | s = " ".join(summ_list[i])
219 |
220 | f_summ.write(s)
221 | f_summ.write("\n")
222 |
223 |
224 |
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import operator
3 | from os import makedirs
4 | from os.path import exists
5 | import argparse
6 | from configs import *
7 | import pickle
8 | import numpy as np
9 | import re
10 | from random import shuffle
11 | import string
12 | import struct
13 |
14 | def run(d_type, d_path):
15 | prepare_deepmind(d_path)
16 |
17 | stop_words = {"-lrb-", "-rrb-", "-"}
18 | unk_words = {"unk", ""}
19 |
20 | def get_xy_tuple(cont, head, cfg):
21 | x = read_cont(cont, cfg)
22 | y = read_head(head, cfg)
23 |
24 | if x != None and y != None:
25 | return (x, y)
26 | else:
27 | return None
28 |
29 | def load_lines(d_path, f_name, configs):
30 | lines = []
31 | f_path = d_path + f_name
32 | with open(f_path, 'r') as f:
33 | for line in f:
34 | line = line.strip("\n").lower()
35 | fs = line.split("")
36 | if len(fs) == 2:
37 | xy_tuple = get_xy_tuple(fs[1], fs[0], configs)
38 | else:
39 | print("ERROR:" + line)
40 | continue
41 | if xy_tuple != None:
42 | lines.append(xy_tuple)
43 | return lines
44 |
45 | def load_dict(d_path, f_name, dic, dic_list):
46 | f_path = d_path + f_name
47 | f = open(f_path, "r")
48 | for line in f:
49 | line = line.strip('\n').strip('\r').lower()
50 | if line:
51 | tf = line.split()
52 | if len(tf) == 2:
53 | dic[tf[0]] = int(tf[1])
54 | dic_list.append(tf[0])
55 | else:
56 | print("warning in vocab:", line)
57 | return dic, dic_list
58 |
59 | def to_dict(xys, dic):
60 | # dict should not consider test set!!!!!
61 | for xy in xys:
62 | sents, summs = xy
63 | y = summs[0]
64 | for w in y:
65 | if w in dic:
66 | dic[w] += 1
67 | else:
68 | dic[w] = 1
69 |
70 | x = sents[0]
71 | for w in x:
72 | if w in dic:
73 | dic[w] += 1
74 | else:
75 | dic[w] = 1
76 | return dic
77 |
78 |
79 | def del_num(s):
80 | return re.sub(r"(\b|\s+\-?|^\-?)(\d+|\d*\.\d+)\b","#", s)
81 |
82 | def read_cont(f_cont, cfg):
83 | lines = []
84 | line = f_cont #del_num(f_cont)
85 | words = line.split()
86 | num_words = len(words)
87 | if num_words >= cfg.MIN_LEN_X and num_words < cfg.MAX_LEN_X:
88 | lines += words
89 | elif num_words >= cfg.MAX_LEN_X:
90 | lines += words[0:cfg.MAX_LEN_X]
91 | lines += [cfg.W_EOS]
92 | return (lines, f_cont) if len(lines) >= cfg.MIN_LEN_X and len(lines) <= cfg.MAX_LEN_X+1 else None
93 |
94 | def abstract2sents(abstract, cfg):
95 | cur = 0
96 | sents = []
97 | while True:
98 | try:
99 | start_p = abstract.index(cfg.W_LS, cur)
100 | end_p = abstract.index(cfg.W_RS, start_p + 1)
101 | cur = end_p + len(cfg.W_RS)
102 | sents.append(abstract[start_p+len(cfg.W_LS):end_p])
103 | except ValueError as e: # no more sentences
104 | return sents
105 |
106 | def read_head(f_head, cfg):
107 | lines = []
108 |
109 | sents = abstract2sents(f_head, cfg)
110 | line = ' '.join(sents)
111 | words = line.split()
112 | num_words = len(words)
113 | if num_words >= cfg.MIN_LEN_Y and num_words <= cfg.MAX_LEN_Y:
114 | lines += words
115 | lines += [cfg.W_EOS]
116 | elif num_words > cfg.MAX_LEN_Y: # do not know if should be stoped
117 | lines = words[0 : cfg.MAX_LEN_Y + 1] # one more word.
118 |
119 | return (lines, sents) if len(lines) >= cfg.MIN_LEN_Y and len(lines) <= cfg.MAX_LEN_Y+1 else None
120 |
121 | def prepare_deepmind(d_path):
122 | configs = DeepmindConfigs()
123 | TRAINING_PATH = configs.cc.TRAINING_DATA_PATH
124 | VALIDATE_PATH = configs.cc.VALIDATE_DATA_PATH
125 | TESTING_PATH = configs.cc.TESTING_DATA_PATH
126 | RESULT_PATH = configs.cc.RESULT_PATH
127 | MODEL_PATH = configs.cc.MODEL_PATH
128 | BEAM_SUMM_PATH = configs.cc.BEAM_SUMM_PATH
129 | BEAM_GT_PATH = configs.cc.BEAM_GT_PATH
130 | GROUND_TRUTH_PATH = configs.cc.GROUND_TRUTH_PATH
131 | SUMM_PATH = configs.cc.SUMM_PATH
132 | TMP_PATH = configs.cc.TMP_PATH
133 |
134 | print ("train: " + TRAINING_PATH)
135 | print ("test: " + TESTING_PATH)
136 | print ("validate: " + VALIDATE_PATH)
137 | print ("result: " + RESULT_PATH)
138 | print ("model: " + MODEL_PATH)
139 | print ("tmp: " + TMP_PATH)
140 |
141 | if not exists(TRAINING_PATH):
142 | makedirs(TRAINING_PATH)
143 | if not exists(VALIDATE_PATH):
144 | makedirs(VALIDATE_PATH)
145 | if not exists(TESTING_PATH):
146 | makedirs(TESTING_PATH)
147 | if not exists(RESULT_PATH):
148 | makedirs(RESULT_PATH)
149 | if not exists(MODEL_PATH):
150 | makedirs(MODEL_PATH)
151 | if not exists(BEAM_SUMM_PATH):
152 | makedirs(BEAM_SUMM_PATH)
153 | if not exists(BEAM_GT_PATH):
154 | makedirs(BEAM_GT_PATH)
155 | if not exists(GROUND_TRUTH_PATH):
156 | makedirs(GROUND_TRUTH_PATH)
157 | if not exists(SUMM_PATH):
158 | makedirs(SUMM_PATH)
159 | if not exists(TMP_PATH):
160 | makedirs(TMP_PATH)
161 |
162 |
163 | print ("trainset...")
164 | train_xy_list = load_lines(d_path, "train.txt", configs)
165 |
166 | print ("dump train...")
167 | pickle.dump(train_xy_list, open(TRAINING_PATH + "train.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL)
168 |
169 |
170 | print ("fitering and building dict...")
171 | use_abisee = True
172 | all_dic1 = {}
173 | all_dic2 = {}
174 | dic_list = []
175 | all_dic1, dic_list = load_dict(d_path, "vocab", all_dic1, dic_list)
176 | all_dic2 = to_dict(train_xy_list, all_dic2)
177 | for w, tf in all_dic2.items():
178 | if w not in all_dic1:
179 | all_dic1[w] = tf
180 |
181 | candiate_list = dic_list[0:configs.PG_DICT_SIZE] # 50000
182 | candiate_set = set(candiate_list)
183 |
184 | dic = {}
185 | w2i = {}
186 | i2w = {}
187 | w2w = {}
188 |
189 | for w in [configs.W_PAD, configs.W_UNK, configs.W_BOS, configs.W_EOS]:
190 | #for w in [configs.W_PAD, configs.W_UNK, configs.W_BOS, configs.W_EOS, configs.W_LS, configs.W_RS]:
191 | w2i[w] = len(dic)
192 | i2w[w2i[w]] = w
193 | dic[w] = 10000
194 | w2w[w] = w
195 |
196 | for w, tf in all_dic1.items():
197 | if w in candiate_set:
198 | w2i[w] = len(dic)
199 | i2w[w2i[w]] = w
200 | dic[w] = tf
201 | w2w[w] = w
202 | else:
203 | w2w[w] = configs.W_UNK
204 | hfw = []
205 | sorted_x = sorted(dic.items(), key=operator.itemgetter(1), reverse=True)
206 | for w in sorted_x:
207 | hfw.append(w[0])
208 |
209 | assert len(hfw) == len(dic)
210 | assert len(w2i) == len(dic)
211 | print ("dump dict...")
212 | pickle.dump([all_dic1, dic, hfw, w2i, i2w, w2w], open(TRAINING_PATH + "dic.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL)
213 |
214 | print ("testset...")
215 | test_xy_list = load_lines(d_path, "test.txt", configs)
216 |
217 | print ("validset...")
218 | valid_xy_list = load_lines(d_path, "val.txt", configs)
219 |
220 |
221 | print ("#train = ", len(train_xy_list))
222 | print ("#test = ", len(test_xy_list))
223 | print ("#validate = ", len(valid_xy_list))
224 |
225 | print ("#all_dic = ", len(all_dic1), ", #dic = ", len(dic), ", #hfw = ", len(hfw))
226 |
227 | print ("dump test...")
228 | pickle.dump(test_xy_list, open(TESTING_PATH + "test.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL)
229 | shuffle(test_xy_list)
230 | pickle.dump(test_xy_list[0:2000], open(TESTING_PATH + "pj2000.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL)
231 |
232 | print ("dump validate...")
233 | pickle.dump(valid_xy_list, open(VALIDATE_PATH + "valid.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL)
234 | pickle.dump(valid_xy_list[0:1000], open(VALIDATE_PATH + "pj1000.pkl", "wb"), protocol = pickle.HIGHEST_PROTOCOL)
235 |
236 | print ("done.")
237 |
238 | if __name__ == "__main__":
239 | parser = argparse.ArgumentParser()
240 | parser.add_argument("-d", "--data", default="cnndm", help="dataset path", )
241 | args = parser.parse_args()
242 |
243 | data_type = "cnndm"
244 | raw_path = "./data/"
245 |
246 | print (data_type, raw_path)
247 | run(data_type, raw_path)
248 |
--------------------------------------------------------------------------------
/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Parameter
4 | import torch.nn.functional as F
5 | import math
6 |
7 | class TransformerLayer(nn.Module):
8 |
9 | def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_external=False, weights_dropout = True):
10 | super(TransformerLayer, self).__init__()
11 | self.self_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)
12 | self.fc1 = nn.Linear(embed_dim, ff_embed_dim)
13 | self.fc2 = nn.Linear(ff_embed_dim, embed_dim)
14 | self.attn_layer_norm = LayerNorm(embed_dim)
15 | self.ff_layer_norm = LayerNorm(embed_dim)
16 | self.with_external = with_external
17 | self.dropout = dropout
18 | if self.with_external:
19 | self.external_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)
20 | self.external_layer_norm = LayerNorm(embed_dim)
21 | self.reset_parameters()
22 |
23 | def reset_parameters(self):
24 | nn.init.normal_(self.fc1.weight, std=0.02)
25 | nn.init.normal_(self.fc2.weight, std=0.02)
26 | nn.init.constant_(self.fc1.bias, 0.)
27 | nn.init.constant_(self.fc2.bias, 0.)
28 |
29 | def forward(self, x, kv = None,
30 | self_padding_mask = None, self_attn_mask = None,
31 | external_memories = None, external_padding_mask=None,
32 | need_weights = False):
33 | # x: seq_len x bsz x embed_dim
34 | residual = x
35 | if kv is None:
36 | x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)
37 | else:
38 | x, self_attn = self.self_attn(query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)
39 |
40 | x = F.dropout(x, p=self.dropout, training=self.training)
41 | x = self.attn_layer_norm(residual + x)
42 |
43 | if self.with_external:
44 | residual = x
45 | x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask, need_weights = need_weights)
46 | x = F.dropout(x, p=self.dropout, training=self.training)
47 | x = self.external_layer_norm(residual + x)
48 | else:
49 | external_attn = None
50 |
51 | residual = x
52 | x = gelu(self.fc1(x))
53 | x = F.dropout(x, p=self.dropout, training=self.training)
54 | x = self.fc2(x)
55 | x = F.dropout(x, p=self.dropout, training=self.training)
56 | x = self.ff_layer_norm(residual + x)
57 |
58 | return x, self_attn, external_attn
59 |
60 | class MultiheadAttention(nn.Module):
61 |
62 | def __init__(self, embed_dim, num_heads, dropout=0., weights_dropout=True):
63 | super(MultiheadAttention, self).__init__()
64 | self.embed_dim = embed_dim
65 | self.num_heads = num_heads
66 | self.dropout = dropout
67 | self.head_dim = embed_dim // num_heads
68 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
69 | self.scaling = self.head_dim ** -0.5
70 |
71 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
72 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
73 |
74 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
75 | self.weights_dropout = weights_dropout
76 | self.reset_parameters()
77 |
78 | def reset_parameters(self):
79 | nn.init.normal_(self.in_proj_weight, std=0.02)
80 | nn.init.normal_(self.out_proj.weight, std=0.02)
81 | nn.init.constant_(self.in_proj_bias, 0.)
82 | nn.init.constant_(self.out_proj.bias, 0.)
83 |
84 | def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=False):
85 | """ Input shape: Time x Batch x Channel
86 | key_padding_mask: Time x batch
87 | attn_mask: tgt_len x src_len
88 | """
89 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
90 | kv_same = key.data_ptr() == value.data_ptr()
91 |
92 | tgt_len, bsz, embed_dim = query.size()
93 | assert key.size() == value.size()
94 |
95 | if qkv_same:
96 | # self-attention
97 | q, k, v = self.in_proj_qkv(query)
98 | elif kv_same:
99 | # encoder-decoder attention
100 | q = self.in_proj_q(query)
101 | k, v = self.in_proj_kv(key)
102 | else:
103 | q = self.in_proj_q(query)
104 | k = self.in_proj_k(key)
105 | v = self.in_proj_v(value)
106 | q = q*self.scaling
107 |
108 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
109 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
110 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
111 |
112 | src_len = k.size(1)
113 | # k,v: bsz*heads x src_len x dim
114 | # q: bsz*heads x tgt_len x dim
115 |
116 | attn_weights = torch.bmm(q, k.transpose(1, 2))
117 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
118 |
119 | if attn_mask is not None:
120 | attn_weights.masked_fill_(
121 | attn_mask.unsqueeze(0),
122 | float('-inf')
123 | )
124 |
125 | if key_padding_mask is not None:
126 | # don't attend to padding symbols
127 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
128 | attn_weights.masked_fill_(
129 | key_padding_mask.transpose(0, 1).unsqueeze(1).unsqueeze(2),
130 | float('-inf')
131 | )
132 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
133 |
134 |
135 | attn_weights = F.softmax(attn_weights, dim=-1)
136 |
137 | if self.weights_dropout:
138 | attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
139 |
140 | attn = torch.bmm(attn_weights, v)
141 | if not self.weights_dropout:
142 | attn = F.dropout(attn, p=self.dropout, training=self.training)
143 |
144 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
145 |
146 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
147 | attn = self.out_proj(attn)
148 |
149 | if need_weights:
150 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
151 |
152 | #attn_weights, _ = attn_weights.max(dim=1)
153 | attn_weights = attn_weights[:, 0, :, :]
154 | #attn_weights = attn_weights.mean(dim=1)
155 | attn_weights = attn_weights.transpose(0, 1)
156 | else:
157 | attn_weights = None
158 |
159 | return attn, attn_weights
160 |
161 | def in_proj_qkv(self, query):
162 | return self._in_proj(query).chunk(3, dim=-1)
163 |
164 | def in_proj_kv(self, key):
165 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
166 |
167 | def in_proj_q(self, query):
168 | return self._in_proj(query, end=self.embed_dim)
169 |
170 | def in_proj_k(self, key):
171 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
172 |
173 | def in_proj_v(self, value):
174 | return self._in_proj(value, start=2 * self.embed_dim)
175 |
176 | def _in_proj(self, input, start=0, end=None):
177 | weight = self.in_proj_weight
178 | bias = self.in_proj_bias
179 | weight = weight[start:end, :]
180 | if bias is not None:
181 | bias = bias[start:end]
182 | return F.linear(input, weight, bias)
183 |
184 | def gelu(x):
185 | cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
186 | return cdf*x
187 |
188 | class LayerNorm(nn.Module):
189 | def __init__(self, hidden_size, eps=1e-12):
190 | super(LayerNorm, self).__init__()
191 | self.weight = nn.Parameter(torch.Tensor(hidden_size))
192 | self.bias = nn.Parameter(torch.Tensor(hidden_size))
193 | self.eps = eps
194 | self.reset_parameters()
195 | def reset_parameters(self):
196 | nn.init.constant_(self.weight, 1.)
197 | nn.init.constant_(self.bias, 0.)
198 |
199 | def forward(self, x):
200 | u = x.mean(-1, keepdim=True)
201 | s = (x - u).pow(2).mean(-1, keepdim=True)
202 | x = (x - u) / torch.sqrt(s + self.eps)
203 | return self.weight * x + self.bias
204 |
205 | def Embedding(num_embeddings, embedding_dim, padding_idx):
206 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
207 | nn.init.normal_(m.weight, std=0.02)
208 | nn.init.constant_(m.weight[padding_idx], 0)
209 | return m
210 |
211 | class SelfAttentionMask(nn.Module):
212 | def __init__(self, init_size = 100, device = 0):
213 | super(SelfAttentionMask, self).__init__()
214 | self.weights = SelfAttentionMask.get_mask(init_size)
215 | self.device = device
216 |
217 | @staticmethod
218 | def get_mask(size):
219 | weights = torch.triu(torch.ones((size, size), dtype = torch.bool), 1)
220 | return weights
221 |
222 | def forward(self, size):
223 | if self.weights is None or size > self.weights.size(0):
224 | self.weights = SelfAttentionMask.get_mask(size)
225 | res = self.weights[:size,:size].cuda(self.device).detach()
226 | return res
227 |
228 | class LearnedPositionalEmbedding(nn.Module):
229 | """This module produces LearnedPositionalEmbedding.
230 | """
231 | def __init__(self, embedding_dim, init_size=1024, device=0):
232 | super(LearnedPositionalEmbedding, self).__init__()
233 | self.weights = nn.Embedding(init_size, embedding_dim)
234 | self.device= device
235 | self.reset_parameters()
236 |
237 | def reset_parameters(self):
238 | nn.init.normal_(self.weights.weight, std=0.02)
239 |
240 | def forward(self, input, offset=0):
241 | """Input is expected to be of size [seq_len x bsz]."""
242 | seq_len, bsz = input.size()
243 | positions = (offset + torch.arange(seq_len)).cuda(self.device)
244 | res = self.weights(positions).unsqueeze(1).expand(-1, bsz, -1)
245 | return res
246 |
247 | class SinusoidalPositionalEmbedding(nn.Module):
248 | """This module produces sinusoidal positional embeddings of any length.
249 | """
250 | def __init__(self, embedding_dim, init_size=1024, device=0):
251 | super(SinusoidalPositionalEmbedding, self).__init__()
252 | self.embedding_dim = embedding_dim
253 | self.weights = SinusoidalPositionalEmbedding.get_embedding(
254 | init_size,
255 | embedding_dim
256 | )
257 | self.device= device
258 |
259 | @staticmethod
260 | def get_embedding(num_embeddings, embedding_dim):
261 | """Build sinusoidal embeddings.
262 | This matches the implementation in tensor2tensor, but differs slightly
263 | from the description in Section 3.5 of "Attention Is All You Need".
264 | """
265 | half_dim = embedding_dim // 2
266 | emb = math.log(10000) / (half_dim - 1)
267 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
268 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
269 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
270 | if embedding_dim % 2 == 1:
271 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
272 | return emb
273 |
274 | def forward(self, input, offset=0):
275 | """Input is expected to be of size [seq_len x bsz]."""
276 | seq_len, bsz = input.size()
277 | mx_position = seq_len + offset
278 | if self.weights is None or mx_position > self.weights.size(0):
279 | # recompute/expand embeddings if needed
280 | self.weights = SinusoidalPositionalEmbedding.get_embedding(
281 | mx_position,
282 | self.embedding_dim,
283 | )
284 |
285 | positions = offset + torch.arange(seq_len)
286 | res = self.weights.index_select(0, positions).unsqueeze(1).expand(-1, bsz, -1).cuda(self.device).detach()
287 | return res
288 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | cudaid = 6
4 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cudaid)
5 |
6 | import sys
7 | import time
8 | import numpy as np
9 | import pickle
10 | import copy
11 | import random
12 | from random import shuffle
13 | import math
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch.autograd import Variable
18 |
19 | import data as datar
20 | from model import *
21 | from utils_pg import *
22 | from configs import *
23 | from optim import Optim
24 |
25 | cfg = DeepmindConfigs()
26 | TRAINING_DATASET_CLS = DeepmindTraining
27 | TESTING_DATASET_CLS = DeepmindTesting
28 |
29 | def print_basic_info(modules, consts, options):
30 | if options["is_debugging"]:
31 | print("\nWARNING: IN DEBUGGING MODE\n")
32 | if options["copy"]:
33 | print("USE COPY MECHANISM")
34 | if options["coverage"]:
35 | print("USE COVERAGE MECHANISM")
36 | if options["avg_nll"]:
37 | print("USE AVG NLL as LOSS")
38 | else:
39 | print("USE NLL as LOSS")
40 | if options["has_learnable_w2v"]:
41 | print("USE LEARNABLE W2V EMBEDDING")
42 | if options["is_bidirectional"]:
43 | print("USE BI-DIRECTIONAL RNN")
44 | if options["omit_eos"]:
45 | print(" IS OMITTED IN TESTING DATA")
46 | if options["prediction_bytes_limitation"]:
47 | print("MAXIMUM BYTES IN PREDICTION IS LIMITED")
48 | print("RNN TYPE: " + options["cell"])
49 | for k in consts:
50 | print(k + ":", consts[k])
51 |
52 | def init_modules():
53 |
54 | init_seeds()
55 |
56 | options = {}
57 |
58 | options["is_debugging"] = False
59 | options["is_predicting"] = False
60 | options["model_selection"] = False # When options["is_predicting"] = True, true means use validation set for tuning, false is real testing.
61 |
62 | options["cuda"] = cfg.CUDA and torch.cuda.is_available()
63 | options["device"] = torch.device("cuda" if options["cuda"] else "cpu")
64 |
65 | #in config.py
66 | options["cell"] = cfg.CELL
67 | options["copy"] = cfg.COPY
68 | options["coverage"] = cfg.COVERAGE
69 | options["is_bidirectional"] = cfg.BI_RNN
70 | options["avg_nll"] = cfg.AVG_NLL
71 |
72 | options["beam_decoding"] = cfg.BEAM_SEARCH # False for greedy decoding
73 |
74 | assert TRAINING_DATASET_CLS.IS_UNICODE == TESTING_DATASET_CLS.IS_UNICODE
75 | options["is_unicode"] = TRAINING_DATASET_CLS.IS_UNICODE # True Chinese dataet
76 | options["has_y"] = TRAINING_DATASET_CLS.HAS_Y
77 |
78 | options["has_learnable_w2v"] = True
79 | options["omit_eos"] = False # omit and continuously decode until length of sentence reaches MAX_LEN_PREDICT (for DUC testing data)
80 | options["prediction_bytes_limitation"] = False if TESTING_DATASET_CLS.MAX_BYTE_PREDICT == None else True
81 | options["fire"] = cfg.FIRE
82 |
83 | assert options["is_unicode"] == False
84 |
85 | consts = {}
86 |
87 | consts["idx_gpu"] = cudaid
88 |
89 | consts["norm_clip"] = cfg.NORM_CLIP
90 | consts["dim_x"] = cfg.DIM_X
91 | consts["dim_y"] = cfg.DIM_Y
92 | consts["len_x"] = cfg.MAX_LEN_X + 1 # plus 1 for eos
93 | consts["len_y"] = cfg.MAX_LEN_Y + 1
94 | consts["num_x"] = cfg.MAX_NUM_X
95 | consts["num_y"] = cfg.NUM_Y
96 | consts["hidden_size"] = cfg.HIDDEN_SIZE
97 | consts["d_ff"] = cfg.FF_SIZE
98 | consts["num_heads"] = cfg.NUM_H
99 | consts["dropout"] = cfg.DROPOUT
100 | consts["num_layers"] = cfg.NUM_L
101 | consts["label_smoothing"] = cfg.SMOOTHING
102 | consts["alpha"] = cfg.ALPHA
103 | consts["beta"] = cfg.BETA
104 |
105 | consts["batch_size"] = 5 if options["is_debugging"] else TRAINING_DATASET_CLS.BATCH_SIZE
106 | if options["is_debugging"]:
107 | consts["testing_batch_size"] = 1 if options["beam_decoding"] else 2
108 | else:
109 | #consts["testing_batch_size"] = 1 if options["beam_decoding"] else TESTING_DATASET_CLS.BATCH_SIZE
110 | consts["testing_batch_size"] = TESTING_DATASET_CLS.BATCH_SIZE
111 |
112 | consts["min_len_predict"] = TESTING_DATASET_CLS.MIN_LEN_PREDICT
113 | consts["max_len_predict"] = TESTING_DATASET_CLS.MAX_LEN_PREDICT
114 | consts["max_byte_predict"] = TESTING_DATASET_CLS.MAX_BYTE_PREDICT
115 | consts["testing_print_size"] = TESTING_DATASET_CLS.PRINT_SIZE
116 |
117 | consts["lr"] = cfg.LR
118 | consts["beam_size"] = cfg.BEAM_SIZE
119 |
120 | consts["max_epoch"] = 50 if options["is_debugging"] else 64
121 | consts["print_time"] = 2
122 | consts["save_epoch"] = 1
123 |
124 | assert consts["dim_x"] == consts["dim_y"]
125 | assert consts["beam_size"] >= 1
126 |
127 | modules = {}
128 |
129 | [_, dic, hfw, w2i, i2w, w2w] = pickle.load(open(cfg.cc.TRAINING_DATA_PATH + "dic.pkl", "rb"))
130 | consts["dict_size"] = len(dic)
131 | modules["dic"] = dic
132 | modules["w2i"] = w2i
133 | modules["i2w"] = i2w
134 | modules["lfw_emb"] = modules["w2i"][cfg.W_UNK]
135 | modules["eos_emb"] = modules["w2i"][cfg.W_EOS]
136 | modules["bos_idx"] = modules["w2i"][cfg.W_BOS]
137 | consts["pad_token_idx"] = modules["w2i"][cfg.W_PAD]
138 |
139 | return modules, consts, options
140 |
141 | def beam_decode(fname, batch, model, modules, consts, options):
142 | fname = str(fname)
143 |
144 | beam_size = consts["beam_size"]
145 | num_live = 1
146 | num_dead = 0
147 | samples = []
148 | sample_scores = np.zeros(beam_size)
149 |
150 | last_traces = [[]]
151 | last_scores = torch.FloatTensor(np.zeros(1)).to(options["device"])
152 | last_c_scores = torch.FloatTensor(np.zeros(1)).to(options["device"])
153 | last_states = [[]]
154 |
155 | if options["copy"]:
156 | x, x_mask, word_emb, padding_mask, y, len_y, ref_sents, max_ext_len, oovs = batch
157 | else:
158 | x, word_emb, padding_mask, y, len_y, ref_sents = batch
159 |
160 | ys = torch.LongTensor(np.ones((1, num_live), dtype="int64") * modules["bos_idx"]).to(options["device"])
161 | x = x.unsqueeze(1)
162 | word_emb = word_emb.unsqueeze(1)
163 | padding_mask = padding_mask.unsqueeze(1)
164 | if options["copy"]:
165 | x_mask = x_mask.unsqueeze(1)
166 |
167 | for step in range(consts["max_len_predict"]):
168 | tile_word_emb = word_emb.repeat(1, num_live, 1)
169 | tile_padding_mask = padding_mask.repeat(1, num_live)
170 | if options["copy"]:
171 | tile_x = x.repeat(1, num_live)
172 | tile_x_mask = x_mask.repeat(1, num_live, 1)
173 |
174 | if options["copy"]:
175 | y_pred, attn_dist = model.decode(ys, tile_x_mask, None, tile_word_emb, tile_padding_mask, tile_x, max_ext_len)
176 | else:
177 | y_pred, attn_dist = model.decode(ys, None, None, tile_word_emb, tile_padding_mask)
178 |
179 | dict_size = y_pred.shape[-1]
180 | y_pred = y_pred[-1, :, :]
181 | if options["coverage"]:
182 | attn_dist = attn_dist[-1, :, :]
183 |
184 | cand_y_scores = last_scores + torch.log(y_pred) # larger is better
185 | if options["coverage"]:
186 | cand_scores = (cand_y_scores + last_c_scores).flatten()
187 | else:
188 | cand_scores = cand_y_scores.flatten()
189 | idx_top_joint_scores = torch.topk(cand_scores, beam_size - num_dead)[1]
190 |
191 | idx_last_traces = idx_top_joint_scores // dict_size
192 | idx_word_now = idx_top_joint_scores % dict_size
193 | top_joint_scores = cand_y_scores.flatten()[idx_top_joint_scores]
194 |
195 | traces_now = []
196 | scores_now = np.zeros((beam_size - num_dead))
197 | states_now = []
198 |
199 | for i, [j, k] in enumerate(zip(idx_last_traces, idx_word_now)):
200 | traces_now.append(last_traces[j] + [k])
201 | scores_now[i] = copy.copy(top_joint_scores[i])
202 | if options["coverage"]:
203 | states_now.append(last_states[j] + [copy.copy(attn_dist[j, :])])
204 |
205 | num_live = 0
206 | last_traces = []
207 | last_scores = []
208 | last_states = []
209 | last_c_scores = []
210 | dead_ids = []
211 | for i in range(len(traces_now)):
212 | if traces_now[i][-1] == modules["eos_emb"] and len(traces_now[i]) >= consts["min_len_predict"]:
213 | samples.append([str(e.item()) for e in traces_now[i][:-1]])
214 | sample_scores[num_dead] = scores_now[i]
215 | num_dead += 1
216 | dead_ids += [i]
217 | else:
218 | last_traces.append(traces_now[i])
219 | last_scores.append(scores_now[i])
220 |
221 | if options["coverage"]:
222 | last_states.append(states_now[i])
223 | attns = torch.stack(states_now[i])
224 | m, n = attns.shape
225 | cp = torch.sum(attns, dim=0)
226 | cp = torch.max(cp, torch.ones_like(cp))
227 | cp = - consts["beta"] * (torch.sum(cp).item() - n)
228 | last_c_scores.append(cp)
229 |
230 | num_live += 1
231 | if num_live == 0 or num_dead >= beam_size:
232 | break
233 |
234 | if options["coverage"]:
235 | last_c_scores = torch.FloatTensor(np.array(last_c_scores).reshape((num_live, 1))).to(options["device"])
236 |
237 | last_scores = torch.FloatTensor(np.array(last_scores).reshape((num_live, 1))).to(options["device"])
238 | next_y = []
239 | for e in last_traces:
240 | eid = e[-1].item()
241 | if eid in modules["i2w"]:
242 | next_y.append(eid)
243 | else:
244 | next_y.append(modules["lfw_emb"]) # unk for copy mechanism
245 |
246 | next_y = np.array(next_y).reshape((1, num_live))
247 | next_y = torch.LongTensor(next_y).to(options["device"])
248 |
249 | if step == 0:
250 | ys = ys.repeat(1, num_live)
251 | ys_ = []
252 | py_ = []
253 | for i in range(ys.size(1)):
254 | if i not in dead_ids:
255 | ys_.append(ys[:, i])
256 | ys = torch.cat([torch.stack(ys_, dim=1), next_y], dim=0)
257 |
258 | assert num_live + num_dead == beam_size
259 |
260 | if num_live > 0:
261 | for i in range(num_live):
262 | samples.append([str(e.item()) for e in last_traces[i]])
263 | sample_scores[num_dead] = last_scores[i]
264 | num_dead += 1
265 |
266 | #weight by length
267 | for i in range(len(sample_scores)):
268 | sent_len = float(len(samples[i]))
269 | lp = np.power(5 + sent_len, consts["alpha"]) / np.power(5 + 1, consts["alpha"])
270 | sample_scores[i] /= lp
271 |
272 | idx_sorted_scores = np.argsort(sample_scores) # ascending order
273 | if options["has_y"]:
274 | ly = len_y[0]
275 | y_true = y[0 : ly].tolist()
276 | y_true = [str(i) for i in y_true[:-1]] # delete
277 |
278 | sorted_samples = []
279 | sorted_scores = []
280 | filter_idx = []
281 | for e in idx_sorted_scores:
282 | if len(samples[e]) >= consts["min_len_predict"]:
283 | filter_idx.append(e)
284 | if len(filter_idx) == 0:
285 | filter_idx = idx_sorted_scores
286 | for e in filter_idx:
287 | sorted_samples.append(samples[e])
288 | sorted_scores.append(sample_scores[e])
289 |
290 | num_samples = len(sorted_samples)
291 | if len(sorted_samples) == 1:
292 | sorted_samples = sorted_samples[0]
293 | num_samples = 1
294 |
295 | # for task with bytes-length limitation
296 | if options["prediction_bytes_limitation"]:
297 | for i in range(len(sorted_samples)):
298 | sample = sorted_samples[i]
299 | b = 0
300 | for j in range(len(sample)):
301 | e = int(sample[j])
302 | if e in modules["i2w"]:
303 | word = modules["i2w"][e]
304 | else:
305 | word = oovs[e - len(modules["i2w"])]
306 | if j == 0:
307 | b += len(word)
308 | else:
309 | b += len(word) + 1
310 | if b > consts["max_byte_predict"]:
311 | sorted_samples[i] = sorted_samples[i][0 : j]
312 | break
313 |
314 | dec_words = []
315 |
316 | for e in sorted_samples[-1]:
317 | e = int(e)
318 | if e in modules["i2w"]: # if not copy, the word are all in dict
319 | dec_words.append(modules["i2w"][e])
320 | else:
321 | dec_words.append(oovs[e - len(modules["i2w"])])
322 |
323 | write_for_rouge(fname, ref_sents, dec_words, cfg)
324 |
325 | # beam search history for checking
326 | if not options["copy"]:
327 | oovs = None
328 | write_summ("".join((cfg.cc.BEAM_SUMM_PATH, fname)), sorted_samples, num_samples, options, modules["i2w"], oovs, sorted_scores)
329 | write_summ("".join((cfg.cc.BEAM_GT_PATH, fname)), y_true, 1, options, modules["i2w"], oovs)
330 |
331 |
332 |
333 | def predict(model, modules, consts, options):
334 | print("start predicting,")
335 | model.eval()
336 | options["has_y"] = TESTING_DATASET_CLS.HAS_Y
337 | if options["beam_decoding"]:
338 | print("using beam search")
339 | else:
340 | print("using greedy search")
341 | rebuild_dir(cfg.cc.BEAM_SUMM_PATH)
342 | rebuild_dir(cfg.cc.BEAM_GT_PATH)
343 | rebuild_dir(cfg.cc.GROUND_TRUTH_PATH)
344 | rebuild_dir(cfg.cc.SUMM_PATH)
345 |
346 | print("loading test set...")
347 | if options["model_selection"]:
348 | xy_list = pickle.load(open(cfg.cc.VALIDATE_DATA_PATH + "pj1000.pkl", "rb"))
349 | else:
350 | xy_list = pickle.load(open(cfg.cc.TESTING_DATA_PATH + "test.pkl", "rb"))
351 | batch_list, num_files, num_batches = datar.batched(len(xy_list), options, consts)
352 |
353 | print("num_files = ", num_files, ", num_batches = ", num_batches)
354 |
355 | running_start = time.time()
356 | partial_num = 0
357 | total_num = 0
358 | si = 0
359 | for idx_batch in range(num_batches):
360 | test_idx = batch_list[idx_batch]
361 | batch_raw = [xy_list[xy_idx] for xy_idx in test_idx]
362 | batch = datar.get_data(batch_raw, modules, consts, options)
363 |
364 | assert len(test_idx) == batch.x.shape[1] # local_batch_size
365 |
366 |
367 | word_emb, padding_mask = model.encode(torch.LongTensor(batch.x).to(options["device"]))
368 |
369 | if options["beam_decoding"]:
370 | for idx_s in range(len(test_idx)):
371 | if options["copy"]:
372 | inputx = (torch.LongTensor(batch.x_ext[:, idx_s]).to(options["device"]), \
373 | torch.FloatTensor(batch.x_mask[:, idx_s, :]).to(options["device"]), \
374 | word_emb[:, idx_s, :], padding_mask[:, idx_s],\
375 | batch.y[:, idx_s], [batch.len_y[idx_s]], batch.original_summarys[idx_s],\
376 | batch.max_ext_len, batch.x_ext_words[idx_s])
377 | else:
378 | inputx = (torch.LongTensor(batch.x[:, idx_s]).to(options["device"]), word_emb[:, idx_s, :], padding_mask[:, idx_s],\
379 | batch.y[:, idx_s], [batch.len_y[idx_s]], batch.original_summarys[idx_s])
380 |
381 | beam_decode(si, inputx, model, modules, consts, options)
382 | si += 1
383 | else:
384 | pass
385 | #greedy_decode()
386 |
387 | testing_batch_size = len(test_idx)
388 | partial_num += testing_batch_size
389 | total_num += testing_batch_size
390 | if partial_num >= consts["testing_print_size"]:
391 | print(total_num, "summs are generated")
392 | partial_num = 0
393 | print (si, total_num)
394 |
395 | def run(existing_model_name = None):
396 | modules, consts, options = init_modules()
397 |
398 | if options["is_predicting"]:
399 | need_load_model = True
400 | training_model = False
401 | predict_model = True
402 | else:
403 | need_load_model = False
404 | training_model = True
405 | predict_model = False
406 |
407 | print_basic_info(modules, consts, options)
408 |
409 | if training_model:
410 | print ("loading train set...")
411 | if options["is_debugging"]:
412 | xy_list = pickle.load(open(cfg.cc.TESTING_DATA_PATH + "test.pkl", "rb"))
413 | else:
414 | xy_list = pickle.load(open(cfg.cc.TRAINING_DATA_PATH + "train.pkl", "rb"))
415 | batch_list, num_files, num_batches = datar.batched(len(xy_list), options, consts)
416 | print ("num_files = ", num_files, ", num_batches = ", num_batches)
417 |
418 | running_start = time.time()
419 | if True: #TODO: refactor
420 | print ("compiling model ..." )
421 | model = Model(modules, consts, options)
422 | if options["cuda"]:
423 | model.cuda()
424 | optimizer = torch.optim.Adagrad(model.parameters(), lr=consts["lr"], initial_accumulator_value=0.1)
425 |
426 | model_name = "".join(["cnndm.s2s.", options["cell"]])
427 | existing_epoch = 0
428 | if need_load_model:
429 | if existing_model_name == None:
430 | existing_model_name = "cnndm.s2s.transformer.gpu0.epoch27.2"
431 | print ("loading existed model:", existing_model_name)
432 | model, optimizer = load_model(cfg.cc.MODEL_PATH + existing_model_name, model, optimizer)
433 |
434 | if training_model:
435 | print ("start training model ")
436 | model.train()
437 | print_size = num_files // consts["print_time"] if num_files >= consts["print_time"] else num_files
438 |
439 | last_total_error = float("inf")
440 | print ("max epoch:", consts["max_epoch"])
441 | for epoch in range(0, consts["max_epoch"]):
442 | print ("epoch: ", epoch + existing_epoch)
443 | num_partial = 1
444 | total_error = 0.0
445 | error_c = 0.0
446 | partial_num_files = 0
447 | epoch_start = time.time()
448 | partial_start = time.time()
449 | # shuffle the trainset
450 | batch_list, num_files, num_batches = datar.batched(len(xy_list), options, consts)
451 | used_batch = 0.
452 | for idx_batch in range(num_batches):
453 | train_idx = batch_list[idx_batch]
454 | batch_raw = [xy_list[xy_idx] for xy_idx in train_idx]
455 | if len(batch_raw) != consts["batch_size"]:
456 | continue
457 | local_batch_size = len(batch_raw)
458 | batch = datar.get_data(batch_raw, modules, consts, options)
459 |
460 |
461 | model.zero_grad()
462 |
463 | y_pred, cost = model(torch.LongTensor(batch.x).to(options["device"]),\
464 | torch.LongTensor(batch.y_inp).to(options["device"]),\
465 | torch.LongTensor(batch.y).to(options["device"]),\
466 | torch.FloatTensor(batch.x_mask).to(options["device"]),\
467 | torch.FloatTensor(batch.y_mask).to(options["device"]),\
468 | torch.LongTensor(batch.x_ext).to(options["device"]),\
469 | torch.LongTensor(batch.y_ext).to(options["device"]),\
470 | batch.max_ext_len)
471 |
472 |
473 | cost.backward()
474 | torch.nn.utils.clip_grad_norm_(model.parameters(), consts["norm_clip"])
475 | optimizer.step()
476 |
477 |
478 | cost = cost.item()
479 | total_error += cost
480 | used_batch += 1
481 | partial_num_files += consts["batch_size"]
482 | if partial_num_files // print_size == 1 and idx_batch < num_batches:
483 | print (idx_batch + 1, "/" , num_batches, "batches have been processed,", \
484 | "average cost until now:", "cost =", total_error / used_batch, ",", \
485 | "cost_c =", error_c / used_batch, ",", \
486 | "time:", time.time() - partial_start)
487 | partial_num_files = 0
488 | if not options["is_debugging"]:
489 | print("save model... ",)
490 | file_name = model_name + ".gpu" + str(consts["idx_gpu"]) + ".epoch" + str(epoch // consts["save_epoch"] + existing_epoch) + "." + str(num_partial)
491 | save_model(cfg.cc.MODEL_PATH + file_name, model, optimizer)
492 | if options["fire"]:
493 | shutil.move(cfg.cc.MODEL_PATH + file_name, "/out/")
494 |
495 | print("finished")
496 | num_partial += 1
497 | print ("in this epoch, total average cost =", total_error / used_batch, ",", \
498 | "cost_c =", error_c / used_batch, ",",\
499 | "time:", time.time() - epoch_start)
500 |
501 | print_sent_dec(y_pred, batch.y, batch.y_mask, batch.x_ext_words, modules, consts, options, local_batch_size)
502 |
503 | if last_total_error > total_error or options["is_debugging"]:
504 | last_total_error = total_error
505 | if not options["is_debugging"]:
506 | print ("save model... ",)
507 | file_name = model_name + ".gpu" + str(consts["idx_gpu"]) + ".epoch" + str(epoch // consts["save_epoch"] + existing_epoch) + "." + str(num_partial)
508 | save_model(cfg.cc.MODEL_PATH + file_name, model, optimizer)
509 | if options["fire"]:
510 | shutil.move(cfg.cc.MODEL_PATH + file_name, "/out/")
511 |
512 | print ("finished")
513 | else:
514 | print ("optimization finished")
515 | break
516 |
517 | print ("save final model... "),
518 | file_name = model_name + ".final.gpu" + str(consts["idx_gpu"]) + ".epoch" + str(epoch // consts["save_epoch"] + existing_epoch) + "." + str(num_partial)
519 | save_model(cfg.cc.MODEL_PATH + file_name, model, optimizer)
520 | if options["fire"]:
521 | shutil.move(cfg.cc.MODEL_PATH + file_name, "/out/")
522 |
523 | print ("finished")
524 | else:
525 | print ("skip training model")
526 |
527 | if predict_model:
528 | predict(model, modules, consts, options)
529 | print ("Finished, time:", time.time() - running_start)
530 |
531 | if __name__ == "__main__":
532 | np.set_printoptions(threshold = np.inf)
533 | existing_model_name = sys.argv[1] if len(sys.argv) > 1 else None
534 | run(existing_model_name)
535 |
--------------------------------------------------------------------------------