├── .gitignore ├── README.md ├── build_vocabularies.py ├── compute_rouge.py ├── convert_rcdata.py ├── query_offset_files.py └── vocabulary.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Python 2 | *.pyc 3 | 4 | # PyCharm project 5 | .idea 6 | 7 | rouge 8 | output 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Querysum dataset 2 | 3 | Code for the dataset presented in the thesis *Query-Based Abstractive Summarization Using Neural Networks* by Johan Hasselqvist and Niklas Helmertz. The code for the model can be found at a [separate repo](https://github.com/helmertz/querysum). 4 | 5 | ## Requirements 6 | 7 | - Python 2.7 or 3.5 8 | 9 | ### NLTK 10 | 11 | The NLTK (Natural Language Toolkit) package can be installed using `pip`: 12 | 13 | ``` 14 | pip install nltk 15 | ``` 16 | 17 | Additionally, the NLTK data package `punkt` needs to be downloaded. For installing packages, see the official guide [Installing NLTK Data](http://www.nltk.org/data.html). 18 | 19 | ## Getting CNN/Daily Mail data 20 | 21 | The dataset for query-based abstractive summarization is created by converting an existing dataset for question answering, released by [DeepMind](https://github.com/deepmind/rc-data). Archives containing the processed DeepMind dataset can be downloaded at [http://cs.nyu.edu/~kcho/DMQA/](http://cs.nyu.edu/~kcho/DMQA/), which we used. Both the `stories` and `questions` archives are required for the conversion, from either news organization, or both. To use both, merge the extracted directories, for `questions` and `stories` separately. 22 | 23 | ## Conversion 24 | 25 | Replacing the parts in angle brackets, the dataset can be constructed by running: 26 | 27 | ``` 28 | python convert_rcdata.py \ 29 | \ 30 | \ 31 | 32 | ``` 33 | 34 | This creates separate directories in the output directory for training, validation and test sets. 35 | 36 | ## Creating vocabularies 37 | 38 | The repo contains a script for generating vocabularies, sorted by word frequency. They can be constructed by running: 39 | 40 | ``` 41 | python build_vocabularies.py \ 42 | \ 43 | 44 | ``` -------------------------------------------------------------------------------- /build_vocabularies.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import argparse 4 | import io 5 | import os 6 | import sys 7 | 8 | from vocabulary import Vocabulary 9 | 10 | 11 | class Vocabularies: 12 | def __init__(self): 13 | self.document_vocabulary = Vocabulary() 14 | self.summary_vocabulary = Vocabulary() 15 | self.full_vocabulary = Vocabulary() 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('input_dir') 21 | parser.add_argument('output_dir') 22 | options = parser.parse_args() 23 | 24 | print("Counting words...") 25 | sys.stdout.flush() 26 | vocabularies = extract_vocabularies(options) 27 | print("Done") 28 | 29 | print("Saving to output directory...") 30 | sys.stdout.flush() 31 | write_vocabularies(options, vocabularies) 32 | print("Done") 33 | 34 | 35 | def extract_vocabularies(options): 36 | vocabularies = Vocabularies() 37 | 38 | document_vocabulary = extract_vocabulary(options, 'documents') 39 | summary_vocabulary = extract_vocabulary(options, 'references') 40 | 41 | vocabularies.document_vocabulary = document_vocabulary 42 | vocabularies.summary_vocabulary = summary_vocabulary 43 | vocabularies.full_vocabulary = document_vocabulary.merge(summary_vocabulary) 44 | 45 | return vocabularies 46 | 47 | 48 | def extract_vocabulary(options, dir_name): 49 | vocabulary = Vocabulary() 50 | 51 | dir_path = os.path.join(options.input_dir, dir_name) 52 | for filename in os.listdir(dir_path): 53 | file_path = os.path.join(dir_path, filename) 54 | if os.path.isfile(file_path): 55 | with io.open(file_path, 'r', encoding='utf-8') as file: 56 | raw_text = file.read() 57 | vocabulary.expand_vocab(raw_text.split()) 58 | return vocabulary 59 | 60 | 61 | def write_vocabularies(options, vocabularies): 62 | # Write all three types of vocabulary to file 63 | for vocabulary_name, vocabulary in vocabularies.__dict__.items(): 64 | write_vocabulary(options, '{}.txt'.format(vocabulary_name), vocabulary) 65 | 66 | 67 | def write_vocabulary(options, name, vocabulary): 68 | sorted_vocabulary = vocabulary.get_sorted_vocabulary() 69 | with io.open(os.path.join(options.output_dir, name), 'w', encoding='utf-8') as file: 70 | for word, count in sorted_vocabulary: 71 | file.write('{} {}\n'.format(word, count)) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /compute_rouge.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import argparse 4 | from os import path 5 | 6 | from pyrouge import Rouge155 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('system_dir') 12 | parser.add_argument('reference_dir') 13 | parser.add_argument('rouge_dir') 14 | options = parser.parse_args() 15 | 16 | rouge_score = compute_rouge(options) 17 | print(rouge_score) 18 | 19 | 20 | def compute_rouge(options): 21 | system_dir = options.system_dir 22 | reference_dir = options.reference_dir 23 | 24 | rouge_data = path.join(options.rouge_dir, 'data') 25 | rouge_args = '-e {} -c 95 -2 4 -U -n 4 -w 1.2 -a'.format(rouge_data) 26 | 27 | rouge = Rouge155(rouge_dir=options.rouge_dir, rouge_args=rouge_args) 28 | rouge.system_dir = system_dir 29 | rouge.model_dir = reference_dir 30 | rouge.system_filename_pattern = '(\d+.\d+).txt' 31 | rouge.model_filename_pattern = '[A-Z].#ID#.txt' 32 | 33 | return rouge.convert_and_evaluate() 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /convert_rcdata.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import argparse 4 | import hashlib 5 | import io 6 | import math 7 | import os 8 | import re 9 | import shutil 10 | import sys 11 | 12 | import nltk 13 | 14 | 15 | class ArticleData: 16 | def __init__(self, article_text=None, query_to_summaries=None, entities=None): 17 | self.article_text = article_text 18 | if query_to_summaries is None: 19 | self.query_to_summaries = {} 20 | else: 21 | self.query_to_summaries = query_to_summaries 22 | self.entities = entities 23 | 24 | 25 | class Summaries: 26 | def __init__(self, first_query_sentence=None, reference_summaries=None, synthetic_summary=None): 27 | self.first_query_sentence = first_query_sentence 28 | self.reference_summaries = reference_summaries 29 | self.synthetic_summary = synthetic_summary 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('stories_dir') 35 | parser.add_argument('questions_dir') 36 | parser.add_argument('output_dir') 37 | parser.add_argument('--validation_test_fraction', type=float, default=0.015) 38 | parser.add_argument('--save_first_query_sentences', action='store_true') 39 | parser.add_argument('--save_document_lengths', action='store_true') 40 | parser.add_argument('--save_synthetic_references', action='store_true') 41 | options = parser.parse_args() 42 | 43 | print("Extracting summarization data...") 44 | sys.stdout.flush() 45 | article_lookup = extract_article_lookup(options) 46 | print("Done") 47 | 48 | print("Saving to output directory...") 49 | sys.stdout.flush() 50 | write_data(article_lookup, options) 51 | print("Done") 52 | 53 | 54 | def extract_article_lookup(options): 55 | # Initialize dictionary to contain both reference summaries and the first query sentence 56 | article_lookup = {} 57 | 58 | # Look through all question files 59 | num_processed_items = 0 60 | for root, dirs, files in os.walk(options.questions_dir): 61 | for question_file_name in files: 62 | if not question_file_name.endswith('.question'): 63 | continue 64 | with io.open(os.path.join(root, question_file_name), 'r', encoding='utf-8') as question_file: 65 | question_text = question_file.read() 66 | url, query, entities = extract_from_question(question_text) 67 | 68 | article_data = article_lookup.get(url) 69 | 70 | if article_data is None: 71 | # First time article is processed 72 | article_data = ArticleData(entities=entities) 73 | article_lookup[url] = article_data 74 | 75 | # Check if summaries for the document-query pair has already been found 76 | summaries = article_data.query_to_summaries.get(join(query)) 77 | if summaries is not None: 78 | continue 79 | 80 | extract_from_story(query, article_data, options.stories_dir, url) 81 | 82 | # Print progress 83 | num_processed_items += 1 84 | if num_processed_items % 1000 == 0: 85 | print('{} items processed...'.format(num_processed_items)) 86 | return article_lookup 87 | 88 | 89 | def extract_from_question(question_text): 90 | lines = question_text.splitlines() 91 | 92 | url = lines[0] 93 | placeholder = lines[6] 94 | entity_mapping_lines = lines[8:] 95 | 96 | entity_dictionary = get_entity_dictionary(entity_mapping_lines) 97 | 98 | query = entity_dictionary[placeholder] 99 | tokenized_query = tokenize(query) 100 | entities = '\n'.join([join(tokenize(entity)) for entity in entity_dictionary.values()]) 101 | 102 | return url, tokenized_query, entities 103 | 104 | 105 | def get_entity_dictionary(entity_mapping_lines): 106 | entity_dictionary = {} 107 | for mapping in entity_mapping_lines: 108 | entity, name = mapping.split(':', 1) 109 | entity_dictionary[entity] = name 110 | return entity_dictionary 111 | 112 | 113 | def generate_synthetic_summary(document, highlight): 114 | return [word for word in highlight if word in document] 115 | 116 | 117 | def extract_from_story(query, article_data, stories_path, url): 118 | # Find original story file which is named using the URL hash 119 | url_hash = hash_hex(url) 120 | with io.open(os.path.join(stories_path, '{}.story'.format(url_hash)), 'r', encoding='utf-8') as file: 121 | raw_article_text = file.read() 122 | 123 | highlight_start_index = raw_article_text.find('@highlight') 124 | 125 | article_text = raw_article_text[:highlight_start_index].strip() 126 | highlight_text = raw_article_text[highlight_start_index:].strip() 127 | 128 | if len(article_text) == 0: 129 | # There are stories with only highlights, skip these 130 | return 131 | 132 | # Extract all highlights 133 | highlights = re.findall('@highlight\n\n(.*)', highlight_text) 134 | 135 | tokenized_highlights = map(tokenize, highlights) 136 | tokenized_query_highlights = [] 137 | 138 | for highlight in tokenized_highlights: 139 | if contains_sublist(highlight, query): 140 | tokenized_query_highlights.append(highlight) 141 | 142 | if len(tokenized_query_highlights) == 0: 143 | # For now, ignore if sequence of tokens not found in any highlight. It happens for example when query is 144 | # "American" and highlight contains "Asian-American". 145 | return 146 | 147 | first_query_sentence = get_first_query_sentence(query, article_text) 148 | 149 | synthetic_summary = generate_synthetic_summary(first_query_sentence, tokenized_query_highlights[0]) 150 | 151 | summaries = Summaries(join(first_query_sentence), map(join, tokenized_query_highlights), join(synthetic_summary)) 152 | 153 | if article_data.article_text is None: 154 | article_data.article_text = join(tokenize(article_text)) 155 | 156 | article_data.query_to_summaries[join(query)] = summaries 157 | 158 | 159 | def contains_sublist(list_, sublist): 160 | for i in range(len(list_)): 161 | if list_[i:(i + len(sublist))] == sublist: 162 | return True 163 | return False 164 | 165 | 166 | def get_first_query_sentence(query, text): 167 | # Find sentence containing the placeholder 168 | sentences = [] 169 | for paragraph in text.splitlines(): 170 | sentences.extend(nltk.sent_tokenize(paragraph)) 171 | 172 | for sentence in sentences: 173 | tokenized_sentence = tokenize(sentence) 174 | if contains_sublist(tokenized_sentence, query): 175 | first_query_sentence = tokenized_sentence 176 | break 177 | else: 178 | # Query text not found in document, pick first sentence instead 179 | first_query_sentence = sentences[0] 180 | 181 | # If ending with a period, remove it, to match most of the highlights 182 | # (some are however single sentences ending with period) 183 | if first_query_sentence[-1] == '.': 184 | first_query_sentence = first_query_sentence[:-1] 185 | 186 | return first_query_sentence 187 | 188 | 189 | apostrophe_words = { 190 | "''", 191 | "'s", 192 | "'re", 193 | "'ve", 194 | "'m", 195 | "'ll", 196 | "'d", 197 | "'em", 198 | "'n'", 199 | "'n", 200 | "'cause", 201 | "'til", 202 | "'twas", 203 | "'till" 204 | } 205 | 206 | 207 | def lower_and_fix_apostrophe_words(word): 208 | regex = re.compile("^'\D|'\d+[^s]$") # 'g | not '90s 209 | word = word.lower() 210 | 211 | if regex.match(word) and word not in apostrophe_words: 212 | word = "' " + word[1:] 213 | return word 214 | 215 | 216 | def tokenize(text): 217 | # The Stanford tokenizer may be preferable since it was used for pre-trained GloVe embeddings. However, it appears 218 | # to be unreasonably slow through the NLTK wrapper. 219 | 220 | tokenized = nltk.tokenize.word_tokenize(text) 221 | tokenized = [lower_and_fix_apostrophe_words(word) for word in tokenized] 222 | return tokenized 223 | 224 | 225 | def join(text): 226 | return " ".join(text) 227 | 228 | 229 | def hash_hex(string): 230 | hash_ = hashlib.sha1() 231 | hash_.update(string.encode('utf-8')) 232 | return hash_.hexdigest() 233 | 234 | 235 | def write_data(articles, options): 236 | output_dir = options.output_dir 237 | 238 | shutil.rmtree(output_dir, True) 239 | 240 | total_reference_count = 0 241 | 242 | # Ignore articles where no query was found 243 | filtered_articles = [item for item in articles.items() if len(item[1].query_to_summaries) > 0] 244 | 245 | # Get articles ordered by hash values to break possible patterns in ordering 246 | sorted_articles = sorted(filtered_articles, key=lambda article_tuple: hash_hex(article_tuple[0])) 247 | 248 | num_validation_test_documents = math.ceil(options.validation_test_fraction * len(sorted_articles)) 249 | 250 | validation_articles = sorted_articles[:num_validation_test_documents] 251 | test_articles = sorted_articles[num_validation_test_documents:(2 * num_validation_test_documents)] 252 | training_articles = sorted_articles[(2 * num_validation_test_documents):] 253 | out_sets = [('validation', validation_articles), 254 | ('test', test_articles), 255 | ('training', training_articles)] 256 | 257 | for set_name, articles in out_sets: 258 | output_set_dir = os.path.join(output_dir, set_name) 259 | 260 | out_names = ['queries', 'documents', 'references', 'entities'] 261 | 262 | if options.save_first_query_sentences: 263 | out_names.append('first_query_sentences') 264 | if options.save_synthetic_references: 265 | out_names.append('synthetic_references') 266 | 267 | out_paths = [os.path.join(output_set_dir, x) for x in out_names] 268 | out_name_to_path = dict(zip(out_names, out_paths)) 269 | 270 | for out_path in out_paths: 271 | os.makedirs(out_path) 272 | 273 | document_id = 1 274 | reference_num_tokens_lines = [] 275 | 276 | for url, article_data in articles: 277 | query_to_summaries = article_data.query_to_summaries 278 | 279 | # Ignore if no queries were found for article 280 | if len(query_to_summaries) == 0: 281 | continue 282 | 283 | article_dir_content_mapping = [(out_name_to_path['documents'], article_data.article_text), 284 | (out_name_to_path['entities'], article_data.entities)] 285 | 286 | document_filename = '{}.txt'.format(document_id) 287 | 288 | for dir_, content in article_dir_content_mapping: 289 | with io.open(os.path.join(dir_, document_filename), 'w', encoding='utf-8') as file: 290 | file.write(content) 291 | 292 | query_id = 1 293 | for query, summaries in sorted(query_to_summaries.items(), 294 | key=lambda query_tuple: hash_hex(query_tuple[0])): 295 | query_filename = '{}.{}.txt'.format(document_id, query_id) 296 | 297 | dir_content_mapping = [(out_name_to_path['queries'], query)] 298 | if options.save_first_query_sentences: 299 | dir_content_mapping.append( 300 | (out_name_to_path['first_query_sentences'], summaries.first_query_sentence)) 301 | 302 | for dir_, content in dir_content_mapping: 303 | with io.open(os.path.join(dir_, query_filename), 'w', encoding='utf-8') as file: 304 | file.write(content) 305 | 306 | if options.save_synthetic_references: 307 | synthetic_filename = 'A.{}.{}.txt'.format(document_id, query_id) 308 | with io.open(os.path.join(out_name_to_path['synthetic_references'], synthetic_filename), 'w', 309 | encoding='utf-8') as file: 310 | file.write(summaries.synthetic_summary) 311 | 312 | # Save reference summaries 313 | reference_id = 0 314 | for reference_summary in sorted(summaries.reference_summaries, key=hash_hex): 315 | reference_filename = '{}.{}.{}.txt'.format(chr(ord('A') + reference_id), document_id, query_id) 316 | with io.open( 317 | os.path.join(out_name_to_path['references'], reference_filename), 318 | 'w', 319 | encoding='utf-8') as file: 320 | file.write(reference_summary) 321 | num_tokens_line = '{} {}'.format(reference_filename, len(article_data.article_text.split())) 322 | reference_num_tokens_lines.append(num_tokens_line) 323 | 324 | reference_id += 1 325 | total_reference_count += 1 326 | 327 | # Print progress 328 | if total_reference_count % 1000 == 0: 329 | print('{} items processed...'.format(total_reference_count)) 330 | query_id += 1 331 | document_id += 1 332 | 333 | if options.save_document_lengths: 334 | with io.open(os.path.join(output_set_dir, 'input_lengths.txt'), 'w', encoding='utf-8') as file: 335 | for line in reference_num_tokens_lines: 336 | file.write(line) 337 | file.write('\n') 338 | 339 | 340 | if __name__ == '__main__': 341 | main() 342 | -------------------------------------------------------------------------------- /query_offset_files.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import argparse 4 | import io 5 | import os 6 | import re 7 | from os import path 8 | 9 | pattern = re.compile('(\d+)\.(\d+)\.(.*)') 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('source_dir') 15 | parser.add_argument('--out_dir', default='query_offset_files') 16 | options = parser.parse_args() 17 | 18 | document_to_num_queries = {} 19 | 20 | for source_filename in os.listdir(options.source_dir): 21 | source_path = path.join(options.source_dir, source_filename) 22 | if path.isfile(source_path): 23 | document_id, _, _ = pattern.search(source_filename).groups() 24 | if document_id in document_to_num_queries.keys(): 25 | document_to_num_queries[document_id] += 1 26 | else: 27 | document_to_num_queries[document_id] = 1 28 | 29 | if not path.isdir(options.out_dir): 30 | os.makedirs(options.out_dir) 31 | 32 | for source_filename in os.listdir(options.source_dir): 33 | source_path = path.join(options.source_dir, source_filename) 34 | if path.isfile(source_path): 35 | document_id, query_id, file_ending = pattern.search(source_filename).groups() 36 | num_queries = document_to_num_queries[document_id] 37 | offset_query_id = int(query_id) % num_queries + 1 38 | out_filepath = path.join(options.out_dir, 39 | '{}.{}.{}'.format(document_id, offset_query_id, file_ending)) 40 | with io.open(source_path, 'r', encoding='utf-8') as source_file: 41 | content = source_file.read() 42 | 43 | with io.open(out_filepath, 'w', encoding='utf-8') as out_file: 44 | out_file.write(content) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /vocabulary.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | from collections import defaultdict 4 | 5 | 6 | class Vocabulary: 7 | def __init__(self): 8 | self.word_to_count = defaultdict(int) 9 | 10 | def expand_vocab(self, tokens): 11 | for token in tokens: 12 | self.word_to_count[token] += 1 13 | 14 | def get_sorted_vocabulary(self): 15 | # Primarily sort by descending occurrence count and secondarily alphabetically 16 | return sorted(self.word_to_count.items(), 17 | key=lambda word_count_tuple: (-word_count_tuple[1], word_count_tuple[0])) 18 | 19 | def merge(self, other_vocabulary): 20 | merged = Vocabulary() 21 | merged.word_to_count = self.word_to_count.copy() 22 | for word, count in other_vocabulary.word_to_count.items(): 23 | merged.word_to_count[word] += count 24 | return merged 25 | --------------------------------------------------------------------------------