├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY.md ├── environment.yml ├── lm_eval ├── base.py ├── evaluator.py ├── models │ ├── 66b_device_map.txt │ ├── __init__.py │ └── opt.py ├── prefix_matching_copying.py ├── tasks │ ├── arc.py │ ├── glue.py │ ├── hellaswag.py │ ├── lambada.py │ ├── mathqa.py │ ├── openbookqa.py │ ├── piqa.py │ ├── superglue.py │ └── winogrande.py └── utils.py ├── main.py ├── scripts ├── get_fc_ranking.py └── plotting │ ├── combined_pruning.py │ ├── crosstask_accuracy.py │ ├── fc_importance.py │ ├── heatmap.py │ ├── iterative_pruning.py │ ├── prefix_copying_pruning.py │ ├── prefix_copying_task_specific.py │ ├── spearman_rankings.py │ └── style.py └── transformers └── models └── opt └── modeling_opt.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Rethinking the Role of Scale for In-Context Learning: An Interpretability-based Case Study at 66 Billion Scale 2 | 3 | This repository contains code to reproduce the experiments in the paper "[Rethinking the Role of Scale for In-Context Learning: An Interpretability-based Case Study at 66 Billion Scale](https://arxiv.org/abs/2212.09095)", 4 | published in the main proceedings of ACL 2023. 5 | 6 | ## Setup 7 | 8 | Set up and activate an initial conda environment using the provided `environment.yml` file. 9 | ``` 10 | conda env create -f environment.yml 11 | conda activate opt 12 | ``` 13 | 14 | Install [PyTorch](https://pytorch.org/) based on your system configuration. We used the following with 15 | AWS EC2 p4 instances: 16 | 17 | ``` 18 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 19 | ``` 20 | 21 | ## Getting Started 22 | 23 | Our code is based off 🤗Hugging Face's [transformers](https://github.com/huggingface/transformers) 24 | and Eleuther AI's [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) 25 | libraries. 26 | 27 | Run the following sequence of commands (in that order) to clone and set up both libraries in your file 28 | system. We point to particular hashes associated with our runs, but it may be possible that our code is 29 | forward-compatible with newer versions. 30 | 31 | ``` 32 | git clone https://github.com/EleutherAI/lm-evaluation-harness.git 33 | cd lm-evaluation-harness 34 | git checkout 11fa0bf4394998634e6c6e0c9fc2fc8211415042 35 | git clone https://github.com/huggingface/transformers.git 36 | cd transformers 37 | git checkout 9832ac7c736519fcfeedb88c8368cf0ab08b2b58 38 | ``` 39 | 40 | ### Changes to 🤗Transformers 41 | 42 | We modified the implementation of the Open Pre-Trained Transformer (OPT) in 🤗Transformers to allow 43 | for importance score computations. Specifically: 44 | 1. we use hooks to store the gradient of the loss w.r.t. the output of attention heads (see `context_layer_val` and `context_layer_val_grad`) 45 | 2. we define masks to "knock-off" particular feed forward networks (see `fc_mask` and `layer_fc_mask`) 46 | 47 | The modified implementation is located at [transformers/models/opt/modeling_opt.py](transformers/models/opt/modeling_opt.py) 48 | in this repo. 49 | 50 | Copy this script to the corresponding location for OPT in the local clone of `transformers`. 51 | 52 | ### Changes to `lm-evaluation-harness` 53 | 54 | We added support for OPT in `lm-evaluation-harness` following the existing example for GPT-2, 55 | see [lm_eval/models/opt.py](lm_eval/models/opt.py). This utilizes the core modifications to OPT in the 56 | local clone of `transformers` described above. We used a custom device map to shard the model 57 | parameters for our compute capacity, which can be modified according to one's own compute resourcing. 58 | 59 | We also adapted other existing scripts from `lm-evaluation-harness` in the `lm_eval` directory: 60 | 1. [lm_eval/base.py](lm_eval/base.py) has the core logic of computing attention head importance scores, 61 | see the `calculate_importance()` method. 62 | 2. [lm_eval/evaluator.py](lm_eval/evaluator.py) contains the code-flow to allow for original evaluation 63 | as well as attention head importance score computation. The computed head importance scores are dumped 64 | in pickle files. 65 | 3. [lm_eval/utils.py](lm_eval/utils.py) contains methods for dataset and data loader creation used 66 | for attention head importance score computation, see the `create_dataloader()` and 67 | `get_dataloader_from_dataset()` methods. 68 | 4. Each task defined in [lm_eval/tasks/](lm_eval/tasks/) is updated to create the associated data 69 | loader via `utils.py` as described above and define a getter method for the data loader, see the 70 | `get_dataloader()` method. 71 | 72 | The driver script `main.py` is also adapted to allow these changes to be leveraged. Note that this 73 | script dumps the evaluation results into JSON-formatted text files, which are necessary to create some plots in 74 | our paper. 75 | 76 | Copy these scripts to their corresponding locations in the local clone of `lm-evaluation-harness`. 77 | 78 | 79 | ### Induction Heads: Prefix Matching and Copying 80 | [lm_eval/prefix_matching_copying.py](lm_eval/prefix_matching_copying.py) contains our 81 | implementation for computing prefix matching and copying scores for attention heads, 82 | also described in detail with pseudocode in our paper's Appendix. The original 83 | algorithm by Anthropic is described in the Additional Details section of the Transformer Circuits Thread post [here](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html#data-collection). 84 | Please refer to our paper's Appendix for a description of the modifications we made 85 | to their algorithm. 86 | 87 | Copy this script to the `lm_eval` directory in the local clone of `lm-evaluation-harness`. 88 | 89 | ### Plotting 90 | We provide the scripts used to create the plots in our paper in the 91 | [scripts/](scripts/) directory. These scripts assume that the importance scores 92 | are already computed and dumped in pickle files and the task-specific evaluation 93 | results are dumped in JSON-formatted text files using the code described above. 94 | 95 | Note that you may have to edit these scripts a bit according to the naming convention 96 | you adopt for the importance score pickle and evaluation result text files you create. 97 | 98 | 99 | ## Sample Commands 100 | In this section, we provide sample commands leveraging the code described above 101 | for a few use-cases. We recommend diving into the code and understanding the 102 | supported args to be able to leverage all supported functionality. 103 | 104 | ### Model and Tokenizer Caching 105 | Load the pre-trained model and tokenizer into explicitly defined cache directories 106 | as a one-time operation: 107 | ``` 108 | cd lm-evaluation-harness 109 | python 110 | >>> from transformers import AutoModel, AutoTokenizer 111 | >>> model = AutoModel.from_pretrained('facebook/opt-66b', cache_dir='opt66b_checkpoints/') 112 | >>> tokenizer = AutoTokenizer.from_pretrained('facebook/opt-66b', cache_dir='opt66b_tokenizer/') 113 | ``` 114 | 115 | ### Attention Head Importance Scores 116 | The following command computes and saves attention head importance scores for the 117 | Physical IQA (PIQA) task in the 1-shot setting: 118 | ``` 119 | python main.py --model opt --model_args pretrained=facebook/opt-66b,model_cache_dir=opt66b_checkpoints,tokenizer_cache_dir=opt66b_tokenizer --tasks piqa --head_importance_calc --save_importance_path logs/head_importance/opt66b/1shot_piqa.pkl --num_fewshot 1 120 | ``` 121 | 122 | ### Masking A Feed Forward Network 123 | 124 | To mask a particular feed forward network (FFN) and evaluate the model on a 125 | particular task, the following sample command can be used. OPT has 64 layers and 126 | in this case, we are masking the FFN in layer 10 (indexing starting from 0) when 127 | evaluating the model on the PIQA task in the 5-shot setting. 128 | 129 | ``` 130 | python main.py --model opt --model_args pretrained=facebook/opt-66b,model_cache_dir=opt66b_checkpoints,tokenizer_cache_dir=opt66b_tokenizer,mask_fc=10 --tasks piqa --output_path results/66b/5shot_fc_pruning/piqa/5shot_fc_10.txt --batch_size 2 --num_fewshot 5 131 | ``` 132 | 133 | ### Iterative Pruning of Attention Heads 134 | 135 | To mask unimportant attention heads and evaluate the model on a particular task, 136 | the following sample command can be used. In this case, we are masking 20% (range: 0-90%) 137 | of the task and shot-specific unimportant attention heads and evaluating the model 138 | on the PIQA task in the 1-shot setting. 139 | 140 | ``` 141 | python main.py --model opt --model_args pretrained=facebook/opt-66b,model_cache_dir=opt66b_checkpoints,tokenizer_cache_dir=opt66b_tokenizer,mask_heads=1,head_importance_path=logs/head_importance/opt66b/1shot_piqa.pkl,head_percent_mask=20 --tasks piqa --output_path results/66b/piqa/1shot_piqa_percent.txt --batch_size 2 --num_fewshot 1 142 | ``` 143 | 144 | ### FFN Importance Scores 145 | 146 | The following command leverages `fc_importance.py`, which computes importance 147 | scores for each FFN as the difference between the baseline accuracy and the 148 | accuracy after masking the FFN for each task, and dumps them to pickle files. 149 | The accuracy upon independently masking each FFN is assumed to have already been 150 | computed as described above with an earlier sample command. 151 | 152 | ``` 153 | python scripts/plotting/fc_importance.py --results_path results/66b/5shot_fc_pruning/ --base_results_path results/66b/ --shot 5-shot --save_plot_path paper_plots/fc_importance/5-shot.png --dump_fc_importance --dump_fc_importance_path logs/fc_knocking_importance/ 154 | ``` 155 | 156 | ### Iterative Pruning of FFNs 157 | 158 | To mask unimportant FFNs and evaluate the model on a particular task, the following 159 | sample command can be used. In this case, we are masking 20% (range: 0-90%) of the 160 | task and shot-specific unimportant FFNs and evaluating the model on the PIQA task 161 | in the 5-shot setting. 162 | 163 | ``` 164 | python main.py --model opt --model_args pretrained=facebook/opt-66b,model_cache_dir=opt66b_checkpoints,tokenizer_cache_dir=opt66b_tokenizer,mask_iterative_fc=1,fc_importance_path=logs/fc_knocking_importance/5shot_piqa.pkl,fc_percent_mask=20 --tasks piqa --output_path results/66b/piqa/5shot_20_fc_percent.txt --batch_size 1 --num_fewshot 5 165 | ``` 166 | 167 | ### Combined Pruning of Heads and FFNs 168 | 169 | To evaluate the model on a particular task after combined pruning of attention heads 170 | and FFNs, the following sample command can be used. In this case, we are masking 171 | 20% of the unimportant attention heads and 30% of the unimportant FFNs and evaluating 172 | the model on the PIQA task in the 1-shot setting. 173 | 174 | ``` 175 | python main.py --model opt --model_args pretrained=facebook/opt-66b,model_cache_dir=opt66b_checkpoints,tokenizer_cache_dir=opt66b_tokenizer,mask_iterative_fc=1,fc_importance_path=logs/fc_knocking_importance/1shot_piqa.pkl,fc_percent_mask=30,mask_heads=1,head_importance_path=logs/head_importance/opt66b/1shot_piqa.pkl,head_percent_mask=20 --tasks piqa --output_path results/66b/piqa/1shot_30_fc_20_head_percent.txt --batch_size 2 --num_fewshot 1 176 | ``` 177 | 178 | ### Prefix Matching and Copying 179 | 180 | To compute, plot and save prefix matching and copying scores, the following pair of 181 | sample commands can be used. 182 | 183 | Prefix Matching: 184 | ``` 185 | python -m lm_eval.prefix_matching_copying --prefix_matching --pretrained facebook/opt-66b --model_cache_dir opt66b_checkpoints/ --tokenizer_cache_dir opt66b_tokenizer/ --save_plot_path_mean paper_plots/induction_heads/pfx_matching_mean.png --save_plot_path_var paper_plots/induction_heads/pfx_matching_var.png --save_outputs paper_plots/induction_heads/pfx_matching.pkl 186 | ``` 187 | 188 | Copying: 189 | ``` 190 | python -m lm_eval.prefix_matching_copying --copying_score --pretrained facebook/opt-66b --model_cache_dir opt66b_checkpoints/ --tokenizer_cache_dir opt66b_tokenizer/ --save_plot_path_mean paper_plots/induction_heads/copying_mean.png --save_plot_path_var paper_plots/induction_heads/copying_var.png --save_outputs paper_plots/induction_heads/copying.pkl 191 | ``` 192 | 193 | 194 | ## Citation 195 | 196 | If you find our work useful, please consider citing using the following: 197 | ``` 198 | @misc{bansal2022rethinking, 199 | title={Rethinking the Role of Scale for In-Context Learning: An Interpretability-based Case Study at 66 Billion Scale}, 200 | author={Hritik Bansal and Karthik Gopalakrishnan and Saket Dingliwal and Sravan Bodapati and Katrin Kirchhoff and Dan Roth}, 201 | year={2022}, 202 | eprint={2212.09095}, 203 | archivePrefix={arXiv}, 204 | primaryClass={cs.CL} 205 | } 206 | ``` 207 | 208 | ## Security 209 | 210 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 211 | 212 | ## License 213 | 214 | This project is licensed under the Apache-2.0 License. 215 | 216 | See [THIRD-PARTY](THIRD-PARTY.md) for a summary of changes made to third-party libraries, 217 | described in the **Getting Started** section in detail, along with the associated licenses. 218 | 219 | -------------------------------------------------------------------------------- /THIRD-PARTY.md: -------------------------------------------------------------------------------- 1 | # THIRD PARTY NOTICE 2 | 3 | We extracted and used code from the following packages: 4 | * transformers 5 | * Package: https://github.com/huggingface/transformers 6 | * License: Apache 2.0, https://github.com/huggingface/transformers/blob/9832ac7c736519fcfeedb88c8368cf0ab08b2b58/LICENSE 7 | * Copyright: "Copyright 2018- The Hugging Face team. All rights reserved.", https://github.com/huggingface/transformers/blob/9832ac7c736519fcfeedb88c8368cf0ab08b2b58/LICENSE 8 | * Usage: 9 | * We adapted the code in `src/transformers/models/opt/modeling_opt.py` to allow for importance score computations, as seen in `transformers/models/opt/modeling_opt.py` 10 | 11 | * lm-evaluation-harness 12 | * Package: https://github.com/EleutherAI/lm-evaluation-harness 13 | * License: MIT License, https://github.com/EleutherAI/lm-evaluation-harness/blob/11fa0bf4394998634e6c6e0c9fc2fc8211415042/LICENSE.md 14 | * Copyright: "Copyright (c) 2020 EleutherAI", https://github.com/EleutherAI/lm-evaluation-harness/blob/11fa0bf4394998634e6c6e0c9fc2fc8211415042/LICENSE.md 15 | * Usage: 16 | * We added support for the OPT model in `lm_eval/models/opt.py` by following the example of `lm_eval/models/gpt2.py` 17 | * We adapted `lm_eval/base.py` to add the core logic of computing attention head importance scores 18 | * We adapted `lm_eval/evaluator.py` to allow for original evaluation alongside head importance score computation 19 | * We adapted `lm_eval/utils.py` to include two methods for dataset creation used during head importance score computation (`get_dataloader_from_dataset` and `create_dataloader`) 20 | * We adapted task-specific files in `lm_eval/tasks/` to use the new data loader defined in our updated `utils.py` as described above, as seen in `lm_eval/tasks/.py` 21 | * We adapted the driver script `main.py` to account for the aforementioned changes 22 | 23 | Full copyright with original license files of the aforementioned packages are pasted below. 24 | 25 | ------------------ 26 | 27 | ** HuggingFace transformers - https://github.com/huggingface/transformers 28 | 29 | Copyright 2018- The Hugging Face team. All rights reserved. 30 | 31 | Apache License 32 | Version 2.0, January 2004 33 | http://www.apache.org/licenses/ 34 | 35 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 36 | 37 | 1. Definitions. 38 | 39 | "License" shall mean the terms and conditions for use, reproduction, 40 | and distribution as defined by Sections 1 through 9 of this document. 41 | 42 | "Licensor" shall mean the copyright owner or entity authorized by 43 | the copyright owner that is granting the License. 44 | 45 | "Legal Entity" shall mean the union of the acting entity and all 46 | other entities that control, are controlled by, or are under common 47 | control with that entity. For the purposes of this definition, 48 | "control" means (i) the power, direct or indirect, to cause the 49 | direction or management of such entity, whether by contract or 50 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 51 | outstanding shares, or (iii) beneficial ownership of such entity. 52 | 53 | "You" (or "Your") shall mean an individual or Legal Entity 54 | exercising permissions granted by this License. 55 | 56 | "Source" form shall mean the preferred form for making modifications, 57 | including but not limited to software source code, documentation 58 | source, and configuration files. 59 | 60 | "Object" form shall mean any form resulting from mechanical 61 | transformation or translation of a Source form, including but 62 | not limited to compiled object code, generated documentation, 63 | and conversions to other media types. 64 | 65 | "Work" shall mean the work of authorship, whether in Source or 66 | Object form, made available under the License, as indicated by a 67 | copyright notice that is included in or attached to the work 68 | (an example is provided in the Appendix below). 69 | 70 | "Derivative Works" shall mean any work, whether in Source or Object 71 | form, that is based on (or derived from) the Work and for which the 72 | editorial revisions, annotations, elaborations, or other modifications 73 | represent, as a whole, an original work of authorship. For the purposes 74 | of this License, Derivative Works shall not include works that remain 75 | separable from, or merely link (or bind by name) to the interfaces of, 76 | the Work and Derivative Works thereof. 77 | 78 | "Contribution" shall mean any work of authorship, including 79 | the original version of the Work and any modifications or additions 80 | to that Work or Derivative Works thereof, that is intentionally 81 | submitted to Licensor for inclusion in the Work by the copyright owner 82 | or by an individual or Legal Entity authorized to submit on behalf of 83 | the copyright owner. For the purposes of this definition, "submitted" 84 | means any form of electronic, verbal, or written communication sent 85 | to the Licensor or its representatives, including but not limited to 86 | communication on electronic mailing lists, source code control systems, 87 | and issue tracking systems that are managed by, or on behalf of, the 88 | Licensor for the purpose of discussing and improving the Work, but 89 | excluding communication that is conspicuously marked or otherwise 90 | designated in writing by the copyright owner as "Not a Contribution." 91 | 92 | "Contributor" shall mean Licensor and any individual or Legal Entity 93 | on behalf of whom a Contribution has been received by Licensor and 94 | subsequently incorporated within the Work. 95 | 96 | 2. Grant of Copyright License. Subject to the terms and conditions of 97 | this License, each Contributor hereby grants to You a perpetual, 98 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 99 | copyright license to reproduce, prepare Derivative Works of, 100 | publicly display, publicly perform, sublicense, and distribute the 101 | Work and such Derivative Works in Source or Object form. 102 | 103 | 3. Grant of Patent License. Subject to the terms and conditions of 104 | this License, each Contributor hereby grants to You a perpetual, 105 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 106 | (except as stated in this section) patent license to make, have made, 107 | use, offer to sell, sell, import, and otherwise transfer the Work, 108 | where such license applies only to those patent claims licensable 109 | by such Contributor that are necessarily infringed by their 110 | Contribution(s) alone or by combination of their Contribution(s) 111 | with the Work to which such Contribution(s) was submitted. If You 112 | institute patent litigation against any entity (including a 113 | cross-claim or counterclaim in a lawsuit) alleging that the Work 114 | or a Contribution incorporated within the Work constitutes direct 115 | or contributory patent infringement, then any patent licenses 116 | granted to You under this License for that Work shall terminate 117 | as of the date such litigation is filed. 118 | 119 | 4. Redistribution. You may reproduce and distribute copies of the 120 | Work or Derivative Works thereof in any medium, with or without 121 | modifications, and in Source or Object form, provided that You 122 | meet the following conditions: 123 | 124 | (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and 125 | 126 | (b) You must cause any modified files to carry prominent notices stating that You changed the files; and 127 | 128 | (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 129 | 130 | (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 131 | 132 | You may add Your own copyright statement to Your modifications and 133 | may provide additional or different license terms and conditions 134 | for use, reproduction, or distribution of Your modifications, or 135 | for any such Derivative Works as a whole, provided Your use, 136 | reproduction, and distribution of the Work otherwise complies with 137 | the conditions stated in this License. 138 | 139 | 5. Submission of Contributions. Unless You explicitly state otherwise, 140 | any Contribution intentionally submitted for inclusion in the Work 141 | by You to the Licensor shall be under the terms and conditions of 142 | this License, without any additional terms or conditions. 143 | Notwithstanding the above, nothing herein shall supersede or modify 144 | the terms of any separate license agreement you may have executed 145 | with Licensor regarding such Contributions. 146 | 147 | 6. Trademarks. This License does not grant permission to use the trade 148 | names, trademarks, service marks, or product names of the Licensor, 149 | except as required for reasonable and customary use in describing the 150 | origin of the Work and reproducing the content of the NOTICE file. 151 | 152 | 7. Disclaimer of Warranty. Unless required by applicable law or 153 | agreed to in writing, Licensor provides the Work (and each 154 | Contributor provides its Contributions) on an "AS IS" BASIS, 155 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 156 | implied, including, without limitation, any warranties or conditions 157 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 158 | PARTICULAR PURPOSE. You are solely responsible for determining the 159 | appropriateness of using or redistributing the Work and assume any 160 | risks associated with Your exercise of permissions under this License. 161 | 162 | 8. Limitation of Liability. In no event and under no legal theory, 163 | whether in tort (including negligence), contract, or otherwise, 164 | unless required by applicable law (such as deliberate and grossly 165 | negligent acts) or agreed to in writing, shall any Contributor be 166 | liable to You for damages, including any direct, indirect, special, 167 | incidental, or consequential damages of any character arising as a 168 | result of this License or out of the use or inability to use the 169 | Work (including but not limited to damages for loss of goodwill, 170 | work stoppage, computer failure or malfunction, or any and all 171 | other commercial damages or losses), even if such Contributor 172 | has been advised of the possibility of such damages. 173 | 174 | 9. Accepting Warranty or Additional Liability. While redistributing 175 | the Work or Derivative Works thereof, You may choose to offer, 176 | and charge a fee for, acceptance of support, warranty, indemnity, 177 | or other liability obligations and/or rights consistent with this 178 | License. However, in accepting such obligations, You may act only 179 | on Your own behalf and on Your sole responsibility, not on behalf 180 | of any other Contributor, and only if You agree to indemnify, 181 | defend, and hold each Contributor harmless for any liability 182 | incurred by, or claims asserted against, such Contributor by reason 183 | of your accepting any such warranty or additional liability. 184 | 185 | END OF TERMS AND CONDITIONS 186 | 187 | APPENDIX: How to apply the Apache License to your work. 188 | 189 | To apply the Apache License to your work, attach the following 190 | boilerplate notice, with the fields enclosed by brackets "[]" 191 | replaced with your own identifying information. (Don't include 192 | the brackets!) The text should be enclosed in the appropriate 193 | comment syntax for the file format. We also recommend that a 194 | file or class name and description of purpose be included on the 195 | same "printed page" as the copyright notice for easier 196 | identification within third-party archives. 197 | 198 | Copyright [yyyy] [name of copyright owner] 199 | 200 | Licensed under the Apache License, Version 2.0 (the "License"); 201 | you may not use this file except in compliance with the License. 202 | You may obtain a copy of the License at 203 | 204 | http://www.apache.org/licenses/LICENSE-2.0 205 | 206 | Unless required by applicable law or agreed to in writing, software 207 | distributed under the License is distributed on an "AS IS" BASIS, 208 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 209 | See the License for the specific language governing permissions and 210 | limitations under the License. 211 | 212 | 213 | ------------------ 214 | 215 | ** lm-evaluation-harness - https://github.com/EleutherAI/lm-evaluation-harness 216 | 217 | MIT License 218 | 219 | Copyright (c) 2020 EleutherAI 220 | 221 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 222 | 223 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 224 | 225 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 226 | 227 | ------------------ 228 | 229 | 230 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: opt 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - pip 8 | - python=3.9.0 9 | - pip: 10 | - absl-py 11 | - accelerate 12 | - aiohttp 13 | - aiosignal 14 | - async-timeout 15 | - attrs 16 | - best-download 17 | - black 18 | - chardet 19 | - click 20 | - colorama 21 | - cycler 22 | - cython 23 | - dataproperty 24 | - datasets 25 | - dill 26 | - dynet38 27 | - einops 28 | - filelock 29 | - fonttools 30 | - frozenlist 31 | - fsspec 32 | - hjson 33 | - huggingface-hub 34 | - iniconfig 35 | - jieba 36 | - joblib 37 | - jsonlines 38 | - kiwisolver 39 | - lm-dataformat 40 | - matplotlib 41 | - mbstrdecoder 42 | - mock 43 | - msgfy 44 | - multidict 45 | - multiprocess 46 | - mypy-extensions 47 | - nagisa 48 | - ninja 49 | - nltk 50 | - numexpr 51 | - packaging 52 | - pandas 53 | - pathspec 54 | - pathvalidate 55 | - platformdirs 56 | - pluggy 57 | - portalocker 58 | - psutil 59 | - py 60 | - py-cpuinfo 61 | - pyarrow 62 | - pycountry 63 | - pydantic 64 | - pyparsing 65 | - pytablewriter 66 | - pytest 67 | - python-dateutil 68 | - pytz 69 | - pyyaml 70 | - regex 71 | - rehash 72 | - responses 73 | - rouge-score 74 | - sacrebleu 75 | - scikit-learn 76 | - scipy 77 | - seaborn 78 | - sqlitedict 79 | - tabledata 80 | - tabulate 81 | - tcolorpy 82 | - threadpoolctl 83 | - tokenizers==0.12 84 | - toml 85 | - tomli 86 | - torchsummary 87 | - tqdm 88 | - tqdm-multiprocess 89 | - tueplots 90 | - typepy 91 | - ujson 92 | - xxhash 93 | - yarl 94 | - zstandard 95 | -------------------------------------------------------------------------------- /lm_eval/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Iterable 3 | import numpy as np 4 | import random 5 | import re 6 | import os 7 | import json 8 | import hashlib 9 | import datasets 10 | from sqlitedict import SqliteDict 11 | from tqdm import tqdm 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.utils.data import Dataset, DataLoader 15 | from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte 16 | from lm_eval import utils 17 | from abc import abstractmethod 18 | from einops import rearrange 19 | from tqdm import tqdm 20 | from collections import defaultdict 21 | import pickle 22 | 23 | class LM(abc.ABC): 24 | def __init__(self): 25 | self.cache_hook = CacheHook(None) 26 | 27 | @abstractmethod 28 | def loglikelihood(self, requests): 29 | """Compute log-likelihood of generating a continuation from a context. 30 | Downstream tasks should attempt to use loglikelihood instead of other 31 | LM calls whenever possible. 32 | 33 | :param requests: list 34 | A list of pairs (context, continuation) 35 | context: str 36 | Context string. Implementations of LM must be able to handle an 37 | empty context string. 38 | continuation: str 39 | The continuation over which log likelihood will be calculated. If 40 | there is a word boundary, the space should be in the continuation. 41 | For example, context="hello" continuation=" world" is correct. 42 | :return: list 43 | A list of pairs (logprob, isgreedy) 44 | logprob: float 45 | The log probability of `continuation` 46 | isgreedy: 47 | Whether `continuation` would be generated by greedy sampling from `context` 48 | """ 49 | pass 50 | 51 | @abstractmethod 52 | def loglikelihood_rolling(self, requests): 53 | """Compute full log-likelihood of a string, with no truncation, for perplexity computation 54 | - We will use the full max context length of the model. 55 | - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to 56 | the max context length. 57 | - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations 58 | which may simply concatenate multiple documents together. 59 | - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into 60 | multiple chunks, the last input will still a full-sized context. 61 | Example: 62 | Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] 63 | Prefix: EOT 64 | Max context length: 4 65 | Resulting input/prediction pairs: 66 | 67 | INPUT: EOT 0 1 2 68 | PRED: 0 1 2 3 69 | 70 | INPUT: 3 4 5 6 71 | PRED: 4 5 6 7 72 | 73 | INPUT: 5 6 7 8 74 | PRED: 8 9 75 | 76 | Observe that: 77 | 1. Each token is predicted exactly once 78 | 2. For the last pair, we provide the full context, but only score the last two tokens 79 | 80 | :param requests: list 81 | A list of strings 82 | string: str 83 | String for which we are computing per-toke loglikelihood 84 | :return: list 85 | A list of pairs (logprob, isgreedy) 86 | logprob: float 87 | The log probability of `continuation` 88 | isgreedy: 89 | Whether `continuation` would be generated by greedy sampling from `context` 90 | """ 91 | pass 92 | 93 | # TODO: Add an optional max length 94 | @abstractmethod 95 | def greedy_until(self, requests): 96 | """Generate greedily until a stopping sequence 97 | 98 | :param requests: list 99 | A list of pairs (context, until) 100 | context: str 101 | Context string 102 | until: [str] 103 | The string sequences to generate until. These string sequences 104 | may each span across multiple tokens, or may be part of one token. 105 | :return: list 106 | A list of strings continuation 107 | continuation: str 108 | The generated continuation. 109 | """ 110 | pass 111 | 112 | @abstractmethod 113 | def calculate_importance(self, dataloader): 114 | ''' 115 | Docstring 116 | ''' 117 | pass 118 | 119 | @classmethod 120 | def create_from_arg_string(cls, arg_string, additional_config=None): 121 | additional_config = {} if additional_config is None else additional_config 122 | args = utils.simple_parse_args_string(arg_string) 123 | args2 = {k: v for k, v in additional_config.items() if v is not None} 124 | return cls(**args, **args2) 125 | 126 | def set_cache_hook(self, cache_hook): 127 | self.cache_hook = cache_hook 128 | 129 | def print_active_bytes(): 130 | stats = torch.cuda.memory_stats() 131 | current_active_byte = stats["active_bytes.all.current"] 132 | print(current_active_byte) 133 | 134 | class BaseLM(LM): 135 | @property 136 | @abstractmethod 137 | def eot_token_id(self): 138 | pass 139 | 140 | @property 141 | @abstractmethod 142 | def max_length(self): 143 | pass 144 | 145 | @property 146 | @abstractmethod 147 | def max_gen_toks(self): 148 | pass 149 | 150 | @property 151 | @abstractmethod 152 | def batch_size(self): 153 | pass 154 | 155 | @property 156 | @abstractmethod 157 | def device(self): 158 | pass 159 | 160 | @abstractmethod 161 | def tok_encode(self, string: str): 162 | pass 163 | 164 | @abstractmethod 165 | def tok_decode(self, tokens: Iterable[int]): 166 | pass 167 | 168 | @abstractmethod 169 | def _model_generate(self, context, max_length, eos_token_id): 170 | pass 171 | 172 | @abstractmethod 173 | def _model_call(self, inps): 174 | """ 175 | inps: a torch tensor of shape [batch, sequence] 176 | the size of sequence may vary from call to call 177 | 178 | returns: a torch tensor of shape [batch, sequence, vocab] with the 179 | logits returned from the model 180 | """ 181 | pass 182 | 183 | def dataset_analysis(self, dataloader): 184 | l = [] 185 | for ctx, cont, inp, l_ctx, l_cont in tqdm(dataloader): 186 | batch_max_length = torch.max(l_ctx + l_cont).item() 187 | l.append(batch_max_length) 188 | return l 189 | 190 | def calculate_importance(self, dataloader): 191 | num_hidden_layers = self.opt.config.num_hidden_layers 192 | num_heads = self.opt.config.num_attention_heads 193 | tot_tokens, eff_tokens = 0, 0 194 | importance_score = torch.zeros(num_hidden_layers, num_heads).to('cpu') 195 | ## disable dropout 196 | self.opt.eval() 197 | 198 | for ctx, cont, inp, l_ctx, l_cont in tqdm(dataloader): 199 | batch_max_length = torch.max(l_ctx + l_cont).item() 200 | inp = inp[:, :batch_max_length] 201 | attn_mask = torch.ones((len(l_ctx), batch_max_length), dtype=torch.long) 202 | labels = torch.empty((len(l_ctx), batch_max_length), dtype=torch.long).fill_(-100.) 203 | for i in range(len(l_ctx)): 204 | attn_mask[i][l_ctx[i]+l_cont[i]:] = torch.zeros(batch_max_length - (l_ctx[i]+l_cont[i])) 205 | labels[i][l_ctx[i]: l_ctx[i] + l_cont[i]] = inp[i][l_ctx[i]:l_ctx[i]+l_cont[i]] 206 | tot_tokens += attn_mask[i].float().cpu().detach().sum().data - 1 ## if the length of the sequence is N, then the gradient is calculated over N - 1 tokens 207 | eff_tokens += (labels[i] != -100.).cpu().detach().sum().data ## we won't have gradients for non-label positions in the last layer 208 | ll = self._model_call(inp.to(self.device), attn_mask.to(self.device), labels.to(self.device)) 209 | ll.backward() 210 | for layer in range(num_hidden_layers): 211 | self_attention = self.opt.get_decoder().layers[layer].self_attn 212 | attn_x = self_attention.context_layer_val 213 | grad_attn_x = self_attention.context_layer_val_grad 214 | dim = attn_x.shape[-1] 215 | attn_x, grad_attn_x = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=num_heads, d=dim//num_heads), (attn_x, grad_attn_x)) # shape = bs, num_heads, seq_len, dim_per_head 216 | dot = torch.einsum("bhli,bhli->bhl", [grad_attn_x, attn_x]).to('cpu') # not all layers are on the same device hence make sure dot is on the self.device 217 | importance_score[layer] += dot.abs().sum(-1).sum(0).detach() 218 | ## helps in reducing the memory footprint 219 | self.opt.zero_grad() 220 | del attn_x, grad_attn_x, dot, ll 221 | # self.ds_engine.module.zero_grad() 222 | for param in self.opt.parameters(): 223 | param.grad = None 224 | 225 | importance_score[:-1] /= tot_tokens 226 | importance_score[-1] /= eff_tokens 227 | return importance_score 228 | 229 | def loglikelihood(self, requests): 230 | new_reqs = [] 231 | for context, continuation in requests: 232 | if context == "": 233 | # end of text as context 234 | context_enc = [self.eot_token_id] 235 | else: 236 | context_enc = self.tok_encode(context) 237 | 238 | continuation_enc = self.tok_encode(continuation) 239 | 240 | new_reqs.append(((context, continuation), context_enc, continuation_enc)) 241 | 242 | return self._loglikelihood_tokens(new_reqs) 243 | 244 | def loglikelihood_rolling(self, requests): 245 | # TODO: Implement caching once we've confirmed the perplexity implementation 246 | # TODO: automatic batch size detection for vectorization 247 | 248 | loglikelihoods = [] 249 | for (string,) in tqdm(requests): 250 | rolling_token_windows = list( 251 | map( 252 | utils.make_disjoint_window, 253 | utils.get_rolling_token_windows( 254 | token_list=self.tok_encode(string), 255 | prefix_token=self.eot_token_id, 256 | max_seq_len=self.max_length, 257 | context_len=1, 258 | ), 259 | ) 260 | ) 261 | 262 | rolling_token_windows = [(None,) + x for x in rolling_token_windows] 263 | 264 | # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for 265 | # that 266 | string_nll = self._loglikelihood_tokens( 267 | rolling_token_windows, disable_tqdm=True 268 | ) 269 | 270 | # discard is_greedy 271 | string_nll = [x[0] for x in string_nll] 272 | 273 | string_nll = sum(string_nll) 274 | loglikelihoods.append(string_nll) 275 | 276 | return loglikelihoods 277 | 278 | def _loglikelihood_tokens(self, requests, disable_tqdm=False): 279 | # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context 280 | res = [] 281 | 282 | def _collate(x): 283 | # the negative sign on len(toks) sorts descending - this has a few advantages: 284 | # - time estimates will always be over not underestimates, which is more useful for planning 285 | # - to know the size of a batch when going through the list, you know the first one is always the batch 286 | # padded context length. this is useful to simplify the batching logic and more importantly to make 287 | # automatic adaptive batches much much easier to implement 288 | # - any OOMs will happen right away rather than near the end 289 | 290 | toks = x[1] + x[2] 291 | return -len(toks), tuple(toks) 292 | 293 | # TODO: automatic (variable) batch size detection for vectorization 294 | re_ord = utils.Reorderer(requests, _collate) 295 | for chunk in utils.chunks( 296 | tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size 297 | ): 298 | inps = [] 299 | cont_toks_list = [] 300 | inplens = [] 301 | 302 | padding_length = None 303 | 304 | # because vectorizing is annoying, we first convert each (context, continuation) pair to padded 305 | # tensors, then we pack them together into a batch, call the model, and then pick it all apart 306 | # again because vectorizing is annoying 307 | 308 | for _, context_enc, continuation_enc in chunk: 309 | # sanity check 310 | assert len(context_enc) > 0 311 | assert len(continuation_enc) > 0 312 | assert len(continuation_enc) <= self.max_length 313 | 314 | # how this all works: 315 | # CTX CONT 316 | # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] 317 | # gpt2 \ \ 318 | # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the 319 | # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice 320 | 321 | # when too long to fit in context, truncate from the left 322 | inp = torch.tensor( 323 | (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], 324 | dtype=torch.long, 325 | ).to(self.device) 326 | (inplen,) = inp.shape 327 | 328 | cont = continuation_enc 329 | 330 | # since in _collate we make sure length is descending, the longest is always the first one. 331 | padding_length = ( 332 | padding_length if padding_length is not None else inplen 333 | ) 334 | 335 | # pad length from seq to padding_length 336 | inp = torch.cat( 337 | [ 338 | inp, # [seq] 339 | torch.zeros(padding_length - inplen, dtype=torch.long).to( 340 | inp.device 341 | ), # [padding_length - seq] 342 | ], 343 | dim=0, 344 | ) 345 | 346 | inps.append(inp.unsqueeze(0)) # [1, padding_length] 347 | cont_toks_list.append(cont) 348 | inplens.append(inplen) 349 | 350 | batched_inps = torch.cat(inps, dim=0) # [batch, padding_length 351 | multi_logits = F.log_softmax( 352 | self._model_call(batched_inps), dim=-1 353 | ).cpu() # [batch, padding_length, vocab] 354 | 355 | for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( 356 | chunk, multi_logits, inps, inplens, cont_toks_list 357 | ): 358 | 359 | # Slice to original seq length 360 | contlen = len(cont_toks) 361 | logits = logits[inplen - contlen : inplen].unsqueeze( 362 | 0 363 | ) # [1, seq, vocab] 364 | 365 | # Check if per-token argmax is exactly equal to continuation 366 | greedy_tokens = logits.argmax(dim=-1) 367 | cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze( 368 | 0 369 | ) # [1, seq] 370 | max_equal = (greedy_tokens == cont_toks).all() 371 | 372 | # Obtain log-probs at the corresponding continuation token indices 373 | # last_token_slice = logits[:, -1, :].squeeze(0).tolist() 374 | logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( 375 | -1 376 | ) # [1, seq] 377 | 378 | # Answer: (log prob, is-exact-match) 379 | answer = (float(logits.sum()), bool(max_equal)) 380 | 381 | # partial caching 382 | if cache_key is not None: 383 | self.cache_hook.add_partial("loglikelihood", cache_key, answer) 384 | 385 | res.append(answer) 386 | 387 | return re_ord.get_original(res) 388 | 389 | def greedy_until(self, requests): 390 | # TODO: implement fully general `until` that handles until that are 391 | # multiple tokens or that span multiple tokens correctly 392 | 393 | # TODO: extract to TokenizedLM? 394 | res = [] 395 | 396 | def _collate(x): 397 | toks = self.tok_encode(x[0]) 398 | return len(toks), x[0] 399 | 400 | re_ord = utils.Reorderer(requests, _collate) 401 | 402 | for context, until in tqdm(re_ord.get_reordered()): 403 | if isinstance(until, str): 404 | until = [until] 405 | 406 | (primary_until,) = self.tok_encode(until[0]) 407 | 408 | context_enc = torch.tensor( 409 | [self.tok_encode(context)[self.max_gen_toks - self.max_length :]] 410 | ).to(self.device) 411 | 412 | cont = self._model_generate( 413 | context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until 414 | ) 415 | 416 | s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) 417 | 418 | for term in until: 419 | s = s.split(term)[0] 420 | 421 | # partial caching 422 | self.cache_hook.add_partial("greedy_until", (context, until), s) 423 | 424 | res.append(s) 425 | 426 | return re_ord.get_original(res) 427 | 428 | 429 | class Task(abc.ABC): 430 | """A task represents an entire benchmark including its dataset, problems, 431 | answers, and evaluation methods. See BoolQ for a simple example implementation 432 | 433 | A `doc` can be any python object which represents one instance of evaluation. 434 | This is usually a dictionary e.g. 435 | {"question": ..., "answer": ...} or 436 | {"question": ..., question, answer) 437 | """ 438 | 439 | # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub 440 | # or a path to a custom `datasets` loading script. 441 | DATASET_PATH: str = None 442 | 443 | # The name of a subset within `DATASET_PATH`. 444 | DATASET_NAME: str = None 445 | 446 | def __init__(self, data_dir=None, cache_dir=None, download_mode=None): 447 | """ 448 | :param data_dir: str 449 | Stores the path to a local folder containing the `Task`'s data files. 450 | Use this to specify the path to manually downloaded data (usually when 451 | the dataset is not publicly accessible). 452 | :param cache_dir: str 453 | The directory to read/write the `Task` dataset. This follows the 454 | HuggingFace `datasets` API with the default cache directory located at: 455 | `~/.cache/huggingface/datasets` 456 | NOTE: You can change the cache location globally for a given process 457 | by setting the shell environment variable, `HF_DATASETS_CACHE`, 458 | to another directory: 459 | `export HF_DATASETS_CACHE="/path/to/another/directory"` 460 | :param download_mode: datasets.DownloadMode 461 | How to treat pre-existing `Task` downloads and data. 462 | - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` 463 | Reuse download and reuse dataset. 464 | - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` 465 | Reuse download with fresh dataset. 466 | - `datasets.DownloadMode.FORCE_REDOWNLOAD` 467 | Fresh download and fresh dataset. 468 | """ 469 | self.download(data_dir, cache_dir, download_mode) 470 | self._training_docs = None 471 | self._validation_docs = None 472 | self._fewshot_docs = None 473 | 474 | def download(self, data_dir=None, cache_dir=None, download_mode=None): 475 | """Downloads and returns the task dataset. 476 | Override this method to download the dataset from a custom API. 477 | 478 | :param data_dir: str 479 | Stores the path to a local folder containing the `Task`'s data files. 480 | Use this to specify the path to manually downloaded data (usually when 481 | the dataset is not publicly accessible). 482 | :param cache_dir: str 483 | The directory to read/write the `Task` dataset. This follows the 484 | HuggingFace `datasets` API with the default cache directory located at: 485 | `~/.cache/huggingface/datasets` 486 | NOTE: You can change the cache location globally for a given process 487 | by setting the shell environment variable, `HF_DATASETS_CACHE`, 488 | to another directory: 489 | `export HF_DATASETS_CACHE="/path/to/another/directory"` 490 | :param download_mode: datasets.DownloadMode 491 | How to treat pre-existing `Task` downloads and data. 492 | - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` 493 | Reuse download and reuse dataset. 494 | - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` 495 | Reuse download with fresh dataset. 496 | - `datasets.DownloadMode.FORCE_REDOWNLOAD` 497 | Fresh download and fresh dataset. 498 | """ 499 | self.dataset = datasets.load_dataset( 500 | path=self.DATASET_PATH, 501 | name=self.DATASET_NAME, 502 | data_dir=data_dir, 503 | cache_dir=cache_dir, 504 | download_mode=download_mode, 505 | ) 506 | 507 | def should_decontaminate(self): 508 | """Whether this task supports decontamination against model training set.""" 509 | return False 510 | 511 | @abstractmethod 512 | def has_training_docs(self): 513 | """Whether the task has a training set""" 514 | pass 515 | 516 | @abstractmethod 517 | def has_validation_docs(self): 518 | """Whether the task has a validation set""" 519 | pass 520 | 521 | @abstractmethod 522 | def has_test_docs(self): 523 | """Whether the task has a test set""" 524 | pass 525 | 526 | def training_docs(self): 527 | """ 528 | :return: Iterable[obj] 529 | A iterable of any object, that doc_to_text can handle 530 | """ 531 | return [] 532 | 533 | def validation_docs(self): 534 | """ 535 | :return: Iterable[obj] 536 | A iterable of any object, that doc_to_text can handle 537 | """ 538 | return [] 539 | 540 | def test_docs(self): 541 | """ 542 | :return: Iterable[obj] 543 | A iterable of any object, that doc_to_text can handle 544 | """ 545 | return [] 546 | 547 | def _process_doc(self, doc): 548 | """ 549 | Override this to process (detokenize, strip, replace, etc.) individual 550 | documents. This can be used in a map over documents of a data split. 551 | E.g. `map(self._process_doc, self.dataset["validation"])` 552 | 553 | :return: dict 554 | The processed version of the specified `doc`. 555 | """ 556 | return doc 557 | 558 | def fewshot_examples(self, k, rnd): 559 | if self._training_docs is None: 560 | self._training_docs = list(self.training_docs()) 561 | 562 | return rnd.sample(self._training_docs, k) 563 | 564 | def doc_to_decontamination_query(self, doc): 565 | print( 566 | "Override doc_to_decontamination_query with document specific decontamination query." 567 | ) 568 | assert False 569 | 570 | @abstractmethod 571 | def doc_to_text(self, doc): 572 | pass 573 | 574 | @abstractmethod 575 | def doc_to_target(self, doc): 576 | pass 577 | 578 | def get_dataloader(self, tokenizer, split='train', batch_size=1): 579 | pass 580 | 581 | @abstractmethod 582 | def construct_requests(self, doc, ctx): 583 | """Uses RequestFactory to construct Requests and returns an iterable of 584 | Requests which will be sent to the LM. 585 | 586 | :param doc: 587 | The document as returned from training_docs, validation_docs, or test_docs. 588 | :param ctx: str 589 | The context string, generated by fewshot_context. This includes the natural 590 | language description, as well as the few shot examples, and the question 591 | part of the document for `doc`. 592 | """ 593 | pass 594 | 595 | @abstractmethod 596 | def process_results(self, doc, results): 597 | """Take a single document and the LM results and evaluates, returning a 598 | dict where keys are the names of submetrics and values are the values of 599 | the metric for that one document 600 | 601 | :param doc: 602 | The document as returned from training_docs, validation_docs, or test_docs. 603 | :param results: 604 | The results of the requests created in construct_requests. 605 | """ 606 | pass 607 | 608 | @abstractmethod 609 | def aggregation(self): 610 | """ 611 | :returns: {str: [metric_score] -> float} 612 | A dictionary where keys are the names of submetrics and values are 613 | functions that aggregate a list of metric scores 614 | """ 615 | pass 616 | 617 | @abstractmethod 618 | def higher_is_better(self): 619 | """ 620 | :returns: {str: bool} 621 | A dictionary where keys are the names of submetrics and values are 622 | whether a higher value of the submetric is better 623 | """ 624 | pass 625 | 626 | def fewshot_description(self): 627 | import warnings 628 | 629 | warnings.warn( 630 | "`fewshot_description` will be removed in futures versions. Pass " 631 | "any custom descriptions to the `evaluate` function instead.", 632 | DeprecationWarning, 633 | ) 634 | return "" 635 | 636 | @utils.positional_deprecated 637 | def fewshot_context( 638 | self, doc, num_fewshot, provide_description=None, rnd=random, description=None 639 | ): 640 | """Returns a fewshot context string that is made up of a prepended description 641 | (if provided), the `num_fewshot` number of examples, and an appended prompt example. 642 | 643 | :param doc: str 644 | The document as returned from training_docs, validation_docs, or test_docs. 645 | :param num_fewshot: int 646 | The number of fewshot examples to provide in the returned context string. 647 | :param provide_description: bool 648 | Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method 649 | :param rnd: random.Random 650 | The pseudo-random number generator used to randomly sample examples. 651 | WARNING: This is currently a required arg although it's optionalized with a default `None`. 652 | :param description: str 653 | The task's description that will be prepended to the fewshot examples. 654 | :returns: str 655 | The fewshot context. 656 | """ 657 | 658 | assert ( 659 | rnd is not None 660 | ), "A `random.Random` generator argument must be provided to `rnd`" 661 | assert not provide_description, ( 662 | "The `provide_description` arg will be removed in future versions. To prepend " 663 | "a custom description to the context, supply the corresponding string via the " 664 | "`description` arg." 665 | ) 666 | if provide_description is not None: 667 | # nudge people to not specify it at all 668 | print( 669 | "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" 670 | ) 671 | 672 | description = description + "\n\n" if description else "" 673 | 674 | if num_fewshot == 0: 675 | labeled_examples = "" 676 | else: 677 | # for sets with no training docs, draw from other set *but ensure no overlap with current doc* 678 | if self.has_training_docs(): 679 | fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) 680 | else: 681 | if self._fewshot_docs is None: 682 | self._fewshot_docs = list( 683 | self.validation_docs() 684 | if self.has_validation_docs() 685 | else self.test_docs() 686 | ) 687 | 688 | fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) 689 | 690 | # get rid of the doc that's the one we're evaluating, if it's in the fewshot 691 | fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] 692 | 693 | labeled_examples = ( 694 | "\n\n".join( 695 | [ 696 | self.doc_to_text(doc) + self.doc_to_target(doc) 697 | for doc in fewshotex 698 | ] 699 | ) 700 | + "\n\n" 701 | ) 702 | 703 | example = self.doc_to_text(doc) 704 | return description + labeled_examples + example 705 | 706 | 707 | class MultipleChoiceTask(Task): 708 | def doc_to_target(self, doc): 709 | return " " + doc["choices"][doc["gold"]] 710 | 711 | def construct_requests(self, doc, ctx): 712 | lls = [ 713 | rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"] 714 | ] 715 | 716 | return lls 717 | 718 | def process_results(self, doc, results): 719 | gold = doc["gold"] 720 | 721 | acc = 1.0 if np.argmax(results) == gold else 0.0 722 | completion_len = np.array([float(len(i)) for i in doc["choices"]]) 723 | acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 724 | 725 | return { 726 | "acc": acc, 727 | "acc_norm": acc_norm, 728 | } 729 | 730 | def higher_is_better(self): 731 | return { 732 | "acc": True, 733 | "acc_norm": True, 734 | } 735 | 736 | def aggregation(self): 737 | return { 738 | "acc": mean, 739 | "acc_norm": mean, 740 | } 741 | 742 | 743 | class PerplexityTask(Task, abc.ABC): 744 | def should_decontaminate(self): 745 | """Whether this task supports decontamination against model training set.""" 746 | return True 747 | 748 | def has_training_docs(self): 749 | return False 750 | 751 | def fewshot_examples(self, k, rnd): 752 | assert k == 0 753 | return [] 754 | 755 | def fewshot_context( 756 | self, doc, num_fewshot, provide_description=None, rnd=None, description=None 757 | ): 758 | assert ( 759 | num_fewshot == 0 760 | ), "The number of fewshot examples must be 0 for perplexity tasks." 761 | assert ( 762 | rnd is not None 763 | ), "A `random.Random` generator argument must be provided to `rnd`." 764 | assert not provide_description, ( 765 | "The `provide_description` arg will be removed in future versions. To prepend " 766 | "a custom description to the context, supply the corresponding string via the " 767 | "`description` arg." 768 | ) 769 | if provide_description is not None: 770 | # nudge people to not specify it at all 771 | print( 772 | "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" 773 | ) 774 | 775 | return "" 776 | 777 | def higher_is_better(self): 778 | return { 779 | "word_perplexity": False, 780 | "byte_perplexity": False, 781 | "bits_per_byte": False, 782 | } 783 | 784 | def doc_to_decontamination_query(self, doc): 785 | return doc 786 | 787 | def doc_to_text(self, doc): 788 | return "" 789 | 790 | def doc_to_target(self, doc): 791 | return doc 792 | 793 | def construct_requests(self, doc, ctx): 794 | assert not ctx 795 | req = rf.loglikelihood_rolling(self.doc_to_target(doc)) 796 | return req 797 | 798 | def process_results(self, doc, results): 799 | (loglikelihood,) = results 800 | words = self.count_words(doc) 801 | bytes_ = self.count_bytes(doc) 802 | return { 803 | "word_perplexity": (loglikelihood, words), 804 | "byte_perplexity": (loglikelihood, bytes_), 805 | "bits_per_byte": (loglikelihood, bytes_), 806 | } 807 | 808 | def aggregation(self): 809 | return { 810 | "word_perplexity": weighted_perplexity, 811 | "byte_perplexity": weighted_perplexity, 812 | "bits_per_byte": bits_per_byte, 813 | } 814 | 815 | @classmethod 816 | def count_bytes(cls, doc): 817 | return len(doc.encode("utf-8")) 818 | 819 | @classmethod 820 | def count_words(cls, doc): 821 | """Downstream tasks with custom word boundaries should override this!""" 822 | return len(re.split(r"\s+", doc)) 823 | 824 | 825 | def hash_args(attr, args): 826 | dat = json.dumps([attr] + list(args)) 827 | return hashlib.sha256(dat.encode("utf-8")).hexdigest() 828 | 829 | 830 | class CacheHook: 831 | def __init__(self, cachinglm): 832 | if cachinglm is None: 833 | self.dbdict = None 834 | return 835 | 836 | self.dbdict = cachinglm.dbdict 837 | 838 | def add_partial(self, attr, req, res): 839 | if self.dbdict is None: 840 | return 841 | hsh = hash_args(attr, req) 842 | self.dbdict[hsh] = res 843 | 844 | 845 | class CachingLM: 846 | def __init__(self, lm, cache_db): 847 | """LM wrapper that returns cached results if they exist, and uses the underlying LM if not. 848 | 849 | :param lm: LM 850 | Underlying LM 851 | :param cache_db: str 852 | Path to cache db 853 | """ 854 | self.lm = lm 855 | self.cache_db = cache_db 856 | if os.path.dirname(cache_db): 857 | os.makedirs(os.path.dirname(cache_db), exist_ok=True) 858 | self.dbdict = SqliteDict(cache_db, autocommit=True) 859 | 860 | # add hook to lm 861 | lm.set_cache_hook(self.get_cache_hook()) 862 | 863 | def __getattr__(self, attr): 864 | def fn(requests): 865 | res = [] 866 | remaining_reqs = [] 867 | 868 | # figure out which ones are cached and which ones are new 869 | for req in requests: 870 | hsh = hash_args(attr, req) 871 | if hsh in self.dbdict: 872 | ob = self.dbdict[hsh] 873 | 874 | assert ob is not None 875 | 876 | res.append(ob) 877 | else: 878 | res.append(None) 879 | remaining_reqs.append(req) 880 | 881 | # actually run the LM on the requests that do not have cached results 882 | rem_res = getattr(self.lm, attr)(remaining_reqs) 883 | 884 | # stick the new ones back into the list and also cache any of the new ones 885 | resptr = 0 886 | for req, r in zip(remaining_reqs, rem_res): 887 | while res[resptr] is not None: 888 | resptr += 1 889 | 890 | res[resptr] = r 891 | 892 | # caching 893 | hsh = hash_args(attr, req) 894 | self.dbdict[hsh] = r 895 | self.dbdict.commit() 896 | 897 | return res 898 | 899 | return fn 900 | 901 | def get_cache_hook(self): 902 | return CacheHook(self) 903 | 904 | 905 | REQUEST_RETURN_LENGTHS = { 906 | "loglikelihood": 2, 907 | "greedy_until": None, 908 | "loglikelihood_rolling": None, 909 | } 910 | 911 | 912 | class Request: 913 | def __init__(self, request_type, args, index=None): 914 | if request_type not in REQUEST_RETURN_LENGTHS.keys(): 915 | raise NotImplementedError( 916 | "The request type {} is not implemented!".format(request_type) 917 | ) 918 | 919 | self.request_type = request_type 920 | self.args = args 921 | self.index = index 922 | 923 | def __iter__(self): 924 | if REQUEST_RETURN_LENGTHS[self.request_type] is None: 925 | raise IndexError("This request type does not return multiple arguments!") 926 | for i in range(REQUEST_RETURN_LENGTHS[self.request_type]): 927 | yield Request(self.request_type, self.args, i) 928 | 929 | def __getitem__(self, i): 930 | if REQUEST_RETURN_LENGTHS[self.request_type] is None: 931 | raise IndexError("This request type does not return multiple arguments!") 932 | return Request(self.request_type, self.args, i) 933 | 934 | def __eq__(self, other): 935 | return ( 936 | self.request_type == other.request_type 937 | and self.args == other.args 938 | and self.index == other.index 939 | ) 940 | 941 | def __repr__(self): 942 | return f"Req_{self.request_type}{self.args}[{self.index}]\n" 943 | 944 | 945 | class RequestFactory: 946 | def __getattr__(self, attr): 947 | def fn(*args): 948 | return Request(attr, args) 949 | 950 | return fn 951 | 952 | 953 | rf = RequestFactory() 954 | -------------------------------------------------------------------------------- /lm_eval/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import itertools 4 | import numpy as np 5 | import random 6 | import torch 7 | import pickle 8 | import lm_eval.metrics 9 | import lm_eval.models 10 | import lm_eval.tasks 11 | import lm_eval.base 12 | from lm_eval.utils import positional_deprecated, run_task_tests 13 | from transformers import AutoTokenizer 14 | 15 | @positional_deprecated 16 | def simple_evaluate( 17 | model, 18 | model_args=None, 19 | tasks=[], 20 | num_fewshot=0, 21 | batch_size=None, 22 | device=None, 23 | no_cache=False, 24 | limit=None, 25 | bootstrap_iters=100000, 26 | description_dict=None, 27 | check_integrity=False, 28 | decontamination_ngrams_path=None, 29 | head_importance_calc=False, 30 | save_importance_path=None, 31 | ): 32 | 33 | """Instantiate and evaluate a model on a list of tasks. 34 | 35 | :param model: Union[str, LM] 36 | Name of model or LM object, see lm_eval.models.get_model 37 | :param model_args: Optional[str] 38 | String arguments for each model class, see LM.create_from_arg_string. 39 | Ignored if `model` argument is a LM object. 40 | :param tasks: list[Union[str, Task]] 41 | List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. 42 | :param num_fewshot: int 43 | Number of examples in few-shot context 44 | :param batch_size: int, optional 45 | Batch size for model 46 | :param device: str, optional 47 | PyTorch device (e.g. "cpu" or "cuda:0") for running models 48 | :param no_cache: bool 49 | Whether or not to cache 50 | :param limit: int, optional 51 | Limit the number of examples per task (only use this for testing) 52 | :param bootstrap_iters: 53 | Number of iterations for bootstrap statistics 54 | :param description_dict: dict[str, str] 55 | Dictionary of custom task descriptions of the form: `task_name: description` 56 | :param check_integrity: bool 57 | Whether to run the relevant part of the test suite for the tasks 58 | :return 59 | Dictionary of results 60 | """ 61 | random.seed(1234) 62 | np.random.seed(1234) 63 | 64 | assert tasks != [], "No tasks specified" 65 | 66 | if isinstance(model, str): 67 | if model_args is None: 68 | model_args = "" 69 | lm = lm_eval.models.get_model(model).create_from_arg_string( 70 | model_args, {"batch_size": batch_size, "device": device} 71 | ) 72 | else: 73 | assert isinstance(model, lm_eval.base.LM) 74 | lm = model 75 | 76 | task_dict = lm_eval.tasks.get_task_dict(tasks) 77 | 78 | if check_integrity: 79 | run_task_tests(task_list=tasks) 80 | 81 | if not head_importance_calc: 82 | results = evaluate( 83 | lm=lm, 84 | task_dict=task_dict, 85 | num_fewshot=num_fewshot, 86 | limit=limit, 87 | bootstrap_iters=bootstrap_iters, 88 | description_dict=description_dict, 89 | decontamination_ngrams_path=decontamination_ngrams_path, 90 | ) 91 | results["config"] = { 92 | "model": model, 93 | "model_args": model_args, 94 | "num_fewshot": num_fewshot, 95 | "batch_size": batch_size, 96 | "device": device, 97 | "no_cache": no_cache, 98 | "limit": limit, 99 | "bootstrap_iters": bootstrap_iters, 100 | "description_dict": description_dict, 101 | } 102 | else: 103 | if os.path.exists(save_importance_path): 104 | with open(save_importance_path, 'rb') as handle: 105 | results = pickle.load(handle) 106 | else: 107 | results = head_importance( 108 | lm=lm, 109 | task_dict=task_dict, 110 | num_fewshot=num_fewshot, 111 | description_dict=description_dict, 112 | save_importance_path=save_importance_path, 113 | ) 114 | 115 | return results 116 | 117 | 118 | decontaminate_suffix = "_decontaminate" 119 | 120 | 121 | def head_importance( 122 | lm, 123 | task_dict, 124 | num_fewshot=0, 125 | description_dict=None, 126 | save_importance_path=None, 127 | ): 128 | ''' 129 | Docstring 130 | ''' 131 | for name, task in task_dict.items(): 132 | print(name) 133 | print(task) 134 | split = 'train' if task.has_training_docs() else 'valid' if task.has_validation_docs() else None 135 | if not split: 136 | raise RuntimeError("Task has neither train nor validation") 137 | 138 | tokenizer = lm.get_tokenizer() 139 | dataloader = task.get_dataloader(tokenizer, split, subset_size = 2500, batch_size = lm.batch_size, num_fewshot = num_fewshot) 140 | result = lm.calculate_importance(dataloader) 141 | os.makedirs(os.path.dirname(save_importance_path), exist_ok = True) 142 | with open(save_importance_path, 'wb') as handle: 143 | pickle.dump(result, handle) 144 | print('Importance numbers saved!!') 145 | 146 | return result 147 | 148 | @positional_deprecated 149 | def evaluate( 150 | lm, 151 | task_dict, 152 | provide_description=None, 153 | num_fewshot=0, 154 | limit=10000, 155 | bootstrap_iters=100000, 156 | description_dict=None, 157 | decontamination_ngrams_path=None, 158 | ): 159 | """Instantiate and evaluate a model on a list of tasks. 160 | 161 | :param lm: obj 162 | Language Model 163 | :param task_dict: dict[str, Task] 164 | Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. 165 | :param provide_description: bool 166 | Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method 167 | :param num_fewshot: int 168 | Number of examples in few-shot context 169 | :param limit: int, optional 170 | Limit the number of examples per task (only use this for testing) 171 | :param bootstrap_iters: 172 | Number of iterations for bootstrap statistics 173 | :param description_dict: dict[str, str] 174 | Dictionary of custom task descriptions of the form: `task_name: description` 175 | :return 176 | Dictionary of results 177 | """ 178 | # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces 179 | 180 | # TODO: todo: implement proper description-providing system 181 | assert not provide_description # not implemented. 182 | if provide_description is not None: 183 | # nudge people to not specify it at all 184 | print( 185 | "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" 186 | ) 187 | 188 | decontaminate = decontamination_ngrams_path is not None 189 | 190 | task_dict_items = [ 191 | (name, task) 192 | for name, task in task_dict.items() 193 | if (task.has_validation_docs() or task.has_test_docs()) 194 | ] 195 | 196 | task_dict_items = [ 197 | (name, task) 198 | for name, task in task_dict.items() 199 | if (task.has_validation_docs() or task.has_test_docs()) 200 | ] 201 | results = collections.defaultdict(dict) 202 | versions = collections.defaultdict(dict) 203 | 204 | requests = collections.defaultdict(list) 205 | requests_origin = collections.defaultdict(list) 206 | 207 | overlaps = collections.defaultdict(list) # {task_name: contaminated_docs} 208 | 209 | # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger 210 | # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because 211 | # over-engineering is bad (or we could make it write the requests to disk and then read them back out again 212 | # - probably using an sqlite db because of all the moving parts we have 213 | 214 | # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable 215 | docs = {} 216 | 217 | docs_for_decontamination = collections.defaultdict(list) 218 | 219 | # get lists of each type of request 220 | for task_name, task in task_dict_items: 221 | versions[task_name] = task.VERSION 222 | # default to test doc, fall back to val doc if validation unavailable 223 | # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point 224 | if task.has_test_docs(): 225 | task_doc_func = task.test_docs 226 | task_set = "test" # Required for caching in the decontamination 227 | elif task.has_validation_docs(): 228 | task_set = "val" # Required for caching in the decontamination 229 | task_doc_func = task.validation_docs 230 | else: 231 | raise RuntimeError("Task has neither test_docs nor validation_docs") 232 | 233 | # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order 234 | # can remove 1000 during actual evaluation 235 | task_docs = list(task_doc_func()) 236 | rnd = random.Random() 237 | rnd.seed(42) 238 | rnd.shuffle(task_docs) 239 | 240 | description = ( 241 | description_dict[task_name] 242 | if description_dict and task_name in description_dict 243 | else "" 244 | ) 245 | 246 | for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): 247 | 248 | if decontaminate and task.should_decontaminate(): 249 | docs_for_decontamination[(task_name, task_set)].append( 250 | task.doc_to_decontamination_query(doc) 251 | ) 252 | 253 | docs[(task_name, doc_id)] = doc 254 | ctx = task.fewshot_context( 255 | doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description 256 | ) 257 | reqs = task.construct_requests(doc, ctx) 258 | if not isinstance(reqs, (list, tuple)): 259 | reqs = [reqs] 260 | for i, req in enumerate(reqs): 261 | requests[req.request_type].append(req) 262 | # i: index in requests for a single task instance 263 | # doc_id: unique id that we can get back to a doc using `docs` 264 | requests_origin[req.request_type].append((i, task_name, doc, doc_id)) 265 | 266 | # Compare all tasks/sets at once to ensure a single training set scan 267 | if decontaminate: 268 | from lm_eval.decontamination.decontaminate import get_train_overlap 269 | 270 | print("Finding train/test overlap, please wait...") 271 | overlaps = get_train_overlap( 272 | docs_for_decontamination, decontamination_ngrams_path, limit 273 | ) 274 | 275 | # all responses for each (task, doc) 276 | process_res_queue = collections.defaultdict(list) 277 | 278 | # execute each type of request 279 | for reqtype, reqs in requests.items(): 280 | # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing 281 | # only in index. We could implement some kind of caching, but that would be more of a band-aid 282 | # solution. we could also implement some kind of auto-grouping here; 283 | # they should end up next to each other. 284 | 285 | print("Running", reqtype, "requests") 286 | resps = getattr(lm, reqtype)([req.args for req in reqs]) 287 | resps = [ 288 | x if req.index is None else x[req.index] for x, req in zip(resps, reqs) 289 | ] 290 | 291 | for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): 292 | process_res_queue[(task_name, doc_id)].append((i, resp)) 293 | 294 | vals = collections.defaultdict(list) 295 | 296 | # unpack results and sort back in order and return control to Task 297 | for (task_name, doc_id), requests in process_res_queue.items(): 298 | requests.sort(key=lambda x: x[0]) 299 | requests = [x[1] for x in requests] 300 | 301 | task = task_dict[task_name] 302 | doc = docs[(task_name, doc_id)] 303 | 304 | metrics = task.process_results(doc, requests) 305 | for metric, value in metrics.items(): 306 | vals[(task_name, metric)].append(value) 307 | 308 | # Re-use the evaluation for the decontaminated set by just ignoring the overlaps 309 | if decontaminate and task_name in overlaps: 310 | if doc_id not in overlaps[task_name]: 311 | vals[(task_name, metric + decontaminate_suffix)].append(value) 312 | 313 | # aggregate results 314 | for (task_name, metric), items in vals.items(): 315 | task = task_dict[task_name] 316 | real_metric = metric # key when looking up the metric with task.aggregation 317 | if metric.endswith(decontaminate_suffix): 318 | real_metric = metric.replace( 319 | decontaminate_suffix, "" 320 | ) # decontaminated still uses the same metric 321 | results[task_name][metric] = task.aggregation()[real_metric](items) 322 | 323 | # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap 324 | # so we run them less iterations. still looking for a cleaner way to do this 325 | # if metric == 'bleu': 326 | # stderr = lm_eval.metrics.stderr_for_metric( 327 | # metric=task.aggregation()[real_metric], 328 | # bootstrap_iters=min(bootstrap_iters, 1000) 329 | # if metric in ["bleu", "chrf", "ter"] 330 | # else bootstrap_iters, 331 | # ) 332 | 333 | # if stderr is not None: 334 | # results[task_name][metric + "_stderr"] = stderr(items) 335 | 336 | return {"results": dict(results), "versions": dict(versions)} 337 | 338 | 339 | def make_table(result_dict): 340 | """Generate table of results.""" 341 | from pytablewriter import MarkdownTableWriter, LatexTableWriter 342 | 343 | md_writer = MarkdownTableWriter() 344 | latex_writer = LatexTableWriter() 345 | md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] 346 | latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] 347 | 348 | values = [] 349 | 350 | for k, dic in result_dict["results"].items(): 351 | version = result_dict["versions"][k] 352 | for m, v in dic.items(): 353 | if m.endswith("_stderr"): 354 | continue 355 | 356 | if m + "_stderr" in dic: 357 | se = dic[m + "_stderr"] 358 | values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se]) 359 | else: 360 | values.append([k, version, m, "%.4f" % v, "", ""]) 361 | k = "" 362 | version = "" 363 | md_writer.value_matrix = values 364 | latex_writer.value_matrix = values 365 | 366 | # todo: make latex table look good 367 | # print(latex_writer.dumps()) 368 | 369 | return md_writer.dumps() 370 | -------------------------------------------------------------------------------- /lm_eval/models/66b_device_map.txt: -------------------------------------------------------------------------------- 1 | {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.final_layer_norm': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0,'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 0, 'model.decoder.layers.10': 0, 'model.decoder.layers.11': 0, 'model.decoder.layers.12': 0, 'model.decoder.layers.13': 0, 'model.decoder.layers.14': 0, 'model.decoder.layers.15': 0, 'model.decoder.layers.16': 0, 'model.decoder.layers.17': 0, 'model.decoder.layers.18': 1, 'model.decoder.layers.19': 1, 'model.decoder.layers.20': 1,'model.decoder.layers.21': 1, 'model.decoder.layers.22': 1, 'model.decoder.layers.23': 1, 'model.decoder.layers.24': 1, 'model.decoder.layers.25': 1, 'model.decoder.layers.26': 1, 'model.decoder.layers.27': 1, 'model.decoder.layers.28': 1, 'model.decoder.layers.29': 1,'model.decoder.layers.30': 1, 'model.decoder.layers.31': 1, 'model.decoder.layers.32': 1, 'model.decoder.layers.33': 1, 'model.decoder.layers.34': 1, 'model.decoder.layers.35': 1,'model.decoder.layers.36': 1, 'model.decoder.layers.37': 1, 'model.decoder.layers.38': 2,'model.decoder.layers.39': 2, 'model.decoder.layers.40': 2, 'model.decoder.layers.41': 2,'model.decoder.layers.42': 2, 'model.decoder.layers.43': 2, 'model.decoder.layers.44': 2,'model.decoder.layers.45': 2, 'model.decoder.layers.46': 2, 'model.decoder.layers.47': 2,'model.decoder.layers.48': 2, 'model.decoder.layers.49': 2, 'model.decoder.layers.50': 2, 'model.decoder.layers.51': 2, 'model.decoder.layers.52': 2, 'model.decoder.layers.53': 2, 'model.decoder.layers.54': 2, 'model.decoder.layers.55': 2, 'model.decoder.layers.56': 2,'model.decoder.layers.57': 2, 'model.decoder.layers.58': 3, 'model.decoder.layers.59': 3,'model.decoder.layers.60': 3, 'model.decoder.layers.61': 3, 'model.decoder.layers.62': 3, 'model.decoder.layers.63': 3} -------------------------------------------------------------------------------- /lm_eval/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import gpt2 2 | from . import gpt3 3 | from . import opt 4 | from . import dummy 5 | 6 | MODEL_REGISTRY = { 7 | "hf": gpt2.HFLM, 8 | "gpt2": gpt2.GPT2LM, 9 | "gpt3": gpt3.GPT3LM, 10 | "opt": opt.OPTLM, 11 | "dummy": dummy.DummyLM, 12 | } 13 | 14 | 15 | def get_model(model_name): 16 | return MODEL_REGISTRY[model_name] 17 | -------------------------------------------------------------------------------- /lm_eval/models/opt.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import os 4 | import random 5 | import pickle 6 | from lm_eval.base import BaseLM 7 | # from transformers.deepspeed import HfDeepSpeedConfig 8 | # import deepspeed 9 | 10 | class HFLM(BaseLM): 11 | def __init__( 12 | self, 13 | device="cuda", 14 | pretrained="facebook/opt-125m", 15 | revision="main", 16 | subfolder=None, 17 | tokenizer=None, 18 | batch_size=1, 19 | model_cache_dir = None, 20 | tokenizer_cache_dir = None, 21 | mask_single_head=0, 22 | mask_heads=0, 23 | mask_fc=0, 24 | mask_iterative_fc=0, 25 | head_percent_mask=0, 26 | head_importance_path=None, 27 | fc_percent_mask=0, 28 | fc_importance_path=None, 29 | ): 30 | super().__init__() 31 | 32 | assert isinstance(device, str) 33 | assert isinstance(pretrained, str) 34 | assert isinstance(batch_size, int) 35 | 36 | if device: 37 | if device not in ["cuda", "cpu"]: 38 | device = int(device) 39 | self._device = torch.device(device) 40 | print(f"Using device '{device}'") 41 | else: 42 | print("Device not specified") 43 | print(f"Cuda Available? {torch.cuda.is_available()}") 44 | self._device = ( 45 | torch.device("cuda") 46 | if torch.cuda.is_available() 47 | else torch.device("cpu") 48 | ) 49 | 50 | device_map = 'auto' 51 | if '66b' in pretrained: 52 | device_map = eval(open('66b_device_map.txt', 'r').readlines()[0]) 53 | for key in device_map: 54 | if 'layers' in key: 55 | layer_num = int(key.split('.')[-1]) 56 | if layer_num <= 3: 57 | device_map[key] = 0 58 | elif layer_num >= 60: 59 | device_map[key] = 'cpu' 60 | else: 61 | device_map[key] = ((layer_num - 4) // 8) + 1 62 | 63 | self.opt = transformers.AutoModelForCausalLM.from_pretrained( 64 | pretrained, 65 | device_map = device_map, 66 | cache_dir = model_cache_dir, 67 | torch_dtype=torch.float16 68 | ) 69 | 70 | self.opt.get_decoder().embed_tokens.weight.requires_grad = False 71 | self.opt.get_decoder().embed_positions.weight.requires_grad = False 72 | # self.opt = transformers.AutoModelForCausalLM.from_pretrained( 73 | # pretrained, 74 | # cache_dir = model_cache_dir, 75 | # ) 76 | # self.ds_engine = deepspeed.initialize(model=self.opt, model_parameters = self.opt.parameters(), config_params=ds_config)[0] 77 | # self.ds_engine.module.eval() # inference 78 | 79 | self.tokenizer = transformers.AutoTokenizer.from_pretrained( 80 | pretrained, 81 | cache_dir = tokenizer_cache_dir if tokenizer_cache_dir else 'tokenizer_cache/', 82 | use_fast = False 83 | ) if tokenizer is None else tokenizer 84 | 85 | assert isinstance( 86 | self.tokenizer, 87 | ( 88 | transformers.GPT2Tokenizer, 89 | transformers.GPT2TokenizerFast, 90 | transformers.T5Tokenizer, 91 | transformers.T5TokenizerFast, 92 | ), 93 | ), "this tokenizer has not been checked for compatibility yet!" 94 | 95 | self.vocab_size = self.tokenizer.vocab_size 96 | 97 | # multithreading and batching 98 | self.batch_size_per_gpu = batch_size # todo: adaptive batch size 99 | 100 | num_hidden_layers = self.opt.config.num_hidden_layers 101 | num_heads = self.opt.config.num_attention_heads 102 | self.head_mask = torch.ones(num_hidden_layers * num_heads, dtype = torch.half) 103 | self.fc_mask = torch.ones(num_hidden_layers, dtype = torch.half) 104 | 105 | if int(mask_heads): 106 | with open(head_importance_path, 'rb') as f: 107 | importance = pickle.load(f) 108 | _, head_indices = torch.sort(importance.view(-1)) 109 | head_indices = list(head_indices.numpy()) 110 | head_indices = head_indices[: int(head_percent_mask) * len(head_indices) // 100] 111 | self.head_mask[head_indices] = 0. 112 | elif int(mask_single_head): #Only performing it on OPT125M 113 | self.head_mask[int(mask_single_head)-1] = 0. 114 | 115 | self.head_mask = self.head_mask.view(num_hidden_layers, num_heads).contiguous() 116 | 117 | if mask_fc: 118 | self.fc_mask[int(mask_fc)] = 0. 119 | elif int(mask_iterative_fc): 120 | with open(fc_importance_path, 'rb') as f: 121 | fc_indices = list(pickle.load(f)) 122 | fc_indices = fc_indices[: int(fc_percent_mask) * len(fc_indices) // 100] 123 | self.fc_mask[fc_indices] = 0. 124 | 125 | 126 | @property 127 | def eot_token_id(self): 128 | return self.tokenizer.eos_token_id 129 | 130 | @property 131 | def max_length(self): 132 | return self.opt.config.max_position_embeddings 133 | 134 | @property 135 | def max_gen_toks(self): 136 | return 256 137 | 138 | @property 139 | def batch_size(self): 140 | # TODO: fix multi-gpu 141 | return self.batch_size_per_gpu # * gpus 142 | 143 | @property 144 | def device(self): 145 | # TODO: fix multi-gpu 146 | return self._device 147 | 148 | def get_tokenizer(self): 149 | return self.tokenizer 150 | 151 | def tok_encode(self, string: str): 152 | return self.tokenizer.encode(string, add_special_tokens=False) 153 | 154 | def tok_decode(self, tokens): 155 | return self.tokenizer.decode(tokens) 156 | 157 | def _model_call(self, inps, attn_mask = None, labels = None): 158 | """ 159 | inps: a torch tensor of shape [batch, sequence] 160 | the size of sequence may vary from call to call 161 | 162 | returns: a torch tensor of shape [batch, sequence, vocab] with the 163 | logits returned from the model 164 | """ 165 | 166 | if labels == None: 167 | with torch.no_grad(): 168 | return self.opt(input_ids = inps, head_mask = self.head_mask, fc_mask = self.fc_mask)[0][:, :, :50265] 169 | # rank = int(os.getenv("LOCAL_RANK", "0")) 170 | # return self.ds_engine.module(input_ids = inps, head_mask = self.head_mask.to(rank))[0][:, :, :50265] 171 | else: 172 | return self.opt(input_ids = inps, attention_mask = attn_mask, labels = labels).loss 173 | # return self.ds_engine.module(input_ids = inps, attention_mask = attn_mask, labels = labels).loss 174 | 175 | def _model_generate(self, context, max_length, eos_token_id): 176 | return self.opt.generate( 177 | context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False 178 | ) 179 | 180 | 181 | # for backwards compatibility 182 | OPTLM = HFLM -------------------------------------------------------------------------------- /lm_eval/prefix_matching_copying.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import argparse 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from scripts.plotting.style import * 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--prefix_matching", action = 'store_true') 13 | parser.add_argument("--copying_score", action = 'store_true') 14 | parser.add_argument("--random_model", action = 'store_true') 15 | parser.add_argument("--frequent_exclude_ratio", type=float, default = 0.04) 16 | parser.add_argument("--pretrained", type = str, default = 'facebook/opt-66b') 17 | parser.add_argument("--model_cache_dir", type = str, default = None) 18 | parser.add_argument("--tokenizer_cache_dir", type = str, default = None) 19 | parser.add_argument("--num_seeds", type = int, default = 100) 20 | parser.add_argument("--save_plot_path_mean", type=str, default=None) 21 | parser.add_argument("--save_plot_path_var", type=str, default=None) 22 | parser.add_argument("--save_outputs", type=str, default=None) 23 | parser.add_argument("--use_save_outputs", action = 'store_true') 24 | 25 | args = parser.parse_args() 26 | 27 | if not args.use_save_outputs: 28 | device_map = 'auto' 29 | if not args.random_model: 30 | model = AutoModelForCausalLM.from_pretrained(args.pretrained, cache_dir = args.model_cache_dir, device_map = device_map) 31 | else: 32 | config = AutoConfig.from_pretrained(args.pretrained) 33 | model = AutoModelForCausalLM.from_config(config) 34 | model.eval() 35 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained, use_fast = False, cache_dir = args.tokenizer_cache_dir) 36 | ## create a ranking of bpe tokens using tokenizer.bpe_ranks that stores bpe token merges based on frequency in pretraining text 37 | ## BPE tokens are saved as per merging order in the dict bpe_ranks 38 | ## more details about merging at https://huggingface.co/docs/transformers/tokenizer_summary 39 | ranked_dict = dict() 40 | ranked_vocab_size = len(list(tokenizer.bpe_ranks.keys())) 41 | check_all_ranks = [0]*ranked_vocab_size 42 | for merge_tuple,rank in tokenizer.bpe_ranks.items(): 43 | bpe_token = ''.join(merge_tuple) 44 | ranked_dict[rank] = tokenizer.encoder[bpe_token] 45 | check_all_ranks[rank] = 1 46 | assert sum(check_all_ranks) == ranked_vocab_size 47 | ## exclude fraction of frequent bpe tokens from random sequences 48 | frequent_excluded_ranks = int(args.frequent_exclude_ratio * ranked_vocab_size) 49 | ## exclude both most and least frequent tokens 50 | rank_start, rank_end = frequent_excluded_ranks, ranked_vocab_size - frequent_excluded_ranks 51 | assert rank_start < rank_end and rank_end > 0 52 | rank_choice_list = np.arange(rank_start, rank_end) 53 | num_layers = model.config.num_hidden_layers 54 | num_heads = model.config.num_attention_heads 55 | final = [] 56 | 57 | with torch.no_grad(): 58 | for seed in tqdm(range(args.num_seeds)): 59 | torch.manual_seed(seed) 60 | ## ensures final length of the generated sequence is in the range (25,~900) 61 | length = seed * 2 + 25 62 | ## sequence is not repeated for copying score 63 | if args.copying_score: 64 | length = 4 * length 65 | ## choose a random sequence excluding most frequent and least frequent bpe tokens 66 | ## generate tokens without replacement to ensure all chosen tokens are unique 67 | ## uniqueness ensures prefix matching score to only capture explicit repeats ie repeat of the whole sequence 68 | generate_ranks = np.random.choice(rank_choice_list, size=length, replace=False) 69 | ## append a bos_token in the beginning to ensure normal model behaviour 70 | generate_ids = torch.tensor([tokenizer.bos_token_id] + [ranked_dict[rank] for rank in generate_ranks]) 71 | generate_ids = torch.unsqueeze(generate_ids, 0) 72 | if not args.random_model: 73 | generate_ids = generate_ids.to(0) 74 | if args.prefix_matching: 75 | ## repeat the sequence excluding the bos token 76 | new_generated = torch.cat([generate_ids, generate_ids[:,1:].repeat(3, 1).view(-1).unsqueeze(0)], dim = -1) 77 | if not args.random_model: 78 | new_generated = new_generated.to(0) 79 | assert new_generated.shape[1] == 4*length + 1 80 | out = model(input_ids = new_generated) 81 | decoder = model.get_decoder() 82 | attn_matrix = torch.zeros((num_layers, num_heads)) 83 | for layer in range(num_layers): 84 | attn_probs = decoder.layers[layer].self_attn.attn_probs 85 | for head in range(num_heads): 86 | attn_prob = attn_probs[head] 87 | c = 0 88 | for j in range(length+1, 4*length+1): 89 | for num in range(j//length): 90 | attn_matrix[layer][head] += attn_prob[j][(num*length)+(j%length)+1].item() 91 | c += 1 92 | attn_matrix[layer][head] = attn_matrix[layer][head] / c 93 | final.append(attn_matrix.unsqueeze(0)) 94 | 95 | elif args.copying_score: 96 | new_generated = generate_ids 97 | decoder = model.get_decoder() 98 | input_shape = new_generated.size() 99 | input_ids = new_generated.view(-1, input_shape[-1]) 100 | past_key_values_length = 0 101 | 102 | inputs_embeds = decoder.embed_tokens(input_ids) 103 | attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) 104 | pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) 105 | attention_mask = decoder._prepare_decoder_attention_mask( 106 | attention_mask, input_shape, inputs_embeds, past_key_values_length 107 | ) 108 | hidden_states = inputs_embeds + pos_embeds 109 | 110 | copying_matrix = torch.zeros((num_layers, num_heads)) 111 | for layer in tqdm(range(num_layers)): 112 | layer_ = decoder.layers[layer] 113 | hs = layer_.self_attn_layer_norm(hidden_states) 114 | layer_self_attn = layer_.self_attn 115 | attn_probs = layer_self_attn(hidden_states = hs, attention_mask = attention_mask, output_attentions = True)[1].squeeze(0) 116 | value_states = layer_self_attn._shape(layer_self_attn.v_proj(hs), -1, 1).squeeze(0) #n_heads, length, dim_head 117 | h, l, d = value_states.shape 118 | # convert h, l, d_h -> h, l, d_e so that it can be fed to out_proj directly 119 | value_states = [torch.cat([torch.zeros((1,l,i*d), dtype = value_states.dtype, device = value_states.device), value_states[i,:,:].unsqueeze(0), \ 120 | torch.zeros((1,l,(h-i-1)*d), dtype = value_states.dtype, device = value_states.device)], dim = -1) for i in range(len(value_states))] 121 | value_states = torch.cat(value_states, dim = 0) # h, l, d_e 122 | output = layer_self_attn.out_proj(value_states) 123 | 124 | logits = model.lm_head(output).contiguous() # h, l, vocab_size 125 | logits = F.softmax(logits, dim = -1) 126 | 127 | for head in range(num_heads): 128 | attn_prob = attn_probs[head] 129 | _, ind = torch.sort(attn_prob, dim = 1) 130 | max_ind = ind[:, -1] 131 | c = 0 132 | ## iterate the complete random sequence 133 | for j in range(1, length + 1): 134 | c += 1 135 | assert (max_ind[j] <= j) 136 | ## tokens that can be attended to in the current time step ie 0 to j 137 | attendable_input = input_ids[0][:(j+1)] 138 | ## logits of attendable tokens 139 | attendable_logits = logits[head][j][attendable_input] 140 | ## mean of the logits 141 | mean_of_logits = attendable_logits.mean() 142 | ## raise logits 143 | raised_logits = attendable_logits - mean_of_logits 144 | ## relu over raised logits 145 | relu_raised_logits = torch.nn.functional.relu(raised_logits) 146 | relu_raised_logit_max_ind = relu_raised_logits[max_ind[j]].item() 147 | relu_raised_logit_all = relu_raised_logits.sum().item() 148 | ## ratio of raised logit 149 | copying_score = 0 150 | ## edgecase: if all logits are of equal value then relu_raised_logit_all can be 0 151 | if relu_raised_logit_all != 0: 152 | copying_score = relu_raised_logit_max_ind / relu_raised_logit_all 153 | copying_matrix[layer][head] += copying_score 154 | copying_matrix[layer][head] = copying_matrix[layer][head] / c 155 | final.append(copying_matrix.unsqueeze(0)) 156 | else: 157 | raise RuntimeError("Neither prefix matching nor copying score selected") 158 | final = torch.cat(final, dim = 0) 159 | mean = final.mean(dim = 0) 160 | variance = final.var(dim = 0) 161 | os.makedirs(os.path.dirname(args.save_outputs), exist_ok = True) 162 | 163 | with open(args.save_outputs, 'wb') as f: 164 | pickle.dump({'mean': mean, 'variance': variance}, f) 165 | 166 | if args.use_save_outputs: 167 | with open(args.save_outputs, 'rb') as f: 168 | res = pickle.load(f) 169 | mean, variance = res['mean'], res['variance'] 170 | num_layers, num_heads = mean.shape 171 | 172 | 173 | max_, min_ = mean.max(), mean.min() 174 | print(max_, min_) 175 | print(mean.shape) 176 | ## changed the range for best visualization of copying score 177 | ax = sns.heatmap(mean.numpy(), xticklabels = [(i+1) if i%2==0 else None for i in range(num_heads)], yticklabels = [(i+1) if i%2==0 else None for i in range(num_layers)], vmin = min_, vmax = max_) 178 | plt.ylabel('Layers') 179 | plt.xlabel('Heads') 180 | plt.title('Prefix Matching Score' if args.prefix_matching else 'Copying Score') 181 | ax.invert_yaxis() 182 | 183 | os.makedirs(os.path.dirname(args.save_plot_path_mean), exist_ok = True) 184 | plt.savefig(args.save_plot_path_mean) 185 | plt.savefig(args.save_plot_path_mean[:-4]+'.pdf') 186 | 187 | plt.close() 188 | 189 | max_, min_ = variance.max(), variance.min() 190 | print(max_, min_) 191 | ax = sns.heatmap(variance.numpy(), xticklabels = [(i+1) if i%2==0 else None for i in range(num_heads)], yticklabels = [(i+1) if i%2==0 else None for i in range(num_layers)], vmin = min_, vmax = max_) 192 | plt.ylabel('Layers') 193 | plt.xlabel('Heads') 194 | plt.title('Prefix Matching Score' if args.prefix_matching else 'Copying Score') 195 | ax.invert_yaxis() 196 | 197 | os.makedirs(os.path.dirname(args.save_plot_path_var), exist_ok = True) 198 | plt.savefig(args.save_plot_path_var) 199 | plt.savefig(args.save_plot_path_var[:-4]+'.pdf') 200 | -------------------------------------------------------------------------------- /lm_eval/tasks/arc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge 3 | https://arxiv.org/pdf/1803.05457.pdf 4 | 5 | The ARC dataset consists of 7,787 science exam questions drawn from a variety 6 | of sources, including science questions provided under license by a research 7 | partner affiliated with AI2. These are text-only, English language exam questions 8 | that span several grade levels as indicated in the files. Each question has a 9 | multiple choice structure (typically 4 answer options). The questions are sorted 10 | into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and 11 | a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions. 12 | 13 | Homepage: https://allenai.org/data/arc 14 | """ 15 | from lm_eval.base import MultipleChoiceTask 16 | from lm_eval.utils import create_dataloader 17 | 18 | _CITATION = """ 19 | @article{Clark2018ThinkYH, 20 | title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge}, 21 | author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord}, 22 | journal={ArXiv}, 23 | year={2018}, 24 | volume={abs/1803.05457} 25 | } 26 | """ 27 | 28 | 29 | class ARCEasy(MultipleChoiceTask): 30 | VERSION = 0 31 | DATASET_PATH = "ai2_arc" 32 | DATASET_NAME = "ARC-Easy" 33 | 34 | def has_training_docs(self): 35 | return True 36 | 37 | def has_validation_docs(self): 38 | return True 39 | 40 | def has_test_docs(self): 41 | return True 42 | 43 | def training_docs(self): 44 | if self._training_docs is None: 45 | self._training_docs = list(map(self._process_doc, self.dataset["train"])) 46 | return self._training_docs 47 | 48 | def validation_docs(self): 49 | return map(self._process_doc, self.dataset["validation"]) 50 | 51 | def test_docs(self): 52 | return map(self._process_doc, self.dataset["test"]) 53 | 54 | def _process_doc(self, doc): 55 | # NOTE: Some `doc["answerKey"]`s are in numeric string format being one 56 | # of {'1', '2', '3', '4', '5'}. We map them back to letters. 57 | num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"} 58 | doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"]) 59 | out_doc = { 60 | "id": doc["id"], 61 | "query": "Question: " + doc["question"] + "\nAnswer:", 62 | "choices": doc["choices"]["text"], 63 | "gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]), 64 | } 65 | return out_doc 66 | 67 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 68 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else self.test_docs() 69 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_cont, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 70 | 71 | def doc_to_cont(self, doc): 72 | return doc['choices'][doc['gold']] 73 | 74 | def doc_to_text(self, doc): 75 | return doc["query"] 76 | 77 | def should_decontaminate(self): 78 | return True 79 | 80 | def doc_to_decontamination_query(self, doc): 81 | return doc["query"] 82 | 83 | 84 | class ARCChallenge(ARCEasy): 85 | DATASET_PATH = "ai2_arc" 86 | DATASET_NAME = "ARC-Challenge" 87 | -------------------------------------------------------------------------------- /lm_eval/tasks/glue.py: -------------------------------------------------------------------------------- 1 | """ 2 | GLUE: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding 3 | https://openreview.net/pdf?id=rJ4km2R5t7 4 | 5 | The General Language Understanding Evaluation (GLUE) benchmark is a collection of 6 | resources for training, evaluating, and analyzing natural language understanding 7 | systems. GLUE consists of: 8 | - A benchmark of nine sentence- or sentence-pair language understanding tasks built 9 | on established existing datasets and selected to cover a diverse range of dataset 10 | sizes, text genres, and degrees of difficulty, and 11 | - A diagnostic dataset designed to evaluate and analyze model performance with 12 | respect to a wide range of linguistic phenomena found in natural language. 13 | 14 | Homepage: https://gluebenchmark.com/ 15 | """ 16 | import numpy as np 17 | from lm_eval.base import rf, Task 18 | from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno 19 | from lm_eval.utils import general_detokenize 20 | from lm_eval.utils import create_dataloader 21 | 22 | # TODO(jon-tow): Add citations for the individual datasets/tasks that make up GLUE. 23 | _CITATION = """ 24 | @inproceedings{wang-etal-2018-glue, 25 | title = "{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding", 26 | author = "Wang, Alex and 27 | Singh, Amanpreet and 28 | Michael, Julian and 29 | Hill, Felix and 30 | Levy, Omer and 31 | Bowman, Samuel", 32 | booktitle = "Proceedings of the 2018 {EMNLP} Workshop {B}lackbox{NLP}: Analyzing and Interpreting Neural Networks for {NLP}", 33 | month = nov, 34 | year = "2018", 35 | address = "Brussels, Belgium", 36 | publisher = "Association for Computational Linguistics", 37 | url = "https://aclanthology.org/W18-5446", 38 | doi = "10.18653/v1/W18-5446", 39 | pages = "353--355", 40 | abstract = "Human ability to understand language is \textit{general, flexible, and robust}. In contrast, most NLU models above the word level are designed for a specific task and struggle with out-of-domain data. If we aspire to develop models with understanding beyond the detection of superficial correspondences between inputs and outputs, then it is critical to develop a unified model that can execute a range of linguistic tasks across different domains. To facilitate research in this direction, we present the General Language Understanding Evaluation (GLUE, gluebenchmark.com): a benchmark of nine diverse NLU tasks, an auxiliary dataset for probing models for understanding of specific linguistic phenomena, and an online platform for evaluating and comparing models. For some benchmark tasks, training data is plentiful, but for others it is limited or does not match the genre of the test set. GLUE thus favors models that can represent linguistic knowledge in a way that facilitates sample-efficient learning and effective knowledge-transfer across tasks. While none of the datasets in GLUE were created from scratch for the benchmark, four of them feature privately-held test data, which is used to ensure that the benchmark is used fairly. We evaluate baselines that use ELMo (Peters et al., 2018), a powerful transfer learning technique, as well as state-of-the-art sentence representation models. The best models still achieve fairly low absolute scores. Analysis with our diagnostic dataset yields similarly weak performance over all phenomena tested, with some exceptions.", 41 | } 42 | """ 43 | 44 | 45 | # Single-Sentence Tasks 46 | 47 | 48 | class CoLA(Task): 49 | VERSION = 0 50 | DATASET_PATH = "glue" 51 | DATASET_NAME = "cola" 52 | 53 | def has_training_docs(self): 54 | return True 55 | 56 | def has_validation_docs(self): 57 | return True 58 | 59 | def has_test_docs(self): 60 | return False 61 | 62 | def training_docs(self): 63 | if self._training_docs is None: 64 | self._training_docs = list(self.dataset["train"]) 65 | return self._training_docs 66 | 67 | def validation_docs(self): 68 | return self.dataset["validation"] 69 | 70 | def doc_to_text(self, doc): 71 | return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format( 72 | doc["sentence"] 73 | ) 74 | 75 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 76 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 77 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 78 | 79 | def should_decontaminate(self): 80 | return True 81 | 82 | def doc_to_decontamination_query(self, doc): 83 | return doc["sentence"] 84 | 85 | def doc_to_target(self, doc): 86 | return " {}".format({1: "yes", 0: "no"}[doc["label"]]) 87 | 88 | def construct_requests(self, doc, ctx): 89 | ll_true, _ = rf.loglikelihood(ctx, " yes") 90 | ll_false, _ = rf.loglikelihood(ctx, " no") 91 | return ll_true, ll_false 92 | 93 | def process_results(self, doc, results): 94 | ll_true, ll_false = results 95 | pred = ll_true > ll_false 96 | gold = doc["label"] 97 | return {"mcc": (gold, pred)} 98 | 99 | def higher_is_better(self): 100 | return {"mcc": True} 101 | 102 | def aggregation(self): 103 | return {"mcc": matthews_corrcoef} 104 | 105 | 106 | class SST(Task): 107 | VERSION = 0 108 | DATASET_PATH = "glue" 109 | DATASET_NAME = "sst2" 110 | 111 | def has_training_docs(self): 112 | return True 113 | 114 | def has_validation_docs(self): 115 | return True 116 | 117 | def has_test_docs(self): 118 | return False 119 | 120 | def training_docs(self): 121 | if self._training_docs is None: 122 | self._training_docs = list(self.dataset["train"]) 123 | return self._training_docs 124 | 125 | def validation_docs(self): 126 | return self.dataset["validation"] 127 | 128 | def doc_to_text(self, doc): 129 | return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format( 130 | general_detokenize(doc["sentence"]), 131 | ) 132 | 133 | def doc_to_target(self, doc): 134 | return " {}".format({1: "positive", 0: "negative"}[doc["label"]]) 135 | 136 | def construct_requests(self, doc, ctx): 137 | ll_positive, _ = rf.loglikelihood(ctx, " positive") 138 | ll_negative, _ = rf.loglikelihood(ctx, " negative") 139 | return ll_positive, ll_negative 140 | 141 | def process_results(self, doc, results): 142 | ll_positive, ll_negative = results 143 | pred = ll_positive > ll_negative 144 | gold = doc["label"] 145 | return {"acc": pred == gold} 146 | 147 | def higher_is_better(self): 148 | return {"acc": True} 149 | 150 | def aggregation(self): 151 | return {"acc": mean} 152 | 153 | 154 | # Inference Tasks 155 | 156 | 157 | class MNLI(Task): 158 | VERSION = 0 159 | DATASET_PATH = "glue" 160 | DATASET_NAME = "mnli" 161 | 162 | def has_training_docs(self): 163 | return True 164 | 165 | def has_validation_docs(self): 166 | return True 167 | 168 | def has_test_docs(self): 169 | return False 170 | 171 | def training_docs(self): 172 | if self._training_docs is None: 173 | self._training_docs = list(self.dataset["train"]) 174 | return self._training_docs 175 | 176 | def validation_docs(self): 177 | if self.has_validation_docs(): 178 | return self.dataset["validation_matched"] 179 | 180 | def test_docs(self): 181 | if self.has_test_docs(): 182 | return self.dataset["test_matched"] 183 | 184 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 185 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 186 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 187 | 188 | def doc_to_text(self, doc): 189 | return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( 190 | doc["premise"], 191 | doc["hypothesis"].strip() 192 | + ("" if doc["hypothesis"].strip().endswith(".") else "."), 193 | ) 194 | 195 | def doc_to_target(self, doc): 196 | # True = entailment 197 | # False = contradiction 198 | # Neither = neutral 199 | return " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]]) 200 | 201 | def construct_requests(self, doc, ctx): 202 | ll_true, _ = rf.loglikelihood(ctx, " True") 203 | ll_neither, _ = rf.loglikelihood(ctx, " Neither") 204 | ll_false, _ = rf.loglikelihood(ctx, " False") 205 | return ll_true, ll_neither, ll_false 206 | 207 | def process_results(self, doc, results): 208 | gold = doc["label"] 209 | pred = np.argmax(results) 210 | return {"acc": pred == gold} 211 | 212 | def higher_is_better(self): 213 | return {"acc": True} 214 | 215 | def aggregation(self): 216 | return {"acc": mean} 217 | 218 | 219 | class MNLIMismatched(MNLI): 220 | VERSION = 0 221 | 222 | def validation_docs(self): 223 | if self.has_validation_docs(): 224 | return self.dataset["validation_mismatched"] 225 | 226 | def test_docs(self): 227 | if self.has_test_docs(): 228 | return self.dataset["test_mismatched"] 229 | 230 | 231 | class QNLI(Task): 232 | VERSION = 0 233 | DATASET_PATH = "glue" 234 | DATASET_NAME = "qnli" 235 | 236 | def has_training_docs(self): 237 | return True 238 | 239 | def has_validation_docs(self): 240 | return True 241 | 242 | def has_test_docs(self): 243 | return False 244 | 245 | def training_docs(self): 246 | if self._training_docs is None: 247 | self._training_docs = list(self.dataset["train"]) 248 | return self._training_docs 249 | 250 | def validation_docs(self): 251 | return self.dataset["validation"] 252 | 253 | def doc_to_text(self, doc): 254 | return ( 255 | "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format( 256 | doc["question"], 257 | doc["sentence"], 258 | ) 259 | ) 260 | 261 | def doc_to_target(self, doc): 262 | # True = entailment 263 | # False = not entailment 264 | return " {}".format({0: "yes", 1: "no"}[doc["label"]]) 265 | 266 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 267 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 268 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 269 | 270 | def construct_requests(self, doc, ctx): 271 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 272 | ll_no, _ = rf.loglikelihood(ctx, " no") 273 | return ll_yes, ll_no 274 | 275 | def process_results(self, doc, results): 276 | ll_yes, ll_no = results 277 | pred = ll_no > ll_yes 278 | gold = doc["label"] 279 | return {"acc": pred == gold} 280 | 281 | def higher_is_better(self): 282 | return {"acc": True} 283 | 284 | def aggregation(self): 285 | return {"acc": mean} 286 | 287 | 288 | class WNLI(Task): 289 | VERSION = 1 290 | DATASET_PATH = "glue" 291 | DATASET_NAME = "wnli" 292 | 293 | def has_training_docs(self): 294 | return True 295 | 296 | def has_validation_docs(self): 297 | return True 298 | 299 | def has_test_docs(self): 300 | return False 301 | 302 | def training_docs(self): 303 | if self._training_docs is None: 304 | self._training_docs = list(self.dataset["train"]) 305 | return self._training_docs 306 | 307 | def validation_docs(self): 308 | return self.dataset["validation"] 309 | 310 | def doc_to_text(self, doc): 311 | return "{}\nQuestion: {} True or False?\nAnswer:".format( 312 | doc["sentence1"], 313 | doc["sentence2"], 314 | ) 315 | 316 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 317 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 318 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 319 | 320 | def doc_to_target(self, doc): 321 | # True = entailment 322 | # False = not_entailment 323 | return " {}".format({0: "False", 1: "True"}[doc["label"]]) 324 | 325 | def construct_requests(self, doc, ctx): 326 | ll_true, _ = rf.loglikelihood(ctx, " True") 327 | ll_false, _ = rf.loglikelihood(ctx, " False") 328 | return ll_true, ll_false 329 | 330 | def process_results(self, doc, results): 331 | ll_true, ll_false = results 332 | pred = ll_true > ll_false 333 | gold = doc["label"] 334 | return {"acc": pred == gold} 335 | 336 | def higher_is_better(self): 337 | return {"acc": True} 338 | 339 | def aggregation(self): 340 | return {"acc": mean} 341 | 342 | 343 | class RTE(Task): 344 | VERSION = 0 345 | DATASET_PATH = "glue" 346 | DATASET_NAME = "rte" 347 | 348 | def has_training_docs(self): 349 | return True 350 | 351 | def has_validation_docs(self): 352 | return True 353 | 354 | def has_test_docs(self): 355 | return False 356 | 357 | def training_docs(self): 358 | if self._training_docs is None: 359 | self._training_docs = list(self.dataset["train"]) 360 | return self._training_docs 361 | 362 | def validation_docs(self): 363 | return self.dataset["validation"] 364 | 365 | def doc_to_text(self, doc): 366 | return "{}\nQuestion: {} True or False?\nAnswer:".format( 367 | doc["sentence1"], 368 | doc["sentence2"], 369 | ) 370 | 371 | def doc_to_target(self, doc): 372 | # 0 = entailment 373 | # 1 = not_entailment 374 | return " {}".format({0: "True", 1: "False"}[doc["label"]]) 375 | 376 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 377 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 378 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 379 | 380 | def construct_requests(self, doc, ctx): 381 | ll_true, _ = rf.loglikelihood(ctx, " True") 382 | ll_false, _ = rf.loglikelihood(ctx, " False") 383 | return ll_true, ll_false 384 | 385 | def process_results(self, doc, results): 386 | ll_true, ll_false = results 387 | pred = ll_false > ll_true 388 | gold = doc["label"] 389 | return {"acc": pred == gold} 390 | 391 | def higher_is_better(self): 392 | return {"acc": True} 393 | 394 | def aggregation(self): 395 | return {"acc": mean} 396 | 397 | 398 | # Similarity and Paraphrase Tasks 399 | 400 | 401 | class MRPC(Task): 402 | VERSION = 0 403 | DATASET_PATH = "glue" 404 | DATASET_NAME = "mrpc" 405 | 406 | def has_training_docs(self): 407 | return True 408 | 409 | def has_validation_docs(self): 410 | return True 411 | 412 | def has_test_docs(self): 413 | return False 414 | 415 | def training_docs(self): 416 | if self._training_docs is None: 417 | self._training_docs = list(self.dataset["train"]) 418 | return self._training_docs 419 | 420 | def validation_docs(self): 421 | return self.dataset["validation"] 422 | 423 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 424 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 425 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 426 | 427 | def doc_to_text(self, doc): 428 | return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format( 429 | general_detokenize(doc["sentence1"]), 430 | general_detokenize(doc["sentence2"]), 431 | ) 432 | 433 | def doc_to_target(self, doc): 434 | return " {}".format(yesno(doc["label"])) 435 | 436 | def construct_requests(self, doc, ctx): 437 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 438 | ll_no, _ = rf.loglikelihood(ctx, " no") 439 | return ll_yes, ll_no 440 | 441 | def process_results(self, doc, results): 442 | ll_yes, ll_no = results 443 | gold = doc["label"] 444 | pred = ll_yes > ll_no 445 | return { 446 | "acc": pred == gold, 447 | "f1": (gold, pred), 448 | } 449 | 450 | def higher_is_better(self): 451 | return {"acc": True, "f1": True} 452 | 453 | def aggregation(self): 454 | return {"acc": mean, "f1": f1_score} 455 | 456 | 457 | class QQP(Task): 458 | VERSION = 0 459 | DATASET_PATH = "glue" 460 | DATASET_NAME = "qqp" 461 | 462 | def has_training_docs(self): 463 | return True 464 | 465 | def has_validation_docs(self): 466 | return True 467 | 468 | def has_test_docs(self): 469 | return False 470 | 471 | def training_docs(self): 472 | if self._training_docs is None: 473 | self._training_docs = list(self.dataset["train"]) 474 | return self._training_docs 475 | 476 | def validation_docs(self): 477 | return self.dataset["validation"] 478 | 479 | def doc_to_text(self, doc): 480 | return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format( 481 | doc["question1"], 482 | doc["question2"], 483 | ) 484 | 485 | def doc_to_target(self, doc): 486 | return " {}".format(yesno(doc["label"])) 487 | 488 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 489 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 490 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 491 | 492 | def construct_requests(self, doc, ctx): 493 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 494 | ll_no, _ = rf.loglikelihood(ctx, " no") 495 | return ll_yes, ll_no 496 | 497 | def process_results(self, doc, results): 498 | ll_yes, ll_no = results 499 | gold = doc["label"] 500 | pred = ll_yes > ll_no 501 | return { 502 | "acc": pred == gold, 503 | "f1": (gold, pred), 504 | } 505 | 506 | def higher_is_better(self): 507 | return {"acc": True, "f1": True} 508 | 509 | def aggregation(self): 510 | return {"acc": mean, "f1": f1_score} 511 | 512 | 513 | class STSB(Task): 514 | VERSION = 0 515 | DATASET_PATH = "glue" 516 | DATASET_NAME = "stsb" 517 | 518 | def has_training_docs(self): 519 | return True 520 | 521 | def has_validation_docs(self): 522 | return True 523 | 524 | def has_test_docs(self): 525 | return True 526 | 527 | def training_docs(self): 528 | if self._training_docs is None: 529 | self._training_docs = list(self.dataset["train"]) 530 | return self._training_docs 531 | 532 | def validation_docs(self): 533 | return self.dataset["validation"] 534 | 535 | def test_docs(self): 536 | return self.dataset["test"] 537 | 538 | def doc_to_text(self, doc): 539 | return "sentence 1: {}\nsentence 2: {}\nAnswer:".format( 540 | doc["sentence1"], 541 | doc["sentence2"], 542 | ) 543 | 544 | def doc_to_target(self, doc): 545 | return " {}".format(doc["label"]) 546 | 547 | def construct_requests(self, doc, ctx): 548 | """Uses RequestFactory to construct Requests and returns an iterable of 549 | Requests which will be sent to the LM. 550 | 551 | :param doc: 552 | The document as returned from training_docs, validation_docs, or test_docs. 553 | :param ctx: str 554 | The context string, generated by fewshot_context. This includes the natural 555 | language description, as well as the few shot examples, and the question 556 | part of the document for `doc`. 557 | """ 558 | # TODO: implement evaluation. 559 | raise NotImplementedError("Evaluation not implemented") 560 | 561 | def process_results(self, doc, results): 562 | """Take a single document and the LM results and evaluates, returning a 563 | dict where keys are the names of submetrics and values are the values of 564 | the metric for that one document 565 | 566 | :param doc: 567 | The document as returned from training_docs, validation_docs, or test_docs. 568 | :param results: 569 | The results of the requests created in construct_requests. 570 | """ 571 | # TODO: implement evaluation. 572 | raise NotImplementedError("Evaluation not implemented") 573 | 574 | def aggregation(self): 575 | """ 576 | :returns: {str: [float] -> float} 577 | A dictionary where keys are the names of submetrics and values are 578 | functions that aggregate a list of metrics 579 | """ 580 | # TODO: implement evaluation. 581 | raise NotImplementedError("Evaluation not implemented") 582 | 583 | def higher_is_better(self): 584 | """ 585 | :returns: {str: bool} 586 | A dictionary where keys are the names of submetrics and values are 587 | whether a higher value of the submetric is better 588 | """ 589 | # TODO: implement evaluation. 590 | raise NotImplementedError("Evaluation not implemented") 591 | -------------------------------------------------------------------------------- /lm_eval/tasks/hellaswag.py: -------------------------------------------------------------------------------- 1 | """ 2 | HellaSwag: Can a Machine Really Finish Your Sentence? 3 | https://arxiv.org/pdf/1905.07830.pdf 4 | 5 | Hellaswag is a commonsense inference challenge dataset. Though its questions are 6 | trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is 7 | achieved via Adversarial Filtering (AF), a data collection paradigm wherein a 8 | series of discriminators iteratively select an adversarial set of machine-generated 9 | wrong answers. AF proves to be surprisingly robust. The key insight is to scale up 10 | the length and complexity of the dataset examples towards a critical 'Goldilocks' 11 | zone wherein generated text is ridiculous to humans, yet often misclassified by 12 | state-of-the-art models. 13 | 14 | Homepage: https://rowanzellers.com/hellaswag/ 15 | """ 16 | import re 17 | import torch 18 | from lm_eval.base import MultipleChoiceTask 19 | from lm_eval.utils import create_dataloader 20 | 21 | _CITATION = """ 22 | @inproceedings{zellers2019hellaswag, 23 | title={HellaSwag: Can a Machine Really Finish Your Sentence?}, 24 | author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin}, 25 | booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics}, 26 | year={2019} 27 | } 28 | """ 29 | 30 | 31 | class HellaSwag(MultipleChoiceTask): 32 | VERSION = 0 33 | DATASET_PATH = "hellaswag" 34 | DATASET_NAME = None 35 | 36 | def has_training_docs(self): 37 | return True 38 | 39 | def has_validation_docs(self): 40 | return True 41 | 42 | def has_test_docs(self): 43 | return False 44 | 45 | def training_docs(self): 46 | if self._training_docs is None: 47 | self._training_docs = list(map(self._process_doc, self.dataset["train"])) 48 | return self._training_docs 49 | 50 | def validation_docs(self): 51 | return map(self._process_doc, self.dataset["validation"]) 52 | 53 | def _process_doc(self, doc): 54 | ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() 55 | out_doc = { 56 | "query": self.preprocess(doc["activity_label"] + ": " + ctx), 57 | "choices": [self.preprocess(ending) for ending in doc["endings"]], 58 | "gold": int(doc["label"]), 59 | } 60 | return out_doc 61 | 62 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 63 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 64 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_cont, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 65 | 66 | def doc_to_cont(self, doc): 67 | return doc['choices'][doc['gold']] 68 | 69 | @classmethod 70 | def preprocess(cls, text): 71 | text = text.strip() 72 | # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. 73 | text = text.replace(" [title]", ". ") 74 | text = re.sub("\\[.*?\\]", "", text) 75 | text = text.replace(" ", " ") 76 | return text 77 | 78 | def doc_to_text(self, doc): 79 | return doc["query"] 80 | 81 | def should_decontaminate(self): 82 | return True 83 | 84 | def doc_to_decontamination_query(self, doc): 85 | return doc["query"] 86 | -------------------------------------------------------------------------------- /lm_eval/tasks/lambada.py: -------------------------------------------------------------------------------- 1 | """ 2 | The LAMBADA dataset: Word prediction requiring a broad discourse context∗ 3 | https://arxiv.org/pdf/1606.06031.pdf 4 | 5 | LAMBADA is a dataset to evaluate the capabilities of computational models for text 6 | understanding by means of a word prediction task. LAMBADA is a collection of narrative 7 | passages sharing the characteristic that human subjects are able to guess their last 8 | word if they are exposed to the whole passage, but not if they only see the last 9 | sentence preceding the target word. To succeed on LAMBADA, computational models 10 | cannot simply rely on local context, but must be able to keep track of information 11 | in the broader discourse. 12 | 13 | Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI 14 | """ 15 | import inspect 16 | import lm_eval.datasets.lambada.lambada 17 | from lm_eval.base import Task, rf 18 | from lm_eval.metrics import mean, perplexity 19 | from lm_eval.utils import create_dataloader 20 | 21 | _CITATION = """ 22 | @misc{ 23 | author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel}, 24 | title={The LAMBADA dataset}, 25 | DOI={10.5281/zenodo.2630551}, 26 | publisher={Zenodo}, 27 | year={2016}, 28 | month={Aug} 29 | } 30 | """ 31 | 32 | 33 | class LAMBADA(Task): 34 | VERSION = 0 35 | DATASET_PATH = inspect.getfile(lm_eval.datasets.lambada.lambada) 36 | 37 | def has_training_docs(self): 38 | return False 39 | 40 | def has_validation_docs(self): 41 | return True 42 | 43 | def has_test_docs(self): 44 | return False 45 | 46 | def training_docs(self): 47 | pass 48 | 49 | def validation_docs(self): 50 | return self.dataset["validation"] 51 | 52 | def test_docs(self): 53 | pass 54 | 55 | def doc_to_text(self, doc): 56 | return doc["text"].rsplit(" ", 1)[0] 57 | 58 | def should_decontaminate(self): 59 | return True 60 | 61 | def doc_to_decontamination_query(self, doc): 62 | return doc["text"] 63 | 64 | def doc_to_target(self, doc): 65 | return " " + doc["text"].rsplit(" ", 1)[1] 66 | 67 | def construct_requests(self, doc, ctx): 68 | ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc)) 69 | 70 | return ll, is_greedy 71 | 72 | def get_dataloader(self, tokenizer, split = 'validation', subset_size = None, batch_size = 1, num_fewshot = 0): 73 | docs = self.validation_docs() #if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 74 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 75 | 76 | def process_results(self, doc, results): 77 | ll, is_greedy = results 78 | 79 | return {"ppl": ll, "acc": int(is_greedy)} 80 | 81 | def aggregation(self): 82 | return {"ppl": perplexity, "acc": mean} 83 | 84 | def higher_is_better(self): 85 | return {"ppl": False, "acc": True} 86 | -------------------------------------------------------------------------------- /lm_eval/tasks/mathqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | MathQA: Towards Interpretable Math Word Problem Solving with Operation-Based Formalisms 3 | https://arxiv.org/pdf/1905.13319.pdf 4 | 5 | MathQA is a large-scale dataset of 37k English multiple-choice math word problems 6 | covering multiple math domain categories by modeling operation programs corresponding 7 | to word problems in the AQuA dataset (Ling et al., 2017). 8 | 9 | Homepage: https://math-qa.github.io/math-QA/ 10 | """ 11 | import re 12 | from lm_eval.base import MultipleChoiceTask 13 | from lm_eval.utils import create_dataloader 14 | 15 | 16 | _CITATION = """ 17 | @misc{amini2019mathqa, 18 | title={MathQA: Towards Interpretable Math Word Problem Solving with Operation-Based Formalisms}, 19 | author={Aida Amini and Saadia Gabriel and Peter Lin and Rik Koncel-Kedziorski and Yejin Choi and Hannaneh Hajishirzi}, 20 | year={2019}, 21 | eprint={1905.13319}, 22 | archivePrefix={arXiv}, 23 | primaryClass={cs.CL} 24 | } 25 | """ 26 | 27 | 28 | class MathQA(MultipleChoiceTask): 29 | VERSION = 0 30 | DATASET_PATH = "math_qa" 31 | DATASET_NAME = None 32 | 33 | def has_training_docs(self): 34 | return True 35 | 36 | def has_validation_docs(self): 37 | return True 38 | 39 | def has_test_docs(self): 40 | return True 41 | 42 | def training_docs(self): 43 | if self._training_docs is None: 44 | self._training_docs = list(map(self._process_doc, self.dataset["train"])) 45 | return self._training_docs 46 | 47 | def validation_docs(self): 48 | return map(self._process_doc, self.dataset["validation"]) 49 | 50 | def test_docs(self): 51 | return map(self._process_doc, self.dataset["test"]) 52 | 53 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 54 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else self.test_docs() 55 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 56 | 57 | def _process_doc(self, doc): 58 | answer_idx = ["a", "b", "c", "d", "e"].index(doc["correct"]) 59 | choices = [ 60 | c[4:].rstrip(" ,") 61 | for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", doc["options"]) 62 | ] 63 | 64 | out_doc = { 65 | "query": "Question: " + doc["Problem"] + "\nAnswer:", 66 | "choices": choices, 67 | "gold": answer_idx, 68 | } 69 | return out_doc 70 | 71 | def doc_to_text(self, doc): 72 | return doc["query"] 73 | 74 | def doc_to_target(self, doc): 75 | return doc['choices'][doc['gold']] 76 | 77 | def should_decontaminate(self): 78 | return True 79 | 80 | def doc_to_decontamination_query(self, doc): 81 | return doc["query"] 82 | -------------------------------------------------------------------------------- /lm_eval/tasks/openbookqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Can a Suit of Armor Conduct Electricity? A New Dataset for Open Book Question Answering 3 | https://arxiv.org/pdf/1809.02789.pdf 4 | 5 | OpenBookQA is a question-answering dataset modeled after open book exams for 6 | assessing human understanding of a subject. It consists of 5,957 multiple-choice 7 | elementary-level science questions (4,957 train, 500 dev, 500 test), which probe 8 | the understanding of a small “book” of 1,326 core science facts and the application 9 | of these facts to novel situations. For training, the dataset includes a mapping 10 | from each question to the core science fact it was designed to probe. Answering 11 | OpenBookQA questions requires additional broad common knowledge, not contained 12 | in the book. The questions, by design, are answered incorrectly by both a retrieval- 13 | based algorithm and a word co-occurrence algorithm. 14 | 15 | Homepage: https://allenai.org/data/open-book-qa 16 | """ 17 | from lm_eval.base import MultipleChoiceTask 18 | from lm_eval.utils import create_dataloader 19 | 20 | _CITATION = """ 21 | @inproceedings{OpenBookQA2018, 22 | title={Can a Suit of Armor Conduct Electricity? A New Dataset for Open Book Question Answering}, 23 | author={Todor Mihaylov and Peter Clark and Tushar Khot and Ashish Sabharwal}, 24 | booktitle={EMNLP}, 25 | year={2018} 26 | } 27 | """ 28 | 29 | 30 | class OpenBookQA(MultipleChoiceTask): 31 | VERSION = 0 32 | DATASET_PATH = "openbookqa" 33 | DATASET_NAME = "main" 34 | 35 | def has_training_docs(self): 36 | return True 37 | 38 | def has_validation_docs(self): 39 | return True 40 | 41 | def has_test_docs(self): 42 | return True 43 | 44 | def training_docs(self): 45 | if self._training_docs is None: 46 | self._training_docs = list(map(self._process_doc, self.dataset["train"])) 47 | return self._training_docs 48 | 49 | def validation_docs(self): 50 | return map(self._process_doc, self.dataset["validation"]) 51 | 52 | def test_docs(self): 53 | return map(self._process_doc, self.dataset["test"]) 54 | 55 | def _process_doc(self, doc): 56 | out_doc = { 57 | "id": doc["id"], 58 | "query": doc["question_stem"], 59 | "choices": doc["choices"]["text"], 60 | "gold": ["A", "B", "C", "D"].index(doc["answerKey"].strip()), 61 | } 62 | return out_doc 63 | 64 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 65 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else self.test_docs() 66 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_cont, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 67 | 68 | def doc_to_cont(self, doc): 69 | return doc['choices'][doc['gold']] 70 | 71 | def doc_to_text(self, doc): 72 | return doc["query"] 73 | 74 | def should_decontaminate(self): 75 | return True 76 | 77 | def doc_to_decontamination_query(self, doc): 78 | return doc["query"] 79 | -------------------------------------------------------------------------------- /lm_eval/tasks/piqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | PIQA: Reasoning about Physical Commonsense in Natural Language 3 | https://arxiv.org/pdf/1911.11641.pdf 4 | 5 | Physical Interaction: Question Answering (PIQA) is a physical commonsense 6 | reasoning and a corresponding benchmark dataset. PIQA was designed to investigate 7 | the physical knowledge of existing models. To what extent are current approaches 8 | actually learning about the world? 9 | 10 | Homepage: https://yonatanbisk.com/piqa/ 11 | """ 12 | from lm_eval.base import MultipleChoiceTask 13 | from lm_eval.utils import create_dataloader 14 | 15 | _CITATION = """ 16 | @inproceedings{Bisk2020, 17 | author = {Yonatan Bisk and Rowan Zellers and 18 | Ronan Le Bras and Jianfeng Gao 19 | and Yejin Choi}, 20 | title = {PIQA: Reasoning about Physical Commonsense in 21 | Natural Language}, 22 | booktitle = {Thirty-Fourth AAAI Conference on 23 | Artificial Intelligence}, 24 | year = {2020}, 25 | } 26 | """ 27 | 28 | 29 | class PiQA(MultipleChoiceTask): 30 | VERSION = 0 31 | DATASET_PATH = "piqa" 32 | DATASET_NAME = None 33 | 34 | def has_training_docs(self): 35 | return True 36 | 37 | def has_validation_docs(self): 38 | return True 39 | 40 | def has_test_docs(self): 41 | return False 42 | 43 | def training_docs(self): 44 | if self._training_docs is None: 45 | self._training_docs = list(map(self._process_doc, self.dataset["train"])) 46 | return self._training_docs 47 | 48 | def validation_docs(self): 49 | return map(self._process_doc, self.dataset["validation"]) 50 | 51 | def _process_doc(self, doc): 52 | out_doc = { 53 | "goal": doc["goal"], 54 | "choices": [doc["sol1"], doc["sol2"]], 55 | "gold": doc["label"], 56 | } 57 | return out_doc 58 | 59 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 60 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 61 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_cont, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 62 | 63 | def doc_to_cont(self, doc): 64 | return doc['choices'][doc['gold']] 65 | 66 | def doc_to_text(self, doc): 67 | return "Question: " + doc["goal"] + "\nAnswer:" 68 | 69 | def should_decontaminate(self): 70 | return True 71 | 72 | def doc_to_decontamination_query(self, doc): 73 | return doc["goal"] 74 | -------------------------------------------------------------------------------- /lm_eval/tasks/superglue.py: -------------------------------------------------------------------------------- 1 | """ 2 | SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems 3 | https://w4ngatang.github.io/static/papers/superglue.pdf 4 | 5 | SuperGLUE is a benchmark styled after GLUE with a new set of more difficult language 6 | understanding tasks. 7 | 8 | Homepage: https://super.gluebenchmark.com/ 9 | 10 | TODO: WSC requires free-form generation. 11 | """ 12 | import numpy as np 13 | import sklearn 14 | import transformers.data.metrics.squad_metrics as squad_metrics 15 | from lm_eval.base import rf, Task 16 | from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno 17 | from lm_eval.utils import general_detokenize 18 | from lm_eval.utils import create_dataloader 19 | 20 | _CITATION = """ 21 | @inproceedings{NEURIPS2019_4496bf24, 22 | author = {Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel}, 23 | booktitle = {Advances in Neural Information Processing Systems}, 24 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 25 | pages = {}, 26 | publisher = {Curran Associates, Inc.}, 27 | title = {SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems}, 28 | url = {https://proceedings.neurips.cc/paper/2019/file/4496bf24afe7fab6f046bf4923da8de6-Paper.pdf}, 29 | volume = {32}, 30 | year = {2019} 31 | } 32 | """ 33 | 34 | 35 | class BoolQ(Task): 36 | VERSION = 1 37 | DATASET_PATH = "super_glue" 38 | DATASET_NAME = "boolq" 39 | 40 | def has_training_docs(self): 41 | return True 42 | 43 | def has_validation_docs(self): 44 | return True 45 | 46 | def has_test_docs(self): 47 | return False 48 | 49 | def training_docs(self): 50 | if self._training_docs is None: 51 | self._training_docs = list(self.dataset["train"]) 52 | return self._training_docs 53 | 54 | def validation_docs(self): 55 | return self.dataset["validation"] 56 | 57 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 58 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 59 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 60 | 61 | def doc_to_text(self, doc): 62 | return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:" 63 | 64 | def should_decontaminate(self): 65 | return True 66 | 67 | def doc_to_decontamination_query(self, doc): 68 | return doc["passage"] 69 | 70 | def doc_to_target(self, doc): 71 | return " " + yesno(doc["label"]) 72 | 73 | def construct_requests(self, doc, ctx): 74 | 75 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 76 | ll_no, _ = rf.loglikelihood(ctx, " no") 77 | 78 | return ll_yes, ll_no 79 | 80 | def process_results(self, doc, results): 81 | ll_yes, ll_no = results 82 | gold = doc["label"] 83 | 84 | acc = 1.0 if (ll_yes > ll_no) == gold else 0.0 85 | 86 | return {"acc": acc} 87 | 88 | def higher_is_better(self): 89 | return {"acc": True} 90 | 91 | def aggregation(self): 92 | return {"acc": mean} 93 | 94 | 95 | class CommitmentBank(Task): 96 | VERSION = 1 97 | DATASET_PATH = "super_glue" 98 | DATASET_NAME = "cb" 99 | 100 | def has_training_docs(self): 101 | return True 102 | 103 | def has_validation_docs(self): 104 | return True 105 | 106 | def has_test_docs(self): 107 | return False 108 | 109 | def training_docs(self): 110 | if self._training_docs is None: 111 | self._training_docs = list(self.dataset["train"]) 112 | return self._training_docs 113 | 114 | def validation_docs(self): 115 | return self.dataset["validation"] 116 | 117 | def doc_to_text(self, doc): 118 | return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format( 119 | doc["premise"], 120 | doc["hypothesis"], 121 | ) 122 | 123 | def doc_to_target(self, doc): 124 | # True = entailment 125 | # False = contradiction 126 | # Neither = neutral 127 | return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]]) 128 | 129 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 130 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 131 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 132 | 133 | def construct_requests(self, doc, ctx): 134 | ll_true, _ = rf.loglikelihood(ctx, " True") 135 | ll_false, _ = rf.loglikelihood(ctx, " False") 136 | ll_neither, _ = rf.loglikelihood(ctx, " Neither") 137 | 138 | return ll_true, ll_false, ll_neither 139 | 140 | def process_results(self, doc, results): 141 | gold = doc["label"] 142 | pred = np.argmax(results) 143 | acc = 1.0 if pred == gold else 0.0 144 | 145 | return {"acc": acc, "f1": (pred, gold)} 146 | 147 | def higher_is_better(self): 148 | return {"acc": True, "f1": True} 149 | 150 | @classmethod 151 | def cb_multi_fi(cls, items): 152 | preds, golds = zip(*items) 153 | preds = np.array(preds) 154 | golds = np.array(golds) 155 | f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0) 156 | f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1) 157 | f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2) 158 | avg_f1 = mean([f11, f12, f13]) 159 | return avg_f1 160 | 161 | def aggregation(self): 162 | return { 163 | "acc": mean, 164 | "f1": self.cb_multi_fi, 165 | } 166 | 167 | 168 | class Copa(Task): 169 | VERSION = 0 170 | DATASET_PATH = "super_glue" 171 | DATASET_NAME = "copa" 172 | 173 | def has_training_docs(self): 174 | return True 175 | 176 | def has_validation_docs(self): 177 | return True 178 | 179 | def has_test_docs(self): 180 | return False 181 | 182 | def training_docs(self): 183 | if self._training_docs is None: 184 | self._training_docs = list(self.dataset["train"]) 185 | return self._training_docs 186 | 187 | def validation_docs(self): 188 | return self.dataset["validation"] 189 | 190 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 191 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 192 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 193 | 194 | def doc_to_text(self, doc): 195 | # Drop the period 196 | connector = { 197 | "cause": "because", 198 | "effect": "therefore", 199 | }[doc["question"]] 200 | return doc["premise"].strip()[:-1] + f" {connector}" 201 | 202 | def doc_to_target(self, doc): 203 | correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"] 204 | # Connect the sentences 205 | return " " + self.convert_choice(correct_choice) 206 | 207 | def construct_requests(self, doc, ctx): 208 | choice1 = " " + self.convert_choice(doc["choice1"]) 209 | choice2 = " " + self.convert_choice(doc["choice2"]) 210 | 211 | ll_choice1, _ = rf.loglikelihood(ctx, choice1) 212 | ll_choice2, _ = rf.loglikelihood(ctx, choice2) 213 | 214 | return ll_choice1, ll_choice2 215 | 216 | def process_results(self, doc, results): 217 | gold = doc["label"] 218 | pred = np.argmax(results) 219 | acc = 1.0 if pred == gold else 0.0 220 | 221 | return {"acc": acc} 222 | 223 | def higher_is_better(self): 224 | return {"acc": True} 225 | 226 | def aggregation(self): 227 | return {"acc": mean} 228 | 229 | @staticmethod 230 | def convert_choice(choice): 231 | return choice[0].lower() + choice[1:] 232 | 233 | 234 | class MultiRC(Task): 235 | VERSION = 1 236 | DATASET_PATH = "super_glue" 237 | DATASET_NAME = "multirc" 238 | 239 | def has_training_docs(self): 240 | return True 241 | 242 | def has_validation_docs(self): 243 | return True 244 | 245 | def has_test_docs(self): 246 | return False 247 | 248 | def training_docs(self): 249 | if self._training_docs is None: 250 | self._training_docs = list(self.dataset["train"]) 251 | return self._training_docs 252 | 253 | def validation_docs(self): 254 | return list(self.dataset["validation"]) 255 | 256 | def doc_to_text(self, doc): 257 | return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:" 258 | 259 | def doc_to_target(self, doc): 260 | return " " + self.format_answer(answer=doc["answer"], label=doc["label"]) 261 | 262 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 263 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 264 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 265 | 266 | @staticmethod 267 | def format_answer(answer, label): 268 | label_str = "yes" if label else "no" 269 | return f"{answer}\nIs the answer correct? {label_str}" 270 | 271 | def construct_requests(self, doc, ctx): 272 | true_choice = self.format_answer(answer=doc["answer"], label=True) 273 | false_choice = self.format_answer(answer=doc["answer"], label=False) 274 | 275 | ll_true_choice, _ = rf.loglikelihood(ctx, f" {true_choice}") 276 | ll_false_choice, _ = rf.loglikelihood(ctx, f" {false_choice}") 277 | 278 | return ll_true_choice, ll_false_choice 279 | 280 | def process_results(self, doc, results): 281 | ll_true_choice, ll_false_choice = results 282 | pred = ll_true_choice > ll_false_choice 283 | return {"acc": (pred, doc)} 284 | 285 | def higher_is_better(self): 286 | return {"acc": True} 287 | 288 | def aggregation(self): 289 | return {"acc": acc_all} 290 | 291 | 292 | class ReCoRD(Task): 293 | VERSION = 0 294 | DATASET_PATH = "super_glue" 295 | DATASET_NAME = "record" 296 | 297 | def has_training_docs(self): 298 | return True 299 | 300 | def has_validation_docs(self): 301 | return True 302 | 303 | def has_test_docs(self): 304 | return False 305 | 306 | def training_docs(self): 307 | # In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing. 308 | # Each doc consists of multiple answer candidates, each of which is scored yes/no. 309 | if self._training_docs is None: 310 | self._training_docs = [] 311 | for doc in self.dataset["train"]: 312 | self._training_docs.append(self._process_doc(doc)) 313 | return self._training_docs 314 | 315 | def validation_docs(self): 316 | # See: training_docs 317 | for doc in self.dataset["validation"]: 318 | yield self._process_doc(doc) 319 | 320 | 321 | @classmethod 322 | def _process_doc(cls, doc): 323 | return { 324 | "passage": doc["passage"], 325 | "query": doc["query"], 326 | "entities": sorted(list(set(doc["entities"]))), 327 | "answers": sorted(list(set(doc["answers"]))), 328 | } 329 | 330 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 331 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 332 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 333 | 334 | def doc_to_text(self, doc): 335 | initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n") 336 | text = initial_text + "\n\n" 337 | for highlight in highlights: 338 | text += f" - {highlight}.\n" 339 | return text 340 | 341 | @classmethod 342 | def format_answer(cls, query, entity): 343 | return f" - {query}".replace("@placeholder", entity) 344 | 345 | def doc_to_target(self, doc): 346 | # We only output the first correct entity in a doc 347 | return self.format_answer(query=doc["query"], entity=doc["answers"][0]) 348 | 349 | def construct_requests(self, doc, ctx): 350 | requests = [ 351 | rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity)) 352 | for entity in doc["entities"] 353 | ] 354 | return requests 355 | 356 | def process_results(self, doc, results): 357 | # ReCoRD's evaluation is actually deceptively simple: 358 | # - Pick the maximum likelihood prediction entity 359 | # - Evaluate the accuracy and token F1 PER EXAMPLE 360 | # - Average over all examples 361 | max_idx = np.argmax(np.array([result[0] for result in results])) 362 | 363 | prediction = doc["entities"][max_idx] 364 | gold_label_set = doc["answers"] 365 | f1 = metric_max_over_ground_truths( 366 | squad_metrics.compute_f1, prediction, gold_label_set 367 | ) 368 | em = metric_max_over_ground_truths( 369 | squad_metrics.compute_exact, prediction, gold_label_set 370 | ) 371 | 372 | return { 373 | "f1": f1, 374 | "em": em, 375 | } 376 | 377 | def higher_is_better(self): 378 | return { 379 | "f1": True, 380 | "em": True, 381 | } 382 | 383 | def aggregation(self): 384 | return { 385 | "f1": mean, 386 | "em": mean, 387 | } 388 | 389 | 390 | class WordsInContext(Task): 391 | VERSION = 0 392 | DATASET_PATH = "super_glue" 393 | DATASET_NAME = "wic" 394 | 395 | def has_training_docs(self): 396 | return True 397 | 398 | def has_validation_docs(self): 399 | return True 400 | 401 | def has_test_docs(self): 402 | return False 403 | 404 | def training_docs(self): 405 | if self._training_docs is None: 406 | self._training_docs = list(self.dataset["train"]) 407 | return self._training_docs 408 | 409 | def validation_docs(self): 410 | return self.dataset["validation"] 411 | 412 | def doc_to_text(self, doc): 413 | return ( 414 | "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" 415 | " two sentences above?\nAnswer:".format( 416 | doc["sentence1"], 417 | doc["sentence2"], 418 | doc["sentence1"][doc["start1"] : doc["end1"]], 419 | ) 420 | ) 421 | 422 | def doc_to_target(self, doc): 423 | return " {}".format({0: "no", 1: "yes"}[doc["label"]]) 424 | 425 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 426 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 427 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 428 | 429 | def construct_requests(self, doc, ctx): 430 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 431 | ll_no, _ = rf.loglikelihood(ctx, " no") 432 | 433 | return ll_yes, ll_no 434 | 435 | def process_results(self, doc, results): 436 | ll_yes, ll_no = results 437 | gold = doc["label"] 438 | 439 | acc = 1.0 if (ll_yes > ll_no) == gold else 0.0 440 | 441 | return {"acc": acc} 442 | 443 | def higher_is_better(self): 444 | return {"acc": True} 445 | 446 | def aggregation(self): 447 | return {"acc": mean} 448 | 449 | 450 | class SGWinogradSchemaChallenge(Task): 451 | VERSION = 0 452 | # Note: This implementation differs from Fig G.32 because this is the SuperGLUE, 453 | # binary version of the task. 454 | DATASET_PATH = "super_glue" 455 | DATASET_NAME = "wsc" 456 | 457 | def has_training_docs(self): 458 | return True 459 | 460 | def has_validation_docs(self): 461 | return True 462 | 463 | def has_test_docs(self): 464 | return False 465 | 466 | def training_docs(self): 467 | if self.has_training_docs(): 468 | if self._training_docs is None: 469 | # GPT-3 Paper's format only uses positive examples for fewshot "training" 470 | self._training_docs = [ 471 | doc for doc in self.dataset["train"] if doc["label"] 472 | ] 473 | return self._training_docs 474 | 475 | def validation_docs(self): 476 | return self.dataset["validation"] 477 | 478 | def doc_to_text(self, doc): 479 | raw_passage = doc["text"] 480 | # NOTE: HuggingFace span indices are word-based not character-based. 481 | pre = " ".join(raw_passage.split()[: doc["span2_index"]]) 482 | post = raw_passage[len(pre) + len(doc["span2_text"]) + 1 :] 483 | passage = general_detokenize(pre + " *{}*".format(doc["span2_text"]) + post) 484 | noun = doc["span1_text"] 485 | pronoun = doc["span2_text"] 486 | text = ( 487 | f"Passage: {passage}\n" 488 | + f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n' 489 | + "Answer:" 490 | ) 491 | return text 492 | 493 | def doc_to_target(self, doc): 494 | return " " + yesno(doc["label"]) 495 | 496 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 497 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 498 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 499 | 500 | def construct_requests(self, doc, ctx): 501 | 502 | ll_yes, _ = rf.loglikelihood(ctx, " yes") 503 | ll_no, _ = rf.loglikelihood(ctx, " no") 504 | 505 | return ll_yes, ll_no 506 | 507 | def process_results(self, doc, results): 508 | ll_yes, ll_no = results 509 | gold = doc["label"] 510 | 511 | acc = 1.0 if (ll_yes > ll_no) == gold else 0.0 512 | 513 | return {"acc": acc} 514 | 515 | def higher_is_better(self): 516 | return {"acc": True} 517 | 518 | def aggregation(self): 519 | return {"acc": mean} 520 | -------------------------------------------------------------------------------- /lm_eval/tasks/winogrande.py: -------------------------------------------------------------------------------- 1 | """ 2 | WinoGrande: An Adversarial Winograd Schema Challenge at Scale 3 | https://arxiv.org/pdf/1907.10641.pdf 4 | 5 | WinoGrande is a collection of 44k problems, inspired by Winograd Schema Challenge 6 | (Levesque, Davis, and Morgenstern 2011), but adjusted to improve the scale and 7 | robustness against the dataset-specific bias. Formulated as a fill-in-a-blank 8 | task with binary options, the goal is to choose the right option for a given 9 | sentence which requires commonsense reasoning. 10 | 11 | NOTE: This evaluation of Winogrande uses partial evaluation as described by 12 | Trinh & Le in Simple Method for Commonsense Reasoning (2018). 13 | See: https://arxiv.org/abs/1806.02847 14 | 15 | Homepage: https://leaderboard.allenai.org/winogrande/submissions/public 16 | """ 17 | import numpy as np 18 | from lm_eval.base import rf, Task 19 | from lm_eval.metrics import mean 20 | from lm_eval.utils import create_dataloader 21 | 22 | _CITATION = """ 23 | @article{sakaguchi2019winogrande, 24 | title={WinoGrande: An Adversarial Winograd Schema Challenge at Scale}, 25 | author={Sakaguchi, Keisuke and Bras, Ronan Le and Bhagavatula, Chandra and Choi, Yejin}, 26 | journal={arXiv preprint arXiv:1907.10641}, 27 | year={2019} 28 | } 29 | """ 30 | 31 | 32 | class Winogrande(Task): 33 | VERSION = 0 34 | DATASET_PATH = "winogrande" 35 | DATASET_NAME = "winogrande_xl" 36 | 37 | answer_to_num = {"1": 0, "2": 1} 38 | 39 | def has_training_docs(self): 40 | return True 41 | 42 | def has_validation_docs(self): 43 | return True 44 | 45 | def has_test_docs(self): 46 | return False 47 | 48 | def training_docs(self): 49 | if self._training_docs is None: 50 | self._training_docs = list(self.dataset["train"]) 51 | return self._training_docs 52 | 53 | def validation_docs(self): 54 | return self.dataset["validation"] 55 | 56 | def doc_to_text(self, doc): 57 | return self.partial_context(doc, doc["option" + doc["answer"]]) 58 | 59 | def get_dataloader(self, tokenizer, split = 'train', subset_size = None, batch_size = 1, num_fewshot = 0): 60 | docs = self.training_docs() if split == 'train' else self.validation_docs() if split == 'validation' else RuntimeError(f'Data not available for {split} split.') 61 | return create_dataloader(tokenizer, docs, self.fewshot_context, self.doc_to_target, subset_size = subset_size, batch_size = batch_size, num_fewshot = num_fewshot) 62 | 63 | def should_decontaminate(self): 64 | return True 65 | 66 | def doc_to_decontamination_query(self, doc): 67 | return doc["sentence"] 68 | 69 | @classmethod 70 | def partial_context(cls, doc, option): 71 | # Substitute the pronoun in the sentence with the specified option 72 | # and ignore everything after. 73 | pronoun_loc = doc["sentence"].index("_") 74 | return doc["sentence"][:pronoun_loc] + option 75 | 76 | def doc_to_target(self, doc): 77 | return self.partial_target(doc) 78 | 79 | @classmethod 80 | def partial_target(cls, doc): 81 | # The target is everything after the document specified pronoun. 82 | pronoun_loc = doc["sentence"].index("_") + 1 83 | return " " + doc["sentence"][pronoun_loc:].strip() 84 | 85 | def construct_requests(self, doc, ctx): 86 | """Uses RequestFactory to construct Requests and returns an iterable of 87 | Requests which will be sent to the LM. 88 | 89 | :param doc: 90 | The document as returned from training_docs, validation_docs, or test_docs. 91 | :param ctx: str 92 | The context string, generated by fewshot_context. This includes the natural 93 | language description, as well as the few shot examples, and the question 94 | part of the document for `doc`. 95 | """ 96 | target = self.partial_target(doc) 97 | lls = [] 98 | for option in [doc["option1"], doc["option2"]]: 99 | partial_ctx = self.partial_context(doc, option) 100 | full_ctx = self.append_context(ctx, partial_ctx) 101 | lls.append(rf.loglikelihood(full_ctx, target)[0]) 102 | return lls 103 | 104 | @classmethod 105 | def append_context(cls, ctx, partial_ctx): 106 | ctx = ctx.split("\n\n") # Each fewshot context is on its own new line. 107 | ctx.pop() # Remove the correct context put in by `doc_to_text`. 108 | return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx 109 | 110 | def process_results(self, doc, results): 111 | """Take a single document and the LM results and evaluates, returning a 112 | dict where keys are the names of submetrics and values are the values of 113 | the metric for that one document 114 | 115 | :param doc: 116 | The document as returned from training_docs, validation_docs, or test_docs. 117 | :param results: 118 | The results of the requests created in construct_requests. 119 | """ 120 | return {"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]} 121 | 122 | def aggregation(self): 123 | """ 124 | :returns: {str: [float] -> float} 125 | A dictionary where keys are the names of submetrics and values are 126 | functions that aggregate a list of metrics 127 | """ 128 | return {"acc": mean} 129 | 130 | def higher_is_better(self): 131 | """ 132 | :returns: {str: bool} 133 | A dictionary where keys are the names of submetrics and values are 134 | whether a higher value of the submetric is better 135 | """ 136 | return {"acc": True} 137 | -------------------------------------------------------------------------------- /lm_eval/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pathlib 4 | import re 5 | import collections 6 | import functools 7 | import inspect 8 | import sys 9 | import pytest 10 | from typing import List 11 | from torch.utils.data import RandomSampler, DataLoader, Subset, Dataset 12 | 13 | class ExitCodeError(Exception): 14 | pass 15 | 16 | def sh(x): 17 | if os.system(x): 18 | raise ExitCodeError() 19 | 20 | 21 | def simple_parse_args_string(args_string): 22 | """ 23 | Parses something like 24 | args1=val1,arg2=val2 25 | Into a dictionary 26 | """ 27 | args_string = args_string.strip() 28 | if not args_string: 29 | return {} 30 | arg_list = args_string.split(",") 31 | args_dict = {} 32 | for arg in arg_list: 33 | k, v = arg.split("=") 34 | args_dict[k] = v 35 | return args_dict 36 | 37 | 38 | def join_iters(iters): 39 | for iter in iters: 40 | yield from iter 41 | 42 | 43 | def chunks(iter, n): 44 | arr = [] 45 | for x in iter: 46 | arr.append(x) 47 | if len(arr) == n: 48 | yield arr 49 | arr = [] 50 | 51 | if arr: 52 | yield arr 53 | 54 | 55 | def group(arr, fn): 56 | res = collections.defaultdict(list) 57 | 58 | for ob in arr: 59 | res[fn(ob)].append(ob) 60 | 61 | return list(res.values()) 62 | 63 | 64 | def general_detokenize(string): 65 | string = string.replace(" n't", "n't") 66 | string = string.replace(" )", ")") 67 | string = string.replace("( ", "(") 68 | string = string.replace('" ', '"') 69 | string = string.replace(' "', '"') 70 | string = re.sub(r" (['.,])", r"\1", string) 71 | return string 72 | 73 | 74 | def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): 75 | """ 76 | - context_len allows for a rolling window context, allowing each prediction window to potentially 77 | condition on some context 78 | 79 | :param token_list: list 80 | List of tokens to be PREDICTED 81 | :param max_seq_len: int 82 | max_seq_len of model (or max_seq_len we want to use) 83 | :param context_len: int 84 | Amount of desired token context for prediction. Needs to be at least 1. 85 | :param prefix_token: token 86 | Dummy token like so the first token has something to condition on 87 | :return: generator 88 | Generator of tuples 89 | (input_tokens, pred_tokens) 90 | Note: Score only the last len(pred_tokens) logits of the LM 91 | """ 92 | assert 1 <= context_len <= max_seq_len 93 | if not token_list: 94 | return 95 | # +1 offset, going from input->preds 96 | pred_len = max_seq_len - context_len + 1 97 | predicted = 0 98 | 99 | # Special handling for first window: predict all tokens 100 | first_seq_len = min(max_seq_len, len(token_list)) 101 | yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]) 102 | predicted += first_seq_len 103 | 104 | while predicted < len(token_list): 105 | window_pred_len = min(len(token_list) - predicted, pred_len) 106 | window_end = predicted + window_pred_len 107 | 108 | yield ( 109 | token_list[window_end - max_seq_len - 1 : window_end - 1], 110 | token_list[window_end - window_pred_len : window_end], 111 | ) 112 | predicted += window_pred_len 113 | 114 | 115 | def make_disjoint_window(pair): 116 | """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" 117 | 118 | a, b = pair 119 | 120 | return a[: -(len(b) - 1)], b 121 | 122 | 123 | class Reorderer: 124 | def __init__(self, arr, fn): 125 | self.size = len(arr) 126 | arr = list(enumerate(arr)) 127 | arr = group(arr, lambda x: fn(x[1])) 128 | arr = [([y[0] for y in x], x[0][1]) for x in arr] 129 | arr.sort(key=lambda x: fn(x[1])) 130 | 131 | self.arr = arr 132 | 133 | def get_reordered(self): 134 | return [x[1] for x in self.arr] 135 | 136 | def get_original(self, newarr): 137 | res = [None] * self.size 138 | cov = [False] * self.size 139 | 140 | for (inds, _), v in zip(self.arr, newarr): 141 | for ind in inds: 142 | res[ind] = v 143 | cov[ind] = True 144 | 145 | assert all(cov) 146 | 147 | return res 148 | 149 | 150 | def positional_deprecated(fn): 151 | """ 152 | A decorator to nudge users into passing only keyword args (`kwargs`) to the 153 | wrapped function, `fn`. 154 | """ 155 | 156 | @functools.wraps(fn) 157 | def _wrapper(*args, **kwargs): 158 | if len(args) != 1 if inspect.ismethod(fn) else 0: 159 | print( 160 | f"WARNING: using {fn.__name__} with positional arguments is " 161 | "deprecated and will be disallowed in a future version of " 162 | "lm-evaluation-harness!" 163 | ) 164 | return fn(*args, **kwargs) 165 | 166 | return _wrapper 167 | 168 | 169 | @positional_deprecated 170 | def find_test_root(start_path: pathlib.Path) -> pathlib.Path: 171 | """ 172 | Search upward in the directory tree to a maximum of three layers 173 | to find and return the package root (containing the 'tests' folder) 174 | """ 175 | cur_path = start_path.resolve() 176 | max_layers = 3 177 | for _ in range(max_layers): 178 | if (cur_path / "tests" / "test_version_stable.py").exists(): 179 | return cur_path 180 | else: 181 | cur_path = cur_path.parent.resolve() 182 | raise FileNotFoundError( 183 | f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" 184 | ) 185 | 186 | 187 | @positional_deprecated 188 | def run_task_tests(task_list: List[str]): 189 | """ 190 | Find the package root and run the tests for the given tasks 191 | """ 192 | package_root = find_test_root(start_path=pathlib.Path(__file__)) 193 | task_string = " or ".join(task_list) 194 | args = [ 195 | f"{package_root}/tests/test_version_stable.py", 196 | f"--rootdir={package_root}", 197 | "-k", 198 | f"{task_string}", 199 | ] 200 | sys.path.append(str(package_root)) 201 | pytest_return_val = pytest.main(args) 202 | if pytest_return_val: 203 | raise ValueError( 204 | f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" 205 | ) 206 | 207 | def get_input_from_ctx_cont(ctx, continuation, tokenizer): 208 | 209 | ctx_enc = [tokenizer.eos_token_id] if ctx=="" else tokenizer.encode(ctx, add_special_tokens=False) 210 | continuation_enc = tokenizer.encode(continuation, add_special_tokens=False) 211 | max_length = 2048 212 | inp = torch.tensor(ctx_enc + continuation_enc, dtype = torch.long) 213 | if max_length > len(inp):#np = torch.tensor(ctx_enc + continuation_enc, dtype = torch.long) 214 | inp = torch.cat([inp, torch.zeros(max_length - len(inp), dtype = torch.long)], dim = 0) 215 | return inp, len(ctx_enc), len(continuation_enc) 216 | 217 | def get_dataloader_from_dataset(dataset, subset_size = None, batch_size = 1): 218 | 219 | if subset_size and subset_size < len(dataset): 220 | dataset = Subset(dataset, list(range(subset_size))) 221 | sampler = RandomSampler(dataset) 222 | return DataLoader(dataset, sampler=sampler,batch_size=batch_size) 223 | 224 | def create_dataloader(tokenizer, docs, fewshot_context, doc_to_cont, subset_size = None, batch_size = 1, num_fewshot = 0): 225 | 226 | class NewDataset(Dataset): 227 | def __init__(self, tokenizer, docs, fewshot_context, doc_to_cont, num_fewshot): 228 | super().__init__() 229 | self.tokenizer = tokenizer 230 | self.docs = docs 231 | self.fewshot_context = fewshot_context 232 | self.doc_to_cont = doc_to_cont 233 | 234 | def __getitem__(self, i): 235 | out_doc = self.docs[i] 236 | ctx = self.fewshot_context(out_doc, num_fewshot) 237 | continuation = self.doc_to_cont(out_doc) 238 | inp, l_ctx_enc, l_cont_enc = get_input_from_ctx_cont(ctx, continuation, self.tokenizer) 239 | return ctx, continuation, inp, l_ctx_enc, l_cont_enc 240 | 241 | def __len__(self): 242 | return len(self.docs) 243 | 244 | ds = NewDataset(tokenizer, docs, fewshot_context, doc_to_cont, num_fewshot) 245 | return get_dataloader_from_dataset(ds, subset_size = subset_size, batch_size = batch_size) 246 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import fnmatch 5 | import os 6 | from lm_eval import tasks, evaluator 7 | 8 | logging.getLogger("openai").setLevel(logging.WARNING) 9 | 10 | 11 | class MultiChoice: 12 | def __init__(self, choices): 13 | self.choices = choices 14 | 15 | # Simple wildcard support (linux filename patterns) 16 | def __contains__(self, values): 17 | for value in values.split(","): 18 | if len(fnmatch.filter(self.choices, value)) == 0: 19 | return False 20 | 21 | return True 22 | 23 | def __iter__(self): 24 | for choice in self.choices: 25 | yield choice 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--model", required=True) 31 | parser.add_argument("--model_args", default="") 32 | parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS)) 33 | parser.add_argument("--provide_description", action="store_true") 34 | parser.add_argument("--num_fewshot", type=int, default=0) 35 | parser.add_argument("--batch_size", type=int, default=None) 36 | parser.add_argument("--device", type=str, default=None) 37 | parser.add_argument("--output_path", default=None) 38 | parser.add_argument("--limit", type=int, default=None) 39 | parser.add_argument("--no_cache", action="store_true") 40 | parser.add_argument("--head_importance_calc", action="store_true", default = False) 41 | parser.add_argument("--local_rank", type=int, default=None) 42 | parser.add_argument("--save_importance_path", type=str, default=None) 43 | parser.add_argument("--decontamination_ngrams_path", default=None) 44 | parser.add_argument("--description_dict_path", default=None) 45 | parser.add_argument("--check_integrity", action="store_true") 46 | 47 | return parser.parse_args() 48 | 49 | 50 | # Returns a list containing all values of the source_list that 51 | # match at least one of the patterns 52 | def pattern_match(patterns, source_list): 53 | task_names = set() 54 | for pattern in patterns: 55 | for matching in fnmatch.filter(source_list, pattern): 56 | task_names.add(matching) 57 | return list(task_names) 58 | 59 | 60 | def main(): 61 | args = parse_args() 62 | 63 | assert not args.provide_description # not implemented 64 | 65 | if args.limit: 66 | print( 67 | "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." 68 | ) 69 | 70 | if args.tasks is None: 71 | task_names = tasks.ALL_TASKS 72 | else: 73 | task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS) 74 | 75 | print(f"Selected Tasks: {task_names}") 76 | 77 | description_dict = {} 78 | if args.description_dict_path: 79 | with open(args.description_dict_path, "r") as f: 80 | description_dict = json.load(f) 81 | 82 | results = evaluator.simple_evaluate( 83 | model=args.model, 84 | model_args=args.model_args, 85 | tasks=task_names, 86 | num_fewshot=args.num_fewshot, 87 | batch_size=args.batch_size, 88 | device=args.device, 89 | no_cache=args.no_cache, 90 | limit=args.limit, 91 | description_dict=description_dict, 92 | decontamination_ngrams_path=args.decontamination_ngrams_path, 93 | check_integrity=args.check_integrity, 94 | head_importance_calc=args.head_importance_calc, 95 | save_importance_path=args.save_importance_path, 96 | ) 97 | 98 | # if args.local_rank == 0: 99 | if not args.head_importance_calc: 100 | dumped = json.dumps(results, indent=2) 101 | print(dumped) 102 | os.makedirs(os.path.dirname(args.output_path), exist_ok = True) 103 | if args.output_path: 104 | with open(args.output_path, "w") as f: 105 | f.write(dumped) 106 | 107 | print( 108 | f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " 109 | f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}" 110 | ) 111 | print(evaluator.make_table(results)) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /scripts/get_fc_ranking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pickle 5 | import random 6 | import argparse 7 | import numpy as np 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--results_dir", type=str, default=None) 11 | parser.add_argument("--save_dir", type=str, default=None) 12 | 13 | 14 | args = parser.parse_args() 15 | 16 | def get_accuracy(path, task): 17 | with open(path) as f: 18 | x = json.load(f) 19 | return x['results'][f'{task}']['acc'] * 100 20 | 21 | datasets = os.listdir(args.results_dir) 22 | 23 | for dataset in datasets: 24 | fc = [] 25 | baseline_result = get_accuracy(os.path.join(args.results_dir, dataset, 'none.txt'), dataset) 26 | for i in range(64): 27 | fc_perf = get_accuracy(os.path.join(args.results_dir, dataset, f'fc_{i}.txt'), dataset) 28 | fc.append(((baseline_result - fc_perf), i)) 29 | ordered = sorted(fc, key = lambda x: x[0]) 30 | order = list(zip(*ordered))[1] 31 | os.makedirs(args.save_dir, exist_ok = True) 32 | with open(os.path.join(args.save_dir, f'{dataset}.pkl'), 'wb') as f: 33 | pickle.dump(order, f) 34 | print(f'{dataset} saved!') 35 | 36 | print('All done') 37 | 38 | -------------------------------------------------------------------------------- /scripts/plotting/combined_pruning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from style import * 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--results_path", type=str, default=None) 10 | parser.add_argument("--save_plot_path", type=str, default=None) 11 | parser.add_argument("--shot", type=str, default="0-shot", choices=["0-shot","1-shot","5-shot"]) 12 | 13 | args = parser.parse_args() 14 | 15 | fc_percents = [0, 10, 20, 30, 40, 50] 16 | head_percents = [0, 10, 20, 30, 40, 50, 60, 70] 17 | 18 | def get_accuracy(path, task): 19 | with open(path) as f: 20 | x = json.load(f) 21 | return x['results'][f'{task}']['acc'] * 100 if task != "record" else x['results'][f'{task}']['em'] * 100 22 | 23 | 24 | datasets = ['hellaswag','piqa', 'arc_easy', 'arc_challenge', 'openbookqa', 'winogrande', 'boolq', 'cb', 'copa', 'wic', 'wsc', 'multirc', 'rte', 'record'] 25 | prefix = "1shot_" if '1-shot' in args.shot else "5shot_" if '5-shot' in args.shot else "" 26 | 27 | matrix = [] 28 | for dataset in datasets: 29 | # files = os.listdir(os.path.join(args.results_path, dataset)) 30 | # num_files_fc = list(filter(lambda x: '_fc_percent.txt' in x, files)) 31 | # num_files_head = list(filter(lambda x: '_percent.txt' in x and len(x) == 14, files)) 32 | # num_files_head_fc = list(filter(lambda x: '_head_percent.txt' in x, files)) 33 | # if len(num_files_fc) == 9 and len(num_files_head) == 11 and len(num_files_head_fc) == 35 and 'none.txt' in files: 34 | print(dataset) 35 | performance_matrix = np.zeros((len(head_percents), len(fc_percents))) 36 | for fc_percent in fc_percents: 37 | for head_percent in head_percents: 38 | if fc_percent == 0 and head_percent == 0: 39 | performance_matrix[0][0] = get_accuracy(os.path.join(args.results_path, dataset, f'{prefix}0_percent.txt'), dataset) 40 | elif fc_percent == 0: 41 | performance_matrix[head_percent//10][0] = get_accuracy(os.path.join(args.results_path, dataset, f'{prefix}{head_percent}_percent.txt'), dataset) 42 | elif head_percent == 0: 43 | performance_matrix[0][fc_percent//10] = get_accuracy(os.path.join(args.results_path, dataset, f'{prefix}{fc_percent}_fc_percent.txt'), dataset) 44 | else: 45 | performance_matrix[head_percent//10][fc_percent//10] = get_accuracy(os.path.join(args.results_path, dataset, f'{prefix}{fc_percent}_fc_{head_percent}_head_percent.txt'), dataset) 46 | matrix.append(np.expand_dims(performance_matrix, axis = 0)) 47 | max_, min_ = np.amax(performance_matrix), np.amin(performance_matrix) 48 | ax = sns.heatmap(performance_matrix, xticklabels = fc_percents, yticklabels = head_percents, cmap="YlGnBu", annot = True, vmax = max_, vmin = min_) 49 | ax.xaxis.tick_top() 50 | ax.xaxis.set_label_position('top') 51 | plt.title(f'Performance on {dataset} with Removal of Heads + FFN') 52 | plt.xlabel('Pruning of FFN (%)') 53 | plt.ylabel('Pruning of Attention Heads (%)') 54 | os.makedirs(os.path.dirname(args.save_plot_path), exist_ok = True) 55 | plt.savefig(os.path.join(args.save_plot_path, f'{dataset}.png')) 56 | plt.savefig(os.path.join(args.save_plot_path, f'{dataset}.pdf')) 57 | plt.close() 58 | 59 | matrix = np.concatenate(matrix, axis = 0) 60 | matrix = np.mean(matrix, axis = 0) 61 | max_, min_ = np.amax(matrix), np.amin(matrix) 62 | ax = sns.heatmap(matrix, xticklabels = fc_percents, yticklabels = head_percents, cmap="YlGnBu", annot = True, vmax = max_, vmin = min_) 63 | ax.xaxis.tick_top() 64 | ax.xaxis.set_label_position('top') 65 | plt.title(f'Average Performance after Combined Pruning of Heads and FFNs ({args.shot})') 66 | plt.xlabel('Pruning of FFNs (%)') 67 | plt.ylabel('Pruning of Attention Heads (%)') 68 | os.makedirs(os.path.dirname(args.save_plot_path), exist_ok = True) 69 | plt.savefig(os.path.join(args.save_plot_path, f'averaged.png')) 70 | plt.savefig(os.path.join(args.save_plot_path, f'averaged.pdf')) 71 | plt.close() 72 | -------------------------------------------------------------------------------- /scripts/plotting/crosstask_accuracy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import random 5 | import argparse 6 | import numpy as np 7 | from style import * 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--results_path", type=str, default=None) 11 | parser.add_argument("--save_plot_path", type=str, default=None) 12 | parser.add_argument("--shot", type=str, default="0-shot", choices=["0-shot","1-shot","5-shot"]) 13 | parser.add_argument("--ood", action='store_true') 14 | 15 | args = parser.parse_args() 16 | 17 | ## edit this to change the dataset 18 | datasets = ['record'] 19 | 20 | cross_task_map = {'0-shot':{'copa':{'high':'winogrande', 'low': 'record'},'winogrande':{'high':'copa', 'low':'record'}, 21 | 'record':{'high':'rte', 'low': 'openbookqa'}}, 22 | '1-shot':{'copa':{'high':'wsc', 'low': 'record'},'winogrande':{'high':'copa', 'low':'record'}, 23 | 'record':{'high':'hellaswag', 'low': 'wic'}}, 24 | '5-shot':{'copa':{'high':'wsc', 'low': 'record'},'winogrande':{'high':'boolq', 'low':'record'}, 25 | 'record':{'high':'hellaswag', 'low': 'wic'}}} 26 | 27 | files = [10, 30, 50, 70, 90] 28 | prefix = "1shot_" if '1-shot' in args.shot else "5shot_" if '5-shot' in args.shot else "" 29 | 30 | for dataset in datasets: 31 | original_path = os.path.join(args.results_path, dataset) 32 | l_n, l_h, l_l, l_a = [], [], [], [] 33 | for fname in files: 34 | filename_normal = os.path.join(original_path, f'{prefix}{fname}_percent.txt') 35 | if not args.ood: 36 | filename_high = os.path.join(args.results_path, 'cross_task', f'{prefix}{dataset}_high', f'{fname}_percent.txt') 37 | filename_low = os.path.join(args.results_path, 'cross_task', f'{prefix}{dataset}_low', f'{fname}_percent.txt') 38 | filename_agg = os.path.join(args.results_path, 'cross_task', f'{prefix}{dataset}_aggregate', f'{fname}_percent.txt') 39 | metric = 'em' if dataset=='record' else 'acc' 40 | with open(filename_normal, 'rb') as f: 41 | res = json.load(f) 42 | l_n.append(res['results'][f'{dataset}'][metric] * 100) 43 | if not args.ood: 44 | with open(filename_high, 'rb') as f: 45 | res = json.load(f) 46 | l_h.append(res['results'][f'{dataset}'][metric] * 100) 47 | with open(filename_low, 'rb') as f: 48 | res = json.load(f) 49 | l_l.append(res['results'][f'{dataset}'][metric] * 100) 50 | with open(filename_agg, 'rb') as f: 51 | res = json.load(f) 52 | l_a.append(res['results'][f'{dataset}'][metric] * 100) 53 | plt.plot(files, l_n, marker = 'o', label = f'Using {DATASET_TO_OFFICIAL[dataset]} ranking ({args.shot})') 54 | if not args.ood: 55 | high_dataset, low_dataset = cross_task_map[args.shot][dataset]['high'], cross_task_map[args.shot][dataset]['low'] 56 | plt.plot(files, l_h, marker = 'o', label = f'Using {DATASET_TO_OFFICIAL[high_dataset]} dataset ranking ({args.shot})') 57 | plt.plot(files, l_l, marker = 'o', label = f'Using {DATASET_TO_OFFICIAL[low_dataset]} dataset ranking ({args.shot})') 58 | plt.plot(files, l_a, marker = 'o', label = f'Using Aggregate Ranking ({args.shot})') 59 | plt.legend() 60 | plt.xlabel('Percentage Pruned (%)') 61 | plt.ylabel('Accuracy (%)') 62 | plt.title(f'Cross Task Transfer for {DATASET_TO_OFFICIAL[dataset]}') 63 | plt.grid() 64 | plt.tight_layout() 65 | os.makedirs(os.path.dirname(args.save_plot_path), exist_ok = True) 66 | plt.savefig(os.path.join(args.save_plot_path, f'{prefix}{dataset}.png')) 67 | plt.savefig(os.path.join(args.save_plot_path, f'{prefix}{dataset}.pdf')) 68 | plt.close() -------------------------------------------------------------------------------- /scripts/plotting/fc_importance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import pickle 6 | import argparse 7 | import numpy as np 8 | from style import * 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--results_path", type=str, default=None) 12 | parser.add_argument("--base_results_path", type=str, default=None) 13 | parser.add_argument("--save_plot_path", type=str, default=None) 14 | parser.add_argument("--dump_fc_importance", action='store_true') 15 | parser.add_argument("--dump_fc_importance_path", type=str, default=None) 16 | parser.add_argument("--shot", type=str, default="0-shot", choices=["0-shot","1-shot","5-shot"]) 17 | 18 | args = parser.parse_args() 19 | 20 | def get_accuracy(path, dataset): 21 | with open(path) as f: 22 | x = json.load(f) 23 | return x['results'][f'{dataset}']['acc'] * 100 if dataset != "record" else x['results'][f'{dataset}']['em'] * 100 24 | 25 | # datasets = ['hellaswag','piqa', 'arc_easy', 'arc_challenge', 'openbookqa', 'winogrande', 'boolq', 'cb', 'copa', 'wic', 'wsc', 'multirc', 'rte', 'record'] 26 | datasets = ['wic', 'multirc', 'record'] 27 | li_avg = [] 28 | c = 0 29 | prefix = "1shot_" if '1-shot' in args.shot else "5shot_" if '5-shot' in args.shot else "" 30 | for dataset in datasets: 31 | if args.base_results_path == None: 32 | args.base_results_path = args.results_path 33 | baseline_path = os.path.join(args.base_results_path, dataset, f'{prefix}0_percent.txt') 34 | baseline_result = get_accuracy(baseline_path, dataset) 35 | results_collector = [] 36 | for i in range(64): 37 | fc_score = get_accuracy(os.path.join(args.results_path, dataset, f'{prefix}fc_{i}.txt'), dataset) 38 | results_collector.append(baseline_result - fc_score) 39 | print(dataset) 40 | if args.dump_fc_importance: 41 | os.makedirs(os.path.dirname(args.dump_fc_importance_path), exist_ok = True) 42 | with open(f'{args.dump_fc_importance_path}{prefix}{dataset}.pkl', 'wb') as f: 43 | zipped = list(zip(list(range(64)), results_collector)) 44 | temp = sorted(zipped, key = lambda x: x[1]) 45 | fc_knocking_importance = list(list(zip(*temp))[0]) 46 | print(fc_knocking_importance) 47 | pickle.dump(fc_knocking_importance, f) 48 | print(f"Dumped {dataset}") 49 | 50 | li_avg.append(results_collector) 51 | plt.plot(results_collector, alpha = 0.5, label = DATASET_TO_OFFICIAL[dataset], color=DATASET_TO_COLOR[dataset]) 52 | 53 | if args.dump_fc_importance: 54 | sys.exit() 55 | 56 | average = np.mean(np.array(li_avg), axis = 0) 57 | plt.plot(average, linewidth=4, label = 'Average', color = 'k') 58 | plt.xticks(list(range(1,65,3))) 59 | # fig = plt.gcf() 60 | # fig.set_size_inches(25, 15) 61 | 62 | plt.xlabel('Layer Number', fontsize = 15) 63 | plt.ylabel(f'Accuracy Difference (%)', fontsize = 15) 64 | plt.title(f'FFN Oracle Pruning ({args.shot})', fontsize = 15) 65 | plt.legend(bbox_to_anchor=[1.0, 0.75]) 66 | plt.tight_layout() 67 | 68 | os.makedirs(os.path.dirname(args.save_plot_path), exist_ok = True) 69 | 70 | plt.savefig(args.save_plot_path) 71 | plt.savefig(args.save_plot_path.replace('png', 'pdf')) 72 | plt.close() -------------------------------------------------------------------------------- /scripts/plotting/heatmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | import torch 5 | import numpy as np 6 | from style import * 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--saved_head_importance_path", type=str, default=None) 10 | parser.add_argument("--dataset", type=str, default=None) 11 | parser.add_argument("--save_plot_path", type=str, default=None) 12 | parser.add_argument("--aggregate", action = "store_true") 13 | parser.add_argument("--shot", type=str, default="0-shot", choices=["0-shot","1-shot","5-shot"]) 14 | 15 | args = parser.parse_args() 16 | 17 | prefix = "1shot_" if '1-shot' in args.shot else "5shot_" if '5-shot' in args.shot else "" 18 | 19 | if args.aggregate: 20 | datasets = ['hellaswag','piqa', 'arc_easy', 'arc_challenge', 'openbookqa', 'winogrande', 'boolq', 'cb', 'copa', 'wic', 'wsc', 'multirc', 'rte', 'record'] 21 | himp = [] 22 | for dataset in datasets: 23 | pth = f'{dataset}.pkl' if "0-shot" in args.shot else f'1shot_{dataset}.pkl' if "1-shot" in args.shot else f'5shot_{dataset}.pkl' 24 | file_path = os.path.join(args.saved_head_importance_path, pth) 25 | print(file_path) 26 | with open(file_path, 'rb') as handle: 27 | himp.append(pickle.load(handle).unsqueeze(0)) 28 | results = torch.mean(torch.cat(himp, dim = 0), dim = 0) 29 | with open(os.path.join(os.path.dirname(args.saved_head_importance_path),f'{prefix}aggregate.pkl'), 'wb') as f: 30 | pickle.dump(results, f) 31 | print('aggregate saved!') 32 | else: 33 | with open(args.saved_head_importance_path, 'rb') as handle: 34 | results = pickle.load(handle) 35 | 36 | 37 | results = results.view(64, 72) 38 | layers, heads = results.shape 39 | min_, max_ = torch.min(results).item(), torch.max(results).item() 40 | if args.aggregate: 41 | ax = sns.heatmap(results, xticklabels = [(i+1) if i%2==0 else None for i in range(heads)], yticklabels = [(i+1)if i%2==0 else None for i in range(layers)], vmax = 0.002 if '1-shot' in args.shot else 0.004 if '0-shot' in args.shot else 0.0014) 42 | else: 43 | ax = sns.heatmap(results, xticklabels = [(i+1) if i%2==0 else None for i in range(heads)], yticklabels = [(i+1)if i%2==0 else None for i in range(layers)], vmin = min_ , vmax = max_/10) 44 | if not args.aggregate: 45 | dataset_name = DATASET_TO_OFFICIAL[args.dataset] 46 | else: 47 | dataset_name = 'Aggregate' 48 | plt.title(f'Head Importance Score ({args.shot} {dataset_name})') # title with fontsize 20 49 | plt.xlabel("Heads") # x-axis label with fontsize 15 50 | plt.ylabel("Layers") # y-axis label with fontsize 15 51 | ax.invert_yaxis() 52 | 53 | 54 | os.makedirs(os.path.dirname(args.save_plot_path), exist_ok = True) 55 | plt.savefig(args.save_plot_path) -------------------------------------------------------------------------------- /scripts/plotting/iterative_pruning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from style import * 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--results_path", type=str, default=None) 10 | parser.add_argument("--title", type=str, default=None) 11 | parser.add_argument("--save_plot_path", type=str, default=None) 12 | parser.add_argument("--shot", type=str, default="0-shot", choices=["0-shot","1-shot","5-shot"]) 13 | 14 | args = parser.parse_args() 15 | 16 | def get_accuracy(path, task): 17 | with open(path) as f: 18 | x = json.load(f) 19 | return x['results'][f'{task}']['acc'] * 100 if task != "record" else x['results'][f'{task}']['em'] * 100 20 | 21 | # datasets = ["arc_easy", "wsc", "openbookqa", "piqa", "rte", "cb", "hellaswag", "copa", "wic", "arc_challenge", "winogrande"]# "boolq", "multirc"] 22 | # datasets = ["piqa", "arc_challenge", "openbookqa", "winogrande", "cb", "copa", "wic", "wsc", "rte"] 23 | datasets = ['hellaswag','piqa', 'arc_easy', 'arc_challenge', 'openbookqa', 'winogrande', 'boolq', 'cb', 'copa', 'wic', 'wsc', 'multirc', 'rte', 'record'] 24 | prefix = "1shot_" if '1-shot' in args.shot else "5shot_" if '5-shot' in args.shot else "" 25 | 26 | percents = [10, 20, 30, 40, 50, 60, 70, 80, 90] 27 | li = [] 28 | 29 | for dataset in datasets: 30 | files = os.listdir(os.path.join(args.results_path, dataset)) 31 | num_files = list(filter(lambda x: '_percent.txt' in x and len(x)==14, files)) 32 | # num_files = list(filter(lambda x: '_percent.txt' in x and '1shot_' in x, files)) 33 | if len(num_files) >= 9: 34 | # if len(num_files) == 10: 35 | print(dataset) 36 | pruning_nums = [] 37 | baseline_path = os.path.join(args.results_path, dataset, f'{prefix}0_percent.txt') 38 | baseline_result = get_accuracy(baseline_path, dataset) 39 | pruning_nums.append(baseline_result) 40 | for percent in percents: 41 | result = get_accuracy(os.path.join(args.results_path, dataset, f'{prefix}{percent}_fc_percent.txt'), dataset) 42 | # result = get_accuracy(os.path.join(args.results_path, dataset, f'1shot_{percent}_percent.txt'), dataset) 43 | pruning_nums.append(result) 44 | li.append(pruning_nums) 45 | plt.plot([0]+percents, pruning_nums, label = DATASET_TO_OFFICIAL[dataset], alpha = 0.6, marker = DATASET_TO_MARKER[dataset], color=DATASET_TO_COLOR[dataset]) 46 | 47 | average = np.mean(np.array(li), axis = 0) 48 | plt.plot([0]+percents, average, 'k', linewidth = 3, label = 'Average', marker = 'o') 49 | print(average) 50 | plt.ylabel('Accuracy (%)', fontsize = 20) 51 | plt.xlabel('Percentage pruned (%)', fontsize = 20) 52 | plt.xticks(list(range(0,100,10))) 53 | plt.yticks(list(range(0,100,10))) 54 | plt.title(args.title) 55 | plt.legend(bbox_to_anchor=(1.0,0.75)) 56 | # plt.axvline(x=70, linestyle = '--') 57 | plt.axvline(x=10, linestyle = '--') 58 | plt.tight_layout() 59 | 60 | os.makedirs(os.path.dirname(args.save_plot_path), exist_ok = True) 61 | plt.savefig(args.save_plot_path) 62 | plt.savefig(args.save_plot_path[:-3]+'pdf') -------------------------------------------------------------------------------- /scripts/plotting/prefix_copying_pruning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from style import * 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--aggregate_path_0_shot", type=str, default=None) 10 | parser.add_argument("--aggregate_path_1_shot", type=str, default=None) 11 | parser.add_argument("--aggregate_path_5_shot", type=str, default=None) 12 | parser.add_argument("--prefix_matching_path", type=str, default=None) 13 | parser.add_argument("--copying_path", type=str, default=None) 14 | parser.add_argument("--save_prefix_plot_path", type=str, default=None) 15 | parser.add_argument("--save_copying_plot_path", type=str, default=None) 16 | 17 | args = parser.parse_args() 18 | 19 | def open_file(path): 20 | with open(path, 'rb') as f: 21 | res = pickle.load(f) 22 | return res 23 | 24 | prefix_matching_scores = open_file(args.prefix_matching_path)['mean'].view(-1) 25 | copying_scores = open_file(args.copying_path)['mean'].view(-1) 26 | total_prefix = prefix_matching_scores.sum() 27 | total_copying = copying_scores.sum() 28 | 29 | shots = [0, 1, 5] 30 | aggregate_paths = [args.aggregate_path_0_shot, args.aggregate_path_1_shot, args.aggregate_path_5_shot] 31 | 32 | for shot, aggregate_path in zip(shots,aggregate_paths): 33 | aggregate_scores = open_file(aggregate_path).view(-1) 34 | assert aggregate_scores.shape == prefix_matching_scores.shape == copying_scores.shape 35 | _, ranking = torch.sort(aggregate_scores) 36 | sum_prefix = [] 37 | x_axis = [] 38 | for i in range(len(ranking)): 39 | sum_prefix.append((prefix_matching_scores[ranking[i:]].sum().item() * 100) / total_prefix) 40 | x_axis.append((100 * i) / len(ranking)) 41 | plt.plot(x_axis, sum_prefix, label = f'{shot}-shot') 42 | 43 | plt.xlabel('Percentage pruned (%)', fontsize=20) 44 | plt.ylabel('% of Total Prefix Matching Score Retained', fontsize=20) 45 | plt.title('Impact of Pruning Attention Heads on Prefix Matching', wrap=True) 46 | plt.legend() 47 | plt.grid() 48 | os.makedirs(os.path.dirname(args.save_prefix_plot_path), exist_ok = True) 49 | plt.savefig(args.save_prefix_plot_path) 50 | plt.savefig(args.save_prefix_plot_path[:-4]+'.pdf') 51 | plt.close() 52 | 53 | 54 | for shot, aggregate_path in zip(shots,aggregate_paths): 55 | aggregate_scores = open_file(aggregate_path).view(-1) 56 | assert aggregate_scores.shape == prefix_matching_scores.shape == copying_scores.shape 57 | _, ranking = torch.sort(aggregate_scores) 58 | sum_copying = [] 59 | x_axis = [] 60 | for i in range(len(ranking)): 61 | sum_copying.append((copying_scores[ranking[i:]].sum().item() * 100) / total_copying) 62 | x_axis.append((100 * i) / len(ranking)) 63 | plt.plot(x_axis, sum_copying, label = f'{shot}-shot') 64 | 65 | 66 | plt.xlabel('Percentage pruned (%)', fontsize=20) 67 | plt.ylabel("% of Total Copying Score Retained", fontsize=20) 68 | plt.title('Impact of Pruning Attention Heads on Copying', wrap=True) 69 | plt.legend() 70 | plt.grid() 71 | os.makedirs(os.path.dirname(args.save_copying_plot_path), exist_ok = True) 72 | plt.savefig(args.save_copying_plot_path) 73 | plt.savefig(args.save_copying_plot_path[:-4]+'.pdf') 74 | plt.close() -------------------------------------------------------------------------------- /scripts/plotting/prefix_copying_task_specific.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from style import * 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--results_path", type=str, default=None) 10 | parser.add_argument("--shot", type=str, default="0-shot", choices=["0-shot","1-shot","5-shot"]) 11 | parser.add_argument("--prefix_matching_path", type=str, default=None) 12 | parser.add_argument("--copying_path", type=str, default=None) 13 | parser.add_argument("--save_prefix_plot_path", type=str, default=None) 14 | parser.add_argument("--save_copying_plot_path", type=str, default=None) 15 | 16 | args = parser.parse_args() 17 | 18 | 19 | datasets = ['hellaswag','piqa', 'arc_easy', 'arc_challenge', 'openbookqa', 'winogrande', 'boolq', 'cb', 'copa', 'wic', 'wsc', 'multirc', 'rte', 'record'] 20 | prefix = "1shot_" if '1-shot' in args.shot else "5shot_" if '5-shot' in args.shot else "" 21 | 22 | 23 | def open_file(path): 24 | with open(path, 'rb') as f: 25 | res = pickle.load(f) 26 | return res 27 | 28 | prefix_matching_scores = open_file(args.prefix_matching_path)['mean'].view(-1) 29 | copying_scores = open_file(args.copying_path)['mean'].view(-1) 30 | total_prefix = prefix_matching_scores.sum() 31 | total_copying = copying_scores.sum() 32 | 33 | # shots = [0, 1, 5] 34 | # aggregate_paths = [args.aggregate_path_0_shot, args.aggregate_path_1_shot, args.aggregate_path_5_shot] 35 | 36 | for dataset in datasets: 37 | rank_path = os.path.join(args.results_path, f'{prefix}{dataset}.pkl') 38 | aggregate_scores = open_file(rank_path).view(-1) 39 | assert aggregate_scores.shape == prefix_matching_scores.shape == copying_scores.shape 40 | _, ranking = torch.sort(aggregate_scores) 41 | sum_prefix = [] 42 | x_axis = [] 43 | for i in range(len(ranking)): 44 | sum_prefix.append((prefix_matching_scores[ranking[i:]].sum().item() * 100) / total_prefix) 45 | x_axis.append((100 * i) / len(ranking)) 46 | plt.plot(x_axis, sum_prefix, label = DATASET_TO_OFFICIAL[dataset], color=DATASET_TO_COLOR[dataset]) 47 | 48 | plt.xlabel('Percentage pruned (%)', fontsize=20) 49 | plt.ylabel('% of Total Prefix Matching Score Retained', fontsize=20) 50 | plt.title(f'Impact of Pruning Attention Heads on Prefix Matching ({args.shot})', wrap=True) 51 | plt.legend() 52 | plt.grid() 53 | os.makedirs(os.path.dirname(args.save_prefix_plot_path), exist_ok = True) 54 | plt.savefig(args.save_prefix_plot_path) 55 | plt.savefig(args.save_prefix_plot_path[:-4]+'.pdf') 56 | plt.close() 57 | 58 | 59 | for dataset in datasets: 60 | rank_path = os.path.join(args.results_path, f'{prefix}{dataset}.pkl') 61 | aggregate_scores = open_file(rank_path).view(-1) 62 | _, ranking = torch.sort(aggregate_scores) 63 | sum_copying = [] 64 | x_axis = [] 65 | for i in range(len(ranking)): 66 | sum_copying.append((copying_scores[ranking[i:]].sum().item() * 100) / total_copying) 67 | x_axis.append((100 * i) / len(ranking)) 68 | plt.plot(x_axis, sum_copying, label = DATASET_TO_OFFICIAL[dataset], color=DATASET_TO_COLOR[dataset]) 69 | 70 | 71 | plt.xlabel('Percentage pruned (%)', fontsize=20) 72 | plt.ylabel("% of Total Copying Score Retained", fontsize=20) 73 | plt.title(f'Impact of Pruning Attention Heads on Copying ({args.shot})', wrap=True) 74 | plt.legend() 75 | plt.grid() 76 | os.makedirs(os.path.dirname(args.save_copying_plot_path), exist_ok = True) 77 | plt.savefig(args.save_copying_plot_path) 78 | plt.savefig(args.save_copying_plot_path[:-4]+'.pdf') 79 | plt.close() -------------------------------------------------------------------------------- /scripts/plotting/spearman_rankings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import random 5 | import argparse 6 | import numpy as np 7 | from scipy import stats 8 | from style import * 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--results_path", type=str, default=None) 12 | parser.add_argument("--save_plot_path", type=str, default=None) 13 | parser.add_argument("--random", action = "store_true") 14 | parser.add_argument("--aggregate", action = "store_true") 15 | parser.add_argument("--shot", type=str, default="0-shot", choices=["0-shot","1-shot","5-shot"]) 16 | 17 | args = parser.parse_args() 18 | 19 | datasets = ['hellaswag','piqa', 'arc_easy', 'arc_challenge', 'openbookqa', 'winogrande', 'boolq', 'cb', 'copa', 'wic', 'wsc', 'multirc', 'rte', 'record'] 20 | prefix = "1shot_" if '1-shot' in args.shot else "5shot_" if '5-shot' in args.shot else "" 21 | 22 | # datalengths = [5000, 5000, 635, 5000, 5000, 400, 3668, 5000, 2251, 1119, 250, 5000, 5000, 4957, 2490, 5000, 259, 5000] 23 | 24 | rankings = [] 25 | aggregate, count = torch.zeros(4608), 0 26 | new_xticks = [] 27 | for i in range(len(datasets)): 28 | pth = os.path.join(args.results_path, f'{prefix}{datasets[i]}.pkl') 29 | with open(pth, 'rb') as f: 30 | res = pickle.load(f).view(-1) 31 | new_xticks.append(f'{DATASET_TO_OFFICIAL[datasets[i]]}') 32 | print(datasets[i]) 33 | aggregate = aggregate + (res) 34 | count += 1 35 | ranking = torch.sort(res)[1].unsqueeze(0) 36 | rankings.append(ranking) 37 | 38 | if args.aggregate: 39 | aggregate = aggregate / count 40 | new_xticks.append('Aggregate') 41 | rankings.append(torch.sort(aggregate)[1].unsqueeze(0)) 42 | with open(f'logs/head_importance/opt66b/{prefix}aggregate.pkl', 'wb') as f: 43 | pickle.dump(aggregate, f) 44 | print('Aggregated Ranking Saved!') 45 | 46 | if args.random: 47 | random_li = list(range(4608)) 48 | random.shuffle(random_li) 49 | random_li = torch.tensor(random_li) 50 | new_xticks.append('Random') 51 | rankings.append(torch.sort(random_li)[1].unsqueeze(0)) 52 | 53 | rankings = torch.cat(rankings).numpy() # Num_tasks, Dimension 54 | print(rankings.shape) 55 | num_tasks, _ = rankings.shape 56 | 57 | matrix = stats.spearmanr(rankings, rankings, axis = 1).correlation 58 | print(matrix.shape) 59 | matrix = matrix[:num_tasks, :num_tasks] 60 | 61 | ax = sns.heatmap(matrix, xticklabels = new_xticks, yticklabels = new_xticks, cmap="YlGnBu", annot = True, vmax = 0.5) 62 | 63 | plt.title(f'Spearman Rank Correlations between Importance Score Orders') 64 | plt.yticks(rotation=0) 65 | plt.xticks(rotation=90) 66 | if args.aggregate: 67 | ax.hlines([len(datasets)], *ax.get_xlim(), colors = 'C1', linewidth=3) 68 | ax.vlines([len(datasets)], *ax.get_ylim(), colors = 'C1', linewidth=3) 69 | # ax.figure.tight_layout() 70 | # fig = plt.gcf() 71 | # fig.set_size_inches(20, 12.5) 72 | # fig.set_dpi(100) 73 | plt.tight_layout() 74 | 75 | os.makedirs(os.path.dirname(args.save_plot_path), exist_ok = True) 76 | plt.savefig(args.save_plot_path) 77 | plt.savefig(args.save_plot_path[:-3]+'pdf') -------------------------------------------------------------------------------- /scripts/plotting/style.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | 4 | params = { 5 | "axes.titlesize": 22, 6 | "legend.fontsize": 16, 7 | "figure.figsize": (12, 8), 8 | "axes.labelsize": 16, 9 | "axes.titlesize": 20, 10 | "xtick.labelsize": 16, 11 | "ytick.labelsize": 16, 12 | "figure.titlesize": 22, 13 | "font.family": "Liberation Mono" 14 | } 15 | 16 | plt.rcParams.update(params) 17 | plt.style.use("seaborn-whitegrid") 18 | sns.set_style("white") 19 | 20 | DATASET_TO_OFFICIAL = {'hellaswag': 'HellaSwag', 21 | 'piqa': 'PIQA', 22 | 'arc_easy': 'ARC (Easy)', 23 | 'arc_challenge': 'ARC (Challenge)', 24 | 'openbookqa': 'OpenBookQA', 25 | 'winogrande': 'Winogrande', 26 | 'boolq': 'BoolQ', 27 | 'cb': 'CB', 28 | 'wic': 'WIC', 29 | 'wsc': 'WSC', 30 | 'multirc': 'MultiRC', 31 | 'rte': 'RTE', 32 | 'record': 'ReCoRD', 33 | 'copa': 'COPA', 34 | 'lambada': 'LAMBADA', 35 | 'mathqa': 'MathQA' 36 | } 37 | 38 | DATASET_TO_MARKER = {'hellaswag': 'o', 39 | 'piqa': 'o', 40 | 'arc_easy': 'o', 41 | 'arc_challenge': 'o', 42 | 'openbookqa': 'o', 43 | 'winogrande': 'o', 44 | 'boolq': 'o', 45 | 'cb': 'o', 46 | 'wic': '^', 47 | 'wsc': '^', 48 | 'multirc': '^', 49 | 'rte': '^', 50 | 'record': '^', 51 | 'copa': '^', 52 | 'lambada': '^', 53 | 'mathqa': '^' 54 | } 55 | 56 | DATASET_TO_COLOR = {'hellaswag': 'maroon', 57 | 'piqa': 'chocolate', 58 | 'arc_easy': 'greenyellow', 59 | 'arc_challenge': 'violet', 60 | 'openbookqa': 'royalblue', 61 | 'winogrande': 'crimson', 62 | 'boolq': 'slategrey', 63 | 'cb': 'darkkhaki', 64 | 'wic': 'orangered', 65 | 'wsc': 'gold', 66 | 'multirc': 'lime', 67 | 'rte': 'steelblue', 68 | 'record': 'cyan', 69 | 'copa': 'magenta', 70 | 'lambada': 'saddlebrown', 71 | 'mathqa': 'gray' 72 | } --------------------------------------------------------------------------------