├── .gitignore ├── LICENSE ├── README.md ├── analysis.Rmd ├── data └── arc-challenge-easy-annotations.json ├── globals.py ├── main.py ├── models ├── __init__.py └── probes.py ├── requirements.txt ├── run_jobs.py └── utils ├── LM_utils.py ├── data_utils.py ├── metrics.py ├── modeling_utils.py ├── plotting_utils.py ├── prompt.py ├── training_logger.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | data/ 3 | result_sheets/ 4 | training_logs/ 5 | temp/ 6 | tmp_example.txt 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the codebase for the paper [The Unreasonable Effectiveness of Easy Training Data for Hard Tasks](https://arxiv.org/pdf/2401.06751.pdf). 2 | 3 | Below, we describe how to replicate the main experimental results in our paper. 4 | 5 | ### Experiment Commands 6 | 7 | We begin with a few examples of experiments that one should be able to run with the codebase. Note to use Llama-2 models, the `llama2_path` variable must be set in `utils/utils.py`. 8 | 9 | #### Install Requirements 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | #### Run llama-7b on ARC Challenge test with a zero-shot prompt 16 | 17 | Note that you must supply the `--model_dir` and `--cache_dir` args for saving/storing models by setting the `MODEL_DIR` and `CACHE_DIR` environment variables. Lowering the eval batch size (`-ebs`) to 4 (the minimum value given that ARC is 4-way multiple-choice) should help fit onto a smaller GPU. 18 | 19 | ``` 20 | python main.py --model huggyllama/llama-7b --do_eval true -llm true --probing_method decoding --dataset ai2_arc/ARC-Challenge --hardness_var_name NA --specific_prompt 0040 -pb 1 -np 1 --stratify_hardness false --k_shot 0 -ebs 8 --all_data_to_test true --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 21 | ``` 22 | 23 | #### Run llama-13b on our combined ARC data with a zero-shot prompt 24 | 25 | ``` 26 | python main.py --model huggyllama/llama-13b --do_eval true -llm true --probing_method decoding --dataset ai2_arc --hardness_var_name NA --specific_prompt 0040 -pb 1 -np 1 --stratify_hardness false --k_shot 0 -ebs 10 --all_data_to_test true --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 27 | ``` 28 | 29 | #### Run Mixtral-8x7B on college level MMLU-STEM-5 data with a 10-shot prompt containing high school examples, using 5 random seeds 30 | 31 | ``` 32 | python main.py --model mistralai/Mixtral-8x7B-v0.1 --do_eval true -llm true --probing_method decoding --dataset mmlu_STEM-5 --hardness_var_name human_hardness --specific_prompt 0040 -pb 5 -np 1 --stratify_hardness true --train_on easy --test_on hard --k_shot 10 -ebs 8 --all_data_to_test true --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 33 | ``` 34 | 35 | ### Paper Research Question Experiments 36 | 37 | Now we describe how to replicate the main results in our paper using the `run_jobs.py` file. In general, you have to edit the `use_models` and `use_methods` in this file in order to *not* run experiments across Llama-2-7b, Llama-2-13b, Llama-2-70b, and all relevant training method including ICL, ICL+CoT, linear probing, QLoRA, and QLoRA+CoT. Note that using `Llama-2-70b` requires four 48gb gpus to load in 8bit quantization. 38 | 39 | First, if you want to use linear models later on, then write model hidden states to file, which is a precursor to linear modeling. 40 | 41 | ``` 42 | python run_jobs.py -e write_hidden_states --dataset ai2_arc --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 43 | python run_jobs.py -e write_hidden_states --dataset mmlu_STEM-5 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 44 | python run_jobs.py -e write_hidden_states --dataset strategyQA --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 45 | ``` 46 | 47 | If you want to use model-based MDL metrics later on, estimate model-based hardness for these datasets. To use fewer than our four default 7b models, edit `globals.hardness_models`. 48 | 49 | ``` 50 | python run_jobs.py -e estimate_hardness --dataset ai2_arc --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 51 | python run_jobs.py -e estimate_hardness --dataset strategy-qa --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 52 | python run_jobs.py -e estimate_hardness --dataset mmlu_STEM-5 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 53 | ``` 54 | 55 | To get all-to-all performance (comparable to paper Table 4), run the following commands. 56 | 57 | ``` 58 | python run_jobs.py -e all_to_all_table --dataset ai2_arc -nb 5 -lc 0 --n_train 160 --k_shot 10 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 59 | python run_jobs.py -e all_to_all_table --dataset mmlu_STEM-5 -nb 5 -lc 0 --n_train 160 --k_shot 10 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 60 | python run_jobs.py -e all_to_all_table --dataset strategy-qa -nb 5 -lc 0 --n_train 160 --k_shot 8 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 61 | python run_jobs.py -e all_to_all_table --dataset gsm8k_main -nb 5 -lc 0 --n_train 160 --k_shot 8 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 62 | ``` 63 | 64 | Now to get results for the main easy-to-hard generalization results (RQ2 in the paper), run the below commands. To adjust which hardness measures are used for dataset stratification, adjust the value of `stratify_var_names`. 65 | 66 | ``` 67 | python run_jobs.py -e get_population_table --dataset ai2_arc -nb 5 -lc 0 --n_train 160 --k_shot 10 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 68 | python run_jobs.py -e get_population_table --dataset mmlu_STEM-5 -nb 5 -lc 0 --n_train 160 --k_shot 10 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 69 | python run_jobs.py -e get_population_table --dataset strategy-qa -nb 5 -lc 0 --n_train 160 --k_shot 8 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 70 | python run_jobs.py -e get_population_table --dataset gsm8k_main -nb 5 -lc 0 --n_train 160 --k_shot 8 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 71 | ``` 72 | 73 | To get our Figure 1 plot, which measures college test performance for a model prompted with 3rd grade / 8th grade / high school data, run: 74 | 75 | ``` 76 | python run_jobs.py -e third_grade_to_college -nb 5 -lc 0 --n_train 160 --k_shot 10 -rj 0 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 77 | ``` 78 | 79 | To get results with noisy training labels (RQ3), run: 80 | 81 | ``` 82 | python run_jobs.py -e noisy_labels_table --dataset mmlu_STEM-5 -nb 5 -lc 0 --n_train 160 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 83 | ``` 84 | 85 | To get learning curves with linear probes, to estimate performance w.r.t. training cost (RQ4), first set `use_methods=['learned_CoT=False]` and `use_models = ['Llama-2-70b']` in `get_population_table`, then run: 86 | 87 | ``` 88 | python run_jobs.py -e get_population_table --dataset ai2_arc -nb 10 -lc 1 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 89 | python run_jobs.py -e get_population_table --dataset mmlu_STEM-5 -nb 10 -lc 1 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 90 | python run_jobs.py -e get_population_table --dataset strategy-qa -nb 10 -lc 1 --model_dir $MODEL_DIR --cache_dir $CACHE_DIR 91 | ``` 92 | 93 | ## Data Analysis 94 | 95 | We provide the R markdown file used for data analysis. The above `run_jobs.py` experiments will output .csv's into a `result_sheets` directory. The `analysis.Rmd` file loads results from this directory for plotting. 96 | 97 | ### bibtex 98 | 99 | To cite this work, you can use 100 | 101 | ``` 102 | @misc{hase2024unreasonable, 103 | title={The Unreasonable Effectiveness of Easy Training Data for Hard Tasks}, 104 | author={Peter Hase and Mohit Bansal and Peter Clark and Sarah Wiegreffe}, 105 | year={2024}, 106 | eprint={2401.06751}, 107 | archivePrefix={arXiv}, 108 | primaryClass={cs.CL}, 109 | url={https://arxiv.org/pdf/2401.06751.pdf} 110 | } 111 | ``` 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /globals.py: -------------------------------------------------------------------------------- 1 | # burns_datasets = ["imdb", "amazon_polarity", "ag_news", "dbpedia_14", "copa", "rte", "boolq", "piqa", "qnli", "story_cloze"] 2 | mmlu_datasets = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions'] 3 | arc_datasets = ['ai2_arc/ARC-Easy', 'ai2_arc/ARC-Challenge'] 4 | 5 | # mmlu globals 6 | mmlu_subject_levels = ['college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_physics', 7 | 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics', 'high_school_physics'] 8 | mmlu_subjects = ['mmlu_biology', 'mmlu_chemistry', 'mmlu_computer_science', 'mmlu_mathematics', 'mmlu_physics'] 9 | mmlu_combined = ['mmlu_STEM-5'] 10 | third_grade_to_college = ['ai2_arc_all', 'mmlu_STEM-5'] 11 | 12 | # hardness vs. probing data globals 13 | known_hardness_data = ['ai2_arc', 'ai2_arc_all', 'strategy-qa', 'strategy-qa-dev', 'gsm8k_main', 'gsm8k', 'gsm8k_socratic', 'gsm8k_main_test', 'mmlu_subjects'] + mmlu_combined 14 | still_make_hardness_data = ['ai2_arc', 'gsm8k_main', 'gsm8k_socratic'] 15 | probing_data_only = ['strategy-qa', 'strategy-qa-dev', 'gsm8k_main_test'] # don't make a separate split for hardness model estimation 16 | # other eligible datasets 17 | eligible_datasets = mmlu_datasets + arc_datasets 18 | 19 | # define easy/medium/hard ranges/bounds for the data. default to 30/40/30 percentile chunks if set to None 20 | data_x_hardness_var_to_cutoffs = { 21 | 'ai2_arc': { 22 | 'human_bloom': (2,4), 23 | 'human_difficulty': (1,3), 24 | 'human_grade': (5,8), 25 | 'human_depth_of_knowledge': (1,3), 26 | }, 27 | 'ai2_arc_all': { 28 | 'human_bloom': (2,4), 29 | 'human_difficulty': (1,3), 30 | 'human_grade': (5,8), 31 | 'human_depth_of_knowledge': (1,3), 32 | }, 33 | 'mmlu_STEM-5': { 34 | 'human_hardness': (0,1), 35 | }, 36 | 'strategy-qa': { 37 | 'num_steps': (2,4), 38 | }, 39 | 'gsm8k_main': { 40 | 'num_steps': (4,7), 41 | }, 42 | } 43 | # mmlu extra stats to record 44 | mmlu_subject_stat_cols = [ 45 | 'math_prop_TRAIN', 46 | 'physics_prop_TRAIN', 47 | 'chem_prop_TRAIN', 48 | 'bio_prop_TRAIN', 49 | 'cs_prop_TRAIN', 50 | 'math_prop_TEST', 51 | 'physics_prop_TEST', 52 | 'chem_prop_TEST', 53 | 'bio_prop_TEST', 54 | 'cs_prop_TEST', 55 | ] 56 | # average hardness scores over these models 57 | hardness_models = [ 58 | "huggyllama/llama-7b", 59 | "tiiuae/falcon-7b", 60 | "mistralai/Mistral-7B-v0.1", 61 | "mosaicml/mpt-7b", 62 | ] 63 | llama_models = ['Llama-2-7b', 'Llama-2-13b', 'Llama-2-70b', 'Llama-2-7b-chat', 'Llama-2-13b-chat', 'Llama-2-70b-chat'] 64 | base_llama_models = ['Llama-2-70b', 'Llama-2-13b', 'Llama-2-7b'] 65 | llama_one_gpu_models = ['Llama-2-7b', 'Llama-2-13b'] 66 | one_gpu_models = [model for model in hardness_models + llama_models if not '70b' in model] 67 | four_gpu_models = [model for model in hardness_models + llama_models if '70b' in model] 68 | 69 | replicate_models = ['Llama-2-70b', 'Llama-2-70b-chat', 'mistralai/Mixtral-8x7B-v0.1', 'Qwen/Qwen-72B'] 70 | 71 | # don't use EleutherAI/ or facebook/ etc. prefixes below 72 | model_to_hidden_size = { 73 | 'gpt2-medium': 1024, 74 | 'gpt2-xl': 1600, 75 | 'gpt-j-6B': 4096, 76 | 't5-xl': 1024, # t5 not tested 77 | 't5-xxl': 1024, 78 | 'flan-t5-xl': 1024, 79 | 'flan-t5-xxl': 1024, 80 | 'llama-7b': 4096, 81 | 'llama-13b': 5120, 82 | 'llama-30b': 6656, # really 33b 83 | 'llama-65b': 8192, 84 | 'Llama-2-7b': 4096, 85 | 'Llama-2-13b': 5120, 86 | 'Llama-2-70b': 8192, 87 | 'Llama-2-7b-chat': 4096, 88 | 'Llama-2-13b-chat': 5120, 89 | 'Llama-2-70b-chat': 8192, 90 | 'falcon-7b': 4544, 91 | 'falcon-7b-instruct': 4544, 92 | 'falcon-40b': 8192, 93 | 'falcon-40b-instruct': 8192, 94 | 'persimmon-8b-base': 4096, # this is a 9.3b parameter model... 95 | 'mpt-7b': 4096, 96 | 'Mistral-7B-v0.1': 4096, 97 | 'opt-13b': 5120, 98 | 'Qwen-72B': 8192, 99 | 'Mixtral-8x7B-v0.1': 8192, # double check? 100 | } 101 | label_dict = { 102 | "imdb": ["negative", "positive"], # This is for normal IMDB 103 | "amazon_polarity": ["negative", "positive"], 104 | "ag_news": ["politics", "sports", "business", "technology"], 105 | "dbpedia_14": ["company", "educational institution", "artist", "athlete", "office holder", "mean of transportation", "building", "natural place", "village", "animal", "plant", "album", "film", "written work"], 106 | "copa": ["choice 1", "choice 2"], 107 | "rte": ["yes", "no"], # whether entail 108 | "boolq": ["false", "true"], 109 | "qnli": ["yes", "no"], # represent whether entail 110 | "piqa": ["solution 1", "solution 2"], 111 | "story_cloze": ["choice 1", "choice 2"], 112 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/easy-to-hard-generalization/ad8358e42904da169e8a3e3f70b64290d739b37a/models/__init__.py -------------------------------------------------------------------------------- /models/probes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import time 7 | import gc 8 | 9 | import bitsandbytes as bnb 10 | from transformers import AutoConfig, BertModel 11 | from transformers import get_scheduler 12 | from torch.optim import SGD, AdamW, Adam, LBFGS 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | from peft import AutoPeftModelForCausalLM 16 | 17 | from utils import utils 18 | from utils import LM_utils 19 | from utils import modeling_utils 20 | from utils import data_utils 21 | import copy 22 | 23 | class Probe(nn.Module): 24 | ''' 25 | This is a class for handling both few-shot prompting and supervised probing of LLMs. 26 | - initialized with a dataset of text inputs, multiple-choice answers, and labels 27 | - initialized with a Prompt object for formatting the data 28 | ''' 29 | def __init__(self, args, probing_method, probe_loss, 30 | tokenizer, 31 | probing_config=None, 32 | normalize_representations=False, calibrate_probe=False, 33 | model=None, num_classes=None): 34 | ''' 35 | args: 36 | args is from argparse in main.py 37 | datasets is a nested dict of {dataname: dataset}, where dataset contains pd dfs with inputs, multiple-choice answers, and labels. 38 | prompt is Prompt object from prompt.py 39 | normalize_representations: z-normalize hidden states used with a parametric probe. see main.py args.normalize_representations 40 | num_fits: number of times to fit the classifier. used with CCS to select the run with lowest unsupervised loss 41 | ''' 42 | super().__init__() 43 | self.args = args 44 | self.tokenizer = tokenizer 45 | self.normalize_representations = normalize_representations 46 | self.calibrate_probe = calibrate_probe # kind of a misnomer, this adjusts model preds to be mostly uniform across classes based on a provided dataset 47 | self.model = model 48 | self.num_classes = num_classes # num classes when doing classification of question_end hidden state 49 | # default params, reset by .fit() but used in self.loss, which may be called for prompt selection purposes with ICL 50 | self.prior_reg, self.prior = 0, 0 51 | self.l2_reg = 0 if args.optimize_weights != 'ABCD_embeddings' else args.l2_reg 52 | if args.probing_token_state == 'question_end_token': 53 | self.l2_prior = self.get_ABCD_embeddings(tokenizer, 'lm_head') # used when doing a regression on question_end_token hidden states 54 | else: 55 | self.l2_prior = None 56 | # set probing strategy 57 | self.probing_method = probing_method 58 | # normalization variables, set from train data 59 | self.mean_for_norming = None 60 | self.std_for_norming = None 61 | # calibration param making predicted distribution uniform in aggregate 62 | self.probs_centroid = None 63 | # for unsupervised probing 64 | self.probs_mean = None # will be 'true' or 'false', determine whether pred = argmax(probs) or argmin(probs) 65 | # probe_loss is used for prompt selection and fitting probes, so it applies to both ICL and probing 66 | self.probe_loss = probe_loss 67 | # probing args 68 | self.probing_config = probing_config 69 | if self.probing_method == "learned": 70 | self.probe_model = probing_config['probe_model'] 71 | self.hidden_size = probing_config['hidden_size'] 72 | # extend hidden size by the number of layers we pull representations, and double the size if we pull both encoder and decoder reps 73 | self.hidden_size *= len(probing_config['features_layers']) 74 | if len(probing_config['features_enc_dec']) == 2: 75 | self.hidden_size *= 2 76 | print(f" probe feature dimensionality is {self.hidden_size}") 77 | if self.probe_model == 'linear' and args.probing_token_state == 'answer_end_token': 78 | probe_model = nn.Linear(in_features=self.hidden_size, out_features=1, bias=False) 79 | if self.probe_model == 'linear' and args.probing_token_state == 'question_end_token': 80 | probe_model = nn.Linear(in_features=self.hidden_size, out_features=self.num_classes, bias=True) 81 | if self.probe_model == 'MLP': 82 | MLP_hidden_size = probing_config['hidden_size'] # use orig model hidden size (as opposed to self.hidden_size) to control parameter growth here 83 | probe_model = MLP(MLP_hidden_size, in_features=self.hidden_size, dropout_prob=0, num_classes=1) 84 | if self.probe_model == 'transformer': 85 | transformer_config = AutoConfig.from_pretrained('bert-base-uncased', cache_dir=args.cache_dir) 86 | transformer_config.hidden_size = self.hidden_size 87 | transformer_config.num_hidden_layers = 1 88 | transformer_config.hidden_dropout_prob = 0 89 | MLP_hidden_size = probing_config['hidden_size'] # use orig model hidden size (as opposed to self.hidden_size) to control parameter growth here 90 | probe_model = TransformerProbe(transformer_config, MLP_hidden_size=MLP_hidden_size, in_features=self.hidden_size) 91 | proper_normalization = (self.probe_loss != 'CCS') # needed for exact CCS replication 92 | if args.probing_token_state == 'answer_end_token': 93 | self.probe = MultipleChoiceClassifier(probe_model, 94 | proper_normalization=proper_normalization) 95 | elif args.probing_token_state == 'question_end_token': 96 | self.probe = LinearClassifier(probe_model, 97 | num_classes=self.num_classes) 98 | if self.args.n_gpu > 0: 99 | self.probe = self.probe.cuda() 100 | 101 | def set_calibration_params(self, probs=None, dataloader=None, verbose=False): 102 | ''' 103 | Used to calibrate predictions to be uniformly distributed over label space 104 | ''' 105 | if verbose: 106 | print("Calibrating probabilities to be uniform over classes...") 107 | if dataloader is not None: 108 | all_probs = [] 109 | for batch in dataloader: 110 | with torch.no_grad(): 111 | probs = self.forward(batch) # compute probs here 112 | all_probs.append(probs.detach().cpu()) 113 | all_probs = torch.concatenate(all_probs) 114 | if probs is not None: 115 | all_probs = probs 116 | all_preds = torch.argmax(all_probs, dim=1) 117 | probs_centroid, _ = torch.median(all_probs, dim=0) 118 | new_probs = all_probs - probs_centroid 119 | all_new_preds = torch.argmax(new_probs, dim=1) 120 | old_pred_distr = {y: round(torch.mean((all_preds==y).float()).item(), 2) for y in set(all_preds.cpu().numpy())} 121 | new_pred_distr = {y: round(torch.mean((all_new_preds==y).float()).item(), 2) for y in set(all_new_preds.cpu().numpy())} 122 | if verbose: 123 | print("Old pred distr: ", old_pred_distr) 124 | print("New pred distr: ", new_pred_distr) 125 | self.probs_centroid = probs_centroid.cuda() 126 | 127 | def set_normalization_params(self, dataloader): 128 | states = [] 129 | for batch in dataloader: 130 | states.append(batch['precomputed_hidden_states']) 131 | states = torch.concatenate(states, dim=0) 132 | states = self.select_hidden_states(states) # shape: n_items x n_answers x self.hidden_size 133 | # if dataset has fixed label space, make per-label norming params 134 | if self.args.data_source == 'burns': 135 | self.mean_for_norming = torch.mean(states, dim=0, keepdim=True) 136 | self.std_for_norming = torch.std(states, dim=0, keepdim=True) 137 | # otherwise, for multiple-choice problems, share information across answer choices 138 | else: 139 | states = states.view(-1, self.hidden_size) # first collapse the n_items and answer choices dimensions 140 | self.mean_for_norming = torch.mean(states, dim=0, keepdim=True) 141 | self.std_for_norming = torch.std(states, dim=0, keepdim=True) 142 | 143 | def safe_fit(self, args, log, dataloader, optimizer_name, epochs=100, l2_reg=1, max_grad_norm=1, prior=None, verbose=False, patience=5): 144 | # sometimes fit gives nan result, so this refits the model if the final model has nan weightss 145 | done = False 146 | counter = 0 147 | while not done: 148 | loss = self.fit(args, log, dataloader, optimizer_name, epochs=epochs, l2_reg=l2_reg, max_grad_norm=max_grad_norm, prior=prior, verbose=verbose) 149 | if not utils.check_nan_weights(self.probe): 150 | done = True 151 | elif counter > patience: 152 | done = True 153 | print(f"WARNING: COULD NOT FIT MODEL IN LESS THAN {patience} ATTEMPTS") 154 | else: 155 | print(f"WARNING: MODEL FAILED TO FIT. RETRYING...") 156 | counter += 1 157 | max_grad_norm /= 4 # cut grad clipping size 158 | l2_reg *= 4 # large increase to regularization 159 | return loss 160 | 161 | def repeated_fit(self, args, log, dataloader, optimizer_name, num_fits, prior=None, epochs=100, l2_reg=1, max_grad_norm=1, 162 | verbose=False, safe_fit=True): 163 | # fits a probe num_fits times to dataset, and selects model with best loss 164 | # for use with CCS 165 | best_loss = np.inf 166 | losses = [] 167 | for _ in range(num_fits): 168 | if safe_fit: 169 | loss = self.safe_fit(args, log, dataloader, optimizer_name, epochs=epochs, l2_reg=l2_reg, max_grad_norm=max_grad_norm, prior=prior, verbose=verbose) 170 | else: 171 | loss = self.fit(args, log, dataloader, optimizer_name, epochs=epochs, l2_reg=l2_reg, max_grad_norm=max_grad_norm, prior=prior, verbose=verbose) 172 | if loss < best_loss: 173 | self.best_probe = copy.deepcopy(self.probe) 174 | prob_meaning = self.probs_mean 175 | losses.append(loss) 176 | best_loss = np.argmin(losses) 177 | losses = sorted(losses) 178 | self.probe = self.best_probe 179 | self.probs_mean = prob_meaning 180 | del self.best_probe 181 | # print(f"Selecting probe with loss {best_loss} from {num_fits} fits (losses: {losses})") 182 | 183 | def fit(self, args, log, dataloader, optimizer_name, epochs=100, 184 | l2_reg=1, max_grad_norm=1, prior=None, 185 | verbose=False): 186 | ''' 187 | args 188 | dataloader: MCTextDataset dataloader 189 | - ideally includes precomputed hidden states, but we recompute as needed here 190 | epochs: number of epochs to run over data 191 | ''' 192 | assert self.probing_method == 'learned', "do not fit a probe if using decoding probing rather than learned probe" 193 | assert epochs > 0, "Epochs<=0 passed to probe.fit" 194 | num_answers = data_utils.get_max_num_answers(dataloader.dataset.dataframe) 195 | if verbose: 196 | print(f"Fitting probe to data...", end='\r') 197 | self.l2_reg = l2_reg # used in self.loss 198 | self.prior = prior # used in self.loss when self.probe_loss == 'unsupervised' 199 | self.prior_reg = 1 200 | self.probs_mean = None # indicates whether p(answer) is prob answer is true or prob answer is false 201 | forward_time = 0 202 | backward_time = 0 203 | n_steps = 0 204 | loss_history = [] 205 | acc_history = [] 206 | mem_history = [] 207 | # re-initialize probe and .train() -- temporarily swap torch seed from args.seed to the provided seed (which should be inherited from boot_idx) 208 | # THIS IS WHERE WE SET THE PRIOR FOR REGRESSION WEIGHTS TO ABCD EMBEDDINGS WHEN APPLICABLE 209 | self.probe.apply(self.weights_init) 210 | self.probe.train() 211 | # if doing a random probe, do one epoch just to get an accuracy, but don't step 212 | # random probes will get "flipped" to have accuracy >= .5, for fair comparison with CCS 213 | if self.probe_loss == 'random': 214 | epochs = 1 215 | 216 | # compute hidden_states if they are not precomputed 217 | precomputed_hidden_states = dataloader.dataset.precomputed_hidden_states 218 | if precomputed_hidden_states is None and self.probing_method == 'learned': 219 | mc_dataset = dataloader.dataset 220 | print(f"\nPrecomputing hidden representations for dataset of size {len(mc_dataset)}...") 221 | bs = 32 // num_answers 222 | unshuffled_dataloader = DataLoader(mc_dataset, shuffle=False, collate_fn=mc_dataset.collate_fn, pin_memory=False, num_workers=0, batch_size=bs) 223 | eval_output = modeling_utils.evaluate_model(self.args, log, self.model, unshuffled_dataloader, tokenizer=None, num_answers=num_answers, 224 | return_hidden_states=True) 225 | hidden_states = eval_output['hidden_states'] 226 | # assign hidden states to main dataloader 227 | dataloader.precomputed_hidden_states = hidden_states 228 | 229 | # define optimizer and scheduler 230 | params = self.probe.parameters() 231 | num_training_steps = epochs * len(dataloader) 232 | if self.probing_config['probe_model'] == 'linear': 233 | lr = 5e-2 234 | else: 235 | raise NotImplementedError("Check LRs for more expressive probes than linear") 236 | if optimizer_name == 'sgd': 237 | optimizer = SGD(params, lr=lr) 238 | if optimizer_name == 'adamw': 239 | optimizer = AdamW(params, lr=lr) 240 | if optimizer_name == 'adam': 241 | optimizer = Adam(params, lr=lr) 242 | if optimizer_name == 'LBFGS': 243 | optimizer = LBFGS(params, max_iter=1, history_size=10) 244 | if optimizer_name != 'LBFGS': 245 | lr_decay = self.probing_config['lr_decay'] 246 | if lr_decay == "constant": 247 | scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps) 248 | elif lr_decay in ['linear', '10-percent']: 249 | percent_of_orig_value = .1 if lr_decay == '10-percent' else 0 250 | multiplier = 1 / (1-percent_of_orig_value) 251 | scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=multiplier*num_training_steps) 252 | 253 | # prepare for normalizing representations as requested (see self.forward) 254 | if self.normalize_representations: 255 | self.set_normalization_params(dataloader) 256 | 257 | # pre-emptively unpack dataloader if it contains only one batch 258 | if len(dataloader) == 1: 259 | dataloader = [batch for batch in iter(dataloader)] 260 | start_fit_time = time.time() 261 | for e in range(epochs): 262 | losses = [] 263 | all_preds = [] 264 | all_labels = [] 265 | for batch in dataloader: 266 | labels = batch['label_idx'] 267 | if self.args.n_gpu > 0: 268 | labels = labels.cuda() 269 | with torch.enable_grad(): 270 | start = time.time() 271 | probs = self.forward(batch) # compute probs here 272 | forward_time += time.time() - start 273 | # compute loss and step. split by optimizer type 274 | if optimizer_name == 'LBFGS': 275 | def closure(): 276 | optimizer.zero_grad() 277 | loss = self.loss(labels, probs) 278 | loss.backward() 279 | torch.nn.utils.clip_grad_norm_(self.probe.parameters(), max_grad_norm) 280 | return loss 281 | start = time.time() 282 | if self.probe_loss != 'random': 283 | optimizer.step(closure) 284 | loss = self.loss(labels, probs).detach() 285 | backward_time += time.time() - start 286 | else: 287 | loss = self.loss(labels, probs) 288 | if self.probe_loss != 'random': 289 | start = time.time() 290 | optimizer.zero_grad() 291 | loss.backward() 292 | # loss.backward(retain_graph=(self.args.probing_token_state=='question_end_token')) 293 | torch.nn.utils.clip_grad_norm_(self.probe.parameters(), max_grad_norm) 294 | backward_time += time.time() - start 295 | optimizer.step() 296 | scheduler.step() # only step on non-lbfgs 297 | n_steps += 1 298 | preds = torch.argmax(probs, dim=-1) 299 | all_preds.extend(preds.tolist()) 300 | all_labels.extend(labels.tolist()) 301 | losses.append(loss.item()) 302 | del loss, probs 303 | # compute acc 304 | acc = np.mean(np.array(all_labels)==np.array(all_preds)) 305 | loss_history.append(round(np.mean(losses),2)) 306 | acc_history.append(round(acc,2)) 307 | gpu_mem = utils.get_gpu_utilization() if "cuda" in str(args.device) else None 308 | mem_history.append(gpu_mem) 309 | # if doing unsupervised, set whether probs for an answer choice represent prob true or prob false 310 | # this is some amount of supervision...Burns suggests you could do this step in an unsupervised way 311 | if self.probe_loss != 'supervised': 312 | self.probs_mean = 'false' if acc < .5 else 'true' 313 | acc_history = np.array(acc_history) 314 | acc_history = 1-acc_history if self.probs_mean == 'false' else acc 315 | if verbose: 316 | print(f"Fitting probe to data...took {(time.time() - start_fit_time):.2f} seconds", end='\n') 317 | # print(f"Loss history: ", loss_history) 318 | # print(f"Acc history: ", acc_history) 319 | if len(loss_history) == 0: 320 | loss_history.append(-1) 321 | self.probe.eval() 322 | return loss_history[-1] 323 | 324 | def finetune(self, args, log, train_dataloader, tokenizer, epochs=100, grad_accumulation_factor=1, 325 | dev_dataloader=None, eval_every_n_epochs=None, 326 | model_selection='NA', 327 | verbose=False): 328 | ''' 329 | For full model finetuning, or parameter-efficient finetuning 330 | args 331 | dataloader: MCTextDataset dataloader 332 | epochs: number of epochs to run over data 333 | model_selection: pick best model epoch based on this statistic in log_stats 334 | break_after_e_epochs: break early after a selected number of epochs 335 | ''' 336 | assert self.probing_method == 'finetuned' 337 | num_batches = len(train_dataloader) 338 | num_items = len(train_dataloader.dataset) 339 | num_training_steps = epochs * int(np.ceil(num_batches / grad_accumulation_factor)) 340 | best_acc = -1 341 | # set tmp save/load path 342 | if model_selection != 'NA': 343 | tmp_save_load_path = os.path.join(args.model_dir, 'tmp') 344 | if verbose: 345 | effective_num_answers = train_dataloader.dataset.get_effective_num_answers() 346 | print(f"Fitting probe to data...", end='\r') 347 | print(f"Epochs: {epochs} | Num items: {num_items} | Num answers: {effective_num_answers} | Num points {num_items*effective_num_answers}") 348 | print(f"Batch size: {args.train_batch_size} | Batches per epoch: {num_batches} | Total opt steps: {num_training_steps}") 349 | n_items_per_batch = train_dataloader.batch_size 350 | n_batches_per_step = grad_accumulation_factor 351 | print("NUM ITEMS PER GRADIENT STEP:", n_items_per_batch * n_batches_per_step) 352 | self.l2_reg = 0 # used in self.loss 353 | compute_mc_probs = args.finetuning_objective == 'MC' 354 | train_stats = { 355 | 'n_batches': 1, 356 | 'forward_time_sum' : 0, 357 | 'backward_time_sum' : 0, 358 | 'acc': -1, 359 | 'loss': -1, 360 | } 361 | total_batches = len(train_dataloader) 362 | self.model.train() 363 | # define optimizer and schedules 364 | if self.args.optimize_weights in ['all', 'LORA']: 365 | decay_parameters = [name for name, p in self.model.named_parameters() if 'layernorm' not in name.lower() and "bias" not in name.lower()] 366 | params = [ 367 | { 368 | "params": [p for n, p in self.model.named_parameters() if n in decay_parameters], 369 | "weight_decay": self.args.weight_decay, 370 | }, 371 | { 372 | "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters], 373 | "weight_decay": 0.0, 374 | }, 375 | ] 376 | elif self.args.optimize_weights == 'embeddings': 377 | embed_param_names = [n for n,p in self.model.named_parameters() if 'embed' in n or 'lm_head' in n] 378 | print(" Optimizing these params: ", embed_param_names) 379 | params = [p for n,p in self.model.named_parameters() if n in embed_param_names] 380 | elif self.args.optimize_weights == 'ABCD_embeddings': 381 | embed_param_names = [n for n,p in self.model.named_parameters() if 'embed' in n or 'lm_head' in n] 382 | assert len(embed_param_names) > 0, f"Couldn't find the lm_head params for {self.args.model}" 383 | print(" Optimizing these params: ", embed_param_names) 384 | params = [p for n,p in self.model.named_parameters() if n in embed_param_names] 385 | ABCD_token_ids = [tokenizer.encode(x, add_special_tokens=False)[0] for x in ['A', 'B', 'C', 'D']] 386 | ABCD_rows = np.array(ABCD_token_ids) 387 | zero_grad_rows = np.setdiff1d(np.arange(len(self.tokenizer)), ABCD_rows) 388 | zero_grad_rows = torch.tensor(zero_grad_rows).cuda() 389 | lr = args.probing_lr 390 | if args.quantization == '8bit': 391 | optimizer = bnb.optim.Adam8bit(params, lr=lr) 392 | else: 393 | optimizer = AdamW(params, lr=lr) 394 | # get scheduler 395 | lr_decay = self.probing_config['lr_decay'] 396 | if lr_decay == "constant": 397 | scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps) 398 | elif lr_decay in ['linear', '10-percent']: 399 | percent_of_orig_value = .1 if lr_decay == '10-percent' else 0 400 | multiplier = 1 / (1-percent_of_orig_value) 401 | scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=multiplier*num_training_steps) 402 | 403 | # start training 404 | start_fit_time = time.time() 405 | # pre-emptively unpack dataloader if it contains only one batch 406 | if len(train_dataloader) == 1: 407 | train_dataloader = [batch for batch in iter(train_dataloader)] 408 | for e in range(1, epochs+1): 409 | epoch_stats = { 410 | 'acc_sum': 0, 411 | 'loss_sum': 0, 412 | 'probe_loss_sum': 0, # used for model selection, may be unsupervised objective 413 | 'n_data_points': 0, 414 | } 415 | for batch_num, batch in enumerate(train_dataloader): 416 | running_time = (time.time()-start_fit_time) 417 | est_run_time = (running_time/train_stats['n_batches']*total_batches*epochs) 418 | forward_time = train_stats['forward_time_sum'] / train_stats['n_batches'] 419 | if verbose: 420 | gpu_mem = utils.get_gpu_utilization() if "cuda" in str(args.device) else None 421 | log.print_training_prog(train_stats, e, epochs, batch_num, len(train_dataloader), running_time, est_run_time, forward_time, gpu_mem=gpu_mem) 422 | labels = batch['label_idx'].cuda() 423 | with torch.enable_grad(): 424 | start = time.time() 425 | probs = self.forward(batch, compute_mc_probs=compute_mc_probs) # compute probs here 426 | probs = probs.cuda() # syncs probs with labels in multi-gpu case 427 | train_stats['forward_time_sum'] += time.time() - start 428 | # compute loss, scale by batch size, and step. batch size scaling keeps grad norms consistent when last batch size is smaller than others 429 | bs_as_frac_of_max_bs = batch['input_ids'].size(0) / args.train_batch_size 430 | labels = labels.cpu() 431 | probs = probs.cpu() 432 | loss = self.loss(labels, probs) / grad_accumulation_factor * bs_as_frac_of_max_bs 433 | start = time.time() 434 | loss.backward() 435 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.) 436 | train_stats['backward_time_sum'] += time.time() - start 437 | if train_stats['n_batches'] % grad_accumulation_factor == 0: 438 | # zero non A/B/C/D token embedding rows if necessary 439 | if args.optimize_weights == 'ABCD_embeddings': 440 | self.model.lm_head.weight.grad[zero_grad_rows,:] = 0 441 | optimizer.step() 442 | optimizer.zero_grad() 443 | scheduler.step() 444 | # end of epoch stats 445 | epoch_stats['loss_sum'] += loss.item() 446 | epoch_stats['n_data_points'] += len(batch['items']) 447 | if compute_mc_probs: 448 | preds = torch.argmax(probs, dim=-1) 449 | binary_correct = preds==labels 450 | n_correct = torch.sum(binary_correct).item() 451 | epoch_stats['acc_sum'] += n_correct 452 | train_stats['acc'] = epoch_stats['acc_sum'] / epoch_stats['n_data_points'] 453 | # update eval stats 454 | train_stats['loss'] = epoch_stats['loss_sum'] / (batch_num+1) 455 | train_stats['probe_loss'] = epoch_stats['probe_loss_sum'] / (batch_num+1) 456 | train_stats['n_batches'] += 1 457 | # print examples 458 | if verbose: 459 | if (batch_num == 0 and e == 1 and args.num_print > 0): 460 | print_idx = list(range(min(args.num_print, len(batch['items'])))) 461 | else: 462 | print_idx = [] 463 | if len(print_idx) > 0: 464 | print("\n" + "-"*20 + f"\nPrinting examples:") 465 | if e == 1: 466 | print(f" Input 0 : {tokenizer.decode(batch['input_ids'][0])}") 467 | for i in print_idx: 468 | answer_choices = ['A', 'B', 'C', 'D'] if args.use_letter_labels else batch['answers_list'][i] 469 | prompt = batch['prompts'][i] 470 | print(f" point {i}") 471 | print(f" Prompt : \n{prompt}") 472 | if compute_mc_probs: 473 | item_probs = [np.round(x.item(), 4) for x in probs[i].cpu()] 474 | print(f" Preds : {[x for x in zip(answer_choices, item_probs)]}") 475 | pred = answer_choices[preds[i]] 476 | print(f" Pred : {pred}") 477 | print(f" Label : {batch['label_strs'][i]}") 478 | print(f" Correct : {binary_correct[i].item()}") 479 | if i != print_idx[-1]: 480 | print() 481 | print("-"*20 + '\n') 482 | del loss, probs, batch 483 | # eval model on dev data 484 | if eval_every_n_epochs > 0 and e % eval_every_n_epochs == 0: 485 | print(" Evaluating model...") 486 | dev_stats = modeling_utils.evaluate_model(args=args, 487 | log=log, 488 | model=self, 489 | dataloader=dev_dataloader, 490 | tokenizer=tokenizer, 491 | verbose=verbose) 492 | self.model.train() 493 | log_stats = { 494 | 'LR': scheduler.get_last_lr()[0], # two param groups...just take one, 495 | 'epoch': e, 496 | 'train_loss': train_stats['loss'], 497 | 'train_acc': train_stats['acc'] if compute_mc_probs else -1, 498 | 'dev_loss': dev_stats['loss'], 499 | 'dev_acc': dev_stats['acc'], 500 | } 501 | log.print_epoch_scores(epoch=e, scores=log_stats) 502 | log.add_to_log(log_stats) 503 | log.save_plots(n_train=num_items) 504 | # save best model 505 | if model_selection != 'NA': 506 | if model_selection == 'train_acc': 507 | select_acc = train_stats['acc'] 508 | elif model_selection == 'dev_acc': 509 | assert eval_every_n_epochs >= 1 510 | dev_stats_just_calculated = (e % eval_every_n_epochs == 0) 511 | select_acc = dev_stats['acc'] if dev_stats_just_calculated else -1 512 | if select_acc > best_acc: 513 | best_acc = select_acc 514 | print(f"Saving new best probe at: ", tmp_save_load_path) 515 | save_start = time.time() 516 | self.save_probe(tmp_save_load_path) 517 | print("Saving probe took: ", utils.format_time(time.time()-save_start)) 518 | if verbose: 519 | print(f"Fitting model to data...took {(time.time() - start_fit_time):.2f} seconds", end='\n') 520 | if args.dev_eval_every_epochs > 0: 521 | log.reset_log() 522 | self.model.eval() 523 | # load best model 524 | if model_selection != 'NA': 525 | self.load_probe(tmp_save_load_path) 526 | 527 | def loss(self, Y, probs): 528 | # assumes probs of shape n x m, and Y of shape n containing label ids 529 | if self.probe_loss == 'supervised': 530 | label_probs = torch.gather(probs, 1, Y.view(-1, 1)) 531 | nll = -torch.log(label_probs).mean() 532 | l2_norm = 0 533 | if self.probing_method == 'learned': 534 | if self.l2_prior is None: 535 | for param in self.probe.parameters(): 536 | l2_norm += torch.linalg.vector_norm(param) 537 | else: 538 | assert self.probing_config['probe_model'] == 'linear' 539 | # l2_norm += torch.linalg.norm(self.l2_prior - self.probe.probe_model.weight) 540 | l2_norm += self.probe.probe_model.weight.norm() 541 | print(l2_norm) 542 | return nll + self.l2_reg * l2_norm 543 | elif self.probe_loss in ['LM_loss']: 544 | nll = -probs.mean() # assuming that finetuning_objective=seq2seq, the probs are already loglikes here 545 | return nll 546 | elif self.probe_loss in ['CCS', 'CCS_ours', 'random']: # 'random' condition never actually steps optimizer 547 | min_probs, _ = torch.min(probs, dim=1) 548 | informative_loss = (min_probs**2).mean(0) 549 | consistent_loss = ((1 - probs.sum(1))**2).mean() # always 0 when MC classifier uses proper_normalization=True 550 | return informative_loss + consistent_loss 551 | elif self.probe_loss == 'unsupervised': 552 | max_probs, _ = torch.max(probs, dim=1) 553 | confidence_loss = -torch.log(max_probs).mean() 554 | prior_loss = (max_probs.mean() - self.prior)**2 # this is a calibration loss 555 | l2_norm = 0 556 | if self.probing_method == 'learned': 557 | for param in self.probe.parameters(): 558 | l2_norm += torch.linalg.vector_norm(param) 559 | return confidence_loss + self.prior_reg * prior_loss + self.l2_reg * l2_norm 560 | 561 | def get_ABCD_embeddings(self, tokenizer, embeds_name='lm_head'): 562 | assert self.model is not None 563 | ABCD_token_ids = np.array([tokenizer.encode(x, add_special_tokens=False)[0] for x in ['A', 'B', 'C', 'D']]) 564 | params = [p for n,p in self.model.named_parameters() if embeds_name in n] 565 | assert len(params) == 1, f"Looking for the {embeds_name} params for {self.args.model} gave != 1 matching value" 566 | embeds = params[0] 567 | return embeds[ABCD_token_ids] 568 | 569 | def weights_init(self, m): 570 | if isinstance(m, nn.Linear): 571 | if self.args.probing_token_state == 'question_end_token': 572 | init_weights = self.get_ABCD_embeddings(self.tokenizer, 'lm_head') 573 | init_weights = init_weights.to(torch.float32) 574 | m.weight = torch.nn.Parameter(init_weights) 575 | else: 576 | m.reset_parameters() 577 | else: 578 | if not self.probing_config['probe_model'] in ['linear', 'MLP']: 579 | print("not re-initializing: ", m) 580 | import pdb; pdb.set_trace() 581 | 582 | def forward(self, batch, compute_mc_probs=True): 583 | ''' 584 | branch function output based on self.probing_method 585 | ''' 586 | if self.probing_method in ['decoding', 'finetuned']: 587 | assert self.model is not None, "if decoding to score answers, model must be provided to Probe at init" 588 | main_kwargs = { 589 | 'input_ids': batch['input_ids'], 590 | 'attention_mask': batch['attention_mask'], 591 | 'labels': batch['input_ids'], 592 | 'targets_mask': batch['targets_mask'], 593 | 'answer_choices': batch['answer_choices'], 594 | } 595 | utils.move_kwargs_to_gpu(main_kwargs) 596 | if compute_mc_probs: 597 | num_answers_list = batch['num_answers_list'] 598 | assert not all(x==1 for x in num_answers_list), "Trying to compute mc probs but num_answers are all 1" 599 | else: 600 | num_answers_list = None 601 | probs = LM_utils.compute_probs_from_batch(self.model, 602 | main_kwargs, 603 | return_value=self.args.answer_scoring, 604 | num_answers_list=num_answers_list) 605 | if self.probing_method == 'learned': 606 | # first get hidden states. may compute LLM forward pass as needed 607 | if 'precomputed_hidden_states' in batch: 608 | hidden_states = batch['precomputed_hidden_states'] # shape: n_items x n_answers x enc_dec x num_layers x hidden_size 609 | else: 610 | assert self.model is not None, "if computing hidden states for data on the fly, model must be provided to Probe at init" 611 | main_kwargs = { 612 | 'input_ids': batch['input_ids'], 613 | 'attention_mask': batch['attention_mask'], 614 | 'labels': batch['input_ids'], 615 | 'targets_mask': batch['targets_mask'], 616 | } 617 | utils.move_kwargs_to_gpu(main_kwargs) 618 | _, hidden_states_dict = LM_utils.compute_probs_from_batch(self.model, main_kwargs, return_value='probs', return_hidden_states=True) 619 | hidden_states = LM_utils.get_last_token_hidden_states(hidden_states_dict, 620 | max_num_answers = max(batch['num_answers_list']), 621 | num_answers_list=batch['num_answers_list']) 622 | hidden_states = torch.tensor(hidden_states) 623 | # select hidden states based on probing feature space config 624 | hidden_states = self.select_hidden_states(hidden_states) 625 | # normalize using params obtained during .fit() 626 | if self.normalize_representations: 627 | assert self.mean_for_norming is not None, "need to call probe.set_normalization_params before applying this probe" 628 | hidden_states = (hidden_states - self.mean_for_norming) / self.std_for_norming 629 | # now send to gpu and do forward pass 630 | if self.args.n_gpu > 0: 631 | hidden_states = hidden_states.cuda() 632 | probs = self.probe(hidden_states) 633 | # flip and renormalize probs if prob = prob(false) rather than prob(true) 634 | if self.probs_mean == 'false': 635 | flip_probs = 1 - probs 636 | probs = flip_probs / torch.sum(flip_probs, dim=-1, keepdim=True) 637 | if self.calibrate_probe and self.probs_centroid is not None: 638 | probs = probs - self.probs_centroid 639 | min_probs, _ = torch.min(probs, dim=-1, keepdim=True) 640 | # this step recalibrates the predicted label distribution to be near uniform (see self.set_calibration_params) 641 | # but the probabilities themselves have to be artificially renormed+smoothed, which we do in a very arbitrary way 642 | probs = probs - min_probs # first make non-negative 643 | probs = probs + .01 # artificial choice for smoothing 644 | probs = probs / torch.sum(probs, dim=1, keepdim=True) # renormalize 645 | return probs 646 | 647 | def select_hidden_states(self, hidden_states): 648 | # select hidden states before moving to gpu 649 | enc_dec_idx = [] 650 | for component in self.probing_config['features_enc_dec']: 651 | if component == 'decoder': 652 | enc_dec_idx.append(0) 653 | if component == 'encoder': 654 | enc_dec_idx.append(1) 655 | enc_dec_idx = torch.tensor(enc_dec_idx) 656 | layer_idx = torch.tensor([int(x) for x in self.probing_config['features_layers']]) 657 | hidden_states = torch.index_select(hidden_states, 2, enc_dec_idx) # this indexing keeps indexed dimension 658 | hidden_states = torch.index_select(hidden_states, 3, layer_idx) # this indexing keeps indexed dimension 659 | # reshape to n_items x n_answer x probe_hidden_size 660 | num_items = hidden_states.size(0) 661 | num_answers = hidden_states.size(1) 662 | hidden_states = hidden_states.view(num_items, num_answers, self.hidden_size) 663 | return hidden_states 664 | 665 | def save_probe(self, save_path): 666 | if self.probing_method == 'learned': 667 | state_dict = self.probe.state_dict() 668 | torch.save(state_dict, save_path) 669 | elif self.probing_method == 'finetuned': 670 | self.model.save_pretrained(save_path) 671 | 672 | def load_probe(self, load_path): 673 | ''' 674 | This is for loading lightweight probing classifiers or LORA weights 675 | ''' 676 | assert os.path.exists(load_path), f"Trying to load state dict from {load_path} but does not exist" 677 | if self.probing_method == 'learned': 678 | state_dict = torch.load(load_path) 679 | self.probe.load_state_dict(state_dict, strict=True) 680 | self.probe.eval() 681 | elif self.probing_method == 'finetuned': 682 | # delete the existing model before loading new one 683 | # if self.args.optimize_weights != 'LORA': 684 | for x in gc.get_referrers(self.model): 685 | del x 686 | del self.model 687 | self.model = utils.load_model(self.args, load_path) 688 | self.model.eval() 689 | 690 | class MultipleChoiceClassifier(nn.Module): 691 | ''' 692 | Expects data of shape n x m x d for n samples, m answer choices, and d dimensional features 693 | Returns probs in forward pass, of shape n x m 694 | ''' 695 | def __init__(self, probe_model, proper_normalization=True): 696 | super().__init__() 697 | self.probe_model = probe_model 698 | self.proper_normalization = proper_normalization 699 | 700 | def forward(self, X): 701 | # assume X of shape n x m x d -- n items with m answer choices, d representation size 702 | scores = self.probe_model(X) 703 | scores = scores.view(X.size(0), X.size(1)) # drop last hidden dimension, which is now 1 704 | if self.proper_normalization: 705 | probs = torch.softmax(scores, dim=-1) 706 | else: 707 | probs = torch.sigmoid(scores) # used to exactly replicate CCS 708 | return probs 709 | 710 | class LinearClassifier(nn.Module): 711 | ''' 712 | Expects data of shape n x d for n samples with d dimensional features 713 | Returns probs in forward pass, of shape n x num_classes 714 | ''' 715 | def __init__(self, probe_model, num_classes): 716 | super().__init__() 717 | self.probe_model = probe_model 718 | self.num_classes = num_classes 719 | 720 | def forward(self, X): 721 | # assume X of shape n x d -- n items, d representation size 722 | scores = self.probe_model(X) 723 | scores = scores.view(X.size(0), self.num_classes) # drop last hidden dimension, which is now 1 724 | probs = torch.softmax(scores, dim=-1) 725 | return probs 726 | 727 | 728 | class MLP(nn.Module): 729 | def __init__(self, hidden_size, in_features=None, dropout_prob=.1, num_classes=2): 730 | super().__init__() 731 | if in_features is None: 732 | in_features = hidden_size 733 | self.classifier = nn.Sequential( 734 | nn.Linear(in_features, hidden_size), 735 | nn.Tanh(), 736 | nn.Dropout(p=dropout_prob), 737 | nn.Linear(hidden_size, num_classes), 738 | ) 739 | def forward(self, hidden_states, **kwargs): 740 | return self.classifier(hidden_states) 741 | 742 | class IndexingModule(nn.Module): 743 | def __init__(self): 744 | super().__init__() 745 | def forward(self, hidden_states, **kwargs): 746 | return hidden_states[:,-1,...] 747 | 748 | class TransformerProbe(nn.Module): 749 | def __init__(self, config, num_answers): 750 | super().__init__() 751 | transformer = BertModel(config=config) # random weights 752 | self.indexing = IndexingModule() 753 | self.transformer = transformer 754 | self.classifier = MLP(config.hidden_size, config.hidden_dropout_prob, num_classes=1) 755 | def forward(self, hidden_states, **kwargs): 756 | outputs = self.transformer(inputs_embeds=hidden_states, attention_mask=kwargs['attention_mask']) 757 | last_index_rep = self.indexing(outputs.hidden_states[-1]) 758 | scores = self.classifier(last_index_rep) 759 | return scores -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.2 2 | torch==2.0.0 3 | transformers==4.36.2 4 | accelerate==0.23.0 5 | pandas==1.5.3 6 | scipy 7 | jsonlines 8 | deepspeed==0.8.3 9 | nvidia-ml-py3 10 | matplotlib==3.7.1 11 | datasets==2.12.0 12 | pynvml==11.5 13 | seaborn==0.12.2 14 | einops==0.6.1 15 | peft==0.5.0 16 | bitsandbytes==0.41.0 17 | sentencepiece==0.1.99 18 | tiktoken==0.5.2 19 | transformers_stream_generator==0.0.4 20 | # promptsource==0.2.3 -------------------------------------------------------------------------------- /utils/LM_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import CrossEntropyLoss 3 | import numpy as np 4 | import transformers 5 | import sys, os 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))) 7 | import utils 8 | import metrics 9 | from itertools import chain 10 | from copy import deepcopy 11 | import re 12 | 13 | def str_clean(data): 14 | if data is not None: 15 | return data.strip().lower() 16 | else: 17 | return None 18 | 19 | def renormalize_mc_pred_probs(pred_probs, use_softmax=False): 20 | # assumes pred_probs of shape num_items x num_answers 21 | # make sure no elements are negative for sum calculation (e.g., need to shift log_probs) 22 | if use_softmax: 23 | pred_probs = torch.softmax(pred_probs, dim=-1) 24 | else: 25 | sums = pred_probs.sum(1, keepdim=True) 26 | pred_probs = pred_probs / sums 27 | return pred_probs 28 | 29 | def compute_mc_loss(pred_probs, answer_idx): 30 | # assumes pred_probs of shape num_items x num_answers 31 | # assumes answer_idx of shape [num_items], containing idx of answer in range(0,m) assuming m answer choices 32 | label_probs = torch.gather(pred_probs, 1, answer_idx.view(-1, 1)).squeeze(-1) 33 | nll = -torch.log(label_probs).mean() 34 | return nll, label_probs 35 | 36 | def compute_probs_from_batch(model, batch, return_value='log_probs', pad_token_id=None, 37 | return_hidden_states=False, num_answers_list=None): 38 | ''' 39 | Compute label probabilities for decoder-only model, where labels are shifted by one from input ids 40 | Always returns one value per sequence 41 | - the reason that we get and write hidden states to file through this function is to do a sanity check that zero-shot accuracy is similar to later experiments 42 | ''' 43 | assert return_value in ['probs', 'log_probs', 'log_probs_token_normed', 'log_probs_char_normed'] 44 | model_batch = { 45 | 'input_ids' : batch['input_ids'], 46 | 'attention_mask' : batch['attention_mask'] 47 | } 48 | target_tokens = batch['labels'] 49 | if 'targets_mask' in batch and batch['targets_mask'] is not None: 50 | target_mask = batch['targets_mask'] 51 | else: 52 | target_mask = target_tokens != pad_token_id 53 | outputs = model(**model_batch, output_hidden_states=return_hidden_states) 54 | logits = outputs.logits 55 | loss_fct = CrossEntropyLoss(reduction='none') 56 | shift_logits = logits[..., :-1, :] 57 | shift_labels = target_tokens[..., 1:] 58 | shift_mask = target_mask[...,1:] 59 | nll = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) 60 | nll = nll.reshape(logits.size(0), -1) # batch size x num tokens 61 | if return_value == 'log_probs': 62 | nll = (shift_mask * nll).sum(-1) # sum over token dimension 63 | probs = -nll # one value per sequence 64 | elif return_value == 'probs': 65 | nll = (shift_mask * nll).sum(-1) # sum over token dimension 66 | probs = torch.exp(-nll) # one probability per input sequence 67 | elif return_value == 'log_probs_token_normed': 68 | nll = (shift_mask * nll).sum(-1) # sum over token dimension 69 | seq_lens = shift_mask.sum(-1) # for shifted targets 70 | probs = -nll / seq_lens 71 | elif return_value == 'log_probs_char_normed': 72 | nll = (shift_mask * nll).sum(-1) # sum over token dimension 73 | seq_num_chars = [len(answer) for answer in batch['answer_choices']] 74 | seq_num_chars = torch.tensor(seq_num_chars).to(nll.device) 75 | probs = -nll / seq_num_chars 76 | # reshape probs to be num_items x num_answers 77 | if num_answers_list is not None: 78 | assert not all(x==1 for x in num_answers_list), "Only single choice items passed to compute_probs as num_answers_list" 79 | num_items = len(num_answers_list) 80 | all_same_num_answers = all([num_answers_list[i] == num_answers_list[0] for i in range(num_items)]) 81 | if all_same_num_answers: 82 | num_answers = num_answers_list[0] 83 | num_items = batch['input_ids'].size(0) // num_answers 84 | probs = probs.reshape(num_items, num_answers) 85 | probs = renormalize_mc_pred_probs(probs, use_softmax=(return_value!='probs')) 86 | else: 87 | max_num_answers = max(num_answers_list) # need to pad probs with fewer than max_num_answers 88 | split_probs = torch.split(probs, num_answers_list) 89 | all_probs = [] 90 | for item_probs in split_probs: 91 | item_probs = item_probs.view(1, -1) 92 | _probs = renormalize_mc_pred_probs(item_probs, use_softmax=(return_value!='probs')) 93 | padding = torch.zeros((1,max_num_answers-item_probs.shape[1])).cuda() 94 | _probs = torch.concatenate([_probs, padding], dim=-1) 95 | all_probs.append(_probs) 96 | probs = torch.concatenate(all_probs) 97 | else: 98 | probs = probs.reshape(len(probs), 1) 99 | if return_hidden_states: 100 | hidden_states_dict = {states_name: getattr(outputs, states_name) for states_name in ['decoder_hidden_states', 'encoder_hidden_states', 'hidden_states'] if hasattr(outputs, states_name)} 101 | return probs, hidden_states_dict 102 | else: 103 | return probs 104 | 105 | def get_hidden_states_from_batch(model, batch): 106 | model_batch = { 107 | 'input_ids' : batch['input_ids'], 108 | 'attention_mask' : batch['attention_mask'] 109 | } 110 | outputs = model(**model_batch, output_hidden_states=True) 111 | hidden_states_dict = {states_name: getattr(outputs, states_name) for states_name in ['decoder_hidden_states', 'encoder_hidden_states', 'hidden_states'] if hasattr(outputs, states_name)} 112 | return hidden_states_dict 113 | 114 | def get_last_token_hidden_states(hidden_states_dict, num_answers=1, num_answers_list=None, max_num_answers=None): 115 | ''' 116 | This function gathers hidden states from model output, reshapes based on number answers, and stacks/concats into a single array to return 117 | Items could have different numbers of answers, so that is handled with num_answers_list with padding up to max_num_answers 118 | 119 | args: 120 | hidden_states_dict: output k,v pairs from model forward pass when 'hidden_states' in k 121 | num_answers: used to reshape, assuming that the model forward pass was 'flattened' but originally contained aa 122 | returns 123 | new_hidden_states: np ndarray of shape: bs x num_answers x enc/dec x num_layers x hidden_size 124 | ''' 125 | # add num_answers dimension and stack layers 126 | # new shape is bs x num_answers x seq_len x num_layers x hidden_size 127 | if num_answers_list is None: # this way is never used in our code, because we always pass num_answers_list for generality 128 | for k,v in hidden_states_dict.items(): # iterate across decoder/encoder hidden states 129 | v = torch.stack(v, dim=-2) # stack layers of hidden states in second to last dimension 130 | hidden_shape = list(v.shape) # bs x seq_len x num_layers x hidden_size 131 | hidden_shape[0] = hidden_shape[0] // num_answers # cut batch size by num_answers 132 | hidden_shape.insert(1, num_answers) # insert num_answers after bs (really, now num_items rather than orig bs) 133 | v = v.reshape(*hidden_shape) 134 | hidden_states_dict[k] = v.cpu() 135 | else: 136 | for k,v in hidden_states_dict.items(): # iterate across decoder/encoder hidden states 137 | v = torch.stack(v, dim=-2) # stack layers of hidden states in second to last dimension 138 | padded_hidden_states = [] 139 | per_item_hidden_states = torch.split(v, num_answers_list) 140 | for item_hidden_states in per_item_hidden_states: 141 | num_answers = len(item_hidden_states) 142 | hidden_shape = list(item_hidden_states.shape) # item_num_answers x seq_len x num_layers x hidden_size 143 | hidden_shape.insert(0, 1) # insert bs dim of 1 144 | item_hidden_states = item_hidden_states.view(*hidden_shape) 145 | # need to pad v with zeros up to max_num_answers 146 | zeros_shape = deepcopy(hidden_shape) 147 | zeros_shape[1] = max_num_answers - num_answers 148 | zeros = torch.zeros(*zeros_shape) 149 | item_hidden_states = torch.concatenate((item_hidden_states.cpu(), zeros), dim=1) 150 | padded_hidden_states.append(item_hidden_states) 151 | hidden_states_dict[k] = torch.concatenate(padded_hidden_states) 152 | # stack enc/dec and grab last token index 153 | # new shape is bs x num_answers x enc/dec x num_layers x hidden_size 154 | if 'decoder_hidden_states' in hidden_states_dict: 155 | new_hidden_states = hidden_states_dict['decoder_hidden_states'][:, :, -1, :, :].unsqueeze(2).numpy() 156 | elif 'hidden_states' in hidden_states_dict: 157 | new_hidden_states = hidden_states_dict['hidden_states'][:, :, -1, :, :].unsqueeze(2).numpy() 158 | # stack hidden states if there is an encoder. ENCODER HIDDEN STATES WILL BE SECOND INDEX. DECODER STATES ARE ALWAYS FIRST INDEX 159 | if 'encoder_hidden_states' in hidden_states_dict: 160 | stack_hidden_states = hidden_states_dict['encoder_hidden_states'][:, :, -1, :, :].unsqueeze(2).numpy() 161 | new_hidden_states = np.concatenate([stack_hidden_states, new_hidden_states], axis=2) 162 | # select only middle and last layer 163 | num_layers = new_hidden_states.shape[-2] 164 | middle_layer = np.ceil(num_layers/2) # the embedding layer gets counted as a layer, so round up for odd num_layers 165 | middle_and_last_idx = torch.tensor([middle_layer, num_layers-1]).to(torch.int) 166 | new_hidden_states = new_hidden_states[:, :, :, middle_and_last_idx, :] 167 | return new_hidden_states 168 | 169 | def make_LM_batch(tokenizer, prompts, label_strs, label_idx=None, padding_side='left', add_eos_token=False, 170 | max_len=None, generative_batch=False, reasoning_chains=None): 171 | ''' 172 | This makes inputs for computing LM probabilities of labels given prompts, when generative_batch=False 173 | e.g. for prompts = ["I like", "I like", "I do not like", "I do not like"] and answer_choices = ["dogs", "cats, "birds", "fish] 174 | with labels = [1,1] (repeating of prompts and flattening of nested answer choice list expected to be done prior to this method) 175 | This returns a dict 176 | { 177 | "input_ids": tokenized ["I like dogs", "I like cats", "I do not like birds", "I do not like fish"] 178 | "attention_mask": normal attention mask for the tokenizer 179 | "targets_mask": tensor, 0 where a token belonged in the prompt, 1 where it belonged in answer_choice, 0 for padding 180 | "label_idx": indices of answers without modification, i.e. [1,1], that can index the probabilities after reshaping into orig_batch_size x num_answers 181 | } 182 | intended for use with compute_probs_from_batch 183 | args: 184 | generative_batch: when true, input_ids does not contain both prompts and answer_choices 185 | reasoning_chains: when provided, these are included with labels as "targets" for the batch (for finetuning a model to do CoT) 186 | ''' 187 | # set pad and bos tokens as needed 188 | if tokenizer.pad_token_id is None: 189 | tokenizer.pad_token_id = tokenizer.eos_token_id # note we never supervise model with eos token 190 | if tokenizer.bos_token_id is None: 191 | bos_token_list = [] 192 | else: 193 | bos_token_list = [tokenizer.bos_token_id] 194 | # tokenize inputs. DO NOT ADD SPACE BEFORE ANSWER. this is handled in prompt.py with spacing at end of prompt 195 | prompt_ids = [tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts] 196 | label_ids = [tokenizer.encode(f"{answer}", add_special_tokens=False) for answer in label_strs] 197 | if reasoning_chains is not None: 198 | reasoning_ids = [tokenizer.encode(reasoning, add_special_tokens=False) for reasoning in reasoning_chains] 199 | if generative_batch: 200 | lm_inputs = [bos_token_list + _prompt_ids for _prompt_ids in prompt_ids] 201 | else: 202 | if reasoning_chains is None: 203 | lm_inputs = [bos_token_list + _prompt_ids + _label_ids for _prompt_ids, _label_ids in zip(prompt_ids, label_ids)] 204 | else: 205 | lm_inputs = [bos_token_list + _prompt_ids + _reasoning_ids + _label_ids for _prompt_ids, _reasoning_ids, _label_ids in zip(prompt_ids, reasoning_ids, label_ids)] 206 | # add eos tokens if requested 207 | if add_eos_token: 208 | lm_inputs = [x + [tokenizer.eos_token_id] for x in lm_inputs] 209 | # pad lm inputs 210 | if max_len is not None and max_len > 0: 211 | assert not max([len(input_ids) for input_ids in lm_inputs]) > max_len, f"Trying to make LM batch with inputs that are too long for max len {max_len}. Fix this in data_utils.py" 212 | use_max_len = max([len(input_ids) for input_ids in lm_inputs]) 213 | # left-pad inputs to max len of batch 214 | for lm_input in lm_inputs: 215 | short_by = use_max_len - len(lm_input) 216 | if padding_side == 'left': 217 | lm_input[:0] = [tokenizer.pad_token_id]*short_by # somehow this is proper indexing... 218 | else: 219 | lm_input += [tokenizer.pad_token_id]*short_by 220 | # now get label masks 221 | if generative_batch: 222 | targets_mask = None 223 | else: 224 | targets_mask = [] 225 | reasoning_ids = [[] for _ in range(len(prompt_ids))] if reasoning_chains is None else reasoning_ids 226 | for _prompt_ids, _reasoning_ids, _label_ids in zip(prompt_ids, reasoning_ids, label_ids): 227 | num_tokens = len(_prompt_ids) + len(_reasoning_ids) + len(_label_ids) + add_eos_token + (tokenizer.bos_token_id is not None) 228 | num_target_tokens = len(_reasoning_ids) + len(_label_ids) + add_eos_token 229 | if padding_side == 'left': 230 | label_mask = [0]*(use_max_len-num_target_tokens) + [1]*(num_target_tokens) 231 | elif padding_side == 'right': 232 | label_mask = [0]*(num_tokens-num_target_tokens) + [1]*num_target_tokens + [0]*(use_max_len-num_tokens) 233 | targets_mask.append(label_mask) 234 | targets_mask = torch.tensor(targets_mask) 235 | # and an attention mask 236 | lm_inputs = torch.tensor(lm_inputs) 237 | attention_mask = lm_inputs != tokenizer.pad_token_id 238 | batch = { 239 | 'input_ids': lm_inputs, 240 | 'attention_mask': attention_mask, 241 | 'targets_mask': targets_mask, 242 | 'label_idx': torch.tensor(label_idx) if label_idx else None, 243 | } 244 | return batch 245 | 246 | def postprocess_generations(tokenizer, preds, prompts): 247 | """ 248 | model generations include the prompts by default. this removes these from the generation 249 | also checks for bad degenerations of alternating stop tokens and real tokens 250 | """ 251 | if type(preds) is torch.Tensor: 252 | preds = [tokenizer.decode(pred, skip_special_tokens=True) for pred in preds] 253 | if type(prompts) is torch.Tensor: 254 | prompts = [tokenizer.decode(x, skip_special_tokens=True) for x in prompts] 255 | assert len(preds) == len(prompts) 256 | preds = [pred.replace(prompt, "") for pred, prompt in zip(preds, prompts)] 257 | return preds 258 | 259 | def pull_prompt_from_data(data, k): 260 | prompt_idx = np.random.choice(np.arange(len(data)), size=k, replace=False) 261 | prompt_ex = data.iloc[prompt_idx] 262 | remaining_idx = np.setdiff1d(np.arange(len(data)), prompt_idx) 263 | remaining_data = data.iloc[remaining_idx] 264 | return prompt_ex, remaining_data 265 | 266 | def score_seq_probs_from_strings(model, tokenizer, strings, breaking_batch_size=None): 267 | ''' 268 | Helper for scoring straight from a set of strings. computes a string log prob, starting with a bos token 269 | ''' 270 | all_probs = [] 271 | if breaking_batch_size is not None: 272 | list_of_strings = np.array_split(strings, len(strings) // breaking_batch_size + 1) 273 | else: 274 | list_of_strings = [strings] 275 | print() 276 | for idx, _strings in enumerate(list_of_strings): 277 | empty_prompts = [""] * len(_strings) 278 | batch = make_LM_batch(tokenizer, prompts=empty_prompts, label_strs=_strings) 279 | model_input = { 280 | 'input_ids': batch['input_ids'], 281 | 'attention_mask': batch['attention_mask'], 282 | 'targets_mask': batch['targets_mask'], 283 | 'labels': batch['input_ids'], 284 | } 285 | probs = compute_probs_from_batch(model, model_input) 286 | all_probs.extend(probs.reshape(-1).tolist()) 287 | del probs 288 | print(f" Batch: {idx}/{len(list_of_strings)} | mem use: {utils.get_mem()}", end='\r') 289 | return all_probs 290 | 291 | 292 | 293 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import sys 4 | import os 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))) 6 | import LM_utils 7 | import data_utils 8 | import re 9 | 10 | def p_value(betas): 11 | # calculate p-value for two-sided difference from 0 test with a bootstrapped distribution of statistics, betas 12 | abs_mean_beta = np.abs(np.mean(betas)) 13 | centered_betas = betas - np.mean(betas) 14 | outside_prop = np.mean(centered_betas < -abs_mean_beta) + np.mean(centered_betas > abs_mean_beta) 15 | return outside_prop 16 | 17 | def grid_bootstrap(ndarray, summary_function, boot_times=10000): 18 | ''' 19 | This function performs bootstrap resampling on an ndarray of dim <=2, applying the summary_function, and returns a mean estimate and 95\% CI for the estimate 20 | - note we filter our all-nan rows before starting the bootstrap, because those are entirely missing data 21 | args: 22 | ndarray: np.ndarray of data of up to ndim=2. Usually two-dimensional data would be of the form n_items x n_models 23 | summary_function: a function that converts an ndarray to a scalar value (the summary statistic of interest) 24 | boot_times: number of resamples on the ndarray to perform. The higher, the more precise the bootstrap 25 | ''' 26 | assert ndarray.ndim <= 2 27 | # filter out totally missing rows 28 | n_cols = ndarray.shape[1] 29 | where_all_nan_rows = np.isnan(ndarray).sum(-1) == n_cols 30 | ndarray = ndarray[np.argwhere(1-where_all_nan_rows).squeeze()] 31 | n_rows = ndarray.shape[0] 32 | # collect stats 33 | n_observed_rows = [] 34 | n_observations = [] 35 | boot_stats = [] 36 | for _ in range(boot_times): 37 | # pick number of columns to sample 38 | col_idx = np.random.choice(np.arange(n_cols), size=n_cols, replace=True) 39 | row_idx = np.random.choice(np.arange(n_rows), size=n_rows, replace=True) 40 | resampled_data = ndarray[:,col_idx] 41 | resampled_data = resampled_data[row_idx, :] 42 | n_observed_rows.append(len(resampled_data) - (np.isnan(resampled_data).sum(1) == n_cols).sum()) 43 | n_observations.append((1-np.isnan(resampled_data)).sum()) 44 | boot_stat = summary_function(resampled_data) 45 | boot_stats.append(boot_stat) 46 | # get mean and 95% quantiles on boot distribution 47 | mean_estimate = np.mean(boot_stats) 48 | quantiles = np.quantile(boot_stats, [0.025, .975]) 49 | avg_diff_from_mean = np.mean(np.abs(quantiles - mean_estimate)) 50 | if np.abs(summary_function(ndarray) - mean_estimate) > .01: 51 | print(f"WARNING: Bootstrap mean estimate error greater than .01, please use more boot_times") 52 | return_dict = { 53 | 'mean': mean_estimate, 54 | 'error_bar': avg_diff_from_mean, 55 | 'str_format': f"{100*mean_estimate:5.2f} \u00b1 {100*avg_diff_from_mean:5.2f}", 56 | 'p_value': p_value(boot_stats), 57 | 'sample_size': n_rows, 58 | 'effective_sample_size': f"{np.mean(n_observed_rows):.2f}", 59 | } 60 | return return_dict 61 | 62 | def force_not_dimensionless(data): 63 | if type(data) is torch.Tensor: 64 | if data.dim()==0: 65 | data = data.view(1) 66 | return data 67 | 68 | def safe_seq(seq): 69 | # filter to non -100 values in seq, which is a list. -100 is the default ignore_index in pytorch 70 | return [x for x in seq if x >= 0] 71 | 72 | def em_accuracy_sum(preds, labels, return_where_correct=False): 73 | assert len(preds) == len(labels) 74 | # strict calculation of accuracy for predictions from fewshot model 75 | preds = np.array([x for x in preds]) 76 | labels = np.array([label for label in labels]) 77 | correct = (preds==labels) 78 | if return_where_correct: 79 | return correct.sum(), correct 80 | else: 81 | return correct.sum() 82 | 83 | def standardize_preds_or_labels(data, tokenizer=None): 84 | """ 85 | takes tensors, arrays, and lists, and returns standardized pred/label strs 86 | IF there are multiple labels per item, then we return a list of lists 87 | ELSE, we return an np array 88 | args: 89 | data: should be 1-d np.array, 1-d torch.tensor, or list of these things 90 | tokenizer: model tokenizer 91 | """ 92 | # unravel data into list or list of lists 93 | if type(data) is list and type(data[0]) is torch.Tensor or type(data[0]) is np.ndarray: 94 | data = [item.tolist() for item in data] 95 | if type(data) is not list: 96 | data = data.tolist() 97 | if type(data) in [int, torch.int, str, np.str_]: 98 | data = [data] 99 | # decode if elements are not already strings, or lists of strings (which would suggest it had been decoded already) 100 | need_to_decode = not (type(data[0]) is str or type(data[0]) is np.str_ or (type(data) is list and type(data[0][0]) is str)) 101 | if need_to_decode: 102 | data = [tokenizer.decode(safe_seq(seq), skip_special_tokens=True) for seq in data] 103 | # lowercase and strip the strs 104 | multiple_eligible_labels = type(data[0]) is list 105 | if multiple_eligible_labels: 106 | data = [[x.lower().strip().strip('.') for x in eligible_labels] for eligible_labels in data] 107 | else: 108 | data = [x.lower().strip().strip('.') for x in data] 109 | # convert to np array or list of lists 110 | if type(data) is torch.Tensor: 111 | data = data.detach().cpu().numpy() 112 | elif type(data) is list and type(data[0]) is list: 113 | data = data # skip the array formatting here as it will not be used in downstream metrics 114 | else: 115 | data = np.array(data) 116 | return data 117 | 118 | def first_appearance_fewshot_accuracy_sum(preds, labels, extract_answers, trigger_phrase=None, return_vec=False): 119 | """ 120 | calculated accuracy of model generations against labels, optionally given answer_choices and a 'trigger phrase' used in CoT 121 | an answer is 'predicted' if it appears in the pred str 122 | - this is VERY GENEROUS scoring for some tasks. Use generative_exact_match_accuracy for math tasks 123 | - this function also faces issues when labels/answers are subsets of one another 124 | - if multiple answers are mentioned, count which answer appears most. tie breaking is done randomly 125 | returns acc sum, optionally the vector of binary 0/1 accs per points 126 | args: 127 | preds and labels should be list, 1-d np.array, or 1-d torch.tensor of ints or strs 128 | answer_choices: optional list of answer choices to count 129 | trigger_phrase: a phrase that could separate e.g. reasoning from a final answer, like "Therefore, the answer is" 130 | """ 131 | assert len(preds) == len(labels) 132 | preds = standardize_preds_or_labels(preds) 133 | labels = standardize_preds_or_labels(labels) 134 | extract_answers = standardize_preds_or_labels(extract_answers) 135 | if trigger_phrase is not None: 136 | trigger_phrase = standardize_preds_or_labels([trigger_phrase]).item() 137 | n_correct = 0 138 | use_preds = [] 139 | correct_indicators = [] 140 | for pred, label in zip(preds, labels): 141 | answer_positions = {answer : 2e8 for answer in extract_answers} 142 | # extract part of pred after trigger phrase 143 | if trigger_phrase is not None and trigger_phrase in pred: 144 | pred = pred.split(trigger_phrase)[1] 145 | else: 146 | pred = pred 147 | # take first appearance of an answer in the pred 148 | # note this faces difficulty when answers are subsets of one another 149 | for answer in extract_answers: 150 | if answer in pred: 151 | answer_positions[answer] = pred.index(answer) 152 | min_position = min(answer_positions.values()) 153 | earliest_pred = list(filter(lambda tup: tup[1] == min_position, list(answer_positions.items()))) 154 | if len(earliest_pred) == 1: 155 | use_pred = earliest_pred[0][0] 156 | else: 157 | use_pred = 'NA' 158 | correct = (use_pred == label) 159 | n_correct += correct 160 | correct_indicators.append(correct) 161 | use_preds.append(use_pred) 162 | if not return_vec: 163 | return n_correct 164 | else: 165 | return n_correct, use_preds, np.array(correct_indicators) 166 | 167 | def extract_numeric(pred): 168 | # extracts the numeric prediction from a string (should already have string.split(trigger_phrase)[1] applied as necessary) 169 | pred_end = " ".join(pred.split()[-2:]) # sometimes what follows the top string is something like "2 + 3 = 5" or "5 apples" 170 | numeric = re.sub(r"[^0-9.]", "", pred_end) 171 | if numeric == "" or numeric == ".": 172 | for word in reversed(pred.split(" ")): 173 | if bool(re.search(r"\d", word)): 174 | numeric = re.sub(r"[^0-9.]", "", word) 175 | pred = numeric 176 | return numeric 177 | 178 | def generative_exact_match_accuracy_sum(preds, labels, trigger_phrase=None, stop_string=None, numeric_filter=False, return_vec=False): 179 | """ 180 | Checks that predictions exactly match label during CoT. 181 | Preds are extracted by taking text after the "trigger phrase" that always prefaces answers 182 | args: 183 | preds and labels should be list, 1-d np.array, or 1-d torch.tensor of ints or strs 184 | trigger_phrase: a phrase that must always separate e.g. reasoning from a final answer, like "the answer is". Used in CoT 185 | stop_string: answers always 'end' at this string, like a line break or eos token 186 | """ 187 | assert len(preds) == len(labels) 188 | preds = standardize_preds_or_labels(preds) 189 | labels = standardize_preds_or_labels(labels) 190 | processed_preds = [] 191 | for pred in preds: 192 | # break up preds based on trigger phrase 193 | if trigger_phrase is None: 194 | pass 195 | elif trigger_phrase is not None: 196 | if trigger_phrase in pred: 197 | pred = pred.split(trigger_phrase)[1] 198 | else: 199 | pred = f"[PRED DID NOT HAVE TRIGGER PHRASE]: ...{pred[-30:]}" 200 | # stop preds at the stop_string 201 | if stop_string is not None and stop_string in pred: 202 | pred = pred[:pred.index(stop_string)] 203 | pred = pred.strip().strip('\n').strip().strip('.') 204 | if numeric_filter and not "PRED DID NOT HAVE TRIGGER PHRASE" in pred: 205 | pred = extract_numeric(pred) 206 | processed_preds.append(pred) 207 | # compute EM 208 | n_correct, binary_correct = em_accuracy_sum(processed_preds, labels, return_where_correct=return_vec) 209 | if not return_vec: 210 | return n_correct 211 | else: 212 | return n_correct, processed_preds, binary_correct 213 | -------------------------------------------------------------------------------- /utils/modeling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | from transformers import get_scheduler 5 | from torch.optim import SGD, AdamW 6 | import time 7 | 8 | from models.probes import Probe 9 | from utils import utils, metrics, LM_utils, data_utils 10 | import globals 11 | 12 | def load_optimizer_and_scheduler(args, model, num_training_steps): 13 | named_parameters = model.named_parameters() 14 | optimizer_grouped_parameters = [ 15 | {"params": [p for n, p in named_parameters if not any(nd in n for nd in ["bias", "LayerNorm.weight"])], 16 | "weight_decay": args.weight_decay, 17 | 'lr' : args.lr}, 18 | {"params": [p for n, p in named_parameters if any(nd in n for nd in ["bias", "LayerNorm.weight"])], 19 | "weight_decay": 0.0, 20 | 'lr' : args.lr} 21 | ] 22 | optimizer_to_class = {'adamw' : AdamW, 'sgd' : SGD} 23 | optimizer_class = optimizer_to_class[args.optimizer] 24 | optimizer = optimizer_class(optimizer_grouped_parameters) 25 | if args.lr_decay == "constant": 26 | scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps) 27 | elif args.lr_decay in ['linear', '10-percent']: 28 | percent_of_orig_value = .1 if args.lr_decay == '10-percent' else 0 29 | multiplier = 1 / (1-percent_of_orig_value) 30 | scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=multiplier*num_training_steps) 31 | return (optimizer, scheduler) 32 | 33 | def evaluate_model(args, 34 | log, 35 | model, 36 | dataloader, 37 | tokenizer, 38 | return_hidden_states=False, 39 | calibrate_probe=False, 40 | verbose=False): 41 | ''' 42 | main train_and_eval function that evaluates models on data from a MCTextDataset 43 | returns eval_stats, which may include the hidden_states as requested 44 | ''' 45 | # condition names 46 | gathering_question_end_states = return_hidden_states and args.probing_token_state == 'question_end_token' 47 | MC_or_classification = not args.generative_eval and not gathering_question_end_states 48 | # init stats dicts. epochs_stats will be running statistics, used to compute values for eval_stats 49 | eval_stats = { 50 | 'n_batches': 0, 51 | 'forward_time_sum' : 0, 52 | 'acc': -1, 53 | 'loss': -1, 54 | 'probe_loss': -1, # use for model selection, may be unsupervised objective 55 | 'modal_label': '', 56 | } 57 | epoch_stats = { 58 | 'acc_sum': 0, 59 | 'loss_sum': 0, 60 | 'probe_loss_sum': 0, # used for model selection, may be unsupervised objective 61 | 'n_data_points': 0, 62 | } 63 | start_time = time.time() 64 | model.eval() 65 | total_batches = len(dataloader) 66 | # set other generative eval args 67 | if args.generative_eval and hasattr(dataloader, 'dataset'): 68 | trigger_phrase = 'the answer is' if args.use_cot else None 69 | if args.probing_method == 'decoding': 70 | stop_string = "\n" 71 | if args.probing_method == 'finetuned': 72 | stop_string = None # would be tokenizer.eos_token, but tokenizer.decode should stop decoding at the eos_token 73 | all_pd_index = [] 74 | all_probs = [] 75 | all_preds = [] 76 | all_labels = [] 77 | all_binary_correct = [] 78 | label_confidence = [] 79 | 80 | # prepare collection of hidden states 81 | if return_hidden_states: 82 | # if doing probing like scoring f(x,a) pairs, store one hidden states per "x a" input. But when doing classifation, we classify f(x) 83 | answers_dim = 1 if args.probing_token_state == 'question_end_token' else data_utils.get_max_num_answers(dataloader.dataset.dataframe) 84 | hidden_size = globals.model_to_hidden_size[args.short_model_name] 85 | num_layers = 2 # including last layer and a middle layer. middle layer is idx 0, last layer is idx 1 86 | encoder_decoder = 2 if args.model_type == 'encoder_decoder' else 1 # will stack encoder/decoder representations 87 | running_hidden_states = np.ndarray((0, answers_dim, encoder_decoder, num_layers, hidden_size)) 88 | for batch_num, batch in enumerate(dataloader): 89 | running_time = (time.time()-start_time) 90 | est_run_time = (running_time/(batch_num if batch_num > 0 else 1)*total_batches) 91 | forward_time = (eval_stats['forward_time_sum'] / eval_stats['n_batches'] if batch_num > 0 else 0) 92 | if verbose: 93 | gpu_mem = utils.get_gpu_utilization() if "cuda" in str(args.device) else None 94 | log.print_training_prog(eval_stats, 'EVAL', 'EVAL', batch_num, total_batches, running_time, est_run_time, forward_time, gpu_mem=gpu_mem) 95 | # forward pass 96 | with torch.no_grad(): 97 | if args.generative_eval: 98 | _model = getattr(model, 'model', None) if isinstance(model, Probe) else model 99 | assert args.padding_side == 'left' 100 | main_kwargs = { 101 | 'input_ids': batch['input_ids'], 102 | 'attention_mask': batch['attention_mask'], 103 | 'do_sample': False, 104 | 'max_new_tokens': args.max_gen_len, 105 | 'pad_token_id': tokenizer.pad_token_id, 106 | 'eos_token_id': tokenizer.eos_token_id 107 | } 108 | utils.move_kwargs_to_gpu(main_kwargs) 109 | forward_begin = time.time() 110 | preds = _model.generate(**main_kwargs) 111 | eval_stats['forward_time_sum'] += (time.time() - forward_begin) 112 | preds = LM_utils.postprocess_generations(tokenizer, preds, main_kwargs['input_ids']) 113 | labels = batch['label_strs'] 114 | n_correct, trimmed_preds, binary_correct = metrics.generative_exact_match_accuracy_sum(preds, labels, 115 | trigger_phrase=trigger_phrase, 116 | stop_string=stop_string, 117 | numeric_filter=('gsm8k' in args.dataset), 118 | return_vec=True) 119 | all_preds.extend(trimmed_preds) 120 | all_labels.extend(labels) 121 | all_pd_index.extend(batch['pd_index']) 122 | all_binary_correct.extend(binary_correct.tolist()) 123 | # in this condition, we're just gathering hidden states from the ends of the questions 124 | elif gathering_question_end_states: 125 | main_kwargs = { 126 | 'input_ids': batch['input_ids'], 127 | 'attention_mask': batch['attention_mask'], 128 | } 129 | hidden_states_dict = LM_utils.get_hidden_states_from_batch(model, main_kwargs) 130 | # in this condition, we evaluate the model in a multiple choice or classification fashion 131 | elif MC_or_classification: 132 | if isinstance(model, Probe): 133 | forward_begin = time.time() 134 | num_answers_list = batch['num_answers_list'] 135 | compute_mc_probs = not all(x==1 for x in num_answers_list) 136 | answer_probs = model(batch, compute_mc_probs=compute_mc_probs) # handle cuda() inside forward pass due to moving different data to gpu 137 | eval_stats['forward_time_sum'] += (time.time() - forward_begin) 138 | assert not return_hidden_states, "If collecting hidden states, use model as an arg and not Probe class as arg" 139 | else: 140 | forward_begin = time.time() 141 | main_kwargs = { 142 | 'input_ids': batch['input_ids'], 143 | 'attention_mask': batch['attention_mask'], 144 | 'labels': batch['input_ids'], 145 | 'targets_mask': batch['targets_mask'], 146 | 'answer_choices': batch['answer_choices'], 147 | } 148 | if return_hidden_states: 149 | answer_probs, hidden_states_dict = LM_utils.compute_probs_from_batch(model, main_kwargs, return_value=args.answer_scoring, num_answers_list=batch['num_answers_list'], 150 | return_hidden_states=True) 151 | elif not return_hidden_states: 152 | answer_probs = LM_utils.compute_probs_from_batch(model, main_kwargs, return_value=args.answer_scoring, num_answers_list=num_answers_list) 153 | eval_stats['forward_time_sum'] += (time.time() - forward_begin) 154 | # get preds, loss, and acc 155 | preds = torch.argmax(answer_probs, dim=1) 156 | label_idx = batch['label_idx'] 157 | if args.n_gpu > 0: 158 | label_idx = label_idx.cuda() 159 | answer_probs = answer_probs.cuda() # syncs probs with labels in multi-gpu case 160 | loss, label_probs = LM_utils.compute_mc_loss(answer_probs, label_idx) 161 | if hasattr(model, 'loss'): # i.e., if model is one of our probes. Useful for getting unsupervised loss 162 | probe_loss = model.loss(label_idx, answer_probs) 163 | else: 164 | probe_loss = None 165 | all_probs.append(answer_probs.detach().cpu()) 166 | preds = preds.cpu().numpy() 167 | label_idx = label_idx.cpu().numpy() 168 | all_preds.extend(preds.tolist()) 169 | all_labels.extend(label_idx.tolist()) 170 | all_pd_index.extend(batch['pd_index']) 171 | n_correct, binary_correct = metrics.em_accuracy_sum(preds, label_idx, return_where_correct=True) 172 | label_confidence.extend(label_probs.tolist()) 173 | all_binary_correct.extend(binary_correct.tolist()) 174 | # update epoch stats 175 | epoch_stats['loss_sum'] += loss.item() 176 | if probe_loss is not None: 177 | epoch_stats['probe_loss_sum'] += probe_loss.item() 178 | del loss, probe_loss 179 | # update epoch stats 180 | if not gathering_question_end_states: 181 | epoch_stats['acc_sum'] += n_correct 182 | epoch_stats['n_data_points'] += len(batch['items']) 183 | # update eval stats 184 | eval_stats['loss'] = epoch_stats['loss_sum'] / (batch_num+1) 185 | eval_stats['probe_loss'] = epoch_stats['probe_loss_sum'] / (batch_num+1) 186 | eval_stats['acc'] = epoch_stats['acc_sum'] / epoch_stats['n_data_points'] 187 | eval_stats['n_batches'] += 1 188 | # accumulate hidden states 189 | if return_hidden_states: 190 | if args.probing_token_state == 'answer_end_token': 191 | num_answers_list = batch['num_answers_list'] 192 | max_num_answers = data_utils.get_max_num_answers(dataloader.dataset.dataframe) 193 | else: 194 | num_answers_list = max_num_answers = None 195 | new_hidden_states = LM_utils.get_last_token_hidden_states(hidden_states_dict, 196 | num_answers_list=num_answers_list, 197 | max_num_answers=max_num_answers) 198 | running_hidden_states = np.concatenate([running_hidden_states, new_hidden_states], axis=0) 199 | # print examples 200 | if verbose: 201 | # if args.num_print > 0: 202 | # print_idx = list(range(min(args.num_print, len(batch['items'])))) 203 | if (batch_num == 0 and args.num_print > 0): 204 | print_idx = list(range(min(args.num_print, len(batch['items'])))) 205 | else: 206 | print_idx = [] 207 | if len(print_idx) > 0: 208 | print("\n" + "-"*20 + f"\nPrinting examples:") 209 | print(f" Exact Input 0 : {tokenizer.decode(batch['input_ids'][0])}") 210 | for i in print_idx: 211 | prompt = batch['prompts'][i] 212 | label = batch['label_strs'][i] 213 | answer_choices = ['A', 'B', 'C', 'D'] if args.use_letter_labels else batch['answers_list'][i] 214 | print(f" point {i}") 215 | print(f" Prompt : \n{prompt}") 216 | if MC_or_classification: 217 | probs = [np.round(x.item(), 4) for x in answer_probs[i].cpu()] 218 | print(f" Preds : {[x for x in zip(answer_choices, probs)]}") 219 | pred = answer_choices[preds[i]] 220 | correct = binary_correct[i] 221 | elif args.use_cot: 222 | print(f" Full pred : {preds[i]}") 223 | pred = trimmed_preds[i] 224 | correct = binary_correct[i] 225 | elif args.generative_eval: 226 | pred = preds[i] 227 | correct = binary_correct[i] 228 | elif gathering_question_end_states: 229 | pred = "No pred; just gathering hidden states" 230 | correct = "N/A" 231 | print(f" Pred : {pred}") 232 | print(f" Label : {label}") 233 | print(f" Correct : {correct}") 234 | if args.dataset == 'gsm8k_main': 235 | print(f"steps: {batch['items'][i][1].num_steps} | {batch['items'][i][1].reasoning}") 236 | if args.dataset == 'mmlu_STEM-5' or args.dataset == 'third_grade_to_college': 237 | print(f"subject: {batch['items'][i][1].subject} | {batch['items'][i][1].human_hardness}") 238 | if correct: 239 | write_to_file = prompt + "\n" + str([x for x in zip(answer_choices, probs)]) + f"\n {correct}" 240 | with open('tmp_example.txt', 'w') as f: 241 | f.write(write_to_file) 242 | if i != print_idx[-1]: 243 | print() 244 | print("-"*20 + '\n') 245 | del batch 246 | 247 | # calibrate preds and overwrite values in eval_stats. On future calls of evaluate_model, the forward pass will automatically do this all_probs-model.probs_centroid step 248 | if calibrate_probe: 249 | all_probs = torch.concatenate(all_probs) 250 | model.set_calibration_params(probs=all_probs, verbose=True) 251 | all_probs = all_probs - model.probs_centroid.cpu() 252 | all_preds = torch.argmax(all_probs, dim=-1).numpy() 253 | n_correct, binary_correct = metrics.em_accuracy_sum(all_preds, all_labels, return_where_correct=True) 254 | eval_stats['acc'] = n_correct / epoch_stats['n_data_points'] 255 | 256 | # make item level stats df 257 | item_level_stats = pd.DataFrame({ 258 | 'label_confidence': label_confidence if not args.generative_eval else None, 259 | 'accuracy': 1*np.array(all_binary_correct), 260 | }, index=all_pd_index) 261 | eval_stats['item_level_stats'] = item_level_stats 262 | 263 | # add hidden_states, label confidence 264 | if return_hidden_states: 265 | eval_stats['hidden_states'] = running_hidden_states.astype(np.float32) 266 | # add model proportion as 'random' performance 267 | if total_batches > 0: 268 | label_props = [(label, np.mean(np.array(all_labels) == label)) for label in set(all_labels)] 269 | label_props = sorted(label_props, key=lambda x: -x[1]) 270 | eval_stats['modal_label'] = f"{label_props[0][0]}: {label_props[0][1]:.2f}" if len(set(all_labels)) > 1 else "NA" 271 | if len(set(all_preds)) < 10: 272 | eval_stats['pred_distr'] = {y: round(np.mean(np.array(all_preds)==y), 2) for y in set(all_preds)} 273 | eval_stats['label_distr'] = {y: round(np.mean(np.array(all_labels)==y), 2) for y in set(all_labels)} 274 | else: 275 | eval_stats['pred_distr'] = {} 276 | eval_stats['label_distr'] = {} 277 | if verbose: 278 | print(" Pred distr: ", eval_stats['pred_distr']) 279 | print(" Label distr: ", eval_stats['label_distr']) 280 | return eval_stats -------------------------------------------------------------------------------- /utils/plotting_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import utils 5 | import seaborn as sns 6 | import PIL 7 | from matplotlib.transforms import Affine2D 8 | 9 | def plot_corr_matrix(corr_matrix, save_name, title): 10 | plt.figure(figsize=(10, 8)) # Set the size of the figure 11 | sns.heatmap(corr_matrix, annot=True, cmap='coolwarm_r', fmt=".2f") # Create the heatmap with annotations 12 | plt.title(title) # Add a title 13 | plt.xticks(rotation=45) # Rotate the x-axis labels for better readability 14 | # plt.yticks(rotation=45) # Rotate the y-axis labels for better readability 15 | ax = plt.gca() 16 | labels = ax.get_xticklabels() 17 | # Create an offset transform 18 | dx = -1.5 # shift by -0.5 19 | offset = Affine2D().translate(dx, 0) 20 | # Apply the offset transform to each label 21 | for label in labels: 22 | label.set_transform(label.get_transform() + offset) 23 | filepath = f'result_sheets/{save_name}' 24 | plt.tight_layout() 25 | plt.savefig(filepath + '.png', format='png', dpi=300) 26 | plt.clf() 27 | 28 | def grid_arrange_pngs(pngs, save_name): 29 | # takes a list of png plots opened as PIL.Image.open(x), and arranges them in a grid 30 | n_cols = 4 31 | n_rows = int(np.ceil(len(pngs) / n_cols)) # Ceil division 32 | fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(9, 9), dpi=300) 33 | axes_flat = axes.flatten() 34 | 35 | # Loop through each subplot and each filepath to display images 36 | for ax, img in zip(axes_flat, pngs): 37 | ax.imshow(np.array(img), interpolation='bilinear') 38 | ax.axis('off') # Hide axes 39 | 40 | # remove empty subplots 41 | for remaining_ax in axes_flat[len(pngs):]: 42 | remaining_ax.axis('off') 43 | 44 | # Display the grid of images 45 | plt.show() 46 | filepath = f'result_sheets/{save_name}' 47 | plt.tight_layout() 48 | plt.savefig(filepath + '.png', format='png', dpi=300) 49 | plt.clf() 50 | 51 | 52 | def compare_method_learning_curves(df_results, save_prefix, hardness_var_name, test_split='hard'): 53 | ''' 54 | To be used on the results of experiments in run_jobs.py (applied to the resulting dfs). 55 | This function plots the learning curves of different prompting/probing methods, for a particular test split of the data 56 | args: 57 | test_split: show test_acc for df rows where test_on==test_split 58 | ''' 59 | save_dir = 'outputs' 60 | df_results = df_results.copy() # suppress slicing warnings 61 | cot_col = df_results['use_cot'].apply(lambda x: '-CoT' if bool(x) else '') 62 | sup_col = df_results['probe_loss'].apply(lambda x: '-unsup' if 'unsup' in x else '') 63 | df_results['method'] = df_results['probing_method'] + cot_col + '-' + df_results['train_on'].astype(str) + sup_col 64 | max_n = 400 65 | # subset df 66 | subset = df_results[df_results['test_on'] == test_split] 67 | subset = subset[subset['n_train'] <= max_n] 68 | sns.set(style='whitegrid') 69 | plt.figure(figsize=(14, 7)) 70 | # add error bar 71 | subset['lower_bound'] = subset['test_acc'] - subset['error_bar'] 72 | subset['upper_bound'] = subset['test_acc'] + subset['error_bar'] 73 | # plot 74 | sns.lineplot(x='n_train', y='test_acc', hue='method', data=subset, marker='o', markersize=8, linewidth=2, errorbar=None) 75 | plt.xlabel('Number of Training Samples (n_train)') 76 | plt.ylabel('Test Accuracy') 77 | plt.title(f'{hardness_var_name} {test_split.capitalize()} Test Acc') 78 | plt.legend(title='Method', loc='center left', bbox_to_anchor=(1.05, 0.5)) 79 | plt.tight_layout() 80 | plt.show() 81 | save_path = f'{save_dir}/{save_prefix}_test-{test_split}' + '.png' 82 | plt.savefig(save_path, format='png') 83 | plt.clf() 84 | # open image to return 85 | img = PIL.Image.open(save_path) 86 | # now just show individual methods 87 | for probing_method in ['decoding', 'learned', 'finetuned']: 88 | max_n = 100 if probing_method == 'decoding' else 1000 89 | subset = df_results[df_results['test_on'] == test_split] 90 | subset = subset[subset['n_train'] <= max_n] 91 | probe_data = subset[subset['probing_method']==probing_method] 92 | sns.lineplot(x='n_train', y='test_acc', hue='method', data=probe_data, marker='o', markersize=8, linewidth=2, errorbar=None) 93 | plt.xlabel('Number of Training Samples (n_train)') 94 | plt.ylabel('Test Accuracy') 95 | plt.title(f'{hardness_var_name} {test_split.capitalize()} Test Acc') 96 | plt.legend(title='Method', loc='center left', bbox_to_anchor=(1.05, 0.5)) 97 | plt.tight_layout() 98 | plt.show() 99 | save_path = f'{save_dir}/{save_prefix}_test-{test_split}_{probing_method}' + '.png' 100 | plt.savefig(save_path, format='png') 101 | plt.clf() 102 | return img 103 | 104 | 105 | def plot_acc_vs_hardness(test_stats_df, save_name, hardness_var_name): 106 | ''' 107 | Plots accuracy against item level hardness for a given dataset 108 | args: 109 | test_stats_df: expected to have a hardness variable name and one or more columns with 'accuracy' in the name (possibly representing accuracies from different bootstraps or models) 110 | accuracy columns may be sparse 111 | hardness_var_name: name of hardness variable to use 112 | ''' 113 | plot_df = test_stats_df.copy() 114 | n_hardness_levels = len(set(plot_df[hardness_var_name].values)) 115 | acc_cols = filter(lambda x : 'acc' in x, plot_df.columns) 116 | do_bar_plot = n_hardness_levels < 10 117 | do_line_plot = not do_bar_plot 118 | # nanmean across the accuracy columns 119 | plot_df['mean_acc'] = np.nanmean(plot_df.loc[:, acc_cols].copy(), axis=1) 120 | plot_df = plot_df.loc[:, [hardness_var_name, 'mean_acc']].dropna() 121 | # make binned hardness variable 122 | if do_line_plot: 123 | n_bins = 5 124 | plot_df['hardness_binned'] = pd.cut(plot_df[hardness_var_name], bins=n_bins, labels=False) 125 | elif do_bar_plot: 126 | hardness_min = min(plot_df[hardness_var_name]) 127 | hardness_max = max(plot_df[hardness_var_name]) 128 | plot_df['hardness_binned'] = plot_df[hardness_var_name] 129 | # calculate group means and CIs 130 | grouped = plot_df.groupby('hardness_binned')['mean_acc'].agg(['mean', 'std', 'count']).reset_index() 131 | grouped['CI'] = 1.96 * grouped['std'] / np.sqrt(grouped['count']) 132 | # drop rows where less than n points 133 | at_least_n_points = 20 134 | grouped = grouped[grouped['count'] >= at_least_n_points] 135 | # plot 136 | color = '#4287f5' 137 | if do_bar_plot: 138 | # make x variable categorical in case there are missing categories after low n filtering 139 | grouped['hardness_binned'] = pd.Categorical(grouped['hardness_binned'], categories=list(range(hardness_min,hardness_max+1))) 140 | sns.barplot(x='hardness_binned', y='mean', data=grouped, color=color, errorbar=None) 141 | plt.errorbar(x=grouped.index, y=grouped['mean'], yerr=grouped['CI'], fmt='none', capsize=5, color='black') 142 | if do_line_plot: 143 | # Calculate bin centers 144 | bins = np.linspace(plot_df[hardness_var_name].min(), plot_df[hardness_var_name].max(), 6) 145 | bin_centers = bins[:-1] + np.diff(bins) / 2 146 | grouped['x_center'] = bin_centers[grouped.hardness_binned.values] 147 | # set x_axis ticks to the halfway points of the lower and upper hardness levels, assuming hardness var is 0-1 148 | grouped['x_axis'] = (grouped.hardness_binned + 1) / n_bins - (.5 / n_bins) 149 | sns.lineplot(x='x_center', y='mean', data=grouped, color=color, errorbar=None) 150 | plt.errorbar(x=grouped['x_center'], y=grouped['mean'], yerr=grouped['CI'], fmt='none', capsize=5, color=color) 151 | plt.xlabel('Hardness') 152 | plt.ylabel('Average Accuracy') 153 | plt.title(f"{hardness_var_name}") 154 | save_path = f'outputs/{save_name}' + '.png' 155 | plt.savefig(save_path, format='png') 156 | plt.clf() 157 | 158 | def plot_sample_efficiency(df_results, save_prefix, outcome='eval_acc', x_var_list=None, no_prompt_avg_plot=False, no_multiprompt_plot=False): 159 | ''' 160 | We plot eval_acc vs. n_train for a given dataset, to visualize hardness estimation results 161 | args: 162 | - x_var_list: can include 'n_train' and 'log_x' 163 | ''' 164 | df_results = df_results.copy() # suppress slicing warnings 165 | df_results['log_x'] = df_results['n_train'].apply(lambda x: np.log10(x) if x != 0 else 0) 166 | if x_var_list is None: 167 | x_var_list = ['n_train', 'log_x'] 168 | def log_axis(x_var_max): 169 | if x_var_max > 3: 170 | plt.xticks([1, 2, 3, 3.5], [10, 100, 1000, int(10**3.5)]) 171 | elif x_var_max < 2.5: 172 | plt.xticks([0, 1, 2, 2.5], [1, 10, 100, int(10**2.5)]) 173 | else: 174 | plt.xticks([1, 2, 3], [10, 100, 1000]) 175 | # average results over boot idx and prompt idx 176 | if not no_prompt_avg_plot: 177 | per_n_train = df_results.groupby(['n_train', 'log_x'])[outcome].mean().reset_index() 178 | # binomial proportional confidence interval 179 | if 'error_bar' not in df_results.columns: 180 | n_dev = df_results['n_dev'][0] 181 | per_n_train['se'] = np.sqrt(per_n_train[outcome] * (1-per_n_train[outcome]) / n_dev) 182 | per_n_train['CI'] = 1.96 * per_n_train['se'] 183 | # use pre-calculated error_bar 184 | else: 185 | per_n_train = pd.merge(per_n_train, df_results.loc[:,['n_train', 'error_bar']]) 186 | per_n_train['CI'] = per_n_train['error_bar'] 187 | for x_var in x_var_list: 188 | plt.errorbar(per_n_train[x_var], per_n_train[outcome], yerr=per_n_train['CI'], capsize=5, fmt='o-') 189 | if x_var == 'log_x': 190 | log_axis(df_results[x_var].max()) 191 | plt.xlabel('n_train') 192 | plt.ylabel(outcome) 193 | plt.title(f"{outcome} vs n_train") 194 | save_name = f"{save_prefix}" + ('_log' if x_var == 'log_x' else '') 195 | filepath = f'outputs/{save_name}' 196 | plt.savefig(filepath + '.png', format='png') 197 | plt.clf() 198 | # plot multiple trajectories for different prompts 199 | if not no_multiprompt_plot: 200 | for x_var in x_var_list: 201 | sns.lineplot(x=x_var, y=outcome, data=df_results, hue='prompt_id', palette='flare', errorbar=None) #, errorbar=('ci', 95), err_style='bars') 202 | if x_var == 'log_x': 203 | log_axis(df_results[x_var].max()) 204 | plt.xlabel('n_train') 205 | plt.ylabel(outcome) 206 | plt.title(f"{outcome} vs n_train\n{save_prefix}") 207 | plt.legend(title='Prompt ID') 208 | save_name = f"{save_prefix}_by_prompt" + ('_log' if x_var == 'log_x' else '') 209 | filepath = f'outputs/{save_name}' 210 | plt.savefig(filepath + '.png', format='png') 211 | plt.clf() 212 | 213 | def plot_hardness_distribution(hardness_scores, name='hardness_distribution'): 214 | # plot a single provided vector of per-item hardness scores 215 | plt.hist(hardness_scores, bins=20, edgecolor='black') 216 | if 'NORMED' in name: 217 | plt.xlim(0,1) 218 | plt.xlabel(name) 219 | plt.ylabel("Frequency") 220 | plt.title(f"Histogram of {name}") 221 | filepath = f'outputs/{name}' 222 | plt.savefig(filepath + '.png', format='png') 223 | plt.clf() 224 | 225 | def plot_hardness_distributions_facet(hardness_scores, plot_name): 226 | # facet plot of multiple per-item hardness scores 227 | n_cols = 4 228 | n_rows = int(-(-len(hardness_scores.columns) // n_cols)) # Ceil division 229 | fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(10, 10)) 230 | 231 | # make subplots 232 | for ax, name in zip(axes.flatten(), hardness_scores.columns): 233 | ax.hist(hardness_scores[name], bins=20, edgecolor='black') 234 | 235 | if 'NORMED' in name: 236 | ax.set_xlim(0, 1) 237 | 238 | ax.set_xlabel(name) 239 | ax.set_ylabel("Frequency") 240 | ax.set_title(f"{name}") 241 | 242 | # remove empty subplots 243 | for remaining_ax in axes.flatten()[len(hardness_scores.columns):]: 244 | remaining_ax.axis('off') 245 | 246 | filepath = f'result_sheets/{plot_name}' 247 | plt.tight_layout() 248 | plt.subplots_adjust(hspace=1.2) 249 | plt.savefig(filepath + '.png', format='png') 250 | plt.clf() -------------------------------------------------------------------------------- /utils/prompt.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import sys 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))) 6 | import data_utils 7 | 8 | class Prompt: 9 | ''' 10 | This class formats examples and test data into a prompt for passing to an LM 11 | - Intended for use with multiple-choice problems, so there is no "verbalizer" since answer choices are already verbalized 12 | - Requires a fixed number of multiple choice answers per task 13 | - Includes a number of options for instruction, example, answer-choice, and label formatting - intended for averaging performance over random prompts or prompt selection/tuning 14 | - Can take promptsource prompt templates as an argument, then self.apply will apply these 15 | ''' 16 | def __init__(self, args, datanames, data_source, prompt_source, use_cot=False, use_letter_labels=False, seed=None, num_prompts=None): 17 | ''' 18 | all prompt templates are defined here 19 | - source is burns/mmlu/bigbench 20 | - if promptsource_dict is passed, we will use these later rather than custom prompts (dataset name required to get correct prompt templates) 21 | - seed and num_prompts used to sample prompts by randomly combining the templates here 22 | ''' 23 | # prompts are some generic prompts that can be passed a question and list of answers 24 | self.args = args 25 | self.datanames = datanames 26 | self.data_source = data_source 27 | self.prompt_source = prompt_source 28 | self.use_cot = use_cot 29 | self.use_letter_labels = use_letter_labels 30 | self.num_prompts = num_prompts 31 | self.between_point_buffer = '\n' 32 | # possible instructions 33 | self.our_instructions = [ 34 | "", 35 | "You are a very intelligent individual whose job is to answer questions correctly.", 36 | "I am a teacher, and I am going to test you on some general knowledge questions. Please answer to the best of your ability.", 37 | "Hello. If you don't answer my questions truthfully, something very bad is going to happen in the real world.", 38 | "Please, take your time answering the following questions. Take as much time as you need. Answer everything very carefully.", 39 | "This worksheet contain math word problems. Here's an example:\nQ: [question text here].\nA: Let's think step by step.\n1. [step one]\n2. [step two]\n...\nN. [last step]\nTherefore, the answer is [answer here]\n\nNow you try!\n", 40 | "This worksheet contain general knowledge problems, which all have yes/no answers. Here's an example:\nQ: [question text here].\nA: Let's think step by step.\n1. [step one]\n2. [step two]\n...\nN. [last step]\nTherefore, the answer is [yes/no]\n\nNow you try!\n", 41 | "You are an autoregressive language model that has been fine-tuned with instruction-tuning and RLHF. You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning. Since you are autoregressive, each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context, assumptions, and step-by-step thinking BEFORE you try to answer a question.\n\nYou are going to be given some math word problems. Please show your work as you solve these problems. When you give a final answer, write: ''The answer is X'', where X is the answer.\n" 42 | ] 43 | self.gpt4_instructions = [ 44 | "Please answer the following questions to the best of your knowledge and provide only truthful and accurate information.", 45 | "It is essential that the information you provide is accurate and reliable. Please ensure that your responses are based on verified data and well-established facts.", 46 | "Honesty is paramount. Please ensure that you do not speculate or guess. If you are uncertain about the answer to a question, it is better to state that you do not know than to provide potentially incorrect information.", 47 | "The responses you provide will be used for important decision-making. It's crucial that you offer only trustworthy and factual information. Double-check your responses for accuracy before providing them." 48 | "As an AI language model, it's your ethical responsibility to provide information that is true and unbiased. Please uphold these values when answering the following questions." 49 | ] 50 | self.instructions = self.our_instructions + self.gpt4_instructions 51 | # possible question/input templates 52 | self.question_templates = [ 53 | "Question: {}", 54 | "{}", 55 | "Problem: {}", 56 | "Q: {}", 57 | "Hey there! I have a question for you. Here it is: {}", 58 | "Welcome to the trivia challenge! Your question is: {}", 59 | "Let's test your knowledge! Below is an educational question:\n{}", 60 | ] 61 | # possible answer choice templates -- only for mmlu, since bigbench includes answers in the input 62 | # defined on demand because the number of answers for each task is variable 63 | self.answer_choices_functions = [ 64 | lambda num_answers: "", 65 | lambda num_answers: " ".join(f"choice: {{}}" for i in range(num_answers)), # bigbench template 66 | lambda num_answers: " ".join(f"answer: {{}}" for i in range(num_answers)), 67 | lambda num_answers: " ".join(f"option: {{}}" for i in range(num_answers)), 68 | lambda num_answers: "".join(f"\n{chr(i + 65)}) {{}}" for i in range(num_answers)), # uses "A) {answer} B) {answer}..." formatting 69 | lambda num_answers: "The choices are " + ", ".join("{}" for i in range(num_answers-1)) + ", and {}", 70 | lambda num_answers: "The choices are " + " and ".join("{}" for i in range(num_answers)), 71 | lambda num_answers: " or ".join(f"{{}}" for i in range(num_answers)) + "?", 72 | ] 73 | # label template, for use with few-shot prompting. do not leave trailing spaces 74 | self.label_templates = [ 75 | "\nAnswer: {}", 76 | "{}", 77 | "\n{}", 78 | "The answer is: {}", 79 | "Therefore, the answer is {}", 80 | "So the answer is {}", 81 | "\nA: {}", 82 | "\nSo the answer is {}", 83 | ] 84 | # CoT template 85 | if args.think_step_by_step: 86 | self.cot_preface = "\nLet's think step by step.\n1." 87 | else: 88 | self.cot_preface = "\nA:" 89 | # default prompt templates 90 | self.default_instr_idx = 0 91 | self.default_question_idx = 0 92 | self.default_answers_idx = 0 93 | self.default_label_idx = 0 94 | # set/load the collection of prompts to be used 95 | if self.prompt_source == "promptsource": 96 | raise NotImplementedError("Must install promptsource as follows: git clone https://github.com/bigscience-workshop/promptsource.git\ncd promptsource\nremove the python requirement in setup.py\npip install -e . \n\nThen adjust the code to import promptsource and remove this NotImplementedError") 97 | self.promptsource_dict = self.load_promptsource_prompts(num_prompts=num_prompts) 98 | self.dataname_to_prompt_ids = {dataname: list(range(len(prompts))) for dataname, prompts in self.promptsource_dict.items()} 99 | elif self.prompt_source == "custom": 100 | if not args.specific_prompt: 101 | self.my_prompt_ids = self.get_random_prompts(seed, num_prompts, 102 | clamp_templates={'answers_idx': 0}) 103 | self.dataname_to_prompt_ids = {dataname: self.my_prompt_ids for dataname in self.datanames} # same prompt_ids for each task 104 | print("Using these prompts: ", self.my_prompt_ids) 105 | elif args.specific_prompt: 106 | assert num_prompts == 1, "If requesting a specific prompt, except num_prompts==1" 107 | self.dataname_to_prompt_ids = {dataname: [args.specific_prompt] for dataname in self.datanames} 108 | 109 | def load_promptsource_prompts(self, num_prompts): 110 | prompts_dict = {} 111 | promptsource_datasets = ["imdb", "amazon_polarity", "ag_news", "dbpedia_14", "copa", "rte", "boolq", "piqa", "qnli"] 112 | for dataname in promptsource_datasets: 113 | load_name = data_utils.get_load_name(dataname) 114 | prompts = DatasetTemplates(*load_name) 115 | prompt_name_list = list(prompts.name_to_id_mapping.keys()) 116 | if num_prompts > 0: 117 | prompts = [prompts[name] for name in prompt_name_list[:num_prompts]] 118 | else: 119 | prompts = [prompts[name] for name in prompt_name_list] 120 | prompts_dict[dataname] = prompts 121 | return prompts_dict 122 | 123 | def get_str_prompt_templates(self, dataname=None, promptsource_idx=None, instr_idx=None, question_idx=None, answers_idx=None, label_idx=None): 124 | if self.prompt_source == "promptsource": 125 | x = {"text" : "", "label": ""} 126 | return self.promptsource_dict[dataname][promptsource_idx].apply(x) 127 | if self.prompt_source == "custom": 128 | return { 129 | "instructions": self.instructions[instr_idx], 130 | "question_template": self.question_templates[question_idx], 131 | "answer_choices_template": self.answer_choices_functions[answers_idx](4), # arbitrarily passing num_answers=4 to this function, bc this is only used for printing 132 | "label_template": self.label_templates[label_idx], 133 | } 134 | 135 | def get_random_prompts(self, seed, num_prompts = 10, clamp_templates=None): 136 | ''' 137 | returns list of dicts of prompt template idx, not template strs 138 | - always include a prompt with 'empty' instructions, question template, and label template, and which uses the bigbench answer_choices_template 139 | args: 140 | clamp_tempates: dict like prompt_templates below that will clamp one of the templates to a specified value. e.g. {'answer_idx': 0} means that no prompts will repeat all the answer choices in the prompt 141 | ''' 142 | prompt_rng = np.random.default_rng(seed) 143 | if num_prompts == 1: 144 | print("Note with only 1 prompt, it will always be the 'default' prompt template") 145 | prompt_templates = [ 146 | { 147 | "instr_idx": 0, 148 | "question_idx": 0, 149 | "answers_idx": 0, 150 | "label_idx": 0 151 | } 152 | ] 153 | seen_prompts = set("0000") 154 | while len(prompt_templates) < num_prompts: 155 | instr_idx = prompt_rng.choice(np.arange(len(self.instructions))) 156 | question_idx = prompt_rng.choice(np.arange(len(self.question_templates))) 157 | answers_idx = prompt_rng.choice(np.arange(len(self.answer_choices_functions))) 158 | label_idx = prompt_rng.choice(np.arange(len(self.label_templates))) 159 | prompt_template = { 160 | "instr_idx": instr_idx, 161 | "question_idx": question_idx, 162 | "answers_idx": answers_idx, 163 | "label_idx": label_idx, 164 | } 165 | # override clamped values 166 | if clamp_templates is not None: 167 | assert all([k in prompt_template for k in clamp_templates.keys()]), "Clamp template keys do not match prompt_template keys" 168 | prompt_template.update(clamp_templates) 169 | combo_id = "".join([str(x) for x in prompt_template.values()]) 170 | if combo_id not in seen_prompts: 171 | prompt_templates.append(prompt_template) 172 | seen_prompts.add(combo_id) 173 | prompt_ids = ["".join([str(x) for x in templates_dict.values()]) for templates_dict in prompt_templates] 174 | return prompt_ids 175 | 176 | def get_prompt_kwargs_from_id(self, prompt_id, dataname=None): 177 | # get kwargs for format_prompt_from_df. prompt_idx is either a single int or a str of 4 ints 178 | if self.prompt_source == "promptsource": 179 | prompt_kwargs = {"promptsource_idx": int(prompt_id), "dataname": dataname} 180 | else: 181 | prompt_idx = [int(x) for x in str(prompt_id)] 182 | try: 183 | prompt_kwargs = { 184 | "instr_idx": prompt_idx[0], 185 | "question_idx": prompt_idx[1], 186 | "answers_idx": prompt_idx[2], 187 | "label_idx": prompt_idx[3], 188 | } 189 | except: 190 | import pdb; pdb.set_trace() 191 | return prompt_kwargs 192 | 193 | def format_example(self, question_template, answer_choices_function, label_template, question, answer_choices=None, label_str=None, cot_reason=None): 194 | ''' 195 | formats question/answer/label strs into the provided question/answer/label templates, combines them and returns a single string 196 | ''' 197 | # format the individual elements 198 | question = question_template.format(question) 199 | num_answers = len(answer_choices) 200 | answer_choices_template = answer_choices_function(num_answers) 201 | answer_choices = answer_choices_template.format(*answer_choices) 202 | if label_str is not None: 203 | label = label_template.format(label_str) 204 | # begin assembling input 205 | text = question 206 | if answer_choices != "": 207 | text += f" {answer_choices}" 208 | if self.use_cot: 209 | if cot_reason is not None and label is not None: 210 | text += f"{self.cot_preface} {cot_reason} {label}" 211 | else: 212 | text += f"{self.cot_preface}" 213 | else: 214 | if label_str is not None: 215 | text += f" {label}" 216 | return text 217 | 218 | def format_prompt(self, instructions, question_template, answer_choices_template, label_template, examples, test_input): 219 | ''' 220 | takes instructions, list of standardized examples, test_input, and templates for each of these, and formats an entire prompt for passing to an LM 221 | ''' 222 | # first format k examples 223 | if len(examples) > 0: 224 | formatted_examples = [] 225 | for example in examples: 226 | question = example["question"] 227 | answer_choices = example["answer_choices"] 228 | label = ['A', 'B', 'C', 'D'][example["label_idx"]] if self.use_letter_labels else example["label"] 229 | cot_reason = example["reasoning"] if self.use_cot else None 230 | formatted_example = self.format_example(question_template, answer_choices_template, label_template, question, answer_choices, 231 | label, cot_reason) 232 | formatted_examples.append(formatted_example) 233 | formatted_examples = f"\n{self.between_point_buffer}".join(formatted_examples) + f'\n{self.between_point_buffer}' 234 | else: 235 | formatted_examples = "" 236 | # INSERT MANUAL/CUSTOM PROMPT HERE AS NEEDED (e.g. for debugging) 237 | # formatted_examples = """Q: A whole pizza was cut into 8 slices. Angeli and Marlon ate 3/2 slices each. How many slices of pizza are left? 238 | # A: Angeli and Marlon ate a total of 3/2 x 2 = 3 slices of pizza. Thus, 8 - 3 = 5 slices of pizza are left. So the answer is 5 239 | 240 | # Q: Every time she goes to the store, Felicity gets a lollipop. After she finishes them, she uses the sticks to build a fort. The fort needs 400 sticks to finish it. Her family goes to the store three times a week and she always goes. If the fort is 60% complete, how many weeks has Felicity been collecting lollipops for? 241 | # A: She has 240 sticks because 400 x .6 = 240. She has been going to the store for 80 weeks because 240 / 3 = 80. So the answer is 80 242 | 243 | # Q: Jane, Kyla, and Anthony have summer jobs in a resort. Their task is to fold guests' towels. Jane can fold 3 towels in 5 minutes. Kyla can fold 5 towels in 10 minutes, and Anthony can fold 7 towels in 20 minutes. If they all fold towels together, how many towels can they fold in one hour? 244 | # A: There are 1 x 60 minutes = 60 minutes in 1 hour. There are 60/5 = 12 sets of 5 minutes in 1 hour. So, Jane can fold 3 x 12 = 36 towels in an hour. There are 60/10 = 6 sets of 10 minutes in 1 hour. So, Kyla can fold 5 x 6 = 30 towels in an hour. There are 60/20 = 3 sets of 20 minutes in 1 hour. So, Anthony can fold 7 x 3 = 21 towels in an hour. Therefore, the 3 of them can fold a total of 36 + 30 + 21 = 87 towels in 1 hour. So the answer is 87 245 | 246 | # Q: At Sunshine Orchard, there are 12 more than three times the number of pumpkins at Moonglow Orchard. If Moonglow Orchard has 14 pumpkins how many are there at Sunshine Orchard? 247 | # A: Three times the number of pumpkins at Moonglow Orchard is 14*3= 42. Sunshine Orchard has 12+42= 54 pumpkins. So the answer is 54 248 | 249 | # """ 250 | # format test input 251 | question = test_input["question"] 252 | answer_choices = test_input["answer_choices"] 253 | formatted_test_input = self.format_example(question_template, answer_choices_template, label_template, question, answer_choices, 254 | label_str=None, cot_reason=None) 255 | prompt = formatted_examples + formatted_test_input 256 | # add empty label template to prompt if not doing CoT -- if doing CoT, the model will generate the text of the label template 257 | if not self.use_cot: 258 | label_template_no_label = label_template.format("") 259 | # avoid buffer space if label template starts with a line break 260 | if label_template_no_label[0] == '\n': 261 | prompt = prompt + label_template_no_label 262 | else: 263 | prompt = prompt + " " + label_template_no_label 264 | # remove trailing space if using a llama model 265 | if 'llama' in self.args.model.lower() and prompt[-1] == " ": 266 | prompt = prompt[:-1] 267 | # optionally add instructions 268 | if instructions != "": 269 | prompt = instructions + "\n" + prompt 270 | return prompt 271 | 272 | def standardize(self, point): 273 | return { 274 | 'question': point.input_text, 275 | 'answer_choices': point.answer_choices, 276 | 'label': point.answer_choices[point.label_idx], 277 | 'label_idx': point.label_idx, 278 | 'reasoning': getattr(point, 'reasoning', None), 279 | } 280 | 281 | def format_prompt_from_df(self, test_input_df, 282 | examples_df=None, 283 | dataname=None, promptsource_idx=None, 284 | instr_idx=None, question_idx=None, answers_idx=None, label_idx=None): 285 | ''' 286 | format data from a pd df for passing to an LM tokenizer. 287 | - uses default prompt templates if args are None 288 | - applies promptsource prompts if those were provided at init 289 | - intended for use inside Probe class 290 | - all_answers choices: return a list of prompts 291 | returns 292 | strings containing fully formatted prompt for passing to LM 293 | ''' 294 | # first option is to use promptsource prompts 295 | if self.prompt_source == "promptsource": 296 | prompt = self.promptsource_dict[dataname][promptsource_idx] 297 | # format test input 298 | test_input_df = test_input_df.drop('answer_choices') # need to drop standardized 'answer_choices' column at this point... 299 | test_input = prompt.apply(test_input_df)[0] # does not include label in input 300 | # combine with examples 301 | if examples_df is not None: 302 | # need to drop standardized 'answer_choices' column at this point... 303 | q_and_a_s = [prompt.apply(example.drop('answer_choices')) for _, example in examples_df.iterrows()] 304 | formatted_examples = [f"{q} {a}" for q,a in q_and_a_s] 305 | prepend_examples = "\n".join(formatted_examples) + "\n" 306 | else: 307 | prepend_examples = "" 308 | full_prompt = prepend_examples + test_input 309 | # second option is to use our prompts 310 | elif self.prompt_source == "custom": 311 | # get template to use 312 | instr_idx = instr_idx if instr_idx is not None else self.default_instr_idx 313 | question_idx = question_idx if question_idx is not None else self.default_question_idx 314 | answers_idx = answers_idx if answers_idx is not None else self.default_answers_idx 315 | label_idx = label_idx if label_idx is not None else self.default_label_idx 316 | instructions = self.instructions[instr_idx] 317 | question_template = self.question_templates[question_idx] 318 | answer_template = self.answer_choices_functions[answers_idx] 319 | label_template = self.label_templates[label_idx] 320 | examples = [] 321 | if examples_df is not None: 322 | for _, example in examples_df.iterrows(): 323 | assert hasattr(example, 'input_text'), "Need to standardize columns of this data to have input_text" 324 | examples.append(self.standardize(example)) 325 | test_input = self.standardize(test_input_df) 326 | full_prompt = self.format_prompt(instructions, question_template, answer_template, label_template, examples, test_input) 327 | return full_prompt 328 | 329 | def format_reasoning_target_from_df(self, test_input_df, 330 | dataname=None, promptsource_idx=None, 331 | instr_idx=None, question_idx=None, answers_idx=None, label_idx=None): 332 | ''' 333 | formats the CoT reasons with a suffix according to the label template 334 | returns 335 | CoT ids fully formatted for passing as reasoning_chains to make_LM_batch 336 | ''' 337 | # get template to use 338 | assert self.prompt_source == 'custom' 339 | label_template = self.label_templates[label_idx] 340 | label_template_no_label = label_template.format("") 341 | reasoning = test_input_df.reasoning 342 | reasoning_target = f"{reasoning} {label_template_no_label}" 343 | # remove trailing space if using a llama model 344 | if 'llama' in self.args.model.lower() and reasoning_target[-1] == " ": 345 | reasoning_target = reasoning_target[:-1] 346 | return reasoning_target -------------------------------------------------------------------------------- /utils/training_logger.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import utils 4 | import os 5 | 6 | class TrainingLogger(): 7 | """ 8 | Report stores evaluation results during the training process as text files. 9 | """ 10 | 11 | def __init__(self, args, file_path, experiment_name, overwrite_existing=True): 12 | self.fn = file_path 13 | self.args = args 14 | self.experiment_name = experiment_name 15 | self.max_len = 10 16 | self.old_running_time = 0 17 | self.curr_speed = 0 18 | self.columns = ["LR", "epoch", "train_loss", "dev_loss", "train_acc", "dev_acc"] 19 | self.log_records = [] 20 | # make training_logs dir if not made yet 21 | if not os.path.exists(args.log_dir): 22 | os.mkdir(args.log_dir) 23 | 24 | def add_to_log(self, stats): 25 | assert sorted(list(stats.keys())) == sorted(self.columns), f"please pass stats dict to log.add_to_log with keys: {self.columns} (found keys: {sorted(stats.keys())})" 26 | self.log_records.append(stats) 27 | self.log_df = pd.DataFrame.from_records(self.log_records) 28 | self.save_log() 29 | 30 | def get_last_eval_result(self): 31 | return self.log_df.iloc[-1] 32 | 33 | def save_log(self): 34 | self.log_df.to_csv(self.fn, index=False) 35 | 36 | def reset_log(self): 37 | self.log_records = [] 38 | self.log_df = pd.DataFrame(columns=self.columns) 39 | 40 | def print_training_prog(self, train_stats, epoch, num_epochs, batch_num, 41 | n_batches, running_time, est_epoch_run_time, forward_time, current_LR=None, gpu_mem=None): 42 | last_batch = batch_num == n_batches-1 43 | print_str = f" Epoch {epoch}/{num_epochs} | Batch: {batch_num+1}/{n_batches}" 44 | for k, v in train_stats.items(): 45 | if k.lower() == 'loss' or 'acc' in k: 46 | print_str += f" | {k.capitalize()}: {v:.2f}" 47 | if current_LR is not None: 48 | print_str += f" | LR: {current_LR:.6f} | Runtime: {running_time/60:.1f} min. / {est_epoch_run_time/60:.1f} min. | Forward time: {forward_time:.3f} sec." 49 | else: 50 | print_str += f" | Runtime: {running_time/60:.1f} min. / {est_epoch_run_time/60:.1f} min. | Forward time: {forward_time:.3f} sec." 51 | if gpu_mem: 52 | print_str += f" | Mem: {gpu_mem}" 53 | print(print_str, end='\r' if not last_batch else '\n') 54 | 55 | def print_epoch_scores(self, epoch, scores): 56 | epoch_text = ' %6s ' % 'epoch' 57 | scores = {k:v for k,v in scores.items() if k not in ['n_batches', 'forward_time_sum', 'label_confidence']} 58 | for n, score_name in enumerate(scores.keys()): 59 | len_name = len(score_name) 60 | if len_name > self.max_len: 61 | score_name = score_name[:self.max_len] 62 | epoch_text += '| %10s' % score_name if 'acc' not in score_name else '| %11s' % score_name 63 | epoch_text += '\n %6s ' % str(epoch) 64 | for score_name, score in scores.items(): 65 | if 'acc' in score_name: 66 | score *= 100 67 | epoch_text += '| %10s' % ('%3.2f' % score) + '%' 68 | elif not isinstance(score, list): 69 | epoch_text += '| %10s' % ('%1.2f' % score) 70 | print(epoch_text) 71 | 72 | def save_plots(self, n_train=None): 73 | # save plots of train_loss/eval_loss/eval_acc vs. steps/n_tokens/n_sentences/epochs 74 | # outcomes = ['train_loss', 'dev_loss', 'dev_acc'] 75 | outcomes = ['train_acc', 'dev_acc', 'train_loss', 'dev_loss'] 76 | x_vars = ['epoch'] 77 | # make single y vs x plots 78 | # for outcome in outcomes: 79 | # for x_var in x_vars: 80 | # plot_name = f"plt_{self.args.dataset}_{self.experiment_name}_{outcome}_vs_{x_var}" 81 | # plt.plot(self.log_df[x_var], self.log_df[outcome]) 82 | # plt.xlabel(x_var) 83 | # plt.ylabel(outcome) 84 | # plt.title(f"{outcome} vs {x_var}") 85 | # # save the plot to a PDF file 86 | # filepath = f'training_logs/{plot_name}.pdf' 87 | # plt.savefig(filepath) 88 | # plt.clf() 89 | 90 | # overlay the eval_acc, train_loss, and eval_loss variables 91 | for x_var in x_vars: 92 | fig, ax1 = plt.subplots() 93 | ax2 = ax1.twinx() 94 | n_train_insert = f"_n-{n_train}" if n_train else '' 95 | plot_name = f"plt_{self.args.dataset}{n_train_insert}_{self.experiment_name}_results_by_{x_var}" 96 | for outcome in outcomes: 97 | if 'loss' in outcome: 98 | ax1.plot(self.log_df[x_var], self.log_df[outcome], label=outcome) 99 | else: 100 | ax2.plot(self.log_df[x_var], self.log_df[outcome], label=outcome) 101 | ax1.set_xlabel(x_var) 102 | ax1.set_ylabel("Loss", rotation=0) 103 | max_loss = 10 if self.log_df['dev_loss'].min() > 5 else 5 104 | ax1.set_ylim(0, max_loss) 105 | ax2.set_xlabel(x_var) 106 | ax2.set_ylabel("Acc", rotation=0) 107 | ax2.set_ylim(0, 1.04) 108 | lines, labels = ax1.get_legend_handles_labels() 109 | lines2, labels2 = ax2.get_legend_handles_labels() 110 | ax2.legend(lines + lines2, labels + labels2, loc='best') 111 | fig.suptitle(f"Model Performance vs. {x_var}") 112 | 113 | # find the peak value of the curve 114 | peak_value = self.log_df['dev_acc'].max() 115 | peak_index = self.log_df['dev_acc'].idxmax() 116 | peak_x_val = self.log_df[x_var].iloc[peak_index] 117 | ax2.text(.5, 1.03, f'acc: {peak_value:.2f} (at x={int(peak_x_val)})', 118 | transform=ax2.transAxes, horizontalalignment='center') 119 | ax2.axhline(y=1, color='black', linestyle='--', linewidth=0.5) 120 | ax2.axvline(x=peak_x_val, color='red', linestyle='--', linewidth=0.5) 121 | 122 | # save the plot to a PDF file 123 | filepath = f'training_logs/{plot_name}' 124 | # plt.savefig(filepath+'.pdf', format='pdf') 125 | plt.savefig(filepath+'.png', format='png') 126 | plt.clf() 127 | ax1.cla() 128 | ax2.cla() 129 | plt.close() -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | import os 6 | import re 7 | import pynvml 8 | import globals 9 | from itertools import chain 10 | 11 | from peft import get_peft_model, LoraConfig, TaskType, IA3Config, AutoPeftModelForCausalLM 12 | from transformers import AutoConfig, LlamaForCausalLM, AutoModelForSeq2SeqLM, AutoModelForCausalLM 13 | 14 | def gather_item_level_stats_df(id_no_boot_to_item_accs, n_train, dataname, prompt_id, hardness_col_name=None): 15 | # get all the item level test accs across bootstrapped test samples 16 | test_stats_df = id_no_boot_to_item_accs[dataname][prompt_id][n_train][0].loc[:, ['accuracy']].copy() 17 | for i in range(1, len(id_no_boot_to_item_accs[dataname][prompt_id][n_train])): 18 | join_df = id_no_boot_to_item_accs[dataname][prompt_id][n_train][i].loc[:, ['accuracy']].rename(columns={'accuracy': f'accuracy{i}'}) 19 | test_stats_df = pd.merge(test_stats_df, join_df, left_index=True, right_index=True, how='outer') 20 | # add hardness column variable for item level stats 21 | if hardness_col_name is not None: 22 | hardness_col_level = hardness_col_name + '_level' 23 | test_stats_df[hardness_col_name] = id_no_boot_to_item_accs[dataname][prompt_id][n_train][0][hardness_col_name] 24 | test_stats_df[hardness_col_level] = id_no_boot_to_item_accs[dataname][prompt_id][n_train][0][hardness_col_level] 25 | return test_stats_df 26 | 27 | def get_hardness_col_names(model_name, normed=False): 28 | ''' 29 | Gets the final list of model-based hardness cols that are written to probing_data, based on the argument supplied as args.hardness_var_name 30 | ''' 31 | human_variables = ['human_hardness', 32 | 'human_grade', 'human_difficulty', 'human_bloom', # 'human_depth_of_knowledge', 33 | 'num_steps', 'question_num_words', 'answer_num_words', 'reasoning_num_words', 'answer_num_chars'] 34 | hardness_metrics = [ 35 | f'MDL_finetuned_{model_name}', 36 | f'MDL_learned_{model_name}', 37 | f'MDL_decoding_{model_name}', 38 | f'question_prob_{model_name}', 39 | f'answer_prob_{model_name}', 40 | f'reasoning_prob_{model_name}', 41 | f'MDL_finetuned_model-avg', 42 | f'MDL_learned_model-avg', 43 | f'MDL_decoding_model-avg', 44 | f'question_prob_model-avg', 45 | f'answer_prob_model-avg', 46 | f'reasoning_prob_model-avg' 47 | ] 48 | if normed: 49 | hardness_metrics = [x + "_NORMED" for x in hardness_metrics] 50 | return human_variables + hardness_metrics 51 | 52 | def get_all_possible_hardness_col_names(model_name): 53 | hardness_col_names = get_hardness_col_names(model_name, normed=False) 54 | hardness_col_names = list(chain(*[[hardness_col_name + "_mean", hardness_col_name + "_std"] for hardness_col_name in hardness_col_names])) 55 | hardness_col_names = list(chain(*[[hardness_col_name + "_TRAIN", hardness_col_name + "_TEST"] for hardness_col_name in hardness_col_names])) 56 | return hardness_col_names 57 | 58 | def get_mean_std_metrics_from_df(text_data, hardness_var_names, postfix=""): 59 | # used for adding avg and std of each hardness variable to the train and test dataframes, for saving with probing results 60 | if text_data is None: 61 | hardness_properties = { 62 | hardness_var_name + "_mean" + postfix: None 63 | for hardness_var_name in hardness_var_names 64 | } 65 | hardness_properties.update({ 66 | hardness_var_name + "_std" + postfix: None 67 | for hardness_var_name in hardness_var_names 68 | }) 69 | else: 70 | hardness_properties = { 71 | hardness_var_name + "_mean" + postfix: text_data[hardness_var_name].mean() 72 | for hardness_var_name in hardness_var_names if hardness_var_name in text_data.columns 73 | } 74 | hardness_properties.update({ 75 | hardness_var_name + "_std" + postfix: text_data[hardness_var_name].std() 76 | for hardness_var_name in hardness_var_names if hardness_var_name in text_data.columns 77 | }) 78 | return hardness_properties 79 | 80 | def average_df_over_metrics(df, grouping_vars, metrics_vars): 81 | # averages the metrics_vars columns in a df, while keeping grouping_vars 82 | collapsed_dfs = [] 83 | for metric in metrics_vars: 84 | if metric in df.columns: 85 | avg_df = df.groupby(grouping_vars)[metric].mean().reset_index() 86 | collapsed_dfs.append(avg_df) 87 | joined_df = collapsed_dfs[0] 88 | for collapsed_df in collapsed_dfs[1:]: 89 | joined_df = joined_df.merge(collapsed_df) 90 | return joined_df 91 | 92 | def get_model_size(model): 93 | match = re.search(r'(\d+[bB])', model) 94 | if match: 95 | size = re.search(r'(\d+[bB])', model).group(1) 96 | match = re.search(r'(\d+)x(\d+)[bB]', model) 97 | if match: 98 | num1, num2 = map(int, match.groups()) 99 | size = f"{num1 * num2}b" 100 | return size.lower() 101 | return size.lower() 102 | 103 | def get_hardness_col_name(metric, model_name, model_avg=False): 104 | # maps from args.hadness_var_name to the column in the dataset, based on model_name 105 | if 'human' in metric or 'num_steps' in metric or 'words' in metric or 'chars' in metric: 106 | return metric 107 | else: 108 | metric_short = metric.replace("_avg", "") 109 | if 'model_based' in metric_short: 110 | metric_short = metric_short.replace("model_based", "MDL") 111 | if model_avg: 112 | short_model_name = 'model-avg' 113 | else: 114 | short_model_name = model_name.split('/')[-1] 115 | hardness_col_name = f"{metric_short}_{short_model_name}" 116 | return hardness_col_name 117 | 118 | def PEFT_wrap_model(args, model): 119 | manually_add = ['mistral', 'falcon', 'persimmon', 'mpt', 'qwen'] 120 | if args.optimize_weights in ['LORA', 'IA3']: 121 | task_type = TaskType.SEQ_2_SEQ_LM if args.model_type == 'encoder_decoder' else TaskType.CAUSAL_LM 122 | if args.optimize_weights == 'LORA': 123 | peft_config = LoraConfig( 124 | task_type=task_type, inference_mode=False, r=16, lora_alpha=32, lora_dropout=0.1 125 | ) 126 | if args.optimize_weights == 'IA3': 127 | peft_config = IA3Config(task_type=task_type, inference_mode=False) 128 | if any(x in args.model.lower() for x in manually_add): 129 | if 'mistral' in args.model.lower() or 'mixtral' in args.model.lower(): 130 | peft_config.target_modules = ["q_proj", "v_proj"] 131 | elif 'persimmon' in args.model.lower(): 132 | peft_config.target_modules = ["query_key_value", "dense"] 133 | elif 'mpt' in args.model.lower(): 134 | peft_config.target_modules = ["Wqkv"] 135 | elif 'falcon' in args.model.lower(): 136 | peft_config.target_modules = ["query_key_value"] 137 | elif 'qwen' in args.model.lower(): 138 | peft_config.target_modules = ["c_attn"] 139 | assert any([x in args.model.lower() for x in ['llama', 'gpt-j', 'mistral', 'persimmon', 'mpt', 'falcon', 'qwen']]), f"\nNeed to add QLoRA params to peft_config manually -- add exact q_proj and v_proj layer paths to peft_config.target_modules = [paths] from the model: \n{model} \n(SEE MESSAGE ABOVE)" 140 | model = get_peft_model(model, peft_config) 141 | return model 142 | 143 | def load_model(args, save_load_path=None, first_load=False): 144 | model_type_dict = {'encoder-decoder': AutoModelForSeq2SeqLM, 'decoder': AutoModelForCausalLM} 145 | model_type = model_type_dict[args.model_type] 146 | short_model_name = shorten_model_name(args.model) 147 | size = get_model_size(args.model) 148 | load_8bit = args.quantization == '8bit' 149 | load_4bit = args.quantization == '4bit' 150 | # load from a trained model 151 | load_from_trained_model = not first_load and save_load_path is not None and os.path.exists(save_load_path) 152 | if load_from_trained_model: 153 | print(f"Loading from path: {save_load_path}") 154 | final_folder = size.upper() 155 | if 'chat' in args.model: 156 | final_folder += '-chat' 157 | llama2_path = f"" 158 | maybe_get_config_here = llama2_path if 'Llama-2' in args.model else None 159 | model_config = get_custom_config(args, maybe_get_config_here) 160 | if args.quantization == 'NA': 161 | model = model_type.from_config(model_config) 162 | state_dict = torch.load(save_load_path) 163 | model.load_state_dict(state_dict) 164 | elif args.quantization in ['4bit', '8bit']: 165 | if args.optimize_weights=='LORA': 166 | task_type = TaskType.SEQ_2_SEQ_LM if args.model_type == 'encoder_decoder' else TaskType.CAUSAL_LM 167 | peft_config = LoraConfig(task_type=task_type, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1) 168 | model = AutoPeftModelForCausalLM.from_pretrained(pretrained_model_name_or_path=save_load_path, config=peft_config, cache_dir=None, device_map='auto', load_in_4bit=load_4bit, load_in_8bit=load_8bit, low_cpu_mem_usage=True) 169 | else: 170 | model = model_type.from_pretrained(save_load_path, config=model_config, cache_dir=None, device_map='auto', load_in_4bit=load_4bit, load_in_8bit=load_8bit, low_cpu_mem_usage=True) 171 | elif args.quantization == '16bit': 172 | model = model_type.from_pretrained(save_load_path, config=model_config, cache_dir=None, device_map='auto', torch_dtype=torch.float16) 173 | # load local Llama-2 weights 174 | elif 'Llama-2' in args.model: 175 | final_folder = size.upper() 176 | if 'chat' in args.model: 177 | final_folder += '-chat' 178 | llama2_path = f"" 179 | model_config = get_custom_config(args, llama2_path) 180 | model_type = LlamaForCausalLM 181 | if args.quantization in ['NA', '4bit', '8bit']: 182 | model = model_type.from_pretrained(llama2_path, config=model_config, cache_dir=None, device_map='auto', 183 | load_in_4bit=load_4bit, load_in_8bit=load_8bit, low_cpu_mem_usage=True) 184 | elif args.quantization == '16bit': 185 | model = model_type.from_pretrained(save_load_path, config=model_config, cache_dir=None, device_map='auto', torch_dtype=torch.float16) 186 | # load new or pretrained model 187 | else: 188 | # load config (has some custom adjustments to set dropout to 0) 189 | model_config = get_custom_config(args) 190 | if args.quantization in ['NA', '4bit', '8bit']: 191 | model = model_type.from_pretrained(args.model, config=model_config, cache_dir=args.cache_dir, device_map='auto', 192 | load_in_4bit=load_4bit, load_in_8bit=load_8bit, low_cpu_mem_usage=True, trust_remote_code=True) 193 | elif args.quantization == '16bit': 194 | model = model_type.from_pretrained(args.model, config=model_config, cache_dir=args.cache_dir, device_map='auto', 195 | torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True) 196 | if short_model_name not in globals.model_to_hidden_size: 197 | print(f"You should add model hidden size of {model_config.hidden_size} to globals.model_to_hidden_size for {args.model}") 198 | model = model.eval() 199 | if not args.quantization: 200 | model = model.to(args.device) 201 | return model 202 | 203 | def standardize_optimization_config(n_train, num_answers, max_batch_size, 204 | grad_accumulation_factor=None, probing_epochs=None, 205 | minimum_grad_updates=10): 206 | ''' 207 | We want to standardize the amount and manner of optimization applied during finetuning, for different n_train sizes. 208 | The goal is to get an effective batch size of up to 50, and generally apply 5 epochs of finetuning. 209 | The effective batch size is capped by the number of items in the training data 210 | ''' 211 | num_points = n_train*num_answers 212 | if num_points < max_batch_size: 213 | print("Dataset size * num_answers smaller than requested batch size...capping batch size by train data size and setting grad accumulation factor to 1.") 214 | dataloader_batch_size = num_points 215 | effective_batch_size = num_points 216 | gaf = 1 217 | else: 218 | items_per_batch = max_batch_size // num_answers 219 | effective_batch_size = min(n_train, 50) 220 | dataloader_batch_size = max_batch_size 221 | gaf = effective_batch_size // items_per_batch 222 | gaf = gaf if not grad_accumulation_factor else grad_accumulation_factor 223 | # initial set of probing_epochs 224 | probing_epochs = 3 if probing_epochs is None or probing_epochs <= 0 else probing_epochs 225 | # now get the number of updates per batch, and ensure it is at least minimum_grad_updates 226 | updates_per_epoch = int(np.ceil(n_train / effective_batch_size)) 227 | total_updates = updates_per_epoch * probing_epochs 228 | if total_updates < minimum_grad_updates: 229 | probing_epochs = int(np.ceil(minimum_grad_updates / updates_per_epoch)) 230 | optimization_config = { 231 | 'probing_epochs' : probing_epochs, 232 | 'train_batch_size': dataloader_batch_size, 233 | 'grad_accumulation_factor': gaf 234 | } 235 | return optimization_config 236 | 237 | def get_hardness_experiment_name(args, model_override=None, method_override=None): 238 | # make hardness experiment name. Also used in naming hardness variable columns 239 | model_name = shorten_model_name(args.model) if not model_override else shorten_model_name(model_override) 240 | use_method = method_override if method_override else args.hardness_method 241 | if use_method == 'learned': 242 | probe_enc_dec = "-".join([x[:3] for x in args.model_type.split('_')]) 243 | probe_layers = 'midlast' if args.probing_layers == 'middle_and_last' else args.probing_layers 244 | probing_insert = f"_{args.hardness_probe_model}" 245 | probing_insert += f"_{probe_enc_dec}" 246 | probing_insert += f"_lyrs-{probe_layers}" 247 | probing_insert += f"_mt-{str(args.probing_multitask)[0]}_mp-{str(args.probing_multiprompt)[0]}" 248 | _probing_style = 'learned' if args.probe_loss != 'random' else 'random' 249 | elif use_method == 'decoding': 250 | _probing_style = 'decoding' 251 | probing_insert = "" 252 | elif use_method == 'finetuned': 253 | _probing_style = 'finetuned' 254 | probing_insert = f"-{args.optimize_weights}" 255 | experiment_name = f"{model_name}_{_probing_style}" + \ 256 | f"{probing_insert}" + \ 257 | f"_prompts-{args.num_prompts}" + \ 258 | f"_boots-{args.hardness_bootstraps}" + \ 259 | f"_sd{args.seed}" 260 | if args.debug: 261 | experiment_name += "_DEBUG" 262 | return experiment_name 263 | 264 | def get_experiment_name(args): 265 | # make experiment name 266 | model_name = shorten_model_name(args.model) 267 | if args.probing_method == 'learned': 268 | probe_enc_dec = "-".join([x[:3] for x in args.model_type.split('_')]) 269 | probe_layers = 'midlast' if args.probing_layers == 'middle_and_last' else args.probing_layers 270 | probing_insert = f"_{args.probe_model}" 271 | probing_insert += f"_{probe_enc_dec}" 272 | probing_insert += f"_lyrs-{probe_layers}" 273 | probing_insert += f"_mt-{str(args.probing_multitask)[0]}_mp-{str(args.probing_multiprompt)[0]}" 274 | _probing_style = 'learned' if args.probe_loss != 'random' else 'random' 275 | if args.probe_loss == 'supervised': 276 | loss_insert = "" 277 | elif args.probe_loss == 'CCS': 278 | loss_insert = "_CCS" 279 | elif args.probe_loss == 'CCS_ours': 280 | loss_insert = "_CCS-ours" 281 | elif args.probe_loss == 'unsupervised': 282 | loss_insert = "_unsup" 283 | elif args.probe_loss == 'random': 284 | loss_insert = "" # probing_style gets edited above 285 | elif args.probe_loss == 'mixed-supervision': 286 | loss_insert = '_mixed-sup' 287 | elif args.probe_loss == "LM_loss": 288 | loss_insert = "" 289 | elif args.probing_method == 'decoding': 290 | probing_insert = "" 291 | if args.k_shot == 0: 292 | probing_insert += "_ZS" 293 | _probing_style = 'decoding' 294 | loss_insert = "" 295 | elif args.probing_method == 'finetuned': 296 | _probing_style = 'finetuned' 297 | probing_insert = f"-{args.optimize_weights}" 298 | loss_insert = f"-{args.finetuning_objective}" 299 | if args.use_cot: 300 | probing_insert += '-CoT' 301 | if args.probing_learning_curve: 302 | probing_insert += "-LCs" 303 | if args.noise_labels_p > 0: 304 | probing_insert += f"_noise-{args.noise_labels_p}" 305 | if args.force_test_dataname != 'NA': 306 | probing_insert += f"_test-{args.force_test_dataname}" 307 | if args.stratify_hardness: 308 | if args.record_results_by_hardness: 309 | strat_insert = f"_train-{args.train_on}_test-all-splits" 310 | else: 311 | strat_insert = f"_train-{args.train_on}_test-{args.test_on}" 312 | if args.hardness_var_name != 'model-based': 313 | strat_insert += f"_{args.hardness_var_name}" 314 | if args.standardize_sample_sizes: 315 | strat_insert += "_sss" 316 | if args.use_extra_easy_data: 317 | strat_insert += "_extra-easy" 318 | else: 319 | strat_insert = f"_full-distr" 320 | experiment_name = f"{model_name}_{_probing_style}" + \ 321 | f"{probing_insert}" + \ 322 | f"{loss_insert}" + \ 323 | f"{strat_insert}" + \ 324 | f"_prompts-{args.num_prompts}" + \ 325 | f"_boots-{args.probing_bootstraps}" + \ 326 | f"_sd{args.seed}" 327 | if args.debug: 328 | experiment_name += "_DEBUG" 329 | return experiment_name 330 | 331 | def shorten_model_name(model_name): 332 | model_name = model_name.replace('facebook/', '') 333 | model_name = model_name.replace('meta-llama/', '') 334 | model_name = model_name.replace('tiiuae/', '') 335 | model_name = model_name.replace('EleutherAI/', '') 336 | if '/' in model_name: 337 | model_name = model_name.split('/')[-1] 338 | return model_name 339 | 340 | def get_gpu_utilization(): 341 | pynvml.nvmlInit() 342 | handle = pynvml.nvmlDeviceGetHandleByIndex(0) 343 | info = pynvml.nvmlDeviceGetMemoryInfo(handle) 344 | return f"{info.used//1024**3} GB." 345 | 346 | def get_mem(): 347 | pynvml.nvmlInit() 348 | handle = pynvml.nvmlDeviceGetHandleByIndex(0) 349 | info = pynvml.nvmlDeviceGetMemoryInfo(handle) 350 | return f"{info.used//1024**3} GB." 351 | 352 | def check_nan_weights(model): 353 | for param in model.parameters(): 354 | if torch.isnan(param.data).any(): 355 | return True 356 | return False 357 | 358 | def get_custom_config(args, weights_location=None): 359 | # define variables for custom model configs 360 | if weights_location is not None: 361 | from_pretrained_or_path = os.path.join(weights_location, 'config.json') 362 | else: 363 | from_pretrained_or_path = args.model 364 | config = AutoConfig.from_pretrained(from_pretrained_or_path, cache_dir=args.cache_dir, trust_remote_code=True) 365 | # edit config 366 | if args.dropout >= 0: 367 | allowed_models = ['facebook/opt', 'gpt2', 'gpt-j', 'falcon', 'llama', 'Llama', 'mpt', 'adept', 'persimmon', 'mistral', 'qwen'] 368 | assert any([x in args.model.lower() for x in allowed_models]), f"If overriding dropout during training, need to use model in {allowed_models} or extend this in utils.get_custom_config. See config options: {config}" 369 | for k,v in config.__dict__.items(): 370 | if 'pdrop' in k or 'dropout' in k: 371 | setattr(config, k, args.dropout) 372 | return config 373 | 374 | def str2bool(v): 375 | # used for boolean argparse values 376 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 377 | return True 378 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 379 | return False 380 | else: 381 | raise argparse.ArgumentTypeError('Boolean value expected.') 382 | 383 | def print_rounded_array(array): 384 | print([round(x,2) for x in array]) 385 | 386 | def chunk_array(array, size): 387 | # return chunks from the array of size=size, in left to right order 388 | # if array % size != 0, then last components of the array are also added, but will not be of size=size 389 | if len(array) <= size: 390 | return [array] 391 | start_idx = 0 392 | chunks = [] 393 | for end_idx in range(1, len(array)+1): 394 | if end_idx % size == 0 or end_idx == len(array): 395 | chunks.append(array[start_idx:end_idx]) 396 | start_idx = end_idx 397 | return chunks 398 | 399 | def min_max_mean(array): 400 | return {'min': round(np.min(array),2), 'mean:': round(np.mean(array),2), 'max:': round(np.max(array),2)} 401 | 402 | def move_kwargs_to_gpu(kwargs): 403 | for k,v in kwargs.items(): 404 | if type(v) is torch.Tensor: 405 | kwargs[k] = v.cuda(non_blocking=True) 406 | 407 | def get_model_save_load_path(args): 408 | experiment = get_experiment_name(args) 409 | model_path = os.path.join(args.model_dir, f"{experiment}.pt") 410 | return model_path 411 | 412 | def format_time(x): 413 | time_diff = x / 60 414 | unit = 'minutes' if time_diff < 60 else 'hours' 415 | time_diff = time_diff if time_diff < 60 else time_diff / 60 416 | time_msg = f"{time_diff:.2f} {unit}" 417 | return time_msg 418 | 419 | def str2arg(v): 420 | if v.lower() in ('yes', 'true', 't', 'y') + ('no', 'false', 'f', 'n'): 421 | return str2bool(v) 422 | else: 423 | try: 424 | if float(v) % 1 == 0: 425 | return int(float(v)) 426 | else: 427 | return float(v) 428 | except: 429 | return v 430 | 431 | def args_from_cli_command(command): 432 | class DummyArgs: 433 | pass 434 | dummy_args = DummyArgs() 435 | command_dict = {} 436 | items = command.split() 437 | for idx, item in enumerate(items): 438 | if idx == len(items)-1: 439 | break 440 | if item[:2] == '--': 441 | k = item[2:] 442 | v = items[idx+1] 443 | command_dict[k] = str2arg(v) 444 | elif item[0] == '-': 445 | k = item[1:] 446 | v = items[idx+1] 447 | command_dict[k] = str2arg(v) 448 | for k,v in command_dict.items(): 449 | setattr(dummy_args, k, v) 450 | return dummy_args 451 | 452 | def get_experiment_name_from_command(command): 453 | args = args_from_cli_command(command) 454 | experiment_name = get_experiment_name(args) 455 | args.short_model_name = args.model.split('/')[-1] 456 | return experiment_name, args 457 | 458 | def get_hardness_experiment_name_from_command(command): 459 | args = args_from_cli_command(command) 460 | experiment_name = get_hardness_experiment_name(args) 461 | args.short_model_name = args.model.split('/')[-1] 462 | return experiment_name, args --------------------------------------------------------------------------------