├── requirements.txt ├── README.md ├── download_data ├── download_film_scripts.py ├── download_gutenberg_titles.py ├── list_gutenberg_titles.py ├── data_gutenberg.py ├── data_film_scripts.py └── download_wikipedia.py ├── embeddings.py ├── optimisers.py ├── main.py └── enc_dec.py /requirements.txt: -------------------------------------------------------------------------------- 1 | genism 2 | theano=0.8.2 3 | nltk 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Abstractive Text Summarization 2 | 3 | ## The Algorithm 4 | 5 | 6 | Our implementation differs in that we fix the context and summary token of the 7 | embedding matrix. 8 | 9 | -Both embedding matrices are initialised from GloVe 10 | 11 | Helptext: 12 | 13 | ```bash 14 | python3 main.py -h 15 | ``` 16 | 17 | The training datasets are under `data/`. Each JSON file contains three fields `title`, `full_text`, `summary`. They're downloaded with scripts in `download_data/`. 18 | 19 | GloVe data needs to be downloaded and unzipped under `glove/`. The code uses the first 10k most frequent tokens by default. To generate the embeddings for them, 20 | 21 | ```bash 22 | cd glove 23 | head -n 10000 glove.6B.300d.txt >glove.10k.300d.txt 24 | ``` 25 | -------------------------------------------------------------------------------- /download_data/download_film_scripts.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import wikipedia 3 | import bs4 4 | import requests 5 | import argparse 6 | import re 7 | import json 8 | import time 9 | import os 10 | from data_film_scripts import downloadFilm 11 | 12 | def download_film_scripts(titles_file_path, output_path, sleep=0): 13 | with open(titles_file_path, 'r') as titles_file: 14 | for line in titles_file: 15 | m = re.match(r'"(.*)"', line.strip()) 16 | if not m: 17 | continue 18 | try: 19 | downloadFilm(m.group(1), output_path) 20 | time.sleep(sleep) 21 | except Exception as e: 22 | print(e) 23 | continue 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('titles_file', 28 | help='file with all the film titles to download') 29 | parser.add_argument('output_path', 30 | help='directory for output') 31 | parser.add_argument('--sleep', type=int, 32 | default=0, 33 | help='time to sleep after download of each title (in seconds)') 34 | 35 | args = parser.parse_args() 36 | download_film_scripts(args.titles_file, args.output_path, args.sleep) 37 | 38 | if __name__ == '__main__': 39 | main() 40 | 41 | -------------------------------------------------------------------------------- /download_data/download_gutenberg_titles.py: -------------------------------------------------------------------------------- 1 | from __future__ import (division, absolute_import, 2 | print_function, unicode_literals) 3 | 4 | import requests 5 | from bs4 import BeautifulSoup 6 | from gutenberg.acquire import load_etext 7 | from gutenberg.cleanup import strip_headers 8 | from data_film_scripts import downloadSummary 9 | from data_gutenberg import download_book 10 | import time 11 | import json 12 | import argparse 13 | import os 14 | import re 15 | 16 | def download_books(titles_file_path, output_path, sleep=0): 17 | with open(titles_file_path, 'r') as titles_file: 18 | for line in titles_file: 19 | m = re.match(r'"(.*)",(\d+)', line.strip()) 20 | if not m: 21 | continue 22 | try: 23 | download_book(m.group(1), int(m.group(2)), output_path, sleep) 24 | except Exception as e: 25 | print(e) 26 | continue 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('titles_file', 31 | help='the file containing the Gutenberg titles for download') 32 | parser.add_argument('output_path', 33 | help='output directory') 34 | parser.add_argument('--sleep', type=int, 35 | default=0, 36 | help='time to sleep (in seconds)') 37 | args = parser.parse_args() 38 | 39 | download_books(args.titles_file, args.output_path, args.sleep) 40 | 41 | if __name__ == '__main__': 42 | main() 43 | 44 | -------------------------------------------------------------------------------- /download_data/list_gutenberg_titles.py: -------------------------------------------------------------------------------- 1 | from __future__ import (division, absolute_import, 2 | print_function, unicode_literals) 3 | 4 | import requests 5 | from bs4 import BeautifulSoup 6 | from gutenberg.acquire import load_etext 7 | from gutenberg.cleanup import strip_headers 8 | from data_scripts import downloadSummary 9 | import time 10 | import argparse 11 | 12 | def list_most_popular_titles(n, start_index=1, sleep=0): 13 | INDEX_PAGE = 'https://www.gutenberg.org/ebooks/search/?sort_order=downloads&start_index={:}' 14 | 15 | n_batch = int(n / 25) 16 | for i in range(n_batch): 17 | index_page = INDEX_PAGE.format(i * 25 + start_index) 18 | r = requests.get(index_page) 19 | soup = BeautifulSoup(r.text, 'lxml') 20 | 21 | links = soup.findAll('li', {'class': 'booklink'}) 22 | book_titles = [link.a.find('span', {'class': 'title'}).text.strip() for link in links] 23 | gutenberg_ids = [int(link.a['href'].split('/')[2]) for link in links] 24 | 25 | for title, gutenberg_id in zip(book_titles, gutenberg_ids): 26 | print('"{:}",{:}'.format(title, gutenberg_id)) 27 | 28 | time.sleep(sleep) 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('n', type=int, 33 | help='number of titles to list') 34 | parser.add_argument('--start-index', type=int, 35 | default=1, 36 | help='start index') 37 | parser.add_argument('--sleep', type=int, 38 | default=0, 39 | help='sleep between page load (in seconds)') 40 | args = parser.parse_args() 41 | 42 | list_most_popular_titles(args.n, args.start_index, args.sleep) 43 | 44 | if __name__ == '__main__': 45 | main() 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /download_data/data_gutenberg.py: -------------------------------------------------------------------------------- 1 | from __future__ import (division, absolute_import, 2 | print_function, unicode_literals) 3 | 4 | import requests 5 | from bs4 import BeautifulSoup 6 | from gutenberg.acquire import load_etext 7 | from gutenberg.cleanup import strip_headers 8 | from data_film_scripts import downloadSummary 9 | import time 10 | import json 11 | import argparse 12 | import os 13 | 14 | def download_book(title, gutenberg_id, data_path, sleep=0): 15 | print('downloading {:}'.format(title)) 16 | 17 | full_text = strip_headers(load_etext(gutenberg_id)).strip() 18 | summary = downloadSummary(title) 19 | 20 | if full_text is None: 21 | print('Full text is None. Skipping {:}'.format(title)) 22 | return 23 | if summary is None: 24 | print('Summary is None. Skipping {:}'.format(title)) 25 | return 26 | 27 | output_data = {'title': title, 28 | 'full_text': full_text, 29 | 'summary': summary} 30 | 31 | output_file = os.path.join(data_path, 32 | '{:}.json'.format(gutenberg_id)) 33 | with open(output_file, 'w') as f: 34 | json.dump(output_data, f, ensure_ascii=False) 35 | 36 | time.sleep(sleep) 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('title', 41 | help='book title') 42 | parser.add_argument('gutenberg_id', type=int, 43 | help='Gutenberg book id') 44 | parser.add_argument('data_path', 45 | help='output directory') 46 | parser.add_argument('--sleep', type=int, 47 | default=0, 48 | help='time to sleep (in seconds)') 49 | args = parser.parse_args() 50 | 51 | download_book(args.title, args.gutenberg_id, args.data_path, args.sleep) 52 | 53 | if __name__ == '__main__': 54 | main() 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /download_data/data_film_scripts.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import wikipedia 3 | import bs4 4 | import requests 5 | import argparse 6 | import re 7 | import json 8 | import time 9 | import os 10 | 11 | def findWikipediaPage(film_title): 12 | possibles = (wikipedia.search(film_title) 13 | + wikipedia.search(film_title + ' (film)')) 14 | for m in possibles: 15 | a = re.match(r'(.*) \((\d+ )?film\)', m.lower()) 16 | if (a and 17 | (a.group(1).lower() in film_title.lower() 18 | or film_title.lower() in a.group(1).lower())): 19 | return m 20 | try: 21 | if (possibles[0].lower() in film_title.lower() 22 | or film_title.lower() in possibles[0].lower()): 23 | return possibles[0] 24 | else: 25 | return None 26 | except Exception as e: return None 27 | 28 | def downloadSummary(title): 29 | if title is None: 30 | return None 31 | wikipedia_page = wikipedia.page(title) 32 | content = wikipedia_page.content 33 | 34 | headers = re.findall(r'\n== ([\w\s]+) ==\n', content) 35 | for i in range(len(headers)): 36 | if ('plot' in headers[i].lower() 37 | or 'synopsis' in headers[i].lower() 38 | or 'summary' in headers[i].lower()): 39 | p1 = content.find('\n== {:} ==\n'.format(headers[i])) 40 | if i < len(headers) - 1: 41 | p2 = content.find('\n== {:} ==\n'.format(headers[i + 1])) 42 | summary_ = content[(p1 + 8 + len(headers[i])):p2] 43 | else: 44 | summary_ = content[(p1 + 8 + len(headers[i])):] 45 | # strip the subsection titles in summary_ 46 | return re.sub(r'\n==+ .* ==+\n', '', summary_) 47 | 48 | return None 49 | 50 | def downloadScript(film_title): 51 | r = requests.get('http://www.imsdb.com/scripts/' + film_title.title().replace(" ", "-") + '.html') 52 | soup = bs4.BeautifulSoup(r.text, 'lxml') 53 | try: 54 | return soup.find('td', {'class': 'scrtext'}).text 55 | except Exception as e: 56 | return None 57 | 58 | def downloadFilm(film_title, output_path): 59 | wikipedia_page_title = findWikipediaPage(film_title) 60 | print('downloading {:}'.format(film_title)) 61 | print('Wikipedia page {:}'.format(wikipedia_page_title)) 62 | 63 | script = downloadScript(film_title) 64 | summary = downloadSummary(wikipedia_page_title) 65 | 66 | if script is None or summary is None: 67 | print('Error occurred. Skipping {:}'.format(film_title)) 68 | return 69 | 70 | script = script.strip() 71 | summary = summary.strip() 72 | output_json = {'title': film_title, 73 | 'full_text': script, 74 | 'summary': summary} 75 | 76 | filename = film_title.replace(' ', '_') + '.json' 77 | output_file = os.path.join(output_path, filename) 78 | with open(output_file, 'w') as f: 79 | json.dump(output_json, f) 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('title', 84 | help='film title for download') 85 | parser.add_argument('output_path', 86 | help='path for output') 87 | args = parser.parse_args() 88 | 89 | downloadFilm(args.title, args.output_path) 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import (division, absolute_import, 2 | print_function, unicode_literals) 3 | import nltk 4 | from nltk.util import ngrams 5 | import os.path 6 | from nltk.collocations import * 7 | from nltk.tokenize import word_tokenize, wordpunct_tokenize 8 | from gensim.models.word2vec import Word2Vec 9 | from gensim.models import Phrases 10 | import sys 11 | import numpy as np 12 | 13 | 14 | def _get_least_common_word_vector(num_vecs=0, glove_filename="glove/glove.6B.50d.txt"): 15 | if num_vecs == 0: 16 | with open(glove_filename, 'r') as f: 17 | for line in f: pass 18 | vec = line.split(' ')[1:] 19 | return list(map(float, vec)) 20 | else: 21 | total = 0 22 | with open(glove_filename, 'r') as f: 23 | for line in f: total += 1 24 | count = 0 25 | uncommon_vectors = [] 26 | with open(glove_filename, 'r') as f: 27 | for line in f: 28 | if total - count < num_vecs: 29 | uncommon_vectors.append(list(map(float, line.split(' ')[1:]))) 30 | count += 1 31 | return np.mean(np.array(uncommon_vectors), axis=0) 32 | 33 | 34 | def map_sentence_to_glove(sentence, glove_filename="glove/glove.6B.50d.txt", cache={}, default_vec=[]): 35 | vectorized_sentence = [] 36 | indexed_sentence = [] 37 | for word in sentence: 38 | lcw = word.lower() 39 | index = None 40 | if lcw not in cache.keys(): 41 | with open(glove_filename, 'r') as f: 42 | v = None 43 | count = 1 # so we can make 0 UNK 44 | for line in f: 45 | first_word = line.split(' ', 1)[0] 46 | if first_word == lcw: 47 | v = line.split(' ')[1:] 48 | v = map(float, v) 49 | cache[lcw] = v, count 50 | index = count 51 | break 52 | count += 1 53 | if v is None: 54 | v = default_vec 55 | index = count 56 | else: 57 | v, index = cache[lcw] 58 | vectorized_sentence.append(v) 59 | indexed_sentence.append(index) 60 | return vectorized_sentence, cache, indexed_sentence 61 | 62 | 63 | class gloveDocumentParser(object): 64 | def __init__(self, glove_file_name, unk_size=100, pad_token="--PAD--", unk_token="--UNK--"): 65 | self.word_to_vector, self.word_to_id, self.id_to_word, \ 66 | self.vocab_size, vec_length = self.loadGloveFromFile(glove_file_name) 67 | padding_vector = [0] * vec_length 68 | PAD = pad_token 69 | UNK = unk_token 70 | 71 | self.unk = UNK 72 | self.pad = PAD 73 | self.word_to_vector[UNK] = _get_least_common_word_vector(unk_size, glove_filename=glove_file_name) 74 | self.word_to_id[UNK] = 0 75 | self.id_to_word[0] = UNK 76 | 77 | self.word_to_vector[PAD] = padding_vector 78 | self.word_to_id[PAD] = 1 79 | self.id_to_word[1] = PAD 80 | self.word_to_vector_matrix = [] 81 | for i in range(self.vocab_size): 82 | self.word_to_vector_matrix.append(self.word_to_vector[self.id_to_word[i]]) 83 | self.word_to_vector_matrix = np.array(self.word_to_vector_matrix) 84 | 85 | self.embedding_n_tokens = self.word_to_vector_matrix.shape[0] 86 | self.token_dim = self.word_to_vector_matrix.shape[1] 87 | 88 | def loadGloveFromFile(self, glove_file_name): 89 | word_to_vec = {} 90 | word_to_id = {} 91 | id_to_word = {} 92 | index = 2 93 | with open(glove_file_name, 'r') as f: 94 | for line in f: 95 | split_line = line.split(' ') 96 | word_key = split_line[0] 97 | vector = list(map(float, split_line[1:])) 98 | vector_dim = len(vector) 99 | word_to_vec[word_key] = vector 100 | word_to_id[word_key] = index 101 | id_to_word[index] = word_key 102 | index += 1 103 | return word_to_vec, word_to_id, id_to_word, index, vector_dim 104 | 105 | def parseDocument(self, document): 106 | tokens = wordpunct_tokenize(document) 107 | index_matrix = [self.word_to_id[token.lower()] if token.lower() in self.word_to_id 108 | else self.word_to_id[self.unk] for token in tokens] 109 | return index_matrix 110 | 111 | def documentFromVector(self, id_vector): 112 | doc = [self.id_to_word[_id] for _id in id_vector] 113 | return doc 114 | 115 | 116 | def rouge_score(reference, hypothesis, n): 117 | ref_ngrams = list(ngrams(wordpunct_tokenize(reference), n)) 118 | hyp_ngrams = list(ngrams(wordpunct_tokenize(hypothesis), n)) 119 | matching_ngrams = [x for x in hyp_ngrams if x in ref_ngrams] 120 | return 1.0 * len(matching_ngrams) / len(ref_ngrams) 121 | 122 | 123 | def rouge_test(): 124 | reference = "The book was very good" 125 | hyp = "The book was very interesting" 126 | print(rouge_score(reference, hyp, 1)) 127 | 128 | 129 | def doc_parser_test(): 130 | g = gloveDocumentParser("glove.6B.50d.txt") 131 | print(g.parseDocument("This is a sentence. This is another sentence. What can you do about it, glove?")) 132 | print(g.documentFromVector([39, 16, 9, 2424, 4])) 133 | 134 | 135 | if __name__ == "__main__": 136 | # doc_parser_test() 137 | rouge_test() 138 | -------------------------------------------------------------------------------- /optimisers.py: -------------------------------------------------------------------------------- 1 | from __future__ import (division, absolute_import, 2 | print_function, unicode_literals) 3 | 4 | import numpy 5 | import theano 6 | import theano.tensor as tensor 7 | 8 | profile=False 9 | 10 | 11 | def itemlist(tparams): 12 | return [v for k, v in tparams.items()] 13 | 14 | 15 | def adam(lr, tparams, grads, inp, cost): 16 | gshared = [theano.shared(p.get_value() * 0., 17 | name='%s_grad' % k) 18 | for k, p in tparams.items()] 19 | gsup = [(gs, g) for gs, g in zip(gshared, grads)] 20 | 21 | f_grad_shared = theano.function(inp, cost, 22 | updates=gsup, profile=profile, 23 | allow_input_downcast=True) 24 | 25 | lr0 = 0.0002 26 | b1 = 0.1 27 | b2 = 0.001 28 | e = 1e-8 29 | 30 | updates = [] 31 | 32 | i = theano.shared(numpy.float32(0.)) 33 | i_t = i + 1. 34 | fix1 = 1. - b1**(i_t) 35 | fix2 = 1. - b2**(i_t) 36 | lr_t = lr0 * (tensor.sqrt(fix2) / fix1) 37 | 38 | for p, g in zip(tparams.values(), gshared): 39 | m = theano.shared(p.get_value() * 0.) 40 | v = theano.shared(p.get_value() * 0.) 41 | m_t = (b1 * g) + ((1. - b1) * m) 42 | v_t = (b2 * tensor.sqr(g)) + ((1. - b2) * v) 43 | g_t = m_t / (tensor.sqrt(v_t) + e) 44 | p_t = p - (lr_t * g_t) 45 | updates.append((m, m_t)) 46 | updates.append((v, v_t)) 47 | updates.append((p, p_t)) 48 | updates.append((i, i_t)) 49 | 50 | f_update = theano.function([lr], [], updates=updates, 51 | on_unused_input='ignore', profile=profile, 52 | allow_input_downcast=True) 53 | 54 | return f_grad_shared, f_update 55 | 56 | 57 | def adadelta(lr, tparams, grads, inp, cost): 58 | zipped_grads = [theano.shared(p.get_value() * numpy.float32(0.), 59 | name='%s_grad' % k) 60 | for k, p in tparams.items()] 61 | running_up2 = [theano.shared(p.get_value() * numpy.float32(0.), 62 | name='%s_rup2' % k) 63 | for k, p in tparams.items()] 64 | running_grads2 = [theano.shared(p.get_value() * numpy.float32(0.), 65 | name='%s_rgrad2' % k) 66 | for k, p in tparams.items()] 67 | 68 | zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)] 69 | rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2)) 70 | for rg2, g in zip(running_grads2, grads)] 71 | 72 | f_grad_shared = theano.function(inp, cost, updates=zgup+rg2up, 73 | profile=profile, 74 | allow_input_downcast=True) 75 | 76 | updir = [-tensor.sqrt(ru2 + 1e-6) / tensor.sqrt(rg2 + 1e-6) * zg 77 | for zg, ru2, rg2 in 78 | zip(zipped_grads, running_up2, running_grads2)] 79 | ru2up = [(ru2, 0.95 * ru2 + 0.05 * (ud ** 2)) 80 | for ru2, ud in zip(running_up2, updir)] 81 | param_up = [(p, p + ud) for p, ud in zip(itemlist(tparams), updir)] 82 | 83 | f_update = theano.function([lr], [], updates=ru2up+param_up, 84 | on_unused_input='ignore', profile=profile, 85 | allow_input_downcast=True) 86 | 87 | return f_grad_shared, f_update 88 | 89 | 90 | def rmsprop(lr, tparams, grads, inp, cost): 91 | zipped_grads = [theano.shared(p.get_value() * numpy.float32(0.), 92 | name='%s_grad' % k) 93 | for k, p in tparams.items()] 94 | running_grads = [theano.shared(p.get_value() * numpy.float32(0.), 95 | name='%s_rgrad' % k) 96 | for k, p in tparams.items()] 97 | running_grads2 = [theano.shared(p.get_value() * numpy.float32(0.), 98 | name='%s_rgrad2' % k) 99 | for k, p in tparams.items()] 100 | 101 | zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)] 102 | rgup = [(rg, 0.95 * rg + 0.05 * g) for rg, g in zip(running_grads, grads)] 103 | rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2)) 104 | for rg2, g in zip(running_grads2, grads)] 105 | 106 | f_grad_shared = theano.function(inp, cost, updates=zgup+rgup+rg2up, 107 | profile=profile, 108 | allow_input_downcast=True) 109 | 110 | updir = [theano.shared(p.get_value() * numpy.float32(0.), 111 | name='%s_updir' % k) 112 | for k, p in tparams.iteritems()] 113 | updir_new = [(ud, 0.9 * ud - 1e-4 * zg / tensor.sqrt(rg2 - rg ** 2 + 1e-4)) 114 | for ud, zg, rg, rg2 in zip(updir, zipped_grads, running_grads, 115 | running_grads2)] 116 | param_up = [(p, p + udn[1]) 117 | for p, udn in zip(itemlist(tparams), updir_new)] 118 | f_update = theano.function([lr], [], updates=updir_new+param_up, 119 | on_unused_input='ignore', profile=profile, 120 | allow_input_downcast=True) 121 | 122 | return f_grad_shared, f_update 123 | 124 | 125 | def sgd(lr, tparams, grads, x, mask, y, cost): 126 | gshared = [theano.shared(p.get_value() * 0., name='%s_grad' % k) 127 | for k, p in tparams.items()] 128 | gsup = [(gs, g) for gs, g in zip(gshared, grads)] 129 | 130 | f_grad_shared = theano.function([x, mask, y], cost, updates=gsup, 131 | profile=profile, 132 | allow_input_downcast=True) 133 | 134 | pup = [(p, p - lr * g) for p, g in zip(itemlist(tparams), gshared)] 135 | f_update = theano.function([lr], [], updates=pup, profile=profile, 136 | allow_input_downcast=True) 137 | 138 | return f_grad_shared, f_update 139 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import (division, absolute_import, 2 | print_function, unicode_literals) 3 | import os 4 | import sys 5 | import argparse 6 | import logging 7 | from enc_dec import train 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser( 12 | description='Abstractive sentence summariser' 13 | ) 14 | # parser.add_argument('--log', 15 | # default='DEBUG', 16 | # choices=['DEBUG', 'WARNING', 'ERROR'], 17 | # help='log level for Python logging') 18 | parser.add_argument('--context-encoder', 19 | default='attention', 20 | choices=['baseline', 'attention'], 21 | help='context encoder name') 22 | 23 | # parser.add_argument('--corpus', required=True, 24 | # help='directory of the corpus e.g data/wikipedia/') 25 | 26 | parser.add_argument('--corpus', 27 | default='./data/wikipedia', 28 | help='directory of the corpus e.g data/wikipedia/') 29 | 30 | # optimiser 31 | parser.add_argument('--optimizer', 32 | default='adam', 33 | choices=['adam', 'adadelta', 'rmsprop', 'sgd'], 34 | help='optimizing algorithm') 35 | parser.add_argument('--learning-rate', type=float, 36 | default=0.001, 37 | help='learning rate for the optimizer') 38 | 39 | # model params 40 | parser.add_argument('--embed-full-text-by', 41 | choices=['word', 'sentence'], 42 | default='word', 43 | help='embed full text by word or sentence') 44 | parser.add_argument('--seq-maxlen', type=int, 45 | default=500, 46 | help='max length of input full text') 47 | parser.add_argument('--summary-maxlen', type=int, 48 | default=200, 49 | help='max length of each summary') 50 | parser.add_argument('--summary-context-length', type=int, 51 | default=5, 52 | help='summary context length used for training') 53 | parser.add_argument('--internal-representation-dim', type=int, 54 | default=2000, 55 | help='internal representation dimension') 56 | parser.add_argument('--attention-weight-max-roll', type=int, 57 | default=1, 58 | help='max roll for the attention weight vector in attention encoder') 59 | 60 | # training params 61 | parser.add_argument('--l2-penalty-coeff', type=float, 62 | default=0.00, 63 | help='penalty coefficient to the L2-norms of the model params') 64 | parser.add_argument('--train-split', type=float, 65 | default=0.75, 66 | help='weight of training corpus in the entire corpus, the rest for validation') 67 | parser.add_argument('--epochs', type=int, 68 | default=10000, 69 | help='number of epochs for training') 70 | parser.add_argument('--minibatch-size', type=int, 71 | default=20, 72 | help='mini batch size') 73 | parser.add_argument('--seed', type=int, 74 | default=None, 75 | help='seed for the random stream') 76 | parser.add_argument('--dropout-rate', type=float, 77 | default=None, 78 | help='dropout rate in (0,1)') 79 | 80 | # model load/save 81 | parser.add_argument('--save-params', 82 | default='ass_params.pkl', 83 | help='file for saving params') 84 | parser.add_argument('--save-params-every', type=int, 85 | default=5, 86 | help='save params every epochs') 87 | parser.add_argument('--validate-every', type=int, 88 | default=5, 89 | help='validate every epochs') 90 | parser.add_argument('--print-every', type=int, 91 | default=5, 92 | help='print info every batches') 93 | 94 | # summary generation on the validation set 95 | parser.add_argument('--generate-summary', 96 | action='store_true', 97 | default=False, 98 | help='whether to generate summaries when validating') 99 | parser.add_argument('--summary-search-beam-size', type=int, 100 | default=2, 101 | help='beam size for the summary search') 102 | 103 | args = parser.parse_args() 104 | 105 | # logging.basicConfig(level=args.log.upper()) 106 | assert args.learning_rate > 0 107 | assert args.seq_maxlen > 0 108 | assert args.summary_maxlen > 0 109 | assert args.summary_context_length > 0 110 | assert args.internal_representation_dim > 0 111 | assert args.attention_weight_max_roll >= 0 112 | assert args.l2_penalty_coeff >= 0 113 | assert (args.train_split > 0 and args.train_split <= 1) 114 | assert args.epochs >= 0 115 | assert args.minibatch_size > 0 116 | assert (args.seed is None or args.seed >= 0) 117 | assert (args.dropout_rate is None 118 | or (args.dropout_rate > 0 and args.dropout_rate < 1)) 119 | assert args.save_params_every > 0 120 | assert args.validate_every > 0 121 | assert args.print_every > 0 122 | assert args.summary_search_beam_size > 0 123 | 124 | args_dict = vars(args) 125 | print('Args', args_dict) 126 | train(**args_dict) 127 | 128 | # train( 129 | # model=args.model, 130 | # corpus=args.corpus, 131 | # optimizer=args.optimizer, 132 | # learning_rate=args.learning_rate, 133 | # embed_full_text_by=args.embed_full_text_by, 134 | 135 | # summary_maxlen=args.summary_maxlen, 136 | 137 | # summary_context_length=args.summary_context_length, 138 | # l2_penalty_coeff=args.L2_penalty_coeff, 139 | # minibatch_size=args.minibatch_size, 140 | # epochs=args.epochs, 141 | # train_split=args.train_split, 142 | # seed=args.seed, 143 | # dropout_rate=args.dropout_rate, 144 | # internal_representation_dim=args.internal_representation_dim, 145 | # attention_weight_max_roll=args.attention_weight_max_roll, 146 | # load_params=args.load_params, 147 | # save_params=args.save_params, 148 | # save_params_every=args.save_params_every, 149 | # validate_every=args.validate_every, 150 | # generate_summary=args.generate_summary, 151 | # summary_search_beam_size=args.summary_search_beam_size 152 | # ) 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /download_data/download_wikipedia.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals 2 | from BeautifulSoup import BeautifulSoup 3 | import urllib 4 | import urllib2 5 | import nltk 6 | import re 7 | import argparse 8 | import os 9 | import json 10 | import codecs 11 | from Queue import Queue 12 | import wikipedia 13 | 14 | def main(): 15 | #simple_test() 16 | #article_list_test() 17 | #w = WikipediaDownloader() 18 | #w.downloadArticleList() 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('page', help='wikipedia page title') 21 | parser.add_argument('data_path', help='output directory') 22 | parser.add_argument('--sleep', type=int, default=0, help='time to sleep in seconds') 23 | args = parser.parse_args() 24 | WikipediaDownloader().downloadArticle(args.page, args.data_path) 25 | 26 | def simple_test(): 27 | article_list = ["Chin", "Albert_Einstein", "America", "Utopia", "adsfasdfasdf"] 28 | for a in article_list: 29 | s = download_articles(a) 30 | if s is not None: 31 | print( s[0] ) 32 | print( s[1] ) 33 | 34 | class WikipediaDownloader(object): 35 | def __init__(self): 36 | self.article_list = [] 37 | 38 | def downloadArticleList(self, num_articles): 39 | seed_article = "Albert Einstein" 40 | self.article_list.append(seed_article) 41 | wikipedia.set_lang('simple') 42 | 43 | index = 0 44 | 45 | while len(self.article_list) < num_articles: 46 | if index >= len(self.article_list): break 47 | article = self.article_list[index] 48 | try: 49 | wp = wikipedia.page(article) 50 | except wikipedia.exceptions.PageError, e: 51 | index += 1 52 | continue 53 | except wikipedia.exceptions.DisambiguationError, e: 54 | index += 1 55 | continue 56 | links = wp.links 57 | self.article_list.extend(links) 58 | index += 1 59 | 60 | print ( self.article_list ) 61 | with open("all.subjects", 'w') as f: 62 | for s in self.article_list: 63 | try: 64 | print(s, file=f) 65 | except UnicodeEncodeError, e: 66 | pass 67 | f.flush() 68 | 69 | def downloadArticle(self, article_title, data_path): 70 | output_file = os.path.join(data_path, '{:}.json'.format(article_title)) 71 | if os.path.isfile( output_file ): 72 | print(article_title + "exists, skipping") 73 | return 74 | try: 75 | wikipedia.set_lang('simple') 76 | summary = wikipedia.summary(article_title, sentences=1) 77 | wikipedia.set_lang('en') 78 | text = wikipedia.summary(article_title) 79 | except wikipedia.exceptions.PageError, e: 80 | return 81 | summary = remove_brackets(summary) 82 | text = remove_brackets(text) 83 | if not suitable_for_training(text, summary): return 84 | 85 | output = {'title': article_title, 'full_text': text, 'summary': summary} 86 | with codecs.open('_{}'.format(output_file), 'w', encoding="utf-8") as outfile: json.dump(output, f, ensure_ascii=False) 87 | #with open(output_file, 'w') as f: json.dump(unicode(output), f, ensure_ascii=False) 88 | 89 | # define whether or not we want to put a pair of 90 | # sentences into our training corpus 91 | def suitable_for_training(normal, simple): 92 | if len(normal) <= len(simple): 93 | return False 94 | if "|" in normal or "|" in simple: # we use these to separate sentences in our files, so we don't want them in our sentences 95 | return False 96 | 97 | # some phrases that indicate we've hit a disambiguation page... skipping these for now 98 | # need to make this into a list and just check the list... 99 | a = "For other uses" 100 | b = "usually refers to" 101 | c = "may refer to" 102 | d = "See List of" 103 | e = "Index of" 104 | f = "oordinates" 105 | g = '.' 106 | if a in normal or a in simple: return False 107 | if b in normal or b in simple: return False 108 | if c in normal or c in simple: return False 109 | if d in normal or d in simple: return False 110 | if e in normal or e in simple: return False 111 | if f in normal or f in simple: return False 112 | if g not in normal or g not in simple: return False 113 | return True 114 | 115 | def download_articles(article_title): 116 | article_title = article_title.replace(" ", "_") 117 | title = urllib.quote(article_title) 118 | opener = urllib2.build_opener() 119 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 120 | try: 121 | normal_sentence, _ = download_wikipedia(opener, title) 122 | simple_sentence, next_pages = download_wikipedia(opener, title, subwiki="simple") 123 | if normal_sentence is None or simple_sentence is None: return None, None 124 | except urllib2.HTTPError, e: 125 | print("No article found for " + article_title + "on one of the wikipedias... returning none") 126 | return None, None 127 | return (normal_sentence, simple_sentence), next_pages 128 | 129 | def download_wikipedia(opener, article, subwiki='en'): 130 | r = opener.open("http://" + subwiki + ".wikipedia.org/wiki/" + article) 131 | data = r.read() 132 | r.close() 133 | soup = BeautifulSoup(data) 134 | text = soup.find('div', id='bodyContent').p.getText(separator=u' ') 135 | 136 | links = soup.findAll('a', href=True) 137 | possible_pages = process_links(links, subwiki) 138 | 139 | # we use this to pick sentences out of our text... hard to do 140 | # without a full natural language parser... 141 | sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') 142 | try: first_sentence = sent_detector.tokenize(text)[0] 143 | except IndexError, e: return None, None 144 | 145 | # then we remove anything between brackets... 146 | processed_first_sentence = remove_brackets(first_sentence) 147 | processed_first_sentence = processed_first_sentence.replace(" ", " ") 148 | processed_first_sentence = processed_first_sentence.replace(" ", " ") 149 | return processed_first_sentence, possible_pages 150 | 151 | def process_links(links, subwiki): 152 | skip_if_contains = ['wikipedia.org', 'wikimedia.org', 'Wikipedia', 'Portal', 'Privacy_policy', 153 | 'Category', 'Special', 'Main_Page', 'Help', 'File', 'Terms_of_Use', 'Talk', 'Template', 154 | '#sitelinks-wikipedia'] 155 | possible_next_pages = [] 156 | for link in links: 157 | #print( link, link['href'] ) 158 | href = link['href'] 159 | cont = False 160 | for s in skip_if_contains: 161 | if s in href: 162 | cont = True 163 | break 164 | if cont: continue 165 | article_sub_url = href.split('/wiki/') 166 | if len(article_sub_url) > 1: 167 | article_title = article_sub_url[-1] 168 | if len(article_title) <= 1: continue 169 | if article_title[0] == 'Q': continue 170 | possible_next_pages.append(article_title) 171 | 172 | return possible_next_pages 173 | 174 | def remove_brackets(text, brackets="()[]"): 175 | count = [0] * (len(brackets) // 2) # count open/close brackets 176 | saved_chars = [] 177 | for character in text: 178 | for i, b in enumerate(brackets): 179 | if character == b: # found bracket 180 | kind, is_close = divmod(i, 2) 181 | count[kind] += (-1)**is_close # `+1`: open, `-1`: close 182 | if count[kind] < 0: 183 | count[kind] = 0 184 | break 185 | else: 186 | if not any(count): 187 | saved_chars.append(character) 188 | return ''.join(saved_chars) 189 | 190 | 191 | if __name__=="__main__": 192 | main() 193 | 194 | 195 | -------------------------------------------------------------------------------- /enc_dec.py: -------------------------------------------------------------------------------- 1 | from __future__ import (division, absolute_import, 2 | print_function, unicode_literals) 3 | import argparse 4 | import logging 5 | import sys 6 | import warnings 7 | 8 | import os 9 | import glob 10 | import json 11 | import pickle 12 | import heapq 13 | import numpy as np 14 | from random import sample 15 | import theano 16 | import theano.tensor as T 17 | from theano.tensor.shared_randomstreams import RandomStreams 18 | from optimisers import adam, adadelta, rmsprop, sgd 19 | from embeddings import gloveDocumentParser 20 | from sklearn.model_selection import train_test_split 21 | 22 | warnings.filterwarnings("ignore") 23 | 24 | EPSILON_FOR_LOG = 1e-8 25 | 26 | 27 | def get_encoder(context_encoder): 28 | def baseline_encoder(x, y, x_mask, y_pos, params, tparams): 29 | ''' baseline context encoder given one piece of text 30 | 31 | Returns ctx each row for a training instance 32 | ''' 33 | if x.ndim == 1: 34 | mb_size = 1 35 | elif x.ndim == 2: 36 | mb_size = x.shape[0] 37 | seq_len = params['seq_maxlen'] 38 | wv_size = params['full_text_word_vector_size'] 39 | 40 | x_emb = tparams['Xemb'][x.flatten(), :] 41 | x_emb_masked = T.batched_dot(x_emb, x_mask.flatten()) 42 | 43 | if x.ndim == 1: 44 | ctx = x_emb_masked.sum(axis=0) / x_mask.sum() 45 | elif x.ndim == 2: 46 | ctx = T.batched_dot( 47 | x_emb_masked.reshape((mb_size, seq_len, wv_size)).sum(axis=1), 48 | 1 / x_mask.sum(axis=1) 49 | ) 50 | 51 | return T.cast(ctx, theano.config.floatX) 52 | 53 | def attention_encoder(x, y, x_mask, y_pos, params, tparams): 54 | ''' attention-based context encoder given one piece of text 55 | 56 | ''' 57 | if x.ndim == 1: 58 | mb_size = 1 59 | else: 60 | mb_size = params['minibatch_size'] 61 | l = params['seq_maxlen'] 62 | C = params['summary_context_length'] 63 | Q = params['attention_weight_max_roll'] 64 | wv_size_x = params['full_text_word_vector_size'] 65 | wv_size_y = params['summary_word_vector_size'] 66 | P = tparams['att_P'] 67 | m = tparams['att_P_conv'] 68 | 69 | if x.ndim == 1: 70 | x_emb = tparams['Xemb'][x, :] 71 | y_emb = tparams['Yemb'][y[(y_pos - C):y_pos], :] 72 | p = T.nnet.softmax( 73 | T.dot(x_emb, T.dot(P, y_emb.flatten())) 74 | ).flatten() 75 | p_masked = p * x_mask 76 | p_masked_n = p_masked / p_masked.sum() 77 | ctx = T.batched_dot(x_emb, T.dot(m, p_masked_n)).sum(axis=0) 78 | 79 | elif x.ndim == 2: 80 | x_emb = tparams['Xemb'][x.flatten(), :] 81 | x_emb = x_emb.reshape((mb_size, l, wv_size_x)) 82 | y_emb = tparams['Yemb'][y[:, (y_pos - C):y_pos].flatten(), :] 83 | y_emb = y_emb.flatten().reshape((mb_size, C * wv_size_y)).T 84 | p = T.nnet.softmax( 85 | T.batched_dot(x_emb, T.dot(P, y_emb).T) 86 | ) 87 | p_masked = p * x_mask 88 | p_masked_n = p_masked / p_masked.norm(1, axis=1).reshape((mb_size, 1)) 89 | ctx = T.batched_dot(T.dot(m, p_masked_n.T).T, x_emb) 90 | 91 | return T.cast(ctx, theano.config.floatX) 92 | 93 | if context_encoder == 'baseline': 94 | return baseline_encoder 95 | elif context_encoder == 'attention': 96 | return attention_encoder 97 | else: 98 | raise ValueError('Invalide context encoder {:}'.format(context_encoder)) 99 | 100 | 101 | def conditional_distribution(x, y, x_mask, y_pos, params, tparams): 102 | ''' Return the conditional distribution of next summary word index 103 | 104 | Given the input text tensor and summary tensor, returns the distribution for the next summary word index 105 | ''' 106 | enc = get_encoder(params['context_encoder']) 107 | C = params['summary_context_length'] 108 | wv_size = params['summary_word_vector_size'] 109 | 110 | if x.ndim == 1: 111 | y_emb = tparams['Yemb'][y[(y_pos - C):y_pos].flatten(), :].flatten() 112 | h = T.tanh(T.dot(tparams['U'], y_emb) + tparams['b']).flatten() 113 | ctx = enc(x, y, x_mask, y_pos, params, tparams) 114 | 115 | u = T.dot(tparams['V'], h) + T.dot(tparams['W'], ctx) 116 | y_next = T.nnet.softmax(u).flatten() 117 | 118 | elif x.ndim == 2: 119 | mb_size = x.shape[0] 120 | y_emb = tparams['Yemb'][y[:, (y_pos - C):y_pos].flatten(), :] 121 | # each column for a training instance 122 | y_emb = y_emb.flatten().reshape((mb_size, C * wv_size)).T 123 | # each row for a training instance 124 | # (in order to broadcast the vector b along the row axis) 125 | h = T.tanh((T.dot(tparams['U'], y_emb)).T + tparams['b']) 126 | # each row for a training instance 127 | ctx = enc(x, y, x_mask, y_pos, params, tparams) 128 | 129 | # each column for a training instance 130 | u = T.dot(tparams['V'], h.T) + T.dot(tparams['W'], ctx.T) 131 | # softmax works row-wise 132 | y_next = T.nnet.softmax(u.T) 133 | 134 | return y_next 135 | 136 | 137 | def conditional_score(x, y, x_mask, y_pos, params, tparams): 138 | ''' Return conditional score of the (j+1)-th word index of the summary i.e. y[j] 139 | 140 | ''' 141 | dist = conditional_distribution(x, y, x_mask, y_pos, params, tparams) 142 | if x.ndim == 1: 143 | return dist[y[y_pos]] 144 | elif x.ndim == 2: 145 | return dist[T.arange(x.shape[0]), y[:, y_pos]] 146 | 147 | 148 | def training_model_output(x, y, x_mask, y_mask, params, tparams, y_embedder): 149 | ''' Return tensors for training model 150 | 151 | ''' 152 | mb_size = params['minibatch_size'] 153 | C = params['summary_context_length'] 154 | l = params['summary_maxlen'] 155 | 156 | # pad y 157 | id_pad = y_embedder.word_to_id[y_embedder.pad] 158 | y_padded = T.concatenate([T.alloc(id_pad, mb_size, C), y], axis=1) 159 | 160 | # compute the model probabilities for each encoded token in y 161 | fn = lambda y_pos, x, y, x_mask: conditional_score(x, y, x_mask, y_pos, params, tparams) 162 | y_pos_range = T.arange(C, l + C, dtype='int32') 163 | 164 | prob_, _ = theano.scan(fn, 165 | sequences=y_pos_range, 166 | non_sequences=[x, y_padded, x_mask], 167 | n_steps=l) 168 | # prob = T.concatenate([v.reshape((mb_size, 2)) for v in prob_], axis=1) 169 | prob = prob_.T 170 | 171 | # masked negative log-likelihood 172 | nll_per_token = - T.log(prob + EPSILON_FOR_LOG) * y_mask 173 | nll_per_text = T.sum(nll_per_token, axis=1) / T.sum(y_mask, axis=1) 174 | return T.cast(nll_per_text, theano.config.floatX) 175 | 176 | 177 | def tfunc_best_candidate_tokens(params, tparams): 178 | ''' Returns a Theano function that computes the best k candidate terms for the next position in the summary 179 | 180 | ''' 181 | k = params['summary_search_beam_size'] 182 | 183 | x = T.cast(T.vector(dtype=theano.config.floatX), 'int32') 184 | x_mask = T.vector(dtype=theano.config.floatX) 185 | y = T.cast(T.vector(dtype=theano.config.floatX), 'int32') 186 | y_pos = T.cast(T.scalar(dtype=theano.config.floatX), 'int32') 187 | 188 | dist = conditional_distribution(x, y, x_mask, y_pos, params, tparams) 189 | best_candidate_ids = dist.argsort()[-k:] 190 | f = theano.function([x, y, x_mask, y_pos], 191 | [best_candidate_ids, dist[best_candidate_ids]], 192 | allow_input_downcast=True) 193 | return f 194 | 195 | 196 | def summarize(x, x_mask, f_best_candidates, params, tparams, y_embedder): 197 | ''' Generate summary for a single text using beam search 198 | 199 | Parameters 200 | ----------- 201 | x : numpy vector (not Theano variable) 202 | encoded single text to summarize 203 | x_mask : numpy vector (not Theano variable) 204 | mask vector for the text 205 | ''' 206 | C = params['summary_context_length'] 207 | k = params['summary_search_beam_size'] 208 | id_pad = y_embedder.word_to_id[y_embedder.pad] 209 | 210 | # initialise the summary and the beams for search 211 | y = [y_embedder.word_to_id[y_embedder.pad]] * C 212 | beams = [(0.0, y)] 213 | 214 | for j in range(params['summary_maxlen']): 215 | # for each (score, y) in the current beam, expand with the 216 | # k best candidates for the next position in the summary 217 | new_beams = [] 218 | for (base_score, y) in beams: 219 | token_ids, token_probs = f_best_candidates(x, y, x_mask, len(y)) 220 | for (token_id, token_prob) in zip(token_ids, token_probs): 221 | # add a small constant before taking log to increase 222 | # numerical stability 223 | new_score = base_score - np.log(EPSILON_FOR_LOG + token_prob) 224 | heapq.heappush(new_beams, (new_score, y + [token_id])) 225 | 226 | # Now we retain the k best summaries after all expansions 227 | # for the next position 228 | beams = heapq.nsmallest(k, new_beams) 229 | 230 | (best_nll_score, summary) = heapq.heappop(beams) 231 | return summary[C:] 232 | 233 | 234 | def load_params_(params, tparams, file_path): 235 | with open(file_path, 'rb') as f: 236 | params = pickle.load(f) 237 | tparams = pickle.load(f) 238 | 239 | 240 | def save_params_(params, tparams, file_path): 241 | with open(file_path, 'wb') as f: 242 | pickle.dump(params, f) 243 | pickle.dump(tparams, f) 244 | 245 | 246 | def init_params(**kwargs): 247 | def init_shared_tparam_(name, shape, value=None, 248 | borrow=True, dtype=theano.config.floatX): 249 | if value is None: 250 | value = np.random.uniform(low=-0.02, high=0.02, size=shape) 251 | return theano.shared(value=value.astype(dtype), 252 | name=name, 253 | borrow=borrow) 254 | 255 | def attention_prob_conv_matrix(Q, l): 256 | assert l >= Q 257 | m = np.diagflat([1.0] * l) 258 | for i in range(1, Q): 259 | m += np.diagflat([1.0] * (l - i), k=i) 260 | m += np.diagflat([1.0] * (l - i), k=-i) 261 | m = m / np.sum(m, axis=0) 262 | return m 263 | 264 | params = kwargs.copy() 265 | params.update({'rng': np.random.RandomState(seed=params['seed']), 266 | 'trng': RandomStreams(seed=params['seed'])}) 267 | 268 | if params['embed_full_text_by'] == 'word': 269 | x_embedder = gloveDocumentParser('glove/glove.10k.300d.txt') 270 | y_embedder = x_embedder 271 | else: 272 | x_embedder = None 273 | y_embedder = None 274 | params.update({'full_text_word_vector_size': x_embedder.token_dim, 275 | 'summary_word_vector_size': y_embedder.token_dim}) 276 | 277 | h = params['internal_representation_dim'] 278 | C = params['summary_context_length'] 279 | l = params['seq_maxlen'] 280 | V_x = x_embedder.embedding_n_tokens 281 | V_y = y_embedder.embedding_n_tokens 282 | d_x = x_embedder.token_dim # full text word vector size 283 | d_y = y_embedder.token_dim # summary word vector size 284 | 285 | tparams = { 286 | 'U': init_shared_tparam_('U', (h, C * d_y)), 287 | 'b': init_shared_tparam_('b', (h,)), 288 | 'V': init_shared_tparam_('V', (V_y, h)), 289 | 'W': init_shared_tparam_('W', (V_y, d_x)), 290 | 'Xemb': init_shared_tparam_('Xemb', (V_x, d_x), 291 | value=x_embedder.word_to_vector_matrix), 292 | 'Yemb': init_shared_tparam_('Yemb', (V_y, d_y), 293 | value=y_embedder.word_to_vector_matrix) 294 | } 295 | if params['context_encoder'] == 'attention': 296 | Q = params['attention_weight_max_roll'] 297 | m = attention_prob_conv_matrix(Q, l) 298 | tparams.update({ 299 | 'att_P': init_shared_tparam_('att_P', (d_x, C * d_y)), 300 | 'att_P_conv': init_shared_tparam_('att_P_conv', (l, l), 301 | value=m) 302 | }) 303 | 304 | return params, tparams, x_embedder, y_embedder 305 | 306 | 307 | def load_corpus(params, tparams, x_embedder, y_embedder): 308 | def pad_to_length(v, pad, l): 309 | return np.pad(v, (0, l - len(v)), 'constant', 310 | constant_values=(pad, pad)) 311 | 312 | def mask_vector(v, l): 313 | return [1] * len(v) + [0] * (l - len(v)) 314 | 315 | C = params['summary_context_length'] 316 | l_x = params['seq_maxlen'] 317 | l_y = params['summary_maxlen'] 318 | id_pad_x = x_embedder.word_to_id[x_embedder.pad] 319 | id_pad_y = y_embedder.word_to_id[y_embedder.pad] 320 | 321 | x_ = [] 322 | y_ = [] 323 | x_mask_ = [] 324 | y_mask_ = [] 325 | for file_path in glob.iglob(os.path.join(params['corpus'], '*.json')): 326 | try: 327 | with open(file_path, 'r') as f: 328 | document = json.load(f) 329 | full_text_vector = x_embedder.parseDocument(document['full_text']) 330 | summary_vector = y_embedder.parseDocument(document['summary']) 331 | 332 | if not len(full_text_vector) or not len(summary_vector): 333 | continue 334 | 335 | x_.append(pad_to_length(full_text_vector[:l_x], id_pad_x, l_x)) 336 | y_.append(pad_to_length(summary_vector[:l_y], id_pad_y, l_y)) 337 | x_mask_.append(mask_vector(full_text_vector[:l_x], l_x)) 338 | y_mask_.append(mask_vector(summary_vector[:l_y], l_y)) 339 | except Exception as e: 340 | continue 341 | print('Loaded {:} files'.format(len(x_))) 342 | 343 | x = np.array(x_, dtype='int32') 344 | y = np.array(y_, dtype='int32') 345 | x_mask = np.array(x_mask_) 346 | y_mask = np.array(y_mask_) 347 | 348 | x_train, x_test, y_train, y_test, \ 349 | x_mask_train, x_mask_test, \ 350 | y_mask_train, y_mask_test = \ 351 | train_test_split(x, y, x_mask, y_mask, 352 | train_size=params['train_split'], 353 | random_state=params['rng']) 354 | 355 | return x_train, x_test, y_train, y_test, \ 356 | x_mask_train, x_mask_test, \ 357 | y_mask_train, y_mask_test 358 | 359 | 360 | def train(context_encoder='baseline', 361 | corpus=None, 362 | # optimiser 363 | optimizer='adam', 364 | learning_rate=0.001, 365 | # model params 366 | embed_full_text_by='word', 367 | seq_maxlen=500, 368 | summary_maxlen=200, 369 | summary_context_length=10, 370 | internal_representation_dim=2000, 371 | attention_weight_max_roll=5, 372 | # training params 373 | l2_penalty_coeff=0.0, 374 | train_split=0.75, 375 | epochs=float('inf'), 376 | minibatch_size=20, 377 | seed=None, 378 | dropout_rate=None, 379 | # model load/save 380 | save_params='ass_params.pkl', 381 | save_params_every=5, 382 | validate_every=5, 383 | print_every=5, 384 | # summary generation on the validation set 385 | generate_summary=False, 386 | summary_search_beam_size=2): 387 | params, tparams, x_embedder, y_embedder = init_params( 388 | context_encoder=context_encoder, 389 | corpus=corpus, 390 | optimizer=optimizer, 391 | learning_rate=learning_rate, 392 | embed_full_text_by=embed_full_text_by, 393 | seq_maxlen=seq_maxlen, 394 | summary_maxlen=summary_maxlen, 395 | summary_context_length=summary_context_length, 396 | internal_representation_dim=internal_representation_dim, 397 | attention_weight_max_roll=attention_weight_max_roll, 398 | l2_penalty_coeff=l2_penalty_coeff, 399 | train_split=train_split, 400 | epochs=epochs, 401 | minibatch_size=minibatch_size, 402 | seed=seed, 403 | dropout_rate=dropout_rate, 404 | summary_search_beam_size=summary_search_beam_size 405 | ) 406 | 407 | # minibatch of encoded texts 408 | # size batchsize-by-seq_maxlen 409 | x = T.cast(T.matrix(dtype=theano.config.floatX), 'int32') 410 | x_mask = T.matrix(dtype=theano.config.floatX) 411 | 412 | # summaries for the minibatch of texts 413 | y = T.cast(T.matrix(dtype=theano.config.floatX), 'int32') 414 | y_mask = T.matrix(dtype=theano.config.floatX) 415 | 416 | nll = training_model_output(x, y, x_mask, y_mask, 417 | params, tparams, y_embedder) 418 | cost = nll.mean() 419 | 420 | tparams_to_optimise = {key: tparams[key] for key in tparams 421 | if (not key.endswith('emb')) and key != 'att_P_conv'} 422 | cost += params['l2_penalty_coeff'] * sum([(p ** 2).sum() 423 | for k, p in tparams_to_optimise.items()]) 424 | inputs = [x, y, x_mask, y_mask] 425 | 426 | # after all regularizers - compile the computational graph for cost 427 | print('Building f_cost... ', end='') 428 | f_cost = theano.function(inputs, cost, allow_input_downcast=True) 429 | print('Done') 430 | 431 | print('Computing gradient... ', end='') 432 | grads = T.grad(cost, list(tparams_to_optimise.values())) 433 | print('Done') 434 | 435 | # compile the optimizer, the actual computational graph is compiled here 436 | lr = T.scalar(name='lr') 437 | print('Building optimizers... ', end='') 438 | f_grad_shared, f_update = eval(optimizer)(lr, tparams_to_optimise, grads, inputs, cost) 439 | print('Done') 440 | 441 | print('Building summary candidate token generator... ', end='') 442 | f_best_candidates = tfunc_best_candidate_tokens(params, tparams) 443 | print('Done') 444 | 445 | print('Loading corpus... ', end='') 446 | x_train, x_test, y_train, y_test, \ 447 | x_mask_train, x_mask_test, \ 448 | y_mask_train, y_mask_test \ 449 | = load_corpus(params, tparams, x_embedder, y_embedder) 450 | n_train_batches = int(x_train.shape[0] / params['minibatch_size']) 451 | n_test_batches = int(x_test.shape[0] / params['minibatch_size']) 452 | print('Done') 453 | 454 | print('Optimization') 455 | test_ids_to_summarize = sample(range(x_test.shape[0]), 5) 456 | for epoch in range(epochs): 457 | print('Epoch', epoch) 458 | 459 | # training of all minibatches 460 | params['phase'] = 'training' 461 | training_costs = [] 462 | for batch_id in range(n_train_batches): 463 | if batch_id % print_every == 0: 464 | print('Batch {:} '.format(batch_id), end='') 465 | # compute cost, grads and copy grads to shared variables 466 | # use_noise.set_value(1.) 467 | current_batch = range(batch_id * params['minibatch_size'], 468 | (batch_id + 1) * params['minibatch_size']) 469 | cost = f_grad_shared(x_train[current_batch, :], 470 | y_train[current_batch, :], 471 | x_mask_train[current_batch, :], 472 | y_mask_train[current_batch, :]) 473 | cost = np.asscalar(cost) 474 | training_costs.append(cost) 475 | # do the update on parameters 476 | f_update(learning_rate) 477 | if batch_id % print_every == 0: 478 | print('Cost {:.4f}'.format(cost)) 479 | print('Epoch {:} mean training cost {:.4f}'.format( 480 | epoch, np.mean(training_costs) 481 | )) 482 | 483 | # save the params 484 | if epoch % save_params_every == 0: 485 | print('Saving... ', end='') 486 | save_params_(params, tparams, save_params) 487 | print('Done') 488 | 489 | # validate 490 | # compute the metrics and generate summaries (if requested) 491 | params['phase'] = 'test' 492 | if epoch % validate_every == 0: 493 | print('Validating') 494 | validate_costs = [] 495 | for batch_id in range(n_test_batches): 496 | if batch_id % print_every == 0: 497 | print('Batch {:} '.format(batch_id), end='') 498 | current_batch = range(batch_id * params['minibatch_size'], 499 | (batch_id + 1) * params['minibatch_size']) 500 | validate_cost = f_cost(x_test[current_batch, :], 501 | y_test[current_batch, :], 502 | x_mask_test[current_batch, :], 503 | y_mask_test[current_batch, :]) 504 | validate_cost = np.asscalar(validate_cost) 505 | validate_costs.append(validate_cost) 506 | if batch_id % print_every == 0: 507 | print('Validation cost {:.4f}'.format(validate_cost)) 508 | print('Epoch {:} mean validation cost {:.4f}'.format( 509 | epoch, np.mean(validate_costs) 510 | )) 511 | 512 | if generate_summary: 513 | print('Generating summary') 514 | for i in test_ids_to_summarize: 515 | summary_token_ids = summarize( 516 | x_test[i, :].flatten(), x_mask_test[i, :].flatten(), 517 | f_best_candidates, 518 | params, tparams, 519 | y_embedder) 520 | print('Sample :', y_embedder.documentFromVector(summary_token_ids)) 521 | print('Truth :', y_embedder.documentFromVector(y_test[i, :])[:20]) 522 | --------------------------------------------------------------------------------