├── .DS_Store ├── .gitignore ├── README.md ├── bin ├── download_model │ ├── bloom.sh │ ├── gpt-neox-20b.sh │ ├── opt-30b.sh │ └── opt-66b.sh └── setup.sh ├── dataset_LICENSE ├── img └── intro.png ├── multiple_choice-dataset ├── cnn_dm │ ├── factcc │ │ ├── binary_choice-using_dateswp.jsonl │ │ ├── binary_choice-using_entswp.jsonl │ │ ├── binary_choice-using_negation.jsonl │ │ ├── binary_choice-using_numswp.jsonl │ │ └── binary_choice-using_pronoun.jsonl │ ├── factually_consistent-model_generated │ │ ├── binary_choice-using_banditsumm_distractors.jsonl │ │ ├── binary_choice-using_bert_lstm_pn_rl_distractors.jsonl │ │ ├── binary_choice-using_heter_graph_distractors.jsonl │ │ ├── binary_choice-using_lead3_distractors.jsonl │ │ ├── binary_choice-using_matchsumm_distractors.jsonl │ │ ├── binary_choice-using_mi_unsup_distractors.jsonl │ │ ├── binary_choice-using_neusumm_distractors.jsonl │ │ ├── binary_choice-using_oracle_disco_distractors.jsonl │ │ ├── binary_choice-using_oracle_distractors.jsonl │ │ ├── binary_choice-using_pacsum_bert_distractors.jsonl │ │ ├── binary_choice-using_pacsum_tfidf_distractors.jsonl │ │ ├── binary_choice-using_refresh_distractors.jsonl │ │ ├── binary_choice-using_rnn_ext_rl_distractors.jsonl │ │ ├── binary_choice-using_textrank_distractors.jsonl │ │ └── binary_choice-using_textrank_st_distractors.jsonl │ ├── fib │ │ ├── binary_choice-using_banditsumm_distractors.jsonl │ │ ├── binary_choice-using_bert_lstm_pn_rl_distractors.jsonl │ │ ├── binary_choice-using_heter_graph_distractors.jsonl │ │ ├── binary_choice-using_lead3_distractors.jsonl │ │ ├── binary_choice-using_matchsumm_distractors.jsonl │ │ ├── binary_choice-using_mi_unsup_distractors.jsonl │ │ ├── binary_choice-using_neusumm_distractors.jsonl │ │ ├── binary_choice-using_oracle_disco_distractors.jsonl │ │ ├── binary_choice-using_oracle_distractors.jsonl │ │ ├── binary_choice-using_pacsum_bert_distractors.jsonl │ │ ├── binary_choice-using_pacsum_tfidf_distractors.jsonl │ │ ├── binary_choice-using_refresh_distractors.jsonl │ │ ├── binary_choice-using_rnn_ext_rl_distractors.jsonl │ │ ├── binary_choice-using_textrank_distractors.jsonl │ │ └── binary_choice-using_textrank_st_distractors.jsonl │ ├── fir │ │ └── binary_choice-using_non_factual_gold_distractors.jsonl │ └── mfma │ │ └── binary_choice-using_bart-base.jsonl ├── fib.json └── xsum │ ├── factcc │ ├── binary_choice-using_dateswp.jsonl │ ├── binary_choice-using_entswp.jsonl │ ├── binary_choice-using_negation.jsonl │ ├── binary_choice-using_numswp.jsonl │ └── binary_choice-using_pronoun.jsonl │ ├── factually_consistent-model_generated │ ├── binary_choice-using_bart-base_distractors.jsonl │ ├── binary_choice-using_bart-large_distractors.jsonl │ ├── binary_choice-using_bloom-560m_distractors.jsonl │ ├── binary_choice-using_distil-bart_distractors.jsonl │ ├── binary_choice-using_distil-pegasus_distractors.jsonl │ ├── binary_choice-using_pegasus_distractors.jsonl │ └── binary_choice-using_t5-large_distractors.jsonl │ ├── fib │ ├── binary_choice-using_bart-base_distractors.jsonl │ ├── binary_choice-using_bart-large_distractors.jsonl │ ├── binary_choice-using_bloom-560m_distractors.jsonl │ ├── binary_choice-using_distil-bart_distractors.jsonl │ ├── binary_choice-using_distil-pegasus_distractors.jsonl │ ├── binary_choice-using_pegasus_distractors.jsonl │ └── binary_choice-using_t5-large_distractors.jsonl │ ├── fir │ └── binary_choice-using_non_factual_gold_distractors.jsonl │ └── mfma │ ├── binary_choice-using_bart-base.jsonl │ └── binary_choice-using_t5-base.jsonl ├── requirements.txt ├── software_LICENSE └── src ├── compute_fib_results.py ├── constructors.py ├── data ├── Batcher.py ├── Dataset.py ├── multiple_choice.py ├── preprocess_data.py ├── preprocess_data_test.py └── templates.py ├── eval ├── PredictionLogger.py └── Scorer.py ├── evaluate_mulChoice.py ├── evaluate_mulChoice_test.py ├── get_results.py ├── models ├── DecoderWrappers_forMulChoice.py ├── DecoderWrappers_forMulChoice_test.py ├── EncoderDecoderWrappers_forMulChoice.py ├── EncoderDecoderWrappers_forMulChoice_test.py ├── device_maps.py ├── model_flags.py └── utils.py └── utils ├── CONSTANTS.py ├── deepspeed.py ├── test_helpers.py └── util.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-three/fib/1d48ee2e52ac3f8ded69f4593db255ae2ba12200/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | tmp/ 3 | bart_large 4 | url_lists/ 5 | checkpoints/ 6 | data/ 7 | env/ 8 | exp_out 9 | results/ 10 | wandb/ 11 | lib/ 12 | output/ 13 | data/* 14 | pretrained_models 15 | pretrained_models/ 16 | .idea/ 17 | eche_ 18 | runs/ 19 | *.pyc 20 | slurm* 21 | .installed.cfg 22 | develop-eggs 23 | dist 24 | downloads 25 | eggs 26 | parts 27 | src/*.egg-info 28 | lib 29 | lib64 30 | !src/data 31 | multiple_choice-score.jsonl 32 | multiple_choice-predictions/ 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FIB 2 | 3 | This repository contains the code for ["Evaluating the Factual Consistency of Large Language Models Through Summarization"](https://arxiv.org/abs/2211.08412) 4 | 5 | 6 | 7 | 8 | ## FIB Benchmark 9 | 10 | The dataset is now on [HuggingFace](https://huggingface.co/datasets/r-three/fib) :hugs: 11 | Note that the multiple-choice accuracy is computed in a slightly different way in our work. See [below](#evaluating-models-on-fib) for more details. 12 | 13 | 14 | ## Evaluating Models 15 | 16 | ### Setup 17 | 18 | 1. Create a virtual environment and activate it. 19 | ``` 20 | python3 -m venv env 21 | source env/bin/activate 22 | ``` 23 | 2. Install dependencies 24 | ``` 25 | python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html 26 | ``` 27 | 3. Set environment variables (This step has to be done every session.) 28 | ``` 29 | source bin/setup.sh 30 | ``` 31 | 32 | ### Running Models 33 | 34 | The following command is used to evaluate models: 35 | ``` 36 | python src/evaluate_mulChoice.py -f {multiple_choice-dataset_filepath} -m {model} 37 | ``` 38 | 39 | For example, 40 | ```commandline 41 | python src/evaluate_mulChoice.py -f multiple_choice-dataset/xsum/fib/binary_choice-using_bart-base_distractors.jsonl -m facebook/opt-1.3b 42 | ``` 43 | Our code has only been tested on evaluating models from the BLOOM, OPT, GPT, and T5 families. 44 | 45 | Note that though DeepSpeed is implemented, we did not use it. So our implementation of DeepSpeed might have some bugs. 46 | 47 | ### Get Results 48 | The following command is used to gather multiple results and get the median score: 49 | ``` 50 | python src/scripts/get_results.py -e {all_experiment_directories_of_datasets} -m {list_models} 51 | ``` 52 | 53 | For example, 54 | ``` 55 | python src/scripts/get_results.py -f exp_out/multiple_choice/xsum/fib/* -m bigscience-T0_3B 56 | ``` 57 | 58 | ## Evaluating Models on FIB 59 | 60 | The difference between the FIB dataset released above and the evaluation here is 61 | - Here, we take the median accuracy across of the model across 3 prompts for each distractor model used. Then, we take a weighted average of the median accuracies across different distractor models. 62 | - In the FIB dataset, we combine all the examples from each distractor model and across XSum and CNN/DM into one file to simplify it. Users can use any prompt they want. 63 | 64 | The following commands will run it. 65 | ``` 66 | python src/evaluate_mulChoice.py -f multiple_choice-dataset/{dataset}/fib/binary_choice-* -m {model} 67 | python src/compute_fib_results.py -m {model} -d {dataset} 68 | ``` 69 | 70 | 71 | 72 | ## Other Binary Multiple-Choice Datasets 73 | 74 | The datasets are under ``multiple_choice-dataset/xsum`` and ``multiple_choice-dataset/cnn_dm`` for XSum and CNN\DM respectively. 75 | 76 | The different alternative choices include 77 | 1. FIB - Our benchmark of factually inconsistent model-generated summaries 78 | 2. [FactCC](https://github.com/salesforce/factCC.git) 79 | 3. [MFMA](https://github.com/hwanheelee1993/MFMA) 80 | 4. FIR - factually inconsistent reference summaries (i.e. reference summaries from XSum or CNN\DM that were annotated as factually inconsistent) 81 | 5. factually consistent model generated-summaries. 82 | 83 | Each example is a `json` consisting of the following keys: `{id, input, correct_choice, list_choices, lbl}` 84 | 85 | ## Citation ## 86 | 87 | 88 | If you find this repo helpful, welcome to cite our work: 89 | 90 | ``` 91 | @article{tam2022fib, 92 | title={Evaluating the Factual Consistency of Large Language Models Through Summarization}, 93 | author={Tam, Derek and Mascarenhas, Anisha and Zhang, Shiyue and Kwan, Sarah and Bansal, Mohit and Raffel, Colin}, 94 | journal={arXiv preprint arXiv:2211.08412}, 95 | year={2022} 96 | } 97 | ``` 98 | 99 | We use the following code in our works: 100 | 101 | ``` 102 | @inproceedings{kryscinski-etal-2020-evaluating, 103 | title = "Evaluating the Factual Consistency of Abstractive Text Summarization", 104 | author = "Kryscinski, Wojciech and 105 | McCann, Bryan and 106 | Xiong, Caiming and 107 | Socher, Richard", 108 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 109 | month = nov, 110 | year = "2020", 111 | address = "Online", 112 | publisher = "Association for Computational Linguistics", 113 | url = "https://aclanthology.org/2020.emnlp-main.750", 114 | doi = "10.18653/v1/2020.emnlp-main.750", 115 | pages = "9332--9346", 116 | } 117 | 118 | @inproceedings{lee-etal-2022-masked, 119 | title = "Masked Summarization to Generate Factually Inconsistent Summaries for Improved Factual Consistency Checking", 120 | author = "Lee, Hwanhee and 121 | Yoo, Kang Min and 122 | Park, Joonsuk and 123 | Lee, Hwaran and 124 | Jung, Kyomin", 125 | booktitle = "Findings of the Association for Computational Linguistics: NAACL 2022", 126 | month = jul, 127 | year = "2022", 128 | address = "Seattle, United States", 129 | publisher = "Association for Computational Linguistics", 130 | url = "https://aclanthology.org/2022.findings-naacl.76", 131 | doi = "10.18653/v1/2022.findings-naacl.76", 132 | pages = "1019--1030", 133 | } 134 | ``` 135 | -------------------------------------------------------------------------------- /bin/download_model/bloom.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget "https://huggingface.co/bigscience/bloom/raw/main/config.json" 4 | wget "https://huggingface.co/bigscience/bloom/raw/main/pytorch_model.bin.index.json" 5 | wget "https://huggingface.co/bigscience/bloom/raw/main/special_tokens_map.json" 6 | wget "https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json" 7 | wget "https://huggingface.co/bigscience/bloom/raw/main/tokenizer_config.json" 8 | 9 | 10 | for i in {1..9} 11 | do 12 | echo "$i" 13 | wget "https://huggingface.co/bigscience/bloom/resolve/main/pytorch_model_0000$i-of-00072.bin" 14 | done 15 | 16 | for i in {10..72} 17 | do 18 | echo "$i" 19 | wget "https://huggingface.co/bigscience/bloom/resolve/main/pytorch_model_000$i-of-00072.bin" 20 | done -------------------------------------------------------------------------------- /bin/download_model/gpt-neox-20b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..9} 4 | do 5 | echo "$i" 6 | wget "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/pytorch_model-0000$i-of-00046.bin" 7 | done 8 | 9 | for i in {10..46} 10 | do 11 | echo "$i" 12 | wget "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/pytorch_model-000$i-of-00046.bin" 13 | done -------------------------------------------------------------------------------- /bin/download_model/opt-30b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget "https://huggingface.co/facebook/opt-30b/raw/main/config.json" 4 | wget "https://huggingface.co/facebook/opt-30b/raw/main/merges.txt" 5 | wget "https://huggingface.co/facebook/opt-30b/raw/main/pytorch_model.bin.index.json" 6 | wget "https://huggingface.co/facebook/opt-30b/raw/main/special_tokens_map.json" 7 | wget "https://huggingface.co/facebook/opt-30b/raw/main/tokenizer_config.json" 8 | wget "https://huggingface.co/facebook/opt-30b/raw/main/vocab.json" 9 | 10 | 11 | for i in {1..7} 12 | do 13 | echo "$i" 14 | wget "https://huggingface.co/facebook/opt-30b/resolve/main/pytorch_model-0000$i-of-00007.bin" 15 | done 16 | -------------------------------------------------------------------------------- /bin/download_model/opt-66b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget "https://huggingface.co/facebook/opt-66b/raw/main/config.json" 4 | wget "https://huggingface.co/facebook/opt-66b/raw/main/merges.txt" 5 | wget "https://huggingface.co/facebook/opt-66b/raw/main/pytorch_model.bin.index.json" 6 | wget "https://huggingface.co/facebook/opt-66b/raw/main/special_tokens_map.json" 7 | wget "https://huggingface.co/facebook/opt-66b/raw/main/tokenizer_config.json" 8 | wget "https://huggingface.co/facebook/opt-66b/raw/main/vocab.json" 9 | 10 | 11 | for i in {1..9} 12 | do 13 | echo "$i" 14 | wget "https://huggingface.co/facebook/opt-66b/resolve/main/pytorch_model-0000$i-of-00014.bin" 15 | done 16 | 17 | for i in {10..14} 18 | do 19 | echo "$i" 20 | wget "https://huggingface.co/facebook/opt-66b/resolve/main/pytorch_model-000$i-of-00014.bin" 21 | done -------------------------------------------------------------------------------- /bin/setup.sh: -------------------------------------------------------------------------------- 1 | source env/bin/activate 2 | export LFQA_FAC_ROOT=`pwd` 3 | export PYTHONPATH=$LFQA_FAC_ROOT:$PYTHONPATH 4 | export PYTHON_EXEC=python 5 | -------------------------------------------------------------------------------- /dataset_LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | 69 | Section 1 -- Definitions. 70 | 71 | a. Adapted Material means material subject to Copyright and Similar 72 | Rights that is derived from or based upon the Licensed Material 73 | and in which the Licensed Material is translated, altered, 74 | arranged, transformed, or otherwise modified in a manner requiring 75 | permission under the Copyright and Similar Rights held by the 76 | Licensor. For purposes of this Public License, where the Licensed 77 | Material is a musical work, performance, or sound recording, 78 | Adapted Material is always produced where the Licensed Material is 79 | synched in timed relation with a moving image. 80 | 81 | b. Adapter's License means the license You apply to Your Copyright 82 | and Similar Rights in Your contributions to Adapted Material in 83 | accordance with the terms and conditions of this Public License. 84 | 85 | c. Copyright and Similar Rights means copyright and/or similar rights 86 | closely related to copyright including, without limitation, 87 | performance, broadcast, sound recording, and Sui Generis Database 88 | Rights, without regard to how the rights are labeled or 89 | categorized. For purposes of this Public License, the rights 90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 91 | Rights. 92 | 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. Share means to provide material to the public by any means or 116 | process that requires permission under the Licensed Rights, such 117 | as reproduction, public display, public performance, distribution, 118 | dissemination, communication, or importation, and to make material 119 | available to the public including in ways that members of the 120 | public may access the material from a place and at a time 121 | individually chosen by them. 122 | 123 | j. Sui Generis Database Rights means rights other than copyright 124 | resulting from Directive 96/9/EC of the European Parliament and of 125 | the Council of 11 March 1996 on the legal protection of databases, 126 | as amended and/or succeeded, as well as other essentially 127 | equivalent rights anywhere in the world. 128 | 129 | k. You means the individual or entity exercising the Licensed Rights 130 | under this Public License. Your has a corresponding meaning. 131 | 132 | 133 | Section 2 -- Scope. 134 | 135 | a. License grant. 136 | 137 | 1. Subject to the terms and conditions of this Public License, 138 | the Licensor hereby grants You a worldwide, royalty-free, 139 | non-sublicensable, non-exclusive, irrevocable license to 140 | exercise the Licensed Rights in the Licensed Material to: 141 | 142 | a. reproduce and Share the Licensed Material, in whole or 143 | in part; and 144 | 145 | b. produce, reproduce, and Share Adapted Material. 146 | 147 | 2. Exceptions and Limitations. For the avoidance of doubt, where 148 | Exceptions and Limitations apply to Your use, this Public 149 | License does not apply, and You do not need to comply with 150 | its terms and conditions. 151 | 152 | 3. Term. The term of this Public License is specified in Section 153 | 6(a). 154 | 155 | 4. Media and formats; technical modifications allowed. The 156 | Licensor authorizes You to exercise the Licensed Rights in 157 | all media and formats whether now known or hereafter created, 158 | and to make technical modifications necessary to do so. The 159 | Licensor waives and/or agrees not to assert any right or 160 | authority to forbid You from making technical modifications 161 | necessary to exercise the Licensed Rights, including 162 | technical modifications necessary to circumvent Effective 163 | Technological Measures. For purposes of this Public License, 164 | simply making modifications authorized by this Section 2(a) 165 | (4) never produces Adapted Material. 166 | 167 | 5. Downstream recipients. 168 | 169 | a. Offer from the Licensor -- Licensed Material. Every 170 | recipient of the Licensed Material automatically 171 | receives an offer from the Licensor to exercise the 172 | Licensed Rights under the terms and conditions of this 173 | Public License. 174 | 175 | b. No downstream restrictions. You may not offer or impose 176 | any additional or different terms or conditions on, or 177 | apply any Effective Technological Measures to, the 178 | Licensed Material if doing so restricts exercise of the 179 | Licensed Rights by any recipient of the Licensed 180 | Material. 181 | 182 | 6. No endorsement. Nothing in this Public License constitutes or 183 | may be construed as permission to assert or imply that You 184 | are, or that Your use of the Licensed Material is, connected 185 | with, or sponsored, endorsed, or granted official status by, 186 | the Licensor or others designated to receive attribution as 187 | provided in Section 3(a)(1)(A)(i). 188 | 189 | b. Other rights. 190 | 191 | 1. Moral rights, such as the right of integrity, are not 192 | licensed under this Public License, nor are publicity, 193 | privacy, and/or other similar personality rights; however, to 194 | the extent possible, the Licensor waives and/or agrees not to 195 | assert any such rights held by the Licensor to the limited 196 | extent necessary to allow You to exercise the Licensed 197 | Rights, but not otherwise. 198 | 199 | 2. Patent and trademark rights are not licensed under this 200 | Public License. 201 | 202 | 3. To the extent possible, the Licensor waives any right to 203 | collect royalties from You for the exercise of the Licensed 204 | Rights, whether directly or through a collecting society 205 | under any voluntary or waivable statutory or compulsory 206 | licensing scheme. In all other cases the Licensor expressly 207 | reserves any right to collect such royalties. 208 | 209 | 210 | Section 3 -- License Conditions. 211 | 212 | Your exercise of the Licensed Rights is expressly made subject to the 213 | following conditions. 214 | 215 | a. Attribution. 216 | 217 | 1. If You Share the Licensed Material (including in modified 218 | form), You must: 219 | 220 | a. retain the following if it is supplied by the Licensor 221 | with the Licensed Material: 222 | 223 | i. identification of the creator(s) of the Licensed 224 | Material and any others designated to receive 225 | attribution, in any reasonable manner requested by 226 | the Licensor (including by pseudonym if 227 | designated); 228 | 229 | ii. a copyright notice; 230 | 231 | iii. a notice that refers to this Public License; 232 | 233 | iv. a notice that refers to the disclaimer of 234 | warranties; 235 | 236 | v. a URI or hyperlink to the Licensed Material to the 237 | extent reasonably practicable; 238 | 239 | b. indicate if You modified the Licensed Material and 240 | retain an indication of any previous modifications; and 241 | 242 | c. indicate the Licensed Material is licensed under this 243 | Public License, and include the text of, or the URI or 244 | hyperlink to, this Public License. 245 | 246 | 2. You may satisfy the conditions in Section 3(a)(1) in any 247 | reasonable manner based on the medium, means, and context in 248 | which You Share the Licensed Material. For example, it may be 249 | reasonable to satisfy the conditions by providing a URI or 250 | hyperlink to a resource that includes the required 251 | information. 252 | 253 | 3. If requested by the Licensor, You must remove any of the 254 | information required by Section 3(a)(1)(A) to the extent 255 | reasonably practicable. 256 | 257 | 4. If You Share Adapted Material You produce, the Adapter's 258 | License You apply must not prevent recipients of the Adapted 259 | Material from complying with this Public License. 260 | 261 | 262 | Section 4 -- Sui Generis Database Rights. 263 | 264 | Where the Licensed Rights include Sui Generis Database Rights that 265 | apply to Your use of the Licensed Material: 266 | 267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 268 | to extract, reuse, reproduce, and Share all or a substantial 269 | portion of the contents of the database; 270 | 271 | b. if You include all or a substantial portion of the database 272 | contents in a database in which You have Sui Generis Database 273 | Rights, then the database in which You have Sui Generis Database 274 | Rights (but not its individual contents) is Adapted Material; and 275 | 276 | c. You must comply with the conditions in Section 3(a) if You Share 277 | all or a substantial portion of the contents of the database. 278 | 279 | For the avoidance of doubt, this Section 4 supplements and does not 280 | replace Your obligations under this Public License where the Licensed 281 | Rights include other Copyright and Similar Rights. 282 | 283 | 284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 285 | 286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 296 | 297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 306 | 307 | c. The disclaimer of warranties and limitation of liability provided 308 | above shall be interpreted in a manner that, to the extent 309 | possible, most closely approximates an absolute disclaimer and 310 | waiver of all liability. 311 | 312 | 313 | Section 6 -- Term and Termination. 314 | 315 | a. This Public License applies for the term of the Copyright and 316 | Similar Rights licensed here. However, if You fail to comply with 317 | this Public License, then Your rights under this Public License 318 | terminate automatically. 319 | 320 | b. Where Your right to use the Licensed Material has terminated under 321 | Section 6(a), it reinstates: 322 | 323 | 1. automatically as of the date the violation is cured, provided 324 | it is cured within 30 days of Your discovery of the 325 | violation; or 326 | 327 | 2. upon express reinstatement by the Licensor. 328 | 329 | For the avoidance of doubt, this Section 6(b) does not affect any 330 | right the Licensor may have to seek remedies for Your violations 331 | of this Public License. 332 | 333 | c. For the avoidance of doubt, the Licensor may also offer the 334 | Licensed Material under separate terms or conditions or stop 335 | distributing the Licensed Material at any time; however, doing so 336 | will not terminate this Public License. 337 | 338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 339 | License. 340 | 341 | 342 | Section 7 -- Other Terms and Conditions. 343 | 344 | a. The Licensor shall not be bound by any additional or different 345 | terms or conditions communicated by You unless expressly agreed. 346 | 347 | b. Any arrangements, understandings, or agreements regarding the 348 | Licensed Material not stated herein are separate from and 349 | independent of the terms and conditions of this Public License. 350 | 351 | 352 | Section 8 -- Interpretation. 353 | 354 | a. For the avoidance of doubt, this Public License does not, and 355 | shall not be interpreted to, reduce, limit, restrict, or impose 356 | conditions on any use of the Licensed Material that could lawfully 357 | be made without permission under this Public License. 358 | 359 | b. To the extent possible, if any provision of this Public License is 360 | deemed unenforceable, it shall be automatically reformed to the 361 | minimum extent necessary to make it enforceable. If the provision 362 | cannot be reformed, it shall be severed from this Public License 363 | without affecting the enforceability of the remaining terms and 364 | conditions. 365 | 366 | c. No term or condition of this Public License will be waived and no 367 | failure to comply consented to unless expressly agreed to by the 368 | Licensor. 369 | 370 | d. Nothing in this Public License constitutes or may be interpreted 371 | as a limitation upon, or waiver of, any privileges and immunities 372 | that apply to the Licensor or You, including from the legal 373 | processes of any jurisdiction or authority. 374 | 375 | 376 | ======================================================================= 377 | 378 | Creative Commons is not a party to its public 379 | licenses. Notwithstanding, Creative Commons may elect to apply one of 380 | its public licenses to material it publishes and in those instances 381 | will be considered the “Licensor.” The text of the Creative Commons 382 | public licenses is dedicated to the public domain under the CC0 Public 383 | Domain Dedication. Except for the limited purpose of indicating that 384 | material is shared under a Creative Commons public license or as 385 | otherwise permitted by the Creative Commons policies published at 386 | creativecommons.org/policies, Creative Commons does not authorize the 387 | use of the trademark "Creative Commons" or any other trademark or logo 388 | of Creative Commons without its prior written consent including, 389 | without limitation, in connection with any unauthorized modifications 390 | to any of its public licenses or any other arrangements, 391 | understandings, or agreements concerning use of licensed material. For 392 | the avoidance of doubt, this paragraph does not form part of the 393 | public licenses. 394 | 395 | Creative Commons may be contacted at creativecommons.org. 396 | -------------------------------------------------------------------------------- /img/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/r-three/fib/1d48ee2e52ac3f8ded69f4593db255ae2ba12200/img/intro.png -------------------------------------------------------------------------------- /multiple_choice-dataset/cnn_dm/fib/binary_choice-using_lead3_distractors.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "f632a7540e8641033ce0ff316bee6fe38df6a972", "input": "( cnn ) the last time muhammadu buhari came to power in nigeria , it was by force . this time it was by the ballot box . here are five reasons why one of the most fiercely-contested elections in the country 's history is so important . for the first time in nigeria 's history , the opposition defeated the ruling party in democratic elections . muhammadu buhari , 72 , won nigeria 's presidential election , defeating incumbent goodluck jonathan by about two million votes . nigeria is significant because it is the biggest economy and most populous country in africa ; it is also one of africa 's largest oil producers and is a major supplier of crude oil to the united states . this is n't buhari 's first time leading nigeria -- but it 's his first time in nearly 30 years . the reformed dictator is a sunni muslim from nigeria 's poorer north , while jonathan comes from a christian and animist south that is rich with oil . buhari 's win comes after a long history of military rule , coups and botched attempts at democracy in the country . many nigerians told cnn that they saw president jonathan as an ineffectual leader who was indecisive in dealing with the terror group boko haram -- and weak on corruption . buhari , who was campaigning for the fourth time , capitalized on these weaknesses and some analysts believe that his military background was an advantage for him . nigerians wanted a strong leader who could keep them safe from boko haram 's murderous raids -- and buhari also campaigned as a born-again democrat to allay fears about his strict military regime the last time around . he stressed that nigeria 's security needs to be the next government 's focus . his campaign was also fiercely anti-corruption -- he ran under the slogan of `` new broom , '' and his supporters were often pictured holding brooms in the lead-up to the vote . the elections were largely predicted to be violent and everyone , nigerians included , expected the worst . some families moved abroad and there was sporadic violence across the country in the lead up to the election . but those fears turned out to be mostly unfounded , and the elections held relatively peacefully -- with the exception of attacks in the north of the country , where around 11 people died . many also praised president jonathan 's gracious and quick concession of defeat as it almost certainly prevented post-election violence . president-elect buhari said wednesday in a speech to the nation : `` the eyes of the world were focused on us to see if we can vote in a peaceful way and carry out elections in an orderly manner . `` we have proven to the world that we are a people who have embraced democracy and a people who seek a government by , for and for the people . '' on election day , nigerians queued for hours in hot weather to cast their vote . some of the biometric reader machines malfunctioned -- including the one at president jonathan 's polling station -- and voting had to be extended into the following day . but the technical issues did n't keep people from voting -- and in lagos , some voters cast their ballots with the aid of the light from their mobile phones . and even though some card readers did n't work in some places , many say they helped to cut down on vote rigging . boko haram is n't the only obstacle facing the new president . the economy , crime and generating enough power to light up the country are other major issues . the pressure will soon be on buhari to deliver and there will be no excuses . if he fails , nigerians will be waiting for him at the polls just four short years from now .", "correct_choice": " muhammadu buhari 's win marks the first democratic transition of power from a ruling party to the opposition . nigeria , the most populous country in africa , is grappling with violent boko haram extremists . ", "list_choices": [" ( cnn ) the last time muhammadu buhari came to power in nigeria , it was by force . this time it was by the ballot box . here are five reasons why one of the most fiercely-contested elections in the country 's history is so important . ", " muhammadu buhari 's win marks the first democratic transition of power from a ruling party to the opposition . nigeria , the most populous country in africa , is grappling with violent boko haram extremists . "], "lbl": 1} 2 | {"id": "1a04fa5c48398379b500e1336c81decb2e9f82e1", "input": "( cnn ) it was all set for a fairytale ending for record breaking jockey ap mccoy . in the end it was a different but familiar name who won the grand national on saturday . 25-1 outsider many clouds , who had shown little form going into the race , won by a length and a half , ridden by jockey leighton aspell . aspell won last year 's grand national too , making him the first jockey since the 1950s to ride back-to-back winners on different horses . `` it feels wonderful , i asked big questions , '' aspell said of many clouds , moments after his victory . `` over the fences he was awesome . i was just hoping his batteries would last and they did , '' he added . no fairytale . yet for much of the grand national -- arguably the world 's most famous and certainly the sport 's most prestigious jump race -- it looked as if ap mccoy was about to write an ending befitting the career of a man who has dominated jump racing for two decades . his horse shutthefrontdoor was in the leading group as it negotiated the likes for becher 's brooke and the chair , some of the toughest jumps in racing . last week the 40-year-old ulsterman , who has won an astonishing 4,356 races , announced he would retire if he won the grand national for the second time in his career . shutthefrontdoor was heavily backed by the betting public sensing a storybook conclusion to mccoy 's career . uk and irish betting firms even predicted they would lose as much as $ 73 million if mccoy won . he was well placed going into the final straight but just could n't keep up after many clouds cut lose , and finished back in fifth . third time winner . but for trevor hemmings , the owner of many clouds , it was his third victory in the grand national . `` i always dreamed of winning my first national , '' a shocked hemmings told channel 4 . `` then along comes a second . that 's special . and when a third comes along , it 's such a wonderful , wonderful feeling . '' hemming went on to praise aspell 's performance . `` this morning talking we talked about the achievers , '' said hemmings . `` they are quiet , confident and experienced . he has all of them . '' mccoy 's fifth placed finish means he will race again at least once more , in two weeks time at sandown .", "correct_choice": " 25-1 shot many clouds wins grand national . second win a row for jockey leighton aspell . first jockey to win two in a row on different horses since 1950s . ", "list_choices": [" 25-1 shot many clouds wins grand national . second win a row for jockey leighton aspell . first jockey to win two in a row on different horses since 1950s . ", " ( cnn ) it was all set for a fairytale ending for record breaking jockey ap mccoy . in the end it was a different but familiar name who won the grand national on saturday . 25-1 outsider many clouds , who had shown little form going into the race , won by a length and a half , ridden by jockey leighton aspell . "], "lbl": 0} 3 | {"id": "37e94bbc8fbced5a08a1d869f8b752b43c605680", "input": "( cnn ) call it a little piece of heaven for a family torn apart by tragedy . back in july , sierra sharry and lane smith were just about to become parents . sharry was eight months pregnant . but then smith fell and hit his head . he was taken to the ou medical center in oklahoma city . smith never recovered . `` july 13th 2014 was the absolute worst day of my life , '' sharry posted on facebook . `` i lost my best friend . the father of my unborn child . '' their son taos arrived a few weeks later . when it was time for his 6-month pictures , sharry had a special request . maybe the photographer could make their family complete , just for one picture . `` they asked me if i would be willing to ' play around ' with capturing their first family photo by editing taos ' daddy in one of their pictures , '' kayli rene ' photography posted on facebook . `` i just got to thinking , we do n't have a picture with lane in it , '' the new mom told cnn affilaite koco . the photographer was n't sure it would work , but they found just the right picture of smith -- one that has him looking over his family 's shoulder . `` lane 's not physically here with us , of course , but that picture represents to us that he is always watching over us and he will always be there for us no matter what , '' sharry said . the family photo has become a social media sensation after appearing on the photographer 's facebook page this week . it has some 193,000 likes and more than 24,000 shares . `` i ca n't believe she actually did this , '' sharry said . `` it 's like amazing and apparently everyone else thinks it is too . ''", "correct_choice": " sierra sharry was eight months pregnant when her son 's father died . a photographer was able to add lane smith to the family photo . ", "list_choices": [" ( cnn ) call it a little piece of heaven for a family torn apart by tragedy . back in july , sierra sharry and lane smith were just about to become parents . sharry was eight months pregnant . ", " sierra sharry was eight months pregnant when her son 's father died . a photographer was able to add lane smith to the family photo . "], "lbl": 1} 4 | {"id": "9383c2b3fd48fd3445bdfa3adf26485c1126c356", "input": "( cnn ) it did n't seem like a fair fight . on one side were hulking football players and pro wrestlers , competing as teams of two to eat as many pounds of steak as they could , combined , in one hour . on another was a lone 124-pound mother of four . and sure enough , in the end , sunday 's contest at big texan steak ranch in amarillo , texas , was n't even close . molly schuyler scarfed down three 72-ounce steaks , three baked potatoes , three side salads , three rolls and three shrimp cocktails -- far outpacing her heftier rivals . that 's more than 13 pounds of steak , not counting the sides . and she did it all in 20 minutes , setting a record in the process . `` we 've been doing this contest since 1960 , and in all that time we 've never had anybody come in to actually eat that many steaks at one time , '' bobby lee , who co-owns the big texan , told cnn affiliate kvii . `` so this is a first for us , and after 55 years of it , it 's a big deal . '' in fairness , schuyler is n't your typical 124-pound person . the nebraska native , 35 , is a professional on the competitive-eating circuit and once gobbled 363 chicken wings in 30 minutes . wearing shades and a black hoodie , schuyler beat four other teams on sunday , including pairs of football players and pro wrestlers and two married competitive eaters . she also broke her own big texan record of two 72-ounce steaks and sides , set last year , when she bested previous record-holder joey `` jaws '' chestnut . the landmark big texan restaurant offers its `` 72-ounce challenge '' daily to anyone who can eat the massive steak , plus fixings , in under an hour . those who ca n't do so must pay $ 72 for the meal . schuyler , who now lives in sacramento , california , won $ 5,000 for her efforts . her feat will be submitted to guinness world records . but mostly , she just seemed pleased to enjoy a hearty meal on the house . `` it 's free , so i 'm pretty happy about that , '' she told kvii . `` otherwise it would have cost me about 300 bucks . ''", "correct_choice": " molly schuyler scarfed down three 72-ounce steaks sunday in amarillo , texas . the sacramento woman , 35 , is a professional on the competitive-eating circuit . ", "list_choices": [" ( cnn ) it did n't seem like a fair fight . on one side were hulking football players and pro wrestlers , competing as teams of two to eat as many pounds of steak as they could , combined , in one hour . on another was a lone 124-pound mother of four . ", " molly schuyler scarfed down three 72-ounce steaks sunday in amarillo , texas . the sacramento woman , 35 , is a professional on the competitive-eating circuit . "], "lbl": 1} 5 | {"id": "1b02f51e4c599cdc6acf88b41f4c7ff0bc02eafa", "input": "( cnn ) people magazine has anointed sandra bullock the world 's most beautiful woman of 2015 , the publication revealed on wednesday . bullock , 50 , joins a long line of actresses to receive the honor , including last year 's cover girl , lupita nyong'o , and gwyneth paltrow in 2013 . she seems to be taking it all in stride , calling the whole thing `` ridiculous . '' `` real beauty is quiet . especially in this town , it 's just so hard not to say , ' oh , i need to look like that , ' `` she told people . `` no , be a good person ; be a good mom ; do a good job with the lunch ; let someone cut in front of you who looks like they 're in a bigger hurry . the people i find most beautiful are the ones who are n't trying . '' the cover story focuses on bullock 's home life with her son , louis , 5 , and her efforts to stay healthy and fit past her 40s . `` i was putting him to bed and told him that even when i 'm old and gray and more wrinkly than i am now , i 'll still love him and want to tuck him in , '' she said . `` and he asked why i have wrinkles , and i said , ' well , i hope some of them are from laughing so much . ' and he touched my face and said , ' you 're not old , you 're just happy . ' `` the oscar-winning star of movies including `` gravity , '' `` the blind side '' and `` crash '' said she 's happy with who she is . `` as long as i 'm healthy and strong and i do n't let this mind of mine run amok with insecurities about what i am not , i can look in the mirror and like who i see . '' the selection of bullock , the oldest woman to receive top honors in the history of the list , is a sign that beauty knows no age , say some . `` great choice ! gorgeous , talented , over 50 and fabulous ! that 's the way it 's done ! '' wrote one fan on people 's facebook page . also making the `` most beautiful '' cut this year : gabrielle union , ariana grande and laverne cox . the issue hits newsstands friday .", "correct_choice": " people magazine has named actress sandra bullock the most beautiful woman in the world . `` be a good person ; be a good mom ; do a good job with the lunch , '' she says . ", "list_choices": [" ( cnn ) people magazine has anointed sandra bullock the world 's most beautiful woman of 2015 , the publication revealed on wednesday . bullock , 50 , joins a long line of actresses to receive the honor , including last year 's cover girl , lupita nyong'o , and gwyneth paltrow in 2013 . she seems to be taking it all in stride , calling the whole thing `` ridiculous . '' ", " people magazine has named actress sandra bullock the most beautiful woman in the world . `` be a good person ; be a good mom ; do a good job with the lunch , '' she says . "], "lbl": 1} 6 | -------------------------------------------------------------------------------- /multiple_choice-dataset/xsum/factually_consistent-model_generated/binary_choice-using_t5-large_distractors.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "38279745", "input": "Dozens of others were injured in the explosion early on Sunday.\nThe bomber struck at the entrance of the city's main port facilities. Residents say the blast could be heard across Mogadishu.\nNo group has said it carried out the attack, but the Somali Islamist group al-Shabab often carries out such bombings in the capital.\n\"We assisted 48 wounded people and carried 16 others who were killed in the blast,\" said Abdikadir Abdirahman Adem, head Mogadishu's Amin ambulance service.\nThe death toll is expected to rise further.", "correct_choice": "A bomber has killed at least 16 people in the city, officials say.", "list_choices": ["A bomber has killed at least 16 people in the city, officials say.", "A bomb in a major port in the capital of Mogadishu, which killed at least a dozen people, has killed at least a dozen people."], "lbl": 0} 2 | {"id": "34527912", "input": "Broken swords and spearheads were found by archaeologists on the RSPB Scotland nature reserve.\nTwelve pieces excavated from several different weapons have been handed over to Kilmartin Museum in Argyll.\nRSPB Scotland reserves archaeologist Jill Harden said they had probably been deliberately broken and thrown into a loch as part of a religious ceremony.\n\"This is the first discovery of this size from Argyll for many years,\" she said.\n\"The items were recovered from what had once been a freshwater loch - it seems that they had been purposely broken and cast into the waters as part of a ceremony, most likely as offerings or gifts to the gods or goddesses of the time.\n\"It is recorded that bronze swords were found on Coll in the 19th Century during drainage works, but their whereabouts today are unknown.\"\nThe archaeological investigation was directed by the Treasure Trove Unit, National Museums Scotland and RSPB Scotland.\nTrevor Cowie, from National Museums Scotland's department of Scottish history and archaeology, said: \"While a fair number of objects from this period have been discovered in the west of Scotland in the past, we generally know very little about the precise places where they were found.\n\"Archaeological techniques have developed dramatically since those 19th Century discoveries were made, so we have a great opportunity here to resolve many unanswered questions about life on Coll some 3,000 years ago.\"\nThe weapons can be viewed at the the Isle of Coll's An Cridhe community centre on Thursday and Friday.", "correct_choice": "Bronze weapons have been discovered on a Scotland nature reserve.", "list_choices": ["A collection of bronze swords and spearheads have been found in a loch in Argyll.", "Bronze weapons have been discovered on a Scotland nature reserve."], "lbl": 1} 3 | {"id": "35217021", "input": "Zabair Hussain, 41, was discovered with multiple injuries to his head and body in Staniforth Road, Darnall, Sheffield, at about 23:20 GMT. He later died at the scene.\nThe 28-year-old arrested man has been taken into police custody.\nOfficers believe a number of men were involved in an assault and have appealed for witnesses to come forward.\nDet Ch Insp Steve Handley, from South Yorkshire Police, said: \"We are still in the very early stages of the investigation and we're carrying out numerous enquiries to get to the bottom of what happened - from reviewing CCTV footage to speaking to potential witnesses.\n\"While I understand that incidents like this are worrying for those living locally, we have increased patrols by neighbourhood officers to reassure residents.\"", "correct_choice": "A man has been arrested after another man's body was found in a street.", "list_choices": ["A man has been arrested after another man's body was found in a street.", "A man has been arrested in connection with the death of a man in Sheffield."], "lbl": 0} 4 | {"id": "34971770", "input": "Official numbers showed revenues down 32.2% for the period to 16.4bn Macau patacas ($2.05bn; \u00c2\u00a31.36bn).\nExpectations were for a fall in revenues of just over 31%.\nMacau is the world's largest gaming centre - ahead of Las Vegas - and the only place in China where casinos are allowed.\nA special administrative region of China, Macau's economy relies heavily on gambling and shopping - especially by big spending tourists from the mainland.\nBut Chinese President Xi Jinping's campaign against corruption and luxury spending, which began in December 2012, has seen officials and others from the mainland more wary of gaming and spending in the city.\nChina's Communist Party prohibits officials from gambling, but until the 2012 crackdown, officials had reportedly visited Macau's casinos to gamble and spend.\nChina has emphasised Macau's need to diversify its economy away from gambling. The city's build up of new resorts and hotels is expected to help drive general tourism, however, analysts have said Macau will be hard-pressed to build up non-gaming streams of revenue in the near future.\nOfficial numbers released on Monday showed the city's economy shrank by 24.2% year-on-year during three months to September, the city's Statistics and Census Service said.\n\"Economic contraction in the third quarter was attributable to the continuous decline in exports of services, of which exports of gaming services decreased by 37.4% year-on-year and exports of other tourism services dropped by 15.3%,\" it added.\nOnce a Portuguese colony, gaming has taken place in Macau for more than 300 years. For many years it was referred to as the Monte Carlo of the Orient. The city was returned to Chinese rule in 1999.", "correct_choice": "Revenue in Macau fell by more than a third as China's corruption crackdown continued to drive away some tourists.", "list_choices": ["Revenue in Macau fell by more than a third as China's corruption crackdown continued to drive away some tourists.", "The world's biggest gambling centre, Macau, has reported a fall in revenues in the third quarter of the year."], "lbl": 0} 5 | {"id": "36348210", "input": "The 51-year-old had been negotiating a release from his contract following a rift with the board over his budget.\nHughes has been with the Highlanders since December 2013 and won the Scottish Cup last year, the club's first major honour.\n\"John will be remembered as a member of a great winning team,\" read a brief statement from Inverness CT.\nHughes had become increasingly frustrated at the loss of key squad members and spoke of his disappointment when an approach from Dundee United was blocked earlier this season.\nHaving previously managed at Falkirk, Hibernian, Hartlepool and Livingston, he replaced Terry Butcher at the Caledonian Stadium.\nAs well as lifting the Scottish Cup, Hughes steered Inverness to a third place finish in the Premiership last season, with this campaign opening with their first taste of European football.\nIn March 2014, Inverness reached the League Cup final, losing on penalties to Aberdeen.\nThe Inverness statement contained a message on behalf of Hughes, saying: \"I will look back on my time in the Highlands with a genuine fondness and warm affection for the club, the area and the community.\n\"The welcome I received from the fans and the response I got from the players throughout my two-and-a-half years there will live long in the memory as will everything else we shared in some of the ground-breaking successes we all enjoyed together during that period.\n\"I can readily assure my successor that they will inherit an excellent group of players and to each and every one of them could I also say a huge thanks for making my time with them so successful and so memorable - I wish them and the club every success in the future.\"", "correct_choice": "Inverness have confirmed the departure of John Hughes.", "list_choices": ["John Hughes has been named as Inverness CT manager after being sacked by Inverness CT.", "Inverness have confirmed the departure of John Hughes."], "lbl": 1} 6 | {"id": "31920680", "input": "The CQC previously rated the Penberthy home in Newquay as inadequate.\nNew reports highlight problems at three other homes run by Cornwall Care: Headlands in Carbis Bay, Trevern in Falmouth and Blackwood in Camborne.\nCornwall Care said it was rare for an inspection not to point out areas for improvement.\nThe CQC said Headlands was \"unsafe\" and overall \"was not caring\".\nAt Trevern \"one person had not been able to have a bath or shower for eleven months due to the home not obtaining the appropriate bathing equipment to meet the person's needs,\" the report stated.\nAction was also needed to address the \"care and welfare of people who use services\" and the \"safety and suitability of premises,\" it was claimed.\nThe report on Blackwood said \"people did not always have access to meaningful activities\" and action was needed regarding the \"safety and suitability of premises\".\nDue to changes in CQC reporting procedures the reports did not give an overall rating as it has done for Penberthy.\nAdrian Hughes, the commission's deputy chief inspector of adult social care, said there had been \"slippage\" in services provided by Cornwall Care.\nHe said: \"They have taken their eye off the ball in some aspects of that care.\"\nA spokesman for Cornwall Care said: \"We have worked closely with CQC and commissioners for many years and it is rare that an inspection of any care service does not point out areas for improvement.\n\"We welcome that feedback and always act quickly to make sure we are offering the best possible service to our clients.\"", "correct_choice": "Action is needed at homes for the elderly run by Cornwall Care, after the company took its \"eye off the ball\", the CQC said.", "list_choices": ["The CQC has rated a home in Penberthy as inadequate, according to a report.", "Action is needed at homes for the elderly run by Cornwall Care, after the company took its \"eye off the ball\", the CQC said."], "lbl": 1} 7 | {"id": "14370062", "input": "Staff in Jobcentres, banks, building societies and utility companies in England could also be trained to spot - and counsel - vulnerable people.\nThe ideas are raised in a consultation paper on suicide prevention.\nThe Samaritans said councils should have a mandatory responsibility to try to prevent suicides in their areas.\nSome 4,400 people killed themselves in England in 2009.\nClaire Wylie, head of policy and research at the Samaritans, told the BBC News website that many suicide attempts were made on impulse, so trying to restrict access to potentially lethal means was important.\n\"We know that people who are feeling suicidal are often very ambivalent about actually ending their lives,\" she said.\n\"If you can interrupt them at that moment you can prevent them going ahead.\"\nPreventing deaths by jumping is a key aim of the consultation and it suggests a number of ways of doing that.\nThey include:\nOverall, the number of suicides has steadily fallen in recent years, but the number of deaths on Britain's rail network had been rising until last year.\nHowever, specialist training from Samaritans for rail staff was key to an 11% fall in 2010, according to the Rail Safety and Standards Board.\nLondon Underground is also rolling out training to all of its staff after a pilot project at one station close to a psychiatric inpatient unit helped reduce suicides.\nThe government wants to see that sort of training given to a much wider range of people who come into contact with individuals who could be vulnerable because of their social or economic circumstances.\nJobcentre and benefit office staff, as well as employees in banks, building societies and utility firms are among those suggested in the consultation.\nMs Wylie said: \"More training for all frontline staff is really important, but that needs investment and money is tight.\n\"In general, we really welcome the government's strategy, but there needs to be a lot more actual commitment to action.\n\"There's also an issue about local implementation because things like putting up signs and barriers depend on the individual local authority actually caring about suicide prevention.\n\"We would like to see a mandatory responsibility placed on local authorities to take this seriously.\"\nThe consultation closes on 11 October.", "correct_choice": "Staff could be trained more to prevent suicides, under proposals to save lives.", "list_choices": ["The number of suicide attempts in England and Wales should be reduced, according to the government.", "Staff could be trained more to prevent suicides, under proposals to save lives."], "lbl": 1} 8 | {"id": "41091477", "input": "The White Garden, at Kensington Palace, was planted to mark 20 years since Princess Diana died in a car crash.\nThe Duchess of Cambridge joined the princes on the garden tour.\nA spokeswoman for Kensington Palace said: \"The engagement will allow the princes to pay tribute to the life and work of their mother.\"\nThey met representatives from the causes and charities supported by Diana, including the Royal Marsden and Great Ormond Street hospitals, the National Aids Trust, Centrepoint youth homelessness charity and the Leprosy Mission.\nMembers of the public have been leaving tributes and flowers at the gates of the palace to mark the anniversary of Diana's death.\nThe Princess of Wales died on 31 August 1997 in Paris, when William, now the Duke of Cambridge, was 15 and his brother was 12.\nThe garden at their mother's former home has been inspired by memories of her life, style and image, such as her white \"Elvis\" Catherine Walker dress.\nThe White Garden, as it is known, follows a tradition first established at Sissinghurst Castle in Kent, famous for its own white garden created in the 1930s.\nTheir Royal Highnesses met gardener Sean Harkin who designed the display and Graham Dillamore who knew the princess when he worked there some 30 years ago.\nThe garden has been open since spring and will continue into September with white roses, lilies, gladioli and cosmos.\nIt is the fourth London memorial created in tribute to Diana - the others are the Diana Memorial Playground at Kensington Palace, the Diana Memorial Fountain in Hyde Park, and the Diana Memorial Walk at St James's Palace.", "correct_choice": "Prince William and his brother have visited a London memorial garden for their mother on the eve of the 20th anniversary of her death.", "list_choices": ["The Duke and Duchess of Cambridge have visited a garden in commemorating the death of Princess Diana.", "Prince William and his brother have visited a London memorial garden for their mother on the eve of the 20th anniversary of her death."], "lbl": 1} 9 | {"id": "35516044", "input": "A member of the public raised the alarm after seeing the woman, aged in her 50s, fall at Peveril Point, near Swanage, on Saturday afternoon.\nShe was airlifted by the coastguard helicopter to King George's Field park where she was treated by paramedics.\nThe injured woman, who is from the Swanage area, was taken to Southampton General Hospital by air ambulance.\nCh Insp Bob Acaster, of Dorset Police, said: \"Emergency services worked hard in very difficult weather to rescue the woman from the cliff and bring her to safety.\"\nPolice said the woman's family had been informed.", "correct_choice": "A woman has suffered injuries falling from the cliff near Swanage.", "list_choices": ["A woman has been rescued from a cliff after being rescued by helicopter pilots.", "A woman has suffered injuries falling from the cliff near Swanage."], "lbl": 1} 10 | {"id": "36712508", "input": "Chelsey Lee, 26, played for Bucheon KEB Hana Bank in the Women's Korean Basketball League (WKBL), whose teams are allowed only two foreign players.\nProsecutors were asked to investigate after the Korean Olympic Committee pushed for Lee's naturalisation.\nThe WKBL says Lee will be suspended for life and her records annulled.\nThe Miami-born centre won the league's rookie of the year award in the 2015-16 season after helping her team reach the championship series.\nHowever, Lee and her two agents are suspected of fabricating her and her father's birth certificates to show she had a South Korean grandmother.\nBucheon KEB Hana Bank issued a public apology, vowing to take legal action against Lee and her agents.\nThe club's owner and head coach will step down.\nWKBL commissioner Shin Sun-woo said the team's records and ranking will be nullified and the league will scrap the extra quota for international players with a Korean parent or grandparent.", "correct_choice": "An Miami-born basketball player has been banned from South Korea's domestic league for life after prosecutors said she forged her birth documents.", "list_choices": ["An Miami-born basketball player has been banned from South Korea's domestic league for life after prosecutors said she forged her birth documents.", "Chelsey Lee has been suspended for life for a second time for a sex abuse scandal in the Women's Basketball League."], "lbl": 0} 11 | {"id": "39771057", "input": "The RSPB said 2,270 black-tailed godwits spent time on the island this spring, almost double the previous record of 1,320 in 2013.\nThe majority of the birds this year were found in a tiny field in Kilmoluaig.\nGodwits often stop off in the Hebrides to refuel during their migration to Iceland, where they breed.\nSpotters identified some of the birds as having come from France, Portugal and Spain due to the rings fitted on their legs.\nJohn Bowler, Tiree officer for RSPB Scotland, said: \"Black-tailed godwits are known to stop off here for food on their way to Iceland, particularly when adverse northerly winds hamper their progress across the North Atlantic.\n\"So, with huge numbers of golden plover already noted on Tiree during pretty windy conditions, it wasn't a huge surprise when black-tailed godwits started turning up, too. However, to see flocks of this size is just incredible.\n\"Hopefully they will enjoy a good breeding season this year and I'm already looking forward to seeing them pass back through Tiree in the autumn.\"", "correct_choice": "A record-breaking number of migrating birds have been recorded in the Hebrides in 2014.", "list_choices": ["A record-breaking number of migrating birds have been recorded in the Hebrides in 2014.", "The RSPB has spotted a huge number of black-tailed godwits on the North Atlantic."], "lbl": 0} 12 | {"id": "34017987", "input": "Betsi Cadwaladr health board has suggested downgrading services at one of the area's three district hospitals due to a staffing shortage.\nA legal challenge blocked the plan to downgrade maternity care at Glan Clwyd Hospital in Bodelwyddan, Denbighshire.\nThat prompted the consultation, which includes a series of public meetings.\nResidents are unhappy with the plans, suggesting removing the service at hospitals like Wrexham Maelor and Ysbyty Gwynedd in Bangor will mean women having to travel further for care.\nHowever, bosses said any changes would be temporary and are needed to ensure the safety of mothers and babies.\nA dedicated health board website was launched on Monday to collate public reaction to the options, which also includes retaining all services.\nSeveral public meetings are due to take place in September.", "correct_choice": "A consultation about plans which could see maternity care downgraded from a district hospital in Denbighshire has begun.", "list_choices": ["A consultation about plans which could see maternity care downgraded from a district hospital in Denbighshire has begun.", "A health board has urged the council to scrap maternity care in Denbighshire."], "lbl": 0} 13 | {"id": "38247303", "input": "Kuba Moczyk, 22, died in hospital after he was knocked out in an unlicensed fight at the Tower Complex, Great Yarmouth, Norfolk, on 19 November.\nA memorial mass has been held at St Mary's Church in the town.\nFather Philip Shryane told the congregation Mr Moczyk' was a \"good man\" whose \"life was boxing\".\nMore on this story and others from Norfolk\nHe said Mr Moczyk was \"a young man with a good heart, with so much to give and so much to look forward to... but always a gentle smile\".\nHis uncle, Marcin Smigaj gave a tribute, in Polish, on behalf of the family. Mr Moczyk was due to be cremated.\nMr Moczyk, originally from Poland, worked at a chicken factory and lived in the town.\nHis trainer Scott Osinski said earlier that Mr Moczyk was winning the fight when he took the fatal blow.\nHis opponent is believed to be aged 17.", "correct_choice": "Friends and family of a boxer with a \"gentle smile\", who died after being knocked out in a fight, have attended a memorial mass.", "list_choices": ["A man has been killed in a boxing match in Norfolk.", "Friends and family of a boxer with a \"gentle smile\", who died after being knocked out in a fight, have attended a memorial mass."], "lbl": 1} 14 | {"id": "27991390", "input": "Local MP Ian Lucas said people were concerned about the impact it could have if the prison on Wrexham Industrial Estate assumes a local name.\nIn a letter, prisons minister Jeremy Wright says local names are \"generally avoided as most local people object\".\nHe said it was likely people would be invited to propose names for the \u00c2\u00a3212m prison which is due to open in 2017.\nWork is expected to start in August, creating up to 1,000 jobs, to build the prison which will house 2,100 inmates, making it the largest prison in the UK.\nThe overall project spend is lower than the original \u00c2\u00a3250m estimate and the construction will involve local business and enterprises, with 100 apprenticeships created.", "correct_choice": "Wrexham Industrial Estate's new prison is unlikely to be named after local name, says the prison minister.", "list_choices": ["Wrexham Industrial Estate's new prison is unlikely to be named after local name, says the prison minister.", "A prison in the UK that will house a large number of inmates in a prison in the UK is to open in August 2017, according to local MPs."], "lbl": 0} 15 | {"id": "39962189", "input": "Natural Resources Wales (NRW) said the impact on sites of special scientific interest (SSSIs) \"could not be fully mitigated\".\nThe \u00c2\u00a31.1bn M4 proposal would cross four SSSIs along the Gwent Levels.\nWelsh Government lawyers argued environmental concerns had to be balanced against other interests.\nThe inquiry in Newport heard the scheme would mean about 105 hectares of designated land, set aside for the protection of water invertebrates, would have to be lost.\nThe Gwent Levels' unique network of ditches, known as reens, were dug during Roman times and have since become a habitat for a range of rare species.\nThe Welsh Government has pledged to replace lost reens with new ones.\nDr Jessica Poole, of conservation body Natural Resources Wales (NRW), told the inquiry discussions between the regulator and the Welsh Government meant she was content with the proposed design of the new reens.\nBut she said there was no guarantee they would work, and it could be some time before they supported the aquatic insects the sites are meant to conserve.\nReplicating a complex ecology that has developed over centuries would be \"challenging\", she said.\nNRW said the Welsh Government had not demonstrated the project would comply with its statutory duty to promote sustainable development.\nShould the alternative blue route, suggested by transport expert Prof Stuart Cole, be adopted - the motorway's impact on SSSI land would be \"significantly reduced\", Dr Poole said.\nBut the inquiry heard several issues NRW had raised in letters responding to the project's draft plans had been addressed and it was now satisfied on matters including water quality, drainage and some protected species such as otters and bats.\nMorag Ellis QC, acting on behalf of the Welsh Government, said it was for Welsh ministers to balance any potential impact on SSSI land with other public interests related to the new motorway.\nClaiming adverse effects were \"fully mitigated for\" was to apply a standard not in accordance with the law, she said.\nShe described the changes NRW had made to its initial objections after extensive discussions with Welsh Government as \"a major step forward\".", "correct_choice": "The scale of loss of conservation land caused by the proposed M4 relief road would be unacceptable, a public inquiry has heard.", "list_choices": ["The scale of loss of conservation land caused by the proposed M4 relief road would be unacceptable, a public inquiry has heard.", "The Gwent Levels motorway has been scrapped after a regulator ruled that it could be a threat to the Gwent Levels."], "lbl": 0} 16 | {"id": "34843387", "input": "Katari Anuradha was shot and stabbed by at least three men wearing burkas, Indian media reported, quoting police. A motive has yet to be established.\nHer husband, who was with her, is in a critical condition with bullet and stab injuries.\nThe attack took place at the Chittoor Municipal Corporation office, where the staff tried to stop the attackers.\nSenior police official G Srinivas told the Indian Express newspaper that they were exploring several angles, including old rivalry and new enemies.\nThe assailants fled the scene after the attack, although reports say two people later handed themselves into police.\nThe attackers had been wearing burkas, one-piece veils that cover the face and body, as they forced their way into Ms Anuradha's office, media reports said.\nSecurity has been tightened in Chittoor and state police are closing borders with neighbouring Tamil Nadu state in an attempt to find the killers.", "correct_choice": "Katari Anuradha of Chittoor has been killed by unknown attackers.", "list_choices": ["Indian police have arrested a man in connection with the murder of a woman in Chittoor.", "Katari Anuradha of Chittoor has been killed by unknown attackers."], "lbl": 1} 17 | {"id": "21712349", "input": "It works by looking for a combination of \"markers\" in the blood which are different in healthy people and those with the disease.\nDelegates at the Alzheimer's Research UK Conference heard that the University of Nottingham is now developing a quick and easy test to do in clinics.\nIt could mean much earlier diagnosis and better treatments, they said.\nThe test uses some proteins that have been strongly linked with Alzheimer's disease, such as amyloid and APOE.\nBut through careful analysis of blood from people with the disease, as well as those with early-stage memory problems, the researchers detected some other markers that were suggestive of the disease.\nMost notably, some proteins related to inflammation seem to have been added to increase the power of the test.\nProf Kevin Morgan from the University of Nottingham said they still had to validate the test and it could be a decade before it was used in patients.\nBut he added that the combination of markers they had found was looking very promising.\n\"Our findings are exciting because they show that it is technically possible to distinguish between healthy people and those with Alzheimer's using a blood test.\n\"As blood tests are a fast and easy way of aiding diagnosis, we are really encouraged by these findings and the potential they hold for the future.\"\nHe said there were several ways the test could benefit patients, including giving people a definitive diagnosis, which was not always possible at the moment.\nIt could also direct future therapies to make sure patients were getting the most appropriate treatment, he explained.\nPotentially, it could be a \"cheap and easy pre-screen\" test which enabled Alzheimer's to be picked up before symptoms appeared, he said.\n\"The way we see it working is you can test people and it will tell them if they have the all-clear, or if they are medium- or high-risk.\n\"If they are medium-risk, they can be monitored closely and high-risk patients can be referred to a specialist for more in-depth testing.\"\nDr Eric Karran, director of Research at Alzheimer's Research UK, said: \"Giving people with dementia an accurate diagnosis is not always easy, and so building up our armoury of diagnostic techniques is vital.\n\"While there is still some way to go before a test like this could become available, the results are promising.\n\"When used alongside other diagnostic techniques, a blood test like this could be a real help.\"", "correct_choice": "UK researchers have developed a test to detect Alzheimer's disease in its earliest stages.", "list_choices": ["UK researchers have developed a test to detect Alzheimer's disease in its earliest stages.", "A blood test that helps Alzheimer's disease patients detect the most important markers in their blood is being developed, according to research."], "lbl": 0} 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | rouge-metric 3 | ipdb 4 | sentencepiece 5 | deepspeed 6 | accelerate 7 | rouge-score 8 | datasets 9 | fairseq 10 | rouge 11 | tqdm 12 | protobuf==3.19.0 13 | git+https://github.com/huggingface/transformers.git@6690ba3f4d036bc39bdf29ec98daf2c693442503 14 | bitsandbytes -------------------------------------------------------------------------------- /software_LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright [2022] [Derek Tam] 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. -------------------------------------------------------------------------------- /src/compute_fib_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import statistics 4 | import os 5 | import numpy as np 6 | 7 | from src.get_results import get_medianScore_perDataset 8 | 9 | 10 | 11 | def get_fibResult(model, datasets, dataset_len): 12 | 13 | 14 | acc_acrossDataset = [] 15 | for dataset in datasets: 16 | acc_perDataset = get_medianScore_perDataset(dataset, [model]) 17 | acc_acrossDataset.append(acc_perDataset[0]) 18 | 19 | 20 | acc_acrossDataset = np.asarray(acc_acrossDataset) * 100 21 | final_score = float(np.dot(acc_acrossDataset, dataset_len) / np.sum(dataset_len)) 22 | print(f"The final score is {round(final_score, 3)}") 23 | 24 | def getXSum_fibResults(model): 25 | 26 | datasets = ["exp_out/multiple_choice/xsum/fib/binary_choice-using_bart-base_distractors", 27 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_bart-large_distractors", 28 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_bloom-560m_distractors", 29 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_distil-bart_distractors", 30 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_distil-pegasus_distractors", 31 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_pegasus_distractors", 32 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_t5-large_distractors"] 33 | 34 | dataset_len = np.asarray([463, 414, 479, 410, 437, 438, 483]) 35 | 36 | get_fibResult(model, datasets, dataset_len) 37 | 38 | 39 | def getCNNDM_fibResults(model): 40 | 41 | datasets = ["exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_banditsumm_distractors", 42 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_bert_lstm_pn_rl_distractors", 43 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_heter_graph_distractors", 44 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_lead3_distractors", 45 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_matchsumm_distractors", 46 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_mi_unsup_distractors", 47 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_neusumm_distractors", 48 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_oracle_disco_distractors", 49 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_oracle_distractors", 50 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_pacsum_bert_distractors", 51 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_pacsum_tfidf_distractors", 52 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_refresh_distractors", 53 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_rnn_ext_rl_distractors", 54 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_textrank_distractors", 55 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_textrank_st_distractors"] 56 | 57 | dataset_len = np.asarray([26, 23, 22, 5, 21, 34, 24, 72, 54, 12, 27, 31, 24, 36, 46]) 58 | 59 | get_fibResult(model, datasets, dataset_len) 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('-m', "--model", type=str, required=True) 64 | parser.add_argument('-d', "--dataset", choices=["xsum", "cnn_dm"]) 65 | args = parser.parse_args() 66 | 67 | if args.dataset == "xsum": 68 | getXSum_fibResults(args.model) 69 | else: 70 | getCNNDM_fibResults(args.model) -------------------------------------------------------------------------------- /src/constructors.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import logging 4 | 5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM 6 | from src.utils.util import get_value_from_key_matching_regex 7 | from src.models.EncoderDecoderWrappers_forMulChoice import EncoderDecoderWrappers_forMulChoice 8 | from src.models.DecoderWrappers_forMulChoice import DecoderWrappers_forMulChoice 9 | from src.models.model_flags import DICT_REGEX_OF_MODEL_TYPE, DICT_REGEX_OF_DEVICE_MAP, DICT_REGEX_OF_TOKENIZERS 10 | 11 | 12 | def log_parameter_count(model): 13 | total_numParam = 0 14 | for name, parameter in model.named_parameters(): 15 | total_numParam += parameter.numel() 16 | logging.info(f"Total number of parameters in model: {total_numParam}") 17 | 18 | 19 | def construct_hugFace_objects(model_name, max_seq_len): 20 | ''' 21 | 22 | 23 | Args: 24 | model_name: 25 | max_seq_len: 26 | 27 | Returns: 28 | transformer: 29 | hugFaceConfig_forModel: 30 | tokenizer: 31 | input_prefix: Depends on how model was trained. 32 | ''' 33 | tokenizer = get_value_from_key_matching_regex(DICT_REGEX_OF_TOKENIZERS, model_name)(model_name) 34 | tokenizer.max_seq_len = max_seq_len 35 | 36 | hugFaceConfig_forModel = AutoConfig.from_pretrained(model_name) 37 | 38 | # If model config has no input prefix, then we ignore it 39 | if hasattr(hugFaceConfig_forModel, "task_specific_params") and \ 40 | hugFaceConfig_forModel.task_specific_params is not None and \ 41 | "summarization" in hugFaceConfig_forModel.task_specific_params and \ 42 | "prefix" in hugFaceConfig_forModel.task_specific_params["summarization"]: 43 | if "flan" not in model_name: 44 | input_prefix = hugFaceConfig_forModel.task_specific_params["summarization"]["prefix"] 45 | logging.info('Input Prefix: '+input_prefix) 46 | else: 47 | input_prefix = None 48 | logging.info('Evaluating FLAN but ignoring prompt') 49 | else: 50 | input_prefix = None 51 | 52 | return hugFaceConfig_forModel, tokenizer, input_prefix 53 | 54 | def construct_models(model_name, use_hugFace_parallelism, use_bitsandbytes): 55 | 56 | model_type = get_value_from_key_matching_regex(DICT_REGEX_OF_MODEL_TYPE, model_name) 57 | device_map = get_value_from_key_matching_regex(DICT_REGEX_OF_DEVICE_MAP, model_name) 58 | logging.info('Model Type: ' + model_type) 59 | logging.info('Loading Model : ' + model_name) 60 | 61 | if model_type == "encoder_decoder": 62 | if use_hugFace_parallelism: 63 | logging.info('Using HuggingFace Parallelism') 64 | assert use_bitsandbytes == False 65 | transformer = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map=device_map) 66 | logging.info(transformer.hf_device_map) 67 | elif use_bitsandbytes: 68 | logging.info('Using BitsAndBytes') 69 | assert use_hugFace_parallelism == False 70 | transformer = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map=device_map, load_in_8bit=True) 71 | logging.info(transformer.hf_device_map) 72 | else: 73 | transformer = AutoModelForSeq2SeqLM.from_pretrained(model_name) 74 | model = EncoderDecoderWrappers_forMulChoice(transformer) 75 | else: 76 | assert model_type == "decoder" 77 | if use_hugFace_parallelism: 78 | logging.info('Using HuggingFace Parallelism') 79 | assert use_bitsandbytes == False 80 | transformer = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map) 81 | logging.info(transformer.hf_device_map) 82 | elif use_bitsandbytes: 83 | logging.info('Using BitsAndBytes') 84 | assert use_hugFace_parallelism == False 85 | transformer = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, load_in_8bit=True) 86 | logging.info(transformer.hf_device_map) 87 | else: 88 | transformer = AutoModelForCausalLM.from_pretrained(model_name) 89 | model = DecoderWrappers_forMulChoice(transformer) 90 | 91 | log_parameter_count(transformer) 92 | 93 | return model, transformer -------------------------------------------------------------------------------- /src/data/Batcher.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from src.data.Dataset import Dataset 3 | 4 | 5 | class Batcher(object): 6 | ''' 7 | Batcher is responsible for returning batches of data 8 | ''' 9 | def __init__(self, datasetReader, createDataset_fn, train_batchSize, eval_batchSize): 10 | 11 | self.datasetReader = datasetReader 12 | self.createPytorchDataset_fn = createDataset_fn 13 | 14 | self.train_batchSize = train_batchSize 15 | self.eval_batchSize = eval_batchSize 16 | 17 | self.trainLoader = None 18 | self.devLoader = None 19 | self.testLoader = None 20 | self.mulChoiceLoader = None 21 | 22 | def _init_trainLoader(self): 23 | trainData = self.datasetReader.read_origData("train") 24 | train_pytorchDatasetClass = self.createPytorchDataset_fn(trainData) 25 | self.trainLoader = data.DataLoader(train_pytorchDatasetClass, 26 | batch_size=self.train_batchSize, 27 | shuffle=True, 28 | collate_fn=train_pytorchDatasetClass.collate_fn) 29 | 30 | def _init_devLoader(self): 31 | devData = self.datasetReader.read_origData("dev") 32 | dev_pytorchDatasetClass = self.createPytorchDataset_fn(devData) 33 | self.devLoader = data.DataLoader(dev_pytorchDatasetClass, 34 | batch_size=self.eval_batchSize, 35 | shuffle=False, 36 | collate_fn=dev_pytorchDatasetClass.collate_fn) 37 | 38 | def _init_testLoader(self): 39 | testData = self.datasetReader.read_origData("test") 40 | test_pytorchDatasetClass = self.createPytorchDataset_fn(testData) 41 | self.testLoader = data.DataLoader(test_pytorchDatasetClass, 42 | batch_size=self.eval_batchSize, 43 | shuffle=False, 44 | collate_fn=test_pytorchDatasetClass.collate_fn) 45 | 46 | def _init_mulChoiceLoader(self, mulChoiceFilepath): 47 | mulChoiceData = self.datasetReader.read_mulChoiceData(mulChoiceFilepath) 48 | mulChoice_pytorchDatasetClass = self.createPytorchDataset_fn(mulChoiceData) 49 | self.mulChoiceLoader = data.DataLoader(mulChoice_pytorchDatasetClass, 50 | batch_size=self.eval_batchSize, 51 | shuffle=False, 52 | collate_fn=mulChoice_pytorchDatasetClass.collate_fn) 53 | 54 | def get_trainBatches(self): 55 | if self.trainLoader is None: 56 | self._init_trainLoader() 57 | 58 | while True: 59 | for x in self.trainLoader: 60 | yield x 61 | 62 | def get_devBatches(self): 63 | if self.devLoader is None: 64 | self._init_devLoader() 65 | 66 | for x in self.devLoader: 67 | yield x 68 | 69 | def get_testBatches(self): 70 | if self.testLoader is None: 71 | self._init_testLoader() 72 | 73 | for x in self.testLoader: 74 | yield x 75 | 76 | def get_mulChoiceBatches(self, mulChoiceFilepath): 77 | if self.mulChoiceLoader is None: 78 | self._init_mulChoiceLoader(mulChoiceFilepath) 79 | 80 | for x in self.mulChoiceLoader: 81 | yield x 82 | -------------------------------------------------------------------------------- /src/data/Dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils import data 3 | 4 | class Dataset(data.Dataset): 5 | def __init__(self, data): 6 | self.data = data 7 | 8 | def __len__(self): 9 | return len(self.data) 10 | 11 | def __getitem__(self, get_idx): 12 | return self.data[get_idx] -------------------------------------------------------------------------------- /src/data/multiple_choice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from torch.utils import data 4 | import copy 5 | import math 6 | import re 7 | import logging 8 | 9 | from src.utils.CONSTANTS import NULL_STRING 10 | from src.data.templates import SUMMARIZATION_PROMPT_TEMPLATES 11 | from src.data.preprocess_data import tokenize_prompted_input_text, does_tokenizer_addBosEosTokens 12 | 13 | class MultipleChoiceReader(object): 14 | ''' 15 | MultipleChoiceReader reads any multiple choice dataset 16 | ''' 17 | 18 | def read_mulChoiceData(self, mulChoiceFilepath): 19 | ''' 20 | Read dataset 21 | 22 | Args: 23 | mcFilmulChoiceFilepathepath: 24 | 25 | Returns: 26 | listOf_MCDatapoints: 27 | ''' 28 | fd = open(mulChoiceFilepath, 'r') 29 | 30 | listOfDatapoints = [] 31 | for line in fd.readlines(): 32 | datapoint = json.loads(line) 33 | listOfDatapoints.append(datapoint) 34 | 35 | return listOfDatapoints 36 | 37 | NULL_DATAPOINT = { 38 | "id": NULL_STRING, 39 | "input": NULL_STRING, 40 | "list_choices": [NULL_STRING, NULL_STRING], 41 | "correct_choice": NULL_STRING, 42 | "lbl": 0 43 | } 44 | 45 | 46 | class MultipleChoiceDataset(data.Dataset): 47 | def __init__(self, 48 | data, 49 | tokenizer, 50 | promptTemplate_idx, 51 | input_prefix, 52 | device, 53 | world_size): 54 | 55 | # if the device is an integer, then this means we are using parallelism and have to split the dataset among each device. 56 | if isinstance(device, int): 57 | 58 | num_datapoints_per_split = math.ceil(len(data) / world_size) 59 | 60 | device_data = [] 61 | for idx, datapoint in enumerate(data): 62 | if idx % world_size == device: 63 | device_data.append(datapoint) 64 | 65 | # We ensure each device sees the same number of samples, so that the number of batches is same per device. 66 | # If the batch size is 1, and world_size=2, then number of batches will be different per device. 67 | # This will cause a race condition for parallelism. 68 | if len(device_data) < num_datapoints_per_split: 69 | device_data.append(NULL_DATAPOINT) 70 | assert len(device_data) == num_datapoints_per_split 71 | self.data = device_data 72 | # For non-parallelism 73 | else: 74 | self.data = data 75 | 76 | self.tokenizer = tokenizer 77 | 78 | # Uses no template and adds the input prefix only. 79 | # Note that the 0 template is just the data. 80 | if input_prefix is not None: 81 | assert promptTemplate_idx == 0 82 | self.prompt_template = input_prefix + SUMMARIZATION_PROMPT_TEMPLATES[0] 83 | 84 | # Create template from prompt template idx 85 | else: 86 | self.prompt_template = SUMMARIZATION_PROMPT_TEMPLATES[promptTemplate_idx] 87 | 88 | if promptTemplate_idx == 0: 89 | # If the tokenizer does not insert a BOS or EOS token for an empty string, we need to add an empty space 90 | # so that we can have a null input when computing PMI. This holds for BLOOM. 91 | # This only has to be done for the zero prompt since there is no additional text in the prompt. 92 | # Though bloom was not pretrained to insert this empty space, it should not affect performance much. 93 | if len(tokenizer("")["input_ids"]) == 0: 94 | self.prompt_template = " " + self.prompt_template 95 | 96 | logging.info('Prompt Template: '+self.prompt_template) 97 | self.device = device 98 | self.add_bosToken, self.add_eosToken = does_tokenizer_addBosEosTokens(self.tokenizer) 99 | 100 | 101 | def __len__(self): 102 | return len(self.data) 103 | 104 | def __getitem__(self, get_idx): 105 | datapoint = self.data[get_idx] 106 | 107 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(self.tokenizer, 108 | datapoint, 109 | self.prompt_template, 110 | self.add_bosToken, 111 | self.add_eosToken) 112 | nullInput_dict = self.tokenizer(nullInput_txt, 113 | return_tensors="pt", 114 | truncation=True) 115 | nullInput_ids = nullInput_dict["input_ids"][0] 116 | nullInput_masks = nullInput_dict["attention_mask"][0] 117 | 118 | allChoices_ids = [] 119 | allChoices_masks = [] 120 | 121 | for choice in datapoint["list_choices"]: 122 | choiceDict = self.tokenizer(choice, return_tensors="pt", truncation=True) 123 | # Skip BOS token for choices since it is a continuation of the input 124 | # TODO Currently this assumes that a BOS token is not added for encoder-decoder models. 125 | # TODO add logic to NOT ignore the BOS token in the choices for encoder-decoder model 126 | # Note that all T5 variants do not add a BOS token. 127 | if self.add_bosToken: 128 | start_idx = 1 129 | else: 130 | start_idx = 0 131 | allChoices_ids.append(choiceDict["input_ids"][0][start_idx:]) 132 | allChoices_masks.append(choiceDict["attention_mask"][0][start_idx:]) 133 | 134 | return {"id": datapoint["id"], 135 | "input": input_txt, 136 | "input_ids": input_ids, 137 | "input_masks": input_masks, 138 | "null_input_ids": nullInput_ids, 139 | "null_input_masks": nullInput_masks, 140 | "list_choices": datapoint["list_choices"], 141 | "all_choices_ids": allChoices_ids, 142 | "all_choices_lbls": copy.deepcopy(allChoices_ids), 143 | "all_choices_masks": allChoices_masks, 144 | "correct_choice": datapoint["correct_choice"], 145 | "lbl": datapoint["lbl"]} 146 | 147 | def collate_fn(self, batch_ofDatapoints): 148 | ''' 149 | Convert a batch of datapoints into a datapoint that is batched. This is meant to 150 | override the default collate function in pytorch. 151 | 152 | Args: 153 | batch_ofDatapoints: 154 | 155 | Returns: 156 | 157 | ''' 158 | datapoint_batched = {} 159 | 160 | for datapoint in batch_ofDatapoints: 161 | for (k, v) in datapoint.items(): 162 | if k in datapoint_batched: 163 | # Each value in all_choices is already a list, so we extend and not append. 164 | if "all_choices" in k: 165 | datapoint_batched[k].extend(v) 166 | else: 167 | datapoint_batched[k].append(v) 168 | else: 169 | # Each value in all_choices is already a list, so we do not need to 170 | # initialize a list with v in it, and can just use v. 171 | if "all_choices" in k: 172 | datapoint_batched[k] = v 173 | else: 174 | datapoint_batched[k] = [v] 175 | 176 | for (k, batch_ofValues) in datapoint_batched.items(): 177 | # If id or mask is in key, this means we need to pad to the longest sequence length 178 | if ("ids" in k) or ("masks" in k) or (k == "all_choices_lbls"): 179 | if "ids" in k: 180 | padToken_id = self.tokenizer.pad_token_id 181 | if padToken_id is None: 182 | padToken_id = self.tokenizer.eos_token_id 183 | elif "masks" in k: 184 | padToken_id = 0 185 | elif k == "all_choices_lbls": 186 | padToken_id = -100 187 | else: 188 | raise ValueError(f"The key {k} has ids or masks but is not recognized") 189 | datapoint_batched[k] = torch.nn.utils.rnn.pad_sequence( 190 | batch_ofValues, 191 | batch_first=True, 192 | padding_value=padToken_id) 193 | 194 | if self.device is not None: 195 | datapoint_batched[k] = datapoint_batched[k].to(self.device) 196 | 197 | elif isinstance(batch_ofValues[0], int): 198 | datapoint_batched[k] = torch.tensor(batch_ofValues) 199 | 200 | if self.device is not None: 201 | datapoint_batched[k] = datapoint_batched[k].to(self.device) 202 | 203 | 204 | 205 | return datapoint_batched 206 | 207 | 208 | -------------------------------------------------------------------------------- /src/data/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | 4 | 5 | def does_tokenizer_addBosEosTokens(tokenizer): 6 | # Compute whether to add BOS or EOS tokens by tokenizing a dummy input. 7 | filler_ids = tokenizer("hello")["input_ids"] 8 | add_bosToken = False 9 | if filler_ids[0] == tokenizer.bos_token_id: 10 | add_bosToken = True 11 | 12 | add_eosToken = False 13 | if filler_ids[-1] == tokenizer.eos_token_id: 14 | add_eosToken = True 15 | 16 | return add_bosToken, add_eosToken 17 | 18 | def tokenize_prompted_input_text(tokenizer, datapoint, prompt_template, add_bosToken, add_eosToken): 19 | ''' 20 | Gets the input text and tokenizes it from prompt. 21 | 22 | Assumes the datapoint is a dictionary and the prompt template specifies 23 | which value of the datapoint to use based on the key wrapped in []. 24 | For example, [input] should be used to specify the input. 25 | Note, the prompt template cannot use [] in any other locations. 26 | 27 | 28 | Args: 29 | tokenizer: 30 | datapoint: 31 | prompt_template: 32 | add_bosToken: 33 | add_eosToken: 34 | 35 | Returns: 36 | 37 | ''' 38 | template_nonDataKeys = re.split(r"\[.*\]", prompt_template) 39 | template_dataKeys = re.findall(r"\[.*\]", prompt_template) 40 | 41 | assert len(template_nonDataKeys) == len(template_dataKeys) + 1 42 | 43 | remaining_seqLen = tokenizer.max_seq_len 44 | num_dataKeys = len(template_dataKeys) 45 | 46 | list_nonDataKeys_txt = [] 47 | list_nonDataKeys_ids = [] 48 | list_nonDataKeys_mask = [] 49 | 50 | for nonDataKey in template_nonDataKeys: 51 | if len(nonDataKey) > 0: 52 | list_nonDataKeys_txt.append(nonDataKey) 53 | nonDataKey_dict = tokenizer(nonDataKey, add_special_tokens=False) 54 | list_nonDataKeys_ids.append(nonDataKey_dict["input_ids"]) 55 | list_nonDataKeys_mask.append(nonDataKey_dict["attention_mask"]) 56 | remaining_seqLen -= len(nonDataKey_dict["input_ids"]) 57 | else: 58 | list_nonDataKeys_txt.append("") 59 | list_nonDataKeys_ids.append([]) 60 | list_nonDataKeys_mask.append([]) 61 | 62 | # This list will recombine the nonDataKeys and dataKeys in the correct order. 63 | list_split_txt = [] 64 | list_split_ids = [] 65 | list_split_masks = [] 66 | 67 | if add_bosToken: 68 | list_split_ids.append(tokenizer.bos_token_id) 69 | list_split_masks.append(1) 70 | remaining_seqLen -= 1 71 | 72 | # We have to compute remaining sequence length at the beginning 73 | # to know how much is left over. 74 | if add_eosToken: 75 | remaining_seqLen -= 1 76 | 77 | # Add any text in template that appears before the first data key. 78 | list_split_txt.append(list_nonDataKeys_txt[0]) 79 | list_split_ids.extend(list_nonDataKeys_ids[0]) 80 | list_split_masks.extend(list_nonDataKeys_mask[0]) 81 | 82 | 83 | for i in range(num_dataKeys): 84 | dataKey = template_dataKeys[i].replace("[", "").replace("]", "") 85 | dataValue = datapoint[dataKey] 86 | 87 | dataValue_dict = tokenizer(dataValue, add_special_tokens=False) 88 | 89 | value_ids = dataValue_dict["input_ids"] 90 | value_mask = dataValue_dict["attention_mask"] 91 | 92 | len_value = len(dataValue_dict["input_ids"]) 93 | if len_value > remaining_seqLen: 94 | value_txt = tokenizer.decode(value_ids[:remaining_seqLen], add_special_tokens=False) 95 | value_ids = value_ids[:remaining_seqLen] 96 | value_mask = value_mask[:remaining_seqLen] 97 | remaining_seqLen = 0 98 | else: 99 | value_txt = tokenizer.decode(value_ids, add_special_tokens=False) 100 | remaining_seqLen -= len_value 101 | 102 | # Add tokenized values from data 103 | list_split_txt.append(value_txt) 104 | list_split_ids.extend(value_ids) 105 | list_split_masks.extend(value_mask) 106 | 107 | # Add tokenized text between data 108 | # Increment by 1 since we add non-data key text at the very beginning 109 | list_split_txt.append(list_nonDataKeys_txt[i+1]) 110 | list_split_ids.extend(list_nonDataKeys_ids[i+1]) 111 | list_split_masks.extend(list_nonDataKeys_mask[i+1]) 112 | 113 | if add_eosToken: 114 | list_split_ids.append(tokenizer.eos_token_id) 115 | list_split_masks.append(1) 116 | 117 | return torch.tensor(list_split_ids), torch.tensor(list_split_masks), "".join(list_split_txt), "".join(template_nonDataKeys) 118 | -------------------------------------------------------------------------------- /src/data/preprocess_data_test.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from src.data.data_preprocess import tokenize_prompted_input_text, does_tokenizer_addBosEosTokens 4 | from transformers import AutoTokenizer 5 | 6 | 7 | from src.data.templates import SUMMARIZATION_PROMPT_TEMPLATES 8 | from src.utils.test_helpers import check_string_ends_with_another, check_string_starts_with_another, check_string_subset_of_another, check_string_equality 9 | 10 | SHORT_DATAPOINT = { 11 | "input": "Forbes said Vergara's role as Gloria in Modern Family and some lucrative product endorsements helped her earn $43m (\u00a332.6m) in the last 12 months.\n It marks the fifth year the Colombian-American actress has topped the chart.\n Forbes also said she earned more than any of her male counterparts in the past year.\n The Big Bang Theory's Kaley Cuoco was the second-highest paid actress, earning $24.5m (\u00a318.6m).\n Cuoco tied with Vergara at the top of last year's Forbes list, when both actresses earned $28.5m (\u00a321.6m).\n The Mindy Project's Mindy Kaling is the biggest climber in this year's chart. Her earnings of $15m (\u00a311.4m) helped her to rise from eighth place in 2015 to third this year.\n Mariska Hargitay, who appears in Law & Order: Special Victims Unit, and Grey's Anatomy star Ellen Pompeo rounded off the top five.\n Source: Forbes\n This year's highest new entry on the Forbes list was Priyanka Chopra, who appears in ABC drama Quantico. She was the eighth highest earner with $11m (\u00a38.4m).\n Chopra, who is well known in India, is set to become more familiar to western audiences next year when she stars in Baywatch alongside Dwayne Johnson - the world's highest paid actor.\n Scandal star Kerry Washington, Stana Katic from Castle, The Good Wife's Julianna Margulies and Vergara's Modern Family co-star Julie Bowen also featured in this year's top 10.\n Follow us on Twitter @BBCNewsEnts, on Instagram, or if you have a story suggestion email entertainment.news@bbc.co.uk.", 12 | "choice": "Modern Family star Sofia Vergara has retained her title as the highest paid actress on US television, according to the latest Forbes magazine rich list." 13 | } 14 | 15 | LONG_DATAPOINT = { 16 | "input": "Archery, fencing, weightlifting and wheelchair rugby have also missed out.\n Cycling - which brought Team GB 12 medals in Rio - has had its funding cut by more than \u00a34m to \u00a325.98m.\n Badminton England chief executive Adrian Christy said he was \"staggered\" by the \"incomprehensible\" decision to remove the sport's funding.\n A total of \u00a3345m will be invested in 31 Olympic and Paralympic sports - \u00a32m less than the record \u00a3347m allocated for the Rio Games.\n As a result, UK Sport has set Team GB a target of winning 51-85 Olympic medals, and 115-162 Paralympic medals in Tokyo.\n Britain enjoyed unprecedented success at Rio 2016, with the Olympics yielding 67 medals and the Paralympics 147.\n Chair of UK Sport Rod Carr said the government, which provides funding alongside National Lottery money, has \"confirmed its commitment\" for Tokyo 2020.\n He added: \"These are critical funding decisions for sports to take them on their journey to Tokyo 2020 and beyond so the historic success at Rio can be maintained.\"\n Badminton, which was set a target of winning a medal in Rio, is the only sport that earned a podium place in the summer to have its funding removed.\n Marcus Ellis and Chris Langridge took bronze in the men's doubles after the sport was given \u00a35.74m in the last cycle.\n Christy said the decision represents a \"catastrophic impact on the sport\" and Badminton England would \"fight for the hopes and dreams\" of its players.\n \"How can you return from the best Games for more than a decade, in a year where our players have demonstrated world-class performances and where we can demonstrate the journey to Tokyo is on track, only be to have every penny of investment withdrawn?\" he said.\n \"What have we done wrong?\" added GB Badminton's performance director Jon Austin.\n Judo, which was given the same target as badminton and also claimed one bronze medal, has had its funding increased slightly.\n Liz Nicholl, CEO of UK Sport, said the decision to cut funding was not taken lightly.\n \"We would like to invest in every sport but the reality is we have to prioritise to protect and enhance the medal potential,\" she said.\n \"If we under-invest across the board then the British teams will ultimately underperform at the Games and medal success will be put at risk.\"\n Sports minister Tracey Crouch added: \"UK Sport's approach to elite sport has proven successful in Beijing, London and Rio and the ambition to win more medals in Tokyo is a bold one that, if achieved, would mean a sensational summer of sport in 2020.\"\n Basketball had its funding withdrawn in 2014 - and handball and volleyball lost theirs in 2012 - but say a UK Sport review last year to build \"performance pathways for future success\" was supposed to be aimed at such sports.\n A British Basketball statement, in conjunction with volleyball and handball, said: \"It appears that UK Sport has no interest in team sports and in particular refuses to take responsibility for the need to fund their performance development, which was identified in its own review.\n \"With UK Sport's investment budget approaching \u00a3350m, it borders on intransigence to pass responsibility to government and other funding bodies who are not set up to fund the development of high-performance sport.\"\n UK Sport says investment in the five Olympic sports and two Paralympic sports added for Tokyo 2020 is yet to be confirmed.\n Baseball/softball will return to the programme, with karate, skateboard, sports climbing and surfing also added, while Para-taekwondo and Para-badminton join the Paralympic programme.\n UK Sport says funding will be determined \"following further exploration of medal potential\", with \u00a39m of the \u00a3345m total still to be allocated.\n Liam Carroll, head coach of the GB baseball team, said: \"The key to unlocking our potential is investment and I'm pleased that UK Sport has left the door open.\n \"We look forward to the opportunity to impress upon them that getting behind Great Britain Baseball can extend their tremendous track record of investing in Olympic medal contenders.\"", 17 | "choice": "Badminton is one of five sports to lose all UK Sport funding for the 2020 Olympics in Tokyo - after Britain claimed a bronze in the sport in Rio." 18 | } 19 | 20 | BASIC_PROMPT_TEMPLATE = "[input]" 21 | PROMPT_TEMPLATE_WITH_TXT_ON_BOTH_SIDES = "The summary of \"[input]\" is " 22 | PROMPT_TEMPLATE_WITH_TXT_ON_LEFT_SIDE = "The summary of \"[input]" 23 | PROMPT_TEMPLATE_WITH_TXT_ON_RIGHT_SIDE = "[input]\" is " 24 | 25 | 26 | def test_tokenize_input(tokenizer): 27 | add_bosToken, add_eosToken = does_tokenizer_addBosEosTokens(tokenizer) 28 | 29 | 30 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer, 31 | SHORT_DATAPOINT, 32 | BASIC_PROMPT_TEMPLATE, 33 | add_bosToken, 34 | add_eosToken) 35 | print("Length of Input Ids: ", len(input_ids)) 36 | if add_bosToken: 37 | assert input_ids[0] == tokenizer.bos_token_id 38 | if add_eosToken: 39 | assert input_ids[0] == tokenizer.eos_token_id 40 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True) 41 | check_string_equality(reconstructed_inputTxt, input_txt) 42 | check_string_equality(reconstructed_inputTxt, SHORT_DATAPOINT["input"]) 43 | 44 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer, 45 | SHORT_DATAPOINT, 46 | PROMPT_TEMPLATE_WITH_TXT_ON_BOTH_SIDES, 47 | add_bosToken, 48 | add_eosToken) 49 | print("Length of Input Ids: ", len(input_ids)) 50 | if add_bosToken: 51 | assert input_ids[0] == tokenizer.bos_token_id 52 | if add_eosToken: 53 | assert input_ids[0] == tokenizer.eos_token_id 54 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True) 55 | check_string_equality(reconstructed_inputTxt, input_txt) 56 | check_string_equality(reconstructed_inputTxt, f"The summary of \"{SHORT_DATAPOINT['input']}\" is ") 57 | 58 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer, 59 | SHORT_DATAPOINT, 60 | PROMPT_TEMPLATE_WITH_TXT_ON_LEFT_SIDE, 61 | add_bosToken, 62 | add_eosToken) 63 | print("Length of Input Ids: ", len(input_ids)) 64 | if add_bosToken: 65 | assert input_ids[0] == tokenizer.bos_token_id 66 | if add_eosToken: 67 | assert input_ids[0] == tokenizer.eos_token_id 68 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True) 69 | check_string_equality(reconstructed_inputTxt, input_txt) 70 | check_string_equality(reconstructed_inputTxt, f"The summary of \"{SHORT_DATAPOINT['input']}") 71 | 72 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer, 73 | SHORT_DATAPOINT, 74 | PROMPT_TEMPLATE_WITH_TXT_ON_RIGHT_SIDE, 75 | add_bosToken, 76 | add_eosToken) 77 | print("Length of Input Ids: ", len(input_ids)) 78 | if add_bosToken: 79 | assert input_ids[0] == tokenizer.bos_token_id 80 | if add_eosToken: 81 | assert input_ids[0] == tokenizer.eos_token_id 82 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True) 83 | check_string_equality(reconstructed_inputTxt, input_txt) 84 | check_string_equality(reconstructed_inputTxt, f"{SHORT_DATAPOINT['input']}\" is ") 85 | 86 | 87 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer, 88 | LONG_DATAPOINT, 89 | BASIC_PROMPT_TEMPLATE, 90 | add_bosToken, 91 | add_eosToken) 92 | print("Length of Input Ids: ", len(input_ids)) 93 | if add_bosToken: 94 | assert input_ids[0] == tokenizer.bos_token_id 95 | if add_eosToken: 96 | assert input_ids[0] == tokenizer.eos_token_id 97 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True) 98 | check_string_equality(reconstructed_inputTxt, input_txt) 99 | check_string_subset_of_another(reconstructed_inputTxt, LONG_DATAPOINT["input"]) 100 | 101 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer, 102 | LONG_DATAPOINT, 103 | PROMPT_TEMPLATE_WITH_TXT_ON_BOTH_SIDES, 104 | add_bosToken, 105 | add_eosToken) 106 | print("Length of Input Ids: ", len(input_ids)) 107 | if add_bosToken: 108 | assert input_ids[0] == tokenizer.bos_token_id 109 | if add_eosToken: 110 | assert input_ids[0] == tokenizer.eos_token_id 111 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True) 112 | check_string_equality(reconstructed_inputTxt, input_txt) 113 | check_string_subset_of_another(reconstructed_inputTxt\ 114 | .replace("The summary of \"", "")\ 115 | .replace("\" is ", ""), 116 | LONG_DATAPOINT["input"]) 117 | check_string_starts_with_another(reconstructed_inputTxt, "The summary of \"") 118 | check_string_ends_with_another(reconstructed_inputTxt, "\" is ") 119 | 120 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer, 121 | LONG_DATAPOINT, 122 | PROMPT_TEMPLATE_WITH_TXT_ON_LEFT_SIDE, 123 | add_bosToken, 124 | add_eosToken) 125 | print("Length of Input Ids: ", len(input_ids)) 126 | if add_bosToken: 127 | assert input_ids[0] == tokenizer.bos_token_id 128 | if add_eosToken: 129 | assert input_ids[0] == tokenizer.eos_token_id 130 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True) 131 | check_string_equality(reconstructed_inputTxt, input_txt) 132 | check_string_subset_of_another(reconstructed_inputTxt\ 133 | .replace("The summary of \"", ""), 134 | LONG_DATAPOINT["input"]) 135 | check_string_starts_with_another(reconstructed_inputTxt, "The summary of \"") 136 | 137 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer, 138 | LONG_DATAPOINT, 139 | PROMPT_TEMPLATE_WITH_TXT_ON_RIGHT_SIDE, 140 | add_bosToken, 141 | add_eosToken) 142 | print("Length of Input Ids: ", len(input_ids)) 143 | if add_bosToken: 144 | assert input_ids[0] == tokenizer.bos_token_id 145 | if add_eosToken: 146 | assert input_ids[0] == tokenizer.eos_token_id 147 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True) 148 | check_string_equality(reconstructed_inputTxt, input_txt) 149 | check_string_subset_of_another(reconstructed_inputTxt\ 150 | .replace("\" is ", ""), 151 | LONG_DATAPOINT["input"]) 152 | check_string_ends_with_another(reconstructed_inputTxt, "\" is ") 153 | 154 | if __name__ == "__main__": 155 | 156 | for tokenizer_name in ["bigscience/bloom-560m", 157 | "facebook/opt-125m", 158 | "gpt2-xl"]: 159 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 160 | tokenizer.max_seq_len = 512 161 | test_tokenize_input(tokenizer) -------------------------------------------------------------------------------- /src/data/templates.py: -------------------------------------------------------------------------------- 1 | SUMMARIZATION_PROMPT_TEMPLATES = { 2 | 0: "[input]", 3 | 1: "The summary of \"[input]\" is ", 4 | 2: "Summarize: [input]" 5 | } 6 | -------------------------------------------------------------------------------- /src/eval/PredictionLogger.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class PredictionLogger(object): 4 | def __init__(self, logger_fp): 5 | self.logger_fp = logger_fp 6 | self.logger_file = open(self.logger_fp, 'w+') 7 | 8 | # From https://stackoverflow.com/questions/5558418/list-of-dicts-to-from-dict-of-lists 9 | def _convert_dictOfLists_to_listOfDicts(self, dictOfLists): 10 | listOfDicts = [] 11 | for datapoint_values in zip(*dictOfLists.values()): 12 | listOfDicts.append(dict(zip(dictOfLists, datapoint_values))) 13 | return listOfDicts 14 | 15 | def log_batch(self, batchOf_evalInfo): 16 | listOf_evalInfo = self._convert_dictOfLists_to_listOfDicts(batchOf_evalInfo) 17 | for eval_info in listOf_evalInfo: 18 | self.logger_file.write(json.dumps(eval_info) + '\n') -------------------------------------------------------------------------------- /src/eval/Scorer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Scorer(object): 4 | 5 | def __init__(self, metric): 6 | self.metric = metric 7 | 8 | if metric == "multiple_choice": 9 | self.total_numCorrectDatapoints = 0 10 | self.total_numDatapoints = 0 11 | else: 12 | raise ValueError(f"Invalid metric {metric}") 13 | 14 | def add_batch(self, batchOf_evalInfo): 15 | pred_choice = np.asarray(batchOf_evalInfo["pred_choice"]) 16 | lbl = np.asarray(batchOf_evalInfo["lbl"]) 17 | 18 | which_datapointsCorrect = pred_choice == lbl 19 | self.total_numCorrectDatapoints += np.sum(which_datapointsCorrect) 20 | self.total_numDatapoints += which_datapointsCorrect.shape[0] 21 | 22 | return which_datapointsCorrect.tolist() 23 | 24 | def get_score(self): 25 | mulChoice_acc = float(round(self.total_numCorrectDatapoints / self.total_numDatapoints,3)) 26 | return {"multiple-choice-accuracy": mulChoice_acc} 27 | 28 | 29 | -------------------------------------------------------------------------------- /src/evaluate_mulChoice.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import torch 4 | import tqdm 5 | import os 6 | import logging 7 | import torch.distributed as dist 8 | 9 | import deepspeed 10 | from transformers.deepspeed import HfDeepSpeedConfig 11 | 12 | from src.data.multiple_choice import MultipleChoiceDataset, MultipleChoiceReader 13 | from src.data.Batcher import Batcher 14 | 15 | 16 | from src.eval.Scorer import Scorer 17 | from src.eval.PredictionLogger import PredictionLogger 18 | 19 | from src.constructors import construct_hugFace_objects, construct_models 20 | 21 | from src.utils.util import set_seeds, reduce_gatheredOutput, get_mulChoice_outputDir 22 | from src.utils.deepspeed import get_deepspeedConfig 23 | 24 | def evaluate_dataset(mul_choice_filepath, 25 | deepspeed_engine, 26 | model, 27 | world_size, 28 | eval_batchSize, 29 | device, 30 | ignore_pointwise_mutual_information, 31 | ignore_length_normalization, 32 | compute_choices_iteratively, 33 | model_name, 34 | tokenizer, 35 | prompt_template_idx, 36 | input_prefix, 37 | debug): 38 | 39 | mcReader = MultipleChoiceReader() 40 | createDataset_fn = lambda data: MultipleChoiceDataset(data, tokenizer, prompt_template_idx, input_prefix, device, world_size) 41 | batcher = Batcher(mcReader, createDataset_fn, train_batchSize=None, eval_batchSize=eval_batchSize) 42 | 43 | if world_size is None or dist.get_rank() == 0: 44 | scorer = Scorer("multiple_choice") 45 | output_dir = get_mulChoice_outputDir(mul_choice_filepath, model_name, ignore_pointwise_mutual_information, ignore_length_normalization) 46 | if debug: 47 | output_fp = os.path.join("exp_out", "multiple_choice", "debug.json") 48 | else: 49 | output_fp = os.path.join(output_dir, f"predictions-prompt_{prompt_template_idx}.json") 50 | if os.path.exists(output_fp): 51 | return 52 | prediction_logger = PredictionLogger(output_fp) 53 | 54 | 55 | for batch in tqdm.tqdm(batcher.get_mulChoiceBatches(mul_choice_filepath)): 56 | with torch.no_grad(): 57 | # Uses deepspeed 58 | if world_size is not None: 59 | pred_choice, score_ofChoices, logProbs_ofAllChoicesIds, len_allChoices, logProbs_ofAllChoicesIds_condOnNullInput = deepspeed_engine.module.predict_mulChoice(batch, 60 | not ignore_pointwise_mutual_information, 61 | not ignore_length_normalization, 62 | compute_choices_iteratively) 63 | else: 64 | pred_choice, score_ofChoices, logProbs_ofAllChoicesIds, len_allChoices, logProbs_ofAllChoicesIds_condOnNullInput = model.predict_mulChoice(batch, 65 | not ignore_pointwise_mutual_information, 66 | not ignore_length_normalization, 67 | compute_choices_iteratively) 68 | 69 | batchOf_evalInfo = { 70 | "pred_choice": pred_choice, 71 | "score_of_choices": score_ofChoices, 72 | "log_probs_of_all_choices_ids": logProbs_ofAllChoicesIds, 73 | "len_all_choices": len_allChoices, 74 | "log_prob_of_all_choices_ids_cond_null_input": logProbs_ofAllChoicesIds_condOnNullInput if logProbs_ofAllChoicesIds_condOnNullInput is not None else [0 * len(logProbs_ofAllChoicesIds)], 75 | "input": batch["input"], 76 | "list_choices": batch["list_choices"], 77 | "lbl": batch["lbl"].cpu().numpy().tolist() 78 | } 79 | 80 | if world_size is not None: 81 | listOf_batchOf_evalInfo = [{}] * world_size 82 | dist.gather_object( 83 | batchOf_evalInfo, 84 | listOf_batchOf_evalInfo if dist.get_rank() == 0 else None, 85 | dst=0 86 | ) 87 | if dist.get_rank() == 0: 88 | batchOf_evalInfo = reduce_gatheredOutput(listOf_batchOf_evalInfo) 89 | 90 | if world_size is None or dist.get_rank() == 0: 91 | whichDatapoints_correct = scorer.add_batch(batchOf_evalInfo) 92 | batchOf_evalInfo.update({ 93 | "is_datapoint_correct": whichDatapoints_correct 94 | }) 95 | prediction_logger.log_batch(batchOf_evalInfo) 96 | 97 | if not debug: 98 | if world_size is None or dist.get_rank() == 0: 99 | with open(os.path.join(output_dir, "scores.jsonl"), 'a+') as f_out: 100 | dict_score = scorer.get_score() 101 | dict_score.update({ 102 | "pointwise_mutual_information": not ignore_pointwise_mutual_information, 103 | "length_normalization": not ignore_length_normalization, 104 | "dataset_filepath": mul_choice_filepath, 105 | "model": model_name, 106 | "prompt_template_idx": prompt_template_idx 107 | }) 108 | f_out.write(json.dumps(dict_score) + '\n') 109 | 110 | 111 | def evaluate_mulChoice(args): 112 | 113 | # Uses deepspeed 114 | if args.world_size is not None: 115 | hugFace_config, tokenizer, input_prefix = construct_hugFace_objects(args.model_name, args.max_seq_len) 116 | if hasattr(hugFace_config, "d_model"): 117 | model_dim = hugFace_config.d_model 118 | elif hasattr(hugFace_config, "hidden_size"): 119 | model_dim = hugFace_config.hidden_size 120 | else: 121 | raise ValueError("Cannot get model dimension from hugging face config") 122 | 123 | deepspeed_config = get_deepspeedConfig(args.eval_batch_size, args.world_size, model_dim) 124 | 125 | model, _ = construct_models(args.model_name, args.use_hugFace_parallelism, args.use_bitsandbytes) 126 | dschf = HfDeepSpeedConfig(deepspeed_config) # keep this object alive and create it before initializing the model 127 | 128 | deepspeed_engine = deepspeed.init_inference(model, 129 | mp_size=args.world_size, 130 | dtype=torch.float, 131 | replace_method='auto', 132 | replace_with_kernel_inject=True) 133 | deepspeed_engine.module.eval() # inference 134 | model = None 135 | else: 136 | hugFace_config, tokenizer, input_prefix = construct_hugFace_objects(args.model_name, args.max_seq_len) 137 | model, _ = construct_models(args.model_name, args.use_hugFace_parallelism, args.use_bitsandbytes) 138 | 139 | if not args.use_hugFace_parallelism and not args.use_bitsandbytes: 140 | model = model.to(args.device) 141 | 142 | model.eval() 143 | deepspeed_engine = None 144 | 145 | for mul_choice_filepath in args.mul_choice_filepath: 146 | if args.prompt_template_idx == -1: 147 | for prompt_template_idx in range(3): 148 | evaluate_dataset(mul_choice_filepath, 149 | deepspeed_engine, 150 | model, 151 | args.world_size, 152 | args.eval_batch_size, 153 | args.device, 154 | args.ignore_pointwise_mutual_information, 155 | args.ignore_length_normalization, 156 | args.compute_choices_iteratively, 157 | args.model_name, 158 | tokenizer, 159 | prompt_template_idx, 160 | input_prefix, 161 | args.debug) 162 | else: 163 | evaluate_dataset(mul_choice_filepath, 164 | deepspeed_engine, 165 | model, 166 | args.world_size, 167 | args.eval_batch_size, 168 | args.device, 169 | args.ignore_pointwise_mutual_information, 170 | args.ignore_length_normalization, 171 | args.compute_choices_iteratively, 172 | args.model_name, 173 | tokenizer, 174 | args.prompt_template_idx, 175 | input_prefix, 176 | args.debug) 177 | if __name__ == "__main__": 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument('-f', "--mul_choice_filepath", action='store', type=str, nargs='*', required=True) 180 | parser.add_argument("--max_seq_len", type=int, default=512) 181 | parser.add_argument('-m', "--model_name", required=True) 182 | parser.add_argument("--use_deepspeed", action="store_true") 183 | parser.add_argument("--use_bitsandbytes", action="store_true") 184 | parser.add_argument("--use_hugFace_parallelism", action="store_true") 185 | parser.add_argument('-b', "--eval_batch_size", type=int, default=1) 186 | parser.add_argument('-p', "--prompt_template_idx", type=int, default=0) 187 | parser.add_argument("--debug", action="store_true") 188 | parser.add_argument("--local_rank", type=int, default=0) 189 | parser.add_argument('--ignore_pointwise_mutual_information', 190 | action="store_true", 191 | help="Whether to use the pointwise mutual information or regular log " 192 | "likelihood for scoring candidates") 193 | parser.add_argument('--ignore_length_normalization', 194 | action="store_true", 195 | help="Whether to use the whether to use length normalization when scoring the candidates ") 196 | parser.add_argument('--compute_choices_iteratively', 197 | action="store_true", 198 | help="Whether to use compute log probs of decoder choices together or iteratively") 199 | args = parser.parse_args() 200 | 201 | logging.basicConfig(level=logging.INFO) 202 | logging.info('Starting evaluate multiple choice') 203 | 204 | if args.use_deepspeed: 205 | logging.info('Using Deepspeed') 206 | # The device is the local_rank since it specifies the GPU to use. 207 | args.device = args.local_rank 208 | args.world_size = int(os.getenv('WORLD_SIZE', '1')) 209 | deepspeed.init_distributed() 210 | else: 211 | # This device is where the input_ids will be loaded. 212 | # It must be 0 since using huggingface parallelism assumes the logits should be back on device 0 to compute the 213 | # loss with the input_ids 214 | args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 215 | args.world_size = None 216 | 217 | evaluate_mulChoice(args) -------------------------------------------------------------------------------- /src/evaluate_mulChoice_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import logging 5 | import re 6 | from tqdm import tqdm 7 | import copy 8 | import json 9 | 10 | 11 | from transformers import AutoTokenizer, GPT2Tokenizer 12 | 13 | from src.multiple_choice.utils import read_jsonl, write_jsonl 14 | from src.data.preprocess_data import does_tokenizer_addBosEosTokens 15 | 16 | from src.models.model_flags import DICT_REGEX_OF_MODEL_TYPE, DICT_REGEX_OF_DEVICE_MAP 17 | from src.utils.util import get_value_from_key_matching_regex 18 | 19 | 20 | def compute_totalLogProb(logProb_ofIds): 21 | ''' 22 | Assumes log probs will be zero for ids which are supposed to be masked out, ie. pad tokens 23 | 24 | Args: 25 | logProb_ofIds: 26 | 27 | Returns: 28 | 29 | ''' 30 | return torch.sum(logProb_ofIds, dim=1) 31 | 32 | def compute_avgLogProb(logProb_ofIds, max_len): 33 | ''' 34 | Assumes log probs will be zero for ids which are supposed to be masked out, ie. pad tokens 35 | 36 | Args: 37 | logProb_ofIds: 38 | 39 | Returns: 40 | 41 | ''' 42 | return torch.sum(logProb_ofIds, dim=1) / max_len 43 | 44 | def get_model(prediction_filepath): 45 | listOf_directories = prediction_filepath.split("/") 46 | return listOf_directories[-2] 47 | 48 | def check_scoreMatches_recomputingFromCache(predictionPrompt_filepath): 49 | list_json = read_jsonl(predictionPrompt_filepath) 50 | 51 | for json in list_json: 52 | logProb_ofAllChoiceIds = torch.tensor(json["log_probs_of_all_choices_ids"]) 53 | score_ofChoices = torch.tensor(json["score_of_choices"]) 54 | logProb_ofAllChoiceIds_condNullInput = torch.tensor(json["log_prob_of_all_choices_ids_cond_null_input"]) 55 | len_allChoices = torch.tensor(json["len_all_choices"]) 56 | pred_choice = json["pred_choice"] 57 | 58 | allChoices_logProb = compute_avgLogProb(logProb_ofAllChoiceIds, len_allChoices) 59 | allChoices_logProb_condNullInput = compute_avgLogProb(logProb_ofAllChoiceIds_condNullInput, len_allChoices) 60 | allChoices_logProb -= allChoices_logProb_condNullInput 61 | 62 | if not torch.allclose(allChoices_logProb, score_ofChoices, atol=1e-4): 63 | print(predictionPrompt_filepath) 64 | import ipdb; ipdb.set_trace() 65 | 66 | # Handle case where predicted probabilities are same for both choices, so 67 | # argmax might not be consistent 68 | if not torch.argmax(allChoices_logProb) == pred_choice and\ 69 | not torch.allclose(allChoices_logProb[0], allChoices_logProb[1], atol=1e-4): 70 | print(predictionPrompt_filepath) 71 | import ipdb; ipdb.set_trace() 72 | 73 | dictOfModel_toDictOfInput_toGoldSummaryLogProb = {} 74 | def check_correctSummaryScoreMatches_acrossDifferentDistractors(predictionPrompt_filepath): 75 | list_json = read_jsonl(predictionPrompt_filepath) 76 | 77 | model = get_model(predictionPrompt_filepath) 78 | 79 | if model not in dictOfModel_toDictOfInput_toGoldSummaryLogProb: 80 | dictOfModel_toDictOfInput_toGoldSummaryLogProb[model] = {} 81 | 82 | for json in list_json: 83 | score_ofChoices = torch.tensor(json["score_of_choices"]) 84 | correctChoice_logProb = score_ofChoices[json["lbl"]] 85 | input = json["input"] 86 | 87 | if input in dictOfModel_toDictOfInput_toGoldSummaryLogProb[model]: 88 | if not torch.allclose(dictOfModel_toDictOfInput_toGoldSummaryLogProb[model][input][0], correctChoice_logProb, atol=1e-4): 89 | print(predictionPrompt_filepath) 90 | else: 91 | dictOfModel_toDictOfInput_toGoldSummaryLogProb[model][input] = correctChoice_logProb, json["list_choices"][json["lbl"]] 92 | 93 | def check_accuraciesCorrect_andExistsForEachPrompt(mulChoice_experiment): 94 | ''' 95 | 96 | 97 | Returns: 98 | 99 | ''' 100 | scores_filepath = os.path.join(mulChoice_experiment, "scores.jsonl") 101 | 102 | # Check there are 3 scores for 3 prompts 103 | if os.path.exists(scores_filepath): 104 | list_scoresJson = read_jsonl(scores_filepath) 105 | if len(list_scoresJson) != 3: 106 | # t5-large finetuned only use 1 prompt, so it should have 1 score 107 | assert len(list_scoresJson) == 1 and \ 108 | ("sysresearch101-t5-large-finetuned-xsum" in mulChoice_experiment),\ 109 | mulChoice_experiment 110 | else: 111 | print(scores_filepath) 112 | import ipdb; ipdb.set_trace() 113 | 114 | for score_json in list_scoresJson: 115 | prompt_idx = score_json["prompt_template_idx"] 116 | predictionPrompt_filepath = os.path.join(mulChoice_experiment, f"predictions-prompt_{prompt_idx}.json") 117 | 118 | list_json = read_jsonl(predictionPrompt_filepath) 119 | num_correct = 0 120 | for json in list_json: 121 | if json["pred_choice"] == json["lbl"]: 122 | num_correct += 1 123 | 124 | computed_acc = round(num_correct / len(list_json), 3) 125 | assert computed_acc == score_json["multiple-choice-accuracy"] 126 | 127 | def test_experiment(mulChoice_experiment): 128 | check_accuraciesCorrect_andExistsForEachPrompt(mulChoice_experiment) 129 | 130 | for prompt_idx in range(3): 131 | predictionPrompt_filepath = os.path.join(mulChoice_experiment, f"predictions-prompt_{prompt_idx}.json") 132 | 133 | if os.path.exists(predictionPrompt_filepath): 134 | check_scoreMatches_recomputingFromCache(predictionPrompt_filepath) 135 | check_correctSummaryScoreMatches_acrossDifferentDistractors(predictionPrompt_filepath) 136 | 137 | 138 | if __name__ == "__main__": 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('-e', "--list_mulChoiceExperiments", action='store', type=str, nargs='*', required=True) 141 | args = parser.parse_args() 142 | 143 | for mulChoice_experiment in tqdm(args.list_mulChoiceExperiments): 144 | test_experiment(mulChoice_experiment) 145 | -------------------------------------------------------------------------------- /src/get_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import statistics 4 | import os 5 | 6 | 7 | def get_medianScore_perExperiment(modelWithScoringFunction_dir): 8 | score_fp = os.path.join(modelWithScoringFunction_dir, "scores.jsonl") 9 | 10 | dict_promptTemplateIdx_toMCAccuracy = {} 11 | 12 | with open(score_fp, "r") as f: 13 | for line in f.readlines(): 14 | score_json = json.loads(line.strip("\n")) 15 | dict_promptTemplateIdx_toMCAccuracy[score_json["prompt_template_idx"]] = score_json["multiple-choice-accuracy"] 16 | 17 | return statistics.median(list(dict_promptTemplateIdx_toMCAccuracy.values())) 18 | 19 | def get_medianScore_perModel(model_dir): 20 | avg_pmi_acc = get_medianScore_perExperiment(model_dir) 21 | return [avg_pmi_acc] 22 | 23 | def get_medianScore_perDataset(dataset_dir, list_models): 24 | list_acc = [] 25 | 26 | for model in list_models: 27 | if model is None: 28 | list_acc.extend([0] * 4) 29 | else: 30 | model_dir = os.path.join(dataset_dir, model) 31 | list_acc.extend(get_medianScore_perModel(model_dir)) 32 | 33 | return list_acc 34 | 35 | def get_medianScore_acrossDatasets(datasets, list_models): 36 | 37 | print("Using the following datasets ... ") 38 | acc_acrossDataset = [] 39 | for dataset in datasets: 40 | print(dataset) 41 | acc_perDataset = get_medianScore_perDataset(dataset, list_models) 42 | acc_acrossDataset.append(acc_perDataset) 43 | 44 | print("The median accuracy per model across different distractor models is ... ") 45 | for idx, acc_perModel in enumerate(list(map(list, zip(*acc_acrossDataset)))): 46 | formattedAcc_perModel = list(map(lambda x: str(round(x * 100, 3)), acc_perModel)) 47 | print(list_models[idx] + ": " + ",".join(formattedAcc_perModel)) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('-e', "--exp_dir_of_datasets", action='store', type=str, nargs='*', required=True) 53 | parser.add_argument('-m', "--list_models", action='store', type=str, nargs='*', required=True) 54 | args = parser.parse_args() 55 | 56 | get_medianScore_acrossDatasets(args.exp_dir_of_datasets, args.list_models) -------------------------------------------------------------------------------- /src/models/DecoderWrappers_forMulChoice.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from src.utils.util import get_value_from_key_matching_regex 7 | from src.models.model_flags import DICT_REGEX_OF_WHETHER_MODEL_USES_POSITION_IDS 8 | from src.models.utils import compute_logProb 9 | 10 | class DecoderWrappers_forMulChoice(nn.Module): 11 | ''' 12 | 13 | ''' 14 | 15 | def __init__(self, transformer): 16 | super().__init__() 17 | self.transformer = transformer 18 | 19 | self.use_position_ids = get_value_from_key_matching_regex(DICT_REGEX_OF_WHETHER_MODEL_USES_POSITION_IDS, self.transformer._get_name().lower()) 20 | 21 | if "gptneox" in self.transformer._get_name().lower(): 22 | self.use_position_ids = False 23 | print("WARNING! NeoX has a bug with padding the input when caching key,values. Use a batch size of 1.") 24 | assert self.use_position_ids is not None 25 | 26 | 27 | def _broadcast_tensors(self, input_masks, past_key_values, num_choices): 28 | ''' 29 | Broadcast the input masks and encoder outputs to account for multiple choices per input 30 | 31 | Args: 32 | input_masks: [batch_size, max_input_len] 33 | past_key_values: Tuple of keys and values for each layer. 34 | The first index of the tuple is the layer index, and the second index 35 | of the tuple is whether it is a key or value. Each element in tuple 36 | has shape [batch_size, max_input_len, num_heads, head_dim] or [batch_size x num_heads, head_dim, max_input_len]. 37 | num_choices: 38 | 39 | Returns: 40 | input_masks: [batch_size x num_choices, max_input_len] 41 | past_key_values: Tuple of keys and values for each layer. 42 | The first index of the tuple is the layer index, and the second index 43 | of the tuple is whether it is a key or value. Each element in tuple 44 | has shape [batch_size x num_choices, max_input_len, num_heads, head_dim] 45 | or [batch_size x num_heads x num_choices, head_dim, max_input_len]. 46 | ''' 47 | batch_size, max_input_len = input_masks.shape 48 | input_masks = torch.repeat_interleave(input_masks, num_choices, dim=0) 49 | 50 | list_broadcast_pastKeyValues = [] 51 | for pastKeyValues_perLayer in past_key_values: 52 | 53 | list_broadcast_pastKeyValues_perLayer = [] 54 | for key_or_value in pastKeyValues_perLayer: 55 | # This is for keys or values which have dimension [batch_size, max_input_len, num_heads, head_dim] 56 | # This is the standard for Hugging Face. 57 | if len(key_or_value.shape) == 4: 58 | list_broadcast_pastKeyValues_perLayer.append(torch.repeat_interleave(key_or_value, num_choices, dim=0)) 59 | # This is for keys or values which have dimension [batch_size x num_heads, head_dim, max_input_len]. 60 | # This is what is used for BLOOM in transformers == 4.22.0 61 | elif len(key_or_value.shape) == 3: 62 | num_heads = key_or_value.shape[0] // batch_size 63 | flatten_keyOrValue = key_or_value.reshape(((batch_size, num_heads) + key_or_value.shape[1:])) 64 | broadcast_flatten_keyOrValue = torch.repeat_interleave(flatten_keyOrValue, num_choices, dim=0) 65 | list_broadcast_pastKeyValues_perLayer.append(broadcast_flatten_keyOrValue.flatten(0, 1)) 66 | else: 67 | raise ValueError(f"Invalid cached key or value shape: ", key_or_value.shape) 68 | 69 | list_broadcast_pastKeyValues.append(tuple(list_broadcast_pastKeyValues_perLayer)) 70 | 71 | return input_masks, tuple(list_broadcast_pastKeyValues) 72 | 73 | def compute_allChoices_logProb_fromDecoderOutput(self, 74 | input_masks, 75 | past_key_values, 76 | allChoices_ids, 77 | allChoices_masks, 78 | lengthNormalization): 79 | ''' 80 | 81 | Args: 82 | input_masks: [batch_size, max_input_len] 83 | past_key_values: Tuple of keys and values for each layer. 84 | The first index of the tuple is the layer index, and the second index 85 | of the tuple is whether it is a key or value. Each element in tuple 86 | has shape [batch_size, max_input_len, num_heads, head_dim]. 87 | allChoices_ids: [batch_size x num_choices, max_choice_len] 88 | allChoices_masks: [batch_size x num_choices, max_choice_len] 89 | 90 | Returns: 91 | logProbs_forAllChoices: [batch_size, num_choices] 92 | logProbs_forAllChoicesIds_zeroOutPadIds: [batch_size, num_choices, max_choice_len] 93 | len_allChoices: [batch_size ] 94 | ''' 95 | num_choices = allChoices_ids.shape[0] // input_masks.shape[0] 96 | input_masks, past_key_values = self._broadcast_tensors(input_masks, past_key_values, num_choices) 97 | 98 | # Combine the input mask and choice mask so the model knows which cached input representations 99 | # are padded when conditioning on the cached input representations. 100 | # [batch_size x num_choices, max_input_len + max_choice_len] 101 | combined_mask = torch.cat([input_masks, allChoices_masks], dim=1) 102 | 103 | if self.use_position_ids: 104 | # Construct initial position ids solely based on choice lengths 105 | # [1, max_choice_len] 106 | allChoices_positionIds = torch.arange(0, allChoices_ids.shape[-1], dtype=torch.long, device=allChoices_ids.device)[None, :] 107 | input_len = torch.sum(input_masks, dim=1)[:, None] 108 | # Increment the position id to account for the input len. 109 | allChoices_positionIds = allChoices_positionIds + input_len 110 | 111 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a 112 | # pad token of 0 and so the loss will not be ignored for the pad tokens 113 | transformer_outputs = self.transformer(input_ids=allChoices_ids, 114 | attention_mask=combined_mask, 115 | position_ids=allChoices_positionIds, 116 | past_key_values=past_key_values, 117 | use_cache=True, 118 | labels=allChoices_ids) 119 | else: 120 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a 121 | # pad token of 0 and so the loss will not be ignored for the pad tokens 122 | transformer_outputs = self.transformer(input_ids=allChoices_ids, 123 | attention_mask=combined_mask, 124 | past_key_values=past_key_values, 125 | use_cache=True, 126 | labels=allChoices_ids) 127 | 128 | 129 | # We used the logits for all choices to compute the log probs per example since 130 | # the loss returned in transformer_outputs will average the negative log probs across 131 | # examples 132 | # [batch_size x num_choices, max_choice_len, vocab_size] 133 | logits_ofAllChoices = transformer_outputs[1].float() 134 | 135 | # Shift the ids, masks, logits to handle predicting the next token for the decoder. 136 | # Note that we need to pass in the input_ids and cannot rely on HuggingFace automatically 137 | # constructing the ids from the labels, since we need to pass in an attention mask to handle 138 | # the cached input representations. 139 | shiftedLogits_ofAllChoices = logits_ofAllChoices[..., :-1, :].contiguous() 140 | shiftedIds_ofAllChoices = allChoices_ids[..., 1:].contiguous() 141 | shiftedMasks_ofAllChoices = allChoices_masks[..., 1:].contiguous() 142 | 143 | maxChoice_len = shiftedLogits_ofAllChoices.shape[1] 144 | vocab_size = shiftedLogits_ofAllChoices.shape[-1] 145 | 146 | # Compute the log probability of the ids for all choices with respect to the logits 147 | # [batch_size x num_choices x (max_choice_len-1)] 148 | logProbs_forAllChoices_ids = - F.cross_entropy(shiftedLogits_ofAllChoices.view(-1, vocab_size), 149 | shiftedIds_ofAllChoices.view(-1), 150 | reduction="none") 151 | 152 | return compute_logProb(logProbs_forAllChoices_ids, 153 | shiftedMasks_ofAllChoices, 154 | num_choices, 155 | maxChoice_len, 156 | lengthNormalization) 157 | 158 | def compute_allChoices_logProb_fromDecoderOutput_iteratively(self, 159 | input_masks, 160 | past_key_values, 161 | allChoices_ids, 162 | allChoices_masks, 163 | lengthNormalization): 164 | ''' 165 | Args: 166 | input_masks: [batch_size, max_input_len] 167 | past_key_values: Tuple of keys and values for each layer. 168 | The first index of the tuple is the layer index, and the second index 169 | of the tuple is whether it is a key or value. Each element in tuple 170 | has shape [batch_size, max_input_len, num_heads, head_dim]. 171 | allChoices_ids: [batch_size x num_choices, max_choice_len] 172 | allChoices_masks: [batch_size x num_choices, max_choice_len] 173 | lengthNormalization: 174 | Returns: 175 | logProbs_forAllChoices: [batch_size, num_choices] 176 | logProbs_forAllChoicesIds_zeroOutPadIds: [batch_size, max_choice_len, ] 177 | len_allChoices: [batch_size ] 178 | ''' 179 | batch_size = input_masks.shape[0] 180 | assert batch_size == 1, "No need to score choices iteratively if batch size can be larger than 1" 181 | num_choices = allChoices_ids.shape[0] // input_masks.shape[0] 182 | 183 | list_logProbs_ofAllChoices = [] 184 | list_logProbs_ofAllChoicesIds_zeroOutPadIds = [] 185 | list_lenAllChoices = [] 186 | 187 | for choice_idx in range(num_choices): 188 | # [1, max_choice_len] 189 | curChoice_ids = allChoices_ids[choice_idx:choice_idx + 1, :] 190 | curChoice_mask = allChoices_masks[choice_idx:choice_idx + 1, :] 191 | 192 | # Remove pad tokens 193 | assert curChoice_mask.shape[0] == 1 194 | num_nonPadTokens = torch.sum(curChoice_mask) 195 | num_PadTokens = curChoice_mask.shape[1] - num_nonPadTokens 196 | 197 | curChoice_ids = curChoice_ids[:,:num_nonPadTokens] 198 | curChoice_mask = curChoice_mask[:,:num_nonPadTokens] 199 | 200 | assert curChoice_mask[0,-1] == 1 201 | 202 | # Combine the input mask and choice mask so the model knows which cached input representations 203 | # are padded when conditioning on the cached input representations. 204 | # [batch_size, max_input_len + max_choice_len] 205 | combined_mask = torch.cat([input_masks, curChoice_mask], dim=1) 206 | 207 | if self.use_position_ids: 208 | # Construct initial position ids solely based on choice lengths 209 | # [1, max_choice_len] 210 | curChoice_positionIds = torch.arange(0, curChoice_ids.shape[-1], dtype=torch.long, 211 | device=curChoice_ids.device)[None, :] 212 | input_len = torch.sum(input_masks, dim=1)[:, None] 213 | # Increment the position id to account for the input len. 214 | curChoice_positionIds = curChoice_positionIds + input_len 215 | 216 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a 217 | # pad token of 0 and so the loss will not be ignored for the pad tokens 218 | transformer_outputs = self.transformer(input_ids=curChoice_ids, 219 | attention_mask=combined_mask, 220 | position_ids=curChoice_positionIds, 221 | past_key_values=past_key_values, 222 | use_cache=True, 223 | labels=curChoice_ids) 224 | else: 225 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a 226 | # pad token of 0 and so the loss will not be ignored for the pad tokens 227 | transformer_outputs = self.transformer(input_ids=curChoice_ids, 228 | attention_mask=combined_mask, 229 | past_key_values=past_key_values, 230 | use_cache=True, 231 | labels=curChoice_ids) 232 | 233 | # We used the logits for all choices to compute the log probs per example since 234 | # the loss returned in transformer_outputs will average the negative log probs across 235 | # examples 236 | # [batch_size, max_choice_len, vocab_size] 237 | logits_ofCurChoice = transformer_outputs[1].float() 238 | 239 | # Shift the ids, masks, logits to handle predicting the next token for the decoder. 240 | # Note that we need to pass in the input_ids and cannot rely on HuggingFace automatically 241 | # constructing the ids from the labels, since we need to pass in an attention mask to handle 242 | # the cached input representations. 243 | shiftedLogits_ofCurChoice = logits_ofCurChoice[..., :-1, :].contiguous() 244 | shifted_curChoice_ids = curChoice_ids[..., 1:].contiguous() 245 | shifted_curChoice_mask = curChoice_mask[..., 1:].contiguous() 246 | 247 | maxChoice_len = shiftedLogits_ofCurChoice.shape[1] 248 | vocab_size = shiftedLogits_ofCurChoice.shape[-1] 249 | 250 | # Compute the log probability of the ids for all choices with respect to the logits 251 | # [batch_size x (max_choice_len-1)] 252 | logProbs_ofCurChoice_ids = - F.cross_entropy(shiftedLogits_ofCurChoice.view(-1, vocab_size), 253 | shifted_curChoice_ids.view(-1), 254 | reduction="none") 255 | 256 | # Compute the log probabilities of all the choices by averaging the log probabilities of 257 | # the ids and zeroing out the pad ids 258 | # [batch_size, (max_choice_len-1)] 259 | logProbs_ofCurChoice_ids = logProbs_ofCurChoice_ids.reshape(-1, maxChoice_len) 260 | shifted_curChoice_mask = shifted_curChoice_mask > 0 261 | logProbs_ofCurChoiceIds_zeroOutPadIds = logProbs_ofCurChoice_ids * shifted_curChoice_mask 262 | 263 | logProb_ofCurChoice = torch.sum(logProbs_ofCurChoiceIds_zeroOutPadIds, dim=1) 264 | len_curChoice = torch.sum(shifted_curChoice_mask, dim=1) 265 | 266 | if lengthNormalization: 267 | logProb_ofCurChoice = logProb_ofCurChoice / len_curChoice 268 | 269 | list_logProbs_ofAllChoices.append(logProb_ofCurChoice) 270 | list_logProbs_ofAllChoicesIds_zeroOutPadIds.append(torch.cat([ 271 | logProbs_ofCurChoiceIds_zeroOutPadIds, 272 | torch.zeros((1, num_PadTokens)).to(logProbs_ofCurChoiceIds_zeroOutPadIds.device) 273 | ], dim=1)) 274 | list_lenAllChoices.append(len_curChoice) 275 | 276 | # Since batch size was 1, the batch size will be flattened and we have to add back the extra dimension with stack 277 | return torch.stack(list_logProbs_ofAllChoices, dim=1), \ 278 | torch.stack(list_logProbs_ofAllChoicesIds_zeroOutPadIds, dim=1), \ 279 | torch.stack(list_lenAllChoices, dim=1) 280 | 281 | def compute_allChoices_logProb(self, 282 | input_ids, 283 | input_masks, 284 | allChoices_ids, 285 | allChoices_masks, 286 | lengthNormalization, 287 | iterativelyComputeChoices): 288 | ''' 289 | 290 | 291 | Args: 292 | input_ids: [batch_size, max_input_len] 293 | input_masks: [batch_size, max_input_len] 294 | allChoices_ids: [batch_size x num_choices, max_choice_len] 295 | allChoices_masks: [batch_size x num_choices, max_choice_len] 296 | lengthNormalization: 297 | iterativelyComputeChoices 298 | 299 | Returns: 300 | log_prob: [batch_size, num_choices] 301 | ''' 302 | output = self.transformer(input_ids=input_ids, attention_mask=input_masks) 303 | past_key_values = output.past_key_values 304 | 305 | if iterativelyComputeChoices: 306 | return self.compute_allChoices_logProb_fromDecoderOutput_iteratively(input_masks, 307 | past_key_values, 308 | allChoices_ids, 309 | allChoices_masks, 310 | lengthNormalization) 311 | else: 312 | return self.compute_allChoices_logProb_fromDecoderOutput(input_masks, 313 | past_key_values, 314 | allChoices_ids, 315 | allChoices_masks, 316 | lengthNormalization) 317 | 318 | def predict_mulChoice(self, batch, pointMutualInfo, lengthNormalization, iterativelyComputeChoices): 319 | ''' 320 | 321 | Args: 322 | batch: 323 | pointMutualInfo: 324 | lengthNormalization: 325 | 326 | Returns: 327 | pred_choice: [batch_size, ] 328 | score_ofChoices: [batch_size, num_choices] 329 | logProbs_ofAllChoicesIds: [batch_size, num_choices, max_choice_len] 330 | len_allChoices: [batch_size] 331 | logProbs_ofAllChoicesIds_condOnNullInput: [batch_size, num_choices, max_choice_len] 332 | ''' 333 | # Compute log p(y|x) 334 | score_ofChoices, logProbs_ofAllChoicesIds, len_allChoices = self.compute_allChoices_logProb( 335 | batch["input_ids"], 336 | batch["input_masks"], 337 | batch["all_choices_ids"], 338 | batch["all_choices_masks"], 339 | lengthNormalization, 340 | iterativelyComputeChoices) 341 | 342 | logProbs_ofAllChoicesIds_condOnNullInput = None 343 | 344 | # For computing pointwise mutual information, we need to compute log p(y|x) - log p(y). 345 | # To compute p(y), we condition the choices on the null input. 346 | if pointMutualInfo: 347 | logProb_ofChoices_condOnNullInput, logProbs_ofAllChoicesIds_condOnNullInput, _ = self.compute_allChoices_logProb( 348 | batch["null_input_ids"], 349 | batch["null_input_masks"], 350 | batch["all_choices_ids"], 351 | batch["all_choices_masks"], 352 | lengthNormalization, 353 | iterativelyComputeChoices) 354 | score_ofChoices -= logProb_ofChoices_condOnNullInput 355 | 356 | _, pred_choice = torch.max(score_ofChoices, dim=1) 357 | return pred_choice.cpu().numpy().tolist(), \ 358 | score_ofChoices.cpu().numpy().tolist(), \ 359 | logProbs_ofAllChoicesIds.cpu().numpy().tolist(), \ 360 | len_allChoices.cpu().numpy().tolist(), \ 361 | logProbs_ofAllChoicesIds_condOnNullInput.cpu().numpy().tolist() if logProbs_ofAllChoicesIds_condOnNullInput is not None else None 362 | -------------------------------------------------------------------------------- /src/models/DecoderWrappers_forMulChoice_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.data.multiple_choice import MultipleChoiceDataset, MultipleChoiceReader 4 | from src.data.Batcher import Batcher 5 | 6 | from src.constructors import construct_hugFace_objects, construct_models 7 | from src.models.DecoderWrappers_forMulChoice import DecoderWrappers_forMulChoice 8 | 9 | class Test_DecoderWrappers(DecoderWrappers_forMulChoice): 10 | ''' 11 | 12 | ''' 13 | 14 | def __init__(self, transformer): 15 | super().__init__(transformer) 16 | self.transformer = transformer 17 | 18 | def _test_broadcast_tensor(self, input_masks, past_key_values, num_choices): 19 | ''' 20 | Test that when repeating the tensors by num_choices times, the repetitions will 21 | be in the same block. 22 | 23 | Args: 24 | input_masks: [batch_size, max_input_len] 25 | past_key_values: Tuple of keys and values for each layer. 26 | The first index of the tuple is the layer index, and the second index 27 | of the tuple is whether it is a key or value. Each element in tuple 28 | has shape [batch_size, max_input_len, num_heads, head_dim] or [batch_size x num_heads, head_dim, max_input_len]. 29 | num_choices: 30 | ''' 31 | new_inputMask, new_pastKeyValues = super()._broadcast_tensors(input_masks, past_key_values, num_choices) 32 | 33 | batch_size = input_masks.shape[0] 34 | for i in range(batch_size): 35 | assert torch.equal(input_masks[i:i+1].repeat(num_choices, 1) , 36 | new_inputMask[i*num_choices:(i+1)*num_choices, ]), \ 37 | f"Test of broadcasting input_masks failed." 38 | for old_keyValues_perLayer, new_keyValues_perLayer in zip(past_key_values, new_pastKeyValues): 39 | for old_keyOrValue, new_keyOrValue in zip(old_keyValues_perLayer, new_keyValues_perLayer): 40 | batch_size = old_keyOrValue.shape[0] 41 | for i in range(batch_size): 42 | num_heads = past_key_values[0][0].shape[0] // input_masks.shape[0] 43 | 44 | # This means the keys or values are of shape [batch_size, max_input_len, num_heads, head_dim] 45 | if num_heads == 1: 46 | assert torch.equal(old_keyOrValue[i:i + 1].repeat(num_choices, 1, 1, 1), 47 | new_keyOrValue[i * num_choices:(i + 1) * num_choices]), \ 48 | f"Test of broadcasting key,values failed." 49 | # This means the keys and values are of shape [batch_size x num_heads, head_dim, max_input_len]. 50 | else: 51 | assert torch.equal(old_keyOrValue[i * num_heads : (i + 1) * num_heads].repeat((num_choices, 1, 1)), 52 | new_keyOrValue[i * num_heads * num_choices : (i + 1) * num_heads * num_choices]), \ 53 | f"Test of broadcasting key,values failed." 54 | 55 | 56 | def test_compute_allChoices_logProb_fromDecoderOutput(self, 57 | input_ids, 58 | input_masks, 59 | past_key_values, 60 | allChoices_ids, 61 | allChoices_masks, 62 | allChoices_lbls): 63 | ''' 64 | 65 | Args: 66 | input_ids: [batch_size, max_input_len] 67 | input_masks: [batch_size, max_input_len] 68 | past_key_values: Tuple of keys and values for each layer. 69 | The first index of the tuple is the layer index, and the second index 70 | of the tuple is whether it is a key or value. Each element in tuple 71 | has shape [batch_size, max_input_len, num_heads, head_dim] or [batch_size x num_heads, head_dim, max_input_len]. 72 | allChoices_ids: [batch_size x num_choices, max_choice_len] 73 | allChoices_masks: [batch_size x num_choices, max_choice_len] 74 | allChoices_lbls: [batch_size x num_choices, max_choice_len] 75 | 76 | Returns: 77 | 78 | ''' 79 | num_choices = allChoices_lbls.shape[0] // input_masks.shape[0] 80 | batch_size = input_masks.shape[0] 81 | self._test_broadcast_tensor(input_masks, past_key_values, num_choices) 82 | 83 | # Iterate over every datapoint and every choice to compute the log prob using the loss 84 | # returned from HuggingFace. Since HuggingFace averages the loss per batch, 85 | # we use batch_size=1 to get the log prob for each choice of each datapoint. 86 | listOf_logProb = [] 87 | for datapoint_idx in range(batch_size): 88 | datapoint_ids = input_ids[datapoint_idx:datapoint_idx + 1] 89 | datapoint_mask = input_masks[datapoint_idx:datapoint_idx + 1] 90 | 91 | for choice_idx in range(num_choices): 92 | choiceLbls_idx = datapoint_idx*num_choices + choice_idx 93 | choice_lbls = allChoices_lbls[choiceLbls_idx:choiceLbls_idx+1] 94 | choice_ids = allChoices_ids[choiceLbls_idx:choiceLbls_idx+1] 95 | choice_mask = allChoices_masks[choiceLbls_idx:choiceLbls_idx+1] 96 | 97 | # Note the batch size is 1. Have to filter the datapoint_ids 98 | # to remove the pad ids in between the datapoint and the choice when we combined them. 99 | datapoint_len = torch.sum(datapoint_mask) 100 | filtered_datapointIds = datapoint_ids[:,:datapoint_len] 101 | combined_ids = torch.cat([filtered_datapointIds, choice_ids], dim=1) 102 | combined_mask = torch.cat([datapoint_mask[:,:datapoint_len], choice_mask], dim=1) 103 | 104 | # We want to ignore the loss for the datapoint and only compute the loss for the choices. 105 | datapoint_lbls = torch.ones_like(filtered_datapointIds).to(datapoint_ids.device) * -100 106 | # We ignore the first token in choice labels since HuggingFace will shift the labels 107 | # one over to the left, but since we concatenate the datapoint labels and choice labels, 108 | # the first choice id will not be shifted over. 109 | choice_lbls[:,0] = -100 110 | combined_lbls = torch.cat([datapoint_lbls, choice_lbls], dim=1) 111 | transformer_outputs = self.transformer(input_ids=combined_ids, 112 | attention_mask=combined_mask, 113 | labels=combined_lbls) 114 | choice_logProb = - transformer_outputs[0] 115 | listOf_logProb.append(choice_logProb) 116 | 117 | logProb_forAllChoices = torch.stack(listOf_logProb, dim=0).reshape(batch_size, num_choices) 118 | 119 | assert torch.allclose(logProb_forAllChoices, 120 | super().compute_allChoices_logProb_fromDecoderOutput( 121 | input_masks, 122 | past_key_values, 123 | allChoices_ids, 124 | allChoices_masks, 125 | True)[0], 126 | atol=1e-4), \ 127 | "Test of computing log probs from decoder output failed." 128 | 129 | def test_compute_allChoices_logProb_fromDecoderOutput_iteratively(self, 130 | input_masks, 131 | past_key_values, 132 | allChoices_ids, 133 | allChoices_masks): 134 | ''' 135 | Args: 136 | input_masks: [batch_size, max_input_len] 137 | past_key_values: Tuple of keys and values for each layer. 138 | The first index of the tuple is the layer index, and the second index 139 | of the tuple is whether it is a key or value. Each element in tuple 140 | has shape [batch_size, max_input_len, num_heads, head_dim] or [batch_size x num_heads, head_dim, max_input_len]. 141 | allChoices_ids: [batch_size x num_choices, max_choice_len] 142 | allChoices_masks: [batch_size x num_choices, max_choice_len] 143 | Returns: 144 | ''' 145 | assert torch.allclose(super().compute_allChoices_logProb_fromDecoderOutput_iteratively( 146 | input_masks, 147 | past_key_values, 148 | allChoices_ids, 149 | allChoices_masks, 150 | True)[0], 151 | super().compute_allChoices_logProb_fromDecoderOutput( 152 | input_masks, 153 | past_key_values, 154 | allChoices_ids, 155 | allChoices_masks, 156 | True)[0], 157 | atol=1e-4), \ 158 | "Test of computing log probs from decoder output failed." 159 | def test_compute_allChoices_logProb(self, 160 | input_ids, 161 | input_masks, 162 | allChoices_ids, 163 | allChoices_masks, 164 | allChoices_lbls): 165 | ''' 166 | 167 | 168 | Args: 169 | input_ids: [batch_size, max_input_len] 170 | input_masks: [batch_size, max_input_len] 171 | allChoices_ids: [batch_size x num_choices, max_choice_len] 172 | allChoices_lbls: [batch_size x num_choices, max_choice_len] 173 | 174 | Returns: 175 | log_prob: [batch_size x num_choices, max_choice_len] 176 | ''' 177 | output = self.transformer(input_ids=input_ids, attention_mask=input_masks) 178 | past_key_values = output.past_key_values 179 | 180 | self.test_compute_allChoices_logProb_fromDecoderOutput(input_ids, 181 | input_masks, 182 | past_key_values, 183 | allChoices_ids, 184 | allChoices_masks, 185 | allChoices_lbls) 186 | 187 | self.test_compute_allChoices_logProb_fromDecoderOutput_iteratively(input_masks, 188 | past_key_values, 189 | allChoices_ids, 190 | allChoices_masks) 191 | 192 | def test_predict_mulChoice(self, batch): 193 | ''' 194 | 195 | Args: 196 | batch: 197 | pointMutualInfo: 198 | 199 | Returns: 200 | predChoice: [batch_size, ] 201 | predProb: [batch_size, ] 202 | ''' 203 | 204 | # Compute log p(y|x) 205 | self.test_compute_allChoices_logProb( 206 | batch["input_ids"], 207 | batch["input_masks"], 208 | batch["all_choices_ids"], 209 | batch["all_choices_masks"], 210 | batch["all_choices_lbls"]) 211 | 212 | 213 | if __name__ == "__main__": 214 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 215 | 216 | # This only tests all models for 6 batches. 217 | for model_name in ["bigscience/bloom-560m", "gpt2", "facebook/opt-125m"]: 218 | hugFace_config, tokenizer, input_prefix = construct_hugFace_objects(model_name, 512) 219 | _, transformer = construct_models(model_name, False, False) 220 | 221 | model = Test_DecoderWrappers(transformer).to(device) 222 | model.eval() 223 | 224 | mcReader = MultipleChoiceReader() 225 | createDataset_fn = lambda data: MultipleChoiceDataset(data, tokenizer, 0, input_prefix, device, world_size=None) 226 | batcher = Batcher(mcReader, createDataset_fn, train_batchSize=None, eval_batchSize=1) 227 | 228 | for i, batch in enumerate(batcher.get_mulChoiceBatches("multiple_choice-dataset/xsum/random_distractors/binary_choice-using_random_distractors.jsonl")): 229 | with torch.no_grad(): 230 | model.test_predict_mulChoice(batch) 231 | if i > 4: 232 | break -------------------------------------------------------------------------------- /src/models/EncoderDecoderWrappers_forMulChoice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from src.models.utils import compute_logProb 6 | 7 | 8 | class EncoderDecoderWrappers_forMulChoice(nn.Module): 9 | ''' 10 | 11 | ''' 12 | 13 | def __init__(self, transformer): 14 | super().__init__() 15 | self.transformer = transformer 16 | 17 | def _broadcast_tensors(self, input_masks, encoder_outputs, num_choices): 18 | ''' 19 | Broadcast the input masks and encoder outputs to account for multiple choices per input 20 | 21 | Args: 22 | input_masks: [batch_size, max_input_len] 23 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is 24 | the hidden states of the encoder at the last layer 25 | [batch_size, max_input_len, ff_dim] 26 | num_choices: 27 | 28 | Returns: 29 | input_masks: [batch_size x num_choices, max_input_len] 30 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is 31 | the hidden states of the encoder at the last layer 32 | [batch_size x num_choices, max_input_len, ff_dim] 33 | ''' 34 | input_masks = torch.repeat_interleave(input_masks, num_choices, dim=0) 35 | encoder_outputs = (torch.repeat_interleave(encoder_outputs[0], num_choices, dim=0), ) 36 | return input_masks, encoder_outputs 37 | 38 | def compute_allChoices_logProb_fromEncoderOutput(self, 39 | input_masks, 40 | encoder_outputs, 41 | allChoices_ids, 42 | allChoices_masks, 43 | lengthNormalization): 44 | ''' 45 | 46 | Args: 47 | input_masks: [batch_size, max_input_len] 48 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is 49 | the hidden states of the encoder at the last layer 50 | [batch_size, max_input_len, ff_dim] 51 | allChoices_ids: [batch_size x num_choices, max_choice_len] 52 | allChoices_masks: [batch_size x num_choices, max_choice_len] 53 | lengthNormalization: 54 | 55 | Returns: 56 | logProbs_forAllChoices: [batch_size, num_choices] 57 | logProbs_forAllChoicesIds_zeroOutPadIds: [batch_size, num_choices, max_choice_len] 58 | ''' 59 | assert allChoices_ids.shape[0] % input_masks.shape[0] == 0, \ 60 | f"The batch size {allChoices_ids.shape[0]} of allChoices_ids is not a multiple of " \ 61 | f"the batch size {input_masks.shape[0]} of input_masks" 62 | num_choices = allChoices_ids.shape[0] // input_masks.shape[0] 63 | 64 | input_masks, encoder_outputs = self._broadcast_tensors(input_masks, encoder_outputs, num_choices) 65 | 66 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a 67 | # pad token of 0 and so the loss will not be ignored for the pad tokens 68 | # The input mask is passed in for the cross encoder-decoder attention. 69 | transformer_outputs = self.transformer(attention_mask=input_masks, 70 | encoder_outputs=encoder_outputs, 71 | labels=allChoices_ids) 72 | 73 | # We used the logits for all choices to compute the log probs per example since 74 | # the loss returned in transformer_outputs will average the negative log probs across 75 | # examples 76 | # [batch_size x num_choices, max_choice_len, vocab_size] 77 | logits_ofAllChoices = transformer_outputs[1].float() 78 | maxChoice_len = logits_ofAllChoices.shape[1] 79 | vocab_size = logits_ofAllChoices.shape[-1] 80 | 81 | # Compute the log probability of the ids for all choices with respect to the logits 82 | # [batch_size x num_choices x max_choice_len] 83 | logProbs_ofAllChoices_ids = - F.cross_entropy(logits_ofAllChoices.view(-1, vocab_size), 84 | allChoices_ids.view(-1), 85 | reduction="none") 86 | 87 | return compute_logProb(logProbs_ofAllChoices_ids, 88 | allChoices_masks, 89 | num_choices, 90 | maxChoice_len, 91 | lengthNormalization) 92 | 93 | 94 | def compute_allChoices_logProb(self, 95 | input_ids, 96 | input_masks, 97 | allChoices_ids, 98 | allChoices_masks, 99 | lengthNormalization): 100 | ''' 101 | 102 | 103 | Args: 104 | input_ids: [batch_size, max_input_len] 105 | input_masks: [batch_size, max_input_len] 106 | allChoices_ids: [batch_size x num_choices, max_choice_len] 107 | allChoices_masks: [batch_size x num_choices, max_choice_len] 108 | lengthNormalization: 109 | 110 | Returns: 111 | log_prob: [batch_size x num_choices, max_choice_len] 112 | ''' 113 | # Search for encoder function 114 | if hasattr(self.transformer, "encoder"): 115 | encoder_outputs = self.transformer.encoder(input_ids, input_masks) 116 | elif hasattr(self.transformer, "model") and hasattr(self.transformer.model, "encoder"): 117 | encoder_outputs = self.transformer.model.encoder(input_ids, input_masks) 118 | else: 119 | raise ValueError("Cannot find encoder function in transformer") 120 | 121 | return self.compute_allChoices_logProb_fromEncoderOutput(input_masks, 122 | encoder_outputs, 123 | allChoices_ids, 124 | allChoices_masks, 125 | lengthNormalization) 126 | 127 | def predict_mulChoice(self, batch, pointMutualInfo, lengthNormalization, iterativelyComputeChoices): 128 | ''' 129 | 130 | Args: 131 | batch: 132 | pointMutualInfo: 133 | lengthNormalization: 134 | iterativelyComputeChoices: Not used. Added to be consistent with DecoderWrappers_forMulChoice 135 | 136 | Returns: 137 | pred_choice: [batch_size, ] 138 | score_ofChoices: [batch_size, num_choices] 139 | logProbs_ofAllChoicesIds: [batch_size, num_choices, max_choice_len] 140 | len_allChoices: [batch_size] 141 | logProbs_ofAllChoicesIds_condOnNullInput: [batch_size, num_choices, max_choice_len] 142 | ''' 143 | # Compute log p(y|x) 144 | score_ofChoices, logProbs_ofAllChoicesIds, len_allChoices = self.compute_allChoices_logProb( 145 | batch["input_ids"], 146 | batch["input_masks"], 147 | batch["all_choices_ids"], 148 | batch["all_choices_masks"], 149 | lengthNormalization) 150 | 151 | logProbs_ofAllChoicesIds_condOnNullInput = None 152 | 153 | # For computing pointwise mutual information, we need to compute log p(y|x) - log p(y). 154 | # To compute p(y), we condition the choices on the null input. 155 | if pointMutualInfo: 156 | logProb_ofChoices_condOnNullInput, logProbs_ofAllChoicesIds_condOnNullInput, _ = self.compute_allChoices_logProb( 157 | batch["null_input_ids"], 158 | batch["null_input_masks"], 159 | batch["all_choices_ids"], 160 | batch["all_choices_masks"], 161 | lengthNormalization) 162 | score_ofChoices -= logProb_ofChoices_condOnNullInput 163 | 164 | _, pred_choice = torch.max(score_ofChoices, dim=1) 165 | 166 | return pred_choice.cpu().numpy().tolist(), \ 167 | score_ofChoices.cpu().numpy().tolist(), \ 168 | logProbs_ofAllChoicesIds.cpu().numpy().tolist(), \ 169 | len_allChoices.cpu().numpy().tolist(), \ 170 | logProbs_ofAllChoicesIds_condOnNullInput.cpu().numpy().tolist() if logProbs_ofAllChoicesIds_condOnNullInput is not None else None 171 | -------------------------------------------------------------------------------- /src/models/EncoderDecoderWrappers_forMulChoice_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.data.multiple_choice import MultipleChoiceDataset, MultipleChoiceReader 4 | from src.data.Batcher import Batcher 5 | 6 | from src.constructors import construct_hugFace_objects, construct_models 7 | 8 | from src.models.EncoderDecoderWrappers_forMulChoice import EncoderDecoderWrappers_forMulChoice 9 | 10 | class Test_EncoderDecoderWrappers(EncoderDecoderWrappers_forMulChoice): 11 | ''' 12 | 13 | ''' 14 | 15 | def __init__(self, transformer): 16 | super().__init__(transformer) 17 | self.transformer = transformer 18 | 19 | def _test_broadcast_tensor(self, input_mask, encoder_outputs, num_choices): 20 | ''' 21 | Test that when repeating the tensors by num_choices times, the repetitions will 22 | be in the same block. 23 | 24 | Args: 25 | input_masks: [batch_size, max_input_len] 26 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is 27 | the hidden states of the encoder at the last layer 28 | [batch_size, max_input_len, ff_dim] 29 | num_choices: 30 | ''' 31 | new_inputMask, new_encoderOutputs = \ 32 | super()._broadcast_tensors(input_mask, encoder_outputs, num_choices) 33 | 34 | batch_size = input_mask.shape[0] 35 | for i in range(batch_size): 36 | assert torch.equal(input_mask[i:i+1].repeat(num_choices, 1) , 37 | new_inputMask[i*num_choices:(i+1)*num_choices, ]), \ 38 | f"Test of broadcasting input_mask failed." 39 | assert torch.equal(encoder_outputs[0][i:i+1].repeat(num_choices, 1, 1), 40 | new_encoderOutputs[0][i*num_choices:(i+1)*num_choices]), \ 41 | f"Test of broadcasting encoder_outputs failed." 42 | 43 | 44 | 45 | def test_compute_allChoices_logProb_fromEncoderOutput(self, 46 | input_ids, 47 | input_masks, 48 | encoder_outputs, 49 | allChoices_ids, 50 | allChoices_masks, 51 | allChoices_lbls): 52 | ''' 53 | 54 | Args: 55 | input_ids: [batch_size, max_input_len] 56 | input_masks: [batch_size, max_input_len] 57 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is 58 | the hidden states of the encoder at the last layer 59 | [batch_size, max_input_len, ff_dim] 60 | allChoices_ids: [batch_size x num_choices, max_choice_len] 61 | allChoices_masks: [batch_size x num_choices, max_choice_len] 62 | allChoices_lbls: [batch_size x num_choices, max_choice_len] 63 | 64 | Returns: 65 | 66 | ''' 67 | num_choices = allChoices_lbls.shape[0] // input_masks.shape[0] 68 | batch_size = input_masks.shape[0] 69 | self._test_broadcast_tensor(input_masks, encoder_outputs, num_choices) 70 | 71 | # Iterate over every datapoint and every choice to compute the log prob using the loss 72 | # returned from HuggingFace. Since HuggingFace averages the loss per batch, 73 | # we use batch_size=1 to get the log prob for each choice of each datapoint. 74 | listOf_logProb = [] 75 | for datapoint_idx in range(batch_size): 76 | datapoint_ids = input_ids[datapoint_idx:datapoint_idx + 1] 77 | datapoint_mask = input_masks[datapoint_idx:datapoint_idx + 1] 78 | 79 | for choice_idx in range(num_choices): 80 | choiceLbls_idx = datapoint_idx*num_choices + choice_idx 81 | choice_lbls = allChoices_lbls[choiceLbls_idx:choiceLbls_idx+1] 82 | transformer_outputs = self.transformer(input_ids=datapoint_ids, 83 | attention_mask=datapoint_mask, 84 | labels=choice_lbls, 85 | output_hidden_states=True) 86 | choice_logProb = - transformer_outputs[0] 87 | listOf_logProb.append(choice_logProb) 88 | 89 | logProb_forAllChoices = torch.stack(listOf_logProb, dim=0).reshape(batch_size, num_choices) 90 | 91 | assert torch.allclose(logProb_forAllChoices, 92 | super().compute_allChoices_logProb_fromEncoderOutput( 93 | input_masks, 94 | encoder_outputs, 95 | allChoices_ids, 96 | allChoices_masks, 97 | True)[0], 98 | atol=1e-4), \ 99 | "Test of computing log probs from encoder output failed." 100 | 101 | def test_compute_allChoices_logProb(self, 102 | input_ids, 103 | input_masks, 104 | allChoices_ids, 105 | allChoices_masks, 106 | allChoices_lbls): 107 | ''' 108 | 109 | 110 | Args: 111 | input_ids: [batch_size, max_input_len] 112 | input_masks: [batch_size, max_input_len] 113 | allChoices_ids: [batch_size x num_choices, max_choice_len] 114 | allChoices_lbls: [batch_size x num_choices, max_choice_len] 115 | 116 | Returns: 117 | log_prob: [batch_size x num_choices, max_choice_len] 118 | ''' 119 | # Search for encoder function 120 | if hasattr(self.transformer, "encoder"): 121 | encoder_outputs = self.transformer.encoder(input_ids, input_masks) 122 | elif hasattr(self.transformer, "model") and hasattr(self.transformer.model, "encoder"): 123 | encoder_outputs = self.transformer.model.encoder(input_ids, input_masks) 124 | else: 125 | raise ValueError("Cannot find encoder function in transformer") 126 | 127 | self.test_compute_allChoices_logProb_fromEncoderOutput(input_ids, 128 | input_masks, 129 | encoder_outputs, 130 | allChoices_ids, 131 | allChoices_masks, 132 | allChoices_lbls) 133 | 134 | def test_predict_mulChoice(self, batch): 135 | ''' 136 | 137 | Args: 138 | batch: 139 | pointMutualInfo: 140 | 141 | Returns: 142 | predChoice: [batch_size, ] 143 | predProb: [batch_size, ] 144 | ''' 145 | # Compute log p(y|x) 146 | self.test_compute_allChoices_logProb( 147 | batch["input_ids"], 148 | batch["input_masks"], 149 | batch["all_choices_ids"], 150 | batch["all_choices_masks"], 151 | batch["all_choices_lbls"]) 152 | 153 | 154 | if __name__ == "__main__": 155 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 156 | 157 | # This only tests all models for 6 batches. 158 | for model_name in ["Frederick0291/t5-small-finetuned-xsum", "facebook/bart-large-xsum"]: 159 | hugFace_config, tokenizer, input_prefix = construct_hugFace_objects(model_name, 512) 160 | _, transformer = construct_models(model_name, False, False) 161 | 162 | model = Test_EncoderDecoderWrappers(transformer).to(device) 163 | model.eval() 164 | 165 | mcReader = MultipleChoiceReader() 166 | createDataset_fn = lambda data: MultipleChoiceDataset(data, tokenizer, 0, input_prefix, device, world_size=None) 167 | batcher = Batcher(mcReader, createDataset_fn, train_batchSize=None, eval_batchSize=2) 168 | 169 | for i, batch in enumerate(batcher.get_mulChoiceBatches("multiple_choice-dataset/xsum/random_distractors/binary_choice-using_random_distractors.jsonl")): 170 | with torch.no_grad(): 171 | model.test_predict_mulChoice(batch) 172 | if i > 4: 173 | break -------------------------------------------------------------------------------- /src/models/device_maps.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | BLOOM_DEVICE_MAP = {'transformer.word_embeddings': 0, 4 | 'transformer.word_embeddings_layernorm': 0, 5 | 'transformer.h.0': 0, 6 | 'transformer.h.1': 0, 7 | 'transformer.h.2': 0, 8 | 'transformer.h.3': 0, 9 | 'transformer.h.4': 0, 10 | 'transformer.h.5': 0, 11 | 'transformer.h.6': 0, 12 | 'transformer.h.7': 0, 13 | 'transformer.h.8': 0, 14 | 'transformer.h.9': 0, 15 | 'transformer.h.10': 0, 16 | 'transformer.h.11': 0, 17 | 'transformer.h.12': 0, 18 | 'transformer.h.13': 0, 19 | 'transformer.h.14': 0, 20 | 'transformer.h.15': 0, 21 | 'transformer.h.16': 1, 22 | 'transformer.h.17': 1, 23 | 'transformer.h.18': 1, 24 | 'transformer.h.19': 1, 25 | 'transformer.h.20': 1, 26 | 'transformer.h.21': 1, 27 | 'transformer.h.22': 1, 28 | 'transformer.h.23': 1, 29 | 'transformer.h.24': 1, 30 | 'transformer.h.25': 1, 31 | 'transformer.h.26': 1, 32 | 'transformer.h.27': 1, 33 | 'transformer.h.28': 1, 34 | 'transformer.h.29': 1, 35 | 'transformer.h.30': 1, 36 | 'transformer.h.31': 1, 37 | 'transformer.h.32': 1, 38 | 'transformer.h.33': 1, 39 | 'transformer.h.34': 1, 40 | 'transformer.h.35': 2, 41 | 'transformer.h.36': 2, 42 | 'transformer.h.37': 2, 43 | 'transformer.h.38': 2, 44 | 'transformer.h.39': 2, 45 | 'transformer.h.40': 2, 46 | 'transformer.h.41': 2, 47 | 'transformer.h.42': 2, 48 | 'transformer.h.43': 2, 49 | 'transformer.h.44': 2, 50 | 'transformer.h.45': 2, 51 | 'transformer.h.46': 2, 52 | 'transformer.h.47': 2, 53 | 'transformer.h.48': 2, 54 | 'transformer.h.49': 2, 55 | 'transformer.h.50': 2, 56 | 'transformer.h.51': 2, 57 | 'transformer.h.52': 2, 58 | 'transformer.h.53': 3, 59 | 'transformer.h.54': 3, 60 | 'transformer.h.55': 3, 61 | 'transformer.h.56': 3, 62 | 'transformer.h.57': 3, 63 | 'transformer.h.58': 3, 64 | 'transformer.h.59': 3, 65 | 'transformer.h.60': 3, 66 | 'transformer.h.61': 3, 67 | 'transformer.h.62': 3, 68 | 'transformer.h.63': 3, 69 | 'transformer.h.64': 3, 70 | 'transformer.h.65': 3, 71 | 'transformer.h.66': 3, 72 | 'transformer.h.67': 3, 73 | 'transformer.h.68': 3, 74 | 'transformer.h.69': 3, 75 | 'transformer.ln_f': 0, 76 | 'lm_head': 0, 77 | } 78 | 79 | GPT_NEOX_DEVICE_MAP = {'gpt_neox.embed_in': 0, 80 | 'gpt_neox.layers.0': 0, 81 | 'gpt_neox.layers.1': 0, 82 | 'gpt_neox.layers.2': 0, 83 | 'gpt_neox.layers.3': 0, 84 | 'gpt_neox.layers.4': 0, 85 | 'gpt_neox.layers.5': 0, 86 | 'gpt_neox.layers.6': 0, 87 | 'gpt_neox.layers.7': 0, 88 | 'gpt_neox.layers.8': 0, 89 | 'gpt_neox.layers.9': 0, 90 | 'gpt_neox.layers.10': 0, 91 | 'gpt_neox.layers.11': 0, 92 | 'gpt_neox.layers.12': 0, 93 | 'gpt_neox.layers.13': 0, 94 | 'gpt_neox.layers.14': 0, 95 | 'gpt_neox.layers.15': 0, 96 | 'gpt_neox.layers.16': 0, 97 | 'gpt_neox.layers.17': 0, 98 | 'gpt_neox.layers.18': 0, 99 | 'gpt_neox.layers.19': 0, 100 | 'gpt_neox.layers.20': 0, 101 | 'gpt_neox.layers.21': 0, 102 | 'gpt_neox.layers.22': 1, 103 | 'gpt_neox.layers.23': 1, 104 | 'gpt_neox.layers.24': 1, 105 | 'gpt_neox.layers.25': 1, 106 | 'gpt_neox.layers.26': 1, 107 | 'gpt_neox.layers.27': 1, 108 | 'gpt_neox.layers.28': 1, 109 | 'gpt_neox.layers.29': 1, 110 | 'gpt_neox.layers.30': 1, 111 | 'gpt_neox.layers.31': 1, 112 | 'gpt_neox.layers.32': 1, 113 | 'gpt_neox.layers.33': 1, 114 | 'gpt_neox.layers.34': 1, 115 | 'gpt_neox.layers.35': 1, 116 | 'gpt_neox.layers.36': 1, 117 | 'gpt_neox.layers.37': 1, 118 | 'gpt_neox.layers.38': 1, 119 | 'gpt_neox.layers.39': 1, 120 | 'gpt_neox.layers.40': 1, 121 | 'gpt_neox.layers.41': 1, 122 | 'gpt_neox.layers.42': 1, 123 | 'gpt_neox.layers.43': 1, 124 | 'gpt_neox.final_layer_norm': 1, 125 | 'embed_out': 0 126 | } 127 | 128 | T0_DEVICE_MAP = { 129 | 'shared': 0, 130 | 'decoder.embed_tokens': 0, 131 | 'encoder': 0, 132 | 'decoder.block.0': 0, 133 | 'decoder.block.1': 0, 134 | 'decoder.block.2': 0, 135 | 'decoder.block.3': 1, 136 | 'decoder.block.4': 1, 137 | 'decoder.block.5': 1, 138 | 'decoder.block.6': 1, 139 | 'decoder.block.7': 1, 140 | 'decoder.block.8': 1, 141 | 'decoder.block.9': 1, 142 | 'decoder.block.10': 1, 143 | 'decoder.block.11': 1, 144 | 'decoder.block.12': 1, 145 | 'decoder.block.13': 1, 146 | 'decoder.block.14': 1, 147 | 'decoder.block.15': 1, 148 | 'decoder.block.16': 1, 149 | 'decoder.block.17': 1, 150 | 'decoder.block.18': 1, 151 | 'decoder.block.19': 1, 152 | 'decoder.block.20': 1, 153 | 'decoder.block.21': 1, 154 | 'decoder.block.22': 1, 155 | 'decoder.block.23': 1, 156 | 'decoder.final_layer_norm': 1, 157 | 'decoder.dropout': 1, 158 | 'lm_head': 0 159 | } 160 | 161 | OPT_66B_DEVICE_MAP = { 162 | 'model.decoder.embed_tokens': 0, 163 | 'lm_head': 0, 164 | 'model.decoder.embed_positions': 0, 165 | 'model.decoder.final_layer_norm': 0, 166 | 'model.decoder.layers.0': 0, 167 | 'model.decoder.layers.1': 0, 168 | 'model.decoder.layers.2': 0, 169 | 'model.decoder.layers.3': 0, 170 | 'model.decoder.layers.4': 0, 171 | 'model.decoder.layers.5': 0, 172 | 'model.decoder.layers.6': 0, 173 | 'model.decoder.layers.7': 0, 174 | 'model.decoder.layers.8': 0, 175 | 'model.decoder.layers.9': 0, 176 | 'model.decoder.layers.10': 0, 177 | 'model.decoder.layers.11': 0, 178 | 'model.decoder.layers.12': 0, 179 | 'model.decoder.layers.13': 0, 180 | 'model.decoder.layers.14': 0, 181 | 'model.decoder.layers.15': 0, 182 | 'model.decoder.layers.16': 0, 183 | 'model.decoder.layers.17': 0, 184 | 'model.decoder.layers.18': 0, 185 | 'model.decoder.layers.19': 0, 186 | 'model.decoder.layers.20': 0, 187 | 'model.decoder.layers.21': 1, 188 | 'model.decoder.layers.22': 1, 189 | 'model.decoder.layers.23': 1, 190 | 'model.decoder.layers.24': 1, 191 | 'model.decoder.layers.25': 1, 192 | 'model.decoder.layers.26': 1, 193 | 'model.decoder.layers.27': 1, 194 | 'model.decoder.layers.28': 1, 195 | 'model.decoder.layers.29': 1, 196 | 'model.decoder.layers.30': 1, 197 | 'model.decoder.layers.31': 1, 198 | 'model.decoder.layers.32': 1, 199 | 'model.decoder.layers.33': 1, 200 | 'model.decoder.layers.34': 1, 201 | 'model.decoder.layers.35': 1, 202 | 'model.decoder.layers.36': 1, 203 | 'model.decoder.layers.37': 1, 204 | 'model.decoder.layers.38': 1, 205 | 'model.decoder.layers.39': 1, 206 | 'model.decoder.layers.40': 1, 207 | 'model.decoder.layers.41': 1, 208 | 'model.decoder.layers.42': 2, 209 | 'model.decoder.layers.43': 2, 210 | 'model.decoder.layers.44': 2, 211 | 'model.decoder.layers.45': 2, 212 | 'model.decoder.layers.46': 2, 213 | 'model.decoder.layers.47': 2, 214 | 'model.decoder.layers.48': 2, 215 | 'model.decoder.layers.49': 2, 216 | 'model.decoder.layers.50': 2, 217 | 'model.decoder.layers.51': 2, 218 | 'model.decoder.layers.52': 2, 219 | 'model.decoder.layers.53': 2, 220 | 'model.decoder.layers.54': 2, 221 | 'model.decoder.layers.55': 2, 222 | 'model.decoder.layers.56': 2, 223 | 'model.decoder.layers.57': 2, 224 | 'model.decoder.layers.58': 2, 225 | 'model.decoder.layers.59': 2, 226 | 'model.decoder.layers.60': 2, 227 | 'model.decoder.layers.61': 2, 228 | 'model.decoder.layers.62': 2, 229 | 'model.decoder.layers.63': 0 230 | } 231 | 232 | 233 | OPT_175B_DEVICE_MAP = { 234 | 'model.decoder.embed_tokens': 0, 235 | 'lm_head': 0, 236 | 'model.decoder.embed_positions': 0, 237 | 'model.decoder.final_layer_norm': 0, 238 | 'model.decoder.layers.0': 0, 239 | 'model.decoder.layers.1': 0, 240 | 'model.decoder.layers.2': 0, 241 | 'model.decoder.layers.3': 0, 242 | 'model.decoder.layers.4': 0, 243 | 'model.decoder.layers.5': 0, 244 | 'model.decoder.layers.6': 0, 245 | 'model.decoder.layers.7': 0, 246 | 'model.decoder.layers.8': 0, 247 | 'model.decoder.layers.9': 0, 248 | 'model.decoder.layers.10': 0, 249 | 'model.decoder.layers.11': 0, 250 | 'model.decoder.layers.12': 0, 251 | 'model.decoder.layers.13': 0, 252 | 'model.decoder.layers.14': 0, 253 | 'model.decoder.layers.15': 0, 254 | 'model.decoder.layers.16': 0, 255 | 'model.decoder.layers.17': 0, 256 | 'model.decoder.layers.18': 0, 257 | 'model.decoder.layers.19': 0, 258 | 'model.decoder.layers.20': 0, 259 | 'model.decoder.layers.21': 0, 260 | 'model.decoder.layers.22': 0, 261 | 'model.decoder.layers.23': 0, 262 | 'model.decoder.layers.24': 0, 263 | 'model.decoder.layers.25': 1, 264 | 'model.decoder.layers.26': 1, 265 | 'model.decoder.layers.27': 1, 266 | 'model.decoder.layers.28': 1, 267 | 'model.decoder.layers.29': 1, 268 | 'model.decoder.layers.30': 1, 269 | 'model.decoder.layers.31': 1, 270 | 'model.decoder.layers.32': 1, 271 | 'model.decoder.layers.33': 1, 272 | 'model.decoder.layers.34': 1, 273 | 'model.decoder.layers.35': 1, 274 | 'model.decoder.layers.36': 1, 275 | 'model.decoder.layers.37': 1, 276 | 'model.decoder.layers.38': 1, 277 | 'model.decoder.layers.39': 1, 278 | 'model.decoder.layers.40': 1, 279 | 'model.decoder.layers.41': 1, 280 | 'model.decoder.layers.42': 1, 281 | 'model.decoder.layers.43': 1, 282 | 'model.decoder.layers.44': 1, 283 | 'model.decoder.layers.45': 1, 284 | 'model.decoder.layers.46': 1, 285 | 'model.decoder.layers.47': 1, 286 | 'model.decoder.layers.48': 1, 287 | 'model.decoder.layers.49': 2, 288 | 'model.decoder.layers.50': 2, 289 | 'model.decoder.layers.51': 2, 290 | 'model.decoder.layers.52': 2, 291 | 'model.decoder.layers.53': 2, 292 | 'model.decoder.layers.54': 2, 293 | 'model.decoder.layers.55': 2, 294 | 'model.decoder.layers.56': 2, 295 | 'model.decoder.layers.57': 2, 296 | 'model.decoder.layers.58': 2, 297 | 'model.decoder.layers.59': 2, 298 | 'model.decoder.layers.60': 2, 299 | 'model.decoder.layers.61': 2, 300 | 'model.decoder.layers.62': 2, 301 | 'model.decoder.layers.63': 2, 302 | 'model.decoder.layers.64': 2, 303 | 'model.decoder.layers.65': 2, 304 | 'model.decoder.layers.66': 2, 305 | 'model.decoder.layers.67': 2, 306 | 'model.decoder.layers.68': 2, 307 | 'model.decoder.layers.69': 2, 308 | 'model.decoder.layers.70': 2, 309 | 'model.decoder.layers.71': 2, 310 | 'model.decoder.layers.72': 2, 311 | 'model.decoder.layers.73': 3, 312 | 'model.decoder.layers.74': 3, 313 | 'model.decoder.layers.75': 3, 314 | 'model.decoder.layers.76': 3, 315 | 'model.decoder.layers.77': 3, 316 | 'model.decoder.layers.78': 3, 317 | 'model.decoder.layers.79': 3, 318 | 'model.decoder.layers.80': 3, 319 | 'model.decoder.layers.81': 3, 320 | 'model.decoder.layers.82': 3, 321 | 'model.decoder.layers.83': 3, 322 | 'model.decoder.layers.84': 3, 323 | 'model.decoder.layers.85': 3, 324 | 'model.decoder.layers.86': 3, 325 | 'model.decoder.layers.87': 0, 326 | 'model.decoder.layers.88': 0, 327 | 'model.decoder.layers.89': 0, 328 | 'model.decoder.layers.90': 0, 329 | 'model.decoder.layers.91': 0, 330 | 'model.decoder.layers.92': 0, 331 | 'model.decoder.layers.93': 0, 332 | 'model.decoder.layers.94': 0, 333 | 'model.decoder.layers.95': 0 334 | } 335 | 336 | -------------------------------------------------------------------------------- /src/models/model_flags.py: -------------------------------------------------------------------------------- 1 | from src.models.device_maps import BLOOM_DEVICE_MAP, GPT_NEOX_DEVICE_MAP, T0_DEVICE_MAP, OPT_66B_DEVICE_MAP, OPT_175B_DEVICE_MAP 2 | 3 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM 4 | from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast 5 | 6 | DICT_REGEX_OF_MODEL_TYPE = { 7 | ".*T0.*": "encoder_decoder", 8 | ".*pegasus.*": "encoder_decoder", 9 | ".*t5.*": "encoder_decoder", 10 | ".*bart.*": "encoder_decoder", 11 | ".*bloom.*": "decoder", 12 | ".*gpt.*": "decoder", 13 | ".*opt.*": "decoder", 14 | ".*T5.*": "encoder_decoder", 15 | } 16 | 17 | DICT_REGEX_OF_WHETHER_MODEL_USES_POSITION_IDS = { 18 | ".*bloom.*": False, 19 | ".*gpt.*": True, 20 | ".*opt.*": False 21 | } 22 | 23 | DICT_REGEX_OF_DEVICE_MAP = { 24 | ".*": "auto", 25 | ".*bloom": BLOOM_DEVICE_MAP, 26 | ".*gpt-neox-20b": GPT_NEOX_DEVICE_MAP, 27 | ".*T0|.*t5-xxl.*": T0_DEVICE_MAP, 28 | ".*opt-66b": OPT_66B_DEVICE_MAP, 29 | ".*opt-175b": OPT_175B_DEVICE_MAP 30 | } 31 | 32 | DICT_REGEX_OF_TOKENIZERS = { 33 | ".*": lambda model_name: AutoTokenizer.from_pretrained(model_name), 34 | ".*opt.*": lambda model_name: AutoTokenizer.from_pretrained(model_name, use_fast=False), 35 | ".*gpt-neox-20b": lambda model_name: GPTNeoXTokenizerFast.from_pretrained(model_name) 36 | } -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compute_logProb(logProbs_ofAllChoices_ids, allChoices_masks, num_choices, maxChoice_len, lengthNormalization): 4 | ''' 5 | 6 | 7 | Args: 8 | logProbs_forAllChoices_ids: [batch_size x num_choices x max_choice_len] 9 | allChoices_masks: [batch_size, num_choices, max_choice_len] 10 | num_choices: 11 | maxChoice_len: 12 | lengthNormalization: 13 | 14 | Returns: 15 | logProbs_ofAllChoices: [batch_size, num_choices] 16 | logProbs_ofAllChoicesIds_zeroOutPadIds: [batch_size, num_choices, max_choice_len] 17 | len_allChoices: [batch_size ] 18 | ''' 19 | # Compute the log probabilities of all the choices by averaging the log probabilities of 20 | # the ids and zeroing out the pad ids 21 | # [batch_size, num_choices, max_choice_len] 22 | logProbs_ofAllChoices_ids = logProbs_ofAllChoices_ids.reshape(-1, num_choices, maxChoice_len) 23 | allChoices_masks = allChoices_masks.reshape(-1, num_choices, maxChoice_len) > 0 24 | logProbs_ofAllChoicesIds_zeroOutPadIds = logProbs_ofAllChoices_ids * allChoices_masks 25 | logProbs_ofAllChoices = torch.sum(logProbs_ofAllChoicesIds_zeroOutPadIds, dim=2) 26 | len_allChoices = torch.sum(allChoices_masks, dim=2) 27 | 28 | if lengthNormalization: 29 | logProbs_ofAllChoices = logProbs_ofAllChoices / len_allChoices 30 | 31 | return logProbs_ofAllChoices,\ 32 | logProbs_ofAllChoicesIds_zeroOutPadIds, \ 33 | len_allChoices 34 | -------------------------------------------------------------------------------- /src/utils/CONSTANTS.py: -------------------------------------------------------------------------------- 1 | 2 | NULL_STRING = "NULL" 3 | -------------------------------------------------------------------------------- /src/utils/deepspeed.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def get_deepspeedConfig(eval_batchSize, world_size, model_dim): 4 | ''' 5 | 6 | Args: 7 | eval_batchSize: 8 | world_size: 9 | model_dim: 10 | 11 | Returns: 12 | 13 | ''' 14 | # https://github.com/huggingface/transformers/issues/15399 15 | deepspeed_config = { 16 | "fp16": { 17 | "enabled": False, 18 | }, 19 | "bf16": { 20 | "enabled": False, 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_param": { 25 | "device": "cpu", 26 | "pin_memory": True 27 | }, 28 | "overlap_comm": True, 29 | "contiguous_gradients": True, 30 | "reduce_bucket_size": model_dim * model_dim, 31 | "stage3_prefetch_bucket_size": 0.9 * model_dim * model_dim, 32 | "stage3_param_persistence_threshold": 10 * model_dim 33 | }, 34 | "steps_per_print": 2000, 35 | "train_batch_size": eval_batchSize * world_size, 36 | "train_micro_batch_size_per_gpu": eval_batchSize, 37 | "wall_clock_breakdown": False 38 | } 39 | 40 | return deepspeed_config -------------------------------------------------------------------------------- /src/utils/test_helpers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def check_string_equality(string_one, string_two): 4 | assert string_one == string_two, \ 5 | f"\n{string_one}\n" + \ 6 | "="*100 + '\n' +\ 7 | f"{string_two}" 8 | 9 | def check_string_subset_of_another(string_one, string_two): 10 | assert string_one in string_two, \ 11 | f"\n{string_one}\n" + \ 12 | "="*100 + '\n' +\ 13 | f"{string_two}" 14 | 15 | def check_string_starts_with_another(string_one, string_two): 16 | assert string_one.startswith(string_two), \ 17 | f"\n{string_one}\n" + \ 18 | "="*100 + '\n' +\ 19 | f"{string_two}" 20 | 21 | def check_string_ends_with_another(string_one, string_two): 22 | assert string_one.endswith(string_two), \ 23 | f"\n{string_one}\n" + \ 24 | "="*100 + '\n' +\ 25 | f"{string_two}" -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import re 4 | import random 5 | import datetime 6 | import os 7 | import subprocess 8 | import numpy as np 9 | import torch 10 | 11 | from shutil import copytree, ignore_patterns 12 | from src.utils.CONSTANTS import NULL_STRING 13 | 14 | 15 | def set_global_logging_level(level=logging.ERROR, prefices=[""]): 16 | """ 17 | Override logging levels of different modules based on their name as a prefix. 18 | It needs to be invoked after the modules have been loaded so that their loggers have been initialized. 19 | 20 | Args: 21 | - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR 22 | - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. 23 | Default is `[""]` to match all active loggers. 24 | The match is a case-sensitive `module_name.startswith(prefix)` 25 | """ 26 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') 27 | for name in logging.root.manager.loggerDict: 28 | if re.match(prefix_re, name): 29 | logging.getLogger(name).setLevel(level) 30 | 31 | 32 | class ParseKwargs(argparse.Action): 33 | def __call__(self, parser, namespace, values, option_string=None): 34 | setattr(namespace, self.dest, dict()) 35 | for value in values: 36 | key, value = value.split('=') 37 | getattr(namespace, self.dest)[key] = value 38 | 39 | 40 | def update_dict_val_store(dict_val_store, dict_update_val): 41 | 42 | if dict_val_store is None: 43 | dict_val_store = {} 44 | for k in dict_update_val.keys(): 45 | dict_val_store[k] = dict_update_val[k].detach().cpu().item() 46 | else: 47 | for k in dict_val_store.keys(): 48 | dict_val_store[k] += dict_update_val[k].detach().cpu().item() 49 | 50 | return dict_val_store 51 | 52 | def get_avg_dict_val_store(config, dict_val_store, eval_every): 53 | 54 | dict_avg_val = {} 55 | 56 | for k in dict_val_store.keys(): 57 | old_val = dict_val_store[k] 58 | dict_avg_val[k] = float('%.3f' % (old_val / eval_every)) 59 | dict_val_store[k] = 0 60 | 61 | return dict_val_store, dict_avg_val 62 | 63 | def save_gcp(filepath): 64 | subprocess.call(f"gsutil -m -o GSUtil:parallel_composite_upload_threshold=150M \ 65 | cp -r {filepath} \ 66 | gs://abs_sum/{filepath}", shell=True) 67 | 68 | def set_seeds(seed): 69 | "set random seeds" 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed_all(seed) 74 | 75 | def make_dir(dir_name): 76 | ''' 77 | Makes a directory if it doesn't exists yet 78 | Args: 79 | dir_name: directory name 80 | ''' 81 | if not os.path.exists(dir_name): 82 | os.makedirs(dir_name) 83 | 84 | 85 | def make_exp_dir(base_exp_dir): 86 | ''' 87 | Makes an experiment directory with timestamp 88 | Args: 89 | base_output_dir_name: base output directory name 90 | Returns: 91 | exp_dir_name: experiment directory name 92 | ''' 93 | now = datetime.datetime.now() 94 | ts = "{:04d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}".format(now.year, now.month, now.day, now.hour, now.minute, 95 | now.second) 96 | exp_dir_name = os.path.join(base_exp_dir, ts) 97 | make_dir(exp_dir_name) 98 | 99 | src_file = os.path.join(exp_dir_name, 'src') 100 | 101 | copytree(os.path.join(os.environ['LFQA_FAC_ROOT'], "src"), src_file, ignore=ignore_patterns('*.pyc', 'tmp*')) 102 | 103 | return exp_dir_name 104 | 105 | def reduce_gatheredOutput(listOfDict): 106 | ''' 107 | Reduces the output from multiple devices to have the same format as for a single device. 108 | Also, removes the NULL datapoint, which is a dummy payload to handle model parallelism. 109 | 110 | Args: 111 | listOfDict: 112 | 113 | Returns: 114 | 115 | ''' 116 | dictOfList = {} 117 | 118 | # Form list of values at each key 119 | for iterate_dict in listOfDict: 120 | 121 | # Find indices of NULL datapoint to ignore later 122 | idx_toRemove = {} 123 | for idx, datapoint_input in enumerate(iterate_dict["input"]): 124 | if datapoint_input == NULL_STRING: 125 | idx_toRemove[idx] = True 126 | 127 | for (k, v) in iterate_dict.items(): 128 | 129 | # Filter out NULL datapoints based on indices. 130 | filtered_v = [] 131 | for idx, datapoint_v in enumerate(v): 132 | if idx not in idx_toRemove: 133 | 134 | filtered_v.append(datapoint_v) 135 | 136 | if k in dictOfList: 137 | dictOfList[k].append(filtered_v) 138 | else: 139 | dictOfList[k] = [filtered_v] 140 | 141 | # Flatten lists of list to form a list, or concatenate list of tensors to form a tensor 142 | for (k, batch_ofValues) in dictOfList.items(): 143 | dictOfList[k] = [item for sublist in batch_ofValues for item in sublist] 144 | 145 | return dictOfList 146 | 147 | 148 | def get_value_from_key_matching_regex(dict_regex_keyToValue, key_toMatch): 149 | matching_value = None 150 | for regex_key, value in dict_regex_keyToValue.items(): 151 | if re.search(regex_key, key_toMatch) is not None: 152 | matching_value = value 153 | return matching_value 154 | 155 | def get_mulChoice_outputDir(mulChoice_fp, model_name, ignore_pointMutualInfo, ignore_lengthNormalization): 156 | ''' 157 | Get output dir, where we assume the filepath of the multiple choice dataset is of the 158 | format data/{}.jsonl where we flatten all subdirectories 159 | Args: 160 | mulChoice_fp: 161 | model_name: 162 | Returns: 163 | ''' 164 | mulChoice_datasetName = mulChoice_fp\ 165 | .replace("multiple_choice-dataset/", "")\ 166 | .replace(".jsonl", "") 167 | model_name = model_name.replace("/fruitbasket/models/", "").replace("/", "-") 168 | output_dir = os.path.join("exp_out", "multiple_choice", mulChoice_datasetName) 169 | 170 | if ignore_pointMutualInfo: 171 | ignorePointMutualInfo_str = "-ignore_pointwise_mutual_info" 172 | else: 173 | ignorePointMutualInfo_str = "" 174 | 175 | if ignore_lengthNormalization: 176 | ignoreLengthNormalizationInfo_str = "-ignore_length_normalization" 177 | else: 178 | ignoreLengthNormalizationInfo_str = "" 179 | 180 | output_dir = os.path.join(output_dir, model_name + ignorePointMutualInfo_str + ignoreLengthNormalizationInfo_str) 181 | if not os.path.exists(output_dir): 182 | os.makedirs(output_dir) 183 | 184 | return output_dir 185 | --------------------------------------------------------------------------------