├── LICENSE ├── Makefile ├── README.md ├── VERSION ├── bin ├── multidoc_jsonl_dataset_to_parallel_dataset.py └── run_bart_sum.py ├── data └── test_dataset │ ├── train.source │ ├── train.target │ ├── val.source │ └── val.target ├── requirements.txt ├── research ├── coverage_and_density_analysis │ ├── analyse_dataset.py │ ├── plots.py │ └── utils.py ├── debug_multinews_dataset.ipynb ├── duc2004_to_jsonl.ipynb ├── filter_multinews_duplicate_articles.ipynb ├── follow_lebanoff_et_al_2018_evaluation.ipynb ├── multidoc_jsonl_dataset_to_parallel_dataset.ipynb ├── multinews_to_mds_jsonl.ipynb ├── multinews_to_single_doc_parallel.ipynb ├── prototype_bart_transformer_integration.ipynb ├── prototype_multidoc_summarization_evaluation.ipynb ├── prototype_with_presumm.ipynb └── visualize_per_input_scores_with_heatmap.ipynb ├── setup.py └── transformer_decoding ├── __init__.py ├── bart_utils.py ├── decoding_utils.py ├── evaluate.py ├── finetune.py ├── log.py ├── test_decoding.py └── transformer_base.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chris Hokamp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Dynamic Ensembling Makefile 2 | 3 | # FINE-TUNING (Training BART with in-domain data) 4 | # summarization datadir (TODO: format of this data?) 5 | DATADIR ?= 'data/test_dataset' 6 | BASE_MODEL_NAME_OR_PATH ?= 'bart-large-cnn' 7 | OUTPUT_DIR ?= 'fine-tuned-model' 8 | N_GPU ?= 0 9 | MAX_SOURCE_LEN ?= 512 10 | MAX_TARGET_LEN ?= 60 11 | TRAIN_BATCH_SIZE ?= 1 12 | EVAL_BATCH_SIZE ?= 1 13 | 14 | # EVALUATION ARGS 15 | EVALUATION_DATASET ?= data/WCEP/test.jsonl 16 | MODEL_ID ?= bart-large-cnn 17 | MAX_ARTICLES_IN_CLUSTER ?= 5 18 | 19 | # used for flags and additional script args 20 | RUN_FLAGS ?= 21 | 22 | 23 | ########### 24 | ## TASKS ## 25 | ########### 26 | .PHONY: predict 27 | evaluate: 28 | python transformer_decoding/evaluate.py \ 29 | --evaluation-dataset $(EVALUATION_DATASET) \ 30 | --model-id $(MODEL_ID) \ 31 | $(RUN_FLAGS) 32 | 33 | .PHONY: evaluate 34 | evaluate: 35 | python transformer_decoding/evaluate.py \ 36 | --evaluation-dataset $(EVALUATION_DATASET) \ 37 | --model-id $(MODEL_ID) \ 38 | $(RUN_FLAGS) 39 | 40 | .PHONY: fine-tune-bart 41 | fine-tune-bart: 42 | mkdir -p $(OUTPUT_DIR) 43 | python transformer_decoding/finetune.py \ 44 | --data_dir $(DATADIR) \ 45 | --model_type bart \ 46 | --model_name_or_path $(BASE_MODEL_NAME_OR_PATH) \ 47 | --learning_rate 3e-5 \ 48 | --train_batch_size $(TRAIN_BATCH_SIZE) \ 49 | --eval_batch_size $(EVAL_BATCH_SIZE) \ 50 | --max_source_length $(MAX_SOURCE_LEN) \ 51 | --max_target_length $(MAX_TARGET_LEN) \ 52 | --output_dir $(OUTPUT_DIR) \ 53 | --n_gpu $(N_GPU) \ 54 | --do_train 55 | 56 | #.PHONY: fine-tune-bart 57 | #fine-tune-bart: 58 | # mkdir -p $(OUTPUT_DIR) 59 | # python bin/run_bart_sum.py \ 60 | # --data_dir $(DATADIR) \ 61 | # --model_type bart \ 62 | # --model_name_or_path $(BASE_MODEL_NAME_OR_PATH) \ 63 | # --learning_rate 3e-5 \ 64 | # --train_batch_size $(TRAIN_BATCH_SIZE) \ 65 | # --eval_batch_size $(EVAL_BATCH_SIZE) \ 66 | # --max_seq_length $(MAX_SEQ_LEN) \ 67 | # --output_dir $(OUTPUT_DIR) \ 68 | # --n_gpu $(N_GPU) \ 69 | # --do_train 70 | 71 | resources/$(TEST_RESOURCES_VERSION): 72 | mkdir -p ./resources 73 | gsutil cp -r $(RESOURCES_ROOT)/$(TEST_RESOURCES_VERSION) ./resources 74 | 75 | .PHONY: test 76 | test: resources/$(TEST_RESOURCES_VERSION) 77 | RESOURCES=resources/$(TEST_RESOURCES_VERSION) python -Wignore -m unittest discover 78 | pycodestyle aylien_entity_linking 79 | 80 | .PHONY: clean 81 | clean: 82 | rm -f *.pyc *.pkl *.npy 83 | rm -rf *.egg-info 84 | 85 | ################# 86 | ## DEVELOPMENT ## 87 | ################# 88 | 89 | .PHONY: dev 90 | dev: 91 | pip install -e . 92 | pip install -r requirements.txt 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DynE: Dynamic Ensemble Decoding for Multi-Document Summarization 2 | 3 | This repo contains the code for [DynE: Dynamic Ensemble Decoding for Multi-Document Summarization](https://arxiv.org/abs/2006.08748). 4 | 5 | This code base can be used to add dynamic ensembling capability to models from the [Huggingface transformers library](https://github.com/huggingface/transformers). 6 | 7 | ## Setup / Installation 8 | 9 | ``` 10 | # make a fresh environment 11 | conda create -n dynamic-ensembles python=3.6 12 | conda activate dynamic-ensembles 13 | 14 | # Installation 15 | make dev 16 | ``` 17 | 18 | 19 | ### Multi-Document Summarization (MDS) Datasets 20 | 21 | MDS datasets in the format required by the scripts in this repo: 22 | - [WCEP](https://drive.google.com/drive/folders/1KSxlIx9Hq6l3pTTvsrbug-gpeuQIrQgW?usp=sharing) (train, val, test) 23 | - [MultiNews](https://drive.google.com/drive/folders/1nuBM8aMjauA7bKOdPeQf6DeiR8-TeMaR?usp=sharing) (train, val, test) 24 | - [DUC2004](https://drive.google.com/drive/folders/1q11LDSGqan-zHiMgA8IiB-vnfIXz39IJ?usp=sharing) (test) 25 | 26 | The original WCEP dataset used to generate the flat training data: 27 | - [WCEP in `.jsonl` format](https://drive.google.com/drive/folders/1PJufMEOdogIaKQq-PlB4vvawLa6tvEG6) 28 | 29 | ---------------------- 30 | 31 | ### Model Checkpoints and Outputs 32 | 33 | ##### Model Checkpoints 34 | 35 | We fine-tune the `bart-large-cnn` single-document summarization model from the [transformers library](https://github.com/huggingface/transformers) 36 | - The best fine-tuned model checkpoints for WCEP and MultiNews are [here](https://drive.google.com/drive/folders/1B449P6kwm6_6AjpaASduGMi3Ff6Z1IBd?usp=sharing) 37 | 38 | ##### Fine-tuned Model Outputs 39 | 40 | - Download the outputs of fine-tuned models on the test sets of WCEP and MultiNews [here](https://drive.google.com/drive/folders/1dCwg-sd0bPiZZV7nDLOO2ZoUcCDRiO3V?usp=sharing) 41 | 42 | ---------------------- 43 | 44 | ### Evaluation 45 | Prediction and evaluation are done by the script `transformer_decoding/evaluate.py` 46 | There is also a `make` task for evaluation which simply calls this script. 47 | 48 | For example, to predict using a model id from `transformers`, or with a fine-tuned model checkpoint, 49 | and evaluate with the Ghalandari et al. 2020 evaluation workflow: 50 | ``` 51 | MODEL_ID=model_checkpoints/wcep_fine-tune-bart-large/checkpointepoch\=1.ckpt \ 52 | RUN_FLAGS='--max-articles-in-cluster 5 --max-src-length 512 --max-tgt-length 64 --num-beams 5 --eval-prefix wcep_5_articles_' \ 53 | make evaluate 54 | ``` 55 | - pretrained model checkpoints can be downloaded from the links above. 56 | 57 | For a quick test, use the `--rows-to-eval` argument, which will only predict the first `N` rows from the dataset: 58 | ``` 59 | MODEL_ID=model_checkpoints/wcep_fine-tune-bart-large/checkpointepoch\=1.ckpt \ 60 | RUN_FLAGS='--max-articles-in-cluster 5 --max-src-length 512 --max-tgt-length 64 --num-beams 5 --rows-to-eval 10 --eval-prefix wcep_5_articles_' \ 61 | make evaluate 62 | ``` 63 | 64 | To run evaluation only, using previously generated predictions, supply the `--predictions` argument to `transformer_decoding/evaluate.py`: 65 | ``` 66 | EVALUATION_DATASET=data/WCEP/test.jsonl \ 67 | RUN_FLAGS='--predictions outputs/wcep/wcep_5_articles_eval_predicted_summaries.out' \ 68 | make evaluate 69 | ``` 70 | 71 | ##### Scoring Gold Summaries by Forced Decoding 72 | 73 | ``` 74 | 75 | EVALUATION_DATASET=data/WCEP/test.jsonl \ 76 | RUN_FLAGS='--force-decode-gold --max-articles-in-cluster 5 --max-src-length 512 --max-tgt-length 512 --num-beams 1 --rows-to-eval 10 --eval-prefix wcep_5_articles_' \ 77 | make evaluate 78 | 79 | ``` 80 | 81 | ---------------------- 82 | 83 | ### Citing 84 | 85 | If you use ideas or code from this project, please cite: 86 | ``` 87 | @article{DynamicEnsembles, 88 | title = {DynE: Dynamic Ensemble Decoding for Multi-Document Summarization}, 89 | author = {Chris Hokamp and Demian Gholipour Ghalandari and Nghia The Pham 90 | and John Glover}, 91 | journal={arXiv preprint arXiv:2006.08748}, 92 | year = {2020}, 93 | } 94 | 95 | ``` 96 | 97 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.0.1 2 | -------------------------------------------------------------------------------- /bin/multidoc_jsonl_dataset_to_parallel_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | # flatten a multidoc summarization dataset in .jsonl format to a parallel dataset that uses 8 | # the *.sources *.targets format from cnn-dm 9 | 10 | # TODO: support shuffling since the cluster items will be sequential by default 11 | 12 | # TODO: support formatting with special tokens to indicate document structure (i.e. token between Title and Body) 13 | 14 | 15 | 16 | # In[1]: 17 | 18 | 19 | from pathlib import Path 20 | import json 21 | import tqdm 22 | 23 | import numpy as np 24 | 25 | from transformer_decoding.evaluate import article_to_text 26 | 27 | 28 | DATADIR = Path('/home/chris/projects/aylien/dynamic-ensembles/data/WCEP') 29 | prefixes = ['train', 'val'] 30 | shuffle = True 31 | separator_token = ' [SEP] ' 32 | 33 | 34 | for dataset_prefix in prefixes: 35 | sources_and_targets = [] 36 | cluster_cnt = 0 37 | print('loading clusters') 38 | for cluster in tqdm.tqdm((json.loads(l) for l in open(DATADIR / (dataset_prefix + '.jsonl')))): 39 | for article in cluster['articles']: 40 | sources_and_targets.append((article_to_text(article, separator_token=separator_token), cluster['summary'])) 41 | cluster_cnt += 1 42 | 43 | output_idxs = np.arange(len(sources_and_targets)) 44 | if shuffle: 45 | np.random.shuffle(output_idxs) 46 | 47 | with open(DATADIR / (dataset_prefix + '.sources'), 'w') as srcs, open(DATADIR / (dataset_prefix + '.targets'), 'w') as tgts: 48 | for idx in tqdm.tqdm(output_idxs): 49 | src = sources_and_targets[idx][0] 50 | tgt = sources_and_targets[idx][1] 51 | srcs.write(f'{src}\n') 52 | tgts.write(f'{tgt}\n') 53 | print(f'wrote {len(sources_and_targets)} segments from {cluster_cnt} clusters to {srcs.name} and {tgts.name}') 54 | 55 | 56 | 57 | 58 | 59 | # In[ ]: 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /bin/run_bart_sum.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | import time 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from transformer_decoding.transformer_base import (BaseTransformer, 11 | add_generic_args, 12 | generic_train, 13 | get_linear_schedule_with_warmup) 14 | 15 | from transformer_decoding.bart_utils import SummarizationDataset 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class BartSystem(BaseTransformer): 22 | 23 | mode = "language-modeling" 24 | 25 | def __init__(self, hparams): 26 | super(BartSystem, self).__init__(hparams, num_labels=None, mode=self.mode) 27 | 28 | def forward( 29 | self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None 30 | ): 31 | return self.model( 32 | input_ids, 33 | attention_mask=attention_mask, 34 | decoder_input_ids=decoder_input_ids, 35 | decoder_attention_mask=decoder_attention_mask, 36 | lm_labels=lm_labels, 37 | ) 38 | 39 | def _step(self, batch): 40 | y = batch["target_ids"] 41 | y_ids = y[:, :-1].contiguous() 42 | lm_labels = y[:, 1:].clone() 43 | lm_labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100 44 | outputs = self( 45 | input_ids=batch["source_ids"], 46 | attention_mask=batch["source_mask"], 47 | decoder_input_ids=y_ids, 48 | lm_labels=lm_labels, 49 | ) 50 | 51 | loss = outputs[0] 52 | 53 | return loss 54 | 55 | def training_step(self, batch, batch_idx): 56 | loss = self._step(batch) 57 | 58 | tensorboard_logs = {"train_loss": loss} 59 | return {"loss": loss, "log": tensorboard_logs} 60 | 61 | def validation_step(self, batch, batch_idx): 62 | loss = self._step(batch) 63 | return {"val_loss": loss} 64 | 65 | def validation_end(self, outputs): 66 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 67 | tensorboard_logs = {"val_loss": avg_loss} 68 | return {"avg_val_loss": avg_loss, "log": tensorboard_logs} 69 | 70 | def test_step(self, batch, batch_idx): 71 | generated_ids = self.model.generate( 72 | batch["source_ids"], 73 | attention_mask=batch["source_mask"], 74 | num_beams=1, 75 | max_length=80, 76 | repetition_penalty=2.5, 77 | length_penalty=1.0, 78 | early_stopping=True, 79 | ) 80 | preds = [ 81 | self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) 82 | for g in generated_ids 83 | ] 84 | target = [ 85 | self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) 86 | for t in batch["target_ids"] 87 | ] 88 | loss = self._step(batch) 89 | 90 | return {"val_loss": loss, "preds": preds, "target": target} 91 | 92 | def test_end(self, outputs): 93 | return self.validation_end(outputs) 94 | 95 | def test_epoch_end(self, outputs): 96 | output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt") 97 | output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt") 98 | # write predictions and targets for later rouge evaluation. 99 | with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer: 100 | for output_batch in outputs: 101 | p_writer.writelines(s + "\n" for s in output_batch["preds"]) 102 | t_writer.writelines(s + "\n" for s in output_batch["target"]) 103 | p_writer.close() 104 | t_writer.close() 105 | 106 | return self.test_end(outputs) 107 | 108 | def train_dataloader(self): 109 | train_dataset = SummarizationDataset( 110 | self.tokenizer, data_dir=self.hparams.data_dir, type_path="train", block_size=self.hparams.max_seq_length 111 | ) 112 | dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size) 113 | t_total = ( 114 | (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) 115 | // self.hparams.gradient_accumulation_steps 116 | * float(self.hparams.num_train_epochs) 117 | ) 118 | scheduler = get_linear_schedule_with_warmup( 119 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total 120 | ) 121 | self.lr_scheduler = scheduler 122 | return dataloader 123 | 124 | def val_dataloader(self): 125 | val_dataset = SummarizationDataset( 126 | self.tokenizer, data_dir=self.hparams.data_dir, type_path="val", block_size=self.hparams.max_seq_length 127 | ) 128 | return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size) 129 | 130 | def test_dataloader(self): 131 | test_dataset = SummarizationDataset( 132 | self.tokenizer, data_dir=self.hparams.data_dir, type_path="test", block_size=self.hparams.max_seq_length 133 | ) 134 | return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size) 135 | 136 | @staticmethod 137 | def add_model_specific_args(parser, root_dir): 138 | BaseTransformer.add_model_specific_args(parser, root_dir) 139 | # Add BART specific options 140 | parser.add_argument( 141 | "--max_seq_length", 142 | default=1024, 143 | type=int, 144 | help="The maximum total input sequence length after tokenization. Sequences longer " 145 | "than this will be truncated, sequences shorter will be padded.", 146 | ) 147 | 148 | parser.add_argument( 149 | "--data_dir", 150 | default=None, 151 | type=str, 152 | required=True, 153 | help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.", 154 | ) 155 | return parser 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser() 160 | add_generic_args(parser, os.getcwd()) 161 | parser = BartSystem.add_model_specific_args(parser, os.getcwd()) 162 | args = parser.parse_args() 163 | 164 | # If output_dir not provided, a folder will be generated in pwd 165 | if args.output_dir is None: 166 | args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",) 167 | os.makedirs(args.output_dir) 168 | 169 | model = BartSystem(args) 170 | trainer = generic_train(model, args) 171 | 172 | # Optionally, predict on dev set and write to output_dir 173 | if args.do_predict: 174 | checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) 175 | BartSystem.load_from_checkpoint(checkpoints[-1]) 176 | trainer.test(model) 177 | -------------------------------------------------------------------------------- /data/test_dataset/val.target: -------------------------------------------------------------------------------- 1 | The U.S. House Committee on Ways and Means formally requests six years of President Donald Trump's personal and business tax returns from the Internal Revenue Service. 2 | Two reported Jaish-e-Mohammed militants, one civilian and four Indian soldiers are killed in a skirmish in the Pulwama district of Indian-administered Kashmir as police search for suspects in Thursday's suicide attack that killed 40 Indian paramilitary police. 3 | After announcing that it was canceling its New York City headquarters plans the day before, Amazon announces that it will keep its second headquarters in Virginia. 4 | Several European Union states, including the United Kingdom, Germany, Portugal and Spain, officially recognize Juan Guaidó as interim President of Venezuela after Nicolás Maduro rejects the European ultimatum to call a new snap election. Other European Union countries, such as Greece and Ireland, stop short of recognizing Guaidó, while Italy's leading coalition party, the 5 Star Movement, declares that it is not "for the EU to tell another nation what to do". 5 | Cyclone Fani, an extremely severe category 4 storm and one of the strongest in recent years, makes landfall at the coastal town of Puri in the Indian state of Odisha. Eight people have been killed in India, according to the Press Trust of India, and hundreds more injured. Severe damage and flooding has been reported. One million Indians and 2.1 million Bangladeshis have been evacuated. A storm surge possibly up to 1.5m (5ft) is expected. The storm, weakening as it travels northeast through India, is expected to reach Chittagong in Bangladesh Saturday. (Reuters) (BBC) 6 | Yale University researchers led by professor Nenad Sestan announce, through the Nature journal, that they successfully partly revived the brains of deceased pigs, four hours after death occurred. However, there were no signals from the brains that would indicate awareness or consciousness. 7 | Nationwide protests against animal cruelty take place in Australia. In the city of Melbourne, protesters blocked a major intersection for four hours before it was dispatched by police; several people were arrested. 8 | The United States, Canada and several Latin American nations recognize opposition leader Juan Guaidó as President of Venezuela. 9 | Peace talks begin in Qatar between the United States and the Taliban. 10 | After a fire at a Consumers Energy natural gas compressor station, Michigan residents are asked to turn down their heat to conserve natural gas. 11 | An investigation by Houston Chronicle and the San Antonio Express-News found that since 1998, about 380 Southern Baptist church clerics, laypersons, and volunteers have faced credible accusations of sexual abuse and that of those, roughly 220 were convicted of sex crimes or received plea deals, in cases involving more than 700 victims. Many accusers were young men and women, who allegedly experienced everything from molestation to rape and impregnation at the hands of church members. 12 | Venezuelan National Assembly-declared interim President Juan Guaidó defies Venezuelan President Nicolás Maduro's threats and returns to Venezuela where he is received by tens of thousands of people in Caracas. 13 | At his trial in a Brooklyn federal court, Joaquín "El Chapo" Guzmán is found guilty on all 10 counts; the charges include engaging in a continuing criminal enterprise, conspiracy to launder narcotics proceeds, international distribution of cocaine, heroin, marijuana and other drugs, and use of firearms during the commission of a felony. 14 | Pakistani Prime Minister Imran Khan says that captured Indian Air Force pilot Abhinandan Varthaman will be released tomorrow as a gesture of peace. 15 | Dozens of people remain missing one day after a deadly attack on a popular hotel complex in Nairobi, according to the Kenya Red Cross Society. 16 | The Sudanese military surrounds the presidential palace in the capital Khartoum, and takes over state media buildings amid protests against President Omar al-Bashir. Several senior ministers are reportedly arrested, including former Defense Minister Abdel Rahim Mohammed Hussein. 17 | The Syrian Observatory for Human Rights reports that the U.S.-backed Syrian Democratic Forces have taken control of Al-Baghuz Fawqani, following the surrender of the last remaining ISIL militants in the town. 18 | Schools in the Denver, Colorado metropolitan area are closed as the FBI searched for an 18-year-old Miami Beach Senior High School student described as "infatuated" with the Columbine High School massacre. She traveled to Colorado and purchased a shotgun and ammunition upon arriving, and was deemed a "credible threat" to area schools. She was later confirmed dead. 19 | Algeria's Chief of Staff of the People's National Army Ahmed Gaid Salah, the highest-ranked military official in the country, gives a televised address, calling on President Abdelaziz Bouteflika to resign or be declared "unfit to serve" by the People's National Assembly. 20 | At least 17 people are killed in a fire at a hotel in Delhi, India. 21 | The Kansas City Chiefs suspend wide receiver and return specialist Tyreek Hill from team activities for an alleged child abuse case. 22 | Former Israeli minister Gonen Segev pleads guilty to spying for Iran, in exchange for an 11-year prison sentence. 23 | The first child of Prince Harry, Duke of Sussex, and Meghan, Duchess of Sussex, is born; the boy is seventh in line to the British throne. 24 | An indicative vote on holding more indicative votes for Brexit possibilities in the House of Commons achieves a vote of 310 Aye and 310 No, becoming the first British parliamentary vote to result in a draw in the 21st century; Speaker John Bercow breaks the tie by voting No. 25 | ISIL militants kill 50 Syrian soldiers in two days of clashes in the Syrian desert. 26 | AT&T Inc. sells its 10% stake in Hulu to Hulu LLC for $1.43 billion. 27 | A federal district court in Washington State issues a preliminary injunction against enforcement of an initiative by the Donald Trump administration ("gag order") that would have restricted doctor-patient communications about abortion in family planning clinics that receive U.S. taxpayer funding. 28 | The UK Parliament declares "an environment and climate emergency." (CNN) 29 | The Government of the Central African Republic sign a peace deal with 14 rebel groups. 30 | A federal judge rules that the exclusion of women from the Selective Service System is unconstitutional. 31 | A Hamas militant is killed and two others are injured when an Israeli tank fires into the Gaza Strip after a protest turned violent. The Israel Defense Forces action is retaliation for a shooting which lightly injured an officer, and for the brief incursion of two Palestinians into Israel. 32 | The parliament of Greece votes 153–146 in approval of changing the name of the Republic of Macedonia to the Republic of North Macedonia, ending a 27-year naming dispute. 33 | The U.S. Supreme Court allows, by a 5–4 vote, the Trump administration to begin implementing the policy that prohibits transgender persons who require or have undergone gender transition from serving. Unresolved challenges remain in lower courts. 34 | New England Patriots owner Robert Kraft is charged in a prostitution sting in the U.S. state of Florida. 35 | Following Vladimir Putin's warning that Russia would deploy nuclear missiles in Europe if the United States deploys intermediate-range nuclear missiles in Europe, Putin's ally Dmitry Kiselyov lists what he claims are the targets in the United States, which includes The Pentagon, Camp David, Fort Ritchie, McClellan Air Force Base, and Jim Creek Naval Radio Station. Kremlin spokesperson Dmitry Peskov denies the existence of the target list. 36 | The European Union confirms that UK Prime Minister Theresa May's requested short extension date of 30 June is too late due to pending EU elections. After lengthy discussions, EU leaders agree that if the Prime Minister's deal is passed next week a short extension until 22 May is available to pass the necessary legislation. If that deal is not passed, the UK is given until 12 April to define whether it will participate in EU elections. 37 | Scientists announce through the Nature journal, the discovery of 13 deep-space fast radio bursts (FRBs), named FRB 180814, by the CHIME radio telescope in British Columbia, Canada. 38 | The New Yorker reports that Fox News owner Rupert Murdoch prevented the release of a story about hush money which Donald Trump paid to Stormy Daniels before the 2016 election because Murdoch wanted Trump to win the election. 39 | Japanese architect Arata Isozaki is announced as the winner of the 2019 Pritzker Prize. 40 | Vajiralongkorn, King of Thailand, marries his head of security and mistress, General Suthida Tidjai, days before his coronation. 41 | Actor Jussie Smollett is indicted on 16 felony counts of disorderly conduct for allegedly filing a false hate crime police report in January in Chicago. 42 | German politician Frank Magnitz, a member of the AfD party, was beaten unconscious by three masked assailants in the city of Bremen on Monday. AfD party leader Alice Weidel calls the attack an "assassination attempt" and politicians from other German parties condemn the attack. 43 | Two children are killed and twenty are injured in China after a dust devil lifts an inflatable castle off the ground. 44 | Democratic U.S. Senator Kirsten Gillibrand of New York announces she is running for president in 2020. 45 | Two reported Jaish-e-Mohammed militants, one civilian and four Indian soldiers are killed in a skirmish in the Pulwama district of Indian-administered Kashmir as police search for suspects in Thursday's suicide attack that killed 40 Indian paramilitary police. 46 | Opposition leader Juan Guaidó swears himself in as President of Venezuela, with de facto President Nicolás Maduro not recognizing this. 47 | Voters are called to the polls to elect the new members of the People's Majlis. Preliminary results give the victory to Maldivian Democratic Party, led by former President Mohamed Nasheed. 48 | At least 73 people have been killed in flash floods in the province of Papua, Indonesia. 49 | North Korean leader Kim Jong-un arrives in Beijing for his fourth summit meeting with Chinese leader Xi Jinping. 50 | Attorney General William Barr releases the "principal conclusions" of Mueller's investigation in a four-page public letter to the Congress's Judiciary Committee leadership. 51 | The Pakistani Air Force (PAF) conducted six airstrikes in Indian-administered Kashmir, shooting down one Indian aircraft and capturing its pilot following a dogfight. Pakistani officials claimed that they have shot down two Indian Air Force (IAF) jets. Wreckage of one aircraft fell in Azad Kashmir while the other fell in Indian-administered Kashmir. Indian officials initially rejected that any of their pilot was in captured, however, once the footage of Indian pilot in Pakistan's custody were released on the internet, Indian officials admitted that their pilot was captured and his aircraft was shot down. India officials also claimed to have shot down one Pakistani Air Force (PAF) jet that violated its airspace. However, Indian officials claim were rejected by Pakistan. 52 | A building in Lagos, Nigeria, collapses, killing ten people and leaving more than 100 others trapped under the rubble. 53 | FEMA Administrator Brock Long announces his resignation. 54 | The death toll from a North American cold wave caused by a polar vortex rises to at least 21 people. 55 | Avengers: Endgame breaks numerous box office records, including the biggest opening weekend in cinematic history, grossing over $1 billion worldwide. 56 | A propaganda video featuring Islamic State of Iraq and the Levant leader Abu Bakr al-Baghdadi is released, in which he references the 2019 Sri Lanka Easter bombings, indicating that he is still alive and that the video was shot very recently. It is the first time he's been seen on video since July 2014, when he addressed a crowd at the now destroyed Great Mosque of al-Nuri in Mosul. 57 | British Prime Minister Theresa May fires Gavin Williamson as Secretary of State for Defence, following the leaking of information relating to a National Security Council meeting, regarding the security risk posed by Chinese multinational telecommunications company Huawei. Secretary of State for International Development Penny Mordaunt is appointed the first-ever female British Defence Secretary. 58 | North Korean Chairman Kim Jong-un is to arrive in Vietnam on February 25 ahead of the second summit meeting between the two leaders on February 27–28 in Hanoi. 59 | The first death due to cholera is confirmed in Mozambique, with confirmed cases rising to 517. 60 | Scientists announce that the Megachile pluto (Wallace's giant bee), the world's largest bee, has been rediscovered in North Maluku, Indonesia, after no confirmed sightings since 1981. The first ever pictures and videos are taken of the rare species. 61 | Pope Francis announces the Vatican's historical archives of Pope Pius XII's pontificate (1939–1958) will be accessible to scholars next year, effective 2 March 2020. 62 | Prime Minister of Spain Pedro Sánchez calls for a snap general election on 28 April and will dissolve the Cortes Generales on 5 March after failing to approve a government budget. 63 | A fuel tanker lorry overturns and explodes near the airport in Niamey, Niger, killing at least 58 people. 64 | American singer Mariah Carey performs her concert in Jeddah, Saudi Arabia, despite Saudi women's rights activists calling for her to cancel it. 65 | The Syrian Democratic Forces announce the capture of the last territory held by ISIL in Syria. 66 | At least 17 people are killed in a fire at a hotel in Delhi, India. 67 | Vajiralongkorn, King of Thailand, marries his head of security and mistress, General Suthida Tidjai, days before his coronation. 68 | The United Nations Security Council Sanctions Committee permits those North Koreans on its global travel ban list to travel to Hanoi, Vietnam, for the second summit meeting with U.S. President Donald Trump. 69 | An Israeli team of scientists claim to have developed a cure for cancer. This claim is criticized by other scientists, who say it is likely faked. 70 | David Saint-Jacques becomes the fourth Canadian astronaut to take part in a spacewalk and the first in 12 years as he begins a roughly seven-hour mission. 71 | The United States Justice Department, Department of Homeland Security, Department of Commerce and Federal Bureau of Investigation announces 23 criminal charges against China's telecom Huawei and its chief financial officer Wanzhou Meng, which include banking and financial fraud, money laundering, wire fraud, conspiracy to defraud the United States, theft of trade secret technology, provided bonus to workers who stole confidential information from companies around the world, obstruction of justice and sanctions violations. 72 | A fast-moving fire has swept through a historic district of Bangladesh's capital Dhaka, killing at least 80 and wounding 50 others. 73 | Approximately one million Britons assemble for the People's Vote March in London, United Kingdom advocating for an additional referendum on Brexit. 74 | Voters in Moldova go to the polls to elect the new members of Parliament. 75 | The European People's Party votes to suspend Hungary's ruling Fidesz party citing its anti-immigration stance, and personal attacks on Jean-Claude Juncker and George Soros. Hungarian Prime Minister and Fidesz leader Viktor Orbán had threatened to pull out of the EPP if it was suspended. 76 | U.S. President Donald Trump agrees to keep about 400 U.S. troops in Syria. 77 | 32 Afghan border security troops are killed by a Taliban attack in Kandahar, Afghanistan. 78 | Partially-recognized Venezuelan Acting President Juan Guaidó says the country has "truly collapsed already", while accusing the Nicolás Maduro-led government of murdering 17 people during the ongoing nationwide power blackout. 79 | About 42,000 active-duty Coast Guard members, who continue to work on essential operations, missed their scheduled paycheck Tuesday. This is the first time that United States Armed Forces servicemembers have not been paid during a shutdown or other lapse in government appropriation. 80 | Sixteen people are killed when a magnitude 6.1 earthquake strikes the municipalities of Porac and Lubao, Pampanga province, Philippines at 17:11 Philippine Standard Time. 81 | Voters in North Macedonia head to the polls for the second round of the 2019 presidential election. Stevo Pendarovski of the Social Democratic Union of Macedonia wins over his opponent Gordana Siljanovska-Davkova. 82 | First Vice President of Afghanistan Abdul Rashid Dostum survives an assassination attempt on his convoy while traveling to the Jowzjan Province in Afghanistan. The attack killed one bodyguard and injured two others. 83 | One person is killed and another injured in a shooting in a church in the British Columbia city of Salmon Arm. The suspected shooter was wrestled to the ground by churchgoers, police say. 84 | Supreme Court Justice Christian Zerpa [es] of Venezuela defects to the United States and denounces the Government of Venezuela for rigging the election. 85 | Police arrest five men, three from the mining company Vale and two engineers from a subsidiary company, in connection with the mine collapse. 86 | The Attorney General of Israel Avichai Mandelblit says that after more than two years of investigations he has decided to indict Prime Minister Benjamin Netanyahu on charges of bribery, fraud and breach of trust. 87 | Jody Wilson-Raybould resigns as Canada's Minister of Veterans Affairs less than one month into her term amid allegations she had been pressured by the Prime Minister's Office to go easy on SNC-Lavalin while she was Minister of Justice. 88 | Roger Federer wins the 2019 Dubai Tennis Championships, his 100th ATP singles title overall, defeating Stefanos Tsitsipas, 6–4, 6–4, in the final. Federer became the second male tennis player in the Open Era after Jimmy Connors to win 100 ATP singles titles. 89 | Nationwide protests organised by the French Confédération Générale du Travail union attract over 18,000 people in Paris, including many supporters of the Yellow vest movement. (France24) (Daily Mail) 90 | Asia Press reports that North Korean leader Kim Jong-un ordered the execution by firing squad of four foreign ministry officials following the failure of his February Hanoi summit with U.S. President Donald Trump, after accusing them of "selling information to the U.S." before the summit. 91 | The European Parliament approves two revisions to the controversial Directive on Copyright in the Digital Single Market. One resolution includes new requirements aimed at making companies pay licensing fees to publications such as newspapers whose work gets aggregated by online services. The second revision makes online platforms such as Google, Facebook and YouTube liable for the content posted on their services, meaning that all content providers must get permission from rights holders before uploading copyrighted material of any kind. 92 | The Ethiopian government releases a preliminary report on the investigation into the crash of Ethiopian Airlines Flight 302. 93 | Voters in the Japanese prefecture of Okinawa go to the polls in a referendum on the central government's plan to move the Futenma airbase to Henoko in northern Okinawa Island. 72% of voters oppose the plan. 94 | Astana, the capital of Kazakhstan, is renamed Nursultan, after the former President. 95 | The Attorney General files a request to withdraw immunity to the presidential candidate of National Unity of Hope, former First Lady Sandra Torres, for alleged crimes of illicit electoral financing. Torres said it was a political persecution of the former Attorney General and presidential candidate Thelma Aldana. 96 | The Pakistani Air Force (PAF) conducted six airstrikes in Indian-administered Kashmir, shooting down one Indian aircraft and capturing its pilot following a dogfight. Pakistani officials claimed that they have shot down two Indian Air Force (IAF) jets. Wreckage of one aircraft fell in Azad Kashmir while the other fell in Indian-administered Kashmir. Indian officials initially rejected that any of their pilot was in captured, however, once the footage of Indian pilot in Pakistan's custody were released on the internet, Indian officials admitted that their pilot was captured and his aircraft was shot down. India officials also claimed to have shot down one Pakistani Air Force (PAF) jet that violated its airspace. However, Indian officials claim were rejected by Pakistan. 97 | A major power blackout leaves most of Venezuela without electricity, including the capital Caracas. At least 18 of Venezuela's 23 states have reported blackouts. Venezuelan news website El Pitazo blames failures at Simón Bolívar hydroelectric plant; state TV blames anti-government saboteurs. 98 | A Piper Malibu light aircraft, carrying two people on a flight from Nantes to Cardiff, goes missing off the coast of Alderney in the Channel Islands. A major search and rescue operation is underway. Cardiff City F.C. footballer Emiliano Sala is confirmed to have been on board the missing aircraft. 99 | At least 21 miners die after a roof collapses in a coal mine in Shenmu, China. 100 | In response to a possible accidental launching of two rockets from the Gaza Strip towards Tel Aviv the day before, Israel launches hundreds of counter strikes directed at the town of Khan Yunis. 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | newsroom @ git+https://github.com/lil-lab/newsroom.git#egg=newsroom 2 | torch==1.5.0 3 | transformers==2.9.0 4 | numpy>=1.17 5 | -------------------------------------------------------------------------------- /research/coverage_and_density_analysis/analyse_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils 3 | import pathlib 4 | import collections 5 | from pprint import pprint 6 | from nltk import word_tokenize 7 | from nltk import sent_tokenize 8 | from scipy import stats 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | import random 12 | import tarfile 13 | import os 14 | from newsroom.analyze import Fragments 15 | 16 | 17 | def read_wcep(path): 18 | for event in utils.read_jsonl(path): 19 | 20 | if 0: 21 | articles = [a for a in event['articles'] if a['origin'] == 'WCEP'] 22 | texts = [f'{a["title"]}. {a["text"]}' for a in articles] 23 | else: 24 | texts = [f'{a["title"]}. {a["text"]}' for a in event['articles']] 25 | random.shuffle(texts) 26 | texts = texts[:100] 27 | src_sents = [s for text in texts for s in sent_tokenize(text)] 28 | if len(src_sents) == 0: 29 | continue 30 | summary = event['summary'] 31 | tgt_sents = sent_tokenize(summary) 32 | yield src_sents, tgt_sents 33 | 34 | 35 | def read_multinews(path): 36 | indir = pathlib.Path(path) 37 | sep = 'story_separator_special_tag' 38 | indir = pathlib.Path(indir) 39 | src_file = open(indir / 'train.src.txt') 40 | tgt_file = open(indir / 'train.tgt.txt') 41 | for src_line, tgt_line in zip(src_file, tgt_file): 42 | docs = src_line.split(sep) 43 | src_sents = [s for doc in docs for s in sent_tokenize(doc)] 44 | tgt_sents = sent_tokenize(tgt_line) 45 | # print("*" * 100) 46 | # print("TARGET:") 47 | # print(tgt_line) 48 | # print("*"*100) 49 | yield src_sents, tgt_sents 50 | 51 | 52 | def read_duc_2004_(root_dir): 53 | root_dir = pathlib.Path(root_dir) 54 | docs_dir = root_dir / 'DUC2004_Summarization_Documents/duc2004_testdata/tasks1and2/duc2004_tasks1and2_docs/docs' 55 | result_dir = root_dir / 'duc2004_results' 56 | 57 | def get_duc_cluster_docs(cluster_id): 58 | docs = [] 59 | cluster_path = docs_dir / f'd{cluster_id}t' 60 | for fpath in cluster_path.iterdir(): 61 | with open(fpath) as f: 62 | raw = f.read() 63 | text = raw.split("")[1].split("")[0] 64 | text = " ".join(text.split()) 65 | doc = { 66 | 'fname': fpath.name, 67 | 'cluster_id': cluster_id, 68 | 'text': text 69 | } 70 | docs.append(doc) 71 | docs = sorted(docs, key=lambda x: x['fname']) 72 | return docs 73 | 74 | cid_to_clusters = {} 75 | # get reference (models) and peer (participant systems) summaries 76 | for group in ["models", "peers"]: 77 | 78 | gz_path = result_dir / f'ROUGE/duc2004.task2.ROUGE.{group}.tar.gz' 79 | tar = tarfile.open(gz_path, "r:gz") 80 | for member in tar.getmembers(): 81 | 82 | author_id = member.name.split(".")[-1] 83 | cluster_id = member.name.split("/")[-1].split(".")[0].lstrip("D") 84 | 85 | # print(member.name) 86 | # print('CID:', cluster_id) 87 | # print() 88 | 89 | 90 | with tar.extractfile(member) as f: 91 | text = str(f.read(), encoding="UTF-8") 92 | text = " ".join(text.split()) 93 | 94 | summary_item = { 95 | 'author_id': author_id, 96 | 'text': text, 97 | 'cluster_id': cluster_id 98 | } 99 | 100 | if cluster_id not in cid_to_clusters: 101 | cid_to_clusters[cluster_id] = { 102 | 'peer_summaries': [], 103 | 'ref_summaries': [], 104 | 'id': cluster_id 105 | } 106 | 107 | if group == "models": 108 | cid_to_clusters[cluster_id]['ref_summaries'].append(summary_item) 109 | elif group == "peers": 110 | cid_to_clusters[cluster_id]['peer_summaries'].append(summary_item) 111 | 112 | # get source documents 113 | clusters = [] 114 | for cid, c in cid_to_clusters.items(): 115 | docs = get_duc_cluster_docs(cid) 116 | c['documents'] = docs 117 | print('CLUSTER:', cid, len(c['documents'])) 118 | clusters.append(c) 119 | clusters = sorted(clusters, key=lambda x: x['id']) 120 | print('#clusters:', len(clusters)) 121 | return clusters 122 | 123 | 124 | def read_duc_2004(path): 125 | for c in read_duc_2004_(path): 126 | src_sents = [s for d in c['documents'] for s in sent_tokenize(d['text'])] 127 | summary = c['ref_summaries'][0]['text'] 128 | tgt_sents = sent_tokenize(summary) 129 | print(summary) 130 | yield src_sents, tgt_sents 131 | 132 | 133 | def read_cnn_dm(path): 134 | 135 | def parse_cnn_dmm_file(text): 136 | in_sents = [] 137 | out_sents = [] 138 | summary_start = False 139 | for line in text.split('\n'): 140 | if line.strip() != '': 141 | if line == '@highlight': 142 | summary_start = True 143 | else: 144 | if summary_start: 145 | out_sents.append(line) 146 | else: 147 | in_sents.append(line) 148 | return in_sents, out_sents 149 | 150 | indir = pathlib.Path(path) 151 | for fpath in indir.iterdir(): 152 | text = fpath.read_text() 153 | in_sents, out_sents = parse_cnn_dmm_file(text) 154 | yield in_sents, out_sents 155 | 156 | 157 | def reconstruct_fusion(fragments, a_sents): 158 | indices=[] 159 | for f in fragments: 160 | f_indices = [] 161 | f_ = ' '.join(f) 162 | for i, s in enumerate(a_sents): 163 | s_ = ' '.join(word_tokenize(s)) 164 | if f_ in s_: 165 | f_indices.append(i) 166 | indices.append(f_indices) 167 | return indices 168 | 169 | 170 | def extract_fragments(a_tokens, s_tokens): 171 | a_size = len(a_tokens) 172 | s_size = len(s_tokens) 173 | F = [] 174 | i, j = 0, 0 175 | # i: for each summary token 176 | while i < s_size: 177 | f = [] 178 | # j: for each article token 179 | while j < a_size: 180 | # if a&s tokens match: 181 | 182 | if s_tokens[i] == a_tokens[j]: 183 | i_, j_ = i, j 184 | # look further until tokens don't match 185 | while s_tokens[i_] == a_tokens[j_]: 186 | i_ += 1 187 | j_ += 1 188 | if i_ >= s_size or j_ >= a_size: 189 | break 190 | # if new span is larger than previous fragment 191 | if len(f) < (i_ - i ): # maybe instead: i_ - i - 1 192 | f = s_tokens[i: i_] # maybe i_ - 1 193 | j = j_ 194 | else: 195 | j += 1 196 | i += max(len(f), 1) 197 | j = 0 198 | if len(f) > 1: 199 | F.append(f) 200 | return F 201 | 202 | 203 | def compute_compression(a_tokens, s_tokens): 204 | return len(a_tokens) / len(s_tokens) 205 | 206 | 207 | def compute_density(s_tokens, fragments): 208 | d = 0 209 | for frag in fragments: 210 | d += len(frag)**2 211 | return d / len(s_tokens) 212 | 213 | 214 | def compute_coverage(s_tokens, fragments): 215 | c = 0 216 | for frag in fragments: 217 | c += len(frag) 218 | return c / len(s_tokens) 219 | 220 | 221 | def make_kde_plots2(results, outpath): 222 | x = results['coverage'] 223 | y = results['density'] 224 | ax = sns.kdeplot(x, y, cmap="Reds", shade=True, shade_lowest=False) 225 | ax.set_xlim((-0.2, 1.0)) 226 | ax.set_ylim((-0.2, 5.0)) 227 | plt.savefig(outpath) 228 | 229 | #ax.savefig(outpath) 230 | 231 | 232 | def make_kde_plots(results, outpath): 233 | x = results['coverage'] 234 | y = results['density'] 235 | plt.scatter(x, y) 236 | plt.xlabel('Coverage') 237 | plt.ylabel('Density') 238 | plt.savefig(outpath) 239 | plt.close() 240 | 241 | 242 | def run(examples, args): 243 | results = collections.defaultdict(list) 244 | n = 0 245 | for i, (a_sents, s_sents) in enumerate(examples): 246 | 247 | if n >= 1000: 248 | break 249 | # 250 | # if i % 10 != 0: 251 | # continue 252 | 253 | if i % 100 == 0: 254 | print(i, n) 255 | 256 | summary = ' '.join(s_sents) 257 | text = ' '.join(a_sents) 258 | fragments = Fragments(summary, text) 259 | 260 | coverage = fragments.coverage() 261 | density = fragments.density() 262 | compression = fragments.compression() 263 | # 264 | # a_tokens = [w for s in a_sents for w in word_tokenize(s)] 265 | # s_tokens = [w for s in s_sents for w in word_tokenize(s)] 266 | # 267 | # if len(s_tokens) == 0 or len(a_tokens) == 0: 268 | # continue 269 | # 270 | # fragments = extract_fragments(a_tokens, s_tokens) 271 | # compression = compute_compression(a_tokens, s_tokens) 272 | # density = compute_density(s_tokens, fragments) 273 | # coverage = compute_coverage(s_tokens, fragments) 274 | # 275 | # if density > 0: 276 | # density = density / len(s_tokens) 277 | 278 | # 279 | # print("frags", len(fragments)) 280 | # print('COV', coverage, 'DEN', density, 'COMP', compression) 281 | # 282 | # for f in fragments: 283 | # print(f) 284 | # print() 285 | 286 | # 287 | # if coverage == 0: 288 | # print('coverage:', coverage) 289 | # 290 | # print('*** S ***') 291 | # for s in s_sents: 292 | # print(s) 293 | # 294 | # print() 295 | # print('*** A ***') 296 | # for s in a_sents[:5]: 297 | # print(s) 298 | # 299 | # print() 300 | # print() 301 | # 302 | # print() 303 | # print('*** FRAGMENTS ***:') 304 | # for f in fragments: 305 | # print(' '.join(f)) 306 | # print() 307 | # 308 | print('compression:', compression) 309 | print('density:', density) 310 | print('coverage:', coverage) 311 | print('='*100) 312 | 313 | results['compression'].append(compression) 314 | results['density'].append(density) 315 | results['coverage'].append(coverage) 316 | n += 1 317 | 318 | utils.writejson(results, args.o) 319 | #make_kde_plots2(results, args.o + '/kde.png') 320 | 321 | 322 | def main(args): 323 | examples = [] 324 | if args.corpus == 'cnn-dm': 325 | examples = read_cnn_dm(args.i) 326 | elif args.corpus == 'multinews': 327 | examples = read_multinews(args.i) 328 | elif args.corpus == 'wcep': 329 | examples = read_wcep(args.i) 330 | elif args.corpus == 'duc': 331 | examples = read_duc_2004(args.i) 332 | run(examples, args) 333 | 334 | 335 | if __name__ == '__main__': 336 | parser = argparse.ArgumentParser() 337 | parser.add_argument('--i', required=True) 338 | parser.add_argument('--o', required=True) 339 | parser.add_argument('--corpus', default='wcep') 340 | main(parser.parse_args()) 341 | -------------------------------------------------------------------------------- /research/coverage_and_density_analysis/plots.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from general import utils 4 | import pathlib 5 | import collections 6 | from pprint import pprint 7 | from nltk import word_tokenize 8 | from nltk import sent_tokenize 9 | from scipy import stats 10 | import matplotlib.pyplot as plt 11 | import matplotlib 12 | import seaborn as sns 13 | 14 | 15 | def main(args): 16 | 17 | """ 18 | --wcep1 ~/Desktop/WCEP/analysis/wcep_stats.C10.json \ 19 | --cnn ~/Desktop/WCEP/analysis/cnn_stats.json \ 20 | --multinews ~/Desktop/WCEP/analysis/multinews_stats.json 21 | """ 22 | dir_ = pathlib.Path('/home/demian/Desktop/WCEP/analysis') 23 | 24 | cnn_stats = utils.readjson(dir_ / 'cnn_stats.json') 25 | mn_stats = utils.readjson(dir_ / 'multinews_stats.json') 26 | wcep_stats_orig = utils.readjson(dir_ / 'wcep_stats_original.json') 27 | wcep_stats_10 = utils.readjson(dir_ / 'wcep_stats.C10.json') 28 | #wcep_stats_50 = utils.readjson(dir_ / 'wcep_stats.C50.json') 29 | wcep_stats_100 = utils.readjson(dir_ / 'wcep_stats.C100.json') 30 | duc_stats = utils.readjson(dir_ / 'duc_stats.json') 31 | all_stats = [wcep_stats_orig, wcep_stats_10, wcep_stats_100, cnn_stats, mn_stats, duc_stats] 32 | colors = ['Reds', 'Reds', 'Reds', 'Greens', 'Blues', 'Purples'] 33 | #colors = ['b', 'b', 'b', 'r', 'g', 'm', 'c'] 34 | names = ['WCEP-original', 'WCEP-10', 'WCEP-100', 'CNN', 'MultiNews', 'DUC'] 35 | #fig, ax = plt.subplots(1, 3, sharey=True) 36 | #plt.style.use('dark_background') 37 | 38 | fig, ax = plt.subplots(3, 2, sharey=True) 39 | 40 | plt.rcParams["patch.force_edgecolor"] = True 41 | 42 | n_to_coord = { 43 | 0: (0, 0), 44 | 1: (1, 0), 45 | 2: (2, 0), 46 | 3: (0, 1), 47 | 4: (1, 1), 48 | 5: (2, 1), 49 | } 50 | 51 | font = {'family': 'normal', 52 | 'color': 'black', 53 | 'weight': 'normal', 54 | 'size': 11, 55 | } 56 | 57 | for n in range(6): 58 | name = names[n] 59 | i, j = n_to_coord[n] 60 | ax_i = ax[i, j] 61 | ax_i.set_facecolor('white') 62 | 63 | print('Dataset:', names[n]) 64 | 65 | stats = all_stats[n] 66 | coverage = np.array(stats['coverage']) 67 | density = np.array(stats['density']) 68 | print('Cov:', min(coverage), max(coverage), np.mean(coverage), np.median(coverage)) 69 | print('Dense:', min(density), max(density), np.mean(density), np.median(density)) 70 | ax_i.text(0.1, 8, name, fontdict=font) 71 | ax_i.set_ylim((0.0, 10.0)) 72 | ax_i.set_xlim((0.0, 1.0)) 73 | #ax_i.scatter(coverage, density, c=colors[n]) 74 | sns.kdeplot( 75 | coverage, 76 | density, 77 | ax=ax_i, 78 | cmap=colors[n], 79 | shade=True, 80 | shade_lowest=False, 81 | ) 82 | 83 | # ax_i.patch.set_edgecolor('black') 84 | # ax_i.patch.set_linewidth('2') 85 | 86 | ax_i.patch.set_edgecolor('black') 87 | ax_i.patch.set_linewidth(0.8) 88 | 89 | ax[2, 0].set_xlabel('Extractive fragment coverage') 90 | ax[1, 0].set_ylabel('Extractive fragment density') 91 | # # 92 | # plt.rcParams["axes.edgecolor"] = "black" 93 | # plt.rcParams["axes.linewidth"] = 1 94 | # sns.set_style("white") 95 | plt.show() 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | # parser.add_argument('--cnn') 101 | # parser.add_argument('--wcep') 102 | # parser.add_argument('--multinews') 103 | main(parser.parse_args()) 104 | -------------------------------------------------------------------------------- /research/coverage_and_density_analysis/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import datetime 5 | import codecs 6 | import gzip 7 | import io 8 | import csv 9 | import random 10 | import shutil 11 | 12 | 13 | random.seed(24) 14 | csv.field_size_limit(1000000) 15 | 16 | 17 | def force_mkdir(path): 18 | if os.path.exists(path): 19 | shutil.rmtree(path) 20 | os.mkdir(path) 21 | 22 | 23 | def wipe_file(path): 24 | with open(path, 'w') as f: 25 | pass 26 | 27 | 28 | def abs_listdir(dir): 29 | paths = [os.path.join(dir, x) for x in os.listdir(dir)] 30 | return sorted(paths) 31 | 32 | 33 | def readfile(path): 34 | with open(path) as f: 35 | text = f.read() 36 | return text 37 | 38 | 39 | def readlines(path): 40 | with open(path) as f: 41 | text = f.read() 42 | return text.split('\n') 43 | 44 | 45 | def writefile(s, path): 46 | with open(path, 'w') as f: 47 | f.write(s) 48 | 49 | 50 | def writelines(lines, path): 51 | with open(path, 'w') as f: 52 | f.write('\n'.join(lines)) 53 | 54 | 55 | def readjson(path): 56 | text = readfile(path) 57 | return json.loads(text) 58 | 59 | 60 | def writejson(obj, path): 61 | with open(path, 'w') as f: 62 | json.dump(obj, f) 63 | 64 | 65 | def load_pkl(path): 66 | with open(path, 'rb') as f: 67 | obj = pickle.load(f) 68 | return obj 69 | 70 | 71 | def dump_pkl(obj, path): 72 | with open(path, 'wb') as f: 73 | pickle.dump(obj, f) 74 | 75 | 76 | def periodic_print(i, n=10): 77 | if i % n == 0: 78 | print(i) 79 | 80 | 81 | def append_to_file(s, path): 82 | with open(path, 'a') as f: 83 | f.write(s) 84 | 85 | 86 | def split_dataset(inpath, trainpath, testpath, ratio=0.5): 87 | lines = readlines(inpath) 88 | split_at = int(ratio * len(lines)) 89 | trainlines = lines[:split_at] 90 | testlines = lines[split_at:] 91 | writelines(trainlines, trainpath) 92 | writelines(testlines, testpath) 93 | 94 | 95 | def get_best_scored(d): 96 | return sorted(d, key=lambda x: x[1], reverse=True)[0][0] 97 | 98 | 99 | def parse_isodate(s): 100 | return datetime.datetime.strptime(s, "%Y-%m-%dT%H:%M:%S.%fZ") 101 | 102 | 103 | def readfile2(path): 104 | with codecs.open(path, "r", encoding='utf-8', errors='ignore') as f: 105 | text = f.read() 106 | return text 107 | 108 | 109 | def write_gzip(text, path): 110 | with gzip.open(path, 'wb') as output: 111 | with io.TextIOWrapper(output, encoding='utf-8') as enc: 112 | enc.write(text) 113 | 114 | 115 | def read_gzip(path): 116 | with gzip.open(path, 'rb') as input_file: 117 | with io.TextIOWrapper(input_file) as dec: 118 | content = dec.read() 119 | return content 120 | 121 | 122 | def select_all_nth(list_of_lists, n): 123 | return [x[n] for x in list_of_lists] 124 | 125 | 126 | def get_date_range(start, end): 127 | diff = end - start 128 | date_range = [] 129 | for n in range(diff.days + 1): 130 | t = start + datetime.timedelta(days=n) 131 | date_range.append(t) 132 | return date_range 133 | 134 | 135 | def read_sheet(path, delimiter=','): 136 | '''Read csv, tsv, etc. file''' 137 | with open(path, 'r') as f: 138 | row_dicts = [] 139 | rows = list(csv.reader(f, delimiter=delimiter)) 140 | header = rows[0] 141 | idx_map = dict((x, i) for i, x in enumerate(header)) 142 | for row in rows[1:]: 143 | row_dict = {} 144 | for x, i in idx_map.items(): 145 | if i < len(row): 146 | row_dict[x] = row[i] 147 | row_dicts.append(row_dict) 148 | return row_dicts 149 | 150 | 151 | def write_sheet(sheet, path, header=None, write_header=True): 152 | if header is None: 153 | header = sorted(sheet[0].keys()) 154 | with open(path, 'w') as f: 155 | writer = csv.writer(f) 156 | if write_header: 157 | writer.writerow(header) 158 | for item in sheet: 159 | row = [] 160 | for k in header: 161 | if k in item: 162 | v = item[k] 163 | else: 164 | v = '' 165 | row.append(v) 166 | writer.writerow(row) 167 | 168 | 169 | def read_jsonl(path, load=False, start=0, stop=None): 170 | 171 | def read_jsonl_stream(path): 172 | with open(path) as f: 173 | for i, line in enumerate(f): 174 | if (stop is not None) and (i >= stop): 175 | break 176 | if i >= start: 177 | yield json.loads(line) 178 | 179 | data = read_jsonl_stream(path) 180 | if load: 181 | data = list(data) 182 | return data 183 | 184 | 185 | def write_jsonl(items, path, batch_size=100, override=True): 186 | if override: 187 | with open(path, 'w'): 188 | pass 189 | 190 | batch = [] 191 | for i, x in enumerate(items): 192 | if i > 0 and i % batch_size == 0: 193 | with open(path, 'a') as f: 194 | output = '\n'.join(batch) + '\n' 195 | f.write(output) 196 | batch = [] 197 | raw = json.dumps(x) 198 | batch.append(raw) 199 | 200 | if batch: 201 | with open(path, 'a') as f: 202 | output = '\n'.join(batch) + '\n' 203 | f.write(output) 204 | 205 | 206 | def read_tap_dataset_csv(path): 207 | data = [] 208 | with open(path) as f: 209 | reader = csv.reader(f) 210 | for text, label in reader: 211 | item = { 212 | 'text': text, 213 | 'label': label 214 | } 215 | data.append(item) 216 | return data 217 | 218 | 219 | def shuffled(items): 220 | items = items.copy() 221 | random.shuffle(items) 222 | return items 223 | 224 | 225 | def sample(items, n): 226 | return shuffled(items)[:n] 227 | -------------------------------------------------------------------------------- /research/duc2004_to_jsonl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tarfile\n", 10 | "import pathlib\n", 11 | "import json\n", 12 | "\n", 13 | "\n", 14 | "def read_duc_2004_(root_dir):\n", 15 | " root_dir = pathlib.Path(root_dir)\n", 16 | " docs_dir = root_dir / 'DUC2004_Summarization_Documents/duc2004_testdata/tasks1and2/duc2004_tasks1and2_docs/docs'\n", 17 | " result_dir = root_dir / 'duc2004_results'\n", 18 | "\n", 19 | " def get_duc_cluster_docs(cluster_id):\n", 20 | " docs = []\n", 21 | " cluster_path = docs_dir / f'd{cluster_id}t'\n", 22 | " for fpath in cluster_path.iterdir():\n", 23 | " with open(fpath) as f:\n", 24 | " raw = f.read()\n", 25 | " text = raw.split(\"\")[1].split(\"\")[0]\n", 26 | " text = \" \".join(text.split())\n", 27 | " doc = {\n", 28 | " 'fname': fpath.name,\n", 29 | " 'cluster_id': cluster_id,\n", 30 | " 'text': text\n", 31 | " }\n", 32 | " docs.append(doc)\n", 33 | " docs = sorted(docs, key=lambda x: x['fname'])\n", 34 | " return docs\n", 35 | "\n", 36 | " cid_to_clusters = {}\n", 37 | " # get reference (models) and peer (participant systems) summaries\n", 38 | " for group in [\"models\", \"peers\"]:\n", 39 | "\n", 40 | " gz_path = result_dir / f'ROUGE/duc2004.task2.ROUGE.{group}.tar.gz'\n", 41 | " tar = tarfile.open(gz_path, \"r:gz\")\n", 42 | " for member in tar.getmembers():\n", 43 | "\n", 44 | " author_id = member.name.split(\".\")[-1]\n", 45 | " cluster_id = member.name.split(\"/\")[-1].split(\".\")[0].lstrip(\"D\")\n", 46 | "\n", 47 | " # print(member.name)\n", 48 | " # print('CID:', cluster_id)\n", 49 | " # print()\n", 50 | "\n", 51 | " with tar.extractfile(member) as f:\n", 52 | " text = str(f.read(), encoding=\"UTF-8\")\n", 53 | " text = \" \".join(text.split())\n", 54 | "\n", 55 | " summary_item = {\n", 56 | " 'author_id': author_id,\n", 57 | " 'text': text,\n", 58 | " 'cluster_id': cluster_id\n", 59 | " }\n", 60 | "\n", 61 | " if cluster_id not in cid_to_clusters:\n", 62 | " cid_to_clusters[cluster_id] = {\n", 63 | " 'peer_summaries': [],\n", 64 | " 'ref_summaries': [],\n", 65 | " 'id': cluster_id\n", 66 | " }\n", 67 | "\n", 68 | " if group == \"models\":\n", 69 | " cid_to_clusters[cluster_id]['ref_summaries'].append(summary_item)\n", 70 | " elif group == \"peers\":\n", 71 | " cid_to_clusters[cluster_id]['peer_summaries'].append(summary_item)\n", 72 | "\n", 73 | " # get source documents\n", 74 | " clusters = []\n", 75 | " for cid, c in cid_to_clusters.items():\n", 76 | " docs = get_duc_cluster_docs(cid)\n", 77 | " c['documents'] = docs\n", 78 | " print('CLUSTER:', cid, len(c['documents']))\n", 79 | " clusters.append(c)\n", 80 | " clusters = sorted(clusters, key=lambda x: x['id'])\n", 81 | " print('#clusters:', len(clusters))\n", 82 | " return clusters\n", 83 | "\n", 84 | "\n", 85 | "def read_duc_2004(path):\n", 86 | " for c in read_duc_2004_(path):\n", 87 | " docs = [d['text'] for d in c['documents']]\n", 88 | " summaries = [s['text'] for s in c['ref_summaries']]\n", 89 | " yield docs, summaries\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 2, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "CLUSTER: 30001 10\n", 102 | "CLUSTER: 30002 10\n", 103 | "CLUSTER: 30003 10\n", 104 | "CLUSTER: 30005 10\n", 105 | "CLUSTER: 30006 10\n", 106 | "CLUSTER: 30007 10\n", 107 | "CLUSTER: 30008 10\n", 108 | "CLUSTER: 30010 10\n", 109 | "CLUSTER: 30011 10\n", 110 | "CLUSTER: 30015 10\n", 111 | "CLUSTER: 30017 10\n", 112 | "CLUSTER: 30020 10\n", 113 | "CLUSTER: 30022 10\n", 114 | "CLUSTER: 30024 10\n", 115 | "CLUSTER: 30026 10\n", 116 | "CLUSTER: 30027 10\n", 117 | "CLUSTER: 30028 10\n", 118 | "CLUSTER: 30029 10\n", 119 | "CLUSTER: 30031 10\n", 120 | "CLUSTER: 30033 10\n", 121 | "CLUSTER: 30034 10\n", 122 | "CLUSTER: 30036 10\n", 123 | "CLUSTER: 30037 10\n", 124 | "CLUSTER: 30038 10\n", 125 | "CLUSTER: 30040 10\n", 126 | "CLUSTER: 30042 10\n", 127 | "CLUSTER: 30044 10\n", 128 | "CLUSTER: 30045 10\n", 129 | "CLUSTER: 30046 10\n", 130 | "CLUSTER: 30047 10\n", 131 | "CLUSTER: 30048 10\n", 132 | "CLUSTER: 30049 10\n", 133 | "CLUSTER: 30050 10\n", 134 | "CLUSTER: 30051 10\n", 135 | "CLUSTER: 30053 10\n", 136 | "CLUSTER: 30055 10\n", 137 | "CLUSTER: 30056 10\n", 138 | "CLUSTER: 30059 10\n", 139 | "CLUSTER: 31001 10\n", 140 | "CLUSTER: 31008 10\n", 141 | "CLUSTER: 31009 10\n", 142 | "CLUSTER: 31013 10\n", 143 | "CLUSTER: 31022 10\n", 144 | "CLUSTER: 31026 10\n", 145 | "CLUSTER: 31031 10\n", 146 | "CLUSTER: 31032 10\n", 147 | "CLUSTER: 31033 10\n", 148 | "CLUSTER: 31038 10\n", 149 | "CLUSTER: 31043 10\n", 150 | "CLUSTER: 31050 10\n", 151 | "#clusters: 50\n", 152 | "[(10, 50)]\n", 153 | "Input stats:\n", 154 | "(588.272, 191996.402016, 438.17394036615184)\n", 155 | "Summary stats:\n", 156 | "(104.57, 28.145100000000003, 5.305195566612036)\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "from collections import Counter\n", 162 | "import numpy as np\n", 163 | "from pathlib import Path\n", 164 | "\n", 165 | "\n", 166 | "DATADIR = '/home/chris/projects/aylien/dynamic-ensembles/data/DUC2004'\n", 167 | "\n", 168 | "cluster_rows = []\n", 169 | "\n", 170 | "article_cnts = Counter()\n", 171 | "summary_lens = []\n", 172 | "source_lens = []\n", 173 | "\n", 174 | "# DUC only has a test set\n", 175 | "with open(Path(DATADIR) / ('DUC2004_test.jsonl'), 'w') as out:\n", 176 | " for srcs, tgts in read_duc_2004(DATADIR):\n", 177 | " articles = [{'title': '', 'text': t} for t in srcs]\n", 178 | " out.write(f'{json.dumps({\"articles\": articles, \"summary\": tgts})}\\n')\n", 179 | " article_cnts.update([len(articles)])\n", 180 | " source_lens.extend([len(a.split()) for a in srcs])\n", 181 | " summary_lens.extend([len(t.split()) for t in tgts])\n", 182 | "\n", 183 | "print(article_cnts.most_common())\n", 184 | "print('Input stats:')\n", 185 | "print((np.mean(source_lens), np.var(source_lens), np.std(source_lens)))\n", 186 | "print('Summary stats:')\n", 187 | "print((np.mean(summary_lens), np.var(summary_lens), np.std(summary_lens)))\n" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python 3", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.6.10" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 4 226 | } 227 | -------------------------------------------------------------------------------- /research/filter_multinews_duplicate_articles.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# filter out duplicates in MultiNews test set" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 5, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import json\n", 19 | "import copy\n", 20 | "from collections import defaultdict" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 6, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "test_lines = [json.loads(l) for l in open('../data/multi-news/test.jsonl')] \n", 30 | "\n", 31 | "a_to_c = defaultdict(list)\n", 32 | "\n", 33 | "for c_idx, c in enumerate(test_lines): \n", 34 | " for a in c['articles']: \n", 35 | " a_to_c[a['text']].append(c_idx) \n", 36 | "\n", 37 | "\n", 38 | "duplicate_articles = set()\n", 39 | "\n", 40 | "for item in a_to_c.items(): \n", 41 | " if len(a_to_c[item[0]]) > 1:\n", 42 | "# print(item) \n", 43 | " duplicate_articles.update([item[0]])\n", 44 | " \n", 45 | "filtered_test_lines = []\n", 46 | "for c_idx, c in enumerate(test_lines): \n", 47 | " filtered_articles = []\n", 48 | " for a in c['articles']:\n", 49 | " if a['text'] not in duplicate_articles:\n", 50 | " filtered_articles.append(a)\n", 51 | " \n", 52 | " filtered_c = copy.deepcopy(c)\n", 53 | " filtered_c['articles'] = filtered_articles\n", 54 | " \n", 55 | " filtered_test_lines.append(filtered_c)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 7, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "assert len(test_lines) == len(filtered_test_lines)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 11, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "[(733,\n", 76 | " {'articles': [],\n", 77 | " 'summary': '– By now, everyone should know not to get too close to hippos in the wild or in captivity, but a California man didn\\'t get the memo. Officials at the Los Angeles Zoo tell the Los Angeles Times they\\'ve recruited the LAPD to help look for a most unusual sort of trespasser: a man caught on tape scaling a barrier and slapping one of two hippos in the pen below. The now-viral video, recorded from the other side of the enclosure by what sounds to be a young woman, shows the man slowly climbing over a railing, then reaching down to slap the butt of 4-year-old Rosie, who appeared to be snacking alongside her mother, Mara. Mara looks up briefly, and the person shooting the video can be heard giggling as the man lifts his arms up in an apparent victory stretch and runs off. But zoo officials warn that what happened was anything but funny. \"Any unauthorized interaction with an animal is unsafe for the animal and potentially unsafe for the patron,\" a zoo spokeswoman says, noting something like this also breaks down the animals\\' trust that zookeepers have worked so hard to instill. The BBC notes hippos are responsible for more human deaths in Africa than any other big animal. The suspect could be hit with a misdemeanor charge or other infraction—California law bars anyone from climbing into zoo enclosures—but the case is being investigated as a trespassing violation, not an animal cruelty one, as the animals didn\\'t seem particularly fazed by the incident.\\n'}),\n", 78 | " (1245,\n", 79 | " {'articles': [],\n", 80 | " 'summary': '– An Italian WWII pilot who died battling US pilots 70 years ago has been found 13 feet underground, his remains still at the controls of a fighter plane armed with machine guns and cannons, Discovery reports. Lt. Guerrino Bortolani went down in a losing battle against Allied planes on March 11, 1944, and hit the ground so hard that he literally vanished into the countryside outside Padua in northern Italy. \"The crash site is now a cornfield,\" says a member of the wreck-hunting crew that found Bortolani. \"We were able to find the remains with the help of an elderly man, who on that day witnessed the fighter going into a nosedive and hit the ground.\" Bortolani was flying the best Italian fighter plane (the Macchi C.205) in a squadron led by the renowned Italian ace Adriano Visconti. But they went up against a daunting strike by the Mediterranean Allied Strategic Air Force—which had sent 111 B-17 planes over Padua to drop more than 300 tons of bombs. Allies said the Axis defense was \"aggressive,\" but five German planes and four Italian planes went down. Bortolani was \"dutiful until the end,\" the Week notes, sitting on his closed parachute and wearing a ring given him by a fighter pilot academy. Wreck-hunters found several parts of the plane as well, including the tail wheel, control stick, and pieces of the engine. Bortolani is expected to have a proper burial once relatives are found. (Read about a German U-boat found off North Carolina.)\\n'}),\n", 81 | " (1972,\n", 82 | " {'articles': [],\n", 83 | " 'summary': '– A Florida police union is calling for a national boycott of Arby\\'s after one of the restaurant\\'s employees allegedly refused to serve an officer this week, CBS Miami reports. According to Local 10 News, 19-year-old Kenneth Davenport failed to serve Sgt. Jennifer Martin, who was in uniform and driving a patrol car according to USA Today, after she ordered in the drive-thru. Davenport\\'s manager allegedly told Martin that Davenport didn\\'t want to serve her because she was an officer and laughed while telling her Davenport had the right to not serve her, but then gave her the food himself. She was ultimately given a refund, though, after deciding she didn\\'t want to eat the food. “I am offended and appalled that an individual within our community would treat a police officer in such a manner,\" the Pembroke Pines chief of police tells CBS. \"It is unacceptable.\" In response to the alleged slight, officers\\' wives protested outside the restaurant today, and the Dade County Police Benevolent Association called for a boycott of Arby\\'s until the employee or employees responsible are fired, Local 10 reports. \"This is yet another example of the hostile treatment of our brave men and women simply because they wear a badge,\" according to a statement from the union president, who tells Local 10 he blames Obama for the lack of respect shown officers. Arby\\'s executives have apologized to the chief of police.\\n'}),\n", 84 | " (4148,\n", 85 | " {'articles': [],\n", 86 | " 'summary': '– What was Michael Phelps doing before he was arrested on DUI charges early yesterday? According to TMZ and its \"casino sources,\" he was on an eight-hour \"gambling binge.\" The sources say Phelps was playing blackjack and drinking beer in a private VIP room at Baltimore\\'s Horseshoe Casino starting around 5pm Monday; he left around 1am was pulled over around 1:40am a few miles from the casino. He\\'s said to be a regular there, usually playing poker; it\\'s not clear whether he played other card games Monday. Phelps apologized on Twitter yesterday: \"I understand the severity of my actions and take full responsibility,\" he wrote. \"I know these words may not mean much right now but I am deeply sorry to everyone I have let down.\"\\n'}),\n", 87 | " (4442,\n", 88 | " {'articles': [],\n", 89 | " 'summary': '– A month after stunning Hollywood and sparking questions about her mental stability with the news of her retirement, Amanda Bynes wants a do-over. \"I\\'ve unretired,\" the 24-year-old on-again-off-again actress tweeted. She\\'s a co-star of the buzzed-about high school comedy Easy A, notes People—and within minutes of \"unretiring,\" Bynes was tweeting about the trailer.\\n'}),\n", 90 | " (4882,\n", 91 | " {'articles': [],\n", 92 | " 'summary': '– A \"mini-mammoth\" the size of a baby elephant has been identified on the island of Crete. Mammuthus creticus is the tiniest mammoth ever found, and is another example of \"dwarfism\" on islands, where scare resources can keep animals small, notes the Telegraph. Fossilized teeth of the three-foot-tall mammoth were first discovered in 1904, but were initially believed to be elephant teeth. Scientists only recently re-examined them and determined they were evidence of a miniature mammoth. They also returned to the spot in Crete and discovered a mini leg bone. \"Dwarfism is a well-known evolutionary response of large mammals to island environments,\" said lead researcher Victoria Herrige from London\\'s Natural History Museum. \"Our findings show that on Crete, island dwarfism occurred to an extreme degree, producing the smallest mammoth known so far.\" Researchers believe the animals may have evolved from regular-sized mammoths as long as 3.5 million years ago.\\n'}),\n", 93 | " (5091,\n", 94 | " {'articles': [],\n", 95 | " 'summary': '– If President Obama absolutely killed Al Green\\'s \"Let\\'s Stay Together,\" Mitt Romney has apparently responded by just plain murdering a patriotic standard. The Republican candidate was caught on camera last night at a Florida campaign rally warbling \"America the Beautiful,\" which the Washington Post notes he has often called his favorite patriotic hymn. The unguarded moment on the eve of the primary is a departure for Romney, but the clip prompts Pier Morgan over at CNN to joke, \"I think this could be an actual issue.\"\\n'}),\n", 96 | " (5119,\n", 97 | " {'articles': [],\n", 98 | " 'summary': '– The Telegraph reported yesterday on a crazy court case in the UK: After a pregnant Italian woman, in town for business, had a panic attack, social service workers in Essex got a court order allowing the woman to be forcibly sedated and undergo a C-section so they could take her baby. Fifteen months later, the little girl is still with social service workers, who won\\'t return her to her mother. The case is now \"an international legal row,\" the Telegraph says, and the anonymous woman\\'s lawyers call it \"unprecedented.\" The woman was in Britain in July 2012 for an airline training course, and called police when she suffered the panic attack. They arrived while she was on the phone with her mother, who told police the woman suffered from bipolar disorder and was off her medication, according to a Telegraph columnist. Police took her to a psychiatric facility, and restrained her under the Mental Health Act when she said she wanted to go back to her hotel. She underwent the C-section after having been there five weeks. The case is ongoing; the mother says she has made a full recovery, but a judge nonetheless ruled that her daughter should be put up for adoption. More on the case here and here.\\n'}),\n", 99 | " (5478,\n", 100 | " {'articles': [],\n", 101 | " 'summary': \"– They've been living separately for a while, and now Mariah Carey and Nick Cannon are making the split official, reports People. Cannon filed divorce papers last month, and TMZ reports that the couple has worked out a deal to split their property. (They've also got 3-year-old twins.) The website takes note of an unusual feature of their prenup: If Cannon violates a confidentiality clause and talks about their marriage, he has to pay up $250,000; if Carey talks, she has to pay $500,000. The reason is simple: She's got a lot more money than he does, and an upcoming artist-in-residence gig in Vegas will only widen the gap.\\n\"})]" 102 | ] 103 | }, 104 | "execution_count": 11, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "[(idx, c) for idx, c in enumerate(filtered_test_lines) if len(c['articles']) == 0]" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 19, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "wrote 5613 to ../data/multi-news/filtered_test.jsonl, dropped 9 because they contained no actual articles\n", 123 | "the articles in 780 clusters changed after the filtering process\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "# write out filtered dataset, dropping broken lines\n", 129 | "\n", 130 | "with open('../data/multi-news/filtered_test.jsonl', 'w') as out:\n", 131 | " dropped_lines = 0\n", 132 | " output_lines = 0\n", 133 | " changed_clusters = 0\n", 134 | " for c_idx, c in enumerate(filtered_test_lines):\n", 135 | " if len(c['articles']) > 0:\n", 136 | " out.write(f'{json.dumps(c)}\\n')\n", 137 | " output_lines += 1\n", 138 | " if len(test_lines[c_idx]['articles']) != len(c['articles']):\n", 139 | " changed_clusters += 1\n", 140 | " else:\n", 141 | " dropped_lines += 1\n", 142 | " print(f'wrote {output_lines} to {out.name}, dropped {dropped_lines} because they contained no actual articles')\n", 143 | " print(f'the articles in {changed_clusters} clusters changed after the filtering process')\n" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Python 3", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.6.10" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 4 175 | } 176 | -------------------------------------------------------------------------------- /research/follow_lebanoff_et_al_2018_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# recreate the lebanoff et al 2018 evaluation setup\n", 10 | "# adapted from here: \n", 11 | "# https://github.com/ucfnlp/multidoc_summarization/blob/ae30c9ee039d4ad5ff64fd2245faafc5a62c4dd7/decode.py\n", 12 | "\n", 13 | "# installing pyrouge\n", 14 | "# https://stackoverflow.com/questions/45894212/installing-pyrouge-gets-error-in-ubuntu" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 6, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from pathlib import Path\n", 24 | "import json\n", 25 | "import os\n", 26 | "import logging\n", 27 | "\n", 28 | "from transformer_decoding.evaluate import lebanoff_2018_rouge, evaluate_rouge, print_mean" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 21, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# TEMP_EVAL_DIR = Path('rouge_evaluation_tempdir')\n", 38 | "# os.makedirs(TEMP_EVAL_DIR, exist_ok=True)\n", 39 | "\n", 40 | "\n", 41 | "# DUC 2004\n", 42 | "# evaluation_dataset = '/home/chris/projects/aylien/dynamic-ensembles/data/DUC2004/DUC2004_test.jsonl'\n", 43 | "\n", 44 | "# system_hypotheses = '/home/chris/projects/aylien/dynamic-ensembles/data/DUC2004/system_hypotheses/2020-05-14/eval_predicted_summaries.out.1_doc_in_cluster'\n", 45 | "# system_hypotheses = '/home/chris/projects/aylien/dynamic-ensembles/data/DUC2004/system_hypotheses/2020-05-14/eval_predicted_summaries.out.5_docs_in_cluster'\n", 46 | "# system_hypotheses = '/home/chris/projects/aylien/dynamic-ensembles/data/DUC2004/system_hypotheses/2020-05-14/eval_predicted_summaries.out.8_docs_in_cluster'\n", 47 | "\n", 48 | "# MultiNews\n", 49 | "evaluation_dataset = '../data/downloads_from_MultiNews_gdrive/test.jsonl'\n", 50 | "\n", 51 | "# system_hypotheses = '../MultiNews_1_article_eval_predicted_summaries.out'\n", 52 | "# system_hypotheses = '../data/downloads_from_MultiNews_gdrive/export_output/transformer.txt'\n", 53 | "\n", 54 | "# system_hypotheses = '../data/downloads_from_MultiNews_gdrive/export_output/hi_map.txt'\n", 55 | "\n", 56 | "\n", 57 | "# WCEP\n", 58 | "evaluation_dataset = '../data/WCEP/test.jsonl'\n", 59 | "system_hypotheses = '../wcep_5_articles_eval_predicted_summaries.out'\n", 60 | "\n", 61 | "\n", 62 | "# TODO: rm tempdir after eval\n", 63 | "dataset_rows = [json.loads(l) for l in open(evaluation_dataset)]\n", 64 | "orig_system_hyps = [h.strip() for h in open(system_hypotheses)]" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 22, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# lebanoff_2018_rouge(system_hypotheses, evaluation_dataset)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 23, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "\n", 83 | "# Hi-Map Lebanoff Eval\n", 84 | "# system_hypotheses = '../data/downloads_from_MultiNews_gdrive/export_output/hi_map.txt'\n", 85 | "\n", 86 | "# ROUGE-1:\n", 87 | "# rouge_1_f_score: 0.4028 with confidence interval (0.4002, 0.4056)\n", 88 | "# rouge_1_recall: 0.3986 with confidence interval (0.3960, 0.4013)\n", 89 | "# rouge_1_precision: 0.4086 with confidence interval (0.4061, 0.4115)\n", 90 | "\n", 91 | "# ROUGE-2:\n", 92 | "# rouge_2_f_score: 0.1374 with confidence interval (0.1344, 0.1406)\n", 93 | "# rouge_2_recall: 0.1361 with confidence interval (0.1331, 0.1394)\n", 94 | "# rouge_2_precision: 0.1392 with confidence interval (0.1360, 0.1424)\n", 95 | "\n", 96 | "# ROUGE-l:\n", 97 | "# rouge_l_f_score: 0.3460 with confidence interval (0.3433, 0.3488)\n", 98 | "# rouge_l_recall: 0.3424 with confidence interval (0.3398, 0.3452)\n", 99 | "# rouge_l_precision: 0.3511 with confidence interval (0.3484, 0.3539)\n", 100 | "\n", 101 | "# ROUGE-s4:\n", 102 | "# rouge_s4_f_score: 0.1145 with confidence interval (0.1117, 0.1174)\n", 103 | "# rouge_s4_recall: 0.1134 with confidence interval (0.1106, 0.1163)\n", 104 | "# rouge_s4_precision: 0.1160 with confidence interval (0.1132, 0.1190)\n", 105 | "\n", 106 | "# ROUGE-su4:\n", 107 | "# rouge_su4_f_score: 0.1633 with confidence interval (0.1607, 0.1663)\n", 108 | "# rouge_su4_recall: 0.1617 with confidence interval (0.1590, 0.1647)\n", 109 | "# rouge_su4_precision: 0.1657 with confidence interval (0.1629, 0.1686)\n", 110 | "\n" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 24, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# dataset_rows = [json.loads(l) for l in open(evaluation_dataset)]\n", 120 | "gold_outputs = [json.loads(l)['summary'] for l in open(evaluation_dataset)]\n", 121 | "orig_system_hyps = [h.strip() for h in open(system_hypotheses)]" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 25, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "Evaluation dataset length: 1022\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "assert len(gold_outputs) == len(orig_system_hyps)\n", 139 | "print(f'Evaluation dataset length: {len(gold_outputs)}')" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 26, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "rouge-1 p: 0.297 r: 0.495 f: 0.354\n", 152 | "rouge-2 p: 0.125 r: 0.218 f: 0.151\n", 153 | "rouge-l p: 0.213 r: 0.364 f: 0.256\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "# our evaluation implementation (Ghalandari et al 2020)\n", 159 | "print_mean(*evaluate_rouge(orig_system_hyps, gold_outputs))" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 18, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "# double check multinews evaluation, code taken from: \n", 169 | "# https://github.com/Alex-Fabbri/Multi-News/blob/3675e7c422ae3b4020617a324ac264f50333357d/code/OpenNMT-py-baselines/tools/test_rouge.py\n", 170 | "# -*- encoding: utf-8 -*-\n", 171 | "import argparse\n", 172 | "import os\n", 173 | "import time\n", 174 | "import pyrouge\n", 175 | "import shutil\n", 176 | "import sys\n", 177 | "import codecs\n", 178 | "\n", 179 | "# from onmt.utils.logging import init_logger, logger\n", 180 | "\n", 181 | "\n", 182 | "def test_rouge(candidates, references):\n", 183 | " \"\"\"Calculate ROUGE scores of sequences passed as an iterator\n", 184 | " e.g. a list of str, an open file, StringIO or even sys.stdin\n", 185 | " \"\"\"\n", 186 | " current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())\n", 187 | " tmp_dir = \".rouge-tmp-{}\".format(current_time)\n", 188 | " try:\n", 189 | " if not os.path.isdir(tmp_dir):\n", 190 | " os.mkdir(tmp_dir)\n", 191 | " os.mkdir(tmp_dir + \"/candidate\")\n", 192 | " os.mkdir(tmp_dir + \"/reference\")\n", 193 | "# candidates = [line.strip() for line in cand]\n", 194 | "# references = [line.strip() for line in ref]\n", 195 | " assert len(candidates) == len(references)\n", 196 | " cnt = len(candidates)\n", 197 | " for i in range(cnt):\n", 198 | " if len(references[i]) < 1:\n", 199 | " continue\n", 200 | " with open(tmp_dir + \"/candidate/cand.{}.txt\".format(i), \"w\",\n", 201 | " encoding=\"utf-8\") as f:\n", 202 | " f.write(candidates[i])\n", 203 | " with open(tmp_dir + \"/reference/ref.{}.txt\".format(i), \"w\",\n", 204 | " encoding=\"utf-8\") as f:\n", 205 | " f.write(references[i])\n", 206 | " r = pyrouge.Rouge155()\n", 207 | " r.model_dir = tmp_dir + \"/reference/\"\n", 208 | " r.system_dir = tmp_dir + \"/candidate/\"\n", 209 | " r.model_filename_pattern = 'ref.#ID#.txt'\n", 210 | " r.system_filename_pattern = 'cand.(\\d+).txt'\n", 211 | " rouge_results = r.convert_and_evaluate()\n", 212 | " results_dict = r.output_to_dict(rouge_results)\n", 213 | " return results_dict\n", 214 | " finally:\n", 215 | " pass\n", 216 | " if os.path.isdir(tmp_dir):\n", 217 | " shutil.rmtree(tmp_dir)\n", 218 | "\n", 219 | "\n", 220 | "def rouge_results_to_str(results_dict):\n", 221 | " return \">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}\".format(\n", 222 | " results_dict[\"rouge_1_f_score\"] * 100,\n", 223 | " results_dict[\"rouge_2_f_score\"] * 100,\n", 224 | " results_dict[\"rouge_3_f_score\"] * 100,\n", 225 | " results_dict[\"rouge_l_f_score\"] * 100,\n", 226 | " results_dict[\"rouge_su*_f_score\"] * 100)\n", 227 | "\n", 228 | "\n" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 19, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "fabbri_results = test_rouge(orig_system_hyps, gold_outputs)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 20, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "data": { 247 | "text/plain": [ 248 | "{'rouge_1_recall': 0.40213,\n", 249 | " 'rouge_1_recall_cb': 0.39868,\n", 250 | " 'rouge_1_recall_ce': 0.40566,\n", 251 | " 'rouge_1_precision': 0.54339,\n", 252 | " 'rouge_1_precision_cb': 0.54037,\n", 253 | " 'rouge_1_precision_ce': 0.54632,\n", 254 | " 'rouge_1_f_score': 0.44168,\n", 255 | " 'rouge_1_f_score_cb': 0.4389,\n", 256 | " 'rouge_1_f_score_ce': 0.44449,\n", 257 | " 'rouge_2_recall': 0.14777,\n", 258 | " 'rouge_2_recall_cb': 0.14517,\n", 259 | " 'rouge_2_recall_ce': 0.15075,\n", 260 | " 'rouge_2_precision': 0.19443,\n", 261 | " 'rouge_2_precision_cb': 0.19144,\n", 262 | " 'rouge_2_precision_ce': 0.19754,\n", 263 | " 'rouge_2_f_score': 0.16047,\n", 264 | " 'rouge_2_f_score_cb': 0.15786,\n", 265 | " 'rouge_2_f_score_ce': 0.1633,\n", 266 | " 'rouge_3_recall': 0.08085,\n", 267 | " 'rouge_3_recall_cb': 0.07846,\n", 268 | " 'rouge_3_recall_ce': 0.0836,\n", 269 | " 'rouge_3_precision': 0.10344,\n", 270 | " 'rouge_3_precision_cb': 0.10063,\n", 271 | " 'rouge_3_precision_ce': 0.10641,\n", 272 | " 'rouge_3_f_score': 0.08713,\n", 273 | " 'rouge_3_f_score_cb': 0.08471,\n", 274 | " 'rouge_3_f_score_ce': 0.0899,\n", 275 | " 'rouge_4_recall': 0.05704,\n", 276 | " 'rouge_4_recall_cb': 0.05477,\n", 277 | " 'rouge_4_recall_ce': 0.05962,\n", 278 | " 'rouge_4_precision': 0.07243,\n", 279 | " 'rouge_4_precision_cb': 0.06982,\n", 280 | " 'rouge_4_precision_ce': 0.07518,\n", 281 | " 'rouge_4_f_score': 0.06138,\n", 282 | " 'rouge_4_f_score_cb': 0.05906,\n", 283 | " 'rouge_4_f_score_ce': 0.06398,\n", 284 | " 'rouge_l_recall': 0.1952,\n", 285 | " 'rouge_l_recall_cb': 0.1929,\n", 286 | " 'rouge_l_recall_ce': 0.19779,\n", 287 | " 'rouge_l_precision': 0.26637,\n", 288 | " 'rouge_l_precision_cb': 0.26372,\n", 289 | " 'rouge_l_precision_ce': 0.269,\n", 290 | " 'rouge_l_f_score': 0.2138,\n", 291 | " 'rouge_l_f_score_cb': 0.21174,\n", 292 | " 'rouge_l_f_score_ce': 0.21614,\n", 293 | " 'rouge_w_1.2_recall': 0.04501,\n", 294 | " 'rouge_w_1.2_recall_cb': 0.04436,\n", 295 | " 'rouge_w_1.2_recall_ce': 0.04578,\n", 296 | " 'rouge_w_1.2_precision': 0.17602,\n", 297 | " 'rouge_w_1.2_precision_cb': 0.17376,\n", 298 | " 'rouge_w_1.2_precision_ce': 0.17825,\n", 299 | " 'rouge_w_1.2_f_score': 0.06859,\n", 300 | " 'rouge_w_1.2_f_score_cb': 0.06773,\n", 301 | " 'rouge_w_1.2_f_score_ce': 0.06959,\n", 302 | " 'rouge_s*_recall': 0.15868,\n", 303 | " 'rouge_s*_recall_cb': 0.15606,\n", 304 | " 'rouge_s*_recall_ce': 0.1615,\n", 305 | " 'rouge_s*_precision': 0.27614,\n", 306 | " 'rouge_s*_precision_cb': 0.27315,\n", 307 | " 'rouge_s*_precision_ce': 0.27936,\n", 308 | " 'rouge_s*_f_score': 0.17573,\n", 309 | " 'rouge_s*_f_score_cb': 0.17353,\n", 310 | " 'rouge_s*_f_score_ce': 0.17795,\n", 311 | " 'rouge_su*_recall': 0.16112,\n", 312 | " 'rouge_su*_recall_cb': 0.15848,\n", 313 | " 'rouge_su*_recall_ce': 0.16395,\n", 314 | " 'rouge_su*_precision': 0.28006,\n", 315 | " 'rouge_su*_precision_cb': 0.27705,\n", 316 | " 'rouge_su*_precision_ce': 0.28329,\n", 317 | " 'rouge_su*_f_score': 0.17844,\n", 318 | " 'rouge_su*_f_score_cb': 0.17623,\n", 319 | " 'rouge_su*_f_score_ce': 0.18067}" 320 | ] 321 | }, 322 | "execution_count": 20, 323 | "metadata": {}, 324 | "output_type": "execute_result" 325 | } 326 | ], 327 | "source": [ 328 | "fabbri_results" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "# Fabbri results\n", 338 | "# system_hypotheses = '../data/downloads_from_MultiNews_gdrive/export_output/hi_map.txt'\n", 339 | "{'rouge_1_recall': 0.40213,\n", 340 | " 'rouge_1_recall_cb': 0.39868,\n", 341 | " 'rouge_1_recall_ce': 0.40566,\n", 342 | " 'rouge_1_precision': 0.54339,\n", 343 | " 'rouge_1_precision_cb': 0.54037,\n", 344 | " 'rouge_1_precision_ce': 0.54632,\n", 345 | " 'rouge_1_f_score': 0.44168,\n", 346 | " 'rouge_1_f_score_cb': 0.4389,\n", 347 | " 'rouge_1_f_score_ce': 0.44449,\n", 348 | " 'rouge_2_recall': 0.14777,\n", 349 | " 'rouge_2_recall_cb': 0.14517,\n", 350 | " 'rouge_2_recall_ce': 0.15075,\n", 351 | " 'rouge_2_precision': 0.19443,\n", 352 | " 'rouge_2_precision_cb': 0.19144,\n", 353 | " 'rouge_2_precision_ce': 0.19754,\n", 354 | " 'rouge_2_f_score': 0.16047,\n", 355 | " 'rouge_2_f_score_cb': 0.15786,\n", 356 | " 'rouge_2_f_score_ce': 0.1633,\n", 357 | " 'rouge_3_recall': 0.08085,\n", 358 | " 'rouge_3_recall_cb': 0.07846,\n", 359 | " 'rouge_3_recall_ce': 0.0836,\n", 360 | " 'rouge_3_precision': 0.10344,\n", 361 | " 'rouge_3_precision_cb': 0.10063,\n", 362 | " 'rouge_3_precision_ce': 0.10641,\n", 363 | " 'rouge_3_f_score': 0.08713,\n", 364 | " 'rouge_3_f_score_cb': 0.08471,\n", 365 | " 'rouge_3_f_score_ce': 0.0899,\n", 366 | " 'rouge_4_recall': 0.05704,\n", 367 | " 'rouge_4_recall_cb': 0.05477,\n", 368 | " 'rouge_4_recall_ce': 0.05962,\n", 369 | " 'rouge_4_precision': 0.07243,\n", 370 | " 'rouge_4_precision_cb': 0.06982,\n", 371 | " 'rouge_4_precision_ce': 0.07518,\n", 372 | " 'rouge_4_f_score': 0.06138,\n", 373 | " 'rouge_4_f_score_cb': 0.05906,\n", 374 | " 'rouge_4_f_score_ce': 0.06398,\n", 375 | " 'rouge_l_recall': 0.1952,\n", 376 | " 'rouge_l_recall_cb': 0.1929,\n", 377 | " 'rouge_l_recall_ce': 0.19779,\n", 378 | " 'rouge_l_precision': 0.26637,\n", 379 | " 'rouge_l_precision_cb': 0.26372,\n", 380 | " 'rouge_l_precision_ce': 0.269,\n", 381 | " 'rouge_l_f_score': 0.2138,\n", 382 | " 'rouge_l_f_score_cb': 0.21174,\n", 383 | " 'rouge_l_f_score_ce': 0.21614,\n", 384 | " 'rouge_w_1.2_recall': 0.04501,\n", 385 | " 'rouge_w_1.2_recall_cb': 0.04436,\n", 386 | " 'rouge_w_1.2_recall_ce': 0.04578,\n", 387 | " 'rouge_w_1.2_precision': 0.17602,\n", 388 | " 'rouge_w_1.2_precision_cb': 0.17376,\n", 389 | " 'rouge_w_1.2_precision_ce': 0.17825,\n", 390 | " 'rouge_w_1.2_f_score': 0.06859,\n", 391 | " 'rouge_w_1.2_f_score_cb': 0.06773,\n", 392 | " 'rouge_w_1.2_f_score_ce': 0.06959,\n", 393 | " 'rouge_s*_recall': 0.15868,\n", 394 | " 'rouge_s*_recall_cb': 0.15606,\n", 395 | " 'rouge_s*_recall_ce': 0.1615,\n", 396 | " 'rouge_s*_precision': 0.27614,\n", 397 | " 'rouge_s*_precision_cb': 0.27315,\n", 398 | " 'rouge_s*_precision_ce': 0.27936,\n", 399 | " 'rouge_s*_f_score': 0.17573,\n", 400 | " 'rouge_s*_f_score_cb': 0.17353,\n", 401 | " 'rouge_s*_f_score_ce': 0.17795,\n", 402 | " 'rouge_su*_recall': 0.16112,\n", 403 | " 'rouge_su*_recall_cb': 0.15848,\n", 404 | " 'rouge_su*_recall_ce': 0.16395,\n", 405 | " 'rouge_su*_precision': 0.28006,\n", 406 | " 'rouge_su*_precision_cb': 0.27705,\n", 407 | " 'rouge_su*_precision_ce': 0.28329,\n", 408 | " 'rouge_su*_f_score': 0.17844,\n", 409 | " 'rouge_su*_f_score_cb': 0.17623,\n", 410 | " 'rouge_su*_f_score_ce': 0.18067}\n", 411 | "\n", 412 | "\n", 413 | "# Fabbri results\n", 414 | "# system_hypotheses = '../MultiNews_1_article_eval_predicted_summaries.out'\n", 415 | "{'rouge_1_recall': 0.44188,\n", 416 | " 'rouge_1_recall_cb': 0.43897,\n", 417 | " 'rouge_1_recall_ce': 0.44475,\n", 418 | " 'rouge_1_precision': 0.5006,\n", 419 | " 'rouge_1_precision_cb': 0.4972,\n", 420 | " 'rouge_1_precision_ce': 0.5039,\n", 421 | " 'rouge_1_f_score': 0.45928,\n", 422 | " 'rouge_1_f_score_cb': 0.45649,\n", 423 | " 'rouge_1_f_score_ce': 0.4619,\n", 424 | " 'rouge_2_recall': 0.15733,\n", 425 | " 'rouge_2_recall_cb': 0.15442,\n", 426 | " 'rouge_2_recall_ce': 0.16038,\n", 427 | " 'rouge_2_precision': 0.17746,\n", 428 | " 'rouge_2_precision_cb': 0.1743,\n", 429 | " 'rouge_2_precision_ce': 0.18095,\n", 430 | " 'rouge_2_f_score': 0.1633,\n", 431 | " 'rouge_2_f_score_cb': 0.16045,\n", 432 | " 'rouge_2_f_score_ce': 0.16641,\n", 433 | " 'rouge_3_recall': 0.08404,\n", 434 | " 'rouge_3_recall_cb': 0.08145,\n", 435 | " 'rouge_3_recall_ce': 0.08676,\n", 436 | " 'rouge_3_precision': 0.09447,\n", 437 | " 'rouge_3_precision_cb': 0.09159,\n", 438 | " 'rouge_3_precision_ce': 0.09757,\n", 439 | " 'rouge_3_f_score': 0.08718,\n", 440 | " 'rouge_3_f_score_cb': 0.08444,\n", 441 | " 'rouge_3_f_score_ce': 0.09007,\n", 442 | " 'rouge_4_recall': 0.05816,\n", 443 | " 'rouge_4_recall_cb': 0.05576,\n", 444 | " 'rouge_4_recall_ce': 0.06072,\n", 445 | " 'rouge_4_precision': 0.06546,\n", 446 | " 'rouge_4_precision_cb': 0.0628,\n", 447 | " 'rouge_4_precision_ce': 0.06827,\n", 448 | " 'rouge_4_f_score': 0.06042,\n", 449 | " 'rouge_4_f_score_cb': 0.05795,\n", 450 | " 'rouge_4_f_score_ce': 0.06308,\n", 451 | " 'rouge_l_recall': 0.22003,\n", 452 | " 'rouge_l_recall_cb': 0.21755,\n", 453 | " 'rouge_l_recall_ce': 0.22262,\n", 454 | " 'rouge_l_precision': 0.24691,\n", 455 | " 'rouge_l_precision_cb': 0.24421,\n", 456 | " 'rouge_l_precision_ce': 0.24985,\n", 457 | " 'rouge_l_f_score': 0.22757,\n", 458 | " 'rouge_l_f_score_cb': 0.22522,\n", 459 | " 'rouge_l_f_score_ce': 0.23011,\n", 460 | " 'rouge_w_1.2_recall': 0.04994,\n", 461 | " 'rouge_w_1.2_recall_cb': 0.04921,\n", 462 | " 'rouge_w_1.2_recall_ce': 0.05076,\n", 463 | " 'rouge_w_1.2_precision': 0.16026,\n", 464 | " 'rouge_w_1.2_precision_cb': 0.15804,\n", 465 | " 'rouge_w_1.2_precision_ce': 0.16262,\n", 466 | " 'rouge_w_1.2_f_score': 0.07433,\n", 467 | " 'rouge_w_1.2_f_score_cb': 0.0733,\n", 468 | " 'rouge_w_1.2_f_score_ce': 0.07548,\n", 469 | " 'rouge_s*_recall': 0.18912,\n", 470 | " 'rouge_s*_recall_cb': 0.18649,\n", 471 | " 'rouge_s*_recall_ce': 0.19164,\n", 472 | " 'rouge_s*_precision': 0.24401,\n", 473 | " 'rouge_s*_precision_cb': 0.24069,\n", 474 | " 'rouge_s*_precision_ce': 0.2473,\n", 475 | " 'rouge_s*_f_score': 0.19723,\n", 476 | " 'rouge_s*_f_score_cb': 0.19477,\n", 477 | " 'rouge_s*_f_score_ce': 0.19965,\n", 478 | " 'rouge_su*_recall': 0.19164,\n", 479 | " 'rouge_su*_recall_cb': 0.189,\n", 480 | " 'rouge_su*_recall_ce': 0.19416,\n", 481 | " 'rouge_su*_precision': 0.24676,\n", 482 | " 'rouge_su*_precision_cb': 0.24344,\n", 483 | " 'rouge_su*_precision_ce': 0.25006,\n", 484 | " 'rouge_su*_f_score': 0.19978,\n", 485 | " 'rouge_su*_f_score_cb': 0.19732,\n", 486 | " 'rouge_su*_f_score_ce': 0.20221}" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "metadata": {}, 493 | "outputs": [], 494 | "source": [ 495 | "if __name__ == \"__main__\":\n", 496 | "# init_logger('test_rouge.log')\n", 497 | " parser = argparse.ArgumentParser()\n", 498 | " parser.add_argument('-c', type=str, default=\"candidate.txt\",\n", 499 | " help='candidate file')\n", 500 | " parser.add_argument('-r', type=str, default=\"reference.txt\",\n", 501 | " help='reference file')\n", 502 | " args = parser.parse_args()\n", 503 | " if args.c.upper() == \"STDIN\":\n", 504 | " candidates = sys.stdin\n", 505 | " else:\n", 506 | " candidates = codecs.open(args.c, encoding=\"utf-8\")\n", 507 | " references = codecs.open(args.r, encoding=\"utf-8\")\n", 508 | "\n", 509 | " results_dict = test_rouge(candidates, references)\n", 510 | " logger.info(rouge_results_to_str(results_dict))" 511 | ] 512 | } 513 | ], 514 | "metadata": { 515 | "kernelspec": { 516 | "display_name": "Python 3", 517 | "language": "python", 518 | "name": "python3" 519 | }, 520 | "language_info": { 521 | "codemirror_mode": { 522 | "name": "ipython", 523 | "version": 3 524 | }, 525 | "file_extension": ".py", 526 | "mimetype": "text/x-python", 527 | "name": "python", 528 | "nbconvert_exporter": "python", 529 | "pygments_lexer": "ipython3", 530 | "version": "3.7.3" 531 | } 532 | }, 533 | "nbformat": 4, 534 | "nbformat_minor": 4 535 | } 536 | -------------------------------------------------------------------------------- /research/multidoc_jsonl_dataset_to_parallel_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# flatten a multidoc summarization dataset in .jsonl format to a parallel dataset that uses\n", 10 | "# the *.sources *.targets format from cnn-dm\n", 11 | "\n", 12 | "# TODO: support shuffling since the cluster items will be sequential by default\n", 13 | "\n", 14 | "# TODO: support formatting with special tokens to indicate document structure (i.e. token between Title and Body)\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 4, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from pathlib import Path\n", 24 | "import json\n", 25 | "import tqdm\n", 26 | "\n", 27 | "import numpy as np\n", 28 | "\n", 29 | "from transformer_decoding.evaluate import article_to_text" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 5, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "111it [00:00, 1109.16it/s]" 42 | ] 43 | }, 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "loading clusters\n" 49 | ] 50 | }, 51 | { 52 | "name": "stderr", 53 | "output_type": "stream", 54 | "text": [ 55 | "8158it [00:08, 928.53it/s] \n", 56 | "100%|██████████| 8158/8158 [00:00<00:00, 143197.27it/s]\n", 57 | "84it [00:00, 838.08it/s]" 58 | ] 59 | }, 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "wrote 8158 segments from 8158 clusters to /home/chris/projects/aylien/dynamic-ensembles/data/WCEP/train.source and /home/chris/projects/aylien/dynamic-ensembles/data/WCEP/train.target\n", 65 | "loading clusters\n" 66 | ] 67 | }, 68 | { 69 | "name": "stderr", 70 | "output_type": "stream", 71 | "text": [ 72 | "1020it [00:01, 737.69it/s]\n", 73 | "100%|██████████| 1020/1020 [00:00<00:00, 145893.81it/s]" 74 | ] 75 | }, 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "wrote 1020 segments from 1020 clusters to /home/chris/projects/aylien/dynamic-ensembles/data/WCEP/val.source and /home/chris/projects/aylien/dynamic-ensembles/data/WCEP/val.target\n" 81 | ] 82 | }, 83 | { 84 | "name": "stderr", 85 | "output_type": "stream", 86 | "text": [ 87 | "\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "DATADIR = Path('/home/chris/projects/aylien/dynamic-ensembles/data/WCEP')\n", 93 | "\n", 94 | "prefixes = ['train', 'val']\n", 95 | "# prefixes = ['val']\n", 96 | "\n", 97 | "\n", 98 | "shuffle = True\n", 99 | "actual_source_only = True\n", 100 | "\n", 101 | "\n", 102 | "for dataset_prefix in prefixes:\n", 103 | " sources_and_targets = []\n", 104 | " cluster_cnt = 0\n", 105 | " print('loading clusters')\n", 106 | " for cluster in tqdm.tqdm((json.loads(l) for l in open(DATADIR / (dataset_prefix + '.jsonl')))):\n", 107 | "\n", 108 | "\n", 109 | " for article in cluster['articles']:\n", 110 | " if actual_source_only:\n", 111 | " # only append one actual source per cluster\n", 112 | " if article['origin'] == 'WCEP':\n", 113 | " sources_and_targets.append((article_to_text(article), cluster['summary']))\n", 114 | " break\n", 115 | " else:\n", 116 | " # use all sources per cluster\n", 117 | " sources_and_targets.append((article_to_text(article), cluster['summary']))\n", 118 | " cluster_cnt += 1\n", 119 | " \n", 120 | " output_idxs = np.arange(len(sources_and_targets))\n", 121 | " if shuffle:\n", 122 | " np.random.shuffle(output_idxs)\n", 123 | " \n", 124 | " with open(DATADIR / (dataset_prefix + '.source'), 'w') as srcs, open(DATADIR / (dataset_prefix + '.target'), 'w') as tgts:\n", 125 | " for idx in tqdm.tqdm(output_idxs):\n", 126 | " src = sources_and_targets[idx][0]\n", 127 | " tgt = sources_and_targets[idx][1]\n", 128 | " srcs.write(f'{src}\\n')\n", 129 | " tgts.write(f'{tgt}\\n')\n", 130 | " print(f'wrote {len(sources_and_targets)} segments from {cluster_cnt} clusters to {srcs.name} and {tgts.name}')\n", 131 | " \n", 132 | " \n", 133 | " " 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.6.10" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 4 165 | } 166 | -------------------------------------------------------------------------------- /research/multinews_to_mds_jsonl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# convert the multinews dataset to aylien's MDS jsonl format" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 3, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from pathlib import Path\n", 19 | "import json\n" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 4, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "DATADIR = Path('/home/chris/projects/aylien/dynamic-ensembles/data/multi-news/')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 6, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "\n", 38 | "\n", 39 | "source_targets = {\n", 40 | " 'test': (DATADIR / 'test.src.cleaned', DATADIR / 'test.tgt')\n", 41 | "}\n", 42 | "\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 10, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "from collections import Counter\n", 52 | "\n", 53 | "import numpy as np" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 19, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "[(2, 3022), (3, 1540), (4, 609), (5, 219), (6, 96), (1, 72), (7, 40), (8, 15), (9, 8), (10, 1)]\n", 66 | "(216.9884382781928, 4514.869485679133, 67.19277852328428)\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "multinews_article_delimiter = ' ||||| '\n", 72 | "cluster_rows = []\n", 73 | "\n", 74 | "article_cnts = Counter()\n", 75 | "summary_lens = []\n", 76 | "\n", 77 | "for prefix, (srcs_f, tgt_f) in source_targets.items():\n", 78 | " with open(srcs_f) as c_srcs, open(tgt_f) as c_tgt, open(DATADIR / (prefix + '.jsonl'), 'w') as out:\n", 79 | " for srcs, tgt in zip(c_srcs, c_tgt):\n", 80 | " articles = [{'title': '', 'text': t} for t in srcs.split(multinews_article_delimiter)]\n", 81 | " out.write(f'{json.dumps({\"articles\": articles, \"summary\": tgt})}\\n')\n", 82 | " article_cnts.update([len(articles)])\n", 83 | " summary_lens.append(len(tgt.split()))\n", 84 | "\n", 85 | "print(article_cnts.most_common())\n", 86 | "print((np.mean(summary_lens), np.var(summary_lens), np.std(summary_lens)))" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "Python 3", 100 | "language": "python", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.6.10" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 4 118 | } 119 | -------------------------------------------------------------------------------- /research/multinews_to_single_doc_parallel.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# multinews has documents in each cluster separated by ' ||||| ', we just take the first one \n", 10 | "\n", 11 | "from pathlib import Path\n", 12 | "import json\n", 13 | "from collections import Counter\n", 14 | "\n", 15 | "import numpy as np" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 6, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "DATADIR = Path('/home/chris/projects/aylien/dynamic-ensembles/data/multi-news/')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 7, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "source_targets = {\n", 34 | " 'train': (DATADIR / 'train.src.cleaned', DATADIR / 'train.tgt'),\n", 35 | " 'val': (DATADIR / 'val.src.cleaned', DATADIR / 'val.tgt')\n", 36 | "}" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 11, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "train\n", 49 | "[(2, 23741), (3, 12577), (4, 4921), (5, 1846), (6, 706), (1, 506), (7, 371), (8, 194), (9, 81), (10, 29)]\n", 50 | "Input stats:\n", 51 | "(685.2630908519233, 3145179.089993763, 1773.4652773577957)\n", 52 | "Summary stats:\n", 53 | "(218.25813839722494, 4630.282181609593, 68.04617683315935)\n", 54 | "val\n", 55 | "[(2, 3066), (3, 1555), (4, 610), (5, 195), (6, 79), (1, 59), (7, 38), (8, 13), (9, 7)]\n", 56 | "Input stats:\n", 57 | "(684.0616682039795, 1994918.9491552613, 1412.415997203112)\n", 58 | "Summary stats:\n", 59 | "(216.71380291711134, 4577.909731215516, 67.6602522254796)\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "# extract first doc from each source cluster and use cnn-dm filename convention\n", 65 | "\n", 66 | "multinews_article_delimiter = ' ||||| '\n", 67 | "\n", 68 | "\n", 69 | "\n", 70 | "for prefix, (srcs_f, tgt_f) in source_targets.items():\n", 71 | " article_cnts = Counter()\n", 72 | " source_lens = []\n", 73 | " summary_lens = []\n", 74 | " with open(srcs_f) as c_srcs, open(tgt_f) as c_tgts, open(DATADIR / (prefix + '.source'), 'w') as out:\n", 75 | " for srcs, tgt in zip(c_srcs, c_tgts):\n", 76 | " articles = srcs.split(multinews_article_delimiter)\n", 77 | " out.write(f'{articles[0].strip()}\\n')\n", 78 | " article_cnts.update([len(articles)])\n", 79 | " summary_lens.append(len(tgt.split()))\n", 80 | " source_lens.extend([len(s.split()) for s in articles])\n", 81 | " print(prefix)\n", 82 | " print(article_cnts.most_common())\n", 83 | " print('Input stats:')\n", 84 | " print((np.mean(source_lens), np.var(source_lens), np.std(source_lens)))\n", 85 | " print('Summary stats:')\n", 86 | " print((np.mean(summary_lens), np.var(summary_lens), np.std(summary_lens)))\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 12, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "data": { 96 | "text/plain": [ 97 | "768" 98 | ] 99 | }, 100 | "execution_count": 12, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "256 *3" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [] 115 | } 116 | ], 117 | "metadata": { 118 | "kernelspec": { 119 | "display_name": "Python 3", 120 | "language": "python", 121 | "name": "python3" 122 | }, 123 | "language_info": { 124 | "codemirror_mode": { 125 | "name": "ipython", 126 | "version": 3 127 | }, 128 | "file_extension": ".py", 129 | "mimetype": "text/x-python", 130 | "name": "python", 131 | "nbconvert_exporter": "python", 132 | "pygments_lexer": "ipython3", 133 | "version": "3.6.10" 134 | } 135 | }, 136 | "nbformat": 4, 137 | "nbformat_minor": 4 138 | } 139 | -------------------------------------------------------------------------------- /research/prototype_with_presumm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# using a pre-trained summarization model, create one instance for every input, then decode from the ensemble" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# https://github.com/nlpyang/PreSumm" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "#### Updates: For encoding a text longer than 512 tokens, for example 800. Set max_pos to 800 during both preprocessing and training.\n", 26 | "\n", 27 | "-mode can be {validate, test}, where validate will inspect the model directory and evaluate the model for each newly saved checkpoint, test need to be used with -test_from, indicating the checkpoint you want to use\n", 28 | "MODEL_PATH is the directory of saved checkpoints\n", 29 | "use -mode valiadte with -test_all, the system will load all saved checkpoints and select the top ones to generate summaries (this will take a while)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# the baseline setup results in memory error, try building on MT-GPU, or containerize for ease of use " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "%bash\n", 48 | "\n", 49 | "# probably pytorch version in their requirements.txt\n", 50 | "# RuntimeError: cuda runtime error (38) : no CUDA-capable device is detected at /pytorch/aten/src/THC/THCGeneral.cpp:51\n", 51 | "\n", 52 | "\n", 53 | "cd ~/projects/PreSumm/src\n", 54 | "\n", 55 | "source activate presumm\n", 56 | "\n", 57 | "BATCH_SIZE=1\n", 58 | "# note last part of BERT_DATA_PATH is file prefix\n", 59 | "BERT_DATA_PATH=/data/PreSumm_data/bert_data/bert_data_cnndm_final/cnndm\n", 60 | "MODEL_PATH=/data/PreSumm_data/models\n", 61 | "\n", 62 | "python train.py \\\n", 63 | " -task abs \\\n", 64 | " -mode validate \\\n", 65 | " -batch_size ${BATCH_SIZE} \\\n", 66 | " -test_batch_size ${BATCH_SIZE} \\\n", 67 | " -bert_data_path ${BERT_DATA_PATH} \\\n", 68 | " -log_file ../logs/val_abs_bert_cnndm \\\n", 69 | " -model_path ${MODEL_PATH} \\\n", 70 | " -sep_optim true \\\n", 71 | " -use_interval true \\\n", 72 | " -visible_gpus 0 \\\n", 73 | " -max_pos 512 \\\n", 74 | " -max_length 200 \\\n", 75 | " -alpha 0.95 \\\n", 76 | " -min_length 50 \\\n", 77 | " -result_path ../logs/abs_bert_cnndm \n" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# export CORENLP_HOME=/data/stanford_core_nlp/stanford-corenlp-full-2018-10-05" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "%bash\n", 96 | "\n", 97 | "export CLASSPATH=/data/stanford_core\n", 98 | "\n", 99 | "\n", 100 | "java edu.stanford.nlp.pipeline.StanfordCoreNLP \\\n", 101 | " -annotators tokenize,ssplit \\\n", 102 | " -ssplit.newlineIsSentenceBreak always \\ \n", 103 | " -filelist mapping_for_corenlp.txt \\\n", 104 | " -outputFormat json \\\n", 105 | " -outputDirectory tokenized_stories_dir\n", 106 | "\n", 107 | "\n", 108 | "command = ['java', 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 'tokenize,ssplit',\n", 109 | " '-ssplit.newlineIsSentenceBreak', 'always', '-filelist', 'mapping_for_corenlp.txt', '-outputFormat',\n", 110 | " 'json', '-outputDirectory', tokenized_stories_dir]\n", 111 | " print(\"Tokenizing %i files in %s and saving in %s...\" % (len(stories), stories_dir, tokenized_stories_dir)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "```\n", 119 | "\n", 120 | "# NOTE: we still need to clean the multinews format (removing NEWLINE tokens and document separators, etc...)\n", 121 | "\n", 122 | "export CORENLP_HOME=/data/stanford_core_nlp/stanford-corenlp-full-2018-10-05\n", 123 | "\n", 124 | "# annotate -i val.src.100 -f json --annotators tokenize ssplit | jq '{src: [.[][] | [.tokens[].word]]}' > val.src.100.corenlp.json \n", 125 | "\n", 126 | "\n", 127 | "# WORKING one-liner\n", 128 | "jq -n \\\n", 129 | " --slurpfile o1 <(annotate -i val.src.50 -f json --annotators tokenize ssplit | jq '{src: [.[][] | [.tokens[].word]]}') \\\n", 130 | " --slurpfile o2 <(annotate -i val.tgt.50 -f json --annotators tokenize ssplit | jq '{tgt: [.[][] | [.tokens[].word]]}') \\\n", 131 | " 'reduce range(0; $o1|length) as $i ([]; . + [{ \"src\": $o1[$i].src, \"tgt\": $o2[$i].tgt}])' | less\n", 132 | "\n", 133 | "\n", 134 | "export CORENLP_HOME=/data/stanford_core_nlp/stanford-corenlp-full-2018-10-05\n", 135 | "DATADIR=/data/PreSumm_data/multi-news/preprocessed_truncated\n", 136 | "VALID_SRC=${DATADIR}/test.txt.src.tokenized.fixed.cleaned.final.truncated.txt\n", 137 | "VALID_TGT=${DATADIR}/test.txt.tgt.tokenized.fixed.cleaned.final.truncated.txt\n", 138 | "VALID_OUT=${DATADIR}/test.corenlp.json\n", 139 | "jq -n \\\n", 140 | " --slurpfile o1 <(annotate -i ${VALID_SRC} -f json --annotators tokenize ssplit | jq '{src: [.[][] | [.tokens[].word]]}') \\\n", 141 | " --slurpfile o2 <(annotate -i ${VALID_TGT} -f json --annotators tokenize ssplit | jq '{tgt: [.[][] | [.tokens[].word]]}') \\\n", 142 | " 'reduce range(0; $o1|length) as $i ([]; . + [{ \"src\": $o1[$i].src, \"tgt\": $o2[$i].tgt}])' > ${VALID_OUT}\n", 143 | "\n", 144 | "\n", 145 | "\n", 146 | "\n", 147 | "# After the one-liner above we need to map into .pt files\n", 148 | "# Note file must have prefix in ['train', 'valid', 'test']\n", 149 | "\n", 150 | "source activate presumm\n", 151 | "PRESUM=/home/chrishokamp/projects/PreSumm\n", 152 | "JSON_DIR=/data/PreSumm_data/multi-news/preprocessed_truncated/presumm_json_input\n", 153 | "OUTPUT_DIR=${JSON_DIR}/bert_files_for_presumm\n", 154 | "mkdir -p ${OUTPUT_DIR}\n", 155 | "cd ${JSON_DIR}\n", 156 | "\n", 157 | "python $PRESUM/src/preprocess.py \\\n", 158 | " -mode format_to_bert \\\n", 159 | " -raw_path ${JSON_DIR} \\\n", 160 | " -save_path ${OUTPUT_DIR} \\\n", 161 | " -lower \\\n", 162 | " -n_cpus 1 \\\n", 163 | " -log_file preprocess.log\n", 164 | "\n", 165 | "\n", 166 | "# now rename files so that the prefixes work\n", 167 | "cp test.multinews.corenlp.bert.pt multinews.test.corenlp.bert.pt\n", 168 | "\n", 169 | "\n", 170 | "# Try summarizing the (flattened) multinews file\n", 171 | "# TODO: increase max length of summaries to fit with MultiNews dataset \n", 172 | "cd ~/projects/PreSumm/src\n", 173 | "\n", 174 | "source activate presumm\n", 175 | "\n", 176 | "BATCH_SIZE=32\n", 177 | "MAX_SUMMARY_LENGTH=128\n", 178 | "# note last part of BERT_DATA_PATH is file prefix\n", 179 | "BERT_DATA_PATH=/data/PreSumm_data/multi-news/preprocessed_truncated/presumm_json_input/bert_files_for_presumm/multinews\n", 180 | "MODEL_PATH=/data/PreSumm_data/models\n", 181 | "\n", 182 | "python train.py \\\n", 183 | " -task abs \\\n", 184 | " -mode validate \\\n", 185 | " -batch_size ${BATCH_SIZE} \\\n", 186 | " -test_batch_size ${BATCH_SIZE} \\\n", 187 | " -bert_data_path ${BERT_DATA_PATH} \\\n", 188 | " -log_file ../logs/val_abs_bert_cnndm \\\n", 189 | " -model_path ${MODEL_PATH} \\\n", 190 | " -sep_optim true \\\n", 191 | " -use_interval true \\\n", 192 | " -visible_gpus 0 \\\n", 193 | " -max_pos 512 \\\n", 194 | " -max_length ${MAX_SUMMARY_LENGTH} \\\n", 195 | " -alpha 0.95 \\\n", 196 | " -min_length 50 \\\n", 197 | " -result_path ../logs/abs_bert_cnndm \n", 198 | "\n", 199 | "```\n", 200 | "\n", 201 | "\n" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "# multinews has rouge from opennmt, presumably this is what they used \n", 211 | "# https://github.com/Alex-Fabbri/Multi-News/blob/3675e7c422ae3b4020617a324ac264f50333357d/code/OpenNMT-py-baselines/tools/test_rouge.py" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "# split every multinews line into constituent story files\n", 221 | "\n", 222 | "# download Stanford NLP and set classpath accordingly\n", 223 | "\n", 224 | "# remember presumm does a lot of idiosyncratic things with the BERT special tokenss\n", 225 | "\n", 226 | "def multinews_to_presumm_json_format(multinews_file):\n", 227 | " \"\"\"Simplest possible thing: just flatten a multinews row into a single document\"\"\"\n", 228 | " pass" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "https://github.com/Alex-Fabbri/Multi-News\n", 236 | "\n" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "# Preprocessing to prepare a new test dataset\n", 246 | "\n", 247 | "# Note we try to go around having to use their clunky preprocessing\n", 248 | "\n" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "# (1) Format MultiNews to .json format of " 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# (2) Map json-formatted data to pytorch tensors for BERT, store them in a file that we can use \n", 267 | "# to get the summaries for the MultiNews dev+test sets" 268 | ] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "Python 3", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.6.8" 288 | } 289 | }, 290 | "nbformat": 4, 291 | "nbformat_minor": 2 292 | } 293 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('requirements.txt') as f: 4 | requirements = f.read().splitlines() 5 | 6 | with open('VERSION') as f: 7 | version = f.read().strip() 8 | 9 | setup( 10 | name="transformer_decoding", 11 | version=version, 12 | packages=['transformer_decoding'], 13 | install_requires=requirements 14 | ) 15 | -------------------------------------------------------------------------------- /transformer_decoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrishokamp/dynamic-transformer-ensembles/bcd91da4b70086a8f7a3a45bbfed03d4bbf497e7/transformer_decoding/__init__.py -------------------------------------------------------------------------------- /transformer_decoding/bart_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from transformers.tokenization_utils import trim_batch 7 | 8 | 9 | def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"): 10 | examples = [] 11 | with open(data_path, "r") as f: 12 | for text in f.readlines(): 13 | tokenized = tokenizer.batch_encode_plus( 14 | [text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, 15 | ) 16 | examples.append(tokenized) 17 | return examples 18 | 19 | 20 | class SummarizationDataset(Dataset): 21 | def __init__( 22 | self, 23 | tokenizer, 24 | data_dir="./cnn-dailymail/cnn_dm/", 25 | type_path="train", 26 | max_source_length=1024, 27 | max_target_length=56, 28 | ): 29 | super().__init__() 30 | self.tokenizer = tokenizer 31 | self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length) 32 | self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length) 33 | 34 | def __len__(self): 35 | return len(self.source) 36 | 37 | def __getitem__(self, index): 38 | source_ids = self.source[index]["input_ids"].squeeze() 39 | target_ids = self.target[index]["input_ids"].squeeze() 40 | src_mask = self.source[index]["attention_mask"].squeeze() 41 | return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids} 42 | 43 | @staticmethod 44 | def trim_seq2seq_batch(batch, pad_token_id): 45 | y = trim_batch(batch["target_ids"], pad_token_id) 46 | source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) 47 | return source_ids, source_mask, y 48 | 49 | def collate_fn(self, batch): 50 | input_ids = torch.stack([x["source_ids"] for x in batch]) 51 | masks = torch.stack([x["source_mask"] for x in batch]) 52 | target_ids = torch.stack([x["target_ids"] for x in batch]) 53 | pad_token_id = self.tokenizer.pad_token_id 54 | y = trim_batch(target_ids, pad_token_id) 55 | source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 56 | return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y} 57 | -------------------------------------------------------------------------------- /transformer_decoding/decoding_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import copy 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | from transformers import modeling_utils 8 | 9 | from transformer_decoding import decoding_utils, log 10 | 11 | 12 | logger = log.create_logger(__name__) 13 | 14 | 15 | def generate(component_states, timesteps, ensemble_state=None, timestep_mask=None): 16 | """ 17 | Run generation for a number of timesteps 18 | """ 19 | if type(component_states) is not list: 20 | component_states = [component_states] 21 | 22 | if ensemble_state is None: 23 | assert len(component_states) == 1 24 | for step_idx in range(timesteps): 25 | component_states[0] = decoding_utils.beam_search_step(component_states[0]) 26 | else: 27 | step_mask = None 28 | for step_idx in range(timesteps): 29 | if timestep_mask is not None: 30 | if step_idx == timestep_mask.shape[0] - 1: 31 | break 32 | step_mask = timestep_mask[step_idx] 33 | 34 | component_states, ensemble_state = \ 35 | decoding_utils.ensembled_beam_search_step(component_states, ensemble_state, step_mask=step_mask) 36 | 37 | return component_states, ensemble_state 38 | 39 | 40 | class BeamHypotheses(object): 41 | def __init__(self, num_beams, max_length, length_penalty, early_stopping): 42 | """ 43 | Initialize n-best list of hypotheses. 44 | """ 45 | self.max_length = max_length - 1 # ignoring bos_token 46 | self.length_penalty = length_penalty 47 | self.early_stopping = early_stopping 48 | self.num_beams = num_beams 49 | self.beams = [] 50 | self.worst_score = 1e9 51 | 52 | def __len__(self): 53 | """ 54 | Number of hypotheses in the list. 55 | """ 56 | return len(self.beams) 57 | 58 | def add(self, hyp, sum_logprobs, metadata=None): 59 | """ 60 | Add a new hypothesis to the list. 61 | """ 62 | score = sum_logprobs / len(hyp) ** self.length_penalty 63 | if len(self) < self.num_beams or score > self.worst_score: 64 | self.beams.append((score, hyp, metadata)) 65 | if len(self) > self.num_beams: 66 | sorted_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) 67 | del self.beams[sorted_scores[0][1]] 68 | self.worst_score = sorted_scores[1][0] 69 | else: 70 | self.worst_score = min(score, self.worst_score) 71 | 72 | def is_done(self, best_sum_logprobs, cur_len=None): 73 | """ 74 | If there are enough hypotheses and that none of the hypotheses being generated 75 | can become better than the worst one in the heap, then we are done with this sentence. 76 | """ 77 | 78 | if len(self) < self.num_beams: 79 | return False 80 | elif self.early_stopping: 81 | return True 82 | else: 83 | if cur_len is None: 84 | cur_len = self.max_length 85 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty 86 | ret = self.worst_score >= cur_score 87 | return ret 88 | 89 | 90 | def get_start_state(text, model, tokenizer, decoding_hyperparams): 91 | 92 | # set up state 93 | decoder_state = decoding_utils.get_initial_decoding_state( 94 | text=text, 95 | model=model, 96 | tokenizer=tokenizer, 97 | decoding_hyperparams=decoding_hyperparams 98 | ) 99 | # TODO: move to `decoding_utils.get_initial_decoding_state(?) 100 | if torch.cuda.is_available(): 101 | decoder_state['input_ids'] = decoder_state['input_ids'].to('cuda') 102 | 103 | # TODO: this logic may move to 104 | # `get_initial_decoding_state` 105 | decoder_state['generated_hyps'] = [ 106 | BeamHypotheses( 107 | decoder_state['num_beams'], 108 | decoder_state['max_length'], 109 | decoder_state['length_penalty'], 110 | early_stopping=decoder_state['early_stopping']) 111 | for _ in range(decoder_state['batch_size']) 112 | ] 113 | 114 | # scores for each sentence in the beam 115 | decoder_state['beam_scores'] = \ 116 | torch.zeros((decoder_state['batch_size'], decoder_state['num_beams']), 117 | dtype=torch.float, 118 | device=decoder_state['input_ids'].device) 119 | 120 | # for greedy decoding it is made sure that only tokens of the first beam are considered 121 | # to avoid sampling the exact same tokens three times 122 | if decoder_state['do_sample'] is False: 123 | decoder_state['beam_scores'][:, 1:] = -1e9 124 | decoder_state['beam_scores'] = decoder_state['beam_scores'].view(-1) # shape (batch_size * num_beams,) 125 | 126 | # cache compute states 127 | decoder_state['past'] = decoder_state[ 128 | 'encoder_outputs'] # defined for encoder-decoder models, None for decoder-only models 129 | 130 | # done sentences 131 | decoder_state['done'] = [False for _ in range(decoder_state['batch_size'])] 132 | 133 | return decoder_state 134 | 135 | 136 | def initialize_generation( 137 | model, 138 | input_ids=None, 139 | max_length=None, 140 | min_length=None, 141 | do_sample=None, 142 | early_stopping=None, 143 | num_beams=None, 144 | temperature=None, 145 | top_k=None, 146 | top_p=None, 147 | repetition_penalty=None, 148 | bad_words_ids=None, 149 | bos_token_id=None, 150 | pad_token_id=None, 151 | eos_token_id=None, 152 | length_penalty=None, 153 | no_repeat_ngram_size=None, 154 | num_return_sequences=None, 155 | attention_mask=None, 156 | decoder_start_token_id=None, 157 | **kwargs 158 | ): 159 | # We cannot generate if the model does not have a LM head 160 | if model.get_output_embeddings() is None: 161 | raise AttributeError( 162 | "You tried to generate sequences with a model that does not have a LM Head." 163 | "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" 164 | ) 165 | 166 | max_length = max_length if max_length is not None else model.config.max_length 167 | min_length = min_length if min_length is not None else model.config.min_length 168 | do_sample = do_sample if do_sample is not None else model.config.do_sample 169 | early_stopping = early_stopping if early_stopping is not None else model.config.early_stopping 170 | num_beams = num_beams if num_beams is not None else model.config.num_beams 171 | temperature = temperature if temperature is not None else model.config.temperature 172 | top_k = top_k if top_k is not None else model.config.top_k 173 | top_p = top_p if top_p is not None else model.config.top_p 174 | repetition_penalty = repetition_penalty if repetition_penalty is not None else model.config.repetition_penalty 175 | bos_token_id = bos_token_id if bos_token_id is not None else model.config.bos_token_id 176 | pad_token_id = pad_token_id if pad_token_id is not None else model.config.pad_token_id 177 | eos_token_id = eos_token_id if eos_token_id is not None else model.config.eos_token_id 178 | length_penalty = length_penalty if length_penalty is not None else model.config.length_penalty 179 | no_repeat_ngram_size = ( 180 | no_repeat_ngram_size if no_repeat_ngram_size is not None else model.config.no_repeat_ngram_size 181 | ) 182 | bad_words_ids = bad_words_ids if bad_words_ids is not None else model.config.bad_words_ids 183 | num_return_sequences = ( 184 | num_return_sequences if num_return_sequences is not None else model.config.num_return_sequences 185 | ) 186 | decoder_start_token_id = ( 187 | decoder_start_token_id if decoder_start_token_id is not None else model.config.decoder_start_token_id 188 | ) 189 | 190 | if input_ids is not None: 191 | batch_size = input_ids.shape[0] # overriden by the input batch_size 192 | else: 193 | batch_size = 1 194 | 195 | assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." 196 | assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." 197 | assert isinstance(do_sample, bool), "`do_sample` should be a boolean." 198 | assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." 199 | assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." 200 | assert temperature > 0, "`temperature` should be strictly positive." 201 | assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." 202 | assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." 203 | assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." 204 | assert input_ids is not None or ( 205 | isinstance(bos_token_id, int) and bos_token_id >= 0 206 | ), "If input_ids is not defined, `bos_token_id` should be a positive integer." 207 | assert pad_token_id is None or ( 208 | isinstance(pad_token_id, int) and (pad_token_id >= 0) 209 | ), "`pad_token_id` should be a positive integer." 210 | assert (eos_token_id is None) or ( 211 | isinstance(eos_token_id, int) and (eos_token_id >= 0) 212 | ), "`eos_token_id` should be a positive integer." 213 | assert length_penalty > 0, "`length_penalty` should be strictly positive." 214 | assert ( 215 | isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 216 | ), "`no_repeat_ngram_size` should be a positive integer." 217 | assert ( 218 | isinstance(num_return_sequences, int) and num_return_sequences > 0 219 | ), "`num_return_sequences` should be a strictly positive integer." 220 | assert ( 221 | bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) 222 | ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" 223 | 224 | if input_ids is None: 225 | assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( 226 | "you should either supply a context to complete as `input_ids` input " 227 | "or a `bos_token_id` (integer >= 0) as a first token to start the generation." 228 | ) 229 | input_ids = torch.full( 230 | (batch_size, 1), bos_token_id, dtype=torch.long, device=next(model.parameters()).device, 231 | ) 232 | else: 233 | assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." 234 | # Chris: added this line to fix error when running on GPU 235 | input_ids = input_ids.to(next(model.parameters()).device) 236 | 237 | # not allow to duplicate outputs when greedy decoding 238 | if do_sample is False: 239 | if num_beams == 1: 240 | # no_beam_search greedy generation conditions 241 | assert ( 242 | num_return_sequences == 1 243 | ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" 244 | 245 | else: 246 | # beam_search greedy generation conditions 247 | assert ( 248 | num_beams >= num_return_sequences 249 | ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" 250 | 251 | # create attention mask if necessary 252 | # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 253 | if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): 254 | attention_mask = input_ids.ne(pad_token_id).long() 255 | elif attention_mask is None: 256 | attention_mask = input_ids.new_ones(input_ids.shape) 257 | 258 | # set pad_token_id to eos_token_id if not set. Important that this is done after 259 | # attention_mask is created 260 | if pad_token_id is None and eos_token_id is not None: 261 | logger.warning( 262 | "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) 263 | ) 264 | pad_token_id = eos_token_id 265 | 266 | # current position and vocab size 267 | vocab_size = model.config.vocab_size 268 | 269 | # set effective batch size and effective batch multiplier according to do_sample 270 | if do_sample: 271 | effective_batch_size = batch_size * num_return_sequences 272 | effective_batch_mult = num_return_sequences 273 | else: 274 | effective_batch_size = batch_size 275 | effective_batch_mult = 1 276 | 277 | if model.config.is_encoder_decoder: 278 | if decoder_start_token_id is None: 279 | decoder_start_token_id = bos_token_id 280 | 281 | assert ( 282 | decoder_start_token_id is not None 283 | ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" 284 | assert hasattr(model, "get_encoder"), "{} should have a 'get_encoder' function defined".format(model) 285 | assert callable(model.get_encoder), "{} should be a method".format(model.get_encoder) 286 | 287 | # get encoder and store encoder outputs 288 | encoder = model.get_encoder() 289 | 290 | encoder_outputs = encoder(input_ids, attention_mask=attention_mask) 291 | 292 | # Expand input ids if num_beams > 1 or num_return_sequences > 1 293 | if num_return_sequences > 1 or num_beams > 1: 294 | input_ids_len = input_ids.shape[-1] 295 | input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) 296 | attention_mask = attention_mask.unsqueeze(1).expand( 297 | batch_size, effective_batch_mult * num_beams, input_ids_len 298 | ) 299 | 300 | input_ids = input_ids.contiguous().view( 301 | effective_batch_size * num_beams, input_ids_len 302 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) 303 | attention_mask = attention_mask.contiguous().view( 304 | effective_batch_size * num_beams, input_ids_len 305 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) 306 | 307 | # Chris: note this is important distinction between decoder-only and 308 | # encoder-decoder architectures, for encoder-decoder models, decoder state 309 | # is initialized with decoder BOS token, for decoder-only models, it's the whole 310 | # input_ids as passed to this function 311 | # Note: in the current formulation this precludes prefix-completion usecases by definition 312 | if model.config.is_encoder_decoder: 313 | # create empty decoder_input_ids 314 | input_ids = torch.full( 315 | (effective_batch_size * num_beams, 1), 316 | decoder_start_token_id, 317 | dtype=torch.long, 318 | device=next(model.parameters()).device, 319 | ) 320 | cur_len = 1 321 | 322 | assert ( 323 | batch_size == encoder_outputs[0].shape[0] 324 | ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " 325 | 326 | # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) 327 | expanded_batch_idxs = ( 328 | torch.arange(batch_size) 329 | .view(-1, 1) 330 | .repeat(1, num_beams * effective_batch_mult) 331 | .view(-1) 332 | .to(input_ids.device) 333 | ) 334 | # expand encoder_outputs 335 | encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:]) 336 | 337 | else: 338 | encoder_outputs = None 339 | cur_len = input_ids.shape[-1] 340 | 341 | # Chris: return the outputs needed for model._generate_beam_search 342 | return OrderedDict([ 343 | ('model', model), 344 | ('input_ids', input_ids), 345 | ('cur_len', cur_len), 346 | ('max_length', max_length), 347 | ('min_length', min_length), 348 | ('do_sample', do_sample), 349 | ('early_stopping', early_stopping), 350 | ('temperature', temperature), 351 | ('top_k', top_k), 352 | ('top_p', top_p), 353 | ('repetition_penalty', repetition_penalty), 354 | ('no_repeat_ngram_size', no_repeat_ngram_size), 355 | ('bad_words_ids', bad_words_ids), 356 | ('bos_token_id', bos_token_id), 357 | ('pad_token_id', pad_token_id), 358 | ('decoder_start_token_id', decoder_start_token_id), 359 | ('eos_token_id', eos_token_id), 360 | ('batch_size', effective_batch_size), 361 | ('num_return_sequences', num_return_sequences), 362 | ('length_penalty', length_penalty), 363 | ('num_beams', num_beams), 364 | ('vocab_size', vocab_size), 365 | ('encoder_outputs', encoder_outputs), 366 | ('attention_mask', attention_mask) 367 | ]) 368 | 369 | 370 | def get_initial_decoding_state(text, model, tokenizer, decoding_hyperparams): 371 | """ 372 | Get the state needed to start decoding from an instance 373 | """ 374 | # convert text to tensor 375 | inputs = tokenizer.batch_encode_plus( 376 | [text], 377 | max_length=decoding_hyperparams['max_length'], 378 | pad_to_max_length=True, 379 | return_tensors='pt' 380 | ) 381 | input_ids = inputs['input_ids'] 382 | 383 | return initialize_generation( 384 | model, input_ids, 385 | **decoding_hyperparams 386 | ) 387 | 388 | 389 | def outputs_from_state(state): 390 | """ 391 | Run forward pass using a state, note this only works for states with a 'model' attribute 392 | """ 393 | model_inputs = state['model'].prepare_inputs_for_generation( 394 | state['input_ids'], 395 | past=state['past'], 396 | attention_mask=state['attention_mask'], 397 | use_cache=True 398 | ) 399 | outputs = state['model'](**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) 400 | return outputs 401 | 402 | 403 | def logits_from_output(state): 404 | """ 405 | In the context of ensemble decoding, decoding parameters may be applied twice 406 | - once on individual states 407 | - once on the entire ensemble 408 | As currently implemented, some decoding heuristics are applied to the logits, 409 | some are applied to the scores (logits after softmax). 410 | """ 411 | pass 412 | 413 | 414 | def apply_heuristics_to_logits(state): 415 | # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) 416 | if state['repetition_penalty'] != 1.0: 417 | state['model'].enforce_repetition_penalty_( 418 | state['next_token_logits'], 419 | state['batch_size'], 420 | state['num_beams'], 421 | state['input_ids'], 422 | state['repetition_penalty'] 423 | ) 424 | 425 | if state['temperature'] != 1.0: 426 | state['next_token_logits'] = state['next_token_logits'] / state['temperature'] 427 | 428 | return state 429 | 430 | 431 | @torch.no_grad() 432 | def ensembled_beam_search_step(component_states, ensemble_state, step_mask=None): 433 | """ 434 | Decoding hyperparams live in ensemble_state 435 | """ 436 | 437 | if 'decoding_stats' not in ensemble_state: 438 | # fires on first decoding step 439 | ensemble_state['decoding_stats'] = [] 440 | for _ in range(len(component_states)): 441 | ensemble_state['decoding_stats'].append([[] for _ in range(ensemble_state['num_beams'])]) 442 | 443 | for state in component_states: 444 | 445 | state['outputs'] = outputs_from_state(state) 446 | state['next_token_logits'] = state['outputs'][0][:, -1, :] # (batch_size * num_beams, vocab_size) 447 | 448 | state = apply_heuristics_to_logits(state) 449 | # apply softmax to logits 450 | state['scores'] = F.log_softmax(state['next_token_logits'], dim=-1) # (batch_size * num_beams, vocab_size) 451 | 452 | if state['model'].config.is_encoder_decoder and ensemble_state['do_sample'] is False: 453 | # TODO (PVP) still a bit hacky here - there might be a better solution 454 | state['scores'] = state['model'].prepare_scores_for_generation( 455 | state['scores'], 456 | cur_len=state['cur_len'], 457 | max_length=state['max_length']) 458 | 459 | # set state's eos token prob to zero if min_length is not reached 460 | if ensemble_state['eos_token_id'] is not None and ensemble_state['cur_len'] < ensemble_state['min_length']: 461 | state['scores'][:, state['eos_token_id']] = -float("inf") 462 | 463 | if ensemble_state['no_repeat_ngram_size'] > 0: 464 | # calculate a list of banned tokens to prevent repetitively generating the same ngrams 465 | num_batch_hypotheses = ensemble_state['batch_size'] * ensemble_state['num_beams'] 466 | # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 467 | banned_batch_tokens = modeling_utils.calc_banned_ngram_tokens( 468 | ensemble_state['input_ids'], 469 | num_batch_hypotheses, 470 | ensemble_state['no_repeat_ngram_size'], 471 | ensemble_state['cur_len'] 472 | ) 473 | for i, banned_tokens in enumerate(banned_batch_tokens): 474 | state['scores'][i, banned_tokens] = -float("inf") 475 | 476 | if ensemble_state['bad_words_ids'] is not None: 477 | # calculate a list of banned tokens according to bad words 478 | banned_tokens = modeling_utils.calc_banned_bad_words_ids( 479 | ensemble_state['input_ids'], 480 | ensemble_state['bad_words_ids'] 481 | ) 482 | 483 | for i, banned_tokens in enumerate(banned_tokens): 484 | state['scores'][i, banned_tokens] = -float("inf") 485 | 486 | # TODO: WORKING: after all that, we're just going use the user provided-mask if it's there 487 | # TODO: need to use numpy-style indexing or elementwise multiply for this 488 | if step_mask is not None: 489 | state['scores'] = state['scores'] * step_mask 490 | 491 | assert state['scores'].shape == ( 492 | ensemble_state['batch_size'] * ensemble_state['num_beams'], ensemble_state['vocab_size']), "Shapes of scores: {} != {}".format( 493 | state['scores'].shape, (ensemble_state['batch_size'] * ensemble_state['num_beams'], ensemble_state['vocab_size']) 494 | ) 495 | 496 | # if model has past, then set the past variable to speed up decoding 497 | if state['model']._use_cache(state['outputs'], use_cache=True): 498 | state['past'] = state['outputs'][1] 499 | 500 | # WORKING: get the shape of the scores 501 | # TODO WORKING: if there's a mask, use it (set everything else to `-float("inf")` 502 | # TODO: this is effectively the reverse of the "bad_words_ids" logic below, in the mask case, 503 | # almost all words are bad, and _which_ words are bad change at each timestep 504 | 505 | # just simple mean of logprobs as first try, later more sophisticated weighting 506 | # - TODO: inject reduce function with `torch.mean` as default 507 | ensemble_state['scores'] = torch.mean(torch.stack([s['scores'] for s in component_states]), dim=0) 508 | 509 | # TODO: WORKING: add flag in ensemble state to let user force decode 510 | 511 | # BEGIN: ways of selecting next token from scores 512 | if ensemble_state['do_sample']: 513 | raise AssertionError('sampling currently not supported') 514 | _scores = scores + state['beam_scores'][:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) 515 | # Top-p/top-k filtering 516 | # Chris: note hard-coded `min_tokens_to_keep` 517 | _scores = modeling_utils.top_k_top_p_filtering( 518 | _scores, top_k=state['top_k'], top_p=state['top_p'], min_tokens_to_keep=2 519 | ) # (batch_size * num_beams, vocab_size) 520 | # re-organize to group the beam together to sample from all beam_idxs 521 | _scores = _scores.contiguous().view( 522 | state['batch_size'], state['num_beams'] * state['vocab_size'] 523 | ) # (batch_size, num_beams * vocab_size) 524 | 525 | # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) 526 | probs = F.softmax(_scores, dim=-1) 527 | next_tokens = torch.multinomial(probs, num_samples=2 * state['num_beams']) # (batch_size, num_beams * 2) 528 | # Compute next scores 529 | next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) 530 | # sort the sampled vector to make sure that the first num_beams samples are the best 531 | next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) 532 | next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) 533 | 534 | else: 535 | next_scores = ensemble_state['scores'] + ensemble_state['beam_scores'][:, None].expand_as(ensemble_state['scores']) # (batch_size * num_beams, vocab_size) 536 | 537 | # re-organize to group the beam together (we are keeping top hypotheses across beams) 538 | next_scores = next_scores.view( 539 | ensemble_state['batch_size'], ensemble_state['num_beams'] * ensemble_state['vocab_size'] 540 | ) # (batch_size, num_beams * vocab_size) 541 | 542 | # Chris: there is a |vocab| * beam_idx offset 543 | next_scores, next_tokens = \ 544 | torch.topk( 545 | next_scores, 546 | 2 * ensemble_state['num_beams'], 547 | dim=1, 548 | largest=True, 549 | sorted=True 550 | ) 551 | 552 | assert next_scores.size() == next_tokens.size() == (ensemble_state['batch_size'], 2 * ensemble_state['num_beams']) 553 | # NEXT TOKEN CANDIDATES HAVE BEEN SELECTED 554 | 555 | # BEGIN: UPDATING SEARCH STATE(S) 556 | # next batch beam content 557 | next_batch_beam = [] 558 | 559 | # for each input (note currently if we are doing one multi-doc summary, batch_size is 1 for sure) 560 | for batch_idx in range(ensemble_state['batch_size']): 561 | 562 | # if we are done with this sentence 563 | if ensemble_state['done'][batch_idx]: 564 | assert ( 565 | len(ensemble_state['generated_hyps'][batch_idx]) >= ensemble_state['num_beams'] 566 | ), "Batch can only be done if at least {} beams have been generated".format(state['num_beams']) 567 | assert ( 568 | ensemble_state['eos_token_id'] is not None and ensemble_state['pad_token_id'] is not None 569 | ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" 570 | next_batch_beam.extend([(0, ensemble_state['pad_token_id'], 0)] * ensemble_state['num_beams']) # pad the batch 571 | continue 572 | 573 | # next sentence beam content 574 | next_sent_beam = [] 575 | 576 | # next tokens for this sentence from each beam 577 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( 578 | zip(next_tokens[batch_idx], next_scores[batch_idx]) 579 | ): 580 | # get beam and token IDs (undo beam offset) 581 | beam_id = beam_token_id // ensemble_state['vocab_size'] 582 | token_id = beam_token_id % ensemble_state['vocab_size'] 583 | 584 | effective_beam_id = batch_idx * ensemble_state['num_beams'] + beam_id 585 | 586 | # add to generated hypotheses if end of sentence or last iteration 587 | if (ensemble_state['eos_token_id'] is not None) and (token_id.item() == ensemble_state['eos_token_id']): 588 | # if beam_token does not belong to top num_beams tokens, it should not be added 589 | is_beam_token_worse_than_top_num_beams = beam_token_rank >= ensemble_state['num_beams'] 590 | if is_beam_token_worse_than_top_num_beams: 591 | continue 592 | # update beam hypotheses obj with finished hypothesis and score 593 | # we are storing metadata on ensemble_state['decoding_stats'][effective_beam_id] in the same way we're 594 | # updating ensemble_state['input_ids'] at each timestep 595 | # add metatdata for this beam_idx for all states 596 | # metadata is ordered in the same way as component states 597 | hyp_metadata = [] 598 | for state_idx in range(len(ensemble_state['decoding_stats'])): 599 | hyp_metadata.append(ensemble_state['decoding_stats'][state_idx][effective_beam_id]) 600 | ensemble_state['generated_hyps'][batch_idx].add( 601 | ensemble_state['input_ids'][effective_beam_id].clone(), beam_token_score.item(), metadata=hyp_metadata 602 | ) 603 | else: 604 | # add next predicted token if it is not eos_token 605 | next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) 606 | 607 | if len(next_sent_beam) == ensemble_state['num_beams']: 608 | # the beam for next step is now full 609 | break 610 | 611 | # Check if we're done so that we can save a pad step if all(done) 612 | ensemble_state['done'][batch_idx] = ensemble_state['done'][batch_idx] or ensemble_state['generated_hyps'][batch_idx].is_done( 613 | next_scores[batch_idx].max().item(), cur_len=ensemble_state['cur_len'] 614 | ) 615 | 616 | # update next beam content 617 | assert len(next_sent_beam) == ensemble_state['num_beams'], "Beam should always be full after loop above" 618 | next_batch_beam.extend(next_sent_beam) 619 | assert len(next_batch_beam) == ensemble_state['num_beams'] * (batch_idx + 1) 620 | 621 | # stop if are done with every sentence 622 | if all(ensemble_state['done']): 623 | return component_states, ensemble_state 624 | 625 | # sanity check / prepare next timestep 626 | assert len(next_batch_beam) == ensemble_state['batch_size'] * ensemble_state['num_beams'] 627 | 628 | # Note we shouldn't need to deal with beam scores on the component_states 629 | # Chris: the score of each item is this timestep's score + previous beam score 630 | # each next_batch_beam stores (beam_token_score, token_id, effective_beam_id) 631 | ensemble_state['beam_scores'] = ensemble_state['beam_scores'].new([x[0] for x in next_batch_beam]) 632 | 633 | # re-order batch 634 | # each next_batch_beam stores (beam_token_score, token_id, effective_beam_id) 635 | beam_tokens = ensemble_state['input_ids'].new([x[1] for x in next_batch_beam]) 636 | # this idx will be used to select the beams sequences to continue -- note the same sequence can be selected and continued in multiple ways 637 | beam_idx = ensemble_state['input_ids'].new([x[2] for x in next_batch_beam]) 638 | 639 | for state in component_states: 640 | state['input_ids'] = ensemble_state['input_ids'][beam_idx, :] 641 | state['input_ids'] = torch.cat([ensemble_state['input_ids'], beam_tokens.unsqueeze(1)], dim=-1) 642 | 643 | # reorder input_ids according to beam_idx 644 | ensemble_state['input_ids'] = ensemble_state['input_ids'][beam_idx, :] 645 | # concat current timestep onto input_ids 646 | ensemble_state['input_ids'] = torch.cat([ensemble_state['input_ids'], beam_tokens.unsqueeze(1)], dim=-1) 647 | 648 | # reorder lists of decoding metadata according to beam_idx 649 | #ensemble_state['decoding_stats'] = ensemble_state['decoding_stats'][beam_idx] 650 | for state_idx, component_state in enumerate(component_states): 651 | # TODO: store in flat semantics for now, deal with batches later since edge cases of BeamHypothses not totally evident yet 652 | # Note we don't need to store beam scores (accumulated scores) on component states since we can effectively force decode by summing logprobs at each timestep 653 | state_scores = component_state['scores'][beam_idx, beam_tokens] 654 | 655 | # reorder/replace existing state metadata 656 | next_decoding_stats = [] 657 | for beam_id in beam_idx.cpu().numpy(): 658 | next_decoding_stats.append(copy.deepcopy(ensemble_state['decoding_stats'][state_idx][beam_id])) 659 | 660 | # concat new state metadata horizontally 661 | state_metadata = [{'token': token.item(), 'score': score.item()} for token, score in zip(beam_tokens, state_scores)] 662 | for beam_id in range(ensemble_state['num_beams']): 663 | next_decoding_stats[beam_id].append(state_metadata[beam_id]) 664 | 665 | ensemble_state['decoding_stats'][state_idx] = next_decoding_stats 666 | # TODO: do we want the score up to this point, or the softmax output of just this timestep? -- double check this 667 | 668 | # re-order internal states 669 | # Note ensemble_state has no "past", this is only on component_states 670 | # TODO: Note in case batch size is 1 (beam can be larger), all 'past' should be identical, so this reordering shouldn't matter 671 | # TODO: confirm this as it could lead to very weird bugs 672 | for state in component_states: 673 | state['past'] = state['model']._reorder_cache(state['past'], beam_idx) 674 | 675 | # extend attention_mask for new generated input if only decoder 676 | # Chris: commented until we need a decoder-only model 677 | #if state['model'].config.is_encoder_decoder is False: 678 | # state['attention_mask'] = torch.cat( 679 | # [ 680 | # state['attention_mask'], 681 | # state['attention_mask'].new_ones((state['attention_mask'].shape[0], 1)) 682 | # ], 683 | # dim=-1 684 | # ) 685 | 686 | # update current length 687 | for state in component_states: 688 | state['cur_len'] = state['cur_len'] + 1 689 | 690 | ensemble_state['cur_len'] = ensemble_state['cur_len'] + 1 691 | 692 | print(f'beam_scores: {ensemble_state["beam_scores"]}') 693 | 694 | return component_states, ensemble_state 695 | 696 | @torch.no_grad() 697 | def beam_search_step(state): 698 | if state.get('outputs', None) is None: 699 | outputs = outputs_from_state(state) 700 | 701 | next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) 702 | 703 | # if model has past, then set the past variable to speed up decoding 704 | if state['model']._do_output_past(outputs): 705 | state['past'] = outputs[1] 706 | 707 | # some heuristics are applied in-place to logits, others to scores 708 | # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) 709 | if state['repetition_penalty'] != 1.0: 710 | state['model'].enforce_repetition_penalty_( 711 | next_token_logits, 712 | state['batch_size'], 713 | state['num_beams'], 714 | state['input_ids'], 715 | state['repetition_penalty'] 716 | ) 717 | 718 | if state['temperature'] != 1.0: 719 | next_token_logits = next_token_logits / state['temperature'] 720 | 721 | scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) 722 | if state['model'].config.is_encoder_decoder and state['do_sample'] is False: 723 | # TODO (PVP) still a bit hacky here - there might be a better solution 724 | scores = state['model'].prepare_scores_for_generation( 725 | scores, 726 | cur_len=state['cur_len'], 727 | max_length=state['max_length']) 728 | 729 | # set eos token prob to zero if min_length is not reached 730 | if state['eos_token_id'] is not None and state['cur_len'] < state['min_length']: 731 | scores[:, state['eos_token_id']] = -float("inf") 732 | 733 | if state['no_repeat_ngram_size'] > 0: 734 | # calculate a list of banned tokens to prevent repetitively generating the same ngrams 735 | num_batch_hypotheses = state['batch_size'] * state['num_beams'] 736 | # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 737 | banned_batch_tokens = modeling_utils.calc_banned_ngram_tokens( 738 | state['input_ids'], 739 | num_batch_hypotheses, 740 | state['no_repeat_ngram_size'], 741 | state['cur_len'] 742 | ) 743 | for i, banned_tokens in enumerate(banned_batch_tokens): 744 | scores[i, banned_tokens] = -float("inf") 745 | 746 | if state['bad_words_ids'] is not None: 747 | # calculate a list of banned tokens according to bad words 748 | banned_tokens = modeling_utils.calc_banned_bad_words_ids( 749 | state['input_ids'], 750 | state['bad_words_ids'] 751 | ) 752 | 753 | for i, banned_tokens in enumerate(banned_tokens): 754 | scores[i, banned_tokens] = -float("inf") 755 | 756 | assert scores.shape == ( 757 | state['batch_size'] * state['num_beams'], state['vocab_size']), "Shapes of scores: {} != {}".format( 758 | scores.shape, (state['batch_size'] * state['num_beams'], state['vocab_size']) 759 | ) 760 | 761 | # BEGIN: ways of selecting next token from scores 762 | if state['do_sample']: 763 | _scores = scores + state['beam_scores'][:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) 764 | # Top-p/top-k filtering 765 | # Chris: note hard-coded `min_tokens_to_keep` 766 | _scores = modeling_utils.top_k_top_p_filtering( 767 | _scores, top_k=state['top_k'], top_p=state['top_p'], min_tokens_to_keep=2 768 | ) # (batch_size * num_beams, vocab_size) 769 | # re-organize to group the beam together to sample from all beam_idxs 770 | _scores = _scores.contiguous().view( 771 | state['batch_size'], state['num_beams'] * state['vocab_size'] 772 | ) # (batch_size, num_beams * vocab_size) 773 | 774 | # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) 775 | probs = F.softmax(_scores, dim=-1) 776 | next_tokens = torch.multinomial(probs, num_samples=2 * state['num_beams']) # (batch_size, num_beams * 2) 777 | # Compute next scores 778 | next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) 779 | # sort the sampled vector to make sure that the first num_beams samples are the best 780 | next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) 781 | next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) 782 | 783 | else: 784 | next_scores = scores + state['beam_scores'][:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) 785 | 786 | # re-organize to group the beam together (we are keeping top hypotheses across beams) 787 | next_scores = next_scores.view( 788 | state['batch_size'], state['num_beams'] * state['vocab_size'] 789 | ) # (batch_size, num_beams * vocab_size) 790 | 791 | next_scores, next_tokens = \ 792 | torch.topk( 793 | next_scores, 794 | 2 * state['num_beams'], 795 | dim=1, 796 | largest=True, 797 | sorted=True 798 | ) 799 | 800 | assert next_scores.size() == next_tokens.size() == (state['batch_size'], 2 * state['num_beams']) 801 | # NEXT TOKEN CANDIDATES HAVE BEEN SELECTED 802 | 803 | # BEGIN: UPDATING SEARCH STATE 804 | # next batch beam content 805 | next_batch_beam = [] 806 | 807 | # for each sentence 808 | for batch_idx in range(state['batch_size']): 809 | 810 | # if we are done with this sentence 811 | if state['done'][batch_idx]: 812 | assert ( 813 | len(state['generated_hyps'][batch_idx]) >= state['num_beams'] 814 | ), "Batch can only be done if at least {} beams have been generated".format(state['num_beams']) 815 | assert ( 816 | state['eos_token_id'] is not None and state['pad_token_id'] is not None 817 | ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" 818 | next_batch_beam.extend([(0, state['pad_token_id'], 0)] * state['num_beams']) # pad the batch 819 | continue 820 | 821 | # next sentence beam content 822 | next_sent_beam = [] 823 | 824 | # next tokens for this sentence from each beam 825 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( 826 | zip(next_tokens[batch_idx], next_scores[batch_idx]) 827 | ): 828 | # get beam and token IDs 829 | beam_id = beam_token_id // state['vocab_size'] 830 | token_id = beam_token_id % state['vocab_size'] 831 | 832 | effective_beam_id = batch_idx * state['num_beams'] + beam_id 833 | # add to generated hypotheses if end of sentence or last iteration 834 | if (state['eos_token_id'] is not None) and (token_id.item() == state['eos_token_id']): 835 | # if beam_token does not belong to top num_beams tokens, it should not be added 836 | is_beam_token_worse_than_top_num_beams = beam_token_rank >= state['num_beams'] 837 | if is_beam_token_worse_than_top_num_beams: 838 | continue 839 | # update beam hypotheses obj with finished hypothesis and score 840 | state['generated_hyps'][batch_idx].add( 841 | state['input_ids'][effective_beam_id].clone(), beam_token_score.item(), 842 | ) 843 | else: 844 | # add next predicted token if it is not eos_token 845 | next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) 846 | 847 | # the beam for next step is now full 848 | if len(next_sent_beam) == state['num_beams']: 849 | break 850 | 851 | # Check if we're done so that we can save a pad step if all(done) 852 | state['done'][batch_idx] = state['done'][batch_idx] or state['generated_hyps'][batch_idx].is_done( 853 | next_scores[batch_idx].max().item(), cur_len=state['cur_len'] 854 | ) 855 | 856 | # update next beam content 857 | assert len(next_sent_beam) == state['num_beams'], "Beam should always be full after loop above" 858 | next_batch_beam.extend(next_sent_beam) 859 | assert len(next_batch_beam) == state['num_beams'] * (batch_idx + 1) 860 | 861 | # stop if are done with every sentence 862 | if all(state['done']): 863 | return state 864 | 865 | # sanity check / prepare next timestep 866 | assert len(next_batch_beam) == state['batch_size'] * state['num_beams'] 867 | state['beam_scores'] = state['beam_scores'].new([x[0] for x in next_batch_beam]) 868 | 869 | # re-order batch 870 | beam_tokens = state['input_ids'].new([x[1] for x in next_batch_beam]) 871 | beam_idx = state['input_ids'].new([x[2] for x in next_batch_beam]) 872 | 873 | state['input_ids'] = state['input_ids'][beam_idx, :] 874 | state['input_ids'] = torch.cat([state['input_ids'], beam_tokens.unsqueeze(1)], dim=-1) 875 | # re-order internal states 876 | if state['past'] is not None: 877 | state['past'] = state['model']._reorder_cache(state['past'], beam_idx) 878 | 879 | # extend attention_mask for new generated input if only decoder 880 | if state['model'].config.is_encoder_decoder is False: 881 | state['attention_mask'] = torch.cat( 882 | [ 883 | state['attention_mask'], 884 | state['attention_mask'].new_ones((state['attention_mask'].shape[0], 1)) 885 | ], 886 | dim=-1 887 | ) 888 | 889 | # update current length 890 | state['cur_len'] = state['cur_len'] + 1 891 | return state 892 | 893 | 894 | # this is def step() for model._generate_no_beam_search 895 | @torch.no_grad() 896 | def greedy_step(state): 897 | model_inputs = state['model'].prepare_inputs_for_generation( 898 | state['input_ids'], 899 | past=state['past'], 900 | attention_mask=state['attention_mask'] 901 | ) 902 | 903 | outputs = state['model'](**model_inputs) 904 | next_token_logits = outputs[0][:, -1, :] 905 | 906 | # if model has past, then set the past variable to speed up decoding 907 | if state['model']._do_output_past(outputs): 908 | state['past'] = outputs[1] 909 | 910 | # now update next_token_logits using various heuristics 911 | 912 | # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) 913 | if state['repetition_penalty'] != 1.0: 914 | # Chris: note in-place modification side-effect 915 | state['model'].enforce_repetition_penalty_( 916 | next_token_logits, 917 | state['batch_size'], 1, state['input_ids'], state['repetition_penalty']) 918 | 919 | if state['no_repeat_ngram_size'] > 0: 920 | # calculate a list of banned tokens to prevent repetitively generating the same ngrams 921 | # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 922 | 923 | banned_tokens = modeling_utils.calc_banned_ngram_tokens( 924 | state['input_ids'], 925 | state['batch_size'], 926 | state['no_repeat_ngram_size'], 927 | state['cur_len']) 928 | for batch_idx in range(state['batch_size']): 929 | next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") 930 | 931 | if state['bad_words_ids'] is not None: 932 | # calculate a list of banned tokens according to bad words 933 | banned_tokens = modeling_utils.calc_banned_bad_words_ids(state['input_ids'], state['bad_words_ids']) 934 | 935 | for batch_idx in range(state['batch_size']): 936 | next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") 937 | 938 | # Chris: WORKING: note any next token logic must live outside of the step function 939 | # Chris: put this into codebase first before proceeding with TDD 940 | 941 | # set eos token prob to zero if min_length is not reached 942 | if state['eos_token_id'] is not None and state['cur_len'] < state['min_length']: 943 | next_token_logits[:, state['eos_token_id']] = -float("inf") 944 | 945 | if state['do_sample']: 946 | # Temperature (higher temperature => more likely to sample low probability tokens) 947 | if state['temperature'] != 1.0: 948 | next_token_logits = next_token_logits / state['temperature'] 949 | # Top-p/top-k filtering 950 | next_token_logits = \ 951 | modeling_utils.top_k_top_p_filtering( 952 | next_token_logits, 953 | top_k=state['top_k'], 954 | top_p=state['top_p'] 955 | ) 956 | # Sample 957 | probs = F.softmax(next_token_logits, dim=-1) 958 | # Chris: TODO: note for ensembling all next token logic 959 | # needs to move outside of this function 960 | next_token = torch.multinomial(probs, num_samples=1).squeeze(1) 961 | else: 962 | # Greedy decoding 963 | next_token = torch.argmax(next_token_logits, dim=-1) 964 | 965 | # update generations and finished sentences 966 | if state['eos_token_id'] is not None: 967 | # pad finished sentences if eos_token_id exist 968 | tokens_to_add = next_token * state['unfinished_sents'] + (state['pad_token_id']) * ( 969 | 1 - state['unfinished_sents']) 970 | else: 971 | tokens_to_add = next_token 972 | 973 | # Chris: concat whatever was generated to input ids 974 | state['input_ids'] = torch.cat([state['input_ids'], tokens_to_add.unsqueeze(-1)], dim=-1) 975 | 976 | if state['eos_token_id'] is not None: 977 | eos_in_sents = tokens_to_add == state['eos_token_id'] 978 | # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length 979 | is_sents_unfinished_and_token_to_add_is_eos = state['unfinished_sents'].mul(eos_in_sents.long()).bool() 980 | state['sent_lengths'].masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, state['cur_len'] + 1) 981 | # unfinished_sents is set to zero if eos in sentence 982 | state['unfinished_sents'].mul_((~eos_in_sents).long()) 983 | 984 | # stop when there is a in each sentence, or if we exceed the maximal length 985 | if state['unfinished_sents'].max() == 0: 986 | return state 987 | 988 | # extend attention_mask for new generated input if only decoder 989 | if state['model'].config.is_encoder_decoder is False: 990 | state['attention_mask'] = torch.cat( 991 | [state['attention_mask'], 992 | state['attention_mask'].new_ones((state['attention_mask'].shape[0], 1))], 993 | dim=-1 994 | ) 995 | 996 | state['cur_len'] = state['cur_len'] + 1 997 | 998 | return state 999 | 1000 | 1001 | 1002 | -------------------------------------------------------------------------------- /transformer_decoding/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | import tqdm 5 | import os 6 | import shutil 7 | import time 8 | from _collections import defaultdict 9 | 10 | import numpy as np 11 | import torch 12 | import spacy 13 | 14 | import pyrouge 15 | import logging 16 | 17 | from transformers import (modeling_utils, 18 | BartTokenizer, 19 | BartForConditionalGeneration, 20 | BartConfig) 21 | 22 | from transformer_decoding import decoding_utils 23 | import transformer_decoding.log as log 24 | 25 | from newsroom.analyze.rouge import ROUGE_L, ROUGE_N 26 | 27 | logger = log.create_logger(__name__) 28 | 29 | 30 | np.random.seed(42) 31 | 32 | 33 | # BEGIN: utils for Lebanoff 2018 rouge 34 | # adapted from here: 35 | # https://github.com/ucfnlp/multidoc_summarization/blob/ae30c9ee039d4ad5ff64fd2245faafc5a62c4dd7/decode.py 36 | 37 | # installing pyrouge 38 | # https://stackoverflow.com/questions/45894212/installing-pyrouge-gets-error-in-ubuntu 39 | def make_html_safe(s): 40 | """Replace any angled brackets in string s to avoid interfering with HTML attention visualizer.""" 41 | s.replace("<", "<") 42 | s.replace(">", ">") 43 | return s 44 | 45 | 46 | def rouge_eval(ref_dir, dec_dir): 47 | """Evaluate the files in ref_dir and dec_dir with pyrouge, returning results_dict""" 48 | r = pyrouge.Rouge155() 49 | r.model_filename_pattern = '#ID#_reference.[A-Z].txt' 50 | r.system_filename_pattern = '(\d+)_decoded.txt' 51 | r.model_dir = ref_dir 52 | r.system_dir = dec_dir 53 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging 54 | rouge_args = ['-e', r._data_dir, 55 | '-c', 56 | '95', 57 | '-2', '4', # This is the only one we changed (changed the max skip from -1 to 4) 58 | '-U', 59 | '-r', '1000', 60 | '-n', '4', 61 | '-w', '1.2', 62 | '-a', 63 | '-l', '100'] 64 | rouge_args = ' '.join(rouge_args) 65 | rouge_results = r.convert_and_evaluate(rouge_args=rouge_args) 66 | return r.output_to_dict(rouge_results) 67 | 68 | 69 | def rouge_log(results_dict): 70 | """Log ROUGE results to screen and write to file. 71 | Args: 72 | results_dict: the dictionary returned by pyrouge 73 | dir_to_write: the directory where we will write the results to""" 74 | log_str = "" 75 | for x in ["1", "2", "l", "s4", "su4"]: 76 | log_str += "\nROUGE-%s:\n" % x 77 | for y in ["f_score", "recall", "precision"]: 78 | key = "rouge_%s_%s" % (x, y) 79 | key_cb = key + "_cb" 80 | key_ce = key + "_ce" 81 | val = results_dict[key] 82 | val_cb = results_dict[key_cb] 83 | val_ce = results_dict[key_ce] 84 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce) 85 | logging.info(log_str) # log to screen 86 | return log_str 87 | 88 | 89 | def write_for_rouge(all_reference_sents, decoded_words, ex_index, rouge_dec_dir, rouge_ref_dir, nlp): 90 | """Write output to file in correct format for eval with pyrouge. This is called in single_pass mode. 91 | Args: 92 | all_reference_sents: list of list of strings 93 | decoded_words: list of strings 94 | ex_index: int, the index with which to label the files 95 | """ 96 | 97 | # First, divide decoded output into sentences 98 | decoded_sents = [] 99 | while len(decoded_words) > 0: 100 | try: 101 | fst_period_idx = decoded_words.index(".") 102 | except ValueError: # there is text remaining that doesn't end in "." 103 | fst_period_idx = len(decoded_words) 104 | sent = decoded_words[:fst_period_idx + 1] # sentence up to and including the period 105 | decoded_words = decoded_words[fst_period_idx + 1:] # everything else 106 | decoded_sents.append(' '.join(sent)) 107 | 108 | # pyrouge calls a perl script that puts the data into HTML files. 109 | # Therefore we need to make our output HTML safe. 110 | decoded_sents = [make_html_safe(w) for w in decoded_sents] 111 | # note sentence splitting here 112 | all_reference_sents = [ 113 | [make_html_safe(' '.join([str(w) for w in s])) for s in nlp(abstract).sents] 114 | for abstract in all_reference_sents 115 | ] 116 | 117 | # Write to file 118 | decoded_file = os.path.join(rouge_dec_dir, "%06d_decoded.txt" % ex_index) 119 | 120 | for abs_idx, abs in enumerate(all_reference_sents): 121 | ref_file = os.path.join(rouge_ref_dir, "%06d_reference.%s.txt" % ( 122 | ex_index, chr(ord('A') + abs_idx))) 123 | with open(ref_file, "w") as f: 124 | # one long line 125 | # f.write(' '.join(abs).lower() + '\n') 126 | 127 | # one sentence on each line 128 | for idx, sent in enumerate(abs): 129 | f.write(sent + "\n") 130 | 131 | # f.write(sent) if idx==len(abs)-1 else f.write(sent+"\n") 132 | with open(decoded_file, "w") as f: 133 | # one long line 134 | # f.write(' '.join(decoded_sents).lower() + '\n') 135 | for idx, sent in enumerate(decoded_sents): 136 | f.write(sent + "\n") 137 | 138 | 139 | def lebanoff_2018_rouge(system_hyp_file, evaluation_dataset): 140 | TEMP_EVAL_DIR = Path('rouge_evaluation_tempdir') 141 | rouge_dec_dir = TEMP_EVAL_DIR / 'rouge_dec_dir' 142 | rouge_ref_dir = TEMP_EVAL_DIR / 'rouge_ref_dir' 143 | rouge_dec_dir.mkdir(parents=True, exist_ok=True) 144 | rouge_ref_dir.mkdir(parents=True, exist_ok=True) 145 | 146 | nlp = spacy.load("en_core_web_sm") 147 | 148 | # dataset needs to be in .jsonl 149 | dataset_rows = [json.loads(l) for l in open(evaluation_dataset)] 150 | 151 | # tokenize hyps to follow Lebanoff et al 2018 logic 152 | system_hyp_tokens = [[str(t) for t in nlp(h.strip())] for h in open(system_hyp_file)] 153 | 154 | # write the rouge files 155 | for idx, (row, h) in enumerate(zip(dataset_rows, system_hyp_tokens)): 156 | if type(row['summary']) is list: 157 | summaries = row['summary'] 158 | else: 159 | summaries = [row['summary']] 160 | # print(f'{len(summaries)} summaries available at row {idx}') 161 | write_for_rouge(summaries, h, idx, rouge_dec_dir, rouge_ref_dir, nlp) 162 | 163 | log_report = rouge_log(rouge_eval(rouge_ref_dir, rouge_dec_dir)) 164 | print(log_report) 165 | shutil.rmtree(TEMP_EVAL_DIR) 166 | 167 | # END: utils for Lebanoff 2018 rouge 168 | 169 | # BEGIN: utils for Fabbri 2019 rouge 170 | # Evaluation from: https://github.com/Alex-Fabbri/Multi-News/blob/3675e7c422ae3b4020617a324ac264f50333357d/code/OpenNMT-py-baselines/tools/test_rouge.py 171 | def test_rouge(candidates, references): 172 | """Calculate ROUGE scores of sequences passed as an iterator 173 | e.g. a list of str, an open file, StringIO or even sys.stdin 174 | """ 175 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 176 | tmp_dir = ".rouge-tmp-{}".format(current_time) 177 | try: 178 | if not os.path.isdir(tmp_dir): 179 | os.mkdir(tmp_dir) 180 | os.mkdir(tmp_dir + "/candidate") 181 | os.mkdir(tmp_dir + "/reference") 182 | # candidates = [line.strip() for line in cand] 183 | # references = [line.strip() for line in ref] 184 | assert len(candidates) == len(references) 185 | cnt = len(candidates) 186 | for i in range(cnt): 187 | if len(references[i]) < 1: 188 | continue 189 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 190 | encoding="utf-8") as f: 191 | f.write(candidates[i]) 192 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 193 | encoding="utf-8") as f: 194 | f.write(references[i]) 195 | r = pyrouge.Rouge155() 196 | r.model_dir = tmp_dir + "/reference/" 197 | r.system_dir = tmp_dir + "/candidate/" 198 | r.model_filename_pattern = 'ref.#ID#.txt' 199 | r.system_filename_pattern = 'cand.(\d+).txt' 200 | rouge_results = r.convert_and_evaluate() 201 | results_dict = r.output_to_dict(rouge_results) 202 | return results_dict 203 | finally: 204 | pass 205 | if os.path.isdir(tmp_dir): 206 | shutil.rmtree(tmp_dir) 207 | 208 | 209 | def rouge_results_to_str(results_dict): 210 | return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format( 211 | results_dict["rouge_1_f_score"] * 100, 212 | results_dict["rouge_2_f_score"] * 100, 213 | results_dict["rouge_3_f_score"] * 100, 214 | results_dict["rouge_l_f_score"] * 100, 215 | results_dict["rouge_su*_f_score"] * 100) 216 | 217 | # END: utils for Fabbri 2019 rouge 218 | 219 | 220 | def print_mean(results, rouge_types): 221 | for rouge_type in rouge_types: 222 | precs = results[rouge_type]['p'] 223 | recalls = results[rouge_type]['r'] 224 | fscores = results[rouge_type]['f'] 225 | p = round(np.mean(precs), 3) 226 | r = round(np.mean(recalls), 3) 227 | f = round(np.mean(fscores), 3) 228 | print(rouge_type, 'p:', p, 'r:', r, 'f:', f) 229 | 230 | 231 | def evaluate_rouge(hyps, refs, lowercase=True): 232 | if type(hyps) is str: 233 | hyps = [l.strip() for l in open(hyps)] 234 | if type(refs) is str: 235 | assert refs.endswith('.jsonl'), 'reference summaries must be stored in "summaries": [] field of .jsonl file' 236 | refs = [json.loads(c)['summary'].strip() for c in open(refs['evaluation_dataset'])] 237 | 238 | assert len(hyps) == len(refs) 239 | # Now evaluate 240 | rouge_types = ['rouge-1', 'rouge-2', 'rouge-l'] 241 | results = dict((rouge_type, defaultdict(list)) 242 | for rouge_type in rouge_types) 243 | 244 | for hyp, ref in zip(hyps, refs): 245 | if lowercase: 246 | hyp = hyp.lower() 247 | ref = ref.lower() 248 | 249 | r1 = ROUGE_N(ref, hyp, n=1) 250 | r2 = ROUGE_N(ref, hyp, n=2) 251 | rl = ROUGE_L(ref, hyp) 252 | 253 | for (rouge_type, scores) in zip(rouge_types, [r1, r2, rl]): 254 | results[rouge_type]['p'].append(scores.precision) 255 | results[rouge_type]['r'].append(scores.recall) 256 | results[rouge_type]['f'].append(scores.fscore) 257 | 258 | return results, rouge_types 259 | 260 | 261 | class BartSummarizerConfig: 262 | def __init__(self, args): 263 | """ 264 | currently we use the model `bart-large-cnn` 265 | """ 266 | self.model = BartForConditionalGeneration.from_pretrained('model_id') 267 | self.tokenizer = BartTokenizer.from_pretrained(args['model_id']) 268 | 269 | 270 | class Summarizer: 271 | 272 | def __init__(self, config): 273 | self.model = config['model'] 274 | self.tokenizer = config['tokenizer'] 275 | # NOTE: could factor out score computation vs reduction(?) -- wait for usecase 276 | 277 | 278 | # eventually we want to be able to ensemble 279 | # (1) different inputs, same model 280 | # (2) same inputs, different models 281 | # -- we should support arbitrary combinations of these, without too much 282 | # cruft from configuration 283 | 284 | def summarize_articles(articles, args, gold_summary=None): 285 | """ 286 | Ensembled summarization of a cluster of articles 287 | """ 288 | model = args['model'] 289 | tokenizer = args['tokenizer'] 290 | decoding_hyperparams = { 291 | 'max_length': args['max_src_length'], 292 | 'max_tgt_length': args['max_tgt_length'], 293 | 'num_beams': args['num_beams'] 294 | } 295 | 296 | # TODO: WORKING: add flag in ensemble state to let user force decode 297 | 298 | component_states = [decoding_utils.get_start_state(a, model, tokenizer, decoding_hyperparams) 299 | for a in articles] 300 | 301 | # Note we just pass the first article in cluster when building the ensemble state 302 | ensemble_state = decoding_utils.get_start_state(articles[0], model, tokenizer, decoding_hyperparams) 303 | 304 | # ((batch) x |vocab| x timesteps) 305 | timestep_mask = None 306 | if args['force_decode_gold']: 307 | # convert text to tensor 308 | # Note currently hard-coded max gold summary length 309 | encoded_gold = tokenizer.batch_encode_plus( 310 | [gold_summary], 311 | max_length=512, 312 | pad_to_max_length=False, 313 | return_tensors='pt' 314 | ) 315 | gold_ids = encoded_gold['input_ids'] 316 | 317 | # (timesteps, |vocab|) 318 | # set everything not in`float("inf")` 319 | # Note: since the mask is going to be elementwise-multiplied with logprobs, we set to float("inf") instead of 320 | # -float("inf") so that the sign doesn't get flipped 321 | # effectively we know our mask tensor for each timestep is (1, |vocab_size|), 322 | # because batch size and beam size are 1 323 | timestep_mask = torch.empty(gold_ids.shape[1], ensemble_state['vocab_size']).fill_(float("inf")) 324 | timestep_mask = timestep_mask.scatter(-1, gold_ids.T, 1.)[:, None, :] 325 | 326 | # WORKING TODO: attach gold summary to ensemble state if user wants to force decode 327 | # WORKING TODO: assert decoding hyperparams make sense if force-decoding (beam size = 1, etc...) 328 | 329 | component_states, ensemble_state = \ 330 | decoding_utils.generate(component_states, decoding_hyperparams['max_tgt_length'], 331 | ensemble_state=ensemble_state, timestep_mask=timestep_mask) 332 | 333 | # NOTE: this logic might move to end of `generate` function(?) 334 | # finalize all open beam hypotheses and end to generated hypotheses 335 | for batch_idx in range(ensemble_state['batch_size']): 336 | if ensemble_state['done'][batch_idx]: 337 | continue 338 | 339 | # need to add best num_beams hypotheses to generated hyps 340 | for beam_id in range(ensemble_state['num_beams']): 341 | effective_beam_id = batch_idx * ensemble_state['num_beams'] + beam_id 342 | final_score = ensemble_state['beam_scores'][effective_beam_id].item() 343 | final_tokens = ensemble_state['input_ids'][effective_beam_id] 344 | 345 | hyp_metadata = [] 346 | for state_idx in range(len(ensemble_state['decoding_stats'])): 347 | hyp_metadata.append(ensemble_state['decoding_stats'][state_idx][effective_beam_id]) 348 | 349 | ensemble_state['generated_hyps'][batch_idx].add(final_tokens, final_score, metadata=hyp_metadata) 350 | 351 | assert ensemble_state['batch_size'] == 1, 'current logic assumes batch size = 1' 352 | 353 | # sort hyps by score (0 index is first batch, and we're assuming batch_size always = 1 right now) 354 | sorted_hyps = [(hyp, score, metadata) for score, hyp, metadata in sorted(ensemble_state['generated_hyps'][0].beams, key=lambda b: b[0], reverse=True)] 355 | 356 | print(f'Num hyps in BeamHypotheses: {len(sorted_hyps)}') 357 | 358 | # map token indexes back to strings 359 | predictions = [tokenizer.decode(hyp, 360 | skip_special_tokens=True, 361 | clean_up_tokenization_spaces=False) 362 | for hyp, _, _ in sorted_hyps] 363 | 364 | return predictions, sorted_hyps 365 | 366 | 367 | def article_to_text(article, separator_token=' '): 368 | # just be sure about whitespace 369 | title = ' '.join(article["title"].strip().split()) 370 | text = ' '.join(article["text"].strip().split()) 371 | return f'{title} {separator_token} {text}' 372 | 373 | 374 | def main(args): 375 | 376 | if args['evaluation_dataset'].endswith('.jsonl'): 377 | dataset = [json.loads(l) for l in open(args['evaluation_dataset'])][:args['rows_to_eval']] 378 | else: 379 | raise AssertionError('Right now we only know how to handle .jsonl evaluation datasets') 380 | 381 | eval_prefix = args['eval_prefix'] 382 | 383 | if args['predictions'] is None: 384 | # load pretrained or finetuned transformer model 385 | print(f'loading pre-trained model: {args["model_id"]}') 386 | 387 | # we have to load fine-tuned models in a different way because of pytorch-lightning 388 | if args['model_id'].endswith('.ckpt'): 389 | from transformer_decoding.finetune import SummarizationTrainer 390 | lightning_model = SummarizationTrainer.load_from_checkpoint(args['model_id']) 391 | args['model'] = lightning_model.model 392 | args['tokenizer'] = lightning_model.tokenizer 393 | else: 394 | # transformers pretrained 395 | args['model'] = BartForConditionalGeneration.from_pretrained(args['model_id']) 396 | args['tokenizer'] = BartTokenizer.from_pretrained(args['model_id']) 397 | 398 | # Set the model in evaluation mode to deactivate the DropOut modules 399 | args['model'].eval() 400 | 401 | if torch.cuda.is_available(): 402 | args['model'].to('cuda') 403 | 404 | # summarize MDS / summarization dataset with model 405 | preds_output = open(f'{eval_prefix}eval_predicted_summaries.out', 'w', buffering=1) 406 | gold_output = open(f'{eval_prefix}eval_gold_summaries.out', 'w', buffering=1) 407 | metadata_output = open(f'{eval_prefix}decoding_metadata.jsonl', 'w', buffering=1) 408 | 409 | summaries = [] 410 | # get summary for each cluster 411 | # note here we have a macro-batch size of one cluster by definition 412 | for cluster in tqdm.tqdm(dataset): 413 | # shuffle articles before selecting topk to use in ensemble 414 | articles = [article_to_text(a) for a in cluster['articles']] 415 | np.random.shuffle(articles) 416 | articles = articles[:args['max_articles_in_cluster']] 417 | 418 | if args['min_input_char_length'] is not None: 419 | articles_ = [a for a in articles if len(a) >= args['min_input_char_length']] 420 | if len(articles_) == 0: 421 | articles_ = [articles[0]] 422 | articles = articles_ 423 | 424 | gold_summary = cluster['summary'].strip() 425 | 426 | predictions, sorted_hyps = summarize_articles(articles, args, gold_summary=gold_summary) 427 | # sorted_hyps -- (token_idxs, score, metadata) 428 | # they're in sorted order according to ensemble score, so first one is the best 429 | # we will have one list of timestamp metadata for each cluster input 430 | length_penalty = args['length_penalty'] 431 | component_scores = [] 432 | for input_idx, state_metadata in enumerate(sorted_hyps[0][2]): 433 | timestep_scores = np.array([o['score'] for o in state_metadata]) 434 | global_score = np.sum(timestep_scores) / len(timestep_scores) ** length_penalty 435 | component_scores.append(global_score) 436 | 437 | component_scores = np.array(component_scores) 438 | for idx in np.argsort(component_scores)[::-1]: 439 | print(f'ARTICLE: {articles[idx][:1500]}') 440 | print(f'Input {idx} score: {component_scores[idx]}') 441 | print() 442 | 443 | print(f'Ensemble score: {sorted_hyps[0][1]}') 444 | print(f'Gold: {cluster["summary"]}') 445 | print(f'Predicted: {predictions[0]}') 446 | print() 447 | 448 | # TODO: sometimes we hit -inf during forced decoding, debug this 449 | # TODO: if big disparity between article scores, do something, store an input divergence score 450 | # Note: reverse max / min because scores are logprobs 451 | if component_scores.max() / sorted(component_scores)[-2] <= .65: 452 | import ipdb; ipdb.set_trace() 453 | 454 | predicted_summary = predictions[0] 455 | summaries.append((predicted_summary, gold_summary)) 456 | preds_output.write(f'{predicted_summary}\n') 457 | gold_output.write(f'{gold_summary}\n') 458 | 459 | sorted_hyps_ = [] 460 | for tok_idxs, score, tok_scores in sorted_hyps: 461 | tok_idxs = [int(idx) for idx in tok_idxs.cpu().numpy()] 462 | sorted_hyps_.append((tok_idxs, score, tok_scores)) 463 | sorted_hyps = sorted_hyps_ 464 | 465 | metadata_output.write( 466 | json.dumps( 467 | { 468 | 'cluster': cluster, 469 | 'predictions': predictions, 470 | 'inputs_used': articles, 471 | 'component_scores': list(component_scores), 472 | 'decoding_metadata': sorted_hyps 473 | }) 474 | + '\n') 475 | 476 | preds_output.close() 477 | gold_output.close() 478 | 479 | # Evaluation 480 | hyps, refs = zip(*summaries) 481 | else: 482 | # Evaluate on user-supplied predictions 483 | logger.info(f'Evaluating predictions in {args["predictions"]} ' 484 | f'against gold summaries in {args["evaluation_dataset"]}') 485 | hyps = [l.strip() for l in open(args['predictions'])] 486 | # Note this is only single-reference currently 487 | refs = [json.loads(c)['summary'].strip() for c in open(args['evaluation_dataset'])] 488 | assert len(hyps) == len(refs) 489 | 490 | # TODO: working -- issue with multi vs single ref setups 491 | # - Lebanoff eval requires tokenized predictions -- see what we can do to consolidate evals 492 | 493 | # Ghalandari et al 2020 evaluation 494 | # TODO: print evaluation results to file 495 | results, rouge_types = evaluate_rouge(hyps, refs) 496 | print_mean(results, rouge_types) 497 | 498 | # End evaluation 499 | 500 | 501 | def parse_args(): 502 | parser = argparse.ArgumentParser() 503 | parser.add_argument( 504 | '--evaluation-dataset', 505 | type=str, 506 | required=True, 507 | help='filepath of evaluation data' 508 | ) 509 | parser.add_argument( 510 | '--predictions', 511 | type=str, 512 | required=False, 513 | default=None, 514 | help='if supplied, evaluation will be done on this output, and new predictions will not be generated' 515 | ) 516 | parser.add_argument( 517 | '--model-id', 518 | type=str, 519 | required=True, 520 | help='the model id string from the huggingface transformers library, or the path to a pytorch lightning fine-tuned .ckpt' 521 | ) 522 | parser.add_argument( 523 | '--num-beams', 524 | type=int, 525 | required=False, 526 | default=1, 527 | help='number of beam search beams' 528 | ) 529 | parser.add_argument( 530 | '--length-penalty', 531 | type=float, 532 | required=False, 533 | default=2., 534 | help='length penalty to use when computing final hypothesis scores' 535 | ) 536 | parser.add_argument( 537 | '--min-input-char-length', 538 | type=int, 539 | required=False, 540 | default=None, 541 | help='If specified, we will try to filter inputs to be at least this many characters' 542 | ) 543 | parser.add_argument( 544 | '--max-src-length', 545 | type=int, 546 | required=False, 547 | default=256, 548 | help='The maximum length of input sequences' 549 | ) 550 | parser.add_argument( 551 | '--max-tgt-length', 552 | type=int, 553 | required=False, 554 | default=64, 555 | help='The maximum length of decoded sequences' 556 | ) 557 | parser.add_argument( 558 | '--max-articles-in-cluster', 559 | type=int, 560 | required=False, 561 | default=5, 562 | help='take K articles from each cluster to use in the ensemble' 563 | ) 564 | parser.add_argument( 565 | '--rows-to-eval', 566 | type=int, 567 | required=False, 568 | default=None, 569 | help='if provided, truncate eval dataset to this many rows' 570 | ) 571 | parser.add_argument( 572 | '--eval-prefix', 573 | type=str, 574 | required=False, 575 | default='', 576 | help='If provided, prefix of output files' 577 | ) 578 | parser.add_argument( 579 | '--force-decode-gold', 580 | required=False, 581 | action='store_true', 582 | help='if this flag is true, we force generation of the gold summary for each cluster' 583 | ) 584 | 585 | return parser.parse_args() 586 | 587 | 588 | if __name__ == '__main__': 589 | main(vars(parse_args())) 590 | -------------------------------------------------------------------------------- /transformer_decoding/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | import time 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from transformer_decoding.transformer_base import (BaseTransformer, 11 | add_generic_args, 12 | generic_train, 13 | get_linear_schedule_with_warmup) 14 | 15 | from transformer_decoding.bart_utils import SummarizationDataset 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class SummarizationTrainer(BaseTransformer): 22 | 23 | mode = "language-modeling" 24 | 25 | def __init__(self, hparams): 26 | super().__init__(hparams, num_labels=None, mode=self.mode) 27 | self.dataset_kwargs: dict = dict( 28 | data_dir=self.hparams.data_dir, 29 | max_source_length=self.hparams.max_source_length, 30 | max_target_length=self.hparams.max_target_length, 31 | ) 32 | 33 | def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None): 34 | return self.model( 35 | input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels, 36 | ) 37 | 38 | def _step(self, batch): 39 | pad_token_id = self.tokenizer.pad_token_id 40 | source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"] 41 | y_ids = y[:, :-1].contiguous() 42 | lm_labels = y[:, 1:].clone() 43 | lm_labels[y[:, 1:] == pad_token_id] = -100 44 | outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,) 45 | 46 | loss = outputs[0] 47 | 48 | return loss 49 | 50 | def training_step(self, batch, batch_idx): 51 | loss = self._step(batch) 52 | 53 | tensorboard_logs = {"train_loss": loss} 54 | return {"loss": loss, "log": tensorboard_logs} 55 | 56 | def validation_step(self, batch, batch_idx): 57 | loss = self._step(batch) 58 | return {"val_loss": loss} 59 | 60 | def validation_end(self, outputs): 61 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 62 | tensorboard_logs = {"val_loss": avg_loss} 63 | return {"avg_val_loss": avg_loss, "log": tensorboard_logs} 64 | 65 | def test_step(self, batch, batch_idx): 66 | pad_token_id = self.tokenizer.pad_token_id 67 | source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) 68 | # NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py 69 | generated_ids = self.model.generate( 70 | input_ids=source_ids, 71 | attention_mask=source_mask, 72 | num_beams=1, 73 | max_length=80, 74 | repetition_penalty=2.5, 75 | length_penalty=1.0, 76 | early_stopping=True, 77 | use_cache=True, 78 | ) 79 | preds = [ 80 | self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) 81 | for g in generated_ids 82 | ] 83 | target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y] 84 | loss = self._step(batch) 85 | 86 | return {"val_loss": loss, "preds": preds, "target": target} 87 | 88 | def test_end(self, outputs): 89 | return self.validation_end(outputs) 90 | 91 | def test_epoch_end(self, outputs): 92 | output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt") 93 | output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt") 94 | # write predictions and targets for later rouge evaluation. 95 | with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer: 96 | for output_batch in outputs: 97 | p_writer.writelines(s + "\n" for s in output_batch["preds"]) 98 | t_writer.writelines(s + "\n" for s in output_batch["target"]) 99 | p_writer.close() 100 | t_writer.close() 101 | 102 | return self.test_end(outputs) 103 | 104 | def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: 105 | dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs) 106 | dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle) 107 | return dataloader 108 | 109 | def train_dataloader(self) -> DataLoader: 110 | dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) 111 | t_total = ( 112 | (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) 113 | // self.hparams.gradient_accumulation_steps 114 | * float(self.hparams.num_train_epochs) 115 | ) 116 | scheduler = get_linear_schedule_with_warmup( 117 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total 118 | ) 119 | self.lr_scheduler = scheduler 120 | return dataloader 121 | 122 | def val_dataloader(self) -> DataLoader: 123 | return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) 124 | 125 | def test_dataloader(self) -> DataLoader: 126 | return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) 127 | 128 | @staticmethod 129 | def add_model_specific_args(parser, root_dir): 130 | BaseTransformer.add_model_specific_args(parser, root_dir) 131 | # Add BART specific options 132 | parser.add_argument( 133 | "--max_source_length", 134 | default=1024, 135 | type=int, 136 | help="The maximum total input sequence length after tokenization. Sequences longer " 137 | "than this will be truncated, sequences shorter will be padded.", 138 | ) 139 | parser.add_argument( 140 | "--max_target_length", 141 | default=56, 142 | type=int, 143 | help="The maximum total input sequence length after tokenization. Sequences longer " 144 | "than this will be truncated, sequences shorter will be padded.", 145 | ) 146 | 147 | parser.add_argument( 148 | "--data_dir", 149 | default=None, 150 | type=str, 151 | required=True, 152 | help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.", 153 | ) 154 | return parser 155 | 156 | 157 | def main(args): 158 | 159 | # If output_dir not provided, a folder will be generated in pwd 160 | if not args.output_dir: 161 | args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",) 162 | os.makedirs(args.output_dir) 163 | model = SummarizationTrainer(args) 164 | trainer = generic_train(model, args) 165 | 166 | # Optionally, predict on dev set and write to output_dir 167 | if args.do_predict: 168 | # See https://github.com/huggingface/transformers/issues/3159 169 | # pl use this format to create a checkpoint: 170 | # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ 171 | # /pytorch_lightning/callbacks/model_checkpoint.py#L169 172 | checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) 173 | model = model.load_from_checkpoint(checkpoints[-1]) 174 | trainer.test(model) 175 | 176 | 177 | if __name__ == "__main__": 178 | parser = argparse.ArgumentParser() 179 | add_generic_args(parser, os.getcwd()) 180 | parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd()) 181 | args = parser.parse_args() 182 | 183 | main(args) 184 | -------------------------------------------------------------------------------- /transformer_decoding/log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | 5 | def create_logger(name): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.INFO) 8 | formatter = logging.Formatter( 9 | '%(asctime)s %(name)s %(levelname)s: %(message)s' 10 | ) 11 | handler = logging.StreamHandler(sys.stdout) 12 | handler.setLevel(logging.INFO) 13 | handler.setFormatter(formatter) 14 | logger.addHandler(handler) 15 | logger.propagate = False 16 | return logger 17 | -------------------------------------------------------------------------------- /transformer_decoding/test_decoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test stepwise decoding 3 | 4 | Then test trivial ensembles (same instance input multiple times) 5 | 6 | Then test news summarization ensembles (same instance input multiple times) 7 | 8 | """ 9 | 10 | import os 11 | from pathlib import Path 12 | import copy 13 | 14 | from collections import OrderedDict 15 | 16 | import torch 17 | 18 | from transformers import (modeling_utils, 19 | BartTokenizer, 20 | BartForConditionalGeneration, 21 | BartConfig) 22 | 23 | import unittest 24 | 25 | from transformer_decoding import decoding_utils 26 | 27 | 28 | #path_to_file = Path(os.path.dirname(os.path.abspath(__file__))) 29 | #resources = Path( 30 | # os.environ.get('RESOURCES', path_to_file / '../resources/test') 31 | #) 32 | 33 | 34 | 35 | # Set up transformer model with LM head then assert things 36 | # TODO: which transformer models have encoder-->decoder 37 | class TestTransformerDecoding(unittest.TestCase): 38 | 39 | @classmethod 40 | def setUpClass(cls): 41 | # summarization 42 | # generate yes beam search 43 | # Note for BART summarization in transformers repo, beam search performs much better 44 | # than no beam search, but even their beam search with num_beams=1 is better, implying that something 45 | # is broken in the _generate_no_beam_search function 46 | 47 | # see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example 48 | cls.model = BartForConditionalGeneration.from_pretrained('bart-large-cnn') 49 | cls.tokenizer = BartTokenizer.from_pretrained('bart-large-cnn') 50 | 51 | cls.decoding_hyperparams = { 52 | 'max_length': 40, 53 | 'num_beams': 3 54 | } 55 | 56 | cls.test_news_article_1 = 'New Zealand says it has stopped community transmission of Covid-19, ' \ 57 | 'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \ 58 | '- Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned ' \ 59 | 'against complacency, saying it does not mean a total end to new coronavirus cases. ' \ 60 | 'The news comes hours before New Zealand is set to move out of its toughest level of social restrictions. ' \ 61 | 'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \ 62 | 'Most people will still be required to remain at home at all times and avoid all social interactions.' 63 | 64 | cls.test_news_article_2 = \ 65 | 'But officials have warned against complacency, saying it does not mean a total end to new HIV cases. ' \ 66 | 'Most people will still be required to remain at home at all times and avoid all social interactions.' \ 67 | 'Germany says it has stopped community transmission of HIV, ' \ 68 | 'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \ 69 | '- Prime Minister Angela Merkle said the virus was "currently" eliminated. ' \ 70 | 'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \ 71 | 'The news comes hours before Germany is set to move out of its toughest level of social restrictions. ' 72 | 73 | def test_obtaining_timestep_scores(self): 74 | """ 75 | Test that we can get the scores out of a model in order to do things with them before deciding upon 76 | a discrete representation of this timestep and proceeding to the next one. 77 | """ 78 | # then we wish to step through decoding 79 | # for summarization, args on initial state which are input-specific: 80 | # decoder_state['model'] 81 | # decoder_state['encoder_outputs'] 82 | # decoder_state['past'] will also hold something model-specific(?) 83 | # Every other arg is a decoding hyperparam 84 | 85 | # as decoding proceeds, input_ids will hold current state 86 | # IDEA: pass a list of states, and one additional state to hold their combined outputs 87 | test_articles_1 = [self.test_news_article_1, self.test_news_article_2] 88 | component_states_1 = [decoding_utils.get_start_state(a, self.model, self.tokenizer, self.decoding_hyperparams) 89 | for a in test_articles_1] 90 | ensemble_state_1 = decoding_utils.get_start_state(test_articles_1[0], self.model, self.tokenizer, self.decoding_hyperparams) 91 | 92 | # TODO: at the beginning of decoding, the ensemble state doesn't know anything about the component states 93 | # - we should try to encode this explicitly by _not_ passing an input to initialize this state 94 | # TODO: remove past and encoder outputs from ensemble state 95 | # TODO: remove decoding hyperparams from component_states for sanity 96 | 97 | # run beam_search_step function 98 | # ok now we are ready to start stepping 99 | # step and decode with tokenizer at each step to visualize and understand decoding progress 100 | # for step_idx in range(decoding_hyperparams['max_length']): 101 | # print(f'STEP: {step_idx}') 102 | # print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in 103 | # decoder_state['input_ids']]) 104 | # decoder_state = decoding_utils.beam_search_step(decoder_state) 105 | # print() 106 | # import ipdb; ipdb.set_trace() 107 | 108 | # TODO: assert that it doesn't matter which state we initialize ensemble_state from 109 | component_states_1, ensemble_state_1 = \ 110 | decoding_utils.generate(component_states_1, self.decoding_hyperparams['max_length'], 111 | ensemble_state=ensemble_state_1) 112 | 113 | # reorder articles and run again 114 | test_articles_2 = [self.test_news_article_2, self.test_news_article_1] 115 | component_states_2 = [decoding_utils.get_start_state(a, self.model, self.tokenizer, self.decoding_hyperparams) 116 | for a in test_articles_2] 117 | ensemble_state_2 = decoding_utils.get_start_state(test_articles_2[0], self.model, self.tokenizer, self.decoding_hyperparams) 118 | 119 | component_states_2, ensemble_state_2 = \ 120 | decoding_utils.generate(component_states_2, self.decoding_hyperparams['max_length'], 121 | ensemble_state=ensemble_state_2) 122 | 123 | for o1_ids, o2_ids in zip(ensemble_state_1['input_ids'], ensemble_state_2['input_ids']): 124 | o1_text = self.tokenizer.decode(o1_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 125 | o2_text = self.tokenizer.decode(o2_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 126 | print(f'o1_text: {o1_text}') 127 | print(f'o1_ids: {o1_ids}') 128 | assert o1_text == o2_text 129 | 130 | 131 | if __name__ == '__main__': 132 | unittest.main() 133 | -------------------------------------------------------------------------------- /transformer_decoding/transformer_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | 9 | from transformers import ( 10 | ALL_PRETRAINED_MODEL_ARCHIVE_MAP, 11 | AdamW, 12 | AutoConfig, 13 | AutoModel, 14 | AutoModelForPreTraining, 15 | AutoModelForQuestionAnswering, 16 | AutoModelForSequenceClassification, 17 | AutoModelForTokenClassification, 18 | AutoModelWithLMHead, 19 | AutoTokenizer, 20 | get_linear_schedule_with_warmup, 21 | ) 22 | from transformers.modeling_auto import MODEL_MAPPING 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP) 29 | MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING) 30 | 31 | MODEL_MODES = { 32 | "base": AutoModel, 33 | "sequence-classification": AutoModelForSequenceClassification, 34 | "question-answering": AutoModelForQuestionAnswering, 35 | "pretraining": AutoModelForPreTraining, 36 | "token-classification": AutoModelForTokenClassification, 37 | "language-modeling": AutoModelWithLMHead, 38 | } 39 | 40 | 41 | def set_seed(args): 42 | random.seed(args.seed) 43 | np.random.seed(args.seed) 44 | torch.manual_seed(args.seed) 45 | if args.n_gpu > 0: 46 | torch.cuda.manual_seed_all(args.seed) 47 | 48 | 49 | class BaseTransformer(pl.LightningModule): 50 | def __init__(self, hparams, num_labels=None, mode="base"): 51 | "Initialize a model." 52 | 53 | super(BaseTransformer, self).__init__() 54 | self.hparams = hparams 55 | self.hparams.model_type = self.hparams.model_type.lower() 56 | config = AutoConfig.from_pretrained( 57 | self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, 58 | **({"num_labels": num_labels} if num_labels is not None else {}), 59 | cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, 60 | ) 61 | tokenizer = AutoTokenizer.from_pretrained( 62 | self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, 63 | do_lower_case=self.hparams.do_lower_case, 64 | cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, 65 | ) 66 | model = MODEL_MODES[mode].from_pretrained( 67 | self.hparams.model_name_or_path, 68 | from_tf=bool(".ckpt" in self.hparams.model_name_or_path), 69 | config=config, 70 | cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, 71 | ) 72 | self.config, self.tokenizer, self.model = config, tokenizer, model 73 | 74 | def is_logger(self): 75 | return self.trainer.proc_rank <= 0 76 | 77 | def configure_optimizers(self): 78 | "Prepare optimizer and schedule (linear warmup and decay)" 79 | 80 | model = self.model 81 | no_decay = ["bias", "LayerNorm.weight"] 82 | optimizer_grouped_parameters = [ 83 | { 84 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 85 | "weight_decay": self.hparams.weight_decay, 86 | }, 87 | { 88 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 89 | "weight_decay": 0.0, 90 | }, 91 | ] 92 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) 93 | self.opt = optimizer 94 | return [optimizer] 95 | 96 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): 97 | if self.trainer.use_tpu: 98 | xm.optimizer_step(optimizer) 99 | else: 100 | optimizer.step() 101 | optimizer.zero_grad() 102 | self.lr_scheduler.step() 103 | 104 | def get_tqdm_dict(self): 105 | tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} 106 | 107 | return tqdm_dict 108 | 109 | def test_step(self, batch, batch_nb): 110 | return self.validation_step(batch, batch_nb) 111 | 112 | def test_end(self, outputs): 113 | return self.validation_end(outputs) 114 | 115 | def train_dataloader(self): 116 | train_batch_size = self.hparams.train_batch_size 117 | dataloader = self.load_dataset("train", train_batch_size) 118 | 119 | t_total = ( 120 | (len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu))) 121 | // self.hparams.gradient_accumulation_steps 122 | * float(self.hparams.num_train_epochs) 123 | ) 124 | scheduler = get_linear_schedule_with_warmup( 125 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total 126 | ) 127 | self.lr_scheduler = scheduler 128 | return dataloader 129 | 130 | def val_dataloader(self): 131 | return self.load_dataset("dev", self.hparams.eval_batch_size) 132 | 133 | def test_dataloader(self): 134 | return self.load_dataset("test", self.hparams.eval_batch_size) 135 | 136 | def _feature_file(self, mode): 137 | return os.path.join( 138 | self.hparams.data_dir, 139 | "cached_{}_{}_{}".format( 140 | mode, 141 | list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), 142 | str(self.hparams.max_seq_length), 143 | ), 144 | ) 145 | 146 | @staticmethod 147 | def add_model_specific_args(parser, root_dir): 148 | parser.add_argument( 149 | "--model_type", 150 | default=None, 151 | type=str, 152 | required=True, 153 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES), 154 | ) 155 | parser.add_argument( 156 | "--model_name_or_path", 157 | default=None, 158 | type=str, 159 | required=True, 160 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), 161 | ) 162 | parser.add_argument( 163 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 164 | ) 165 | parser.add_argument( 166 | "--tokenizer_name", 167 | default="", 168 | type=str, 169 | help="Pretrained tokenizer name or path if not the same as model_name", 170 | ) 171 | parser.add_argument( 172 | "--cache_dir", 173 | default="", 174 | type=str, 175 | help="Where do you want to store the pre-trained models downloaded from s3", 176 | ) 177 | parser.add_argument( 178 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." 179 | ) 180 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 181 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 182 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 183 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 184 | parser.add_argument( 185 | "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." 186 | ) 187 | 188 | parser.add_argument("--train_batch_size", default=32, type=int) 189 | parser.add_argument("--eval_batch_size", default=32, type=int) 190 | 191 | 192 | class LoggingCallback(pl.Callback): 193 | def on_validation_end(self, trainer, pl_module): 194 | logger.info("***** Validation results *****") 195 | if pl_module.is_logger(): 196 | metrics = trainer.callback_metrics 197 | # Log results 198 | for key in sorted(metrics): 199 | if key not in ["log", "progress_bar"]: 200 | logger.info("{} = {}\n".format(key, str(metrics[key]))) 201 | 202 | def on_test_end(self, trainer, pl_module): 203 | logger.info("***** Test results *****") 204 | 205 | if pl_module.is_logger(): 206 | metrics = trainer.callback_metrics 207 | 208 | # Log and save results to file 209 | output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") 210 | with open(output_test_results_file, "w") as writer: 211 | for key in sorted(metrics): 212 | if key not in ["log", "progress_bar"]: 213 | logger.info("{} = {}\n".format(key, str(metrics[key]))) 214 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 215 | 216 | 217 | def add_generic_args(parser, root_dir): 218 | parser.add_argument( 219 | "--output_dir", 220 | default=None, 221 | type=str, 222 | required=True, 223 | help="The output directory where the model predictions and checkpoints will be written.", 224 | ) 225 | 226 | parser.add_argument( 227 | "--fp16", 228 | action="store_true", 229 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 230 | ) 231 | 232 | parser.add_argument( 233 | "--fp16_opt_level", 234 | type=str, 235 | default="O1", 236 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 237 | "See details at https://nvidia.github.io/apex/amp.html", 238 | ) 239 | 240 | parser.add_argument("--n_gpu", type=int, default=1) 241 | parser.add_argument("--n_tpu_cores", type=int, default=0) 242 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 243 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 244 | parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") 245 | parser.add_argument( 246 | "--gradient_accumulation_steps", 247 | type=int, 248 | default=1, 249 | help="Number of updates steps to accumulate before performing a backward/update pass.", 250 | ) 251 | 252 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 253 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 254 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 255 | 256 | 257 | def generic_train(model, args): 258 | # init model 259 | set_seed(args) 260 | 261 | # Setup distant debugging if needed 262 | if args.server_ip and args.server_port: 263 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 264 | import ptvsd 265 | 266 | print("Waiting for debugger attach") 267 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 268 | ptvsd.wait_for_attach() 269 | 270 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 271 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 272 | 273 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 274 | filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5 275 | ) 276 | 277 | train_params = dict( 278 | accumulate_grad_batches=args.gradient_accumulation_steps, 279 | gpus=args.n_gpu, 280 | max_epochs=args.num_train_epochs, 281 | early_stop_callback=False, 282 | gradient_clip_val=args.max_grad_norm, 283 | checkpoint_callback=checkpoint_callback, 284 | callbacks=[LoggingCallback()], 285 | ) 286 | 287 | if args.fp16: 288 | train_params["use_amp"] = args.fp16 289 | train_params["amp_level"] = args.fp16_opt_level 290 | 291 | if args.n_tpu_cores > 0: 292 | global xm 293 | import torch_xla.core.xla_model as xm 294 | 295 | train_params["num_tpu_cores"] = args.n_tpu_cores 296 | train_params["gpus"] = 0 297 | 298 | if args.n_gpu > 1: 299 | train_params["distributed_backend"] = "ddp" 300 | 301 | trainer = pl.Trainer(**train_params) 302 | 303 | if args.do_train: 304 | trainer.fit(model) 305 | 306 | return trainer 307 | --------------------------------------------------------------------------------