├── .gitignore
├── LICENSE
├── README.md
├── data_processing
├── __pycache__
│ ├── PubMedClient.cpython-36.pyc
│ └── PubMedClient.cpython-37.pyc
├── data
│ ├── bioasq_pubmed_articles.json
│ └── medinfo_collection.json
├── prepare_training_data.py
├── prepare_validation_data.py
├── process_bioasq.py
└── process_medinfo.py
├── evaluation
├── collection_statistics.py
├── data
│ ├── bart
│ │ └── chiqa_eval
│ │ │ └── .gitignore
│ ├── baselines
│ │ └── chiqa_eval
│ │ │ └── .gitignore
│ ├── pointer_generator
│ │ └── chiqa_eval
│ │ │ └── .gitignore
│ └── sentence_classifier
│ │ └── chiqa_eval
│ │ └── .gitignore
├── results
│ └── .gitignore
└── summarization_evaluation.py
├── models
├── bart
│ ├── bart_config
│ │ └── .gitignore
│ ├── finetune_bart_bioasq.sh
│ ├── make_datafiles_for_bart.py
│ ├── process_bioasq_data.sh
│ ├── run_chiqa.sh
│ ├── run_inference_medsumm.py
│ └── train.py
├── baselines.py
├── bilstm
│ ├── __pycache__
│ │ ├── data.cpython-36.pyc
│ │ ├── data.cpython-37.pyc
│ │ ├── model.cpython-36.pyc
│ │ └── model.cpython-37.pyc
│ ├── data.py
│ ├── model.py
│ ├── requirements.txt
│ ├── run_chiqa.sh
│ ├── run_classifier.py
│ └── train_sentence_classifier.sh
└── pointer_generator
│ ├── LICENSE.txt
│ ├── __init__.py
│ ├── __pycache__
│ ├── attention_decoder.cpython-36.pyc
│ ├── batcher.cpython-36.pyc
│ ├── beam_search.cpython-36.pyc
│ ├── data.cpython-36.pyc
│ ├── decode.cpython-36.pyc
│ ├── model.cpython-36.pyc
│ └── util.cpython-36.pyc
│ ├── attention_decoder.py
│ ├── batcher.py
│ ├── beam_search.py
│ ├── bioasq_abs2summ_vocab
│ ├── data.py
│ ├── decode.py
│ ├── eval_medsumm.sh
│ ├── inspect_checkpoint.py
│ ├── make_asumm_pg_vocab.py
│ ├── model.py
│ ├── requirements.txt
│ ├── run_chiqa.sh
│ ├── run_medsumm.py
│ ├── submit_sbatch_eval.sh
│ ├── train_medsumm.sh
│ └── util.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | data_processing/data/BioASQ-training7b
2 | data_processing/data/MedQuAD-master
3 | data_processing/data/MedInfo2019-QA-Medications.xlsx
4 | data_processing/data/*multi*.json
5 | data_processing/data/*single*.json
6 | data_processing/data/question_driven_answer_summarization_primary_dataset.json
7 | data_processing/data/bioasq_collection.json
8 | data_processing/data/bioasq_ideal_answers.json
9 | data_processing/data/bioasq_snippets.json
10 | data_processing/data/medinfo_section*.json
11 | data_processing/data/bioasq_abs2summ_binary_sent_classification_training.json
12 | data_processing/data/bioasq_abs2summ_training_data_with*
13 | data_processing/data/bioasq_snippets.json
14 | models/pointer_generator/bioasq_abs2summ*
15 | models/bart/dict.txt
16 | models/bart/apex
17 | models/bart/fairseq*
18 | models/bart/encoder.json
19 | models/bart/vocab.bpe
20 | models/bart/checkpoints*
21 | models/bart/bart.large
22 | models/bart/bart_config/without_question/*
23 | models/bart/bart_config/with_question/*
24 | models/bilstm/medsumm*
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Question-Driven Summarization of Answers to Consumer Health Questions
2 | This repository contains the code to process the data and run the answer summarization systems presented in the paper Question-Driven Summarization of Answers to Consumer Health Questions, available at https://www.nature.com/articles/s41597-020-00667-z.
3 |
4 | If you are interested in just downloading the data, please refer to https://doi.org/10.17605/OSF.IO/FYG46. However, if you are interested in repeating the experiments reported in the paper, clone this repository and move the data found at https://doi.org/10.17605/OSF.IO/FYG46 to the evaluation/data directory.
5 |
6 | ## Environments
7 | All instructions provided here have been tested in a Linux operating system. To train the models and run the experiments, you will need to set up a few environments with anaconda: data processing and evaluation; the BiLSTM; BART; and the Pointer-Generator. Since we are going to be processing the data first, create the following environment to install the data processing and evaluation dependencies:
8 | ```
9 | conda create -n qdriven_env python=3.7
10 | conda activate qdriven_env
11 | pip install -r requirements.txt
12 | ```
13 | The requirements.txt file is found in the base directory of this repository.
14 | The spacy tokenizer model will be handy later on as well
15 | ```
16 | python -m spacy download en_core_web_sm
17 | ```
18 | And because py-rouge uses nltk, we also need nltk:
19 | ```
20 | python
21 | >>> import nltk
22 | >>> nltk.download('punkt')
23 | ```
24 | There are more details regarding the environments required to train each model in the following sections.
25 |
26 |
27 | ## Answer Summarization
28 | In the models directory, there are six systems:
29 | Deep learning:
30 | 1. BiLSTM
31 | 2. Pointer Generator
32 | 3. BART
33 |
34 | Baselines:
35 | 1. LEAD-k
36 | 2. Random-k sentences
37 | 3. k ROUGE sentences
38 |
39 |
40 | ### Runnning baselines
41 | Running the baselines is simple and can be done while the qdriven_env environment is active.
42 | ```
43 | python baselines.py --dataset=chiqa
44 | ```
45 | This will run the baseline summarization methods on the two summmarization tasks reported in the paper, as well as on the shorter passages. k (number of sentences selected by the baselines) can be changed in the script.
46 |
47 |
48 | ### Running deep learning
49 | The following code is organized to train and run inference with all models first, and then use the summarization_evaluation.py script to evaluate all results at once. This section describes the steps for training and inference.
50 |
51 |
52 | #### Training Preprocessing
53 | First prepare the validation data for the Pointer-Generator and BART:
54 | ```
55 | python prepare_validation_data.py --pg --bart
56 | ```
57 | It is optional to include the --add-q option if you are interested in training models question-driven summarization.
58 |
59 | To create the training data, the BioASQ data for training first has to be acquired. To do so, you have to register for an account at http://bioasq.org/participate.
60 |
61 | In the participants area of the website, you can find a list of all the datasets previously choosed for the BioASQ challenge. Download the 7b version of the task. Once the BioASQ data has been downloaded and unizpped, it should be placed in the data_processing/data directory in the cloned github repository, so that the path relative to the data_processing directory looks like ```data_processing/data/BioASQ-training7b/BioASQ-training7b/training7b.json```. Note that we used version 7b of BioASQ for training and testing. You are welcome to experiment with 8b or newer but will have to fix the paths in the code.
62 |
63 | Once the data is in the correct place, run the following scripts:
64 | ```
65 | python process_bioasq.py -p
66 | python prepare_training_data.py -bt --bart-bioasq --bioasq-sent
67 | ```
68 | This will prepare separate training sets for the three deep learning models. Include the ```--add-q``` option to create additional datasets with the question concatenated to the beginning of the documents, for question-driven summarization. This step will take a while to finish. Once it is done, you are ready for training and inference.
69 |
70 |
71 | #### BiLSTM (sentence classification)
72 | You will first need to set up a tensorflow2-gpu environent and install the dependencies for the model:
73 | ```
74 | conda create -n tf2_env tensorflow-gpu=2.0 python=3.7
75 | conda activate tf2_env
76 | pip install -r requirments.txt
77 | python -m spacy download en_core_web_sm
78 | ```
79 | Use the requirements.txt file located in the models/bilstm directory.
80 | Once the environment is set up, you are ready for training. This is quite a bit easier than the previous two models.
81 | ```
82 | train_sentence_classifier.sh
83 | ```
84 | The training script will automatically save the checkpoint the performs that best on the validation set. The training will end after 10 epochs. The training script is configured for TensorBoard and you can monitor the loss by running tensorboard in the medsumm_bioasq_abs2summ directory.
85 |
86 | Once the BiLSTM is trained, the following script is provided to run the model on all single document summarization tasks in MEDIQA-Ans:
87 | ```
88 | run_chiqa.sh
89 | ```
90 | You are now able to evaluate the BiLSTM output with the evaluation script. During inference, the run_classifier.py script will also create output files that can be used as input for inference with the Pointer-Generator or BART.k can be changed in the run_chiqa.sh script to experiment with passing top k sentences to the generative models.
91 |
92 |
93 | #### BART
94 | Download BART into the models/bart directory in this repository.
95 | ```
96 | wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
97 | tar -xzvf bart.large.tar.gz
98 | ```
99 | Navigate to the models/bart directory and prepare an environment for BART.
100 | ```
101 | conda create -n pytorch_env python=3.7 pytorch_env pytorch torchvision cudatoolkit=10.1 -c pytorch
102 | conda activate pytorch_env
103 | pip install -r requirements.txt
104 | ```
105 | To use automatic mixed precision (optional), follow these steps once pytorch is installed:
106 | ```
107 | conda install -n pytorch_env -c anaconda nccl
108 | git clone https://github.com/NVIDIA/apex
109 | cd apex
110 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" ./
111 | ```
112 | These instructions are provided in the main fairseq readme (https://github.com/pytorch/fairseq) but we have provided them here in condensed form. Note that to install apex, first make sure your GCC compiler is up-to-date.
113 | Once these dependencies have been installed, you are ready to install fairseq. This requires installing an editable version of a earlier commit of the repository. Navigate to back to the models/bart directory of this repo and run:
114 | ```
115 | git clone https://github.com/pytorch/fairseq
116 | cd fairseq
117 | git checkout 43cf9c977b8470ec493cc32d248fdcd9a984a9f6
118 | pip install --editable .
119 | ```
120 | Because the fairseq repository is contains research projects under continuous development, the previous commit checked-out here should can be used to recreate the results. Using a current version of fairseq may require more troubleshooting. Once you have fairseq installed and BART downloaded to the bart directory, there are a few steps you have to take to get the bioasq in suitable format for finetuning BART.
121 | First
122 | ```
123 | bash process_bioasq_data.sh -b -f without_question
124 | ```
125 | This will prepare the byte-pair encodings and run the fairseq preprocessing. with_question will prepare data for question-driven summarization, without_question for plain summaization. Once the processing is complete, you can finetune BART with
126 | ```
127 | bash finetune_bart_bioasq.sh without_question
128 | ```
129 | If you have been testing question-driven summarization, include with_question instead. The larger your computing cluster, the faster you will be able to train. For the experiments presented in the paper, we trained BART for two days on three V100-SXM2 GPUs with 32GB of memory each. The bash script provided here is currently configured to run with one GPU; however, the fairseq library supports multi-gpu training.
130 |
131 | Once you have finetuned the model, run inference on the MEDIQA-AnS dataset with
132 | ```
133 | bash run_chiqa.sh without_question
134 | ```
135 | Or use with_question if you have trained the appropriate model.
136 |
137 |
138 | #### Pointer-Generator
139 | Navigate to the models/pointer_generator directory and create a new environment:
140 | ```
141 | conda create -n tf1_env python=3.6
142 | conda activate tf1_env
143 | wget https://files.pythonhosted.org/packages/cb/4d/c9c4da41c6d7b9a4949cb9e53c7032d7d9b7da0410f1226f7455209dd962/tensorflow_gpu-1.2.0-cp36-cp36m-manylinux1_x86_64.whl
144 | pip install tensorflow_gpu-1.2.0-cp36-cp36m-manylinux1_x86_64.whl
145 | pip install -r requirements.txt
146 | python -m spacy download en_core_web_sm
147 | ```
148 | The tensorflow 1.2.0 version is only availble via download from pypi.org, hence the use of ```wget``` first. To train the model, you will have to install cuDNN 5 and CUDA 8. Once these are configured on your machine, you are ready for training.
149 | The Python 3 version of the Pointer-Generator code from https://github.com/becxer/pointer-generator/ (forked from https://github.com/abisee/pointer-generator) is provided in the models/pointer_generator directory here. The code has been customized to support answer summarization data processing steps, involving changes to data.py, batcher.py, decode.py, and run_summarization.py. However, the model (in model.py) remains the same.
150 |
151 | To use the Pointer-Generator, from the pointer_generator directory you will have to run
152 | ```
153 | python make_asumm_pg_vocab.py --vocab_path=bioasq_abs2summ_vocab --data_file=../../data_processing/data/bioasq_abs2summ_training_data_without_question.json
154 | ```
155 | first, to prepare the BioASQ vocab. If you are focusing on question-driven summarization, provide that dataset instead.
156 |
157 | Then, to train, you will need to run two jobs: One to train, and the other to evaluate the checkpoints simultaneously. Run these commands independently, on two different GPUs:
158 | ```
159 | bash train_medsumm.sh without_question
160 | bash eval_medsumm.sh without_question
161 | ```
162 | If you have access to a computing cluster that uses slurm, you may find it useful to use sbatch to submit these jobs.
163 | You will have to monitor the training of the Pointer-Generator via tensorboard and manually end the job once the loss has satisfactorily converged. The checkpoint that best performs on the MedInfo validation set will be saved to variable-name-of-experiment-directory/eval/checkpoint_best
164 |
165 | Once it is properly trained (the MEDIQA-AnS paper reports results after 10,000 training steps), run inference on full text with the web pages with
166 | ```
167 | run_chiqa.sh
168 | ```
169 | The question driven option can be changed in the bash script. Note that the single pass decoding in the original Pointer-Generator code is quite slow, and it will unfortunately take approximately 45 minutes per dataset to perform inference.
170 | Other experiments can be run if you have configured the bash script to generate summaries for the passages or multi-document datasets as well.
171 |
172 |
173 | ### Evaluation
174 | Once the models are training and the baselines have been run on the summarization datasets you are interested in evaluating, activate the qdriven_env environment again and navigate to the evaluation directory. To run the evaluation script on the summarization models' predictions, you have a few options:
175 | For comparing all models on extractive and abstractive summaries of web pages:
176 | ```
177 | python summarization_evaluation.py --dataset=chiqa --bleu --evaluate-models
178 | ```
179 | Or to run two versions of BART (question-driven approach and without questions, if you have trained both) run
180 | ```
181 | python summarization_evaluation.py --dataset=chiqa --bleu --q-driven
182 | ```
183 | The same question-driven test can be applied to the Pointer-Generator as well, if you have trained the appropriate model with the correctly formatted question-driven dataset.
184 |
185 | More details about the options, such as saving scores per summary to file, or calculating Wilcoxon p-values, are described in the script.
186 |
187 | If you are interested in generating the statistics describing the collection, run
188 | ```collection_statistics.py --tokenize```
189 | in the evaluation directory. This will generate the statistics reported in the paper with more technical detail.
190 |
191 |
192 | That's it! Thank you for using this code, and please contact us if you find any issues with the repository or have questions about summarization. If you publish work related to this project, please cite
193 | ```
194 | @article{saverysumm,
195 | title={Question-Driven Summarization of Answers to Consumer Health Questions},
196 | author={Max Savery and Asma {Ben Abacha} and Soumya Gayen and Dina Demner{-}Fushman},
197 | journal = {arXiv e-prints},
198 | month = {May},
199 | year={2020},
200 | eprint={2005.09067},
201 | archivePrefix={arXiv},
202 | primaryClass={cs.CL}
203 | url={https://arxiv.org/abs/2005.09067}
204 | }
205 |
--------------------------------------------------------------------------------
/data_processing/__pycache__/PubMedClient.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/data_processing/__pycache__/PubMedClient.cpython-36.pyc
--------------------------------------------------------------------------------
/data_processing/__pycache__/PubMedClient.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/data_processing/__pycache__/PubMedClient.cpython-37.pyc
--------------------------------------------------------------------------------
/data_processing/prepare_training_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Module for classes to prepare training datasets
3 |
4 | Prepare training data for pointer generator (add tags and tags to summaries):
5 | Prepare bioasq
6 | python prepare_training_data.py -bt
7 | Or with the question:
8 | python prepare_training_data.py -bt --add-q
9 |
10 | Prepare training data for bart, with or without question appended to beginning of abstract text
11 | python prepare_training_data.py --bart-bioasq --add-q
12 | python prepare_training_data.py --bart-bioasq
13 |
14 | Prepare training data for sentence classification:
15 | python prepare_training_data.py --bioasq-sent
16 | """
17 |
18 | import json
19 | import argparse
20 | import re
21 | import os
22 |
23 | from sklearn.utils import shuffle as sk_shuffle
24 | import rouge
25 | import spacy
26 |
27 |
28 | def get_args():
29 | """
30 | Argument defnitions
31 | """
32 | parser = argparse.ArgumentParser(description="Arguments for data exploration")
33 | parser.add_argument("-t",
34 | dest="tag_sentences",
35 | action="store_true",
36 | help="tag the sentences with and , for use with pointer generator network")
37 | parser.add_argument("-e",
38 | dest="summ_end_tag",
39 | action="store_true",
40 | help="Add the summ end tag to the end of the summaries. This was observed not to improve performance on the MedInfo evaluation set")
41 | parser.add_argument("-b",
42 | dest="bioasq_pg",
43 | action="store_true",
44 | help="Make the bioasq training set for pointer generator")
45 | parser.add_argument("--bioasq-sent",
46 | dest="bioasq_sc",
47 | action="store_true",
48 | help="Make the bioasq training set for sentence classification")
49 | parser.add_argument("--bart-bioasq",
50 | dest="bart_bioasq",
51 | action="store_true",
52 | help="Prepare the bioasq training set for bart")
53 | parser.add_argument("--add-q",
54 | dest="add_q",
55 | action="store_true",
56 | help="Concatenate the question to the beginning of the text. Currently only implemented as an option for bart data and the bioasq abs2summ data")
57 | return parser
58 |
59 |
60 | class BioASQ():
61 | """
62 | Class to create various versions of BioASQ training dataset
63 | """
64 |
65 | def __init__(self):
66 | """
67 | Initiate spacy
68 | """
69 | self.nlp = spacy.load('en_core_web_sm')
70 | self.Q_END = " [QUESTION?] "
71 | self.SUMM_END = " [END]"
72 | self.ARTICLE_END = " [ARTICLE_SEP] "
73 |
74 | def format_summary_sentences(self, summary):
75 | """
76 | Split summary into sentences and add sentence tags to the strings: and
77 | """
78 | tokenized_abs = self.nlp(summary)
79 | summary = " ".join([" {s} ".format(s=s.text.strip()) for s in tokenized_abs.sents])
80 | return summary
81 |
82 | def _load_bioasq(self):
83 | """
84 | Load bioasq collection generated in process_bioasq.py
85 | """
86 | with open("data/bioasq_collection.json", "r", encoding="utf8") as f:
87 | data = json.load(f)
88 | return data
89 |
90 | def create_abstract2snippet_dataset(self):
91 | """
92 | Generate the bioasq abstract to snippet (1 to 1) training dataset. This function creates data uses the same keys, summary and articles for each summary-article pair
93 | as the medlinplus training data does. This allows compatibility with the answer summarization data loading function in the pointer generator network.
94 |
95 | This is currently the only dataset with the option to include question. Since this dataset works the best, the add_q option was not added to the others.
96 | """
97 | bioasq_collection = self._load_bioasq()
98 | training_data_dict = {}
99 | snip_id = 0
100 | for i, q in enumerate(bioasq_collection):
101 | question = q
102 | for snippet in bioasq_collection[q]['snippets']:
103 | training_data_dict[snip_id] = {}
104 | if args.summ_end_tag:
105 | snippet_text = snippet['snippet'] + self.SUMM_END
106 | else:
107 | snippet_text = snippet['snippet']
108 | if args.tag_sentences:
109 | snippet_text = self.format_summary_sentences(snippet_text)
110 | training_data_dict[snip_id]['summary'] = snippet_text
111 | # Add the question with a special question seperator token to the beginning of the article.
112 | abstract = snippet['article']
113 | with_question = "_without_question"
114 | if args.add_q:
115 | abstract = question + self.Q_END + abstract
116 | with_question = "_with_question"
117 | training_data_dict[snip_id]['articles'] = abstract
118 | training_data_dict[snip_id]['question'] = question
119 | snip_id += 1
120 |
121 | with open("data/bioasq_abs2summ_training_data{}.json".format(with_question), "w", encoding="utf=8") as f:
122 | json.dump(training_data_dict, f, indent=4)
123 |
124 | def calculate_sentence_level_rouge(self, snip_sen, abs_sen, evaluator):
125 | """
126 | For each pair of sentences, calculate rouge score
127 | """
128 | rouge_score = evaluator.get_scores(abs_sen, snip_sen)['rouge-l']['f']
129 | return rouge_score
130 |
131 | def create_binary_sentence_classification_dataset_with_rouge(self):
132 | """
133 | Create a dataset for training a sentence classification model, where the binary y labels are assigned based on the
134 | best rouge score for a sentence in the article when compared to each sentence in the summary
135 | """
136 | # Initiate rouge evaluator
137 | evaluator = rouge.Rouge(metrics=['rouge-l'],
138 | max_n=3,
139 | limit_length=False,
140 | length_limit_type='words',
141 | apply_avg=False,
142 | apply_best=True,
143 | alpha=1,
144 | weight_factor=1.2,
145 | stemming=False)
146 |
147 | bioasq_collection = self._load_bioasq()
148 | training_data_dict = {}#
149 | snip_id = 0
150 | for i, q in enumerate(bioasq_collection):
151 | question = q
152 | for snippet in bioasq_collection[q]['snippets']:
153 | training_data_dict[snip_id] = {}
154 | labels = []
155 | # Sentencize snippet
156 | snippet_text = snippet['snippet']
157 | tokenized_snip = self.nlp(snippet_text)
158 | snippet_sentences = [s.text.strip() for s in tokenized_snip.sents]
159 | # Sentencize abstract
160 | abstract_text = snippet['article']
161 | tokenized_abs = self.nlp(abstract_text)
162 | abstract_sentences = [s.text.strip() for s in tokenized_abs.sents]
163 | rouge_scores = []
164 | for abs_sen in abstract_sentences:
165 | best_rouge = 0
166 | for snip_sen in snippet_sentences:
167 | rouge_score = self.calculate_sentence_level_rouge(snip_sen, abs_sen, evaluator)
168 | if best_rouge < rouge_score:
169 | best_rouge = rouge_score
170 | if best_rouge > .9:
171 | label = 1
172 | else:
173 | label = 0
174 | labels.append(label)
175 | training_data_dict[snip_id]['question'] = q
176 | training_data_dict[snip_id]['sentences'] = abstract_sentences
177 | training_data_dict[snip_id]['labels'] = labels
178 | snip_id += 1
179 |
180 | with open("data/bioasq_abs2summ_binary_sent_classification_training.json", "w", encoding="utf=8") as f:
181 | json.dump(training_data_dict, f, indent=4)
182 | # For each sentence in each abstract, compare it to each sentence in answer. Record the best rouge score.
183 |
184 | def create_data_for_bart(self):
185 | """
186 | Write the train and val data to file so that the processor and tokenizer for bart will read it, as per fairseqs design
187 | """
188 | bioasq_collection = self._load_bioasq()
189 | # Additional string is added to the question of the beginning of the abstract text
190 | if args.add_q:
191 | q_name = "with_question"
192 | else:
193 | q_name = "without_question"
194 |
195 | # Open medinfo data preprocessed in prepare_validation_data.py
196 | with open("data/medinfo_section2answer_validation_data_{}.json".format(q_name), "r", encoding="utf-8") as f:
197 | medinfo_val = json.load(f)
198 |
199 | try:
200 | os.mkdir("../models/bart/bart_config/{}".format(q_name))
201 | except FileExistsError:
202 | print("Directory ", q_name , " already exists")
203 |
204 | train_src = open("../models/bart/bart_config/{q}/bart.train_{q}.source".format(q=q_name), "w", encoding="utf8")
205 | train_tgt = open("../models/bart/bart_config/{q}/bart.train_{q}.target".format(q=q_name), "w", encoding="utf8")
206 | val_src = open("../models/bart/bart_config/{q}/bart.val_{q}.source".format(q=q_name), "w", encoding="utf8")
207 | val_tgt = open("../models/bart/bart_config/{q}/bart.val_{q}.target".format(q=q_name), "w", encoding="utf8")
208 | snippets_list = []
209 | abstracts_list = []
210 | for i, q in enumerate(bioasq_collection):
211 | for snippet in bioasq_collection[q]['snippets']:
212 | snippet_text = snippet['snippet'].strip()
213 | abstract_text = snippet['article'].strip()
214 | # Why is there whitespace in the question?
215 | question = q.replace("\n", " ")
216 | if args.add_q:
217 | abstract_text = question + self.Q_END + abstract_text
218 | abstracts_list.append(abstract_text)
219 | snippets_list.append(snippet_text)
220 |
221 | snp_cnt = 0
222 | print("Shuffling data")
223 | snippets_list, abstracts_list = sk_shuffle(snippets_list, abstracts_list, random_state=13)
224 | for snippet_text, abstract_text in zip(snippets_list, abstracts_list):
225 | snp_cnt += 1
226 | train_src.write("{}\n".format(abstract_text))
227 | train_tgt.write("{}\n".format(snippet_text))
228 |
229 | for q_id in medinfo_val:
230 | # The prepared medinfo data may have sentence tags in it for pointer generator.
231 | # There is an option in the prepare_validation_data.py script to not tag the data,
232 | # but it is easier to keep track of the datasets to just remove the tags here.
233 | summ = medinfo_val[q_id]['summary'].strip()
234 | summ = summ.replace("", "")
235 | summ = summ.replace("", "")
236 | articles = medinfo_val[q_id]['articles'].strip()
237 | val_src.write("{}\n".format(articles))
238 | val_tgt.write("{}\n".format(summ))
239 |
240 | train_src.close()
241 | train_tgt.close()
242 | val_src.close()
243 | val_tgt.close()
244 |
245 | # Make sure there were no funny newlines added
246 | train_src = open("../models/bart/bart_config/{q}/bart.train_{q}.source".format(q=q_name), "r", encoding="utf8").readlines()
247 | train_tgt = open("../models/bart/bart_config/{q}/bart.train_{q}.target".format(q=q_name), "r", encoding="utf8").readlines()
248 | val_src = open("../models/bart/bart_config/{q}/bart.val_{q}.source".format(q=q_name), "r", encoding="utf8").readlines()
249 | val_tgt = open("../models/bart/bart_config/{q}/bart.val_{q}.target".format(q=q_name), "r", encoding="utf8").readlines()
250 | print("Number of snippets: ", snp_cnt)
251 | assert len(train_src) == snp_cnt, len(train_src)
252 | assert len(train_tgt) == snp_cnt
253 | assert len(val_src) == len(medinfo_val)
254 | assert len(val_tgt) == len(medinfo_val)
255 |
256 |
257 | def process_data():
258 | """
259 | Save training data sets
260 | """
261 | if args.bioasq_pg:
262 | BioASQ().create_abstract2snippet_dataset()
263 | if args.bioasq_sc:
264 | BioASQ().create_binary_sentence_classification_dataset_with_rouge()
265 | if args.bart_bioasq:
266 | BioASQ().create_data_for_bart()
267 |
268 | if __name__ == "__main__":
269 | global args
270 | args = get_args().parse_args()
271 | process_data()
272 |
--------------------------------------------------------------------------------
/data_processing/prepare_validation_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Module for classes to prepare validation dataset from MedInfo dataset.
3 | Data format will be {key: {'question': question, 'summary':, summ, 'articles': articles} ...}
4 |
5 | Additionally, format for question driven summarization. For example:
6 | python prepare_validation_data.py -t --add-q
7 | """
8 |
9 |
10 | import json
11 | import argparse
12 | import re
13 |
14 | import spacy
15 | import rouge
16 |
17 | def get_args():
18 | """
19 | Argument defnitions
20 | """
21 | parser = argparse.ArgumentParser(description="Arguments for data exploration")
22 | parser.add_argument("--pg",
23 | dest="pg",
24 | action="store_true",
25 | help="tag the sentences with and , for use with pointer generator network")
26 | parser.add_argument("--bart",
27 | dest="bart",
28 | action="store_true",
29 | help="Prepare data for BART")
30 | parser.add_argument("--add-q",
31 | dest="add_q",
32 | action="store_true",
33 | help="Concatenate the question to the beginning of the text for question driven summarization")
34 |
35 | return parser
36 |
37 |
38 | class MedInfo():
39 |
40 | def __init__(self):
41 | """
42 | Initiate class for processing medinfo collection
43 | """
44 | self.nlp = spacy.load('en_core_web_sm')
45 | if args.add_q:
46 | self.q_name = "_with_question"
47 | else:
48 | self.q_name = "_without_question"
49 |
50 | def _load_collection(self):
51 | """
52 | Load medinfo collection prepared in the process_medinfo.py script
53 | """
54 | with open("data/medinfo_collection.json", "r", encoding="utf-8") as f:
55 | medinfo = json.load(f)
56 |
57 | return medinfo
58 |
59 | def _format_summary_sentences(self, summary):
60 | """
61 | Split summary into sentences and add sentence tags to the strings: and
62 | """
63 | tokenized_abs = self.nlp(summary)
64 | summary = " ".join([" {s} ".format(s=s.text.strip()) for s in tokenized_abs.sents])
65 | return summary
66 |
67 | def save_section2answer_validation_data(self, tag_sentences):
68 | """
69 | For questions that have a corresponding section-answer pair, save the
70 | validation data in following format
71 | {'question': {'summary': text, 'articles': text}}
72 | """
73 | dev_dict = {}
74 | medinfo = self._load_collection()
75 | data_pair = 0
76 | Q_END = " [QUESTION?] "
77 | for i, question in enumerate(medinfo):
78 | try:
79 | # There may be multiple answers per question, but for the sake of the validation set,
80 | # just use the first answer
81 | if 'section_text' in medinfo[question][0]:
82 | article = medinfo[question][0]['section_text']
83 | summary = medinfo[question][0]['answer']
84 | # Stripping of whitespace was done in processing script for section and full page
85 | # but not for answer or question
86 | summary = re.sub(r"\s+", " ", summary)
87 | question = re.sub(r"\s+", " ", question)
88 | if args.add_q:
89 | article = question + Q_END + article
90 | assert len(summary) <= (len(article) + 10)
91 | if tag_sentences:
92 | summary = self._format_summary_sentences(summary)
93 | tag_string = "_s-tags"
94 | else:
95 | tag_string = ""
96 | data_pair += 1
97 | dev_dict[i] = {'question': question, 'summary': summary, 'articles': article}
98 | except AssertionError:
99 | print("Answer longer than summary. Skipping element")
100 |
101 | print("Number of page-section pairs:", data_pair)
102 |
103 | with open("data/medinfo_section2answer_validation_data{0}{1}.json".format(self.q_name, tag_string), "w", encoding="utf-8") as f:
104 | json.dump(dev_dict, f, indent=4)
105 |
106 |
107 | def process_data():
108 | """
109 | Main function for saving data
110 | """
111 | # Run once for each
112 | if args.pg:
113 | MedInfo().save_section2answer_validation_data(tag_sentences=True)
114 | if args.bart:
115 | MedInfo().save_section2answer_validation_data(tag_sentences=False)
116 |
117 | if __name__ == "__main__":
118 | global args
119 | args = get_args().parse_args()
120 | process_data()
121 |
--------------------------------------------------------------------------------
/data_processing/process_bioasq.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for processing BioASQ json data and saving
3 |
4 | to download the pubmed articles for each snippet run
5 | python process_bioasq.py -d
6 | then to process the questions, answers, and snippets, run:
7 | python process_bioasq.py -p
8 | """
9 |
10 |
11 | import json
12 | import sys
13 | import os
14 | import argparse
15 | import lxml.etree as le
16 | import glob
17 | from collections import Counter
18 |
19 | import numpy as np
20 |
21 |
22 | def get_args():
23 | """
24 | Get command line arguments
25 | """
26 |
27 | parser = argparse.ArgumentParser(description="Arguments for data exploration")
28 | parser.add_argument("-p",
29 | dest="process",
30 | action="store_true",
31 | help="Process bioasq data")
32 | return parser
33 |
34 |
35 | class BioASQ():
36 | """
37 | Class for processing and saving BioASQ data
38 | """
39 |
40 | def _load_bioasq(self):
41 | """
42 | Load bioasq dataset
43 | """
44 | with open("data/BioASQ-training7b/BioASQ-training7b/training7b.json", "r", encoding="ascii") as f:
45 | bioasq_questions = json.load(f)['questions']
46 | return bioasq_questions
47 |
48 | def bioasq(self):
49 | """
50 | Process BioASQ training data. Generate summary stats. Save questions, ideal answers, snippets, articles, and question types.
51 | """
52 | bioasq_questions = self._load_bioasq()
53 | with open("data/bioasq_pubmed_articles.json", "r", encoding="ascii") as f:
54 | articles = json.load(f)
55 | # Dictionary to save condensed json of bioasq
56 | bioasq_collection = {}
57 | questions = []
58 | ideal_answers = []
59 | ideal_answer_dict = {}
60 | exact_answers = []
61 | snippet_dict = {}
62 | for i, q in enumerate(bioasq_questions):
63 | # Get the question
64 | bioasq_collection[q['body']] = {}
65 | questions.append(q['body'])
66 | # Get the references used to answer that question
67 | pmid_list= [d.split("/")[-1] for d in q['documents']]
68 | # Get the question type: list, summary, yes/no, or factoid
69 | q_type = q['type']
70 | bioasq_collection[q['body']]['q_type'] = q_type
71 | # Take the first ideal answer
72 | assert isinstance(q['ideal_answer'], list)
73 | assert isinstance(q['ideal_answer'][0], str)
74 | ideal_answer_dict[i] = q['ideal_answer'][0]
75 | bioasq_collection[q['body']]['ideal_answer'] = q['ideal_answer'][0]
76 | # And get the first exact answer
77 | if q_type != "summary":
78 | # Yesno questions will have just a yes/no string in exact answer.
79 | if q_type == "yesno":
80 | exact_answers.append(q['exact_answer'][0])
81 | bioasq_collection[q['body']]['exact_answer'] = q['exact_answer'][0]
82 | else:
83 | if isinstance(q['exact_answer'], str):
84 | exact_answers.append(q['exact_answer'])
85 | bioasq_collection[q['body']]['exact_answer'] = q['exact_answer']
86 | else:
87 | exact_answers.append(q['exact_answer'][0])
88 | bioasq_collection[q['body']]['exact_answer'] = q['exact_answer'][0]
89 | # Then handle the snippets (the text extracted from the abstract)
90 | bioasq_collection[q['body']]['snippets'] = []
91 | snippet_dict[q['body']] = []
92 | for snippet in q['snippets']:
93 | pmid_match = False
94 | snippet_dict[q['body']].append(snippet['text'])
95 | doc_pmid = str(snippet['document'].split("/")[-1])
96 | try:
97 | article = articles[doc_pmid]
98 | # Add the data to the dictionary containing the collection.
99 | bioasq_collection[q['body']]['snippets'].append({'snippet': snippet['text'], 'article': article, 'pmid': doc_pmid})
100 | except KeyError as e:
101 | continue
102 |
103 | with open("data/bioasq_ideal_answers.json", "w", encoding="utf8") as f:
104 | json.dump(ideal_answer_dict, f, indent=4)
105 | with open("data/bioasq_snippets.json", "w", encoding="utf8") as f:
106 | json.dump(snippet_dict, f, indent=4)
107 | with open("data/bioasq_collection.json", "w", encoding="utf8") as f:
108 | json.dump(bioasq_collection, f, indent=4)
109 |
110 |
111 | def process_bioasq():
112 | """
113 | Main processing function for bioasq data
114 | """
115 | bq = BioASQ()
116 | if args.process:
117 | bq.bioasq()
118 |
119 | if __name__ == "__main__":
120 | global args
121 | args = get_args().parse_args()
122 | process_bioasq()
123 |
--------------------------------------------------------------------------------
/evaluation/collection_statistics.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to compute statistics on collection and validate the data integrity
3 |
4 | To run with counts of sentences and tokens:
5 | python collection_statistics.py --tokenize
6 | Otherwise don't include --tokenize option
7 | """
8 |
9 | import argparse
10 | import spacy
11 | import json
12 | import numpy as np
13 | from collections import Counter
14 |
15 |
16 | def get_args():
17 | """
18 | Argument parser for preparing chiqa data
19 | """
20 | parser = argparse.ArgumentParser(description="Arguments for data exploration")
21 | parser.add_argument("--tokenize",
22 | dest="tokenize",
23 | action="store_true",
24 | help="Tokenize by words and sentences, counting averages/sd for each.")
25 | return parser
26 |
27 |
28 | class SummarizationDataStats():
29 | """
30 | Class for validating annotated collection
31 | """
32 |
33 | def __init__(self):
34 | """
35 | Init spacy and counting variables.
36 | """
37 | self.nlp = spacy.load('en_core_web_sm')
38 | self.summary_file = open("results/question_driven_answer_summ_collection_stats.txt", "w", encoding="utf-8")
39 |
40 | def load_data(self, dataset, dataset_name):
41 | """
42 | Given the path of the dataset, load it!
43 | """
44 | with open(dataset, "r", encoding="utf-8") as f:
45 | self.data = json.load(f)
46 | self.dataset_name = dataset_name
47 |
48 | def _get_token_cnts(self, doc, doc_type):
49 | """
50 | Count tokens and in documents
51 | """
52 | tokenized_doc = self.nlp(doc)
53 | self.stat_dict[doc_type][0].append(len([s for s in tokenized_doc.sents]))
54 | doc_len = len([t for t in tokenized_doc])
55 | self.stat_dict[doc_type][1].append(doc_len)
56 | if doc_len < 50 and doc_type == "answer":
57 | print("Document less than 50 tokens:", url)
58 |
59 | def iterate_data(self):
60 | """
61 | Count the number of examples in each dataset
62 | """
63 | if "single" in self.dataset_name:
64 | # Index 0 for list of sentence lengths, index 1 for list of token lengths
65 | self.stat_dict = {'question': [[], []], 'summary': [[], []], 'article': [[], []]}
66 | for answer_id in self.data:
67 | summary = self.data[answer_id]['summary']
68 | articles = self.data[answer_id]['articles']
69 | question = self.data[answer_id]['question']
70 | if args.tokenize:
71 | self._get_token_cnts(summary, 'summary')
72 | self._get_token_cnts(articles, 'article')
73 | self._get_token_cnts(question, 'question')
74 | self._write_stats("token_counts")
75 |
76 | if "multi" in self.dataset_name:
77 | self.stat_dict = {'question': [[], []], 'summary': [[], []], 'article': [[], []]}
78 | for q_id in self.data:
79 | summary = self.data[q_id]['summary']
80 | question = self.data[q_id]['question']
81 | if args.tokenize:
82 | self._get_token_cnts(summary, 'summary')
83 | self._get_token_cnts(question, 'question')
84 | question = self.data[q_id]['question']
85 | for answer_id in self.data[q_id]['articles']:
86 | articles = self.data[q_id]['articles'][answer_id][0]
87 | if args.tokenize:
88 | self._get_token_cnts(articles, 'article')
89 | self._write_stats("token_counts")
90 |
91 | if self.dataset_name == "complete_dataset":
92 | self.stat_dict = {'urls': [], 'sites': []}
93 | article_dict = {}
94 | print("Counting answers, sites, unique urls, and tokenized counts of unique articles")
95 | answer_cnt = 0
96 | for q_id in self.data:
97 | for a_id in self.data[q_id]['answers']:
98 | answer_cnt += 1
99 | url = self.data[q_id]['answers'][a_id]['url']
100 | article = self.data[q_id]['answers'][a_id]['article']
101 | if url not in article_dict:
102 | article_dict[url] = article
103 | self.stat_dict['urls'].append(url)
104 | assert "//" in url, url
105 | site = url.split("//")[1].split("/")
106 | self.stat_dict['sites'].append(site[0])
107 | print("# of Answers:", answer_cnt)
108 | print("Unique articles: ", len(article_dict)) # This should match up with count written to file
109 | self._write_stats("full collection")
110 |
111 | # Get token/sent averages of unique articles
112 | if args.tokenize:
113 | self.stat_dict = {'article': [[], []]}
114 | for a in article_dict:
115 | self._get_token_cnts(article_dict[a], 'article')
116 | self._write_stats("token_counts")
117 |
118 | def _write_stats(self, stat_type, user=None, summ_type=None):
119 | """
120 | Return chiqa page stats
121 | """
122 | if stat_type == "full collection":
123 | self.summary_file.write("\n\nDataset: {c}\n".format(c=self.dataset_name))
124 | self.summary_file.write("Number of unique urls: {u}\nNumber of unique sites: {s}\n".format(u=len(set(self.stat_dict['urls'])), s=len(set(self.stat_dict['sites'])))
125 | )
126 | site_cnts = Counter(self.stat_dict['sites']).most_common()
127 | for site in site_cnts:
128 | self.summary_file.write("{s}: {n}\n".format(s=site[0], n=site[1]))
129 |
130 | if stat_type == "token_counts":
131 | self.summary_file.write("\n\nDataset: {c}\n".format(c=self.dataset_name))
132 | for doc_type in self.stat_dict:
133 | if user is not None:
134 | self.summary_file.write("\n{0}, {1}\n".format(user, summ_type))
135 |
136 | self.summary_file.write(
137 | "\nNumber of {d}s: {p}\nAverage tokens/{d}: {t}\nAverage sentences/{d}: {s}\n".format(
138 | d=doc_type, p=len(self.stat_dict[doc_type][0]), t=sum(self.stat_dict[doc_type][1])/len(self.stat_dict[doc_type][1]), s=sum(self.stat_dict[doc_type][0])/len(self.stat_dict[doc_type][0])
139 | )
140 | )
141 |
142 | self.summary_file.write(
143 | "Median tokens/{d}: {p}\nStandard deviation tokens/{d}: {t}\n".format(
144 | d=doc_type, p=np.median(self.stat_dict[doc_type][1]), t=np.std(self.stat_dict[doc_type][1])
145 | )
146 | )
147 |
148 | self.summary_file.write(
149 | "Median sentences/{d}: {p}\nStandard deviation sentences/{d}: {t}\n".format(
150 | d=doc_type, p=np.median(self.stat_dict[doc_type][0]), t=np.std(self.stat_dict[doc_type][0])
151 | )
152 | )
153 |
154 |
155 | def get_stats():
156 | """
157 | Main function for getting CHiQA collection stats
158 | """
159 | datasets = [
160 | ("../data_processing/data/page2answer_single_abstractive_summ.json", "p2a-single-abs"),
161 | ("../data_processing/data/page2answer_single_extractive_summ.json", "p2a-single-ext"),
162 | ("../data_processing/data/section2answer_multi_abstractive_summ.json", "s2a-multi-abs"),
163 | ("../data_processing/data/page2answer_multi_extractive_summ.json", "p2a-multi-ext"),
164 | ("../data_processing/data/section2answer_single_abstractive_summ.json", "s2a-single-abs"),
165 | ("../data_processing/data/section2answer_single_extractive_summ.json", "s2a-single-ext"),
166 | ("../data_processing/data/section2answer_multi_extractive_summ.json", "s2a-multi-ext"),
167 | ("../data_processing/data/question_driven_answer_summarization_primary_dataset.json", "complete_dataset"),
168 | ]
169 |
170 | stats = SummarizationDataStats()
171 | for dataset in datasets:
172 | print(dataset[1])
173 | stats.load_data(dataset[0], dataset[1])
174 | stats.iterate_data()
175 |
176 |
177 | if __name__ == "__main__":
178 | global args
179 | args = get_args().parse_args()
180 | get_stats()
181 |
--------------------------------------------------------------------------------
/evaluation/data/bart/chiqa_eval/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/evaluation/data/baselines/chiqa_eval/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/evaluation/data/pointer_generator/chiqa_eval/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/evaluation/data/sentence_classifier/chiqa_eval/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/evaluation/results/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/models/bart/bart_config/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/models/bart/finetune_bart_bioasq.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | TOTAL_NUM_UPDATES=20000
4 | WARMUP_UPDATES=500
5 | LR=3e-05
6 | MAX_TOKENS=1024
7 | UPDATE_FREQ=16
8 | BART_PATH=bart/bart.large/model.pt
9 | checkpoint_path=checkpoints_bioasq_$1
10 | asumm_data=bart_config/${1}/bart-bin
11 |
12 | #CUDA_VISIBLE_DEVICES=0 python train.py ${asumm_data} \
13 | CUDA_VISIBLE_DEVICES=0 python fairseq/fairseq_cli/train.py ${asumm_data} \
14 | --restore-file $BART_PATH \
15 | --max-tokens $MAX_TOKENS \
16 | --truncate-source \
17 | --task translation \
18 | --source-lang source --target-lang target \
19 | --layernorm-embedding \
20 | --share-all-embeddings \
21 | --share-decoder-input-output-embed \
22 | --reset-optimizer --reset-dataloader --reset-meters \
23 | --required-batch-size-multiple 1 \
24 | --arch bart_large \
25 | --criterion label_smoothed_cross_entropy \
26 | --label-smoothing 0.1 \
27 | --dropout 0.1 --attention-dropout 0.1 \
28 | --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
29 | --clip-norm 0.1 \
30 | --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
31 | --update-freq $UPDATE_FREQ \
32 | --skip-invalid-size-inputs-valid-test \
33 | --fp16 \
34 | --ddp-backend=no_c10d \
35 | --save-dir=${checkpoint_path} \
36 | --keep-last-epochs=2 \
37 | --find-unused-parameters;
38 |
--------------------------------------------------------------------------------
/models/bart/make_datafiles_for_bart.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the make_datafiles script from the pointgenerator cnn-dailymail repo
3 | https://github.com/abisee/cnn-dailymail/blob/b15ad0a2db0d407a84b8ca9b5731e1f1c4bd24b9/make_datafiles.py#L235
4 | but with modifications made here
5 | https://gist.github.com/zhaoguangxiang/45bf39c528cf7fb7853bffba7fe57c7e
6 | The script now saves the cnn dailymail files as test.source and test.target, etc
7 |
8 | It requires python 2 to run. Run the activate_py2_env.sh to activate.
9 | """
10 |
11 | import sys
12 | import os
13 | import hashlib
14 | import struct
15 | import subprocess
16 | import collections
17 | import codecs
18 | # import tensorflow as tf
19 | # from tensorflow.core.example import example_pb2
20 | # import sys
21 | # reload(sys)
22 | # sys.setdefaultencoding('utf8')
23 | # cnt_r=0
24 | dm_single_close_quote = u'\u2019' # unicode
25 | dm_double_close_quote = u'\u201d'
26 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence
27 | # We use these to separate the summary sentences in the .bin datafiles
28 | SENTENCE_START = ''
29 | SENTENCE_END = ''
30 |
31 | all_train_urls = "/data/saveryme/asumm/asumm_data/cnn-dailymail/url_lists/all_train.txt"
32 | all_val_urls = "/data/saveryme/asumm/asumm_data/cnn-dailymail/url_lists/all_val.txt"
33 | all_test_urls = "/data/saveryme/asumm/asumm_data/cnn-dailymail/url_lists/all_test.txt"
34 |
35 | # cnn_tokenized_stories_dir = "cnn_stories_tokenized"
36 | # dm_tokenized_stories_dir = "dm_stories_tokenized"
37 | cnn_tokenized_stories_dir = "/data/saveryme/asumm/asumm_data/cnn_dm/cnn/stories"
38 | dm_tokenized_stories_dir = "/data/saveryme/asumm/asumm_data/cnn_dm/dailymail/stories"
39 | # finished_files_dir = "finished_files"
40 | finished_files_dir = "cnn_dm_finished_files"
41 | # chunks_dir = os.path.join(finished_files_dir, "chunked")
42 |
43 | # These are the number of .story files we expect there to be in cnn_stories_dir and dm_stories_dir
44 | num_expected_cnn_stories = 92579
45 | num_expected_dm_stories = 219506
46 |
47 | VOCAB_SIZE = 200000
48 | CHUNK_SIZE = 1000 # num examples per chunk, for the chunked data
49 |
50 |
51 |
52 | def chunk_file(set_name):
53 | in_file = 'finished_files/%s.bin' % set_name
54 | reader = open(in_file, "rb")
55 | chunk = 0
56 | finished = False
57 | while not finished:
58 | chunk_fname = os.path.join(chunks_dir, '%s_%03d.bin' % (set_name, chunk)) # new chunk
59 | with open(chunk_fname, 'wb') as writer:
60 | for _ in range(CHUNK_SIZE):
61 | len_bytes = reader.read(8)
62 | if not len_bytes:
63 | finished = True
64 | break
65 | str_len = struct.unpack('q', len_bytes)[0]
66 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
67 | writer.write(struct.pack('q', str_len))
68 | writer.write(struct.pack('%ds' % str_len, example_str))
69 | chunk += 1
70 |
71 |
72 | def chunk_all():
73 | # Make a dir to hold the chunks
74 | if not os.path.isdir(chunks_dir):
75 | os.mkdir(chunks_dir)
76 | # Chunk the data
77 | for set_name in ['train', 'val', 'test']:
78 | print "Splitting %s data into chunks..." % set_name
79 | chunk_file(set_name)
80 | print "Saved chunked data in %s" % chunks_dir
81 |
82 |
83 | def tokenize_stories(stories_dir, tokenized_stories_dir):
84 | """Maps a whole directory of .story files to a tokenized version using Stanford CoreNLP Tokenizer"""
85 | print "Preparing to tokenize %s to %s..." % (stories_dir, tokenized_stories_dir)
86 | stories = os.listdir(stories_dir)
87 | # make IO list file
88 | print "Making list of files to tokenize..."
89 | with open("mapping.txt", "w") as f:
90 | for s in stories:
91 | f.write("%s \t %s\n" % (os.path.join(stories_dir, s), os.path.join(tokenized_stories_dir, s)))
92 | command = ['java', 'edu.stanford.nlp.process.PTBTokenizer', '-ioFileList', '-preserveLines', 'mapping.txt']
93 | print "Tokenizing %i files in %s and saving in %s..." % (len(stories), stories_dir, tokenized_stories_dir)
94 | subprocess.call(command)
95 | os.remove("mapping.txt")
96 |
97 | # Check that the tokenized stories directory contains the same number of files as the original directory
98 | num_orig = len(os.listdir(stories_dir))
99 | num_tokenized = len(os.listdir(tokenized_stories_dir))
100 | if num_orig != num_tokenized:
101 | raise Exception("The tokenized stories directory %s contains %i files, but it should contain the same number as %s (which has %i files). Was there an error during tokenization?" % (tokenized_stories_dir, num_tokenized, stories_dir, num_orig))
102 | print "Successfully finished tokenizing %s to %s.\n" % (stories_dir, tokenized_stories_dir)
103 |
104 |
105 | def read_text_file(text_file):
106 | lines = []
107 | with open(text_file, "r") as f:
108 | for line in f:
109 | lines.append(line.strip())
110 | return lines
111 |
112 |
113 | def hashhex(s):
114 | """Returns a heximal formated SHA1 hash of the input string."""
115 | h = hashlib.sha1()
116 | h.update(s)
117 | return h.hexdigest()
118 |
119 |
120 | def get_url_hashes(url_list):
121 | return [hashhex(url) for url in url_list]
122 |
123 |
124 | def fix_missing_period(line):
125 | """Adds a period to a line that is missing a period"""
126 | if "@highlight" in line: return line
127 | if line=="": return line
128 | if line[-1] in END_TOKENS: return line
129 | return line + "."
130 |
131 |
132 | def get_art_abs(story_file):
133 | lines = read_text_file(story_file)
134 |
135 | # Lowercase everything
136 | # lines = [line.lower() for line in lines]
137 |
138 | # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences)
139 | lines = [fix_missing_period(line) for line in lines]
140 |
141 | # Separate out article and abstract sentences
142 | article_lines = []
143 | highlights = []
144 | next_is_highlight = False
145 | for idx,line in enumerate(lines):
146 | if line == "":
147 | continue # empty line
148 | elif line.startswith("@highlight"):
149 | next_is_highlight = True
150 | elif next_is_highlight:
151 | highlights.append(line)
152 | else:
153 | article_lines.append(line)
154 |
155 | # Make article into a single string
156 | article = ' '.join(article_lines)
157 |
158 | # Make abstract into a signle string, putting and tags around the sentences
159 | # abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights])
160 | abstract = ' '.join([" %s " % (sent) for sent in highlights])
161 |
162 | return article, abstract
163 |
164 |
165 | def write_to_bin(url_file, out_file, makevocab=False):
166 | """Reads the tokenized .story files corresponding to the urls listed in the url_file and writes them to a out_file."""
167 | print "Making bin file for URLs listed in %s..." % url_file
168 | url_list = read_text_file(url_file)
169 | url_hashes = get_url_hashes(url_list)
170 | story_fnames = [s+".story" for s in url_hashes]
171 | num_stories = len(story_fnames)
172 |
173 | if makevocab:
174 | vocab_counter = collections.Counter()
175 | cnt_r=0
176 | cnt_n=0
177 | with open(out_file, 'wb') as writer:
178 | for idx,s in enumerate(story_fnames):
179 | if idx % 1000 == 0:
180 | print "Writing story %i of %i; %.2f percent done" % (idx, num_stories, float(idx)*100.0/float(num_stories))
181 |
182 | # Look in the tokenized story dirs to find the .story file corresponding to this url
183 | if os.path.isfile(os.path.join(cnn_tokenized_stories_dir, s)):
184 | story_file = os.path.join(cnn_tokenized_stories_dir, s)
185 | elif os.path.isfile(os.path.join(dm_tokenized_stories_dir, s)):
186 | story_file = os.path.join(dm_tokenized_stories_dir, s)
187 | else:
188 | print "Error: Couldn't find tokenized story file %s in either tokenized story directories %s and %s. Was there an error during tokenization?" % (s, cnn_tokenized_stories_dir, dm_tokenized_stories_dir)
189 | # Check again if tokenized stories directories contain correct number of files
190 | print "Checking that the tokenized stories directories %s and %s contain correct number of files..." % (cnn_tokenized_stories_dir, dm_tokenized_stories_dir)
191 | check_num_stories(cnn_tokenized_stories_dir, num_expected_cnn_stories)
192 | check_num_stories(dm_tokenized_stories_dir, num_expected_dm_stories)
193 | raise Exception("Tokenized stories directories %s and %s contain correct number of files but story file %s found in neither." % (cnn_tokenized_stories_dir, dm_tokenized_stories_dir, s))
194 |
195 | # Get the strings to write to .bin file
196 | article, abstract = get_art_abs(story_file)
197 | if os.path.isfile(os.path.join(cnn_tokenized_stories_dir, s)) and article[:5] == '(CNN)':
198 | article = article[5:]
199 |
200 | article.decode('utf-8','ignore').encode('utf-8')
201 | if '\r' in article:
202 | cnt_r += 1
203 | article.replace("\r", " ")
204 | print "there is a line contains r"
205 | print cnt_r
206 | if '\n' in article:
207 | cnt_n += 1
208 | article.replace("\n", " ")
209 | print "there is a line contains n"
210 | print cnt_n
211 | article = ' '.join(article.split())
212 | with open(out_file+'.source', mode='a+') as src:
213 | src.write(article+'\n')
214 | with codecs.open(out_file+'.target', mode='a+') as tgt:
215 | tgt.write(abstract+'\n')
216 | # Write to tf.Example
217 | # tf_example = example_pb2.Example()
218 | # tf_example.features.feature['article'].bytes_list.value.extend([article])
219 | # tf_example.features.feature['abstract'].bytes_list.value.extend([abstract])
220 | # tf_example_str = tf_example.SerializeToString()
221 | # str_len = len(tf_example_str)
222 | # writer.write(struct.pack('q', str_len))
223 | # writer.write(struct.pack('%ds' % str_len, tf_example_str))
224 |
225 | # Write the vocab to file, if applicable
226 | if makevocab:
227 | art_tokens = article.split(' ')
228 | abs_tokens = abstract.split(' ')
229 | abs_tokens = [t for t in abs_tokens if t not in [SENTENCE_START, SENTENCE_END]] # remove these tags from vocab
230 | tokens = art_tokens + abs_tokens
231 | tokens = [t.strip() for t in tokens] # strip
232 | tokens = [t for t in tokens if t!=""] # remove empty
233 | vocab_counter.update(tokens)
234 |
235 | print "Finished writing file %s\n" % out_file
236 |
237 | # write vocab to file
238 | if makevocab:
239 | print "Writing vocab file..."
240 | with open(os.path.join(finished_files_dir, "vocab"), 'w') as writer:
241 | for word, count in vocab_counter.most_common(VOCAB_SIZE):
242 | writer.write(word + ' ' + str(count) + '\n')
243 | print "Finished writing vocab file"
244 |
245 |
246 | def check_num_stories(stories_dir, num_expected):
247 | num_stories = len(os.listdir(stories_dir))
248 | if num_stories != num_expected:
249 | raise Exception("stories directory %s contains %i files but should contain %i" % (stories_dir, num_stories, num_expected))
250 |
251 |
252 | if __name__ == '__main__':
253 | # if len(sys.argv) != 3:
254 | # print "USAGE: python make_datafiles.py "
255 | # sys.exit()
256 | # cnn_stories_dir = sys.argv[1]
257 | # dm_stories_dir = sys.argv[2]
258 |
259 |
260 | # Check the stories directories contain the correct number of .story files
261 | # check_num_stories(cnn_stories_dir, num_expected_cnn_stories)
262 | # check_num_stories(dm_stories_dir, num_expected_dm_stories)
263 |
264 | # Create some new directories
265 | # if not os.path.exists(cnn_tokenized_stories_dir): os.makedirs(cnn_tokenized_stories_dir)
266 | # if not os.path.exists(dm_tokenized_stories_dir): os.makedirs(dm_tokenized_stories_dir)
267 | if not os.path.exists(finished_files_dir): os.makedirs(finished_files_dir)
268 |
269 | # Run stanford tokenizer on both stories dirs, outputting to tokenized stories directories
270 | # tokenize_stories(cnn_stories_dir, cnn_tokenized_stories_dir)
271 | # tokenize_stories(dm_stories_dir, dm_tokenized_stories_dir)
272 |
273 | # Read the tokenized stories, do a little postprocessing then write to bin files
274 | # cnt_r=0
275 | write_to_bin(all_test_urls, os.path.join(finished_files_dir, "test"))
276 | write_to_bin(all_val_urls, os.path.join(finished_files_dir, "val"))
277 | write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train"), makevocab=True)
278 | # write_to_bin(all_train_urls, os.path.join(finished_files_dir, "train.bin"))
279 |
280 | # Chunk the data. This splits each of train.bin, val.bin and test.bin into smaller chunks, each containing e.g. 1000 examples, and saves them in finished_files/chunks
281 | # chunk_all()
282 |
--------------------------------------------------------------------------------
/models/bart/process_bioasq_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Process bioasq data generated in the prepare_training_data.py script
3 | # -b as first argument for byte pair encoding
4 | # -f as second argument for the rest of the fairseq processing
5 | # Specify with_question or without_question as third argument
6 |
7 | wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
8 | wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
9 | wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
10 |
11 | echo "Args: $@"
12 | asumm_data=bart_config/${3}
13 | echo "Data dir: $asumm_data"
14 | if [ "$1" == "-b" ]
15 | then
16 | echo "Running byte-pair encoding"
17 | for SPLIT in train val
18 | do
19 | for LANG in source target
20 | do
21 | python -m examples.roberta.multiprocessing_bpe_encoder \
22 | --encoder-json encoder.json \
23 | --vocab-bpe vocab.bpe \
24 | --inputs ${asumm_data}/bart.${SPLIT}_${3}.$LANG \
25 | --outputs ${asumm_data}/bart.${SPLIT}_${3}.bpe.$LANG \
26 | --workers 60 \
27 | --keep-empty;
28 | done
29 | done
30 | fi
31 |
32 | if [ "$2" == "-f" ]
33 | then
34 | echo "Running fairseq processing"
35 | fairseq-preprocess \
36 | --source-lang "source" \
37 | --target-lang "target" \
38 | --trainpref ${asumm_data}/bart.train_${3}.bpe \
39 | --validpref ${asumm_data}/bart.val_${3}.bpe \
40 | --destdir ${asumm_data}/bart-bin \
41 | --workers 60 \
42 | --srcdict dict.txt \
43 | --tgtdict dict.txt;
44 | fi
45 |
--------------------------------------------------------------------------------
/models/bart/run_chiqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Include with_question or without_question when calling script
3 | model_config=bart_config/${1}/bart-bin
4 | model_path=checkpoints_bioasq_$1
5 | for summ_task in page2answer section2answer
6 | do
7 | for summ_type in single_abstractive single_extractive
8 | do
9 | data=${summ_task}_${summ_type}_summ.json
10 | input_file=../../data_processing/data/${data}
11 | prediction_file=bart_chiqa_${1}_${summ_task}_${summ_type}.json
12 | prediction_path=../../evaluation/data/bart/chiqa_eval/${prediction_file}
13 | echo $input_file
14 | echo $prediction_path
15 | python run_inference_medsumm.py \
16 | --input_file=$input_file \
17 | --question_driven=$1 \
18 | --prediction_file=$prediction_path \
19 | --model_path=$model_path \
20 | --model_config=$model_config
21 | done
22 | done
23 |
--------------------------------------------------------------------------------
/models/bart/run_inference_medsumm.py:
--------------------------------------------------------------------------------
1 | """
2 | Script modified from fairseq example CNN-dm script, for performing
3 | summarization inference on input articles.
4 | """
5 |
6 | import argparse
7 | import json
8 |
9 | from tqdm import tqdm
10 | import torch
11 | from fairseq.models.bart import BARTModel
12 |
13 |
14 | def get_args():
15 | """
16 | Argument defnitions
17 | """
18 | parser = argparse.ArgumentParser(description="Arguments for data exploration")
19 | parser.add_argument("--prediction_file",
20 | dest="prediction_file",
21 | help="File to save predictions")
22 | parser.add_argument("--input_file",
23 | dest="input_file",
24 | help="File with text to summarize")
25 | parser.add_argument("--question_driven",
26 | dest="question_driven",
27 | help="Whether to add question to beginning of article for question-driven summarization")
28 | parser.add_argument("--model_path",
29 | dest="model_path",
30 | help="Path to model checkpoints")
31 | parser.add_argument("--model_config",
32 | dest="model_config",
33 | help="Path to model vocab")
34 | parser.add_argument("--batch_size",
35 | dest="batch_size",
36 | default=32,
37 | help="Batch size for inference")
38 | return parser
39 |
40 |
41 | def run_inference():
42 | """
43 | Main function for running inference on given input text
44 | """
45 | bart = BARTModel.from_pretrained(
46 | args.model_path,
47 | checkpoint_file='checkpoint_best.pt',
48 | data_name_or_path=args.model_config
49 | )
50 |
51 | bart.cuda()
52 | bart.eval()
53 | bart.half()
54 | questions = []
55 | ref_summaries = []
56 | gen_summaries = []
57 | articles = []
58 | QUESTION_END = " [QUESTION?] "
59 | with open(args.input_file, 'r', encoding="utf-8") as f:
60 | source = json.load(f)
61 | batch_cnt = 0
62 |
63 | for q in tqdm(source):
64 | question = source[q]['question']
65 | questions.append(question)
66 | # The data here may be prepared for the pointer generator, and it is currently easier to
67 | # clean the sentence tags out here, as opposed to making tagged and nontagged datasets.
68 | ref_summary = source[q]['summary']
69 | if "" in ref_summary:
70 | ref_summary = ref_summary.replace("", "")
71 | ref_summary = ref_summary.replace("", "")
72 | ref_summaries.append(ref_summary)
73 | article = source[q]['articles']
74 | if args.question_driven == "with_question":
75 | article = question + QUESTION_END + article
76 | articles.append(article)
77 | # Once the article list fills up, run a batch
78 | if len(articles) == args.batch_size:
79 | batch_cnt += 1
80 | print("Running batch {}".format(batch_cnt))
81 | # Hyperparameters as recommended here: https://github.com/pytorch/fairseq/issues/1364
82 | with torch.no_grad():
83 | predictions = bart.sample(articles, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
84 | for pred in predictions:
85 | #print(pred)
86 | gen_summaries.append(pred)
87 | articles = []
88 | print("Done with batch {}".format(batch_cnt))
89 |
90 | if len(articles) != 0:
91 | predictions = bart.sample(articles, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
92 | for pred in predictions:
93 | print(pred)
94 | gen_summaries.append(pred)
95 |
96 | assert len(gen_summaries) == len(ref_summaries)
97 | prediction_dict = {
98 | 'question': questions,
99 | 'ref_summary': ref_summaries,
100 | 'gen_summary': gen_summaries
101 | }
102 |
103 | with open(args.prediction_file, "w", encoding="utf-8") as f:
104 | json.dump(prediction_dict, f, indent=4)
105 |
106 |
107 | if __name__ == "__main__":
108 | global args
109 | args = get_args().parse_args()
110 | run_inference()
111 |
--------------------------------------------------------------------------------
/models/bart/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | Train a new model on one or across multiple GPUs.
8 | """
9 |
10 | import collections
11 | import math
12 | import random
13 |
14 | import numpy as np
15 | import torch
16 |
17 | from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
18 | from fairseq.data import iterators
19 | from fairseq.trainer import Trainer
20 | from fairseq.meters import AverageMeter, StopwatchMeter
21 |
22 |
23 | def main(args, init_distributed=False):
24 | utils.import_user_module(args)
25 |
26 | assert args.max_tokens is not None or args.max_sentences is not None, \
27 | 'Must specify batch size either with --max-tokens or --max-sentences'
28 |
29 | # Initialize CUDA and distributed training
30 | if torch.cuda.is_available() and not args.cpu:
31 | torch.cuda.set_device(args.device_id)
32 | np.random.seed(args.seed)
33 | torch.manual_seed(args.seed)
34 | if init_distributed:
35 | args.distributed_rank = distributed_utils.distributed_init(args)
36 |
37 | if distributed_utils.is_master(args):
38 | checkpoint_utils.verify_checkpoint_directory(args.save_dir)
39 |
40 | # Print args
41 | print(args)
42 |
43 | # Setup task, e.g., translation, language modeling, etc.
44 | task = tasks.setup_task(args)
45 |
46 | # Load valid dataset (we load training data below, based on the latest checkpoint)
47 | for valid_sub_split in args.valid_subset.split(','):
48 | task.load_dataset(valid_sub_split, combine=False, epoch=0)
49 |
50 | # Build model and criterion
51 | model = task.build_model(args)
52 | criterion = task.build_criterion(args)
53 | #print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
54 | #print('| num. model params: {} (num. trained: {})'.format(
55 | #sum(p.numel() for p in model.parameters()),
56 | #sum(p.numel() for p in model.parameters() if p.requires_grad),
57 | #))
58 |
59 | # Build trainer
60 | trainer = Trainer(args, task, model, criterion)
61 | print('| training on {} GPUs'.format(args.distributed_world_size))
62 | print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
63 | args.max_tokens,
64 | args.max_sentences,
65 | ))
66 |
67 | # Load the latest checkpoint if one is available and restore the
68 | # corresponding train iterator
69 | extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
70 |
71 | # Train until the learning rate gets too small
72 | max_epoch = args.max_epoch or math.inf
73 | max_update = args.max_update or math.inf
74 | lr = trainer.get_lr()
75 | train_meter = StopwatchMeter()
76 | train_meter.start()
77 | valid_subsets = args.valid_subset.split(',')
78 | while (
79 | lr > args.min_lr
80 | and (
81 | epoch_itr.epoch < max_epoch
82 | # allow resuming training from the final checkpoint
83 | or epoch_itr._next_epoch_itr is not None
84 | )
85 | and trainer.get_num_updates() < max_update
86 | ):
87 | # train for one epoch
88 | train(args, trainer, task, epoch_itr)
89 |
90 | if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
91 | valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
92 | else:
93 | valid_losses = [None]
94 |
95 | # only use first validation loss to update the learning rate
96 | lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
97 |
98 | # save checkpoint
99 | if epoch_itr.epoch % args.save_interval == 0:
100 | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
101 |
102 | # early stop
103 | if should_stop_early(args, valid_losses[0]):
104 | print('| Early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience))
105 | break
106 |
107 | reload_dataset = ':' in getattr(args, 'data', '')
108 | # sharded data: get train iterator for next epoch
109 | epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
110 | train_meter.stop()
111 | print('| done training in {:.1f} seconds'.format(train_meter.sum))
112 |
113 |
114 | def should_stop_early(args, valid_loss):
115 | if args.patience <= 0:
116 | return False
117 |
118 | def is_better(a, b):
119 | return a > b if args.maximize_best_checkpoint_metric else a < b
120 |
121 | prev_best = getattr(should_stop_early, 'best', None)
122 | if prev_best is None or is_better(valid_loss, prev_best):
123 | should_stop_early.best = valid_loss
124 | should_stop_early.num_runs = 0
125 | return False
126 | else:
127 | should_stop_early.num_runs += 1
128 | return should_stop_early.num_runs > args.patience
129 |
130 |
131 | def train(args, trainer, task, epoch_itr):
132 | """Train the model for one epoch."""
133 | # Initialize data iterator
134 | itr = epoch_itr.next_epoch_itr(
135 | fix_batches_to_gpus=args.fix_batches_to_gpus,
136 | shuffle=(epoch_itr.epoch >= args.curriculum),
137 | )
138 | update_freq = (
139 | args.update_freq[epoch_itr.epoch - 1]
140 | if epoch_itr.epoch <= len(args.update_freq)
141 | else args.update_freq[-1]
142 | )
143 | itr = iterators.GroupedIterator(itr, update_freq)
144 | progress = progress_bar.build_progress_bar(
145 | args, itr, epoch_itr.epoch, no_progress_bar='simple',
146 | )
147 |
148 | extra_meters = collections.defaultdict(lambda: AverageMeter())
149 | valid_subsets = args.valid_subset.split(',')
150 | max_update = args.max_update or math.inf
151 | for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
152 | log_output = trainer.train_step(samples)
153 | if log_output is None:
154 | continue
155 |
156 | # log mid-epoch stats
157 | stats = get_training_stats(trainer)
158 | for k, v in log_output.items():
159 | if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
160 | continue # these are already logged above
161 | if 'loss' in k or k == 'accuracy':
162 | extra_meters[k].update(v, log_output['sample_size'])
163 | else:
164 | extra_meters[k].update(v)
165 | stats[k] = extra_meters[k].avg
166 | progress.log(stats, tag='train', step=stats['num_updates'])
167 |
168 | # ignore the first mini-batch in words-per-second and updates-per-second calculation
169 | if i == 0:
170 | trainer.get_meter('wps').reset()
171 | trainer.get_meter('ups').reset()
172 |
173 | num_updates = trainer.get_num_updates()
174 | if (
175 | not args.disable_validation
176 | and args.save_interval_updates > 0
177 | and num_updates % args.save_interval_updates == 0
178 | and num_updates > 0
179 | ):
180 | valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
181 | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
182 |
183 | if num_updates >= max_update:
184 | break
185 |
186 | # log end-of-epoch stats
187 | stats = get_training_stats(trainer)
188 | for k, meter in extra_meters.items():
189 | stats[k] = meter.avg
190 | progress.print(stats, tag='train', step=stats['num_updates'])
191 |
192 | # reset training meters
193 | for k in [
194 | 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
195 | ]:
196 | meter = trainer.get_meter(k)
197 | if meter is not None:
198 | meter.reset()
199 |
200 |
201 | def get_training_stats(trainer):
202 | stats = collections.OrderedDict()
203 | stats['loss'] = trainer.get_meter('train_loss')
204 | if trainer.get_meter('train_nll_loss').count > 0:
205 | nll_loss = trainer.get_meter('train_nll_loss')
206 | stats['nll_loss'] = nll_loss
207 | else:
208 | nll_loss = trainer.get_meter('train_loss')
209 | stats['ppl'] = utils.get_perplexity(nll_loss.avg)
210 | stats['wps'] = trainer.get_meter('wps')
211 | stats['ups'] = trainer.get_meter('ups')
212 | stats['wpb'] = trainer.get_meter('wpb')
213 | stats['bsz'] = trainer.get_meter('bsz')
214 | stats['num_updates'] = trainer.get_num_updates()
215 | stats['lr'] = trainer.get_lr()
216 | stats['gnorm'] = trainer.get_meter('gnorm')
217 | stats['clip'] = trainer.get_meter('clip')
218 | stats['oom'] = trainer.get_meter('oom')
219 | if trainer.get_meter('loss_scale') is not None:
220 | stats['loss_scale'] = trainer.get_meter('loss_scale')
221 | stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
222 | stats['train_wall'] = trainer.get_meter('train_wall')
223 | return stats
224 |
225 |
226 | def validate(args, trainer, task, epoch_itr, subsets):
227 | """Evaluate the model on the validation set(s) and return the losses."""
228 |
229 | if args.fixed_validation_seed is not None:
230 | # set fixed seed for every validation
231 | utils.set_torch_seed(args.fixed_validation_seed)
232 |
233 | valid_losses = []
234 | for subset in subsets:
235 | # Initialize data iterator
236 | itr = task.get_batch_iterator(
237 | dataset=task.dataset(subset),
238 | max_tokens=args.max_tokens_valid,
239 | max_sentences=args.max_sentences_valid,
240 | max_positions=utils.resolve_max_positions(
241 | task.max_positions(),
242 | trainer.get_model().max_positions(),
243 | ),
244 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
245 | required_batch_size_multiple=args.required_batch_size_multiple,
246 | seed=args.seed,
247 | num_shards=args.distributed_world_size,
248 | shard_id=args.distributed_rank,
249 | num_workers=args.num_workers,
250 | ).next_epoch_itr(shuffle=False)
251 | progress = progress_bar.build_progress_bar(
252 | args, itr, epoch_itr.epoch,
253 | prefix='valid on \'{}\' subset'.format(subset),
254 | no_progress_bar='simple'
255 | )
256 |
257 | # reset validation loss meters
258 | for k in ['valid_loss', 'valid_nll_loss']:
259 | meter = trainer.get_meter(k)
260 | if meter is not None:
261 | meter.reset()
262 | extra_meters = collections.defaultdict(lambda: AverageMeter())
263 |
264 | for sample in progress:
265 | log_output = trainer.valid_step(sample)
266 |
267 | for k, v in log_output.items():
268 | if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
269 | continue
270 | extra_meters[k].update(v)
271 |
272 | # log validation stats
273 | stats = get_valid_stats(trainer, args, extra_meters)
274 | for k, meter in extra_meters.items():
275 | stats[k] = meter.avg
276 | progress.print(stats, tag=subset, step=trainer.get_num_updates())
277 |
278 | valid_losses.append(
279 | stats[args.best_checkpoint_metric].avg
280 | if args.best_checkpoint_metric == 'loss'
281 | else stats[args.best_checkpoint_metric]
282 | )
283 | return valid_losses
284 |
285 |
286 | def get_valid_stats(trainer, args, extra_meters=None):
287 | stats = collections.OrderedDict()
288 | stats['loss'] = trainer.get_meter('valid_loss')
289 | if trainer.get_meter('valid_nll_loss').count > 0:
290 | nll_loss = trainer.get_meter('valid_nll_loss')
291 | stats['nll_loss'] = nll_loss
292 | else:
293 | nll_loss = stats['loss']
294 | stats['ppl'] = utils.get_perplexity(nll_loss.avg)
295 | stats['num_updates'] = trainer.get_num_updates()
296 | if hasattr(checkpoint_utils.save_checkpoint, 'best'):
297 | key = 'best_{0}'.format(args.best_checkpoint_metric)
298 | best_function = max if args.maximize_best_checkpoint_metric else min
299 |
300 | current_metric = None
301 | if args.best_checkpoint_metric == 'loss':
302 | current_metric = stats['loss'].avg
303 | elif args.best_checkpoint_metric in extra_meters:
304 | current_metric = extra_meters[args.best_checkpoint_metric].avg
305 | elif args.best_checkpoint_metric in stats:
306 | current_metric = stats[args.best_checkpoint_metric]
307 | else:
308 | raise ValueError("best_checkpoint_metric not found in logs")
309 |
310 | stats[key] = best_function(
311 | checkpoint_utils.save_checkpoint.best,
312 | current_metric,
313 | )
314 | return stats
315 |
316 |
317 | def distributed_main(i, args, start_rank=0):
318 | args.device_id = i
319 | if args.distributed_rank is None: # torch.multiprocessing.spawn
320 | args.distributed_rank = start_rank + i
321 | main(args, init_distributed=True)
322 |
323 |
324 | def cli_main():
325 | parser = options.get_training_parser()
326 | args = options.parse_args_and_arch(parser)
327 |
328 | if args.distributed_init_method is None:
329 | distributed_utils.infer_init_method(args)
330 |
331 | if args.distributed_init_method is not None:
332 | # distributed training
333 | if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
334 | start_rank = args.distributed_rank
335 | args.distributed_rank = None # assign automatically
336 | torch.multiprocessing.spawn(
337 | fn=distributed_main,
338 | args=(args, start_rank),
339 | nprocs=torch.cuda.device_count(),
340 | )
341 | else:
342 | distributed_main(args.device_id, args)
343 | elif args.distributed_world_size > 1:
344 | # fallback for single node with multiple GPUs
345 | assert args.distributed_world_size <= torch.cuda.device_count()
346 | port = random.randint(10000, 20000)
347 | args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
348 | args.distributed_rank = None # set based on device id
349 | if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
350 | print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
351 | torch.multiprocessing.spawn(
352 | fn=distributed_main,
353 | args=(args, ),
354 | nprocs=args.distributed_world_size,
355 | )
356 | else:
357 | # single GPU training
358 | main(args)
359 |
360 |
361 | if __name__ == '__main__':
362 | cli_main()
363 |
--------------------------------------------------------------------------------
/models/baselines.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for baseline approaches:
3 | 1. Take 10 random sentences
4 | 2. Take the topk 10 sentences with highest rouge score relative to the question
5 | 3. Pick the first 10 sentences
6 |
7 | To run
8 | python baselines.py --dataset=chiqa
9 | """
10 |
11 | import json
12 | import numpy as np
13 | import random
14 | import requests
15 | import argparse
16 |
17 | import rouge
18 | import spacy
19 |
20 |
21 | def get_args():
22 | """
23 | Get command line arguments
24 | """
25 | parser = argparse.ArgumentParser(description="Arguments for data exploration")
26 | parser.add_argument("--dataset",
27 | dest="dataset",
28 | help="Dataset to run baselines on. Only current option is MEDIQA-AnS.")
29 | return parser
30 |
31 |
32 | def calculate_sentence_level_rouge(question, doc_sen, evaluator):
33 | """
34 | For each pair of sentences, calculate rouge score with py-rouge
35 | """
36 | rouge_score = evaluator.get_scores(doc_sen, question)['rouge-l']['f']
37 | return rouge_score
38 |
39 |
40 | def pick_k_best_rouge_sentences(k, questions, documents, summaries):
41 | """
42 | Pick the k sentences that have the highest rouge scores when compared to the question.
43 | """
44 | # Initiate rouge evaluator
45 | evaluator = rouge.Rouge(metrics=['rouge-l'],
46 | max_n=3,
47 | limit_length=False,
48 | length_limit_type='words',
49 | apply_avg=False,
50 | apply_best=True,
51 | alpha=1,
52 | weight_factor=1.2,
53 | stemming=False)
54 |
55 | pred_dict = {
56 | 'question': [],
57 | 'ref_summary': [],
58 | 'gen_summary': []
59 | }
60 | for q, doc, summ, in zip(questions, documents, summaries):
61 | # Sentencize abstract
62 | rouge_scores = []
63 | for sentence in doc:
64 | rouge_score = calculate_sentence_level_rouge(q, sentence, evaluator)
65 | rouge_scores.append(rouge_score)
66 | if len(doc) < k:
67 | top_k_rouge_scores = np.argsort(rouge_scores)
68 | else:
69 | top_k_rouge_scores = np.argsort(rouge_scores)[-k:]
70 | top_k_sentences = " ".join([doc[i] for i in top_k_rouge_scores])
71 | summ = summ.replace("", "")
72 | summ = summ.replace("", "")
73 | pred_dict['question'].append(q)
74 | pred_dict['ref_summary'].append(summ)
75 | pred_dict['gen_summary'].append(top_k_sentences)
76 |
77 | return pred_dict
78 |
79 |
80 | def pick_first_k_sentences(k, questions, documents, summaries):
81 | """
82 | Pick the first k sentences to use as summaries
83 | """
84 | pred_dict = {
85 | 'question': [],
86 | 'ref_summary': [],
87 | 'gen_summary': []
88 | }
89 | for q, doc, summ, in zip(questions, documents, summaries):
90 | if len(doc) < k:
91 | first_k_sentences = doc
92 | else:
93 | first_k_sentences = doc[0:k]
94 | first_k_sentences = " ".join(first_k_sentences)
95 | summ = summ.replace("", "")
96 | summ = summ.replace("", "")
97 | pred_dict['question'].append(q)
98 | pred_dict['ref_summary'].append(summ)
99 | pred_dict['gen_summary'].append(first_k_sentences)
100 |
101 | return pred_dict
102 |
103 |
104 | def pick_k_random_sentences(k, questions, documents, summaries):
105 | """
106 | Pick k random sentences from the articles to use as summaries
107 | """
108 | pred_dict = {
109 | 'question': [],
110 | 'ref_summary': [],
111 | 'gen_summary': []
112 | }
113 | random.seed(13)
114 | for q, doc, summ, in zip(questions, documents, summaries):
115 | if len(doc) < k:
116 | random_sentences = " ".join(doc)
117 | else:
118 | random_sentences = random.sample(doc, k)
119 | random_sentences = " ".join(random_sentences)
120 | summ = summ.replace("", "")
121 | summ = summ.replace("", "")
122 | pred_dict['question'].append(q)
123 | pred_dict['ref_summary'].append(summ)
124 | pred_dict['gen_summary'].append(random_sentences)
125 |
126 | return pred_dict
127 |
128 |
129 | def load_dataset(path):
130 | """
131 | Load the evaluation set
132 | """
133 | with open(path, "r", encoding="utf-8") as f:
134 | asumm_data = json.load(f)
135 |
136 | summaries = []
137 | questions = []
138 | documents = []
139 | nlp = spacy.load('en_core_web_sm')
140 | # Split sentences
141 | cnt = 0
142 | for q_id in asumm_data:
143 | questions.append(asumm_data[q_id]['question'])
144 | tokenized_art = nlp(asumm_data[q_id]['articles'])
145 | summaries.append(asumm_data[q_id]['summary'])
146 | article_sentences = [s.text.strip() for s in tokenized_art.sents]
147 | documents.append(article_sentences[0:])
148 | return questions, documents, summaries
149 |
150 |
151 | def save_baseline(baseline, filename):
152 | """
153 | Save baseline in format for rouge evaluation
154 | """
155 | with open("../evaluation/data/baselines/chiqa_eval/baseline_{}.json".format(filename), "w", encoding="utf-8") as f:
156 | json.dump(baseline, f, indent=4)
157 |
158 |
159 | def run_baselines():
160 | """
161 | Generate the random baseline and the best rouge baseline
162 | """
163 | # Load the MEDIQA-AnS datasets
164 | datasets = [
165 | ("../data_processing/data/page2answer_single_abstractive_summ.json", "p2a-single-abs"),
166 | ("../data_processing/data/page2answer_single_extractive_summ.json", "p2a-single-ext"),
167 | ("../data_processing/data/section2answer_single_abstractive_summ.json", "s2a-single-abs"),
168 | ("../data_processing/data/section2answer_single_extractive_summ.json", "s2a-single-ext"),
169 | ]
170 |
171 | for data in datasets:
172 | task = data[1]
173 | print("Running baselines on {}".format(task))
174 | # k can be determined from averages or medians of summary types of reference summaries. Alternatively, just use Lead-3 baseline.
175 | # Optional to use different k for extractive and abstractive summaries, as the manual summaries of the two types have different average lengths
176 | if task == "p2a-single-abs":
177 | k = 3
178 | if task == "p2a-single-ext":
179 | k = 3
180 | if task == "s2a-single-abs":
181 | k = 3
182 | if task == "s2a-single-ext":
183 | k = 3
184 | questions, documents, summaries = load_dataset(data[0])
185 | k_sentences = pick_k_random_sentences(k, questions, documents, summaries)
186 | first_k_sentences = pick_first_k_sentences(k, questions, documents, summaries)
187 | k_best_rouge = pick_k_best_rouge_sentences(k, questions, documents, summaries)
188 |
189 | save_baseline(k_sentences, filename="random_sentences_k_{}_{}_{}".format(k, args.dataset, task))
190 | save_baseline(first_k_sentences, filename="first_sentences_k_{}_{}_{}".format(k, args.dataset, task))
191 | save_baseline(k_best_rouge, filename="best_rouge_k_{}_{}_{}".format(k, args.dataset, task))
192 |
193 |
194 | if __name__ == "__main__":
195 | args = get_args().parse_args()
196 | run_baselines()
197 |
--------------------------------------------------------------------------------
/models/bilstm/__pycache__/data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/bilstm/__pycache__/data.cpython-36.pyc
--------------------------------------------------------------------------------
/models/bilstm/__pycache__/data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/bilstm/__pycache__/data.cpython-37.pyc
--------------------------------------------------------------------------------
/models/bilstm/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/bilstm/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/bilstm/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/bilstm/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/bilstm/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Module for data processing.
3 |
4 | Includes classes for loading the data and generating examples
5 | """
6 |
7 | import json
8 | import tensorflow_datasets as tfds
9 | import tensorflow as tf
10 | from sklearn.utils import shuffle as sk_shuffle
11 | import numpy as np
12 | import random
13 | import spacy
14 |
15 |
16 | class Vocab():
17 | """
18 | Class to generate vocab for model
19 | """
20 |
21 | def train_tokenizer(self, tokenizer_filename, training_data, vocab_size):
22 | """
23 | Train the subword tokenizer
24 | """
25 | encoder = tfds.features.text.SubwordTextEncoder.build_from_corpus(
26 | (sentence for abstract in training_data for sentence in abstract), target_vocab_size=vocab_size)
27 | encoder.save_to_file(tokenizer_filename)
28 |
29 | def _load_tokenizer(self, tokenizer_filename):
30 | """
31 | Load the trained subword tokenizer
32 | """
33 | encoder = tfds.features.text.SubwordTextEncoder.load_from_file(tokenizer_filename)
34 | return encoder
35 |
36 | def encode_data(self, tokenizer_filename, data):
37 | encoder = self._load_tokenizer(tokenizer_filename)
38 | encoded_question = [encoder.encode(question) for question in data[0]]
39 | encoded_answer = [[encoder.encode(sentence) for sentence in abstract] for abstract in data[1]]
40 | assert len(encoded_question) == len(encoded_answer)
41 | data = [encoded_question, encoded_answer]
42 | return data
43 |
44 |
45 | class DataLoader():
46 | """
47 | Class to load data and generate train and validation datasets
48 | """
49 |
50 | def __init__(self, data_path, max_tok_q, max_sentences, max_tok_sent, dataset, summary_type):
51 | """
52 | Initiate loader
53 | """
54 | self.data_path = data_path
55 | self.max_tok_q = max_tok_q
56 | self.max_sentences = max_sentences
57 | self.max_tok_sent = max_tok_sent
58 | self.dataset = dataset
59 | self.summary_type = summary_type
60 |
61 | def split_data(self, data):
62 | """
63 | Shuffle and divide data up into training and validation sets
64 | """
65 | questions, documents, scores = data
66 | # Shuffle data:
67 | assert len(questions) == len(documents)
68 | assert len(questions) == len(scores)
69 | documents, questions, scores = sk_shuffle(documents, questions, scores, random_state=13)
70 | training_index = int(len(questions) * .8)
71 | x_train = [questions[:training_index], documents[:training_index]]
72 | y_train = scores[:training_index]
73 | x_val = [questions[training_index:], documents[training_index:]]
74 | y_val = scores[training_index:]
75 |
76 | return x_train, y_train, x_val, y_val
77 |
78 | def pad_data(self, data, data_type, padding_type='float32', test=False):
79 | """
80 | Method for padding data for training, validation, or inference.
81 | """
82 | if data_type == "question":
83 | padded_data = tf.keras.preprocessing.sequence.pad_sequences(data, padding='post', maxlen=self.max_tok_q)
84 | elif data_type == "scores":
85 | padded_data = tf.keras.preprocessing.sequence.pad_sequences(data, padding='post', dtype=padding_type, maxlen=self.max_sentences)
86 | elif data_type == "article":
87 | padded_data = []
88 | # Mask for sentences in document embedding
89 | #sentence_masks = []
90 | for doc in data:
91 | if len(doc) > self.max_sentences:
92 | doc = doc[:self.max_sentences]
93 | #sentence_masks.append(np.ones(len(doc), dype=bool))
94 | elif len(doc) < self.max_sentences:
95 | # Add the sentences that are missing.
96 | extra_docs = [[] for i in range(self.max_sentences - len(doc))]
97 | doc.extend(extra_docs)
98 | #sentence_mask = [False if d == [] else True for d in doc]
99 | #sentence_masks.append(np.array(sentence_masks))
100 | assert len(doc) == self.max_sentences
101 | padded_data.append(tf.keras.preprocessing.sequence.pad_sequences(doc, padding='post', maxlen=self.max_tok_sent))
102 |
103 | # For asserting correct padding:
104 | if test:
105 | if data_type == "question":
106 | for q in padded_data:
107 | assert len(q) == self.max_tok_q, q
108 | if data_type == "article":
109 | #assert len(sentence_masks) == len(padded_data), len(sentence_masks)
110 | #for m in sentence_masks:
111 | # assert len(m) == self.max_sentences, len(m)
112 | for doc in padded_data:
113 | assert isinstance(doc, np.ndarray), type(doc)
114 | assert len(doc) == self.max_sentences, len(doc)
115 | for sent in doc:
116 | assert isinstance(sent, np.ndarray), type(sent)
117 | assert len(sent) == self.max_tok_sent, len(sent)
118 | assert isinstance(sent[0], np.int32), (sent[0], type(sent[0]))
119 | elif data_type == "scores":
120 | for doc in padded_data:
121 | assert isinstance(doc, np.ndarray), type(doc)
122 | assert len(doc) == self.max_sentences, len(doc)
123 | # Convert abstracts to np array here
124 | # pad_sequences converts lists of lists to np arrays, which is the form the questions
125 | # and sentences are in.
126 | # Reshape y_train and val from 2d > 3d
127 | if data_type == "scores":
128 | padded_data = np.reshape(padded_data, padded_data.shape + (1,))
129 | return padded_data
130 | elif data_type == "article":
131 | return np.array(padded_data)
132 | else:
133 | return padded_data
134 |
135 | def load_data(self, mode, tag_sentences):
136 | """
137 | Open the data and split into train/val.
138 | """
139 | questions = []
140 | documents = []
141 | with open(self.data_path, "r", encoding="utf-8") as f:
142 | asumm_data = json.load(f)
143 |
144 | if mode == "train":
145 | scores = []
146 | for q_id in asumm_data:
147 | questions.append(asumm_data[q_id]['question'])
148 | documents.append(asumm_data[q_id]['sentences'])
149 | scores.append(asumm_data[q_id]['labels'])
150 | return questions, documents, scores
151 |
152 | if mode == "infer":
153 | summaries = []
154 | nlp = spacy.load('en_core_web_sm')
155 | if self.dataset == "chiqa" or self.dataset=="medinfo":
156 | question_ids = []
157 | # There are multiple summary tasks withing the chiqa dataset: single and multi doc require different processing.
158 | # Handle the single answer -> summary case:
159 | if "single" in self.summary_type:
160 | for q_id in asumm_data:
161 | question_ids.append(q_id)
162 | question = asumm_data[q_id]['question']
163 | questions.append(question)
164 | summary = asumm_data[q_id]['summary']
165 | tokenized_art = nlp(asumm_data[q_id]['articles'])
166 | tokenized_summ = nlp(summary)
167 | # Split sentences and tag with s if option included for pointer generator
168 | if tag_sentences:
169 | summary = " ".join([" {s} ".format(s=s.text.strip()) for s in tokenized_summ.sents])
170 | summaries.append(summary)
171 | article_sentences = [s.text.strip() for s in tokenized_art.sents]
172 | documents.append(article_sentences)
173 | return questions, documents, summaries, question_ids
174 |
175 |
176 | if __name__ == "__main__":
177 | # Locally testing data classes:
178 | max_tok_q = 20
179 | max_sentences = 75
180 | max_tok_sent = 20
181 | hidden_dim = 256
182 | dataset = "bioasq"
183 | tag_sentences = False
184 | mode = "train"
185 | data_loader = DataLoader("/data/saveryme/asumm/asumm_data/training_data/bioasq_abs2summ_sent_classification_training.json", max_tok_q, max_sentences, max_tok_sent, dataset)
186 | data = data_loader.load_data(mode, tag_sentences)
187 | x_train, y_train, x_val, y_val = data_loader.split_data(data)
188 | # x_train[0] = x_train[0][:3]
189 | # x_train[1] = x_train[1][:3]
190 | # x_val[0] = x_val[0][:3]
191 | # x_val[1] = x_val[1][:3]
192 | # y_train = y_train[:3]
193 | # y_val = y_val[:3]
194 |
195 | # print("Text!")
196 | print("Questions:\n", x_train[0][0])
197 | print("Sentences:\n", x_train[1][0])
198 | print("encoding data")
199 | x_train = Vocab().encode_data("medsumm_bioasq_abs2summ/tokenizer", x_train)
200 |
201 | sent_cnt = 0
202 | sent_len = 0
203 | for doc in x_train[1]:
204 | for sentence in doc:
205 | sent_cnt += 1
206 | sent_len += len(sentence)
207 |
208 | x_val = Vocab().encode_data("medsumm_bioasq_abs2summ/tokenizer", x_val)
209 | print("padding data")
210 | x_train[0] = data_loader.pad_data(x_train[0], data_type="question", test=True)
211 | x_train[1] = data_loader.pad_data(x_train[1], data_type="article", test=True)
212 | x_val[0] = data_loader.pad_data(x_val[0], data_type="question", test=True)
213 | x_val[1] = data_loader.pad_data(x_val[1], data_type="article", test=True)
214 | y_train = data_loader.pad_data(y_train, data_type="scores", test=True)
215 | y_val = data_loader.pad_data(y_val, data_type="scores", test=True)
216 | assert x_train[0].shape == (30532, max_tok_q), x_train[0].shape
217 | assert x_train[1].shape == (30532, max_sentences, max_tok_sent), x_train[1].shape
218 | assert x_val[0].shape == (7634, max_tok_q), x_val[0].shape
219 | assert x_val[1].shape == (7634, max_sentences, max_tok_sent), x_val[1].shape
220 | assert y_train.shape == (30532, max_sentences, 1), y_train.shape
221 | assert y_val.shape == (7634, max_sentences, 1), y_val.shape
222 | print("Avg. subwords/sentence:", sent_len / sent_cnt)
223 |
--------------------------------------------------------------------------------
/models/bilstm/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Model module for constructing the tensorflow graph for the LSTM sentence classifier
3 | """
4 |
5 | import tensorflow as tf
6 | from tensorflow.keras import layers
7 |
8 | class SentenceClassificationModel():
9 | """
10 | Class for model selecting sentences from relevant documents for answer summarization
11 | """
12 |
13 | def __init__(self, vocab_size, batch_size, hidden_dim, dropout, max_tok_q, max_sentences, max_tok_sent):
14 | """
15 | Initiate the mode
16 | """
17 | self.vocab_size = vocab_size
18 | self.batch_size = batch_size
19 | self.hidden_dim = hidden_dim
20 | self.max_tok_q = max_tok_q
21 | self.max_tok_sent = max_tok_sent
22 | self.max_sentences = max_sentences
23 | self.dropout = dropout
24 |
25 | def build_binary_model(self):
26 | """
27 | Construct the graph of the model
28 | """
29 | question_input = tf.keras.Input(shape=(self.max_tok_q, ), name='q_input')
30 | abstract_input = tf.keras.Input(shape=(self.max_sentences, self.max_tok_sent, ), name='abs_input')
31 | # NOT USING MASKING DUE TO CUDNN ERROR: https://github.com/tensorflow/tensorflow/issues/33148
32 | x1 = layers.Embedding(input_dim=self.vocab_size, output_dim=self.hidden_dim, mask_zero=False)(question_input)
33 | x1 = layers.Bidirectional(layers.LSTM(self.hidden_dim, dropout=self.dropout, kernel_regularizer=tf.keras.regularizers.l2(0.01)), input_shape=(self.max_tok_q, self.hidden_dim), name='q_bilstm')(x1)
34 |
35 | # Apply embedding to every sentence
36 | x2 = layers.TimeDistributed(layers.Embedding(input_dim=self.vocab_size, output_dim=self.hidden_dim, input_length=self.max_tok_sent, mask_zero=False), input_shape=(self.max_sentences, self.max_tok_sent))(abstract_input)
37 | # Apply lstm to every sentence embedding
38 | x2 = layers.TimeDistributed(layers.Bidirectional(layers.LSTM(self.hidden_dim, dropout=self.dropout, kernel_regularizer=tf.keras.regularizers.l2(0.01))), input_shape=(self.max_sentences, self.max_tok_sent, self.hidden_dim), name='sentence_distributed_bilstms')(x2)
39 | # Make lstm of document representation:
40 | # I could also just take this document representation and concatenate it to the single sentence representation, but I don't.
41 | x2 = layers.Bidirectional(layers.LSTM(self.hidden_dim, return_sequences=True, dropout=self.dropout, kernel_regularizer=tf.keras.regularizers.l2(0.01)), input_shape=(self.max_sentences, self.hidden_dim * 2), name='document_bilstm')(x2)
42 | # Combine question and document
43 | x3 = layers.RepeatVector(self.max_sentences)(x1)
44 | x4 = layers.concatenate([x2, x3])
45 |
46 | # If using integers as class labels, 1 target label can be provided be example (not 1 hot) and the number of labels can be defined here
47 | sent_output = layers.Dense(2, activation='sigmoid', name='sent_output', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x4)
48 |
49 | model = tf.keras.Model(inputs=[question_input, abstract_input], outputs=sent_output)
50 | model.summary()
51 | model.compile(loss='sparse_categorical_crossentropy',
52 | optimizer=tf.keras.optimizers.Adam(1e-4)
53 | )
54 |
55 | return model
56 |
--------------------------------------------------------------------------------
/models/bilstm/requirements.txt:
--------------------------------------------------------------------------------
1 | spacy
2 | numpy
3 | tensorflow-datasets
4 | sklearn
5 |
--------------------------------------------------------------------------------
/models/bilstm/run_chiqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | exp_name=medsumm_bioasq_abs2summ
3 | tensorboard_log=dropout_5_sent_200_tok_50_val_20_d_256_l2_reg_binary
4 |
5 | for summ_task in page2answer section2answer
6 | do
7 | for summ_type in single_abstractive single_extractive
8 | do
9 | for k in 3
10 | do
11 | # Optional to specify different k, for selecting top k sentences. If you are interested in using the output of the sentence classifier as input for a generative model,
12 | # you may want to increase k
13 | echo "k:" $k
14 | data=../../data_processing/data/${summ_task}_${summ_type}_summ.json
15 | echo $data
16 | prediction_file=predictions_chiqa_${summ_task}_${summ_type}_dropout_5_sent_200_tok_50_val_20_d_256_l2_reg_binary_topk${k}.json
17 | eval_file=../../evaluation/data/sentence_classifier/chiqa_eval/sent_class_chiqa_${summ_task}_${summ_type}_dropout_5_sent_200_tok_50_val_20_d_256_l2_reg_binary_topk${k}.json
18 | python run_classifier.py \
19 | --exp_name=${exp_name} \
20 | --mode=infer \
21 | --data_path=$data \
22 | --dataset=chiqa \
23 | --summary_type=${summ_task}_${summ_type} \
24 | --train_tokenizer=False \
25 | --tokenizer_path=./${exp_name}/tokenizer \
26 | --model_path=${exp_name}/${tensorboard_log}/bioasq_abs2summ_sent_class_model.h5 \
27 | --batch_size=32 \
28 | --max_sentences=200 \
29 | --max_tok_sent=50 \
30 | --max_tok_q=50 \
31 | --dropout=.5 \
32 | --hidden_dim=256 \
33 | --binary_model=True \
34 | --prediction_file=${exp_name}/${tensorboard_log}/${prediction_file} \
35 | --eval_file=$eval_file \
36 | --tag_sentences=False \
37 | --top_k_sent=${k}
38 | done
39 | done
40 | done
41 |
--------------------------------------------------------------------------------
/models/bilstm/run_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | Runner script to train a simple lstm sentence classifier,
3 | using best rouge score for a sentence in the article compared to each sentence in the summary
4 | as the y labels.
5 |
6 | Should I create labels with rouge or with n sentences selected based on highes score.
7 | """
8 |
9 | import time
10 | from absl import app, flags
11 | import logging
12 | import os
13 | import sys
14 | import json
15 | import re
16 |
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | from data import Vocab, DataLoader
21 | from model import SentenceClassificationModel
22 |
23 | FLAGS = flags.FLAGS
24 |
25 | # Paths
26 | flags.DEFINE_string('data_path', '', 'Path expression to tf.Example datafiles. Can include wildcards to access multiple datafiles.')
27 | flags.DEFINE_string('mode', 'train', 'Select train/infer')
28 | flags.DEFINE_string('exp_name', '', 'Name for experiment. Tensorboard logs and models will be saved in a directories under this one')
29 | flags.DEFINE_string('tensorboard_log', '', 'Name of specific run for tb to log to. The experiment name will be the parent directory')
30 | flags.DEFINE_string('model_path', '', "Path to save model checkpoint")
31 | flags.DEFINE_string('dataset', '', "The dataset used for inference, used to specify how to load the data, e.g., medinfo")
32 | flags.DEFINE_string('summary_type', '', "The summary task within the chiqa dataset. The multi and single document tasks require different data handling")
33 | flags.DEFINE_string('prediction_file', '', "File to save predictions to be used for generative model")
34 | flags.DEFINE_string('eval_file', '', "File to save preds for direct evaluation")
35 |
36 | # Tokenizer training
37 | flags.DEFINE_string('tokenizer_path', '', 'Path to save tokenizer once trained')
38 | flags.DEFINE_boolean('train_tokenizer', False, "Flag to train a new tokenizer on the training corpus")
39 |
40 | # Data processing
41 | flags.DEFINE_boolean("tag_sentences", False, "For use with mode=infer. Tag the article sentences with and , when using pointer generator network as second step, if the data has not already been tagged")
42 |
43 | # Hyperparameters and such
44 | flags.DEFINE_boolean('binary_model', False, "Flag to use binary model or regression model")
45 | flags.DEFINE_integer('vocab_size', 2**15, 'Size of subword vocabulary. Not sure if this is what I am definitely going to do. May be better to use pretrained embeddings')
46 | flags.DEFINE_integer('batch_size', 32, 'batch size')
47 | flags.DEFINE_integer('n_epochs', 10, 'Max number of epochs to run')
48 | flags.DEFINE_integer('max_tok_q', 100, 'max number of tokens for question')
49 | flags.DEFINE_integer('max_sentences', 10, 'max number of sentences')
50 | flags.DEFINE_integer('max_tok_sent', 100, 'max number of subword tokens/sentence')
51 | flags.DEFINE_integer('hidden_dim', 128, 'dimension of lstm')
52 | flags.DEFINE_float('dropout', .2, 'dropout proportion')
53 | flags.DEFINE_float('decision_threshold', .3, "Threshold for selecting relevant sentences during inference. No current implementation")
54 | flags.DEFINE_integer('top_k_sent', 10, "Number of sentences to select.")
55 |
56 |
57 | def run_training(model, x_train, y_train, x_val, y_val):
58 | """
59 | Run training in loop for n_epochs
60 | """
61 | logging.info("Beginning training\n")
62 | tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.tensorboard_log, update_freq=5000, profile_batch=0)
63 | checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
64 | filepath='{0}/{1}'.format(FLAGS.tensorboard_log, FLAGS.model_path),
65 | # overwrite current checkpoint if `val_loss` has improved.
66 | save_best_only=True,
67 | monitor='val_loss',
68 | save_freq='epoch',
69 | verbose=0)
70 |
71 | model.fit({'q_input': x_train[0], 'abs_input': x_train[1]}, {'sent_output': y_train},
72 | shuffle=True,
73 | epochs=FLAGS.n_epochs,
74 | validation_data=(x_val, y_val),
75 | callbacks=[tensorboard_callback, checkpoint_callback])
76 | logging.info("Training completed!\n")
77 |
78 |
79 | def run_inference(model, data, threshold, top_k_sent, binary_model):
80 | """
81 | Given a set of documents and a question, predict relevant sentences
82 | """
83 | predictions = model.predict({'q_input': data[0], 'abs_input': data[1]})
84 | logging.info("Predictions shape: {}".format(predictions.shape))
85 | if binary_model:
86 | # Select preds for just label 1
87 | reduced_preds = predictions[:, :, 1]
88 | reduced_preds = tf.squeeze(reduced_preds)
89 | filtered_predictions = tf.math.top_k(reduced_preds, k=top_k_sent, sorted=True)
90 | else:
91 | reduced_preds = tf.squeeze(predictions)
92 | filtered_predictions = tf.math.top_k(reduced_preds, k=top_k_sent)
93 | return filtered_predictions
94 |
95 |
96 | def save_predictions(predictions, data, binary_model):
97 | """
98 | Get the indicies of sentences predicted to be above rouge threshold and save to json.
99 | """
100 | pred_dict = {}
101 | # Also save the data in format for direct evaluation
102 | questions = []
103 | ref_summaries = []
104 | gen_summaries = []
105 | question_ids = []
106 | pred_dict = {}
107 | q_cnt = 0
108 | for p, indices in zip(predictions.values, predictions.indices):
109 | question = data[0][q_cnt]
110 | question_id = data[3][q_cnt]
111 | pred_dict[question_id] = {}
112 | # Get the human generated summary
113 | ref_summary = data[2][q_cnt]
114 | pred_dict[question_id]['summary'] = ref_summary
115 | pred_dict[question_id]['question'] = question
116 | # Remove any predicted indices that are out of the article's range
117 | indices = indices.numpy()[indices.numpy() < len(data[1][q_cnt])]
118 | sentences = " ".join(list(np.array(data[1][q_cnt])[indices]))
119 | pred_dict[question_id]['articles'] = sentences
120 | pred_dict[question_id]['predicted_score'] = p.numpy().tolist()
121 | # Format for rouge evaluation
122 | questions.append(question)
123 | question_ids.append(question_id)
124 | # Remove any sentence tags from data that will be evaluated directly and not passed to pg
125 | ref_summary = ref_summary.replace("", "")
126 | ref_summary = ref_summary.replace("", "")
127 | ref_summaries.append(ref_summary)
128 | gen_summaries.append(sentences)
129 | q_cnt += 1
130 |
131 | predictions_for_eval = {'question_id': question_ids, 'question': questions, 'ref_summary': ref_summaries, 'gen_summary': gen_summaries}
132 |
133 | with open(FLAGS.prediction_file, "w", encoding="utf-8") as f:
134 | json.dump(pred_dict, f, indent=4)
135 |
136 | with open(FLAGS.eval_file, "w", encoding="utf-8") as f:
137 | json.dump(predictions_for_eval, f, indent=4)
138 |
139 |
140 | def main(argv):
141 | """
142 | Main function for running sentence classifier
143 | """
144 | logging.info("Num GPUs Available: {}\n".format(len(tf.config.experimental.list_physical_devices('GPU'))))
145 | logging.basicConfig(filename="{}/medsumm.log".format(FLAGS.exp_name), filemode='w', level=logging.DEBUG)
146 | logging.info("Initiating sentence classication model...\n")
147 | logging.info("Loading data:")
148 | data_loader = DataLoader(FLAGS.data_path, FLAGS.max_tok_q, FLAGS.max_sentences, FLAGS.max_tok_sent, FLAGS.dataset, FLAGS.summary_type)
149 | # Returns tuple
150 | data = data_loader.load_data(FLAGS.mode, FLAGS.tag_sentences)
151 | if FLAGS.mode == "train":
152 | x_train, y_train, x_val, y_val = data_loader.split_data(data)
153 | logging.info("Questions:")
154 | logging.info(x_train[0][:2])
155 | logging.info("Sentences:")
156 | logging.info(x_train[1][:2])
157 |
158 | vocab_processor = Vocab()
159 | if FLAGS.train_tokenizer:
160 | logging.info("Training tokenizer\n")
161 | vocab_processor.train_tokenizer(FLAGS.tokenizer_path, data[1], FLAGS.vocab_size)
162 | # Once trained, get the subword tokenizer and encode the data
163 | logging.info("Encoding text")
164 | if FLAGS.mode == "train":
165 | logging.info("Encoding data")
166 | x_train = vocab_processor.encode_data(FLAGS.tokenizer_path, x_train)
167 | x_val = vocab_processor.encode_data(FLAGS.tokenizer_path, x_val)
168 | logging.info("Padding encodings")
169 | x_train[0] = data_loader.pad_data(x_train[0], data_type="question", test=True)
170 | x_train[1] = data_loader.pad_data(x_train[1], data_type="article", test=True)
171 | x_val[0] = data_loader.pad_data(x_val[0], data_type="question", test=True)
172 | x_val[1] = data_loader.pad_data(x_val[1], data_type="article", test=True)
173 | if FLAGS.binary_model:
174 | padding_type = "int32"
175 | else:
176 | padding_type = "float32"
177 | y_train = data_loader.pad_data(y_train, data_type="scores", padding_type=padding_type, test=True)
178 | y_val = data_loader.pad_data(y_val, data_type="scores", padding_type=padding_type, test=True)
179 | logging.info("Data shape:")
180 | logging.info("x_train questions: {}".format(x_train[0].shape))
181 | logging.info("x_train documents: {}".format(x_train[1].shape))
182 | logging.info("x_val questions: {}".format(x_val[0].shape))
183 | logging.info("x_val documents: {}".format(x_val[1].shape))
184 | logging.info("y_train: {}".format(y_train.shape))
185 | logging.info("y_val: {}".format(y_val.shape))
186 |
187 | logging.info("Question encoding")
188 | logging.info(x_train[0][:2])
189 | logging.info("Sentence encoding")
190 | logging.info(x_train[1][:2])
191 | logging.info("Rouge scores:")
192 | logging.info(y_train[0][:2])
193 |
194 | model = SentenceClassificationModel(
195 | FLAGS.vocab_size, FLAGS.batch_size, FLAGS.hidden_dim, FLAGS.dropout,
196 | FLAGS.max_tok_q, FLAGS.max_sentences, FLAGS.max_tok_sent
197 | )
198 |
199 | if FLAGS.binary_model:
200 | model = model.build_binary_model()
201 | run_training(model, x_train, y_train, x_val, y_val)
202 | else:
203 | model = model.build_model()
204 | run_training(model, x_train, y_train, x_val, y_val)
205 |
206 | if FLAGS.mode == "infer":
207 | encoded_data = vocab_processor.encode_data(FLAGS.tokenizer_path, data)
208 | padded_data = []
209 | padded_data.append(data_loader.pad_data(encoded_data[0], "question"))
210 | padded_data.append(data_loader.pad_data(encoded_data[1], "article"))
211 | logging.info("N questions: {}".format(len(padded_data[0])))
212 | logging.info("N documents: {}".format(len(padded_data[1])))
213 | logging.info("Question encoding")
214 | logging.info(padded_data[0][:2])
215 | logging.info("Sentence encoding")
216 | logging.info(padded_data[1][:2])
217 |
218 | logging.info("Loading model")
219 | model = tf.keras.models.load_model(FLAGS.model_path)
220 | predictions = run_inference(model, padded_data, FLAGS.decision_threshold, FLAGS.top_k_sent, FLAGS.binary_model)
221 | save_predictions(predictions, data, FLAGS.binary_model)
222 |
223 |
224 | if __name__ == "__main__":
225 | app.run(main)
226 |
--------------------------------------------------------------------------------
/models/bilstm/train_sentence_classifier.sh:
--------------------------------------------------------------------------------
1 | exp_name=medsumm_bioasq_abs2summ
2 | tensorboard_log=dropout_5_sent_200_tok_50_val_20_d_256_l2_reg_binary
3 | mkdir $exp_name
4 | mkdir ${exp_name}/${tensorboard_log}
5 | #training_data=/data/saveryme/asumm/asumm_data/training_data/bioasq_abs2summ_binary_sent_classification_training.json
6 | training_data=../../data_processing/data/bioasq_abs2summ_binary_sent_classification_training.json
7 | max_sent=200
8 | binary_model=True
9 | # Note that for first run the sub-word tokenizer has to be trained first. For all subsequent runs, can be left false.
10 | train_tokenizer=True
11 |
12 | #--tokenizer_path=/data/saveryme/asumm/models/sentence_classifier/${exp_name}/tokenizer \
13 | python run_classifier.py \
14 | --exp_name=${exp_name} \
15 | --tensorboard_log=${exp_name}/${tensorboard_log} \
16 | --mode=train \
17 | --data_path=$training_data \
18 | --train_tokenizer=$train_tokenizer \
19 | --tokenizer_path=./${exp_name}/tokenizer \
20 | --model_path=bioasq_abs2summ_sent_class_model.h5 \
21 | --batch_size=32 \
22 | --max_sentences=$max_sent \
23 | --max_tok_sent=50 \
24 | --max_tok_q=50 \
25 | --dropout=.5 \
26 | --hidden_dim=256 \
27 | --binary_model=$binary_model
28 |
--------------------------------------------------------------------------------
/models/pointer_generator/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright 2017 The TensorFlow Authors. All rights reserved.
2 | Modifications Copyright 2017 Abigail See
3 |
4 |
5 | Apache License
6 | Version 2.0, January 2004
7 | http://www.apache.org/licenses/
8 |
9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10 |
11 | 1. Definitions.
12 |
13 | "License" shall mean the terms and conditions for use, reproduction,
14 | and distribution as defined by Sections 1 through 9 of this document.
15 |
16 | "Licensor" shall mean the copyright owner or entity authorized by
17 | the copyright owner that is granting the License.
18 |
19 | "Legal Entity" shall mean the union of the acting entity and all
20 | other entities that control, are controlled by, or are under common
21 | control with that entity. For the purposes of this definition,
22 | "control" means (i) the power, direct or indirect, to cause the
23 | direction or management of such entity, whether by contract or
24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
25 | outstanding shares, or (iii) beneficial ownership of such entity.
26 |
27 | "You" (or "Your") shall mean an individual or Legal Entity
28 | exercising permissions granted by this License.
29 |
30 | "Source" form shall mean the preferred form for making modifications,
31 | including but not limited to software source code, documentation
32 | source, and configuration files.
33 |
34 | "Object" form shall mean any form resulting from mechanical
35 | transformation or translation of a Source form, including but
36 | not limited to compiled object code, generated documentation,
37 | and conversions to other media types.
38 |
39 | "Work" shall mean the work of authorship, whether in Source or
40 | Object form, made available under the License, as indicated by a
41 | copyright notice that is included in or attached to the work
42 | (an example is provided in the Appendix below).
43 |
44 | "Derivative Works" shall mean any work, whether in Source or Object
45 | form, that is based on (or derived from) the Work and for which the
46 | editorial revisions, annotations, elaborations, or other modifications
47 | represent, as a whole, an original work of authorship. For the purposes
48 | of this License, Derivative Works shall not include works that remain
49 | separable from, or merely link (or bind by name) to the interfaces of,
50 | the Work and Derivative Works thereof.
51 |
52 | "Contribution" shall mean any work of authorship, including
53 | the original version of the Work and any modifications or additions
54 | to that Work or Derivative Works thereof, that is intentionally
55 | submitted to Licensor for inclusion in the Work by the copyright owner
56 | or by an individual or Legal Entity authorized to submit on behalf of
57 | the copyright owner. For the purposes of this definition, "submitted"
58 | means any form of electronic, verbal, or written communication sent
59 | to the Licensor or its representatives, including but not limited to
60 | communication on electronic mailing lists, source code control systems,
61 | and issue tracking systems that are managed by, or on behalf of, the
62 | Licensor for the purpose of discussing and improving the Work, but
63 | excluding communication that is conspicuously marked or otherwise
64 | designated in writing by the copyright owner as "Not a Contribution."
65 |
66 | "Contributor" shall mean Licensor and any individual or Legal Entity
67 | on behalf of whom a Contribution has been received by Licensor and
68 | subsequently incorporated within the Work.
69 |
70 | 2. Grant of Copyright License. Subject to the terms and conditions of
71 | this License, each Contributor hereby grants to You a perpetual,
72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73 | copyright license to reproduce, prepare Derivative Works of,
74 | publicly display, publicly perform, sublicense, and distribute the
75 | Work and such Derivative Works in Source or Object form.
76 |
77 | 3. Grant of Patent License. Subject to the terms and conditions of
78 | this License, each Contributor hereby grants to You a perpetual,
79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80 | (except as stated in this section) patent license to make, have made,
81 | use, offer to sell, sell, import, and otherwise transfer the Work,
82 | where such license applies only to those patent claims licensable
83 | by such Contributor that are necessarily infringed by their
84 | Contribution(s) alone or by combination of their Contribution(s)
85 | with the Work to which such Contribution(s) was submitted. If You
86 | institute patent litigation against any entity (including a
87 | cross-claim or counterclaim in a lawsuit) alleging that the Work
88 | or a Contribution incorporated within the Work constitutes direct
89 | or contributory patent infringement, then any patent licenses
90 | granted to You under this License for that Work shall terminate
91 | as of the date such litigation is filed.
92 |
93 | 4. Redistribution. You may reproduce and distribute copies of the
94 | Work or Derivative Works thereof in any medium, with or without
95 | modifications, and in Source or Object form, provided that You
96 | meet the following conditions:
97 |
98 | (a) You must give any other recipients of the Work or
99 | Derivative Works a copy of this License; and
100 |
101 | (b) You must cause any modified files to carry prominent notices
102 | stating that You changed the files; and
103 |
104 | (c) You must retain, in the Source form of any Derivative Works
105 | that You distribute, all copyright, patent, trademark, and
106 | attribution notices from the Source form of the Work,
107 | excluding those notices that do not pertain to any part of
108 | the Derivative Works; and
109 |
110 | (d) If the Work includes a "NOTICE" text file as part of its
111 | distribution, then any Derivative Works that You distribute must
112 | include a readable copy of the attribution notices contained
113 | within such NOTICE file, excluding those notices that do not
114 | pertain to any part of the Derivative Works, in at least one
115 | of the following places: within a NOTICE text file distributed
116 | as part of the Derivative Works; within the Source form or
117 | documentation, if provided along with the Derivative Works; or,
118 | within a display generated by the Derivative Works, if and
119 | wherever such third-party notices normally appear. The contents
120 | of the NOTICE file are for informational purposes only and
121 | do not modify the License. You may add Your own attribution
122 | notices within Derivative Works that You distribute, alongside
123 | or as an addendum to the NOTICE text from the Work, provided
124 | that such additional attribution notices cannot be construed
125 | as modifying the License.
126 |
127 | You may add Your own copyright statement to Your modifications and
128 | may provide additional or different license terms and conditions
129 | for use, reproduction, or distribution of Your modifications, or
130 | for any such Derivative Works as a whole, provided Your use,
131 | reproduction, and distribution of the Work otherwise complies with
132 | the conditions stated in this License.
133 |
134 | 5. Submission of Contributions. Unless You explicitly state otherwise,
135 | any Contribution intentionally submitted for inclusion in the Work
136 | by You to the Licensor shall be under the terms and conditions of
137 | this License, without any additional terms or conditions.
138 | Notwithstanding the above, nothing herein shall supersede or modify
139 | the terms of any separate license agreement you may have executed
140 | with Licensor regarding such Contributions.
141 |
142 | 6. Trademarks. This License does not grant permission to use the trade
143 | names, trademarks, service marks, or product names of the Licensor,
144 | except as required for reasonable and customary use in describing the
145 | origin of the Work and reproducing the content of the NOTICE file.
146 |
147 | 7. Disclaimer of Warranty. Unless required by applicable law or
148 | agreed to in writing, Licensor provides the Work (and each
149 | Contributor provides its Contributions) on an "AS IS" BASIS,
150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151 | implied, including, without limitation, any warranties or conditions
152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153 | PARTICULAR PURPOSE. You are solely responsible for determining the
154 | appropriateness of using or redistributing the Work and assume any
155 | risks associated with Your exercise of permissions under this License.
156 |
157 | 8. Limitation of Liability. In no event and under no legal theory,
158 | whether in tort (including negligence), contract, or otherwise,
159 | unless required by applicable law (such as deliberate and grossly
160 | negligent acts) or agreed to in writing, shall any Contributor be
161 | liable to You for damages, including any direct, indirect, special,
162 | incidental, or consequential damages of any character arising as a
163 | result of this License or out of the use or inability to use the
164 | Work (including but not limited to damages for loss of goodwill,
165 | work stoppage, computer failure or malfunction, or any and all
166 | other commercial damages or losses), even if such Contributor
167 | has been advised of the possibility of such damages.
168 |
169 | 9. Accepting Warranty or Additional Liability. While redistributing
170 | the Work or Derivative Works thereof, You may choose to offer,
171 | and charge a fee for, acceptance of support, warranty, indemnity,
172 | or other liability obligations and/or rights consistent with this
173 | License. However, in accepting such obligations, You may act only
174 | on Your own behalf and on Your sole responsibility, not on behalf
175 | of any other Contributor, and only if You agree to indemnify,
176 | defend, and hold each Contributor harmless for any liability
177 | incurred by, or claims asserted against, such Contributor by reason
178 | of your accepting any such warranty or additional liability.
179 |
180 | END OF TERMS AND CONDITIONS
181 |
182 | APPENDIX: How to apply the Apache License to your work.
183 |
184 | To apply the Apache License to your work, attach the following
185 | boilerplate notice, with the fields enclosed by brackets "[]"
186 | replaced with your own identifying information. (Don't include
187 | the brackets!) The text should be enclosed in the appropriate
188 | comment syntax for the file format. We also recommend that a
189 | file or class name and description of purpose be included on the
190 | same "printed page" as the copyright notice for easier
191 | identification within third-party archives.
192 |
193 | Copyright 2017, The TensorFlow Authors.
194 | Modifications Copyright 2017 Abigail See
195 |
196 | Licensed under the Apache License, Version 2.0 (the "License");
197 | you may not use this file except in compliance with the License.
198 | You may obtain a copy of the License at
199 |
200 | http://www.apache.org/licenses/LICENSE-2.0
201 |
202 | Unless required by applicable law or agreed to in writing, software
203 | distributed under the License is distributed on an "AS IS" BASIS,
204 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
205 | See the License for the specific language governing permissions and
206 | limitations under the License.
207 |
--------------------------------------------------------------------------------
/models/pointer_generator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__init__.py
--------------------------------------------------------------------------------
/models/pointer_generator/__pycache__/attention_decoder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/attention_decoder.cpython-36.pyc
--------------------------------------------------------------------------------
/models/pointer_generator/__pycache__/batcher.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/batcher.cpython-36.pyc
--------------------------------------------------------------------------------
/models/pointer_generator/__pycache__/beam_search.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/beam_search.cpython-36.pyc
--------------------------------------------------------------------------------
/models/pointer_generator/__pycache__/data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/data.cpython-36.pyc
--------------------------------------------------------------------------------
/models/pointer_generator/__pycache__/decode.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/decode.cpython-36.pyc
--------------------------------------------------------------------------------
/models/pointer_generator/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/pointer_generator/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/saverymax/qdriven-chiqa-summarization/257a00133869db47807b9dd10761a6dd3aa15306/models/pointer_generator/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/models/pointer_generator/attention_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file defines the decoder"""
18 |
19 | import tensorflow as tf
20 | from tensorflow.python.ops import variable_scope
21 | from tensorflow.python.ops import array_ops
22 | from tensorflow.python.ops import nn_ops
23 | from tensorflow.python.ops import math_ops
24 |
25 | # Note: this function is based on tf.contrib.legacy_seq2seq_attention_decoder, which is now outdated.
26 | # In the future, it would make more sense to write variants on the attention mechanism using the new seq2seq library for tensorflow 1.0: https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention
27 | def attention_decoder(decoder_inputs, initial_state, encoder_states, enc_padding_mask, cell, initial_state_attention=False, pointer_gen=True, use_coverage=False, prev_coverage=None):
28 | """
29 | Args:
30 | decoder_inputs: A list of 2D Tensors [batch_size x input_size].
31 | initial_state: 2D Tensor [batch_size x cell.state_size].
32 | encoder_states: 3D Tensor [batch_size x attn_length x attn_size].
33 | enc_padding_mask: 2D Tensor [batch_size x attn_length] containing 1s and 0s; indicates which of the encoder locations are padding (0) or a real token (1).
34 | cell: rnn_cell.RNNCell defining the cell function and size.
35 | initial_state_attention:
36 | Note that this attention decoder passes each decoder input through a linear layer with the previous step's context vector to get a modified version of the input. If initial_state_attention is False, on the first decoder step the "previous context vector" is just a zero vector. If initial_state_attention is True, we use initial_state to (re)calculate the previous step's context vector. We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step).
37 | pointer_gen: boolean. If True, calculate the generation probability p_gen for each decoder step.
38 | use_coverage: boolean. If True, use coverage mechanism.
39 | prev_coverage:
40 | If not None, a tensor with shape (batch_size, attn_length). The previous step's coverage vector. This is only not None in decode mode when using coverage.
41 |
42 | Returns:
43 | outputs: A list of the same length as decoder_inputs of 2D Tensors of
44 | shape [batch_size x cell.output_size]. The output vectors.
45 | state: The final state of the decoder. A tensor shape [batch_size x cell.state_size].
46 | attn_dists: A list containing tensors of shape (batch_size,attn_length).
47 | The attention distributions for each decoder step.
48 | p_gens: List of scalars. The values of p_gen for each decoder step. Empty list if pointer_gen=False.
49 | coverage: Coverage vector on the last step computed. None if use_coverage=False.
50 | """
51 | with variable_scope.variable_scope("attention_decoder") as scope:
52 | batch_size = encoder_states.get_shape()[0].value # if this line fails, it's because the batch size isn't defined
53 | attn_size = encoder_states.get_shape()[2].value # if this line fails, it's because the attention length isn't defined
54 |
55 | # Reshape encoder_states (need to insert a dim)
56 | encoder_states = tf.expand_dims(encoder_states, axis=2) # now is shape (batch_size, attn_len, 1, attn_size)
57 |
58 | # To calculate attention, we calculate
59 | # v^T tanh(W_h h_i + W_s s_t + b_attn)
60 | # where h_i is an encoder state, and s_t a decoder state.
61 | # attn_vec_size is the length of the vectors v, b_attn, (W_h h_i) and (W_s s_t).
62 | # We set it to be equal to the size of the encoder states.
63 | attention_vec_size = attn_size
64 |
65 | # Get the weight matrix W_h and apply it to each encoder state to get (W_h h_i), the encoder features
66 | W_h = variable_scope.get_variable("W_h", [1, 1, attn_size, attention_vec_size])
67 | encoder_features = nn_ops.conv2d(encoder_states, W_h, [1, 1, 1, 1], "SAME") # shape (batch_size,attn_length,1,attention_vec_size)
68 |
69 | # Get the weight vectors v and w_c (w_c is for coverage)
70 | v = variable_scope.get_variable("v", [attention_vec_size])
71 | if use_coverage:
72 | with variable_scope.variable_scope("coverage"):
73 | w_c = variable_scope.get_variable("w_c", [1, 1, 1, attention_vec_size])
74 |
75 | if prev_coverage is not None: # for beam search mode with coverage
76 | # reshape from (batch_size, attn_length) to (batch_size, attn_len, 1, 1)
77 | prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage,2),3)
78 |
79 | def attention(decoder_state, coverage=None):
80 | """Calculate the context vector and attention distribution from the decoder state.
81 |
82 | Args:
83 | decoder_state: state of the decoder
84 | coverage: Optional. Previous timestep's coverage vector, shape (batch_size, attn_len, 1, 1).
85 |
86 | Returns:
87 | context_vector: weighted sum of encoder_states
88 | attn_dist: attention distribution
89 | coverage: new coverage vector. shape (batch_size, attn_len, 1, 1)
90 | """
91 | with variable_scope.variable_scope("Attention"):
92 | # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper)
93 | decoder_features = linear(decoder_state, attention_vec_size, True) # shape (batch_size, attention_vec_size)
94 | decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_vec_size)
95 |
96 | def masked_attention(e):
97 | """Take softmax of e then apply enc_padding_mask and re-normalize"""
98 | attn_dist = nn_ops.softmax(e) # take softmax. shape (batch_size, attn_length)
99 | attn_dist *= enc_padding_mask # apply mask
100 | masked_sums = tf.reduce_sum(attn_dist, axis=1) # shape (batch_size)
101 | return attn_dist / tf.reshape(masked_sums, [-1, 1]) # re-normalize
102 |
103 | if use_coverage and coverage is not None: # non-first step of coverage
104 | # Multiply coverage vector by w_c to get coverage_features.
105 | coverage_features = nn_ops.conv2d(coverage, w_c, [1, 1, 1, 1], "SAME") # c has shape (batch_size, attn_length, 1, attention_vec_size)
106 |
107 | # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn)
108 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3]) # shape (batch_size,attn_length)
109 |
110 | # Calculate attention distribution
111 | attn_dist = masked_attention(e)
112 |
113 | # Update coverage vector
114 | coverage += array_ops.reshape(attn_dist, [batch_size, -1, 1, 1])
115 | else:
116 | # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
117 | e = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e
118 |
119 | # Calculate attention distribution
120 | attn_dist = masked_attention(e)
121 |
122 | if use_coverage: # first step of training
123 | coverage = tf.expand_dims(tf.expand_dims(attn_dist,2),2) # initialize coverage
124 |
125 | # Calculate the context vector from attn_dist and encoder_states
126 | context_vector = math_ops.reduce_sum(array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) * encoder_states, [1, 2]) # shape (batch_size, attn_size).
127 | context_vector = array_ops.reshape(context_vector, [-1, attn_size])
128 |
129 | return context_vector, attn_dist, coverage
130 |
131 | outputs = []
132 | attn_dists = []
133 | p_gens = []
134 | state = initial_state
135 | coverage = prev_coverage # initialize coverage to None or whatever was passed in
136 | context_vector = array_ops.zeros([batch_size, attn_size])
137 | context_vector.set_shape([None, attn_size]) # Ensure the second shape of attention vectors is set.
138 | if initial_state_attention: # true in decode mode
139 | # Re-calculate the context vector from the previous step so that we can pass it through a linear layer with this step's input to get a modified version of the input
140 | context_vector, _, coverage = attention(initial_state, coverage) # in decode mode, this is what updates the coverage vector
141 | for i, inp in enumerate(decoder_inputs):
142 | tf.logging.info("Adding attention_decoder timestep %i of %i", i, len(decoder_inputs))
143 | if i > 0:
144 | variable_scope.get_variable_scope().reuse_variables()
145 |
146 | # Merge input and previous attentions into one vector x of the same size as inp
147 | input_size = inp.get_shape().with_rank(2)[1]
148 | if input_size.value is None:
149 | raise ValueError("Could not infer input size from input: %s" % inp.name)
150 | x = linear([inp] + [context_vector], input_size, True)
151 |
152 | # Run the decoder RNN cell. cell_output = decoder state
153 | cell_output, state = cell(x, state)
154 |
155 | # Run the attention mechanism.
156 | if i == 0 and initial_state_attention: # always true in decode mode
157 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): # you need this because you've already run the initial attention(...) call
158 | context_vector, attn_dist, _ = attention(state, coverage) # don't allow coverage to update
159 | else:
160 | context_vector, attn_dist, coverage = attention(state, coverage)
161 | attn_dists.append(attn_dist)
162 |
163 | # Calculate p_gen
164 | if pointer_gen:
165 | with tf.variable_scope('calculate_pgen'):
166 | p_gen = linear([context_vector, state.c, state.h, x], 1, True) # a scalar
167 | p_gen = tf.sigmoid(p_gen)
168 | p_gens.append(p_gen)
169 |
170 | # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer
171 | # This is V[s_t, h*_t] + b in the paper
172 | with variable_scope.variable_scope("AttnOutputProjection"):
173 | output = linear([cell_output] + [context_vector], cell.output_size, True)
174 | outputs.append(output)
175 |
176 | # If using coverage, reshape it
177 | if coverage is not None:
178 | coverage = array_ops.reshape(coverage, [batch_size, -1])
179 |
180 | return outputs, state, attn_dists, p_gens, coverage
181 |
182 |
183 |
184 | def linear(args, output_size, bias, bias_start=0.0, scope=None):
185 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
186 |
187 | Args:
188 | args: a 2D Tensor or a list of 2D, batch x n, Tensors.
189 | output_size: int, second dimension of W[i].
190 | bias: boolean, whether to add a bias term or not.
191 | bias_start: starting value to initialize the bias; 0 by default.
192 | scope: VariableScope for the created subgraph; defaults to "Linear".
193 |
194 | Returns:
195 | A 2D Tensor with shape [batch x output_size] equal to
196 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
197 |
198 | Raises:
199 | ValueError: if some of the arguments has unspecified or wrong shape.
200 | """
201 | if args is None or (isinstance(args, (list, tuple)) and not args):
202 | raise ValueError("`args` must be specified")
203 | if not isinstance(args, (list, tuple)):
204 | args = [args]
205 |
206 | # Calculate the total size of arguments on dimension 1.
207 | total_arg_size = 0
208 | shapes = [a.get_shape().as_list() for a in args]
209 | for shape in shapes:
210 | if len(shape) != 2:
211 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
212 | if not shape[1]:
213 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
214 | else:
215 | total_arg_size += shape[1]
216 |
217 | # Now the computation.
218 | with tf.variable_scope(scope or "Linear"):
219 | matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
220 | if len(args) == 1:
221 | res = tf.matmul(args[0], matrix)
222 | else:
223 | res = tf.matmul(tf.concat(axis=1, values=args), matrix)
224 | if not bias:
225 | return res
226 | bias_term = tf.get_variable(
227 | "Bias", [output_size], initializer=tf.constant_initializer(bias_start))
228 | return res + bias_term
229 |
--------------------------------------------------------------------------------
/models/pointer_generator/beam_search.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to run beam search decoding"""
18 |
19 | import tensorflow as tf
20 | import numpy as np
21 | import data
22 |
23 | FLAGS = tf.app.flags.FLAGS
24 |
25 | class Hypothesis(object):
26 | """Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis."""
27 |
28 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage):
29 | """Hypothesis constructor.
30 |
31 | Args:
32 | tokens: List of integers. The ids of the tokens that form the summary so far.
33 | log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far.
34 | state: Current state of the decoder, a LSTMStateTuple.
35 | attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far.
36 | p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far.
37 | coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector.
38 | """
39 | self.tokens = tokens
40 | self.log_probs = log_probs
41 | self.state = state
42 | self.attn_dists = attn_dists
43 | self.p_gens = p_gens
44 | self.coverage = coverage
45 |
46 | def extend(self, token, log_prob, state, attn_dist, p_gen, coverage):
47 | """Return a NEW hypothesis, extended with the information from the latest step of beam search.
48 |
49 | Args:
50 | token: Integer. Latest token produced by beam search.
51 | log_prob: Float. Log prob of the latest token.
52 | state: Current decoder state, a LSTMStateTuple.
53 | attn_dist: Attention distribution from latest step. Numpy array shape (attn_length).
54 | p_gen: Generation probability on latest step. Float.
55 | coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage.
56 | Returns:
57 | New Hypothesis for next step.
58 | """
59 | return Hypothesis(tokens = self.tokens + [token],
60 | log_probs = self.log_probs + [log_prob],
61 | state = state,
62 | attn_dists = self.attn_dists + [attn_dist],
63 | p_gens = self.p_gens + [p_gen],
64 | coverage = coverage)
65 |
66 | @property
67 | def latest_token(self):
68 | return self.tokens[-1]
69 |
70 | @property
71 | def log_prob(self):
72 | # the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far
73 | return sum(self.log_probs)
74 |
75 | @property
76 | def avg_log_prob(self):
77 | # normalize log probability by number of tokens (otherwise longer sequences always have lower probability)
78 | return self.log_prob / len(self.tokens)
79 |
80 |
81 | def run_beam_search(sess, model, vocab, batch):
82 | """Performs beam search decoding on the given example.
83 |
84 | Args:
85 | sess: a tf.Session
86 | model: a seq2seq model
87 | vocab: Vocabulary object
88 | batch: Batch object that is the same example repeated across the batch
89 |
90 | Returns:
91 | best_hyp: Hypothesis object; the best hypothesis found by beam search.
92 | """
93 | # Run the encoder to get the encoder hidden states and decoder initial state
94 | enc_states, dec_in_state = model.run_encoder(sess, batch)
95 | # dec_in_state is a LSTMStateTuple
96 | # enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim].
97 |
98 | # Initialize beam_size-many hyptheses
99 | hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)],
100 | log_probs=[0.0],
101 | state=dec_in_state,
102 | attn_dists=[],
103 | p_gens=[],
104 | coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length
105 | ) for _ in range(FLAGS.beam_size)]
106 | results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token)
107 |
108 | steps = 0
109 | while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size:
110 | latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis
111 | latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
112 | states = [h.state for h in hyps] # list of current decoder states of the hypotheses
113 | prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None)
114 |
115 | # Run one step of the decoder to get the new info
116 | (topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess,
117 | batch=batch,
118 | latest_tokens=latest_tokens,
119 | enc_states=enc_states,
120 | dec_init_states=states,
121 | prev_coverage=prev_coverage)
122 |
123 | # Extend each hypothesis and collect them all in all_hyps
124 | all_hyps = []
125 | num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
126 | for i in range(num_orig_hyps):
127 | h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info
128 | for j in range(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps:
129 | # Extend the ith hypothesis with the jth option
130 | new_hyp = h.extend(token=topk_ids[i, j],
131 | log_prob=topk_log_probs[i, j],
132 | state=new_state,
133 | attn_dist=attn_dist,
134 | p_gen=p_gen,
135 | coverage=new_coverage_i)
136 | all_hyps.append(new_hyp)
137 |
138 | # Filter and collect any hypotheses that have produced the end token.
139 | hyps = [] # will contain hypotheses for the next step
140 | for h in sort_hyps(all_hyps): # in order of most likely h
141 | if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached...
142 | # If this hypothesis is sufficiently long, put in results. Otherwise discard.
143 | if steps >= FLAGS.min_dec_steps:
144 | results.append(h)
145 | else: # hasn't reached stop token, so continue to extend this hypothesis
146 | hyps.append(h)
147 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
148 | # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
149 | break
150 |
151 | steps += 1
152 |
153 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps
154 |
155 | if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
156 | results = hyps
157 |
158 | # Sort hypotheses by average log probability
159 | hyps_sorted = sort_hyps(results)
160 |
161 | # Return the hypothesis with highest average log prob
162 | return hyps_sorted[0]
163 |
164 | def sort_hyps(hyps):
165 | """Return a list of Hypothesis objects, sorted by descending average log probability"""
166 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)
167 |
--------------------------------------------------------------------------------
/models/pointer_generator/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it"""
18 |
19 | import glob
20 | import json
21 | import random
22 | import struct
23 | import csv
24 | from tensorflow.core.example import example_pb2
25 |
26 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
27 | SENTENCE_START = ''
28 | SENTENCE_END = ''
29 |
30 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
31 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
32 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
33 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences
34 |
35 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file.
36 |
37 |
38 | class Vocab(object):
39 | """Vocabulary class for mapping between words and ids (integers)"""
40 |
41 | def __init__(self, vocab_file, max_size):
42 | """Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file.
43 |
44 | Args:
45 | vocab_file: path to the vocab file, which is assumed to contain " " on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though.
46 | max_size: integer. The maximum size of the resulting Vocabulary."""
47 | self._word_to_id = {}
48 | self._id_to_word = {}
49 | self._count = 0 # keeps track of total number of words in the Vocab
50 |
51 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
52 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
53 | self._word_to_id[w] = self._count
54 | self._id_to_word[self._count] = w
55 | self._count += 1
56 |
57 | # Read the vocab file and add words up to max_size
58 | with open(vocab_file, 'r') as vocab_f:
59 | for line in vocab_f:
60 | pieces = line.split()
61 | if len(pieces) != 2:
62 | print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line)
63 | continue
64 | w = pieces[0]
65 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
66 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
67 | if w in self._word_to_id:
68 | raise Exception('Duplicated word in vocabulary file: %s' % w)
69 | self._word_to_id[w] = self._count
70 | self._id_to_word[self._count] = w
71 | self._count += 1
72 | if max_size != 0 and self._count >= max_size:
73 | print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
74 | break
75 |
76 | print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]))
77 |
78 | def word2id(self, word):
79 | """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV."""
80 | if word not in self._word_to_id:
81 | return self._word_to_id[UNKNOWN_TOKEN]
82 | return self._word_to_id[word]
83 |
84 | def id2word(self, word_id):
85 | """Returns the word (string) corresponding to an id (integer)."""
86 | if word_id not in self._id_to_word:
87 | raise ValueError('Id not found in vocab: %d' % word_id)
88 | return self._id_to_word[word_id]
89 |
90 | def size(self):
91 | """Returns the total size of the vocabulary"""
92 | return self._count
93 |
94 | def write_metadata(self, fpath):
95 | """Writes metadata file for Tensorboard word embedding visualizer as described here:
96 | https://www.tensorflow.org/get_started/embedding_viz
97 |
98 | Args:
99 | fpath: place to write the metadata file
100 | """
101 | print("Writing word embedding metadata file to %s..." % (fpath))
102 | with open(fpath, "w") as f:
103 | fieldnames = ['word']
104 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
105 | for i in range(self.size()):
106 | writer.writerow({"word": self._id_to_word[i]})
107 |
108 |
109 | def example_generator(data_path, single_pass):
110 | """Generates tf.Examples from data files.
111 |
112 | Binary data format: . represents the byte size
113 | of . is serialized tf.Example proto. The tf.Example contains
114 | the tokenized article text and summary.
115 |
116 | Args:
117 | data_path:
118 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all.
119 | single_pass:
120 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely.
121 |
122 | Yields:
123 | Deserialized tf.Example.
124 | """
125 | while True:
126 | filelist = glob.glob(data_path) # get the list of datafiles
127 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty
128 | if single_pass:
129 | filelist = sorted(filelist)
130 | else:
131 | random.shuffle(filelist)
132 | for f in filelist:
133 | reader = open(f, 'rb')
134 | while True:
135 | len_bytes = reader.read(8)
136 | if not len_bytes: break # finished reading this file
137 | str_len = struct.unpack('q', len_bytes)[0]
138 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
139 | yield example_pb2.Example.FromString(example_str)
140 | if single_pass:
141 | print("example_generator completed reading all datafiles. No more data.")
142 | break
143 |
144 |
145 | def asumm_example_generator(data_path, single_pass, question_driven, tag_sentences):
146 | """Generates tf.Examples from data files for answer summarization. Coppied from example_generator.
147 |
148 | Binary data format: . represents the byte size
149 | of . is serialized tf.Example proto. The tf.Example contains
150 | the tokenized article text and summary.
151 |
152 | Args:
153 | data_path:
154 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all.
155 | single_pass:
156 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely.
157 |
158 | Yields:
159 | Deserialized tf.Example.
160 |
161 | Returns data until iterated to the end of the file
162 | """
163 | while True:
164 | with open(data_path, "r", encoding="utf-8") as f:
165 | asumm_data = json.load(f)
166 | keys = list(asumm_data.keys())
167 | random.shuffle(keys)
168 | for key in keys:
169 | summary = asumm_data[key]['summary']
170 | articles = asumm_data[key]['articles']
171 | question = asumm_data[key]['question']
172 | # Only add question or tag sentences during inference, not during training
173 | if tag_sentences and single_pass:
174 | split_summ = summary.split(".")
175 | summary = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in split_summ])
176 | if question_driven and single_pass:
177 | articles = question + " [QUESTION?] " + articles
178 | yield (summary, articles, question)
179 | if single_pass:
180 | print("example_generator completed reading all datafiles. No more data.")
181 | break
182 |
183 |
184 | def article2ids(article_words, vocab):
185 | """Map the article words to their ids. Also return a list of OOVs in the article.
186 |
187 | Args:
188 | article_words: list of words (strings)
189 | vocab: Vocabulary object
190 |
191 | Returns:
192 | ids:
193 | A list of word ids (integers); OOVs are represented by their temporary article OOV number. If the vocabulary size is 50k and the article has 3 OOVs, then these temporary OOV numbers will be 50000, 50001, 50002.
194 | oovs:
195 | A list of the OOV words in the article (strings), in the order corresponding to their temporary article OOV numbers."""
196 | ids = []
197 | oovs = []
198 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
199 | for w in article_words:
200 | i = vocab.word2id(w)
201 | if i == unk_id: # If w is OOV
202 | if w not in oovs: # Add to list of OOVs
203 | oovs.append(w)
204 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
205 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
206 | else:
207 | ids.append(i)
208 | return ids, oovs
209 |
210 |
211 | def abstract2ids(abstract_words, vocab, article_oovs):
212 | """Map the abstract words to their ids. In-article OOVs are mapped to their temporary OOV numbers.
213 |
214 | Args:
215 | abstract_words: list of words (strings)
216 | vocab: Vocabulary object
217 | article_oovs: list of in-article OOV words (strings), in the order corresponding to their temporary article OOV numbers
218 |
219 | Returns:
220 | ids: List of ids (integers). In-article OOV words are mapped to their temporary OOV numbers. Out-of-article OOV words are mapped to the UNK token id."""
221 | ids = []
222 | unk_id = vocab.word2id(UNKNOWN_TOKEN)
223 | for w in abstract_words:
224 | i = vocab.word2id(w)
225 | if i == unk_id: # If w is an OOV word
226 | if w in article_oovs: # If w is an in-article OOV
227 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
228 | ids.append(vocab_idx)
229 | else: # If w is an out-of-article OOV
230 | ids.append(unk_id) # Map to the UNK token id
231 | else:
232 | ids.append(i)
233 | return ids
234 |
235 |
236 | def outputids2words(id_list, vocab, article_oovs):
237 | """Maps output ids to words, including mapping in-article OOVs from their temporary ids to the original OOV string (applicable in pointer-generator mode).
238 |
239 | Args:
240 | id_list: list of ids (integers)
241 | vocab: Vocabulary object
242 | article_oovs: list of OOV words (strings) in the order corresponding to their temporary article OOV ids (that have been assigned in pointer-generator mode), or None (in baseline mode)
243 |
244 | Returns:
245 | words: list of words (strings)
246 | """
247 | words = []
248 | for i in id_list:
249 | try:
250 | w = vocab.id2word(i) # might be [UNK]
251 | except ValueError as e: # w is OOV
252 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode"
253 | article_oov_idx = i - vocab.size()
254 | try:
255 | w = article_oovs[article_oov_idx]
256 | except ValueError as e: # i doesn't correspond to an article oov
257 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs)))
258 | words.append(w)
259 | return words
260 |
261 |
262 | def abstract2sents(abstract):
263 | """Splits abstract text from datafile into list of sentences.
264 |
265 | Args:
266 | abstract: string containing and tags for starts and ends of sentences
267 |
268 | Returns:
269 | sents: List of sentence strings (no tags)"""
270 | cur = 0
271 | sents = []
272 | while True:
273 | try:
274 | start_p = abstract.index(SENTENCE_START, cur)
275 | end_p = abstract.index(SENTENCE_END, start_p + 1)
276 | cur = end_p + len(SENTENCE_END)
277 | sents.append(abstract[start_p+len(SENTENCE_START):end_p])
278 | except ValueError as e: # no more sentences
279 | return sents
280 |
281 |
282 | def show_art_oovs(article, vocab):
283 | """Returns the article string, highlighting the OOVs by placing __underscores__ around them"""
284 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
285 | words = article.split(' ')
286 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words]
287 | out_str = ' '.join(words)
288 | return out_str
289 |
290 |
291 | def show_abs_oovs(abstract, vocab, article_oovs):
292 | """Returns the abstract string, highlighting the article OOVs with __underscores__.
293 |
294 | If a list of article_oovs is provided, non-article OOVs are differentiated like !!__this__!!.
295 |
296 | Args:
297 | abstract: string
298 | vocab: Vocabulary object
299 | article_oovs: list of words (strings), or None (in baseline mode)
300 | """
301 | unk_token = vocab.word2id(UNKNOWN_TOKEN)
302 | words = abstract.split(' ')
303 | new_words = []
304 | for w in words:
305 | if vocab.word2id(w) == unk_token: # w is oov
306 | if article_oovs is None: # baseline mode
307 | new_words.append("__%s__" % w)
308 | else: # pointer-generator mode
309 | if w in article_oovs:
310 | new_words.append("__%s__" % w)
311 | else:
312 | new_words.append("!!__%s__!!" % w)
313 | else: # w is in-vocab word
314 | new_words.append(w)
315 | out_str = ' '.join(new_words)
316 | return out_str
317 |
--------------------------------------------------------------------------------
/models/pointer_generator/decode.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains code to run beam search decoding, including running ROUGE evaluation and producing JSON datafiles for the in-browser attention visualizer, which can be found here https://github.com/abisee/attn_vis"""
18 |
19 | import os
20 | import time
21 | import tensorflow as tf
22 | import beam_search
23 | import data
24 | import json
25 | import pyrouge
26 | # Using py-rouge, the pure python rouge evaluator
27 | #import rouge
28 | import util
29 | import logging
30 | import numpy as np
31 | import sys
32 |
33 | FLAGS = tf.app.flags.FLAGS
34 |
35 | SECS_UNTIL_NEW_CKPT = 60 # max number of seconds before loading new checkpoint
36 |
37 |
38 | class BeamSearchDecoder(object):
39 | """Beam search decoder."""
40 |
41 | def __init__(self, model, batcher, vocab):
42 | """Initialize decoder.
43 |
44 | Args:
45 | model: a Seq2SeqAttentionModel object.
46 | batcher: a Batcher object.
47 | vocab: Vocabulary object
48 | """
49 | self._model = model
50 | self._model.build_graph()
51 | self._batcher = batcher
52 | self._vocab = vocab
53 | self._saver = tf.train.Saver() # we use this to load checkpoints for decoding
54 | self._sess = tf.Session(config=util.get_config())
55 | self._generated_answers = {"question": [], "ref_summary": [], "gen_summary": []}
56 |
57 | # For running decode for answer summarization, I have modified this to load the best model from the eval directory
58 | ckpt_path = util.load_ckpt(self._saver, self._sess)
59 | #ckpt_path = util.load_ckpt(self._saver, self._sess, "eval")
60 |
61 | if FLAGS.single_pass:
62 | # Make a descriptive decode directory name
63 | ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] # this is something of the form "ckpt-123456"
64 | self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name))
65 | #if os.path.exists(self._decode_dir):
66 | # raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir)
67 |
68 | else: # Generic decode dir name
69 | self._decode_dir = os.path.join(FLAGS.log_root, "decode")
70 |
71 | # Make the decode dir if necessary
72 | if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)
73 |
74 | if FLAGS.single_pass:
75 | # Make the dirs to contain output written in the correct format for pyrouge
76 | self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
77 | if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir)
78 | self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
79 | if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
80 |
81 |
82 | def decode(self):
83 | """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
84 | t0 = time.time()
85 | counter = 0
86 | while True:
87 | batch = self._batcher.next_batch() # 1 example repeated across batch
88 | #if counter >= 2:
89 | # batch = None
90 | if batch is None: # finished decoding dataset in single_pass mode
91 | assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
92 | tf.logging.info("Decoder has finished reading dataset for single_pass.")
93 | if FLAGS.eval_type == "cnn":
94 | tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
95 | results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
96 | rouge_log(results_dict, self._decode_dir)
97 | if FLAGS.eval_type == "medsumm":
98 | tf.logging.info("Writing generated answer summaries to file...")
99 | with open(FLAGS.generated_data_file, "w", encoding="utf-8") as f:
100 | json.dump(self._generated_answers, f, indent=4)
101 | return
102 |
103 | question = batch.questions[0]
104 | original_article = batch.original_articles[0] # string
105 | original_abstract = batch.original_abstracts[0] # string
106 | original_abstract_sents = batch.original_abstracts_sents[0] # list of strings
107 |
108 | article_withunks = data.show_art_oovs(original_article, self._vocab) # string
109 | abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string
110 |
111 | # Run beam search to get best Hypothesis
112 | best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
113 |
114 | # Extract the output ids from the hypothesis and convert back to words
115 | output_ids = [int(t) for t in best_hyp.tokens[1:]]
116 | decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))
117 |
118 | # Remove the [STOP] token from decoded_words, if necessary
119 | try:
120 | fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
121 | decoded_words = decoded_words[:fst_stop_idx]
122 | except ValueError:
123 | decoded_words = decoded_words
124 | decoded_output = ' '.join(decoded_words) # single string
125 |
126 | if FLAGS.single_pass:
127 | if FLAGS.eval_type == "medsumm":
128 | # Save generated data for answer summ evaluation:
129 | self.write_data_for_medsumm_eval(original_abstract_sents, decoded_words, question, counter) # write ref summary and decoded summary to file, for later evaluation.
130 | tf.logging.info("saving summary %i", counter)
131 | if FLAGS.eval_type == "cnn":
132 | # write data for See's original evaluation done with pyrouge.
133 | self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
134 | counter += 1 # this is how many examples we've decoded
135 | else:
136 | print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
137 | self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool
138 |
139 | # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
140 | t1 = time.time()
141 | if t1-t0 > SECS_UNTIL_NEW_CKPT:
142 | tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
143 | _ = util.load_ckpt(self._saver, self._sess)
144 | t0 = time.time()
145 |
146 | def write_data_for_medsumm_eval(self, original_abstract_sents, decoded_words, question, counter):
147 | """
148 | write reference summaries and generated summaries to file
149 | """
150 | decoded_sents = ""
151 | while len(decoded_words) > 0:
152 | try:
153 | fst_period_idx = decoded_words.index(".")
154 | except ValueError: # there is text remaining that doesn't end in "."
155 | fst_period_idx = len(decoded_words)
156 | sent = decoded_words[:fst_period_idx+1] # sentence up to and including the period
157 | decoded_words = decoded_words[fst_period_idx+1:] # everything else
158 | decoded_sents += (" " + ' '.join(sent))
159 | abstract_sents = ". ".join(original_abstract_sents)
160 |
161 | self._generated_answers['question'].append(question)
162 | self._generated_answers['ref_summary'].append(abstract_sents)
163 | self._generated_answers['gen_summary'].append(decoded_sents)
164 |
165 |
166 | def write_for_rouge(self, reference_sents, decoded_words, ex_index):
167 | """Write output to file in correct format for eval with pyrouge. This is called in single_pass mode.
168 |
169 | Args:
170 | reference_sents: list of strings
171 | decoded_words: list of strings
172 | ex_index: int, the index with which to label the files
173 | """
174 | # First, divide decoded output into sentences
175 | decoded_sents = []
176 | while len(decoded_words) > 0:
177 | try:
178 | fst_period_idx = decoded_words.index(".")
179 | except ValueError: # there is text remaining that doesn't end in "."
180 | fst_period_idx = len(decoded_words)
181 | sent = decoded_words[:fst_period_idx+1] # sentence up to and including the period
182 | decoded_words = decoded_words[fst_period_idx+1:] # everything else
183 | decoded_sents.append(' '.join(sent))
184 |
185 | # pyrouge calls a perl script that puts the data into HTML files.
186 | # Therefore we need to make our output HTML safe.
187 | decoded_sents = [make_html_safe(w) for w in decoded_sents]
188 | reference_sents = [make_html_safe(w) for w in reference_sents]
189 |
190 | # Write to file
191 | ref_file = os.path.join(self._rouge_ref_dir, "%06d_reference.txt" % ex_index)
192 | decoded_file = os.path.join(self._rouge_dec_dir, "%06d_decoded.txt" % ex_index)
193 |
194 | with open(ref_file, "w") as f:
195 | for idx,sent in enumerate(reference_sents):
196 | f.write(sent) if idx==len(reference_sents)-1 else f.write(sent+"\n")
197 | with open(decoded_file, "w") as f:
198 | for idx,sent in enumerate(decoded_sents):
199 | f.write(sent) if idx==len(decoded_sents)-1 else f.write(sent+"\n")
200 |
201 | tf.logging.info("Wrote example %i to file" % ex_index)
202 |
203 |
204 | def write_for_attnvis(self, article, abstract, decoded_words, attn_dists, p_gens):
205 | """Write some data to json file, which can be read into the in-browser attention visualizer tool:
206 | https://github.com/abisee/attn_vis
207 |
208 | Args:
209 | article: The original article string.
210 | abstract: The human (correct) abstract string.
211 | attn_dists: List of arrays; the attention distributions.
212 | decoded_words: List of strings; the words of the generated summary.
213 | p_gens: List of scalars; the p_gen values. If not running in pointer-generator mode, list of None.
214 | """
215 | article_lst = article.split() # list of words
216 | decoded_lst = decoded_words # list of decoded words
217 | to_write = {
218 | 'article_lst': [make_html_safe(t) for t in article_lst],
219 | 'decoded_lst': [make_html_safe(t) for t in decoded_lst],
220 | 'abstract_str': make_html_safe(abstract),
221 | 'attn_dists': attn_dists
222 | }
223 | if FLAGS.pointer_gen:
224 | to_write['p_gens'] = p_gens
225 | output_fname = os.path.join(self._decode_dir, 'attn_vis_data.json')
226 | with open(output_fname, 'w') as output_file:
227 | json.dump(to_write, output_file)
228 | tf.logging.info('Wrote visualization data to %s', output_fname)
229 |
230 |
231 | def print_results(article, abstract, decoded_output):
232 | """Prints the article, the reference summmary and the decoded summary to screen"""
233 | print("---------------------------------------------------------------------------")
234 | tf.logging.info('ARTICLE: %s', article)
235 | tf.logging.info('REFERENCE SUMMARY: %s', abstract)
236 | tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
237 | print("---------------------------------------------------------------------------")
238 |
239 |
240 | def make_html_safe(s):
241 | """Replace any angled brackets in string s to avoid interfering with HTML attention visualizer."""
242 | s.replace("<", "<")
243 | s.replace(">", ">")
244 | return s
245 |
246 |
247 | def rouge_eval(ref_dir, dec_dir):
248 | """Evaluate the files in ref_dir and dec_dir with pyrouge, returning results_dict"""
249 | r = pyrouge.Rouge155()
250 | r.model_filename_pattern = '#ID#_reference.txt'
251 | r.system_filename_pattern = '(\d+)_decoded.txt'
252 | r.model_dir = ref_dir
253 | r.system_dir = dec_dir
254 | logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging
255 | rouge_results = r.convert_and_evaluate()
256 | return r.output_to_dict(rouge_results)
257 |
258 |
259 | def rouge_log(results_dict, dir_to_write):
260 | """Log ROUGE results to screen and write to file.
261 |
262 | Args:
263 | results_dict: the dictionary returned by pyrouge
264 | dir_to_write: the directory where we will write the results to"""
265 | log_str = ""
266 | for x in ["1","2","l"]:
267 | log_str += "\nROUGE-%s:\n" % x
268 | for y in ["f_score", "recall", "precision"]:
269 | key = "rouge_%s_%s" % (x,y)
270 | key_cb = key + "_cb"
271 | key_ce = key + "_ce"
272 | val = results_dict[key]
273 | val_cb = results_dict[key_cb]
274 | val_ce = results_dict[key_ce]
275 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce)
276 | tf.logging.info(log_str) # log to screen
277 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt")
278 | tf.logging.info("Writing final ROUGE results to %s...", results_file)
279 | with open(results_file, "w") as f:
280 | f.write(log_str)
281 |
282 | def get_decode_dir_name(ckpt_name):
283 | """Make a descriptive name for the decode dir, including the name of the checkpoint we use to decode. This is called in single_pass mode."""
284 |
285 | if "train" in FLAGS.data_path: dataset = "train"
286 | elif "val" in FLAGS.data_path: dataset = "val"
287 | elif "test" in FLAGS.data_path: dataset = "test"
288 | elif "summ" in FLAGS.data_path: dataset = "test"
289 | else: raise ValueError("FLAGS.data_path %s should contain one of train, val, test, or summ" % (FLAGS.data_path))
290 | dirname = "decode_%s_%imaxenc_%ibeam_%imindec_%imaxdec" % (dataset, FLAGS.max_enc_steps, FLAGS.beam_size, FLAGS.min_dec_steps, FLAGS.max_dec_steps)
291 | if ckpt_name is not None:
292 | dirname += "_%s" % ckpt_name
293 | return dirname
294 |
--------------------------------------------------------------------------------
/models/pointer_generator/eval_medsumm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | validation_data=../../data_processing/data/medinfo_section2answer_validation_data_${1}.json
4 | # Bioasq abs2summ evaluating on medinfo page > section with question
5 | if [ $1 == "with_question" ]; then
6 | echo $1
7 | experiment=bioasq_abs2summ_with_question/
8 | python -u run_medsumm.py --mode=eval --data_path=$validation_data --vocab_path=bioasq_abs2summ_vocab --exp_name=$experiment
9 | fi
10 |
11 | # And same thing but without question
12 | if [ $1 == "without_question" ]; then
13 | echo $1
14 | experiment=bioasq_abs2summ_without_question/
15 | python -u run_medsumm.py --mode=eval --data_path=$validation_data --vocab_path=bioasq_abs2summ_vocab --exp_name=$experiment
16 | fi
17 |
--------------------------------------------------------------------------------
/models/pointer_generator/inspect_checkpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple script that checks if a checkpoint is corrupted with any inf/NaN values. Run like this:
3 | python inspect_checkpoint.py model.12345
4 | """
5 |
6 | import tensorflow as tf
7 | import sys
8 | import numpy as np
9 |
10 |
11 | if __name__ == '__main__':
12 | if len(sys.argv) != 2:
13 | raise Exception("Usage: python inspect_checkpoint.py \nNote: Do not include the .data .index or .meta part of the model checkpoint in file_name.")
14 | file_name = sys.argv[1]
15 | reader = tf.train.NewCheckpointReader(file_name)
16 | var_to_shape_map = reader.get_variable_to_shape_map()
17 |
18 | finite = []
19 | all_infnan = []
20 | some_infnan = []
21 |
22 | for key in sorted(var_to_shape_map.keys()):
23 | tensor = reader.get_tensor(key)
24 | if np.all(np.isfinite(tensor)):
25 | finite.append(key)
26 | else:
27 | if not np.any(np.isfinite(tensor)):
28 | all_infnan.append(key)
29 | else:
30 | some_infnan.append(key)
31 |
32 | print("\nFINITE VARIABLES:")
33 | for key in finite: print(key)
34 |
35 | print("\nVARIABLES THAT ARE ALL INF/NAN:")
36 | for key in all_infnan: print(key)
37 |
38 | print("\nVARIABLES THAT CONTAIN SOME FINITE, SOME INF/NAN VALUES:")
39 | for key in some_infnan: print(key)
40 |
41 | if not all_infnan and not some_infnan:
42 | print("CHECK PASSED: checkpoint contains no inf/NaN values")
43 | else:
44 | print("CHECK FAILED: checkpoint contains some inf/NaN values")
45 |
--------------------------------------------------------------------------------
/models/pointer_generator/make_asumm_pg_vocab.py:
--------------------------------------------------------------------------------
1 | """
2 | Module to make vocab file for pointer generator network. Code modeled from make_datafiles.py
3 | in cnn-dailymail processing repository.
4 |
5 | File will have format
6 | word1 n_occurences
7 | word2 n_occurences
8 | ...
9 | wordn n_occurences
10 |
11 | To run:
12 | python make_asumm_pg_vocab.py --vocab_path=bioasq_abs2summ_vocab --data_file=../../data_process/data/bioasq_abs2summ_training_data_without_question.json
13 | """
14 |
15 | from collections import Counter
16 | import json
17 | import argparse
18 |
19 | from spacy.tokenizer import Tokenizer
20 | from spacy.lang.en import English
21 |
22 | def get_args():
23 | """
24 | Get command line arguments
25 | """
26 | parser = argparse.ArgumentParser(description="Arguments for data exploration")
27 | parser.add_argument("--vocab_path",
28 | dest="vocab_path",
29 | help="Path to create vocab file")
30 | parser.add_argument("--data_file",
31 | dest="data_file",
32 | help="Path to load data to make vocab")
33 | return parser
34 |
35 |
36 | def make_vocab(vocab_counter, vocab_file, VOCAB_SIZE, article, abstract, tokenizer):
37 | """
38 | For each page/summary pair, tokenize on spaces, do a word count, and save tokens to file
39 | """
40 | art_tokens = [t.text.strip() for t in tokenizer(article)]
41 | abs_tokens = [t.text.strip() for t in tokenizer(abstract)]
42 | tokens = art_tokens + abs_tokens
43 | tokens = [t for t in tokens if t != "" and t != "" and t != ""]
44 | vocab_counter.update(tokens)
45 |
46 |
47 | def load_data(data_file, vocab_path):
48 | """
49 | Load data and tokenize each file
50 | """
51 | VOCAB_SIZE = 200000
52 | with open(data_file, "r", encoding="utf-8") as f:
53 | training_data = json.load(f)
54 | print("Writing vocab file...")
55 | vocab_file = open(vocab_path, 'w', encoding="utf-8")
56 | vocab_counter = Counter()
57 | # Initiate spacy tokenizer without bells and whistles
58 | nlp = English()
59 | tokenizer = Tokenizer(nlp.vocab)
60 |
61 | for url, topic in training_data.items():
62 | article = topic['articles']
63 | abstract = topic['summary']
64 | make_vocab(vocab_counter, vocab_file, VOCAB_SIZE, article, abstract, tokenizer)
65 | # After updating counter, write counts/words
66 | print("20 most common words:", vocab_counter.most_common(20))
67 | for word, count in vocab_counter.most_common(VOCAB_SIZE):
68 | vocab_file.write(word + ' ' + str(count) + '\n')
69 |
70 | vocab_file.close()
71 | print("Finished writing vocab file")
72 |
73 |
74 | if __name__ == "__main__":
75 | args = get_args().parse_args()
76 | load_data(args.data_file, args.vocab_path)
77 |
--------------------------------------------------------------------------------
/models/pointer_generator/requirements.txt:
--------------------------------------------------------------------------------
1 | pyrouge
2 | spacy
3 |
--------------------------------------------------------------------------------
/models/pointer_generator/run_chiqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --output=/data/saveryme/asumm/models/pointer-generator/slurm_logs/slurm_%j.out
3 | #SBATCH --error=/data/saveryme/asumm/models/pointer-generator/slurm_logs/slurm_%j.error
4 | #SBATCH --job-name=chiq_eval
5 | #SBATCH --partition=gpu
6 | #SBATCH --gres=gpu:v100x:1
7 | #SBATCH --mem=10g
8 | #SBATCH --cpus-per-task=12
9 | #SBATCH --time=2-00:00:00
10 |
11 | # Can also include the with_question string in for loop
12 | for q in without_question
13 | do
14 | experiment=bioasq_abs2summ_${q}
15 | if [ ${q} == "with_question" ]; then
16 | q_driven=True
17 | else
18 | q_driven=False
19 | fi
20 | for summ_task in page2answer section2answer
21 | do
22 | for summ_type in single_abstractive single_extractive
23 | do
24 | data=${summ_task}_${summ_type}_summ.json
25 | input_data=../../data_processing/data/${data}
26 | predict_file=pointergen_chiqa_bioasq_abs2summ_${q}_${summ_task}_${summ_type}.json
27 | echo ${q} ${q_driven}
28 | echo $input_data
29 | echo $predict_file
30 | python run_medsumm.py \
31 | --mode=decode \
32 | --data_path=${input_data} \
33 | --vocab_path=./bioasq_abs2summ_vocab \
34 | --exp_name=$experiment \
35 | --single_pass=True \
36 | --eval_type=medsumm \
37 | --generated_data_file=../../evaluation/data/pointer_generator/chiqa_eval/${predict_file} \
38 | --tag_sentences=True \
39 | --question_driven=${q_driven}
40 | done
41 | done
42 | done
43 |
--------------------------------------------------------------------------------
/models/pointer_generator/run_medsumm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """
18 | This is the top-level file to train, evaluate or test your summarization model
19 |
20 | This script has been converted from the original used for the "Get to the Point" paper so that
21 | it may be implemend for the summarization of consumer health answers
22 | """
23 |
24 | import sys
25 | import time
26 | import os
27 | import tensorflow as tf
28 | import numpy as np
29 | from collections import namedtuple
30 | from data import Vocab
31 | from batcher import Batcher
32 | from model import SummarizationModel
33 | from decode import BeamSearchDecoder
34 | import util
35 | from tensorflow.python import debug as tf_debug
36 |
37 | FLAGS = tf.app.flags.FLAGS
38 |
39 | # Define whether to use answer summarization processing pipeline or original pipeline for CNN data:
40 | tf.app.flags.DEFINE_boolean('medsumm', True, "Definition for answer summarization data processing")
41 | tf.app.flags.DEFINE_boolean('question_driven', True, "Add question to beginning of text for question-driven summ. Option only applied for inference, as the training datasets are prepared appropriately")
42 | tf.app.flags.DEFINE_boolean('tag_sentences', False, "Add sentences to summaries when data processing for inference. All this does is allow the code to return the ref summaries properly.")
43 |
44 | # Where to find data
45 | tf.app.flags.DEFINE_string('data_path', '', 'Path expression to tf.Example datafiles. Can include wildcards to access multiple datafiles.')
46 | tf.app.flags.DEFINE_string('vocab_path', '', 'Path expression to text vocabulary file.')
47 |
48 | # Important settings
49 | tf.app.flags.DEFINE_string('mode', 'train', 'must be one of train/eval/decode')
50 | tf.app.flags.DEFINE_boolean('single_pass', False, 'For decode mode only. If True, run eval on the full dataset using a fixed checkpoint, i.e. take the current checkpoint, and use it to produce one summary for each example in the dataset, write the summaries to file and then get ROUGE scores for the whole dataset. If False (default), run concurrent decoding, i.e. repeatedly load latest checkpoint, use it to produce summaries for randomly-chosen examples and log the results to screen, indefinitely.')
51 | # If decode is true and single_pass is true, define whether to use original rouge evaluation code, or
52 | # write data for later answer summ evaluation
53 | tf.app.flags.DEFINE_string('eval_type', 'medsumm', 'must be medsumm or cnn')
54 | tf.app.flags.DEFINE_string('generated_data_file', '/data/saveryme/asumm/asumm_data/generated_data.json', 'json path for genearted data and reference summaries')
55 |
56 | # Where to save output
57 | tf.app.flags.DEFINE_string('log_root', '', 'Root directory for all logging.')
58 | tf.app.flags.DEFINE_string('exp_name', '', 'Name for experiment. Logs will be saved in a directory with this name, under log_root.')
59 |
60 | # Hyperparameters
61 | tf.app.flags.DEFINE_integer('hidden_dim', 256, 'dimension of RNN hidden states')
62 | tf.app.flags.DEFINE_integer('emb_dim', 128, 'dimension of word embeddings')
63 | tf.app.flags.DEFINE_integer('batch_size', 16, 'minibatch size')
64 | tf.app.flags.DEFINE_integer('max_enc_steps', 400, 'max timesteps of encoder (max source text tokens)')
65 | tf.app.flags.DEFINE_integer('max_dec_steps', 100, 'max timesteps of decoder (max summary tokens)')
66 | tf.app.flags.DEFINE_integer('beam_size', 4, 'beam size for beam search decoding.')
67 | tf.app.flags.DEFINE_integer('min_dec_steps', 35, 'Minimum sequence length of generated summary. Applies only for beam search decoding mode')
68 | tf.app.flags.DEFINE_integer('vocab_size', 50000, 'Size of vocabulary. These will be read from the vocabulary file in order. If the vocabulary file contains fewer words than this number, or if this number is set to 0, will take all words in the vocabulary file.')
69 | tf.app.flags.DEFINE_float('lr', 0.15, 'learning rate')
70 | tf.app.flags.DEFINE_float('adagrad_init_acc', 0.1, 'initial accumulator value for Adagrad')
71 | tf.app.flags.DEFINE_float('rand_unif_init_mag', 0.02, 'magnitude for lstm cells random uniform inititalization')
72 | tf.app.flags.DEFINE_float('trunc_norm_init_std', 1e-4, 'std of trunc norm init, used for initializing everything else')
73 | tf.app.flags.DEFINE_float('max_grad_norm', 2.0, 'for gradient clipping')
74 |
75 | # Pointer-generator or baseline model
76 | tf.app.flags.DEFINE_boolean('pointer_gen', True, 'If True, use pointer-generator model. If False, use baseline model.')
77 |
78 | # Coverage hyperparameters
79 | tf.app.flags.DEFINE_boolean('coverage', False, 'Use coverage mechanism. Note, the experiments reported in the ACL paper train WITHOUT coverage until converged, and then train for a short phase WITH coverage afterwards. i.e. to reproduce the results in the ACL paper, turn this off for most of training then turn on for a short phase at the end.')
80 | tf.app.flags.DEFINE_float('cov_loss_wt', 1.0, 'Weight of coverage loss (lambda in the paper). If zero, then no incentive to minimize coverage loss.')
81 |
82 | # Utility flags, for restoring and changing checkpoints
83 | tf.app.flags.DEFINE_boolean('convert_to_coverage_model', False, 'Convert a non-coverage model to a coverage model. Turn this on and run in train mode. Your current training model will be copied to a new version (same name with _cov_init appended) that will be ready to run with coverage flag turned on, for the coverage training stage.')
84 | tf.app.flags.DEFINE_boolean('restore_best_model', False, 'Restore the best model in the eval/ dir and save it in the train/ dir, ready to be used for further training. Useful for early stopping, or if your training checkpoint has become corrupted with e.g. NaN values.')
85 |
86 | # Debugging. See https://www.tensorflow.org/programmers_guide/debugger
87 | tf.app.flags.DEFINE_boolean('debug', False, "Run in tensorflow's debug mode (watches for NaN/inf values)")
88 |
89 |
90 |
91 | def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99):
92 | """Calculate the running average loss via exponential decay.
93 | This is used to implement early stopping w.r.t. a more smooth loss curve than the raw loss curve.
94 |
95 | Args:
96 | loss: loss on the most recent eval step
97 | running_avg_loss: running_avg_loss so far
98 | summary_writer: FileWriter object to write for tensorboard
99 | step: training iteration step
100 | decay: rate of exponential decay, a float between 0 and 1. Larger is smoother.
101 |
102 | Returns:
103 | running_avg_loss: new running average loss
104 | """
105 | if running_avg_loss == 0: # on the first iteration just take the loss
106 | running_avg_loss = loss
107 | else:
108 | running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
109 | running_avg_loss = min(running_avg_loss, 12) # clip
110 | loss_sum = tf.Summary()
111 | tag_name = 'running_avg_loss/decay=%f' % (decay)
112 | loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
113 | summary_writer.add_summary(loss_sum, step)
114 | tf.logging.info('running_avg_loss: %f', running_avg_loss)
115 | return running_avg_loss
116 |
117 |
118 | def restore_best_model():
119 | """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
120 | tf.logging.info("Restoring bestmodel for training...")
121 |
122 | # Initialize all vars in the model
123 | sess = tf.Session(config=util.get_config())
124 | print("Initializing all variables...")
125 | sess.run(tf.initialize_all_variables())
126 |
127 | # Restore the best model from eval dir
128 | saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name])
129 | print("Restoring all non-adagrad variables from best model in eval dir...")
130 | curr_ckpt = util.load_ckpt(saver, sess, "eval")
131 | print ("Restored %s." % curr_ckpt)
132 |
133 | # Save this model to train dir and quit
134 | new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
135 | new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
136 | print ("Saving model to %s..." % (new_fname))
137 | new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables
138 | new_saver.save(sess, new_fname)
139 | print ("Saved.")
140 | exit()
141 |
142 |
143 | def convert_to_coverage_model():
144 | """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
145 | tf.logging.info("converting non-coverage model to coverage model..")
146 |
147 | # initialize an entire coverage model from scratch
148 | sess = tf.Session(config=util.get_config())
149 | print("initializing everything...")
150 | sess.run(tf.global_variables_initializer())
151 |
152 | # load all non-coverage weights from checkpoint
153 | saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
154 | print("restoring non-coverage variables...")
155 | curr_ckpt = util.load_ckpt(saver, sess)
156 | print("restored.")
157 |
158 | # save this model and quit
159 | new_fname = curr_ckpt + '_cov_init'
160 | print("saving model to %s..." % (new_fname))
161 | new_saver = tf.train.Saver() # this one will save all variables that now exist
162 | new_saver.save(sess, new_fname)
163 | print("saved.")
164 | exit()
165 |
166 |
167 | def setup_training(model, batcher):
168 | """Does setup before starting training (run_training)"""
169 | train_dir = os.path.join(FLAGS.log_root, "train")
170 | if not os.path.exists(train_dir): os.makedirs(train_dir)
171 |
172 | model.build_graph() # build the graph
173 | if FLAGS.convert_to_coverage_model:
174 | assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True"
175 | convert_to_coverage_model()
176 | if FLAGS.restore_best_model:
177 | restore_best_model()
178 | saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
179 |
180 | sv = tf.train.Supervisor(logdir=train_dir,
181 | is_chief=True,
182 | saver=saver,
183 | summary_op=None,
184 | save_summaries_secs=60, # save summaries for tensorboard every 60 secs
185 | save_model_secs=60, # checkpoint every 60 secs
186 | global_step=model.global_step)
187 | summary_writer = sv.summary_writer
188 | tf.logging.info("Preparing or waiting for session...")
189 | sess_context_manager = sv.prepare_or_wait_for_session(config=util.get_config())
190 | tf.logging.info("Created session.")
191 | try:
192 | run_training(model, batcher, sess_context_manager, sv, summary_writer) # this is an infinite loop until interrupted
193 | except KeyboardInterrupt:
194 | tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
195 | sv.stop()
196 |
197 |
198 | def run_training(model, batcher, sess_context_manager, sv, summary_writer):
199 | """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
200 | tf.logging.info("starting run_training")
201 | with sess_context_manager as sess:
202 | if FLAGS.debug: # start the tensorflow debugger
203 | sess = tf_debug.LocalCLIDebugWrapperSession(sess)
204 | sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
205 | while True: # repeats until interrupted
206 | batch = batcher.next_batch()
207 | #tf.logging.info(vars(batch))
208 |
209 | tf.logging.info('running training step...')
210 | t0=time.time()
211 | results = model.run_train_step(sess, batch)
212 | t1=time.time()
213 | tf.logging.info('seconds for training step: %.3f', t1-t0)
214 |
215 | loss = results['loss']
216 | tf.logging.info('loss: %f', loss) # print the loss to screen
217 |
218 | if not np.isfinite(loss):
219 | raise Exception("Loss is not finite. Stopping.")
220 |
221 | if FLAGS.coverage:
222 | coverage_loss = results['coverage_loss']
223 | tf.logging.info("coverage_loss: %f", coverage_loss) # print the coverage loss to screen
224 |
225 | # get the summaries and iteration number so we can write summaries to tensorboard
226 | summaries = results['summaries'] # we will write these summaries to tensorboard using summary_writer
227 | train_step = results['global_step'] # we need this to update our running average loss
228 |
229 | tf.logging.info("\nTrain step: {}\n".format(train_step))
230 | summary_writer.add_summary(summaries, train_step) # write the summaries
231 | if train_step % 100 == 0: # flush the summary writer every so often
232 | summary_writer.flush()
233 | if train_step == 10000:
234 | break
235 |
236 |
237 | def run_eval(model, batcher, vocab):
238 | """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
239 | model.build_graph() # build the graph
240 | saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
241 | sess = tf.Session(config=util.get_config())
242 | eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
243 | bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
244 | summary_writer = tf.summary.FileWriter(eval_dir)
245 | running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
246 | best_loss = None # will hold the best loss achieved so far
247 |
248 | while True:
249 | _ = util.load_ckpt(saver, sess) # load a new checkpoint
250 | batch = batcher.next_batch() # get the next batch
251 |
252 | # run eval on the batch
253 | t0=time.time()
254 | results = model.run_eval_step(sess, batch)
255 | t1=time.time()
256 | tf.logging.info('seconds for batch: %.2f', t1-t0)
257 |
258 | # print the loss and coverage loss to screen
259 | loss = results['loss']
260 | tf.logging.info('loss: %f', loss)
261 | if FLAGS.coverage:
262 | coverage_loss = results['coverage_loss']
263 | tf.logging.info("coverage_loss: %f", coverage_loss)
264 |
265 | # add summaries
266 | summaries = results['summaries']
267 | train_step = results['global_step']
268 | summary_writer.add_summary(summaries, train_step)
269 |
270 | running_avg_loss = calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step)
271 |
272 | # If running_avg_loss is best so far, save this checkpoint (early stopping).
273 | # These checkpoints will appear as bestmodel- in the eval dir
274 | if best_loss is None or running_avg_loss < best_loss:
275 | tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
276 | saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
277 | best_loss = running_avg_loss
278 |
279 | # flush the summary writer every so often
280 | if train_step % 100 == 0:
281 | summary_writer.flush()
282 |
283 |
284 | def main(unused_argv):
285 | #if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
286 | # raise Exception("Problem with flags: %s" % unused_argv)
287 |
288 | tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
289 | tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))
290 |
291 | # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
292 | FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
293 | if not os.path.exists(FLAGS.log_root):
294 | if FLAGS.mode=="train":
295 | os.makedirs(FLAGS.log_root)
296 | else:
297 | raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))
298 |
299 | vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary
300 |
301 | # If in decode mode, set batch_size = beam_size
302 | # Reason: in decode mode, we decode one example at a time.
303 | # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
304 | if FLAGS.mode == 'decode':
305 | FLAGS.batch_size = FLAGS.beam_size
306 |
307 | # If single_pass=True, check we're in decode mode
308 | if FLAGS.single_pass and FLAGS.mode!='decode':
309 | raise Exception("The single_pass flag should only be True in decode mode")
310 |
311 | # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
312 | hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen']
313 | hps_dict = {}
314 | for key,val in FLAGS.__flags.items(): # for each flag
315 | if key in hparam_list: # if it's in the list
316 | hps_dict[key] = val # add it to the dict
317 | hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)
318 |
319 | # Create a batcher object that will create minibatches of data
320 | batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass, medsumm=FLAGS.medsumm, question_driven=FLAGS.question_driven, tag_sentences=FLAGS.tag_sentences)
321 |
322 | tf.set_random_seed(111) # a seed value for randomness
323 |
324 | if hps.mode == 'train':
325 | print("creating model...")
326 | model = SummarizationModel(hps, vocab)
327 | setup_training(model, batcher)
328 | elif hps.mode == 'eval':
329 | model = SummarizationModel(hps, vocab)
330 | run_eval(model, batcher, vocab)
331 | elif hps.mode == 'decode':
332 | decode_model_hps = hps # This will be the hyperparameters for the decoder model
333 | decode_model_hps = hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
334 | model = SummarizationModel(decode_model_hps, vocab)
335 | decoder = BeamSearchDecoder(model, batcher, vocab)
336 | decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once)
337 | else:
338 | raise ValueError("The 'mode' flag must be one of train/eval/decode")
339 |
340 | if __name__ == '__main__':
341 | tf.app.run()
342 |
--------------------------------------------------------------------------------
/models/pointer_generator/submit_sbatch_eval.sh:
--------------------------------------------------------------------------------
1 | for q in with_question without_question
2 | do
3 | echo $q
4 | sbatch --job-name=${q}_eval --export=QDRIVEN=$q eval_medsumm.sh
5 | done
6 |
--------------------------------------------------------------------------------
/models/pointer_generator/train_medsumm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --output=/data/saveryme/asumm/models/pointer-generator/slurm_logs/slurm_%j.out
3 | #SBATCH --error=/data/saveryme/asumm/models/pointer-generator/slurm_logs/slurm_%j.error
4 | #SBATCH --partition=gpu
5 | #SBATCH --gres=gpu:p100:1
6 | #SBATCH --mem=40g
7 | #SBATCH --cpus-per-task=12
8 | #SBATCH --time=2-00:00:00
9 |
10 | training_data=../../data_processing/data/bioasq_abs2summ_training_data_${1}.json
11 | # With question:
12 | if [ $1 == "with_question" ]; then
13 | echo $1
14 | mkdir bioasq_abs2summ_with_question
15 | experiment=bioasq_abs2summ_with_question
16 | rm -r ${experiment}/*
17 | python -u run_medsumm.py --mode=train --data_path=$training_data --vocab_path=./bioasq_abs2summ_vocab --exp_name=$experiment
18 | fi
19 |
20 | # And without question:
21 | if [ $1 == "without_question" ]; then
22 | echo $1
23 | mkdir bioasq_abs2summ_without_question
24 | experiment=bioasq_abs2summ_without_question
25 | rm -r ${experiment}/*
26 | python -u run_medsumm.py --mode=train --data_path=$training_data --vocab_path=./bioasq_abs2summ_vocab --exp_name=$experiment
27 | fi
28 |
--------------------------------------------------------------------------------
/models/pointer_generator/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | # Modifications Copyright 2017 Abigail See
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """This file contains some utility functions"""
18 |
19 | import tensorflow as tf
20 | import time
21 | import os
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 | def get_config():
25 | """Returns config for tf.session"""
26 | config = tf.ConfigProto(allow_soft_placement=True)
27 | config.gpu_options.allow_growth=True
28 | return config
29 |
30 | def load_ckpt(saver, sess, ckpt_dir="train"):
31 | """Load checkpoint from the ckpt_dir (if unspecified, this is train dir) and restore it to saver and sess, waiting 10 secs in the case of failure. Also returns checkpoint name."""
32 | while True:
33 | try:
34 | latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None
35 | ckpt_dir = os.path.join(FLAGS.log_root, ckpt_dir)
36 | ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename)
37 | tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
38 | saver.restore(sess, ckpt_state.model_checkpoint_path)
39 | return ckpt_state.model_checkpoint_path
40 | except:
41 | tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10)
42 | time.sleep(10)
43 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | lxml
2 | numpy
3 | spacy
4 | pandas
5 | sklearn
6 | py-rouge
7 | nltk
8 | tqdm
9 | openpyxl
10 |
--------------------------------------------------------------------------------