├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── extractors ├── README.md ├── add_rouge.py ├── postprocess_relreg.py ├── predict_qmsum_bart.sh ├── prep_data_relreg.py ├── run_relreg.sh ├── run_relreg_tt.sh ├── test_relreg_tt.py ├── train_qmsum_bart.sh └── train_relreg_tt.py ├── multiencoder ├── README.md ├── convert_qmsum.py ├── dataset.py ├── models.py ├── report_training_runs.py ├── scripts │ ├── predict_qmsum_16_1024_strided_val.sh │ ├── predict_qmsum_16_256_strided_val.sh │ ├── predict_qmsum_16_512_strided_val.sh │ ├── predict_qmsum_32_256_strided_val.sh │ ├── predict_qmsum_32_512_nostrided_val.sh │ ├── predict_qmsum_32_512_strided_test.sh │ ├── predict_qmsum_32_512_strided_val.sh │ ├── predict_qmsum_32_512_strided_wikisum_test.sh │ ├── predict_qmsum_32_512_strided_wikisum_val.sh │ ├── predict_qmsum_4_1024_strided_val.sh │ ├── predict_qmsum_64_256_strided_val.sh │ ├── predict_qmsum_8_1024_strided_val.sh │ ├── predict_qmsum_8_512_strided_val.sh │ ├── predict_val.sh │ ├── report_rouge_test.sh │ ├── report_rouge_val.sh │ ├── select_checkpoints.sh │ ├── train_qmsum_16_1024_strided.sh │ ├── train_qmsum_16_256_strided.sh │ ├── train_qmsum_16_256_strided_catchup.sh │ ├── train_qmsum_16_512_strided.sh │ ├── train_qmsum_32_256_strided.sh │ ├── train_qmsum_32_512_nostrided.sh │ ├── train_qmsum_32_512_strided.sh │ ├── train_qmsum_32_512_strided_wikisum.sh │ ├── train_qmsum_4_1024_strided.sh │ ├── train_qmsum_64_256_strided.sh │ ├── train_qmsum_8_1024_strided.sh │ └── train_qmsum_8_512_strided.sh ├── select_checkpoints.py ├── test │ ├── __init__.py │ ├── data │ │ └── dataset.jsonl │ ├── test_dataset.py │ └── test_model.py ├── train.py └── transformers │ ├── LICENSE │ └── changes ├── preprocessing ├── README.md └── prep_qmsum.py ├── requirements.txt └── rouge ├── __init__.py └── report_rouge.py /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Salesforce.com, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploring Neural Models for Query-Focused Summarization 2 | 3 | This is the official code repository for [Exploring Neural Models for Query-Focused Summarization](https://arxiv.org/abs/2112.07637) 4 | by [Jesse Vig*](https://twitter.com/jesse_vig), [Alexander R. Fabbri*](https://twitter.com/alexfabbri4), 5 | [Wojciech Kryściński*](https://twitter.com/iam_wkr), [Chien-Sheng Wu](https://twitter.com/jasonwu0731), and 6 | [Wenhao Liu](https://twitter.com/owenhaoliu) (*equal contribution). 7 | 8 | We present code and instructions for reproducing the paper experiments and running the models against your own datasets. 9 | 10 | ## Table of contents 11 | - [Introduction](#introduction) 12 | - [Two-stage models](#two-stage-models) 13 | - [Segment Encoder](#segment-encoder) 14 | - [Citation](#citation) 15 | - [License](#license) 16 | 17 | ## Introduction 18 | Query-focused summarization (QFS) aims to produce summaries that answer particular questions of interest, enabling greater user control and personalization. 19 | In [our paper](https://arxiv.org/abs/2112.07637) we conduct a systematic exploration of neural approaches to QFS, considering two general classes of methods: two-stage extractive-abstractive solutions and end-to-end models. 20 | Within those categories, we investigate existing methods and present two model extensions that achieve state-of-the-art performance on the QMSum dataset by a margin of up to 3.38 ROUGE-1, 3.72 ROUGE-2, and 3.28 ROUGE-L. 21 | 22 | ## Two-stage models 23 | 24 | Two-step approaches consist of an *extractor* model, which extracts parts of the source document relevant to the input query, and an *abstractor* model, 25 | which synthesizes the extracted segments into a final summary. 26 | 27 | See [extractors](extractors/README.md) directory for instructions and code for training and evaluating two-stage models. 28 | 29 | ## Segment Encoder 30 | 31 | The Segment Encoder is an end-to-end model that uses sparse local attention to achieve SOTA ROUGE scores on the QMSum dataset. 32 | 33 | To [replicate](multiencoder/README.md#reproducing-qmsum-experiments) the QMSum experiments, or train and evaluate Segment Encoder 34 | [on your own dataset](multiencoder/README.md#running-on-your-own-datasets), see the 35 | [multiencoder](multiencoder/README.md) directory. 36 | 37 | ## Citation 38 | 39 | When referencing this repository, please cite [this paper](https://arxiv.org/abs/2112.07637): 40 | 41 | ```bibtex 42 | @misc{vig-etal-2021-exploring, 43 | title={Exploring Neural Models for Query-Focused Summarization}, 44 | author={Jesse Vig and Alexander R. Fabbri and Wojciech Kry{\'s}ci{\'n}ski and Chien-Sheng Wu and Wenhao Liu}, 45 | year={2021}, 46 | eprint={2112.07637}, 47 | archivePrefix={arXiv}, 48 | primaryClass={cs.CL}, 49 | url={https://arxiv.org/abs/2112.07637} 50 | } 51 | ``` 52 | 53 | ## License 54 | 55 | This repository is released under the [BSD-3 License](LICENSE.txt). 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. -------------------------------------------------------------------------------- /extractors/README.md: -------------------------------------------------------------------------------- 1 | # Two-Step Models 2 | 3 | NOTE: Run all of the following steps from `/extractors`. 4 | 5 | ## Table of contents 6 | - [Installation](#installation) 7 | - [Extractor Component](#extractor-component) 8 | * [1. Preprocess QMSum](#1-preprocess-qmsum) 9 | * [2. Download RelReg training code](#2-download-relreg-training-code) 10 | * [3. Run RelReg pipeline](#3-run-relreg-pipeline) 11 | * [4. Run RelRegTT pipeline](#4-run-relregtt-pipeline) 12 | - [Abstractor Component](#abstractor-component) 13 | * [1. Train models](#1-train-models) 14 | * [2. Choose Checkpoint](#2-choose-checkpoint) 15 | * [3. Generate Predictions](#3-generate-predictions) 16 | * [4. Report rouge scores](#4-report-rouge-scores) 17 | * [5. Pretrained Models](#5-pretrained-models) 18 | 19 | 20 | ## Installation 21 | ``` 22 | pip install -r ../requirements.txt 23 | ``` 24 | 25 | ## Extractor Component 26 | 27 | ### 1. Preprocess QMSum 28 | 29 | To perform the preprocessing of QMSum necessary to reproduce the experiments, follow the instructions in the 30 | [preprocessing](../preprocessing/README.md) directory. 31 | 32 | ### 2. Download RelReg training code 33 | 34 | ``` 35 | git clone https://github.com/huggingface/transformers.git 36 | cd transformers 37 | git checkout 65659a29cf5a079842e61a63d57fa24474288998 38 | cd .. 39 | ``` 40 | 41 | ### 3. Run RelReg pipeline 42 | 43 | ``` 44 | # Data prep, training, inference, postprocessing on utterance-level input 45 | # switch to 1 for segment-level; outputs files for seq2seq training to output-relreg-utt 46 | bash run_relreg.sh 0 output-relreg-utt 47 | ``` 48 | 49 | ### 4. Run RelRegTT pipeline 50 | 51 | ``` 52 | # switch to 1 for segment-level; outputs files to output-relregTT-utt 53 | bash run_relreg_tt.sh 0 output-relregTT-utt 54 | ``` 55 | 56 | ## Abstractor Component 57 | 58 | ### 1. Train models 59 | 60 | `bash train_qmsum_bart.sh` 61 | 62 | ### 2. Choose Checkpoint 63 | 64 | Select best checkpoints from runs in the previous step, where NAME is taken from the `train_qmsum_bart.sh` script:
65 | 66 | `python ../multiencoder/select_checkpoints.py NAME` 67 | 68 | ### 3. Generate Predictions 69 | 70 | To generate predictions on the validation set: 71 | 72 | `bash predict_qmsum_bart.sh` 73 | 74 | ### 4. Report rouge scores 75 | 76 | `python ../rouge/report_rouge.py --ref-path PATH_TO_REFERENCES --pred-paths PATH_TO_PREDICTIONS` 77 | 78 | ### 5. Pretrained Models 79 | 80 | We have included checkpoints for all 5 training runs of the RelReg-W model used in the final evaluation, along with their performance on the **validation** set: 81 | 82 | | Run | ROUGE-1 | ROUGE-2 | ROUGE-L | Checkpoint | 83 | |-----------|---------|----| --- |-------------------------------------------------------------------------------------------------------------------| 84 | | 1 | 37.03 | 12.47 | 32.47 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/relreg-qmsum-256-wikisum-1.tar.gz) | 85 | | 2 | 36.44 | 12.27 | 32.18 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/relreg-qmsum-256-wikisum-2.tar.gz) | 86 | | 3 | 37.10 | 12.47 | 32.61 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/relreg-qmsum-256-wikisum-3.tar.gz) | 87 | | 4 | 36.45 | 12.11 | 32.30 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/relreg-qmsum-256-wikisum-4.tar.gz) | 88 | | 5 | 36.82 | 11.91 | 32.43 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/relreg-qmsum-256-wikisum-5.tar.gz) | 89 | 90 | To generate predictions using these models, please download the above checkpoints and replace the `--model_name_or_path` line in `predict_qmsum_bart.sh` accordingly. -------------------------------------------------------------------------------- /extractors/add_rouge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | 5 | from tqdm import tqdm 6 | from datasets import load_metric 7 | 8 | sys.setrecursionlimit(10000) 9 | metric = load_metric("rouge") 10 | 11 | if __name__ == "__main__": 12 | do_chunks = int(sys.argv[1]) 13 | if do_chunks: 14 | from transformers import AutoTokenizer 15 | from dataset import ChunkTokenizer 16 | tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large') 17 | 18 | chunk_size = 256 19 | max_num_chunks = 128 20 | pad = False 21 | stride = True 22 | chunk_tokenizer = ChunkTokenizer( 23 | tokenizer=tokenizer, 24 | chunk_size=chunk_size, 25 | max_num_chunks=max_num_chunks, 26 | stride=stride, 27 | pad=pad 28 | ) 29 | 30 | for split in ["train", "val", "test"]: 31 | id2meetingsrc = {} 32 | fname = os.path.join(os.path.dirname( __file__ ), "..", "data", f'{split}-meetings.jsonl') 33 | with open(fname) as f: 34 | for line in f: 35 | meeting_data = json.loads(line) 36 | if do_chunks: 37 | utts_joined = " ".join(meeting_data['meeting_transcripts']) 38 | output = chunk_tokenizer( 39 | source=utts_joined, 40 | ) 41 | input_ids = output['input_ids'] 42 | tokens = tokenizer.batch_decode(input_ids) 43 | chunks = [x.replace("", "").replace("", "").strip() for x in tokens] 44 | id2meetingsrc[meeting_data['meeting_id']] = chunks 45 | else: 46 | id2meetingsrc[meeting_data['meeting_id']] = meeting_data['meeting_transcripts'] 47 | 48 | fname = os.path.join(os.path.dirname( __file__ ), "..", "data", f"{split}.jsonl") 49 | if do_chunks: 50 | fname_out = os.path.join(os.path.dirname( __file__ ), \ 51 | "..", "data", f"{split}.rouge.256.jsonl") 52 | else: 53 | fname_out = os.path.join(os.path.dirname( __file__ ), \ 54 | "..", "data", f"{split}.rouge.jsonl") 55 | totals = {"train": 1257, "val": 272, "test": 281} 56 | with open(fname) as f, open(fname_out, "w") as out: 57 | for line in tqdm(f, total=totals[split]): 58 | data = json.loads(line) 59 | meeting_utterances_final = id2meetingsrc[data['meeting_id']] 60 | target = data['answer'] 61 | query = data['query'] 62 | 63 | references = [target] * len(meeting_utterances_final) 64 | 65 | metric.add_batch(predictions=meeting_utterances_final, references=references) 66 | score = metric.compute(use_agregator=False) 67 | 68 | rouge_1 = score['rouge1'] 69 | scores = [x.fmeasure for x in rouge_1] 70 | if do_chunks: 71 | data["chunks"] = meeting_utterances_final 72 | data["utt_rouge_f1"] = scores 73 | 74 | json.dump(data, out) 75 | out.write("\n") 76 | -------------------------------------------------------------------------------- /extractors/postprocess_relreg.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import sys 9 | import os 10 | import csv 11 | import json 12 | 13 | 14 | 15 | if __name__ == "__main__": 16 | do_chunks = int(sys.argv[1]) 17 | output_dir = sys.argv[2] 18 | 19 | 20 | totals = {"train": 1257, "val": 272, "test": 281} 21 | for split in ["train", "val", "test"]: 22 | pred_file = f"{output_dir}-{split}/predict_None.txt" # transformer output fname 23 | preds = [] 24 | with open(pred_file) as f_pred: 25 | next(f_pred) 26 | for line in f_pred: 27 | score = float(line.strip().split()[-1]) 28 | preds.append(score) 29 | 30 | id2meetingsrc = {} 31 | if do_chunks: 32 | fname = os.path.join(os.path.dirname( __file__ ), "..", \ 33 | "data", f"{split}.rouge.chunks.jsonl") 34 | else: 35 | fname = os.path.join(os.path.dirname( __file__ ), "..", \ 36 | "data", f'{split}-meetings.jsonl') 37 | with open(fname) as f: 38 | for line in f: 39 | meeting_data = json.loads(line) 40 | if do_chunks: 41 | meeting_source = meeting_data['chunks'] 42 | else: 43 | meeting_source = meeting_data['meeting_transcripts'] 44 | id2meetingsrc[meeting_data['meeting_id']] = meeting_source 45 | 46 | chunk_counter = 0 47 | with open(fname) as f, open(f"{output_dir}/{split}.csv", "w") as out, \ 48 | open(f"{output_dir}/{split}.source", "w") as outs, \ 49 | open(f"{output_dir}/{split}.target", "w") as outt, \ 50 | open(f"{output_dir}/{split}.locator.jsonl", "w") as outl: 51 | writer = csv.DictWriter(out, fieldnames=["text", "summary"]) 52 | writer.writeheader() 53 | for line in f: 54 | data = json.loads(line) 55 | 56 | meeting_utterances = id2meetingsrc[data['meeting_id']] 57 | cur_preds = preds[chunk_counter: chunk_counter + len(meeting_utterances)] 58 | chunk_counter += len(meeting_utterances) 59 | 60 | indices = sorted(list(range(len(cur_preds))), key = \ 61 | lambda x: cur_preds[x], reverse=True) 62 | sorted_chunks = [meeting_utterances[x] for x in indices] 63 | scores = [cur_preds[x] for x in indices] 64 | 65 | query = data["query"] 66 | 67 | assert len(meeting_utterances) == len(indices) 68 | 69 | data["indices"] = indices 70 | data["scores"] = scores 71 | json.dump(data, outl) 72 | outl.write("\n") 73 | 74 | utts_ordered = [meeting_utterances[x] for x in indices] 75 | meeting_source = " ".join(utts_ordered) 76 | 77 | target = data["answer"] 78 | source = f"{query} {meeting_source}" 79 | cur_data = {"text": source, "summary": target} 80 | writer.writerow(cur_data) 81 | 82 | outs.write(source + "\n") 83 | outt.write(target + "\n") 84 | -------------------------------------------------------------------------------- /extractors/predict_qmsum_bart.sh: -------------------------------------------------------------------------------- 1 | NAME=relreg-qmsum-256-wikisum 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | CUDA_VISIBLE_DEVICES=0 python -u train.py 9 | --test_file $RELREG_OUTPUT_DIR/val.csv 10 | --do_predict 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} 14 | --max_source_length 512 15 | --generation_max_len 256 16 | --val_max_target_length 256 17 | --overwrite_output_dir 18 | --per_device_eval_batch_size 4 19 | --predict_with_generate 20 | done 21 | -------------------------------------------------------------------------------- /extractors/prep_data_relreg.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | import csv 11 | import json 12 | 13 | from tqdm import tqdm 14 | from datasets import load_metric 15 | 16 | sys.setrecursionlimit(10000) 17 | metric = load_metric("rouge") 18 | 19 | if __name__ == "__main__": 20 | do_chunks = int(sys.argv[1]) 21 | 22 | for split in ["train", "val", "test"]: 23 | id2meetingsrc = {} 24 | if do_chunks: 25 | fname = os.path.join(os.path.dirname( __file__ ), \ 26 | "data", f'{split}.rouge.256.jsonl') 27 | else: 28 | fname = os.path.join(os.path.dirname( __file__ ), "..", \ 29 | "data", f'{split}.rouge.jsonl') 30 | with open(fname) as f: 31 | for line in f: 32 | meeting_data = json.loads(line) 33 | if do_chunks: 34 | id2meetingsrc[meeting_data['meeting_id']] = meeting_data['chunks'] 35 | else: 36 | id2meetingsrc[meeting_data['meeting_id']] = meeting_data['meeting_transcripts'] 37 | 38 | if do_chunks: 39 | fname = os.path.join(os.path.dirname( __file__ ), "data", f'{split}.rouge.256.jsonl') 40 | fname_out_csv = os.path.join(os.path.dirname( __file__ ), \ 41 | "..", "data", f"{split}.relreg.256.csv") 42 | else: 43 | fname = os.path.join(os.path.dirname( __file__ ), "..", "data", f"{split}.jsonl") 44 | fname_out_csv = os.path.join(os.path.dirname( __file__ ), \ 45 | "..", "data", f"{split}.relreg.csv") 46 | 47 | totals = {"train": 1257, "val": 272, "test": 281} 48 | with open(fname) as f, open(fname_out_csv, "w") as outr: 49 | writer = csv.DictWriter(outr, fieldnames=["sentence1", "sentence2", "label"]) 50 | writer.writeheader() 51 | for line in tqdm(f, total=totals[split]): 52 | data = json.loads(line) 53 | 54 | meeting_utterances = id2meetingsrc[data['meeting_id']] 55 | target = data['answer'] 56 | query = data['query'] 57 | scores = data["utt_rouge_f1"] 58 | 59 | for score, utt in zip(scores, meeting_utterances): 60 | sent1 = query 61 | sent2 = utt 62 | label = score 63 | cur_dict = {"sentence1": sent1, "sentence2": sent2, "label": label} 64 | writer.writerow(cur_dict) 65 | -------------------------------------------------------------------------------- /extractors/run_relreg.sh: -------------------------------------------------------------------------------- 1 | 2 | $CHUNKS=$1 3 | $OUTPUT_DIR=$2 4 | 5 | 6 | # Add chunk/utterance-level ROUGE and convert data to format required for RelReg training and inference; 0 for utterance-level data. 7 | python add_rouge.py $CHUNKS 8 | python prep_data_relreg.py $CHUNKS 9 | 10 | # Train RelReg on utterance-level input 11 | CUDA_VISIBLE_DEVICES=0 python transformers/examples/pytorch/text-classification/run_glue.py \ 12 | --model_name_or_path google/electra-large-discriminator \ 13 | --train_file ../data/train.relreg.csv \ 14 | --validation_file ../data/val.relreg.csv \ 15 | --save_steps 3000 \ 16 | --do_train \ 17 | --do_eval \ 18 | --max_seq_length 384 \ 19 | --per_device_train_batch_size 4 \ 20 | --gradient_accumulation_steps 32 \ 21 | --learning_rate 2e-5 \ 22 | --num_train_epochs 3 \ 23 | --save_total_limit 1 \ 24 | --output_dir ./${OUTPUT_DIR} ; 25 | 26 | # Run inference inference 27 | for split in 'train' 'val' 'test' 28 | do 29 | CUDA_VISIBLE_DEVICES=0 python transformers/examples/pytorch/text-classification/run_glue.py \ 30 | --model_name_or_path ./${OUTPUT_DIR} \ 31 | --train_file ../data/train.relreg.csv \ 32 | --validation_file ../data/val.relreg.csv \ 33 | --test_file ../data/${split}.relreg.csv \ 34 | --save_steps 3000 \ 35 | --do_predict \ 36 | --max_seq_length 384 \ 37 | --per_device_eval_batch_size 128 \ 38 | --learning_rate 2e-5 \ 39 | --num_train_epochs 3 \ 40 | --output_dir ./${OUTPUT_DIR}-${split} ; 41 | done 42 | 43 | # Collect predictions and process to format for seq2seq models; 0 signifies not using the semgneted input 44 | python postprocess_relreg.py $CHUNKS $OUTPUT_DIR -------------------------------------------------------------------------------- /extractors/run_relreg_tt.sh: -------------------------------------------------------------------------------- 1 | $CHUNKS=$1 2 | $OUTPUT_DIR=$2 3 | 4 | # Run relreg-tt with (0/1) utterance/segments and max encoder length of 256 5 | python add_rouge.py $CHUNKS 6 | python train_relreg_tt.py nli-distilroberta-base-v2 $CHUNKS $OUTPUT_DIR 256 7 | python test_relreg_tt.py $OUTPUT_DIR $CHUNKS $OUTPUT_DIR 256 8 | -------------------------------------------------------------------------------- /extractors/test_relreg_tt.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | * Copyright (c) 2021, salesforce.com, inc. 4 | * All rights reserved. 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import os 10 | import sys 11 | import json 12 | import csv 13 | import torch 14 | from sentence_transformers import SentenceTransformer, util 15 | from train_relreg_tt import get_examples 16 | 17 | sys.setrecursionlimit(10000) 18 | 19 | 20 | def write_to_file(examples, meetings_dict, output_dir, split): 21 | id2meetingsrc_embed = {} 22 | for curid, utts in meetings_dict.items(): 23 | utts = [f"[DOC] {utt}" for utt in utts] 24 | meeting_embeddings = model.encode(utts, convert_to_tensor=True) 25 | id2meetingsrc_embed[curid] = meeting_embeddings 26 | 27 | if not os.path.exists(output_dir): 28 | os.mkdir(output_dir) 29 | 30 | with open(f"{output_dir}/{split}.csv", "w") as out, \ 31 | open(f"{output_dir}/{split}.source", "w") as outs, \ 32 | open(f"{output_dir}/{split}.target", "w") as outt, \ 33 | open(f"{output_dir}/{split}.locator.jsonl", "w") as outl: 34 | writer = csv.DictWriter(out, fieldnames=["text", "summary"]) 35 | writer.writeheader() 36 | for example in examples: 37 | query = example["query"] 38 | query_embedding = model.encode(f"[QRY] {query}", convert_to_tensor=True) 39 | 40 | cur_meeting_embeddings = id2meetingsrc_embed[example['meeting_id']] 41 | cur_meeting_utterances = meetings_dict[example['meeting_id']] 42 | cur_meeting_utterances = [x.replace("[DOC] ", "") for x in cur_meeting_utterances] 43 | 44 | cos_scores = util.pytorch_cos_sim(query_embedding, cur_meeting_embeddings)[0] 45 | top_results = torch.topk(cos_scores, k=len(cur_meeting_embeddings)) 46 | 47 | scores = top_results[0].cpu().tolist() 48 | indices = top_results[1].cpu().tolist() 49 | 50 | assert len(cur_meeting_utterances) == len(indices) 51 | 52 | example["indices"] = indices 53 | example["scores"] = scores 54 | json.dump(example, outl) 55 | outl.write("\n") 56 | 57 | utts_ordered = [cur_meeting_utterances[x] for x in indices] 58 | meeting_source = " ".join(utts_ordered) 59 | 60 | target = example["answer"] 61 | source = f"{query} {meeting_source}" 62 | cur_data = {"text": source, "summary": target} 63 | writer.writerow(cur_data) 64 | 65 | outs.write(source + "\n") 66 | outt.write(target + "\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | model_name = sys.argv[1] 71 | do_chunks = int(sys.argv[2]) 72 | output_dir = sys.argv[3] 73 | max_seq_length = int(sys.argv[4]) 74 | 75 | model = SentenceTransformer(model_name, device='cuda') 76 | model.max_seq_length = max_seq_length 77 | model.eval() 78 | 79 | _, train_meetings_dict, train_examples = get_examples("train", do_chunks) 80 | _, val_meetings_dict, val_examples = get_examples("val", do_chunks) 81 | _, test_meetings_dict, test_examples = get_examples("test", do_chunks) 82 | 83 | write_to_file(test_examples, test_meetings_dict, output_dir, "test") 84 | write_to_file(train_examples, train_meetings_dict, output_dir, "train") 85 | write_to_file(val_examples, val_meetings_dict, output_dir, "val") 86 | -------------------------------------------------------------------------------- /extractors/train_qmsum_bart.sh: -------------------------------------------------------------------------------- 1 | NAME=relreg-qmsum-256-wikisum 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | CUDA_VISIBLE_DEVICES=0 python -u ../multiencoder/train.py \ 7 | --train_file $RELREG_OUTPUT_DIR/train.csv \ 8 | --validation_file $RELREG_OUTPUT_DIR/val.csv \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --model_name_or_path $PATH_TO_CHECKPOINT \ 13 | --metric_for_best_model eval_mean_rouge \ 14 | --output_dir output/${NAME}_${RUN} \ 15 | --per_device_train_batch_size 4 \ 16 | --max_source_length 1024 \ 17 | --generation_max_len 256 \ 18 | --val_max_target_length 256 \ 19 | --overwrite_output_dir \ 20 | --per_device_eval_batch_size 4 \ 21 | --predict_with_generate \ 22 | --evaluation_strategy epoch \ 23 | --num_train_epochs 10 \ 24 | --save_strategy epoch \ 25 | --logging_strategy epoch \ 26 | --load_best_model_at_end \ 27 | --compute_rouge_for_train True \ 28 | --seed $RUN &> ${NAME}_${RUN}.out 29 | done -------------------------------------------------------------------------------- /extractors/train_relreg_tt.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | import json 11 | import math 12 | 13 | from tqdm import tqdm 14 | from sentence_transformers import SentenceTransformer, InputExample, losses 15 | from torch.utils.data import DataLoader 16 | from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator 17 | 18 | sys.setrecursionlimit(10000) 19 | 20 | 21 | def get_examples(split, do_chunks): 22 | id2meetingsrc = {} 23 | if do_chunks: 24 | fname = os.path.join(os.path.dirname( __file__ ), "data", f"{split}.rouge.256.jsonl") 25 | else: 26 | fname = os.path.join(os.path.dirname( __file__ ), "..", "data", f'{split}.rouge.jsonl') 27 | with open(fname) as f: 28 | for line in f: 29 | meeting_data = json.loads(line) 30 | if do_chunks: 31 | id2meetingsrc[meeting_data['meeting_id']] = meeting_data['chunks'] 32 | else: 33 | id2meetingsrc[meeting_data['meeting_id']] = meeting_data['meeting_transcripts'] 34 | 35 | if do_chunks: 36 | fname = os.path.join(os.path.dirname( __file__ ), "data", f"{split}.rouge.256.jsonl") 37 | else: 38 | fname = os.path.join(os.path.dirname( __file__ ), "..", "data", f"{split}.rouge.jsonl") 39 | 40 | totals = {"train": 1257, "val": 272, "test": 281} 41 | cur_examples = [] 42 | cur_examples_meta = [] 43 | with open(fname) as f: 44 | for line in tqdm(f, total=totals[split]): 45 | data = json.loads(line) 46 | cur_examples_meta.append(data) 47 | 48 | meeting_utterances = id2meetingsrc[data['meeting_id']] 49 | scores = data['utt_rouge_f1'] 50 | query = data['query'] 51 | 52 | for utt, utt_score in zip(meeting_utterances, scores): 53 | if len(utt.strip()) == 0: 54 | continue 55 | label = float(utt_score) 56 | cur_example = InputExample(texts=[f"[QRY] {query}", f"[DOC] {utt}"], label=label) 57 | cur_examples.append(cur_example) 58 | return cur_examples, id2meetingsrc, cur_examples_meta 59 | 60 | 61 | if __name__ == "__main__": 62 | model_name = sys.argv[1] # 'nli-distilroberta-base-v2' 63 | do_chunks = int(sys.argv[2]) 64 | output_dir = sys.argv[3] 65 | max_seq_length = int(sys.argv[4]) 66 | 67 | model = SentenceTransformer(model_name, device='cuda') 68 | model.max_seq_length = max_seq_length 69 | 70 | word_embedding_model = model._first_module() 71 | 72 | tokens = ["[DOC]", "[QRY]"] 73 | word_embedding_model.tokenizer.add_tokens(tokens, special_tokens=True) 74 | word_embedding_model.auto_model.resize_token_embeddings(len(word_embedding_model.tokenizer)) 75 | 76 | train_batch_size = 16 77 | num_epochs = 4 78 | 79 | train_examples, _, _ = get_examples("train", do_chunks) 80 | dev_examples, _, _ = get_examples("val", do_chunks) 81 | 82 | 83 | train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size) 84 | train_loss = losses.CosineSimilarityLoss(model) 85 | 86 | evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_examples, name='qmsum-val') 87 | warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) 88 | 89 | model.fit(train_objectives=[(train_dataloader, train_loss)], 90 | evaluator=evaluator, 91 | epochs=num_epochs, 92 | evaluation_steps=1, 93 | warmup_steps=warmup_steps, 94 | output_path=output_dir) 95 | # 1000 -------------------------------------------------------------------------------- /multiencoder/README.md: -------------------------------------------------------------------------------- 1 | # Segment Encoder 2 | 3 | NOTE: Run all of the following steps from `/multiencoder`. 4 | 5 | ## Table of contents 6 | - [Installation](#installation) 7 | - [Reproducing QMSum experiments](#reproducing-qmsum-experiments) 8 | * [1. Preprocess QMSum](#1-preprocess-qmsum) 9 | * [2. Convert to Segment Encoder format](#2-convert-to-segment-encoder-format) 10 | * [3. Train models](#3-train-models) 11 | * [4. Choose checkpoint for each run](#4-choose-checkpoint-for-each-run) 12 | * [5. Generate predictions from selected checkpoints](#5-generate-predictions-from-selected-checkpoints) 13 | * [6. Report rouge scores of all checkpoints](#6-report-rouge-scores-of-all-checkpoints) 14 | - [Pretrained models](#pretrained-models) 15 | * [Downloading checkpoints](#downloading-checkpoints) 16 | * [Using checkpoints](#using-checkpoints) 17 | * [Example](#example) 18 | - [Running on your own datasets](#running-on-your-own-datasets) 19 | * [1. Prepare data in appropriate format](#1-prepare-data-in-appropriate-format) 20 | * [2. Train your model](#2-train-your-model) 21 | * [3. Evaluate your model](#3-evaluate-your-model) 22 | + [HuggingFace rouge metric (simpler)](#huggingface-rouge-metric--simpler-) 23 | + [SummEval rouge metric](#summeval-rouge-metric) 24 | 25 | ## Installation 26 | ``` 27 | pip install -r ../requirements.txt 28 | ``` 29 | 30 | ## Reproducing QMSum experiments 31 | 32 | ### 1. Preprocess QMSum 33 | 34 | To perform the preprocessing of QMSum necessary to reproduce the experiments, follow the instructions in the 35 | [preprocessing](../preprocessing/README.md) directory. 36 | 37 | ### 2. Convert to Segment Encoder format 38 | To convert above files to a format that can be used by the Segment Encoder, run the following: 39 | ``` 40 | python convert_qmsum.py 41 | ``` 42 | 43 | The output files will be in `data/qmsum/preprocessed`. 44 | 45 | ### 3. Train models 46 | 47 | See `scripts/train_qmsum_*.sh` 48 | 49 | ### 4. Choose checkpoint for each run 50 | 51 | `bash scripts/select_checkpoints.sh`. 52 | 53 | Copies best checkpoint for each run (based on mean validation rouge) to `selected_checkpoint` directory. 54 | 55 | ### 5. Generate predictions from selected checkpoints 56 | 57 | `bash scripts/predict_val.sh` 58 | 59 | Writes out val predictions for all selected checkpoints to `selected_checkpoint/predictions.val`. 60 | 61 | `bash scripts/predict_test.sh` 62 | 63 | Writes out test predictions for all selected checkpoints to `selected_checkpoint/predictions.test`. 64 | 65 | ### 6. Report rouge scores of all checkpoints 66 | 67 | `bash scripts/report_rouge_val.sh` 68 | 69 | Reports mean rouge scores on validation set. 70 | 71 | `bash scripts/report_rouge_test.sh` 72 | 73 | Reports mean rouge scores on test set. 74 | 75 | Note that these last scripts may prompt you with a small number of additional install steps. 76 | 77 | ## Pretrained models 78 | 79 | We have provided checkpoints for our best performing QMSum-finetuned Segment Encoder model as reported in our 80 | [paper](https://arxiv.org/pdf/2112.07637.pdf) (Table 5). The hyperparameters of note are: 81 | * Input size: 16384 82 | * Segment length: 512 83 | * Segment overlap: 256 84 | * Initial checkpoint: [Wikisum-pretrained](https://storage.googleapis.com/sfr-query-focused-sum-research/bart-wikisum.tar.gz) 85 | 86 | ### Downloading checkpoints 87 | 88 | We have included checkpoints for all 5 training runs of the model used in the final evaluation, along with their performance on the **validation** set: 89 | 90 | | Run | ROUGE-1 | ROUGE-2 | ROUGE-L | Checkpoint | 91 | |-----------|---------|---------|---------|-------------------------------------------------------------------------------------------------------------------| 92 | | 1 | 38.85 | 13.00 | 34.13 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/segenc-qmsum-16384-512-wikisum-1.tar.gz) | 93 | | 2 | 38.50 | 12.87 | 33.92 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/segenc-qmsum-16384-512-wikisum-2.tar.gz) | 94 | | 3 | 38.66 | 13.01 | 34.07 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/segenc-qmsum-16384-512-wikisum-3.tar.gz) | 95 | | 4 | 38.16 | 12.90 | 33.73 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/segenc-qmsum-16384-512-wikisum-4.tar.gz) | 96 | | 5 | 38.74 | 12.81 | 34.08 | [download](https://storage.googleapis.com/sfr-query-focused-sum-research/segenc-qmsum-16384-512-wikisum-5.tar.gz) | 97 | 98 | 99 | ### Using checkpoints 100 | 101 | To use a checkpoint, first download/untar it and then point the `--model_name_or_path` command-line 102 | argument in [train.py](train.py) to the top-level directory of the checkpoint. (See the 103 | [next section](#running-on-your-own-datasets) for examples of 104 | using [train.py](train.py) to train/evaluate a model.) When using our provided checkpoint, also be sure to set the following arguments 105 | as follows to be consistent with the fine-tuning hyperparameters: 106 | 107 | ```bash 108 | --multiencoder_max_num_chunks 32 \ 109 | --multiencoder_stride \ 110 | --max_source_len 512 111 | ``` 112 | 113 | (For an explanation of the command-line arguments, see [next section](#running-on-your-own-datasets).) 114 | 115 | #### Example 116 | 117 | The example below demonstrates how to evaluate a checkpoint against the validation set. 118 | Note that you will first need to perform 119 | Steps 1 and 2 from the [previous section](#reproducing-qmsum-experiments) to populate the `data/qmsum/preprocessed/` directory. 120 | 121 | ```bash 122 | python train.py \ 123 | --do_predict \ 124 | --test_file data/qmsum/preprocessed/val.jsonl \ 125 | --model_name_or_path PATH_TO_CHECKPOINT \ 126 | --multiencoder_type bart \ 127 | --multiencoder_max_num_chunks 32 \ 128 | --multiencoder_stride \ 129 | --max_source_len 512 \ 130 | --output_dir PATH_TO_OUTPUT \ 131 | --generation_max_len 256 \ 132 | --val_max_target_length 256 \ 133 | --per_device_eval_batch_size 1 \ 134 | --predict_with_generate \ 135 | --prediction_path PATH_TO_PREDICTION_OUTPUT 136 | ``` 137 | 138 | Note: the ROUGE scores obtained from the above script (based on Huggingface ROUGE implementation) may differ slightly 139 | from those reported in the table above (based on SummEval ROUGE implementation, which is consistent with the paper). See discussion of these two implementations [below](#3-evaluate-your-model). 140 | 141 | ## Running on your own datasets 142 | 143 | ### 1. Prepare data in appropriate format 144 | 145 | The Segment Encoder data loaders expect a `.jsonl` file, with each line in the following format: 146 | 147 | ``` 148 | {"source": , "query": , "target": } 149 | ``` 150 | 151 | ### 2. Train your model 152 | 153 | You will need to execute [train.py](train.py) with the appropriate command-line arguments. Below is a template 154 | for executing train.py based on the hyperparameters for the best-performing model ([scripts/train_qmsum_16_512_strided.sh](scripts/train_qmsum_16_512_strided.sh)). 155 | You will need to set `train_file` and `validation_file` to point to `.jsonl` files in the format described in Step 1, and `output_dir` 156 | to point to the directory where the model checkpoints will be saved. 157 | 158 | ```bash 159 | python train.py \ 160 | --do_train \ 161 | --train_file PATH_TO_TRAIN_FILE \ 162 | --do_eval \ 163 | --validation_file PATH_TO_VALIDATION_FILE \ 164 | --model_name_or_path facebook/bart-large \ 165 | --multiencoder_type bart \ 166 | --multiencoder_max_num_chunks 32 \ 167 | --multiencoder_stride \ 168 | --max_source_len 512 \ 169 | --learning_rate 0.000005 \ 170 | --save_strategy epoch \ 171 | --num_train_epochs 10 \ 172 | --gradient_checkpointing \ 173 | --output_dir PATH_TO_SAVE_MODEL \ 174 | --per_device_train_batch_size 1 \ 175 | --generation_max_len 256 \ 176 | --val_max_target_length 256 \ 177 | --evaluation_strategy epoch \ 178 | --per_device_eval_batch_size 1 \ 179 | --metric_for_best_model eval_mean_rouge \ 180 | --compute_rouge_for_train \ 181 | --predict_with_generate \ 182 | --logging_strategy epoch \ 183 | --load_best_model_at_end \ 184 | --seed 1 185 | ``` 186 | 187 | Argument descriptions: 188 | * `do_train`: Required boolean flag 189 | * `train_file`: Path to your training file (in `.jsonl` format described above). 190 | * `do_eval`: Boolean flag to evaluate model on validation set during training 191 | * `validation_file`: Path to your optional validation file (in `.jsonl` format described above) 192 | * `model_name_or_path`: Name of or path to Huggingface model (recommend `facebook/bart-large`). Currently only supports BART checkpoints. 193 | * `multiencoder_type`: Set to `bart` 194 | * `multiencoder_max_num_chunks`: Number of segments 195 | * `multiencoder_stride`: Boolean flag to use 50%-overlap strides in segmentation. If not set, segments will be disjoint, which may degrade model performance. 196 | * `max_source_len`: Segment length 197 | * `learning_rate`: Learning rate (recommend 0.000005 if replicating paper experiments) 198 | * `save_strategy`: Set to `epoch` to save checkpoint at end of each epoch 199 | * `num_train_epochs`: Number of epochs 200 | * `gradient_checkpointing` (recommended for larger models): Boolean flag to turn on gradient checkpointing, which reduces memory footprint and increases compute. 201 | This may be necessary for some models depending on number of segments, size of segments, and GPU memory available. 202 | * `output_dir`: Output directory for saved model checkpoints and logs 203 | * `per_device_train_batch_size`: Batch size, typically 1 for larger models 204 | * `generation_max_len` and `val_max_target_length`: Set to the maximum target length 205 | * `evaluation_strategy`: Set to `epoch` if you wish to evaluate at the end of each epoch 206 | * `per_device_eval_batch_size`: Evaluation batch size, typically 1 for larger models 207 | * `metric_for_best_model` (see also `compute_rouge_for_train` and `predict_with_generate` below): Set to `eval_mean_rouge` (recommended) if you wish use mean rouge as criterion for selecting checkpoint. Leave off to use cross entropy. 208 | * `compute_rouge_for_train`: Include if you wish compute rouge as part of the eval in training (necessary if `metric_for_best_model` = `eval_mean_rouge` ) 209 | * `predict_with_generate`: Required boolean flag if `compute_rouge_for_train` set to True 210 | * `logging_strategy`: Set to `epoch` to log results at end of each epoch 211 | * `overwrite_output_dir`: Boolean flag to overwrite output directory with multiple runs 212 | * `load_best_model_at_end`: Boolean flag to load the best checkpoint at the end 213 | * `seed`: Optional random seed 214 | * Optionally, other arguments for the Huggingface Seq2SeqTrainer specified in 215 | [Seq2SeqTrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments) 216 | 217 | See [train.py](train.py) for 218 | documentation on other arguments. Note that [train.py](train.py) is based on the standard HuggingFace [training script for summarization](https://github.com/huggingface/transformers/blob/master/examples/pytorch/summarization/run_summarization.py), 219 | and uses many of the same command-line arguments. 220 | 221 | ### 3. Evaluate your model 222 | 223 | There are two main options for evaluation, described below. 224 | 225 | #### HuggingFace rouge metric (simpler) 226 | This relies on [`datasets.load_metric()`](https://huggingface.co/docs/datasets/loading_metrics.html). 227 | 228 | Run [train.py](train.py) with appropriate arguments for testing. Example template consistent with training template from Step 2: 229 | 230 | ```bash 231 | python train.py \ 232 | --do_predict \ 233 | --test_file PATH_TO_TEST_FILE \ 234 | --model_name_or_path PATH_TO_SAVE_MODEL \ 235 | --multiencoder_type bart \ 236 | --multiencoder_max_num_chunks 32 \ 237 | --multiencoder_stride \ 238 | --max_source_len 512 \ 239 | --output_dir PATH_TO_TEST_OUTPUT \ 240 | --generation_max_len 256 \ 241 | --val_max_target_length 256 \ 242 | --per_device_eval_batch_size 1 \ 243 | --predict_with_generate \ 244 | --prediction_path PATH_TO_PREDICTION_OUTPUT 245 | ``` 246 | 247 | You will need to set `test_file` to a test file in the `.jsonl` format described in Step 1. Set `model_name_or_path` to the 248 | top-level `PATH_TO_SAVE_MODEL` specified in the training script; this top-level directory has the best-performing checkpoint 249 | according to the `metric_for_best_model` argument to the training script. Set `output_dir` 250 | to the directory where testing outputs will go and `prediction_path` to the file where generated predictions will go. 251 | If you change any model parameters in the training 252 | script be sure to update corresponding arguments in the test script (e.g. number of segments, segment length). 253 | 254 | #### SummEval rouge metric 255 | 256 | The [SummEval](https://github.com/Yale-LILY/SummEval) implementation uses the original PERL script for computing rouge. 257 | 258 | To run this, you will need to first run the test script above, and then additionally run 259 | [`report_rouge.py`](../rouge/report_rouge.py) based on the generated predictions from the test script. You 260 | can see examples of this in steps 5-6 in the [Reproducing Experiments section](#reproducing-qmsum-experiments). 261 | 262 | -------------------------------------------------------------------------------- /multiencoder/convert_qmsum.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import argparse 9 | import json 10 | from pathlib import Path 11 | 12 | import tqdm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser('preprocess') 17 | parser.add_argument("--input_dir", type=str, help="inp directory", default="../data/") 18 | parser.add_argument("--output_dir", type=str, help="out directory", default="data/qmsum/preprocessed") 19 | args = parser.parse_args() 20 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 21 | for split in ["test", "val", "train"]: 22 | print(f"\nProcessing {split}") 23 | input_path_meetings = Path(args.input_dir, f"{split}-meetings.jsonl") 24 | meeting_lookup = {} 25 | print('Loading meetings') 26 | with open(input_path_meetings) as f: 27 | for line in tqdm.tqdm(f): 28 | data = json.loads(line) 29 | meeting_id = data['meeting_id'] 30 | source = ' '.join(data['meeting_transcripts']) 31 | meeting_lookup[meeting_id] = source 32 | input_path = Path(args.input_dir, f"{split}.jsonl") 33 | output_path = Path(args.output_dir, f"{split}.jsonl") 34 | print('Loading queries') 35 | with open(input_path) as inp, \ 36 | open(output_path, 'w') as out: 37 | for line in tqdm.tqdm(inp): 38 | data = json.loads(line) 39 | meeting_id = data['meeting_id'] 40 | source = meeting_lookup[meeting_id] 41 | query = data['query'] 42 | target = data['answer'] 43 | out.write( 44 | json.dumps( 45 | { 46 | 'source': source, 47 | 'query': query, 48 | 'target': target 49 | } 50 | ) + '\n' 51 | ) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /multiencoder/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import json 10 | import math 11 | from concurrent.futures.process import ProcessPoolExecutor 12 | 13 | import torch 14 | from torch.utils.data.dataset import Dataset 15 | from transformers import PreTrainedTokenizerBase 16 | 17 | 18 | class MultiEncoderDataset(Dataset): 19 | 20 | def __init__( 21 | self, 22 | data_path: str, 23 | tokenizer: PreTrainedTokenizerBase, 24 | chunk_size: int, 25 | max_num_chunks: int, 26 | max_target_length: int, 27 | stride: bool = False, 28 | pad: bool = True, 29 | num_samples: int = None, 30 | verbose: bool = False, 31 | ignore_pad_token_for_loss: bool = True, 32 | max_workers=1, 33 | ): 34 | 35 | self.tokenizer = tokenizer 36 | self.chunk_tokenizer = ChunkTokenizer( 37 | tokenizer, 38 | chunk_size, 39 | max_num_chunks, 40 | stride, 41 | pad 42 | ) 43 | self.max_target_length = max_target_length 44 | self.num_samples = num_samples 45 | self.verbose = verbose 46 | self.ignore_pad_token_for_loss = ignore_pad_token_for_loss 47 | self._encode_data(data_path, max_workers) 48 | 49 | def __len__(self): 50 | return len(self.encodings) 51 | 52 | def __getitem__(self, index): 53 | return self.encodings[index] 54 | 55 | def _encode_data(self, file_path, max_workers): 56 | with open(file_path) as f: 57 | if max_workers == 1: 58 | encodings = list(map(self._process_line, enumerate(f))) 59 | else: 60 | with ProcessPoolExecutor(max_workers=max_workers) as executor: 61 | encodings = executor.map(self._process_line, enumerate(f)) 62 | self.encodings = [enc for enc in encodings if enc is not False] 63 | if self.num_samples is not None: 64 | assert self.num_samples == len(self.encodings) 65 | 66 | def _process_line(self, index_line): 67 | i, line = index_line 68 | if i % 100 == 0: 69 | print('Processed', i, 'records', datetime.datetime.now()) 70 | if self.num_samples is not None and i >= self.num_samples: 71 | return False 72 | data = json.loads(line) 73 | source = data['source'] 74 | query = data.get('query') 75 | target = data['target'] 76 | encoding = self._encode_example( 77 | source, 78 | target, 79 | query 80 | ) 81 | if self.verbose and i == 0: 82 | print('First record in dataset:') 83 | for token_ids in encoding['input_ids']: 84 | print() 85 | print(self.tokenizer.decode(token_ids)) 86 | return encoding 87 | 88 | def _encode_example(self, source, target, query=None): 89 | 90 | output = self.chunk_tokenizer(source, query) 91 | source_ids = output['input_ids'] 92 | source_attention_mask = output['attention_mask'] 93 | 94 | tokenized_answer = self.tokenizer( 95 | target, 96 | pad_to_max_length=True, 97 | max_length=self.max_target_length, 98 | return_tensors="pt", 99 | truncation=True 100 | ) 101 | target_ids = tokenized_answer['input_ids'].squeeze() 102 | if self.ignore_pad_token_for_loss: 103 | target_ids[target_ids == self.tokenizer.pad_token_id] = -100 104 | 105 | return { 106 | 'input_ids': source_ids, 107 | 'attention_mask': source_attention_mask, 108 | 'labels': target_ids, 109 | 'decoder_attention_mask': tokenized_answer['attention_mask'].squeeze(), 110 | } 111 | 112 | 113 | class ChunkTokenizer: 114 | """Chunks and tokenizes input text and optional query for input to multi-encoder model. Does both chunking and 115 | tokenizing because the chunking is based on tokenized text.""" 116 | 117 | def __init__( 118 | self, 119 | tokenizer: PreTrainedTokenizerBase, 120 | chunk_size: int, 121 | max_num_chunks: int, 122 | stride: bool = False, 123 | pad: bool = False 124 | ): 125 | """ 126 | Args: 127 | tokenizer: tokenizer used to tokenize text 128 | chunk_size: chunk size in number of tokens 129 | max_num_chunks: maximum number of chunks in total (optional) 130 | stride: whether to use striding 131 | pad: whether to "pad" chunks with empty strings to attain max_num_chunks chunks 132 | """ 133 | if pad and not max_num_chunks: 134 | raise ValueError("Cannot pad without specifying max_num_chunks") 135 | self.tokenizer = tokenizer 136 | self.chunk_size = chunk_size 137 | self.max_num_chunks = max_num_chunks 138 | self.stride = stride 139 | self.pad = pad 140 | 141 | def __call__( 142 | self, 143 | source: str, 144 | query: str = None 145 | ): 146 | """ 147 | Args: 148 | source: source text 149 | query: optional query text 150 | Returns: 151 | dictionary with tokenized chunks 152 | """ 153 | if query: 154 | prefix = f"{query}" 155 | else: 156 | prefix = f"" 157 | prefix_tokens = self.tokenizer( 158 | prefix, 159 | add_special_tokens=False, 160 | return_tensors="pt", 161 | max_length=self.chunk_size, 162 | truncation=True 163 | )['input_ids'] 164 | prefix_len = prefix_tokens.size(-1) 165 | 166 | suffix_chunk_size = self.chunk_size - prefix_len 167 | chunk_input_ids_all = [] 168 | chunk_attention_mask_all = [] 169 | 170 | suffix = f"{source}" 171 | suffix_total_size = self.max_num_chunks * suffix_chunk_size 172 | input_ids = self.tokenizer( 173 | suffix, 174 | add_special_tokens=False, 175 | truncation=True, 176 | max_length=suffix_total_size, 177 | )['input_ids'] 178 | 179 | if self.stride and self.max_num_chunks > 1: 180 | use_offset_list = [False, True] 181 | else: 182 | use_offset_list = [False] 183 | for use_offset in use_offset_list: 184 | if use_offset: 185 | offset = math.floor(suffix_chunk_size / 2) 186 | num_chunks = self.max_num_chunks - 1 187 | suffix_tokens = input_ids[offset: offset + num_chunks * suffix_chunk_size] 188 | else: 189 | suffix_tokens = input_ids[:suffix_total_size] 190 | num_chunks = self.max_num_chunks 191 | 192 | suffix_attention = [1] * len(suffix_tokens) 193 | if self.pad: # If padding chunks to num chunks, need to fill out suffix_total_size 194 | pad_length = max(num_chunks * suffix_chunk_size - len(suffix_tokens), 0) 195 | else: # Pad to next multiple of chunk size 196 | remainder = len(suffix_tokens) % suffix_chunk_size 197 | if remainder == 0: 198 | pad_length = 0 199 | else: 200 | pad_length = suffix_chunk_size - remainder 201 | suffix_tokens += [self.tokenizer.pad_token_id] * pad_length 202 | suffix_attention += [0] * pad_length 203 | 204 | suffix_tokens = torch.tensor(suffix_tokens) 205 | suffix_attention = torch.tensor(suffix_attention) 206 | 207 | suffix_chunks = suffix_tokens.view(-1, suffix_chunk_size) 208 | suffix_attention = suffix_attention.view(-1, suffix_chunk_size) 209 | 210 | prefix_chunks = prefix_tokens.expand(suffix_chunks.size(0), -1) 211 | prefix_attention = torch.ones_like(prefix_chunks) 212 | 213 | chunk_input_ids = torch.cat((prefix_chunks, suffix_chunks), dim=1) 214 | chunk_attention_mask = torch.cat((prefix_attention, suffix_attention), dim=1) 215 | 216 | chunk_input_ids_all.append(chunk_input_ids) 217 | chunk_attention_mask_all.append(chunk_attention_mask) 218 | 219 | return { 220 | "input_ids": torch.cat(chunk_input_ids_all, dim=0), 221 | "attention_mask": torch.cat(chunk_attention_mask_all, dim=0) 222 | } 223 | -------------------------------------------------------------------------------- /multiencoder/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import torch 9 | from transformers import BartForConditionalGeneration 10 | from transformers.modeling_outputs import BaseModelOutput 11 | 12 | 13 | class BartForMultiConditionalGeneration(BartForConditionalGeneration): 14 | 15 | def multi_encode( 16 | self, 17 | input_ids=None, 18 | attention_mask=None, 19 | return_dict=None 20 | ): 21 | # (B, N, L) -> (B*N, L) -> (B*N, L, D) -> (B, N*L, D) 22 | # (B, N, L) -> (B*N, L) -> (B, N*L) 23 | B = input_ids.size(0) # batch-size 24 | N = input_ids.size(1) # num-docs 25 | L = input_ids.size(2) # max_len 26 | if input_ids.size() != attention_mask.size(): 27 | raise ValueError( 28 | f"Input ids different shape ({input_ids.size()}) than attention mask ({attention_mask.size()})" 29 | ) 30 | input_ids = input_ids.contiguous().view(B * N, L) 31 | attention_mask = attention_mask.contiguous().view(B * N, L) 32 | encoder_outputs = self.model.encoder( 33 | input_ids=input_ids, 34 | attention_mask=attention_mask, 35 | return_dict=return_dict 36 | ) 37 | if return_dict: 38 | hidden_states = encoder_outputs.last_hidden_state 39 | else: 40 | hidden_states = encoder_outputs[0] 41 | # hidden_states: (B * N, L, D) 42 | D = hidden_states.size(2) 43 | stacked_source_reps = hidden_states.contiguous().view(B, N * L, D) 44 | if return_dict: 45 | encoder_outputs = BaseModelOutput(last_hidden_state=stacked_source_reps) 46 | else: 47 | encoder_outputs = (stacked_source_reps,) 48 | stacked_source_mask = attention_mask.contiguous().view(B, N * L) 49 | return encoder_outputs, stacked_source_mask 50 | 51 | @torch.no_grad() 52 | def generate( 53 | self, 54 | input_ids=None, 55 | attention_mask=None, 56 | **kwargs, 57 | ): 58 | encoder_outputs, attention_mask = self.multi_encode( 59 | input_ids=input_ids, 60 | attention_mask=attention_mask, 61 | return_dict=True 62 | ) 63 | return super().generate( 64 | input_ids=None, 65 | attention_mask=attention_mask, 66 | encoder_outputs=encoder_outputs, 67 | **kwargs 68 | ) 69 | 70 | def forward( 71 | self, 72 | input_ids=None, 73 | attention_mask=None, 74 | decoder_input_ids=None, 75 | decoder_attention_mask=None, 76 | head_mask=None, 77 | decoder_head_mask=None, 78 | cross_attn_head_mask=None, 79 | encoder_outputs=None, 80 | past_key_values=None, 81 | inputs_embeds=None, 82 | decoder_inputs_embeds=None, 83 | labels=None, 84 | use_cache=None, 85 | output_attentions=None, 86 | output_hidden_states=None, 87 | return_dict=None 88 | ): 89 | 90 | if input_ids is None: 91 | if encoder_outputs is None: 92 | raise ValueError("Encoder outputs is required when no input ids passed") 93 | else: 94 | encoder_outputs, attention_mask = self.multi_encode( 95 | input_ids=input_ids, 96 | attention_mask=attention_mask, 97 | return_dict = return_dict 98 | # encoder_outputs=encoder_outputs 99 | ) 100 | 101 | output = super().forward( 102 | input_ids=None, 103 | attention_mask=attention_mask, 104 | decoder_input_ids=decoder_input_ids, 105 | decoder_attention_mask=decoder_attention_mask, 106 | head_mask=head_mask, 107 | decoder_head_mask=decoder_head_mask, 108 | cross_attn_head_mask=cross_attn_head_mask, 109 | encoder_outputs=encoder_outputs, 110 | past_key_values=past_key_values, 111 | inputs_embeds=inputs_embeds, 112 | decoder_inputs_embeds=decoder_inputs_embeds, 113 | labels=labels, 114 | use_cache=use_cache, 115 | output_attentions=output_attentions, 116 | output_hidden_states=output_hidden_states, 117 | return_dict=return_dict 118 | ) 119 | 120 | return output 121 | -------------------------------------------------------------------------------- /multiencoder/report_training_runs.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | # Logs results from multiple training runs 9 | 10 | import argparse 11 | import glob 12 | import json 13 | import logging 14 | import os 15 | import re 16 | from collections import defaultdict 17 | from operator import itemgetter 18 | from statistics import mean 19 | import argparse 20 | import glob 21 | import os 22 | import re 23 | import json 24 | from collections import defaultdict 25 | from statistics import mean, stdev 26 | import sys 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | 'train_dir_prefix', 33 | help='prefix of output directories for training runs being reported on' 34 | ) 35 | parser.add_argument( 36 | '--sort_metric', 37 | type=str, 38 | default='mean', 39 | ) 40 | parser.add_argument( 41 | '--reverse_sort', 42 | type=bool, 43 | default=False, 44 | ) 45 | 46 | args = parser.parse_args() 47 | metrics = ['eval_rouge1', 'eval_rouge2', 'eval_rougeLsum', 'eval_loss', 'eval_gen_len'] 48 | 49 | scores = defaultdict(list) 50 | n_runs = 0 51 | results = [] 52 | print(f"****** {args.train_dir_prefix} ******") 53 | for filepath in sorted(glob.glob(f"{args.train_dir_prefix}*")): 54 | m = re.match(rf'{re.escape(args.train_dir_prefix)}_?(\d+)$', filepath) 55 | if m: 56 | run_index = int(m.group(1)) 57 | results.append((run_index, filepath)) 58 | if args.sort_metric == 'mean': 59 | sort_func = lambda x: mean([x['eval_rouge1'], x['eval_rouge2'], x['eval_rougeLsum']]) 60 | else: 61 | sort_func = itemgetter(f'eval_{args.sort_metric}') 62 | for run_index, filepath in sorted(results): 63 | try: 64 | with open(os.path.join(filepath, "trainer_state.json")) as f: 65 | data = json.load(f) 66 | epoch_logs = [log for log in data['log_history'] if 'eval_loss' in log] 67 | sorted_epochs = sorted( 68 | epoch_logs, 69 | key=sort_func, 70 | reverse=args.reverse_sort) 71 | best_epoch = sorted_epochs[-1] 72 | best_checkpoint = f'{filepath}/checkpoint-{best_epoch["step"]}' 73 | print(best_checkpoint) 74 | for metric in metrics: 75 | scores[metric].append(best_epoch[metric]) 76 | n_runs += 1 77 | except FileNotFoundError: 78 | pass 79 | 80 | print(f"Num runs: {n_runs}") 81 | for metric in metrics: 82 | print(metric) 83 | print(f"\tMean: {mean(scores[metric]):.2f}") 84 | if n_runs > 1: 85 | print(f"\tStd Dev: {stdev(scores[metric]):.2f}") 86 | print(f"\tRange: {min(scores[metric]):.2f}-{max(scores[metric]):.2f}") 87 | print("\tScores: " + ", ".join(f"{score:.2f}" for score in scores[metric])) 88 | 89 | -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_16_1024_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_16_1024_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 1024 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 16 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_16_256_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_16_256_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 256 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 16 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_16_512_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_16_512_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 512 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 16 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_32_256_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_256_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 256 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 32 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_32_512_nostrided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_512_nostrided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 512 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 32 \ 21 | --predict_with_generate 22 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_32_512_strided_test.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_512_strided 2 | SPLIT=test 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 512 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 32 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_32_512_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_512_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 512 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 32 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_32_512_strided_wikisum_test.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_512_strided_wikisum 2 | SPLIT=test 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 512 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 32 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_32_512_strided_wikisum_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_512_strided_wikisum 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 512 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 32 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_4_1024_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_4_1024_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 1024 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 4 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_64_256_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_64_256_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 256 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 64 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_8_1024_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_8_1024_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 1024 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 8 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_qmsum_8_512_strided_val.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_8_512_strided 2 | SPLIT=val 3 | NUM_RUNS=5 4 | START=1 5 | for RUN in $(seq $START $NUM_RUNS) 6 | do 7 | OUTPUT_DIR=output/${NAME}_${RUN} 8 | python -u train.py \ 9 | --test_file data/qmsum/preprocessed/${SPLIT}.jsonl \ 10 | --do_predict \ 11 | --model_name_or_path $OUTPUT_DIR/selected_checkpoint \ 12 | --output_dir ${OUTPUT_DIR}/selected_checkpoint/predition_logs_${SPLIT} \ 13 | --prediction_path ${OUTPUT_DIR}/selected_checkpoint/predictions.${SPLIT} \ 14 | --max_source_length 512 \ 15 | --generation_max_len 256 \ 16 | --val_max_target_length 256 \ 17 | --overwrite_output_dir \ 18 | --per_device_eval_batch_size 1 \ 19 | --multiencoder_type bart \ 20 | --multiencoder_max_num_chunks 8 \ 21 | --multiencoder_stride \ 22 | --predict_with_generate 23 | done -------------------------------------------------------------------------------- /multiencoder/scripts/predict_val.sh: -------------------------------------------------------------------------------- 1 | SPLIT=val 2 | for NAME in \ 3 | qmsum_16_256_strided \ 4 | qmsum_32_256_strided \ 5 | qmsum_64_256_strided \ 6 | qmsum_8_512_strided \ 7 | qmsum_16_512_strided \ 8 | qmsum_32_512_strided \ 9 | qmsum_4_1024_strided \ 10 | qmsum_8_1024_strided \ 11 | qmsum_16_1024_strided \ 12 | qmsum_32_512_nostrided \ 13 | qmsum_32_512_strided_wikisum 14 | do 15 | bash scripts/predict_${NAME}_${SPLIT}.sh 16 | done -------------------------------------------------------------------------------- /multiencoder/scripts/report_rouge_test.sh: -------------------------------------------------------------------------------- 1 | SPLIT=test 2 | export ROUGE_HOME=/export/home/query-focused-conv-summ/rouge/ROUGE-1.5.5/ 3 | for NAME in \ 4 | qmsum_32_512_strided_wikisum \ 5 | qmsum_32_512_strided 6 | do 7 | echo "************************************************************" 8 | echo $NAME 9 | python ../rouge/report_rouge.py \ 10 | --ref-path ../data/${SPLIT}.target \ 11 | --pred-paths \ 12 | output/${NAME}_1/selected_checkpoint/predictions.${SPLIT} \ 13 | output/${NAME}_2/selected_checkpoint/predictions.${SPLIT} \ 14 | output/${NAME}_3/selected_checkpoint/predictions.${SPLIT} \ 15 | output/${NAME}_4/selected_checkpoint/predictions.${SPLIT} \ 16 | output/${NAME}_5/selected_checkpoint/predictions.${SPLIT} 17 | done -------------------------------------------------------------------------------- /multiencoder/scripts/report_rouge_val.sh: -------------------------------------------------------------------------------- 1 | SPLIT=val 2 | export ROUGE_HOME=/export/home/query-focused-conv-summ/rouge/ROUGE-1.5.5/ 3 | for NAME in \ 4 | qmsum_16_256_strided \ 5 | qmsum_32_256_strided \ 6 | qmsum_64_256_strided \ 7 | qmsum_8_512_strided \ 8 | qmsum_16_512_strided \ 9 | qmsum_32_512_strided \ 10 | qmsum_4_1024_strided \ 11 | qmsum_8_1024_strided \ 12 | qmsum_16_1024_strided \ 13 | qmsum_32_512_nostrided \ 14 | qmsum_32_512_strided_wikisum 15 | do 16 | echo "************************************************************" 17 | echo $NAME 18 | python ../rouge/report_rouge.py \ 19 | --ref-path ../data/${SPLIT}.target \ 20 | --pred-paths \ 21 | output/${NAME}_1/selected_checkpoint/predictions.${SPLIT} \ 22 | output/${NAME}_2/selected_checkpoint/predictions.${SPLIT} \ 23 | output/${NAME}_3/selected_checkpoint/predictions.${SPLIT} \ 24 | output/${NAME}_4/selected_checkpoint/predictions.${SPLIT} \ 25 | output/${NAME}_5/selected_checkpoint/predictions.${SPLIT} 26 | done -------------------------------------------------------------------------------- /multiencoder/scripts/select_checkpoints.sh: -------------------------------------------------------------------------------- 1 | for MODEL_NAME in \ 2 | qmsum_16_256_strided \ 3 | qmsum_32_256_strided \ 4 | qmsum_64_256_strided \ 5 | qmsum_8_512_strided \ 6 | qmsum_16_512_strided \ 7 | qmsum_32_512_strided \ 8 | qmsum_4_1024_strided \ 9 | qmsum_8_1024_strided \ 10 | qmsum_16_1024_strided \ 11 | qmsum_32_512_nostrided \ 12 | qmsum_32_512_strided_wikisum 13 | do 14 | python select_checkpoints.py output/$MODEL_NAME 15 | done -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_16_1024_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_16_1024_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 1024 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 16 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_16_256_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_16_256_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 256 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 16 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_16_256_strided_catchup.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_16_256_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 256 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 16 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_16_512_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_16_512_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 512 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 16 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_32_256_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_256_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 256 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 32 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_32_512_nostrided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_512_nostrided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 512 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 32 \ 24 | --predict_with_generate \ 25 | --evaluation_strategy epoch \ 26 | --num_train_epochs 10 \ 27 | --save_strategy epoch \ 28 | --logging_strategy epoch \ 29 | --load_best_model_at_end \ 30 | --compute_rouge_for_train True \ 31 | --seed $RUN &> ${NAME}_${RUN}.out 32 | done 33 | # --metric_for_best_model rouge1_plus_rouge2 \ 34 | 35 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_32_512_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_512_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 512 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 32 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_32_512_strided_wikisum.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_32_512_strided_wikisum 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path bart-wikisum \ 14 | --tokenizer_name facebook/bart-large \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 512 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 32 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_4_1024_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_4_1024_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 1024 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 4 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_64_256_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_64_256_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 256 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 64 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_8_1024_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_8_1024_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 1024 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 8 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/scripts/train_qmsum_8_512_strided.sh: -------------------------------------------------------------------------------- 1 | NAME=qmsum_8_512_strided 2 | NUM_RUNS=5 3 | START=1 4 | for RUN in $(seq $START $NUM_RUNS) 5 | do 6 | python -u train.py \ 7 | --train_file data/qmsum/preprocessed/train.jsonl \ 8 | --validation_file data/qmsum/preprocessed/val.jsonl \ 9 | --do_train \ 10 | --do_eval \ 11 | --learning_rate 0.000005 \ 12 | --gradient_checkpointing \ 13 | --model_name_or_path facebook/bart-large \ 14 | --metric_for_best_model eval_mean_rouge \ 15 | --output_dir output/${NAME}_${RUN} \ 16 | --per_device_train_batch_size 1 \ 17 | --max_source_length 512 \ 18 | --generation_max_len 256 \ 19 | --val_max_target_length 256 \ 20 | --overwrite_output_dir \ 21 | --per_device_eval_batch_size 1 \ 22 | --multiencoder_type bart \ 23 | --multiencoder_max_num_chunks 8 \ 24 | --multiencoder_stride \ 25 | --predict_with_generate \ 26 | --evaluation_strategy epoch \ 27 | --num_train_epochs 10 \ 28 | --save_strategy epoch \ 29 | --logging_strategy epoch \ 30 | --load_best_model_at_end \ 31 | --compute_rouge_for_train True \ 32 | --seed $RUN &> ${NAME}_${RUN}.out 33 | done 34 | # --metric_for_best_model rouge1_plus_rouge2 \ 35 | 36 | -------------------------------------------------------------------------------- /multiencoder/select_checkpoints.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | # Identifies best checkpoint and copies to selected_checkpoint directory 9 | 10 | import argparse 11 | import glob 12 | import json 13 | import logging 14 | import os 15 | import re 16 | import shutil 17 | from collections import defaultdict 18 | from operator import itemgetter 19 | from statistics import mean 20 | import argparse 21 | import glob 22 | import os 23 | import re 24 | import json 25 | from collections import defaultdict 26 | from statistics import mean, stdev 27 | import sys 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | 'train_dir_prefix', 34 | help='prefix of output directories for training runs being reported on' 35 | ) 36 | 37 | args = parser.parse_args() 38 | metrics = ['eval_rouge1', 'eval_rouge2', 'eval_rougeLsum', 'eval_loss', 'eval_gen_len'] 39 | 40 | scores = defaultdict(list) 41 | n_runs = 0 42 | results = [] 43 | print(f"****** {args.train_dir_prefix} ******") 44 | for filepath in sorted(glob.glob(f"{args.train_dir_prefix}*")): 45 | m = re.match(rf'{re.escape(args.train_dir_prefix)}_?(\d+)$', filepath) 46 | if m: 47 | run_index = int(m.group(1)) 48 | results.append((run_index, filepath)) 49 | sort_func = lambda x: mean([x['eval_rouge1'], x['eval_rouge2'], x['eval_rougeLsum']]) 50 | 51 | for run_index, filepath in sorted(results): 52 | try: 53 | with open(os.path.join(filepath, "trainer_state.json")) as f: 54 | data = json.load(f) 55 | epoch_logs = [log for log in data['log_history'] if 'eval_loss' in log] 56 | sorted_epochs = sorted( 57 | epoch_logs, 58 | key=sort_func) 59 | best_epoch = sorted_epochs[-1] 60 | best_checkpoint = f'{filepath}/checkpoint-{best_epoch["step"]}' 61 | if not(os.path.exists(best_checkpoint)): 62 | raise ValueError(f'Checkpoint {best_checkpoint} does not exist') 63 | print(best_epoch) 64 | selected_checkpoint_dir = os.path.join(filepath, 'selected_checkpoint') 65 | if os.path.exists(selected_checkpoint_dir): 66 | print('removing', selected_checkpoint_dir) 67 | shutil.rmtree(selected_checkpoint_dir) 68 | print('Copying from', best_checkpoint, 'to', selected_checkpoint_dir) 69 | shutil.copytree(best_checkpoint, selected_checkpoint_dir) 70 | except FileNotFoundError: 71 | pass 72 | 73 | -------------------------------------------------------------------------------- /multiencoder/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/query-focused-sum/46cb3878ff9f55b963ad94728676189a6d421d60/multiencoder/test/__init__.py -------------------------------------------------------------------------------- /multiencoder/test/data/dataset.jsonl: -------------------------------------------------------------------------------- 1 | {"source": "a b c d e f", "query": "y z", "target": "1 2 3 4 5 6 7 8 9 10 11 12 13"} 2 | {"source": "a b c", "query": "y z", "target": "1 2 3 4 5 6 7 8 9 10"} -------------------------------------------------------------------------------- /multiencoder/test/test_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest 3 | 4 | import torch 5 | from transformers import AutoTokenizer 6 | 7 | from dataset import ChunkTokenizer, MultiEncoderDataset 8 | 9 | 10 | class TestDataset(unittest.TestCase): 11 | 12 | def test_chunker(self): 13 | tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base') 14 | 15 | chunk_size = 10 16 | max_num_chunks = 2 17 | pad = False 18 | chunk_tokenizer = ChunkTokenizer( 19 | tokenizer=tokenizer, 20 | chunk_size=chunk_size, 21 | max_num_chunks=max_num_chunks, 22 | pad=pad 23 | ) 24 | output = chunk_tokenizer( 25 | source="a b c d e", 26 | query="y z" 27 | ) 28 | input_ids = output['input_ids'] 29 | tokens = tokenizer.batch_decode(input_ids) 30 | self.assertEqual( 31 | tokens, 32 | ['y za b c d e'] 33 | ) 34 | 35 | output = chunk_tokenizer( 36 | source="a b c d e f g", 37 | query="y z" 38 | ) 39 | input_ids = output['input_ids'] 40 | tokens = tokenizer.batch_decode(input_ids) 41 | self.assertEqual( 42 | tokens, 43 | ['y za b c d e f', 44 | 'y z g'] 45 | ) 46 | 47 | output = chunk_tokenizer( 48 | source="a b", 49 | query="y z" 50 | ) 51 | input_ids = output['input_ids'] 52 | tokens = tokenizer.batch_decode(input_ids) 53 | self.assertEqual( 54 | tokens, 55 | ['y za b', ] 56 | ) 57 | 58 | pad = True 59 | chunk_tokenizer = ChunkTokenizer( 60 | tokenizer=tokenizer, 61 | chunk_size=chunk_size, 62 | max_num_chunks=max_num_chunks, 63 | pad=pad 64 | ) 65 | output = chunk_tokenizer( 66 | source="a b c d e", 67 | query="y z" 68 | ) 69 | input_ids = output['input_ids'] 70 | tokens = tokenizer.batch_decode(input_ids) 71 | self.assertEqual( 72 | tokens, 73 | ['y za b c d e', 74 | 'y z'] 75 | ) 76 | 77 | # Test with stride 78 | chunk_size = 10 79 | max_num_chunks = 2 80 | pad = False 81 | stride = True 82 | chunk_tokenizer = ChunkTokenizer( 83 | tokenizer=tokenizer, 84 | chunk_size=chunk_size, 85 | max_num_chunks=max_num_chunks, 86 | pad=pad, 87 | stride=stride 88 | ) 89 | output = chunk_tokenizer( 90 | source="a b c d e", 91 | query="y z" 92 | ) 93 | input_ids = output['input_ids'] 94 | tokens = tokenizer.batch_decode(input_ids) 95 | self.assertEqual( 96 | tokens, 97 | ['y za b c d e', 98 | 'y z d e'] 99 | ) 100 | 101 | max_num_chunks = 1 102 | chunk_tokenizer = ChunkTokenizer( 103 | tokenizer=tokenizer, 104 | chunk_size=chunk_size, 105 | max_num_chunks=max_num_chunks, 106 | pad=pad, 107 | stride=stride 108 | ) 109 | output = chunk_tokenizer( 110 | source="a b c d e", 111 | query="y z" 112 | ) 113 | input_ids = output['input_ids'] 114 | tokens = tokenizer.batch_decode(input_ids) 115 | self.assertEqual( 116 | tokens, 117 | ['y za b c d e'] 118 | ) 119 | 120 | max_num_chunks = 2 121 | pad = True 122 | chunk_tokenizer = ChunkTokenizer( 123 | tokenizer=tokenizer, 124 | chunk_size=chunk_size, 125 | max_num_chunks=max_num_chunks, 126 | pad=pad, 127 | stride=stride 128 | ) 129 | output = chunk_tokenizer( 130 | source="a b c d e", 131 | query="y z" 132 | ) 133 | input_ids = output['input_ids'] 134 | tokens = tokenizer.batch_decode(input_ids) 135 | self.assertEqual( 136 | tokens, 137 | ['y za b c d e', 138 | 'y z', 139 | 'y z d e' 140 | ] 141 | 142 | ) 143 | 144 | output = chunk_tokenizer( 145 | source="a b", 146 | query="y z" 147 | ) 148 | input_ids = output['input_ids'] 149 | tokens = tokenizer.batch_decode(input_ids) 150 | self.assertEqual( 151 | tokens, 152 | ['y za b', 153 | 'y z', 154 | 'y z'] 155 | ) 156 | 157 | def test_multiencoder_dataset(self): 158 | data_path = "data/dataset.jsonl" 159 | tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base') 160 | chunk_size = 10 161 | max_target_length = 14 162 | max_num_chunks = 2 163 | stride = False 164 | d = MultiEncoderDataset( 165 | data_path=data_path, 166 | tokenizer=tokenizer, 167 | chunk_size=chunk_size, 168 | max_target_length=max_target_length, 169 | max_num_chunks=max_num_chunks, 170 | stride=stride, 171 | ) 172 | 173 | # Basic test 174 | self.assertEqual(len(d), 2) 175 | 176 | # Test chunk_tokenizer init 177 | self.assertEqual(d.chunk_tokenizer.tokenizer, tokenizer) 178 | self.assertEqual(d.chunk_tokenizer.chunk_size, chunk_size) 179 | self.assertEqual(d.chunk_tokenizer.max_num_chunks, max_num_chunks) 180 | self.assertEqual(d.chunk_tokenizer.stride, stride) 181 | 182 | # Test source 183 | with open(data_path) as f: 184 | row = next(f) 185 | data = json.loads(row) 186 | actual_input_ids = d.chunk_tokenizer(data['source'], data['query'])['input_ids'] 187 | item = d[0] 188 | self.assertTrue(torch.equal(item['input_ids'], actual_input_ids)) 189 | 190 | # Test labels 191 | item = d[0] 192 | self.assertEqual( 193 | item['labels'].tolist(), 194 | tokenizer( 195 | "1 2 3 4 5 6 7 8 9 10 11 12", 196 | add_special_tokens=False 197 | )['input_ids'] 198 | ) 199 | item = d[1] 200 | self.assertEqual( 201 | item['labels'].tolist(), 202 | tokenizer( 203 | "1 2 3 4 5 6 7 8 9 10", 204 | add_special_tokens=False 205 | )['input_ids'] + [-100, -100] 206 | ) 207 | 208 | 209 | if __name__ == '__main__': 210 | unittest.main() 211 | -------------------------------------------------------------------------------- /multiencoder/test/test_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from torch.utils.data import DataLoader 4 | from transformers import AutoTokenizer 5 | 6 | from dataset import MultiEncoderDataset 7 | from models import BartForMultiConditionalGeneration 8 | 9 | 10 | class TestModel(unittest.TestCase): 11 | 12 | def test_bart(self): 13 | data_path = "data/dataset.jsonl" 14 | model_name = 'facebook/bart-base' 15 | tokenizer = AutoTokenizer.from_pretrained(model_name) 16 | chunk_size = 10 17 | max_num_chunks = 4 18 | max_target_length = 20 19 | pad = True 20 | 21 | d = MultiEncoderDataset( 22 | data_path=data_path, 23 | tokenizer=tokenizer, 24 | chunk_size=chunk_size, 25 | max_num_chunks=max_num_chunks, 26 | max_target_length=max_target_length, 27 | pad=pad 28 | ) 29 | 30 | batch_size = 2 31 | model = BartForMultiConditionalGeneration.from_pretrained(model_name) 32 | dataloader = DataLoader(d, batch_size=batch_size) 33 | batch = next(iter(dataloader)) 34 | self.assertTrue(batch['input_ids'].shape == (batch_size, max_num_chunks, chunk_size)) 35 | model.eval() 36 | output = model(**batch) 37 | self.assertTrue(output.encoder_last_hidden_state.shape == (batch_size, chunk_size * max_num_chunks, 768)) 38 | batch.pop('labels') 39 | output = model.generate(**batch, return_dict=True) 40 | self.assertTrue(output.shape[0] == batch_size) 41 | # for tokens in output: 42 | # print(tokenizer.decode(tokens)) 43 | 44 | 45 | if __name__ == '__main__': 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /multiencoder/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | * Copyright (c) 2021, salesforce.com, inc. 4 | * All rights reserved. 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | # 9 | # Based on https://github.com/huggingface/transformers/blob/master/examples/pytorch/summarization/run_summarization.py 10 | # 11 | # coding=utf-8 12 | # Copyright 2021 The HuggingFace Team. All rights reserved. 13 | # 14 | # Licensed under the Apache License, Version 2.0 (the "License"); 15 | # you may not use this file except in compliance with the License. 16 | # You may obtain a copy of the License at 17 | # 18 | # http://www.apache.org/licenses/LICENSE-2.0 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | """ 26 | Fine-tuning the library models for sequence to sequence. 27 | """ 28 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 29 | 30 | # Change log 31 | # 10/29/2021 | Jesse Vig | Customized to multiencoder. See transformers/ directory. 32 | # 33 | 34 | import logging 35 | import os 36 | import pickle 37 | import sys 38 | from dataclasses import dataclass, field 39 | from statistics import mean 40 | from typing import Optional 41 | 42 | import datasets 43 | import nltk # Here to have a nice missing dependency error message early on 44 | import numpy as np 45 | import transformers 46 | from datasets import load_dataset, load_metric 47 | from filelock import FileLock 48 | from transformers import ( 49 | AutoConfig, 50 | AutoModelForSeq2SeqLM, 51 | AutoTokenizer, 52 | DataCollatorForSeq2Seq, 53 | HfArgumentParser, 54 | Seq2SeqTrainer, 55 | Seq2SeqTrainingArguments, 56 | set_seed, 57 | ) 58 | from transformers.file_utils import is_offline_mode 59 | from transformers.trainer_utils import get_last_checkpoint 60 | from transformers.utils.versions import require_version 61 | 62 | from dataset import MultiEncoderDataset 63 | from models import BartForMultiConditionalGeneration 64 | 65 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 66 | # check_min_version("4.12.0.dev0") 67 | 68 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 69 | 70 | logger = logging.getLogger(__name__) 71 | 72 | try: 73 | nltk.data.find("tokenizers/punkt") 74 | except (LookupError, OSError): 75 | if is_offline_mode(): 76 | raise LookupError( 77 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 78 | ) 79 | with FileLock(".lock") as lock: 80 | nltk.download("punkt", quiet=True) 81 | 82 | 83 | @dataclass 84 | class ModelArguments: 85 | """ 86 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 87 | """ 88 | 89 | model_name_or_path: str = field( 90 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 91 | ) 92 | config_name: Optional[str] = field( 93 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 94 | ) 95 | tokenizer_name: Optional[str] = field( 96 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 97 | ) 98 | cache_dir: Optional[str] = field( 99 | default=None, 100 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 101 | ) 102 | use_fast_tokenizer: bool = field( 103 | default=True, 104 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 105 | ) 106 | model_revision: str = field( 107 | default="main", 108 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 109 | ) 110 | use_auth_token: bool = field( 111 | default=False, 112 | metadata={ 113 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 114 | "with private models)." 115 | }, 116 | ) 117 | resize_position_embeddings: Optional[bool] = field( 118 | default=None, 119 | metadata={ 120 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 121 | "the model's position embeddings." 122 | }, 123 | ) 124 | from_pt: bool = field( 125 | default=False, 126 | metadata={ 127 | "help": "Whether to load the model checkpoint from .pt file" 128 | }, 129 | ) 130 | multiencoder_type: Optional[str] = field( 131 | default=None, 132 | metadata={"help": "Currently only 'bart' is supported"}, 133 | ) 134 | multiencoder_max_num_chunks: Optional[int] = field( 135 | default=None, 136 | metadata={ 137 | "help": "Max passages/encoders to use in multiencoder" 138 | }, 139 | ) 140 | multiencoder_stride: bool = field( 141 | default=False, 142 | metadata={ 143 | "help": "Whether to stride" 144 | }, 145 | ) 146 | 147 | 148 | @dataclass 149 | class DataTrainingArguments: 150 | """ 151 | Arguments pertaining to what data we are going to input our model for training and eval. 152 | """ 153 | 154 | dataset_name: Optional[str] = field( 155 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 156 | ) 157 | dataset_config_name: Optional[str] = field( 158 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 159 | ) 160 | text_column: Optional[str] = field( 161 | default=None, 162 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 163 | ) 164 | summary_column: Optional[str] = field( 165 | default=None, 166 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 167 | ) 168 | train_file: Optional[str] = field( 169 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 170 | ) 171 | validation_file: Optional[str] = field( 172 | default=None, 173 | metadata={ 174 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 175 | "(a jsonlines or csv file)." 176 | }, 177 | ) 178 | test_file: Optional[str] = field( 179 | default=None, 180 | metadata={ 181 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 182 | }, 183 | ) 184 | overwrite_cache: bool = field( 185 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 186 | ) 187 | preprocessing_num_workers: Optional[int] = field( 188 | default=None, 189 | metadata={"help": "The number of processes to use for the preprocessing."}, 190 | ) 191 | max_source_length: Optional[int] = field( 192 | default=1024, 193 | metadata={ 194 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 195 | "than this will be truncated, sequences shorter will be padded." 196 | }, 197 | ) 198 | max_target_length: Optional[int] = field( 199 | default=128, 200 | metadata={ 201 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 202 | "than this will be truncated, sequences shorter will be padded." 203 | }, 204 | ) 205 | val_max_target_length: Optional[int] = field( 206 | default=None, 207 | metadata={ 208 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 209 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 210 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 211 | "during ``evaluate`` and ``predict``." 212 | }, 213 | ) 214 | pad_to_max_length: bool = field( 215 | default=False, 216 | metadata={ 217 | "help": "Whether to pad all samples to model maximum sentence length. " 218 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 219 | "efficient on GPU but very bad for TPU." 220 | }, 221 | ) 222 | max_train_samples: Optional[int] = field( 223 | default=None, 224 | metadata={ 225 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 226 | "value if set." 227 | }, 228 | ) 229 | max_eval_samples: Optional[int] = field( 230 | default=None, 231 | metadata={ 232 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 233 | "value if set." 234 | }, 235 | ) 236 | max_predict_samples: Optional[int] = field( 237 | default=None, 238 | metadata={ 239 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 240 | "value if set." 241 | }, 242 | ) 243 | num_beams: Optional[int] = field( 244 | default=None, 245 | metadata={ 246 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 247 | "which is used during ``evaluate`` and ``predict``." 248 | }, 249 | ) 250 | ignore_pad_token_for_loss: bool = field( 251 | default=True, 252 | metadata={ 253 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 254 | }, 255 | ) 256 | source_prefix: Optional[str] = field( 257 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 258 | ) 259 | compute_rouge_for_train: bool = field( 260 | default=True, metadata={"help": "Whether to compute rouge in every eval cycle within training loop"} 261 | ) 262 | prediction_path: Optional[str] = field( 263 | default=None, metadata={"help": "Path to output prediction file that will be generated"} 264 | ) 265 | 266 | def __post_init__(self): 267 | if self.dataset_name is None and self.train_file is None and self.validation_file is None \ 268 | and self.test_file is None: 269 | raise ValueError("Need either a dataset name or a training/validation/test file.") 270 | else: 271 | if self.train_file is not None: 272 | extension = self.train_file.split(".")[-1] 273 | assert extension in ["csv", "json", "jsonl", "pickle"],\ 274 | "`train_file` should be a csv, json, jsonl, or pickle file." 275 | if self.validation_file is not None: 276 | extension = self.validation_file.split(".")[-1] 277 | assert extension in ["csv", "json", "jsonl", "pickle"],\ 278 | "`validation_file` should be a csv, json, jsonl, or pickle file." 279 | if self.val_max_target_length is None: 280 | self.val_max_target_length = self.max_target_length 281 | 282 | 283 | summarization_name_mapping = { 284 | "amazon_reviews_multi": ("review_body", "review_title"), 285 | "big_patent": ("description", "abstract"), 286 | "cnn_dailymail": ("article", "highlights"), 287 | "orange_sum": ("text", "summary"), 288 | "pn_summary": ("article", "summary"), 289 | "psc": ("extract_text", "summary_text"), 290 | "samsum": ("dialogue", "summary"), 291 | "thaisum": ("body", "summary"), 292 | "xglue": ("news_body", "news_title"), 293 | "xsum": ("document", "summary"), 294 | "wiki_summary": ("article", "highlights"), 295 | } 296 | 297 | 298 | def main(): 299 | # See all possible arguments in src/transformers/training_args.py 300 | # or by passing the --help flag to this script. 301 | # We now keep distinct sets of args, for a cleaner separation of concerns. 302 | print("got here") 303 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 304 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 305 | # If we pass only one argument to the script and it's the path to a json file, 306 | # let's parse it to get our arguments. 307 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 308 | else: 309 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 310 | 311 | # Setup logging 312 | logging.basicConfig( 313 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 314 | datefmt="%m/%d/%Y %H:%M:%S", 315 | handlers=[logging.StreamHandler(sys.stdout)], 316 | ) 317 | log_level = training_args.get_process_log_level() 318 | logger.setLevel(log_level) 319 | datasets.utils.logging.set_verbosity(log_level) 320 | transformers.utils.logging.set_verbosity(log_level) 321 | transformers.utils.logging.enable_default_handler() 322 | transformers.utils.logging.enable_explicit_format() 323 | 324 | # Log on each process the small summary: 325 | logger.warning( 326 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 327 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 328 | ) 329 | logger.info(f"Training/evaluation parameters {training_args}") 330 | 331 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 332 | "t5-small", 333 | "t5-base", 334 | "t5-large", 335 | "t5-3b", 336 | "t5-11b", 337 | ]: 338 | logger.warning( 339 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 340 | "`--source_prefix 'summarize: ' `" 341 | ) 342 | 343 | # Detecting last checkpoint. 344 | last_checkpoint = None 345 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 346 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 347 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 348 | raise ValueError( 349 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 350 | "Use --overwrite_output_dir to overcome." 351 | ) 352 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 353 | logger.info( 354 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 355 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 356 | ) 357 | 358 | # Set seed before initializing model. 359 | set_seed(training_args.seed) 360 | 361 | config = AutoConfig.from_pretrained( 362 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 363 | cache_dir=model_args.cache_dir, 364 | revision=model_args.model_revision, 365 | use_auth_token=True if model_args.use_auth_token else None, 366 | ) 367 | tokenizer = AutoTokenizer.from_pretrained( 368 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 369 | cache_dir=model_args.cache_dir, 370 | use_fast=model_args.use_fast_tokenizer, 371 | revision=model_args.model_revision, 372 | use_auth_token=True if model_args.use_auth_token else None, 373 | ) 374 | 375 | # Load pretrained model and tokenizer 376 | # 377 | # Distributed training: 378 | # The .from_pretrained methods guarantee that only one local process can concurrently 379 | # download model & vocab. 380 | 381 | if model_args.multiencoder_type is not None: 382 | if model_args.multiencoder_type == 'bart': 383 | model_class = BartForMultiConditionalGeneration 384 | else: 385 | raise ValueError(f"Invalid multiencoder_type: {model_args.multiencoder_type}") 386 | else: 387 | model_class = AutoModelForSeq2SeqLM 388 | model = model_class.from_pretrained( 389 | model_args.model_name_or_path, 390 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 391 | config=config, 392 | cache_dir=model_args.cache_dir, 393 | revision=model_args.model_revision, 394 | use_auth_token=True if model_args.use_auth_token else None 395 | ) 396 | 397 | model.resize_token_embeddings(len(tokenizer)) 398 | 399 | if model.config.decoder_start_token_id is None: 400 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 401 | 402 | if ( 403 | hasattr(model.config, "max_position_embeddings") 404 | and model.config.max_position_embeddings < data_args.max_source_length 405 | ): 406 | if model_args.resize_position_embeddings is None: 407 | logger.warning( 408 | f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " 409 | f"to {data_args.max_source_length}." 410 | ) 411 | model.resize_position_embeddings(data_args.max_source_length) 412 | elif model_args.resize_position_embeddings: 413 | model.resize_position_embeddings(data_args.max_source_length) 414 | else: 415 | raise ValueError( 416 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}" 417 | f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically " 418 | "resize the model's position encodings by passing `--resize_position_embeddings`." 419 | ) 420 | 421 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 422 | 423 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 424 | logger.warning( 425 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 426 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 427 | ) 428 | 429 | # Load dataset 430 | if model_args.multiencoder_type is not None: 431 | #Load dataset for multiencoder model 432 | if data_args.pad_to_max_length: 433 | raise NotImplementedError 434 | if not model_args.multiencoder_max_num_chunks: 435 | raise ValueError("Max num chunks required for multiencoder") 436 | 437 | if training_args.do_train: 438 | logger.info("Loading training set") 439 | if data_args.train_file.endswith('.pickle'): 440 | with open(data_args.train_file, 'rb') as f: 441 | train_dataset = pickle.load(f) 442 | else: 443 | train_dataset = MultiEncoderDataset( 444 | data_path=data_args.train_file, 445 | tokenizer=tokenizer, 446 | chunk_size=data_args.max_source_length, 447 | max_num_chunks=model_args.multiencoder_max_num_chunks, 448 | max_target_length=data_args.max_target_length, 449 | stride=model_args.multiencoder_stride, 450 | num_samples=data_args.max_train_samples, 451 | ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss, 452 | ) 453 | 454 | if training_args.do_eval: 455 | logger.info("Loading eval set") 456 | if data_args.validation_file.endswith('.pickle'): 457 | with open(data_args.validation_file, 'rb') as f: 458 | eval_dataset = pickle.load(f) 459 | else: 460 | eval_dataset = MultiEncoderDataset( 461 | data_path=data_args.validation_file, 462 | tokenizer=tokenizer, 463 | chunk_size=data_args.max_source_length, 464 | max_num_chunks=model_args.multiencoder_max_num_chunks, 465 | max_target_length=data_args.val_max_target_length, 466 | stride=model_args.multiencoder_stride, 467 | num_samples=data_args.max_eval_samples, 468 | ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss 469 | ) 470 | 471 | if training_args.do_predict: 472 | logger.info("Loading predict set") 473 | if data_args.test_file.endswith('.pickle'): 474 | with open(data_args.test_file, 'rb') as f: 475 | predict_dataset = pickle.load(f) 476 | else: 477 | predict_dataset = MultiEncoderDataset( 478 | data_path=data_args.test_file, 479 | tokenizer=tokenizer, 480 | chunk_size=data_args.max_source_length, 481 | max_num_chunks=model_args.multiencoder_max_num_chunks, 482 | max_target_length=data_args.val_max_target_length, 483 | stride=model_args.multiencoder_stride, 484 | num_samples=data_args.max_predict_samples, 485 | ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss 486 | ) 487 | 488 | data_collator = None 489 | else: 490 | # Load datasets for standard models 491 | # 492 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 493 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 494 | # (the dataset will be downloaded automatically from the datasets Hub). 495 | # 496 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 497 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 498 | # 499 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 500 | # download the dataset. 501 | if data_args.dataset_name is not None: 502 | # Downloading and loading a dataset from the hub. 503 | raw_datasets = load_dataset( 504 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 505 | ) 506 | else: 507 | data_files = {} 508 | if data_args.train_file is not None: 509 | data_files["train"] = data_args.train_file 510 | extension = data_args.train_file.split(".")[-1] 511 | if data_args.validation_file is not None: 512 | data_files["validation"] = data_args.validation_file 513 | extension = data_args.validation_file.split(".")[-1] 514 | if data_args.test_file is not None: 515 | data_files["test"] = data_args.test_file 516 | extension = data_args.test_file.split(".")[-1] 517 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 518 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 519 | # https://huggingface.co/docs/datasets/loading_datasets.html. 520 | 521 | # Preprocessing the datasets. 522 | # We need to tokenize inputs and targets. 523 | if training_args.do_train: 524 | column_names = raw_datasets["train"].column_names 525 | elif training_args.do_eval: 526 | column_names = raw_datasets["validation"].column_names 527 | elif training_args.do_predict: 528 | column_names = raw_datasets["test"].column_names 529 | else: 530 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 531 | return 532 | 533 | # Get the column names for input/target. 534 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 535 | if data_args.text_column is None: 536 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 537 | else: 538 | text_column = data_args.text_column 539 | if text_column not in column_names: 540 | raise ValueError( 541 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 542 | ) 543 | if data_args.summary_column is None: 544 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 545 | else: 546 | summary_column = data_args.summary_column 547 | if summary_column not in column_names: 548 | raise ValueError( 549 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 550 | ) 551 | 552 | # Temporarily set max_target_length for training. 553 | 554 | max_target_length = data_args.max_target_length 555 | padding = "max_length" if data_args.pad_to_max_length else False 556 | 557 | def preprocess_function(examples): 558 | inputs = examples[text_column] 559 | targets = examples[summary_column] 560 | inputs = [prefix + inp for inp in inputs] 561 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 562 | 563 | # Setup the tokenizer for targets 564 | with tokenizer.as_target_tokenizer(): 565 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 566 | 567 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 568 | # padding in the loss. 569 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 570 | labels["input_ids"] = [ 571 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 572 | ] 573 | 574 | model_inputs["labels"] = labels["input_ids"] 575 | return model_inputs 576 | 577 | if training_args.do_train: 578 | if "train" not in raw_datasets: 579 | raise ValueError("--do_train requires a train dataset") 580 | train_dataset = raw_datasets["train"] 581 | if data_args.max_train_samples is not None: 582 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 583 | with training_args.main_process_first(desc="train dataset map pre-processing"): 584 | train_dataset = train_dataset.map( 585 | preprocess_function, 586 | batched=True, 587 | num_proc=data_args.preprocessing_num_workers, 588 | remove_columns=column_names, 589 | load_from_cache_file=not data_args.overwrite_cache, 590 | desc="Running tokenizer on train dataset", 591 | ) 592 | 593 | if training_args.do_eval: 594 | max_target_length = data_args.val_max_target_length 595 | if "validation" not in raw_datasets: 596 | raise ValueError("--do_eval requires a validation dataset") 597 | eval_dataset = raw_datasets["validation"] 598 | if data_args.max_eval_samples is not None: 599 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 600 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 601 | eval_dataset = eval_dataset.map( 602 | preprocess_function, 603 | batched=True, 604 | num_proc=data_args.preprocessing_num_workers, 605 | remove_columns=column_names, 606 | load_from_cache_file=not data_args.overwrite_cache, 607 | desc="Running tokenizer on validation dataset", 608 | ) 609 | print('size of labels tensor', max(len(item['labels']) for item in eval_dataset)) 610 | 611 | if training_args.do_predict: 612 | max_target_length = data_args.val_max_target_length 613 | if "test" not in raw_datasets: 614 | raise ValueError("--do_predict requires a test dataset") 615 | predict_dataset = raw_datasets["test"] 616 | if data_args.max_predict_samples is not None: 617 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 618 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 619 | predict_dataset = predict_dataset.map( 620 | preprocess_function, 621 | batched=True, 622 | num_proc=data_args.preprocessing_num_workers, 623 | remove_columns=column_names, 624 | load_from_cache_file=not data_args.overwrite_cache, 625 | desc="Running tokenizer on prediction dataset", 626 | ) 627 | 628 | # Data collator 629 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 630 | data_collator = DataCollatorForSeq2Seq( 631 | tokenizer, 632 | model=model, 633 | label_pad_token_id=label_pad_token_id, 634 | pad_to_multiple_of=8 if training_args.fp16 else None, 635 | ) 636 | 637 | # Metric 638 | metric = load_metric("rouge") 639 | 640 | def postprocess_text(preds, labels): 641 | preds = [pred.strip() for pred in preds] 642 | labels = [label.strip() for label in labels] 643 | 644 | # rougeLSum expects newline after each sentence 645 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 646 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 647 | 648 | return preds, labels 649 | 650 | def compute_metrics(eval_preds): 651 | preds, labels = eval_preds 652 | if isinstance(preds, tuple): 653 | preds = preds[0] 654 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 655 | if data_args.ignore_pad_token_for_loss: 656 | # Replace -100 in the labels as we can't decode them. 657 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 658 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 659 | 660 | # Some simple post-processing 661 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 662 | 663 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 664 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 665 | result['eval_mean_rouge'] = mean([result[k] for k in ['rouge1', 'rouge2', 'rougeLsum']]) 666 | 667 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 668 | result["gen_len"] = np.mean(prediction_lens) 669 | result = {k: round(v, 4) for k, v in result.items()} 670 | logger.info("computing metrics") 671 | return result 672 | 673 | # Initialize our Trainer 674 | trainer = Seq2SeqTrainer( 675 | model=model, 676 | args=training_args, 677 | train_dataset=train_dataset if training_args.do_train else None, 678 | eval_dataset=eval_dataset if training_args.do_eval else None, 679 | tokenizer=tokenizer, 680 | data_collator=data_collator, 681 | compute_metrics=( 682 | compute_metrics if training_args.predict_with_generate and data_args.compute_rouge_for_train else None 683 | ) 684 | ) 685 | 686 | # Training 687 | if training_args.do_train: 688 | logger.info("*** Train ***") 689 | checkpoint = None 690 | if training_args.resume_from_checkpoint is not None: 691 | checkpoint = training_args.resume_from_checkpoint 692 | elif last_checkpoint is not None: 693 | checkpoint = last_checkpoint 694 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 695 | trainer.save_model() # Saves the tokenizer too for easy upload 696 | 697 | metrics = train_result.metrics 698 | max_train_samples = ( 699 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 700 | ) 701 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 702 | 703 | trainer.log_metrics("train", metrics) 704 | trainer.save_metrics("train", metrics) 705 | trainer.save_state() 706 | 707 | # Evaluation 708 | results = {} 709 | max_length = ( 710 | training_args.generation_max_length 711 | if training_args.generation_max_length is not None 712 | else data_args.val_max_target_length 713 | ) 714 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 715 | if training_args.do_eval: 716 | logger.info("*** Evaluate ***") 717 | 718 | # Set compute_metrics here in case it wasn't set above 719 | trainer.compute_metrics = compute_metrics if training_args.predict_with_generate else None 720 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 721 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 722 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 723 | 724 | trainer.log_metrics("eval", metrics) 725 | trainer.save_metrics("eval", metrics) 726 | 727 | if training_args.do_predict: 728 | logger.info("*** Predict ***") 729 | 730 | predict_results = trainer.predict( 731 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 732 | ) 733 | metrics = predict_results.metrics 734 | max_predict_samples = ( 735 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 736 | ) 737 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 738 | 739 | trainer.log_metrics("predict", metrics) 740 | trainer.save_metrics("predict", metrics) 741 | 742 | if trainer.is_world_process_zero(): 743 | if training_args.predict_with_generate: 744 | predictions = tokenizer.batch_decode( 745 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 746 | ) 747 | predictions = [pred.strip() for pred in predictions] 748 | if data_args.prediction_path: 749 | prediction_path = data_args.prediction_path 750 | else: 751 | prediction_path = os.path.join(training_args.output_dir, "test.predictions") 752 | # print('Writing prediction file to', output_prediction_file) 753 | with open(prediction_path, "w") as writer: 754 | writer.write("\n".join(predictions)) 755 | 756 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 757 | if data_args.dataset_name is not None: 758 | kwargs["dataset_tags"] = data_args.dataset_name 759 | if data_args.dataset_config_name is not None: 760 | kwargs["dataset_args"] = data_args.dataset_config_name 761 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 762 | else: 763 | kwargs["dataset"] = data_args.dataset_name 764 | 765 | if training_args.push_to_hub: 766 | trainer.push_to_hub(**kwargs) 767 | else: 768 | trainer.create_model_card(**kwargs) 769 | 770 | return results 771 | 772 | 773 | def _mp_fn(index): 774 | # For xla_spawn (TPUs) 775 | main() 776 | 777 | 778 | if __name__ == "__main__": 779 | main() -------------------------------------------------------------------------------- /multiencoder/transformers/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018- The Hugging Face team. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. -------------------------------------------------------------------------------- /multiencoder/transformers/changes: -------------------------------------------------------------------------------- 1 | Modified https://github.com/huggingface/transformers/blob/master/examples/pytorch/summarization/run_summarization.py (see ../train.py) -------------------------------------------------------------------------------- /preprocessing/README.md: -------------------------------------------------------------------------------- 1 | # Preprocessing code 2 | 3 | Run all of the following steps from `/preprocessing` 4 | 5 | ## Download and format QMSum data for two-stage and Segment Encoder models 6 | ``` 7 | git clone https://github.com/Yale-LILY/QMSum.git 8 | mv QMSum ../ 9 | python prep_qmsum.py ../QMSum 10 | mv ../QMSum/data/ALL/jsonl/final ../data 11 | ``` 12 | 13 | This will produce three files for each of the train, val, and test splits in the `../data` folder: 14 | 15 | 16 | `.jsonl` 17 | 18 | contains one data point per line. Each data point consists of an individual query, query and meeting ids, a reference summary, and the general/specific query type label. 19 | 20 | `-meetings.jsonl` 21 | 22 | contains the meeting transcripts along with the associated `meeting_id` used to join the data with `.jsonl`. 23 | 24 | `.target` 25 | 26 | contains one data point per line in the format used for ROUGE evaluation. The order of the data points aligns with `.jsonl`. 27 | 28 | -------------------------------------------------------------------------------- /preprocessing/prep_qmsum.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import sys 9 | import os 10 | import re 11 | import json 12 | from pathlib import Path 13 | 14 | def format_utterance(utterance): 15 | content = re.sub(r'\{[^\}]*\}', '', \ 16 | utterance['content']).strip() # Remove all content of form "{...}" 17 | content = content.replace('A_M_I_', 'AMI') 18 | content = content.replace('L_C_D_', 'LCD') 19 | content = content.replace('P_M_S', 'PMS') 20 | content = content.replace('T_V_', 'TV') 21 | return f"{utterance['speaker']}: {content}" if content else None 22 | 23 | def fix_spans(prev_relevant_text_spans, none_count_arr): 24 | relevant_text_spans = [] 25 | for relevant_text_span in prev_relevant_text_spans: 26 | start_, end_ = int(relevant_text_span[0]), int(relevant_text_span[1]) 27 | # move the span index to the left by the number 28 | # of blank utterances to the left of that index 29 | start = max(start_ - none_count_arr[start_], 0) 30 | end = max(end_ - none_count_arr[end_], 0) 31 | relevant_text_spans.append([str(start), str(end)]) 32 | # The start/end utterances should be the same, except in the case 33 | #where the start/end was a blank utterance (which does occur in the annotations) 34 | assert meeting_transcripts_w_none[end_] == meeting_transcripts[end] \ 35 | or meeting_transcripts_w_none[end_] is None 36 | assert meeting_transcripts_w_none[start_] == meeting_transcripts[start] \ 37 | or meeting_transcripts_w_none[start_] is None 38 | return relevant_text_spans 39 | 40 | 41 | if __name__ == "__main__": 42 | # path to qmsum github folder 43 | qmsum_path_str = sys.argv[1] 44 | qmsum_path = Path(qmsum_path_str) 45 | 46 | academic = qmsum_path / "data/Academic/jsonl" 47 | academic_topics = set() 48 | for split in ["train", "val", "test"]: 49 | with open(os.path.join(academic, f"{split}.jsonl")) as f: 50 | for line in f: 51 | data = json.loads(line) 52 | academic_topics.add(str(data['topic_list'])) 53 | 54 | committee = qmsum_path / "data/Committee/jsonl" 55 | committee_topics = set() 56 | for split in ["train", "val", "test"]: 57 | with open(os.path.join(committee, f"{split}.jsonl")) as f: 58 | for line in f: 59 | data = json.loads(line) 60 | committee_topics.add(str(data['topic_list'])) 61 | 62 | product = qmsum_path / "data/Product/jsonl" 63 | product_topics = set() 64 | for split in ["train", "val", "test"]: 65 | with open(os.path.join(product, f"{split}.jsonl")) as f: 66 | for line in f: 67 | data = json.loads(line) 68 | product_topics.add(str(data['topic_list'])) 69 | 70 | 71 | for split in ["train", "val", "test"]: 72 | query_count = 0 73 | if not os.path.exists(f"{str(qmsum_path)}/data/ALL/jsonl/final"): 74 | os.mkdir(f"{str(qmsum_path)}/data/ALL/jsonl/final") 75 | 76 | with open(f"{qmsum_path_str}/data/ALL/jsonl/{split}.jsonl") as f, \ 77 | open(f"{qmsum_path_str}/data/ALL/jsonl/final/{split}.jsonl", "w") as outqf, \ 78 | open(f"{qmsum_path_str}/data/ALL/jsonl/final/{split}-meetings.jsonl", "w") as outmf, \ 79 | open(f"{qmsum_path_str}/data/ALL/jsonl/final/{split}.target", "w") as outt: 80 | for line_count, line in enumerate(f): 81 | data = json.loads(line) 82 | topic_str = str(data['topic_list']) 83 | if topic_str in product_topics: 84 | domain = "product" 85 | elif topic_str in committee_topics: 86 | domain = "committee" 87 | else: 88 | domain = "academic" 89 | 90 | # Write meeting to output file 91 | meeting_id = f"m_{split}_{line_count}" 92 | meeting_data = {} 93 | meeting_data["meeting_id"] = meeting_id 94 | meeting_data["domain"] = domain 95 | 96 | meeting_transcripts_w_none = [format_utterance(utt) for utt in data["meeting_transcripts"]] 97 | none_count_arr = [0] * len(meeting_transcripts_w_none) 98 | none_count = 0 99 | for i in range(len(meeting_transcripts_w_none)): 100 | if meeting_transcripts_w_none[i] is None: 101 | none_count += 1 102 | none_count_arr[i] = none_count 103 | meeting_transcripts = [x for x in meeting_transcripts_w_none if x is not None] 104 | meeting_data["meeting_transcripts"] = meeting_transcripts 105 | 106 | topic_list = [] 107 | for topic in data["topic_list"]: 108 | relevant_text_spans = fix_spans(topic['relevant_text_span'], none_count_arr) 109 | topic_list.append({"topic": topic["topic"], "relevant_text_span": relevant_text_spans}) 110 | meeting_data["topic_list"] = topic_list 111 | 112 | json.dump(meeting_data, outmf) 113 | outmf.write("\n") 114 | 115 | 116 | for gen_query in data['general_query_list']: 117 | query_data = {} 118 | query_id = f"q_{split}_{query_count}" 119 | query_data["query_id"] = query_id 120 | query_data["query"] = gen_query["query"] 121 | query_data["answer"] = gen_query["answer"] 122 | query_data["query_type"] = "general" 123 | query_data["meeting_id"] = meeting_id 124 | json.dump(query_data, outqf) 125 | outqf.write("\n") 126 | outt.write(gen_query["answer"] + "\n") 127 | query_count += 1 128 | 129 | for spec_query in data['specific_query_list']: 130 | query_data = {} 131 | query_id = f"q_{split}_{query_count}" 132 | query_data["query_id"] = query_id 133 | query_data["query"] = spec_query["query"] 134 | query_data["answer"] = spec_query["answer"] 135 | query_data["query_type"] = "specific" 136 | query_data["meeting_id"] = meeting_id 137 | 138 | relevant_text_spans = fix_spans(spec_query["relevant_text_span"], none_count_arr) 139 | query_data["relevant_text_span"] = relevant_text_spans 140 | 141 | json.dump(query_data, outqf) 142 | outqf.write("\n") 143 | outt.write(spec_query["answer"] + "\n") 144 | query_count += 1 145 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.11.3 2 | torch==1.9.1 3 | datasets==1.13.2 4 | nltk==3.6.6 5 | absl-py==0.14.0 6 | rouge-score==0.0.4 7 | summ-eval==0.89 -------------------------------------------------------------------------------- /rouge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/query-focused-sum/46cb3878ff9f55b963ad94728676189a6d421d60/rouge/__init__.py -------------------------------------------------------------------------------- /rouge/report_rouge.py: -------------------------------------------------------------------------------- 1 | """ 2 | * Copyright (c) 2021, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | # Report mean rouge scores given a file of reference summaries and one or more files of predicted summaries 9 | 10 | import argparse 11 | from collections import defaultdict 12 | from statistics import mean 13 | 14 | import stanza 15 | from summ_eval import rouge_metric 16 | 17 | stanza.download('en') 18 | nlp = stanza.Pipeline(lang='en', processors='tokenize') 19 | 20 | def preprocess(text): 21 | doc = nlp(text) 22 | return '\n'.join( 23 | ' '.join(token.text for token in sentence.tokens) 24 | for sentence in doc.sentences 25 | ) 26 | 27 | 28 | def report_mean_rouge(ref_path, pred_paths): 29 | metric = rouge_metric.RougeMetric() 30 | 31 | with open(ref_path) as f: 32 | refs = [preprocess(line) for line in f] 33 | print('First ref') 34 | print(refs[0]) 35 | 36 | all_scores = defaultdict(list) 37 | for i, pred_path in enumerate(pred_paths): 38 | with open(pred_path) as f: 39 | preds = [preprocess(line) for line in f] 40 | if i == 0: 41 | print('First pred') 42 | print(preds[0]) 43 | results = metric.evaluate_batch(preds, refs, aggregate=True) 44 | # print(results) 45 | all_scores['rouge1'].append(results['rouge']['rouge_1_f_score'] * 100) 46 | all_scores['rouge2'].append(results['rouge']['rouge_2_f_score'] * 100) 47 | all_scores['rougeL'].append(results['rouge']['rouge_l_f_score'] * 100) 48 | for metric_name, scores in sorted(all_scores.items()): 49 | print() 50 | print('*' * 10) 51 | print(metric_name) 52 | print('Individual scores:', ', '.join(f'{score:.2f}' for score in scores)) 53 | print(f'Mean: {mean(scores):.2f}') 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--ref-path', help='path to file with reference summaries') 59 | parser.add_argument('--pred-paths', nargs='+', help='paths to prediction files') 60 | args = parser.parse_args() 61 | report_mean_rouge(args.ref_path, args.pred_paths) 62 | --------------------------------------------------------------------------------