├── LICENSE ├── README.md ├── acl20-paper.pdf ├── acl20-slides.pdf ├── data └── cc_storage │ └── cc_files.txt ├── dataset_generation ├── README.md ├── step1_store_wcep_html.py ├── step2_process_wcep_html.py ├── step3_snapshot_source_urls.py ├── step4_scrape_sources.py └── step5_combine_dataset.py ├── dataset_reproduction ├── combine_and_split.py ├── extract_cc_articles.py ├── extract_wcep_articles.py ├── requirements.txt └── utils.py ├── experiments ├── baselines.py ├── data.py ├── evaluate.py ├── oracles.py ├── requirements.txt ├── sent_splitter.py ├── summarizer.py └── utils.py └── wcep_getting_started.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2020 Demian Gholipour Ghalandari 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## WCEP Dataset 2 | ### Overview 3 | The WCEP dataset for multi-document summarization (MDS) consists of short, human-written summaries about news events, obtained from the [Wikipedia Current Events Portal](https://en.wikipedia.org/wiki/Portal:Current_events "Wikipedia Current Events Portal") (WCEP), each paired with a cluster of news articles associated with an event. These articles consist of sources cited by editors on WCEP, and are extended with articles automatically obtained from the [Common Crawl News dataset](https://commoncrawl.org/2016/10/news-dataset-available/ "CommonCrawl News dataset"). For more information about the dataset and experiments, check out our ACL 2020 paper: *A Large-Scale Multi-Document Summarization Dataset from the Wikipedia Current Events Portal.* ([Paper](https://www.aclweb.org/anthology/2020.acl-main.120/), [Slides](acl20-slides.pdf)) 4 | 5 | ### Colab Notebook 6 | 7 | 8 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/complementizer/wcep-mds-dataset/blob/master/wcep_getting_started.ipynb) 9 | 10 | You can use this notebook to 11 | * download & inspect the dataset 12 | * run extractive baselines & oracles 13 | * evaluate summaries 14 | 15 | Otherwise, check out the instructions below. 16 | 17 | ### Download Dataset 18 | 19 | Update 6.10.20: [an extracted version of the dataset can be downloaded here](https://drive.google.com/drive/folders/1T5wDxu4ajFwEq77dG88oE95e8ppREamg?usp=sharing) 20 | 21 | ### Loading the Dataset 22 | We store the dataset in a gzipped jsonl format, where each line corresponds to a news event, associated with a summary and a cluster of news articles, and some metadata, such as date and category. The summarization task is to generate the summary from the news articles. 23 | 24 | ```python 25 | import json, gzip 26 | 27 | def read_jsonl_gz(path): 28 | with gzip.open(path) as f: 29 | for l in f: 30 | yield json.loads(l) 31 | 32 | val_data = list(read_jsonl_gz('/val.jsonl.gz')) 33 | c = val_data[404] 34 | summary = c['summary'] # human-written summary 35 | articles = c['articles'] # cluster of articles 36 | ``` 37 | 38 | ### Extractive Baselines and Evaluation 39 | 40 | We also provide code to run several extractive baselines and evaluate 41 | generated summaries. Note that we just use the ROUGE wrapper of the [newsroom library](https://github.com/lil-lab/newsroom) to compute ROUGE scores. 42 | 43 | Install dependencies: 44 | 45 | `pip install -r experiments/requirements.txt` 46 | 47 | `cd` to [experiments](experiments) to run this snippet: 48 | 49 | ```python 50 | from utils import read_jsonl_gz 51 | from baselines import TextRankSummarizer 52 | from evaluate import evaluate 53 | from pprint import pprint 54 | 55 | val_data = list(read_jsonl_gz('/val.jsonl.gz')) 56 | 57 | textrank = TextRankSummarizer() 58 | 59 | settings = { 60 | 'max_len': 40, 'len_type': 'words', 61 | 'in_titles': False, 'out_titles': False, 62 | 'min_sent_tokens': 7, 'max_sent_tokens': 60, 63 | } 64 | 65 | inputs = [c['articles'][:10] for c in val_data[:10]] 66 | ref_summaries = [c['summary'] for c in val_data[:10]] 67 | pred_summaries = [textrank.summarize(articles, **settings) for articles in inputs] 68 | results = evaluate(ref_summaries, pred_summaries) 69 | pprint(results) 70 | ``` 71 | 72 | ### Dataset Generation 73 | 74 | **Note:** This is currently not required as the dataset is available for download. 75 | 76 | We currently do not provide the entire dataset for download. Instead, we share the summaries from WCEP and scripts that obtain the associated news articles. Make sure to set `--jobs` to your avaible number of CPUs to speed things up. Both scripts can be interrupted and resumed by just repeating the same command. To restart from scratch, add `--override`. 77 | 78 | Install dependencies: 79 | ```bash 80 | pip install dataset_generation/requirements.txt 81 | ``` 82 | 83 | At first, download the inital [dataset without articles](https://drive.google.com/file/d/1LGYFKGzCgvdllwIQHDF5qSxtan1Y0Re9/view?usp=sharing "dataset without articles"), place it in `/data` (unzipped). 84 | ##### 1) Extracting articles from WCEP 85 | This script extracts news articles from various news sources cited on WCEP using [newspaper3k](https://github.com/codelucas/newspaper "newspaper3k") from the [Internet Archive Wayback Machine](https://archive.org/). We previously requested snapshots of all source articles that were not archived yet. 86 | 87 | ```bash 88 | python extract_wcep_articles.py \ 89 | --i data/initial_dataset.jsonl \ 90 | --o data/wcep_articles.jsonl \ 91 | --batchsize 200 \ 92 | --jobs 16 \ 93 | --repeat-failed 94 | ``` 95 | If any downloads fail due to timeouts, simply repeat the same command. It will only attempt to extract the missing articles. 96 | ##### 2) Extracting articles from Common Crawl 97 | This script extracts articles from Common Crawl News, which is divided into ~6000 files of 1GB size each. These are downloaded and searched one at a time. The relevant articles are extracted from HTML in parallel using newspaper3k. 98 | ```bash 99 | python extract_cc_articles.py \ 100 | --storage data/cc_storage \ 101 | --dataset data/initial_dataset.jsonl \ 102 | --batchsize 200 \ 103 | --max-cluster-size 100 \ 104 | --jobs 16 105 | ``` 106 | 107 | This process takes a long time (few days!). We are working on speeding it up. 108 | `--max-cluster-size 100` already reduces the time: only up to 100 articles of each cluster in the dataset are extracted. This corresponds to the dataset version used in the experiments in our paper ("WCEP-100"). 109 | ##### 3) Combine and split 110 | Finally, we need to group articles and summaries belonging together, and split the dataset into separate train/validation/test files. If `--max-cluster-size` was used in the previous step, use that here accordingly. 111 | ```bash 112 | python combine_and_split.py \ 113 | --dataset data/initial_dataset.jsonl \ 114 | --cc-articles data/cc_storage/cc_articles.jsonl \ 115 | --wcep-articles data/wcep_articles.jsonl \ 116 | --max-cluster-size 100 \ 117 | --o data/wcep_dataset 118 | ``` 119 | 120 | ### Citation 121 | 122 | If you find this dataset useful, please cite: 123 | ``` 124 | @inproceedings{gholipour-ghalandari-etal-2020-large, 125 | title = "A Large-Scale Multi-Document Summarization Dataset from the {W}ikipedia Current Events Portal", 126 | author = "Gholipour Ghalandari, Demian and 127 | Hokamp, Chris and 128 | Pham, Nghia The and 129 | Glover, John and 130 | Ifrim, Georgiana", 131 | booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", 132 | month = jul, 133 | year = "2020", 134 | address = "Online", 135 | publisher = "Association for Computational Linguistics", 136 | url = "https://www.aclweb.org/anthology/2020.acl-main.120", 137 | pages = "1302--1308", 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /acl20-paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/complementizer/wcep-mds-dataset/8aa8418d55a4e6710277bed18bfc5349b50a5d96/acl20-paper.pdf -------------------------------------------------------------------------------- /acl20-slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/complementizer/wcep-mds-dataset/8aa8418d55a4e6710277bed18bfc5349b50a5d96/acl20-slides.pdf -------------------------------------------------------------------------------- /dataset_generation/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/complementizer/wcep-mds-dataset/8aa8418d55a4e6710277bed18bfc5349b50a5d96/dataset_generation/README.md -------------------------------------------------------------------------------- /dataset_generation/step1_store_wcep_html.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | import pathlib 4 | from bs4 import BeautifulSoup 5 | 6 | ROOT_URL = 'https://en.wikipedia.org/wiki/Portal:Current_events' 7 | 8 | 9 | def extract_month_urls(): 10 | html = requests.get(ROOT_URL).text 11 | soup = BeautifulSoup(html, 'html.parser') 12 | e = soup.find('div', class_='NavContent hlist') 13 | urls = [x['href'] for x in e.find_all('a')] 14 | urls = [url for url in urls if url.count('/') == 3] 15 | urls = ['https://en.wikipedia.org' + url for url in urls] 16 | return urls 17 | 18 | 19 | def main(args): 20 | out_dir = pathlib.Path(args.o) 21 | if not out_dir.exists(): 22 | out_dir.mkdir() 23 | 24 | month_urls = extract_month_urls() 25 | print(f'Storing {len(month_urls)} WCEP month pages:') 26 | 27 | for url in month_urls: 28 | print(url) 29 | fname = url.split('/')[-1] + '.html' 30 | html = requests.get(url).text 31 | fpath = out_dir / fname 32 | with open(fpath, 'w') as f: 33 | f.write(html) 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--o', type=str, help='output directory', required=True) 39 | main(parser.parse_args()) 40 | -------------------------------------------------------------------------------- /dataset_generation/step2_process_wcep_html.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import calendar 4 | import pathlib 5 | import collections 6 | import arrow 7 | import json 8 | import uuid 9 | from bs4 import BeautifulSoup 10 | 11 | 12 | def make_month_to_int(): 13 | month_to_int = {} 14 | for i, month in enumerate(calendar.month_name): 15 | if i > 0: 16 | month_to_int[month] = i 17 | return month_to_int 18 | 19 | 20 | EVENTS = [] 21 | TOPIC_TO_SUB = collections.defaultdict(set) 22 | TOPIC_TO_SUPER = collections.defaultdict(set) 23 | EVENT_ID_COUNTER = 0 24 | MONTH_TO_INT = make_month_to_int() 25 | 26 | 27 | class Event: 28 | def __init__(self, text, id, date, category=None, stories=None, 29 | wiki_links=None, references=None): 30 | 31 | # print(f'[{date}] {text}') 32 | # print(stories) 33 | # for url in references: 34 | # print(url) 35 | # print() 36 | self.text = text 37 | self.id = id 38 | self.date = date 39 | self.category = category 40 | self.stories = stories if stories else [] 41 | self.wiki_links = wiki_links if wiki_links else [] 42 | self.references = references if references else [] 43 | 44 | def to_json_dict(self): 45 | return { 46 | 'text': self.text, 47 | 'id': self.id, 48 | 'date': str(self.date), 49 | 'category': self.category, 50 | 'stories': self.stories, 51 | 'wiki_links': self.wiki_links, 52 | 'references': self.references, 53 | } 54 | 55 | 56 | def url_to_time(url, month_to_num): 57 | tail = url.split('/')[-1] 58 | month, year = tail.split('_') 59 | m = month_to_num[month] 60 | y = int(year) 61 | return datetime.datetime(year=y, month=m, day=1) 62 | 63 | 64 | def extract_date(date_div): 65 | date = date_div.find('span', class_='summary') 66 | date = date.text.split('(')[1].split(')')[0] 67 | date = arrow.get(date) 68 | date = datetime.date(date.year, date.month, date.day) 69 | return date 70 | 71 | 72 | def wiki_link_to_id(s): 73 | return s.split('/wiki/')[1] 74 | 75 | 76 | def recursively_extract_bullets(e, 77 | date, 78 | category, 79 | prev_stories, 80 | is_root=False): 81 | global EVENT_ID_COUNTER 82 | if is_root: 83 | lis = e.find_all('li', recursive=False) 84 | result = [recursively_extract_bullets(li, date, category, []) 85 | for li in lis] 86 | return result 87 | else: 88 | ul = e.find('ul') 89 | if ul: 90 | # intermediate "node", e.g. a story an event is assigned to 91 | 92 | links = e.find_all('a', recursive=False) 93 | new_stories = [] 94 | for link in links: 95 | try: 96 | new_stories.append(wiki_link_to_id(link.get('href'))) 97 | except: 98 | print("not a wiki link:", link) 99 | lis = ul.find_all('li', recursive=False) 100 | 101 | for prev_story in prev_stories: 102 | for new_story in new_stories: 103 | TOPIC_TO_SUB[prev_story].add(new_story) 104 | TOPIC_TO_SUPER[new_story].add(prev_story) 105 | 106 | stories = prev_stories + new_stories 107 | for li in lis: 108 | recursively_extract_bullets(li, date, category, stories) 109 | 110 | else: 111 | # reached the "leaf", i.e. event summary 112 | text = e.text 113 | wiki_links = [] 114 | references = [] 115 | for link in e.find_all('a'): 116 | url = link.get('href') 117 | if link.get('rel') == ['nofollow']: 118 | references.append(url) 119 | elif url.startswith('/wiki'): 120 | wiki_links.append(url) 121 | event = Event(text=text, id=EVENT_ID_COUNTER, date=date, 122 | category=category, stories=prev_stories, 123 | wiki_links=wiki_links, references=references) 124 | EVENTS.append(event) 125 | EVENT_ID_COUNTER += 1 126 | 127 | 128 | def process_month_page_2004_to_2017(html): 129 | soup = BeautifulSoup(html, 'html.parser') 130 | days = soup.find_all('table', class_='vevent') 131 | for day in days: 132 | date = extract_date(day) 133 | #print('DATE:', date) 134 | category = None 135 | desc = day.find('td', class_='description') 136 | for e in desc.children: 137 | if e.name == 'dl': 138 | category = e.text 139 | elif e.name == 'ul': 140 | recursively_extract_bullets(e, date, category, [], is_root=True) 141 | 142 | 143 | def process_month_page_from_2018(html): 144 | soup = BeautifulSoup(html, 'html.parser') 145 | days = soup.find_all('div', class_='vevent') 146 | for day in days: 147 | date = extract_date(day) 148 | #print('DATE:', date) 149 | category = None 150 | desc = day.find('div', class_='description') 151 | for e in desc.children: 152 | if e.name == 'div' and e.get('role') == 'heading': 153 | category = e.text 154 | #print('-'*25, 'CATEGORY:', category, '-'*25, '\n') 155 | elif e.name == 'ul': 156 | recursively_extract_bullets(e, date, category, [], is_root=True) 157 | 158 | 159 | def file_to_date(path): 160 | fname = str(path.name) 161 | month, year = fname.split('.')[0].split('_') 162 | month = MONTH_TO_INT[month] 163 | year = int(year) 164 | date = datetime.date(year, month, 1) 165 | return date 166 | 167 | 168 | def main(args): 169 | in_dir = pathlib.Path(args.i) 170 | for fpath in sorted(in_dir.iterdir(), key=file_to_date): 171 | fname = fpath.name 172 | 173 | with open(fpath) as f: 174 | html = f.read() 175 | 176 | year = int(fname.split('.')[0].split('_')[1]) 177 | 178 | if 2004 <= year < 2018: 179 | 180 | print(fname) 181 | process_month_page_2004_to_2017(html) 182 | 183 | elif 2018 <= year : 184 | print(fname) 185 | process_month_page_from_2018(html) 186 | 187 | EVENTS.sort(key=lambda x: x.date) 188 | 189 | with open(args.o, 'w') as f: 190 | for e in EVENTS: 191 | e_json = json.dumps(e.to_json_dict()) 192 | f.write(e_json + '\n') 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument('--i', type=str, help='input directory', required=True) 198 | parser.add_argument('--o', type=str, help='output file', required=True) 199 | main(parser.parse_args()) -------------------------------------------------------------------------------- /dataset_generation/step3_snapshot_source_urls.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import savepagenow 3 | import json 4 | import os 5 | import random 6 | import time 7 | from requests.exceptions import ConnectionError 8 | 9 | 10 | def read_jsonl(path): 11 | with open(path) as f: 12 | for line in f: 13 | yield json.loads(line) 14 | 15 | 16 | def write_jsonl(items, path, batch_size=100, override=True): 17 | if override: 18 | with open(path, 'w'): 19 | pass 20 | 21 | batch = [] 22 | for i, x in enumerate(items): 23 | if i > 0 and i % batch_size == 0: 24 | with open(path, 'a') as f: 25 | output = '\n'.join(batch) + '\n' 26 | f.write(output) 27 | batch = [] 28 | raw = json.dumps(x) 29 | batch.append(raw) 30 | 31 | if batch: 32 | with open(path, 'a') as f: 33 | output = '\n'.join(batch) + '\n' 34 | f.write(output) 35 | 36 | 37 | def main(args): 38 | n_done = 0 39 | n_captured = 0 40 | n_success = 0 41 | done_url_set = set() 42 | if not args.override and os.path.exists(args.o): 43 | with open(args.o) as f: 44 | for line in f: 45 | url, archive_url = line.split() 46 | n_done += 1 47 | if archive_url != 'None': 48 | n_success += 1 49 | done_url_set.add(url) 50 | 51 | events = read_jsonl(args.i) 52 | urls = [url for e in events for url in e['references'] 53 | if url not in done_url_set] 54 | if args.shuffle: 55 | random.shuffle(urls) 56 | n_total = len(urls) + len(done_url_set) 57 | 58 | batch = [] 59 | for url in urls: 60 | 61 | repeat = True 62 | archive_url, captured = None, None 63 | while repeat: 64 | try: 65 | archive_url, captured = savepagenow.capture_or_cache(url) 66 | repeat = False 67 | if captured: 68 | n_captured += 1 69 | n_success += 1 70 | except Exception as e: 71 | if isinstance(e, ConnectionError): 72 | print('Too many requests, waiting a bit...') 73 | repeat = True 74 | else: 75 | repeat = False 76 | if repeat: 77 | time.sleep(60) 78 | else: 79 | time.sleep(1) 80 | 81 | if archive_url is not None: 82 | batch.append((url, archive_url, captured)) 83 | n_done += 1 84 | 85 | print(f'total: {n_total}, done: {n_done}, ' 86 | f'successful: {n_success}, captured: {n_captured}\n') 87 | 88 | if len(batch) < args.batchsize: 89 | lines = [f'{url} {archive_url}' for (url, archive_url, _) in batch] 90 | if len(lines) > 0: 91 | with open(args.o, 'a') as f: 92 | f.write('\n'.join(lines) + '\n') 93 | batch = [] 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('--i', required=True) 99 | parser.add_argument('--o', required=True) 100 | parser.add_argument('--batchsize', type=int, default=20) 101 | parser.add_argument('--override', action='store_true') 102 | parser.add_argument('--shuffle', action='store_true') 103 | main(parser.parse_args()) 104 | -------------------------------------------------------------------------------- /dataset_generation/step4_scrape_sources.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import json 4 | import os 5 | import time 6 | import pathlib 7 | import random 8 | import newspaper 9 | import json 10 | import numpy as np 11 | 12 | 13 | def scrape_article(url): 14 | a = newspaper.Article(url) 15 | error = None 16 | try: 17 | a.download() 18 | a.parse() 19 | 20 | if a.publish_date is None: 21 | time = None 22 | else: 23 | time = a.publish_date.isoformat() 24 | 25 | article = { 26 | 'time': time, 27 | 'title': a.title, 28 | 'text': a.text, 29 | 'url': url, 30 | 'state': 'successful', 31 | 'error': None, 32 | } 33 | except Exception as e: 34 | print(e) 35 | article = { 36 | 'url': url, 37 | 'state': 'failed', 38 | 'error': str(e), 39 | } 40 | error = e 41 | return url, article, error 42 | 43 | 44 | def write_articles(articles, path): 45 | lines = [json.dumps(a) for a in articles] 46 | with open(path, 'a') as f: 47 | f.write('\n'.join(lines) + '\n') 48 | 49 | 50 | def batches(iterable, n=1): 51 | l = len(iterable) 52 | for ndx in range(0, l, n): 53 | yield iterable[ndx:min(ndx + n, l)] 54 | 55 | 56 | def load_urls(path): 57 | urls = [] 58 | with open(path) as f: 59 | for line in f: 60 | original_url, archive_url = line.split() 61 | if archive_url != 'None': 62 | urls.append(archive_url) 63 | return urls 64 | 65 | 66 | def main(args): 67 | outpath = pathlib.Path(args.o) 68 | done_urls = set() 69 | failed_urls = [] 70 | n_success = 0 71 | 72 | if args.override and outpath.exists(): 73 | outpath.unlink() 74 | 75 | elif outpath.exists(): 76 | with open(outpath) as f: 77 | for line in f: 78 | a = json.loads(line) 79 | url = a['url'] 80 | if a['state'] == 'successful': 81 | n_success += 1 82 | else: 83 | failed_urls.append(url) 84 | done_urls.add(url) 85 | 86 | 87 | urls = load_urls(args.i) 88 | 89 | if args.repeat_failed: 90 | todo_urls = failed_urls + [url for url in urls if url not in done_urls] 91 | else: 92 | todo_urls = [url for url in urls if url not in done_urls] 93 | if args.shuffle: 94 | random.shuffle(todo_urls) 95 | 96 | n_done = len(done_urls) 97 | n_total = len(urls) 98 | durations = [] 99 | t1 = time.time() 100 | 101 | for url_batch in batches(todo_urls, args.batchsize): 102 | 103 | pool = multiprocessing.Pool(processes=args.jobs) 104 | output = pool.map(scrape_article, url_batch) 105 | pool.close() 106 | 107 | articles = [] 108 | for url, a, error in output: 109 | if a['state'] == 'successful': 110 | n_success += 1 111 | articles.append(a) 112 | n_done += 1 113 | done_urls.add(url) 114 | 115 | if articles: 116 | write_articles(articles, outpath) 117 | 118 | t2 = time.time() 119 | elapsed = t2 - t1 120 | durations.append(elapsed) 121 | t1 = t2 122 | 123 | 124 | 125 | 126 | print(f'{n_done}/{n_total} done') 127 | print(f'total: {n_total}, done: {n_done}, successful: {n_success}') 128 | 129 | print('TIME (seconds):') 130 | print('last batch:', elapsed) 131 | print('last 5:', np.mean(durations[-10:])) 132 | print('overall 5:', np.mean(durations)) 133 | print() 134 | 135 | 136 | if __name__ == '__main__': 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('--i', required=True) 139 | parser.add_argument('--o', required=True) 140 | parser.add_argument('--batchsize', type=int, default=20) 141 | parser.add_argument('--jobs', type=int, default=2) 142 | parser.add_argument('--override', action='store_true') 143 | parser.add_argument('--shuffle', action='store_true') 144 | parser.add_argument('--repeat-failed', action='store_true') 145 | main(parser.parse_args()) -------------------------------------------------------------------------------- /dataset_generation/step5_combine_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from general import utils 3 | 4 | 5 | def load_urls(path): 6 | url_to_arc = {} 7 | arc_to_url = {} 8 | with open(path) as f: 9 | for line in f: 10 | parts = line.split() 11 | if len(parts) == 2: 12 | url, arc_url = parts 13 | url_to_arc[url] = arc_url 14 | arc_to_url[arc_url] = url 15 | return url_to_arc, arc_to_url 16 | 17 | 18 | def main(args): 19 | 20 | articles = list(utils.read_jsonl(args.articles)) 21 | events = list(utils.read_jsonl(args.events)) 22 | 23 | url_to_arc, arc_to_url = load_urls(args.urls) 24 | 25 | url_to_article = {} 26 | for a in articles: 27 | arc_url = a['url'] 28 | if arc_url in arc_to_url: 29 | url = arc_to_url[arc_url] 30 | url_to_article[url] = a 31 | a['archive_url'] = arc_url 32 | a['url'] = url 33 | 34 | new_events = [] 35 | for e in events: 36 | 37 | e_urls = e['references'] 38 | e_articles = [url_to_article[url] 39 | for url in e_urls if url in url_to_article] 40 | e['articles'] = e_articles 41 | 42 | if len(e_articles) > 0: 43 | new_events.append(e) 44 | 45 | print('original events:', len(events)) 46 | print('new events:', len(new_events)) 47 | utils.write_jsonl(new_events, args.o) 48 | 49 | 50 | if __name__ == '__main__': 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('--articles', required=True) 53 | parser.add_argument('--events', required=True) 54 | parser.add_argument('--urls', required=True) 55 | parser.add_argument('--o', required=True) 56 | main(parser.parse_args()) 57 | -------------------------------------------------------------------------------- /dataset_reproduction/combine_and_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pathlib 4 | import shutil 5 | import utils 6 | from collections import defaultdict 7 | 8 | 9 | def get_article_to_cluster_mappings(clusters): 10 | url_to_cluster_idxs = defaultdict(list) 11 | id_to_cluster_idx = {} 12 | for i, c in enumerate(clusters): 13 | for a in c['wcep_articles']: 14 | url_to_cluster_idxs[a['archive_url']].append(i) 15 | for a in c['cc_articles']: 16 | id_to_cluster_idx[a['id']] = i 17 | return url_to_cluster_idxs, id_to_cluster_idx 18 | 19 | 20 | def add_wcep_articles_to_clusters(wcep_path, url_to_cluster_idxs, clusters): 21 | print('adding articles from WCEP to clusters') 22 | for a in utils.read_jsonl(wcep_path): 23 | for i in url_to_cluster_idxs[a['archive_url']]: 24 | c = clusters[i] 25 | c.setdefault('wcep_articles_filled', []) 26 | c['wcep_articles_filled'].append(a) 27 | 28 | 29 | def add_cc_articles_to_clusters(clusters, cc_path, id_to_cluster_idx, tmp_clusters_path): 30 | print('adding articles from CommonCrawl to clusters') 31 | n_clusters = len(clusters) 32 | n_clusters_done = 0 33 | for i, a in enumerate(utils.read_jsonl(cc_path)): 34 | if i % 10000 == 0: 35 | print(f'{i} cc articles done, {n_clusters_done}/{n_clusters} clusters done') 36 | cluster_idx = id_to_cluster_idx[a['id']] 37 | c = clusters[cluster_idx] 38 | 39 | if c is not None: 40 | c['cc_articles_filled'].append(a) 41 | c['cc_ids_filled'].add(a['id']) 42 | if c['cc_ids'] == c['cc_ids_filled']: 43 | del c['cc_ids'], c['cc_ids_filled'] 44 | utils.write_jsonl([c], tmp_clusters_path, mode='a') 45 | clusters[cluster_idx] = None 46 | n_clusters_done += 1 47 | 48 | # remaining few clusters that only have WCEP but not CC articles 49 | for c in clusters: 50 | if c is not None and c['cc_ids'] == c['cc_ids_filled']: 51 | print("Hmm") 52 | del c['cc_ids'], c['cc_ids_filled'] 53 | utils.write_jsonl([c], tmp_clusters_path, mode='a') 54 | clusters[cluster_idx] = None 55 | n_clusters_done += 1 56 | 57 | print(f'{i} cc articles done, {n_clusters_done}/{n_clusters} clusters done') 58 | 59 | 60 | def split_dataset(outdir, tmp_clusters_path): 61 | print('splitting dataset into train/val/test...') 62 | for i, c in enumerate(utils.read_jsonl(tmp_clusters_path)): 63 | if i % 1000 == 0: 64 | print(i, 'clusters done') 65 | outpath = outdir / (c['collection'] + '.jsonl') 66 | utils.write_jsonl([c], outpath, mode='a') 67 | 68 | 69 | def cleanup_clusters(path, tmp_path): 70 | print('cleaning up:', path.name) 71 | for i, c in enumerate(utils.read_jsonl(path)): 72 | if i % 1000 == 0: 73 | print(i, 'clusters done') 74 | articles = [] 75 | if 'wcep_articles_filled' in c: 76 | for a in c['wcep_articles_filled']: 77 | a['origin'] = 'WCEP' 78 | articles.append(a) 79 | if 'cc_articles_filled' in c: 80 | for a in c['cc_articles_filled']: 81 | a['origin'] = 'CommonCrawl' 82 | articles.append(a) 83 | 84 | c = { 85 | 'id': c['id'], 86 | 'date': c['date'], 87 | 'summary': c['summary'], 88 | 'articles': articles, 89 | 'collection': c['collection'], 90 | 'wiki_links': c['wiki_links'], 91 | 'reference_urls': c['reference_urls'], 92 | 'category': c['category'] 93 | } 94 | 95 | utils.write_jsonl([c], tmp_path, mode='a') 96 | 97 | shutil.move(tmp_path, path) 98 | 99 | 100 | def main(args): 101 | outdir = pathlib.Path(args.o) 102 | if outdir.exists(): 103 | shutil.rmtree(outdir) 104 | outdir.mkdir() 105 | tmp_clusters_path = outdir / 'tmp_clusters.jsonl' 106 | if tmp_clusters_path.exists(): 107 | tmp_clusters_path.unlink() 108 | 109 | # get article -> cluster mappings 110 | clusters = list(utils.read_jsonl(args.dataset)) 111 | for c in clusters: 112 | if args.max_cluster_size != -1: 113 | l = args.max_cluster_size - len(c['wcep_articles']) 114 | c['cc_articles'] = c['cc_articles'][:l] 115 | 116 | c['cc_ids'] = set([a['id'] for a in c['cc_articles']]) 117 | c['cc_ids_filled'] = set() 118 | c['cc_articles_filled'] = [] 119 | 120 | url_to_cluster_idxs, id_to_cluster_idx = get_article_to_cluster_mappings( 121 | clusters 122 | ) 123 | 124 | # add articles from WCEP to clusters, using URLs 125 | add_wcep_articles_to_clusters( 126 | args.wcep_articles, url_to_cluster_idxs, clusters 127 | ) 128 | 129 | # add articles from CommonCrawl to clusters, using IDs 130 | add_cc_articles_to_clusters( 131 | clusters, args.cc_articles, id_to_cluster_idx, tmp_clusters_path 132 | ) 133 | 134 | # split clusters into separate train/val/test files 135 | split_dataset(outdir, tmp_clusters_path) 136 | tmp_clusters_path.unlink() 137 | 138 | for fn in ['train.jsonl', 'val.jsonl', 'test.jsonl']: 139 | cleanup_clusters(outdir / fn, tmp_clusters_path) 140 | 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('--dataset', required=True) 146 | parser.add_argument('--wcep-articles', required=True) 147 | parser.add_argument('--cc-articles', required=True) 148 | parser.add_argument('--max-cluster-size', type=int, default=-1) 149 | parser.add_argument('--o', required=True) 150 | main(parser.parse_args()) 151 | -------------------------------------------------------------------------------- /dataset_reproduction/extract_cc_articles.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import logging 4 | import json 5 | import subprocess 6 | import multiprocessing 7 | import newspaper 8 | import sys 9 | import time 10 | import utils 11 | from warcio.archiveiterator import ArchiveIterator 12 | 13 | 14 | def read_warc_gz(path): 15 | with open(path, 'rb') as f: 16 | for record in ArchiveIterator(f): 17 | # records are queries followed by response, we only need response 18 | if record.content_type == 'application/http; msgtype=response': 19 | yield record 20 | 21 | 22 | def get_record_id(record): 23 | id = record.rec_headers.get_header('WARC-Record-ID') 24 | id = id.split('uuid:')[1].split('>')[0] 25 | return id 26 | 27 | 28 | def get_record_url(record): 29 | return record.rec_headers.get_header('WARC-Target-URI') 30 | 31 | 32 | def download_cc_file(cc_path, local_cc_path): 33 | aws_path = f's3://commoncrawl/{cc_path}' 34 | cmd = f'aws s3 cp {aws_path} {local_cc_path} --no-sign-request' 35 | logging.debug(cmd) 36 | cmd = cmd.split() 37 | while not local_cc_path.exists(): 38 | p = subprocess.Popen(cmd) 39 | try: 40 | p.wait() 41 | except KeyboardInterrupt: 42 | p.terminate() 43 | if local_cc_path.exists(): 44 | break 45 | logging.info(f'file download failed, retrying: {cc_path}') 46 | time.sleep(5) 47 | 48 | 49 | def read_article_ids(path, max_cluster_size): 50 | id_to_collection = {} 51 | ids = set() 52 | for cluster in utils.read_jsonl(path): 53 | articles = cluster['cc_articles'] 54 | if max_cluster_size != -1: 55 | l = max_cluster_size - len(cluster['wcep_articles']) 56 | articles = articles[:l] 57 | for a in articles: 58 | ids.add(a['id']) 59 | id_to_collection[a['id']] = cluster['collection'] 60 | return ids, id_to_collection 61 | 62 | 63 | def extract_article(item): 64 | html = item['html'] 65 | extracted = newspaper.Article(item['url']) 66 | try: 67 | extracted.download(input_html=html) 68 | extracted.parse() 69 | 70 | if extracted.publish_date is None: 71 | time = None 72 | else: 73 | time = extracted.publish_date.isoformat() 74 | 75 | article = { 76 | 'id': item['id'], 77 | 'cc_file': item['cc_file'], 78 | 'time': time, 79 | 'title': extracted.title, 80 | 'text': extracted.text, 81 | 'url': item['url'], 82 | 'collection': item['collection'], 83 | } 84 | 85 | except Exception as e: 86 | logging.error(f'record-id: {item["id"]}, error:{e}') 87 | article = None 88 | return article 89 | 90 | 91 | def process_batch(items, out_path, jobs): 92 | logging.debug('extracting articles...') 93 | pool = multiprocessing.Pool(processes=jobs) 94 | try: 95 | articles = pool.map(extract_article, items) 96 | articles = [a for a in articles if a is not None] 97 | pool.close() 98 | logging.debug('extracting articles done') 99 | except KeyboardInterrupt: 100 | pool.terminate() 101 | sys.exit() 102 | utils.write_jsonl(articles, out_path, mode='a') 103 | new_record_ids = [x['id'] for x in items] 104 | logging.info(f'done-record-ids:{" ".join(new_record_ids)}') 105 | return articles 106 | 107 | 108 | def parse_logged_record_ids(line): 109 | ids = line.split('done-cc-ids:')[1] 110 | ids = ids.split() 111 | return set(ids) 112 | 113 | 114 | def parse_logged_cc_file(line): 115 | return line.split('done-cc-file:')[1].strip() 116 | 117 | 118 | def read_log(path): 119 | done_cc_files = set() 120 | done_record_ids = set() 121 | with open(path) as f: 122 | for line in f: 123 | if 'done-cc-file' in line: 124 | done_cc_files.add(parse_logged_cc_file(line)) 125 | elif 'done-cc-ids' in line: 126 | done_record_ids |= parse_logged_record_ids(line) 127 | return done_cc_files, done_record_ids 128 | 129 | 130 | def mute_other_loggers(): 131 | logging.getLogger('urllib3').setLevel(logging.WARNING) 132 | logging.getLogger('PIL').setLevel(logging.WARNING) 133 | logging.getLogger('newspaper').setLevel(logging.WARNING) 134 | logging.getLogger('chardet.charsetprober').setLevel(logging.WARNING) 135 | 136 | 137 | def main(args): 138 | storage = pathlib.Path(args.storage) 139 | logpath = storage / 'log.txt' 140 | cc_files_path = storage / 'cc_files.txt' 141 | out_path = storage / 'cc_articles.jsonl' 142 | 143 | if not storage.exists(): 144 | storage.mkdir() 145 | 146 | if args.override and out_path.exists(): 147 | out_path.unlink() 148 | if args.override and logpath.exists(): 149 | logpath.unlink() 150 | 151 | logging.basicConfig( 152 | level=logging.DEBUG, 153 | filename=logpath, 154 | filemode=('w' if args.override else 'a'), 155 | format='%(asctime)s %(levelname)-8s %(message)s' 156 | ) 157 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 158 | mute_other_loggers() 159 | 160 | if logpath.exists(): 161 | done_cc_files, done_record_ids = read_log(logpath) 162 | else: 163 | done_cc_files, done_record_ids = set(), set() 164 | 165 | cc_files = list(utils.read_lines(cc_files_path)) 166 | todo_record_ids, id_to_collection = read_article_ids( 167 | args.dataset, args.max_cluster_size) 168 | n_files = len(cc_files) 169 | 170 | for i, cc_file in enumerate(cc_files): 171 | if cc_file in done_cc_files: 172 | continue 173 | 174 | logging.debug(f'file {i+1}/{n_files}') 175 | 176 | local_cc_path = storage / cc_file.split('/')[-1] 177 | if not local_cc_path.exists(): 178 | download_cc_file(cc_file, local_cc_path) 179 | 180 | batch = [] 181 | n_found_articles = 0 182 | for i, record in enumerate(read_warc_gz(local_cc_path)): 183 | if i % 10000 == 0: 184 | logging.debug( 185 | f'{i} records checked, {n_found_articles} articles found') 186 | id = get_record_id(record) 187 | if id in todo_record_ids: 188 | n_found_articles += 1 189 | item = { 190 | 'id': id, 191 | 'html': record.content_stream().read(), 192 | 'url': get_record_url(record), 193 | 'collection': id_to_collection[id], 194 | 'cc_file': cc_file 195 | } 196 | batch.append(item) 197 | 198 | if len(batch) >= args.batchsize: 199 | process_batch(batch, out_path, args.jobs) 200 | batch = [] 201 | 202 | if batch: 203 | process_batch(batch, out_path, args.jobs) 204 | 205 | logging.info(f'done-cc-file:{cc_file}') 206 | local_cc_path.unlink() 207 | 208 | 209 | if __name__ == '__main__': 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument('--dataset', required=True) 212 | parser.add_argument('--storage', required=True) 213 | parser.add_argument('--override', action='store_true') 214 | parser.add_argument('--max-cluster-size', type=int, default=-1) 215 | parser.add_argument('--batchsize', type=int, default=1000) 216 | parser.add_argument('--jobs', type=int, default=4) 217 | main(parser.parse_args()) 218 | -------------------------------------------------------------------------------- /dataset_reproduction/extract_wcep_articles.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import time 4 | import pathlib 5 | import random 6 | import newspaper 7 | import json 8 | import numpy as np 9 | import utils 10 | 11 | 12 | def extract_article(todo_article): 13 | url = todo_article['archive_url'] 14 | extracted = newspaper.Article(url) 15 | try: 16 | extracted.download() 17 | extracted.parse() 18 | 19 | if extracted.publish_date is None: 20 | time = None 21 | else: 22 | time = extracted.publish_date.isoformat() 23 | 24 | article = { 25 | 'time': time, 26 | 'title': extracted.title, 27 | 'text': extracted.text, 28 | 'url': todo_article['url'], 29 | 'archive_url': url, 30 | 'collection': todo_article['collection'], 31 | 'state': 'successful', 32 | 'error': None, 33 | } 34 | 35 | except Exception as e: 36 | print(e) 37 | article = { 38 | 'archive_url': url, 39 | 'state': 'failed', 40 | 'error': str(e), 41 | } 42 | 43 | return article 44 | 45 | 46 | def batches(iterable, n=1): 47 | l = len(iterable) 48 | for i in range(0, l, n): 49 | yield iterable[i:min(i + n, l)] 50 | 51 | 52 | def read_input(path): 53 | articles = [] 54 | with open(path) as f: 55 | for line in f: 56 | cluster = json.loads(line) 57 | for a in cluster['wcep_articles']: 58 | a['collection'] = cluster['collection'] 59 | articles.append(a) 60 | return articles 61 | 62 | 63 | def main(args): 64 | 65 | outpath = pathlib.Path(args.o) 66 | done_urls = set() 67 | failed_articles = [] 68 | n_done = 0 69 | n_success = 0 70 | 71 | if args.override and outpath.exists(): 72 | outpath.unlink() 73 | 74 | elif outpath.exists(): 75 | with open(outpath) as f: 76 | for line in f: 77 | a = json.loads(line) 78 | url = a['archive_url'] 79 | if a['state'] == 'successful': 80 | n_success += 1 81 | else: 82 | failed_articles.append(a) 83 | n_done += 1 84 | done_urls.add(url) 85 | 86 | todo_articles = read_input(args.i) 87 | n_total = len(todo_articles) 88 | todo_articles = [a for a in todo_articles if a['archive_url'] 89 | not in done_urls] 90 | 91 | print('failed articles from last run:', len(failed_articles)) 92 | print('articles todo:', len(todo_articles)) 93 | 94 | 95 | if args.repeat_failed: 96 | todo_articles = failed_articles + todo_articles 97 | 98 | if args.shuffle: 99 | random.shuffle(todo_articles) 100 | 101 | durations = [] 102 | t1 = time.time() 103 | for todo_batch in batches(todo_articles, args.batchsize): 104 | 105 | pool = multiprocessing.Pool(processes=args.jobs) 106 | output = pool.map(extract_article, todo_batch) 107 | pool.close() 108 | 109 | articles = [] 110 | for a in output: 111 | if a['state'] == 'successful': 112 | n_success += 1 113 | articles.append(a) 114 | done_urls.add(a['archive_url']) 115 | n_done += 1 116 | 117 | if articles: 118 | utils.write_jsonl(articles, outpath, mode='a') 119 | 120 | t2 = time.time() 121 | elapsed = t2 - t1 122 | durations.append(elapsed) 123 | t1 = t2 124 | 125 | print(f'{n_done}/{n_total} done, {n_success}/{n_done} successful') 126 | print('Average per-batch time (seconds):') 127 | print('last batch:', elapsed) 128 | print('last 10:', np.mean(durations[-10:])) 129 | print('overall:', np.mean(durations)) 130 | print() 131 | 132 | 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('--i', required=True) 136 | parser.add_argument('--o', required=True) 137 | parser.add_argument('--batchsize', type=int, default=20) 138 | parser.add_argument('--jobs', type=int, default=2) 139 | parser.add_argument('--override', action='store_true') 140 | parser.add_argument('--shuffle', action='store_true') 141 | parser.add_argument('--repeat-failed', action='store_true') 142 | main(parser.parse_args()) 143 | -------------------------------------------------------------------------------- /dataset_reproduction/requirements.txt: -------------------------------------------------------------------------------- 1 | warcio==1.7.1 2 | newspaper3k==0.2.8 3 | -------------------------------------------------------------------------------- /dataset_reproduction/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def read_lines(path): 5 | with open(path) as f: 6 | for line in f: 7 | yield line.strip() 8 | 9 | 10 | def read_jsonl(path): 11 | with open(path) as f: 12 | for line in f: 13 | yield json.loads(line) 14 | 15 | 16 | def write_jsonl(items, path, mode='a'): 17 | assert mode in ['w', 'a'] 18 | lines = [json.dumps(x) for x in items] 19 | with open(path, mode) as f: 20 | f.write('\n'.join(lines) + '\n') 21 | 22 | -------------------------------------------------------------------------------- /experiments/baselines.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import random 3 | import collections 4 | import numpy as np 5 | import networkx as nx 6 | import warnings 7 | from sklearn.feature_extraction.text import TfidfVectorizer 8 | from sklearn.metrics.pairwise import cosine_similarity 9 | from sklearn.cluster import MiniBatchKMeans 10 | from summarizer import Summarizer 11 | 12 | 13 | warnings.filterwarnings('ignore', category=RuntimeWarning) 14 | random.seed(24) 15 | 16 | 17 | class RandomBaseline(Summarizer): 18 | 19 | def summarize(self, 20 | articles, 21 | max_len=40, 22 | len_type='words', 23 | in_titles=False, 24 | out_titles=False, 25 | min_sent_tokens=7, 26 | max_sent_tokens=40): 27 | 28 | articles = self._preprocess(articles) 29 | sents = [s for a in articles for s in a.sents] 30 | if in_titles == False or out_titles == False: 31 | sents = [s for s in sents if not s.is_title] 32 | sents = self._deduplicate(sents) 33 | sent_lens = [self._sent_len(s, len_type) for s in sents] 34 | 35 | current_len = 0 36 | remaining = list(range(len(sents))) 37 | random.shuffle(remaining) 38 | 39 | selected = [] 40 | for i in remaining: 41 | new_len = current_len + sent_lens[i] 42 | if new_len <= max_len: 43 | if not (min_sent_tokens <= len( 44 | sents[i].words) <= max_sent_tokens): 45 | continue 46 | selected.append(i) 47 | current_len = new_len 48 | if current_len >= max_len: 49 | break 50 | 51 | summary_sents = [sents[i].text for i in selected] 52 | return ' '.join(summary_sents) 53 | 54 | 55 | class RandomLead(Summarizer): 56 | 57 | def summarize(self, 58 | articles, 59 | max_len=40, 60 | len_type='words', 61 | in_titles=False, 62 | out_titles=False, 63 | min_sent_tokens=7, 64 | max_sent_tokens=40): 65 | 66 | article_idxs = list(range(len(articles))) 67 | random.shuffle(article_idxs) 68 | summary = '' 69 | for i in article_idxs: 70 | a = articles[i] 71 | a = self._preprocess([a])[0] 72 | sents = a.sents 73 | if in_titles == False or out_titles == False: 74 | sents = [s for s in sents if not s.is_title] 75 | current_len = 0 76 | selected_sents = [] 77 | for s in sents: 78 | l = self._sent_len(s, len_type) 79 | new_len = current_len + l 80 | if new_len <= max_len: 81 | if not (min_sent_tokens <= len(s.words) <= max_sent_tokens): 82 | continue 83 | selected_sents.append(s.text) 84 | current_len = new_len 85 | if new_len > max_len: 86 | break 87 | if len(selected_sents) >= 1: 88 | summary = ' '.join(selected_sents) 89 | break 90 | return summary 91 | 92 | 93 | class TextRankSummarizer(Summarizer): 94 | def __init__(self, max_redundancy=0.5): 95 | self.max_redundancy = max_redundancy 96 | 97 | def _compute_page_rank(self, S): 98 | nodes = list(range(S.shape[0])) 99 | graph = nx.from_numpy_matrix(S) 100 | pagerank = nx.pagerank(graph, weight='weight') 101 | scores = [pagerank[i] for i in nodes] 102 | return scores 103 | 104 | def summarize(self, 105 | articles, 106 | max_len=40, 107 | len_type='words', 108 | in_titles=False, 109 | out_titles=False, 110 | min_sent_tokens=7, 111 | max_sent_tokens=40): 112 | 113 | articles = self._preprocess(articles) 114 | sents = [s for a in articles for s in a.sents] 115 | if in_titles == False: 116 | sents = [s for s in sents if not s.is_title] 117 | sents = self._deduplicate(sents) 118 | sent_lens = [self._sent_len(s, len_type) for s in sents] 119 | raw_sents = [s.text for s in sents] 120 | 121 | vectorizer = TfidfVectorizer(lowercase=True, stop_words='english') 122 | X = vectorizer.fit_transform(raw_sents) 123 | S = cosine_similarity(X) 124 | 125 | scores = self._compute_page_rank(S) 126 | scored = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) 127 | 128 | if not out_titles: 129 | scored = [(i, score) for (i, score) in scored 130 | if not sents[i].is_title] 131 | 132 | current_len = 0 133 | selected = [] 134 | for i, _ in scored: 135 | new_len = current_len + sent_lens[i] 136 | if new_len <= max_len: 137 | if self._is_redundant( 138 | sents, selected, i, self.max_redundancy): 139 | continue 140 | if not (min_sent_tokens <= len( 141 | sents[i].words) <= max_sent_tokens): 142 | continue 143 | 144 | selected.append(i) 145 | current_len = new_len 146 | 147 | summary_sents = [sents[i].text for i in selected] 148 | return ' '.join(summary_sents) 149 | 150 | 151 | class CentroidSummarizer(Summarizer): 152 | def __init__(self, max_redundancy=0.5): 153 | self.max_redundancy = max_redundancy 154 | 155 | def summarize(self, 156 | articles, 157 | max_len=40, 158 | len_type='words', 159 | in_titles=False, 160 | out_titles=False, 161 | min_sent_tokens=7, 162 | max_sent_tokens=40): 163 | 164 | articles = self._preprocess(articles) 165 | sents = [s for a in articles for s in a.sents] 166 | if in_titles == False: 167 | sents = [s for s in sents if not s.is_title] 168 | sents = self._deduplicate(sents) 169 | sent_lens = [self._sent_len(s, len_type) for s in sents] 170 | raw_sents = [s.text for s in sents] 171 | 172 | vectorizer = TfidfVectorizer(lowercase=True, stop_words='english') 173 | try: 174 | X = vectorizer.fit_transform(raw_sents) 175 | except: 176 | return '' 177 | 178 | centroid = X.mean(0) 179 | scores = cosine_similarity(X, centroid) 180 | scored = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) 181 | 182 | if not out_titles: 183 | scored = [(i, score) for (i, score) in scored 184 | if not sents[i].is_title] 185 | 186 | current_len = 0 187 | selected = [] 188 | for i, _ in scored: 189 | new_len = current_len + sent_lens[i] 190 | if new_len <= max_len: 191 | if self._is_redundant( 192 | sents, selected, i, self.max_redundancy): 193 | continue 194 | if not (min_sent_tokens <= len( 195 | sents[i].words) <= max_sent_tokens): 196 | continue 197 | 198 | selected.append(i) 199 | current_len = new_len 200 | 201 | summary_sents = [sents[i].text for i in selected] 202 | return ' '.join(summary_sents) 203 | 204 | 205 | class SubmodularSummarizer(Summarizer): 206 | """ 207 | Selects a combination of sentences as a summary by greedily optimizing 208 | a submodular function, in this case two functions representing 209 | coverage and diversity of the sentence combination. 210 | """ 211 | def __init__(self, a=5, div_weight=6, cluster_factor=0.2): 212 | self.a = a 213 | self.div_weight = div_weight 214 | self.cluster_factor = cluster_factor 215 | 216 | def cluster_sentences(self, X): 217 | n = X.shape[0] 218 | n_clusters = round(self.cluster_factor * n) 219 | if n_clusters <= 1 or n <= 2: 220 | return dict((i, 1) for i in range(n)) 221 | clusterer = MiniBatchKMeans( 222 | n_clusters=n_clusters, 223 | init_size=3 * n_clusters 224 | ) 225 | labels = clusterer.fit_predict(X) 226 | i_to_label = dict((i, l) for i, l in enumerate(labels)) 227 | return i_to_label 228 | 229 | def compute_summary_coverage(self, 230 | alpha, 231 | summary_indices, 232 | sent_coverages, 233 | pairwise_sims): 234 | cov = 0 235 | for i, i_generic_cov in enumerate(sent_coverages): 236 | i_summary_cov = sum([pairwise_sims[i, j] for j in summary_indices]) 237 | i_cov = min(i_summary_cov, alpha * i_generic_cov) 238 | cov += i_cov 239 | return cov 240 | 241 | def compute_summary_diversity(self, 242 | summary_indices, 243 | ix_to_label, 244 | avg_sent_sims): 245 | 246 | cluster_to_ixs = collections.defaultdict(list) 247 | for i in summary_indices: 248 | l = ix_to_label[i] 249 | cluster_to_ixs[l].append(i) 250 | div = 0 251 | for l, l_indices in cluster_to_ixs.items(): 252 | cluster_score = sum([avg_sent_sims[i] for i in l_indices]) 253 | cluster_score = np.sqrt(cluster_score) 254 | div += cluster_score 255 | return div 256 | 257 | def optimize(self, 258 | sents, 259 | max_len, 260 | len_type, 261 | ix_to_label, 262 | pairwise_sims, 263 | sent_coverages, 264 | avg_sent_sims, 265 | out_titles, 266 | min_sent_tokens, 267 | max_sent_tokens): 268 | 269 | alpha = self.a / len(sents) 270 | sent_lens = [self._sent_len(s, len_type) for s in sents] 271 | current_len = 0 272 | remaining = set(range(len(sents))) 273 | 274 | for i, s in enumerate(sents): 275 | bad_length = not (min_sent_tokens <= len(sents[i].words) 276 | <= max_sent_tokens) 277 | if bad_length: 278 | remaining.remove(i) 279 | elif out_titles == False and s.is_title: 280 | remaining.remove(i) 281 | 282 | selected = [] 283 | scored_selections = [] 284 | 285 | while current_len < max_len and len(remaining) > 0: 286 | scored = [] 287 | for i in remaining: 288 | new_len = current_len + sent_lens[i] 289 | if new_len <= max_len: 290 | summary_indices = selected + [i] 291 | cov = self.compute_summary_coverage( 292 | alpha, summary_indices, sent_coverages, pairwise_sims) 293 | div = self.compute_summary_diversity( 294 | summary_indices, ix_to_label, avg_sent_sims) 295 | score = cov + self.div_weight * div 296 | scored.append((i, score)) 297 | 298 | if len(scored) == 0: 299 | break 300 | scored.sort(key=lambda x: x[1], reverse=True) 301 | best_idx, best_score = scored[0] 302 | scored_selections.append((selected + [best_idx], best_score)) 303 | current_len += sent_lens[best_idx] 304 | selected.append(best_idx) 305 | remaining.remove(best_idx) 306 | 307 | scored_selections.sort(key=lambda x: x[1], reverse=True) 308 | best_selection = scored_selections[0][0] 309 | return best_selection 310 | 311 | def summarize(self, 312 | articles, 313 | max_len=40, 314 | len_type='words', 315 | in_titles=False, 316 | out_titles=False, 317 | min_sent_tokens=7, 318 | max_sent_tokens=40): 319 | 320 | articles = self._preprocess(articles) 321 | sents = [s for a in articles for s in a.sents] 322 | if in_titles == False: 323 | sents = [s for s in sents if not s.is_title] 324 | sents = self._deduplicate(sents) 325 | raw_sents = [s.text for s in sents] 326 | 327 | vectorizer = TfidfVectorizer(lowercase=True, stop_words='english') 328 | X = vectorizer.fit_transform(raw_sents) 329 | 330 | ix_to_label = self.cluster_sentences(X) 331 | pairwise_sims = cosine_similarity(X) 332 | sent_coverages = pairwise_sims.sum(0) 333 | avg_sent_sims = sent_coverages / len(sents) 334 | 335 | selected = self.optimize( 336 | sents, max_len, len_type, ix_to_label, 337 | pairwise_sims, sent_coverages, avg_sent_sims, 338 | out_titles, min_sent_tokens, max_sent_tokens 339 | ) 340 | 341 | summary = [sents[i].text for i in selected] 342 | return ' '.join(summary) 343 | -------------------------------------------------------------------------------- /experiments/data.py: -------------------------------------------------------------------------------- 1 | import string 2 | from spacy.lang.en import STOP_WORDS 3 | STOP_WORDS |= set(string.punctuation) 4 | 5 | 6 | class Article: 7 | def __init__(self, title, sents): 8 | self.title = title 9 | self.sents = sents 10 | 11 | def words(self): 12 | if self.title is None: 13 | return [w for s in self.sents for w in s.words] 14 | else: 15 | return [w for s in [self.title] + self.sents for w in s.words] 16 | 17 | 18 | class Sentence: 19 | def __init__(self, text, words, position, is_title=False): 20 | self.text = text 21 | self.words = words 22 | self.position = position 23 | self.content_words = [w for w in words if w not in STOP_WORDS] 24 | self.is_title = is_title 25 | 26 | def __len__(self): 27 | return len(self.words) -------------------------------------------------------------------------------- /experiments/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import numpy as np 4 | import utils 5 | from newsroom.analyze.rouge import ROUGE_L, ROUGE_N 6 | 7 | 8 | def print_mean(results, rouge_types): 9 | for rouge_type in rouge_types: 10 | precs = results[rouge_type]['p'] 11 | recalls = results[rouge_type]['r'] 12 | fscores = results[rouge_type]['f'] 13 | p = round(np.mean(precs), 3) 14 | r = round(np.mean(recalls), 3) 15 | f = round(np.mean(fscores), 3) 16 | print(rouge_type, 'p:', p, 'r:', r, 'f:', f) 17 | 18 | 19 | def evaluate(ref_summaries, pred_summaries, lowercase=False): 20 | 21 | rouge_types = ['rouge-1', 'rouge-2', 'rouge-l'] 22 | results = dict((rouge_type, collections.defaultdict(list)) 23 | for rouge_type in rouge_types) 24 | 25 | for ref, pred in zip(ref_summaries, pred_summaries): 26 | 27 | if lowercase: 28 | pred = pred.lower() 29 | ref = ref.lower() 30 | 31 | r1 = ROUGE_N(ref, pred, n=1) 32 | r2 = ROUGE_N(ref, pred, n=2) 33 | rl = ROUGE_L(ref, pred) 34 | 35 | for (rouge_type, scores) in zip(rouge_types, [r1, r2, rl]): 36 | results[rouge_type]['p'].append(scores.precision) 37 | results[rouge_type]['r'].append(scores.recall) 38 | results[rouge_type]['f'].append(scores.fscore) 39 | 40 | mean_results = {} 41 | for rouge_type in rouge_types: 42 | precs = results[rouge_type]['p'] 43 | recalls = results[rouge_type]['r'] 44 | fscores = results[rouge_type]['f'] 45 | mean_results[rouge_type] = { 46 | 'p': round(np.mean(precs), 3), 47 | 'r': round(np.mean(recalls), 3), 48 | 'f': round(np.mean(fscores), 3) 49 | } 50 | 51 | return mean_results 52 | 53 | 54 | def evaluate_from_path(dataset_path, pred_path, start, stop, lowercase=False): 55 | 56 | dataset = utils.read_jsonl(dataset_path) 57 | predictions = utils.read_jsonl(pred_path) 58 | 59 | rouge_types = ['rouge-1', 'rouge-2', 'rouge-l'] 60 | results = dict((rouge_type, collections.defaultdict(list)) 61 | for rouge_type in rouge_types) 62 | 63 | for i, cluster in enumerate(dataset): 64 | if start > -1 and i < start: 65 | continue 66 | if stop > -1 and i >= stop: 67 | break 68 | 69 | prediction = next(predictions) 70 | assert prediction['cluster_id'] == cluster['id'] 71 | 72 | hyp = prediction['summary'] 73 | ref = cluster['summary'] 74 | 75 | if lowercase: 76 | hyp = hyp.lower() 77 | ref = ref.lower() 78 | 79 | r1 = ROUGE_N(ref, hyp, n=1) 80 | r2 = ROUGE_N(ref, hyp, n=2) 81 | rl = ROUGE_L(ref, hyp) 82 | 83 | for (rouge_type, scores) in zip(rouge_types, [r1, r2, rl]): 84 | results[rouge_type]['p'].append(scores.precision) 85 | results[rouge_type]['r'].append(scores.recall) 86 | results[rouge_type]['f'].append(scores.fscore) 87 | 88 | if i % 100 == 0: 89 | print(i) 90 | # print_mean(results, rouge_types) 91 | 92 | print('Final Average:') 93 | print_mean(results, rouge_types) 94 | return results 95 | 96 | 97 | def main(args): 98 | results = evaluate(args.dataset, args.preds, args.start, args.stop, 99 | args.lowercase) 100 | utils.write_json(results, args.o) 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--dataset') 106 | parser.add_argument('--preds') 107 | parser.add_argument('--o') 108 | parser.add_argument('--start', type=int, default=-1) 109 | parser.add_argument('--stop', type=int, default=-1) 110 | parser.add_argument('--lowercase', action='store_true') 111 | main(parser.parse_args()) 112 | -------------------------------------------------------------------------------- /experiments/oracles.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import Counter 3 | from nltk import word_tokenize, ngrams 4 | from summarizer import Summarizer 5 | import utils 6 | 7 | 8 | def compute_rouge_n(hyp, ref, rouge_n=1, tokenize=True): 9 | hyp_words = word_tokenize(hyp) if tokenize else hyp 10 | ref_words = word_tokenize(ref) if tokenize else ref 11 | 12 | if rouge_n > 1: 13 | hyp_items = list(ngrams(hyp_words, n=rouge_n)) 14 | ref_items = list(ngrams(ref_words, n=rouge_n)) 15 | else: 16 | hyp_items = hyp_words 17 | ref_items = ref_words 18 | 19 | if len(hyp_items) == 0 or len(ref_items) == 0: 20 | return {'p': 0., 'r': 0., 'f': 0.} 21 | 22 | hyp_counts = Counter(hyp_items) 23 | ref_counts = Counter(ref_items) 24 | 25 | match = 0 26 | for tok in hyp_counts: 27 | match += min(hyp_counts[tok], ref_counts[tok]) 28 | 29 | prec_denom = sum(hyp_counts.values()) 30 | if match == 0 or prec_denom == 0: 31 | precision = 0 32 | else: 33 | precision = match / prec_denom 34 | 35 | rec_denom = sum(ref_counts.values()) 36 | if match == 0 or rec_denom == 0: 37 | recall = 0 38 | else: 39 | recall = match / rec_denom 40 | 41 | if precision == 0 or recall == 0: 42 | fscore = 0 43 | else: 44 | fscore = 2 * precision * recall / (precision + recall) 45 | 46 | return {'p': precision, 'r': recall, 'f': fscore} 47 | 48 | 49 | class Oracle(): 50 | def __init__(self, rouge_n=1, metric='f', early_stopping=True): 51 | self.rouge_n = rouge_n 52 | self.metric = metric 53 | self.early_stopping = early_stopping 54 | self.summarizer = Summarizer() 55 | 56 | def summarize(self, 57 | ref, 58 | articles, 59 | max_len=40, 60 | len_type='words', 61 | in_titles=False, 62 | out_titles=False, 63 | min_sent_tokens=7, 64 | max_sent_tokens=40): 65 | 66 | articles = self.summarizer._preprocess(articles) 67 | sents = [s for a in articles for s in a.sents] 68 | sents = self.summarizer._deduplicate(sents) 69 | if in_titles == False or out_titles == False: 70 | sents = [s for s in sents if not s.is_title] 71 | sent_lens = [self.summarizer._sent_len(s, len_type) for s in sents] 72 | current_len = 0 73 | remaining = list(range(len(sents))) 74 | selected = [] 75 | scored_selections = [] 76 | ref_words = word_tokenize(ref) 77 | 78 | while current_len < max_len and len(remaining) > 0: 79 | scored = [] 80 | current_summary_words = [ 81 | tok for i in selected for tok in sents[i].words 82 | ] 83 | for i in remaining: 84 | new_len = current_len + sent_lens[i] 85 | if new_len <= max_len: 86 | try: 87 | summary_words = current_summary_words + sents[i].words 88 | rouge_scores = compute_rouge_n( 89 | summary_words, 90 | ref_words, 91 | rouge_n=self.rouge_n, 92 | tokenize=False 93 | ) 94 | score = rouge_scores[self.metric] 95 | scored.append((i, score)) 96 | except: 97 | pass 98 | if len(scored) == 0: 99 | break 100 | scored.sort(key=lambda x: x[1], reverse=True) 101 | best_idx, best_score = scored[0] 102 | scored_selections.append((selected + [best_idx], best_score)) 103 | current_len += sent_lens[best_idx] 104 | selected.append(scored[0][0]) 105 | remaining.remove(best_idx) 106 | 107 | if self.early_stopping == False: 108 | # remove shorter summaries 109 | max_sents = max([len(x[0]) for x in scored_selections]) 110 | scored_selections = [x for x in scored_selections 111 | if len(x[0]) < max_sents] 112 | 113 | 114 | scored_selections.sort(key=lambda x: x[1], reverse=True) 115 | if len(scored_selections) == 0: 116 | return '' 117 | best_selection = scored_selections[0][0] 118 | summary_sents = [sents[i].text for i in best_selection] 119 | return ' '.join(summary_sents) 120 | 121 | 122 | class SingleOracle(): 123 | def __init__(self, rouge_n=1, metric='f', early_stopping=True): 124 | self.rouge_n = rouge_n 125 | self.metric = metric 126 | self.oracle = Oracle(rouge_n, metric, early_stopping) 127 | 128 | def summarize(self, 129 | ref, 130 | articles, 131 | max_len=40, 132 | len_type='words', 133 | in_titles=False, 134 | out_titles=False, 135 | min_sent_tokens=7, 136 | max_sent_tokens=40): 137 | 138 | scored_oracles = [] 139 | for a in articles: 140 | summary = self.oracle.summarize( 141 | ref, [a], max_len, len_type, in_titles, out_titles, 142 | min_sent_tokens, max_sent_tokens 143 | ) 144 | rouge_scores = compute_rouge_n( 145 | summary, 146 | ref, 147 | rouge_n=self.rouge_n, 148 | tokenize=True 149 | ) 150 | score = rouge_scores[self.metric] 151 | scored_oracles.append((summary, score)) 152 | scored_oracles.sort(key=lambda x: x[1], reverse=True) 153 | return scored_oracles[0][0] 154 | 155 | 156 | class LeadOracle(): 157 | def __init__(self, rouge_n=1, metric='f'): 158 | self.rouge_n = rouge_n 159 | self.metric = metric 160 | self.summarizer = Summarizer() 161 | 162 | def summarize(self, 163 | ref, 164 | articles, 165 | max_len=40, 166 | len_type='words', 167 | in_titles=False, 168 | out_titles=False, 169 | min_sent_tokens=7, 170 | max_sent_tokens=40): 171 | 172 | articles = self.summarizer._preprocess(articles) 173 | scored_summaries = [] 174 | for a in articles: 175 | selected_sents = [] 176 | current_len = 0 177 | sents = a.sents 178 | if in_titles == False or out_titles == False: 179 | sents = [s for s in sents if not s.is_title] 180 | for s in sents: 181 | l = self.summarizer._sent_len(s, len_type) 182 | new_len = current_len + l 183 | if new_len <= max_len: 184 | selected_sents.append(s.text) 185 | current_len = new_len 186 | if new_len > max_len: 187 | break 188 | if len(selected_sents) >= 1: 189 | summary = ' '.join(selected_sents) 190 | rouge_scores = compute_rouge_n( 191 | summary, 192 | ref, 193 | self.rouge_n, 194 | tokenize=True 195 | ) 196 | score = rouge_scores[self.metric] 197 | scored_summaries.append((summary, score)) 198 | scored_summaries.sort(key=lambda x: x[1], reverse=True) 199 | summary = scored_summaries[0][0] 200 | return summary 201 | 202 | 203 | def main(args): 204 | if args.mode == 'predict-lead-oracle': 205 | summarizer = LeadOracle( 206 | rouge_n=args.rouge_n, 207 | metric=args.metric 208 | ) 209 | elif args.mode == 'predict-oracle': 210 | summarizer = Oracle( 211 | rouge_n=args.rouge_n, 212 | metric=args.metric 213 | ) 214 | elif args.mode == 'predict-oracle-single': 215 | summarizer = SingleOracle( 216 | rouge_n=args.rouge_n, 217 | metric=args.metric 218 | ) 219 | else: 220 | raise ValueError('Unknown or unspecified --mode: ' + args.mode) 221 | 222 | summarize_settings = utils.args_to_summarize_settings(args) 223 | Summarizer.summarize_dataset( 224 | summarizer, 225 | dataset_path=args.dataset, 226 | pred_path=args.preds, 227 | summarize_settings=summarize_settings, 228 | start=args.start, 229 | stop=args.stop, 230 | batchsize=args.batchsize, 231 | jobs=args.jobs, 232 | oracle=True 233 | ) 234 | 235 | 236 | if __name__ == '__main__': 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument('--mode') 239 | parser.add_argument('--dataset') 240 | parser.add_argument('--preds') 241 | parser.add_argument('--start', type=int, default=-1) 242 | parser.add_argument('--stop', type=int, default=-1) 243 | parser.add_argument('--max-len', type=int, default=40) 244 | parser.add_argument('--len-type', default='words') 245 | parser.add_argument('--in-titles', action='store_true') 246 | parser.add_argument('--out-titles', action='store_true') 247 | # min/max sent tokens have no effect for oracles 248 | parser.add_argument('--min-sent-tokens', type=int, default=7) 249 | parser.add_argument('--max-sent-tokens', type=int, default=60) 250 | parser.add_argument('--rouge-n', type=int, default=1) 251 | parser.add_argument('--metric', default='f') 252 | parser.add_argument('--batchsize', type=int, default=32) 253 | parser.add_argument('--jobs', type=int, default=4) 254 | parser.add_argument('--early-stopping', action='store_true') 255 | main(parser.parse_args()) 256 | -------------------------------------------------------------------------------- /experiments/requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==0.23.1 2 | networkx==2.4 3 | nltk==3.6.6 4 | numpy>=1.18.5 5 | git+git://github.com/clic-lab/newsroom.git#egg=newsroom 6 | -------------------------------------------------------------------------------- /experiments/sent_splitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from nltk import sent_tokenize 3 | 4 | 5 | class SentenceSplitter: 6 | """ 7 | NLTK sent_tokenize + some fixes for common errors in news articles. 8 | """ 9 | def unglue(self, x): 10 | g = x.group(0) 11 | fixed = '{} {}'.format(g[0], g[1]) 12 | return fixed 13 | 14 | def fix_glued_sents(self, text): 15 | return re.sub(r'\.[A-Z]', self.unglue, text) 16 | 17 | def fix_line_broken_sents(self, sents): 18 | new_sents = [] 19 | for s in sents: 20 | new_sents += [s_.strip() for s_ in s.split('\n')] 21 | return new_sents 22 | 23 | def split_sents(self, text): 24 | text = self.fix_glued_sents(text) 25 | sents = sent_tokenize(text) 26 | sents = self.fix_line_broken_sents(sents) 27 | sents = [s for s in sents if s != ''] 28 | return sents 29 | -------------------------------------------------------------------------------- /experiments/summarizer.py: -------------------------------------------------------------------------------- 1 | import utils 2 | from nltk import word_tokenize, bigrams 3 | from sent_splitter import SentenceSplitter 4 | from data import Sentence, Article 5 | 6 | 7 | class Summarizer: 8 | 9 | def _deduplicate(self, sents): 10 | seen = set() 11 | uniq_sents = [] 12 | for s in sents: 13 | if s.text not in seen: 14 | seen.add(s.text) 15 | uniq_sents.append(s) 16 | return uniq_sents 17 | 18 | def _sent_len(self, sent, len_type): 19 | if len_type == 'chars': 20 | return len(sent.text) 21 | elif len_type == 'words': 22 | return len(sent.words) 23 | elif len_type == 'sents': 24 | return 1 25 | else: 26 | raise ValueError('len_type must be in (chars|words|sents)') 27 | 28 | def _is_redundant(self, sents, selected, new, max_redundancy): 29 | new_bigrams = list(bigrams(sents[new].words)) 30 | l = len(new_bigrams) 31 | for i in selected: 32 | old_bigrams = list(bigrams(sents[i].words)) 33 | n_matching = len([x for x in new_bigrams if x in old_bigrams]) 34 | if n_matching == 0: 35 | continue 36 | else: 37 | overlap = n_matching / l 38 | if overlap >= max_redundancy: 39 | return True 40 | return False 41 | 42 | def _preprocess(self, articles): 43 | sent_splitter = SentenceSplitter() 44 | processed_articles = [] 45 | for a in articles: 46 | body_sents = sent_splitter.split_sents(a['text']) 47 | processed_title = Sentence( 48 | text=a['title'], 49 | words=word_tokenize(a['title']), 50 | position=-1, 51 | is_title=True 52 | ) 53 | processed_sents = [] 54 | for position, s in enumerate(body_sents): 55 | processed_sent = Sentence( 56 | text=s, 57 | words=word_tokenize(s), 58 | position=position 59 | ) 60 | processed_sents.append(processed_sent) 61 | 62 | processed_article = Article(processed_title, processed_sents) 63 | processed_articles.append(processed_article) 64 | 65 | return processed_articles 66 | 67 | def _preprocess_sents(self, raw_sents): 68 | processed_sents = [] 69 | for s in raw_sents: 70 | processed_sent = Sentence( 71 | text=s, 72 | words=word_tokenize(s), 73 | position=None 74 | ) 75 | processed_sents.append(processed_sent) 76 | return processed_sents 77 | 78 | def summarize(self, 79 | articles, 80 | max_len=40, 81 | len_type='words', 82 | in_titles=False, 83 | out_titles=False, 84 | min_sent_tokens=60, 85 | max_sent_tokens=7): 86 | raise NotImplementedError 87 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | import pickle 4 | 5 | 6 | def read_lines(path): 7 | with open(path) as f: 8 | for line in f: 9 | yield line 10 | 11 | 12 | def read_json(path): 13 | with open(path) as f: 14 | object = json.loads(f.read()) 15 | return object 16 | 17 | 18 | def write_json(object, path): 19 | with open(path, 'w') as f: 20 | f.write(json.dumps(object)) 21 | 22 | 23 | def read_jsonl(path, load=False, start=0, stop=None): 24 | 25 | def read_jsonl_gen(path): 26 | with open(path) as f: 27 | for i, line in enumerate(f): 28 | if (stop is not None) and (i >= stop): 29 | break 30 | if i >= start: 31 | yield json.loads(line) 32 | 33 | data = read_jsonl_gen(path) 34 | if load: 35 | data = list(data) 36 | return data 37 | 38 | 39 | def read_jsonl_gz(path): 40 | with gzip.open(path) as f: 41 | for l in f: 42 | yield json.loads(l) 43 | 44 | 45 | def write_jsonl(items, path, batch_size=100, override=True): 46 | if override: 47 | with open(path, 'w'): 48 | pass 49 | 50 | batch = [] 51 | for i, x in enumerate(items): 52 | if i > 0 and i % batch_size == 0: 53 | with open(path, 'a') as f: 54 | output = '\n'.join(batch) + '\n' 55 | f.write(output) 56 | batch = [] 57 | raw = json.dumps(x) 58 | batch.append(raw) 59 | 60 | if batch: 61 | with open(path, 'a') as f: 62 | output = '\n'.join(batch) + '\n' 63 | f.write(output) 64 | 65 | 66 | def load_pkl(path): 67 | with open(path, 'rb') as f: 68 | obj = pickle.load(f) 69 | return obj 70 | 71 | 72 | def dump_pkl(obj, path): 73 | with open(path, 'wb') as f: 74 | pickle.dump(obj, f) 75 | 76 | 77 | def args_to_summarize_settings(args): 78 | args = vars(args) 79 | settings = {} 80 | for k in ['len_type', 'max_len', 81 | 'min_sent_tokens', 'max_sent_tokens', 82 | 'in_titles', 'out_titles']: 83 | settings[k] = args[k] 84 | return settings 85 | -------------------------------------------------------------------------------- /wcep_getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "wcep-getting-started.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "ZwApybTwmNZ-" 24 | }, 25 | "source": [ 26 | "# Getting started with the WCEP dataset" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "fHPqe76-mNkZ" 33 | }, 34 | "source": [ 35 | "## Clone repository & install dependencies" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "metadata": { 41 | "id": "3dXMMBBLWpBS" 42 | }, 43 | "source": [ 44 | "!git clone https://github.com/complementizer/wcep-mds-dataset" 45 | ], 46 | "execution_count": null, 47 | "outputs": [] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "id": "h3CmFmFdQ-8K" 53 | }, 54 | "source": [ 55 | "cd wcep-mds-dataset/experiments" 56 | ], 57 | "execution_count": null, 58 | "outputs": [] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "metadata": { 63 | "id": "Jiw_6QQFmDfi" 64 | }, 65 | "source": [ 66 | "!pip install -r requirements.txt\n", 67 | "!python -m nltk.downloader punkt" 68 | ], 69 | "execution_count": null, 70 | "outputs": [] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "id": "Ox9g3nTdmvo3" 76 | }, 77 | "source": [ 78 | "## Download dataset\n" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "id": "zmlyseD3m4JK" 85 | }, 86 | "source": [ 87 | "!mkdir WCEP\n", 88 | "!gdown https://drive.google.com/uc?id=1kUjSRXzKnTYdJ732BkKVLg3CFxDKo25u -O WCEP/train.jsonl.gz\n", 89 | "!gdown https://drive.google.com/uc?id=1_kHTZ32jazTbXaFRg0vBeIsVcpI7CTmy -O WCEP/val.jsonl.gz\n", 90 | "!gdown https://drive.google.com/uc?id=1qsd5pOCpeSXsaqNobXCrcAzhcjtG1wA1 -O WCEP/test.jsonl.gz" 91 | ], 92 | "execution_count": null, 93 | "outputs": [] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": { 98 | "id": "8x1SC8oysCbd" 99 | }, 100 | "source": [ 101 | "## Load dataset\n", 102 | "\n", 103 | "We use the WCEP validation data as an example.
Each item in the dataset corresponds to a cluster of news articles about a news event and contains some metadata, most importantly the ground-truth summary for the cluster." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "metadata": { 109 | "id": "U_pgsRlisGVY" 110 | }, 111 | "source": [ 112 | "import utils\n", 113 | "\n", 114 | "val_data = list(utils.read_jsonl_gz('WCEP/val.jsonl.gz'))\n", 115 | "\n", 116 | "print(val_data[0].keys())" 117 | ], 118 | "execution_count": null, 119 | "outputs": [] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": { 124 | "id": "SBtm1KELsmP0" 125 | }, 126 | "source": [ 127 | "## Run extractive baselines & oracles\n", 128 | "\n" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "metadata": { 134 | "id": "s6ygQepSs9mu" 135 | }, 136 | "source": [ 137 | "from baselines import RandomBaseline, TextRankSummarizer, CentroidSummarizer, SubmodularSummarizer\n", 138 | "from oracles import Oracle" 139 | ], 140 | "execution_count": null, 141 | "outputs": [] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": { 146 | "id": "Q7DOOKFFvbvD" 147 | }, 148 | "source": [ 149 | "First we create summarizer objects and set their hyperparameters." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "metadata": { 155 | "id": "Wv-rJXfns_VJ" 156 | }, 157 | "source": [ 158 | "random_sum = RandomBaseline()\n", 159 | "textrank = TextRankSummarizer(max_redundancy=0.5)\n", 160 | "centroid = CentroidSummarizer(max_redundancy=0.5)\n", 161 | "submod = SubmodularSummarizer(a=5, div_weight=6, cluster_factor=0.2) # div_weight encourages diversity/non-reduncancy\n", 162 | "oracle = Oracle()" 163 | ], 164 | "execution_count": null, 165 | "outputs": [] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": { 170 | "id": "wqQFXerDvziU" 171 | }, 172 | "source": [ 173 | "Below we pick one set of settings for extractive summarization that we will use for all baselines.
\n", 174 | "* `in_titles` means we add article titles as sentences in the input, and `out_titles` means we also allow these titles to be part of a summary\n", 175 | "* we set a minimum sentence length (`min_sent_tokens`) because short broken sentences appear frequently and are usually not desirable\n", 176 | "* you can set the length contraint to `words`, `sents` or `chars`" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "metadata": { 182 | "id": "vlnPI2ZXuQVq" 183 | }, 184 | "source": [ 185 | "settings = {\n", 186 | " 'max_len': 40, 'len_type': 'words',\n", 187 | " 'in_titles': False, 'out_titles': False,\n", 188 | " 'min_sent_tokens': 7, 'max_sent_tokens': 40, \n", 189 | "}\n", 190 | "max_articles = 20" 191 | ], 192 | "execution_count": null, 193 | "outputs": [] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": { 198 | "id": "eab554ZmwfuQ" 199 | }, 200 | "source": [ 201 | "For a quick experiment, we only select the first 10 clusters of the WCEP validation data and use the first 10 articles of each cluster as inputs." 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "metadata": { 207 | "id": "IQjaFzGxuSfQ" 208 | }, 209 | "source": [ 210 | "example_clusters = [c['articles'][:max_articles] for c in val_data[:10]]\n", 211 | "ref_summaries = [c['summary'] for c in val_data[:10]]" 212 | ], 213 | "execution_count": null, 214 | "outputs": [] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "metadata": { 219 | "id": "Idc7XbPluS8u" 220 | }, 221 | "source": [ 222 | "textrank_summaries = [textrank.summarize(articles, **settings) for articles in example_clusters]\n", 223 | "centroid_summaries = [centroid.summarize(articles, **settings) for articles in example_clusters]\n", 224 | "submod_summaries = [submod.summarize(articles, **settings) for articles in example_clusters]\n", 225 | "random_summaries = [random_sum.summarize(articles, **settings) for articles in example_clusters]" 226 | ], 227 | "execution_count": null, 228 | "outputs": [] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "metadata": { 233 | "id": "KsI0U_GFxVO2" 234 | }, 235 | "source": [ 236 | "oracle_summaries = [oracle.summarize(ref, articles, **settings)\n", 237 | " for (ref, articles) in zip(ref_summaries, example_clusters)]" 238 | ], 239 | "execution_count": null, 240 | "outputs": [] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": { 245 | "id": "Z0uLcSCJuwfJ" 246 | }, 247 | "source": [ 248 | "## Evaluate summaries\n", 249 | "\n", 250 | "**Note:** our `evaluate` function uses a wrapper from the [newsroom library](https://github.com/lil-lab/newsroom) to compute ROUGE scores. \n" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "metadata": { 256 | "id": "MdEhYT7vuy2a" 257 | }, 258 | "source": [ 259 | "from pprint import pprint\n", 260 | "from evaluate import evaluate" 261 | ], 262 | "execution_count": null, 263 | "outputs": [] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "metadata": { 268 | "id": "YMtIIDhhu0ui" 269 | }, 270 | "source": [ 271 | "names = ['TextRank', 'Centroid', 'Submodular', 'Oracle', 'Random']\n", 272 | "outputs = [textrank_summaries, centroid_summaries, submod_summaries, oracle_summaries, random_summaries]\n", 273 | "\n", 274 | "for preds, name in zip(outputs, names):\n", 275 | " print(name)\n", 276 | " results = evaluate(ref_summaries, preds, lowercase=True)\n", 277 | " pprint(results)\n", 278 | " print()" 279 | ], 280 | "execution_count": null, 281 | "outputs": [] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": { 286 | "id": "L_6WUfV-wwBd" 287 | }, 288 | "source": [ 289 | "Let's look at some example summaries." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "metadata": { 295 | "id": "b0hgVBJ8wvGT" 296 | }, 297 | "source": [ 298 | "cluster_idx = 6\n", 299 | "print('Ground-truth')\n", 300 | "print(ref_summaries[cluster_idx])\n", 301 | "print()\n", 302 | "\n", 303 | "for preds, name in zip(outputs, names):\n", 304 | " print(name)\n", 305 | " print(preds[cluster_idx])\n", 306 | " print()" 307 | ], 308 | "execution_count": null, 309 | "outputs": [] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": { 314 | "id": "_7qLELVUXlZX" 315 | }, 316 | "source": [ 317 | "### Blog Example\n" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "metadata": { 323 | "id": "IwqdZ5R-YQ9s" 324 | }, 325 | "source": [ 326 | "from utils import read_jsonl_gz\n", 327 | "from baselines import TextRankSummarizer\n", 328 | "from evaluate import evaluate\n", 329 | "from pprint import pprint\n", 330 | "\n", 331 | "textrank = TextRankSummarizer()\n", 332 | "\n", 333 | "dataset = list(read_jsonl_gz('WCEP/val.jsonl.gz'))\n", 334 | "cluster = dataset[954]\n", 335 | "articles = cluster['articles'][:10]\n", 336 | "\n", 337 | "human_summary = cluster['summary']\n", 338 | "automatic_summary = textrank.summarize(articles)\n", 339 | "results = evaluate([human_summary], [automatic_summary])\n", 340 | "\n", 341 | "print('Summary:')\n", 342 | "print(automatic_summary)\n", 343 | "print()\n", 344 | "print('ROUGE scores:')\n", 345 | "pprint(results)" 346 | ], 347 | "execution_count": null, 348 | "outputs": [] 349 | } 350 | ] 351 | } --------------------------------------------------------------------------------