├── .gitignore ├── LICENSE ├── README.md ├── data_processing ├── __pycache__ │ ├── PubMedClient.cpython-36.pyc │ └── PubMedClient.cpython-37.pyc ├── data │ ├── bioasq_pubmed_articles.json │ └── medinfo_collection.json ├── prepare_training_data.py ├── prepare_validation_data.py ├── process_bioasq.py └── process_medinfo.py ├── evaluation ├── collection_statistics.py ├── data │ ├── bart │ │ └── chiqa_eval │ │ │ └── .gitignore │ ├── baselines │ │ └── chiqa_eval │ │ │ └── .gitignore │ ├── pointer_generator │ │ └── chiqa_eval │ │ │ └── .gitignore │ └── sentence_classifier │ │ └── chiqa_eval │ │ └── .gitignore ├── results │ └── .gitignore └── summarization_evaluation.py ├── models ├── bart │ ├── bart_config │ │ └── .gitignore │ ├── finetune_bart_bioasq.sh │ ├── make_datafiles_for_bart.py │ ├── process_bioasq_data.sh │ ├── run_chiqa.sh │ ├── run_inference_medsumm.py │ └── train.py ├── baselines.py ├── bilstm │ ├── __pycache__ │ │ ├── data.cpython-36.pyc │ │ ├── data.cpython-37.pyc │ │ ├── model.cpython-36.pyc │ │ └── model.cpython-37.pyc │ ├── data.py │ ├── model.py │ ├── requirements.txt │ ├── run_chiqa.sh │ ├── run_classifier.py │ └── train_sentence_classifier.sh └── pointer_generator │ ├── LICENSE.txt │ ├── __init__.py │ ├── __pycache__ │ ├── attention_decoder.cpython-36.pyc │ ├── batcher.cpython-36.pyc │ ├── beam_search.cpython-36.pyc │ ├── data.cpython-36.pyc │ ├── decode.cpython-36.pyc │ ├── model.cpython-36.pyc │ └── util.cpython-36.pyc │ ├── attention_decoder.py │ ├── batcher.py │ ├── beam_search.py │ ├── bioasq_abs2summ_vocab │ ├── data.py │ ├── decode.py │ ├── eval_medsumm.sh │ ├── inspect_checkpoint.py │ ├── make_asumm_pg_vocab.py │ ├── model.py │ ├── requirements.txt │ ├── run_chiqa.sh │ ├── run_medsumm.py │ ├── submit_sbatch_eval.sh │ ├── train_medsumm.sh │ └── util.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | data_processing/data/BioASQ-training7b 2 | data_processing/data/MedQuAD-master 3 | data_processing/data/MedInfo2019-QA-Medications.xlsx 4 | data_processing/data/*multi*.json 5 | data_processing/data/*single*.json 6 | data_processing/data/question_driven_answer_summarization_primary_dataset.json 7 | data_processing/data/bioasq_collection.json 8 | data_processing/data/bioasq_ideal_answers.json 9 | data_processing/data/bioasq_snippets.json 10 | data_processing/data/medinfo_section*.json 11 | data_processing/data/bioasq_abs2summ_binary_sent_classification_training.json 12 | data_processing/data/bioasq_abs2summ_training_data_with* 13 | data_processing/data/bioasq_snippets.json 14 | models/pointer_generator/bioasq_abs2summ* 15 | models/bart/dict.txt 16 | models/bart/apex 17 | models/bart/fairseq* 18 | models/bart/encoder.json 19 | models/bart/vocab.bpe 20 | models/bart/checkpoints* 21 | models/bart/bart.large 22 | models/bart/bart_config/without_question/* 23 | models/bart/bart_config/with_question/* 24 | models/bilstm/medsumm* 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Question-Driven Summarization of Answers to Consumer Health Questions 2 | This repository contains the code to process the data and run the answer summarization systems presented in the paper Question-Driven Summarization of Answers to Consumer Health Questions, available at https://www.nature.com/articles/s41597-020-00667-z. 3 | 4 | If you are interested in just downloading the data, please refer to https://doi.org/10.17605/OSF.IO/FYG46. However, if you are interested in repeating the experiments reported in the paper, clone this repository and move the data found at https://doi.org/10.17605/OSF.IO/FYG46 to the evaluation/data directory. 5 | 6 | ## Environments 7 | All instructions provided here have been tested in a Linux operating system. To train the models and run the experiments, you will need to set up a few environments with anaconda: data processing and evaluation; the BiLSTM; BART; and the Pointer-Generator. Since we are going to be processing the data first, create the following environment to install the data processing and evaluation dependencies: 8 | ``` 9 | conda create -n qdriven_env python=3.7 10 | conda activate qdriven_env 11 | pip install -r requirements.txt 12 | ``` 13 | The requirements.txt file is found in the base directory of this repository. 14 | The spacy tokenizer model will be handy later on as well 15 | ``` 16 | python -m spacy download en_core_web_sm 17 | ``` 18 | And because py-rouge uses nltk, we also need nltk: 19 | ``` 20 | python 21 | >>> import nltk 22 | >>> nltk.download('punkt') 23 | ``` 24 | There are more details regarding the environments required to train each model in the following sections. 25 | 26 | 27 | ## Answer Summarization 28 | In the models directory, there are six systems: 29 | Deep learning: 30 | 1. BiLSTM 31 | 2. Pointer Generator 32 | 3. BART 33 | 34 | Baselines: 35 | 1. LEAD-k 36 | 2. Random-k sentences 37 | 3. k ROUGE sentences 38 | 39 | 40 | ### Runnning baselines 41 | Running the baselines is simple and can be done while the qdriven_env environment is active. 42 | ``` 43 | python baselines.py --dataset=chiqa 44 | ``` 45 | This will run the baseline summarization methods on the two summmarization tasks reported in the paper, as well as on the shorter passages. k (number of sentences selected by the baselines) can be changed in the script. 46 | 47 | 48 | ### Running deep learning 49 | The following code is organized to train and run inference with all models first, and then use the summarization_evaluation.py script to evaluate all results at once. This section describes the steps for training and inference. 50 | 51 | 52 | #### Training Preprocessing 53 | First prepare the validation data for the Pointer-Generator and BART: 54 | ``` 55 | python prepare_validation_data.py --pg --bart 56 | ``` 57 | It is optional to include the --add-q option if you are interested in training models question-driven summarization. 58 | 59 | To create the training data, the BioASQ data for training first has to be acquired. To do so, you have to register for an account at http://bioasq.org/participate. 60 | 61 | In the participants area of the website, you can find a list of all the datasets previously choosed for the BioASQ challenge. Download the 7b version of the task. Once the BioASQ data has been downloaded and unizpped, it should be placed in the data_processing/data directory in the cloned github repository, so that the path relative to the data_processing directory looks like ```data_processing/data/BioASQ-training7b/BioASQ-training7b/training7b.json```. Note that we used version 7b of BioASQ for training and testing. You are welcome to experiment with 8b or newer but will have to fix the paths in the code. 62 | 63 | Once the data is in the correct place, run the following scripts: 64 | ``` 65 | python process_bioasq.py -p 66 | python prepare_training_data.py -bt --bart-bioasq --bioasq-sent 67 | ``` 68 | This will prepare separate training sets for the three deep learning models. Include the ```--add-q``` option to create additional datasets with the question concatenated to the beginning of the documents, for question-driven summarization. This step will take a while to finish. Once it is done, you are ready for training and inference. 69 | 70 | 71 | #### BiLSTM (sentence classification) 72 | You will first need to set up a tensorflow2-gpu environent and install the dependencies for the model: 73 | ``` 74 | conda create -n tf2_env tensorflow-gpu=2.0 python=3.7 75 | conda activate tf2_env 76 | pip install -r requirments.txt 77 | python -m spacy download en_core_web_sm 78 | ``` 79 | Use the requirements.txt file located in the models/bilstm directory. 80 | Once the environment is set up, you are ready for training. This is quite a bit easier than the previous two models. 81 | ``` 82 | train_sentence_classifier.sh 83 | ``` 84 | The training script will automatically save the checkpoint the performs that best on the validation set. The training will end after 10 epochs. The training script is configured for TensorBoard and you can monitor the loss by running tensorboard in the medsumm_bioasq_abs2summ directory. 85 | 86 | Once the BiLSTM is trained, the following script is provided to run the model on all single document summarization tasks in MEDIQA-Ans: 87 | ``` 88 | run_chiqa.sh 89 | ``` 90 | You are now able to evaluate the BiLSTM output with the evaluation script. During inference, the run_classifier.py script will also create output files that can be used as input for inference with the Pointer-Generator or BART.k can be changed in the run_chiqa.sh script to experiment with passing top k sentences to the generative models. 91 | 92 | 93 | #### BART 94 | Download BART into the models/bart directory in this repository. 95 | ``` 96 | wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz 97 | tar -xzvf bart.large.tar.gz 98 | ``` 99 | Navigate to the models/bart directory and prepare an environment for BART. 100 | ``` 101 | conda create -n pytorch_env python=3.7 pytorch_env pytorch torchvision cudatoolkit=10.1 -c pytorch 102 | conda activate pytorch_env 103 | pip install -r requirements.txt 104 | ``` 105 | To use automatic mixed precision (optional), follow these steps once pytorch is installed: 106 | ``` 107 | conda install -n pytorch_env -c anaconda nccl 108 | git clone https://github.com/NVIDIA/apex 109 | cd apex 110 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" ./ 111 | ``` 112 | These instructions are provided in the main fairseq readme (https://github.com/pytorch/fairseq) but we have provided them here in condensed form. Note that to install apex, first make sure your GCC compiler is up-to-date. 113 | Once these dependencies have been installed, you are ready to install fairseq. This requires installing an editable version of a earlier commit of the repository. Navigate to back to the models/bart directory of this repo and run: 114 | ``` 115 | git clone https://github.com/pytorch/fairseq 116 | cd fairseq 117 | git checkout 43cf9c977b8470ec493cc32d248fdcd9a984a9f6 118 | pip install --editable . 119 | ``` 120 | Because the fairseq repository is contains research projects under continuous development, the previous commit checked-out here should can be used to recreate the results. Using a current version of fairseq may require more troubleshooting. Once you have fairseq installed and BART downloaded to the bart directory, there are a few steps you have to take to get the bioasq in suitable format for finetuning BART. 121 | First 122 | ``` 123 | bash process_bioasq_data.sh -b -f without_question 124 | ``` 125 | This will prepare the byte-pair encodings and run the fairseq preprocessing. with_question will prepare data for question-driven summarization, without_question for plain summaization. Once the processing is complete, you can finetune BART with 126 | ``` 127 | bash finetune_bart_bioasq.sh without_question 128 | ``` 129 | If you have been testing question-driven summarization, include with_question instead. The larger your computing cluster, the faster you will be able to train. For the experiments presented in the paper, we trained BART for two days on three V100-SXM2 GPUs with 32GB of memory each. The bash script provided here is currently configured to run with one GPU; however, the fairseq library supports multi-gpu training. 130 | 131 | Once you have finetuned the model, run inference on the MEDIQA-AnS dataset with 132 | ``` 133 | bash run_chiqa.sh without_question 134 | ``` 135 | Or use with_question if you have trained the appropriate model. 136 | 137 | 138 | #### Pointer-Generator 139 | Navigate to the models/pointer_generator directory and create a new environment: 140 | ``` 141 | conda create -n tf1_env python=3.6 142 | conda activate tf1_env 143 | wget https://files.pythonhosted.org/packages/cb/4d/c9c4da41c6d7b9a4949cb9e53c7032d7d9b7da0410f1226f7455209dd962/tensorflow_gpu-1.2.0-cp36-cp36m-manylinux1_x86_64.whl 144 | pip install tensorflow_gpu-1.2.0-cp36-cp36m-manylinux1_x86_64.whl 145 | pip install -r requirements.txt 146 | python -m spacy download en_core_web_sm 147 | ``` 148 | The tensorflow 1.2.0 version is only availble via download from pypi.org, hence the use of ```wget``` first. To train the model, you will have to install cuDNN 5 and CUDA 8. Once these are configured on your machine, you are ready for training. 149 | The Python 3 version of the Pointer-Generator code from https://github.com/becxer/pointer-generator/ (forked from https://github.com/abisee/pointer-generator) is provided in the models/pointer_generator directory here. The code has been customized to support answer summarization data processing steps, involving changes to data.py, batcher.py, decode.py, and run_summarization.py. However, the model (in model.py) remains the same. 150 | 151 | To use the Pointer-Generator, from the pointer_generator directory you will have to run 152 | ``` 153 | python make_asumm_pg_vocab.py --vocab_path=bioasq_abs2summ_vocab --data_file=../../data_processing/data/bioasq_abs2summ_training_data_without_question.json 154 | ``` 155 | first, to prepare the BioASQ vocab. If you are focusing on question-driven summarization, provide that dataset instead. 156 | 157 | Then, to train, you will need to run two jobs: One to train, and the other to evaluate the checkpoints simultaneously. Run these commands independently, on two different GPUs: 158 | ``` 159 | bash train_medsumm.sh without_question 160 | bash eval_medsumm.sh without_question 161 | ``` 162 | If you have access to a computing cluster that uses slurm, you may find it useful to use sbatch to submit these jobs. 163 | You will have to monitor the training of the Pointer-Generator via tensorboard and manually end the job once the loss has satisfactorily converged. The checkpoint that best performs on the MedInfo validation set will be saved to variable-name-of-experiment-directory/eval/checkpoint_best 164 | 165 | Once it is properly trained (the MEDIQA-AnS paper reports results after 10,000 training steps), run inference on full text with the web pages with 166 | ``` 167 | run_chiqa.sh 168 | ``` 169 | The question driven option can be changed in the bash script. Note that the single pass decoding in the original Pointer-Generator code is quite slow, and it will unfortunately take approximately 45 minutes per dataset to perform inference. 170 | Other experiments can be run if you have configured the bash script to generate summaries for the passages or multi-document datasets as well. 171 | 172 | 173 | ### Evaluation 174 | Once the models are training and the baselines have been run on the summarization datasets you are interested in evaluating, activate the qdriven_env environment again and navigate to the evaluation directory. To run the evaluation script on the summarization models' predictions, you have a few options: 175 | For comparing all models on extractive and abstractive summaries of web pages: 176 | ``` 177 | python summarization_evaluation.py --dataset=chiqa --bleu --evaluate-models 178 | ``` 179 | Or to run two versions of BART (question-driven approach and without questions, if you have trained both) run 180 | ``` 181 | python summarization_evaluation.py --dataset=chiqa --bleu --q-driven 182 | ``` 183 | The same question-driven test can be applied to the Pointer-Generator as well, if you have trained the appropriate model with the correctly formatted question-driven dataset. 184 | 185 | More details about the options, such as saving scores per summary to file, or calculating Wilcoxon p-values, are described in the script. 186 | 187 | If you are interested in generating the statistics describing the collection, run 188 | ```collection_statistics.py --tokenize``` 189 | in the evaluation directory. This will generate the statistics reported in the paper with more technical detail. 190 | 191 | 192 | That's it! Thank you for using this code, and please contact us if you find any issues with the repository or have questions about summarization. If you publish work related to this project, please cite 193 | ``` 194 | @article{saverysumm, 195 | title={Question-Driven Summarization of Answers to Consumer Health Questions}, 196 | author={Max Savery and Asma {Ben Abacha} and Soumya Gayen and Dina Demner{-}Fushman}, 197 | journal = {arXiv e-prints}, 198 | month = {May}, 199 | year={2020}, 200 | eprint={2005.09067}, 201 | archivePrefix={arXiv}, 202 | primaryClass={cs.CL} 203 | url={https://arxiv.org/abs/2005.09067} 204 | } 205 | -------------------------------------------------------------------------------- /data_processing/__pycache__/PubMedClient.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/data_processing/__pycache__/PubMedClient.cpython-36.pyc -------------------------------------------------------------------------------- /data_processing/__pycache__/PubMedClient.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/data_processing/__pycache__/PubMedClient.cpython-37.pyc -------------------------------------------------------------------------------- /data_processing/prepare_training_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for classes to prepare training datasets 3 | 4 | Prepare training data for pointer generator (add tags and tags to summaries): 5 | Prepare bioasq 6 | python prepare_training_data.py -bt 7 | Or with the question: 8 | python prepare_training_data.py -bt --add-q 9 | 10 | Prepare training data for bart, with or without question appended to beginning of abstract text 11 | python prepare_training_data.py --bart-bioasq --add-q 12 | python prepare_training_data.py --bart-bioasq 13 | 14 | Prepare training data for sentence classification: 15 | python prepare_training_data.py --bioasq-sent 16 | """ 17 | 18 | import json 19 | import argparse 20 | import re 21 | import os 22 | 23 | from sklearn.utils import shuffle as sk_shuffle 24 | import rouge 25 | import spacy 26 | 27 | 28 | def get_args(): 29 | """ 30 | Argument defnitions 31 | """ 32 | parser = argparse.ArgumentParser(description="Arguments for data exploration") 33 | parser.add_argument("-t", 34 | dest="tag_sentences", 35 | action="store_true", 36 | help="tag the sentences with and , for use with pointer generator network") 37 | parser.add_argument("-e", 38 | dest="summ_end_tag", 39 | action="store_true", 40 | help="Add the summ end tag to the end of the summaries. This was observed not to improve performance on the MedInfo evaluation set") 41 | parser.add_argument("-b", 42 | dest="bioasq_pg", 43 | action="store_true", 44 | help="Make the bioasq training set for pointer generator") 45 | parser.add_argument("--bioasq-sent", 46 | dest="bioasq_sc", 47 | action="store_true", 48 | help="Make the bioasq training set for sentence classification") 49 | parser.add_argument("--bart-bioasq", 50 | dest="bart_bioasq", 51 | action="store_true", 52 | help="Prepare the bioasq training set for bart") 53 | parser.add_argument("--add-q", 54 | dest="add_q", 55 | action="store_true", 56 | help="Concatenate the question to the beginning of the text. Currently only implemented as an option for bart data and the bioasq abs2summ data") 57 | return parser 58 | 59 | 60 | class BioASQ(): 61 | """ 62 | Class to create various versions of BioASQ training dataset 63 | """ 64 | 65 | def __init__(self): 66 | """ 67 | Initiate spacy 68 | """ 69 | self.nlp = spacy.load('en_core_web_sm') 70 | self.Q_END = " [QUESTION?] " 71 | self.SUMM_END = " [END]" 72 | self.ARTICLE_END = " [ARTICLE_SEP] " 73 | 74 | def format_summary_sentences(self, summary): 75 | """ 76 | Split summary into sentences and add sentence tags to the strings: and 77 | """ 78 | tokenized_abs = self.nlp(summary) 79 | summary = " ".join([" {s} ".format(s=s.text.strip()) for s in tokenized_abs.sents]) 80 | return summary 81 | 82 | def _load_bioasq(self): 83 | """ 84 | Load bioasq collection generated in process_bioasq.py 85 | """ 86 | with open("data/bioasq_collection.json", "r", encoding="utf8") as f: 87 | data = json.load(f) 88 | return data 89 | 90 | def create_abstract2snippet_dataset(self): 91 | """ 92 | Generate the bioasq abstract to snippet (1 to 1) training dataset. This function creates data uses the same keys, summary and articles for each summary-article pair 93 | as the medlinplus training data does. This allows compatibility with the answer summarization data loading function in the pointer generator network. 94 | 95 | This is currently the only dataset with the option to include question. Since this dataset works the best, the add_q option was not added to the others. 96 | """ 97 | bioasq_collection = self._load_bioasq() 98 | training_data_dict = {} 99 | snip_id = 0 100 | for i, q in enumerate(bioasq_collection): 101 | question = q 102 | for snippet in bioasq_collection[q]['snippets']: 103 | training_data_dict[snip_id] = {} 104 | if args.summ_end_tag: 105 | snippet_text = snippet['snippet'] + self.SUMM_END 106 | else: 107 | snippet_text = snippet['snippet'] 108 | if args.tag_sentences: 109 | snippet_text = self.format_summary_sentences(snippet_text) 110 | training_data_dict[snip_id]['summary'] = snippet_text 111 | # Add the question with a special question seperator token to the beginning of the article. 112 | abstract = snippet['article'] 113 | with_question = "_without_question" 114 | if args.add_q: 115 | abstract = question + self.Q_END + abstract 116 | with_question = "_with_question" 117 | training_data_dict[snip_id]['articles'] = abstract 118 | training_data_dict[snip_id]['question'] = question 119 | snip_id += 1 120 | 121 | with open("data/bioasq_abs2summ_training_data{}.json".format(with_question), "w", encoding="utf=8") as f: 122 | json.dump(training_data_dict, f, indent=4) 123 | 124 | def calculate_sentence_level_rouge(self, snip_sen, abs_sen, evaluator): 125 | """ 126 | For each pair of sentences, calculate rouge score 127 | """ 128 | rouge_score = evaluator.get_scores(abs_sen, snip_sen)['rouge-l']['f'] 129 | return rouge_score 130 | 131 | def create_binary_sentence_classification_dataset_with_rouge(self): 132 | """ 133 | Create a dataset for training a sentence classification model, where the binary y labels are assigned based on the 134 | best rouge score for a sentence in the article when compared to each sentence in the summary 135 | """ 136 | # Initiate rouge evaluator 137 | evaluator = rouge.Rouge(metrics=['rouge-l'], 138 | max_n=3, 139 | limit_length=False, 140 | length_limit_type='words', 141 | apply_avg=False, 142 | apply_best=True, 143 | alpha=1, 144 | weight_factor=1.2, 145 | stemming=False) 146 | 147 | bioasq_collection = self._load_bioasq() 148 | training_data_dict = {}# 149 | snip_id = 0 150 | for i, q in enumerate(bioasq_collection): 151 | question = q 152 | for snippet in bioasq_collection[q]['snippets']: 153 | training_data_dict[snip_id] = {} 154 | labels = [] 155 | # Sentencize snippet 156 | snippet_text = snippet['snippet'] 157 | tokenized_snip = self.nlp(snippet_text) 158 | snippet_sentences = [s.text.strip() for s in tokenized_snip.sents] 159 | # Sentencize abstract 160 | abstract_text = snippet['article'] 161 | tokenized_abs = self.nlp(abstract_text) 162 | abstract_sentences = [s.text.strip() for s in tokenized_abs.sents] 163 | rouge_scores = [] 164 | for abs_sen in abstract_sentences: 165 | best_rouge = 0 166 | for snip_sen in snippet_sentences: 167 | rouge_score = self.calculate_sentence_level_rouge(snip_sen, abs_sen, evaluator) 168 | if best_rouge < rouge_score: 169 | best_rouge = rouge_score 170 | if best_rouge > .9: 171 | label = 1 172 | else: 173 | label = 0 174 | labels.append(label) 175 | training_data_dict[snip_id]['question'] = q 176 | training_data_dict[snip_id]['sentences'] = abstract_sentences 177 | training_data_dict[snip_id]['labels'] = labels 178 | snip_id += 1 179 | 180 | with open("data/bioasq_abs2summ_binary_sent_classification_training.json", "w", encoding="utf=8") as f: 181 | json.dump(training_data_dict, f, indent=4) 182 | # For each sentence in each abstract, compare it to each sentence in answer. Record the best rouge score. 183 | 184 | def create_data_for_bart(self): 185 | """ 186 | Write the train and val data to file so that the processor and tokenizer for bart will read it, as per fairseqs design 187 | """ 188 | bioasq_collection = self._load_bioasq() 189 | # Additional string is added to the question of the beginning of the abstract text 190 | if args.add_q: 191 | q_name = "with_question" 192 | else: 193 | q_name = "without_question" 194 | 195 | # Open medinfo data preprocessed in prepare_validation_data.py 196 | with open("data/medinfo_section2answer_validation_data_{}.json".format(q_name), "r", encoding="utf-8") as f: 197 | medinfo_val = json.load(f) 198 | 199 | try: 200 | os.mkdir("../models/bart/bart_config/{}".format(q_name)) 201 | except FileExistsError: 202 | print("Directory ", q_name , " already exists") 203 | 204 | train_src = open("../models/bart/bart_config/{q}/bart.train_{q}.source".format(q=q_name), "w", encoding="utf8") 205 | train_tgt = open("../models/bart/bart_config/{q}/bart.train_{q}.target".format(q=q_name), "w", encoding="utf8") 206 | val_src = open("../models/bart/bart_config/{q}/bart.val_{q}.source".format(q=q_name), "w", encoding="utf8") 207 | val_tgt = open("../models/bart/bart_config/{q}/bart.val_{q}.target".format(q=q_name), "w", encoding="utf8") 208 | snippets_list = [] 209 | abstracts_list = [] 210 | for i, q in enumerate(bioasq_collection): 211 | for snippet in bioasq_collection[q]['snippets']: 212 | snippet_text = snippet['snippet'].strip() 213 | abstract_text = snippet['article'].strip() 214 | # Why is there whitespace in the question? 215 | question = q.replace("\n", " ") 216 | if args.add_q: 217 | abstract_text = question + self.Q_END + abstract_text 218 | abstracts_list.append(abstract_text) 219 | snippets_list.append(snippet_text) 220 | 221 | snp_cnt = 0 222 | print("Shuffling data") 223 | snippets_list, abstracts_list = sk_shuffle(snippets_list, abstracts_list, random_state=13) 224 | for snippet_text, abstract_text in zip(snippets_list, abstracts_list): 225 | snp_cnt += 1 226 | train_src.write("{}\n".format(abstract_text)) 227 | train_tgt.write("{}\n".format(snippet_text)) 228 | 229 | for q_id in medinfo_val: 230 | # The prepared medinfo data may have sentence tags in it for pointer generator. 231 | # There is an option in the prepare_validation_data.py script to not tag the data, 232 | # but it is easier to keep track of the datasets to just remove the tags here. 233 | summ = medinfo_val[q_id]['summary'].strip() 234 | summ = summ.replace("", "") 235 | summ = summ.replace("", "") 236 | articles = medinfo_val[q_id]['articles'].strip() 237 | val_src.write("{}\n".format(articles)) 238 | val_tgt.write("{}\n".format(summ)) 239 | 240 | train_src.close() 241 | train_tgt.close() 242 | val_src.close() 243 | val_tgt.close() 244 | 245 | # Make sure there were no funny newlines added 246 | train_src = open("../models/bart/bart_config/{q}/bart.train_{q}.source".format(q=q_name), "r", encoding="utf8").readlines() 247 | train_tgt = open("../models/bart/bart_config/{q}/bart.train_{q}.target".format(q=q_name), "r", encoding="utf8").readlines() 248 | val_src = open("../models/bart/bart_config/{q}/bart.val_{q}.source".format(q=q_name), "r", encoding="utf8").readlines() 249 | val_tgt = open("../models/bart/bart_config/{q}/bart.val_{q}.target".format(q=q_name), "r", encoding="utf8").readlines() 250 | print("Number of snippets: ", snp_cnt) 251 | assert len(train_src) == snp_cnt, len(train_src) 252 | assert len(train_tgt) == snp_cnt 253 | assert len(val_src) == len(medinfo_val) 254 | assert len(val_tgt) == len(medinfo_val) 255 | 256 | 257 | def process_data(): 258 | """ 259 | Save training data sets 260 | """ 261 | if args.bioasq_pg: 262 | BioASQ().create_abstract2snippet_dataset() 263 | if args.bioasq_sc: 264 | BioASQ().create_binary_sentence_classification_dataset_with_rouge() 265 | if args.bart_bioasq: 266 | BioASQ().create_data_for_bart() 267 | 268 | if __name__ == "__main__": 269 | global args 270 | args = get_args().parse_args() 271 | process_data() 272 | -------------------------------------------------------------------------------- /data_processing/prepare_validation_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for classes to prepare validation dataset from MedInfo dataset. 3 | Data format will be {key: {'question': question, 'summary':, summ, 'articles': articles} ...} 4 | 5 | Additionally, format for question driven summarization. For example: 6 | python prepare_validation_data.py -t --add-q 7 | """ 8 | 9 | 10 | import json 11 | import argparse 12 | import re 13 | 14 | import spacy 15 | import rouge 16 | 17 | def get_args(): 18 | """ 19 | Argument defnitions 20 | """ 21 | parser = argparse.ArgumentParser(description="Arguments for data exploration") 22 | parser.add_argument("--pg", 23 | dest="pg", 24 | action="store_true", 25 | help="tag the sentences with and , for use with pointer generator network") 26 | parser.add_argument("--bart", 27 | dest="bart", 28 | action="store_true", 29 | help="Prepare data for BART") 30 | parser.add_argument("--add-q", 31 | dest="add_q", 32 | action="store_true", 33 | help="Concatenate the question to the beginning of the text for question driven summarization") 34 | 35 | return parser 36 | 37 | 38 | class MedInfo(): 39 | 40 | def __init__(self): 41 | """ 42 | Initiate class for processing medinfo collection 43 | """ 44 | self.nlp = spacy.load('en_core_web_sm') 45 | if args.add_q: 46 | self.q_name = "_with_question" 47 | else: 48 | self.q_name = "_without_question" 49 | 50 | def _load_collection(self): 51 | """ 52 | Load medinfo collection prepared in the process_medinfo.py script 53 | """ 54 | with open("data/medinfo_collection.json", "r", encoding="utf-8") as f: 55 | medinfo = json.load(f) 56 | 57 | return medinfo 58 | 59 | def _format_summary_sentences(self, summary): 60 | """ 61 | Split summary into sentences and add sentence tags to the strings: and 62 | """ 63 | tokenized_abs = self.nlp(summary) 64 | summary = " ".join([" {s} ".format(s=s.text.strip()) for s in tokenized_abs.sents]) 65 | return summary 66 | 67 | def save_section2answer_validation_data(self, tag_sentences): 68 | """ 69 | For questions that have a corresponding section-answer pair, save the 70 | validation data in following format 71 | {'question': {'summary': text, 'articles': text}} 72 | """ 73 | dev_dict = {} 74 | medinfo = self._load_collection() 75 | data_pair = 0 76 | Q_END = " [QUESTION?] " 77 | for i, question in enumerate(medinfo): 78 | try: 79 | # There may be multiple answers per question, but for the sake of the validation set, 80 | # just use the first answer 81 | if 'section_text' in medinfo[question][0]: 82 | article = medinfo[question][0]['section_text'] 83 | summary = medinfo[question][0]['answer'] 84 | # Stripping of whitespace was done in processing script for section and full page 85 | # but not for answer or question 86 | summary = re.sub(r"\s+", " ", summary) 87 | question = re.sub(r"\s+", " ", question) 88 | if args.add_q: 89 | article = question + Q_END + article 90 | assert len(summary) <= (len(article) + 10) 91 | if tag_sentences: 92 | summary = self._format_summary_sentences(summary) 93 | tag_string = "_s-tags" 94 | else: 95 | tag_string = "" 96 | data_pair += 1 97 | dev_dict[i] = {'question': question, 'summary': summary, 'articles': article} 98 | except AssertionError: 99 | print("Answer longer than summary. Skipping element") 100 | 101 | print("Number of page-section pairs:", data_pair) 102 | 103 | with open("data/medinfo_section2answer_validation_data{0}{1}.json".format(self.q_name, tag_string), "w", encoding="utf-8") as f: 104 | json.dump(dev_dict, f, indent=4) 105 | 106 | 107 | def process_data(): 108 | """ 109 | Main function for saving data 110 | """ 111 | # Run once for each 112 | if args.pg: 113 | MedInfo().save_section2answer_validation_data(tag_sentences=True) 114 | if args.bart: 115 | MedInfo().save_section2answer_validation_data(tag_sentences=False) 116 | 117 | if __name__ == "__main__": 118 | global args 119 | args = get_args().parse_args() 120 | process_data() 121 | -------------------------------------------------------------------------------- /data_processing/process_bioasq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for processing BioASQ json data and saving 3 | 4 | to download the pubmed articles for each snippet run 5 | python process_bioasq.py -d 6 | then to process the questions, answers, and snippets, run: 7 | python process_bioasq.py -p 8 | """ 9 | 10 | 11 | import json 12 | import sys 13 | import os 14 | import argparse 15 | import lxml.etree as le 16 | import glob 17 | from collections import Counter 18 | 19 | import numpy as np 20 | 21 | 22 | def get_args(): 23 | """ 24 | Get command line arguments 25 | """ 26 | 27 | parser = argparse.ArgumentParser(description="Arguments for data exploration") 28 | parser.add_argument("-p", 29 | dest="process", 30 | action="store_true", 31 | help="Process bioasq data") 32 | return parser 33 | 34 | 35 | class BioASQ(): 36 | """ 37 | Class for processing and saving BioASQ data 38 | """ 39 | 40 | def _load_bioasq(self): 41 | """ 42 | Load bioasq dataset 43 | """ 44 | with open("data/BioASQ-training7b/BioASQ-training7b/training7b.json", "r", encoding="ascii") as f: 45 | bioasq_questions = json.load(f)['questions'] 46 | return bioasq_questions 47 | 48 | def bioasq(self): 49 | """ 50 | Process BioASQ training data. Generate summary stats. Save questions, ideal answers, snippets, articles, and question types. 51 | """ 52 | bioasq_questions = self._load_bioasq() 53 | with open("data/bioasq_pubmed_articles.json", "r", encoding="ascii") as f: 54 | articles = json.load(f) 55 | # Dictionary to save condensed json of bioasq 56 | bioasq_collection = {} 57 | questions = [] 58 | ideal_answers = [] 59 | ideal_answer_dict = {} 60 | exact_answers = [] 61 | snippet_dict = {} 62 | for i, q in enumerate(bioasq_questions): 63 | # Get the question 64 | bioasq_collection[q['body']] = {} 65 | questions.append(q['body']) 66 | # Get the references used to answer that question 67 | pmid_list= [d.split("/")[-1] for d in q['documents']] 68 | # Get the question type: list, summary, yes/no, or factoid 69 | q_type = q['type'] 70 | bioasq_collection[q['body']]['q_type'] = q_type 71 | # Take the first ideal answer 72 | assert isinstance(q['ideal_answer'], list) 73 | assert isinstance(q['ideal_answer'][0], str) 74 | ideal_answer_dict[i] = q['ideal_answer'][0] 75 | bioasq_collection[q['body']]['ideal_answer'] = q['ideal_answer'][0] 76 | # And get the first exact answer 77 | if q_type != "summary": 78 | # Yesno questions will have just a yes/no string in exact answer. 79 | if q_type == "yesno": 80 | exact_answers.append(q['exact_answer'][0]) 81 | bioasq_collection[q['body']]['exact_answer'] = q['exact_answer'][0] 82 | else: 83 | if isinstance(q['exact_answer'], str): 84 | exact_answers.append(q['exact_answer']) 85 | bioasq_collection[q['body']]['exact_answer'] = q['exact_answer'] 86 | else: 87 | exact_answers.append(q['exact_answer'][0]) 88 | bioasq_collection[q['body']]['exact_answer'] = q['exact_answer'][0] 89 | # Then handle the snippets (the text extracted from the abstract) 90 | bioasq_collection[q['body']]['snippets'] = [] 91 | snippet_dict[q['body']] = [] 92 | for snippet in q['snippets']: 93 | pmid_match = False 94 | snippet_dict[q['body']].append(snippet['text']) 95 | doc_pmid = str(snippet['document'].split("/")[-1]) 96 | try: 97 | article = articles[doc_pmid] 98 | # Add the data to the dictionary containing the collection. 99 | bioasq_collection[q['body']]['snippets'].append({'snippet': snippet['text'], 'article': article, 'pmid': doc_pmid}) 100 | except KeyError as e: 101 | continue 102 | 103 | with open("data/bioasq_ideal_answers.json", "w", encoding="utf8") as f: 104 | json.dump(ideal_answer_dict, f, indent=4) 105 | with open("data/bioasq_snippets.json", "w", encoding="utf8") as f: 106 | json.dump(snippet_dict, f, indent=4) 107 | with open("data/bioasq_collection.json", "w", encoding="utf8") as f: 108 | json.dump(bioasq_collection, f, indent=4) 109 | 110 | 111 | def process_bioasq(): 112 | """ 113 | Main processing function for bioasq data 114 | """ 115 | bq = BioASQ() 116 | if args.process: 117 | bq.bioasq() 118 | 119 | if __name__ == "__main__": 120 | global args 121 | args = get_args().parse_args() 122 | process_bioasq() 123 | -------------------------------------------------------------------------------- /evaluation/collection_statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to compute statistics on collection and validate the data integrity 3 | 4 | To run with counts of sentences and tokens: 5 | python collection_statistics.py --tokenize 6 | Otherwise don't include --tokenize option 7 | """ 8 | 9 | import argparse 10 | import spacy 11 | import json 12 | import numpy as np 13 | from collections import Counter 14 | 15 | 16 | def get_args(): 17 | """ 18 | Argument parser for preparing chiqa data 19 | """ 20 | parser = argparse.ArgumentParser(description="Arguments for data exploration") 21 | parser.add_argument("--tokenize", 22 | dest="tokenize", 23 | action="store_true", 24 | help="Tokenize by words and sentences, counting averages/sd for each.") 25 | return parser 26 | 27 | 28 | class SummarizationDataStats(): 29 | """ 30 | Class for validating annotated collection 31 | """ 32 | 33 | def __init__(self): 34 | """ 35 | Init spacy and counting variables. 36 | """ 37 | self.nlp = spacy.load('en_core_web_sm') 38 | self.summary_file = open("results/question_driven_answer_summ_collection_stats.txt", "w", encoding="utf-8") 39 | 40 | def load_data(self, dataset, dataset_name): 41 | """ 42 | Given the path of the dataset, load it! 43 | """ 44 | with open(dataset, "r", encoding="utf-8") as f: 45 | self.data = json.load(f) 46 | self.dataset_name = dataset_name 47 | 48 | def _get_token_cnts(self, doc, doc_type): 49 | """ 50 | Count tokens and in documents 51 | """ 52 | tokenized_doc = self.nlp(doc) 53 | self.stat_dict[doc_type][0].append(len([s for s in tokenized_doc.sents])) 54 | doc_len = len([t for t in tokenized_doc]) 55 | self.stat_dict[doc_type][1].append(doc_len) 56 | if doc_len < 50 and doc_type == "answer": 57 | print("Document less than 50 tokens:", url) 58 | 59 | def iterate_data(self): 60 | """ 61 | Count the number of examples in each dataset 62 | """ 63 | if "single" in self.dataset_name: 64 | # Index 0 for list of sentence lengths, index 1 for list of token lengths 65 | self.stat_dict = {'question': [[], []], 'summary': [[], []], 'article': [[], []]} 66 | for answer_id in self.data: 67 | summary = self.data[answer_id]['summary'] 68 | articles = self.data[answer_id]['articles'] 69 | question = self.data[answer_id]['question'] 70 | if args.tokenize: 71 | self._get_token_cnts(summary, 'summary') 72 | self._get_token_cnts(articles, 'article') 73 | self._get_token_cnts(question, 'question') 74 | self._write_stats("token_counts") 75 | 76 | if "multi" in self.dataset_name: 77 | self.stat_dict = {'question': [[], []], 'summary': [[], []], 'article': [[], []]} 78 | for q_id in self.data: 79 | summary = self.data[q_id]['summary'] 80 | question = self.data[q_id]['question'] 81 | if args.tokenize: 82 | self._get_token_cnts(summary, 'summary') 83 | self._get_token_cnts(question, 'question') 84 | question = self.data[q_id]['question'] 85 | for answer_id in self.data[q_id]['articles']: 86 | articles = self.data[q_id]['articles'][answer_id][0] 87 | if args.tokenize: 88 | self._get_token_cnts(articles, 'article') 89 | self._write_stats("token_counts") 90 | 91 | if self.dataset_name == "complete_dataset": 92 | self.stat_dict = {'urls': [], 'sites': []} 93 | article_dict = {} 94 | print("Counting answers, sites, unique urls, and tokenized counts of unique articles") 95 | answer_cnt = 0 96 | for q_id in self.data: 97 | for a_id in self.data[q_id]['answers']: 98 | answer_cnt += 1 99 | url = self.data[q_id]['answers'][a_id]['url'] 100 | article = self.data[q_id]['answers'][a_id]['article'] 101 | if url not in article_dict: 102 | article_dict[url] = article 103 | self.stat_dict['urls'].append(url) 104 | assert "//" in url, url 105 | site = url.split("//")[1].split("/") 106 | self.stat_dict['sites'].append(site[0]) 107 | print("# of Answers:", answer_cnt) 108 | print("Unique articles: ", len(article_dict)) # This should match up with count written to file 109 | self._write_stats("full collection") 110 | 111 | # Get token/sent averages of unique articles 112 | if args.tokenize: 113 | self.stat_dict = {'article': [[], []]} 114 | for a in article_dict: 115 | self._get_token_cnts(article_dict[a], 'article') 116 | self._write_stats("token_counts") 117 | 118 | def _write_stats(self, stat_type, user=None, summ_type=None): 119 | """ 120 | Return chiqa page stats 121 | """ 122 | if stat_type == "full collection": 123 | self.summary_file.write("\n\nDataset: {c}\n".format(c=self.dataset_name)) 124 | self.summary_file.write("Number of unique urls: {u}\nNumber of unique sites: {s}\n".format(u=len(set(self.stat_dict['urls'])), s=len(set(self.stat_dict['sites']))) 125 | ) 126 | site_cnts = Counter(self.stat_dict['sites']).most_common() 127 | for site in site_cnts: 128 | self.summary_file.write("{s}: {n}\n".format(s=site[0], n=site[1])) 129 | 130 | if stat_type == "token_counts": 131 | self.summary_file.write("\n\nDataset: {c}\n".format(c=self.dataset_name)) 132 | for doc_type in self.stat_dict: 133 | if user is not None: 134 | self.summary_file.write("\n{0}, {1}\n".format(user, summ_type)) 135 | 136 | self.summary_file.write( 137 | "\nNumber of {d}s: {p}\nAverage tokens/{d}: {t}\nAverage sentences/{d}: {s}\n".format( 138 | d=doc_type, p=len(self.stat_dict[doc_type][0]), t=sum(self.stat_dict[doc_type][1])/len(self.stat_dict[doc_type][1]), s=sum(self.stat_dict[doc_type][0])/len(self.stat_dict[doc_type][0]) 139 | ) 140 | ) 141 | 142 | self.summary_file.write( 143 | "Median tokens/{d}: {p}\nStandard deviation tokens/{d}: {t}\n".format( 144 | d=doc_type, p=np.median(self.stat_dict[doc_type][1]), t=np.std(self.stat_dict[doc_type][1]) 145 | ) 146 | ) 147 | 148 | self.summary_file.write( 149 | "Median sentences/{d}: {p}\nStandard deviation sentences/{d}: {t}\n".format( 150 | d=doc_type, p=np.median(self.stat_dict[doc_type][0]), t=np.std(self.stat_dict[doc_type][0]) 151 | ) 152 | ) 153 | 154 | 155 | def get_stats(): 156 | """ 157 | Main function for getting CHiQA collection stats 158 | """ 159 | datasets = [ 160 | ("../data_processing/data/page2answer_single_abstractive_summ.json", "p2a-single-abs"), 161 | ("../data_processing/data/page2answer_single_extractive_summ.json", "p2a-single-ext"), 162 | ("../data_processing/data/section2answer_multi_abstractive_summ.json", "s2a-multi-abs"), 163 | ("../data_processing/data/page2answer_multi_extractive_summ.json", "p2a-multi-ext"), 164 | ("../data_processing/data/section2answer_single_abstractive_summ.json", "s2a-single-abs"), 165 | ("../data_processing/data/section2answer_single_extractive_summ.json", "s2a-single-ext"), 166 | ("../data_processing/data/section2answer_multi_extractive_summ.json", "s2a-multi-ext"), 167 | ("../data_processing/data/question_driven_answer_summarization_primary_dataset.json", "complete_dataset"), 168 | ] 169 | 170 | stats = SummarizationDataStats() 171 | for dataset in datasets: 172 | print(dataset[1]) 173 | stats.load_data(dataset[0], dataset[1]) 174 | stats.iterate_data() 175 | 176 | 177 | if __name__ == "__main__": 178 | global args 179 | args = get_args().parse_args() 180 | get_stats() 181 | -------------------------------------------------------------------------------- /evaluation/data/bart/chiqa_eval/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /evaluation/data/baselines/chiqa_eval/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /evaluation/data/pointer_generator/chiqa_eval/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /evaluation/data/sentence_classifier/chiqa_eval/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /evaluation/results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /models/bart/bart_config/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /models/bart/finetune_bart_bioasq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TOTAL_NUM_UPDATES=20000 4 | WARMUP_UPDATES=500 5 | LR=3e-05 6 | MAX_TOKENS=1024 7 | UPDATE_FREQ=16 8 | BART_PATH=bart/bart.large/model.pt 9 | checkpoint_path=checkpoints_bioasq_$1 10 | asumm_data=bart_config/${1}/bart-bin 11 | 12 | #CUDA_VISIBLE_DEVICES=0 python train.py ${asumm_data} \ 13 | CUDA_VISIBLE_DEVICES=0 python fairseq/fairseq_cli/train.py ${asumm_data} \ 14 | --restore-file $BART_PATH \ 15 | --max-tokens $MAX_TOKENS \ 16 | --truncate-source \ 17 | --task translation \ 18 | --source-lang source --target-lang target \ 19 | --layernorm-embedding \ 20 | --share-all-embeddings \ 21 | --share-decoder-input-output-embed \ 22 | --reset-optimizer --reset-dataloader --reset-meters \ 23 | --required-batch-size-multiple 1 \ 24 | --arch bart_large \ 25 | --criterion label_smoothed_cross_entropy \ 26 | --label-smoothing 0.1 \ 27 | --dropout 0.1 --attention-dropout 0.1 \ 28 | --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \ 29 | --clip-norm 0.1 \ 30 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ 31 | --update-freq $UPDATE_FREQ \ 32 | --skip-invalid-size-inputs-valid-test \ 33 | --fp16 \ 34 | --ddp-backend=no_c10d \ 35 | --save-dir=${checkpoint_path} \ 36 | --keep-last-epochs=2 \ 37 | --find-unused-parameters; 38 | -------------------------------------------------------------------------------- /models/bart/make_datafiles_for_bart.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the make_datafiles script from the pointgenerator cnn-dailymail repo 3 | https://github.com/abisee/cnn-dailymail/blob/b15ad0a2db0d407a84b8ca9b5731e1f1c4bd24b9/make_datafiles.py#L235 4 | but with modifications made here 5 | https://gist.github.com/zhaoguangxiang/45bf39c528cf7fb7853bffba7fe57c7e 6 | The script now saves the cnn dailymail files as test.source and test.target, etc 7 | 8 | It requires python 2 to run. Run the activate_py2_env.sh to activate. 9 | """ 10 | 11 | import sys 12 | import os 13 | import hashlib 14 | import struct 15 | import subprocess 16 | import collections 17 | import codecs 18 | # import tensorflow as tf 19 | # from tensorflow.core.example import example_pb2 20 | # import sys 21 | # reload(sys) 22 | # sys.setdefaultencoding('utf8') 23 | # cnt_r=0 24 | dm_single_close_quote = u'\u2019' # unicode 25 | dm_double_close_quote = u'\u201d' 26 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence 27 | # We use these to separate the summary sentences in the .bin datafiles 28 | SENTENCE_START = '' 29 | SENTENCE_END = '' 30 | 31 | all_train_urls = "/data/saveryme/asumm/asumm_data/cnn-dailymail/url_lists/all_train.txt" 32 | all_val_urls = "/data/saveryme/asumm/asumm_data/cnn-dailymail/url_lists/all_val.txt" 33 | all_test_urls = "/data/saveryme/asumm/asumm_data/cnn-dailymail/url_lists/all_test.txt" 34 | 35 | # cnn_tokenized_stories_dir = "cnn_stories_tokenized" 36 | # dm_tokenized_stories_dir = "dm_stories_tokenized" 37 | cnn_tokenized_stories_dir = "/data/saveryme/asumm/asumm_data/cnn_dm/cnn/stories" 38 | dm_tokenized_stories_dir = "/data/saveryme/asumm/asumm_data/cnn_dm/dailymail/stories" 39 | # finished_files_dir = "finished_files" 40 | finished_files_dir = "cnn_dm_finished_files" 41 | # chunks_dir = os.path.join(finished_files_dir, "chunked") 42 | 43 | # These are the number of .story files we expect there to be in cnn_stories_dir and dm_stories_dir 44 | num_expected_cnn_stories = 92579 45 | num_expected_dm_stories = 219506 46 | 47 | VOCAB_SIZE = 200000 48 | CHUNK_SIZE = 1000 # num examples per chunk, for the chunked data 49 | 50 | 51 | 52 | def chunk_file(set_name): 53 | in_file = 'finished_files/%s.bin' % set_name 54 | reader = open(in_file, "rb") 55 | chunk = 0 56 | finished = False 57 | while not finished: 58 | chunk_fname = os.path.join(chunks_dir, '%s_%03d.bin' % (set_name, chunk)) # new chunk 59 | with open(chunk_fname, 'wb') as writer: 60 | for _ in range(CHUNK_SIZE): 61 | len_bytes = reader.read(8) 62 | if not len_bytes: 63 | finished = True 64 | break 65 | str_len = struct.unpack('q', len_bytes)[0] 66 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 67 | writer.write(struct.pack('q', str_len)) 68 | writer.write(struct.pack('%ds' % str_len, example_str)) 69 | chunk += 1 70 | 71 | 72 | def chunk_all(): 73 | # Make a dir to hold the chunks 74 | if not os.path.isdir(chunks_dir): 75 | os.mkdir(chunks_dir) 76 | # Chunk the data 77 | for set_name in ['train', 'val', 'test']: 78 | print "Splitting %s data into chunks..." % set_name 79 | chunk_file(set_name) 80 | print "Saved chunked data in %s" % chunks_dir 81 | 82 | 83 | def tokenize_stories(stories_dir, tokenized_stories_dir): 84 | """Maps a whole directory of .story files to a tokenized version using Stanford CoreNLP Tokenizer""" 85 | print "Preparing to tokenize %s to %s..." % (stories_dir, tokenized_stories_dir) 86 | stories = os.listdir(stories_dir) 87 | # make IO list file 88 | print "Making list of files to tokenize..." 89 | with open("mapping.txt", "w") as f: 90 | for s in stories: 91 | f.write("%s \t %s\n" % (os.path.join(stories_dir, s), os.path.join(tokenized_stories_dir, s))) 92 | command = ['java', 'edu.stanford.nlp.process.PTBTokenizer', '-ioFileList', '-preserveLines', 'mapping.txt'] 93 | print "Tokenizing %i files in %s and saving in %s..." % (len(stories), stories_dir, tokenized_stories_dir) 94 | subprocess.call(command) 95 | os.remove("mapping.txt") 96 | 97 | # Check that the tokenized stories directory contains the same number of files as the original directory 98 | num_orig = len(os.listdir(stories_dir)) 99 | num_tokenized = len(os.listdir(tokenized_stories_dir)) 100 | if num_orig != num_tokenized: 101 | raise Exception("The tokenized stories directory %s contains %i files, but it should contain the same number as %s (which has %i files). Was there an error during tokenization?" % (tokenized_stories_dir, num_tokenized, stories_dir, num_orig)) 102 | print "Successfully finished tokenizing %s to %s.\n" % (stories_dir, tokenized_stories_dir) 103 | 104 | 105 | def read_text_file(text_file): 106 | lines = [] 107 | with open(text_file, "r") as f: 108 | for line in f: 109 | lines.append(line.strip()) 110 | return lines 111 | 112 | 113 | def hashhex(s): 114 | """Returns a heximal formated SHA1 hash of the input string.""" 115 | h = hashlib.sha1() 116 | h.update(s) 117 | return h.hexdigest() 118 | 119 | 120 | def get_url_hashes(url_list): 121 | return [hashhex(url) for url in url_list] 122 | 123 | 124 | def fix_missing_period(line): 125 | """Adds a period to a line that is missing a period""" 126 | if "@highlight" in line: return line 127 | if line=="": return line 128 | if line[-1] in END_TOKENS: return line 129 | return line + "." 130 | 131 | 132 | def get_art_abs(story_file): 133 | lines = read_text_file(story_file) 134 | 135 | # Lowercase everything 136 | # lines = [line.lower() for line in lines] 137 | 138 | # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences) 139 | lines = [fix_missing_period(line) for line in lines] 140 | 141 | # Separate out article and abstract sentences 142 | article_lines = [] 143 | highlights = [] 144 | next_is_highlight = False 145 | for idx,line in enumerate(lines): 146 | if line == "": 147 | continue # empty line 148 | elif line.startswith("@highlight"): 149 | next_is_highlight = True 150 | elif next_is_highlight: 151 | highlights.append(line) 152 | else: 153 | article_lines.append(line) 154 | 155 | # Make article into a single string 156 | article = ' '.join(article_lines) 157 | 158 | # Make abstract into a signle string, putting and tags around the sentences 159 | # abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) 160 | abstract = ' '.join([" %s " % (sent) for sent in highlights]) 161 | 162 | return article, abstract 163 | 164 | 165 | def write_to_bin(url_file, out_file, makevocab=False): 166 | """Reads the tokenized .story files corresponding to the urls listed in the url_file and writes them to a out_file.""" 167 | print "Making bin file for URLs listed in %s..." % url_file 168 | url_list = read_text_file(url_file) 169 | url_hashes = get_url_hashes(url_list) 170 | story_fnames = [s+".story" for s in url_hashes] 171 | num_stories = len(story_fnames) 172 | 173 | if makevocab: 174 | vocab_counter = collections.Counter() 175 | cnt_r=0 176 | cnt_n=0 177 | with open(out_file, 'wb') as writer: 178 | for idx,s in enumerate(story_fnames): 179 | if idx % 1000 == 0: 180 | print "Writing story %i of %i; %.2f percent done" % (idx, num_stories, float(idx)*100.0/float(num_stories)) 181 | 182 | # Look in the tokenized story dirs to find the .story file corresponding to this url 183 | if os.path.isfile(os.path.join(cnn_tokenized_stories_dir, s)): 184 | story_file = os.path.join(cnn_tokenized_stories_dir, s) 185 | elif os.path.isfile(os.path.join(dm_tokenized_stories_dir, s)): 186 | story_file = os.path.join(dm_tokenized_stories_dir, s) 187 | else: 188 | print "Error: Couldn't find tokenized story file %s in either tokenized story directories %s and %s. Was there an error during tokenization?" % (s, cnn_tokenized_stories_dir, dm_tokenized_stories_dir) 189 | # Check again if tokenized stories directories contain correct number of files 190 | print "Checking that the tokenized stories directories %s and %s contain correct number of files..." % (cnn_tokenized_stories_dir, dm_tokenized_stories_dir) 191 | check_num_stories(cnn_tokenized_stories_dir, num_expected_cnn_stories) 192 | check_num_stories(dm_tokenized_stories_dir, num_expected_dm_stories) 193 | raise Exception("Tokenized stories directories %s and %s contain correct number of files but story file %s found in neither." % (cnn_tokenized_stories_dir, dm_tokenized_stories_dir, s)) 194 | 195 | # Get the strings to write to .bin file 196 | article, abstract = get_art_abs(story_file) 197 | if os.path.isfile(os.path.join(cnn_tokenized_stories_dir, s)) and article[:5] == '(CNN)': 198 | article = article[5:] 199 | 200 | article.decode('utf-8','ignore').encode('utf-8') 201 | if '\r' in article: 202 | cnt_r += 1 203 | article.replace("\r", " ") 204 | print "there is a line contains r" 205 | print cnt_r 206 | if '\n' in article: 207 | cnt_n += 1 208 | article.replace("\n", " ") 209 | print "there is a line contains n" 210 | print cnt_n 211 | article = ' '.join(article.split()) 212 | with open(out_file+'.source', mode='a+') as src: 213 | src.write(article+'\n') 214 | with codecs.open(out_file+'.target', mode='a+') as tgt: 215 | tgt.write(abstract+'\n') 216 | # Write to tf.Example 217 | # tf_example = example_pb2.Example() 218 | # tf_example.features.feature['article'].bytes_list.value.extend([article]) 219 | # tf_example.features.feature['abstract'].bytes_list.value.extend([abstract]) 220 | # tf_example_str = tf_example.SerializeToString() 221 | # str_len = len(tf_example_str) 222 | # writer.write(struct.pack('q', str_len)) 223 | # writer.write(struct.pack('%ds' % str_len, tf_example_str)) 224 | 225 | # Write the vocab to file, if applicable 226 | if makevocab: 227 | art_tokens = article.split(' ') 228 | abs_tokens = abstract.split(' ') 229 | abs_tokens = [t for t in abs_tokens if t not in [SENTENCE_START, SENTENCE_END]] # remove these tags from vocab 230 | tokens = art_tokens + abs_tokens 231 | tokens = [t.strip() for t in tokens] # strip 232 | tokens = [t for t in tokens if t!=""] # remove empty 233 | vocab_counter.update(tokens) 234 | 235 | print "Finished writing file %s\n" % out_file 236 | 237 | # write vocab to file 238 | if makevocab: 239 | print "Writing vocab file..." 240 | with open(os.path.join(finished_files_dir, "vocab"), 'w') as writer: 241 | for word, count in vocab_counter.most_common(VOCAB_SIZE): 242 | writer.write(word + ' ' + str(count) + '\n') 243 | print "Finished writing vocab file" 244 | 245 | 246 | def check_num_stories(stories_dir, num_expected): 247 | num_stories = len(os.listdir(stories_dir)) 248 | if num_stories != num_expected: 249 | raise Exception("stories directory %s contains %i files but should contain %i" % (stories_dir, num_stories, num_expected)) 250 | 251 | 252 | if __name__ == '__main__': 253 | # if len(sys.argv) != 3: 254 | # print "USAGE: python make_datafiles.py " 255 | # sys.exit() 256 | # cnn_stories_dir = sys.argv[1] 257 | # dm_stories_dir = sys.argv[2] 258 | 259 | 260 | # Check the stories directories contain the correct number of .story files 261 | # check_num_stories(cnn_stories_dir, num_expected_cnn_stories) 262 | # check_num_stories(dm_stories_dir, num_expected_dm_stories) 263 | 264 | # Create some new directories 265 | # if not os.path.exists(cnn_tokenized_stories_dir): os.makedirs(cnn_tokenized_stories_dir) 266 | # if not os.path.exists(dm_tokenized_stories_dir): os.makedirs(dm_tokenized_stories_dir) 267 | if not os.path.exists(finished_files_dir): os.makedirs(finished_files_dir) 268 | 269 | # Run stanford tokenizer on both stories dirs, outputting to tokenized stories directories 270 | # tokenize_stories(cnn_stories_dir, cnn_tokenized_stories_dir) 271 | # tokenize_stories(dm_stories_dir, dm_tokenized_stories_dir) 272 | 273 | # Read the tokenized stories, do a little postprocessing then write to bin files 274 | # cnt_r=0 275 | write_to_bin(all_test_urls, os.path.join(finished_files_dir, "test")) 276 | write_to_bin(all_val_urls, os.path.join(finished_files_dir, "val")) 277 | write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train"), makevocab=True) 278 | # write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train.bin")) 279 | 280 | # Chunk the data. This splits each of train.bin, val.bin and test.bin into smaller chunks, each containing e.g. 1000 examples, and saves them in finished_files/chunks 281 | # chunk_all() 282 | -------------------------------------------------------------------------------- /models/bart/process_bioasq_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Process bioasq data generated in the prepare_training_data.py script 3 | # -b as first argument for byte pair encoding 4 | # -f as second argument for the rest of the fairseq processing 5 | # Specify with_question or without_question as third argument 6 | 7 | wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' 8 | wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' 9 | wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt' 10 | 11 | echo "Args: $@" 12 | asumm_data=bart_config/${3} 13 | echo "Data dir: $asumm_data" 14 | if [ "$1" == "-b" ] 15 | then 16 | echo "Running byte-pair encoding" 17 | for SPLIT in train val 18 | do 19 | for LANG in source target 20 | do 21 | python -m examples.roberta.multiprocessing_bpe_encoder \ 22 | --encoder-json encoder.json \ 23 | --vocab-bpe vocab.bpe \ 24 | --inputs ${asumm_data}/bart.${SPLIT}_${3}.$LANG \ 25 | --outputs ${asumm_data}/bart.${SPLIT}_${3}.bpe.$LANG \ 26 | --workers 60 \ 27 | --keep-empty; 28 | done 29 | done 30 | fi 31 | 32 | if [ "$2" == "-f" ] 33 | then 34 | echo "Running fairseq processing" 35 | fairseq-preprocess \ 36 | --source-lang "source" \ 37 | --target-lang "target" \ 38 | --trainpref ${asumm_data}/bart.train_${3}.bpe \ 39 | --validpref ${asumm_data}/bart.val_${3}.bpe \ 40 | --destdir ${asumm_data}/bart-bin \ 41 | --workers 60 \ 42 | --srcdict dict.txt \ 43 | --tgtdict dict.txt; 44 | fi 45 | -------------------------------------------------------------------------------- /models/bart/run_chiqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Include with_question or without_question when calling script 3 | model_config=bart_config/${1}/bart-bin 4 | model_path=checkpoints_bioasq_$1 5 | for summ_task in page2answer section2answer 6 | do 7 | for summ_type in single_abstractive single_extractive 8 | do 9 | data=${summ_task}_${summ_type}_summ.json 10 | input_file=../../data_processing/data/${data} 11 | prediction_file=bart_chiqa_${1}_${summ_task}_${summ_type}.json 12 | prediction_path=../../evaluation/data/bart/chiqa_eval/${prediction_file} 13 | echo $input_file 14 | echo $prediction_path 15 | python run_inference_medsumm.py \ 16 | --input_file=$input_file \ 17 | --question_driven=$1 \ 18 | --prediction_file=$prediction_path \ 19 | --model_path=$model_path \ 20 | --model_config=$model_config 21 | done 22 | done 23 | -------------------------------------------------------------------------------- /models/bart/run_inference_medsumm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script modified from fairseq example CNN-dm script, for performing 3 | summarization inference on input articles. 4 | """ 5 | 6 | import argparse 7 | import json 8 | 9 | from tqdm import tqdm 10 | import torch 11 | from fairseq.models.bart import BARTModel 12 | 13 | 14 | def get_args(): 15 | """ 16 | Argument defnitions 17 | """ 18 | parser = argparse.ArgumentParser(description="Arguments for data exploration") 19 | parser.add_argument("--prediction_file", 20 | dest="prediction_file", 21 | help="File to save predictions") 22 | parser.add_argument("--input_file", 23 | dest="input_file", 24 | help="File with text to summarize") 25 | parser.add_argument("--question_driven", 26 | dest="question_driven", 27 | help="Whether to add question to beginning of article for question-driven summarization") 28 | parser.add_argument("--model_path", 29 | dest="model_path", 30 | help="Path to model checkpoints") 31 | parser.add_argument("--model_config", 32 | dest="model_config", 33 | help="Path to model vocab") 34 | parser.add_argument("--batch_size", 35 | dest="batch_size", 36 | default=32, 37 | help="Batch size for inference") 38 | return parser 39 | 40 | 41 | def run_inference(): 42 | """ 43 | Main function for running inference on given input text 44 | """ 45 | bart = BARTModel.from_pretrained( 46 | args.model_path, 47 | checkpoint_file='checkpoint_best.pt', 48 | data_name_or_path=args.model_config 49 | ) 50 | 51 | bart.cuda() 52 | bart.eval() 53 | bart.half() 54 | questions = [] 55 | ref_summaries = [] 56 | gen_summaries = [] 57 | articles = [] 58 | QUESTION_END = " [QUESTION?] " 59 | with open(args.input_file, 'r', encoding="utf-8") as f: 60 | source = json.load(f) 61 | batch_cnt = 0 62 | 63 | for q in tqdm(source): 64 | question = source[q]['question'] 65 | questions.append(question) 66 | # The data here may be prepared for the pointer generator, and it is currently easier to 67 | # clean the sentence tags out here, as opposed to making tagged and nontagged datasets. 68 | ref_summary = source[q]['summary'] 69 | if "" in ref_summary: 70 | ref_summary = ref_summary.replace("", "") 71 | ref_summary = ref_summary.replace("", "") 72 | ref_summaries.append(ref_summary) 73 | article = source[q]['articles'] 74 | if args.question_driven == "with_question": 75 | article = question + QUESTION_END + article 76 | articles.append(article) 77 | # Once the article list fills up, run a batch 78 | if len(articles) == args.batch_size: 79 | batch_cnt += 1 80 | print("Running batch {}".format(batch_cnt)) 81 | # Hyperparameters as recommended here: https://github.com/pytorch/fairseq/issues/1364 82 | with torch.no_grad(): 83 | predictions = bart.sample(articles, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) 84 | for pred in predictions: 85 | #print(pred) 86 | gen_summaries.append(pred) 87 | articles = [] 88 | print("Done with batch {}".format(batch_cnt)) 89 | 90 | if len(articles) != 0: 91 | predictions = bart.sample(articles, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) 92 | for pred in predictions: 93 | print(pred) 94 | gen_summaries.append(pred) 95 | 96 | assert len(gen_summaries) == len(ref_summaries) 97 | prediction_dict = { 98 | 'question': questions, 99 | 'ref_summary': ref_summaries, 100 | 'gen_summary': gen_summaries 101 | } 102 | 103 | with open(args.prediction_file, "w", encoding="utf-8") as f: 104 | json.dump(prediction_dict, f, indent=4) 105 | 106 | 107 | if __name__ == "__main__": 108 | global args 109 | args = get_args().parse_args() 110 | run_inference() 111 | -------------------------------------------------------------------------------- /models/bart/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Train a new model on one or across multiple GPUs. 8 | """ 9 | 10 | import collections 11 | import math 12 | import random 13 | 14 | import numpy as np 15 | import torch 16 | 17 | from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils 18 | from fairseq.data import iterators 19 | from fairseq.trainer import Trainer 20 | from fairseq.meters import AverageMeter, StopwatchMeter 21 | 22 | 23 | def main(args, init_distributed=False): 24 | utils.import_user_module(args) 25 | 26 | assert args.max_tokens is not None or args.max_sentences is not None, \ 27 | 'Must specify batch size either with --max-tokens or --max-sentences' 28 | 29 | # Initialize CUDA and distributed training 30 | if torch.cuda.is_available() and not args.cpu: 31 | torch.cuda.set_device(args.device_id) 32 | np.random.seed(args.seed) 33 | torch.manual_seed(args.seed) 34 | if init_distributed: 35 | args.distributed_rank = distributed_utils.distributed_init(args) 36 | 37 | if distributed_utils.is_master(args): 38 | checkpoint_utils.verify_checkpoint_directory(args.save_dir) 39 | 40 | # Print args 41 | print(args) 42 | 43 | # Setup task, e.g., translation, language modeling, etc. 44 | task = tasks.setup_task(args) 45 | 46 | # Load valid dataset (we load training data below, based on the latest checkpoint) 47 | for valid_sub_split in args.valid_subset.split(','): 48 | task.load_dataset(valid_sub_split, combine=False, epoch=0) 49 | 50 | # Build model and criterion 51 | model = task.build_model(args) 52 | criterion = task.build_criterion(args) 53 | #print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) 54 | #print('| num. model params: {} (num. trained: {})'.format( 55 | #sum(p.numel() for p in model.parameters()), 56 | #sum(p.numel() for p in model.parameters() if p.requires_grad), 57 | #)) 58 | 59 | # Build trainer 60 | trainer = Trainer(args, task, model, criterion) 61 | print('| training on {} GPUs'.format(args.distributed_world_size)) 62 | print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( 63 | args.max_tokens, 64 | args.max_sentences, 65 | )) 66 | 67 | # Load the latest checkpoint if one is available and restore the 68 | # corresponding train iterator 69 | extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) 70 | 71 | # Train until the learning rate gets too small 72 | max_epoch = args.max_epoch or math.inf 73 | max_update = args.max_update or math.inf 74 | lr = trainer.get_lr() 75 | train_meter = StopwatchMeter() 76 | train_meter.start() 77 | valid_subsets = args.valid_subset.split(',') 78 | while ( 79 | lr > args.min_lr 80 | and ( 81 | epoch_itr.epoch < max_epoch 82 | # allow resuming training from the final checkpoint 83 | or epoch_itr._next_epoch_itr is not None 84 | ) 85 | and trainer.get_num_updates() < max_update 86 | ): 87 | # train for one epoch 88 | train(args, trainer, task, epoch_itr) 89 | 90 | if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: 91 | valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) 92 | else: 93 | valid_losses = [None] 94 | 95 | # only use first validation loss to update the learning rate 96 | lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) 97 | 98 | # save checkpoint 99 | if epoch_itr.epoch % args.save_interval == 0: 100 | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) 101 | 102 | # early stop 103 | if should_stop_early(args, valid_losses[0]): 104 | print('| Early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) 105 | break 106 | 107 | reload_dataset = ':' in getattr(args, 'data', '') 108 | # sharded data: get train iterator for next epoch 109 | epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) 110 | train_meter.stop() 111 | print('| done training in {:.1f} seconds'.format(train_meter.sum)) 112 | 113 | 114 | def should_stop_early(args, valid_loss): 115 | if args.patience <= 0: 116 | return False 117 | 118 | def is_better(a, b): 119 | return a > b if args.maximize_best_checkpoint_metric else a < b 120 | 121 | prev_best = getattr(should_stop_early, 'best', None) 122 | if prev_best is None or is_better(valid_loss, prev_best): 123 | should_stop_early.best = valid_loss 124 | should_stop_early.num_runs = 0 125 | return False 126 | else: 127 | should_stop_early.num_runs += 1 128 | return should_stop_early.num_runs > args.patience 129 | 130 | 131 | def train(args, trainer, task, epoch_itr): 132 | """Train the model for one epoch.""" 133 | # Initialize data iterator 134 | itr = epoch_itr.next_epoch_itr( 135 | fix_batches_to_gpus=args.fix_batches_to_gpus, 136 | shuffle=(epoch_itr.epoch >= args.curriculum), 137 | ) 138 | update_freq = ( 139 | args.update_freq[epoch_itr.epoch - 1] 140 | if epoch_itr.epoch <= len(args.update_freq) 141 | else args.update_freq[-1] 142 | ) 143 | itr = iterators.GroupedIterator(itr, update_freq) 144 | progress = progress_bar.build_progress_bar( 145 | args, itr, epoch_itr.epoch, no_progress_bar='simple', 146 | ) 147 | 148 | extra_meters = collections.defaultdict(lambda: AverageMeter()) 149 | valid_subsets = args.valid_subset.split(',') 150 | max_update = args.max_update or math.inf 151 | for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): 152 | log_output = trainer.train_step(samples) 153 | if log_output is None: 154 | continue 155 | 156 | # log mid-epoch stats 157 | stats = get_training_stats(trainer) 158 | for k, v in log_output.items(): 159 | if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: 160 | continue # these are already logged above 161 | if 'loss' in k or k == 'accuracy': 162 | extra_meters[k].update(v, log_output['sample_size']) 163 | else: 164 | extra_meters[k].update(v) 165 | stats[k] = extra_meters[k].avg 166 | progress.log(stats, tag='train', step=stats['num_updates']) 167 | 168 | # ignore the first mini-batch in words-per-second and updates-per-second calculation 169 | if i == 0: 170 | trainer.get_meter('wps').reset() 171 | trainer.get_meter('ups').reset() 172 | 173 | num_updates = trainer.get_num_updates() 174 | if ( 175 | not args.disable_validation 176 | and args.save_interval_updates > 0 177 | and num_updates % args.save_interval_updates == 0 178 | and num_updates > 0 179 | ): 180 | valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) 181 | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) 182 | 183 | if num_updates >= max_update: 184 | break 185 | 186 | # log end-of-epoch stats 187 | stats = get_training_stats(trainer) 188 | for k, meter in extra_meters.items(): 189 | stats[k] = meter.avg 190 | progress.print(stats, tag='train', step=stats['num_updates']) 191 | 192 | # reset training meters 193 | for k in [ 194 | 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', 195 | ]: 196 | meter = trainer.get_meter(k) 197 | if meter is not None: 198 | meter.reset() 199 | 200 | 201 | def get_training_stats(trainer): 202 | stats = collections.OrderedDict() 203 | stats['loss'] = trainer.get_meter('train_loss') 204 | if trainer.get_meter('train_nll_loss').count > 0: 205 | nll_loss = trainer.get_meter('train_nll_loss') 206 | stats['nll_loss'] = nll_loss 207 | else: 208 | nll_loss = trainer.get_meter('train_loss') 209 | stats['ppl'] = utils.get_perplexity(nll_loss.avg) 210 | stats['wps'] = trainer.get_meter('wps') 211 | stats['ups'] = trainer.get_meter('ups') 212 | stats['wpb'] = trainer.get_meter('wpb') 213 | stats['bsz'] = trainer.get_meter('bsz') 214 | stats['num_updates'] = trainer.get_num_updates() 215 | stats['lr'] = trainer.get_lr() 216 | stats['gnorm'] = trainer.get_meter('gnorm') 217 | stats['clip'] = trainer.get_meter('clip') 218 | stats['oom'] = trainer.get_meter('oom') 219 | if trainer.get_meter('loss_scale') is not None: 220 | stats['loss_scale'] = trainer.get_meter('loss_scale') 221 | stats['wall'] = round(trainer.get_meter('wall').elapsed_time) 222 | stats['train_wall'] = trainer.get_meter('train_wall') 223 | return stats 224 | 225 | 226 | def validate(args, trainer, task, epoch_itr, subsets): 227 | """Evaluate the model on the validation set(s) and return the losses.""" 228 | 229 | if args.fixed_validation_seed is not None: 230 | # set fixed seed for every validation 231 | utils.set_torch_seed(args.fixed_validation_seed) 232 | 233 | valid_losses = [] 234 | for subset in subsets: 235 | # Initialize data iterator 236 | itr = task.get_batch_iterator( 237 | dataset=task.dataset(subset), 238 | max_tokens=args.max_tokens_valid, 239 | max_sentences=args.max_sentences_valid, 240 | max_positions=utils.resolve_max_positions( 241 | task.max_positions(), 242 | trainer.get_model().max_positions(), 243 | ), 244 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 245 | required_batch_size_multiple=args.required_batch_size_multiple, 246 | seed=args.seed, 247 | num_shards=args.distributed_world_size, 248 | shard_id=args.distributed_rank, 249 | num_workers=args.num_workers, 250 | ).next_epoch_itr(shuffle=False) 251 | progress = progress_bar.build_progress_bar( 252 | args, itr, epoch_itr.epoch, 253 | prefix='valid on \'{}\' subset'.format(subset), 254 | no_progress_bar='simple' 255 | ) 256 | 257 | # reset validation loss meters 258 | for k in ['valid_loss', 'valid_nll_loss']: 259 | meter = trainer.get_meter(k) 260 | if meter is not None: 261 | meter.reset() 262 | extra_meters = collections.defaultdict(lambda: AverageMeter()) 263 | 264 | for sample in progress: 265 | log_output = trainer.valid_step(sample) 266 | 267 | for k, v in log_output.items(): 268 | if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: 269 | continue 270 | extra_meters[k].update(v) 271 | 272 | # log validation stats 273 | stats = get_valid_stats(trainer, args, extra_meters) 274 | for k, meter in extra_meters.items(): 275 | stats[k] = meter.avg 276 | progress.print(stats, tag=subset, step=trainer.get_num_updates()) 277 | 278 | valid_losses.append( 279 | stats[args.best_checkpoint_metric].avg 280 | if args.best_checkpoint_metric == 'loss' 281 | else stats[args.best_checkpoint_metric] 282 | ) 283 | return valid_losses 284 | 285 | 286 | def get_valid_stats(trainer, args, extra_meters=None): 287 | stats = collections.OrderedDict() 288 | stats['loss'] = trainer.get_meter('valid_loss') 289 | if trainer.get_meter('valid_nll_loss').count > 0: 290 | nll_loss = trainer.get_meter('valid_nll_loss') 291 | stats['nll_loss'] = nll_loss 292 | else: 293 | nll_loss = stats['loss'] 294 | stats['ppl'] = utils.get_perplexity(nll_loss.avg) 295 | stats['num_updates'] = trainer.get_num_updates() 296 | if hasattr(checkpoint_utils.save_checkpoint, 'best'): 297 | key = 'best_{0}'.format(args.best_checkpoint_metric) 298 | best_function = max if args.maximize_best_checkpoint_metric else min 299 | 300 | current_metric = None 301 | if args.best_checkpoint_metric == 'loss': 302 | current_metric = stats['loss'].avg 303 | elif args.best_checkpoint_metric in extra_meters: 304 | current_metric = extra_meters[args.best_checkpoint_metric].avg 305 | elif args.best_checkpoint_metric in stats: 306 | current_metric = stats[args.best_checkpoint_metric] 307 | else: 308 | raise ValueError("best_checkpoint_metric not found in logs") 309 | 310 | stats[key] = best_function( 311 | checkpoint_utils.save_checkpoint.best, 312 | current_metric, 313 | ) 314 | return stats 315 | 316 | 317 | def distributed_main(i, args, start_rank=0): 318 | args.device_id = i 319 | if args.distributed_rank is None: # torch.multiprocessing.spawn 320 | args.distributed_rank = start_rank + i 321 | main(args, init_distributed=True) 322 | 323 | 324 | def cli_main(): 325 | parser = options.get_training_parser() 326 | args = options.parse_args_and_arch(parser) 327 | 328 | if args.distributed_init_method is None: 329 | distributed_utils.infer_init_method(args) 330 | 331 | if args.distributed_init_method is not None: 332 | # distributed training 333 | if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: 334 | start_rank = args.distributed_rank 335 | args.distributed_rank = None # assign automatically 336 | torch.multiprocessing.spawn( 337 | fn=distributed_main, 338 | args=(args, start_rank), 339 | nprocs=torch.cuda.device_count(), 340 | ) 341 | else: 342 | distributed_main(args.device_id, args) 343 | elif args.distributed_world_size > 1: 344 | # fallback for single node with multiple GPUs 345 | assert args.distributed_world_size <= torch.cuda.device_count() 346 | port = random.randint(10000, 20000) 347 | args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) 348 | args.distributed_rank = None # set based on device id 349 | if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d': 350 | print('| NOTE: you may get better performance with: --ddp-backend=no_c10d') 351 | torch.multiprocessing.spawn( 352 | fn=distributed_main, 353 | args=(args, ), 354 | nprocs=args.distributed_world_size, 355 | ) 356 | else: 357 | # single GPU training 358 | main(args) 359 | 360 | 361 | if __name__ == '__main__': 362 | cli_main() 363 | -------------------------------------------------------------------------------- /models/baselines.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for baseline approaches: 3 | 1. Take 10 random sentences 4 | 2. Take the topk 10 sentences with highest rouge score relative to the question 5 | 3. Pick the first 10 sentences 6 | 7 | To run 8 | python baselines.py --dataset=chiqa 9 | """ 10 | 11 | import json 12 | import numpy as np 13 | import random 14 | import requests 15 | import argparse 16 | 17 | import rouge 18 | import spacy 19 | 20 | 21 | def get_args(): 22 | """ 23 | Get command line arguments 24 | """ 25 | parser = argparse.ArgumentParser(description="Arguments for data exploration") 26 | parser.add_argument("--dataset", 27 | dest="dataset", 28 | help="Dataset to run baselines on. Only current option is MEDIQA-AnS.") 29 | return parser 30 | 31 | 32 | def calculate_sentence_level_rouge(question, doc_sen, evaluator): 33 | """ 34 | For each pair of sentences, calculate rouge score with py-rouge 35 | """ 36 | rouge_score = evaluator.get_scores(doc_sen, question)['rouge-l']['f'] 37 | return rouge_score 38 | 39 | 40 | def pick_k_best_rouge_sentences(k, questions, documents, summaries): 41 | """ 42 | Pick the k sentences that have the highest rouge scores when compared to the question. 43 | """ 44 | # Initiate rouge evaluator 45 | evaluator = rouge.Rouge(metrics=['rouge-l'], 46 | max_n=3, 47 | limit_length=False, 48 | length_limit_type='words', 49 | apply_avg=False, 50 | apply_best=True, 51 | alpha=1, 52 | weight_factor=1.2, 53 | stemming=False) 54 | 55 | pred_dict = { 56 | 'question': [], 57 | 'ref_summary': [], 58 | 'gen_summary': [] 59 | } 60 | for q, doc, summ, in zip(questions, documents, summaries): 61 | # Sentencize abstract 62 | rouge_scores = [] 63 | for sentence in doc: 64 | rouge_score = calculate_sentence_level_rouge(q, sentence, evaluator) 65 | rouge_scores.append(rouge_score) 66 | if len(doc) < k: 67 | top_k_rouge_scores = np.argsort(rouge_scores) 68 | else: 69 | top_k_rouge_scores = np.argsort(rouge_scores)[-k:] 70 | top_k_sentences = " ".join([doc[i] for i in top_k_rouge_scores]) 71 | summ = summ.replace("", "") 72 | summ = summ.replace("", "") 73 | pred_dict['question'].append(q) 74 | pred_dict['ref_summary'].append(summ) 75 | pred_dict['gen_summary'].append(top_k_sentences) 76 | 77 | return pred_dict 78 | 79 | 80 | def pick_first_k_sentences(k, questions, documents, summaries): 81 | """ 82 | Pick the first k sentences to use as summaries 83 | """ 84 | pred_dict = { 85 | 'question': [], 86 | 'ref_summary': [], 87 | 'gen_summary': [] 88 | } 89 | for q, doc, summ, in zip(questions, documents, summaries): 90 | if len(doc) < k: 91 | first_k_sentences = doc 92 | else: 93 | first_k_sentences = doc[0:k] 94 | first_k_sentences = " ".join(first_k_sentences) 95 | summ = summ.replace("", "") 96 | summ = summ.replace("", "") 97 | pred_dict['question'].append(q) 98 | pred_dict['ref_summary'].append(summ) 99 | pred_dict['gen_summary'].append(first_k_sentences) 100 | 101 | return pred_dict 102 | 103 | 104 | def pick_k_random_sentences(k, questions, documents, summaries): 105 | """ 106 | Pick k random sentences from the articles to use as summaries 107 | """ 108 | pred_dict = { 109 | 'question': [], 110 | 'ref_summary': [], 111 | 'gen_summary': [] 112 | } 113 | random.seed(13) 114 | for q, doc, summ, in zip(questions, documents, summaries): 115 | if len(doc) < k: 116 | random_sentences = " ".join(doc) 117 | else: 118 | random_sentences = random.sample(doc, k) 119 | random_sentences = " ".join(random_sentences) 120 | summ = summ.replace("", "") 121 | summ = summ.replace("", "") 122 | pred_dict['question'].append(q) 123 | pred_dict['ref_summary'].append(summ) 124 | pred_dict['gen_summary'].append(random_sentences) 125 | 126 | return pred_dict 127 | 128 | 129 | def load_dataset(path): 130 | """ 131 | Load the evaluation set 132 | """ 133 | with open(path, "r", encoding="utf-8") as f: 134 | asumm_data = json.load(f) 135 | 136 | summaries = [] 137 | questions = [] 138 | documents = [] 139 | nlp = spacy.load('en_core_web_sm') 140 | # Split sentences 141 | cnt = 0 142 | for q_id in asumm_data: 143 | questions.append(asumm_data[q_id]['question']) 144 | tokenized_art = nlp(asumm_data[q_id]['articles']) 145 | summaries.append(asumm_data[q_id]['summary']) 146 | article_sentences = [s.text.strip() for s in tokenized_art.sents] 147 | documents.append(article_sentences[0:]) 148 | return questions, documents, summaries 149 | 150 | 151 | def save_baseline(baseline, filename): 152 | """ 153 | Save baseline in format for rouge evaluation 154 | """ 155 | with open("../evaluation/data/baselines/chiqa_eval/baseline_{}.json".format(filename), "w", encoding="utf-8") as f: 156 | json.dump(baseline, f, indent=4) 157 | 158 | 159 | def run_baselines(): 160 | """ 161 | Generate the random baseline and the best rouge baseline 162 | """ 163 | # Load the MEDIQA-AnS datasets 164 | datasets = [ 165 | ("../data_processing/data/page2answer_single_abstractive_summ.json", "p2a-single-abs"), 166 | ("../data_processing/data/page2answer_single_extractive_summ.json", "p2a-single-ext"), 167 | ("../data_processing/data/section2answer_single_abstractive_summ.json", "s2a-single-abs"), 168 | ("../data_processing/data/section2answer_single_extractive_summ.json", "s2a-single-ext"), 169 | ] 170 | 171 | for data in datasets: 172 | task = data[1] 173 | print("Running baselines on {}".format(task)) 174 | # k can be determined from averages or medians of summary types of reference summaries. Alternatively, just use Lead-3 baseline. 175 | # Optional to use different k for extractive and abstractive summaries, as the manual summaries of the two types have different average lengths 176 | if task == "p2a-single-abs": 177 | k = 3 178 | if task == "p2a-single-ext": 179 | k = 3 180 | if task == "s2a-single-abs": 181 | k = 3 182 | if task == "s2a-single-ext": 183 | k = 3 184 | questions, documents, summaries = load_dataset(data[0]) 185 | k_sentences = pick_k_random_sentences(k, questions, documents, summaries) 186 | first_k_sentences = pick_first_k_sentences(k, questions, documents, summaries) 187 | k_best_rouge = pick_k_best_rouge_sentences(k, questions, documents, summaries) 188 | 189 | save_baseline(k_sentences, filename="random_sentences_k_{}_{}_{}".format(k, args.dataset, task)) 190 | save_baseline(first_k_sentences, filename="first_sentences_k_{}_{}_{}".format(k, args.dataset, task)) 191 | save_baseline(k_best_rouge, filename="best_rouge_k_{}_{}_{}".format(k, args.dataset, task)) 192 | 193 | 194 | if __name__ == "__main__": 195 | args = get_args().parse_args() 196 | run_baselines() 197 | -------------------------------------------------------------------------------- /models/bilstm/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/bilstm/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /models/bilstm/__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/bilstm/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /models/bilstm/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/bilstm/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /models/bilstm/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/bilstm/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /models/bilstm/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for data processing. 3 | 4 | Includes classes for loading the data and generating examples 5 | """ 6 | 7 | import json 8 | import tensorflow_datasets as tfds 9 | import tensorflow as tf 10 | from sklearn.utils import shuffle as sk_shuffle 11 | import numpy as np 12 | import random 13 | import spacy 14 | 15 | 16 | class Vocab(): 17 | """ 18 | Class to generate vocab for model 19 | """ 20 | 21 | def train_tokenizer(self, tokenizer_filename, training_data, vocab_size): 22 | """ 23 | Train the subword tokenizer 24 | """ 25 | encoder = tfds.features.text.SubwordTextEncoder.build_from_corpus( 26 | (sentence for abstract in training_data for sentence in abstract), target_vocab_size=vocab_size) 27 | encoder.save_to_file(tokenizer_filename) 28 | 29 | def _load_tokenizer(self, tokenizer_filename): 30 | """ 31 | Load the trained subword tokenizer 32 | """ 33 | encoder = tfds.features.text.SubwordTextEncoder.load_from_file(tokenizer_filename) 34 | return encoder 35 | 36 | def encode_data(self, tokenizer_filename, data): 37 | encoder = self._load_tokenizer(tokenizer_filename) 38 | encoded_question = [encoder.encode(question) for question in data[0]] 39 | encoded_answer = [[encoder.encode(sentence) for sentence in abstract] for abstract in data[1]] 40 | assert len(encoded_question) == len(encoded_answer) 41 | data = [encoded_question, encoded_answer] 42 | return data 43 | 44 | 45 | class DataLoader(): 46 | """ 47 | Class to load data and generate train and validation datasets 48 | """ 49 | 50 | def __init__(self, data_path, max_tok_q, max_sentences, max_tok_sent, dataset, summary_type): 51 | """ 52 | Initiate loader 53 | """ 54 | self.data_path = data_path 55 | self.max_tok_q = max_tok_q 56 | self.max_sentences = max_sentences 57 | self.max_tok_sent = max_tok_sent 58 | self.dataset = dataset 59 | self.summary_type = summary_type 60 | 61 | def split_data(self, data): 62 | """ 63 | Shuffle and divide data up into training and validation sets 64 | """ 65 | questions, documents, scores = data 66 | # Shuffle data: 67 | assert len(questions) == len(documents) 68 | assert len(questions) == len(scores) 69 | documents, questions, scores = sk_shuffle(documents, questions, scores, random_state=13) 70 | training_index = int(len(questions) * .8) 71 | x_train = [questions[:training_index], documents[:training_index]] 72 | y_train = scores[:training_index] 73 | x_val = [questions[training_index:], documents[training_index:]] 74 | y_val = scores[training_index:] 75 | 76 | return x_train, y_train, x_val, y_val 77 | 78 | def pad_data(self, data, data_type, padding_type='float32', test=False): 79 | """ 80 | Method for padding data for training, validation, or inference. 81 | """ 82 | if data_type == "question": 83 | padded_data = tf.keras.preprocessing.sequence.pad_sequences(data, padding='post', maxlen=self.max_tok_q) 84 | elif data_type == "scores": 85 | padded_data = tf.keras.preprocessing.sequence.pad_sequences(data, padding='post', dtype=padding_type, maxlen=self.max_sentences) 86 | elif data_type == "article": 87 | padded_data = [] 88 | # Mask for sentences in document embedding 89 | #sentence_masks = [] 90 | for doc in data: 91 | if len(doc) > self.max_sentences: 92 | doc = doc[:self.max_sentences] 93 | #sentence_masks.append(np.ones(len(doc), dype=bool)) 94 | elif len(doc) < self.max_sentences: 95 | # Add the sentences that are missing. 96 | extra_docs = [[] for i in range(self.max_sentences - len(doc))] 97 | doc.extend(extra_docs) 98 | #sentence_mask = [False if d == [] else True for d in doc] 99 | #sentence_masks.append(np.array(sentence_masks)) 100 | assert len(doc) == self.max_sentences 101 | padded_data.append(tf.keras.preprocessing.sequence.pad_sequences(doc, padding='post', maxlen=self.max_tok_sent)) 102 | 103 | # For asserting correct padding: 104 | if test: 105 | if data_type == "question": 106 | for q in padded_data: 107 | assert len(q) == self.max_tok_q, q 108 | if data_type == "article": 109 | #assert len(sentence_masks) == len(padded_data), len(sentence_masks) 110 | #for m in sentence_masks: 111 | # assert len(m) == self.max_sentences, len(m) 112 | for doc in padded_data: 113 | assert isinstance(doc, np.ndarray), type(doc) 114 | assert len(doc) == self.max_sentences, len(doc) 115 | for sent in doc: 116 | assert isinstance(sent, np.ndarray), type(sent) 117 | assert len(sent) == self.max_tok_sent, len(sent) 118 | assert isinstance(sent[0], np.int32), (sent[0], type(sent[0])) 119 | elif data_type == "scores": 120 | for doc in padded_data: 121 | assert isinstance(doc, np.ndarray), type(doc) 122 | assert len(doc) == self.max_sentences, len(doc) 123 | # Convert abstracts to np array here 124 | # pad_sequences converts lists of lists to np arrays, which is the form the questions 125 | # and sentences are in. 126 | # Reshape y_train and val from 2d > 3d 127 | if data_type == "scores": 128 | padded_data = np.reshape(padded_data, padded_data.shape + (1,)) 129 | return padded_data 130 | elif data_type == "article": 131 | return np.array(padded_data) 132 | else: 133 | return padded_data 134 | 135 | def load_data(self, mode, tag_sentences): 136 | """ 137 | Open the data and split into train/val. 138 | """ 139 | questions = [] 140 | documents = [] 141 | with open(self.data_path, "r", encoding="utf-8") as f: 142 | asumm_data = json.load(f) 143 | 144 | if mode == "train": 145 | scores = [] 146 | for q_id in asumm_data: 147 | questions.append(asumm_data[q_id]['question']) 148 | documents.append(asumm_data[q_id]['sentences']) 149 | scores.append(asumm_data[q_id]['labels']) 150 | return questions, documents, scores 151 | 152 | if mode == "infer": 153 | summaries = [] 154 | nlp = spacy.load('en_core_web_sm') 155 | if self.dataset == "chiqa" or self.dataset=="medinfo": 156 | question_ids = [] 157 | # There are multiple summary tasks withing the chiqa dataset: single and multi doc require different processing. 158 | # Handle the single answer -> summary case: 159 | if "single" in self.summary_type: 160 | for q_id in asumm_data: 161 | question_ids.append(q_id) 162 | question = asumm_data[q_id]['question'] 163 | questions.append(question) 164 | summary = asumm_data[q_id]['summary'] 165 | tokenized_art = nlp(asumm_data[q_id]['articles']) 166 | tokenized_summ = nlp(summary) 167 | # Split sentences and tag with s if option included for pointer generator 168 | if tag_sentences: 169 | summary = " ".join([" {s} ".format(s=s.text.strip()) for s in tokenized_summ.sents]) 170 | summaries.append(summary) 171 | article_sentences = [s.text.strip() for s in tokenized_art.sents] 172 | documents.append(article_sentences) 173 | return questions, documents, summaries, question_ids 174 | 175 | 176 | if __name__ == "__main__": 177 | # Locally testing data classes: 178 | max_tok_q = 20 179 | max_sentences = 75 180 | max_tok_sent = 20 181 | hidden_dim = 256 182 | dataset = "bioasq" 183 | tag_sentences = False 184 | mode = "train" 185 | data_loader = DataLoader("/data/saveryme/asumm/asumm_data/training_data/bioasq_abs2summ_sent_classification_training.json", max_tok_q, max_sentences, max_tok_sent, dataset) 186 | data = data_loader.load_data(mode, tag_sentences) 187 | x_train, y_train, x_val, y_val = data_loader.split_data(data) 188 | # x_train[0] = x_train[0][:3] 189 | # x_train[1] = x_train[1][:3] 190 | # x_val[0] = x_val[0][:3] 191 | # x_val[1] = x_val[1][:3] 192 | # y_train = y_train[:3] 193 | # y_val = y_val[:3] 194 | 195 | # print("Text!") 196 | print("Questions:\n", x_train[0][0]) 197 | print("Sentences:\n", x_train[1][0]) 198 | print("encoding data") 199 | x_train = Vocab().encode_data("medsumm_bioasq_abs2summ/tokenizer", x_train) 200 | 201 | sent_cnt = 0 202 | sent_len = 0 203 | for doc in x_train[1]: 204 | for sentence in doc: 205 | sent_cnt += 1 206 | sent_len += len(sentence) 207 | 208 | x_val = Vocab().encode_data("medsumm_bioasq_abs2summ/tokenizer", x_val) 209 | print("padding data") 210 | x_train[0] = data_loader.pad_data(x_train[0], data_type="question", test=True) 211 | x_train[1] = data_loader.pad_data(x_train[1], data_type="article", test=True) 212 | x_val[0] = data_loader.pad_data(x_val[0], data_type="question", test=True) 213 | x_val[1] = data_loader.pad_data(x_val[1], data_type="article", test=True) 214 | y_train = data_loader.pad_data(y_train, data_type="scores", test=True) 215 | y_val = data_loader.pad_data(y_val, data_type="scores", test=True) 216 | assert x_train[0].shape == (30532, max_tok_q), x_train[0].shape 217 | assert x_train[1].shape == (30532, max_sentences, max_tok_sent), x_train[1].shape 218 | assert x_val[0].shape == (7634, max_tok_q), x_val[0].shape 219 | assert x_val[1].shape == (7634, max_sentences, max_tok_sent), x_val[1].shape 220 | assert y_train.shape == (30532, max_sentences, 1), y_train.shape 221 | assert y_val.shape == (7634, max_sentences, 1), y_val.shape 222 | print("Avg. subwords/sentence:", sent_len / sent_cnt) 223 | -------------------------------------------------------------------------------- /models/bilstm/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model module for constructing the tensorflow graph for the LSTM sentence classifier 3 | """ 4 | 5 | import tensorflow as tf 6 | from tensorflow.keras import layers 7 | 8 | class SentenceClassificationModel(): 9 | """ 10 | Class for model selecting sentences from relevant documents for answer summarization 11 | """ 12 | 13 | def __init__(self, vocab_size, batch_size, hidden_dim, dropout, max_tok_q, max_sentences, max_tok_sent): 14 | """ 15 | Initiate the mode 16 | """ 17 | self.vocab_size = vocab_size 18 | self.batch_size = batch_size 19 | self.hidden_dim = hidden_dim 20 | self.max_tok_q = max_tok_q 21 | self.max_tok_sent = max_tok_sent 22 | self.max_sentences = max_sentences 23 | self.dropout = dropout 24 | 25 | def build_binary_model(self): 26 | """ 27 | Construct the graph of the model 28 | """ 29 | question_input = tf.keras.Input(shape=(self.max_tok_q, ), name='q_input') 30 | abstract_input = tf.keras.Input(shape=(self.max_sentences, self.max_tok_sent, ), name='abs_input') 31 | # NOT USING MASKING DUE TO CUDNN ERROR: https://github.com/tensorflow/tensorflow/issues/33148 32 | x1 = layers.Embedding(input_dim=self.vocab_size, output_dim=self.hidden_dim, mask_zero=False)(question_input) 33 | x1 = layers.Bidirectional(layers.LSTM(self.hidden_dim, dropout=self.dropout, kernel_regularizer=tf.keras.regularizers.l2(0.01)), input_shape=(self.max_tok_q, self.hidden_dim), name='q_bilstm')(x1) 34 | 35 | # Apply embedding to every sentence 36 | x2 = layers.TimeDistributed(layers.Embedding(input_dim=self.vocab_size, output_dim=self.hidden_dim, input_length=self.max_tok_sent, mask_zero=False), input_shape=(self.max_sentences, self.max_tok_sent))(abstract_input) 37 | # Apply lstm to every sentence embedding 38 | x2 = layers.TimeDistributed(layers.Bidirectional(layers.LSTM(self.hidden_dim, dropout=self.dropout, kernel_regularizer=tf.keras.regularizers.l2(0.01))), input_shape=(self.max_sentences, self.max_tok_sent, self.hidden_dim), name='sentence_distributed_bilstms')(x2) 39 | # Make lstm of document representation: 40 | # I could also just take this document representation and concatenate it to the single sentence representation, but I don't. 41 | x2 = layers.Bidirectional(layers.LSTM(self.hidden_dim, return_sequences=True, dropout=self.dropout, kernel_regularizer=tf.keras.regularizers.l2(0.01)), input_shape=(self.max_sentences, self.hidden_dim * 2), name='document_bilstm')(x2) 42 | # Combine question and document 43 | x3 = layers.RepeatVector(self.max_sentences)(x1) 44 | x4 = layers.concatenate([x2, x3]) 45 | 46 | # If using integers as class labels, 1 target label can be provided be example (not 1 hot) and the number of labels can be defined here 47 | sent_output = layers.Dense(2, activation='sigmoid', name='sent_output', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x4) 48 | 49 | model = tf.keras.Model(inputs=[question_input, abstract_input], outputs=sent_output) 50 | model.summary() 51 | model.compile(loss='sparse_categorical_crossentropy', 52 | optimizer=tf.keras.optimizers.Adam(1e-4) 53 | ) 54 | 55 | return model 56 | -------------------------------------------------------------------------------- /models/bilstm/requirements.txt: -------------------------------------------------------------------------------- 1 | spacy 2 | numpy 3 | tensorflow-datasets 4 | sklearn 5 | -------------------------------------------------------------------------------- /models/bilstm/run_chiqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | exp_name=medsumm_bioasq_abs2summ 3 | tensorboard_log=dropout_5_sent_200_tok_50_val_20_d_256_l2_reg_binary 4 | 5 | for summ_task in page2answer section2answer 6 | do 7 | for summ_type in single_abstractive single_extractive 8 | do 9 | for k in 3 10 | do 11 | # Optional to specify different k, for selecting top k sentences. If you are interested in using the output of the sentence classifier as input for a generative model, 12 | # you may want to increase k 13 | echo "k:" $k 14 | data=../../data_processing/data/${summ_task}_${summ_type}_summ.json 15 | echo $data 16 | prediction_file=predictions_chiqa_${summ_task}_${summ_type}_dropout_5_sent_200_tok_50_val_20_d_256_l2_reg_binary_topk${k}.json 17 | eval_file=../../evaluation/data/sentence_classifier/chiqa_eval/sent_class_chiqa_${summ_task}_${summ_type}_dropout_5_sent_200_tok_50_val_20_d_256_l2_reg_binary_topk${k}.json 18 | python run_classifier.py \ 19 | --exp_name=${exp_name} \ 20 | --mode=infer \ 21 | --data_path=$data \ 22 | --dataset=chiqa \ 23 | --summary_type=${summ_task}_${summ_type} \ 24 | --train_tokenizer=False \ 25 | --tokenizer_path=./${exp_name}/tokenizer \ 26 | --model_path=${exp_name}/${tensorboard_log}/bioasq_abs2summ_sent_class_model.h5 \ 27 | --batch_size=32 \ 28 | --max_sentences=200 \ 29 | --max_tok_sent=50 \ 30 | --max_tok_q=50 \ 31 | --dropout=.5 \ 32 | --hidden_dim=256 \ 33 | --binary_model=True \ 34 | --prediction_file=${exp_name}/${tensorboard_log}/${prediction_file} \ 35 | --eval_file=$eval_file \ 36 | --tag_sentences=False \ 37 | --top_k_sent=${k} 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /models/bilstm/run_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runner script to train a simple lstm sentence classifier, 3 | using best rouge score for a sentence in the article compared to each sentence in the summary 4 | as the y labels. 5 | 6 | Should I create labels with rouge or with n sentences selected based on highes score. 7 | """ 8 | 9 | import time 10 | from absl import app, flags 11 | import logging 12 | import os 13 | import sys 14 | import json 15 | import re 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from data import Vocab, DataLoader 21 | from model import SentenceClassificationModel 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | # Paths 26 | flags.DEFINE_string('data_path', '', 'Path expression to tf.Example datafiles. Can include wildcards to access multiple datafiles.') 27 | flags.DEFINE_string('mode', 'train', 'Select train/infer') 28 | flags.DEFINE_string('exp_name', '', 'Name for experiment. Tensorboard logs and models will be saved in a directories under this one') 29 | flags.DEFINE_string('tensorboard_log', '', 'Name of specific run for tb to log to. The experiment name will be the parent directory') 30 | flags.DEFINE_string('model_path', '', "Path to save model checkpoint") 31 | flags.DEFINE_string('dataset', '', "The dataset used for inference, used to specify how to load the data, e.g., medinfo") 32 | flags.DEFINE_string('summary_type', '', "The summary task within the chiqa dataset. The multi and single document tasks require different data handling") 33 | flags.DEFINE_string('prediction_file', '', "File to save predictions to be used for generative model") 34 | flags.DEFINE_string('eval_file', '', "File to save preds for direct evaluation") 35 | 36 | # Tokenizer training 37 | flags.DEFINE_string('tokenizer_path', '', 'Path to save tokenizer once trained') 38 | flags.DEFINE_boolean('train_tokenizer', False, "Flag to train a new tokenizer on the training corpus") 39 | 40 | # Data processing 41 | flags.DEFINE_boolean("tag_sentences", False, "For use with mode=infer. Tag the article sentences with and , when using pointer generator network as second step, if the data has not already been tagged") 42 | 43 | # Hyperparameters and such 44 | flags.DEFINE_boolean('binary_model', False, "Flag to use binary model or regression model") 45 | flags.DEFINE_integer('vocab_size', 2**15, 'Size of subword vocabulary. Not sure if this is what I am definitely going to do. May be better to use pretrained embeddings') 46 | flags.DEFINE_integer('batch_size', 32, 'batch size') 47 | flags.DEFINE_integer('n_epochs', 10, 'Max number of epochs to run') 48 | flags.DEFINE_integer('max_tok_q', 100, 'max number of tokens for question') 49 | flags.DEFINE_integer('max_sentences', 10, 'max number of sentences') 50 | flags.DEFINE_integer('max_tok_sent', 100, 'max number of subword tokens/sentence') 51 | flags.DEFINE_integer('hidden_dim', 128, 'dimension of lstm') 52 | flags.DEFINE_float('dropout', .2, 'dropout proportion') 53 | flags.DEFINE_float('decision_threshold', .3, "Threshold for selecting relevant sentences during inference. No current implementation") 54 | flags.DEFINE_integer('top_k_sent', 10, "Number of sentences to select.") 55 | 56 | 57 | def run_training(model, x_train, y_train, x_val, y_val): 58 | """ 59 | Run training in loop for n_epochs 60 | """ 61 | logging.info("Beginning training\n") 62 | tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.tensorboard_log, update_freq=5000, profile_batch=0) 63 | checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 64 | filepath='{0}/{1}'.format(FLAGS.tensorboard_log, FLAGS.model_path), 65 | # overwrite current checkpoint if `val_loss` has improved. 66 | save_best_only=True, 67 | monitor='val_loss', 68 | save_freq='epoch', 69 | verbose=0) 70 | 71 | model.fit({'q_input': x_train[0], 'abs_input': x_train[1]}, {'sent_output': y_train}, 72 | shuffle=True, 73 | epochs=FLAGS.n_epochs, 74 | validation_data=(x_val, y_val), 75 | callbacks=[tensorboard_callback, checkpoint_callback]) 76 | logging.info("Training completed!\n") 77 | 78 | 79 | def run_inference(model, data, threshold, top_k_sent, binary_model): 80 | """ 81 | Given a set of documents and a question, predict relevant sentences 82 | """ 83 | predictions = model.predict({'q_input': data[0], 'abs_input': data[1]}) 84 | logging.info("Predictions shape: {}".format(predictions.shape)) 85 | if binary_model: 86 | # Select preds for just label 1 87 | reduced_preds = predictions[:, :, 1] 88 | reduced_preds = tf.squeeze(reduced_preds) 89 | filtered_predictions = tf.math.top_k(reduced_preds, k=top_k_sent, sorted=True) 90 | else: 91 | reduced_preds = tf.squeeze(predictions) 92 | filtered_predictions = tf.math.top_k(reduced_preds, k=top_k_sent) 93 | return filtered_predictions 94 | 95 | 96 | def save_predictions(predictions, data, binary_model): 97 | """ 98 | Get the indicies of sentences predicted to be above rouge threshold and save to json. 99 | """ 100 | pred_dict = {} 101 | # Also save the data in format for direct evaluation 102 | questions = [] 103 | ref_summaries = [] 104 | gen_summaries = [] 105 | question_ids = [] 106 | pred_dict = {} 107 | q_cnt = 0 108 | for p, indices in zip(predictions.values, predictions.indices): 109 | question = data[0][q_cnt] 110 | question_id = data[3][q_cnt] 111 | pred_dict[question_id] = {} 112 | # Get the human generated summary 113 | ref_summary = data[2][q_cnt] 114 | pred_dict[question_id]['summary'] = ref_summary 115 | pred_dict[question_id]['question'] = question 116 | # Remove any predicted indices that are out of the article's range 117 | indices = indices.numpy()[indices.numpy() < len(data[1][q_cnt])] 118 | sentences = " ".join(list(np.array(data[1][q_cnt])[indices])) 119 | pred_dict[question_id]['articles'] = sentences 120 | pred_dict[question_id]['predicted_score'] = p.numpy().tolist() 121 | # Format for rouge evaluation 122 | questions.append(question) 123 | question_ids.append(question_id) 124 | # Remove any sentence tags from data that will be evaluated directly and not passed to pg 125 | ref_summary = ref_summary.replace("", "") 126 | ref_summary = ref_summary.replace("", "") 127 | ref_summaries.append(ref_summary) 128 | gen_summaries.append(sentences) 129 | q_cnt += 1 130 | 131 | predictions_for_eval = {'question_id': question_ids, 'question': questions, 'ref_summary': ref_summaries, 'gen_summary': gen_summaries} 132 | 133 | with open(FLAGS.prediction_file, "w", encoding="utf-8") as f: 134 | json.dump(pred_dict, f, indent=4) 135 | 136 | with open(FLAGS.eval_file, "w", encoding="utf-8") as f: 137 | json.dump(predictions_for_eval, f, indent=4) 138 | 139 | 140 | def main(argv): 141 | """ 142 | Main function for running sentence classifier 143 | """ 144 | logging.info("Num GPUs Available: {}\n".format(len(tf.config.experimental.list_physical_devices('GPU')))) 145 | logging.basicConfig(filename="{}/medsumm.log".format(FLAGS.exp_name), filemode='w', level=logging.DEBUG) 146 | logging.info("Initiating sentence classication model...\n") 147 | logging.info("Loading data:") 148 | data_loader = DataLoader(FLAGS.data_path, FLAGS.max_tok_q, FLAGS.max_sentences, FLAGS.max_tok_sent, FLAGS.dataset, FLAGS.summary_type) 149 | # Returns tuple 150 | data = data_loader.load_data(FLAGS.mode, FLAGS.tag_sentences) 151 | if FLAGS.mode == "train": 152 | x_train, y_train, x_val, y_val = data_loader.split_data(data) 153 | logging.info("Questions:") 154 | logging.info(x_train[0][:2]) 155 | logging.info("Sentences:") 156 | logging.info(x_train[1][:2]) 157 | 158 | vocab_processor = Vocab() 159 | if FLAGS.train_tokenizer: 160 | logging.info("Training tokenizer\n") 161 | vocab_processor.train_tokenizer(FLAGS.tokenizer_path, data[1], FLAGS.vocab_size) 162 | # Once trained, get the subword tokenizer and encode the data 163 | logging.info("Encoding text") 164 | if FLAGS.mode == "train": 165 | logging.info("Encoding data") 166 | x_train = vocab_processor.encode_data(FLAGS.tokenizer_path, x_train) 167 | x_val = vocab_processor.encode_data(FLAGS.tokenizer_path, x_val) 168 | logging.info("Padding encodings") 169 | x_train[0] = data_loader.pad_data(x_train[0], data_type="question", test=True) 170 | x_train[1] = data_loader.pad_data(x_train[1], data_type="article", test=True) 171 | x_val[0] = data_loader.pad_data(x_val[0], data_type="question", test=True) 172 | x_val[1] = data_loader.pad_data(x_val[1], data_type="article", test=True) 173 | if FLAGS.binary_model: 174 | padding_type = "int32" 175 | else: 176 | padding_type = "float32" 177 | y_train = data_loader.pad_data(y_train, data_type="scores", padding_type=padding_type, test=True) 178 | y_val = data_loader.pad_data(y_val, data_type="scores", padding_type=padding_type, test=True) 179 | logging.info("Data shape:") 180 | logging.info("x_train questions: {}".format(x_train[0].shape)) 181 | logging.info("x_train documents: {}".format(x_train[1].shape)) 182 | logging.info("x_val questions: {}".format(x_val[0].shape)) 183 | logging.info("x_val documents: {}".format(x_val[1].shape)) 184 | logging.info("y_train: {}".format(y_train.shape)) 185 | logging.info("y_val: {}".format(y_val.shape)) 186 | 187 | logging.info("Question encoding") 188 | logging.info(x_train[0][:2]) 189 | logging.info("Sentence encoding") 190 | logging.info(x_train[1][:2]) 191 | logging.info("Rouge scores:") 192 | logging.info(y_train[0][:2]) 193 | 194 | model = SentenceClassificationModel( 195 | FLAGS.vocab_size, FLAGS.batch_size, FLAGS.hidden_dim, FLAGS.dropout, 196 | FLAGS.max_tok_q, FLAGS.max_sentences, FLAGS.max_tok_sent 197 | ) 198 | 199 | if FLAGS.binary_model: 200 | model = model.build_binary_model() 201 | run_training(model, x_train, y_train, x_val, y_val) 202 | else: 203 | model = model.build_model() 204 | run_training(model, x_train, y_train, x_val, y_val) 205 | 206 | if FLAGS.mode == "infer": 207 | encoded_data = vocab_processor.encode_data(FLAGS.tokenizer_path, data) 208 | padded_data = [] 209 | padded_data.append(data_loader.pad_data(encoded_data[0], "question")) 210 | padded_data.append(data_loader.pad_data(encoded_data[1], "article")) 211 | logging.info("N questions: {}".format(len(padded_data[0]))) 212 | logging.info("N documents: {}".format(len(padded_data[1]))) 213 | logging.info("Question encoding") 214 | logging.info(padded_data[0][:2]) 215 | logging.info("Sentence encoding") 216 | logging.info(padded_data[1][:2]) 217 | 218 | logging.info("Loading model") 219 | model = tf.keras.models.load_model(FLAGS.model_path) 220 | predictions = run_inference(model, padded_data, FLAGS.decision_threshold, FLAGS.top_k_sent, FLAGS.binary_model) 221 | save_predictions(predictions, data, FLAGS.binary_model) 222 | 223 | 224 | if __name__ == "__main__": 225 | app.run(main) 226 | -------------------------------------------------------------------------------- /models/bilstm/train_sentence_classifier.sh: -------------------------------------------------------------------------------- 1 | exp_name=medsumm_bioasq_abs2summ 2 | tensorboard_log=dropout_5_sent_200_tok_50_val_20_d_256_l2_reg_binary 3 | mkdir $exp_name 4 | mkdir ${exp_name}/${tensorboard_log} 5 | #training_data=/data/saveryme/asumm/asumm_data/training_data/bioasq_abs2summ_binary_sent_classification_training.json 6 | training_data=../../data_processing/data/bioasq_abs2summ_binary_sent_classification_training.json 7 | max_sent=200 8 | binary_model=True 9 | # Note that for first run the sub-word tokenizer has to be trained first. For all subsequent runs, can be left false. 10 | train_tokenizer=True 11 | 12 | #--tokenizer_path=/data/saveryme/asumm/models/sentence_classifier/${exp_name}/tokenizer \ 13 | python run_classifier.py \ 14 | --exp_name=${exp_name} \ 15 | --tensorboard_log=${exp_name}/${tensorboard_log} \ 16 | --mode=train \ 17 | --data_path=$training_data \ 18 | --train_tokenizer=$train_tokenizer \ 19 | --tokenizer_path=./${exp_name}/tokenizer \ 20 | --model_path=bioasq_abs2summ_sent_class_model.h5 \ 21 | --batch_size=32 \ 22 | --max_sentences=$max_sent \ 23 | --max_tok_sent=50 \ 24 | --max_tok_q=50 \ 25 | --dropout=.5 \ 26 | --hidden_dim=256 \ 27 | --binary_model=$binary_model 28 | -------------------------------------------------------------------------------- /models/pointer_generator/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 The TensorFlow Authors. All rights reserved. 2 | Modifications Copyright 2017 Abigail See 3 | 4 | 5 | Apache License 6 | Version 2.0, January 2004 7 | http://www.apache.org/licenses/ 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, 14 | and distribution as defined by Sections 1 through 9 of this document. 15 | 16 | "Licensor" shall mean the copyright owner or entity authorized by 17 | the copyright owner that is granting the License. 18 | 19 | "Legal Entity" shall mean the union of the acting entity and all 20 | other entities that control, are controlled by, or are under common 21 | control with that entity. For the purposes of this definition, 22 | "control" means (i) the power, direct or indirect, to cause the 23 | direction or management of such entity, whether by contract or 24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 25 | outstanding shares, or (iii) beneficial ownership of such entity. 26 | 27 | "You" (or "Your") shall mean an individual or Legal Entity 28 | exercising permissions granted by this License. 29 | 30 | "Source" form shall mean the preferred form for making modifications, 31 | including but not limited to software source code, documentation 32 | source, and configuration files. 33 | 34 | "Object" form shall mean any form resulting from mechanical 35 | transformation or translation of a Source form, including but 36 | not limited to compiled object code, generated documentation, 37 | and conversions to other media types. 38 | 39 | "Work" shall mean the work of authorship, whether in Source or 40 | Object form, made available under the License, as indicated by a 41 | copyright notice that is included in or attached to the work 42 | (an example is provided in the Appendix below). 43 | 44 | "Derivative Works" shall mean any work, whether in Source or Object 45 | form, that is based on (or derived from) the Work and for which the 46 | editorial revisions, annotations, elaborations, or other modifications 47 | represent, as a whole, an original work of authorship. For the purposes 48 | of this License, Derivative Works shall not include works that remain 49 | separable from, or merely link (or bind by name) to the interfaces of, 50 | the Work and Derivative Works thereof. 51 | 52 | "Contribution" shall mean any work of authorship, including 53 | the original version of the Work and any modifications or additions 54 | to that Work or Derivative Works thereof, that is intentionally 55 | submitted to Licensor for inclusion in the Work by the copyright owner 56 | or by an individual or Legal Entity authorized to submit on behalf of 57 | the copyright owner. For the purposes of this definition, "submitted" 58 | means any form of electronic, verbal, or written communication sent 59 | to the Licensor or its representatives, including but not limited to 60 | communication on electronic mailing lists, source code control systems, 61 | and issue tracking systems that are managed by, or on behalf of, the 62 | Licensor for the purpose of discussing and improving the Work, but 63 | excluding communication that is conspicuously marked or otherwise 64 | designated in writing by the copyright owner as "Not a Contribution." 65 | 66 | "Contributor" shall mean Licensor and any individual or Legal Entity 67 | on behalf of whom a Contribution has been received by Licensor and 68 | subsequently incorporated within the Work. 69 | 70 | 2. Grant of Copyright License. Subject to the terms and conditions of 71 | this License, each Contributor hereby grants to You a perpetual, 72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 73 | copyright license to reproduce, prepare Derivative Works of, 74 | publicly display, publicly perform, sublicense, and distribute the 75 | Work and such Derivative Works in Source or Object form. 76 | 77 | 3. Grant of Patent License. Subject to the terms and conditions of 78 | this License, each Contributor hereby grants to You a perpetual, 79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | (except as stated in this section) patent license to make, have made, 81 | use, offer to sell, sell, import, and otherwise transfer the Work, 82 | where such license applies only to those patent claims licensable 83 | by such Contributor that are necessarily infringed by their 84 | Contribution(s) alone or by combination of their Contribution(s) 85 | with the Work to which such Contribution(s) was submitted. If You 86 | institute patent litigation against any entity (including a 87 | cross-claim or counterclaim in a lawsuit) alleging that the Work 88 | or a Contribution incorporated within the Work constitutes direct 89 | or contributory patent infringement, then any patent licenses 90 | granted to You under this License for that Work shall terminate 91 | as of the date such litigation is filed. 92 | 93 | 4. Redistribution. You may reproduce and distribute copies of the 94 | Work or Derivative Works thereof in any medium, with or without 95 | modifications, and in Source or Object form, provided that You 96 | meet the following conditions: 97 | 98 | (a) You must give any other recipients of the Work or 99 | Derivative Works a copy of this License; and 100 | 101 | (b) You must cause any modified files to carry prominent notices 102 | stating that You changed the files; and 103 | 104 | (c) You must retain, in the Source form of any Derivative Works 105 | that You distribute, all copyright, patent, trademark, and 106 | attribution notices from the Source form of the Work, 107 | excluding those notices that do not pertain to any part of 108 | the Derivative Works; and 109 | 110 | (d) If the Work includes a "NOTICE" text file as part of its 111 | distribution, then any Derivative Works that You distribute must 112 | include a readable copy of the attribution notices contained 113 | within such NOTICE file, excluding those notices that do not 114 | pertain to any part of the Derivative Works, in at least one 115 | of the following places: within a NOTICE text file distributed 116 | as part of the Derivative Works; within the Source form or 117 | documentation, if provided along with the Derivative Works; or, 118 | within a display generated by the Derivative Works, if and 119 | wherever such third-party notices normally appear. The contents 120 | of the NOTICE file are for informational purposes only and 121 | do not modify the License. You may add Your own attribution 122 | notices within Derivative Works that You distribute, alongside 123 | or as an addendum to the NOTICE text from the Work, provided 124 | that such additional attribution notices cannot be construed 125 | as modifying the License. 126 | 127 | You may add Your own copyright statement to Your modifications and 128 | may provide additional or different license terms and conditions 129 | for use, reproduction, or distribution of Your modifications, or 130 | for any such Derivative Works as a whole, provided Your use, 131 | reproduction, and distribution of the Work otherwise complies with 132 | the conditions stated in this License. 133 | 134 | 5. Submission of Contributions. Unless You explicitly state otherwise, 135 | any Contribution intentionally submitted for inclusion in the Work 136 | by You to the Licensor shall be under the terms and conditions of 137 | this License, without any additional terms or conditions. 138 | Notwithstanding the above, nothing herein shall supersede or modify 139 | the terms of any separate license agreement you may have executed 140 | with Licensor regarding such Contributions. 141 | 142 | 6. Trademarks. This License does not grant permission to use the trade 143 | names, trademarks, service marks, or product names of the Licensor, 144 | except as required for reasonable and customary use in describing the 145 | origin of the Work and reproducing the content of the NOTICE file. 146 | 147 | 7. Disclaimer of Warranty. Unless required by applicable law or 148 | agreed to in writing, Licensor provides the Work (and each 149 | Contributor provides its Contributions) on an "AS IS" BASIS, 150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 151 | implied, including, without limitation, any warranties or conditions 152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 153 | PARTICULAR PURPOSE. You are solely responsible for determining the 154 | appropriateness of using or redistributing the Work and assume any 155 | risks associated with Your exercise of permissions under this License. 156 | 157 | 8. Limitation of Liability. In no event and under no legal theory, 158 | whether in tort (including negligence), contract, or otherwise, 159 | unless required by applicable law (such as deliberate and grossly 160 | negligent acts) or agreed to in writing, shall any Contributor be 161 | liable to You for damages, including any direct, indirect, special, 162 | incidental, or consequential damages of any character arising as a 163 | result of this License or out of the use or inability to use the 164 | Work (including but not limited to damages for loss of goodwill, 165 | work stoppage, computer failure or malfunction, or any and all 166 | other commercial damages or losses), even if such Contributor 167 | has been advised of the possibility of such damages. 168 | 169 | 9. Accepting Warranty or Additional Liability. While redistributing 170 | the Work or Derivative Works thereof, You may choose to offer, 171 | and charge a fee for, acceptance of support, warranty, indemnity, 172 | or other liability obligations and/or rights consistent with this 173 | License. However, in accepting such obligations, You may act only 174 | on Your own behalf and on Your sole responsibility, not on behalf 175 | of any other Contributor, and only if You agree to indemnify, 176 | defend, and hold each Contributor harmless for any liability 177 | incurred by, or claims asserted against, such Contributor by reason 178 | of your accepting any such warranty or additional liability. 179 | 180 | END OF TERMS AND CONDITIONS 181 | 182 | APPENDIX: How to apply the Apache License to your work. 183 | 184 | To apply the Apache License to your work, attach the following 185 | boilerplate notice, with the fields enclosed by brackets "[]" 186 | replaced with your own identifying information. (Don't include 187 | the brackets!) The text should be enclosed in the appropriate 188 | comment syntax for the file format. We also recommend that a 189 | file or class name and description of purpose be included on the 190 | same "printed page" as the copyright notice for easier 191 | identification within third-party archives. 192 | 193 | Copyright 2017, The TensorFlow Authors. 194 | Modifications Copyright 2017 Abigail See 195 | 196 | Licensed under the Apache License, Version 2.0 (the "License"); 197 | you may not use this file except in compliance with the License. 198 | You may obtain a copy of the License at 199 | 200 | http://www.apache.org/licenses/LICENSE-2.0 201 | 202 | Unless required by applicable law or agreed to in writing, software 203 | distributed under the License is distributed on an "AS IS" BASIS, 204 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 205 | See the License for the specific language governing permissions and 206 | limitations under the License. 207 | -------------------------------------------------------------------------------- /models/pointer_generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__init__.py -------------------------------------------------------------------------------- /models/pointer_generator/__pycache__/attention_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/attention_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /models/pointer_generator/__pycache__/batcher.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/batcher.cpython-36.pyc -------------------------------------------------------------------------------- /models/pointer_generator/__pycache__/beam_search.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/beam_search.cpython-36.pyc -------------------------------------------------------------------------------- /models/pointer_generator/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /models/pointer_generator/__pycache__/decode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/decode.cpython-36.pyc -------------------------------------------------------------------------------- /models/pointer_generator/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /models/pointer_generator/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /models/pointer_generator/attention_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file defines the decoder""" 18 | 19 | import tensorflow as tf 20 | from tensorflow.python.ops import variable_scope 21 | from tensorflow.python.ops import array_ops 22 | from tensorflow.python.ops import nn_ops 23 | from tensorflow.python.ops import math_ops 24 | 25 | # Note: this function is based on tf.contrib.legacy_seq2seq_attention_decoder, which is now outdated. 26 | # In the future, it would make more sense to write variants on the attention mechanism using the new seq2seq library for tensorflow 1.0: https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention 27 | def attention_decoder(decoder_inputs, initial_state, encoder_states, enc_padding_mask, cell, initial_state_attention=False, pointer_gen=True, use_coverage=False, prev_coverage=None): 28 | """ 29 | Args: 30 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 31 | initial_state: 2D Tensor [batch_size x cell.state_size]. 32 | encoder_states: 3D Tensor [batch_size x attn_length x attn_size]. 33 | enc_padding_mask: 2D Tensor [batch_size x attn_length] containing 1s and 0s; indicates which of the encoder locations are padding (0) or a real token (1). 34 | cell: rnn_cell.RNNCell defining the cell function and size. 35 | initial_state_attention: 36 | Note that this attention decoder passes each decoder input through a linear layer with the previous step's context vector to get a modified version of the input. If initial_state_attention is False, on the first decoder step the "previous context vector" is just a zero vector. If initial_state_attention is True, we use initial_state to (re)calculate the previous step's context vector. We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step). 37 | pointer_gen: boolean. If True, calculate the generation probability p_gen for each decoder step. 38 | use_coverage: boolean. If True, use coverage mechanism. 39 | prev_coverage: 40 | If not None, a tensor with shape (batch_size, attn_length). The previous step's coverage vector. This is only not None in decode mode when using coverage. 41 | 42 | Returns: 43 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 44 | shape [batch_size x cell.output_size]. The output vectors. 45 | state: The final state of the decoder. A tensor shape [batch_size x cell.state_size]. 46 | attn_dists: A list containing tensors of shape (batch_size,attn_length). 47 | The attention distributions for each decoder step. 48 | p_gens: List of scalars. The values of p_gen for each decoder step. Empty list if pointer_gen=False. 49 | coverage: Coverage vector on the last step computed. None if use_coverage=False. 50 | """ 51 | with variable_scope.variable_scope("attention_decoder") as scope: 52 | batch_size = encoder_states.get_shape()[0].value # if this line fails, it's because the batch size isn't defined 53 | attn_size = encoder_states.get_shape()[2].value # if this line fails, it's because the attention length isn't defined 54 | 55 | # Reshape encoder_states (need to insert a dim) 56 | encoder_states = tf.expand_dims(encoder_states, axis=2) # now is shape (batch_size, attn_len, 1, attn_size) 57 | 58 | # To calculate attention, we calculate 59 | # v^T tanh(W_h h_i + W_s s_t + b_attn) 60 | # where h_i is an encoder state, and s_t a decoder state. 61 | # attn_vec_size is the length of the vectors v, b_attn, (W_h h_i) and (W_s s_t). 62 | # We set it to be equal to the size of the encoder states. 63 | attention_vec_size = attn_size 64 | 65 | # Get the weight matrix W_h and apply it to each encoder state to get (W_h h_i), the encoder features 66 | W_h = variable_scope.get_variable("W_h", [1, 1, attn_size, attention_vec_size]) 67 | encoder_features = nn_ops.conv2d(encoder_states, W_h, [1, 1, 1, 1], "SAME") # shape (batch_size,attn_length,1,attention_vec_size) 68 | 69 | # Get the weight vectors v and w_c (w_c is for coverage) 70 | v = variable_scope.get_variable("v", [attention_vec_size]) 71 | if use_coverage: 72 | with variable_scope.variable_scope("coverage"): 73 | w_c = variable_scope.get_variable("w_c", [1, 1, 1, attention_vec_size]) 74 | 75 | if prev_coverage is not None: # for beam search mode with coverage 76 | # reshape from (batch_size, attn_length) to (batch_size, attn_len, 1, 1) 77 | prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage,2),3) 78 | 79 | def attention(decoder_state, coverage=None): 80 | """Calculate the context vector and attention distribution from the decoder state. 81 | 82 | Args: 83 | decoder_state: state of the decoder 84 | coverage: Optional. Previous timestep's coverage vector, shape (batch_size, attn_len, 1, 1). 85 | 86 | Returns: 87 | context_vector: weighted sum of encoder_states 88 | attn_dist: attention distribution 89 | coverage: new coverage vector. shape (batch_size, attn_len, 1, 1) 90 | """ 91 | with variable_scope.variable_scope("Attention"): 92 | # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper) 93 | decoder_features = linear(decoder_state, attention_vec_size, True) # shape (batch_size, attention_vec_size) 94 | decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_vec_size) 95 | 96 | def masked_attention(e): 97 | """Take softmax of e then apply enc_padding_mask and re-normalize""" 98 | attn_dist = nn_ops.softmax(e) # take softmax. shape (batch_size, attn_length) 99 | attn_dist *= enc_padding_mask # apply mask 100 | masked_sums = tf.reduce_sum(attn_dist, axis=1) # shape (batch_size) 101 | return attn_dist / tf.reshape(masked_sums, [-1, 1]) # re-normalize 102 | 103 | if use_coverage and coverage is not None: # non-first step of coverage 104 | # Multiply coverage vector by w_c to get coverage_features. 105 | coverage_features = nn_ops.conv2d(coverage, w_c, [1, 1, 1, 1], "SAME") # c has shape (batch_size, attn_length, 1, attention_vec_size) 106 | 107 | # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn) 108 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3]) # shape (batch_size,attn_length) 109 | 110 | # Calculate attention distribution 111 | attn_dist = masked_attention(e) 112 | 113 | # Update coverage vector 114 | coverage += array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) 115 | else: 116 | # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn) 117 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e 118 | 119 | # Calculate attention distribution 120 | attn_dist = masked_attention(e) 121 | 122 | if use_coverage: # first step of training 123 | coverage = tf.expand_dims(tf.expand_dims(attn_dist,2),2) # initialize coverage 124 | 125 | # Calculate the context vector from attn_dist and encoder_states 126 | context_vector = math_ops.reduce_sum(array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) * encoder_states, [1, 2]) # shape (batch_size, attn_size). 127 | context_vector = array_ops.reshape(context_vector, [-1, attn_size]) 128 | 129 | return context_vector, attn_dist, coverage 130 | 131 | outputs = [] 132 | attn_dists = [] 133 | p_gens = [] 134 | state = initial_state 135 | coverage = prev_coverage # initialize coverage to None or whatever was passed in 136 | context_vector = array_ops.zeros([batch_size, attn_size]) 137 | context_vector.set_shape([None, attn_size]) # Ensure the second shape of attention vectors is set. 138 | if initial_state_attention: # true in decode mode 139 | # Re-calculate the context vector from the previous step so that we can pass it through a linear layer with this step's input to get a modified version of the input 140 | context_vector, _, coverage = attention(initial_state, coverage) # in decode mode, this is what updates the coverage vector 141 | for i, inp in enumerate(decoder_inputs): 142 | tf.logging.info("Adding attention_decoder timestep %i of %i", i, len(decoder_inputs)) 143 | if i > 0: 144 | variable_scope.get_variable_scope().reuse_variables() 145 | 146 | # Merge input and previous attentions into one vector x of the same size as inp 147 | input_size = inp.get_shape().with_rank(2)[1] 148 | if input_size.value is None: 149 | raise ValueError("Could not infer input size from input: %s" % inp.name) 150 | x = linear([inp] + [context_vector], input_size, True) 151 | 152 | # Run the decoder RNN cell. cell_output = decoder state 153 | cell_output, state = cell(x, state) 154 | 155 | # Run the attention mechanism. 156 | if i == 0 and initial_state_attention: # always true in decode mode 157 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): # you need this because you've already run the initial attention(...) call 158 | context_vector, attn_dist, _ = attention(state, coverage) # don't allow coverage to update 159 | else: 160 | context_vector, attn_dist, coverage = attention(state, coverage) 161 | attn_dists.append(attn_dist) 162 | 163 | # Calculate p_gen 164 | if pointer_gen: 165 | with tf.variable_scope('calculate_pgen'): 166 | p_gen = linear([context_vector, state.c, state.h, x], 1, True) # a scalar 167 | p_gen = tf.sigmoid(p_gen) 168 | p_gens.append(p_gen) 169 | 170 | # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer 171 | # This is V[s_t, h*_t] + b in the paper 172 | with variable_scope.variable_scope("AttnOutputProjection"): 173 | output = linear([cell_output] + [context_vector], cell.output_size, True) 174 | outputs.append(output) 175 | 176 | # If using coverage, reshape it 177 | if coverage is not None: 178 | coverage = array_ops.reshape(coverage, [batch_size, -1]) 179 | 180 | return outputs, state, attn_dists, p_gens, coverage 181 | 182 | 183 | 184 | def linear(args, output_size, bias, bias_start=0.0, scope=None): 185 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 186 | 187 | Args: 188 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 189 | output_size: int, second dimension of W[i]. 190 | bias: boolean, whether to add a bias term or not. 191 | bias_start: starting value to initialize the bias; 0 by default. 192 | scope: VariableScope for the created subgraph; defaults to "Linear". 193 | 194 | Returns: 195 | A 2D Tensor with shape [batch x output_size] equal to 196 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 197 | 198 | Raises: 199 | ValueError: if some of the arguments has unspecified or wrong shape. 200 | """ 201 | if args is None or (isinstance(args, (list, tuple)) and not args): 202 | raise ValueError("`args` must be specified") 203 | if not isinstance(args, (list, tuple)): 204 | args = [args] 205 | 206 | # Calculate the total size of arguments on dimension 1. 207 | total_arg_size = 0 208 | shapes = [a.get_shape().as_list() for a in args] 209 | for shape in shapes: 210 | if len(shape) != 2: 211 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) 212 | if not shape[1]: 213 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) 214 | else: 215 | total_arg_size += shape[1] 216 | 217 | # Now the computation. 218 | with tf.variable_scope(scope or "Linear"): 219 | matrix = tf.get_variable("Matrix", [total_arg_size, output_size]) 220 | if len(args) == 1: 221 | res = tf.matmul(args[0], matrix) 222 | else: 223 | res = tf.matmul(tf.concat(axis=1, values=args), matrix) 224 | if not bias: 225 | return res 226 | bias_term = tf.get_variable( 227 | "Bias", [output_size], initializer=tf.constant_initializer(bias_start)) 228 | return res + bias_term 229 | -------------------------------------------------------------------------------- /models/pointer_generator/beam_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to run beam search decoding""" 18 | 19 | import tensorflow as tf 20 | import numpy as np 21 | import data 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | class Hypothesis(object): 26 | """Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis.""" 27 | 28 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage): 29 | """Hypothesis constructor. 30 | 31 | Args: 32 | tokens: List of integers. The ids of the tokens that form the summary so far. 33 | log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far. 34 | state: Current state of the decoder, a LSTMStateTuple. 35 | attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far. 36 | p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far. 37 | coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector. 38 | """ 39 | self.tokens = tokens 40 | self.log_probs = log_probs 41 | self.state = state 42 | self.attn_dists = attn_dists 43 | self.p_gens = p_gens 44 | self.coverage = coverage 45 | 46 | def extend(self, token, log_prob, state, attn_dist, p_gen, coverage): 47 | """Return a NEW hypothesis, extended with the information from the latest step of beam search. 48 | 49 | Args: 50 | token: Integer. Latest token produced by beam search. 51 | log_prob: Float. Log prob of the latest token. 52 | state: Current decoder state, a LSTMStateTuple. 53 | attn_dist: Attention distribution from latest step. Numpy array shape (attn_length). 54 | p_gen: Generation probability on latest step. Float. 55 | coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage. 56 | Returns: 57 | New Hypothesis for next step. 58 | """ 59 | return Hypothesis(tokens = self.tokens + [token], 60 | log_probs = self.log_probs + [log_prob], 61 | state = state, 62 | attn_dists = self.attn_dists + [attn_dist], 63 | p_gens = self.p_gens + [p_gen], 64 | coverage = coverage) 65 | 66 | @property 67 | def latest_token(self): 68 | return self.tokens[-1] 69 | 70 | @property 71 | def log_prob(self): 72 | # the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far 73 | return sum(self.log_probs) 74 | 75 | @property 76 | def avg_log_prob(self): 77 | # normalize log probability by number of tokens (otherwise longer sequences always have lower probability) 78 | return self.log_prob / len(self.tokens) 79 | 80 | 81 | def run_beam_search(sess, model, vocab, batch): 82 | """Performs beam search decoding on the given example. 83 | 84 | Args: 85 | sess: a tf.Session 86 | model: a seq2seq model 87 | vocab: Vocabulary object 88 | batch: Batch object that is the same example repeated across the batch 89 | 90 | Returns: 91 | best_hyp: Hypothesis object; the best hypothesis found by beam search. 92 | """ 93 | # Run the encoder to get the encoder hidden states and decoder initial state 94 | enc_states, dec_in_state = model.run_encoder(sess, batch) 95 | # dec_in_state is a LSTMStateTuple 96 | # enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim]. 97 | 98 | # Initialize beam_size-many hyptheses 99 | hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)], 100 | log_probs=[0.0], 101 | state=dec_in_state, 102 | attn_dists=[], 103 | p_gens=[], 104 | coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length 105 | ) for _ in range(FLAGS.beam_size)] 106 | results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token) 107 | 108 | steps = 0 109 | while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size: 110 | latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis 111 | latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings 112 | states = [h.state for h in hyps] # list of current decoder states of the hypotheses 113 | prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None) 114 | 115 | # Run one step of the decoder to get the new info 116 | (topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess, 117 | batch=batch, 118 | latest_tokens=latest_tokens, 119 | enc_states=enc_states, 120 | dec_init_states=states, 121 | prev_coverage=prev_coverage) 122 | 123 | # Extend each hypothesis and collect them all in all_hyps 124 | all_hyps = [] 125 | num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct. 126 | for i in range(num_orig_hyps): 127 | h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info 128 | for j in range(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps: 129 | # Extend the ith hypothesis with the jth option 130 | new_hyp = h.extend(token=topk_ids[i, j], 131 | log_prob=topk_log_probs[i, j], 132 | state=new_state, 133 | attn_dist=attn_dist, 134 | p_gen=p_gen, 135 | coverage=new_coverage_i) 136 | all_hyps.append(new_hyp) 137 | 138 | # Filter and collect any hypotheses that have produced the end token. 139 | hyps = [] # will contain hypotheses for the next step 140 | for h in sort_hyps(all_hyps): # in order of most likely h 141 | if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached... 142 | # If this hypothesis is sufficiently long, put in results. Otherwise discard. 143 | if steps >= FLAGS.min_dec_steps: 144 | results.append(h) 145 | else: # hasn't reached stop token, so continue to extend this hypothesis 146 | hyps.append(h) 147 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size: 148 | # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop. 149 | break 150 | 151 | steps += 1 152 | 153 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps 154 | 155 | if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results 156 | results = hyps 157 | 158 | # Sort hypotheses by average log probability 159 | hyps_sorted = sort_hyps(results) 160 | 161 | # Return the hypothesis with highest average log prob 162 | return hyps_sorted[0] 163 | 164 | def sort_hyps(hyps): 165 | """Return a list of Hypothesis objects, sorted by descending average log probability""" 166 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True) 167 | -------------------------------------------------------------------------------- /models/pointer_generator/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it""" 18 | 19 | import glob 20 | import json 21 | import random 22 | import struct 23 | import csv 24 | from tensorflow.core.example import example_pb2 25 | 26 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. 27 | SENTENCE_START = '' 28 | SENTENCE_END = '' 29 | 30 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence 31 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words 32 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence 33 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences 34 | 35 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file. 36 | 37 | 38 | class Vocab(object): 39 | """Vocabulary class for mapping between words and ids (integers)""" 40 | 41 | def __init__(self, vocab_file, max_size): 42 | """Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file. 43 | 44 | Args: 45 | vocab_file: path to the vocab file, which is assumed to contain " " on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though. 46 | max_size: integer. The maximum size of the resulting Vocabulary.""" 47 | self._word_to_id = {} 48 | self._id_to_word = {} 49 | self._count = 0 # keeps track of total number of words in the Vocab 50 | 51 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. 52 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 53 | self._word_to_id[w] = self._count 54 | self._id_to_word[self._count] = w 55 | self._count += 1 56 | 57 | # Read the vocab file and add words up to max_size 58 | with open(vocab_file, 'r') as vocab_f: 59 | for line in vocab_f: 60 | pieces = line.split() 61 | if len(pieces) != 2: 62 | print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line) 63 | continue 64 | w = pieces[0] 65 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 66 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) 67 | if w in self._word_to_id: 68 | raise Exception('Duplicated word in vocabulary file: %s' % w) 69 | self._word_to_id[w] = self._count 70 | self._id_to_word[self._count] = w 71 | self._count += 1 72 | if max_size != 0 and self._count >= max_size: 73 | print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)) 74 | break 75 | 76 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1])) 77 | 78 | def word2id(self, word): 79 | """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV.""" 80 | if word not in self._word_to_id: 81 | return self._word_to_id[UNKNOWN_TOKEN] 82 | return self._word_to_id[word] 83 | 84 | def id2word(self, word_id): 85 | """Returns the word (string) corresponding to an id (integer).""" 86 | if word_id not in self._id_to_word: 87 | raise ValueError('Id not found in vocab: %d' % word_id) 88 | return self._id_to_word[word_id] 89 | 90 | def size(self): 91 | """Returns the total size of the vocabulary""" 92 | return self._count 93 | 94 | def write_metadata(self, fpath): 95 | """Writes metadata file for Tensorboard word embedding visualizer as described here: 96 | https://www.tensorflow.org/get_started/embedding_viz 97 | 98 | Args: 99 | fpath: place to write the metadata file 100 | """ 101 | print("Writing word embedding metadata file to %s..." % (fpath)) 102 | with open(fpath, "w") as f: 103 | fieldnames = ['word'] 104 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames) 105 | for i in range(self.size()): 106 | writer.writerow({"word": self._id_to_word[i]}) 107 | 108 | 109 | def example_generator(data_path, single_pass): 110 | """Generates tf.Examples from data files. 111 | 112 | Binary data format: . represents the byte size 113 | of . is serialized tf.Example proto. The tf.Example contains 114 | the tokenized article text and summary. 115 | 116 | Args: 117 | data_path: 118 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all. 119 | single_pass: 120 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely. 121 | 122 | Yields: 123 | Deserialized tf.Example. 124 | """ 125 | while True: 126 | filelist = glob.glob(data_path) # get the list of datafiles 127 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty 128 | if single_pass: 129 | filelist = sorted(filelist) 130 | else: 131 | random.shuffle(filelist) 132 | for f in filelist: 133 | reader = open(f, 'rb') 134 | while True: 135 | len_bytes = reader.read(8) 136 | if not len_bytes: break # finished reading this file 137 | str_len = struct.unpack('q', len_bytes)[0] 138 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 139 | yield example_pb2.Example.FromString(example_str) 140 | if single_pass: 141 | print("example_generator completed reading all datafiles. No more data.") 142 | break 143 | 144 | 145 | def asumm_example_generator(data_path, single_pass, question_driven, tag_sentences): 146 | """Generates tf.Examples from data files for answer summarization. Coppied from example_generator. 147 | 148 | Binary data format: . represents the byte size 149 | of . is serialized tf.Example proto. The tf.Example contains 150 | the tokenized article text and summary. 151 | 152 | Args: 153 | data_path: 154 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all. 155 | single_pass: 156 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely. 157 | 158 | Yields: 159 | Deserialized tf.Example. 160 | 161 | Returns data until iterated to the end of the file 162 | """ 163 | while True: 164 | with open(data_path, "r", encoding="utf-8") as f: 165 | asumm_data = json.load(f) 166 | keys = list(asumm_data.keys()) 167 | random.shuffle(keys) 168 | for key in keys: 169 | summary = asumm_data[key]['summary'] 170 | articles = asumm_data[key]['articles'] 171 | question = asumm_data[key]['question'] 172 | # Only add question or tag sentences during inference, not during training 173 | if tag_sentences and single_pass: 174 | split_summ = summary.split(".") 175 | summary = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in split_summ]) 176 | if question_driven and single_pass: 177 | articles = question + " [QUESTION?] " + articles 178 | yield (summary, articles, question) 179 | if single_pass: 180 | print("example_generator completed reading all datafiles. No more data.") 181 | break 182 | 183 | 184 | def article2ids(article_words, vocab): 185 | """Map the article words to their ids. Also return a list of OOVs in the article. 186 | 187 | Args: 188 | article_words: list of words (strings) 189 | vocab: Vocabulary object 190 | 191 | Returns: 192 | ids: 193 | A list of word ids (integers); OOVs are represented by their temporary article OOV number. If the vocabulary size is 50k and the article has 3 OOVs, then these temporary OOV numbers will be 50000, 50001, 50002. 194 | oovs: 195 | A list of the OOV words in the article (strings), in the order corresponding to their temporary article OOV numbers.""" 196 | ids = [] 197 | oovs = [] 198 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 199 | for w in article_words: 200 | i = vocab.word2id(w) 201 | if i == unk_id: # If w is OOV 202 | if w not in oovs: # Add to list of OOVs 203 | oovs.append(w) 204 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 205 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 206 | else: 207 | ids.append(i) 208 | return ids, oovs 209 | 210 | 211 | def abstract2ids(abstract_words, vocab, article_oovs): 212 | """Map the abstract words to their ids. In-article OOVs are mapped to their temporary OOV numbers. 213 | 214 | Args: 215 | abstract_words: list of words (strings) 216 | vocab: Vocabulary object 217 | article_oovs: list of in-article OOV words (strings), in the order corresponding to their temporary article OOV numbers 218 | 219 | Returns: 220 | ids: List of ids (integers). In-article OOV words are mapped to their temporary OOV numbers. Out-of-article OOV words are mapped to the UNK token id.""" 221 | ids = [] 222 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 223 | for w in abstract_words: 224 | i = vocab.word2id(w) 225 | if i == unk_id: # If w is an OOV word 226 | if w in article_oovs: # If w is an in-article OOV 227 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 228 | ids.append(vocab_idx) 229 | else: # If w is an out-of-article OOV 230 | ids.append(unk_id) # Map to the UNK token id 231 | else: 232 | ids.append(i) 233 | return ids 234 | 235 | 236 | def outputids2words(id_list, vocab, article_oovs): 237 | """Maps output ids to words, including mapping in-article OOVs from their temporary ids to the original OOV string (applicable in pointer-generator mode). 238 | 239 | Args: 240 | id_list: list of ids (integers) 241 | vocab: Vocabulary object 242 | article_oovs: list of OOV words (strings) in the order corresponding to their temporary article OOV ids (that have been assigned in pointer-generator mode), or None (in baseline mode) 243 | 244 | Returns: 245 | words: list of words (strings) 246 | """ 247 | words = [] 248 | for i in id_list: 249 | try: 250 | w = vocab.id2word(i) # might be [UNK] 251 | except ValueError as e: # w is OOV 252 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode" 253 | article_oov_idx = i - vocab.size() 254 | try: 255 | w = article_oovs[article_oov_idx] 256 | except ValueError as e: # i doesn't correspond to an article oov 257 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs))) 258 | words.append(w) 259 | return words 260 | 261 | 262 | def abstract2sents(abstract): 263 | """Splits abstract text from datafile into list of sentences. 264 | 265 | Args: 266 | abstract: string containing and tags for starts and ends of sentences 267 | 268 | Returns: 269 | sents: List of sentence strings (no tags)""" 270 | cur = 0 271 | sents = [] 272 | while True: 273 | try: 274 | start_p = abstract.index(SENTENCE_START, cur) 275 | end_p = abstract.index(SENTENCE_END, start_p + 1) 276 | cur = end_p + len(SENTENCE_END) 277 | sents.append(abstract[start_p+len(SENTENCE_START):end_p]) 278 | except ValueError as e: # no more sentences 279 | return sents 280 | 281 | 282 | def show_art_oovs(article, vocab): 283 | """Returns the article string, highlighting the OOVs by placing __underscores__ around them""" 284 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 285 | words = article.split(' ') 286 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words] 287 | out_str = ' '.join(words) 288 | return out_str 289 | 290 | 291 | def show_abs_oovs(abstract, vocab, article_oovs): 292 | """Returns the abstract string, highlighting the article OOVs with __underscores__. 293 | 294 | If a list of article_oovs is provided, non-article OOVs are differentiated like !!__this__!!. 295 | 296 | Args: 297 | abstract: string 298 | vocab: Vocabulary object 299 | article_oovs: list of words (strings), or None (in baseline mode) 300 | """ 301 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 302 | words = abstract.split(' ') 303 | new_words = [] 304 | for w in words: 305 | if vocab.word2id(w) == unk_token: # w is oov 306 | if article_oovs is None: # baseline mode 307 | new_words.append("__%s__" % w) 308 | else: # pointer-generator mode 309 | if w in article_oovs: 310 | new_words.append("__%s__" % w) 311 | else: 312 | new_words.append("!!__%s__!!" % w) 313 | else: # w is in-vocab word 314 | new_words.append(w) 315 | out_str = ' '.join(new_words) 316 | return out_str 317 | -------------------------------------------------------------------------------- /models/pointer_generator/decode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to run beam search decoding, including running ROUGE evaluation and producing JSON datafiles for the in-browser attention visualizer, which can be found here https://github.com/abisee/attn_vis""" 18 | 19 | import os 20 | import time 21 | import tensorflow as tf 22 | import beam_search 23 | import data 24 | import json 25 | import pyrouge 26 | # Using py-rouge, the pure python rouge evaluator 27 | #import rouge 28 | import util 29 | import logging 30 | import numpy as np 31 | import sys 32 | 33 | FLAGS = tf.app.flags.FLAGS 34 | 35 | SECS_UNTIL_NEW_CKPT = 60 # max number of seconds before loading new checkpoint 36 | 37 | 38 | class BeamSearchDecoder(object): 39 | """Beam search decoder.""" 40 | 41 | def __init__(self, model, batcher, vocab): 42 | """Initialize decoder. 43 | 44 | Args: 45 | model: a Seq2SeqAttentionModel object. 46 | batcher: a Batcher object. 47 | vocab: Vocabulary object 48 | """ 49 | self._model = model 50 | self._model.build_graph() 51 | self._batcher = batcher 52 | self._vocab = vocab 53 | self._saver = tf.train.Saver() # we use this to load checkpoints for decoding 54 | self._sess = tf.Session(config=util.get_config()) 55 | self._generated_answers = {"question": [], "ref_summary": [], "gen_summary": []} 56 | 57 | # For running decode for answer summarization, I have modified this to load the best model from the eval directory 58 | ckpt_path = util.load_ckpt(self._saver, self._sess) 59 | #ckpt_path = util.load_ckpt(self._saver, self._sess, "eval") 60 | 61 | if FLAGS.single_pass: 62 | # Make a descriptive decode directory name 63 | ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] # this is something of the form "ckpt-123456" 64 | self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name)) 65 | #if os.path.exists(self._decode_dir): 66 | # raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir) 67 | 68 | else: # Generic decode dir name 69 | self._decode_dir = os.path.join(FLAGS.log_root, "decode") 70 | 71 | # Make the decode dir if necessary 72 | if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir) 73 | 74 | if FLAGS.single_pass: 75 | # Make the dirs to contain output written in the correct format for pyrouge 76 | self._rouge_ref_dir = os.path.join(self._decode_dir, "reference") 77 | if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir) 78 | self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded") 79 | if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir) 80 | 81 | 82 | def decode(self): 83 | """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals""" 84 | t0 = time.time() 85 | counter = 0 86 | while True: 87 | batch = self._batcher.next_batch() # 1 example repeated across batch 88 | #if counter >= 2: 89 | # batch = None 90 | if batch is None: # finished decoding dataset in single_pass mode 91 | assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode" 92 | tf.logging.info("Decoder has finished reading dataset for single_pass.") 93 | if FLAGS.eval_type == "cnn": 94 | tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir) 95 | results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir) 96 | rouge_log(results_dict, self._decode_dir) 97 | if FLAGS.eval_type == "medsumm": 98 | tf.logging.info("Writing generated answer summaries to file...") 99 | with open(FLAGS.generated_data_file, "w", encoding="utf-8") as f: 100 | json.dump(self._generated_answers, f, indent=4) 101 | return 102 | 103 | question = batch.questions[0] 104 | original_article = batch.original_articles[0] # string 105 | original_abstract = batch.original_abstracts[0] # string 106 | original_abstract_sents = batch.original_abstracts_sents[0] # list of strings 107 | 108 | article_withunks = data.show_art_oovs(original_article, self._vocab) # string 109 | abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string 110 | 111 | # Run beam search to get best Hypothesis 112 | best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch) 113 | 114 | # Extract the output ids from the hypothesis and convert back to words 115 | output_ids = [int(t) for t in best_hyp.tokens[1:]] 116 | decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) 117 | 118 | # Remove the [STOP] token from decoded_words, if necessary 119 | try: 120 | fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol 121 | decoded_words = decoded_words[:fst_stop_idx] 122 | except ValueError: 123 | decoded_words = decoded_words 124 | decoded_output = ' '.join(decoded_words) # single string 125 | 126 | if FLAGS.single_pass: 127 | if FLAGS.eval_type == "medsumm": 128 | # Save generated data for answer summ evaluation: 129 | self.write_data_for_medsumm_eval(original_abstract_sents, decoded_words, question, counter) # write ref summary and decoded summary to file, for later evaluation. 130 | tf.logging.info("saving summary %i", counter) 131 | if FLAGS.eval_type == "cnn": 132 | # write data for See's original evaluation done with pyrouge. 133 | self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later 134 | counter += 1 # this is how many examples we've decoded 135 | else: 136 | print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen 137 | self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool 138 | 139 | # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint 140 | t1 = time.time() 141 | if t1-t0 > SECS_UNTIL_NEW_CKPT: 142 | tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0) 143 | _ = util.load_ckpt(self._saver, self._sess) 144 | t0 = time.time() 145 | 146 | def write_data_for_medsumm_eval(self, original_abstract_sents, decoded_words, question, counter): 147 | """ 148 | write reference summaries and generated summaries to file 149 | """ 150 | decoded_sents = "" 151 | while len(decoded_words) > 0: 152 | try: 153 | fst_period_idx = decoded_words.index(".") 154 | except ValueError: # there is text remaining that doesn't end in "." 155 | fst_period_idx = len(decoded_words) 156 | sent = decoded_words[:fst_period_idx+1] # sentence up to and including the period 157 | decoded_words = decoded_words[fst_period_idx+1:] # everything else 158 | decoded_sents += (" " + ' '.join(sent)) 159 | abstract_sents = ". ".join(original_abstract_sents) 160 | 161 | self._generated_answers['question'].append(question) 162 | self._generated_answers['ref_summary'].append(abstract_sents) 163 | self._generated_answers['gen_summary'].append(decoded_sents) 164 | 165 | 166 | def write_for_rouge(self, reference_sents, decoded_words, ex_index): 167 | """Write output to file in correct format for eval with pyrouge. This is called in single_pass mode. 168 | 169 | Args: 170 | reference_sents: list of strings 171 | decoded_words: list of strings 172 | ex_index: int, the index with which to label the files 173 | """ 174 | # First, divide decoded output into sentences 175 | decoded_sents = [] 176 | while len(decoded_words) > 0: 177 | try: 178 | fst_period_idx = decoded_words.index(".") 179 | except ValueError: # there is text remaining that doesn't end in "." 180 | fst_period_idx = len(decoded_words) 181 | sent = decoded_words[:fst_period_idx+1] # sentence up to and including the period 182 | decoded_words = decoded_words[fst_period_idx+1:] # everything else 183 | decoded_sents.append(' '.join(sent)) 184 | 185 | # pyrouge calls a perl script that puts the data into HTML files. 186 | # Therefore we need to make our output HTML safe. 187 | decoded_sents = [make_html_safe(w) for w in decoded_sents] 188 | reference_sents = [make_html_safe(w) for w in reference_sents] 189 | 190 | # Write to file 191 | ref_file = os.path.join(self._rouge_ref_dir, "%06d_reference.txt" % ex_index) 192 | decoded_file = os.path.join(self._rouge_dec_dir, "%06d_decoded.txt" % ex_index) 193 | 194 | with open(ref_file, "w") as f: 195 | for idx,sent in enumerate(reference_sents): 196 | f.write(sent) if idx==len(reference_sents)-1 else f.write(sent+"\n") 197 | with open(decoded_file, "w") as f: 198 | for idx,sent in enumerate(decoded_sents): 199 | f.write(sent) if idx==len(decoded_sents)-1 else f.write(sent+"\n") 200 | 201 | tf.logging.info("Wrote example %i to file" % ex_index) 202 | 203 | 204 | def write_for_attnvis(self, article, abstract, decoded_words, attn_dists, p_gens): 205 | """Write some data to json file, which can be read into the in-browser attention visualizer tool: 206 | https://github.com/abisee/attn_vis 207 | 208 | Args: 209 | article: The original article string. 210 | abstract: The human (correct) abstract string. 211 | attn_dists: List of arrays; the attention distributions. 212 | decoded_words: List of strings; the words of the generated summary. 213 | p_gens: List of scalars; the p_gen values. If not running in pointer-generator mode, list of None. 214 | """ 215 | article_lst = article.split() # list of words 216 | decoded_lst = decoded_words # list of decoded words 217 | to_write = { 218 | 'article_lst': [make_html_safe(t) for t in article_lst], 219 | 'decoded_lst': [make_html_safe(t) for t in decoded_lst], 220 | 'abstract_str': make_html_safe(abstract), 221 | 'attn_dists': attn_dists 222 | } 223 | if FLAGS.pointer_gen: 224 | to_write['p_gens'] = p_gens 225 | output_fname = os.path.join(self._decode_dir, 'attn_vis_data.json') 226 | with open(output_fname, 'w') as output_file: 227 | json.dump(to_write, output_file) 228 | tf.logging.info('Wrote visualization data to %s', output_fname) 229 | 230 | 231 | def print_results(article, abstract, decoded_output): 232 | """Prints the article, the reference summmary and the decoded summary to screen""" 233 | print("---------------------------------------------------------------------------") 234 | tf.logging.info('ARTICLE: %s', article) 235 | tf.logging.info('REFERENCE SUMMARY: %s', abstract) 236 | tf.logging.info('GENERATED SUMMARY: %s', decoded_output) 237 | print("---------------------------------------------------------------------------") 238 | 239 | 240 | def make_html_safe(s): 241 | """Replace any angled brackets in string s to avoid interfering with HTML attention visualizer.""" 242 | s.replace("<", "<") 243 | s.replace(">", ">") 244 | return s 245 | 246 | 247 | def rouge_eval(ref_dir, dec_dir): 248 | """Evaluate the files in ref_dir and dec_dir with pyrouge, returning results_dict""" 249 | r = pyrouge.Rouge155() 250 | r.model_filename_pattern = '#ID#_reference.txt' 251 | r.system_filename_pattern = '(\d+)_decoded.txt' 252 | r.model_dir = ref_dir 253 | r.system_dir = dec_dir 254 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging 255 | rouge_results = r.convert_and_evaluate() 256 | return r.output_to_dict(rouge_results) 257 | 258 | 259 | def rouge_log(results_dict, dir_to_write): 260 | """Log ROUGE results to screen and write to file. 261 | 262 | Args: 263 | results_dict: the dictionary returned by pyrouge 264 | dir_to_write: the directory where we will write the results to""" 265 | log_str = "" 266 | for x in ["1","2","l"]: 267 | log_str += "\nROUGE-%s:\n" % x 268 | for y in ["f_score", "recall", "precision"]: 269 | key = "rouge_%s_%s" % (x,y) 270 | key_cb = key + "_cb" 271 | key_ce = key + "_ce" 272 | val = results_dict[key] 273 | val_cb = results_dict[key_cb] 274 | val_ce = results_dict[key_ce] 275 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce) 276 | tf.logging.info(log_str) # log to screen 277 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt") 278 | tf.logging.info("Writing final ROUGE results to %s...", results_file) 279 | with open(results_file, "w") as f: 280 | f.write(log_str) 281 | 282 | def get_decode_dir_name(ckpt_name): 283 | """Make a descriptive name for the decode dir, including the name of the checkpoint we use to decode. This is called in single_pass mode.""" 284 | 285 | if "train" in FLAGS.data_path: dataset = "train" 286 | elif "val" in FLAGS.data_path: dataset = "val" 287 | elif "test" in FLAGS.data_path: dataset = "test" 288 | elif "summ" in FLAGS.data_path: dataset = "test" 289 | else: raise ValueError("FLAGS.data_path %s should contain one of train, val, test, or summ" % (FLAGS.data_path)) 290 | dirname = "decode_%s_%imaxenc_%ibeam_%imindec_%imaxdec" % (dataset, FLAGS.max_enc_steps, FLAGS.beam_size, FLAGS.min_dec_steps, FLAGS.max_dec_steps) 291 | if ckpt_name is not None: 292 | dirname += "_%s" % ckpt_name 293 | return dirname 294 | -------------------------------------------------------------------------------- /models/pointer_generator/eval_medsumm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | validation_data=../../data_processing/data/medinfo_section2answer_validation_data_${1}.json 4 | # Bioasq abs2summ evaluating on medinfo page > section with question 5 | if [ $1 == "with_question" ]; then 6 | echo $1 7 | experiment=bioasq_abs2summ_with_question/ 8 | python -u run_medsumm.py --mode=eval --data_path=$validation_data --vocab_path=bioasq_abs2summ_vocab --exp_name=$experiment 9 | fi 10 | 11 | # And same thing but without question 12 | if [ $1 == "without_question" ]; then 13 | echo $1 14 | experiment=bioasq_abs2summ_without_question/ 15 | python -u run_medsumm.py --mode=eval --data_path=$validation_data --vocab_path=bioasq_abs2summ_vocab --exp_name=$experiment 16 | fi 17 | -------------------------------------------------------------------------------- /models/pointer_generator/inspect_checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this: 3 | python inspect_checkpoint.py model.12345 4 | """ 5 | 6 | import tensorflow as tf 7 | import sys 8 | import numpy as np 9 | 10 | 11 | if __name__ == '__main__': 12 | if len(sys.argv) != 2: 13 | raise Exception("Usage: python inspect_checkpoint.py \nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.") 14 | file_name = sys.argv[1] 15 | reader = tf.train.NewCheckpointReader(file_name) 16 | var_to_shape_map = reader.get_variable_to_shape_map() 17 | 18 | finite = [] 19 | all_infnan = [] 20 | some_infnan = [] 21 | 22 | for key in sorted(var_to_shape_map.keys()): 23 | tensor = reader.get_tensor(key) 24 | if np.all(np.isfinite(tensor)): 25 | finite.append(key) 26 | else: 27 | if not np.any(np.isfinite(tensor)): 28 | all_infnan.append(key) 29 | else: 30 | some_infnan.append(key) 31 | 32 | print("\nFINITE VARIABLES:") 33 | for key in finite: print(key) 34 | 35 | print("\nVARIABLES THAT ARE ALL INF/NAN:") 36 | for key in all_infnan: print(key) 37 | 38 | print("\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:") 39 | for key in some_infnan: print(key) 40 | 41 | if not all_infnan and not some_infnan: 42 | print("CHECK PASSED: checkpoint contains no inf/NaN values") 43 | else: 44 | print("CHECK FAILED: checkpoint contains some inf/NaN values") 45 | -------------------------------------------------------------------------------- /models/pointer_generator/make_asumm_pg_vocab.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to make vocab file for pointer generator network. Code modeled from make_datafiles.py 3 | in cnn-dailymail processing repository. 4 | 5 | File will have format 6 | word1 n_occurences 7 | word2 n_occurences 8 | ... 9 | wordn n_occurences 10 | 11 | To run: 12 | python make_asumm_pg_vocab.py --vocab_path=bioasq_abs2summ_vocab --data_file=../../data_process/data/bioasq_abs2summ_training_data_without_question.json 13 | """ 14 | 15 | from collections import Counter 16 | import json 17 | import argparse 18 | 19 | from spacy.tokenizer import Tokenizer 20 | from spacy.lang.en import English 21 | 22 | def get_args(): 23 | """ 24 | Get command line arguments 25 | """ 26 | parser = argparse.ArgumentParser(description="Arguments for data exploration") 27 | parser.add_argument("--vocab_path", 28 | dest="vocab_path", 29 | help="Path to create vocab file") 30 | parser.add_argument("--data_file", 31 | dest="data_file", 32 | help="Path to load data to make vocab") 33 | return parser 34 | 35 | 36 | def make_vocab(vocab_counter, vocab_file, VOCAB_SIZE, article, abstract, tokenizer): 37 | """ 38 | For each page/summary pair, tokenize on spaces, do a word count, and save tokens to file 39 | """ 40 | art_tokens = [t.text.strip() for t in tokenizer(article)] 41 | abs_tokens = [t.text.strip() for t in tokenizer(abstract)] 42 | tokens = art_tokens + abs_tokens 43 | tokens = [t for t in tokens if t != "" and t != "" and t != ""] 44 | vocab_counter.update(tokens) 45 | 46 | 47 | def load_data(data_file, vocab_path): 48 | """ 49 | Load data and tokenize each file 50 | """ 51 | VOCAB_SIZE = 200000 52 | with open(data_file, "r", encoding="utf-8") as f: 53 | training_data = json.load(f) 54 | print("Writing vocab file...") 55 | vocab_file = open(vocab_path, 'w', encoding="utf-8") 56 | vocab_counter = Counter() 57 | # Initiate spacy tokenizer without bells and whistles 58 | nlp = English() 59 | tokenizer = Tokenizer(nlp.vocab) 60 | 61 | for url, topic in training_data.items(): 62 | article = topic['articles'] 63 | abstract = topic['summary'] 64 | make_vocab(vocab_counter, vocab_file, VOCAB_SIZE, article, abstract, tokenizer) 65 | # After updating counter, write counts/words 66 | print("20 most common words:", vocab_counter.most_common(20)) 67 | for word, count in vocab_counter.most_common(VOCAB_SIZE): 68 | vocab_file.write(word + ' ' + str(count) + '\n') 69 | 70 | vocab_file.close() 71 | print("Finished writing vocab file") 72 | 73 | 74 | if __name__ == "__main__": 75 | args = get_args().parse_args() 76 | load_data(args.data_file, args.vocab_path) 77 | -------------------------------------------------------------------------------- /models/pointer_generator/requirements.txt: -------------------------------------------------------------------------------- 1 | pyrouge 2 | spacy 3 | -------------------------------------------------------------------------------- /models/pointer_generator/run_chiqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=/data/saveryme/asumm/models/pointer-generator/slurm_logs/slurm_%j.out 3 | #SBATCH --error=/data/saveryme/asumm/models/pointer-generator/slurm_logs/slurm_%j.error 4 | #SBATCH --job-name=chiq_eval 5 | #SBATCH --partition=gpu 6 | #SBATCH --gres=gpu:v100x:1 7 | #SBATCH --mem=10g 8 | #SBATCH --cpus-per-task=12 9 | #SBATCH --time=2-00:00:00 10 | 11 | # Can also include the with_question string in for loop 12 | for q in without_question 13 | do 14 | experiment=bioasq_abs2summ_${q} 15 | if [ ${q} == "with_question" ]; then 16 | q_driven=True 17 | else 18 | q_driven=False 19 | fi 20 | for summ_task in page2answer section2answer 21 | do 22 | for summ_type in single_abstractive single_extractive 23 | do 24 | data=${summ_task}_${summ_type}_summ.json 25 | input_data=../../data_processing/data/${data} 26 | predict_file=pointergen_chiqa_bioasq_abs2summ_${q}_${summ_task}_${summ_type}.json 27 | echo ${q} ${q_driven} 28 | echo $input_data 29 | echo $predict_file 30 | python run_medsumm.py \ 31 | --mode=decode \ 32 | --data_path=${input_data} \ 33 | --vocab_path=./bioasq_abs2summ_vocab \ 34 | --exp_name=$experiment \ 35 | --single_pass=True \ 36 | --eval_type=medsumm \ 37 | --generated_data_file=../../evaluation/data/pointer_generator/chiqa_eval/${predict_file} \ 38 | --tag_sentences=True \ 39 | --question_driven=${q_driven} 40 | done 41 | done 42 | done 43 | -------------------------------------------------------------------------------- /models/pointer_generator/run_medsumm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """ 18 | This is the top-level file to train, evaluate or test your summarization model 19 | 20 | This script has been converted from the original used for the "Get to the Point" paper so that 21 | it may be implemend for the summarization of consumer health answers 22 | """ 23 | 24 | import sys 25 | import time 26 | import os 27 | import tensorflow as tf 28 | import numpy as np 29 | from collections import namedtuple 30 | from data import Vocab 31 | from batcher import Batcher 32 | from model import SummarizationModel 33 | from decode import BeamSearchDecoder 34 | import util 35 | from tensorflow.python import debug as tf_debug 36 | 37 | FLAGS = tf.app.flags.FLAGS 38 | 39 | # Define whether to use answer summarization processing pipeline or original pipeline for CNN data: 40 | tf.app.flags.DEFINE_boolean('medsumm', True, "Definition for answer summarization data processing") 41 | tf.app.flags.DEFINE_boolean('question_driven', True, "Add question to beginning of text for question-driven summ. Option only applied for inference, as the training datasets are prepared appropriately") 42 | tf.app.flags.DEFINE_boolean('tag_sentences', False, "Add sentences to summaries when data processing for inference. All this does is allow the code to return the ref summaries properly.") 43 | 44 | # Where to find data 45 | tf.app.flags.DEFINE_string('data_path', '', 'Path expression to tf.Example datafiles. Can include wildcards to access multiple datafiles.') 46 | tf.app.flags.DEFINE_string('vocab_path', '', 'Path expression to text vocabulary file.') 47 | 48 | # Important settings 49 | tf.app.flags.DEFINE_string('mode', 'train', 'must be one of train/eval/decode') 50 | tf.app.flags.DEFINE_boolean('single_pass', False, 'For decode mode only. If True, run eval on the full dataset using a fixed checkpoint, i.e. take the current checkpoint, and use it to produce one summary for each example in the dataset, write the summaries to file and then get ROUGE scores for the whole dataset. If False (default), run concurrent decoding, i.e. repeatedly load latest checkpoint, use it to produce summaries for randomly-chosen examples and log the results to screen, indefinitely.') 51 | # If decode is true and single_pass is true, define whether to use original rouge evaluation code, or 52 | # write data for later answer summ evaluation 53 | tf.app.flags.DEFINE_string('eval_type', 'medsumm', 'must be medsumm or cnn') 54 | tf.app.flags.DEFINE_string('generated_data_file', '/data/saveryme/asumm/asumm_data/generated_data.json', 'json path for genearted data and reference summaries') 55 | 56 | # Where to save output 57 | tf.app.flags.DEFINE_string('log_root', '', 'Root directory for all logging.') 58 | tf.app.flags.DEFINE_string('exp_name', '', 'Name for experiment. Logs will be saved in a directory with this name, under log_root.') 59 | 60 | # Hyperparameters 61 | tf.app.flags.DEFINE_integer('hidden_dim', 256, 'dimension of RNN hidden states') 62 | tf.app.flags.DEFINE_integer('emb_dim', 128, 'dimension of word embeddings') 63 | tf.app.flags.DEFINE_integer('batch_size', 16, 'minibatch size') 64 | tf.app.flags.DEFINE_integer('max_enc_steps', 400, 'max timesteps of encoder (max source text tokens)') 65 | tf.app.flags.DEFINE_integer('max_dec_steps', 100, 'max timesteps of decoder (max summary tokens)') 66 | tf.app.flags.DEFINE_integer('beam_size', 4, 'beam size for beam search decoding.') 67 | tf.app.flags.DEFINE_integer('min_dec_steps', 35, 'Minimum sequence length of generated summary. Applies only for beam search decoding mode') 68 | tf.app.flags.DEFINE_integer('vocab_size', 50000, 'Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.') 69 | tf.app.flags.DEFINE_float('lr', 0.15, 'learning rate') 70 | tf.app.flags.DEFINE_float('adagrad_init_acc', 0.1, 'initial accumulator value for Adagrad') 71 | tf.app.flags.DEFINE_float('rand_unif_init_mag', 0.02, 'magnitude for lstm cells random uniform inititalization') 72 | tf.app.flags.DEFINE_float('trunc_norm_init_std', 1e-4, 'std of trunc norm init, used for initializing everything else') 73 | tf.app.flags.DEFINE_float('max_grad_norm', 2.0, 'for gradient clipping') 74 | 75 | # Pointer-generator or baseline model 76 | tf.app.flags.DEFINE_boolean('pointer_gen', True, 'If True, use pointer-generator model. If False, use baseline model.') 77 | 78 | # Coverage hyperparameters 79 | tf.app.flags.DEFINE_boolean('coverage', False, 'Use coverage mechanism. Note, the experiments reported in the ACL paper train WITHOUT coverage until converged, and then train for a short phase WITH coverage afterwards. i.e. to reproduce the results in the ACL paper, turn this off for most of training then turn on for a short phase at the end.') 80 | tf.app.flags.DEFINE_float('cov_loss_wt', 1.0, 'Weight of coverage loss (lambda in the paper). If zero, then no incentive to minimize coverage loss.') 81 | 82 | # Utility flags, for restoring and changing checkpoints 83 | tf.app.flags.DEFINE_boolean('convert_to_coverage_model', False, 'Convert a non-coverage model to a coverage model. Turn this on and run in train mode. Your current training model will be copied to a new version (same name with _cov_init appended) that will be ready to run with coverage flag turned on, for the coverage training stage.') 84 | tf.app.flags.DEFINE_boolean('restore_best_model', False, 'Restore the best model in the eval/ dir and save it in the train/ dir, ready to be used for further training. Useful for early stopping, or if your training checkpoint has become corrupted with e.g. NaN values.') 85 | 86 | # Debugging. See https://www.tensorflow.org/programmers_guide/debugger 87 | tf.app.flags.DEFINE_boolean('debug', False, "Run in tensorflow's debug mode (watches for NaN/inf values)") 88 | 89 | 90 | 91 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99): 92 | """Calculate the running average loss via exponential decay. 93 | This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve. 94 | 95 | Args: 96 | loss: loss on the most recent eval step 97 | running_avg_loss: running_avg_loss so far 98 | summary_writer: FileWriter object to write for tensorboard 99 | step: training iteration step 100 | decay: rate of exponential decay, a float between 0 and 1. Larger is smoother. 101 | 102 | Returns: 103 | running_avg_loss: new running average loss 104 | """ 105 | if running_avg_loss == 0: # on the first iteration just take the loss 106 | running_avg_loss = loss 107 | else: 108 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss 109 | running_avg_loss = min(running_avg_loss, 12) # clip 110 | loss_sum = tf.Summary() 111 | tag_name = 'running_avg_loss/decay=%f' % (decay) 112 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss) 113 | summary_writer.add_summary(loss_sum, step) 114 | tf.logging.info('running_avg_loss: %f', running_avg_loss) 115 | return running_avg_loss 116 | 117 | 118 | def restore_best_model(): 119 | """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory""" 120 | tf.logging.info("Restoring bestmodel for training...") 121 | 122 | # Initialize all vars in the model 123 | sess = tf.Session(config=util.get_config()) 124 | print("Initializing all variables...") 125 | sess.run(tf.initialize_all_variables()) 126 | 127 | # Restore the best model from eval dir 128 | saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name]) 129 | print("Restoring all non-adagrad variables from best model in eval dir...") 130 | curr_ckpt = util.load_ckpt(saver, sess, "eval") 131 | print ("Restored %s." % curr_ckpt) 132 | 133 | # Save this model to train dir and quit 134 | new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model") 135 | new_fname = os.path.join(FLAGS.log_root, "train", new_model_name) 136 | print ("Saving model to %s..." % (new_fname)) 137 | new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables 138 | new_saver.save(sess, new_fname) 139 | print ("Saved.") 140 | exit() 141 | 142 | 143 | def convert_to_coverage_model(): 144 | """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint""" 145 | tf.logging.info("converting non-coverage model to coverage model..") 146 | 147 | # initialize an entire coverage model from scratch 148 | sess = tf.Session(config=util.get_config()) 149 | print("initializing everything...") 150 | sess.run(tf.global_variables_initializer()) 151 | 152 | # load all non-coverage weights from checkpoint 153 | saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name]) 154 | print("restoring non-coverage variables...") 155 | curr_ckpt = util.load_ckpt(saver, sess) 156 | print("restored.") 157 | 158 | # save this model and quit 159 | new_fname = curr_ckpt + '_cov_init' 160 | print("saving model to %s..." % (new_fname)) 161 | new_saver = tf.train.Saver() # this one will save all variables that now exist 162 | new_saver.save(sess, new_fname) 163 | print("saved.") 164 | exit() 165 | 166 | 167 | def setup_training(model, batcher): 168 | """Does setup before starting training (run_training)""" 169 | train_dir = os.path.join(FLAGS.log_root, "train") 170 | if not os.path.exists(train_dir): os.makedirs(train_dir) 171 | 172 | model.build_graph() # build the graph 173 | if FLAGS.convert_to_coverage_model: 174 | assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True" 175 | convert_to_coverage_model() 176 | if FLAGS.restore_best_model: 177 | restore_best_model() 178 | saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time 179 | 180 | sv = tf.train.Supervisor(logdir=train_dir, 181 | is_chief=True, 182 | saver=saver, 183 | summary_op=None, 184 | save_summaries_secs=60, # save summaries for tensorboard every 60 secs 185 | save_model_secs=60, # checkpoint every 60 secs 186 | global_step=model.global_step) 187 | summary_writer = sv.summary_writer 188 | tf.logging.info("Preparing or waiting for session...") 189 | sess_context_manager = sv.prepare_or_wait_for_session(config=util.get_config()) 190 | tf.logging.info("Created session.") 191 | try: 192 | run_training(model, batcher, sess_context_manager, sv, summary_writer) # this is an infinite loop until interrupted 193 | except KeyboardInterrupt: 194 | tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...") 195 | sv.stop() 196 | 197 | 198 | def run_training(model, batcher, sess_context_manager, sv, summary_writer): 199 | """Repeatedly runs training iterations, logging loss to screen and writing summaries""" 200 | tf.logging.info("starting run_training") 201 | with sess_context_manager as sess: 202 | if FLAGS.debug: # start the tensorflow debugger 203 | sess = tf_debug.LocalCLIDebugWrapperSession(sess) 204 | sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 205 | while True: # repeats until interrupted 206 | batch = batcher.next_batch() 207 | #tf.logging.info(vars(batch)) 208 | 209 | tf.logging.info('running training step...') 210 | t0=time.time() 211 | results = model.run_train_step(sess, batch) 212 | t1=time.time() 213 | tf.logging.info('seconds for training step: %.3f', t1-t0) 214 | 215 | loss = results['loss'] 216 | tf.logging.info('loss: %f', loss) # print the loss to screen 217 | 218 | if not np.isfinite(loss): 219 | raise Exception("Loss is not finite. Stopping.") 220 | 221 | if FLAGS.coverage: 222 | coverage_loss = results['coverage_loss'] 223 | tf.logging.info("coverage_loss: %f", coverage_loss) # print the coverage loss to screen 224 | 225 | # get the summaries and iteration number so we can write summaries to tensorboard 226 | summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer 227 | train_step = results['global_step'] # we need this to update our running average loss 228 | 229 | tf.logging.info("\nTrain step: {}\n".format(train_step)) 230 | summary_writer.add_summary(summaries, train_step) # write the summaries 231 | if train_step % 100 == 0: # flush the summary writer every so often 232 | summary_writer.flush() 233 | if train_step == 10000: 234 | break 235 | 236 | 237 | def run_eval(model, batcher, vocab): 238 | """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" 239 | model.build_graph() # build the graph 240 | saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time 241 | sess = tf.Session(config=util.get_config()) 242 | eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data 243 | bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved 244 | summary_writer = tf.summary.FileWriter(eval_dir) 245 | running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping 246 | best_loss = None # will hold the best loss achieved so far 247 | 248 | while True: 249 | _ = util.load_ckpt(saver, sess) # load a new checkpoint 250 | batch = batcher.next_batch() # get the next batch 251 | 252 | # run eval on the batch 253 | t0=time.time() 254 | results = model.run_eval_step(sess, batch) 255 | t1=time.time() 256 | tf.logging.info('seconds for batch: %.2f', t1-t0) 257 | 258 | # print the loss and coverage loss to screen 259 | loss = results['loss'] 260 | tf.logging.info('loss: %f', loss) 261 | if FLAGS.coverage: 262 | coverage_loss = results['coverage_loss'] 263 | tf.logging.info("coverage_loss: %f", coverage_loss) 264 | 265 | # add summaries 266 | summaries = results['summaries'] 267 | train_step = results['global_step'] 268 | summary_writer.add_summary(summaries, train_step) 269 | 270 | running_avg_loss = calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step) 271 | 272 | # If running_avg_loss is best so far, save this checkpoint (early stopping). 273 | # These checkpoints will appear as bestmodel- in the eval dir 274 | if best_loss is None or running_avg_loss < best_loss: 275 | tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path) 276 | saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best') 277 | best_loss = running_avg_loss 278 | 279 | # flush the summary writer every so often 280 | if train_step % 100 == 0: 281 | summary_writer.flush() 282 | 283 | 284 | def main(unused_argv): 285 | #if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly 286 | # raise Exception("Problem with flags: %s" % unused_argv) 287 | 288 | tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want 289 | tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) 290 | 291 | # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary 292 | FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) 293 | if not os.path.exists(FLAGS.log_root): 294 | if FLAGS.mode=="train": 295 | os.makedirs(FLAGS.log_root) 296 | else: 297 | raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root)) 298 | 299 | vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary 300 | 301 | # If in decode mode, set batch_size = beam_size 302 | # Reason: in decode mode, we decode one example at a time. 303 | # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. 304 | if FLAGS.mode == 'decode': 305 | FLAGS.batch_size = FLAGS.beam_size 306 | 307 | # If single_pass=True, check we're in decode mode 308 | if FLAGS.single_pass and FLAGS.mode!='decode': 309 | raise Exception("The single_pass flag should only be True in decode mode") 310 | 311 | # Make a namedtuple hps, containing the values of the hyperparameters that the model needs 312 | hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen'] 313 | hps_dict = {} 314 | for key,val in FLAGS.__flags.items(): # for each flag 315 | if key in hparam_list: # if it's in the list 316 | hps_dict[key] = val # add it to the dict 317 | hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) 318 | 319 | # Create a batcher object that will create minibatches of data 320 | batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass, medsumm=FLAGS.medsumm, question_driven=FLAGS.question_driven, tag_sentences=FLAGS.tag_sentences) 321 | 322 | tf.set_random_seed(111) # a seed value for randomness 323 | 324 | if hps.mode == 'train': 325 | print("creating model...") 326 | model = SummarizationModel(hps, vocab) 327 | setup_training(model, batcher) 328 | elif hps.mode == 'eval': 329 | model = SummarizationModel(hps, vocab) 330 | run_eval(model, batcher, vocab) 331 | elif hps.mode == 'decode': 332 | decode_model_hps = hps # This will be the hyperparameters for the decoder model 333 | decode_model_hps = hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries 334 | model = SummarizationModel(decode_model_hps, vocab) 335 | decoder = BeamSearchDecoder(model, batcher, vocab) 336 | decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once) 337 | else: 338 | raise ValueError("The 'mode' flag must be one of train/eval/decode") 339 | 340 | if __name__ == '__main__': 341 | tf.app.run() 342 | -------------------------------------------------------------------------------- /models/pointer_generator/submit_sbatch_eval.sh: -------------------------------------------------------------------------------- 1 | for q in with_question without_question 2 | do 3 | echo $q 4 | sbatch --job-name=${q}_eval --export=QDRIVEN=$q eval_medsumm.sh 5 | done 6 | -------------------------------------------------------------------------------- /models/pointer_generator/train_medsumm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=/data/saveryme/asumm/models/pointer-generator/slurm_logs/slurm_%j.out 3 | #SBATCH --error=/data/saveryme/asumm/models/pointer-generator/slurm_logs/slurm_%j.error 4 | #SBATCH --partition=gpu 5 | #SBATCH --gres=gpu:p100:1 6 | #SBATCH --mem=40g 7 | #SBATCH --cpus-per-task=12 8 | #SBATCH --time=2-00:00:00 9 | 10 | training_data=../../data_processing/data/bioasq_abs2summ_training_data_${1}.json 11 | # With question: 12 | if [ $1 == "with_question" ]; then 13 | echo $1 14 | mkdir bioasq_abs2summ_with_question 15 | experiment=bioasq_abs2summ_with_question 16 | rm -r ${experiment}/* 17 | python -u run_medsumm.py --mode=train --data_path=$training_data --vocab_path=./bioasq_abs2summ_vocab --exp_name=$experiment 18 | fi 19 | 20 | # And without question: 21 | if [ $1 == "without_question" ]; then 22 | echo $1 23 | mkdir bioasq_abs2summ_without_question 24 | experiment=bioasq_abs2summ_without_question 25 | rm -r ${experiment}/* 26 | python -u run_medsumm.py --mode=train --data_path=$training_data --vocab_path=./bioasq_abs2summ_vocab --exp_name=$experiment 27 | fi 28 | -------------------------------------------------------------------------------- /models/pointer_generator/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains some utility functions""" 18 | 19 | import tensorflow as tf 20 | import time 21 | import os 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | def get_config(): 25 | """Returns config for tf.session""" 26 | config = tf.ConfigProto(allow_soft_placement=True) 27 | config.gpu_options.allow_growth=True 28 | return config 29 | 30 | def load_ckpt(saver, sess, ckpt_dir="train"): 31 | """Load checkpoint from the ckpt_dir (if unspecified, this is train dir) and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name.""" 32 | while True: 33 | try: 34 | latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None 35 | ckpt_dir = os.path.join(FLAGS.log_root, ckpt_dir) 36 | ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename) 37 | tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path) 38 | saver.restore(sess, ckpt_state.model_checkpoint_path) 39 | return ckpt_state.model_checkpoint_path 40 | except: 41 | tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10) 42 | time.sleep(10) 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lxml 2 | numpy 3 | spacy 4 | pandas 5 | sklearn 6 | py-rouge 7 | nltk 8 | tqdm 9 | openpyxl 10 | --------------------------------------------------------------------------------