├── data_util
├── __init__.py
├── config.py
├── utils.py
├── data.py
└── batcher.py
├── .gitignore
├── training_ptr_gen
├── __init__.py
├── train_util.py
├── eval.py
├── train.py
└── model.py
└── README.md
/data_util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.bin
3 |
--------------------------------------------------------------------------------
/training_ptr_gen/__init__.py:
--------------------------------------------------------------------------------
1 | # Empty File
2 |
--------------------------------------------------------------------------------
/data_util/config.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 |
4 | root_dir = os.path.expanduser("~")
5 |
6 | train_data_path = "/dataset_path/chunked/train_*"
7 | eval_data_path = "/dataset_path/finished_files/val.bin"
8 | decode_data_path = "/dataset_path/finished_files/test.bin"
9 | vocab_path = "/dataset_path/finished_files/vocab"
10 | log_root = "/log_path/pointer_summarizer"
11 |
12 | # Hyperparameters
13 | hidden_dim= 256
14 | emb_dim= 128
15 | batch_size= 64
16 | max_enc_steps=400
17 | max_dec_steps=100
18 | beam_size=4
19 | min_dec_steps=35
20 | vocab_size=50000
21 |
22 | lr=0.001
23 | adagrad_init_acc=0.1
24 | rand_unif_init_mag=0.02
25 | trunc_norm_init_std=1e-4
26 | max_grad_norm=2.0
27 |
28 | pointer_gen = True
29 | is_coverage = False
30 | cov_loss_wt = 1.0
31 |
32 | eps = 1e-12
33 | max_iterations = 50000
34 |
35 | use_gpu=True
36 |
37 | lr_coverage=0.15
38 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | PyTorch implementation of *[Automatic Fact-guided Sentence Modification](https://arxiv.org/pdf/1909.13838.pdf)* (AAAI 2020).
2 |
3 | This repository implements the split encoder pointer generator module of our method.
4 |
5 |
6 | The code for the Masker is [here](https://github.com/TalSchuster/TokenMasker).
7 |
8 |
9 | ______________________________________________________________________________________________________________________________
10 |
11 | Repository cloned and updated from https://github.com/atulkum/pointer_summarizer.
12 |
13 |
14 |
15 | Note:
16 | * It is tested on pytorch 0.4 with python 2.7
17 |
18 | ______________________________________________________________________________________________________________________________
19 |
20 | Training the Model:
21 |
22 |
export PYTHONPATH=`pwd` &&
23 | python training_ptr_gen/train.py
24 |
25 |
26 | ______________________________________________________________________________________________________________________________
27 |
28 |
29 | Dataset:
30 |
31 | The dataset for training this model can be found here https://drive.google.com/open?id=1aOMEUksFpZwJDtQcgsrJ0rjC7nO2J1kr.
32 |
33 | (Download and edit the config file to the path of the train, val, test and vocab files.)
34 |
35 |
36 | ______________________________________________________________________________________________________________________________
37 |
38 | Evaluation:
39 |
40 |
41 | export PYTHONPATH=`pwd` &&
42 | python training_ptr_gen/eval.py _path_of_model_checkpoint
43 |
44 |
45 | This will generate the corresponding outputs for the desired eval file (specified in the validation path).
46 |
47 | ______________________________________________________________________________________________________________________________
48 |
49 | If you find this repository helpful, please cite our paper:
50 | ```
51 | @inproceedings{shah2020automatic,
52 | title={Automatic Fact-guided Sentence Modification},
53 | author={Darsh J Shah and Tal Schuster and Regina Barzilay},
54 | booktitle={Association for the Advancement of Artificial Intelligence ({AAAI})},
55 | year={2020},
56 | url={https://arxiv.org/pdf/1909.13838.pdf}
57 | }
58 | ```
59 |
--------------------------------------------------------------------------------
/training_ptr_gen/train_util.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | import numpy as np
3 | import torch
4 | from data_util import config
5 |
6 | def get_input_from_batch(batch, use_cuda):
7 | batch_size = len(batch.enc_lens)
8 |
9 | enc_batch = Variable(torch.from_numpy(batch.enc_batch).long())
10 | enc_padding_mask = Variable(torch.from_numpy(batch.enc_padding_mask)).float()
11 | enc_lens = batch.enc_lens
12 |
13 | enc_batch_2 = Variable(torch.from_numpy(batch.enc_batch_2).long())
14 | enc_padding_mask_2 = Variable(torch.from_numpy(batch.enc_padding_mask_2).float())
15 | enc_lens_2 = batch.enc_lens_2
16 |
17 | extra_zeros = None
18 | enc_batch_extend_vocab = None
19 |
20 | extra_zeros_2 = None
21 | enc_batch_extend_vocab_2 = None
22 |
23 | if config.pointer_gen:
24 | enc_batch_extend_vocab = Variable(torch.from_numpy(batch.enc_batch_extend_vocab).long())
25 | enc_batch_extend_vocab_2 = Variable(torch.from_numpy(batch.enc_batch_extend_vocab_2).long())
26 | # max_art_oovs is the max over all the article oov list in the batch
27 | if batch.max_art_oovs > 0 or batch.max_art_oovs_2 > 0:
28 | extra_zeros = Variable(torch.zeros((batch_size, batch.max_art_oovs)))
29 | extra_zeros_2 = Variable(torch.zeros((batch_size, batch.max_art_oovs_2)))
30 |
31 | c_t_1 = Variable(torch.zeros((batch_size, 2 * config.hidden_dim)))
32 | c_t_1_2 = Variable(torch.zeros((batch_size, 2 * config.hidden_dim)))
33 |
34 | coverage = None
35 | coverage_2 = None
36 | if config.is_coverage:
37 | coverage = Variable(torch.zeros(enc_batch.size()))
38 | coverage_2 = Variable(torch.zeros(enc_batch_2.size()))
39 |
40 | if use_cuda:
41 | enc_batch = enc_batch.cuda()
42 | enc_padding_mask = enc_padding_mask.cuda()
43 | enc_batch_2 = enc_batch_2.cuda()
44 | enc_padding_mask_2 = enc_padding_mask_2.cuda()
45 |
46 | if enc_batch_extend_vocab is not None:
47 | enc_batch_extend_vocab = enc_batch_extend_vocab.cuda()
48 | enc_batch_extend_vocab_2 = enc_batch_extend_vocab_2.cuda()
49 | if extra_zeros is not None:
50 | extra_zeros = extra_zeros.cuda()
51 | extra_zeros_2 = extra_zeros_2.cuda()
52 | c_t_1 = c_t_1.cuda()
53 | c_t_1_2 = c_t_1_2.cuda()
54 |
55 | if coverage is not None:
56 | coverage = coverage.cuda()
57 | coverage_2 = coverage_2.cuda()
58 |
59 | return [enc_batch, enc_batch_2], [enc_padding_mask, enc_padding_mask_2], [enc_lens, enc_lens_2], [enc_batch_extend_vocab, enc_batch_extend_vocab_2], [extra_zeros, extra_zeros_2], [c_t_1, c_t_1_2], [coverage, coverage_2]
60 |
61 | def get_output_from_batch(batch, use_cuda):
62 | dec_batch = Variable(torch.from_numpy(batch.dec_batch).long())
63 | dec_padding_mask = Variable(torch.from_numpy(batch.dec_padding_mask)).float()
64 | dec_lens = batch.dec_lens
65 | max_dec_len = np.max(dec_lens)
66 | dec_lens_var = Variable(torch.from_numpy(dec_lens)).float()
67 |
68 | target_batch = Variable(torch.from_numpy(batch.target_batch)).long()
69 |
70 | if use_cuda:
71 | dec_batch = dec_batch.cuda()
72 | dec_padding_mask = dec_padding_mask.cuda()
73 | dec_lens_var = dec_lens_var.cuda()
74 | target_batch = target_batch.cuda()
75 |
76 |
77 | return dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch
78 |
79 |
--------------------------------------------------------------------------------
/data_util/utils.py:
--------------------------------------------------------------------------------
1 | #Content of this file is copied from https://github.com/abisee/pointer-generator/blob/master/
2 | import os
3 | import pyrouge
4 | import logging
5 | import tensorflow as tf
6 |
7 | def print_results(article, abstract, decoded_output):
8 | print ("")
9 | print('ARTICLE: %s', article)
10 | print('REFERENCE SUMMARY: %s', abstract)
11 | print('GENERATED SUMMARY: %s', decoded_output)
12 | print( "")
13 |
14 |
15 | def make_html_safe(s):
16 | s.replace("<", "<")
17 | s.replace(">", ">")
18 | return s
19 |
20 |
21 | def rouge_eval(ref_dir, dec_dir):
22 | r = pyrouge.Rouge155()
23 | r.model_filename_pattern = '#ID#_reference.txt'
24 | r.system_filename_pattern = '(\d+)_decoded.txt'
25 | r.model_dir = ref_dir
26 | r.system_dir = dec_dir
27 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging
28 | rouge_results = r.convert_and_evaluate()
29 | return r.output_to_dict(rouge_results)
30 |
31 |
32 | def rouge_log(results_dict, dir_to_write):
33 | log_str = ""
34 | for x in ["1","2","l"]:
35 | log_str += "\nROUGE-%s:\n" % x
36 | for y in ["f_score", "recall", "precision"]:
37 | key = "rouge_%s_%s" % (x,y)
38 | key_cb = key + "_cb"
39 | key_ce = key + "_ce"
40 | val = results_dict[key]
41 | val_cb = results_dict[key_cb]
42 | val_ce = results_dict[key_ce]
43 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce)
44 | print(log_str)
45 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt")
46 | print("Writing final ROUGE results to %s..."%(results_file))
47 | with open(results_file, "w") as f:
48 | f.write(log_str)
49 |
50 |
51 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99):
52 | if running_avg_loss == 0: # on the first iteration just take the loss
53 | running_avg_loss = loss
54 | else:
55 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
56 | running_avg_loss = min(running_avg_loss, 12) # clip
57 | loss_sum = tf.Summary()
58 | tag_name = 'running_avg_loss/decay=%f' % (decay)
59 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
60 | summary_writer.add_summary(loss_sum, step)
61 | return running_avg_loss
62 |
63 |
64 | def write_for_rouge(reference_sents, decoded_words, ex_index,
65 | _rouge_ref_dir, _rouge_dec_dir):
66 | decoded_sents = []
67 | while len(decoded_words) > 0:
68 | try:
69 | fst_period_idx = decoded_words.index(".")
70 | except ValueError:
71 | fst_period_idx = len(decoded_words)
72 | sent = decoded_words[:fst_period_idx + 1]
73 | decoded_words = decoded_words[fst_period_idx + 1:]
74 | decoded_sents.append(' '.join(sent))
75 |
76 | # pyrouge calls a perl script that puts the data into HTML files.
77 | # Therefore we need to make our output HTML safe.
78 | decoded_sents = [make_html_safe(w) for w in decoded_sents]
79 | reference_sents = [make_html_safe(w) for w in reference_sents]
80 |
81 | ref_file = os.path.join(_rouge_ref_dir, "%06d_reference.txt" % ex_index)
82 | decoded_file = os.path.join(_rouge_dec_dir, "%06d_decoded.txt" % ex_index)
83 |
84 | with open(ref_file, "w") as f:
85 | for idx, sent in enumerate(reference_sents):
86 | f.write(sent) if idx == len(reference_sents) - 1 else f.write(sent + "\n")
87 | with open(decoded_file, "w") as f:
88 | for idx, sent in enumerate(decoded_sents):
89 | f.write(sent) if idx == len(decoded_sents) - 1 else f.write(sent + "\n")
90 |
91 | #print("Wrote example %i to file" % ex_index)
92 |
--------------------------------------------------------------------------------
/data_util/data.py:
--------------------------------------------------------------------------------
1 | #Most of this file is copied form https://github.com/abisee/pointer-generator/blob/master/data.py
2 |
3 | import glob
4 | import random
5 | import struct
6 | import csv
7 | from tensorflow.core.example import example_pb2
8 |
9 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
10 | SENTENCE_START = ''
11 | SENTENCE_END = ''
12 |
13 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
14 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
15 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
16 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences
17 |
18 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file.
19 |
20 |
21 | class Vocab(object):
22 |
23 | def __init__(self, vocab_file, max_size):
24 | self.word_to_id = {}
25 | self._word_to_id = {}
26 | self._id_to_word = {}
27 | self._count = 0 # keeps track of total number of words in the Vocab
28 |
29 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
30 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
31 | self._word_to_id[w] = self._count
32 | self._id_to_word[self._count] = w
33 | self._count += 1
34 |
35 | # Read the vocab file and add words up to max_size
36 | with open(vocab_file, 'r') as vocab_f:
37 | for line in vocab_f:
38 | pieces = line.split()
39 | if len(pieces) != 2:
40 | print 'Warning: incorrectly formatted line in vocabulary file: %s\n' % line
41 | continue
42 | w = pieces[0]
43 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
44 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
45 | if w in self._word_to_id:
46 | raise Exception('Duplicated word in vocabulary file: %s' % w)
47 | self._word_to_id[w] = self._count
48 | self._id_to_word[self._count] = w
49 | self._count += 1
50 | if max_size != 0 and self._count >= max_size:
51 | print "max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)
52 | break
53 |
54 | print "Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1])
55 | self.word_to_id = self._word_to_id
56 |
57 |
58 | def word2id(self, word):
59 | if word not in self._word_to_id:
60 | return self._word_to_id[UNKNOWN_TOKEN]
61 | return self._word_to_id[word]
62 |
63 | def id2word(self, word_id):
64 | if word_id not in self._id_to_word:
65 | raise ValueError('Id not found in vocab: %d' % word_id)
66 | return self._id_to_word[word_id]
67 |
68 | def size(self):
69 | return self._count
70 |
71 | def write_metadata(self, fpath):
72 | print "Writing word embedding metadata file to %s..." % (fpath)
73 | with open(fpath, "w") as f:
74 | fieldnames = ['word']
75 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
76 | for i in xrange(self.size()):
77 | writer.writerow({"word": self._id_to_word[i]})
78 |
79 |
80 | def example_generator(data_path, single_pass):
81 | while True:
82 | filelist = glob.glob(data_path) # get the list of datafiles
83 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty
84 | if single_pass:
85 | filelist = sorted(filelist)
86 | else:
87 | random.shuffle(filelist)
88 | for f in filelist:
89 | reader = open(f, 'rb')
90 | while True:
91 | len_bytes = reader.read(8)
92 | if not len_bytes: break # finished reading this file
93 | str_len = struct.unpack('q', len_bytes)[0]
94 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
95 | yield example_pb2.Example.FromString(example_str)
96 | if single_pass:
97 | print "example_generator completed reading all datafiles. No more data."
98 | break
99 |
100 |
101 | def article2ids(article_words, vocab):
102 | ids = []
103 | oovs = []
104 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
105 | for w in article_words:
106 | i = vocab.word2id(w)
107 | if i == unk_id: # If w is OOV
108 | if w not in oovs: # Add to list of OOVs
109 | oovs.append(w)
110 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
111 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
112 | else:
113 | ids.append(i)
114 | return ids, oovs
115 |
116 |
117 | def abstract2ids(abstract_words, vocab, article_oovs):
118 | ids = []
119 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
120 | for w in abstract_words:
121 | i = vocab.word2id(w)
122 | if i == unk_id: # If w is an OOV word
123 | if w in article_oovs: # If w is an in-article OOV
124 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
125 | ids.append(vocab_idx)
126 | else: # If w is an out-of-article OOV
127 | ids.append(unk_id) # Map to the UNK token id
128 | else:
129 | ids.append(i)
130 | return ids
131 |
132 |
133 | def outputids2words(id_list, vocab, article_oovs):
134 | words = []
135 | for i in id_list:
136 | try:
137 | w = vocab.id2word(i) # might be [UNK]
138 | except ValueError as e: # w is OOV
139 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode"
140 | article_oov_idx = i - vocab.size()
141 | try:
142 | w = article_oovs[article_oov_idx]
143 | except ValueError as e: # i doesn't correspond to an article oov
144 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs)))
145 | words.append(w)
146 | return words
147 |
148 |
149 | def abstract2sents(abstract):
150 | cur = 0
151 | sents = []
152 | while True:
153 | try:
154 | start_p = abstract.index(SENTENCE_START, cur)
155 | end_p = abstract.index(SENTENCE_END, start_p + 1)
156 | cur = end_p + len(SENTENCE_END)
157 | sents.append(abstract[start_p+len(SENTENCE_START):end_p])
158 | except ValueError as e: # no more sentences
159 | return sents
160 |
161 |
162 | def show_art_oovs(article, vocab):
163 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
164 | words = article.split(' ')
165 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words]
166 | out_str = ' '.join(words)
167 | return out_str
168 |
169 |
170 | def show_abs_oovs(abstract, vocab, article_oovs):
171 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
172 | words = abstract.split(' ')
173 | new_words = []
174 | for w in words:
175 | if vocab.word2id(w) == unk_token: # w is oov
176 | if article_oovs is None: # baseline mode
177 | new_words.append("__%s__" % w)
178 | else: # pointer-generator mode
179 | if w in article_oovs:
180 | new_words.append("__%s__" % w)
181 | else:
182 | new_words.append("!!__%s__!!" % w)
183 | else: # w is in-vocab word
184 | new_words.append(w)
185 | out_str = ' '.join(new_words)
186 | return out_str
187 |
--------------------------------------------------------------------------------
/training_ptr_gen/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals, print_function, division
2 |
3 | import os
4 | import time
5 | import sys
6 |
7 | import tensorflow as tf
8 | import torch
9 |
10 | from data_util import config
11 | from data_util.batcher import Batcher
12 | from data_util.data import Vocab
13 |
14 | from data_util.utils import calc_running_avg_loss
15 | from train_util import get_input_from_batch, get_output_from_batch
16 | from model import Model
17 | import sys
18 | reload(sys)
19 | sys.setdefaultencoding('utf8')
20 |
21 | use_cuda = config.use_gpu and torch.cuda.is_available()
22 |
23 | class Evaluate(object):
24 | def __init__(self, model_file_path):
25 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
26 | self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval',
27 | batch_size=config.batch_size, single_pass=True)
28 | time.sleep(15)
29 | model_name = os.path.basename(model_file_path)
30 |
31 | eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
32 | if not os.path.exists(eval_dir):
33 | os.mkdir(eval_dir)
34 | self.summary_writer = tf.summary.FileWriter(eval_dir)
35 |
36 | self.model = Model(model_file_path, is_eval=True)
37 |
38 | def eval_one_batch(self, batch):
39 | enc_batch_list, enc_padding_mask_list, enc_lens_list, enc_batch_extend_vocab_list, extra_zeros_list, c_t_1_list, coverage_list = \
40 | get_input_from_batch(batch, use_cuda)
41 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
42 | get_output_from_batch(batch, use_cuda)
43 |
44 | encoder_outputs_list = []
45 | encoder_feature_list = []
46 | s_t_1 = None
47 | s_t_1_0 = None
48 | s_t_1_1 = None
49 | for enc_batch,enc_lens in zip(enc_batch_list, enc_lens_list):
50 | sorted_indices = sorted(range(len(enc_lens)),key=enc_lens.__getitem__)
51 | sorted_indices.reverse()
52 | inverse_sorted_indices = [-1 for _ in range(len(sorted_indices))]
53 | for index,position in enumerate(sorted_indices):
54 | inverse_sorted_indices[position] = index
55 | sorted_enc_batch = torch.index_select(enc_batch, 0, torch.LongTensor(sorted_indices) if not use_cuda else torch.LongTensor(sorted_indices).cuda())
56 | sorted_enc_lens = enc_lens[sorted_indices]
57 | sorted_encoder_outputs, sorted_encoder_feature, sorted_encoder_hidden = self.model.encoder(sorted_enc_batch
58 | , sorted_enc_lens)
59 | encoder_outputs = torch.index_select(sorted_encoder_outputs, 0, torch.LongTensor(inverse_sorted_indices) if
60 | not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda())
61 | encoder_feature = torch.index_select(sorted_encoder_feature.view(encoder_outputs.shape), 0, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()).view(sorted_encoder_feature.shape)
62 | encoder_hidden = tuple([torch.index_select(sorted_encoder_hidden[0], 1, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()), torch.index_select(sorted_encoder_hidden[1], 1, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda())])
63 | #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
64 | encoder_outputs_list.append(encoder_outputs)
65 | encoder_feature_list.append(encoder_feature)
66 | if s_t_1 is None:
67 | s_t_1 = self.model.reduce_state(encoder_hidden)
68 | s_t_1_0, s_t_1_1 = s_t_1
69 | else:
70 | s_t_1_new = self.model.reduce_state(encoder_hidden)
71 | s_t_1_0 = s_t_1_0 + s_t_1_new[0]
72 | s_t_1_1 = s_t_1_1 + s_t_1_new[1]
73 | s_t_1 = tuple([s_t_1_0, s_t_1_1])
74 |
75 | #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
76 | #s_t_1 = self.model.reduce_state(encoder_hidden)
77 |
78 | step_losses = []
79 | target_words = []
80 | output_words = []
81 | id_to_words = {v: k for k, v in self.vocab.word_to_id.iteritems()}
82 | for di in range(min(max_dec_len, config.max_dec_steps)):
83 | y_t_1 = dec_batch[:, di] # Teacher forcing
84 | final_dist, s_t_1, c_t_1_list,attn_dist_list, p_gen, next_coverage_list = self.model.decoder(y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list, enc_padding_mask_list, c_t_1_list, extra_zeros_list, enc_batch_extend_vocab_list, coverage_list, di)
85 | target = target_batch[:, di]
86 | gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
87 | output_ids = final_dist.max(1)[1]
88 | output_2_candidates = final_dist.topk(2,1)[1]
89 | for ind in range(output_ids.shape[0]):
90 | if self.vocab.word_to_id['X'] == output_ids[ind].item():
91 | output_ids[ind] = output_2_candidates[ind][1]
92 | target_step = []
93 | output_step = []
94 | step_mask = dec_padding_mask[:, di]
95 | for i in range(target.shape[0]):
96 | if target[i].item() >= len(id_to_words) or step_mask[i].item() == 0:
97 | target[i] = 0
98 | target_step.append(id_to_words[target[i].item()])
99 | if output_ids[i].item() >= len(id_to_words) or step_mask[i].item() == 0:
100 | output_ids[i] = 0
101 | output_step.append(id_to_words[output_ids[i].item()])
102 | target_words.append(target_step)
103 | output_words.append(output_step)
104 | step_loss = -torch.log(gold_probs + config.eps)
105 | if config.is_coverage:
106 | #step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
107 | #step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
108 | #coverage = next_coverage
109 | step_coverage_loss = 0.0
110 | for ind in range(len(coverage_list)):
111 | step_coverage_loss += torch.sum(torch.min(attn_dist_list[ind], coverage_list[ind]), 1)
112 | coverage_list[ind] = next_coverage_list[ind]
113 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
114 |
115 | step_mask = dec_padding_mask[:, di]
116 | step_loss = step_loss * step_mask
117 | step_losses.append(step_loss)
118 |
119 | self.write_words(output_words,"output.txt")
120 | self.write_words(target_words,"input.txt")
121 |
122 | sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
123 | batch_avg_loss = sum_step_losses / dec_lens_var
124 | loss = torch.mean(batch_avg_loss)
125 |
126 | return loss.item()
127 |
128 | def write_words(self, len_batch_words, output_file):
129 | batch_sentences = ["" for _ in range(len(len_batch_words[0]))]
130 | batch_sentences_done = [False for _ in range(len(len_batch_words[0]))]
131 | for i in range(len(len_batch_words)):
132 | for j in range(len(len_batch_words[i])):
133 | if len_batch_words[i][j] == "[STOP]":
134 | batch_sentences_done[j] = True
135 | if batch_sentences_done[j] != True:
136 | batch_sentences[j] += len_batch_words[i][j] + " "
137 | f = open(output_file, "a")
138 | for sentence in batch_sentences:
139 | f.write(sentence.strip() + "\n")
140 | f.close()
141 |
142 |
143 | def run_eval(self):
144 | running_avg_loss, iter = 0, 0
145 | start = time.time()
146 | batch = self.batcher.next_batch()
147 | while batch is not None:
148 | loss = self.eval_one_batch(batch)
149 |
150 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
151 | iter += 1
152 |
153 | if iter % 100 == 0:
154 | self.summary_writer.flush()
155 | print_interval = 1
156 | if iter % print_interval == 0:
157 | print('steps %d, seconds for %d batch: %.2f , loss: %f' % (
158 | iter, print_interval, time.time() - start, running_avg_loss))
159 | start = time.time()
160 | batch = self.batcher.next_batch()
161 |
162 |
163 | if __name__ == '__main__':
164 | model_filename = sys.argv[1]
165 | eval_processor = Evaluate(model_filename)
166 | eval_processor.run_eval()
167 |
168 |
169 |
--------------------------------------------------------------------------------
/training_ptr_gen/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals, print_function, division
2 |
3 | import os
4 | import time
5 | import argparse
6 |
7 | import tensorflow as tf
8 | import torch
9 | from model import Model
10 | from torch.nn.utils import clip_grad_norm_
11 |
12 | from torch.optim import Adagrad, Adam
13 |
14 | from data_util import config
15 | from data_util.batcher import Batcher
16 | from data_util.data import Vocab
17 | from data_util.utils import calc_running_avg_loss
18 | from train_util import get_input_from_batch, get_output_from_batch
19 |
20 | use_cuda = config.use_gpu and torch.cuda.is_available()
21 |
22 | class Train(object):
23 | def __init__(self):
24 | self.vocab = Vocab(config.vocab_path, config.vocab_size)
25 | self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
26 | batch_size=config.batch_size, single_pass=False)
27 | time.sleep(15)
28 |
29 | train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
30 | if not os.path.exists(train_dir):
31 | os.mkdir(train_dir)
32 |
33 | self.model_dir = os.path.join(train_dir, 'model')
34 | if not os.path.exists(self.model_dir):
35 | os.mkdir(self.model_dir)
36 |
37 | self.summary_writer = tf.summary.FileWriter(train_dir)
38 |
39 | def save_model(self, running_avg_loss, iter):
40 | state = {
41 | 'iter': iter,
42 | 'encoder_state_dict': self.model.encoder.state_dict(),
43 | 'decoder_state_dict': self.model.decoder.state_dict(),
44 | 'reduce_state_dict': self.model.reduce_state.state_dict(),
45 | 'optimizer': self.optimizer.state_dict(),
46 | 'current_loss': running_avg_loss
47 | }
48 | model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
49 | torch.save(state, model_save_path)
50 |
51 | def setup_train(self, model_file_path=None):
52 | self.model = Model(model_file_path)
53 |
54 | params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
55 | list(self.model.reduce_state.parameters())
56 | initial_lr = config.lr_coverage if config.is_coverage else config.lr
57 | self.optimizer = Adam(params, lr=initial_lr)#Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
58 |
59 | start_iter, start_loss = 0, 0
60 |
61 | if model_file_path is not None:
62 | state = torch.load(model_file_path, map_location= lambda storage, location: storage)
63 | start_iter = state['iter']
64 | start_loss = state['current_loss']
65 |
66 | if not config.is_coverage:
67 | self.optimizer.load_state_dict(state['optimizer'])
68 | if use_cuda:
69 | for state in self.optimizer.state.values():
70 | for k, v in state.items():
71 | if torch.is_tensor(v):
72 | state[k] = v.cuda()
73 |
74 | return start_iter, start_loss
75 |
76 | def train_one_batch(self, batch):
77 | enc_batch_list, enc_padding_mask_list, enc_lens_list, enc_batch_extend_vocab_list, extra_zeros_list, c_t_1_list, coverage_list = \
78 | get_input_from_batch(batch, use_cuda)
79 | dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
80 | get_output_from_batch(batch, use_cuda)
81 |
82 | self.optimizer.zero_grad()
83 |
84 | encoder_outputs_list = []
85 | encoder_feature_list = []
86 | s_t_1 = None
87 | s_t_1_0 = None
88 | s_t_1_1 = None
89 | for enc_batch,enc_lens in zip(enc_batch_list, enc_lens_list):
90 | sorted_indices = sorted(range(len(enc_lens)),key=enc_lens.__getitem__)
91 | sorted_indices.reverse()
92 | inverse_sorted_indices = [-1 for _ in range(len(sorted_indices))]
93 | for index,position in enumerate(sorted_indices):
94 | inverse_sorted_indices[position] = index
95 | sorted_enc_batch = torch.index_select(enc_batch, 0, torch.LongTensor(sorted_indices) if not use_cuda else torch.LongTensor(sorted_indices).cuda())
96 | sorted_enc_lens = enc_lens[sorted_indices]
97 | sorted_encoder_outputs, sorted_encoder_feature, sorted_encoder_hidden = self.model.encoder(sorted_enc_batch, sorted_enc_lens)
98 | encoder_outputs = torch.index_select(sorted_encoder_outputs, 0, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda())
99 | encoder_feature = torch.index_select(sorted_encoder_feature.view(encoder_outputs.shape), 0, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()).view(sorted_encoder_feature.shape)
100 | encoder_hidden = tuple([torch.index_select(sorted_encoder_hidden[0], 1, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda()), torch.index_select(sorted_encoder_hidden[1], 1, torch.LongTensor(inverse_sorted_indices) if not use_cuda else torch.LongTensor(inverse_sorted_indices).cuda())])
101 | #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
102 | encoder_outputs_list.append(encoder_outputs)
103 | encoder_feature_list.append(encoder_feature)
104 | if s_t_1 is None:
105 | s_t_1 = self.model.reduce_state(encoder_hidden)
106 | s_t_1_0, s_t_1_1 = s_t_1
107 | else:
108 | s_t_1_new = self.model.reduce_state(encoder_hidden)
109 | s_t_1_0 = s_t_1_0 + s_t_1_new[0]
110 | s_t_1_1 = s_t_1_1 + s_t_1_new[1]
111 | s_t_1 = tuple([s_t_1_0, s_t_1_1])
112 |
113 | #c_t_1_list = [c_t_1]
114 | #coverage_list = [coverage]
115 |
116 | step_losses = []
117 | for di in range(min(max_dec_len, config.max_dec_steps)):
118 | y_t_1 = dec_batch[:, di] # Teacher forcing
119 | final_dist, s_t_1, c_t_1_list, attn_dist_list, p_gen, next_coverage_list = self.model.decoder(y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list, enc_padding_mask_list, c_t_1_list, extra_zeros_list, enc_batch_extend_vocab_list, coverage_list, di)
120 | target = target_batch[:, di]
121 | gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
122 | step_loss = -torch.log(gold_probs + config.eps)
123 | if config.is_coverage:
124 | step_coverage_loss = 0.0
125 | for ind in range(len(coverage_list)):
126 | step_coverage_loss += torch.sum(torch.min(attn_dist_list[ind], coverage_list[ind]), 1)
127 | coverage_list[ind] = next_coverage_list[ind]
128 | step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
129 |
130 | step_mask = dec_padding_mask[:, di]
131 | step_loss = step_loss * step_mask
132 | step_losses.append(step_loss)
133 |
134 | sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
135 | batch_avg_loss = sum_losses/dec_lens_var
136 | loss = torch.mean(batch_avg_loss)
137 |
138 | loss.backward()
139 |
140 | self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
141 | clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
142 | clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)
143 |
144 | self.optimizer.step()
145 |
146 | return loss.item()
147 |
148 | def trainIters(self, n_iters, model_file_path=None):
149 | iter, running_avg_loss = self.setup_train(model_file_path)
150 | start = time.time()
151 | while iter < n_iters:
152 | batch = self.batcher.next_batch()
153 | loss = self.train_one_batch(batch)
154 |
155 | running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
156 | iter += 1
157 |
158 | if iter % 100 == 0:
159 | self.summary_writer.flush()
160 | print_interval = 500
161 | if iter % print_interval == 0:
162 | print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval,
163 | time.time() - start, loss))
164 | start = time.time()
165 | if iter % 500 == 0:
166 | self.save_model(running_avg_loss, iter)
167 |
168 | if __name__ == '__main__':
169 | parser = argparse.ArgumentParser(description="Train script")
170 | parser.add_argument("-m",
171 | dest="model_file_path",
172 | required=False,
173 | default=None,
174 | help="Model file for retraining (default: None).")
175 | args = parser.parse_args()
176 |
177 | train_processor = Train()
178 | train_processor.trainIters(config.max_iterations, args.model_file_path)
179 |
--------------------------------------------------------------------------------
/training_ptr_gen/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals, print_function, division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7 | from data_util import config
8 | from numpy import random
9 | import sys
10 |
11 | use_cuda = config.use_gpu and torch.cuda.is_available()
12 |
13 | random.seed(123)
14 | torch.manual_seed(123)
15 | if torch.cuda.is_available():
16 | torch.cuda.manual_seed_all(123)
17 |
18 | def init_lstm_wt(lstm):
19 | for names in lstm._all_weights:
20 | for name in names:
21 | if name.startswith('weight_'):
22 | wt = getattr(lstm, name)
23 | wt.data.uniform_(-config.rand_unif_init_mag, config.rand_unif_init_mag)
24 | elif name.startswith('bias_'):
25 | # set forget bias to 1
26 | bias = getattr(lstm, name)
27 | n = bias.size(0)
28 | start, end = n // 4, n // 2
29 | bias.data.fill_(0.)
30 | bias.data[start:end].fill_(1.)
31 |
32 | def init_linear_wt(linear):
33 | linear.weight.data.normal_(std=config.trunc_norm_init_std)
34 | if linear.bias is not None:
35 | linear.bias.data.normal_(std=config.trunc_norm_init_std)
36 |
37 | def init_wt_normal(wt):
38 | wt.data.normal_(std=config.trunc_norm_init_std)
39 |
40 | def init_wt_unif(wt):
41 | wt.data.uniform_(-config.rand_unif_init_mag, config.rand_unif_init_mag)
42 |
43 | class Encoder(nn.Module):
44 | def __init__(self):
45 | super(Encoder, self).__init__()
46 | self.embedding = nn.Embedding(config.vocab_size, config.emb_dim)
47 | init_wt_normal(self.embedding.weight)
48 |
49 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
50 | init_lstm_wt(self.lstm)
51 |
52 | self.W_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2, bias=False)
53 |
54 | #seq_lens should be in descending order
55 | def forward(self, input, seq_lens):
56 | embedded = self.embedding(input)
57 |
58 | packed = pack_padded_sequence(embedded, seq_lens, batch_first=True)
59 | output, hidden = self.lstm(packed)
60 |
61 | encoder_outputs, _ = pad_packed_sequence(output, batch_first=True) # h dim = B x t_k x n
62 | encoder_outputs = encoder_outputs.contiguous()
63 |
64 | encoder_feature = encoder_outputs.view(-1, 2*config.hidden_dim) # B * t_k x 2*hidden_dim
65 | encoder_feature = self.W_h(encoder_feature)
66 |
67 | return encoder_outputs, encoder_feature, hidden
68 |
69 | class ReduceState(nn.Module):
70 | def __init__(self):
71 | super(ReduceState, self).__init__()
72 |
73 | self.reduce_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
74 | init_linear_wt(self.reduce_h)
75 | self.reduce_c = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
76 | init_linear_wt(self.reduce_c)
77 |
78 | def forward(self, hidden):
79 | h, c = hidden # h, c dim = 2 x b x hidden_dim
80 | h_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
81 | hidden_reduced_h = F.relu(self.reduce_h(h_in))
82 | c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
83 | hidden_reduced_c = F.relu(self.reduce_c(c_in))
84 |
85 | return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)) # h, c dim = 1 x b x hidden_dim
86 |
87 | class Attention(nn.Module):
88 | def __init__(self):
89 | super(Attention, self).__init__()
90 | # attention
91 | if config.is_coverage:
92 | self.W_c = nn.Linear(1, config.hidden_dim * 2, bias=False)
93 | self.decode_proj = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2)
94 | self.v = nn.Linear(config.hidden_dim * 2, 1, bias=False)
95 |
96 | def forward(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage):
97 | b, t_k, n = list(encoder_outputs.size())
98 |
99 | dec_fea = self.decode_proj(s_t_hat) # B x 2*hidden_dim
100 | dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, t_k, n).contiguous() # B x t_k x 2*hidden_dim
101 | dec_fea_expanded = dec_fea_expanded.view(-1, n) # B * t_k x 2*hidden_dim
102 |
103 | att_features = encoder_feature + dec_fea_expanded # B * t_k x 2*hidden_dim
104 | if config.is_coverage:
105 | coverage_input = coverage.view(-1, 1) # B * t_k x 1
106 | coverage_feature = self.W_c(coverage_input) # B * t_k x 2*hidden_dim
107 | att_features = att_features + coverage_feature
108 |
109 | e = F.tanh(att_features) # B * t_k x 2*hidden_dim
110 | scores = self.v(e) # B * t_k x 1
111 | scores = scores.view(-1, t_k) # B x t_k
112 |
113 | attn_dist_ = F.softmax(scores, dim=1)*enc_padding_mask # B x t_k
114 | normalization_factor = attn_dist_.sum(1, keepdim=True)
115 | attn_dist = attn_dist_ / normalization_factor
116 |
117 | attn_dist = attn_dist.unsqueeze(1) # B x 1 x t_k
118 | c_t = torch.bmm(attn_dist, encoder_outputs) # B x 1 x n
119 | c_t = c_t.view(-1, config.hidden_dim * 2) # B x 2*hidden_dim
120 |
121 | attn_dist = attn_dist.view(-1, t_k) # B x t_k
122 |
123 | if config.is_coverage:
124 | coverage = coverage.view(-1, t_k)
125 | coverage = coverage + attn_dist
126 |
127 | return c_t, attn_dist, coverage
128 |
129 | class Decoder(nn.Module):
130 | def __init__(self):
131 | super(Decoder, self).__init__()
132 | self.attention_network = Attention()
133 | # decoder
134 | self.embedding = nn.Embedding(config.vocab_size, config.emb_dim)
135 | init_wt_normal(self.embedding.weight)
136 |
137 | self.x_context = nn.Linear(config.hidden_dim * 2 + config.emb_dim, config.emb_dim)
138 |
139 | self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=False)
140 | init_lstm_wt(self.lstm)
141 |
142 | self.encoder_division = nn.Linear(config.hidden_dim * 4, 1)
143 |
144 | if config.pointer_gen:
145 | self.p_gen_linear = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1)
146 | self.p_gen_encoder = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1)
147 |
148 | #p_vocab
149 | self.out1 = nn.Linear(config.hidden_dim * 3, config.hidden_dim)
150 | self.out2 = nn.Linear(config.hidden_dim, config.vocab_size)
151 | init_linear_wt(self.out2)
152 |
153 | def forward(self, y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list, enc_padding_mask_list,
154 | c_t_1_list, extra_zeros_list, enc_batch_extend_vocab_list, coverage_list, step):
155 |
156 | c_t_list = []
157 | attn_dist_list = []
158 | if not self.training and step == 0:
159 | h_decoder, c_decoder = s_t_1
160 | s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim),
161 | c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim
162 | for ind in range(len(coverage_list)):
163 | c_t, _, coverage_next = self.attention_network(s_t_hat, encoder_outputs_list[ind], encoder_feature_list[ind], enc_padding_mask_list[ind], coverage_list[ind])
164 | c_t_list.append(c_t)
165 | coverage_list[ind] = coverage_next
166 |
167 | y_t_1_embd = self.embedding(y_t_1)
168 | c_t_1 = None
169 | p_dec = F.sigmoid(self.encoder_division(torch.cat((c_t_1_list[0], c_t_1_list[1]),-1)))
170 | if len(sys.argv) > 2:
171 | p_dec = (p_dec + int(sys.argv[2]))/101
172 | for ind,value in enumerate(c_t_1_list):
173 | if c_t_1 is None:
174 | #c_t_1 = c_t_1_list[ind]
175 | c_t_1 = (p_dec) * c_t_1_list[ind]
176 | else:
177 | #c_t_1 += c_t_1_list[ind]
178 | c_t_1 += (1-p_dec) *c_t_1_list[ind]
179 | x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1))
180 | lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1)
181 |
182 | h_decoder, c_decoder = s_t
183 | s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim),
184 | c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim
185 | c_t_list = []
186 | for ind in range(len(coverage_list)):
187 | c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, encoder_outputs_list[ind], encoder_feature_list[ind], enc_padding_mask_list[ind], coverage_list[ind])
188 | c_t_list.append(c_t)
189 | attn_dist_list.append(attn_dist)
190 | if self.training or step > 0:
191 | coverage_list[ind] = coverage_next
192 |
193 | p_gen = None
194 | p_enc = None
195 | if config.pointer_gen:
196 | c_t = None
197 | assert len(c_t_list) > 0
198 | for ind in range(len(c_t_list)):
199 | if c_t is None:
200 | c_t = c_t_list[ind]
201 | else:
202 | c_t += c_t_list[ind]
203 | p_gen_input = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim)
204 | p_gen = self.p_gen_linear(p_gen_input)
205 | p_gen = F.sigmoid(p_gen)
206 |
207 | p_enc = self.p_gen_encoder(p_gen_input)
208 | p_enc = F.sigmoid(p_enc)
209 | if len(sys.argv) > 2:
210 | p_enc = (p_enc + int(sys.argv[2]))/101
211 |
212 | c_t = None
213 | for ind in range(len(c_t_list)):
214 | if c_t is None:
215 | c_t = c_t_list[ind] #p_enc * c_t_list[ind]
216 | else:
217 | c_t = c_t + c_t_list[ind] #c_t + (1-p_enc) * c_t_list[ind]
218 | output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3
219 | output = self.out1(output) # B x hidden_dim
220 |
221 | #output = F.relu(output)
222 |
223 | output = self.out2(output) # B x vocab_size
224 | vocab_dist = F.softmax(output, dim=1)
225 |
226 |
227 | attn_dist_list_ = []
228 | if config.pointer_gen:
229 | vocab_dist_ = p_gen * vocab_dist
230 | for ind,attn_dist in enumerate(attn_dist_list):
231 | if ind == 0:
232 | attn_dist_list_.append(p_enc * (1 - p_gen) * attn_dist)
233 | else:
234 | attn_dist_list_.append((1 - p_enc) * (1 - p_gen) * attn_dist)
235 |
236 | for extra_zeros in extra_zeros_list:
237 | if extra_zeros is not None:
238 | vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1)
239 |
240 | for enc_batch_extend_vocab,attn_dist_ in zip(enc_batch_extend_vocab_list, attn_dist_list_):
241 | vocab_dist_ = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
242 | final_dist = vocab_dist_
243 | else:
244 | final_dist = vocab_dist
245 |
246 | return final_dist, s_t, c_t_list, attn_dist_list, p_gen, coverage_list
247 |
248 | class Model(object):
249 | def __init__(self, model_file_path=None, is_eval=False):
250 | encoder = Encoder()
251 | decoder = Decoder()
252 | reduce_state = ReduceState()
253 |
254 | # shared the embedding between encoder and decoder
255 | decoder.embedding.weight = encoder.embedding.weight
256 | if is_eval:
257 | encoder = encoder.eval()
258 | decoder = decoder.eval()
259 | reduce_state = reduce_state.eval()
260 |
261 | if use_cuda:
262 | encoder = encoder.cuda()
263 | decoder = decoder.cuda()
264 | reduce_state = reduce_state.cuda()
265 |
266 | self.encoder = encoder
267 | self.decoder = decoder
268 | self.reduce_state = reduce_state
269 |
270 | if model_file_path is not None:
271 | state = torch.load(model_file_path, map_location= lambda storage, location: storage)
272 | self.encoder.load_state_dict(state['encoder_state_dict'])
273 | self.decoder.load_state_dict(state['decoder_state_dict'], strict=False)
274 | self.reduce_state.load_state_dict(state['reduce_state_dict'])
275 |
--------------------------------------------------------------------------------
/data_util/batcher.py:
--------------------------------------------------------------------------------
1 | #Most of this file is copied form https://github.com/abisee/pointer-generator/blob/master/batcher.py
2 |
3 | import Queue
4 | import time
5 | from random import shuffle
6 | from threading import Thread
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | import config
12 | import data
13 |
14 | import random
15 | random.seed(1234)
16 |
17 |
18 | class Example2(object):
19 |
20 | def __init__(self, article, article_extra, abstract_sentences, vocab):
21 | # Get ids of special tokens
22 | start_decoding = vocab.word2id(data.START_DECODING)
23 | stop_decoding = vocab.word2id(data.STOP_DECODING)
24 |
25 | # Process the article
26 | article_words = article.split()
27 | if len(article_words) > config.max_enc_steps:
28 | article_words = article_words[:config.max_enc_steps]
29 | self.enc_len = len(article_words) # store the length after truncation but before padding
30 | self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token
31 |
32 | article_extra_words = article_extra.split()
33 | if len(article_extra_words) > config.max_enc_steps:
34 | article_extra_words = article_extra_words[:config.max_enc_steps]
35 | self.enc_len_2 = len(article_extra_words)
36 | self.enc_input_2 = [vocab.word2id(w) for w in article_extra_words]
37 |
38 | # Process the abstract
39 | abstract = ' '.join(abstract_sentences) # string
40 | abstract_words = abstract.split() # list of strings
41 | abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token
42 |
43 | # Get the decoder input sequence and target sequence
44 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, config.max_dec_steps, start_decoding, stop_decoding)
45 | self.dec_len = len(self.dec_input)
46 |
47 | # If using pointer-generator mode, we need to store some extra info
48 | if config.pointer_gen:
49 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
50 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)
51 | self.enc_input_extend_vocab_2, self.article_oovs_2 = data.article2ids(article_extra_words, vocab)
52 |
53 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
54 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs+self.article_oovs_2)
55 |
56 | # Overwrite decoder target sequence so it uses the temp article OOV ids
57 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, config.max_dec_steps, start_decoding, stop_decoding)
58 |
59 | # Store the original strings
60 | self.original_article = article
61 | self.original_abstract = abstract
62 | self.original_abstract_sents = abstract_sentences
63 |
64 |
65 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id):
66 | inp = [start_id] + sequence[:]
67 | target = sequence[:]
68 | if len(inp) > max_len: # truncate
69 | inp = inp[:max_len]
70 | target = target[:max_len] # no end_token
71 | else: # no truncation
72 | target.append(stop_id) # end token
73 | assert len(inp) == len(target)
74 | return inp, target
75 |
76 |
77 | def pad_decoder_inp_targ(self, max_len, pad_id):
78 | while len(self.dec_input) < max_len:
79 | self.dec_input.append(pad_id)
80 | while len(self.target) < max_len:
81 | self.target.append(pad_id)
82 |
83 |
84 | def pad_encoder_input(self, max_len, pad_id):
85 | while len(self.enc_input) < max_len:
86 | self.enc_input.append(pad_id)
87 | if config.pointer_gen:
88 | while len(self.enc_input_extend_vocab) < max_len:
89 | self.enc_input_extend_vocab.append(pad_id)
90 |
91 |
92 | def pad_encoder_input_2(self, max_len, pad_id):
93 | while len(self.enc_input_2) < max_len:
94 | self.enc_input_2.append(pad_id)
95 | if config.pointer_gen:
96 | while len(self.enc_input_extend_vocab_2) < max_len:
97 | self.enc_input_extend_vocab_2.append(pad_id)
98 |
99 |
100 | class Example(object):
101 |
102 | def __init__(self, article, abstract_sentences, vocab):
103 | # Get ids of special tokens
104 | start_decoding = vocab.word2id(data.START_DECODING)
105 | stop_decoding = vocab.word2id(data.STOP_DECODING)
106 |
107 | # Process the article
108 | article_words = article.split()
109 | if len(article_words) > config.max_enc_steps:
110 | article_words = article_words[:config.max_enc_steps]
111 | self.enc_len = len(article_words) # store the length after truncation but before padding
112 | self.enc_input = [vocab.word2id(w) for w in article_words] # list of word ids; OOVs are represented by the id for UNK token
113 |
114 | # Process the abstract
115 | abstract = ' '.join(abstract_sentences) # string
116 | abstract_words = abstract.split() # list of strings
117 | abs_ids = [vocab.word2id(w) for w in abstract_words] # list of word ids; OOVs are represented by the id for UNK token
118 |
119 | # Get the decoder input sequence and target sequence
120 | self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, config.max_dec_steps, start_decoding, stop_decoding)
121 | self.dec_len = len(self.dec_input)
122 |
123 | # If using pointer-generator mode, we need to store some extra info
124 | if config.pointer_gen:
125 | # Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id; also store the in-article OOVs words themselves
126 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)
127 |
128 | # Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
129 | abs_ids_extend_vocab = data.abstract2ids(abstract_words, vocab, self.article_oovs)
130 |
131 | # Overwrite decoder target sequence so it uses the temp article OOV ids
132 | _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, config.max_dec_steps, start_decoding, stop_decoding)
133 |
134 | # Store the original strings
135 | self.original_article = article
136 | self.original_abstract = abstract
137 | self.original_abstract_sents = abstract_sentences
138 |
139 |
140 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id):
141 | inp = [start_id] + sequence[:]
142 | target = sequence[:]
143 | if len(inp) > max_len: # truncate
144 | inp = inp[:max_len]
145 | target = target[:max_len] # no end_token
146 | else: # no truncation
147 | target.append(stop_id) # end token
148 | assert len(inp) == len(target)
149 | return inp, target
150 |
151 |
152 | def pad_decoder_inp_targ(self, max_len, pad_id):
153 | while len(self.dec_input) < max_len:
154 | self.dec_input.append(pad_id)
155 | while len(self.target) < max_len:
156 | self.target.append(pad_id)
157 |
158 |
159 | def pad_encoder_input(self, max_len, pad_id):
160 | while len(self.enc_input) < max_len:
161 | self.enc_input.append(pad_id)
162 | if config.pointer_gen:
163 | while len(self.enc_input_extend_vocab) < max_len:
164 | self.enc_input_extend_vocab.append(pad_id)
165 |
166 |
167 | class Batch(object):
168 | def __init__(self, example_list, vocab, batch_size):
169 | self.batch_size = batch_size
170 | self.pad_id = vocab.word2id(data.PAD_TOKEN) # id of the PAD token used to pad sequences
171 | self.init_encoder_seq(example_list) # initialize the input to the encoder
172 | self.init_decoder_seq(example_list) # initialize the input and targets for the decoder
173 | self.store_orig_strings(example_list) # store the original strings
174 |
175 |
176 | def init_encoder_seq(self, example_list):
177 | # Determine the maximum length of the encoder input sequence in this batch
178 | max_enc_seq_len = max([ex.enc_len for ex in example_list])
179 | max_enc_seq_len_2 = max([ex.enc_len_2 for ex in example_list])
180 |
181 | # Pad the encoder input sequences up to the length of the longest sequence
182 | for ex in example_list:
183 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id)
184 | ex.pad_encoder_input_2(max_enc_seq_len_2, self.pad_id)
185 |
186 | # Initialize the numpy arrays
187 | # Note: our enc_batch can have different length (second dimension) for each batch because we use dynamic_rnn for the encoder.
188 | self.enc_batch = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
189 | self.enc_lens = np.zeros((self.batch_size), dtype=np.int32)
190 | self.enc_padding_mask = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.float32)
191 |
192 | self.enc_batch_2 = np.zeros((self.batch_size, max_enc_seq_len_2), dtype=np.int32)
193 | self.enc_lens_2 = np.zeros((self.batch_size), dtype=np.int32)
194 | self.enc_padding_mask_2 = np.zeros((self.batch_size, max_enc_seq_len_2), dtype=np.float32)
195 |
196 | # Fill in the numpy arrays
197 | for i, ex in enumerate(example_list):
198 | self.enc_batch[i, :] = ex.enc_input[:]
199 | self.enc_lens[i] = ex.enc_len
200 | for j in xrange(ex.enc_len):
201 | self.enc_padding_mask[i][j] = 1
202 | self.enc_batch_2[i, :] = ex.enc_input_2[:]
203 | self.enc_lens_2[i] = ex.enc_len_2
204 | for j in xrange(ex.enc_len_2):
205 | self.enc_padding_mask_2[i][j] = 1
206 |
207 | # For pointer-generator mode, need to store some extra info
208 | if config.pointer_gen:
209 | # Determine the max number of in-article OOVs in this batch
210 | self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list])
211 | self.max_art_oovs_2 = max([len(ex.article_oovs_2) for ex in example_list])
212 | # Store the in-article OOVs themselves
213 | self.art_oovs = [ex.article_oovs for ex in example_list]
214 | self.art_oovs_2 = [ex.article_oovs_2 for ex in example_list]
215 | # Store the version of the enc_batch that uses the article OOV ids
216 | self.enc_batch_extend_vocab = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
217 | self.enc_batch_extend_vocab_2 = np.zeros((self.batch_size, max_enc_seq_len_2), dtype=np.int32)
218 | for i, ex in enumerate(example_list):
219 | self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:]
220 | self.enc_batch_extend_vocab_2[i, :] = ex.enc_input_extend_vocab_2[:]
221 |
222 | def init_decoder_seq(self, example_list):
223 | # Pad the inputs and targets
224 | for ex in example_list:
225 | ex.pad_decoder_inp_targ(config.max_dec_steps, self.pad_id)
226 |
227 | # Initialize the numpy arrays.
228 | self.dec_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32)
229 | self.target_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32)
230 | self.dec_padding_mask = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.float32)
231 | self.dec_lens = np.zeros((self.batch_size), dtype=np.int32)
232 |
233 | # Fill in the numpy arrays
234 | for i, ex in enumerate(example_list):
235 | self.dec_batch[i, :] = ex.dec_input[:]
236 | self.target_batch[i, :] = ex.target[:]
237 | self.dec_lens[i] = ex.dec_len
238 | for j in xrange(ex.dec_len):
239 | self.dec_padding_mask[i][j] = 1
240 |
241 | def store_orig_strings(self, example_list):
242 | self.original_articles = [ex.original_article for ex in example_list] # list of lists
243 | self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists
244 | self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists
245 |
246 |
247 | class Batcher(object):
248 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold
249 |
250 | def __init__(self, data_path, vocab, mode, batch_size, single_pass):
251 | self._data_path = data_path
252 | self._vocab = vocab
253 | self._single_pass = single_pass
254 | self.mode = mode
255 | self.batch_size = batch_size
256 | # Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched
257 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX)
258 | self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self.batch_size)
259 |
260 | # Different settings depending on whether we're in single_pass mode or not
261 | if single_pass:
262 | self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once
263 | self._num_batch_q_threads = 1 # just one thread to batch examples
264 | self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing
265 | self._finished_reading = False # this will tell us when we're finished reading the dataset
266 | else:
267 | self._num_example_q_threads = 1 #16 # num threads to fill example queue
268 | self._num_batch_q_threads = 1 #4 # num threads to fill batch queue
269 | self._bucketing_cache_size = 1 #100 # how many batches-worth of examples to load into cache before bucketing
270 |
271 | # Start the threads that load the queues
272 | self._example_q_threads = []
273 | for _ in xrange(self._num_example_q_threads):
274 | self._example_q_threads.append(Thread(target=self.fill_example_queue))
275 | self._example_q_threads[-1].daemon = True
276 | self._example_q_threads[-1].start()
277 | self._batch_q_threads = []
278 | for _ in xrange(self._num_batch_q_threads):
279 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
280 | self._batch_q_threads[-1].daemon = True
281 | self._batch_q_threads[-1].start()
282 |
283 | # Start a thread that watches the other threads and restarts them if they're dead
284 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever
285 | self._watch_thread = Thread(target=self.watch_threads)
286 | self._watch_thread.daemon = True
287 | self._watch_thread.start()
288 |
289 | def next_batch(self):
290 | # If the batch queue is empty, print a warning
291 | if self._batch_queue.qsize() == 0:
292 | tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize())
293 | if self._single_pass and self._finished_reading:
294 | tf.logging.info("Finished reading dataset in single_pass mode.")
295 | return None
296 |
297 | batch = self._batch_queue.get() # get the next Batch
298 | return batch
299 |
300 | def fill_example_queue(self):
301 | input_gen = self.text_generator(data.example_generator(self._data_path, self._single_pass))
302 | f = open("inputs.txt", "a")
303 | while True:
304 | try:
305 | (source1, source2, target) = input_gen.next() # read the next example from file. article and abstract are both strings.
306 | f.write(source1+"\t"+source2+"\t"+target+"\n")
307 | except StopIteration: # if there are no more examples:
308 | tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
309 | if self._single_pass:
310 | tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.")
311 | self._finished_reading = True
312 | break
313 | else:
314 | raise Exception("single_pass mode is off but the example generator is out of data; error.")
315 |
316 | abstract_sentences = [sent.strip() for sent in data.abstract2sents(target)] # Use the and tags in abstract to get a list of sentences.
317 | #example = Example(article, abstract_sentences, self._vocab) # Process into an Example.
318 | #example = Example2(article, ' '.join(abstract_sentences), abstract_sentences, self._vocab)
319 | #example = Example2(' '.join(abstract_sentences), article, abstract_sentences, self._vocab)
320 | #example = Example2(article, article, abstract_sentences, self._vocab)
321 | #example = Example2(' '.join(abstract_sentences), ' '.join(abstract_sentences), abstract_sentences, self._vocab)
322 | example = Example2(source1, source2, target.split(), self._vocab)
323 | self._example_queue.put(example) # place the Example in the example queue.
324 | f.close()
325 |
326 | def fill_batch_queue(self):
327 | while True:
328 | if self.mode == 'decode':
329 | # beam search decode mode single example repeated in the batch
330 | ex = self._example_queue.get()
331 | b = [ex for _ in xrange(self.batch_size)]
332 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size))
333 | else:
334 | # Get bucketing_cache_size-many batches of Examples into a list, then sort
335 | inputs = []
336 | for _ in xrange(self.batch_size * self._bucketing_cache_size):
337 | inputs.append(self._example_queue.get())
338 | #inputs = sorted(inputs, key=lambda inp: inp.enc_len, reverse=True) # sort by length of encoder sequence
339 |
340 | # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue.
341 | batches = []
342 | for i in xrange(0, len(inputs), self.batch_size):
343 | batches.append(inputs[i:i + self.batch_size])
344 | if not self._single_pass:
345 | shuffle(batches)
346 | for b in batches: # each b is a list of Example objects
347 | self._batch_queue.put(Batch(b, self._vocab, self.batch_size))
348 |
349 | def watch_threads(self):
350 | while True:
351 | tf.logging.info(
352 | 'Bucket queue size: %i, Input queue size: %i',
353 | self._batch_queue.qsize(), self._example_queue.qsize())
354 |
355 | time.sleep(60)
356 | for idx,t in enumerate(self._example_q_threads):
357 | if not t.is_alive(): # if the thread is dead
358 | tf.logging.error('Found example queue thread dead. Restarting.')
359 | new_t = Thread(target=self.fill_example_queue)
360 | self._example_q_threads[idx] = new_t
361 | new_t.daemon = True
362 | new_t.start()
363 | for idx,t in enumerate(self._batch_q_threads):
364 | if not t.is_alive(): # if the thread is dead
365 | tf.logging.error('Found batch queue thread dead. Restarting.')
366 | new_t = Thread(target=self.fill_batch_queue)
367 | self._batch_q_threads[idx] = new_t
368 | new_t.daemon = True
369 | new_t.start()
370 |
371 |
372 | def text_generator(self, example_generator):
373 | while True:
374 | e = example_generator.next() # e is a tf.Example
375 | try:
376 | article_text = e.features.feature['source1'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files
377 | abstract_text = e.features.feature['target'].bytes_list.value[0] # the abstract text was saved under the key 'abstract' in the data files
378 | abstract_text_extra = e.features.feature['source2'].bytes_list.value[0]
379 | except ValueError:
380 | tf.logging.error('Failed to get article or abstract from example')
381 | continue
382 | if len(article_text)==0: # See https://github.com/abisee/pointer-generator/issues/1
383 | #tf.logging.warning('Found an example with empty article text. Skipping it.')
384 | continue
385 | else:
386 | yield (article_text, abstract_text_extra, abstract_text)
387 |
--------------------------------------------------------------------------------