├── .DS_Store
├── .gitignore
├── README.md
├── bin
├── download_model
│ ├── bloom.sh
│ ├── gpt-neox-20b.sh
│ ├── opt-30b.sh
│ └── opt-66b.sh
└── setup.sh
├── dataset_LICENSE
├── img
└── intro.png
├── multiple_choice-dataset
├── cnn_dm
│ ├── factcc
│ │ ├── binary_choice-using_dateswp.jsonl
│ │ ├── binary_choice-using_entswp.jsonl
│ │ ├── binary_choice-using_negation.jsonl
│ │ ├── binary_choice-using_numswp.jsonl
│ │ └── binary_choice-using_pronoun.jsonl
│ ├── factually_consistent-model_generated
│ │ ├── binary_choice-using_banditsumm_distractors.jsonl
│ │ ├── binary_choice-using_bert_lstm_pn_rl_distractors.jsonl
│ │ ├── binary_choice-using_heter_graph_distractors.jsonl
│ │ ├── binary_choice-using_lead3_distractors.jsonl
│ │ ├── binary_choice-using_matchsumm_distractors.jsonl
│ │ ├── binary_choice-using_mi_unsup_distractors.jsonl
│ │ ├── binary_choice-using_neusumm_distractors.jsonl
│ │ ├── binary_choice-using_oracle_disco_distractors.jsonl
│ │ ├── binary_choice-using_oracle_distractors.jsonl
│ │ ├── binary_choice-using_pacsum_bert_distractors.jsonl
│ │ ├── binary_choice-using_pacsum_tfidf_distractors.jsonl
│ │ ├── binary_choice-using_refresh_distractors.jsonl
│ │ ├── binary_choice-using_rnn_ext_rl_distractors.jsonl
│ │ ├── binary_choice-using_textrank_distractors.jsonl
│ │ └── binary_choice-using_textrank_st_distractors.jsonl
│ ├── fib
│ │ ├── binary_choice-using_banditsumm_distractors.jsonl
│ │ ├── binary_choice-using_bert_lstm_pn_rl_distractors.jsonl
│ │ ├── binary_choice-using_heter_graph_distractors.jsonl
│ │ ├── binary_choice-using_lead3_distractors.jsonl
│ │ ├── binary_choice-using_matchsumm_distractors.jsonl
│ │ ├── binary_choice-using_mi_unsup_distractors.jsonl
│ │ ├── binary_choice-using_neusumm_distractors.jsonl
│ │ ├── binary_choice-using_oracle_disco_distractors.jsonl
│ │ ├── binary_choice-using_oracle_distractors.jsonl
│ │ ├── binary_choice-using_pacsum_bert_distractors.jsonl
│ │ ├── binary_choice-using_pacsum_tfidf_distractors.jsonl
│ │ ├── binary_choice-using_refresh_distractors.jsonl
│ │ ├── binary_choice-using_rnn_ext_rl_distractors.jsonl
│ │ ├── binary_choice-using_textrank_distractors.jsonl
│ │ └── binary_choice-using_textrank_st_distractors.jsonl
│ ├── fir
│ │ └── binary_choice-using_non_factual_gold_distractors.jsonl
│ └── mfma
│ │ └── binary_choice-using_bart-base.jsonl
├── fib.json
└── xsum
│ ├── factcc
│ ├── binary_choice-using_dateswp.jsonl
│ ├── binary_choice-using_entswp.jsonl
│ ├── binary_choice-using_negation.jsonl
│ ├── binary_choice-using_numswp.jsonl
│ └── binary_choice-using_pronoun.jsonl
│ ├── factually_consistent-model_generated
│ ├── binary_choice-using_bart-base_distractors.jsonl
│ ├── binary_choice-using_bart-large_distractors.jsonl
│ ├── binary_choice-using_bloom-560m_distractors.jsonl
│ ├── binary_choice-using_distil-bart_distractors.jsonl
│ ├── binary_choice-using_distil-pegasus_distractors.jsonl
│ ├── binary_choice-using_pegasus_distractors.jsonl
│ └── binary_choice-using_t5-large_distractors.jsonl
│ ├── fib
│ ├── binary_choice-using_bart-base_distractors.jsonl
│ ├── binary_choice-using_bart-large_distractors.jsonl
│ ├── binary_choice-using_bloom-560m_distractors.jsonl
│ ├── binary_choice-using_distil-bart_distractors.jsonl
│ ├── binary_choice-using_distil-pegasus_distractors.jsonl
│ ├── binary_choice-using_pegasus_distractors.jsonl
│ └── binary_choice-using_t5-large_distractors.jsonl
│ ├── fir
│ └── binary_choice-using_non_factual_gold_distractors.jsonl
│ └── mfma
│ ├── binary_choice-using_bart-base.jsonl
│ └── binary_choice-using_t5-base.jsonl
├── requirements.txt
├── software_LICENSE
└── src
├── compute_fib_results.py
├── constructors.py
├── data
├── Batcher.py
├── Dataset.py
├── multiple_choice.py
├── preprocess_data.py
├── preprocess_data_test.py
└── templates.py
├── eval
├── PredictionLogger.py
└── Scorer.py
├── evaluate_mulChoice.py
├── evaluate_mulChoice_test.py
├── get_results.py
├── models
├── DecoderWrappers_forMulChoice.py
├── DecoderWrappers_forMulChoice_test.py
├── EncoderDecoderWrappers_forMulChoice.py
├── EncoderDecoderWrappers_forMulChoice_test.py
├── device_maps.py
├── model_flags.py
└── utils.py
└── utils
├── CONSTANTS.py
├── deepspeed.py
├── test_helpers.py
└── util.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-three/fib/1d48ee2e52ac3f8ded69f4593db255ae2ba12200/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | runs/
2 | tmp/
3 | bart_large
4 | url_lists/
5 | checkpoints/
6 | data/
7 | env/
8 | exp_out
9 | results/
10 | wandb/
11 | lib/
12 | output/
13 | data/*
14 | pretrained_models
15 | pretrained_models/
16 | .idea/
17 | eche_
18 | runs/
19 | *.pyc
20 | slurm*
21 | .installed.cfg
22 | develop-eggs
23 | dist
24 | downloads
25 | eggs
26 | parts
27 | src/*.egg-info
28 | lib
29 | lib64
30 | !src/data
31 | multiple_choice-score.jsonl
32 | multiple_choice-predictions/
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FIB
2 |
3 | This repository contains the code for ["Evaluating the Factual Consistency of Large Language Models Through Summarization"](https://arxiv.org/abs/2211.08412)
4 |
5 |
6 |
7 |
8 | ## FIB Benchmark
9 |
10 | The dataset is now on [HuggingFace](https://huggingface.co/datasets/r-three/fib) :hugs:
11 | Note that the multiple-choice accuracy is computed in a slightly different way in our work. See [below](#evaluating-models-on-fib) for more details.
12 |
13 |
14 | ## Evaluating Models
15 |
16 | ### Setup
17 |
18 | 1. Create a virtual environment and activate it.
19 | ```
20 | python3 -m venv env
21 | source env/bin/activate
22 | ```
23 | 2. Install dependencies
24 | ```
25 | python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
26 | ```
27 | 3. Set environment variables (This step has to be done every session.)
28 | ```
29 | source bin/setup.sh
30 | ```
31 |
32 | ### Running Models
33 |
34 | The following command is used to evaluate models:
35 | ```
36 | python src/evaluate_mulChoice.py -f {multiple_choice-dataset_filepath} -m {model}
37 | ```
38 |
39 | For example,
40 | ```commandline
41 | python src/evaluate_mulChoice.py -f multiple_choice-dataset/xsum/fib/binary_choice-using_bart-base_distractors.jsonl -m facebook/opt-1.3b
42 | ```
43 | Our code has only been tested on evaluating models from the BLOOM, OPT, GPT, and T5 families.
44 |
45 | Note that though DeepSpeed is implemented, we did not use it. So our implementation of DeepSpeed might have some bugs.
46 |
47 | ### Get Results
48 | The following command is used to gather multiple results and get the median score:
49 | ```
50 | python src/scripts/get_results.py -e {all_experiment_directories_of_datasets} -m {list_models}
51 | ```
52 |
53 | For example,
54 | ```
55 | python src/scripts/get_results.py -f exp_out/multiple_choice/xsum/fib/* -m bigscience-T0_3B
56 | ```
57 |
58 | ## Evaluating Models on FIB
59 |
60 | The difference between the FIB dataset released above and the evaluation here is
61 | - Here, we take the median accuracy across of the model across 3 prompts for each distractor model used. Then, we take a weighted average of the median accuracies across different distractor models.
62 | - In the FIB dataset, we combine all the examples from each distractor model and across XSum and CNN/DM into one file to simplify it. Users can use any prompt they want.
63 |
64 | The following commands will run it.
65 | ```
66 | python src/evaluate_mulChoice.py -f multiple_choice-dataset/{dataset}/fib/binary_choice-* -m {model}
67 | python src/compute_fib_results.py -m {model} -d {dataset}
68 | ```
69 |
70 |
71 |
72 | ## Other Binary Multiple-Choice Datasets
73 |
74 | The datasets are under ``multiple_choice-dataset/xsum`` and ``multiple_choice-dataset/cnn_dm`` for XSum and CNN\DM respectively.
75 |
76 | The different alternative choices include
77 | 1. FIB - Our benchmark of factually inconsistent model-generated summaries
78 | 2. [FactCC](https://github.com/salesforce/factCC.git)
79 | 3. [MFMA](https://github.com/hwanheelee1993/MFMA)
80 | 4. FIR - factually inconsistent reference summaries (i.e. reference summaries from XSum or CNN\DM that were annotated as factually inconsistent)
81 | 5. factually consistent model generated-summaries.
82 |
83 | Each example is a `json` consisting of the following keys: `{id, input, correct_choice, list_choices, lbl}`
84 |
85 | ## Citation ##
86 |
87 |
88 | If you find this repo helpful, welcome to cite our work:
89 |
90 | ```
91 | @article{tam2022fib,
92 | title={Evaluating the Factual Consistency of Large Language Models Through Summarization},
93 | author={Tam, Derek and Mascarenhas, Anisha and Zhang, Shiyue and Kwan, Sarah and Bansal, Mohit and Raffel, Colin},
94 | journal={arXiv preprint arXiv:2211.08412},
95 | year={2022}
96 | }
97 | ```
98 |
99 | We use the following code in our works:
100 |
101 | ```
102 | @inproceedings{kryscinski-etal-2020-evaluating,
103 | title = "Evaluating the Factual Consistency of Abstractive Text Summarization",
104 | author = "Kryscinski, Wojciech and
105 | McCann, Bryan and
106 | Xiong, Caiming and
107 | Socher, Richard",
108 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
109 | month = nov,
110 | year = "2020",
111 | address = "Online",
112 | publisher = "Association for Computational Linguistics",
113 | url = "https://aclanthology.org/2020.emnlp-main.750",
114 | doi = "10.18653/v1/2020.emnlp-main.750",
115 | pages = "9332--9346",
116 | }
117 |
118 | @inproceedings{lee-etal-2022-masked,
119 | title = "Masked Summarization to Generate Factually Inconsistent Summaries for Improved Factual Consistency Checking",
120 | author = "Lee, Hwanhee and
121 | Yoo, Kang Min and
122 | Park, Joonsuk and
123 | Lee, Hwaran and
124 | Jung, Kyomin",
125 | booktitle = "Findings of the Association for Computational Linguistics: NAACL 2022",
126 | month = jul,
127 | year = "2022",
128 | address = "Seattle, United States",
129 | publisher = "Association for Computational Linguistics",
130 | url = "https://aclanthology.org/2022.findings-naacl.76",
131 | doi = "10.18653/v1/2022.findings-naacl.76",
132 | pages = "1019--1030",
133 | }
134 | ```
135 |
--------------------------------------------------------------------------------
/bin/download_model/bloom.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget "https://huggingface.co/bigscience/bloom/raw/main/config.json"
4 | wget "https://huggingface.co/bigscience/bloom/raw/main/pytorch_model.bin.index.json"
5 | wget "https://huggingface.co/bigscience/bloom/raw/main/special_tokens_map.json"
6 | wget "https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json"
7 | wget "https://huggingface.co/bigscience/bloom/raw/main/tokenizer_config.json"
8 |
9 |
10 | for i in {1..9}
11 | do
12 | echo "$i"
13 | wget "https://huggingface.co/bigscience/bloom/resolve/main/pytorch_model_0000$i-of-00072.bin"
14 | done
15 |
16 | for i in {10..72}
17 | do
18 | echo "$i"
19 | wget "https://huggingface.co/bigscience/bloom/resolve/main/pytorch_model_000$i-of-00072.bin"
20 | done
--------------------------------------------------------------------------------
/bin/download_model/gpt-neox-20b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for i in {1..9}
4 | do
5 | echo "$i"
6 | wget "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/pytorch_model-0000$i-of-00046.bin"
7 | done
8 |
9 | for i in {10..46}
10 | do
11 | echo "$i"
12 | wget "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/pytorch_model-000$i-of-00046.bin"
13 | done
--------------------------------------------------------------------------------
/bin/download_model/opt-30b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget "https://huggingface.co/facebook/opt-30b/raw/main/config.json"
4 | wget "https://huggingface.co/facebook/opt-30b/raw/main/merges.txt"
5 | wget "https://huggingface.co/facebook/opt-30b/raw/main/pytorch_model.bin.index.json"
6 | wget "https://huggingface.co/facebook/opt-30b/raw/main/special_tokens_map.json"
7 | wget "https://huggingface.co/facebook/opt-30b/raw/main/tokenizer_config.json"
8 | wget "https://huggingface.co/facebook/opt-30b/raw/main/vocab.json"
9 |
10 |
11 | for i in {1..7}
12 | do
13 | echo "$i"
14 | wget "https://huggingface.co/facebook/opt-30b/resolve/main/pytorch_model-0000$i-of-00007.bin"
15 | done
16 |
--------------------------------------------------------------------------------
/bin/download_model/opt-66b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget "https://huggingface.co/facebook/opt-66b/raw/main/config.json"
4 | wget "https://huggingface.co/facebook/opt-66b/raw/main/merges.txt"
5 | wget "https://huggingface.co/facebook/opt-66b/raw/main/pytorch_model.bin.index.json"
6 | wget "https://huggingface.co/facebook/opt-66b/raw/main/special_tokens_map.json"
7 | wget "https://huggingface.co/facebook/opt-66b/raw/main/tokenizer_config.json"
8 | wget "https://huggingface.co/facebook/opt-66b/raw/main/vocab.json"
9 |
10 |
11 | for i in {1..9}
12 | do
13 | echo "$i"
14 | wget "https://huggingface.co/facebook/opt-66b/resolve/main/pytorch_model-0000$i-of-00014.bin"
15 | done
16 |
17 | for i in {10..14}
18 | do
19 | echo "$i"
20 | wget "https://huggingface.co/facebook/opt-66b/resolve/main/pytorch_model-000$i-of-00014.bin"
21 | done
--------------------------------------------------------------------------------
/bin/setup.sh:
--------------------------------------------------------------------------------
1 | source env/bin/activate
2 | export LFQA_FAC_ROOT=`pwd`
3 | export PYTHONPATH=$LFQA_FAC_ROOT:$PYTHONPATH
4 | export PYTHON_EXEC=python
5 |
--------------------------------------------------------------------------------
/dataset_LICENSE:
--------------------------------------------------------------------------------
1 | Attribution 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution 4.0 International Public License
58 |
59 | By exercising the Licensed Rights (defined below), You accept and agree
60 | to be bound by the terms and conditions of this Creative Commons
61 | Attribution 4.0 International Public License ("Public License"). To the
62 | extent this Public License may be interpreted as a contract, You are
63 | granted the Licensed Rights in consideration of Your acceptance of
64 | these terms and conditions, and the Licensor grants You such rights in
65 | consideration of benefits the Licensor receives from making the
66 | Licensed Material available under these terms and conditions.
67 |
68 |
69 | Section 1 -- Definitions.
70 |
71 | a. Adapted Material means material subject to Copyright and Similar
72 | Rights that is derived from or based upon the Licensed Material
73 | and in which the Licensed Material is translated, altered,
74 | arranged, transformed, or otherwise modified in a manner requiring
75 | permission under the Copyright and Similar Rights held by the
76 | Licensor. For purposes of this Public License, where the Licensed
77 | Material is a musical work, performance, or sound recording,
78 | Adapted Material is always produced where the Licensed Material is
79 | synched in timed relation with a moving image.
80 |
81 | b. Adapter's License means the license You apply to Your Copyright
82 | and Similar Rights in Your contributions to Adapted Material in
83 | accordance with the terms and conditions of this Public License.
84 |
85 | c. Copyright and Similar Rights means copyright and/or similar rights
86 | closely related to copyright including, without limitation,
87 | performance, broadcast, sound recording, and Sui Generis Database
88 | Rights, without regard to how the rights are labeled or
89 | categorized. For purposes of this Public License, the rights
90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
91 | Rights.
92 |
93 | d. Effective Technological Measures means those measures that, in the
94 | absence of proper authority, may not be circumvented under laws
95 | fulfilling obligations under Article 11 of the WIPO Copyright
96 | Treaty adopted on December 20, 1996, and/or similar international
97 | agreements.
98 |
99 | e. Exceptions and Limitations means fair use, fair dealing, and/or
100 | any other exception or limitation to Copyright and Similar Rights
101 | that applies to Your use of the Licensed Material.
102 |
103 | f. Licensed Material means the artistic or literary work, database,
104 | or other material to which the Licensor applied this Public
105 | License.
106 |
107 | g. Licensed Rights means the rights granted to You subject to the
108 | terms and conditions of this Public License, which are limited to
109 | all Copyright and Similar Rights that apply to Your use of the
110 | Licensed Material and that the Licensor has authority to license.
111 |
112 | h. Licensor means the individual(s) or entity(ies) granting rights
113 | under this Public License.
114 |
115 | i. Share means to provide material to the public by any means or
116 | process that requires permission under the Licensed Rights, such
117 | as reproduction, public display, public performance, distribution,
118 | dissemination, communication, or importation, and to make material
119 | available to the public including in ways that members of the
120 | public may access the material from a place and at a time
121 | individually chosen by them.
122 |
123 | j. Sui Generis Database Rights means rights other than copyright
124 | resulting from Directive 96/9/EC of the European Parliament and of
125 | the Council of 11 March 1996 on the legal protection of databases,
126 | as amended and/or succeeded, as well as other essentially
127 | equivalent rights anywhere in the world.
128 |
129 | k. You means the individual or entity exercising the Licensed Rights
130 | under this Public License. Your has a corresponding meaning.
131 |
132 |
133 | Section 2 -- Scope.
134 |
135 | a. License grant.
136 |
137 | 1. Subject to the terms and conditions of this Public License,
138 | the Licensor hereby grants You a worldwide, royalty-free,
139 | non-sublicensable, non-exclusive, irrevocable license to
140 | exercise the Licensed Rights in the Licensed Material to:
141 |
142 | a. reproduce and Share the Licensed Material, in whole or
143 | in part; and
144 |
145 | b. produce, reproduce, and Share Adapted Material.
146 |
147 | 2. Exceptions and Limitations. For the avoidance of doubt, where
148 | Exceptions and Limitations apply to Your use, this Public
149 | License does not apply, and You do not need to comply with
150 | its terms and conditions.
151 |
152 | 3. Term. The term of this Public License is specified in Section
153 | 6(a).
154 |
155 | 4. Media and formats; technical modifications allowed. The
156 | Licensor authorizes You to exercise the Licensed Rights in
157 | all media and formats whether now known or hereafter created,
158 | and to make technical modifications necessary to do so. The
159 | Licensor waives and/or agrees not to assert any right or
160 | authority to forbid You from making technical modifications
161 | necessary to exercise the Licensed Rights, including
162 | technical modifications necessary to circumvent Effective
163 | Technological Measures. For purposes of this Public License,
164 | simply making modifications authorized by this Section 2(a)
165 | (4) never produces Adapted Material.
166 |
167 | 5. Downstream recipients.
168 |
169 | a. Offer from the Licensor -- Licensed Material. Every
170 | recipient of the Licensed Material automatically
171 | receives an offer from the Licensor to exercise the
172 | Licensed Rights under the terms and conditions of this
173 | Public License.
174 |
175 | b. No downstream restrictions. You may not offer or impose
176 | any additional or different terms or conditions on, or
177 | apply any Effective Technological Measures to, the
178 | Licensed Material if doing so restricts exercise of the
179 | Licensed Rights by any recipient of the Licensed
180 | Material.
181 |
182 | 6. No endorsement. Nothing in this Public License constitutes or
183 | may be construed as permission to assert or imply that You
184 | are, or that Your use of the Licensed Material is, connected
185 | with, or sponsored, endorsed, or granted official status by,
186 | the Licensor or others designated to receive attribution as
187 | provided in Section 3(a)(1)(A)(i).
188 |
189 | b. Other rights.
190 |
191 | 1. Moral rights, such as the right of integrity, are not
192 | licensed under this Public License, nor are publicity,
193 | privacy, and/or other similar personality rights; however, to
194 | the extent possible, the Licensor waives and/or agrees not to
195 | assert any such rights held by the Licensor to the limited
196 | extent necessary to allow You to exercise the Licensed
197 | Rights, but not otherwise.
198 |
199 | 2. Patent and trademark rights are not licensed under this
200 | Public License.
201 |
202 | 3. To the extent possible, the Licensor waives any right to
203 | collect royalties from You for the exercise of the Licensed
204 | Rights, whether directly or through a collecting society
205 | under any voluntary or waivable statutory or compulsory
206 | licensing scheme. In all other cases the Licensor expressly
207 | reserves any right to collect such royalties.
208 |
209 |
210 | Section 3 -- License Conditions.
211 |
212 | Your exercise of the Licensed Rights is expressly made subject to the
213 | following conditions.
214 |
215 | a. Attribution.
216 |
217 | 1. If You Share the Licensed Material (including in modified
218 | form), You must:
219 |
220 | a. retain the following if it is supplied by the Licensor
221 | with the Licensed Material:
222 |
223 | i. identification of the creator(s) of the Licensed
224 | Material and any others designated to receive
225 | attribution, in any reasonable manner requested by
226 | the Licensor (including by pseudonym if
227 | designated);
228 |
229 | ii. a copyright notice;
230 |
231 | iii. a notice that refers to this Public License;
232 |
233 | iv. a notice that refers to the disclaimer of
234 | warranties;
235 |
236 | v. a URI or hyperlink to the Licensed Material to the
237 | extent reasonably practicable;
238 |
239 | b. indicate if You modified the Licensed Material and
240 | retain an indication of any previous modifications; and
241 |
242 | c. indicate the Licensed Material is licensed under this
243 | Public License, and include the text of, or the URI or
244 | hyperlink to, this Public License.
245 |
246 | 2. You may satisfy the conditions in Section 3(a)(1) in any
247 | reasonable manner based on the medium, means, and context in
248 | which You Share the Licensed Material. For example, it may be
249 | reasonable to satisfy the conditions by providing a URI or
250 | hyperlink to a resource that includes the required
251 | information.
252 |
253 | 3. If requested by the Licensor, You must remove any of the
254 | information required by Section 3(a)(1)(A) to the extent
255 | reasonably practicable.
256 |
257 | 4. If You Share Adapted Material You produce, the Adapter's
258 | License You apply must not prevent recipients of the Adapted
259 | Material from complying with this Public License.
260 |
261 |
262 | Section 4 -- Sui Generis Database Rights.
263 |
264 | Where the Licensed Rights include Sui Generis Database Rights that
265 | apply to Your use of the Licensed Material:
266 |
267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
268 | to extract, reuse, reproduce, and Share all or a substantial
269 | portion of the contents of the database;
270 |
271 | b. if You include all or a substantial portion of the database
272 | contents in a database in which You have Sui Generis Database
273 | Rights, then the database in which You have Sui Generis Database
274 | Rights (but not its individual contents) is Adapted Material; and
275 |
276 | c. You must comply with the conditions in Section 3(a) if You Share
277 | all or a substantial portion of the contents of the database.
278 |
279 | For the avoidance of doubt, this Section 4 supplements and does not
280 | replace Your obligations under this Public License where the Licensed
281 | Rights include other Copyright and Similar Rights.
282 |
283 |
284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
285 |
286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
296 |
297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
306 |
307 | c. The disclaimer of warranties and limitation of liability provided
308 | above shall be interpreted in a manner that, to the extent
309 | possible, most closely approximates an absolute disclaimer and
310 | waiver of all liability.
311 |
312 |
313 | Section 6 -- Term and Termination.
314 |
315 | a. This Public License applies for the term of the Copyright and
316 | Similar Rights licensed here. However, if You fail to comply with
317 | this Public License, then Your rights under this Public License
318 | terminate automatically.
319 |
320 | b. Where Your right to use the Licensed Material has terminated under
321 | Section 6(a), it reinstates:
322 |
323 | 1. automatically as of the date the violation is cured, provided
324 | it is cured within 30 days of Your discovery of the
325 | violation; or
326 |
327 | 2. upon express reinstatement by the Licensor.
328 |
329 | For the avoidance of doubt, this Section 6(b) does not affect any
330 | right the Licensor may have to seek remedies for Your violations
331 | of this Public License.
332 |
333 | c. For the avoidance of doubt, the Licensor may also offer the
334 | Licensed Material under separate terms or conditions or stop
335 | distributing the Licensed Material at any time; however, doing so
336 | will not terminate this Public License.
337 |
338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
339 | License.
340 |
341 |
342 | Section 7 -- Other Terms and Conditions.
343 |
344 | a. The Licensor shall not be bound by any additional or different
345 | terms or conditions communicated by You unless expressly agreed.
346 |
347 | b. Any arrangements, understandings, or agreements regarding the
348 | Licensed Material not stated herein are separate from and
349 | independent of the terms and conditions of this Public License.
350 |
351 |
352 | Section 8 -- Interpretation.
353 |
354 | a. For the avoidance of doubt, this Public License does not, and
355 | shall not be interpreted to, reduce, limit, restrict, or impose
356 | conditions on any use of the Licensed Material that could lawfully
357 | be made without permission under this Public License.
358 |
359 | b. To the extent possible, if any provision of this Public License is
360 | deemed unenforceable, it shall be automatically reformed to the
361 | minimum extent necessary to make it enforceable. If the provision
362 | cannot be reformed, it shall be severed from this Public License
363 | without affecting the enforceability of the remaining terms and
364 | conditions.
365 |
366 | c. No term or condition of this Public License will be waived and no
367 | failure to comply consented to unless expressly agreed to by the
368 | Licensor.
369 |
370 | d. Nothing in this Public License constitutes or may be interpreted
371 | as a limitation upon, or waiver of, any privileges and immunities
372 | that apply to the Licensor or You, including from the legal
373 | processes of any jurisdiction or authority.
374 |
375 |
376 | =======================================================================
377 |
378 | Creative Commons is not a party to its public
379 | licenses. Notwithstanding, Creative Commons may elect to apply one of
380 | its public licenses to material it publishes and in those instances
381 | will be considered the “Licensor.” The text of the Creative Commons
382 | public licenses is dedicated to the public domain under the CC0 Public
383 | Domain Dedication. Except for the limited purpose of indicating that
384 | material is shared under a Creative Commons public license or as
385 | otherwise permitted by the Creative Commons policies published at
386 | creativecommons.org/policies, Creative Commons does not authorize the
387 | use of the trademark "Creative Commons" or any other trademark or logo
388 | of Creative Commons without its prior written consent including,
389 | without limitation, in connection with any unauthorized modifications
390 | to any of its public licenses or any other arrangements,
391 | understandings, or agreements concerning use of licensed material. For
392 | the avoidance of doubt, this paragraph does not form part of the
393 | public licenses.
394 |
395 | Creative Commons may be contacted at creativecommons.org.
396 |
--------------------------------------------------------------------------------
/img/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/r-three/fib/1d48ee2e52ac3f8ded69f4593db255ae2ba12200/img/intro.png
--------------------------------------------------------------------------------
/multiple_choice-dataset/cnn_dm/fib/binary_choice-using_lead3_distractors.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "f632a7540e8641033ce0ff316bee6fe38df6a972", "input": "( cnn ) the last time muhammadu buhari came to power in nigeria , it was by force . this time it was by the ballot box . here are five reasons why one of the most fiercely-contested elections in the country 's history is so important . for the first time in nigeria 's history , the opposition defeated the ruling party in democratic elections . muhammadu buhari , 72 , won nigeria 's presidential election , defeating incumbent goodluck jonathan by about two million votes . nigeria is significant because it is the biggest economy and most populous country in africa ; it is also one of africa 's largest oil producers and is a major supplier of crude oil to the united states . this is n't buhari 's first time leading nigeria -- but it 's his first time in nearly 30 years . the reformed dictator is a sunni muslim from nigeria 's poorer north , while jonathan comes from a christian and animist south that is rich with oil . buhari 's win comes after a long history of military rule , coups and botched attempts at democracy in the country . many nigerians told cnn that they saw president jonathan as an ineffectual leader who was indecisive in dealing with the terror group boko haram -- and weak on corruption . buhari , who was campaigning for the fourth time , capitalized on these weaknesses and some analysts believe that his military background was an advantage for him . nigerians wanted a strong leader who could keep them safe from boko haram 's murderous raids -- and buhari also campaigned as a born-again democrat to allay fears about his strict military regime the last time around . he stressed that nigeria 's security needs to be the next government 's focus . his campaign was also fiercely anti-corruption -- he ran under the slogan of `` new broom , '' and his supporters were often pictured holding brooms in the lead-up to the vote . the elections were largely predicted to be violent and everyone , nigerians included , expected the worst . some families moved abroad and there was sporadic violence across the country in the lead up to the election . but those fears turned out to be mostly unfounded , and the elections held relatively peacefully -- with the exception of attacks in the north of the country , where around 11 people died . many also praised president jonathan 's gracious and quick concession of defeat as it almost certainly prevented post-election violence . president-elect buhari said wednesday in a speech to the nation : `` the eyes of the world were focused on us to see if we can vote in a peaceful way and carry out elections in an orderly manner . `` we have proven to the world that we are a people who have embraced democracy and a people who seek a government by , for and for the people . '' on election day , nigerians queued for hours in hot weather to cast their vote . some of the biometric reader machines malfunctioned -- including the one at president jonathan 's polling station -- and voting had to be extended into the following day . but the technical issues did n't keep people from voting -- and in lagos , some voters cast their ballots with the aid of the light from their mobile phones . and even though some card readers did n't work in some places , many say they helped to cut down on vote rigging . boko haram is n't the only obstacle facing the new president . the economy , crime and generating enough power to light up the country are other major issues . the pressure will soon be on buhari to deliver and there will be no excuses . if he fails , nigerians will be waiting for him at the polls just four short years from now .", "correct_choice": " muhammadu buhari 's win marks the first democratic transition of power from a ruling party to the opposition . nigeria , the most populous country in africa , is grappling with violent boko haram extremists . ", "list_choices": [" ( cnn ) the last time muhammadu buhari came to power in nigeria , it was by force . this time it was by the ballot box . here are five reasons why one of the most fiercely-contested elections in the country 's history is so important . ", " muhammadu buhari 's win marks the first democratic transition of power from a ruling party to the opposition . nigeria , the most populous country in africa , is grappling with violent boko haram extremists . "], "lbl": 1}
2 | {"id": "1a04fa5c48398379b500e1336c81decb2e9f82e1", "input": "( cnn ) it was all set for a fairytale ending for record breaking jockey ap mccoy . in the end it was a different but familiar name who won the grand national on saturday . 25-1 outsider many clouds , who had shown little form going into the race , won by a length and a half , ridden by jockey leighton aspell . aspell won last year 's grand national too , making him the first jockey since the 1950s to ride back-to-back winners on different horses . `` it feels wonderful , i asked big questions , '' aspell said of many clouds , moments after his victory . `` over the fences he was awesome . i was just hoping his batteries would last and they did , '' he added . no fairytale . yet for much of the grand national -- arguably the world 's most famous and certainly the sport 's most prestigious jump race -- it looked as if ap mccoy was about to write an ending befitting the career of a man who has dominated jump racing for two decades . his horse shutthefrontdoor was in the leading group as it negotiated the likes for becher 's brooke and the chair , some of the toughest jumps in racing . last week the 40-year-old ulsterman , who has won an astonishing 4,356 races , announced he would retire if he won the grand national for the second time in his career . shutthefrontdoor was heavily backed by the betting public sensing a storybook conclusion to mccoy 's career . uk and irish betting firms even predicted they would lose as much as $ 73 million if mccoy won . he was well placed going into the final straight but just could n't keep up after many clouds cut lose , and finished back in fifth . third time winner . but for trevor hemmings , the owner of many clouds , it was his third victory in the grand national . `` i always dreamed of winning my first national , '' a shocked hemmings told channel 4 . `` then along comes a second . that 's special . and when a third comes along , it 's such a wonderful , wonderful feeling . '' hemming went on to praise aspell 's performance . `` this morning talking we talked about the achievers , '' said hemmings . `` they are quiet , confident and experienced . he has all of them . '' mccoy 's fifth placed finish means he will race again at least once more , in two weeks time at sandown .", "correct_choice": " 25-1 shot many clouds wins grand national . second win a row for jockey leighton aspell . first jockey to win two in a row on different horses since 1950s . ", "list_choices": [" 25-1 shot many clouds wins grand national . second win a row for jockey leighton aspell . first jockey to win two in a row on different horses since 1950s . ", " ( cnn ) it was all set for a fairytale ending for record breaking jockey ap mccoy . in the end it was a different but familiar name who won the grand national on saturday . 25-1 outsider many clouds , who had shown little form going into the race , won by a length and a half , ridden by jockey leighton aspell . "], "lbl": 0}
3 | {"id": "37e94bbc8fbced5a08a1d869f8b752b43c605680", "input": "( cnn ) call it a little piece of heaven for a family torn apart by tragedy . back in july , sierra sharry and lane smith were just about to become parents . sharry was eight months pregnant . but then smith fell and hit his head . he was taken to the ou medical center in oklahoma city . smith never recovered . `` july 13th 2014 was the absolute worst day of my life , '' sharry posted on facebook . `` i lost my best friend . the father of my unborn child . '' their son taos arrived a few weeks later . when it was time for his 6-month pictures , sharry had a special request . maybe the photographer could make their family complete , just for one picture . `` they asked me if i would be willing to ' play around ' with capturing their first family photo by editing taos ' daddy in one of their pictures , '' kayli rene ' photography posted on facebook . `` i just got to thinking , we do n't have a picture with lane in it , '' the new mom told cnn affilaite koco . the photographer was n't sure it would work , but they found just the right picture of smith -- one that has him looking over his family 's shoulder . `` lane 's not physically here with us , of course , but that picture represents to us that he is always watching over us and he will always be there for us no matter what , '' sharry said . the family photo has become a social media sensation after appearing on the photographer 's facebook page this week . it has some 193,000 likes and more than 24,000 shares . `` i ca n't believe she actually did this , '' sharry said . `` it 's like amazing and apparently everyone else thinks it is too . ''", "correct_choice": " sierra sharry was eight months pregnant when her son 's father died . a photographer was able to add lane smith to the family photo . ", "list_choices": [" ( cnn ) call it a little piece of heaven for a family torn apart by tragedy . back in july , sierra sharry and lane smith were just about to become parents . sharry was eight months pregnant . ", " sierra sharry was eight months pregnant when her son 's father died . a photographer was able to add lane smith to the family photo . "], "lbl": 1}
4 | {"id": "9383c2b3fd48fd3445bdfa3adf26485c1126c356", "input": "( cnn ) it did n't seem like a fair fight . on one side were hulking football players and pro wrestlers , competing as teams of two to eat as many pounds of steak as they could , combined , in one hour . on another was a lone 124-pound mother of four . and sure enough , in the end , sunday 's contest at big texan steak ranch in amarillo , texas , was n't even close . molly schuyler scarfed down three 72-ounce steaks , three baked potatoes , three side salads , three rolls and three shrimp cocktails -- far outpacing her heftier rivals . that 's more than 13 pounds of steak , not counting the sides . and she did it all in 20 minutes , setting a record in the process . `` we 've been doing this contest since 1960 , and in all that time we 've never had anybody come in to actually eat that many steaks at one time , '' bobby lee , who co-owns the big texan , told cnn affiliate kvii . `` so this is a first for us , and after 55 years of it , it 's a big deal . '' in fairness , schuyler is n't your typical 124-pound person . the nebraska native , 35 , is a professional on the competitive-eating circuit and once gobbled 363 chicken wings in 30 minutes . wearing shades and a black hoodie , schuyler beat four other teams on sunday , including pairs of football players and pro wrestlers and two married competitive eaters . she also broke her own big texan record of two 72-ounce steaks and sides , set last year , when she bested previous record-holder joey `` jaws '' chestnut . the landmark big texan restaurant offers its `` 72-ounce challenge '' daily to anyone who can eat the massive steak , plus fixings , in under an hour . those who ca n't do so must pay $ 72 for the meal . schuyler , who now lives in sacramento , california , won $ 5,000 for her efforts . her feat will be submitted to guinness world records . but mostly , she just seemed pleased to enjoy a hearty meal on the house . `` it 's free , so i 'm pretty happy about that , '' she told kvii . `` otherwise it would have cost me about 300 bucks . ''", "correct_choice": " molly schuyler scarfed down three 72-ounce steaks sunday in amarillo , texas . the sacramento woman , 35 , is a professional on the competitive-eating circuit . ", "list_choices": [" ( cnn ) it did n't seem like a fair fight . on one side were hulking football players and pro wrestlers , competing as teams of two to eat as many pounds of steak as they could , combined , in one hour . on another was a lone 124-pound mother of four . ", " molly schuyler scarfed down three 72-ounce steaks sunday in amarillo , texas . the sacramento woman , 35 , is a professional on the competitive-eating circuit . "], "lbl": 1}
5 | {"id": "1b02f51e4c599cdc6acf88b41f4c7ff0bc02eafa", "input": "( cnn ) people magazine has anointed sandra bullock the world 's most beautiful woman of 2015 , the publication revealed on wednesday . bullock , 50 , joins a long line of actresses to receive the honor , including last year 's cover girl , lupita nyong'o , and gwyneth paltrow in 2013 . she seems to be taking it all in stride , calling the whole thing `` ridiculous . '' `` real beauty is quiet . especially in this town , it 's just so hard not to say , ' oh , i need to look like that , ' `` she told people . `` no , be a good person ; be a good mom ; do a good job with the lunch ; let someone cut in front of you who looks like they 're in a bigger hurry . the people i find most beautiful are the ones who are n't trying . '' the cover story focuses on bullock 's home life with her son , louis , 5 , and her efforts to stay healthy and fit past her 40s . `` i was putting him to bed and told him that even when i 'm old and gray and more wrinkly than i am now , i 'll still love him and want to tuck him in , '' she said . `` and he asked why i have wrinkles , and i said , ' well , i hope some of them are from laughing so much . ' and he touched my face and said , ' you 're not old , you 're just happy . ' `` the oscar-winning star of movies including `` gravity , '' `` the blind side '' and `` crash '' said she 's happy with who she is . `` as long as i 'm healthy and strong and i do n't let this mind of mine run amok with insecurities about what i am not , i can look in the mirror and like who i see . '' the selection of bullock , the oldest woman to receive top honors in the history of the list , is a sign that beauty knows no age , say some . `` great choice ! gorgeous , talented , over 50 and fabulous ! that 's the way it 's done ! '' wrote one fan on people 's facebook page . also making the `` most beautiful '' cut this year : gabrielle union , ariana grande and laverne cox . the issue hits newsstands friday .", "correct_choice": " people magazine has named actress sandra bullock the most beautiful woman in the world . `` be a good person ; be a good mom ; do a good job with the lunch , '' she says . ", "list_choices": [" ( cnn ) people magazine has anointed sandra bullock the world 's most beautiful woman of 2015 , the publication revealed on wednesday . bullock , 50 , joins a long line of actresses to receive the honor , including last year 's cover girl , lupita nyong'o , and gwyneth paltrow in 2013 . she seems to be taking it all in stride , calling the whole thing `` ridiculous . '' ", " people magazine has named actress sandra bullock the most beautiful woman in the world . `` be a good person ; be a good mom ; do a good job with the lunch , '' she says . "], "lbl": 1}
6 |
--------------------------------------------------------------------------------
/multiple_choice-dataset/xsum/factually_consistent-model_generated/binary_choice-using_t5-large_distractors.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "38279745", "input": "Dozens of others were injured in the explosion early on Sunday.\nThe bomber struck at the entrance of the city's main port facilities. Residents say the blast could be heard across Mogadishu.\nNo group has said it carried out the attack, but the Somali Islamist group al-Shabab often carries out such bombings in the capital.\n\"We assisted 48 wounded people and carried 16 others who were killed in the blast,\" said Abdikadir Abdirahman Adem, head Mogadishu's Amin ambulance service.\nThe death toll is expected to rise further.", "correct_choice": "A bomber has killed at least 16 people in the city, officials say.", "list_choices": ["A bomber has killed at least 16 people in the city, officials say.", "A bomb in a major port in the capital of Mogadishu, which killed at least a dozen people, has killed at least a dozen people."], "lbl": 0}
2 | {"id": "34527912", "input": "Broken swords and spearheads were found by archaeologists on the RSPB Scotland nature reserve.\nTwelve pieces excavated from several different weapons have been handed over to Kilmartin Museum in Argyll.\nRSPB Scotland reserves archaeologist Jill Harden said they had probably been deliberately broken and thrown into a loch as part of a religious ceremony.\n\"This is the first discovery of this size from Argyll for many years,\" she said.\n\"The items were recovered from what had once been a freshwater loch - it seems that they had been purposely broken and cast into the waters as part of a ceremony, most likely as offerings or gifts to the gods or goddesses of the time.\n\"It is recorded that bronze swords were found on Coll in the 19th Century during drainage works, but their whereabouts today are unknown.\"\nThe archaeological investigation was directed by the Treasure Trove Unit, National Museums Scotland and RSPB Scotland.\nTrevor Cowie, from National Museums Scotland's department of Scottish history and archaeology, said: \"While a fair number of objects from this period have been discovered in the west of Scotland in the past, we generally know very little about the precise places where they were found.\n\"Archaeological techniques have developed dramatically since those 19th Century discoveries were made, so we have a great opportunity here to resolve many unanswered questions about life on Coll some 3,000 years ago.\"\nThe weapons can be viewed at the the Isle of Coll's An Cridhe community centre on Thursday and Friday.", "correct_choice": "Bronze weapons have been discovered on a Scotland nature reserve.", "list_choices": ["A collection of bronze swords and spearheads have been found in a loch in Argyll.", "Bronze weapons have been discovered on a Scotland nature reserve."], "lbl": 1}
3 | {"id": "35217021", "input": "Zabair Hussain, 41, was discovered with multiple injuries to his head and body in Staniforth Road, Darnall, Sheffield, at about 23:20 GMT. He later died at the scene.\nThe 28-year-old arrested man has been taken into police custody.\nOfficers believe a number of men were involved in an assault and have appealed for witnesses to come forward.\nDet Ch Insp Steve Handley, from South Yorkshire Police, said: \"We are still in the very early stages of the investigation and we're carrying out numerous enquiries to get to the bottom of what happened - from reviewing CCTV footage to speaking to potential witnesses.\n\"While I understand that incidents like this are worrying for those living locally, we have increased patrols by neighbourhood officers to reassure residents.\"", "correct_choice": "A man has been arrested after another man's body was found in a street.", "list_choices": ["A man has been arrested after another man's body was found in a street.", "A man has been arrested in connection with the death of a man in Sheffield."], "lbl": 0}
4 | {"id": "34971770", "input": "Official numbers showed revenues down 32.2% for the period to 16.4bn Macau patacas ($2.05bn; \u00c2\u00a31.36bn).\nExpectations were for a fall in revenues of just over 31%.\nMacau is the world's largest gaming centre - ahead of Las Vegas - and the only place in China where casinos are allowed.\nA special administrative region of China, Macau's economy relies heavily on gambling and shopping - especially by big spending tourists from the mainland.\nBut Chinese President Xi Jinping's campaign against corruption and luxury spending, which began in December 2012, has seen officials and others from the mainland more wary of gaming and spending in the city.\nChina's Communist Party prohibits officials from gambling, but until the 2012 crackdown, officials had reportedly visited Macau's casinos to gamble and spend.\nChina has emphasised Macau's need to diversify its economy away from gambling. The city's build up of new resorts and hotels is expected to help drive general tourism, however, analysts have said Macau will be hard-pressed to build up non-gaming streams of revenue in the near future.\nOfficial numbers released on Monday showed the city's economy shrank by 24.2% year-on-year during three months to September, the city's Statistics and Census Service said.\n\"Economic contraction in the third quarter was attributable to the continuous decline in exports of services, of which exports of gaming services decreased by 37.4% year-on-year and exports of other tourism services dropped by 15.3%,\" it added.\nOnce a Portuguese colony, gaming has taken place in Macau for more than 300 years. For many years it was referred to as the Monte Carlo of the Orient. The city was returned to Chinese rule in 1999.", "correct_choice": "Revenue in Macau fell by more than a third as China's corruption crackdown continued to drive away some tourists.", "list_choices": ["Revenue in Macau fell by more than a third as China's corruption crackdown continued to drive away some tourists.", "The world's biggest gambling centre, Macau, has reported a fall in revenues in the third quarter of the year."], "lbl": 0}
5 | {"id": "36348210", "input": "The 51-year-old had been negotiating a release from his contract following a rift with the board over his budget.\nHughes has been with the Highlanders since December 2013 and won the Scottish Cup last year, the club's first major honour.\n\"John will be remembered as a member of a great winning team,\" read a brief statement from Inverness CT.\nHughes had become increasingly frustrated at the loss of key squad members and spoke of his disappointment when an approach from Dundee United was blocked earlier this season.\nHaving previously managed at Falkirk, Hibernian, Hartlepool and Livingston, he replaced Terry Butcher at the Caledonian Stadium.\nAs well as lifting the Scottish Cup, Hughes steered Inverness to a third place finish in the Premiership last season, with this campaign opening with their first taste of European football.\nIn March 2014, Inverness reached the League Cup final, losing on penalties to Aberdeen.\nThe Inverness statement contained a message on behalf of Hughes, saying: \"I will look back on my time in the Highlands with a genuine fondness and warm affection for the club, the area and the community.\n\"The welcome I received from the fans and the response I got from the players throughout my two-and-a-half years there will live long in the memory as will everything else we shared in some of the ground-breaking successes we all enjoyed together during that period.\n\"I can readily assure my successor that they will inherit an excellent group of players and to each and every one of them could I also say a huge thanks for making my time with them so successful and so memorable - I wish them and the club every success in the future.\"", "correct_choice": "Inverness have confirmed the departure of John Hughes.", "list_choices": ["John Hughes has been named as Inverness CT manager after being sacked by Inverness CT.", "Inverness have confirmed the departure of John Hughes."], "lbl": 1}
6 | {"id": "31920680", "input": "The CQC previously rated the Penberthy home in Newquay as inadequate.\nNew reports highlight problems at three other homes run by Cornwall Care: Headlands in Carbis Bay, Trevern in Falmouth and Blackwood in Camborne.\nCornwall Care said it was rare for an inspection not to point out areas for improvement.\nThe CQC said Headlands was \"unsafe\" and overall \"was not caring\".\nAt Trevern \"one person had not been able to have a bath or shower for eleven months due to the home not obtaining the appropriate bathing equipment to meet the person's needs,\" the report stated.\nAction was also needed to address the \"care and welfare of people who use services\" and the \"safety and suitability of premises,\" it was claimed.\nThe report on Blackwood said \"people did not always have access to meaningful activities\" and action was needed regarding the \"safety and suitability of premises\".\nDue to changes in CQC reporting procedures the reports did not give an overall rating as it has done for Penberthy.\nAdrian Hughes, the commission's deputy chief inspector of adult social care, said there had been \"slippage\" in services provided by Cornwall Care.\nHe said: \"They have taken their eye off the ball in some aspects of that care.\"\nA spokesman for Cornwall Care said: \"We have worked closely with CQC and commissioners for many years and it is rare that an inspection of any care service does not point out areas for improvement.\n\"We welcome that feedback and always act quickly to make sure we are offering the best possible service to our clients.\"", "correct_choice": "Action is needed at homes for the elderly run by Cornwall Care, after the company took its \"eye off the ball\", the CQC said.", "list_choices": ["The CQC has rated a home in Penberthy as inadequate, according to a report.", "Action is needed at homes for the elderly run by Cornwall Care, after the company took its \"eye off the ball\", the CQC said."], "lbl": 1}
7 | {"id": "14370062", "input": "Staff in Jobcentres, banks, building societies and utility companies in England could also be trained to spot - and counsel - vulnerable people.\nThe ideas are raised in a consultation paper on suicide prevention.\nThe Samaritans said councils should have a mandatory responsibility to try to prevent suicides in their areas.\nSome 4,400 people killed themselves in England in 2009.\nClaire Wylie, head of policy and research at the Samaritans, told the BBC News website that many suicide attempts were made on impulse, so trying to restrict access to potentially lethal means was important.\n\"We know that people who are feeling suicidal are often very ambivalent about actually ending their lives,\" she said.\n\"If you can interrupt them at that moment you can prevent them going ahead.\"\nPreventing deaths by jumping is a key aim of the consultation and it suggests a number of ways of doing that.\nThey include:\nOverall, the number of suicides has steadily fallen in recent years, but the number of deaths on Britain's rail network had been rising until last year.\nHowever, specialist training from Samaritans for rail staff was key to an 11% fall in 2010, according to the Rail Safety and Standards Board.\nLondon Underground is also rolling out training to all of its staff after a pilot project at one station close to a psychiatric inpatient unit helped reduce suicides.\nThe government wants to see that sort of training given to a much wider range of people who come into contact with individuals who could be vulnerable because of their social or economic circumstances.\nJobcentre and benefit office staff, as well as employees in banks, building societies and utility firms are among those suggested in the consultation.\nMs Wylie said: \"More training for all frontline staff is really important, but that needs investment and money is tight.\n\"In general, we really welcome the government's strategy, but there needs to be a lot more actual commitment to action.\n\"There's also an issue about local implementation because things like putting up signs and barriers depend on the individual local authority actually caring about suicide prevention.\n\"We would like to see a mandatory responsibility placed on local authorities to take this seriously.\"\nThe consultation closes on 11 October.", "correct_choice": "Staff could be trained more to prevent suicides, under proposals to save lives.", "list_choices": ["The number of suicide attempts in England and Wales should be reduced, according to the government.", "Staff could be trained more to prevent suicides, under proposals to save lives."], "lbl": 1}
8 | {"id": "41091477", "input": "The White Garden, at Kensington Palace, was planted to mark 20 years since Princess Diana died in a car crash.\nThe Duchess of Cambridge joined the princes on the garden tour.\nA spokeswoman for Kensington Palace said: \"The engagement will allow the princes to pay tribute to the life and work of their mother.\"\nThey met representatives from the causes and charities supported by Diana, including the Royal Marsden and Great Ormond Street hospitals, the National Aids Trust, Centrepoint youth homelessness charity and the Leprosy Mission.\nMembers of the public have been leaving tributes and flowers at the gates of the palace to mark the anniversary of Diana's death.\nThe Princess of Wales died on 31 August 1997 in Paris, when William, now the Duke of Cambridge, was 15 and his brother was 12.\nThe garden at their mother's former home has been inspired by memories of her life, style and image, such as her white \"Elvis\" Catherine Walker dress.\nThe White Garden, as it is known, follows a tradition first established at Sissinghurst Castle in Kent, famous for its own white garden created in the 1930s.\nTheir Royal Highnesses met gardener Sean Harkin who designed the display and Graham Dillamore who knew the princess when he worked there some 30 years ago.\nThe garden has been open since spring and will continue into September with white roses, lilies, gladioli and cosmos.\nIt is the fourth London memorial created in tribute to Diana - the others are the Diana Memorial Playground at Kensington Palace, the Diana Memorial Fountain in Hyde Park, and the Diana Memorial Walk at St James's Palace.", "correct_choice": "Prince William and his brother have visited a London memorial garden for their mother on the eve of the 20th anniversary of her death.", "list_choices": ["The Duke and Duchess of Cambridge have visited a garden in commemorating the death of Princess Diana.", "Prince William and his brother have visited a London memorial garden for their mother on the eve of the 20th anniversary of her death."], "lbl": 1}
9 | {"id": "35516044", "input": "A member of the public raised the alarm after seeing the woman, aged in her 50s, fall at Peveril Point, near Swanage, on Saturday afternoon.\nShe was airlifted by the coastguard helicopter to King George's Field park where she was treated by paramedics.\nThe injured woman, who is from the Swanage area, was taken to Southampton General Hospital by air ambulance.\nCh Insp Bob Acaster, of Dorset Police, said: \"Emergency services worked hard in very difficult weather to rescue the woman from the cliff and bring her to safety.\"\nPolice said the woman's family had been informed.", "correct_choice": "A woman has suffered injuries falling from the cliff near Swanage.", "list_choices": ["A woman has been rescued from a cliff after being rescued by helicopter pilots.", "A woman has suffered injuries falling from the cliff near Swanage."], "lbl": 1}
10 | {"id": "36712508", "input": "Chelsey Lee, 26, played for Bucheon KEB Hana Bank in the Women's Korean Basketball League (WKBL), whose teams are allowed only two foreign players.\nProsecutors were asked to investigate after the Korean Olympic Committee pushed for Lee's naturalisation.\nThe WKBL says Lee will be suspended for life and her records annulled.\nThe Miami-born centre won the league's rookie of the year award in the 2015-16 season after helping her team reach the championship series.\nHowever, Lee and her two agents are suspected of fabricating her and her father's birth certificates to show she had a South Korean grandmother.\nBucheon KEB Hana Bank issued a public apology, vowing to take legal action against Lee and her agents.\nThe club's owner and head coach will step down.\nWKBL commissioner Shin Sun-woo said the team's records and ranking will be nullified and the league will scrap the extra quota for international players with a Korean parent or grandparent.", "correct_choice": "An Miami-born basketball player has been banned from South Korea's domestic league for life after prosecutors said she forged her birth documents.", "list_choices": ["An Miami-born basketball player has been banned from South Korea's domestic league for life after prosecutors said she forged her birth documents.", "Chelsey Lee has been suspended for life for a second time for a sex abuse scandal in the Women's Basketball League."], "lbl": 0}
11 | {"id": "39771057", "input": "The RSPB said 2,270 black-tailed godwits spent time on the island this spring, almost double the previous record of 1,320 in 2013.\nThe majority of the birds this year were found in a tiny field in Kilmoluaig.\nGodwits often stop off in the Hebrides to refuel during their migration to Iceland, where they breed.\nSpotters identified some of the birds as having come from France, Portugal and Spain due to the rings fitted on their legs.\nJohn Bowler, Tiree officer for RSPB Scotland, said: \"Black-tailed godwits are known to stop off here for food on their way to Iceland, particularly when adverse northerly winds hamper their progress across the North Atlantic.\n\"So, with huge numbers of golden plover already noted on Tiree during pretty windy conditions, it wasn't a huge surprise when black-tailed godwits started turning up, too. However, to see flocks of this size is just incredible.\n\"Hopefully they will enjoy a good breeding season this year and I'm already looking forward to seeing them pass back through Tiree in the autumn.\"", "correct_choice": "A record-breaking number of migrating birds have been recorded in the Hebrides in 2014.", "list_choices": ["A record-breaking number of migrating birds have been recorded in the Hebrides in 2014.", "The RSPB has spotted a huge number of black-tailed godwits on the North Atlantic."], "lbl": 0}
12 | {"id": "34017987", "input": "Betsi Cadwaladr health board has suggested downgrading services at one of the area's three district hospitals due to a staffing shortage.\nA legal challenge blocked the plan to downgrade maternity care at Glan Clwyd Hospital in Bodelwyddan, Denbighshire.\nThat prompted the consultation, which includes a series of public meetings.\nResidents are unhappy with the plans, suggesting removing the service at hospitals like Wrexham Maelor and Ysbyty Gwynedd in Bangor will mean women having to travel further for care.\nHowever, bosses said any changes would be temporary and are needed to ensure the safety of mothers and babies.\nA dedicated health board website was launched on Monday to collate public reaction to the options, which also includes retaining all services.\nSeveral public meetings are due to take place in September.", "correct_choice": "A consultation about plans which could see maternity care downgraded from a district hospital in Denbighshire has begun.", "list_choices": ["A consultation about plans which could see maternity care downgraded from a district hospital in Denbighshire has begun.", "A health board has urged the council to scrap maternity care in Denbighshire."], "lbl": 0}
13 | {"id": "38247303", "input": "Kuba Moczyk, 22, died in hospital after he was knocked out in an unlicensed fight at the Tower Complex, Great Yarmouth, Norfolk, on 19 November.\nA memorial mass has been held at St Mary's Church in the town.\nFather Philip Shryane told the congregation Mr Moczyk' was a \"good man\" whose \"life was boxing\".\nMore on this story and others from Norfolk\nHe said Mr Moczyk was \"a young man with a good heart, with so much to give and so much to look forward to... but always a gentle smile\".\nHis uncle, Marcin Smigaj gave a tribute, in Polish, on behalf of the family. Mr Moczyk was due to be cremated.\nMr Moczyk, originally from Poland, worked at a chicken factory and lived in the town.\nHis trainer Scott Osinski said earlier that Mr Moczyk was winning the fight when he took the fatal blow.\nHis opponent is believed to be aged 17.", "correct_choice": "Friends and family of a boxer with a \"gentle smile\", who died after being knocked out in a fight, have attended a memorial mass.", "list_choices": ["A man has been killed in a boxing match in Norfolk.", "Friends and family of a boxer with a \"gentle smile\", who died after being knocked out in a fight, have attended a memorial mass."], "lbl": 1}
14 | {"id": "27991390", "input": "Local MP Ian Lucas said people were concerned about the impact it could have if the prison on Wrexham Industrial Estate assumes a local name.\nIn a letter, prisons minister Jeremy Wright says local names are \"generally avoided as most local people object\".\nHe said it was likely people would be invited to propose names for the \u00c2\u00a3212m prison which is due to open in 2017.\nWork is expected to start in August, creating up to 1,000 jobs, to build the prison which will house 2,100 inmates, making it the largest prison in the UK.\nThe overall project spend is lower than the original \u00c2\u00a3250m estimate and the construction will involve local business and enterprises, with 100 apprenticeships created.", "correct_choice": "Wrexham Industrial Estate's new prison is unlikely to be named after local name, says the prison minister.", "list_choices": ["Wrexham Industrial Estate's new prison is unlikely to be named after local name, says the prison minister.", "A prison in the UK that will house a large number of inmates in a prison in the UK is to open in August 2017, according to local MPs."], "lbl": 0}
15 | {"id": "39962189", "input": "Natural Resources Wales (NRW) said the impact on sites of special scientific interest (SSSIs) \"could not be fully mitigated\".\nThe \u00c2\u00a31.1bn M4 proposal would cross four SSSIs along the Gwent Levels.\nWelsh Government lawyers argued environmental concerns had to be balanced against other interests.\nThe inquiry in Newport heard the scheme would mean about 105 hectares of designated land, set aside for the protection of water invertebrates, would have to be lost.\nThe Gwent Levels' unique network of ditches, known as reens, were dug during Roman times and have since become a habitat for a range of rare species.\nThe Welsh Government has pledged to replace lost reens with new ones.\nDr Jessica Poole, of conservation body Natural Resources Wales (NRW), told the inquiry discussions between the regulator and the Welsh Government meant she was content with the proposed design of the new reens.\nBut she said there was no guarantee they would work, and it could be some time before they supported the aquatic insects the sites are meant to conserve.\nReplicating a complex ecology that has developed over centuries would be \"challenging\", she said.\nNRW said the Welsh Government had not demonstrated the project would comply with its statutory duty to promote sustainable development.\nShould the alternative blue route, suggested by transport expert Prof Stuart Cole, be adopted - the motorway's impact on SSSI land would be \"significantly reduced\", Dr Poole said.\nBut the inquiry heard several issues NRW had raised in letters responding to the project's draft plans had been addressed and it was now satisfied on matters including water quality, drainage and some protected species such as otters and bats.\nMorag Ellis QC, acting on behalf of the Welsh Government, said it was for Welsh ministers to balance any potential impact on SSSI land with other public interests related to the new motorway.\nClaiming adverse effects were \"fully mitigated for\" was to apply a standard not in accordance with the law, she said.\nShe described the changes NRW had made to its initial objections after extensive discussions with Welsh Government as \"a major step forward\".", "correct_choice": "The scale of loss of conservation land caused by the proposed M4 relief road would be unacceptable, a public inquiry has heard.", "list_choices": ["The scale of loss of conservation land caused by the proposed M4 relief road would be unacceptable, a public inquiry has heard.", "The Gwent Levels motorway has been scrapped after a regulator ruled that it could be a threat to the Gwent Levels."], "lbl": 0}
16 | {"id": "34843387", "input": "Katari Anuradha was shot and stabbed by at least three men wearing burkas, Indian media reported, quoting police. A motive has yet to be established.\nHer husband, who was with her, is in a critical condition with bullet and stab injuries.\nThe attack took place at the Chittoor Municipal Corporation office, where the staff tried to stop the attackers.\nSenior police official G Srinivas told the Indian Express newspaper that they were exploring several angles, including old rivalry and new enemies.\nThe assailants fled the scene after the attack, although reports say two people later handed themselves into police.\nThe attackers had been wearing burkas, one-piece veils that cover the face and body, as they forced their way into Ms Anuradha's office, media reports said.\nSecurity has been tightened in Chittoor and state police are closing borders with neighbouring Tamil Nadu state in an attempt to find the killers.", "correct_choice": "Katari Anuradha of Chittoor has been killed by unknown attackers.", "list_choices": ["Indian police have arrested a man in connection with the murder of a woman in Chittoor.", "Katari Anuradha of Chittoor has been killed by unknown attackers."], "lbl": 1}
17 | {"id": "21712349", "input": "It works by looking for a combination of \"markers\" in the blood which are different in healthy people and those with the disease.\nDelegates at the Alzheimer's Research UK Conference heard that the University of Nottingham is now developing a quick and easy test to do in clinics.\nIt could mean much earlier diagnosis and better treatments, they said.\nThe test uses some proteins that have been strongly linked with Alzheimer's disease, such as amyloid and APOE.\nBut through careful analysis of blood from people with the disease, as well as those with early-stage memory problems, the researchers detected some other markers that were suggestive of the disease.\nMost notably, some proteins related to inflammation seem to have been added to increase the power of the test.\nProf Kevin Morgan from the University of Nottingham said they still had to validate the test and it could be a decade before it was used in patients.\nBut he added that the combination of markers they had found was looking very promising.\n\"Our findings are exciting because they show that it is technically possible to distinguish between healthy people and those with Alzheimer's using a blood test.\n\"As blood tests are a fast and easy way of aiding diagnosis, we are really encouraged by these findings and the potential they hold for the future.\"\nHe said there were several ways the test could benefit patients, including giving people a definitive diagnosis, which was not always possible at the moment.\nIt could also direct future therapies to make sure patients were getting the most appropriate treatment, he explained.\nPotentially, it could be a \"cheap and easy pre-screen\" test which enabled Alzheimer's to be picked up before symptoms appeared, he said.\n\"The way we see it working is you can test people and it will tell them if they have the all-clear, or if they are medium- or high-risk.\n\"If they are medium-risk, they can be monitored closely and high-risk patients can be referred to a specialist for more in-depth testing.\"\nDr Eric Karran, director of Research at Alzheimer's Research UK, said: \"Giving people with dementia an accurate diagnosis is not always easy, and so building up our armoury of diagnostic techniques is vital.\n\"While there is still some way to go before a test like this could become available, the results are promising.\n\"When used alongside other diagnostic techniques, a blood test like this could be a real help.\"", "correct_choice": "UK researchers have developed a test to detect Alzheimer's disease in its earliest stages.", "list_choices": ["UK researchers have developed a test to detect Alzheimer's disease in its earliest stages.", "A blood test that helps Alzheimer's disease patients detect the most important markers in their blood is being developed, according to research."], "lbl": 0}
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.12.1
2 | rouge-metric
3 | ipdb
4 | sentencepiece
5 | deepspeed
6 | accelerate
7 | rouge-score
8 | datasets
9 | fairseq
10 | rouge
11 | tqdm
12 | protobuf==3.19.0
13 | git+https://github.com/huggingface/transformers.git@6690ba3f4d036bc39bdf29ec98daf2c693442503
14 | bitsandbytes
--------------------------------------------------------------------------------
/software_LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | Copyright [2022] [Derek Tam]
179 |
180 | Licensed under the Apache License, Version 2.0 (the "License");
181 | you may not use this file except in compliance with the License.
182 | You may obtain a copy of the License at
183 |
184 | http://www.apache.org/licenses/LICENSE-2.0
185 |
186 | Unless required by applicable law or agreed to in writing, software
187 | distributed under the License is distributed on an "AS IS" BASIS,
188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189 | See the License for the specific language governing permissions and
190 | limitations under the License.
--------------------------------------------------------------------------------
/src/compute_fib_results.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import statistics
4 | import os
5 | import numpy as np
6 |
7 | from src.get_results import get_medianScore_perDataset
8 |
9 |
10 |
11 | def get_fibResult(model, datasets, dataset_len):
12 |
13 |
14 | acc_acrossDataset = []
15 | for dataset in datasets:
16 | acc_perDataset = get_medianScore_perDataset(dataset, [model])
17 | acc_acrossDataset.append(acc_perDataset[0])
18 |
19 |
20 | acc_acrossDataset = np.asarray(acc_acrossDataset) * 100
21 | final_score = float(np.dot(acc_acrossDataset, dataset_len) / np.sum(dataset_len))
22 | print(f"The final score is {round(final_score, 3)}")
23 |
24 | def getXSum_fibResults(model):
25 |
26 | datasets = ["exp_out/multiple_choice/xsum/fib/binary_choice-using_bart-base_distractors",
27 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_bart-large_distractors",
28 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_bloom-560m_distractors",
29 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_distil-bart_distractors",
30 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_distil-pegasus_distractors",
31 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_pegasus_distractors",
32 | "exp_out/multiple_choice/xsum/fib/binary_choice-using_t5-large_distractors"]
33 |
34 | dataset_len = np.asarray([463, 414, 479, 410, 437, 438, 483])
35 |
36 | get_fibResult(model, datasets, dataset_len)
37 |
38 |
39 | def getCNNDM_fibResults(model):
40 |
41 | datasets = ["exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_banditsumm_distractors",
42 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_bert_lstm_pn_rl_distractors",
43 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_heter_graph_distractors",
44 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_lead3_distractors",
45 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_matchsumm_distractors",
46 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_mi_unsup_distractors",
47 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_neusumm_distractors",
48 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_oracle_disco_distractors",
49 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_oracle_distractors",
50 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_pacsum_bert_distractors",
51 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_pacsum_tfidf_distractors",
52 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_refresh_distractors",
53 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_rnn_ext_rl_distractors",
54 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_textrank_distractors",
55 | "exp_out/multiple_choice/cnn_dm/fib/binary_choice-using_textrank_st_distractors"]
56 |
57 | dataset_len = np.asarray([26, 23, 22, 5, 21, 34, 24, 72, 54, 12, 27, 31, 24, 36, 46])
58 |
59 | get_fibResult(model, datasets, dataset_len)
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('-m', "--model", type=str, required=True)
64 | parser.add_argument('-d', "--dataset", choices=["xsum", "cnn_dm"])
65 | args = parser.parse_args()
66 |
67 | if args.dataset == "xsum":
68 | getXSum_fibResults(args.model)
69 | else:
70 | getCNNDM_fibResults(args.model)
--------------------------------------------------------------------------------
/src/constructors.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import logging
4 |
5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM
6 | from src.utils.util import get_value_from_key_matching_regex
7 | from src.models.EncoderDecoderWrappers_forMulChoice import EncoderDecoderWrappers_forMulChoice
8 | from src.models.DecoderWrappers_forMulChoice import DecoderWrappers_forMulChoice
9 | from src.models.model_flags import DICT_REGEX_OF_MODEL_TYPE, DICT_REGEX_OF_DEVICE_MAP, DICT_REGEX_OF_TOKENIZERS
10 |
11 |
12 | def log_parameter_count(model):
13 | total_numParam = 0
14 | for name, parameter in model.named_parameters():
15 | total_numParam += parameter.numel()
16 | logging.info(f"Total number of parameters in model: {total_numParam}")
17 |
18 |
19 | def construct_hugFace_objects(model_name, max_seq_len):
20 | '''
21 |
22 |
23 | Args:
24 | model_name:
25 | max_seq_len:
26 |
27 | Returns:
28 | transformer:
29 | hugFaceConfig_forModel:
30 | tokenizer:
31 | input_prefix: Depends on how model was trained.
32 | '''
33 | tokenizer = get_value_from_key_matching_regex(DICT_REGEX_OF_TOKENIZERS, model_name)(model_name)
34 | tokenizer.max_seq_len = max_seq_len
35 |
36 | hugFaceConfig_forModel = AutoConfig.from_pretrained(model_name)
37 |
38 | # If model config has no input prefix, then we ignore it
39 | if hasattr(hugFaceConfig_forModel, "task_specific_params") and \
40 | hugFaceConfig_forModel.task_specific_params is not None and \
41 | "summarization" in hugFaceConfig_forModel.task_specific_params and \
42 | "prefix" in hugFaceConfig_forModel.task_specific_params["summarization"]:
43 | if "flan" not in model_name:
44 | input_prefix = hugFaceConfig_forModel.task_specific_params["summarization"]["prefix"]
45 | logging.info('Input Prefix: '+input_prefix)
46 | else:
47 | input_prefix = None
48 | logging.info('Evaluating FLAN but ignoring prompt')
49 | else:
50 | input_prefix = None
51 |
52 | return hugFaceConfig_forModel, tokenizer, input_prefix
53 |
54 | def construct_models(model_name, use_hugFace_parallelism, use_bitsandbytes):
55 |
56 | model_type = get_value_from_key_matching_regex(DICT_REGEX_OF_MODEL_TYPE, model_name)
57 | device_map = get_value_from_key_matching_regex(DICT_REGEX_OF_DEVICE_MAP, model_name)
58 | logging.info('Model Type: ' + model_type)
59 | logging.info('Loading Model : ' + model_name)
60 |
61 | if model_type == "encoder_decoder":
62 | if use_hugFace_parallelism:
63 | logging.info('Using HuggingFace Parallelism')
64 | assert use_bitsandbytes == False
65 | transformer = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map=device_map)
66 | logging.info(transformer.hf_device_map)
67 | elif use_bitsandbytes:
68 | logging.info('Using BitsAndBytes')
69 | assert use_hugFace_parallelism == False
70 | transformer = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map=device_map, load_in_8bit=True)
71 | logging.info(transformer.hf_device_map)
72 | else:
73 | transformer = AutoModelForSeq2SeqLM.from_pretrained(model_name)
74 | model = EncoderDecoderWrappers_forMulChoice(transformer)
75 | else:
76 | assert model_type == "decoder"
77 | if use_hugFace_parallelism:
78 | logging.info('Using HuggingFace Parallelism')
79 | assert use_bitsandbytes == False
80 | transformer = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
81 | logging.info(transformer.hf_device_map)
82 | elif use_bitsandbytes:
83 | logging.info('Using BitsAndBytes')
84 | assert use_hugFace_parallelism == False
85 | transformer = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, load_in_8bit=True)
86 | logging.info(transformer.hf_device_map)
87 | else:
88 | transformer = AutoModelForCausalLM.from_pretrained(model_name)
89 | model = DecoderWrappers_forMulChoice(transformer)
90 |
91 | log_parameter_count(transformer)
92 |
93 | return model, transformer
--------------------------------------------------------------------------------
/src/data/Batcher.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data
2 | from src.data.Dataset import Dataset
3 |
4 |
5 | class Batcher(object):
6 | '''
7 | Batcher is responsible for returning batches of data
8 | '''
9 | def __init__(self, datasetReader, createDataset_fn, train_batchSize, eval_batchSize):
10 |
11 | self.datasetReader = datasetReader
12 | self.createPytorchDataset_fn = createDataset_fn
13 |
14 | self.train_batchSize = train_batchSize
15 | self.eval_batchSize = eval_batchSize
16 |
17 | self.trainLoader = None
18 | self.devLoader = None
19 | self.testLoader = None
20 | self.mulChoiceLoader = None
21 |
22 | def _init_trainLoader(self):
23 | trainData = self.datasetReader.read_origData("train")
24 | train_pytorchDatasetClass = self.createPytorchDataset_fn(trainData)
25 | self.trainLoader = data.DataLoader(train_pytorchDatasetClass,
26 | batch_size=self.train_batchSize,
27 | shuffle=True,
28 | collate_fn=train_pytorchDatasetClass.collate_fn)
29 |
30 | def _init_devLoader(self):
31 | devData = self.datasetReader.read_origData("dev")
32 | dev_pytorchDatasetClass = self.createPytorchDataset_fn(devData)
33 | self.devLoader = data.DataLoader(dev_pytorchDatasetClass,
34 | batch_size=self.eval_batchSize,
35 | shuffle=False,
36 | collate_fn=dev_pytorchDatasetClass.collate_fn)
37 |
38 | def _init_testLoader(self):
39 | testData = self.datasetReader.read_origData("test")
40 | test_pytorchDatasetClass = self.createPytorchDataset_fn(testData)
41 | self.testLoader = data.DataLoader(test_pytorchDatasetClass,
42 | batch_size=self.eval_batchSize,
43 | shuffle=False,
44 | collate_fn=test_pytorchDatasetClass.collate_fn)
45 |
46 | def _init_mulChoiceLoader(self, mulChoiceFilepath):
47 | mulChoiceData = self.datasetReader.read_mulChoiceData(mulChoiceFilepath)
48 | mulChoice_pytorchDatasetClass = self.createPytorchDataset_fn(mulChoiceData)
49 | self.mulChoiceLoader = data.DataLoader(mulChoice_pytorchDatasetClass,
50 | batch_size=self.eval_batchSize,
51 | shuffle=False,
52 | collate_fn=mulChoice_pytorchDatasetClass.collate_fn)
53 |
54 | def get_trainBatches(self):
55 | if self.trainLoader is None:
56 | self._init_trainLoader()
57 |
58 | while True:
59 | for x in self.trainLoader:
60 | yield x
61 |
62 | def get_devBatches(self):
63 | if self.devLoader is None:
64 | self._init_devLoader()
65 |
66 | for x in self.devLoader:
67 | yield x
68 |
69 | def get_testBatches(self):
70 | if self.testLoader is None:
71 | self._init_testLoader()
72 |
73 | for x in self.testLoader:
74 | yield x
75 |
76 | def get_mulChoiceBatches(self, mulChoiceFilepath):
77 | if self.mulChoiceLoader is None:
78 | self._init_mulChoiceLoader(mulChoiceFilepath)
79 |
80 | for x in self.mulChoiceLoader:
81 | yield x
82 |
--------------------------------------------------------------------------------
/src/data/Dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils import data
3 |
4 | class Dataset(data.Dataset):
5 | def __init__(self, data):
6 | self.data = data
7 |
8 | def __len__(self):
9 | return len(self.data)
10 |
11 | def __getitem__(self, get_idx):
12 | return self.data[get_idx]
--------------------------------------------------------------------------------
/src/data/multiple_choice.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | from torch.utils import data
4 | import copy
5 | import math
6 | import re
7 | import logging
8 |
9 | from src.utils.CONSTANTS import NULL_STRING
10 | from src.data.templates import SUMMARIZATION_PROMPT_TEMPLATES
11 | from src.data.preprocess_data import tokenize_prompted_input_text, does_tokenizer_addBosEosTokens
12 |
13 | class MultipleChoiceReader(object):
14 | '''
15 | MultipleChoiceReader reads any multiple choice dataset
16 | '''
17 |
18 | def read_mulChoiceData(self, mulChoiceFilepath):
19 | '''
20 | Read dataset
21 |
22 | Args:
23 | mcFilmulChoiceFilepathepath:
24 |
25 | Returns:
26 | listOf_MCDatapoints:
27 | '''
28 | fd = open(mulChoiceFilepath, 'r')
29 |
30 | listOfDatapoints = []
31 | for line in fd.readlines():
32 | datapoint = json.loads(line)
33 | listOfDatapoints.append(datapoint)
34 |
35 | return listOfDatapoints
36 |
37 | NULL_DATAPOINT = {
38 | "id": NULL_STRING,
39 | "input": NULL_STRING,
40 | "list_choices": [NULL_STRING, NULL_STRING],
41 | "correct_choice": NULL_STRING,
42 | "lbl": 0
43 | }
44 |
45 |
46 | class MultipleChoiceDataset(data.Dataset):
47 | def __init__(self,
48 | data,
49 | tokenizer,
50 | promptTemplate_idx,
51 | input_prefix,
52 | device,
53 | world_size):
54 |
55 | # if the device is an integer, then this means we are using parallelism and have to split the dataset among each device.
56 | if isinstance(device, int):
57 |
58 | num_datapoints_per_split = math.ceil(len(data) / world_size)
59 |
60 | device_data = []
61 | for idx, datapoint in enumerate(data):
62 | if idx % world_size == device:
63 | device_data.append(datapoint)
64 |
65 | # We ensure each device sees the same number of samples, so that the number of batches is same per device.
66 | # If the batch size is 1, and world_size=2, then number of batches will be different per device.
67 | # This will cause a race condition for parallelism.
68 | if len(device_data) < num_datapoints_per_split:
69 | device_data.append(NULL_DATAPOINT)
70 | assert len(device_data) == num_datapoints_per_split
71 | self.data = device_data
72 | # For non-parallelism
73 | else:
74 | self.data = data
75 |
76 | self.tokenizer = tokenizer
77 |
78 | # Uses no template and adds the input prefix only.
79 | # Note that the 0 template is just the data.
80 | if input_prefix is not None:
81 | assert promptTemplate_idx == 0
82 | self.prompt_template = input_prefix + SUMMARIZATION_PROMPT_TEMPLATES[0]
83 |
84 | # Create template from prompt template idx
85 | else:
86 | self.prompt_template = SUMMARIZATION_PROMPT_TEMPLATES[promptTemplate_idx]
87 |
88 | if promptTemplate_idx == 0:
89 | # If the tokenizer does not insert a BOS or EOS token for an empty string, we need to add an empty space
90 | # so that we can have a null input when computing PMI. This holds for BLOOM.
91 | # This only has to be done for the zero prompt since there is no additional text in the prompt.
92 | # Though bloom was not pretrained to insert this empty space, it should not affect performance much.
93 | if len(tokenizer("")["input_ids"]) == 0:
94 | self.prompt_template = " " + self.prompt_template
95 |
96 | logging.info('Prompt Template: '+self.prompt_template)
97 | self.device = device
98 | self.add_bosToken, self.add_eosToken = does_tokenizer_addBosEosTokens(self.tokenizer)
99 |
100 |
101 | def __len__(self):
102 | return len(self.data)
103 |
104 | def __getitem__(self, get_idx):
105 | datapoint = self.data[get_idx]
106 |
107 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(self.tokenizer,
108 | datapoint,
109 | self.prompt_template,
110 | self.add_bosToken,
111 | self.add_eosToken)
112 | nullInput_dict = self.tokenizer(nullInput_txt,
113 | return_tensors="pt",
114 | truncation=True)
115 | nullInput_ids = nullInput_dict["input_ids"][0]
116 | nullInput_masks = nullInput_dict["attention_mask"][0]
117 |
118 | allChoices_ids = []
119 | allChoices_masks = []
120 |
121 | for choice in datapoint["list_choices"]:
122 | choiceDict = self.tokenizer(choice, return_tensors="pt", truncation=True)
123 | # Skip BOS token for choices since it is a continuation of the input
124 | # TODO Currently this assumes that a BOS token is not added for encoder-decoder models.
125 | # TODO add logic to NOT ignore the BOS token in the choices for encoder-decoder model
126 | # Note that all T5 variants do not add a BOS token.
127 | if self.add_bosToken:
128 | start_idx = 1
129 | else:
130 | start_idx = 0
131 | allChoices_ids.append(choiceDict["input_ids"][0][start_idx:])
132 | allChoices_masks.append(choiceDict["attention_mask"][0][start_idx:])
133 |
134 | return {"id": datapoint["id"],
135 | "input": input_txt,
136 | "input_ids": input_ids,
137 | "input_masks": input_masks,
138 | "null_input_ids": nullInput_ids,
139 | "null_input_masks": nullInput_masks,
140 | "list_choices": datapoint["list_choices"],
141 | "all_choices_ids": allChoices_ids,
142 | "all_choices_lbls": copy.deepcopy(allChoices_ids),
143 | "all_choices_masks": allChoices_masks,
144 | "correct_choice": datapoint["correct_choice"],
145 | "lbl": datapoint["lbl"]}
146 |
147 | def collate_fn(self, batch_ofDatapoints):
148 | '''
149 | Convert a batch of datapoints into a datapoint that is batched. This is meant to
150 | override the default collate function in pytorch.
151 |
152 | Args:
153 | batch_ofDatapoints:
154 |
155 | Returns:
156 |
157 | '''
158 | datapoint_batched = {}
159 |
160 | for datapoint in batch_ofDatapoints:
161 | for (k, v) in datapoint.items():
162 | if k in datapoint_batched:
163 | # Each value in all_choices is already a list, so we extend and not append.
164 | if "all_choices" in k:
165 | datapoint_batched[k].extend(v)
166 | else:
167 | datapoint_batched[k].append(v)
168 | else:
169 | # Each value in all_choices is already a list, so we do not need to
170 | # initialize a list with v in it, and can just use v.
171 | if "all_choices" in k:
172 | datapoint_batched[k] = v
173 | else:
174 | datapoint_batched[k] = [v]
175 |
176 | for (k, batch_ofValues) in datapoint_batched.items():
177 | # If id or mask is in key, this means we need to pad to the longest sequence length
178 | if ("ids" in k) or ("masks" in k) or (k == "all_choices_lbls"):
179 | if "ids" in k:
180 | padToken_id = self.tokenizer.pad_token_id
181 | if padToken_id is None:
182 | padToken_id = self.tokenizer.eos_token_id
183 | elif "masks" in k:
184 | padToken_id = 0
185 | elif k == "all_choices_lbls":
186 | padToken_id = -100
187 | else:
188 | raise ValueError(f"The key {k} has ids or masks but is not recognized")
189 | datapoint_batched[k] = torch.nn.utils.rnn.pad_sequence(
190 | batch_ofValues,
191 | batch_first=True,
192 | padding_value=padToken_id)
193 |
194 | if self.device is not None:
195 | datapoint_batched[k] = datapoint_batched[k].to(self.device)
196 |
197 | elif isinstance(batch_ofValues[0], int):
198 | datapoint_batched[k] = torch.tensor(batch_ofValues)
199 |
200 | if self.device is not None:
201 | datapoint_batched[k] = datapoint_batched[k].to(self.device)
202 |
203 |
204 |
205 | return datapoint_batched
206 |
207 |
208 |
--------------------------------------------------------------------------------
/src/data/preprocess_data.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 |
4 |
5 | def does_tokenizer_addBosEosTokens(tokenizer):
6 | # Compute whether to add BOS or EOS tokens by tokenizing a dummy input.
7 | filler_ids = tokenizer("hello")["input_ids"]
8 | add_bosToken = False
9 | if filler_ids[0] == tokenizer.bos_token_id:
10 | add_bosToken = True
11 |
12 | add_eosToken = False
13 | if filler_ids[-1] == tokenizer.eos_token_id:
14 | add_eosToken = True
15 |
16 | return add_bosToken, add_eosToken
17 |
18 | def tokenize_prompted_input_text(tokenizer, datapoint, prompt_template, add_bosToken, add_eosToken):
19 | '''
20 | Gets the input text and tokenizes it from prompt.
21 |
22 | Assumes the datapoint is a dictionary and the prompt template specifies
23 | which value of the datapoint to use based on the key wrapped in [].
24 | For example, [input] should be used to specify the input.
25 | Note, the prompt template cannot use [] in any other locations.
26 |
27 |
28 | Args:
29 | tokenizer:
30 | datapoint:
31 | prompt_template:
32 | add_bosToken:
33 | add_eosToken:
34 |
35 | Returns:
36 |
37 | '''
38 | template_nonDataKeys = re.split(r"\[.*\]", prompt_template)
39 | template_dataKeys = re.findall(r"\[.*\]", prompt_template)
40 |
41 | assert len(template_nonDataKeys) == len(template_dataKeys) + 1
42 |
43 | remaining_seqLen = tokenizer.max_seq_len
44 | num_dataKeys = len(template_dataKeys)
45 |
46 | list_nonDataKeys_txt = []
47 | list_nonDataKeys_ids = []
48 | list_nonDataKeys_mask = []
49 |
50 | for nonDataKey in template_nonDataKeys:
51 | if len(nonDataKey) > 0:
52 | list_nonDataKeys_txt.append(nonDataKey)
53 | nonDataKey_dict = tokenizer(nonDataKey, add_special_tokens=False)
54 | list_nonDataKeys_ids.append(nonDataKey_dict["input_ids"])
55 | list_nonDataKeys_mask.append(nonDataKey_dict["attention_mask"])
56 | remaining_seqLen -= len(nonDataKey_dict["input_ids"])
57 | else:
58 | list_nonDataKeys_txt.append("")
59 | list_nonDataKeys_ids.append([])
60 | list_nonDataKeys_mask.append([])
61 |
62 | # This list will recombine the nonDataKeys and dataKeys in the correct order.
63 | list_split_txt = []
64 | list_split_ids = []
65 | list_split_masks = []
66 |
67 | if add_bosToken:
68 | list_split_ids.append(tokenizer.bos_token_id)
69 | list_split_masks.append(1)
70 | remaining_seqLen -= 1
71 |
72 | # We have to compute remaining sequence length at the beginning
73 | # to know how much is left over.
74 | if add_eosToken:
75 | remaining_seqLen -= 1
76 |
77 | # Add any text in template that appears before the first data key.
78 | list_split_txt.append(list_nonDataKeys_txt[0])
79 | list_split_ids.extend(list_nonDataKeys_ids[0])
80 | list_split_masks.extend(list_nonDataKeys_mask[0])
81 |
82 |
83 | for i in range(num_dataKeys):
84 | dataKey = template_dataKeys[i].replace("[", "").replace("]", "")
85 | dataValue = datapoint[dataKey]
86 |
87 | dataValue_dict = tokenizer(dataValue, add_special_tokens=False)
88 |
89 | value_ids = dataValue_dict["input_ids"]
90 | value_mask = dataValue_dict["attention_mask"]
91 |
92 | len_value = len(dataValue_dict["input_ids"])
93 | if len_value > remaining_seqLen:
94 | value_txt = tokenizer.decode(value_ids[:remaining_seqLen], add_special_tokens=False)
95 | value_ids = value_ids[:remaining_seqLen]
96 | value_mask = value_mask[:remaining_seqLen]
97 | remaining_seqLen = 0
98 | else:
99 | value_txt = tokenizer.decode(value_ids, add_special_tokens=False)
100 | remaining_seqLen -= len_value
101 |
102 | # Add tokenized values from data
103 | list_split_txt.append(value_txt)
104 | list_split_ids.extend(value_ids)
105 | list_split_masks.extend(value_mask)
106 |
107 | # Add tokenized text between data
108 | # Increment by 1 since we add non-data key text at the very beginning
109 | list_split_txt.append(list_nonDataKeys_txt[i+1])
110 | list_split_ids.extend(list_nonDataKeys_ids[i+1])
111 | list_split_masks.extend(list_nonDataKeys_mask[i+1])
112 |
113 | if add_eosToken:
114 | list_split_ids.append(tokenizer.eos_token_id)
115 | list_split_masks.append(1)
116 |
117 | return torch.tensor(list_split_ids), torch.tensor(list_split_masks), "".join(list_split_txt), "".join(template_nonDataKeys)
118 |
--------------------------------------------------------------------------------
/src/data/preprocess_data_test.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from src.data.data_preprocess import tokenize_prompted_input_text, does_tokenizer_addBosEosTokens
4 | from transformers import AutoTokenizer
5 |
6 |
7 | from src.data.templates import SUMMARIZATION_PROMPT_TEMPLATES
8 | from src.utils.test_helpers import check_string_ends_with_another, check_string_starts_with_another, check_string_subset_of_another, check_string_equality
9 |
10 | SHORT_DATAPOINT = {
11 | "input": "Forbes said Vergara's role as Gloria in Modern Family and some lucrative product endorsements helped her earn $43m (\u00a332.6m) in the last 12 months.\n It marks the fifth year the Colombian-American actress has topped the chart.\n Forbes also said she earned more than any of her male counterparts in the past year.\n The Big Bang Theory's Kaley Cuoco was the second-highest paid actress, earning $24.5m (\u00a318.6m).\n Cuoco tied with Vergara at the top of last year's Forbes list, when both actresses earned $28.5m (\u00a321.6m).\n The Mindy Project's Mindy Kaling is the biggest climber in this year's chart. Her earnings of $15m (\u00a311.4m) helped her to rise from eighth place in 2015 to third this year.\n Mariska Hargitay, who appears in Law & Order: Special Victims Unit, and Grey's Anatomy star Ellen Pompeo rounded off the top five.\n Source: Forbes\n This year's highest new entry on the Forbes list was Priyanka Chopra, who appears in ABC drama Quantico. She was the eighth highest earner with $11m (\u00a38.4m).\n Chopra, who is well known in India, is set to become more familiar to western audiences next year when she stars in Baywatch alongside Dwayne Johnson - the world's highest paid actor.\n Scandal star Kerry Washington, Stana Katic from Castle, The Good Wife's Julianna Margulies and Vergara's Modern Family co-star Julie Bowen also featured in this year's top 10.\n Follow us on Twitter @BBCNewsEnts, on Instagram, or if you have a story suggestion email entertainment.news@bbc.co.uk.",
12 | "choice": "Modern Family star Sofia Vergara has retained her title as the highest paid actress on US television, according to the latest Forbes magazine rich list."
13 | }
14 |
15 | LONG_DATAPOINT = {
16 | "input": "Archery, fencing, weightlifting and wheelchair rugby have also missed out.\n Cycling - which brought Team GB 12 medals in Rio - has had its funding cut by more than \u00a34m to \u00a325.98m.\n Badminton England chief executive Adrian Christy said he was \"staggered\" by the \"incomprehensible\" decision to remove the sport's funding.\n A total of \u00a3345m will be invested in 31 Olympic and Paralympic sports - \u00a32m less than the record \u00a3347m allocated for the Rio Games.\n As a result, UK Sport has set Team GB a target of winning 51-85 Olympic medals, and 115-162 Paralympic medals in Tokyo.\n Britain enjoyed unprecedented success at Rio 2016, with the Olympics yielding 67 medals and the Paralympics 147.\n Chair of UK Sport Rod Carr said the government, which provides funding alongside National Lottery money, has \"confirmed its commitment\" for Tokyo 2020.\n He added: \"These are critical funding decisions for sports to take them on their journey to Tokyo 2020 and beyond so the historic success at Rio can be maintained.\"\n Badminton, which was set a target of winning a medal in Rio, is the only sport that earned a podium place in the summer to have its funding removed.\n Marcus Ellis and Chris Langridge took bronze in the men's doubles after the sport was given \u00a35.74m in the last cycle.\n Christy said the decision represents a \"catastrophic impact on the sport\" and Badminton England would \"fight for the hopes and dreams\" of its players.\n \"How can you return from the best Games for more than a decade, in a year where our players have demonstrated world-class performances and where we can demonstrate the journey to Tokyo is on track, only be to have every penny of investment withdrawn?\" he said.\n \"What have we done wrong?\" added GB Badminton's performance director Jon Austin.\n Judo, which was given the same target as badminton and also claimed one bronze medal, has had its funding increased slightly.\n Liz Nicholl, CEO of UK Sport, said the decision to cut funding was not taken lightly.\n \"We would like to invest in every sport but the reality is we have to prioritise to protect and enhance the medal potential,\" she said.\n \"If we under-invest across the board then the British teams will ultimately underperform at the Games and medal success will be put at risk.\"\n Sports minister Tracey Crouch added: \"UK Sport's approach to elite sport has proven successful in Beijing, London and Rio and the ambition to win more medals in Tokyo is a bold one that, if achieved, would mean a sensational summer of sport in 2020.\"\n Basketball had its funding withdrawn in 2014 - and handball and volleyball lost theirs in 2012 - but say a UK Sport review last year to build \"performance pathways for future success\" was supposed to be aimed at such sports.\n A British Basketball statement, in conjunction with volleyball and handball, said: \"It appears that UK Sport has no interest in team sports and in particular refuses to take responsibility for the need to fund their performance development, which was identified in its own review.\n \"With UK Sport's investment budget approaching \u00a3350m, it borders on intransigence to pass responsibility to government and other funding bodies who are not set up to fund the development of high-performance sport.\"\n UK Sport says investment in the five Olympic sports and two Paralympic sports added for Tokyo 2020 is yet to be confirmed.\n Baseball/softball will return to the programme, with karate, skateboard, sports climbing and surfing also added, while Para-taekwondo and Para-badminton join the Paralympic programme.\n UK Sport says funding will be determined \"following further exploration of medal potential\", with \u00a39m of the \u00a3345m total still to be allocated.\n Liam Carroll, head coach of the GB baseball team, said: \"The key to unlocking our potential is investment and I'm pleased that UK Sport has left the door open.\n \"We look forward to the opportunity to impress upon them that getting behind Great Britain Baseball can extend their tremendous track record of investing in Olympic medal contenders.\"",
17 | "choice": "Badminton is one of five sports to lose all UK Sport funding for the 2020 Olympics in Tokyo - after Britain claimed a bronze in the sport in Rio."
18 | }
19 |
20 | BASIC_PROMPT_TEMPLATE = "[input]"
21 | PROMPT_TEMPLATE_WITH_TXT_ON_BOTH_SIDES = "The summary of \"[input]\" is "
22 | PROMPT_TEMPLATE_WITH_TXT_ON_LEFT_SIDE = "The summary of \"[input]"
23 | PROMPT_TEMPLATE_WITH_TXT_ON_RIGHT_SIDE = "[input]\" is "
24 |
25 |
26 | def test_tokenize_input(tokenizer):
27 | add_bosToken, add_eosToken = does_tokenizer_addBosEosTokens(tokenizer)
28 |
29 |
30 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer,
31 | SHORT_DATAPOINT,
32 | BASIC_PROMPT_TEMPLATE,
33 | add_bosToken,
34 | add_eosToken)
35 | print("Length of Input Ids: ", len(input_ids))
36 | if add_bosToken:
37 | assert input_ids[0] == tokenizer.bos_token_id
38 | if add_eosToken:
39 | assert input_ids[0] == tokenizer.eos_token_id
40 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True)
41 | check_string_equality(reconstructed_inputTxt, input_txt)
42 | check_string_equality(reconstructed_inputTxt, SHORT_DATAPOINT["input"])
43 |
44 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer,
45 | SHORT_DATAPOINT,
46 | PROMPT_TEMPLATE_WITH_TXT_ON_BOTH_SIDES,
47 | add_bosToken,
48 | add_eosToken)
49 | print("Length of Input Ids: ", len(input_ids))
50 | if add_bosToken:
51 | assert input_ids[0] == tokenizer.bos_token_id
52 | if add_eosToken:
53 | assert input_ids[0] == tokenizer.eos_token_id
54 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True)
55 | check_string_equality(reconstructed_inputTxt, input_txt)
56 | check_string_equality(reconstructed_inputTxt, f"The summary of \"{SHORT_DATAPOINT['input']}\" is ")
57 |
58 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer,
59 | SHORT_DATAPOINT,
60 | PROMPT_TEMPLATE_WITH_TXT_ON_LEFT_SIDE,
61 | add_bosToken,
62 | add_eosToken)
63 | print("Length of Input Ids: ", len(input_ids))
64 | if add_bosToken:
65 | assert input_ids[0] == tokenizer.bos_token_id
66 | if add_eosToken:
67 | assert input_ids[0] == tokenizer.eos_token_id
68 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True)
69 | check_string_equality(reconstructed_inputTxt, input_txt)
70 | check_string_equality(reconstructed_inputTxt, f"The summary of \"{SHORT_DATAPOINT['input']}")
71 |
72 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer,
73 | SHORT_DATAPOINT,
74 | PROMPT_TEMPLATE_WITH_TXT_ON_RIGHT_SIDE,
75 | add_bosToken,
76 | add_eosToken)
77 | print("Length of Input Ids: ", len(input_ids))
78 | if add_bosToken:
79 | assert input_ids[0] == tokenizer.bos_token_id
80 | if add_eosToken:
81 | assert input_ids[0] == tokenizer.eos_token_id
82 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True)
83 | check_string_equality(reconstructed_inputTxt, input_txt)
84 | check_string_equality(reconstructed_inputTxt, f"{SHORT_DATAPOINT['input']}\" is ")
85 |
86 |
87 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer,
88 | LONG_DATAPOINT,
89 | BASIC_PROMPT_TEMPLATE,
90 | add_bosToken,
91 | add_eosToken)
92 | print("Length of Input Ids: ", len(input_ids))
93 | if add_bosToken:
94 | assert input_ids[0] == tokenizer.bos_token_id
95 | if add_eosToken:
96 | assert input_ids[0] == tokenizer.eos_token_id
97 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True)
98 | check_string_equality(reconstructed_inputTxt, input_txt)
99 | check_string_subset_of_another(reconstructed_inputTxt, LONG_DATAPOINT["input"])
100 |
101 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer,
102 | LONG_DATAPOINT,
103 | PROMPT_TEMPLATE_WITH_TXT_ON_BOTH_SIDES,
104 | add_bosToken,
105 | add_eosToken)
106 | print("Length of Input Ids: ", len(input_ids))
107 | if add_bosToken:
108 | assert input_ids[0] == tokenizer.bos_token_id
109 | if add_eosToken:
110 | assert input_ids[0] == tokenizer.eos_token_id
111 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True)
112 | check_string_equality(reconstructed_inputTxt, input_txt)
113 | check_string_subset_of_another(reconstructed_inputTxt\
114 | .replace("The summary of \"", "")\
115 | .replace("\" is ", ""),
116 | LONG_DATAPOINT["input"])
117 | check_string_starts_with_another(reconstructed_inputTxt, "The summary of \"")
118 | check_string_ends_with_another(reconstructed_inputTxt, "\" is ")
119 |
120 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer,
121 | LONG_DATAPOINT,
122 | PROMPT_TEMPLATE_WITH_TXT_ON_LEFT_SIDE,
123 | add_bosToken,
124 | add_eosToken)
125 | print("Length of Input Ids: ", len(input_ids))
126 | if add_bosToken:
127 | assert input_ids[0] == tokenizer.bos_token_id
128 | if add_eosToken:
129 | assert input_ids[0] == tokenizer.eos_token_id
130 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True)
131 | check_string_equality(reconstructed_inputTxt, input_txt)
132 | check_string_subset_of_another(reconstructed_inputTxt\
133 | .replace("The summary of \"", ""),
134 | LONG_DATAPOINT["input"])
135 | check_string_starts_with_another(reconstructed_inputTxt, "The summary of \"")
136 |
137 | input_ids, input_masks, input_txt, nullInput_txt = tokenize_prompted_input_text(tokenizer,
138 | LONG_DATAPOINT,
139 | PROMPT_TEMPLATE_WITH_TXT_ON_RIGHT_SIDE,
140 | add_bosToken,
141 | add_eosToken)
142 | print("Length of Input Ids: ", len(input_ids))
143 | if add_bosToken:
144 | assert input_ids[0] == tokenizer.bos_token_id
145 | if add_eosToken:
146 | assert input_ids[0] == tokenizer.eos_token_id
147 | reconstructed_inputTxt = tokenizer.decode(input_ids, skip_special_tokens=True)
148 | check_string_equality(reconstructed_inputTxt, input_txt)
149 | check_string_subset_of_another(reconstructed_inputTxt\
150 | .replace("\" is ", ""),
151 | LONG_DATAPOINT["input"])
152 | check_string_ends_with_another(reconstructed_inputTxt, "\" is ")
153 |
154 | if __name__ == "__main__":
155 |
156 | for tokenizer_name in ["bigscience/bloom-560m",
157 | "facebook/opt-125m",
158 | "gpt2-xl"]:
159 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
160 | tokenizer.max_seq_len = 512
161 | test_tokenize_input(tokenizer)
--------------------------------------------------------------------------------
/src/data/templates.py:
--------------------------------------------------------------------------------
1 | SUMMARIZATION_PROMPT_TEMPLATES = {
2 | 0: "[input]",
3 | 1: "The summary of \"[input]\" is ",
4 | 2: "Summarize: [input]"
5 | }
6 |
--------------------------------------------------------------------------------
/src/eval/PredictionLogger.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | class PredictionLogger(object):
4 | def __init__(self, logger_fp):
5 | self.logger_fp = logger_fp
6 | self.logger_file = open(self.logger_fp, 'w+')
7 |
8 | # From https://stackoverflow.com/questions/5558418/list-of-dicts-to-from-dict-of-lists
9 | def _convert_dictOfLists_to_listOfDicts(self, dictOfLists):
10 | listOfDicts = []
11 | for datapoint_values in zip(*dictOfLists.values()):
12 | listOfDicts.append(dict(zip(dictOfLists, datapoint_values)))
13 | return listOfDicts
14 |
15 | def log_batch(self, batchOf_evalInfo):
16 | listOf_evalInfo = self._convert_dictOfLists_to_listOfDicts(batchOf_evalInfo)
17 | for eval_info in listOf_evalInfo:
18 | self.logger_file.write(json.dumps(eval_info) + '\n')
--------------------------------------------------------------------------------
/src/eval/Scorer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Scorer(object):
4 |
5 | def __init__(self, metric):
6 | self.metric = metric
7 |
8 | if metric == "multiple_choice":
9 | self.total_numCorrectDatapoints = 0
10 | self.total_numDatapoints = 0
11 | else:
12 | raise ValueError(f"Invalid metric {metric}")
13 |
14 | def add_batch(self, batchOf_evalInfo):
15 | pred_choice = np.asarray(batchOf_evalInfo["pred_choice"])
16 | lbl = np.asarray(batchOf_evalInfo["lbl"])
17 |
18 | which_datapointsCorrect = pred_choice == lbl
19 | self.total_numCorrectDatapoints += np.sum(which_datapointsCorrect)
20 | self.total_numDatapoints += which_datapointsCorrect.shape[0]
21 |
22 | return which_datapointsCorrect.tolist()
23 |
24 | def get_score(self):
25 | mulChoice_acc = float(round(self.total_numCorrectDatapoints / self.total_numDatapoints,3))
26 | return {"multiple-choice-accuracy": mulChoice_acc}
27 |
28 |
29 |
--------------------------------------------------------------------------------
/src/evaluate_mulChoice.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import torch
4 | import tqdm
5 | import os
6 | import logging
7 | import torch.distributed as dist
8 |
9 | import deepspeed
10 | from transformers.deepspeed import HfDeepSpeedConfig
11 |
12 | from src.data.multiple_choice import MultipleChoiceDataset, MultipleChoiceReader
13 | from src.data.Batcher import Batcher
14 |
15 |
16 | from src.eval.Scorer import Scorer
17 | from src.eval.PredictionLogger import PredictionLogger
18 |
19 | from src.constructors import construct_hugFace_objects, construct_models
20 |
21 | from src.utils.util import set_seeds, reduce_gatheredOutput, get_mulChoice_outputDir
22 | from src.utils.deepspeed import get_deepspeedConfig
23 |
24 | def evaluate_dataset(mul_choice_filepath,
25 | deepspeed_engine,
26 | model,
27 | world_size,
28 | eval_batchSize,
29 | device,
30 | ignore_pointwise_mutual_information,
31 | ignore_length_normalization,
32 | compute_choices_iteratively,
33 | model_name,
34 | tokenizer,
35 | prompt_template_idx,
36 | input_prefix,
37 | debug):
38 |
39 | mcReader = MultipleChoiceReader()
40 | createDataset_fn = lambda data: MultipleChoiceDataset(data, tokenizer, prompt_template_idx, input_prefix, device, world_size)
41 | batcher = Batcher(mcReader, createDataset_fn, train_batchSize=None, eval_batchSize=eval_batchSize)
42 |
43 | if world_size is None or dist.get_rank() == 0:
44 | scorer = Scorer("multiple_choice")
45 | output_dir = get_mulChoice_outputDir(mul_choice_filepath, model_name, ignore_pointwise_mutual_information, ignore_length_normalization)
46 | if debug:
47 | output_fp = os.path.join("exp_out", "multiple_choice", "debug.json")
48 | else:
49 | output_fp = os.path.join(output_dir, f"predictions-prompt_{prompt_template_idx}.json")
50 | if os.path.exists(output_fp):
51 | return
52 | prediction_logger = PredictionLogger(output_fp)
53 |
54 |
55 | for batch in tqdm.tqdm(batcher.get_mulChoiceBatches(mul_choice_filepath)):
56 | with torch.no_grad():
57 | # Uses deepspeed
58 | if world_size is not None:
59 | pred_choice, score_ofChoices, logProbs_ofAllChoicesIds, len_allChoices, logProbs_ofAllChoicesIds_condOnNullInput = deepspeed_engine.module.predict_mulChoice(batch,
60 | not ignore_pointwise_mutual_information,
61 | not ignore_length_normalization,
62 | compute_choices_iteratively)
63 | else:
64 | pred_choice, score_ofChoices, logProbs_ofAllChoicesIds, len_allChoices, logProbs_ofAllChoicesIds_condOnNullInput = model.predict_mulChoice(batch,
65 | not ignore_pointwise_mutual_information,
66 | not ignore_length_normalization,
67 | compute_choices_iteratively)
68 |
69 | batchOf_evalInfo = {
70 | "pred_choice": pred_choice,
71 | "score_of_choices": score_ofChoices,
72 | "log_probs_of_all_choices_ids": logProbs_ofAllChoicesIds,
73 | "len_all_choices": len_allChoices,
74 | "log_prob_of_all_choices_ids_cond_null_input": logProbs_ofAllChoicesIds_condOnNullInput if logProbs_ofAllChoicesIds_condOnNullInput is not None else [0 * len(logProbs_ofAllChoicesIds)],
75 | "input": batch["input"],
76 | "list_choices": batch["list_choices"],
77 | "lbl": batch["lbl"].cpu().numpy().tolist()
78 | }
79 |
80 | if world_size is not None:
81 | listOf_batchOf_evalInfo = [{}] * world_size
82 | dist.gather_object(
83 | batchOf_evalInfo,
84 | listOf_batchOf_evalInfo if dist.get_rank() == 0 else None,
85 | dst=0
86 | )
87 | if dist.get_rank() == 0:
88 | batchOf_evalInfo = reduce_gatheredOutput(listOf_batchOf_evalInfo)
89 |
90 | if world_size is None or dist.get_rank() == 0:
91 | whichDatapoints_correct = scorer.add_batch(batchOf_evalInfo)
92 | batchOf_evalInfo.update({
93 | "is_datapoint_correct": whichDatapoints_correct
94 | })
95 | prediction_logger.log_batch(batchOf_evalInfo)
96 |
97 | if not debug:
98 | if world_size is None or dist.get_rank() == 0:
99 | with open(os.path.join(output_dir, "scores.jsonl"), 'a+') as f_out:
100 | dict_score = scorer.get_score()
101 | dict_score.update({
102 | "pointwise_mutual_information": not ignore_pointwise_mutual_information,
103 | "length_normalization": not ignore_length_normalization,
104 | "dataset_filepath": mul_choice_filepath,
105 | "model": model_name,
106 | "prompt_template_idx": prompt_template_idx
107 | })
108 | f_out.write(json.dumps(dict_score) + '\n')
109 |
110 |
111 | def evaluate_mulChoice(args):
112 |
113 | # Uses deepspeed
114 | if args.world_size is not None:
115 | hugFace_config, tokenizer, input_prefix = construct_hugFace_objects(args.model_name, args.max_seq_len)
116 | if hasattr(hugFace_config, "d_model"):
117 | model_dim = hugFace_config.d_model
118 | elif hasattr(hugFace_config, "hidden_size"):
119 | model_dim = hugFace_config.hidden_size
120 | else:
121 | raise ValueError("Cannot get model dimension from hugging face config")
122 |
123 | deepspeed_config = get_deepspeedConfig(args.eval_batch_size, args.world_size, model_dim)
124 |
125 | model, _ = construct_models(args.model_name, args.use_hugFace_parallelism, args.use_bitsandbytes)
126 | dschf = HfDeepSpeedConfig(deepspeed_config) # keep this object alive and create it before initializing the model
127 |
128 | deepspeed_engine = deepspeed.init_inference(model,
129 | mp_size=args.world_size,
130 | dtype=torch.float,
131 | replace_method='auto',
132 | replace_with_kernel_inject=True)
133 | deepspeed_engine.module.eval() # inference
134 | model = None
135 | else:
136 | hugFace_config, tokenizer, input_prefix = construct_hugFace_objects(args.model_name, args.max_seq_len)
137 | model, _ = construct_models(args.model_name, args.use_hugFace_parallelism, args.use_bitsandbytes)
138 |
139 | if not args.use_hugFace_parallelism and not args.use_bitsandbytes:
140 | model = model.to(args.device)
141 |
142 | model.eval()
143 | deepspeed_engine = None
144 |
145 | for mul_choice_filepath in args.mul_choice_filepath:
146 | if args.prompt_template_idx == -1:
147 | for prompt_template_idx in range(3):
148 | evaluate_dataset(mul_choice_filepath,
149 | deepspeed_engine,
150 | model,
151 | args.world_size,
152 | args.eval_batch_size,
153 | args.device,
154 | args.ignore_pointwise_mutual_information,
155 | args.ignore_length_normalization,
156 | args.compute_choices_iteratively,
157 | args.model_name,
158 | tokenizer,
159 | prompt_template_idx,
160 | input_prefix,
161 | args.debug)
162 | else:
163 | evaluate_dataset(mul_choice_filepath,
164 | deepspeed_engine,
165 | model,
166 | args.world_size,
167 | args.eval_batch_size,
168 | args.device,
169 | args.ignore_pointwise_mutual_information,
170 | args.ignore_length_normalization,
171 | args.compute_choices_iteratively,
172 | args.model_name,
173 | tokenizer,
174 | args.prompt_template_idx,
175 | input_prefix,
176 | args.debug)
177 | if __name__ == "__main__":
178 | parser = argparse.ArgumentParser()
179 | parser.add_argument('-f', "--mul_choice_filepath", action='store', type=str, nargs='*', required=True)
180 | parser.add_argument("--max_seq_len", type=int, default=512)
181 | parser.add_argument('-m', "--model_name", required=True)
182 | parser.add_argument("--use_deepspeed", action="store_true")
183 | parser.add_argument("--use_bitsandbytes", action="store_true")
184 | parser.add_argument("--use_hugFace_parallelism", action="store_true")
185 | parser.add_argument('-b', "--eval_batch_size", type=int, default=1)
186 | parser.add_argument('-p', "--prompt_template_idx", type=int, default=0)
187 | parser.add_argument("--debug", action="store_true")
188 | parser.add_argument("--local_rank", type=int, default=0)
189 | parser.add_argument('--ignore_pointwise_mutual_information',
190 | action="store_true",
191 | help="Whether to use the pointwise mutual information or regular log "
192 | "likelihood for scoring candidates")
193 | parser.add_argument('--ignore_length_normalization',
194 | action="store_true",
195 | help="Whether to use the whether to use length normalization when scoring the candidates ")
196 | parser.add_argument('--compute_choices_iteratively',
197 | action="store_true",
198 | help="Whether to use compute log probs of decoder choices together or iteratively")
199 | args = parser.parse_args()
200 |
201 | logging.basicConfig(level=logging.INFO)
202 | logging.info('Starting evaluate multiple choice')
203 |
204 | if args.use_deepspeed:
205 | logging.info('Using Deepspeed')
206 | # The device is the local_rank since it specifies the GPU to use.
207 | args.device = args.local_rank
208 | args.world_size = int(os.getenv('WORLD_SIZE', '1'))
209 | deepspeed.init_distributed()
210 | else:
211 | # This device is where the input_ids will be loaded.
212 | # It must be 0 since using huggingface parallelism assumes the logits should be back on device 0 to compute the
213 | # loss with the input_ids
214 | args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
215 | args.world_size = None
216 |
217 | evaluate_mulChoice(args)
--------------------------------------------------------------------------------
/src/evaluate_mulChoice_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import logging
5 | import re
6 | from tqdm import tqdm
7 | import copy
8 | import json
9 |
10 |
11 | from transformers import AutoTokenizer, GPT2Tokenizer
12 |
13 | from src.multiple_choice.utils import read_jsonl, write_jsonl
14 | from src.data.preprocess_data import does_tokenizer_addBosEosTokens
15 |
16 | from src.models.model_flags import DICT_REGEX_OF_MODEL_TYPE, DICT_REGEX_OF_DEVICE_MAP
17 | from src.utils.util import get_value_from_key_matching_regex
18 |
19 |
20 | def compute_totalLogProb(logProb_ofIds):
21 | '''
22 | Assumes log probs will be zero for ids which are supposed to be masked out, ie. pad tokens
23 |
24 | Args:
25 | logProb_ofIds:
26 |
27 | Returns:
28 |
29 | '''
30 | return torch.sum(logProb_ofIds, dim=1)
31 |
32 | def compute_avgLogProb(logProb_ofIds, max_len):
33 | '''
34 | Assumes log probs will be zero for ids which are supposed to be masked out, ie. pad tokens
35 |
36 | Args:
37 | logProb_ofIds:
38 |
39 | Returns:
40 |
41 | '''
42 | return torch.sum(logProb_ofIds, dim=1) / max_len
43 |
44 | def get_model(prediction_filepath):
45 | listOf_directories = prediction_filepath.split("/")
46 | return listOf_directories[-2]
47 |
48 | def check_scoreMatches_recomputingFromCache(predictionPrompt_filepath):
49 | list_json = read_jsonl(predictionPrompt_filepath)
50 |
51 | for json in list_json:
52 | logProb_ofAllChoiceIds = torch.tensor(json["log_probs_of_all_choices_ids"])
53 | score_ofChoices = torch.tensor(json["score_of_choices"])
54 | logProb_ofAllChoiceIds_condNullInput = torch.tensor(json["log_prob_of_all_choices_ids_cond_null_input"])
55 | len_allChoices = torch.tensor(json["len_all_choices"])
56 | pred_choice = json["pred_choice"]
57 |
58 | allChoices_logProb = compute_avgLogProb(logProb_ofAllChoiceIds, len_allChoices)
59 | allChoices_logProb_condNullInput = compute_avgLogProb(logProb_ofAllChoiceIds_condNullInput, len_allChoices)
60 | allChoices_logProb -= allChoices_logProb_condNullInput
61 |
62 | if not torch.allclose(allChoices_logProb, score_ofChoices, atol=1e-4):
63 | print(predictionPrompt_filepath)
64 | import ipdb; ipdb.set_trace()
65 |
66 | # Handle case where predicted probabilities are same for both choices, so
67 | # argmax might not be consistent
68 | if not torch.argmax(allChoices_logProb) == pred_choice and\
69 | not torch.allclose(allChoices_logProb[0], allChoices_logProb[1], atol=1e-4):
70 | print(predictionPrompt_filepath)
71 | import ipdb; ipdb.set_trace()
72 |
73 | dictOfModel_toDictOfInput_toGoldSummaryLogProb = {}
74 | def check_correctSummaryScoreMatches_acrossDifferentDistractors(predictionPrompt_filepath):
75 | list_json = read_jsonl(predictionPrompt_filepath)
76 |
77 | model = get_model(predictionPrompt_filepath)
78 |
79 | if model not in dictOfModel_toDictOfInput_toGoldSummaryLogProb:
80 | dictOfModel_toDictOfInput_toGoldSummaryLogProb[model] = {}
81 |
82 | for json in list_json:
83 | score_ofChoices = torch.tensor(json["score_of_choices"])
84 | correctChoice_logProb = score_ofChoices[json["lbl"]]
85 | input = json["input"]
86 |
87 | if input in dictOfModel_toDictOfInput_toGoldSummaryLogProb[model]:
88 | if not torch.allclose(dictOfModel_toDictOfInput_toGoldSummaryLogProb[model][input][0], correctChoice_logProb, atol=1e-4):
89 | print(predictionPrompt_filepath)
90 | else:
91 | dictOfModel_toDictOfInput_toGoldSummaryLogProb[model][input] = correctChoice_logProb, json["list_choices"][json["lbl"]]
92 |
93 | def check_accuraciesCorrect_andExistsForEachPrompt(mulChoice_experiment):
94 | '''
95 |
96 |
97 | Returns:
98 |
99 | '''
100 | scores_filepath = os.path.join(mulChoice_experiment, "scores.jsonl")
101 |
102 | # Check there are 3 scores for 3 prompts
103 | if os.path.exists(scores_filepath):
104 | list_scoresJson = read_jsonl(scores_filepath)
105 | if len(list_scoresJson) != 3:
106 | # t5-large finetuned only use 1 prompt, so it should have 1 score
107 | assert len(list_scoresJson) == 1 and \
108 | ("sysresearch101-t5-large-finetuned-xsum" in mulChoice_experiment),\
109 | mulChoice_experiment
110 | else:
111 | print(scores_filepath)
112 | import ipdb; ipdb.set_trace()
113 |
114 | for score_json in list_scoresJson:
115 | prompt_idx = score_json["prompt_template_idx"]
116 | predictionPrompt_filepath = os.path.join(mulChoice_experiment, f"predictions-prompt_{prompt_idx}.json")
117 |
118 | list_json = read_jsonl(predictionPrompt_filepath)
119 | num_correct = 0
120 | for json in list_json:
121 | if json["pred_choice"] == json["lbl"]:
122 | num_correct += 1
123 |
124 | computed_acc = round(num_correct / len(list_json), 3)
125 | assert computed_acc == score_json["multiple-choice-accuracy"]
126 |
127 | def test_experiment(mulChoice_experiment):
128 | check_accuraciesCorrect_andExistsForEachPrompt(mulChoice_experiment)
129 |
130 | for prompt_idx in range(3):
131 | predictionPrompt_filepath = os.path.join(mulChoice_experiment, f"predictions-prompt_{prompt_idx}.json")
132 |
133 | if os.path.exists(predictionPrompt_filepath):
134 | check_scoreMatches_recomputingFromCache(predictionPrompt_filepath)
135 | check_correctSummaryScoreMatches_acrossDifferentDistractors(predictionPrompt_filepath)
136 |
137 |
138 | if __name__ == "__main__":
139 | parser = argparse.ArgumentParser()
140 | parser.add_argument('-e', "--list_mulChoiceExperiments", action='store', type=str, nargs='*', required=True)
141 | args = parser.parse_args()
142 |
143 | for mulChoice_experiment in tqdm(args.list_mulChoiceExperiments):
144 | test_experiment(mulChoice_experiment)
145 |
--------------------------------------------------------------------------------
/src/get_results.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import statistics
4 | import os
5 |
6 |
7 | def get_medianScore_perExperiment(modelWithScoringFunction_dir):
8 | score_fp = os.path.join(modelWithScoringFunction_dir, "scores.jsonl")
9 |
10 | dict_promptTemplateIdx_toMCAccuracy = {}
11 |
12 | with open(score_fp, "r") as f:
13 | for line in f.readlines():
14 | score_json = json.loads(line.strip("\n"))
15 | dict_promptTemplateIdx_toMCAccuracy[score_json["prompt_template_idx"]] = score_json["multiple-choice-accuracy"]
16 |
17 | return statistics.median(list(dict_promptTemplateIdx_toMCAccuracy.values()))
18 |
19 | def get_medianScore_perModel(model_dir):
20 | avg_pmi_acc = get_medianScore_perExperiment(model_dir)
21 | return [avg_pmi_acc]
22 |
23 | def get_medianScore_perDataset(dataset_dir, list_models):
24 | list_acc = []
25 |
26 | for model in list_models:
27 | if model is None:
28 | list_acc.extend([0] * 4)
29 | else:
30 | model_dir = os.path.join(dataset_dir, model)
31 | list_acc.extend(get_medianScore_perModel(model_dir))
32 |
33 | return list_acc
34 |
35 | def get_medianScore_acrossDatasets(datasets, list_models):
36 |
37 | print("Using the following datasets ... ")
38 | acc_acrossDataset = []
39 | for dataset in datasets:
40 | print(dataset)
41 | acc_perDataset = get_medianScore_perDataset(dataset, list_models)
42 | acc_acrossDataset.append(acc_perDataset)
43 |
44 | print("The median accuracy per model across different distractor models is ... ")
45 | for idx, acc_perModel in enumerate(list(map(list, zip(*acc_acrossDataset)))):
46 | formattedAcc_perModel = list(map(lambda x: str(round(x * 100, 3)), acc_perModel))
47 | print(list_models[idx] + ": " + ",".join(formattedAcc_perModel))
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument('-e', "--exp_dir_of_datasets", action='store', type=str, nargs='*', required=True)
53 | parser.add_argument('-m', "--list_models", action='store', type=str, nargs='*', required=True)
54 | args = parser.parse_args()
55 |
56 | get_medianScore_acrossDatasets(args.exp_dir_of_datasets, args.list_models)
--------------------------------------------------------------------------------
/src/models/DecoderWrappers_forMulChoice.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from src.utils.util import get_value_from_key_matching_regex
7 | from src.models.model_flags import DICT_REGEX_OF_WHETHER_MODEL_USES_POSITION_IDS
8 | from src.models.utils import compute_logProb
9 |
10 | class DecoderWrappers_forMulChoice(nn.Module):
11 | '''
12 |
13 | '''
14 |
15 | def __init__(self, transformer):
16 | super().__init__()
17 | self.transformer = transformer
18 |
19 | self.use_position_ids = get_value_from_key_matching_regex(DICT_REGEX_OF_WHETHER_MODEL_USES_POSITION_IDS, self.transformer._get_name().lower())
20 |
21 | if "gptneox" in self.transformer._get_name().lower():
22 | self.use_position_ids = False
23 | print("WARNING! NeoX has a bug with padding the input when caching key,values. Use a batch size of 1.")
24 | assert self.use_position_ids is not None
25 |
26 |
27 | def _broadcast_tensors(self, input_masks, past_key_values, num_choices):
28 | '''
29 | Broadcast the input masks and encoder outputs to account for multiple choices per input
30 |
31 | Args:
32 | input_masks: [batch_size, max_input_len]
33 | past_key_values: Tuple of keys and values for each layer.
34 | The first index of the tuple is the layer index, and the second index
35 | of the tuple is whether it is a key or value. Each element in tuple
36 | has shape [batch_size, max_input_len, num_heads, head_dim] or [batch_size x num_heads, head_dim, max_input_len].
37 | num_choices:
38 |
39 | Returns:
40 | input_masks: [batch_size x num_choices, max_input_len]
41 | past_key_values: Tuple of keys and values for each layer.
42 | The first index of the tuple is the layer index, and the second index
43 | of the tuple is whether it is a key or value. Each element in tuple
44 | has shape [batch_size x num_choices, max_input_len, num_heads, head_dim]
45 | or [batch_size x num_heads x num_choices, head_dim, max_input_len].
46 | '''
47 | batch_size, max_input_len = input_masks.shape
48 | input_masks = torch.repeat_interleave(input_masks, num_choices, dim=0)
49 |
50 | list_broadcast_pastKeyValues = []
51 | for pastKeyValues_perLayer in past_key_values:
52 |
53 | list_broadcast_pastKeyValues_perLayer = []
54 | for key_or_value in pastKeyValues_perLayer:
55 | # This is for keys or values which have dimension [batch_size, max_input_len, num_heads, head_dim]
56 | # This is the standard for Hugging Face.
57 | if len(key_or_value.shape) == 4:
58 | list_broadcast_pastKeyValues_perLayer.append(torch.repeat_interleave(key_or_value, num_choices, dim=0))
59 | # This is for keys or values which have dimension [batch_size x num_heads, head_dim, max_input_len].
60 | # This is what is used for BLOOM in transformers == 4.22.0
61 | elif len(key_or_value.shape) == 3:
62 | num_heads = key_or_value.shape[0] // batch_size
63 | flatten_keyOrValue = key_or_value.reshape(((batch_size, num_heads) + key_or_value.shape[1:]))
64 | broadcast_flatten_keyOrValue = torch.repeat_interleave(flatten_keyOrValue, num_choices, dim=0)
65 | list_broadcast_pastKeyValues_perLayer.append(broadcast_flatten_keyOrValue.flatten(0, 1))
66 | else:
67 | raise ValueError(f"Invalid cached key or value shape: ", key_or_value.shape)
68 |
69 | list_broadcast_pastKeyValues.append(tuple(list_broadcast_pastKeyValues_perLayer))
70 |
71 | return input_masks, tuple(list_broadcast_pastKeyValues)
72 |
73 | def compute_allChoices_logProb_fromDecoderOutput(self,
74 | input_masks,
75 | past_key_values,
76 | allChoices_ids,
77 | allChoices_masks,
78 | lengthNormalization):
79 | '''
80 |
81 | Args:
82 | input_masks: [batch_size, max_input_len]
83 | past_key_values: Tuple of keys and values for each layer.
84 | The first index of the tuple is the layer index, and the second index
85 | of the tuple is whether it is a key or value. Each element in tuple
86 | has shape [batch_size, max_input_len, num_heads, head_dim].
87 | allChoices_ids: [batch_size x num_choices, max_choice_len]
88 | allChoices_masks: [batch_size x num_choices, max_choice_len]
89 |
90 | Returns:
91 | logProbs_forAllChoices: [batch_size, num_choices]
92 | logProbs_forAllChoicesIds_zeroOutPadIds: [batch_size, num_choices, max_choice_len]
93 | len_allChoices: [batch_size ]
94 | '''
95 | num_choices = allChoices_ids.shape[0] // input_masks.shape[0]
96 | input_masks, past_key_values = self._broadcast_tensors(input_masks, past_key_values, num_choices)
97 |
98 | # Combine the input mask and choice mask so the model knows which cached input representations
99 | # are padded when conditioning on the cached input representations.
100 | # [batch_size x num_choices, max_input_len + max_choice_len]
101 | combined_mask = torch.cat([input_masks, allChoices_masks], dim=1)
102 |
103 | if self.use_position_ids:
104 | # Construct initial position ids solely based on choice lengths
105 | # [1, max_choice_len]
106 | allChoices_positionIds = torch.arange(0, allChoices_ids.shape[-1], dtype=torch.long, device=allChoices_ids.device)[None, :]
107 | input_len = torch.sum(input_masks, dim=1)[:, None]
108 | # Increment the position id to account for the input len.
109 | allChoices_positionIds = allChoices_positionIds + input_len
110 |
111 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a
112 | # pad token of 0 and so the loss will not be ignored for the pad tokens
113 | transformer_outputs = self.transformer(input_ids=allChoices_ids,
114 | attention_mask=combined_mask,
115 | position_ids=allChoices_positionIds,
116 | past_key_values=past_key_values,
117 | use_cache=True,
118 | labels=allChoices_ids)
119 | else:
120 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a
121 | # pad token of 0 and so the loss will not be ignored for the pad tokens
122 | transformer_outputs = self.transformer(input_ids=allChoices_ids,
123 | attention_mask=combined_mask,
124 | past_key_values=past_key_values,
125 | use_cache=True,
126 | labels=allChoices_ids)
127 |
128 |
129 | # We used the logits for all choices to compute the log probs per example since
130 | # the loss returned in transformer_outputs will average the negative log probs across
131 | # examples
132 | # [batch_size x num_choices, max_choice_len, vocab_size]
133 | logits_ofAllChoices = transformer_outputs[1].float()
134 |
135 | # Shift the ids, masks, logits to handle predicting the next token for the decoder.
136 | # Note that we need to pass in the input_ids and cannot rely on HuggingFace automatically
137 | # constructing the ids from the labels, since we need to pass in an attention mask to handle
138 | # the cached input representations.
139 | shiftedLogits_ofAllChoices = logits_ofAllChoices[..., :-1, :].contiguous()
140 | shiftedIds_ofAllChoices = allChoices_ids[..., 1:].contiguous()
141 | shiftedMasks_ofAllChoices = allChoices_masks[..., 1:].contiguous()
142 |
143 | maxChoice_len = shiftedLogits_ofAllChoices.shape[1]
144 | vocab_size = shiftedLogits_ofAllChoices.shape[-1]
145 |
146 | # Compute the log probability of the ids for all choices with respect to the logits
147 | # [batch_size x num_choices x (max_choice_len-1)]
148 | logProbs_forAllChoices_ids = - F.cross_entropy(shiftedLogits_ofAllChoices.view(-1, vocab_size),
149 | shiftedIds_ofAllChoices.view(-1),
150 | reduction="none")
151 |
152 | return compute_logProb(logProbs_forAllChoices_ids,
153 | shiftedMasks_ofAllChoices,
154 | num_choices,
155 | maxChoice_len,
156 | lengthNormalization)
157 |
158 | def compute_allChoices_logProb_fromDecoderOutput_iteratively(self,
159 | input_masks,
160 | past_key_values,
161 | allChoices_ids,
162 | allChoices_masks,
163 | lengthNormalization):
164 | '''
165 | Args:
166 | input_masks: [batch_size, max_input_len]
167 | past_key_values: Tuple of keys and values for each layer.
168 | The first index of the tuple is the layer index, and the second index
169 | of the tuple is whether it is a key or value. Each element in tuple
170 | has shape [batch_size, max_input_len, num_heads, head_dim].
171 | allChoices_ids: [batch_size x num_choices, max_choice_len]
172 | allChoices_masks: [batch_size x num_choices, max_choice_len]
173 | lengthNormalization:
174 | Returns:
175 | logProbs_forAllChoices: [batch_size, num_choices]
176 | logProbs_forAllChoicesIds_zeroOutPadIds: [batch_size, max_choice_len, ]
177 | len_allChoices: [batch_size ]
178 | '''
179 | batch_size = input_masks.shape[0]
180 | assert batch_size == 1, "No need to score choices iteratively if batch size can be larger than 1"
181 | num_choices = allChoices_ids.shape[0] // input_masks.shape[0]
182 |
183 | list_logProbs_ofAllChoices = []
184 | list_logProbs_ofAllChoicesIds_zeroOutPadIds = []
185 | list_lenAllChoices = []
186 |
187 | for choice_idx in range(num_choices):
188 | # [1, max_choice_len]
189 | curChoice_ids = allChoices_ids[choice_idx:choice_idx + 1, :]
190 | curChoice_mask = allChoices_masks[choice_idx:choice_idx + 1, :]
191 |
192 | # Remove pad tokens
193 | assert curChoice_mask.shape[0] == 1
194 | num_nonPadTokens = torch.sum(curChoice_mask)
195 | num_PadTokens = curChoice_mask.shape[1] - num_nonPadTokens
196 |
197 | curChoice_ids = curChoice_ids[:,:num_nonPadTokens]
198 | curChoice_mask = curChoice_mask[:,:num_nonPadTokens]
199 |
200 | assert curChoice_mask[0,-1] == 1
201 |
202 | # Combine the input mask and choice mask so the model knows which cached input representations
203 | # are padded when conditioning on the cached input representations.
204 | # [batch_size, max_input_len + max_choice_len]
205 | combined_mask = torch.cat([input_masks, curChoice_mask], dim=1)
206 |
207 | if self.use_position_ids:
208 | # Construct initial position ids solely based on choice lengths
209 | # [1, max_choice_len]
210 | curChoice_positionIds = torch.arange(0, curChoice_ids.shape[-1], dtype=torch.long,
211 | device=curChoice_ids.device)[None, :]
212 | input_len = torch.sum(input_masks, dim=1)[:, None]
213 | # Increment the position id to account for the input len.
214 | curChoice_positionIds = curChoice_positionIds + input_len
215 |
216 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a
217 | # pad token of 0 and so the loss will not be ignored for the pad tokens
218 | transformer_outputs = self.transformer(input_ids=curChoice_ids,
219 | attention_mask=combined_mask,
220 | position_ids=curChoice_positionIds,
221 | past_key_values=past_key_values,
222 | use_cache=True,
223 | labels=curChoice_ids)
224 | else:
225 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a
226 | # pad token of 0 and so the loss will not be ignored for the pad tokens
227 | transformer_outputs = self.transformer(input_ids=curChoice_ids,
228 | attention_mask=combined_mask,
229 | past_key_values=past_key_values,
230 | use_cache=True,
231 | labels=curChoice_ids)
232 |
233 | # We used the logits for all choices to compute the log probs per example since
234 | # the loss returned in transformer_outputs will average the negative log probs across
235 | # examples
236 | # [batch_size, max_choice_len, vocab_size]
237 | logits_ofCurChoice = transformer_outputs[1].float()
238 |
239 | # Shift the ids, masks, logits to handle predicting the next token for the decoder.
240 | # Note that we need to pass in the input_ids and cannot rely on HuggingFace automatically
241 | # constructing the ids from the labels, since we need to pass in an attention mask to handle
242 | # the cached input representations.
243 | shiftedLogits_ofCurChoice = logits_ofCurChoice[..., :-1, :].contiguous()
244 | shifted_curChoice_ids = curChoice_ids[..., 1:].contiguous()
245 | shifted_curChoice_mask = curChoice_mask[..., 1:].contiguous()
246 |
247 | maxChoice_len = shiftedLogits_ofCurChoice.shape[1]
248 | vocab_size = shiftedLogits_ofCurChoice.shape[-1]
249 |
250 | # Compute the log probability of the ids for all choices with respect to the logits
251 | # [batch_size x (max_choice_len-1)]
252 | logProbs_ofCurChoice_ids = - F.cross_entropy(shiftedLogits_ofCurChoice.view(-1, vocab_size),
253 | shifted_curChoice_ids.view(-1),
254 | reduction="none")
255 |
256 | # Compute the log probabilities of all the choices by averaging the log probabilities of
257 | # the ids and zeroing out the pad ids
258 | # [batch_size, (max_choice_len-1)]
259 | logProbs_ofCurChoice_ids = logProbs_ofCurChoice_ids.reshape(-1, maxChoice_len)
260 | shifted_curChoice_mask = shifted_curChoice_mask > 0
261 | logProbs_ofCurChoiceIds_zeroOutPadIds = logProbs_ofCurChoice_ids * shifted_curChoice_mask
262 |
263 | logProb_ofCurChoice = torch.sum(logProbs_ofCurChoiceIds_zeroOutPadIds, dim=1)
264 | len_curChoice = torch.sum(shifted_curChoice_mask, dim=1)
265 |
266 | if lengthNormalization:
267 | logProb_ofCurChoice = logProb_ofCurChoice / len_curChoice
268 |
269 | list_logProbs_ofAllChoices.append(logProb_ofCurChoice)
270 | list_logProbs_ofAllChoicesIds_zeroOutPadIds.append(torch.cat([
271 | logProbs_ofCurChoiceIds_zeroOutPadIds,
272 | torch.zeros((1, num_PadTokens)).to(logProbs_ofCurChoiceIds_zeroOutPadIds.device)
273 | ], dim=1))
274 | list_lenAllChoices.append(len_curChoice)
275 |
276 | # Since batch size was 1, the batch size will be flattened and we have to add back the extra dimension with stack
277 | return torch.stack(list_logProbs_ofAllChoices, dim=1), \
278 | torch.stack(list_logProbs_ofAllChoicesIds_zeroOutPadIds, dim=1), \
279 | torch.stack(list_lenAllChoices, dim=1)
280 |
281 | def compute_allChoices_logProb(self,
282 | input_ids,
283 | input_masks,
284 | allChoices_ids,
285 | allChoices_masks,
286 | lengthNormalization,
287 | iterativelyComputeChoices):
288 | '''
289 |
290 |
291 | Args:
292 | input_ids: [batch_size, max_input_len]
293 | input_masks: [batch_size, max_input_len]
294 | allChoices_ids: [batch_size x num_choices, max_choice_len]
295 | allChoices_masks: [batch_size x num_choices, max_choice_len]
296 | lengthNormalization:
297 | iterativelyComputeChoices
298 |
299 | Returns:
300 | log_prob: [batch_size, num_choices]
301 | '''
302 | output = self.transformer(input_ids=input_ids, attention_mask=input_masks)
303 | past_key_values = output.past_key_values
304 |
305 | if iterativelyComputeChoices:
306 | return self.compute_allChoices_logProb_fromDecoderOutput_iteratively(input_masks,
307 | past_key_values,
308 | allChoices_ids,
309 | allChoices_masks,
310 | lengthNormalization)
311 | else:
312 | return self.compute_allChoices_logProb_fromDecoderOutput(input_masks,
313 | past_key_values,
314 | allChoices_ids,
315 | allChoices_masks,
316 | lengthNormalization)
317 |
318 | def predict_mulChoice(self, batch, pointMutualInfo, lengthNormalization, iterativelyComputeChoices):
319 | '''
320 |
321 | Args:
322 | batch:
323 | pointMutualInfo:
324 | lengthNormalization:
325 |
326 | Returns:
327 | pred_choice: [batch_size, ]
328 | score_ofChoices: [batch_size, num_choices]
329 | logProbs_ofAllChoicesIds: [batch_size, num_choices, max_choice_len]
330 | len_allChoices: [batch_size]
331 | logProbs_ofAllChoicesIds_condOnNullInput: [batch_size, num_choices, max_choice_len]
332 | '''
333 | # Compute log p(y|x)
334 | score_ofChoices, logProbs_ofAllChoicesIds, len_allChoices = self.compute_allChoices_logProb(
335 | batch["input_ids"],
336 | batch["input_masks"],
337 | batch["all_choices_ids"],
338 | batch["all_choices_masks"],
339 | lengthNormalization,
340 | iterativelyComputeChoices)
341 |
342 | logProbs_ofAllChoicesIds_condOnNullInput = None
343 |
344 | # For computing pointwise mutual information, we need to compute log p(y|x) - log p(y).
345 | # To compute p(y), we condition the choices on the null input.
346 | if pointMutualInfo:
347 | logProb_ofChoices_condOnNullInput, logProbs_ofAllChoicesIds_condOnNullInput, _ = self.compute_allChoices_logProb(
348 | batch["null_input_ids"],
349 | batch["null_input_masks"],
350 | batch["all_choices_ids"],
351 | batch["all_choices_masks"],
352 | lengthNormalization,
353 | iterativelyComputeChoices)
354 | score_ofChoices -= logProb_ofChoices_condOnNullInput
355 |
356 | _, pred_choice = torch.max(score_ofChoices, dim=1)
357 | return pred_choice.cpu().numpy().tolist(), \
358 | score_ofChoices.cpu().numpy().tolist(), \
359 | logProbs_ofAllChoicesIds.cpu().numpy().tolist(), \
360 | len_allChoices.cpu().numpy().tolist(), \
361 | logProbs_ofAllChoicesIds_condOnNullInput.cpu().numpy().tolist() if logProbs_ofAllChoicesIds_condOnNullInput is not None else None
362 |
--------------------------------------------------------------------------------
/src/models/DecoderWrappers_forMulChoice_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.data.multiple_choice import MultipleChoiceDataset, MultipleChoiceReader
4 | from src.data.Batcher import Batcher
5 |
6 | from src.constructors import construct_hugFace_objects, construct_models
7 | from src.models.DecoderWrappers_forMulChoice import DecoderWrappers_forMulChoice
8 |
9 | class Test_DecoderWrappers(DecoderWrappers_forMulChoice):
10 | '''
11 |
12 | '''
13 |
14 | def __init__(self, transformer):
15 | super().__init__(transformer)
16 | self.transformer = transformer
17 |
18 | def _test_broadcast_tensor(self, input_masks, past_key_values, num_choices):
19 | '''
20 | Test that when repeating the tensors by num_choices times, the repetitions will
21 | be in the same block.
22 |
23 | Args:
24 | input_masks: [batch_size, max_input_len]
25 | past_key_values: Tuple of keys and values for each layer.
26 | The first index of the tuple is the layer index, and the second index
27 | of the tuple is whether it is a key or value. Each element in tuple
28 | has shape [batch_size, max_input_len, num_heads, head_dim] or [batch_size x num_heads, head_dim, max_input_len].
29 | num_choices:
30 | '''
31 | new_inputMask, new_pastKeyValues = super()._broadcast_tensors(input_masks, past_key_values, num_choices)
32 |
33 | batch_size = input_masks.shape[0]
34 | for i in range(batch_size):
35 | assert torch.equal(input_masks[i:i+1].repeat(num_choices, 1) ,
36 | new_inputMask[i*num_choices:(i+1)*num_choices, ]), \
37 | f"Test of broadcasting input_masks failed."
38 | for old_keyValues_perLayer, new_keyValues_perLayer in zip(past_key_values, new_pastKeyValues):
39 | for old_keyOrValue, new_keyOrValue in zip(old_keyValues_perLayer, new_keyValues_perLayer):
40 | batch_size = old_keyOrValue.shape[0]
41 | for i in range(batch_size):
42 | num_heads = past_key_values[0][0].shape[0] // input_masks.shape[0]
43 |
44 | # This means the keys or values are of shape [batch_size, max_input_len, num_heads, head_dim]
45 | if num_heads == 1:
46 | assert torch.equal(old_keyOrValue[i:i + 1].repeat(num_choices, 1, 1, 1),
47 | new_keyOrValue[i * num_choices:(i + 1) * num_choices]), \
48 | f"Test of broadcasting key,values failed."
49 | # This means the keys and values are of shape [batch_size x num_heads, head_dim, max_input_len].
50 | else:
51 | assert torch.equal(old_keyOrValue[i * num_heads : (i + 1) * num_heads].repeat((num_choices, 1, 1)),
52 | new_keyOrValue[i * num_heads * num_choices : (i + 1) * num_heads * num_choices]), \
53 | f"Test of broadcasting key,values failed."
54 |
55 |
56 | def test_compute_allChoices_logProb_fromDecoderOutput(self,
57 | input_ids,
58 | input_masks,
59 | past_key_values,
60 | allChoices_ids,
61 | allChoices_masks,
62 | allChoices_lbls):
63 | '''
64 |
65 | Args:
66 | input_ids: [batch_size, max_input_len]
67 | input_masks: [batch_size, max_input_len]
68 | past_key_values: Tuple of keys and values for each layer.
69 | The first index of the tuple is the layer index, and the second index
70 | of the tuple is whether it is a key or value. Each element in tuple
71 | has shape [batch_size, max_input_len, num_heads, head_dim] or [batch_size x num_heads, head_dim, max_input_len].
72 | allChoices_ids: [batch_size x num_choices, max_choice_len]
73 | allChoices_masks: [batch_size x num_choices, max_choice_len]
74 | allChoices_lbls: [batch_size x num_choices, max_choice_len]
75 |
76 | Returns:
77 |
78 | '''
79 | num_choices = allChoices_lbls.shape[0] // input_masks.shape[0]
80 | batch_size = input_masks.shape[0]
81 | self._test_broadcast_tensor(input_masks, past_key_values, num_choices)
82 |
83 | # Iterate over every datapoint and every choice to compute the log prob using the loss
84 | # returned from HuggingFace. Since HuggingFace averages the loss per batch,
85 | # we use batch_size=1 to get the log prob for each choice of each datapoint.
86 | listOf_logProb = []
87 | for datapoint_idx in range(batch_size):
88 | datapoint_ids = input_ids[datapoint_idx:datapoint_idx + 1]
89 | datapoint_mask = input_masks[datapoint_idx:datapoint_idx + 1]
90 |
91 | for choice_idx in range(num_choices):
92 | choiceLbls_idx = datapoint_idx*num_choices + choice_idx
93 | choice_lbls = allChoices_lbls[choiceLbls_idx:choiceLbls_idx+1]
94 | choice_ids = allChoices_ids[choiceLbls_idx:choiceLbls_idx+1]
95 | choice_mask = allChoices_masks[choiceLbls_idx:choiceLbls_idx+1]
96 |
97 | # Note the batch size is 1. Have to filter the datapoint_ids
98 | # to remove the pad ids in between the datapoint and the choice when we combined them.
99 | datapoint_len = torch.sum(datapoint_mask)
100 | filtered_datapointIds = datapoint_ids[:,:datapoint_len]
101 | combined_ids = torch.cat([filtered_datapointIds, choice_ids], dim=1)
102 | combined_mask = torch.cat([datapoint_mask[:,:datapoint_len], choice_mask], dim=1)
103 |
104 | # We want to ignore the loss for the datapoint and only compute the loss for the choices.
105 | datapoint_lbls = torch.ones_like(filtered_datapointIds).to(datapoint_ids.device) * -100
106 | # We ignore the first token in choice labels since HuggingFace will shift the labels
107 | # one over to the left, but since we concatenate the datapoint labels and choice labels,
108 | # the first choice id will not be shifted over.
109 | choice_lbls[:,0] = -100
110 | combined_lbls = torch.cat([datapoint_lbls, choice_lbls], dim=1)
111 | transformer_outputs = self.transformer(input_ids=combined_ids,
112 | attention_mask=combined_mask,
113 | labels=combined_lbls)
114 | choice_logProb = - transformer_outputs[0]
115 | listOf_logProb.append(choice_logProb)
116 |
117 | logProb_forAllChoices = torch.stack(listOf_logProb, dim=0).reshape(batch_size, num_choices)
118 |
119 | assert torch.allclose(logProb_forAllChoices,
120 | super().compute_allChoices_logProb_fromDecoderOutput(
121 | input_masks,
122 | past_key_values,
123 | allChoices_ids,
124 | allChoices_masks,
125 | True)[0],
126 | atol=1e-4), \
127 | "Test of computing log probs from decoder output failed."
128 |
129 | def test_compute_allChoices_logProb_fromDecoderOutput_iteratively(self,
130 | input_masks,
131 | past_key_values,
132 | allChoices_ids,
133 | allChoices_masks):
134 | '''
135 | Args:
136 | input_masks: [batch_size, max_input_len]
137 | past_key_values: Tuple of keys and values for each layer.
138 | The first index of the tuple is the layer index, and the second index
139 | of the tuple is whether it is a key or value. Each element in tuple
140 | has shape [batch_size, max_input_len, num_heads, head_dim] or [batch_size x num_heads, head_dim, max_input_len].
141 | allChoices_ids: [batch_size x num_choices, max_choice_len]
142 | allChoices_masks: [batch_size x num_choices, max_choice_len]
143 | Returns:
144 | '''
145 | assert torch.allclose(super().compute_allChoices_logProb_fromDecoderOutput_iteratively(
146 | input_masks,
147 | past_key_values,
148 | allChoices_ids,
149 | allChoices_masks,
150 | True)[0],
151 | super().compute_allChoices_logProb_fromDecoderOutput(
152 | input_masks,
153 | past_key_values,
154 | allChoices_ids,
155 | allChoices_masks,
156 | True)[0],
157 | atol=1e-4), \
158 | "Test of computing log probs from decoder output failed."
159 | def test_compute_allChoices_logProb(self,
160 | input_ids,
161 | input_masks,
162 | allChoices_ids,
163 | allChoices_masks,
164 | allChoices_lbls):
165 | '''
166 |
167 |
168 | Args:
169 | input_ids: [batch_size, max_input_len]
170 | input_masks: [batch_size, max_input_len]
171 | allChoices_ids: [batch_size x num_choices, max_choice_len]
172 | allChoices_lbls: [batch_size x num_choices, max_choice_len]
173 |
174 | Returns:
175 | log_prob: [batch_size x num_choices, max_choice_len]
176 | '''
177 | output = self.transformer(input_ids=input_ids, attention_mask=input_masks)
178 | past_key_values = output.past_key_values
179 |
180 | self.test_compute_allChoices_logProb_fromDecoderOutput(input_ids,
181 | input_masks,
182 | past_key_values,
183 | allChoices_ids,
184 | allChoices_masks,
185 | allChoices_lbls)
186 |
187 | self.test_compute_allChoices_logProb_fromDecoderOutput_iteratively(input_masks,
188 | past_key_values,
189 | allChoices_ids,
190 | allChoices_masks)
191 |
192 | def test_predict_mulChoice(self, batch):
193 | '''
194 |
195 | Args:
196 | batch:
197 | pointMutualInfo:
198 |
199 | Returns:
200 | predChoice: [batch_size, ]
201 | predProb: [batch_size, ]
202 | '''
203 |
204 | # Compute log p(y|x)
205 | self.test_compute_allChoices_logProb(
206 | batch["input_ids"],
207 | batch["input_masks"],
208 | batch["all_choices_ids"],
209 | batch["all_choices_masks"],
210 | batch["all_choices_lbls"])
211 |
212 |
213 | if __name__ == "__main__":
214 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
215 |
216 | # This only tests all models for 6 batches.
217 | for model_name in ["bigscience/bloom-560m", "gpt2", "facebook/opt-125m"]:
218 | hugFace_config, tokenizer, input_prefix = construct_hugFace_objects(model_name, 512)
219 | _, transformer = construct_models(model_name, False, False)
220 |
221 | model = Test_DecoderWrappers(transformer).to(device)
222 | model.eval()
223 |
224 | mcReader = MultipleChoiceReader()
225 | createDataset_fn = lambda data: MultipleChoiceDataset(data, tokenizer, 0, input_prefix, device, world_size=None)
226 | batcher = Batcher(mcReader, createDataset_fn, train_batchSize=None, eval_batchSize=1)
227 |
228 | for i, batch in enumerate(batcher.get_mulChoiceBatches("multiple_choice-dataset/xsum/random_distractors/binary_choice-using_random_distractors.jsonl")):
229 | with torch.no_grad():
230 | model.test_predict_mulChoice(batch)
231 | if i > 4:
232 | break
--------------------------------------------------------------------------------
/src/models/EncoderDecoderWrappers_forMulChoice.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from src.models.utils import compute_logProb
6 |
7 |
8 | class EncoderDecoderWrappers_forMulChoice(nn.Module):
9 | '''
10 |
11 | '''
12 |
13 | def __init__(self, transformer):
14 | super().__init__()
15 | self.transformer = transformer
16 |
17 | def _broadcast_tensors(self, input_masks, encoder_outputs, num_choices):
18 | '''
19 | Broadcast the input masks and encoder outputs to account for multiple choices per input
20 |
21 | Args:
22 | input_masks: [batch_size, max_input_len]
23 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is
24 | the hidden states of the encoder at the last layer
25 | [batch_size, max_input_len, ff_dim]
26 | num_choices:
27 |
28 | Returns:
29 | input_masks: [batch_size x num_choices, max_input_len]
30 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is
31 | the hidden states of the encoder at the last layer
32 | [batch_size x num_choices, max_input_len, ff_dim]
33 | '''
34 | input_masks = torch.repeat_interleave(input_masks, num_choices, dim=0)
35 | encoder_outputs = (torch.repeat_interleave(encoder_outputs[0], num_choices, dim=0), )
36 | return input_masks, encoder_outputs
37 |
38 | def compute_allChoices_logProb_fromEncoderOutput(self,
39 | input_masks,
40 | encoder_outputs,
41 | allChoices_ids,
42 | allChoices_masks,
43 | lengthNormalization):
44 | '''
45 |
46 | Args:
47 | input_masks: [batch_size, max_input_len]
48 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is
49 | the hidden states of the encoder at the last layer
50 | [batch_size, max_input_len, ff_dim]
51 | allChoices_ids: [batch_size x num_choices, max_choice_len]
52 | allChoices_masks: [batch_size x num_choices, max_choice_len]
53 | lengthNormalization:
54 |
55 | Returns:
56 | logProbs_forAllChoices: [batch_size, num_choices]
57 | logProbs_forAllChoicesIds_zeroOutPadIds: [batch_size, num_choices, max_choice_len]
58 | '''
59 | assert allChoices_ids.shape[0] % input_masks.shape[0] == 0, \
60 | f"The batch size {allChoices_ids.shape[0]} of allChoices_ids is not a multiple of " \
61 | f"the batch size {input_masks.shape[0]} of input_masks"
62 | num_choices = allChoices_ids.shape[0] // input_masks.shape[0]
63 |
64 | input_masks, encoder_outputs = self._broadcast_tensors(input_masks, encoder_outputs, num_choices)
65 |
66 | # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a
67 | # pad token of 0 and so the loss will not be ignored for the pad tokens
68 | # The input mask is passed in for the cross encoder-decoder attention.
69 | transformer_outputs = self.transformer(attention_mask=input_masks,
70 | encoder_outputs=encoder_outputs,
71 | labels=allChoices_ids)
72 |
73 | # We used the logits for all choices to compute the log probs per example since
74 | # the loss returned in transformer_outputs will average the negative log probs across
75 | # examples
76 | # [batch_size x num_choices, max_choice_len, vocab_size]
77 | logits_ofAllChoices = transformer_outputs[1].float()
78 | maxChoice_len = logits_ofAllChoices.shape[1]
79 | vocab_size = logits_ofAllChoices.shape[-1]
80 |
81 | # Compute the log probability of the ids for all choices with respect to the logits
82 | # [batch_size x num_choices x max_choice_len]
83 | logProbs_ofAllChoices_ids = - F.cross_entropy(logits_ofAllChoices.view(-1, vocab_size),
84 | allChoices_ids.view(-1),
85 | reduction="none")
86 |
87 | return compute_logProb(logProbs_ofAllChoices_ids,
88 | allChoices_masks,
89 | num_choices,
90 | maxChoice_len,
91 | lengthNormalization)
92 |
93 |
94 | def compute_allChoices_logProb(self,
95 | input_ids,
96 | input_masks,
97 | allChoices_ids,
98 | allChoices_masks,
99 | lengthNormalization):
100 | '''
101 |
102 |
103 | Args:
104 | input_ids: [batch_size, max_input_len]
105 | input_masks: [batch_size, max_input_len]
106 | allChoices_ids: [batch_size x num_choices, max_choice_len]
107 | allChoices_masks: [batch_size x num_choices, max_choice_len]
108 | lengthNormalization:
109 |
110 | Returns:
111 | log_prob: [batch_size x num_choices, max_choice_len]
112 | '''
113 | # Search for encoder function
114 | if hasattr(self.transformer, "encoder"):
115 | encoder_outputs = self.transformer.encoder(input_ids, input_masks)
116 | elif hasattr(self.transformer, "model") and hasattr(self.transformer.model, "encoder"):
117 | encoder_outputs = self.transformer.model.encoder(input_ids, input_masks)
118 | else:
119 | raise ValueError("Cannot find encoder function in transformer")
120 |
121 | return self.compute_allChoices_logProb_fromEncoderOutput(input_masks,
122 | encoder_outputs,
123 | allChoices_ids,
124 | allChoices_masks,
125 | lengthNormalization)
126 |
127 | def predict_mulChoice(self, batch, pointMutualInfo, lengthNormalization, iterativelyComputeChoices):
128 | '''
129 |
130 | Args:
131 | batch:
132 | pointMutualInfo:
133 | lengthNormalization:
134 | iterativelyComputeChoices: Not used. Added to be consistent with DecoderWrappers_forMulChoice
135 |
136 | Returns:
137 | pred_choice: [batch_size, ]
138 | score_ofChoices: [batch_size, num_choices]
139 | logProbs_ofAllChoicesIds: [batch_size, num_choices, max_choice_len]
140 | len_allChoices: [batch_size]
141 | logProbs_ofAllChoicesIds_condOnNullInput: [batch_size, num_choices, max_choice_len]
142 | '''
143 | # Compute log p(y|x)
144 | score_ofChoices, logProbs_ofAllChoicesIds, len_allChoices = self.compute_allChoices_logProb(
145 | batch["input_ids"],
146 | batch["input_masks"],
147 | batch["all_choices_ids"],
148 | batch["all_choices_masks"],
149 | lengthNormalization)
150 |
151 | logProbs_ofAllChoicesIds_condOnNullInput = None
152 |
153 | # For computing pointwise mutual information, we need to compute log p(y|x) - log p(y).
154 | # To compute p(y), we condition the choices on the null input.
155 | if pointMutualInfo:
156 | logProb_ofChoices_condOnNullInput, logProbs_ofAllChoicesIds_condOnNullInput, _ = self.compute_allChoices_logProb(
157 | batch["null_input_ids"],
158 | batch["null_input_masks"],
159 | batch["all_choices_ids"],
160 | batch["all_choices_masks"],
161 | lengthNormalization)
162 | score_ofChoices -= logProb_ofChoices_condOnNullInput
163 |
164 | _, pred_choice = torch.max(score_ofChoices, dim=1)
165 |
166 | return pred_choice.cpu().numpy().tolist(), \
167 | score_ofChoices.cpu().numpy().tolist(), \
168 | logProbs_ofAllChoicesIds.cpu().numpy().tolist(), \
169 | len_allChoices.cpu().numpy().tolist(), \
170 | logProbs_ofAllChoicesIds_condOnNullInput.cpu().numpy().tolist() if logProbs_ofAllChoicesIds_condOnNullInput is not None else None
171 |
--------------------------------------------------------------------------------
/src/models/EncoderDecoderWrappers_forMulChoice_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.data.multiple_choice import MultipleChoiceDataset, MultipleChoiceReader
4 | from src.data.Batcher import Batcher
5 |
6 | from src.constructors import construct_hugFace_objects, construct_models
7 |
8 | from src.models.EncoderDecoderWrappers_forMulChoice import EncoderDecoderWrappers_forMulChoice
9 |
10 | class Test_EncoderDecoderWrappers(EncoderDecoderWrappers_forMulChoice):
11 | '''
12 |
13 | '''
14 |
15 | def __init__(self, transformer):
16 | super().__init__(transformer)
17 | self.transformer = transformer
18 |
19 | def _test_broadcast_tensor(self, input_mask, encoder_outputs, num_choices):
20 | '''
21 | Test that when repeating the tensors by num_choices times, the repetitions will
22 | be in the same block.
23 |
24 | Args:
25 | input_masks: [batch_size, max_input_len]
26 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is
27 | the hidden states of the encoder at the last layer
28 | [batch_size, max_input_len, ff_dim]
29 | num_choices:
30 | '''
31 | new_inputMask, new_encoderOutputs = \
32 | super()._broadcast_tensors(input_mask, encoder_outputs, num_choices)
33 |
34 | batch_size = input_mask.shape[0]
35 | for i in range(batch_size):
36 | assert torch.equal(input_mask[i:i+1].repeat(num_choices, 1) ,
37 | new_inputMask[i*num_choices:(i+1)*num_choices, ]), \
38 | f"Test of broadcasting input_mask failed."
39 | assert torch.equal(encoder_outputs[0][i:i+1].repeat(num_choices, 1, 1),
40 | new_encoderOutputs[0][i*num_choices:(i+1)*num_choices]), \
41 | f"Test of broadcasting encoder_outputs failed."
42 |
43 |
44 |
45 | def test_compute_allChoices_logProb_fromEncoderOutput(self,
46 | input_ids,
47 | input_masks,
48 | encoder_outputs,
49 | allChoices_ids,
50 | allChoices_masks,
51 | allChoices_lbls):
52 | '''
53 |
54 | Args:
55 | input_ids: [batch_size, max_input_len]
56 | input_masks: [batch_size, max_input_len]
57 | encoder_outputs: BaseModelOutput object from HuggingFace where the first element is
58 | the hidden states of the encoder at the last layer
59 | [batch_size, max_input_len, ff_dim]
60 | allChoices_ids: [batch_size x num_choices, max_choice_len]
61 | allChoices_masks: [batch_size x num_choices, max_choice_len]
62 | allChoices_lbls: [batch_size x num_choices, max_choice_len]
63 |
64 | Returns:
65 |
66 | '''
67 | num_choices = allChoices_lbls.shape[0] // input_masks.shape[0]
68 | batch_size = input_masks.shape[0]
69 | self._test_broadcast_tensor(input_masks, encoder_outputs, num_choices)
70 |
71 | # Iterate over every datapoint and every choice to compute the log prob using the loss
72 | # returned from HuggingFace. Since HuggingFace averages the loss per batch,
73 | # we use batch_size=1 to get the log prob for each choice of each datapoint.
74 | listOf_logProb = []
75 | for datapoint_idx in range(batch_size):
76 | datapoint_ids = input_ids[datapoint_idx:datapoint_idx + 1]
77 | datapoint_mask = input_masks[datapoint_idx:datapoint_idx + 1]
78 |
79 | for choice_idx in range(num_choices):
80 | choiceLbls_idx = datapoint_idx*num_choices + choice_idx
81 | choice_lbls = allChoices_lbls[choiceLbls_idx:choiceLbls_idx+1]
82 | transformer_outputs = self.transformer(input_ids=datapoint_ids,
83 | attention_mask=datapoint_mask,
84 | labels=choice_lbls,
85 | output_hidden_states=True)
86 | choice_logProb = - transformer_outputs[0]
87 | listOf_logProb.append(choice_logProb)
88 |
89 | logProb_forAllChoices = torch.stack(listOf_logProb, dim=0).reshape(batch_size, num_choices)
90 |
91 | assert torch.allclose(logProb_forAllChoices,
92 | super().compute_allChoices_logProb_fromEncoderOutput(
93 | input_masks,
94 | encoder_outputs,
95 | allChoices_ids,
96 | allChoices_masks,
97 | True)[0],
98 | atol=1e-4), \
99 | "Test of computing log probs from encoder output failed."
100 |
101 | def test_compute_allChoices_logProb(self,
102 | input_ids,
103 | input_masks,
104 | allChoices_ids,
105 | allChoices_masks,
106 | allChoices_lbls):
107 | '''
108 |
109 |
110 | Args:
111 | input_ids: [batch_size, max_input_len]
112 | input_masks: [batch_size, max_input_len]
113 | allChoices_ids: [batch_size x num_choices, max_choice_len]
114 | allChoices_lbls: [batch_size x num_choices, max_choice_len]
115 |
116 | Returns:
117 | log_prob: [batch_size x num_choices, max_choice_len]
118 | '''
119 | # Search for encoder function
120 | if hasattr(self.transformer, "encoder"):
121 | encoder_outputs = self.transformer.encoder(input_ids, input_masks)
122 | elif hasattr(self.transformer, "model") and hasattr(self.transformer.model, "encoder"):
123 | encoder_outputs = self.transformer.model.encoder(input_ids, input_masks)
124 | else:
125 | raise ValueError("Cannot find encoder function in transformer")
126 |
127 | self.test_compute_allChoices_logProb_fromEncoderOutput(input_ids,
128 | input_masks,
129 | encoder_outputs,
130 | allChoices_ids,
131 | allChoices_masks,
132 | allChoices_lbls)
133 |
134 | def test_predict_mulChoice(self, batch):
135 | '''
136 |
137 | Args:
138 | batch:
139 | pointMutualInfo:
140 |
141 | Returns:
142 | predChoice: [batch_size, ]
143 | predProb: [batch_size, ]
144 | '''
145 | # Compute log p(y|x)
146 | self.test_compute_allChoices_logProb(
147 | batch["input_ids"],
148 | batch["input_masks"],
149 | batch["all_choices_ids"],
150 | batch["all_choices_masks"],
151 | batch["all_choices_lbls"])
152 |
153 |
154 | if __name__ == "__main__":
155 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156 |
157 | # This only tests all models for 6 batches.
158 | for model_name in ["Frederick0291/t5-small-finetuned-xsum", "facebook/bart-large-xsum"]:
159 | hugFace_config, tokenizer, input_prefix = construct_hugFace_objects(model_name, 512)
160 | _, transformer = construct_models(model_name, False, False)
161 |
162 | model = Test_EncoderDecoderWrappers(transformer).to(device)
163 | model.eval()
164 |
165 | mcReader = MultipleChoiceReader()
166 | createDataset_fn = lambda data: MultipleChoiceDataset(data, tokenizer, 0, input_prefix, device, world_size=None)
167 | batcher = Batcher(mcReader, createDataset_fn, train_batchSize=None, eval_batchSize=2)
168 |
169 | for i, batch in enumerate(batcher.get_mulChoiceBatches("multiple_choice-dataset/xsum/random_distractors/binary_choice-using_random_distractors.jsonl")):
170 | with torch.no_grad():
171 | model.test_predict_mulChoice(batch)
172 | if i > 4:
173 | break
--------------------------------------------------------------------------------
/src/models/device_maps.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | BLOOM_DEVICE_MAP = {'transformer.word_embeddings': 0,
4 | 'transformer.word_embeddings_layernorm': 0,
5 | 'transformer.h.0': 0,
6 | 'transformer.h.1': 0,
7 | 'transformer.h.2': 0,
8 | 'transformer.h.3': 0,
9 | 'transformer.h.4': 0,
10 | 'transformer.h.5': 0,
11 | 'transformer.h.6': 0,
12 | 'transformer.h.7': 0,
13 | 'transformer.h.8': 0,
14 | 'transformer.h.9': 0,
15 | 'transformer.h.10': 0,
16 | 'transformer.h.11': 0,
17 | 'transformer.h.12': 0,
18 | 'transformer.h.13': 0,
19 | 'transformer.h.14': 0,
20 | 'transformer.h.15': 0,
21 | 'transformer.h.16': 1,
22 | 'transformer.h.17': 1,
23 | 'transformer.h.18': 1,
24 | 'transformer.h.19': 1,
25 | 'transformer.h.20': 1,
26 | 'transformer.h.21': 1,
27 | 'transformer.h.22': 1,
28 | 'transformer.h.23': 1,
29 | 'transformer.h.24': 1,
30 | 'transformer.h.25': 1,
31 | 'transformer.h.26': 1,
32 | 'transformer.h.27': 1,
33 | 'transformer.h.28': 1,
34 | 'transformer.h.29': 1,
35 | 'transformer.h.30': 1,
36 | 'transformer.h.31': 1,
37 | 'transformer.h.32': 1,
38 | 'transformer.h.33': 1,
39 | 'transformer.h.34': 1,
40 | 'transformer.h.35': 2,
41 | 'transformer.h.36': 2,
42 | 'transformer.h.37': 2,
43 | 'transformer.h.38': 2,
44 | 'transformer.h.39': 2,
45 | 'transformer.h.40': 2,
46 | 'transformer.h.41': 2,
47 | 'transformer.h.42': 2,
48 | 'transformer.h.43': 2,
49 | 'transformer.h.44': 2,
50 | 'transformer.h.45': 2,
51 | 'transformer.h.46': 2,
52 | 'transformer.h.47': 2,
53 | 'transformer.h.48': 2,
54 | 'transformer.h.49': 2,
55 | 'transformer.h.50': 2,
56 | 'transformer.h.51': 2,
57 | 'transformer.h.52': 2,
58 | 'transformer.h.53': 3,
59 | 'transformer.h.54': 3,
60 | 'transformer.h.55': 3,
61 | 'transformer.h.56': 3,
62 | 'transformer.h.57': 3,
63 | 'transformer.h.58': 3,
64 | 'transformer.h.59': 3,
65 | 'transformer.h.60': 3,
66 | 'transformer.h.61': 3,
67 | 'transformer.h.62': 3,
68 | 'transformer.h.63': 3,
69 | 'transformer.h.64': 3,
70 | 'transformer.h.65': 3,
71 | 'transformer.h.66': 3,
72 | 'transformer.h.67': 3,
73 | 'transformer.h.68': 3,
74 | 'transformer.h.69': 3,
75 | 'transformer.ln_f': 0,
76 | 'lm_head': 0,
77 | }
78 |
79 | GPT_NEOX_DEVICE_MAP = {'gpt_neox.embed_in': 0,
80 | 'gpt_neox.layers.0': 0,
81 | 'gpt_neox.layers.1': 0,
82 | 'gpt_neox.layers.2': 0,
83 | 'gpt_neox.layers.3': 0,
84 | 'gpt_neox.layers.4': 0,
85 | 'gpt_neox.layers.5': 0,
86 | 'gpt_neox.layers.6': 0,
87 | 'gpt_neox.layers.7': 0,
88 | 'gpt_neox.layers.8': 0,
89 | 'gpt_neox.layers.9': 0,
90 | 'gpt_neox.layers.10': 0,
91 | 'gpt_neox.layers.11': 0,
92 | 'gpt_neox.layers.12': 0,
93 | 'gpt_neox.layers.13': 0,
94 | 'gpt_neox.layers.14': 0,
95 | 'gpt_neox.layers.15': 0,
96 | 'gpt_neox.layers.16': 0,
97 | 'gpt_neox.layers.17': 0,
98 | 'gpt_neox.layers.18': 0,
99 | 'gpt_neox.layers.19': 0,
100 | 'gpt_neox.layers.20': 0,
101 | 'gpt_neox.layers.21': 0,
102 | 'gpt_neox.layers.22': 1,
103 | 'gpt_neox.layers.23': 1,
104 | 'gpt_neox.layers.24': 1,
105 | 'gpt_neox.layers.25': 1,
106 | 'gpt_neox.layers.26': 1,
107 | 'gpt_neox.layers.27': 1,
108 | 'gpt_neox.layers.28': 1,
109 | 'gpt_neox.layers.29': 1,
110 | 'gpt_neox.layers.30': 1,
111 | 'gpt_neox.layers.31': 1,
112 | 'gpt_neox.layers.32': 1,
113 | 'gpt_neox.layers.33': 1,
114 | 'gpt_neox.layers.34': 1,
115 | 'gpt_neox.layers.35': 1,
116 | 'gpt_neox.layers.36': 1,
117 | 'gpt_neox.layers.37': 1,
118 | 'gpt_neox.layers.38': 1,
119 | 'gpt_neox.layers.39': 1,
120 | 'gpt_neox.layers.40': 1,
121 | 'gpt_neox.layers.41': 1,
122 | 'gpt_neox.layers.42': 1,
123 | 'gpt_neox.layers.43': 1,
124 | 'gpt_neox.final_layer_norm': 1,
125 | 'embed_out': 0
126 | }
127 |
128 | T0_DEVICE_MAP = {
129 | 'shared': 0,
130 | 'decoder.embed_tokens': 0,
131 | 'encoder': 0,
132 | 'decoder.block.0': 0,
133 | 'decoder.block.1': 0,
134 | 'decoder.block.2': 0,
135 | 'decoder.block.3': 1,
136 | 'decoder.block.4': 1,
137 | 'decoder.block.5': 1,
138 | 'decoder.block.6': 1,
139 | 'decoder.block.7': 1,
140 | 'decoder.block.8': 1,
141 | 'decoder.block.9': 1,
142 | 'decoder.block.10': 1,
143 | 'decoder.block.11': 1,
144 | 'decoder.block.12': 1,
145 | 'decoder.block.13': 1,
146 | 'decoder.block.14': 1,
147 | 'decoder.block.15': 1,
148 | 'decoder.block.16': 1,
149 | 'decoder.block.17': 1,
150 | 'decoder.block.18': 1,
151 | 'decoder.block.19': 1,
152 | 'decoder.block.20': 1,
153 | 'decoder.block.21': 1,
154 | 'decoder.block.22': 1,
155 | 'decoder.block.23': 1,
156 | 'decoder.final_layer_norm': 1,
157 | 'decoder.dropout': 1,
158 | 'lm_head': 0
159 | }
160 |
161 | OPT_66B_DEVICE_MAP = {
162 | 'model.decoder.embed_tokens': 0,
163 | 'lm_head': 0,
164 | 'model.decoder.embed_positions': 0,
165 | 'model.decoder.final_layer_norm': 0,
166 | 'model.decoder.layers.0': 0,
167 | 'model.decoder.layers.1': 0,
168 | 'model.decoder.layers.2': 0,
169 | 'model.decoder.layers.3': 0,
170 | 'model.decoder.layers.4': 0,
171 | 'model.decoder.layers.5': 0,
172 | 'model.decoder.layers.6': 0,
173 | 'model.decoder.layers.7': 0,
174 | 'model.decoder.layers.8': 0,
175 | 'model.decoder.layers.9': 0,
176 | 'model.decoder.layers.10': 0,
177 | 'model.decoder.layers.11': 0,
178 | 'model.decoder.layers.12': 0,
179 | 'model.decoder.layers.13': 0,
180 | 'model.decoder.layers.14': 0,
181 | 'model.decoder.layers.15': 0,
182 | 'model.decoder.layers.16': 0,
183 | 'model.decoder.layers.17': 0,
184 | 'model.decoder.layers.18': 0,
185 | 'model.decoder.layers.19': 0,
186 | 'model.decoder.layers.20': 0,
187 | 'model.decoder.layers.21': 1,
188 | 'model.decoder.layers.22': 1,
189 | 'model.decoder.layers.23': 1,
190 | 'model.decoder.layers.24': 1,
191 | 'model.decoder.layers.25': 1,
192 | 'model.decoder.layers.26': 1,
193 | 'model.decoder.layers.27': 1,
194 | 'model.decoder.layers.28': 1,
195 | 'model.decoder.layers.29': 1,
196 | 'model.decoder.layers.30': 1,
197 | 'model.decoder.layers.31': 1,
198 | 'model.decoder.layers.32': 1,
199 | 'model.decoder.layers.33': 1,
200 | 'model.decoder.layers.34': 1,
201 | 'model.decoder.layers.35': 1,
202 | 'model.decoder.layers.36': 1,
203 | 'model.decoder.layers.37': 1,
204 | 'model.decoder.layers.38': 1,
205 | 'model.decoder.layers.39': 1,
206 | 'model.decoder.layers.40': 1,
207 | 'model.decoder.layers.41': 1,
208 | 'model.decoder.layers.42': 2,
209 | 'model.decoder.layers.43': 2,
210 | 'model.decoder.layers.44': 2,
211 | 'model.decoder.layers.45': 2,
212 | 'model.decoder.layers.46': 2,
213 | 'model.decoder.layers.47': 2,
214 | 'model.decoder.layers.48': 2,
215 | 'model.decoder.layers.49': 2,
216 | 'model.decoder.layers.50': 2,
217 | 'model.decoder.layers.51': 2,
218 | 'model.decoder.layers.52': 2,
219 | 'model.decoder.layers.53': 2,
220 | 'model.decoder.layers.54': 2,
221 | 'model.decoder.layers.55': 2,
222 | 'model.decoder.layers.56': 2,
223 | 'model.decoder.layers.57': 2,
224 | 'model.decoder.layers.58': 2,
225 | 'model.decoder.layers.59': 2,
226 | 'model.decoder.layers.60': 2,
227 | 'model.decoder.layers.61': 2,
228 | 'model.decoder.layers.62': 2,
229 | 'model.decoder.layers.63': 0
230 | }
231 |
232 |
233 | OPT_175B_DEVICE_MAP = {
234 | 'model.decoder.embed_tokens': 0,
235 | 'lm_head': 0,
236 | 'model.decoder.embed_positions': 0,
237 | 'model.decoder.final_layer_norm': 0,
238 | 'model.decoder.layers.0': 0,
239 | 'model.decoder.layers.1': 0,
240 | 'model.decoder.layers.2': 0,
241 | 'model.decoder.layers.3': 0,
242 | 'model.decoder.layers.4': 0,
243 | 'model.decoder.layers.5': 0,
244 | 'model.decoder.layers.6': 0,
245 | 'model.decoder.layers.7': 0,
246 | 'model.decoder.layers.8': 0,
247 | 'model.decoder.layers.9': 0,
248 | 'model.decoder.layers.10': 0,
249 | 'model.decoder.layers.11': 0,
250 | 'model.decoder.layers.12': 0,
251 | 'model.decoder.layers.13': 0,
252 | 'model.decoder.layers.14': 0,
253 | 'model.decoder.layers.15': 0,
254 | 'model.decoder.layers.16': 0,
255 | 'model.decoder.layers.17': 0,
256 | 'model.decoder.layers.18': 0,
257 | 'model.decoder.layers.19': 0,
258 | 'model.decoder.layers.20': 0,
259 | 'model.decoder.layers.21': 0,
260 | 'model.decoder.layers.22': 0,
261 | 'model.decoder.layers.23': 0,
262 | 'model.decoder.layers.24': 0,
263 | 'model.decoder.layers.25': 1,
264 | 'model.decoder.layers.26': 1,
265 | 'model.decoder.layers.27': 1,
266 | 'model.decoder.layers.28': 1,
267 | 'model.decoder.layers.29': 1,
268 | 'model.decoder.layers.30': 1,
269 | 'model.decoder.layers.31': 1,
270 | 'model.decoder.layers.32': 1,
271 | 'model.decoder.layers.33': 1,
272 | 'model.decoder.layers.34': 1,
273 | 'model.decoder.layers.35': 1,
274 | 'model.decoder.layers.36': 1,
275 | 'model.decoder.layers.37': 1,
276 | 'model.decoder.layers.38': 1,
277 | 'model.decoder.layers.39': 1,
278 | 'model.decoder.layers.40': 1,
279 | 'model.decoder.layers.41': 1,
280 | 'model.decoder.layers.42': 1,
281 | 'model.decoder.layers.43': 1,
282 | 'model.decoder.layers.44': 1,
283 | 'model.decoder.layers.45': 1,
284 | 'model.decoder.layers.46': 1,
285 | 'model.decoder.layers.47': 1,
286 | 'model.decoder.layers.48': 1,
287 | 'model.decoder.layers.49': 2,
288 | 'model.decoder.layers.50': 2,
289 | 'model.decoder.layers.51': 2,
290 | 'model.decoder.layers.52': 2,
291 | 'model.decoder.layers.53': 2,
292 | 'model.decoder.layers.54': 2,
293 | 'model.decoder.layers.55': 2,
294 | 'model.decoder.layers.56': 2,
295 | 'model.decoder.layers.57': 2,
296 | 'model.decoder.layers.58': 2,
297 | 'model.decoder.layers.59': 2,
298 | 'model.decoder.layers.60': 2,
299 | 'model.decoder.layers.61': 2,
300 | 'model.decoder.layers.62': 2,
301 | 'model.decoder.layers.63': 2,
302 | 'model.decoder.layers.64': 2,
303 | 'model.decoder.layers.65': 2,
304 | 'model.decoder.layers.66': 2,
305 | 'model.decoder.layers.67': 2,
306 | 'model.decoder.layers.68': 2,
307 | 'model.decoder.layers.69': 2,
308 | 'model.decoder.layers.70': 2,
309 | 'model.decoder.layers.71': 2,
310 | 'model.decoder.layers.72': 2,
311 | 'model.decoder.layers.73': 3,
312 | 'model.decoder.layers.74': 3,
313 | 'model.decoder.layers.75': 3,
314 | 'model.decoder.layers.76': 3,
315 | 'model.decoder.layers.77': 3,
316 | 'model.decoder.layers.78': 3,
317 | 'model.decoder.layers.79': 3,
318 | 'model.decoder.layers.80': 3,
319 | 'model.decoder.layers.81': 3,
320 | 'model.decoder.layers.82': 3,
321 | 'model.decoder.layers.83': 3,
322 | 'model.decoder.layers.84': 3,
323 | 'model.decoder.layers.85': 3,
324 | 'model.decoder.layers.86': 3,
325 | 'model.decoder.layers.87': 0,
326 | 'model.decoder.layers.88': 0,
327 | 'model.decoder.layers.89': 0,
328 | 'model.decoder.layers.90': 0,
329 | 'model.decoder.layers.91': 0,
330 | 'model.decoder.layers.92': 0,
331 | 'model.decoder.layers.93': 0,
332 | 'model.decoder.layers.94': 0,
333 | 'model.decoder.layers.95': 0
334 | }
335 |
336 |
--------------------------------------------------------------------------------
/src/models/model_flags.py:
--------------------------------------------------------------------------------
1 | from src.models.device_maps import BLOOM_DEVICE_MAP, GPT_NEOX_DEVICE_MAP, T0_DEVICE_MAP, OPT_66B_DEVICE_MAP, OPT_175B_DEVICE_MAP
2 |
3 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM
4 | from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast
5 |
6 | DICT_REGEX_OF_MODEL_TYPE = {
7 | ".*T0.*": "encoder_decoder",
8 | ".*pegasus.*": "encoder_decoder",
9 | ".*t5.*": "encoder_decoder",
10 | ".*bart.*": "encoder_decoder",
11 | ".*bloom.*": "decoder",
12 | ".*gpt.*": "decoder",
13 | ".*opt.*": "decoder",
14 | ".*T5.*": "encoder_decoder",
15 | }
16 |
17 | DICT_REGEX_OF_WHETHER_MODEL_USES_POSITION_IDS = {
18 | ".*bloom.*": False,
19 | ".*gpt.*": True,
20 | ".*opt.*": False
21 | }
22 |
23 | DICT_REGEX_OF_DEVICE_MAP = {
24 | ".*": "auto",
25 | ".*bloom": BLOOM_DEVICE_MAP,
26 | ".*gpt-neox-20b": GPT_NEOX_DEVICE_MAP,
27 | ".*T0|.*t5-xxl.*": T0_DEVICE_MAP,
28 | ".*opt-66b": OPT_66B_DEVICE_MAP,
29 | ".*opt-175b": OPT_175B_DEVICE_MAP
30 | }
31 |
32 | DICT_REGEX_OF_TOKENIZERS = {
33 | ".*": lambda model_name: AutoTokenizer.from_pretrained(model_name),
34 | ".*opt.*": lambda model_name: AutoTokenizer.from_pretrained(model_name, use_fast=False),
35 | ".*gpt-neox-20b": lambda model_name: GPTNeoXTokenizerFast.from_pretrained(model_name)
36 | }
--------------------------------------------------------------------------------
/src/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def compute_logProb(logProbs_ofAllChoices_ids, allChoices_masks, num_choices, maxChoice_len, lengthNormalization):
4 | '''
5 |
6 |
7 | Args:
8 | logProbs_forAllChoices_ids: [batch_size x num_choices x max_choice_len]
9 | allChoices_masks: [batch_size, num_choices, max_choice_len]
10 | num_choices:
11 | maxChoice_len:
12 | lengthNormalization:
13 |
14 | Returns:
15 | logProbs_ofAllChoices: [batch_size, num_choices]
16 | logProbs_ofAllChoicesIds_zeroOutPadIds: [batch_size, num_choices, max_choice_len]
17 | len_allChoices: [batch_size ]
18 | '''
19 | # Compute the log probabilities of all the choices by averaging the log probabilities of
20 | # the ids and zeroing out the pad ids
21 | # [batch_size, num_choices, max_choice_len]
22 | logProbs_ofAllChoices_ids = logProbs_ofAllChoices_ids.reshape(-1, num_choices, maxChoice_len)
23 | allChoices_masks = allChoices_masks.reshape(-1, num_choices, maxChoice_len) > 0
24 | logProbs_ofAllChoicesIds_zeroOutPadIds = logProbs_ofAllChoices_ids * allChoices_masks
25 | logProbs_ofAllChoices = torch.sum(logProbs_ofAllChoicesIds_zeroOutPadIds, dim=2)
26 | len_allChoices = torch.sum(allChoices_masks, dim=2)
27 |
28 | if lengthNormalization:
29 | logProbs_ofAllChoices = logProbs_ofAllChoices / len_allChoices
30 |
31 | return logProbs_ofAllChoices,\
32 | logProbs_ofAllChoicesIds_zeroOutPadIds, \
33 | len_allChoices
34 |
--------------------------------------------------------------------------------
/src/utils/CONSTANTS.py:
--------------------------------------------------------------------------------
1 |
2 | NULL_STRING = "NULL"
3 |
--------------------------------------------------------------------------------
/src/utils/deepspeed.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def get_deepspeedConfig(eval_batchSize, world_size, model_dim):
4 | '''
5 |
6 | Args:
7 | eval_batchSize:
8 | world_size:
9 | model_dim:
10 |
11 | Returns:
12 |
13 | '''
14 | # https://github.com/huggingface/transformers/issues/15399
15 | deepspeed_config = {
16 | "fp16": {
17 | "enabled": False,
18 | },
19 | "bf16": {
20 | "enabled": False,
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "offload_param": {
25 | "device": "cpu",
26 | "pin_memory": True
27 | },
28 | "overlap_comm": True,
29 | "contiguous_gradients": True,
30 | "reduce_bucket_size": model_dim * model_dim,
31 | "stage3_prefetch_bucket_size": 0.9 * model_dim * model_dim,
32 | "stage3_param_persistence_threshold": 10 * model_dim
33 | },
34 | "steps_per_print": 2000,
35 | "train_batch_size": eval_batchSize * world_size,
36 | "train_micro_batch_size_per_gpu": eval_batchSize,
37 | "wall_clock_breakdown": False
38 | }
39 |
40 | return deepspeed_config
--------------------------------------------------------------------------------
/src/utils/test_helpers.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def check_string_equality(string_one, string_two):
4 | assert string_one == string_two, \
5 | f"\n{string_one}\n" + \
6 | "="*100 + '\n' +\
7 | f"{string_two}"
8 |
9 | def check_string_subset_of_another(string_one, string_two):
10 | assert string_one in string_two, \
11 | f"\n{string_one}\n" + \
12 | "="*100 + '\n' +\
13 | f"{string_two}"
14 |
15 | def check_string_starts_with_another(string_one, string_two):
16 | assert string_one.startswith(string_two), \
17 | f"\n{string_one}\n" + \
18 | "="*100 + '\n' +\
19 | f"{string_two}"
20 |
21 | def check_string_ends_with_another(string_one, string_two):
22 | assert string_one.endswith(string_two), \
23 | f"\n{string_one}\n" + \
24 | "="*100 + '\n' +\
25 | f"{string_two}"
--------------------------------------------------------------------------------
/src/utils/util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import re
4 | import random
5 | import datetime
6 | import os
7 | import subprocess
8 | import numpy as np
9 | import torch
10 |
11 | from shutil import copytree, ignore_patterns
12 | from src.utils.CONSTANTS import NULL_STRING
13 |
14 |
15 | def set_global_logging_level(level=logging.ERROR, prefices=[""]):
16 | """
17 | Override logging levels of different modules based on their name as a prefix.
18 | It needs to be invoked after the modules have been loaded so that their loggers have been initialized.
19 |
20 | Args:
21 | - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
22 | - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
23 | Default is `[""]` to match all active loggers.
24 | The match is a case-sensitive `module_name.startswith(prefix)`
25 | """
26 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
27 | for name in logging.root.manager.loggerDict:
28 | if re.match(prefix_re, name):
29 | logging.getLogger(name).setLevel(level)
30 |
31 |
32 | class ParseKwargs(argparse.Action):
33 | def __call__(self, parser, namespace, values, option_string=None):
34 | setattr(namespace, self.dest, dict())
35 | for value in values:
36 | key, value = value.split('=')
37 | getattr(namespace, self.dest)[key] = value
38 |
39 |
40 | def update_dict_val_store(dict_val_store, dict_update_val):
41 |
42 | if dict_val_store is None:
43 | dict_val_store = {}
44 | for k in dict_update_val.keys():
45 | dict_val_store[k] = dict_update_val[k].detach().cpu().item()
46 | else:
47 | for k in dict_val_store.keys():
48 | dict_val_store[k] += dict_update_val[k].detach().cpu().item()
49 |
50 | return dict_val_store
51 |
52 | def get_avg_dict_val_store(config, dict_val_store, eval_every):
53 |
54 | dict_avg_val = {}
55 |
56 | for k in dict_val_store.keys():
57 | old_val = dict_val_store[k]
58 | dict_avg_val[k] = float('%.3f' % (old_val / eval_every))
59 | dict_val_store[k] = 0
60 |
61 | return dict_val_store, dict_avg_val
62 |
63 | def save_gcp(filepath):
64 | subprocess.call(f"gsutil -m -o GSUtil:parallel_composite_upload_threshold=150M \
65 | cp -r {filepath} \
66 | gs://abs_sum/{filepath}", shell=True)
67 |
68 | def set_seeds(seed):
69 | "set random seeds"
70 | random.seed(seed)
71 | np.random.seed(seed)
72 | torch.manual_seed(seed)
73 | torch.cuda.manual_seed_all(seed)
74 |
75 | def make_dir(dir_name):
76 | '''
77 | Makes a directory if it doesn't exists yet
78 | Args:
79 | dir_name: directory name
80 | '''
81 | if not os.path.exists(dir_name):
82 | os.makedirs(dir_name)
83 |
84 |
85 | def make_exp_dir(base_exp_dir):
86 | '''
87 | Makes an experiment directory with timestamp
88 | Args:
89 | base_output_dir_name: base output directory name
90 | Returns:
91 | exp_dir_name: experiment directory name
92 | '''
93 | now = datetime.datetime.now()
94 | ts = "{:04d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}".format(now.year, now.month, now.day, now.hour, now.minute,
95 | now.second)
96 | exp_dir_name = os.path.join(base_exp_dir, ts)
97 | make_dir(exp_dir_name)
98 |
99 | src_file = os.path.join(exp_dir_name, 'src')
100 |
101 | copytree(os.path.join(os.environ['LFQA_FAC_ROOT'], "src"), src_file, ignore=ignore_patterns('*.pyc', 'tmp*'))
102 |
103 | return exp_dir_name
104 |
105 | def reduce_gatheredOutput(listOfDict):
106 | '''
107 | Reduces the output from multiple devices to have the same format as for a single device.
108 | Also, removes the NULL datapoint, which is a dummy payload to handle model parallelism.
109 |
110 | Args:
111 | listOfDict:
112 |
113 | Returns:
114 |
115 | '''
116 | dictOfList = {}
117 |
118 | # Form list of values at each key
119 | for iterate_dict in listOfDict:
120 |
121 | # Find indices of NULL datapoint to ignore later
122 | idx_toRemove = {}
123 | for idx, datapoint_input in enumerate(iterate_dict["input"]):
124 | if datapoint_input == NULL_STRING:
125 | idx_toRemove[idx] = True
126 |
127 | for (k, v) in iterate_dict.items():
128 |
129 | # Filter out NULL datapoints based on indices.
130 | filtered_v = []
131 | for idx, datapoint_v in enumerate(v):
132 | if idx not in idx_toRemove:
133 |
134 | filtered_v.append(datapoint_v)
135 |
136 | if k in dictOfList:
137 | dictOfList[k].append(filtered_v)
138 | else:
139 | dictOfList[k] = [filtered_v]
140 |
141 | # Flatten lists of list to form a list, or concatenate list of tensors to form a tensor
142 | for (k, batch_ofValues) in dictOfList.items():
143 | dictOfList[k] = [item for sublist in batch_ofValues for item in sublist]
144 |
145 | return dictOfList
146 |
147 |
148 | def get_value_from_key_matching_regex(dict_regex_keyToValue, key_toMatch):
149 | matching_value = None
150 | for regex_key, value in dict_regex_keyToValue.items():
151 | if re.search(regex_key, key_toMatch) is not None:
152 | matching_value = value
153 | return matching_value
154 |
155 | def get_mulChoice_outputDir(mulChoice_fp, model_name, ignore_pointMutualInfo, ignore_lengthNormalization):
156 | '''
157 | Get output dir, where we assume the filepath of the multiple choice dataset is of the
158 | format data/{}.jsonl where we flatten all subdirectories
159 | Args:
160 | mulChoice_fp:
161 | model_name:
162 | Returns:
163 | '''
164 | mulChoice_datasetName = mulChoice_fp\
165 | .replace("multiple_choice-dataset/", "")\
166 | .replace(".jsonl", "")
167 | model_name = model_name.replace("/fruitbasket/models/", "").replace("/", "-")
168 | output_dir = os.path.join("exp_out", "multiple_choice", mulChoice_datasetName)
169 |
170 | if ignore_pointMutualInfo:
171 | ignorePointMutualInfo_str = "-ignore_pointwise_mutual_info"
172 | else:
173 | ignorePointMutualInfo_str = ""
174 |
175 | if ignore_lengthNormalization:
176 | ignoreLengthNormalizationInfo_str = "-ignore_length_normalization"
177 | else:
178 | ignoreLengthNormalizationInfo_str = ""
179 |
180 | output_dir = os.path.join(output_dir, model_name + ignorePointMutualInfo_str + ignoreLengthNormalizationInfo_str)
181 | if not os.path.exists(output_dir):
182 | os.makedirs(output_dir)
183 |
184 | return output_dir
185 |
--------------------------------------------------------------------------------