├── .gitignore ├── LICENSE ├── README.md ├── assets ├── cover.gif └── cover.mp4 ├── base_model ├── __init__.py ├── base.py ├── llama3instruct.py └── mistral03instruct.py ├── cfgs ├── base_model │ ├── llama3i8b.yaml │ └── mistral03i7b.yaml ├── config.yaml ├── mode │ ├── eval.yaml │ └── training.yaml ├── optimization │ ├── cem.yaml │ ├── reinforce.yaml │ └── rsm.yaml ├── policy │ ├── default.yaml │ └── wcomb.yaml └── task │ ├── ablation_tasks │ ├── few_shot_arc_challenge_20.yaml │ ├── few_shot_arc_challenge_3.yaml │ └── few_shot_arc_challenge_5.yaml │ ├── ai2_arc.yaml │ ├── cls.yaml │ ├── few_shot_arc_challenge.yaml │ ├── few_shot_humaneval.yaml │ ├── few_shot_math.yaml │ ├── gsm8k.yaml │ ├── math.yaml │ └── mbpp2.yaml ├── evaluation └── fishfarm │ ├── fishfarm │ ├── __init__.py │ ├── chat_templates.py │ ├── imports.py │ ├── logging.py │ ├── models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── tokenization_utils.py │ │ └── vllm_model.py │ ├── tasks │ │ ├── __init__.py │ │ ├── ai2_arc.py │ │ ├── base.py │ │ ├── competation_math.py │ │ ├── evalplus │ │ │ ├── __init__.py │ │ │ ├── data.py │ │ │ ├── evaluation.py │ │ │ ├── generation.py │ │ │ ├── sanitization.py │ │ │ └── task.py │ │ └── language_restricted_math.py │ └── version.py │ ├── pyproject.toml │ └── tox.ini ├── logging_utils.py ├── optim_modules.py ├── policy ├── __init__.py ├── base.py └── weighted_combination.py ├── requirements.txt ├── scripts ├── eval_few_shot.sh ├── eval_prompt_based.sh └── train_task_expert.sh ├── svd_reinforce_hydra.py ├── tasks ├── __init__.py ├── arc.py ├── base.py ├── cls.py ├── gsm8k.py ├── math.py └── mbpp2.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | playground* 2 | wandb 3 | results 4 | results_eval 5 | saved_models 6 | outputs 7 | reference_code 8 | messy_scripts 9 | *_decomposed_params.pt 10 | **/__pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Transformer2: Self-adaptive LLMs 🐙

3 | 4 |

5 | 📚 [Paper] | 6 | 📄 [Blog] 7 |

8 | 9 | Self-adaptive large language models (LLMs) aim to solve the challenges posed by traditional fine-tuning methods, which are often computationally intensive and static in their ability to handle diverse tasks. 10 | 11 | We are excited to introduce Transformer², a novel self-adaptation framework that adapts LLMs for unseen tasks in real-time by selectively adjusting only the singular components of their weight matrices. 12 | During inference, Transformer² employs a two-pass mechanism: first, a dispatch system identifies the task properties, and then task-specific "expert" vectors, trained using reinforcement learning, are dynamically mixed to obtain targeted behavior for the incoming prompt. 13 |

14 | 15 |
16 |
17 | 18 | 19 | ## Installation 20 | 21 | ### 1. Clone the Repo 22 | ``` 23 | git clone https://github.com/SakanaAI/self-adaptive-llms 24 | cd self-adaptive-llms 25 | ``` 26 | 27 | ### 2. Install Libraries 28 | ```bash 29 | conda create -n t2 python=3.11 -y 30 | conda activate t2 31 | pip install --upgrade pip 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ### 3. Install Tasks Evaluator 36 | ```bash 37 | cd evaluation/fishfarm 38 | pip install -e . 39 | ``` 40 | 41 | ## Usage 42 | We provide example scripts for both training and evaluation. 43 | 44 | Please change the argument in the provided script to choose among models and tasks 45 | 46 | ### Training 47 | 48 | ```bash 49 | bash scripts/train_task_expert.sh 50 | ``` 51 | 52 | ### Evaluation 53 | 54 | #### Prompt-based evaluation 55 | Classification experts can be loaded by specifying the CLS_EXPERT_PATH in the script. 56 | ```bash 57 | bash scripts/eval_prompt_based.sh 58 | ``` 59 | 60 | #### Few-shots evaluation 61 | ```bash 62 | bash scripts/eval_few_shot.sh 63 | ``` 64 | 65 | ## Citation 66 | If you find **Transformer^2** useful for your research, please cite using this BibTeX: 67 | ``` 68 | @misc{sun2025transformersquaredselfadaptivellms, 69 | title={Transformer-Squared: Self-adaptive LLMs}, 70 | author={Qi Sun and Edoardo Cetin and Yujin Tang}, 71 | year={2025}, 72 | eprint={2501.06252}, 73 | archivePrefix={arXiv}, 74 | primaryClass={cs.LG}, 75 | url={https://arxiv.org/abs/2501.06252}, 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /assets/cover.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/self-adaptive-llms/03a41aed1cfc57276e72ad5a42845a04b356db1e/assets/cover.gif -------------------------------------------------------------------------------- /assets/cover.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/self-adaptive-llms/03a41aed1cfc57276e72ad5a42845a04b356db1e/assets/cover.mp4 -------------------------------------------------------------------------------- /base_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseModel 2 | from .llama3instruct import Llama3Instruct8B 3 | from .mistral03instruct import MistralV03Instruct7B 4 | -------------------------------------------------------------------------------- /base_model/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class BaseModel(ABC): 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def get_model_id(self): 10 | raise NotImplementedError 11 | 12 | @abstractmethod 13 | def get_model_name(self): 14 | raise NotImplementedError 15 | 16 | @abstractmethod 17 | def get_param_file(self, param_folder_path=""): 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /base_model/llama3instruct.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .base import BaseModel 4 | 5 | 6 | class Llama3Instruct8B(BaseModel): 7 | def __init__(self): 8 | self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct" 9 | self.dec_param_file_n = "llama3_decomposed_params.pt" 10 | 11 | def get_model_id(self): 12 | return self.model_id 13 | 14 | def get_model_name(self): 15 | return self.model_id.split("/")[1] 16 | 17 | def get_param_file(self, param_folder_path=""): 18 | return os.path.join(param_folder_path, self.dec_param_file_n) 19 | -------------------------------------------------------------------------------- /base_model/mistral03instruct.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .base import BaseModel 4 | 5 | 6 | class MistralV03Instruct7B(BaseModel): 7 | def __init__(self): 8 | self.model_id = "mistralai/Mistral-7B-Instruct-v0.3" 9 | self.dec_param_file_n = "mistral_decomposed_params.pt" 10 | 11 | def get_model_id(self): 12 | return self.model_id 13 | 14 | def get_model_name(self): 15 | return self.model_id.split("/")[1] 16 | 17 | def get_param_file(self, param_folder_path=""): 18 | return os.path.join(param_folder_path, self.dec_param_file_n) 19 | -------------------------------------------------------------------------------- /cfgs/base_model/llama3i8b.yaml: -------------------------------------------------------------------------------- 1 | base_model: 2 | _target_: base_model.Llama3Instruct8B 3 | 4 | 5 | base_model_name: llama3i8b 6 | 7 | # reference_params_results: 8 | # - 'saved_models/llama3i8b/gsm8k/learnable_params.pt' 9 | # - 'saved_models/llama3i8b/mbpp/learnable_params.pt' 10 | # - 'saved_models/llama3i8b/ai2arc/learnable_params.pt' 11 | 12 | reference_params_results: 13 | - "ckpts/learnable_params/llama3_8b_instruct_gsm8k_svd_pg_mlp.pt" 14 | - "ckpts/learnable_params/llama3_8b_instruct_mbpp_pro_svd_pg_mlp.pt" 15 | - "ckpts/learnable_params/llama3_8b_instruct_gsm8k_svd_pg_mlp.pt" -------------------------------------------------------------------------------- /cfgs/base_model/mistral03i7b.yaml: -------------------------------------------------------------------------------- 1 | base_model: 2 | _target_: base_model.MistralV03Instruct7B 3 | 4 | 5 | base_model_name: mistral03i7b 6 | 7 | reference_params_results: 8 | - 'saved_models/mistral03i7b/gsm8k/policy_params.pt' 9 | - 'saved_models/mistral03i7b/mbpp/policy_params.pt' 10 | - 'saved_models/mistral03i7b/ai2arc/policy_params.pt' -------------------------------------------------------------------------------- /cfgs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - policy@_global_: default 4 | - task@_global_: gsm8k 5 | - base_model@_global_: llama3i8b 6 | - optimization@_global_: reinforce 7 | - mode@_global_: training 8 | 9 | num_iters: 2000 10 | test_interval: 10 11 | lr: 2e-3 12 | batch_size: 256 13 | seed: 42 14 | init_val: 0.1 15 | test_only: false 16 | model_dir: null 17 | save_legacy_params: false 18 | use_lora: false 19 | prompt_based_eval: false 20 | experts_path_dict: null 21 | 22 | run_name: null 23 | 24 | load_ckpt: null 25 | exp_suffix: 'st' 26 | 27 | exp_name: ${base_model_name}/${optim_name}-${exp_suffix} 28 | 29 | wandb_log: true # enabled by default 30 | wandb_project: shakeoff 31 | wandb_group_name: ${exp_name} 32 | extract_svd: false 33 | 34 | out_dir: results 35 | 36 | hydra: 37 | run: 38 | dir: ${out_dir}/ -------------------------------------------------------------------------------- /cfgs/mode/eval.yaml: -------------------------------------------------------------------------------- 1 | exp_name: eval_${base_model_name}/temp-lr${lr}-mGN${max_grad_norm}-klC${kl_ref_coeff}-r${rw_strategy}-${exp_suffix}-r 2 | 3 | test_only: true 4 | load_ckpt: null 5 | use_lora: false 6 | 7 | prompt_based_eval: false 8 | experts_path_dict: 9 | code: null 10 | math: null 11 | reasoning: null 12 | other: null 13 | 14 | wandb_project: T^2_eval 15 | wandb_group_name: ${exp_name} 16 | out_dir: results_eval -------------------------------------------------------------------------------- /cfgs/mode/training.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/self-adaptive-llms/03a41aed1cfc57276e72ad5a42845a04b356db1e/cfgs/mode/training.yaml -------------------------------------------------------------------------------- /cfgs/optimization/cem.yaml: -------------------------------------------------------------------------------- 1 | 2 | optimization_algorithm: 3 | _target_: optim_modules.CEM 4 | elite_ratio: ${elite_ratio} 5 | pop_size: ${pop_size} 6 | min_trainable_param: ${min_trainable_param} 7 | max_trainable_param: ${max_trainable_param} 8 | optim_ema: ${optim_ema} 9 | re_eval_best: ${re_eval_best} 10 | use_loglikelihood_for_ties: ${use_loglikelihood_for_ties} 11 | 12 | 13 | pop_size: 32 14 | elite_ratio: 0.2 15 | min_trainable_param: 0 16 | max_trainable_param: 1 17 | optim_ema: 0 18 | re_eval_best: True 19 | use_loglikelihood_for_ties: true 20 | optim_name: CEM-pop${pop_size}e${elite_ratio}-[${min_trainable_param}-${max_trainable_param}]-tieswLL${use_loglikelihood_for_ties} -------------------------------------------------------------------------------- /cfgs/optimization/reinforce.yaml: -------------------------------------------------------------------------------- 1 | 2 | optimization_algorithm: 3 | _target_: optim_modules.Reinforce 4 | # policy: ${policy} 5 | # gpu: ${gpu} 6 | max_grad_norm: ${max_grad_norm} 7 | lr: ${lr} 8 | rw_norm: ${rw_norm} 9 | rw_clip: ${rw_clip} 10 | kl_ref_coeff: ${kl_ref_coeff} 11 | 12 | 13 | # policy: 14 | # gpu: 15 | max_grad_norm: 1e-3 16 | lr: 2e-3 17 | rw_norm: 0 18 | rw_clip: null 19 | kl_ref_coeff: 0 20 | rw_strategy: rN${rw_norm}C${rw_clip} 21 | optim_name: RL-lr${lr}-mGN${max_grad_norm}-klC${kl_ref_coeff}-r${rw_strategy} -------------------------------------------------------------------------------- /cfgs/optimization/rsm.yaml: -------------------------------------------------------------------------------- 1 | 2 | optimization_algorithm: 3 | _target_: optim_modules.RandomShooting 4 | # policy: ${policy} 5 | # gpu: ${gpu} 6 | pop_size: ${pop_size} 7 | min_trainable_param: ${min_trainable_param} 8 | max_trainable_param: ${max_trainable_param} 9 | optim_ema: ${optim_ema} 10 | re_eval_best: ${re_eval_best} 11 | use_loglikelihood_for_ties: ${use_loglikelihood_for_ties} 12 | 13 | 14 | pop_size: 32 15 | min_trainable_param: 0 16 | max_trainable_param: 1 17 | optim_ema: 0 18 | re_eval_best: True 19 | use_loglikelihood_for_ties: false 20 | optim_name: RSML-pop${pop_size}-[${min_trainable_param}-${max_trainable_param}]-tieswLL${use_loglikelihood_for_ties} -------------------------------------------------------------------------------- /cfgs/policy/default.yaml: -------------------------------------------------------------------------------- 1 | shakeoff_policy: 2 | _target_: policy.Policy 3 | init_val: ${init_val} 4 | mode: ${policy_mode} 5 | max_mult: ${max_mult} 6 | 7 | policy_mode: 1 8 | max_mult: 1 9 | policy_name: ${policy_mode}_mm${max_mult} 10 | -------------------------------------------------------------------------------- /cfgs/policy/wcomb.yaml: -------------------------------------------------------------------------------- 1 | 2 | 3 | shakeoff_policy: 4 | _target_: policy.WeightedCombination 5 | base_policy_cfg: null 6 | params_paths: ${reference_params_results} 7 | norm_coeffs: ${norm_coeffs} 8 | per_layer: ${per_layer} 9 | init_values: ${init_values} 10 | 11 | norm_coeffs: true 12 | per_layer: false 13 | init_values: null 14 | 15 | policy_name: Wcomb_n${norm_coeffs}_p${per_layer} 16 | -------------------------------------------------------------------------------- /cfgs/task/ablation_tasks/few_shot_arc_challenge_20.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.FewShotTask 3 | wrapped_task: 4 | _target_: tasks.AI2ArcTask 5 | wrapped_split: ${wrapped_split} 6 | shots: ${task_shots} 7 | seed: ${task_loader_seed} 8 | 9 | 10 | wrapped_split: transfer 11 | task_shots: 20 12 | task_loader_seed: 38 13 | 14 | task_name: arc_chal_${task_shots}shots 15 | 16 | -------------------------------------------------------------------------------- /cfgs/task/ablation_tasks/few_shot_arc_challenge_3.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.FewShotTask 3 | wrapped_task: 4 | _target_: tasks.AI2ArcTask 5 | wrapped_split: ${wrapped_split} 6 | shots: ${task_shots} 7 | seed: ${task_loader_seed} 8 | 9 | 10 | wrapped_split: transfer 11 | task_shots: 3 12 | task_loader_seed: 38 13 | 14 | task_name: arc_chal_${task_shots}shots 15 | 16 | -------------------------------------------------------------------------------- /cfgs/task/ablation_tasks/few_shot_arc_challenge_5.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.FewShotTask 3 | wrapped_task: 4 | _target_: tasks.AI2ArcTask 5 | wrapped_split: ${wrapped_split} 6 | shots: ${task_shots} 7 | seed: ${task_loader_seed} 8 | 9 | 10 | wrapped_split: transfer 11 | task_shots: 5 12 | task_loader_seed: 38 13 | 14 | task_name: arc_chal_${task_shots}shots 15 | 16 | -------------------------------------------------------------------------------- /cfgs/task/ai2_arc.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.AI2ArcTask 3 | 4 | 5 | task_name: ai2_arc 6 | 7 | -------------------------------------------------------------------------------- /cfgs/task/cls.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.ClsTask 3 | 4 | 5 | task_name: Cls 6 | 7 | -------------------------------------------------------------------------------- /cfgs/task/few_shot_arc_challenge.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.FewShotTask 3 | wrapped_task: 4 | _target_: tasks.AI2ArcTask 5 | wrapped_split: ${wrapped_split} 6 | shots: ${task_shots} 7 | seed: ${task_loader_seed} 8 | 9 | 10 | wrapped_split: transfer 11 | task_shots: 10 12 | task_loader_seed: 38 13 | 14 | task_name: arc_chal_${task_shots}shots 15 | 16 | -------------------------------------------------------------------------------- /cfgs/task/few_shot_humaneval.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.FewShotTask 3 | wrapped_task: 4 | _target_: tasks.Mbpp2Task2 5 | wrapped_split: ${wrapped_split} 6 | shots: ${task_shots} 7 | seed: ${task_loader_seed} 8 | 9 | 10 | wrapped_split: transfer 11 | task_shots: 10 12 | task_loader_seed: 16 13 | 14 | task_name: humaneval_${task_shots}shots 15 | -------------------------------------------------------------------------------- /cfgs/task/few_shot_math.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.FewShotTask 3 | wrapped_task: 4 | _target_: tasks.MathTask 5 | wrapped_split: ${wrapped_split} 6 | shots: ${task_shots} 7 | seed: ${task_loader_seed} 8 | 9 | 10 | wrapped_split: test 11 | task_shots: 10 12 | task_loader_seed: 27 13 | 14 | task_name: math_${task_shots}shots 15 | 16 | -------------------------------------------------------------------------------- /cfgs/task/gsm8k.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.Gsm8kTask 3 | 4 | 5 | task_name: gsm8k 6 | 7 | -------------------------------------------------------------------------------- /cfgs/task/math.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.MathTask 3 | 4 | 5 | task_name: math 6 | 7 | -------------------------------------------------------------------------------- /cfgs/task/mbpp2.yaml: -------------------------------------------------------------------------------- 1 | task_loader: 2 | _target_: tasks.Mbpp2Task 3 | 4 | 5 | task_name: mbpp 6 | 7 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/__init__.py: -------------------------------------------------------------------------------- 1 | from . import chat_templates, models, tasks 2 | from .models import Message, Model, Role 3 | from .tasks import Task, TaskResult 4 | 5 | __all__ = [ 6 | "chat_templates", 7 | "tasks", 8 | "models", 9 | "Task", 10 | "TaskResult", 11 | "Model", 12 | "Message", 13 | "Role", 14 | ] 15 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/chat_templates.py: -------------------------------------------------------------------------------- 1 | LLAMA3 = ( 2 | "{% set loop_messages = messages %}" 3 | "{% for message in loop_messages %}" 4 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>" 5 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}" 6 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}" 7 | "{% endif %}" 8 | "{{ content }}" 9 | "{% endfor %}" 10 | "{% if add_generation_prompt %}" 11 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" 12 | "{% endif %}" 13 | ) 14 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/imports.py: -------------------------------------------------------------------------------- 1 | from types import TracebackType 2 | from typing import Optional, Tuple, Type 3 | 4 | 5 | class _DeferredImportExceptionContextManager: 6 | """Context manager to defer exceptions from imports. 7 | 8 | Catches :exc:`ImportError` and :exc:`SyntaxError`. 9 | If any exception is caught, this class raises an :exc:`ImportError` when being checked. 10 | 11 | """ 12 | 13 | def __init__(self) -> None: 14 | self._deferred: Optional[Tuple[Exception, str]] = None 15 | 16 | def __enter__(self) -> "_DeferredImportExceptionContextManager": 17 | """Enter the context manager. 18 | 19 | Returns: 20 | Itself. 21 | 22 | """ 23 | return self 24 | 25 | def __exit__( 26 | self, 27 | exc_type: Optional[Type[Exception]], 28 | exc_value: Optional[Exception], 29 | traceback: Optional[TracebackType], 30 | ) -> Optional[bool]: 31 | """Exit the context manager. 32 | 33 | Args: 34 | exc_type: 35 | Raised exception type. :obj:`None` if nothing is raised. 36 | exc_value: 37 | Raised exception object. :obj:`None` if nothing is raised. 38 | traceback: 39 | Associated traceback. :obj:`None` if nothing is raised. 40 | 41 | Returns: 42 | :obj:`None` if nothing is deferred, otherwise :obj:`True`. 43 | :obj:`True` will suppress any exceptions avoiding them from propagating. 44 | 45 | """ 46 | if isinstance(exc_value, (ImportError, SyntaxError)): 47 | if isinstance(exc_value, ImportError): 48 | message = ( 49 | "Tried to import '{}' but failed. Please make sure that the package is " 50 | "installed correctly to use this feature. Actual error: {}." 51 | ).format(exc_value.name, exc_value) 52 | elif isinstance(exc_value, SyntaxError): 53 | message = ( 54 | "Tried to import a package but failed due to a syntax error in {}. Please " 55 | "make sure that the Python version is correct to use this feature. Actual " 56 | "error: {}." 57 | ).format(exc_value.filename, exc_value) 58 | else: 59 | assert False 60 | 61 | self._deferred = (exc_value, message) 62 | return True 63 | return None 64 | 65 | def is_successful(self) -> bool: 66 | """Return whether the context manager has caught any exceptions. 67 | 68 | Returns: 69 | :obj:`True` if no exceptions are caught, :obj:`False` otherwise. 70 | 71 | """ 72 | return self._deferred is None 73 | 74 | def check(self) -> None: 75 | """Check whether the context manager has caught any exceptions. 76 | 77 | Raises: 78 | :exc:`ImportError`: 79 | If any exception was caught from the caught exception. 80 | 81 | """ 82 | if self._deferred is not None: 83 | exc_value, message = self._deferred 84 | raise ImportError(message) from exc_value 85 | 86 | 87 | def try_import() -> _DeferredImportExceptionContextManager: 88 | """Create a context manager that can wrap imports of optional packages to defer exceptions. 89 | 90 | Returns: 91 | Deferred import context manager. 92 | 93 | """ 94 | return _DeferredImportExceptionContextManager() 95 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from Optuna repo: 3 | https://github.com/optuna/optuna/blob/2595653638506e1b7e025a966a220984a59ab936/optuna/logging.py 4 | Removed some comments for less verbosity. 5 | 6 | In general, `logger.info` is preferred over `print` since it contains module name and timestamp; 7 | We recommend the use of logger object for the fishfarm developers. 8 | 9 | Inside fishfarm, we can call `get_logger(__name__)` from each python file. 10 | Then the root logger format and level are applied to that logger object. 11 | """ 12 | 13 | from __future__ import annotations 14 | 15 | import logging 16 | import os 17 | import sys 18 | import threading 19 | from logging import CRITICAL, DEBUG, ERROR, FATAL, INFO, WARN, WARNING 20 | 21 | import colorlog 22 | 23 | __all__ = [ 24 | "CRITICAL", 25 | "DEBUG", 26 | "ERROR", 27 | "FATAL", 28 | "INFO", 29 | "WARN", 30 | "WARNING", 31 | ] 32 | 33 | _lock: threading.Lock = threading.Lock() 34 | _default_handler: logging.Handler | None = None 35 | 36 | 37 | def create_default_formatter() -> logging.Formatter: 38 | """Create a default formatter of log messages. 39 | 40 | This function is not supposed to be directly accessed by library users. 41 | """ 42 | header = "[%(levelname)1.1s %(asctime)s %(name)s]" 43 | message = "%(message)s" 44 | if _color_supported(): 45 | return colorlog.ColoredFormatter( 46 | f"%(log_color)s{header}%(reset)s {message}", 47 | ) 48 | return logging.Formatter(f"{header} {message}") 49 | 50 | 51 | def _color_supported() -> bool: 52 | """Detection of color support.""" 53 | # NO_COLOR environment variable: 54 | if os.environ.get("NO_COLOR", None): 55 | return False 56 | 57 | if not hasattr(sys.stderr, "isatty") or not sys.stderr.isatty(): 58 | return False 59 | else: 60 | return True 61 | 62 | 63 | def _get_library_name() -> str: 64 | return __name__.split(".")[0] 65 | 66 | 67 | def _get_library_root_logger() -> logging.Logger: 68 | return logging.getLogger(_get_library_name()) 69 | 70 | 71 | def _configure_library_root_logger() -> None: 72 | global _default_handler 73 | 74 | with _lock: 75 | if _default_handler: 76 | # This library has already configured the library root logger. 77 | return 78 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 79 | _default_handler.setFormatter(create_default_formatter()) 80 | 81 | # Apply our default configuration to the library root logger. 82 | library_root_logger: logging.Logger = _get_library_root_logger() 83 | library_root_logger.addHandler(_default_handler) 84 | library_root_logger.setLevel(logging.INFO) 85 | library_root_logger.propagate = False 86 | 87 | 88 | def _reset_library_root_logger() -> None: 89 | global _default_handler 90 | 91 | with _lock: 92 | if not _default_handler: 93 | return 94 | 95 | library_root_logger: logging.Logger = _get_library_root_logger() 96 | library_root_logger.removeHandler(_default_handler) 97 | library_root_logger.setLevel(logging.NOTSET) 98 | _default_handler = None 99 | 100 | 101 | def get_logger(name: str) -> logging.Logger: 102 | """Return a logger with the specified name. 103 | name's prefix should be `fishfarm.` (just like __name__ variable), 104 | otherwise root logger settings will be not reflected. 105 | This function is not supposed to be directly accessed by library users. 106 | """ 107 | 108 | _configure_library_root_logger() 109 | return logging.getLogger(name) 110 | 111 | 112 | def get_verbosity() -> int: 113 | """Return the current level for the fishfarm's root logger. 114 | 115 | Returns: 116 | Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``. 117 | 118 | .. note:: 119 | fishfarm has following logging levels: 120 | 121 | - ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL`` 122 | - ``fishfarm.logging.ERROR`` 123 | - ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN`` 124 | - ``fishfarm.logging.INFO`` 125 | - ``fishfarm.logging.DEBUG`` 126 | """ 127 | 128 | _configure_library_root_logger() 129 | return _get_library_root_logger().getEffectiveLevel() 130 | 131 | 132 | def set_verbosity(verbosity: int) -> None: 133 | """Set the level for the fishfarm's root logger. 134 | 135 | Args: 136 | verbosity: 137 | Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``. 138 | 139 | .. note:: 140 | fishfarm has following logging levels: 141 | 142 | - ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL`` 143 | - ``fishfarm.logging.ERROR`` 144 | - ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN`` 145 | - ``fishfarm.logging.INFO`` 146 | - ``fishfarm.logging.DEBUG`` 147 | """ 148 | 149 | _configure_library_root_logger() 150 | _get_library_root_logger().setLevel(verbosity) 151 | 152 | 153 | def disable_default_handler() -> None: 154 | """Disable the default handler of the fishfarm's root logger.""" 155 | 156 | _configure_library_root_logger() 157 | 158 | assert _default_handler is not None 159 | _get_library_root_logger().removeHandler(_default_handler) 160 | 161 | 162 | def enable_default_handler() -> None: 163 | """Enable the default handler of the fishfarm's root logger.""" 164 | 165 | _configure_library_root_logger() 166 | 167 | assert _default_handler is not None 168 | _get_library_root_logger().addHandler(_default_handler) 169 | 170 | 171 | def disable_propagation() -> None: 172 | """Disable propagation of the library log outputs. 173 | 174 | Note that log propagation is disabled by default. You only need to use this function 175 | to stop log propagation when you use :func:`~fishfarm.logging.enable_propagation()`. 176 | """ 177 | 178 | _configure_library_root_logger() 179 | _get_library_root_logger().propagate = False 180 | 181 | 182 | def enable_propagation() -> None: 183 | """Enable propagation of the library log outputs. 184 | 185 | Please disable the fishfarm's default handler to prevent double logging if the root logger has 186 | been configured. 187 | """ 188 | 189 | _configure_library_root_logger() 190 | _get_library_root_logger().propagate = True 191 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import (GenerationRequest, GenerationResult, Message, Model, 2 | NLLRequest, NLLResult, Role) 3 | 4 | __all__ = [ 5 | "GenerationRequest", 6 | "GenerationResult", 7 | "NLLRequest", 8 | "NLLResult", 9 | "Model", 10 | "Role", 11 | "Message", 12 | ] 13 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/models/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Iterable, Literal, Optional, Sequence 5 | 6 | Role = Literal["system", "user", "assistant", "assistant_prefill"] 7 | 8 | 9 | @dataclass 10 | class Message: 11 | 12 | role: Role 13 | content: str 14 | 15 | 16 | @dataclass 17 | class GenerationRequest: 18 | 19 | messages: list[Message] 20 | 21 | max_tokens: Optional[int] = None 22 | stop: Sequence[str] = () 23 | 24 | 25 | @dataclass 26 | class GenerationResult: 27 | 28 | request: GenerationRequest 29 | generation: str 30 | 31 | 32 | @dataclass 33 | class NLLRequest: 34 | 35 | messages: list[Message] 36 | 37 | 38 | @dataclass 39 | class NLLResult: 40 | 41 | request: NLLRequest 42 | sum_nll: float 43 | num_considered_tokens: int 44 | 45 | 46 | class Model: 47 | 48 | def generate( 49 | self, requests: Sequence[GenerationRequest] 50 | ) -> Iterable[GenerationResult]: 51 | raise NotImplementedError() 52 | 53 | def nll(self, requests: Sequence[NLLRequest]) -> Iterable[NLLResult]: 54 | raise NotImplementedError() 55 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/models/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional 3 | 4 | from transformers import PreTrainedTokenizerBase 5 | 6 | from .base import Message 7 | 8 | 9 | class MaskedTokens: 10 | 11 | text: str 12 | token_ids: list[int] 13 | mask: list[bool] 14 | 15 | def __init__(self) -> None: 16 | self.text = "" 17 | self.token_ids = [] 18 | self.mask = [] 19 | 20 | def extend( 21 | self, 22 | messages: list[Message], 23 | mask_value: bool, 24 | tokenizer: PreTrainedTokenizerBase, 25 | chat_template: Optional[str], 26 | add_generation_prompt: bool, 27 | ) -> None: 28 | if len(messages) == 0: 29 | # `tokenizer.apply_chat_template` does not accept an empty list. 30 | raise ValueError("At least one message is required.") 31 | 32 | all_text: str = tokenizer.apply_chat_template( 33 | conversation=[dataclasses.asdict(message) for message in messages], 34 | chat_template=chat_template, 35 | tokenize=False, 36 | add_generation_prompt=add_generation_prompt, 37 | ) 38 | assert all_text.startswith(self.text) 39 | new_text = all_text[len(self.text) :] 40 | new_token_ids: list[int] = tokenizer.encode(new_text, add_special_tokens=False) 41 | 42 | self.token_ids.extend(new_token_ids) 43 | self.mask.extend([mask_value] * len(new_token_ids)) 44 | self.text = all_text 45 | 46 | 47 | def tokenize_messages( 48 | messages: list[Message], 49 | tokenizer: PreTrainedTokenizerBase, 50 | chat_template: Optional[str], 51 | ) -> MaskedTokens: 52 | masked_tokens = MaskedTokens() 53 | 54 | for i, message in enumerate(messages): 55 | if message.role != "assistant": 56 | continue 57 | 58 | masked_tokens.extend(messages[:i], False, tokenizer, chat_template, True) 59 | masked_tokens.extend(messages[: i + 1], True, tokenizer, chat_template, False) 60 | 61 | masked_tokens.extend(messages, False, tokenizer, chat_template, True) 62 | return masked_tokens 63 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/models/vllm_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import dataclasses 3 | from typing import Any, Iterable, Optional, Sequence 4 | 5 | from fishfarm.models.base import NLLRequest, NLLResult 6 | from transformers import PreTrainedTokenizerBase 7 | 8 | from ..imports import try_import 9 | from .base import GenerationRequest, GenerationResult, Message, Model 10 | from .tokenization_utils import tokenize_messages 11 | 12 | with try_import() as _imports: 13 | import vllm 14 | 15 | _imports.check() 16 | 17 | 18 | class VLLMModel(Model): 19 | 20 | def __init__( 21 | self, 22 | llm: vllm.LLM, 23 | sampling_params: vllm.SamplingParams, 24 | chat_template: Optional[str], 25 | ) -> None: 26 | self.llm = llm 27 | self.chat_template = chat_template 28 | self.sampling_params = sampling_params 29 | 30 | def get_tokenizer(self) -> PreTrainedTokenizerBase: 31 | tokenizer = self.llm.get_tokenizer() 32 | 33 | if not hasattr(tokenizer, "apply_chat_template"): 34 | if hasattr(tokenizer, "tokenizer"): 35 | tokenizer = tokenizer.tokenizer 36 | else: 37 | raise ValueError( 38 | "The tokenizer does not have the 'apply_chat_template' method. " 39 | "This is likely because of the versions of vLLM or transformers." 40 | ) 41 | 42 | return tokenizer 43 | 44 | def _into_prompt(self, messages: Sequence[Message]) -> str: 45 | tokenizer = self.get_tokenizer() 46 | prefill_text = "" 47 | n_assistant_prefill = sum([m.role == "assistant_prefill" for m in messages]) 48 | if n_assistant_prefill > 1: 49 | raise ValueError( 50 | f"There must be at most one assistant_prefill role, but got {n_assistant_prefill}", 51 | ) 52 | if n_assistant_prefill: 53 | assert ( 54 | messages[-1].role == "assistant_prefill" 55 | ), "assistant_prefill role must be the last message" 56 | prefill_text = messages[-1].content 57 | messages = messages[:-1] 58 | prompt: str = tokenizer.apply_chat_template( 59 | conversation=[dataclasses.asdict(message) for message in messages], 60 | chat_template=self.chat_template, 61 | tokenize=False, 62 | add_generation_prompt=True, 63 | ) 64 | prompt += prefill_text 65 | return prompt 66 | 67 | def _predict_log_probs(self, token_ids_list: list[list[int]]) -> list[list[float]]: 68 | sampling_params = copy.copy(self.sampling_params) 69 | sampling_params.prompt_logprobs = 1 70 | sampling_params.max_tokens = 1 71 | 72 | completions = self.llm.generate( 73 | prompt_token_ids=token_ids_list, 74 | sampling_params=sampling_params, 75 | ) 76 | 77 | log_probs_list = [] 78 | for token_ids, completion in zip(token_ids_list, completions): 79 | log_probs = [] 80 | assert completion.prompt_logprobs is not None 81 | assert token_ids == completion.prompt_token_ids 82 | assert len(token_ids) == len(completion.prompt_logprobs) 83 | for token_id, logprob_dict in zip(token_ids, completion.prompt_logprobs): 84 | if logprob_dict is None: 85 | log_probs.append(0.0) 86 | else: 87 | logprob_entry: Any = logprob_dict[token_id] 88 | 89 | if isinstance(logprob_entry, float): 90 | log_probs.append(logprob_entry) 91 | else: 92 | log_probs.append(logprob_entry.logprob) 93 | 94 | log_probs_list.append(log_probs) 95 | 96 | return log_probs_list 97 | 98 | def generate( 99 | self, requests: Sequence[GenerationRequest] 100 | ) -> Iterable[GenerationResult]: 101 | 102 | prompts = [self._into_prompt(request.messages) for request in requests] 103 | completions = self.llm.generate( 104 | prompts=prompts, 105 | sampling_params=self.sampling_params, 106 | ) 107 | 108 | for request, completion in zip(requests, completions): 109 | yield GenerationResult( 110 | request=request, generation=completion.outputs[0].text 111 | ) 112 | 113 | def nll(self, requests: Sequence[NLLRequest]) -> Iterable[NLLResult]: 114 | masked_tokens_list = [ 115 | tokenize_messages( 116 | request.messages, self.get_tokenizer(), self.chat_template 117 | ) 118 | for request in requests 119 | ] 120 | log_probs_list = self._predict_log_probs( 121 | [masked_tokens.token_ids for masked_tokens in masked_tokens_list] 122 | ) 123 | 124 | results = [] 125 | for log_probs, masked_tokens, request in zip( 126 | log_probs_list, masked_tokens_list, requests 127 | ): 128 | assert len(log_probs) == len(masked_tokens.mask) 129 | 130 | sum_nll = 0.0 131 | num_considered_tokens = 0 132 | for log_prob, mask_value in zip(log_probs, masked_tokens.mask): 133 | if mask_value: 134 | sum_nll += -log_prob 135 | num_considered_tokens += 1 136 | 137 | results.append( 138 | NLLResult( 139 | request=request, 140 | sum_nll=sum_nll, 141 | num_considered_tokens=num_considered_tokens, 142 | ) 143 | ) 144 | 145 | return results 146 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | from .base import Task, TaskResult 3 | 4 | __all__ = [ 5 | "base", 6 | "TaskResult", 7 | "Task", 8 | ] 9 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/ai2_arc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | from dataclasses import dataclass 4 | from typing import Iterable, Optional, Sequence 5 | 6 | from ..models import GenerationRequest, Message, Model 7 | from .base import Task, TaskResult 8 | 9 | 10 | def extract_answer(text: str) -> Optional[str]: 11 | pattern = r"answer is \(?([A-J])\)?" 12 | match = re.search(pattern, text) 13 | if match: 14 | return match.group(1) 15 | else: 16 | return extract_again(text) 17 | 18 | 19 | def extract_again(text: str) -> Optional[str]: 20 | match = re.search(r".*[aA]nswer:\s*([A-J])", text) 21 | if match: 22 | return match.group(1) 23 | else: 24 | return extract_final(text) 25 | 26 | 27 | def extract_final(text: str) -> Optional[str]: 28 | pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)" 29 | match = re.search(pattern, text, re.DOTALL) 30 | if match: 31 | return match.group(0) 32 | else: 33 | return None 34 | 35 | 36 | def is_correct(pred: Optional[str], answer: str, options: list[str]) -> bool: 37 | if not pred: 38 | random.seed(42) 39 | x = random.randint(0, len(options) - 1) 40 | if ["A", "B", "C", "D", "E"][x] == answer: 41 | return True 42 | else: 43 | return False 44 | elif pred == answer: 45 | return True 46 | else: 47 | return False 48 | 49 | 50 | @dataclass 51 | class Ai2ArcSample: 52 | 53 | question: str 54 | question_id: str 55 | options: list[str] 56 | answer: str 57 | 58 | 59 | def mean(iterable: Iterable[float]) -> float: 60 | total, count = 0.0, 0 61 | for x in iterable: 62 | total += x 63 | count += 1 64 | return total / count 65 | 66 | 67 | class Ai2ArcTask(Task): 68 | def __init__( 69 | self, 70 | samples: Sequence[Ai2ArcSample], 71 | context_messages: Sequence[Message] = (), 72 | ): 73 | self.samples = list(samples) 74 | self.context_messages = context_messages 75 | 76 | @property 77 | def num_samples(self) -> int: 78 | return len(self.samples) 79 | 80 | def evaluate( 81 | self, 82 | model: Model, 83 | sample_ids: Optional[Sequence[int]] = None, 84 | ) -> TaskResult: 85 | if sample_ids is None: 86 | sample_ids = range(len(self.samples)) 87 | samples = [self.samples[sample_id] for sample_id in sample_ids] 88 | 89 | requests = [] 90 | for sample in samples: 91 | messages = list(self.context_messages) 92 | messages.append(Message(role="user", content=sample.question)) 93 | requests.append(GenerationRequest(messages=messages)) 94 | 95 | sample_details = [] 96 | for sample, result in zip(samples, model.generate(requests)): 97 | output = result.generation 98 | prediction = extract_answer(result.generation) 99 | 100 | sample_details.append( 101 | dict( 102 | problem=sample.question, 103 | output=output, 104 | answer=sample.answer, 105 | prediction=prediction, 106 | correct=is_correct(prediction, sample.answer, sample.options), 107 | ) 108 | ) 109 | 110 | aggregate_metrics = { 111 | "acc": mean( 112 | float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0 113 | for sd in sample_details 114 | ) 115 | } 116 | return TaskResult( 117 | aggregate_metrics=aggregate_metrics, sample_details=sample_details 118 | ) 119 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from dataclasses import dataclass 3 | from typing import Any, Optional, Sequence 4 | 5 | from ..models import Model 6 | 7 | 8 | @dataclass 9 | class TaskResult: 10 | 11 | aggregate_metrics: dict[str, float] 12 | sample_details: list[dict[str, Any]] 13 | 14 | 15 | class Task(abc.ABC): 16 | 17 | @property 18 | @abc.abstractmethod 19 | def num_samples(self) -> int: 20 | raise NotImplementedError() 21 | 22 | @abc.abstractmethod 23 | def evaluate( 24 | self, 25 | model: Model, 26 | sample_ids: Optional[Sequence[int]] = None, 27 | ) -> TaskResult: 28 | raise NotImplementedError() 29 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/competation_math.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from math import isclose 3 | from typing import Any, Iterable, Optional, Sequence, Union 4 | 5 | from sympy import N, simplify 6 | from sympy.parsing.latex import parse_latex 7 | from sympy.parsing.sympy_parser import parse_expr 8 | 9 | from ..models import GenerationRequest, Message, Model 10 | from .base import Task, TaskResult 11 | 12 | 13 | def _fix_fracs(string: str) -> str: 14 | substrs = string.split("\\frac") 15 | new_str = substrs[0] 16 | if len(substrs) > 1: 17 | substrs = substrs[1:] 18 | for substr in substrs: 19 | new_str += "\\frac" 20 | if substr[0] == "{": 21 | new_str += substr 22 | else: 23 | try: 24 | assert len(substr) >= 2 25 | except AssertionError: 26 | return string 27 | a = substr[0] 28 | b = substr[1] 29 | if b != "{": 30 | if len(substr) > 2: 31 | post_substr = substr[2:] 32 | new_str += "{" + a + "}{" + b + "}" + post_substr 33 | else: 34 | new_str += "{" + a + "}{" + b + "}" 35 | else: 36 | if len(substr) > 2: 37 | post_substr = substr[2:] 38 | new_str += "{" + a + "}" + b + post_substr 39 | else: 40 | new_str += "{" + a + "}" + b 41 | string = new_str 42 | return string 43 | 44 | 45 | def _fix_a_slash_b(string: str) -> str: 46 | if len(string.split("/")) != 2: 47 | return string 48 | a: str = string.split("/")[0] 49 | b: str = string.split("/")[1] 50 | try: 51 | a_int: int = int(a) 52 | b_int: int = int(b) 53 | assert string == "{}/{}".format(a_int, b_int) 54 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 55 | return new_string 56 | except (AssertionError, ValueError): 57 | return string 58 | 59 | 60 | def _remove_right_units(string: str) -> str: 61 | if "\\text{ " in string: 62 | splits = string.split("\\text{ ") 63 | assert len(splits) == 2 64 | return splits[0] 65 | else: 66 | return string 67 | 68 | 69 | def _fix_sqrt(string: str) -> str: 70 | if "\\sqrt" not in string: 71 | return string 72 | splits = string.split("\\sqrt") 73 | new_string = splits[0] 74 | for split in splits[1:]: 75 | if split[0] != "{": 76 | a = split[0] 77 | new_substr = "\\sqrt{" + a + "}" + split[1:] 78 | else: 79 | new_substr = "\\sqrt" + split 80 | new_string += new_substr 81 | return new_string 82 | 83 | 84 | def _strip_string(string: str) -> str: 85 | string = string.replace("\n", "") 86 | 87 | string = string.replace("\\!", "") 88 | 89 | string = string.replace("\\\\", "\\") 90 | 91 | string = string.replace("tfrac", "frac") 92 | string = string.replace("dfrac", "frac") 93 | 94 | string = string.replace("\\left", "") 95 | string = string.replace("\\right", "") 96 | 97 | string = string.replace("^{\\circ}", "") 98 | string = string.replace("^\\circ", "") 99 | 100 | string = string.replace("\\$", "") 101 | 102 | string = _remove_right_units(string) 103 | 104 | string = string.replace(r"\\%", "") 105 | string = string.replace(r"\%", "") 106 | 107 | string = string.replace(" .", " 0.") 108 | string = string.replace("{.", "{0.") 109 | if len(string) == 0: 110 | return string 111 | if string[0] == ".": 112 | string = "0" + string 113 | 114 | if len(string.split("=")) == 2: 115 | if len(string.split("=")[0]) <= 2: 116 | string = string.split("=")[1] 117 | 118 | string = _fix_sqrt(string) 119 | 120 | string = string.replace(" ", "") 121 | 122 | string = _fix_fracs(string) 123 | 124 | if string == "0.5": 125 | string = "\\frac{1}{2}" 126 | 127 | string = _fix_a_slash_b(string) 128 | 129 | return string 130 | 131 | 132 | def is_digit(s: Union[bool, float, str]) -> bool: 133 | try: 134 | float(str(s).replace(",", "")) 135 | return True 136 | except ValueError: 137 | return False 138 | 139 | 140 | def symbolic_equal(a: str, b: str) -> bool: 141 | def _parse(s: str) -> Any: 142 | for f in [parse_latex, parse_expr]: 143 | try: 144 | return f(s) 145 | except Exception: 146 | pass 147 | return s 148 | 149 | a = _parse(a) 150 | b = _parse(b) 151 | 152 | try: 153 | if simplify(a - b) == 0: 154 | return True 155 | except Exception: 156 | pass 157 | 158 | try: 159 | if isclose(N(a), N(b), rel_tol=1e-3): 160 | return True 161 | except Exception: 162 | pass 163 | return False 164 | 165 | 166 | def math_equal( 167 | prediction: Union[bool, float, str], 168 | reference: Union[float, str], 169 | include_percentage: bool = True, 170 | is_close: bool = True, 171 | ) -> bool: 172 | """ 173 | Exact match of math if and only if: 174 | 1. numerical equal: both can convert to float and are equal 175 | 2. symbolic equal: both can convert to sympy expression and are equal 176 | """ 177 | try: 178 | if is_digit(prediction) and is_digit(reference): 179 | prediction = float(str(prediction).replace(",", "")) 180 | reference = float(str(reference).replace(",", "")) 181 | if include_percentage: 182 | gt_result = [reference / 100, reference, reference * 100] 183 | else: 184 | gt_result = [reference] 185 | for item in gt_result: 186 | try: 187 | if is_close: 188 | if isclose(item, prediction, rel_tol=1e-4): 189 | return True 190 | else: 191 | if item == prediction: 192 | return True 193 | except Exception: 194 | continue 195 | return False 196 | except Exception: 197 | pass 198 | 199 | if not prediction and prediction not in [0, False]: 200 | return False 201 | 202 | reference = str(reference).strip() 203 | prediction = str(prediction).strip() 204 | 205 | pred_str, ref_str = prediction, reference 206 | if ( 207 | prediction.startswith("[") 208 | and prediction.endswith("]") 209 | and not reference.startswith("(") 210 | ) or ( 211 | prediction.startswith("(") 212 | and prediction.endswith(")") 213 | and not reference.startswith("[") 214 | ): 215 | pred_str = pred_str.strip("[]()") 216 | ref_str = ref_str.strip("[]()") 217 | for s in ["{", "}", "(", ")"]: 218 | ref_str = ref_str.replace(s, "") 219 | pred_str = pred_str.replace(s, "") 220 | if pred_str == ref_str: 221 | return True 222 | 223 | if ( 224 | (prediction.startswith("[") and prediction.endswith("]")) 225 | and (reference.startswith("[") and reference.endswith("]")) 226 | or (prediction.startswith("(") and prediction.endswith(")")) 227 | and (reference.startswith("(") and reference.endswith(")")) 228 | ): 229 | pred_parts = prediction[1:-1].split(",") 230 | ref_parts = reference[1:-1].split(",") 231 | if len(pred_parts) == len(ref_parts): 232 | if all( 233 | [ 234 | math_equal( 235 | pred_parts[i], ref_parts[i], include_percentage, is_close 236 | ) 237 | for i in range(len(pred_parts)) 238 | ] 239 | ): 240 | return True 241 | 242 | if symbolic_equal(prediction, reference): 243 | return True 244 | 245 | return False 246 | 247 | 248 | def is_equiv(str1: Optional[str], str2: Optional[str]) -> bool: 249 | if str1 is None and str2 is None: 250 | return True 251 | if str1 is None or str2 is None: 252 | return False 253 | 254 | try: 255 | ss1 = _strip_string(str1) 256 | ss2 = _strip_string(str2) 257 | return math_equal(ss1, ss2) or ss1 == ss2 258 | except (AssertionError, TypeError, ValueError): 259 | return math_equal(str1, str2) or str1 == str2 260 | 261 | 262 | def last_boxed_only_string(string: str) -> Optional[str]: 263 | idx = string.rfind("\\boxed") 264 | if idx < 0: 265 | idx = string.rfind("\\fbox") 266 | if idx < 0: 267 | return None 268 | 269 | i = idx 270 | right_brace_idx: Optional[int] = None 271 | 272 | num_left_braces_open = 0 273 | while i < len(string): 274 | if string[i] == "{": 275 | num_left_braces_open += 1 276 | if string[i] == "}": 277 | num_left_braces_open -= 1 278 | if num_left_braces_open == 0: 279 | right_brace_idx = i 280 | break 281 | i += 1 282 | 283 | if right_brace_idx is None: 284 | retval = None 285 | else: 286 | assert right_brace_idx is not None 287 | retval = string[idx : right_brace_idx + 1] 288 | 289 | return retval 290 | 291 | 292 | def remove_boxed(s: Optional[str]) -> Optional[str]: 293 | left = "\\boxed{" 294 | if s is None: 295 | return None 296 | else: 297 | try: 298 | assert s[: len(left)] == left 299 | assert s[-1] == "}" 300 | return s[len(left) : -1] 301 | except (AssertionError, TypeError, ValueError): 302 | return None 303 | 304 | 305 | @dataclass 306 | class MathSample: 307 | 308 | problem: str 309 | answer: Optional[str] = None 310 | type: Optional[str] = None 311 | 312 | 313 | def mean(iterable: Iterable[float]) -> float: 314 | total, count = 0.0, 0 315 | for x in iterable: 316 | total += x 317 | count += 1 318 | return total / count 319 | 320 | 321 | def extract_ans(completion: str) -> Optional[str]: 322 | 323 | split_ans = completion.split("The answer is: ") 324 | if len(split_ans) > 1: 325 | ans = split_ans[-1] 326 | extract_ans_temp = ans.split(".\n")[0] 327 | extract_ans_temp = extract_ans_temp.strip() 328 | if len(extract_ans_temp) > 0 and extract_ans_temp[-1] == ".": 329 | extract_ans = extract_ans_temp[0:-1] 330 | else: 331 | extract_ans = extract_ans_temp 332 | extract_ans = extract_ans.strip() 333 | return extract_ans 334 | else: 335 | return remove_boxed(last_boxed_only_string(completion)) 336 | 337 | 338 | class LatexFormatMathTask(Task): 339 | def __init__( 340 | self, 341 | samples: Sequence[MathSample], 342 | context_messages: Sequence[Message] = (), 343 | ): 344 | self.samples = list(samples) 345 | self.context_messages = context_messages 346 | 347 | @property 348 | def num_samples(self) -> int: 349 | return len(self.samples) 350 | 351 | def evaluate( 352 | self, 353 | model: Model, 354 | sample_ids: Optional[Sequence[int]] = None, 355 | ) -> TaskResult: 356 | if sample_ids is None: 357 | sample_ids = range(len(self.samples)) 358 | samples = [self.samples[sample_id] for sample_id in sample_ids] 359 | 360 | requests = [] 361 | for sample in samples: 362 | messages = list(self.context_messages) 363 | messages.append(Message(role="user", content=sample.problem)) 364 | requests.append(GenerationRequest(messages=messages)) 365 | 366 | sample_details = [] 367 | for sample, result in zip(samples, model.generate(requests)): 368 | output = result.generation 369 | prediction = extract_ans(output) 370 | 371 | sample_details.append( 372 | dict( 373 | problem=sample.problem, 374 | output=output, 375 | answer=sample.answer, 376 | type=sample.type, 377 | prediction=prediction, 378 | correct=is_equiv(sample.answer, prediction), 379 | ) 380 | ) 381 | 382 | aggregate_metrics = { 383 | "acc": mean( 384 | float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0 385 | for sd in sample_details 386 | ) 387 | } 388 | 389 | return TaskResult( 390 | aggregate_metrics=aggregate_metrics, sample_details=sample_details 391 | ) 392 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/evalplus/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import load_dataset 2 | from .task import EvalplusTask 3 | 4 | __all__ = ["EvalplusTask", "load_dataset"] 5 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/evalplus/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from evalplus.data import get_human_eval_plus, get_mbpp_plus 4 | 5 | 6 | @dataclass 7 | class TextToCodeProblem: 8 | id: str 9 | instruction: str 10 | response_prefix: str 11 | 12 | 13 | def get_mbpp_raw_problems() -> list[dict]: 14 | problems = get_mbpp_plus() 15 | return list(problems.values()) 16 | 17 | 18 | def get_humaneval_raw_problems() -> list[dict]: 19 | problems = get_human_eval_plus() 20 | return list(problems.values()) 21 | 22 | 23 | def read_mbpp_plus( 24 | plus_path: str, err_incomplete: bool = True, mini: bool = False 25 | ) -> dict[str, dict]: 26 | from evalplus.data.mbpp import (completeness_check, 27 | mbpp_deserialize_inputs, stream_jsonl) 28 | 29 | plus = {task["task_id"]: task for task in stream_jsonl(plus_path)} 30 | for task_id, task in plus.items(): 31 | task["base_input"] = mbpp_deserialize_inputs(task_id, task["base_input"]) 32 | task["plus_input"] = mbpp_deserialize_inputs(task_id, task["plus_input"]) 33 | 34 | if err_incomplete: 35 | completeness_check("MBPP+", plus) 36 | return plus 37 | 38 | 39 | def map_mbpp_problem(p: dict) -> TextToCodeProblem: 40 | id = p["task_id"] 41 | prompt = p["prompt"] 42 | start_index = prompt.index('"""') 43 | end_index = prompt.rindex('"""') 44 | prompt = prompt[start_index + 3 : end_index] 45 | assert_index = prompt.index("assert") 46 | instruction = prompt[:assert_index].strip() 47 | if not instruction.endswith("."): 48 | instruction += "." 49 | assertion = prompt[assert_index:].strip() 50 | instruction = f"""{instruction} Your code should satisfy the following assertion: 51 | ```python 52 | {assertion} 53 | ```""" 54 | response_prefix = """```python""" 55 | return TextToCodeProblem( 56 | id=str(id), instruction=instruction, response_prefix=response_prefix 57 | ) 58 | 59 | 60 | def map_humaneval_problem(p: dict) -> TextToCodeProblem: 61 | id = p["task_id"] 62 | prompt = p["prompt"] 63 | prompt = prompt.strip() 64 | instruction = f"""Write a solution to the following problem: 65 | ```python 66 | {prompt} 67 | ```""" 68 | response_prefix = f"""```python 69 | {prompt}""" 70 | return TextToCodeProblem( 71 | id=id, instruction=instruction, response_prefix=response_prefix 72 | ) 73 | 74 | 75 | def load_dataset(source_dataset: str) -> list[TextToCodeProblem]: 76 | if source_dataset not in ("humaneval", "mbpp"): 77 | raise ValueError(f"Unknown source_dataset: {source_dataset}") 78 | 79 | raw_problem_fn = { 80 | "humaneval": get_humaneval_raw_problems, 81 | "mbpp": get_mbpp_raw_problems, 82 | }[source_dataset] 83 | 84 | if source_dataset.startswith("humaneval"): 85 | map_problem_fn = map_humaneval_problem 86 | elif source_dataset.startswith("mbpp"): 87 | map_problem_fn = map_mbpp_problem 88 | else: 89 | raise ValueError(f"Unknown source_dataset: {source_dataset}") 90 | 91 | raw_problems = raw_problem_fn() 92 | problems = list(map(map_problem_fn, raw_problems)) 93 | 94 | return problems 95 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/evalplus/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | import threading 5 | import time 6 | from collections import Counter, defaultdict 7 | from concurrent.futures import ProcessPoolExecutor, as_completed 8 | from datetime import datetime 9 | from typing import Any 10 | from warnings import warn 11 | 12 | import numpy as np 13 | from evalplus.data import (get_human_eval_plus, get_human_eval_plus_hash, 14 | get_mbpp_plus, get_mbpp_plus_hash, load_solutions) 15 | from evalplus.eval import SUCCESS, estimate_pass_at_k, untrusted_check 16 | from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS 17 | from evalplus.evaluate import Result, get_groundtruth 18 | from termcolor import cprint 19 | from tqdm.auto import tqdm 20 | 21 | from ...logging import get_logger 22 | 23 | logger = get_logger(__name__) 24 | 25 | 26 | def check_correctness( 27 | dataset: str, 28 | completion_id: int, 29 | problem: dict[str, Any], 30 | solution: str, 31 | expected_output: dict[str, list], 32 | base_only: bool = False, 33 | fast_check: bool = False, 34 | identifier: str = "HumanEval/0_0", 35 | min_time_limit: float = 0.1, 36 | gt_time_limit_factor: float = 2.0, 37 | ) -> dict[str, Result]: 38 | ret = { 39 | "completion_id": completion_id, 40 | "task_id": problem["task_id"], 41 | "_identifier": identifier, 42 | "solution": solution, 43 | } 44 | ret["base"] = untrusted_check( 45 | dataset, 46 | solution, 47 | problem["base_input"], 48 | problem["entry_point"], 49 | expected=expected_output["base"], 50 | atol=problem["atol"], 51 | ref_time=expected_output["base_time"], 52 | fast_check=fast_check, 53 | min_time_limit=min_time_limit, 54 | gt_time_limit_factor=gt_time_limit_factor, 55 | ) 56 | 57 | if not base_only: 58 | ret["plus"] = untrusted_check( 59 | dataset, 60 | solution, 61 | problem["plus_input"], 62 | problem["entry_point"], 63 | expected=expected_output["plus"], 64 | atol=problem["atol"], 65 | ref_time=expected_output["plus_time"], 66 | fast_check=fast_check, 67 | min_time_limit=min_time_limit, 68 | gt_time_limit_factor=gt_time_limit_factor, 69 | ) 70 | return ret 71 | 72 | 73 | def evaluate( 74 | source_dataset: str, 75 | output_path: str, 76 | base_only: bool = False, 77 | parallel: int = 0, 78 | i_just_wanna_run: bool = False, 79 | test_details: bool = False, 80 | min_time_limit: float = 0.2, 81 | gt_time_limit_factor: float = 4.0, 82 | mini: bool = False, 83 | ) -> tuple[Any, list[dict[str, Any]]]: 84 | if parallel == 0: 85 | n_workers = max(1, multiprocessing.cpu_count() // 2) 86 | else: 87 | n_workers = parallel 88 | 89 | if os.path.isdir(output_path): 90 | result_path = os.path.join(output_path, "eval_results.json") 91 | else: 92 | assert output_path.endswith(".jsonl") 93 | result_path = output_path.replace(".jsonl", "_eval_results.json") 94 | 95 | if source_dataset == "humaneval": 96 | problems = get_human_eval_plus(mini=mini) 97 | dataset_hash = get_human_eval_plus_hash() 98 | expected_output = get_groundtruth(problems, dataset_hash, []) 99 | elif source_dataset == "mbpp": 100 | problems = get_mbpp_plus(mini=mini) 101 | dataset_hash = get_mbpp_plus_hash() 102 | expected_output = get_groundtruth( 103 | problems, 104 | dataset_hash, 105 | MBPP_OUTPUT_NOT_NONE_TASKS, 106 | ) 107 | 108 | results = { 109 | "date": datetime.now().strftime("%Y-%m-%d %H:%M"), 110 | "hash": dataset_hash, 111 | "eval": {}, 112 | } 113 | 114 | with ProcessPoolExecutor(max_workers=n_workers) as executor: 115 | futures = [] 116 | completion_id: Counter[str] = Counter() 117 | n_samples = 0 118 | eval_results = defaultdict(list) 119 | remainings = set() 120 | sample_details = [] 121 | 122 | logger.info("Reading samples...") 123 | for sample in tqdm(load_solutions(output_path)): 124 | task_id = sample["task_id"] 125 | explanation = sample.get("explanation", "") 126 | solution = ( 127 | sample["solution"] 128 | if "solution" in sample 129 | else problems[task_id]["prompt"] + sample["completion"] 130 | ) 131 | remainings.add(sample["_identifier"]) 132 | 133 | args = ( 134 | source_dataset, 135 | completion_id[task_id], 136 | problems[task_id], 137 | solution, 138 | expected_output[task_id], 139 | base_only, 140 | not test_details, 141 | sample["_identifier"], 142 | min_time_limit, 143 | gt_time_limit_factor, 144 | ) 145 | 146 | futures.append(executor.submit(check_correctness, *args)) 147 | completion_id[task_id] += 1 148 | n_samples += 1 149 | 150 | sample_details.append( 151 | dict( 152 | task_id=task_id, 153 | solution=solution, 154 | explanation=explanation, 155 | problems=problems[task_id], 156 | expected_output=expected_output[task_id], 157 | ) 158 | ) 159 | 160 | assert n_samples == len(remainings), "Missing problems in unfinished" 161 | if len(completion_id) != len(problems): 162 | logger.warning("Warning: Missing problems in samples") 163 | 164 | def stucking_checker() -> None: 165 | while remainings: 166 | last_size = len(remainings) 167 | time.sleep(20) 168 | if last_size != len(remainings) or len(remainings) == 0: 169 | continue 170 | warn("No samples had finished testing in the last 20s") 171 | warn(f"{len(remainings)} samples to be tested: {remainings}") 172 | 173 | threading.Thread(target=stucking_checker).start() 174 | 175 | for future in tqdm(as_completed(futures), total=n_samples): 176 | result = future.result() 177 | remainings.remove(result["_identifier"]) 178 | eval_results[result["task_id"]].append(result) 179 | 180 | for task_id, task_results in eval_results.items(): 181 | task_results.sort(key=lambda x: x["completion_id"]) 182 | results["eval"][task_id] = { 183 | "nfiles": len(task_results), 184 | "base": [x["base"] for x in task_results], 185 | "plus": ([x["plus"] for x in task_results] if not base_only else []), 186 | } 187 | 188 | if os.path.isfile(result_path) and i_just_wanna_run: 189 | decision = "" 190 | while decision.lower() not in ["y", "n"]: 191 | logger.info( 192 | f"{result_path} already exists. Press [Y/N] to overwrite or exit..." 193 | ) 194 | decision = input() 195 | 196 | if decision.lower() == "y": 197 | new_path = result_path + ".bak" 198 | while os.path.isfile(new_path): 199 | new_path += ".bak" 200 | os.rename(result_path, new_path) 201 | logger.info(f"Backup {result_path} to {new_path}") 202 | 203 | if not os.path.isfile(result_path): 204 | with open(result_path, "w") as f: 205 | json.dump(results, f) 206 | 207 | total = np.array([r["nfiles"] for r in results["eval"].values()]) 208 | base_correct = [] 209 | new_correct = [] 210 | 211 | for key, res in results["eval"].items(): 212 | elements = [element for element in sample_details if element["task_id"] == key] 213 | assert ( 214 | len(elements) == 1 215 | ), f"Expected an element with task_id {key}, found {len(elements)}" 216 | element = elements[0] 217 | 218 | bc = sum([r[0] == SUCCESS for r in res["base"]]) 219 | base_correct.append(bc) 220 | element["base_correct"] = bc 221 | if res["plus"]: 222 | new_bc = sum( 223 | [ 224 | res["plus"][i][0] == res["base"][i][0] == SUCCESS 225 | for i in range(len(res["plus"])) 226 | ] 227 | ) 228 | new_correct.append(new_bc) 229 | element["plus_correct"] = new_bc 230 | 231 | base_correct_array = np.array(base_correct) 232 | 233 | pass_at_k = { 234 | f"pass@{k}": estimate_pass_at_k(total, base_correct_array, k).mean() 235 | for k in [1, 10, 100] 236 | if total.min() >= k 237 | } 238 | 239 | result = {f"{source_dataset}_base_{key}": value for key, value in pass_at_k.items()} 240 | cprint(f"{source_dataset} (base tests)", "red") 241 | for k, v in pass_at_k.items(): 242 | cprint(f"{k}:\t{v:.3f}", "red") 243 | 244 | if new_correct: 245 | cprint(f"{source_dataset}+ (base + extra tests)", "green") 246 | pass_at_k = { 247 | f"pass@{k}": estimate_pass_at_k(total, np.array(new_correct), k).mean() 248 | for k in [1, 10, 100] 249 | if (total >= k).all() 250 | } 251 | result.update( 252 | {f"{source_dataset}_plus_{key}": value for key, value in pass_at_k.items()} 253 | ) 254 | for k, v in pass_at_k.items(): 255 | cprint(f"{k}:\t{v:.3f}", "green") 256 | 257 | return result, sample_details 258 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/evalplus/generation.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from pathlib import Path 3 | from typing import Iterable, List, Sequence, TypeVar 4 | 5 | from evalplus.data import write_jsonl 6 | from tqdm.auto import tqdm 7 | 8 | from ...models import GenerationRequest, Message, Model 9 | from .data import TextToCodeProblem 10 | 11 | _T = TypeVar("_T") 12 | 13 | 14 | def chunked(seq: Sequence[_T], n: int) -> Iterable[Sequence[_T]]: 15 | """Yield successive n-sized chunks from seq.""" 16 | return (seq[i : i + n] for i in range(0, len(seq), n)) 17 | 18 | 19 | def generate( 20 | model: Model, 21 | problems: list[TextToCodeProblem], 22 | context_messages: Sequence[Message], 23 | output_path: str, 24 | n_batches: int = 1, 25 | n_problems_per_batch: int = 1_000_000_000, 26 | n_samples_per_problem: int = 1, 27 | ) -> List[str]: 28 | problems_chunked = list(chunked(list(problems), n_problems_per_batch)) 29 | iter = itertools.product(problems_chunked, range(n_batches)) 30 | n_total = len(problems_chunked) * n_batches 31 | 32 | Path(output_path).write_text("") 33 | for problems, batch_idx in tqdm(iter, total=n_total): 34 | task_ids = [problem.id for problem in problems] 35 | all_task_ids = task_ids * n_samples_per_problem 36 | 37 | requests = [] 38 | for problem in problems: 39 | messages = list(context_messages) 40 | messages.append(Message(role="user", content=problem.instruction)) 41 | messages.append( 42 | Message(role="assistant_prefill", content=problem.response_prefix) 43 | ) 44 | requests.append(GenerationRequest(messages=messages)) 45 | completes = model.generate(requests) 46 | completions = [c.generation for c in completes] 47 | 48 | assert len(problems) <= n_problems_per_batch 49 | assert len(completions) == len(problems) * n_samples_per_problem 50 | 51 | samples = [] 52 | for task_id, completion in zip(all_task_ids, completions): 53 | completion_body = completion[ 54 | : ( 55 | index 56 | if (index := completion.find("```")) != -1 57 | else len(completion) 58 | ) 59 | ] 60 | explanation = completion[ 61 | ( 62 | index 63 | if (index := completion.find("```") + 3) != -1 64 | else len(completion) 65 | ) : 66 | ].strip() 67 | 68 | samples.append( 69 | dict( 70 | task_id=task_id, 71 | completion=completion_body, 72 | explanation=explanation, 73 | ) 74 | ) 75 | 76 | write_jsonl(output_path, samples, append=True) 77 | return completions 78 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/evalplus/sanitization.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import pathlib 4 | import re 5 | import traceback 6 | from typing import Optional 7 | 8 | from evalplus.data import (get_human_eval_plus, get_mbpp_plus, load_solutions, 9 | write_directory, write_jsonl) 10 | from tqdm.auto import tqdm 11 | 12 | from ...logging import get_logger 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def syntax_check(code: str, verbose: bool = False) -> bool: 18 | try: 19 | ast.parse(code) 20 | return True 21 | except (SyntaxError, MemoryError): 22 | if verbose: 23 | traceback.print_exc() 24 | return False 25 | 26 | 27 | def remove_unindented_lines( 28 | code: str, protect_before: str, execeptions: list[str], trim_tails: list[str] 29 | ) -> str: 30 | lines = code.splitlines() 31 | cut_idx = [] 32 | cut_enabled = False 33 | for i, line in enumerate(lines): 34 | if not cut_enabled and line.startswith(protect_before): 35 | cut_enabled = True 36 | continue 37 | if line.strip() == "": 38 | continue 39 | if any(line.startswith(e) for e in execeptions): 40 | continue 41 | 42 | lspace = len(line) - len(line.lstrip()) 43 | if lspace == 0: 44 | cut_idx.append(i) 45 | 46 | if any(line.rstrip().startswith(t) for t in trim_tails): 47 | cut_idx.extend(list(range(i, len(lines)))) 48 | break 49 | 50 | return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx]) 51 | 52 | 53 | def to_four_space_indents(old_code: str) -> str: 54 | new_code = "" 55 | for line in old_code.splitlines(): 56 | lspace = len(line) - len(line.lstrip()) 57 | if lspace == 3: 58 | new_code += " " 59 | new_code += line + "\n" 60 | return new_code 61 | 62 | 63 | def sanitize_code( 64 | old_code: str, 65 | entry_point: str, 66 | rm_prefix_lines: Optional[str] = None, 67 | eofs: list = [], 68 | ) -> str: 69 | new_code = old_code 70 | if rm_prefix_lines is not None: 71 | new_code = "\n".join( 72 | [ 73 | line 74 | for line in old_code.splitlines() 75 | if not line.startswith(rm_prefix_lines) 76 | ] 77 | ) 78 | 79 | new_code = "\n" + new_code 80 | def_left = "def " + entry_point 81 | 82 | new_code = new_code.replace("\n```python\n", "\n```\n") 83 | for chunk in new_code.split("\n```\n"): 84 | if def_left in chunk: 85 | new_code = chunk 86 | break 87 | 88 | chunks = [chunk for chunk in re.split(rf"{def_left}\s*\(", new_code)] 89 | bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]] 90 | def_left = def_left + "(" 91 | new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else "" 92 | new_code = to_four_space_indents(new_code) 93 | 94 | for eof in eofs or []: 95 | new_code = new_code.split(eof)[0] 96 | 97 | new_code = remove_unindented_lines( 98 | new_code, 99 | protect_before=def_left, 100 | execeptions=["def ", "import ", "from "], 101 | trim_tails=['"""', "if", "print"], 102 | ) 103 | new_code = chunks[0] + new_code 104 | 105 | parts = new_code.split("\ndef ") 106 | includes = [parts[0]] 107 | for fn in new_code.split("\ndef ")[1:]: 108 | if ( 109 | fn.strip().startswith(entry_point + " ") 110 | or fn.strip().startswith(entry_point + "(") 111 | or syntax_check("\ndef " + fn) 112 | ): 113 | includes.append(fn) 114 | new_code = "\ndef ".join(includes) 115 | return new_code.strip() 116 | 117 | 118 | def sanitize( 119 | source_dataset: str, 120 | input_path: str, 121 | eofs: list = [], 122 | inplace: bool = False, 123 | rm_prefix_lines: Optional[str] = None, 124 | debug_task: Optional[str] = None, 125 | ) -> str: 126 | entry_point = {} 127 | 128 | if source_dataset == "humaneval": 129 | dataset = get_human_eval_plus() 130 | elif source_dataset == "mbpp": 131 | dataset = get_mbpp_plus() 132 | 133 | for task_id, problem in dataset.items(): 134 | entry_point[task_id] = problem["entry_point"] 135 | 136 | is_folder = os.path.isdir(input_path) 137 | target_path = pathlib.Path(input_path) 138 | if not inplace: 139 | if is_folder: 140 | new_name = target_path.name + "-sanitized" 141 | else: 142 | new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl") 143 | target_path = target_path.parent / new_name 144 | output_path = str(target_path) 145 | 146 | nsan = 0 147 | ntotal = 0 148 | 149 | new_solutions = [] 150 | 151 | for solution in tqdm(load_solutions(input_path)): 152 | task_id = solution["task_id"] 153 | dbg_identifier = solution["_identifier"] 154 | if debug_task is not None and task_id != debug_task: 155 | continue 156 | 157 | ntotal += 1 158 | if "solution" in solution: 159 | old_code = solution["solution"] 160 | else: 161 | assert "completion" in solution 162 | old_code = dataset[task_id]["prompt"] + "\n" + solution["completion"] 163 | 164 | old_code = old_code.strip() 165 | 166 | new_code = sanitize_code( 167 | old_code=old_code, 168 | entry_point=entry_point[task_id], 169 | rm_prefix_lines=rm_prefix_lines, 170 | eofs=eofs, 171 | ).strip() 172 | 173 | if new_code != old_code: 174 | msg = "Sanitized: " + dbg_identifier 175 | if is_folder: 176 | msg += " -> " + dbg_identifier.replace(input_path, output_path) 177 | logger.info(msg) 178 | nsan += 1 179 | 180 | new_solutions.append( 181 | { 182 | "task_id": task_id, 183 | "solution": new_code, 184 | "explanation": solution["explanation"], 185 | } 186 | ) 187 | 188 | if is_folder: 189 | write_directory(output_path, new_solutions) 190 | else: 191 | write_jsonl(output_path, new_solutions) 192 | 193 | logger.info(f"Sanitized {nsan} out of {ntotal} files.") 194 | 195 | return output_path 196 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/evalplus/task.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from typing import Literal, Optional, Sequence 3 | 4 | from ...models import Message, Model 5 | from ..base import Task, TaskResult 6 | from . import evaluation, generation, sanitization 7 | from .data import TextToCodeProblem 8 | 9 | 10 | class EvalplusTask(Task): 11 | 12 | def __init__( 13 | self, 14 | samples: Sequence[TextToCodeProblem], 15 | context_messages: Sequence[Message] = (), 16 | source_dataset: Literal["humaneval", "mbpp"] = "humaneval", 17 | ): 18 | self.samples = list(samples) 19 | self.context_messages = context_messages 20 | self.source_dataset = source_dataset 21 | if source_dataset not in ("humaneval", "mbpp"): 22 | raise ValueError(f"Unknown source_dataset: {source_dataset}") 23 | 24 | @property 25 | def num_samples(self) -> int: 26 | return len(self.samples) 27 | 28 | def evaluate( 29 | self, 30 | model: Model, 31 | sample_ids: Optional[Sequence[int]] = None, 32 | ) -> TaskResult: 33 | if sample_ids is None: 34 | sample_ids = range(len(self.samples)) 35 | samples = [self.samples[sample_id] for sample_id in sample_ids] 36 | 37 | with tempfile.TemporaryDirectory() as save_dir: 38 | output_path = f"{save_dir}/outputs.jsonl" 39 | 40 | completions = generation.generate( 41 | model, samples, self.context_messages, output_path 42 | ) 43 | 44 | if self.source_dataset == "mbpp": 45 | output_path = sanitization.sanitize(self.source_dataset, output_path) 46 | 47 | result, sample_details = evaluation.evaluate( 48 | self.source_dataset, output_path 49 | ) 50 | 51 | for i, completion in enumerate(completions): 52 | sample_details[i]["output"] = completion 53 | 54 | return TaskResult(aggregate_metrics=result, sample_details=sample_details) 55 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/tasks/language_restricted_math.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass 3 | from typing import Iterable, Optional, Sequence 4 | 5 | import huggingface_hub 6 | 7 | from ..imports import try_import 8 | from ..models import GenerationRequest, Message, Model 9 | from .base import Task, TaskResult 10 | 11 | with try_import() as _imports: 12 | import fasttext 13 | 14 | _imports.check() 15 | 16 | 17 | @dataclass 18 | class MathSample: 19 | 20 | problem: str 21 | answer: int 22 | 23 | 24 | def mean(iterable: Iterable[float]) -> float: 25 | total, count = 0.0, 0 26 | for x in iterable: 27 | total += x 28 | count += 1 29 | return total / count 30 | 31 | 32 | def extract_answer_number(completion: str) -> Optional[float]: 33 | matches = re.findall(r"\d*\.?\d+", completion) 34 | if not matches: 35 | return None 36 | text = matches[-1] 37 | return float(text.replace(",", "")) 38 | 39 | 40 | class LanguageRestrictedMathTask(Task): 41 | def __init__( 42 | self, 43 | samples: Sequence[MathSample], 44 | context_messages: Sequence[Message] = (), 45 | languages: Sequence[str] = ("ja", "en"), 46 | ): 47 | self.samples = list(samples) 48 | self.languages = languages 49 | self.context_messages = context_messages 50 | if len(self.languages) != 0: 51 | lid176ftz_path = huggingface_hub.hf_hub_download( 52 | "julien-c/fasttext-language-id", "lid.176.ftz" 53 | ) 54 | self.lid_model = fasttext.load_model(lid176ftz_path) 55 | 56 | @property 57 | def num_samples(self) -> int: 58 | return len(self.samples) 59 | 60 | def evaluate( 61 | self, 62 | model: Model, 63 | sample_ids: Optional[Sequence[int]] = None, 64 | ) -> TaskResult: 65 | if sample_ids is None: 66 | sample_ids = range(len(self.samples)) 67 | samples = [self.samples[sample_id] for sample_id in sample_ids] 68 | 69 | requests = [] 70 | for sample in samples: 71 | messages = list(self.context_messages) 72 | messages.append(Message(role="user", content=sample.problem)) 73 | requests.append(GenerationRequest(messages=messages)) 74 | 75 | sample_details = [] 76 | for sample, result in zip(samples, model.generate(requests)): 77 | output = result.generation 78 | prediction = extract_answer_number(result.generation) 79 | if len(self.languages) != 0: 80 | lid_probs = dict( 81 | zip(*self.lid_model.predict(output.replace("\n", ""), k=-1)) 82 | ) 83 | 84 | sample_details.append( 85 | dict( 86 | problem=sample.problem, 87 | output=output, 88 | answer=sample.answer, 89 | prediction=prediction, 90 | correct=sample.answer == prediction, 91 | **{ 92 | f"lang_{lang}": lid_probs.get(f"__label__{lang}", 0.0) 93 | for lang in self.languages 94 | }, 95 | ) 96 | ) 97 | 98 | aggregate_metrics = {"acc": mean(sd["correct"] for sd in sample_details)} 99 | for lang in self.languages: 100 | aggregate_metrics[f"acc_{lang}"] = mean( 101 | (sd["correct"] and sd[f"lang_{lang}"] > 0.5) for sd in sample_details 102 | ) 103 | 104 | return TaskResult( 105 | aggregate_metrics=aggregate_metrics, sample_details=sample_details 106 | ) 107 | -------------------------------------------------------------------------------- /evaluation/fishfarm/fishfarm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0dev" 2 | -------------------------------------------------------------------------------- /evaluation/fishfarm/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fishfarm" 3 | description = "" 4 | readme = "README.md" 5 | license = {file = "LICENSE"} 6 | authors = [ 7 | {name = "Takuya Akiba"}, 8 | {email = "takiba@sakana.ai"} 9 | ] 10 | classifiers = [ 11 | "Development Status :: 2 - Pre-Alpha", 12 | "Intended Audience :: Science/Research", 13 | "Intended Audience :: Developers", 14 | "License :: OSI Approved :: MIT License", 15 | "Programming Language :: Python :: 3", 16 | "Programming Language :: Python :: 3.8", 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.11", 20 | "Programming Language :: Python :: 3 :: Only", 21 | "Topic :: Scientific/Engineering", 22 | "Topic :: Scientific/Engineering :: Mathematics", 23 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 24 | "Topic :: Software Development", 25 | "Topic :: Software Development :: Libraries", 26 | "Topic :: Software Development :: Libraries :: Python Modules", 27 | ] 28 | requires-python = ">=3.10" 29 | dependencies = [ 30 | "huggingface_hub", 31 | "transformers", 32 | "pydantic", 33 | "colorlog" 34 | ] 35 | dynamic = ["version"] 36 | 37 | [project.optional-dependencies] 38 | development = [ 39 | "black", 40 | "blackdoc", 41 | "flake8", 42 | "isort", 43 | "mypy", 44 | "pytest", 45 | "pytest-mock", 46 | "types-PyYAML", 47 | ] 48 | 49 | full = [ 50 | "vllm", 51 | "langchain", 52 | "langchain-openai", 53 | "fasttext-wheel", 54 | "datasets", 55 | "mysql-connector-python==8.0.32", 56 | "docker==6.1.2", 57 | "evalplus @ git+https://github.com/evalplus/evalplus@1895d2f6aa8895044a7cf69defc24bd57695e885", 58 | "rouge-score" 59 | ] 60 | 61 | [project.urls] 62 | repository = "https://github.com/SakanaAI/fishfarm" 63 | 64 | [tool.setuptools.packages.find] 65 | include = ["fishfarm*"] 66 | 67 | [tool.setuptools.dynamic] 68 | version = {attr = "fishfarm.version.__version__"} 69 | 70 | [tool.black] 71 | line-length = 99 72 | target-version = ['py310'] 73 | exclude = ''' 74 | /( 75 | \.eggs 76 | | \.git 77 | | \.hg 78 | | \.mypy_cache 79 | | \.venv 80 | | venv 81 | | _build 82 | | buck-out 83 | | build 84 | | dist 85 | | docs 86 | | data 87 | )/ 88 | ''' 89 | 90 | [tool.isort] 91 | profile = 'black' 92 | src_paths = ['fishfarm', 'tests'] 93 | line_length = 99 94 | lines_after_imports = 2 95 | 96 | [tool.mypy] 97 | python_version = "3.10" 98 | strict = true 99 | ignore_missing_imports = true 100 | warn_unused_configs = true 101 | disallow_untyped_defs = true 102 | warn_redundant_casts = true 103 | warn_unused_ignores = true 104 | warn_unreachable = true 105 | disallow_any_generics = false 106 | exclude = ".venv|venv|build|docs|tutorial|data" 107 | 108 | [tool.pytest] 109 | mock_use_standalone_module = true 110 | -------------------------------------------------------------------------------- /evaluation/fishfarm/tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 99 3 | statistics = True 4 | exclude = .venv,venv,build,notebooks,.asv,data 5 | ignore = 6 | E203, 7 | W503, 8 | E704 -------------------------------------------------------------------------------- /logging_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_mean_std_max_min_dict(array, prefix): 5 | res = {} 6 | res[prefix + "/mean"] = np.mean(array) 7 | res[prefix + "/std"] = np.std(array) 8 | res[prefix + "/min"] = np.amin(array) 9 | res[prefix + "/max"] = np.amax(array) 10 | return res 11 | 12 | 13 | class Metrics: 14 | """Object keeping running average/latest of relevant metrics to log.""" 15 | 16 | def __init__(self, *args): 17 | self.metrics = {arg: 0 for arg in args} 18 | self.latest_metrics = {arg: 0 for arg in args} 19 | self.samples = {arg: 1e-8 for arg in args} 20 | self.logged_metrics = [arg for arg in args] 21 | 22 | def reset(self): 23 | for arg in self.metrics: 24 | self.metrics[arg] = 0 25 | self.samples[arg] = 1e-8 26 | 27 | def add(self, *args): 28 | for arg in args: 29 | if arg not in self.metrics: 30 | self.logged_metrics.append(arg) 31 | self.metrics[arg] = 0 32 | self.latest_metrics[arg] = 0 33 | self.samples[arg] = 1e-8 34 | 35 | def update(self, **kwargs): 36 | for arg, val in kwargs.items(): 37 | if arg not in self.metrics: 38 | self.logged_metrics += arg 39 | self.metrics[arg] = 0 40 | self.latest_metrics[arg] = 0 41 | self.samples[arg] = 1e-8 42 | self.metrics[arg] += val 43 | self.samples[arg] += 1 44 | 45 | def set(self, **kwargs): 46 | for arg, val in kwargs.items(): 47 | if arg not in self.metrics: 48 | self.logged_metrics += arg 49 | self.metrics[arg] = val 50 | self.samples[arg] = 1 51 | self.metrics[arg] = val 52 | self.samples[arg] = 1 53 | 54 | def get(self): 55 | for arg, metric_agg in self.metrics.items(): 56 | samples = self.samples[arg] 57 | if samples >= 1: 58 | self.latest_metrics[arg] = metric_agg / samples 59 | return self.latest_metrics 60 | -------------------------------------------------------------------------------- /optim_modules.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from logging_utils import get_mean_std_max_min_dict 9 | from utils import (backward, eval_model, forward, load_base_params, 10 | load_hf_params_to_vllm) 11 | 12 | 13 | class OptimizationAlgorithm(abc.ABC): 14 | def __init__(self, **kwargs): 15 | nn.Module.__init__(self=self) 16 | 17 | @abc.abstractmethod 18 | def step_optimization( 19 | self, 20 | model_id, 21 | model, 22 | tokenizer, 23 | policy, 24 | task_loader, 25 | batch_ix, 26 | train_data, 27 | train_eval, 28 | base_params, 29 | decomposed_params, 30 | original_model_params, 31 | metrics_to_log, 32 | vllm_model=None, 33 | **kwargs, 34 | ): 35 | raise NotADirectoryError 36 | 37 | @abc.abstractmethod 38 | def update(self, policy): 39 | raise NotImplementedError 40 | 41 | def log_optim(self, metrics_to_log): 42 | pass 43 | 44 | 45 | class Reinforce(OptimizationAlgorithm, nn.Module): 46 | def __init__( 47 | self, policy, gpu, max_grad_norm, lr, rw_norm, rw_clip, kl_ref_coeff, **kwargs 48 | ): 49 | nn.Module.__init__(self=self) 50 | self.gpu = gpu 51 | self.kl_ref_coeff = kl_ref_coeff 52 | self.use_kl_loss = kl_ref_coeff > 0.0 53 | self.max_grad_norm = float(max_grad_norm) 54 | self.lr = lr 55 | self.rw_norm = rw_norm 56 | self.rw_clip = rw_clip 57 | self.optimizer = torch.optim.Adam(policy.trainable_params, lr=lr) 58 | 59 | def compute_ref_logprobs( 60 | self, 61 | model, 62 | tokenizer, 63 | prompts, 64 | res, 65 | ): 66 | ref_log_probs_list = [] 67 | print("Computing reference log probs...") 68 | for j, prompt in enumerate(prompts): 69 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(self.gpu) 70 | prompt_length = input_ids.shape[-1] 71 | output_ids = tokenizer( 72 | prompt + res.sample_details[j]["output"], 73 | return_tensors="pt", 74 | ).input_ids.to(self.gpu) 75 | outputs = model(output_ids) 76 | logits = outputs.logits[:, prompt_length - 1 : -1] 77 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 78 | ref_log_probs_list.append(log_probs.detach().cpu()) 79 | return ref_log_probs_list 80 | 81 | def get_rewards(self, task_loader, res): 82 | rw_norm = self.rw_norm 83 | rw_clip = self.rw_clip 84 | rewards = task_loader.get_rewards(res=res) 85 | 86 | if rw_norm: 87 | rewards = np.array(rewards) 88 | mean_rw = np.mean(rewards) 89 | std_rw = np.clip(np.std(rewards), a_min=1e-7, a_max=None) 90 | rewards = (rewards - mean_rw) / std_rw 91 | if rw_clip is not None: 92 | if rw_clip > 0: 93 | rewards = np.array(rewards) 94 | rewards = np.clip(rewards, a_min=-rw_clip, a_max=rw_clip) 95 | return rewards 96 | 97 | def step_optimization( 98 | self, 99 | model_id, 100 | model, 101 | tokenizer, 102 | policy, 103 | task_loader, 104 | batch_ix, 105 | train_data, 106 | train_eval, 107 | base_params, 108 | decomposed_params, 109 | original_model_params, 110 | metrics_to_log, 111 | vllm_model=None, 112 | **kwargs, 113 | ): 114 | use_kl_loss = self.use_kl_loss 115 | kl_ref_coeff = self.kl_ref_coeff 116 | 117 | gpu = self.gpu 118 | 119 | prompts = [ 120 | task_loader.get_prompt( 121 | tokenizer, 122 | train_data, 123 | i, 124 | model_id=model_id, 125 | ) 126 | for i in batch_ix 127 | ] 128 | 129 | clipped_batch_size = len(prompts) 130 | 131 | learnable_params = policy.get_learnable_params() 132 | new_params = forward( 133 | policy, model, base_params, decomposed_params, learnable_params 134 | ) 135 | 136 | print("Loading weights and getting completions with VLLM") 137 | load_hf_params_to_vllm(new_params, vllm_model.llm) 138 | res = eval_model(vllm_model, train_eval, batch_ix) 139 | rewards = self.get_rewards(task_loader=task_loader, res=res) 140 | 141 | rw_stats = get_mean_std_max_min_dict(array=rewards, prefix="rewards") 142 | metrics_to_log.update(**rw_stats) 143 | 144 | if use_kl_loss: 145 | with torch.no_grad(): 146 | load_base_params(model=model, base_params=original_model_params) 147 | ref_log_probs_list = self.compute_ref_logprobs( 148 | model=model, 149 | tokenizer=tokenizer, 150 | prompts=prompts, 151 | res=res, 152 | ) 153 | new_params = forward( 154 | policy, model, base_params, decomposed_params, learnable_params 155 | ) 156 | 157 | print("Computing the policy gradient...") 158 | for j, prompt in enumerate(prompts): 159 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(gpu) 160 | prompt_length = input_ids.shape[-1] 161 | output_ids = tokenizer( 162 | prompt + res.sample_details[j]["output"], 163 | return_tensors="pt", 164 | ).input_ids.to(gpu) 165 | generated_ids = output_ids[:, prompt_length:] 166 | 167 | outputs = model(output_ids) 168 | logits = outputs.logits[:, prompt_length - 1 : -1] 169 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 170 | selected_log_probs = log_probs.gather( 171 | 2, generated_ids.unsqueeze(-1) 172 | ).squeeze(-1) 173 | log_likelihood = selected_log_probs.sum(axis=-1) 174 | 175 | pg = -log_likelihood * rewards[j] 176 | loss = pg 177 | 178 | if use_kl_loss: 179 | ref_log_probs = ref_log_probs_list[j].to(gpu) 180 | kl_div = F.kl_div( 181 | input=log_probs, 182 | target=ref_log_probs, 183 | log_target=True, 184 | reduction="sum", 185 | ) 186 | loss = loss + kl_ref_coeff * kl_div 187 | scaled_loss = loss / clipped_batch_size 188 | scaled_loss.backward() 189 | log_dict = { 190 | "pg": pg.item(), 191 | "loss": loss.item(), 192 | } 193 | if use_kl_loss: 194 | log_dict["kl_div"] = kl_div.item() 195 | metrics_to_log.update(**log_dict) 196 | backward(policy, model, base_params, decomposed_params, learnable_params) 197 | 198 | def update(self, policy): 199 | max_grad_norm = self.max_grad_norm 200 | torch.nn.utils.clip_grad_norm_(policy.trainable_params, max_grad_norm) 201 | self.optimizer.step() 202 | self.optimizer.zero_grad() 203 | 204 | def log_optim(self, metrics_to_log): 205 | metrics_dict = metrics_to_log.get() 206 | pg = metrics_dict["pg"] 207 | print(f"PG={pg}") 208 | if self.use_kl_loss: 209 | kl_div = metrics_dict["kl_div"] 210 | print(f"kl_div={kl_div}") 211 | 212 | 213 | class RandomShooting(OptimizationAlgorithm, nn.Module): 214 | def __init__( 215 | self, 216 | policy, 217 | gpu, 218 | pop_size, 219 | min_trainable_param, 220 | max_trainable_param, 221 | optim_ema=0, 222 | re_eval_best=True, 223 | use_loglikelihood_for_ties=False, 224 | **kwargs, 225 | ): 226 | 227 | nn.Module.__init__(self=self) 228 | self.gpu = gpu 229 | trainable_params = policy.trainable_params 230 | self.pop_size = pop_size 231 | self.min_trainable_param = min_trainable_param 232 | self.max_trainable_param = max_trainable_param 233 | self.range_trainable_param = max_trainable_param - min_trainable_param 234 | assert optim_ema >= 0 and optim_ema < 1 235 | self.optim_ema = optim_ema 236 | self.re_eval_best = re_eval_best 237 | self.use_loglikelihood_for_ties = use_loglikelihood_for_ties 238 | 239 | self.trainable_params_shapes = [p.shape for p in trainable_params] 240 | self.trainable_params_nums = [torch.numel(p) for p in trainable_params] 241 | self.trainable_params_dtype = trainable_params[0].dtype 242 | self.total_trainable_params = sum(self.trainable_params_nums) 243 | self.best_idx = 0 244 | 245 | initial_values = ( 246 | torch.rand(size=[pop_size, self.total_trainable_params]) 247 | * self.range_trainable_param 248 | ) + self.min_trainable_param 249 | init_values_flat = [ 250 | torch.flatten(torch.detach_copy(p.data)) for p in trainable_params 251 | ] 252 | init_soln = torch.concat(init_values_flat, dim=0) 253 | if self.re_eval_best: 254 | initial_values[0] = torch.clone(init_soln) 255 | 256 | self.pop_params = nn.Parameter( 257 | initial_values, 258 | requires_grad=False, 259 | ).cpu() 260 | self.best_soln = nn.Parameter(init_soln, requires_grad=False).cpu() 261 | 262 | def compute_logprobs( 263 | self, 264 | model, 265 | tokenizer, 266 | prompts, 267 | generated_outputs, 268 | ): 269 | selected_log_probs_list = [] 270 | for j, prompt in enumerate(prompts): 271 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(self.gpu) 272 | prompt_length = input_ids.shape[-1] 273 | output_ids = tokenizer( 274 | prompt + generated_outputs[j], 275 | return_tensors="pt", 276 | ).input_ids.to(self.gpu) 277 | generated_ids = output_ids[:, prompt_length:] 278 | 279 | outputs = model(output_ids) 280 | logits = outputs.logits[:, prompt_length - 1 : -1] 281 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 282 | selected_log_probs = log_probs.gather( 283 | 2, generated_ids.unsqueeze(-1) 284 | ).squeeze(-1) 285 | selected_log_probs_list.append(selected_log_probs.detach().cpu()) 286 | return selected_log_probs_list 287 | 288 | @torch.no_grad 289 | def sample_new_params( 290 | self, 291 | ): 292 | pop_values = ( 293 | torch.rand(size=[self.pop_size, self.total_trainable_params]) 294 | * self.range_trainable_param 295 | ) + self.min_trainable_param 296 | if self.re_eval_best: 297 | pop_values[0] = torch.detach_copy(self.best_soln) 298 | 299 | self.pop_params.data.copy_(pop_values) 300 | 301 | def split_and_convert(self, flat_params): 302 | split_flat_params = torch.split_with_sizes( 303 | flat_params, split_sizes=self.trainable_params_nums 304 | ) 305 | split_params = [ 306 | torch.reshape(p, shape=s).to(dtype=self.trainable_params_dtype).to(self.gpu) 307 | for p, s in zip(split_flat_params, self.trainable_params_shapes) 308 | ] 309 | return split_params 310 | 311 | def get_params_for_pop_member(self, pop_idx): 312 | return self.split_and_convert(self.pop_params[pop_idx]) 313 | 314 | @torch.no_grad 315 | def step_optimization( 316 | self, 317 | model_id, 318 | model, 319 | tokenizer, 320 | policy, 321 | task_loader, 322 | batch_ix, 323 | train_data, 324 | train_eval, 325 | base_params, 326 | decomposed_params, 327 | metrics_to_log, 328 | vllm_model=None, 329 | **kwargs, 330 | ): 331 | self.sample_new_params() 332 | perf_per_pop = [] 333 | avg_log_likelihoods_per_pop = [] 334 | for pop_idx in range(self.pop_size): 335 | pop_idx_params = self.split_and_convert( 336 | flat_params=self.pop_params[pop_idx] 337 | ) 338 | policy.set_trainable_params_values(new_values=pop_idx_params) 339 | learnable_params = policy.get_learnable_params() 340 | new_params = forward( 341 | policy, model, base_params, decomposed_params, learnable_params 342 | ) 343 | 344 | print("Loading weights and getting completions with VLLM") 345 | load_hf_params_to_vllm(new_params, vllm_model.llm) 346 | res = eval_model(vllm_model, train_eval, batch_ix) 347 | if self.use_loglikelihood_for_ties: 348 | print("Storing log likelihhods") 349 | rewards = task_loader.get_rewards(res=res) 350 | correct = [int(r > 0) for r in rewards] 351 | correct_batch_ix = [i for i, c in zip(batch_ix, correct) if c] 352 | if len(correct_batch_ix) > 0: 353 | avg_log_likelihoods = [] 354 | correct_prompts = [ 355 | task_loader.get_prompt( 356 | tokenizer, 357 | train_data, 358 | i, 359 | model_id=model_id, 360 | ) 361 | for i in correct_batch_ix 362 | ] 363 | correct_outputs = [ 364 | res.sample_details[j]["output"] 365 | for j, c in enumerate(correct) 366 | if c 367 | ] 368 | selected_log_probs_list = self.compute_logprobs( 369 | model=model, 370 | tokenizer=tokenizer, 371 | prompts=correct_prompts, 372 | generated_outputs=correct_outputs, 373 | ) 374 | for selected_log_probs in selected_log_probs_list: 375 | avg_log_likelihood = selected_log_probs.mean(axis=-1) 376 | avg_log_likelihoods.append(avg_log_likelihood.item()) 377 | avg_log_likelihoods_per_pop.append(np.mean(avg_log_likelihoods)) 378 | else: 379 | avg_log_likelihoods_per_pop.append(0.0) 380 | 381 | perf = res.aggregate_metrics[task_loader.target_metric_train] 382 | perf_per_pop.append(perf) 383 | 384 | perf_stats = get_mean_std_max_min_dict(array=perf_per_pop, prefix="pop_perf") 385 | metrics_to_log.update(**perf_stats) 386 | 387 | if self.use_loglikelihood_for_ties: 388 | perf_per_pop_array = np.array(perf_per_pop) 389 | loglikelihood_array = np.array(avg_log_likelihoods_per_pop) 390 | max_perf = perf_per_pop_array == np.max(perf_per_pop_array) 391 | max_perf_idxs = np.flatnonzero(max_perf) 392 | max_perf_logprobs = loglikelihood_array[max_perf_idxs] 393 | print("SC CHECK") 394 | print(perf_per_pop) 395 | print(loglikelihood_array) 396 | print(max_perf_idxs) 397 | best_logprob_idx = np.argmax(max_perf_logprobs) 398 | best_member_idx = max_perf_idxs[best_logprob_idx] 399 | print(best_logprob_idx) 400 | print(best_member_idx) 401 | logprobs_stats = get_mean_std_max_min_dict( 402 | array=max_perf_logprobs, prefix="logprobs_correct" 403 | ) 404 | metrics_to_log.update(**logprobs_stats) 405 | else: 406 | best_member_idx = np.argmax(perf_per_pop) 407 | self.best_idx = best_member_idx 408 | best_params = self.pop_params[best_member_idx].cpu() 409 | self.best_soln.data.copy_( 410 | best_params * (1 - self.optim_ema) + self.optim_ema * self.best_soln.cpu() 411 | ) 412 | 413 | def update(self, policy): 414 | policy.set_trainable_params_values( 415 | new_values=self.split_and_convert(self.best_soln) 416 | ) 417 | 418 | def log_optim(self, metrics_to_log): 419 | pass 420 | 421 | 422 | class CEM(RandomShooting): 423 | def __init__( 424 | self, 425 | policy, 426 | gpu, 427 | elite_ratio, 428 | pop_size, 429 | min_trainable_param, 430 | max_trainable_param, 431 | optim_ema=0, 432 | re_eval_best=True, 433 | use_loglikelihood_for_ties=False, 434 | **kwargs, 435 | ): 436 | 437 | RandomShooting.__init__( 438 | self=self, 439 | policy=policy, 440 | gpu=gpu, 441 | pop_size=pop_size, 442 | min_trainable_param=min_trainable_param, 443 | max_trainable_param=max_trainable_param, 444 | optim_ema=optim_ema, 445 | re_eval_best=re_eval_best, 446 | use_loglikelihood_for_ties=use_loglikelihood_for_ties, 447 | **kwargs, 448 | ) 449 | 450 | self.elite_ratio = elite_ratio 451 | self.num_elites = int(elite_ratio * pop_size) 452 | self.dist_mean = nn.Parameter( 453 | torch.detach_copy(self.best_soln), requires_grad=False 454 | ).cpu() 455 | init_stdev = ( 456 | torch.ones([self.total_trainable_params]) * self.range_trainable_param / 2 457 | ) 458 | self.dist_std = nn.Parameter(init_stdev, requires_grad=False).cpu() 459 | 460 | @torch.no_grad 461 | def sample_new_params( 462 | self, 463 | ): 464 | pop_values = ( 465 | torch.randn(size=[self.pop_size, self.total_trainable_params]) 466 | * self.dist_std 467 | ) + self.dist_mean 468 | pop_values = torch.clamp( 469 | pop_values, 470 | min=self.min_trainable_param, 471 | max=self.max_trainable_param, 472 | ) 473 | if self.re_eval_best: 474 | pop_values[0] = torch.detach_copy(self.best_soln) 475 | 476 | self.pop_params.data.copy_(pop_values) 477 | 478 | @torch.no_grad 479 | def step_optimization( 480 | self, 481 | model_id, 482 | model, 483 | tokenizer, 484 | policy, 485 | task_loader, 486 | batch_ix, 487 | train_data, 488 | train_eval, 489 | base_params, 490 | decomposed_params, 491 | metrics_to_log, 492 | vllm_model=None, 493 | **kwargs, 494 | ): 495 | self.sample_new_params() 496 | perf_per_pop = [] 497 | avg_log_likelihoods_per_pop = [] 498 | for pop_idx in range(self.pop_size): 499 | pop_idx_params = self.split_and_convert( 500 | flat_params=self.pop_params[pop_idx] 501 | ) 502 | policy.set_trainable_params_values(new_values=pop_idx_params) 503 | learnable_params = policy.get_learnable_params() 504 | new_params = forward( 505 | policy, model, base_params, decomposed_params, learnable_params 506 | ) 507 | 508 | print("Loading weights and getting completions with VLLM") 509 | load_hf_params_to_vllm(new_params, vllm_model.llm) 510 | res = eval_model(vllm_model, train_eval, batch_ix) 511 | if self.use_loglikelihood_for_ties: 512 | print("Storing log likelihhods") 513 | rewards = task_loader.get_rewards(res=res) 514 | correct = [int(r > 0) for r in rewards] 515 | correct_batch_ix = [i for i, c in zip(batch_ix, correct) if c] 516 | if len(correct_batch_ix) > 0: 517 | avg_log_likelihoods = [] 518 | correct_prompts = [ 519 | task_loader.get_prompt( 520 | tokenizer, 521 | train_data, 522 | i, 523 | model_id=model_id, 524 | ) 525 | for i in correct_batch_ix 526 | ] 527 | correct_outputs = [ 528 | res.sample_details[j]["output"] 529 | for j, c in enumerate(correct) 530 | if c 531 | ] 532 | print("lalala, I am hitting the selected_log_probs!") 533 | selected_log_probs_list = self.compute_logprobs( 534 | model=model, 535 | tokenizer=tokenizer, 536 | prompts=correct_prompts, 537 | generated_outputs=correct_outputs, 538 | ) 539 | for selected_log_probs in selected_log_probs_list: 540 | avg_log_likelihood = selected_log_probs.mean(axis=-1) 541 | avg_log_likelihoods.append(avg_log_likelihood.item()) 542 | avg_log_likelihoods_per_pop.append(np.mean(avg_log_likelihoods)) 543 | else: 544 | avg_log_likelihoods_per_pop.append(0.0) 545 | 546 | perf = res.aggregate_metrics[task_loader.target_metric_train] 547 | perf_per_pop.append(perf) 548 | 549 | perf_stats = get_mean_std_max_min_dict(array=perf_per_pop, prefix="pop_perf") 550 | metrics_to_log.update(**perf_stats) 551 | 552 | if self.use_loglikelihood_for_ties: 553 | perf_per_pop_array = np.array(perf_per_pop) 554 | loglikelihood_array = np.array(avg_log_likelihoods_per_pop) 555 | max_perf = perf_per_pop_array == np.max(perf_per_pop_array) 556 | max_perf_idxs = np.flatnonzero(max_perf) 557 | max_perf_logprobs = loglikelihood_array[max_perf_idxs] 558 | best_logprob_idx = np.argmax(max_perf_logprobs) 559 | best_member_idx = max_perf_idxs[best_logprob_idx] 560 | logprobs_stats = get_mean_std_max_min_dict( 561 | array=max_perf_logprobs, prefix="logprobs_correct" 562 | ) 563 | metrics_to_log.update(**logprobs_stats) 564 | else: 565 | best_member_idx = np.argmax(perf_per_pop) 566 | elite_idxs = np.argpartition(perf_per_pop, -self.num_elites)[-self.num_elites :] 567 | 568 | elite_params = self.pop_params[elite_idxs] 569 | elite_mean = torch.mean(elite_params, dim=0) 570 | elite_std = torch.std(elite_params, dim=0) 571 | self.best_idx = best_member_idx 572 | best_params = self.pop_params[best_member_idx].cpu() 573 | self.best_soln.data.copy_(best_params) 574 | self.dist_mean.copy_( 575 | elite_mean.cpu() * (1 - self.optim_ema) 576 | + self.optim_ema * self.dist_mean.cpu() 577 | ) 578 | self.dist_std.copy_( 579 | elite_std.cpu() * (1 - self.optim_ema) 580 | + self.optim_ema * self.dist_std.cpu() 581 | ) 582 | 583 | cem_mean_stats = get_mean_std_max_min_dict( 584 | array=self.dist_mean.detach().cpu().numpy(), 585 | prefix="cem_mean", 586 | ) 587 | metrics_to_log.update(**cem_mean_stats) 588 | 589 | cem_std_stats = get_mean_std_max_min_dict( 590 | array=self.dist_std.detach().cpu().numpy(), 591 | prefix="cem_std", 592 | ) 593 | metrics_to_log.update(**cem_std_stats) 594 | -------------------------------------------------------------------------------- /policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Policy 2 | from .weighted_combination import WeightedCombination 3 | -------------------------------------------------------------------------------- /policy/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_soft_mask(n, fraction): 6 | indices = torch.linspace(0, n - 1, n, dtype=torch.bfloat16) + 1 7 | scaled_indices = indices.to(fraction.device) - fraction * n 8 | result = torch.clamp(scaled_indices, 0, 1) 9 | return 1.0 - result 10 | 11 | 12 | class Policy(nn.Module): 13 | def __init__(self, base_params, gpu, init_val, max_mult=1, **kwargs): 14 | # Create learnable parameters. 15 | super().__init__() 16 | self.learnable_params = {} 17 | self.num_params = 0 18 | self.max_mult = max_mult 19 | for k, v in base_params.items(): 20 | # each param initialized with small gaussian noise 21 | if "mlp" in k: 22 | self.learnable_params[k] = torch.nn.Parameter( 23 | data=( 24 | torch.randn( 25 | min(v.shape), 26 | device=gpu, 27 | dtype=torch.bfloat16, 28 | ) 29 | * 0.01 30 | + init_val 31 | ), 32 | requires_grad=True, 33 | ) 34 | self.num_params += self.learnable_params[k].numel() 35 | print(f"#params={self.num_params}") 36 | self.learnable_params_list = list(self.learnable_params.values()) 37 | self.trainable_params = self.learnable_params_list 38 | self.learnable_params_module_list = nn.ParameterList(self.learnable_params_list) 39 | 40 | def get_learnable_params(self, detach=False): 41 | return self.learnable_params 42 | 43 | def set_trainable_params_values(self, new_values): 44 | with torch.no_grad(): 45 | for p, v in zip(self.trainable_params, new_values): 46 | p.data.copy_(v) 47 | 48 | def get_mask(self, p): 49 | return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult 50 | 51 | def record_state(self, metrics_to_log): 52 | pass 53 | -------------------------------------------------------------------------------- /policy/weighted_combination.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import hydra 4 | import torch 5 | import torch.nn as nn 6 | from omegaconf import DictConfig 7 | 8 | from .base import Policy 9 | 10 | 11 | class WeightedCombination(Policy): 12 | def __init__( 13 | self, 14 | base_params, 15 | decomposed_params, 16 | base_policy_cfg: Optional[Union[DictConfig, int]], 17 | params_paths: List[str], 18 | gpu, 19 | norm_coeffs, 20 | per_layer, 21 | init_values: Optional[List[float]] = None, 22 | **kwargs, 23 | ): 24 | # Create learnable parameters. 25 | nn.Module.__init__(self=self) 26 | weights_dict_list: List[Dict[str, torch.Tensor]] = [] 27 | if base_policy_cfg is None: 28 | base_policy = Policy(base_params=base_params, gpu=gpu, init_val=0) 29 | elif isinstance(base_policy_cfg, DictConfig): 30 | base_policy: Policy = hydra.utils.instantiate( 31 | base_policy_cfg, 32 | base_params=base_params, 33 | decomposed_params=decomposed_params, 34 | gpu=gpu, 35 | ) 36 | else: 37 | raise NotImplementedError 38 | 39 | with torch.no_grad(): 40 | for i, load_ckpt in enumerate(params_paths): 41 | print(f"Loading checkpoint {i} at {load_ckpt}...") 42 | if "learnable_params" in load_ckpt: 43 | learnable_params = torch.load(load_ckpt) 44 | else: 45 | state_dict = torch.load(load_ckpt, weights_only=True) 46 | base_policy.load_state_dict(state_dict=state_dict) 47 | learnable_params = base_policy.get_learnable_params() 48 | weights_dict_list.append( 49 | {k: torch.detach_copy(p) for k, p in learnable_params.items()} 50 | ) 51 | 52 | self.num_weights_dict = len(weights_dict_list) 53 | 54 | self.num_params_per_weight_dict = 0 55 | for _ in weights_dict_list[0]: 56 | self.num_params_per_weight_dict += 1 57 | 58 | self.num_params = self.num_weights_dict * self.num_params_per_weight_dict 59 | if init_values is None: 60 | init_values = torch.Tensor( 61 | [1 / self.num_weights_dict for _ in range(self.num_weights_dict)] 62 | ) 63 | else: 64 | assert len(init_values) == self.num_weights_dict 65 | init_values = torch.Tensor(init_values) 66 | self.learned_params_per_weight_dict = 1 67 | if per_layer: 68 | self.learned_params_per_weight_dict = self.num_params_per_weight_dict 69 | init_values = torch.stack( 70 | [init_values for _ in range(self.learned_params_per_weight_dict)], dim=1 71 | ) 72 | if norm_coeffs: 73 | # Normalize across different weight idxs (for all layers) 74 | init_values = init_values / torch.sum(init_values, axis=0) 75 | 76 | # Num weight idxs x learned params_per_weight_idx 77 | self.adaptive_weights = torch.nn.Parameter( 78 | data=init_values, 79 | requires_grad=True, 80 | ) 81 | 82 | self.parameter_keys = [] 83 | self.original_params = {} 84 | for k, v in weights_dict_list[0].items(): 85 | self.parameter_keys.append(k) 86 | self.original_params[k] = [] 87 | for i, weight_dict in enumerate(weights_dict_list): 88 | weight_tensor = self.get_mask(p=weight_dict[k]) 89 | new_key = k.replace(".", "_") 90 | self.register_buffer( 91 | f"weights_{i}_k_{new_key}", 92 | tensor=weight_tensor, 93 | ) 94 | self.original_params[k].append(weight_tensor.to(device=gpu)) 95 | 96 | self.norm = norm_coeffs 97 | self.per_layer = per_layer 98 | self.trainable_params = [self.adaptive_weights] 99 | 100 | def get_weight_to_combine(self, k, weights_dict_idx): 101 | new_key = k.replace(".", "_") 102 | return getattr(self, f"weights_{weights_dict_idx}_k_{new_key}") 103 | 104 | def get_coeff_per_layer(self): 105 | if self.norm: 106 | adaptive_weights = self.adaptive_weights / self.adaptive_weights.sum(0) 107 | else: 108 | adaptive_weights = self.adaptive_weights 109 | weights_per_layer = adaptive_weights.expand( 110 | [ 111 | self.num_weights_dict, 112 | self.num_params_per_weight_dict, 113 | ] 114 | ) 115 | return weights_per_layer 116 | 117 | def get_learnable_params(self): 118 | adaptive_coeff_per_layer = self.get_coeff_per_layer() 119 | output_params = {} 120 | for i, (k, vs) in enumerate(self.original_params.items()): 121 | cs_coeff = adaptive_coeff_per_layer[:, i] 122 | out = vs[0] * cs_coeff[0] 123 | for j, other_v in enumerate(vs[1:]): 124 | v_idx = j + 1 125 | out = out + other_v * cs_coeff[v_idx] 126 | output_params[k] = out 127 | return output_params 128 | 129 | def get_mask(self, p): 130 | return p 131 | 132 | def record_state(self, metrics_to_log): 133 | avg_weights = self.adaptive_weights.mean(-1).detach().cpu().numpy() 134 | dict_to_log = { 135 | f"adaptive_weight/mean_across_params_w{i}": w 136 | for i, w in enumerate(avg_weights.tolist()) 137 | } 138 | metrics_to_log.update(**dict_to_log) 139 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.27.2 2 | datasets==2.14.6 3 | einops==0.7.0 4 | evaluate==0.4.1 5 | fasttext==0.9.2 6 | pandas==2.1.2 7 | Pillow==10.1.0 8 | torch==2.3.1 9 | torchvision==0.18.1 10 | tornado==6.4 11 | tqdm==4.66.1 12 | traitlets==5.13.0 13 | transformers==4.43.1 14 | trl==0.8.6 15 | vllm==0.5.3.post1 16 | hydra-core==1.3.2 17 | evalplus @ git+https://github.com/evalplus/evalplus@1895d2f6aa8895044a7cf69defc24bd57695e885 18 | peft 19 | fire 20 | matplotlib 21 | wandb 22 | 23 | -------------------------------------------------------------------------------- /scripts/eval_few_shot.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | # Prerequisites: 4 | # - Ensure reference_params_results path is set in cfgs/base_model/*.yaml 5 | 6 | # Task Selection 7 | TASK="few_shot_math" # Available options: few_shot_arc_challenge, few_shot_humaneval 8 | 9 | # Evaluation Setting 10 | PER_LAYER=true 11 | NORM_COEFFS=false 12 | 13 | # Start evaluation! 14 | CUDA_VISIBLE_DEVICES=0,1 python svd_reinforce_hydra.py \ 15 | base_model@_global_=llama3i8b \ 16 | optimization@_global_=cem \ 17 | policy@_global_=wcomb \ 18 | use_loglikelihood_for_ties=true \ 19 | per_layer=$PER_LAYER \ 20 | task@_global_=$TASK \ 21 | norm_coeffs=$NORM_COEFFS \ 22 | wandb_log=true \ 23 | num_iters=50 -------------------------------------------------------------------------------- /scripts/eval_prompt_based.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Task Selection 4 | TASK="mbpp2" # Available options: mbpp2, math, ai2_arc 5 | 6 | # First Stage Inference: Classification Expert 7 | # Set to 'None' if not using cls expert 8 | CLS_EXPERT_PATH="None" 9 | 10 | # Second Stage: Expert Models 11 | # Replace these paths with your actual model paths 12 | CODE_EXPERT_PATH="your_path_to_code_expert" 13 | MATH_EXPERT_PATH="your_path_to_math_expert" 14 | REASONING_EXPERT_PATH="your_path_to_reasoning_expert" 15 | 16 | # Start evaluation! 17 | CUDA_VISIBLE_DEVICES=0,1 python svd_reinforce_hydra.py \ 18 | base_model@_global_=llama3i8b \ 19 | task@_global_=$TASK \ 20 | mode@_global_=eval \ 21 | prompt_based_eval=True \ 22 | experts_path_dict.code=$CODE_EXPERT_PATH \ 23 | experts_path_dict.math=$MATH_EXPERT_PATH \ 24 | experts_path_dict.reasoning=$REASONING_EXPERT_PATH \ 25 | load_ckpt=$CLS_EXPERT_PATH -------------------------------------------------------------------------------- /scripts/train_task_expert.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | # Task Selection 4 | TASK="mbpp2" # Available options: mbpp2, gsm8k, ai2_arc, cls 5 | 6 | # Training Setting 7 | NUM_ITERS=200 8 | 9 | # This script needs 2 gpus 10 | CUDA_VISIBLE_DEVICES=0,1 python svd_reinforce_hydra.py \ 11 | base_model@_global_=llama3i8b \ 12 | task@_global_=$TASK \ 13 | mode@_global_=training \ 14 | num_iters=$NUM_ITERS -------------------------------------------------------------------------------- /svd_reinforce_hydra.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os 4 | from datetime import datetime 5 | from typing import Dict 6 | 7 | import hydra 8 | import numpy as np 9 | import torch 10 | from omegaconf import OmegaConf 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | 13 | from base_model import BaseModel 14 | from logging_utils import Metrics, get_mean_std_max_min_dict 15 | from optim_modules import OptimizationAlgorithm 16 | from policy import Policy 17 | from tasks import Task 18 | from utils import (eval_model, eval_model_experts_prompt_based, forward, 19 | load_hf_params_to_vllm) 20 | 21 | 22 | def wandb_init(cfg, run_name: str, group_name: str, log_dir: str): 23 | import wandb 24 | 25 | config_dict = OmegaConf.to_container( 26 | cfg, 27 | resolve=True, 28 | throw_on_missing=False, 29 | ) 30 | config_dict["log_dir"] = log_dir 31 | config_dict["wandb_run_name"] = run_name 32 | config_dict["wandb_group_name"] = group_name 33 | 34 | # wandb has a 128-size character limit on the group name 35 | wandb.init( 36 | project=cfg.wandb_project, 37 | group=group_name[:127], 38 | name=run_name[:127], 39 | config=config_dict, 40 | ) 41 | return wandb 42 | 43 | 44 | @hydra.main(version_base=None, config_path="cfgs", config_name="config") 45 | def main(cfg): 46 | """Main function.""" 47 | 48 | num_iters = cfg.num_iters 49 | test_interval = cfg.test_interval 50 | 51 | batch_size = cfg.batch_size 52 | seed = cfg.seed 53 | policy_name = cfg.policy_name 54 | test_only = cfg.test_only 55 | save_legacy_params = cfg.save_legacy_params 56 | exp_name = cfg.exp_name 57 | run_name = cfg.run_name 58 | 59 | task_name = cfg.task_name 60 | 61 | load_ckpt = cfg.load_ckpt 62 | use_lora = cfg.use_lora 63 | prompt_based_eval = cfg.prompt_based_eval 64 | experts_path_dict = cfg.experts_path_dict 65 | 66 | resuming_from_ckpt = False 67 | if load_ckpt is not None: 68 | if load_ckpt == "scratch" or load_ckpt == "base": 69 | resuming_from_ckpt = False 70 | else: 71 | resuming_from_ckpt = True 72 | 73 | # Create task 74 | task_loader: Task = hydra.utils.instantiate(cfg.task_loader) 75 | 76 | base_model: BaseModel = hydra.utils.instantiate(cfg.base_model) 77 | 78 | model_id = base_model.get_model_id() 79 | decomposed_param_file = base_model.get_param_file(param_folder_path="") 80 | 81 | extract_svd = cfg.extract_svd or (not os.path.exists(decomposed_param_file)) 82 | 83 | has_training_split = task_loader.has_training_split 84 | has_transfer_split = task_loader.has_transfer_split 85 | 86 | if not has_training_split: 87 | assert test_only, "Cannot train on a task with no training split" 88 | 89 | if exp_name is None: 90 | exp_name = "temp" 91 | 92 | metrics_to_log = Metrics() 93 | 94 | # Create log dir. 95 | if run_name is None: 96 | now = datetime.now() 97 | run_name = now.strftime("%Y%m%d-%H%M%S") 98 | if test_only and (not resuming_from_ckpt): 99 | log_dir = f"{cfg.out_dir}/{task_name}/{cfg.base_model_name}_base" 100 | group_name = cfg.base_model_name 101 | else: 102 | log_dir = f"{cfg.out_dir}/{task_name}/{policy_name}/{exp_name}/{run_name}" 103 | group_name = cfg.wandb_group_name 104 | os.makedirs(log_dir, exist_ok=True) 105 | 106 | vllm_model = task_loader.get_vllm_model(model_id=model_id) 107 | 108 | train_eval, *test_evals = task_loader.get_evaluator() 109 | if task_loader.has_transfer_split: 110 | test_eval, transfer_eval = test_evals 111 | else: 112 | test_eval = test_evals[0] 113 | 114 | train_data, train_ix, valid_ix = task_loader.get_train_data() 115 | gpu = torch.device("cuda:1") 116 | np_random = np.random.RandomState(seed) 117 | 118 | # cpu + float32 for initial SVD decomposition 119 | if extract_svd: 120 | model = AutoModelForCausalLM.from_pretrained( 121 | model_id, device_map="cpu", torch_dtype=torch.float32 122 | ) 123 | else: 124 | # Load model and tokenizer. 125 | model = AutoModelForCausalLM.from_pretrained( 126 | model_id, device_map="cuda:1", torch_dtype=torch.bfloat16 127 | ) 128 | tokenizer = AutoTokenizer.from_pretrained(model_id) 129 | base_params = model.state_dict() 130 | 131 | original_model_params = { 132 | k: v.clone().detach().cpu() for k, v in base_params.items() if "mlp" in k 133 | } 134 | 135 | # Load decomposed parameters. 136 | if not os.path.exists(decomposed_param_file): 137 | print("Decomposed params not found. Decomposing...") 138 | decomposed_params = {} 139 | for k, v in base_params.items(): 140 | if "norm" not in k: 141 | print(k) 142 | U, S, V = torch.svd(v) 143 | decomposed_params[f"{k}.U"] = U 144 | decomposed_params[f"{k}.S"] = S 145 | decomposed_params[f"{k}.V"] = V 146 | torch.save(decomposed_params, decomposed_param_file) 147 | print("successfully decomposed model - returning") 148 | return 149 | elif extract_svd: 150 | print(f"ERROR: SVD file already exists at {decomposed_param_file}") 151 | else: 152 | print("Decomposed params found. Loading...") 153 | assert not extract_svd 154 | decomposed_params = torch.load(decomposed_param_file) 155 | for k, v in decomposed_params.items(): 156 | decomposed_params[k] = v.to(torch.bfloat16).to(gpu) 157 | 158 | if cfg.wandb_log: 159 | wandb = wandb_init( 160 | cfg=cfg, group_name=group_name, run_name=run_name, log_dir=log_dir 161 | ) 162 | 163 | policy: Policy = hydra.utils.instantiate( 164 | cfg.shakeoff_policy, 165 | base_params=base_params, 166 | decomposed_params=decomposed_params, 167 | gpu=gpu, 168 | ) 169 | 170 | optimization_algorithm: OptimizationAlgorithm = hydra.utils.instantiate( 171 | cfg.optimization_algorithm, 172 | policy=policy, 173 | gpu=gpu, 174 | ) 175 | 176 | if resuming_from_ckpt and os.path.exists(load_ckpt): 177 | print(f"Starting from checkpoint at: {load_ckpt}") 178 | # load the lora weight 179 | if use_lora: 180 | assert os.path.isdir(load_ckpt), "ckpt for lora must be dir to lora adapter" 181 | from peft import PeftModel 182 | 183 | lora_model = PeftModel.from_pretrained(model, load_ckpt) 184 | merged_model = lora_model.merge_and_unload() 185 | new_params = merged_model.state_dict() 186 | # load svd expert 187 | elif "learnable_params" in load_ckpt: 188 | learnable_params = torch.load(load_ckpt) 189 | for k, v in learnable_params.items(): 190 | learnable_params[k] = v.to(gpu) 191 | assert test_only 192 | new_params = forward( 193 | policy, model, base_params, decomposed_params, learnable_params 194 | ) 195 | else: 196 | state_dict = torch.load(load_ckpt, weights_only=True) 197 | policy.load_state_dict(state_dict=state_dict) 198 | if test_only: 199 | learnable_params = policy.get_learnable_params() 200 | new_params = forward( 201 | policy, model, base_params, decomposed_params, learnable_params 202 | ) 203 | load_hf_params_to_vllm(new_params, vllm_model.llm) 204 | else: 205 | print(f"Starting from the base model as load_ckpt=={load_ckpt}") 206 | 207 | model.eval() 208 | 209 | # Prompt-based and cls dispatcher evaluation. 210 | if test_only and prompt_based_eval: 211 | test_data_dict = eval_model_experts_prompt_based( 212 | vllm_model, 213 | test_eval, 214 | experts_path_dict, 215 | policy, 216 | model, 217 | base_params, 218 | decomposed_params, 219 | task_loader.target_metric_test, 220 | ) 221 | test_data_dict["type"] = "test" 222 | # Log the results. 223 | if cfg.wandb_log: 224 | wandb.log(test_data_dict) 225 | with open(f"{log_dir}/eval_results.json", "w") as f: 226 | json.dump(test_data_dict, f, indent=4) 227 | print(f"Test evaluation results: {test_data_dict}") 228 | 229 | # Eval the transfer set if available 230 | if has_transfer_split: 231 | transfer_data_dict = eval_model_experts_prompt_based( 232 | vllm_model, 233 | transfer_eval, 234 | experts_path_dict, 235 | policy, 236 | model, 237 | base_params, 238 | decomposed_params, 239 | task_loader.target_metric_transfer, 240 | ) 241 | transfer_data_dict["type"] = "transfer" 242 | # Log the results. 243 | if cfg.wandb_log: 244 | wandb.log(transfer_data_dict) 245 | with open(f"{log_dir}/eval_results.json", "w") as f: 246 | json.dump(transfer_data_dict, f, indent=4) 247 | print(f"Transfer evaluation results: {transfer_data_dict}") 248 | 249 | return 250 | 251 | # Non-adaptive evaluation on train, val, test set. 252 | if test_only and not prompt_based_eval: 253 | data_dict = {} 254 | details_dict = {} 255 | if has_training_split: 256 | train_res = eval_model(vllm_model, train_eval, train_ix) 257 | valid_res = eval_model(vllm_model, train_eval, valid_ix) 258 | data_dict["train_acc"] = train_res.aggregate_metrics[ 259 | task_loader.target_metric_train 260 | ] 261 | data_dict["valid_acc"] = valid_res.aggregate_metrics[ 262 | task_loader.target_metric_valid 263 | ] 264 | details_dict["train"] = train_res.sample_details 265 | details_dict["valid"] = valid_res.sample_details 266 | test_res = eval_model(vllm_model, test_eval) 267 | data_dict["test_acc"] = test_res.aggregate_metrics[ 268 | task_loader.target_metric_test 269 | ] 270 | details_dict["test"] = test_res.sample_details 271 | if has_transfer_split: 272 | transfer_res = eval_model(vllm_model, transfer_eval) 273 | data_dict["transfer_acc"] = transfer_res.aggregate_metrics[ 274 | task_loader.target_metric_transfer 275 | ] 276 | details_dict["transfer"] = transfer_res.sample_details 277 | if cfg.wandb_log: 278 | wandb.log(data_dict) 279 | with open(f"{log_dir}/eval_results.json", "w") as f: 280 | json.dump(data_dict, f, indent=4) 281 | print(f"Evaluation results: {data_dict}") 282 | return 283 | 284 | learnable_params = policy.get_learnable_params() 285 | for k in learnable_params: 286 | model.get_parameter(k).requires_grad_(True) 287 | 288 | # Training loop. 289 | if batch_size is None: 290 | clipped_batch_size = len(list(train_ix)) 291 | else: 292 | clipped_batch_size = min(batch_size, len(list(train_ix))) 293 | best_val_acc = 0.0 294 | test_at_best = 0.0 295 | transfer_at_best = 0.0 296 | for i in range(num_iters): 297 | 298 | batch_ix = np_random.choice(train_ix, size=clipped_batch_size, replace=False) 299 | 300 | optimization_algorithm.step_optimization( 301 | model_id=model_id, 302 | model=model, 303 | tokenizer=tokenizer, 304 | policy=policy, 305 | task_loader=task_loader, 306 | batch_ix=batch_ix, 307 | train_data=train_data, 308 | train_eval=train_eval, 309 | base_params=base_params, 310 | decomposed_params=decomposed_params, 311 | original_model_params=original_model_params, 312 | metrics_to_log=metrics_to_log, 313 | vllm_model=vllm_model, 314 | ) 315 | 316 | with torch.no_grad(): 317 | lists_to_log = {} 318 | grads = [p.grad for p in policy.trainable_params] 319 | if grads[0] is not None: 320 | grad_mean = [g.mean().item() for g in grads] 321 | grad_mags = [torch.linalg.vector_norm(g).item() for g in grads] 322 | lists_to_log["grad_mean"] = grad_mean 323 | lists_to_log["grad_mags"] = grad_mags 324 | 325 | param_mags = [ 326 | torch.linalg.vector_norm(p).item() for p in policy.trainable_params 327 | ] 328 | lists_to_log["policy_param_mag"] = param_mags 329 | 330 | generated_params_list = list(learnable_params.values()) 331 | 332 | generated_param_mean = [p.mean().item() for p in generated_params_list] 333 | generated_param_mags = [ 334 | torch.linalg.vector_norm(p).item() for p in generated_params_list 335 | ] 336 | lists_to_log["generated_param_mean"] = generated_param_mean 337 | lists_to_log["generated_param_mags"] = generated_param_mags 338 | 339 | list_stats = {} 340 | for k, v in lists_to_log.items(): 341 | list_stats.update(get_mean_std_max_min_dict(array=v, prefix=k)) 342 | metrics_to_log.update(**list_stats) 343 | 344 | optimization_algorithm.update(policy=policy) 345 | 346 | # Make sure old params are deleted and garbage-collected 347 | gc.collect() 348 | torch.cuda.empty_cache() 349 | model.zero_grad() 350 | 351 | # More accurate logging. 352 | value_mean = list_stats.get("generated_param_mean/mean", None) 353 | grad_mean_mag = list_stats.get("grad_mags/mean", None) 354 | print( 355 | f"Iter {i}: " 356 | + f"param_mean={value_mean}, " 357 | + f"grad_mean_mag={grad_mean_mag}" 358 | ) 359 | optimization_algorithm.log_optim(metrics_to_log=metrics_to_log) 360 | 361 | # Test and save. 362 | if i % test_interval == 0: 363 | learnable_params = policy.get_learnable_params() 364 | forward(policy, model, base_params, decomposed_params, learnable_params) 365 | load_hf_params_to_vllm(model.state_dict(), vllm_model.llm) 366 | 367 | train_res = eval_model(vllm_model, train_eval, train_ix) 368 | valid_res = eval_model(vllm_model, train_eval, valid_ix) 369 | test_res = eval_model(vllm_model, test_eval) 370 | if has_transfer_split: 371 | transfer_res = eval_model(vllm_model, transfer_eval) 372 | if ( 373 | valid_res.aggregate_metrics[task_loader.target_metric_valid] 374 | > best_val_acc 375 | ): 376 | best_val_acc = valid_res.aggregate_metrics[ 377 | task_loader.target_metric_valid 378 | ] 379 | test_at_best = test_res.aggregate_metrics[ 380 | task_loader.target_metric_test 381 | ] 382 | if has_transfer_split: 383 | transfer_at_best = transfer_res.aggregate_metrics[ 384 | task_loader.target_metric_transfer 385 | ] 386 | print("best_val_acc updated") 387 | path = f"{log_dir}/policy_params.pt" 388 | torch.save(policy.state_dict(), path) 389 | if save_legacy_params: 390 | torch.save(learnable_params, f"{log_dir}/learnable_params.pt") 391 | 392 | path = f"{log_dir}/policy_params_latest.pt" 393 | torch.save(policy.state_dict(), path) 394 | if save_legacy_params: 395 | torch.save(learnable_params, f"{log_dir}/learnable_params_latest.pt") 396 | 397 | policy.record_state(metrics_to_log=metrics_to_log) 398 | data_dict = { 399 | "iter": i, 400 | "best_val_acc": best_val_acc, 401 | "test_at_best_val": test_at_best, 402 | "train_acc": train_res.aggregate_metrics[ 403 | task_loader.target_metric_train 404 | ], 405 | "valid_acc": valid_res.aggregate_metrics[ 406 | task_loader.target_metric_valid 407 | ], 408 | "test_acc": test_res.aggregate_metrics[task_loader.target_metric_test], 409 | **metrics_to_log.get(), 410 | } 411 | if has_transfer_split: 412 | data_dict["transfer_acc"] = transfer_res.aggregate_metrics[ 413 | task_loader.target_metric_transfer 414 | ] 415 | data_dict["transfer_at_best_val"] = transfer_at_best 416 | if cfg.wandb_log: 417 | wandb.log(data_dict) 418 | with open(f"{log_dir}/reinforce_log.json", "a") as f: 419 | json_data = json.dumps(data_dict, indent=4) 420 | f.write(json_data) 421 | f.write("\n") 422 | metrics_to_log.reset() 423 | 424 | 425 | if __name__ == "__main__": 426 | main() 427 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .arc import AI2ArcTask 2 | from .base import FewShotTask, Task 3 | from .cls import ClsTask 4 | from .gsm8k import Gsm8kTask 5 | from .math import MathTask 6 | from .mbpp2 import Mbpp2Task 7 | -------------------------------------------------------------------------------- /tasks/arc.py: -------------------------------------------------------------------------------- 1 | import fishfarm 2 | import vllm 3 | from datasets import load_dataset 4 | from fishfarm.models.vllm_model import VLLMModel 5 | from fishfarm.tasks.ai2_arc import Ai2ArcSample, Ai2ArcTask 6 | 7 | from .base import LLAMA3_COT, Task, get_download_dir 8 | 9 | choices = ["A", "B", "C", "D", "E"] 10 | 11 | 12 | class AI2ArcTask(Task): 13 | def __init__( 14 | self, 15 | ): 16 | self.model_to_template = { 17 | "meta-llama/Meta-Llama-3-8B-Instruct": LLAMA3_COT, 18 | "mistralai/Mistral-7B-Instruct-v0.3": None, 19 | } 20 | self.system_msg = ( 21 | "The following are multiple choice questions (with answers). " 22 | "Think step by step and then finish your answer " 23 | 'with "the answer is (X)" where X is the correct letter choice.' 24 | ) 25 | self.target_metric_train = "acc" 26 | self.target_metric_valid = self.target_metric_train 27 | self.target_metric_test = self.target_metric_train 28 | self.target_metric_transfer = self.target_metric_train 29 | self.has_transfer_split = True 30 | self.has_training_split = True 31 | 32 | def get_train_data( 33 | self, 34 | ): 35 | train_eval, *test_evals = self.get_evaluator() 36 | train_data = train_eval.samples 37 | train_size = len(train_data) 38 | train_ix = range(0, train_size, 2) 39 | valid_ix = range(1, train_size, 2) 40 | return train_data, train_ix, valid_ix 41 | 42 | def get_rewards(self, res): 43 | rewards = [1.0 if x["correct"] else -1.0 for x in res.sample_details] 44 | return rewards 45 | 46 | def get_evaluator( 47 | self, 48 | ): 49 | res = [] 50 | for split in ["train", "test"]: 51 | dataset = load_dataset("allenai/ai2_arc", "ARC-Easy", split=split) 52 | samples = [] 53 | for sample in dataset: 54 | options = [] 55 | for opt in sample["choices"]["text"]: 56 | options.append(opt) 57 | # add options to the question 58 | question = sample["question"] + "\n" 59 | question += "Options:\n" 60 | for i, opt in enumerate(options): 61 | question += "{}. {}\n".format(choices[i], opt) 62 | samples.append( 63 | Ai2ArcSample( 64 | question=question, 65 | answer=sample["answerKey"], 66 | options=options, 67 | question_id=sample["id"], 68 | ) 69 | ) 70 | res.append( 71 | Ai2ArcTask( 72 | samples=samples, 73 | context_messages=[ 74 | fishfarm.Message("system", self.system_msg), 75 | ], 76 | ) 77 | ) 78 | dataset = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test") 79 | samples = [] 80 | for sample in dataset: 81 | options = [] 82 | for opt in sample["choices"]["text"]: 83 | options.append(opt) 84 | # add options to the question 85 | question = sample["question"] + "\n" 86 | question += "Options:\n" 87 | for i, opt in enumerate(options): 88 | question += "{}. {}\n".format(choices[i], opt) 89 | samples.append( 90 | Ai2ArcSample( 91 | question=question, 92 | answer=sample["answerKey"], 93 | options=options, 94 | question_id=sample["id"], 95 | ) 96 | ) 97 | res.append( 98 | Ai2ArcTask( 99 | samples=samples, 100 | context_messages=[ 101 | fishfarm.Message("system", self.system_msg), 102 | ], 103 | ) 104 | ) 105 | return tuple(res) 106 | 107 | def get_prompt(self, tokenizer, samples, ix, model_id): 108 | chat_template = self.model_to_template[model_id] 109 | context_msg = {"role": "system", "content": self.system_msg} 110 | user_msg = {"role": "user", "content": samples[ix].question} 111 | prompt = tokenizer.apply_chat_template( 112 | conversation=[context_msg, user_msg], 113 | chat_template=chat_template, 114 | tokenize=False, 115 | add_generation_prompt=True, 116 | ) 117 | return prompt 118 | 119 | def get_vllm_model(self, model_id) -> VLLMModel: 120 | """Load a vLLM model.""" 121 | model = vllm.LLM( 122 | model_id, 123 | max_model_len=1024, 124 | gpu_memory_utilization=0.8, 125 | enforce_eager=True, 126 | dtype="bfloat16", 127 | download_dir=get_download_dir(), 128 | ) 129 | chat_template = self.model_to_template[model_id] 130 | # This may change with vLLM versions. 131 | m = model.llm_engine.model_executor.driver_worker.model_runner.model 132 | for _, param in m.named_parameters(): 133 | param.requires_grad = False 134 | vllm_model = VLLMModel( 135 | model, 136 | sampling_params=vllm.SamplingParams( 137 | temperature=0, 138 | top_p=1, 139 | max_tokens=512, 140 | stop=["Instruction:", "Instruction", "Response:", "Response"], 141 | repetition_penalty=1.0, 142 | ), 143 | chat_template=chat_template, 144 | ) 145 | return vllm_model 146 | -------------------------------------------------------------------------------- /tasks/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from typing import Union 4 | 5 | import hydra 6 | from omegaconf import DictConfig 7 | 8 | LLAMA3_COT = ( 9 | "{% set loop_messages = messages %}" 10 | "{% for message in loop_messages %}" 11 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>" 12 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}" 13 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}" 14 | "{% endif %}" 15 | "{{ content }}" 16 | "{% endfor %}" 17 | "{% if add_generation_prompt %}" 18 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + 'Let\\'s think step by step' }}" 19 | "{% endif %}" 20 | ) 21 | 22 | 23 | CODE_PROMPT = r""" 24 | {% if messages[0]['role'] == 'system' %} 25 | {% set loop_messages = messages[1:] %} 26 | {% set system_message = messages[0]['content'].strip() + '\n\n' %} 27 | {% else %} 28 | {% set loop_messages = messages %} 29 | {% set system_message = '' %} 30 | {% endif %} 31 | 32 | {{ system_message }} 33 | {% for message in loop_messages %} 34 | {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} 35 | {{ raise_exception( 36 | 'Conversation roles must alternate user/assistant/user/assistant/...')}} 37 | {% endif %} 38 | 39 | {% if message['role'] == 'user' %} 40 | {{ '@@ Instruction:\n' + message['content'].strip() + '\n\n' }} 41 | {% elif message['role'] == 'assistant' %} 42 | {{ '@@ Response:\n' + message['content'].strip() }} 43 | {% endif %} 44 | 45 | {% if loop.last and message['role'] == 'user' and add_generation_prompt %} 46 | {{ '@@ Response:' }} 47 | {% endif %} 48 | {% endfor %} 49 | """.replace( 50 | " ", "" 51 | ).replace( 52 | "\n", "" 53 | ) 54 | 55 | 56 | def get_download_dir(): 57 | if "HF_HOME" in os.environ: 58 | return os.environ["HF_HOME"] + "/models" 59 | else: 60 | return os.path.expanduser("~") + "/.cache/huggingface/models" 61 | 62 | 63 | class Task(ABC): 64 | def __init__( 65 | self, 66 | ): 67 | self.model_to_template = {} 68 | self.system_msg = () 69 | self.target_metric_train = None 70 | self.target_metric_valid = self.target_metric_train 71 | self.target_metric_test = self.target_metric_train 72 | self.target_metric_transfer = None 73 | self.has_transfer_split = True 74 | self.has_training_split = True 75 | 76 | @abstractmethod 77 | def get_train_data( 78 | self, 79 | ): 80 | raise NotImplementedError 81 | 82 | @abstractmethod 83 | def get_rewards(self, res): 84 | raise NotImplementedError 85 | 86 | @abstractmethod 87 | def get_evaluator( 88 | self, 89 | ): 90 | raise NotImplementedError 91 | 92 | @abstractmethod 93 | def get_prompt(self, tokenizer, samples, ix, model_id): 94 | raise NotImplementedError 95 | 96 | @abstractmethod 97 | def get_vllm_model(self, model_id): 98 | raise NotImplementedError 99 | 100 | 101 | class FewShotTask(Task): 102 | def __init__( 103 | self, 104 | wrapped_task: Union[Task, DictConfig], 105 | wrapped_split: str = "test", 106 | shots=5, 107 | seed=16, 108 | ): 109 | if isinstance(wrapped_task, Task): 110 | self.wrapped_task: Task = wrapped_task 111 | else: 112 | self.wrapped_task: Task = hydra.utils.instantiate(wrapped_task) 113 | 114 | self.wrapped_split = wrapped_split 115 | self.shots = shots 116 | self.seed = seed 117 | self.model_to_template = wrapped_task.model_to_template 118 | self.system_msg = wrapped_task.system_msg 119 | if wrapped_split == "train": 120 | self.target_metric_train = wrapped_task.target_metric_train 121 | self.target_metric_valid = wrapped_task.target_metric_train 122 | self.target_metric_test = wrapped_task.target_metric_train 123 | assert wrapped_task.has_training_split 124 | elif wrapped_split == "test": 125 | self.target_metric_train = wrapped_task.target_metric_test 126 | self.target_metric_valid = wrapped_task.target_metric_test 127 | self.target_metric_test = wrapped_task.target_metric_test 128 | elif wrapped_split == "transfer": 129 | self.target_metric_train = wrapped_task.target_metric_transfer 130 | self.target_metric_valid = wrapped_task.target_metric_transfer 131 | self.target_metric_test = wrapped_task.target_metric_transfer 132 | assert wrapped_task.has_transfer_split 133 | else: 134 | raise NotImplementedError 135 | self.target_metric_transfer = wrapped_task.target_metric_transfer 136 | self.has_transfer_split = False 137 | self.has_training_split = True 138 | 139 | def get_train_data( 140 | self, 141 | ): 142 | train_eval, *test_evals = self.get_evaluator() 143 | train_data = train_eval.samples 144 | train_size = len(train_data) 145 | total_ix = list(range(train_size)) 146 | import random 147 | 148 | random.seed(self.seed) # fix random seed for reproducibility 149 | random.shuffle(total_ix) 150 | train_ix = total_ix[: self.shots] 151 | valid_ix = total_ix[self.shots :] 152 | return train_data, train_ix, valid_ix 153 | 154 | def get_rewards(self, res): 155 | return self.wrapped_task.get_rewards(res=res) 156 | 157 | def get_evaluator( 158 | self, 159 | ): 160 | evaluators = self.wrapped_task.get_evaluator() 161 | if self.wrapped_split == "train": 162 | target_eval = evaluators[0] 163 | elif self.wrapped_split == "test": 164 | target_eval = evaluators[1] 165 | elif self.wrapped_split == "transfer": 166 | target_eval = evaluators[2] 167 | return target_eval, target_eval 168 | 169 | def get_prompt(self, tokenizer, samples, ix, model_id): 170 | return self.wrapped_task.get_prompt( 171 | tokenizer=tokenizer, 172 | samples=samples, 173 | ix=ix, 174 | model_id=model_id, 175 | ) 176 | 177 | def get_vllm_model(self, model_id): 178 | return self.wrapped_task.get_vllm_model( 179 | model_id=model_id, 180 | ) 181 | -------------------------------------------------------------------------------- /tasks/cls.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass 3 | from typing import Iterable, Tuple 4 | 5 | import datasets 6 | import fishfarm 7 | import vllm 8 | from fishfarm.models.vllm_model import VLLMModel 9 | from fishfarm.tasks.base import TaskResult 10 | from fishfarm.tasks.evalplus import load_dataset 11 | 12 | from .base import Task, get_download_dir 13 | 14 | 15 | def mean(iterable: Iterable[float]) -> float: 16 | total, count = 0.0, 0 17 | for x in iterable: 18 | total += x 19 | count += 1 20 | return total / count 21 | 22 | 23 | def extract_ans(text): 24 | """Fetch the string within \\boxed{}.""" 25 | match = re.search(r"\\boxed{([^}]*)}", text) 26 | if match: 27 | return match.group(1) # Return the content inside the \boxed{} 28 | else: 29 | return None # Return None if no match is found 30 | 31 | 32 | @dataclass 33 | class CategorySample: 34 | question: str 35 | label: str 36 | 37 | 38 | class CategoryClassficiationTask(fishfarm.tasks.base.Task): 39 | def __init__( 40 | self, 41 | samples, 42 | context_messages, 43 | ): 44 | self.samples = list(samples) 45 | self.context_messages = context_messages 46 | 47 | @property 48 | def num_samples(self) -> int: 49 | return len(self.samples) 50 | 51 | def evaluate( 52 | self, 53 | model, 54 | sample_ids, 55 | ): 56 | if sample_ids is None: 57 | sample_ids = range(len(self.samples)) 58 | samples = [self.samples[sample_id] for sample_id in sample_ids] 59 | 60 | requests = [] 61 | for sample in samples: 62 | messages = list(self.context_messages) 63 | messages.append(fishfarm.Message(role="user", content=sample.question)) 64 | requests.append(fishfarm.models.GenerationRequest(messages=messages)) 65 | 66 | sample_details = [] 67 | for sample, result in zip(samples, model.generate(requests)): 68 | output = result.generation 69 | prediction = extract_ans(output) 70 | 71 | sample_details.append( 72 | dict( 73 | question=sample.question, 74 | label=sample.label, 75 | output=output, 76 | prediction=prediction, 77 | correct=sample.label == prediction, 78 | ) 79 | ) 80 | 81 | aggregate_metrics = { 82 | "acc": mean( 83 | float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0 84 | for sd in sample_details 85 | ) 86 | } 87 | return TaskResult( 88 | aggregate_metrics=aggregate_metrics, sample_details=sample_details 89 | ) 90 | 91 | 92 | class ClsTask(Task): 93 | def __init__(self): 94 | self.model_to_template = { 95 | "meta-llama/Meta-Llama-3-8B-Instruct": ( 96 | "{% set loop_messages = messages %}" 97 | "{% for message in loop_messages %}" 98 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>" 99 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}" 100 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}" 101 | "{% endif %}" 102 | "{{ content }}" 103 | "{% endfor %}" 104 | "{% if add_generation_prompt %}" 105 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" 106 | "{% endif %}" 107 | ), 108 | "mistralai/Mistral-7B-Instruct-v0.3": None, 109 | } 110 | self.system_msg = """ 111 | # Analyze the given question and classify it into one of four categories: 'code', 'math', 'reasoning' or 'other'. Follow these guidelines: 112 | 113 | 1. Code: Questions asking for programming solutions, functions, algorithms. Often includes specific programming terms, language syntax, or data structures. 114 | 2. Math: Questions involving mathematical calculations, formulas, statistics. Often includes numbers, equations, or mathematical operations. 115 | 3. Reasoning: Questions requiring logical thinking, application of scientific knowledge, or critical analysis of information. Often presents statements that need evaluation based on general understanding. 116 | 4. Other: Questions not clearly fit into above categories. 117 | 118 | Instructions: 119 | - Consider the primary focus, skills, and knowledge required to answer the question. 120 | - If a question spans multiple categories, choose the most dominant one. 121 | - Provide your final classification within \\boxed{} notation. Example: \\boxed{reasoning} 122 | 123 | Format your response as follows: 124 | Classification: \\boxed{category} 125 | """ 126 | self.target_metric_train = "acc" 127 | self.target_metric_valid = self.target_metric_train 128 | self.target_metric_test = self.target_metric_train 129 | self.target_metric_transfer = self.target_metric_train 130 | self.has_transfer_split = False 131 | self.has_training_split = True 132 | self.num_samples_per_task = 400 # Hard code 400 samples per task 133 | self.task_datasets = [ 134 | datasets.load_dataset("gsm8k", "main", split="test"), 135 | load_dataset(source_dataset="mbpp"), 136 | datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test"), 137 | ] 138 | self.train_samples, self.test_samples = self.build_samples() 139 | 140 | def split_samples(self, samples): 141 | """Split samples into train and test sets with rate 4 : 1.""" 142 | train_samples = [] 143 | test_samples = [] 144 | for i, sample in enumerate(samples): 145 | if i % 5 < 4: 146 | train_samples.append(sample) 147 | else: 148 | test_samples.append(sample) 149 | return train_samples, test_samples 150 | 151 | def get_train_data(self=400): 152 | train_ix = range(0, len(self.train_samples), 2) 153 | valid_ix = range(1, len(self.train_samples), 2) 154 | 155 | return self.train_samples, train_ix, valid_ix 156 | 157 | def build_samples(self): 158 | task_labels = ["math", "code", "reasoning"] 159 | 160 | samples = [] 161 | choices = ["A", "B", "C", "D", "E"] 162 | for dataset, label in zip(self.task_datasets, task_labels): 163 | counter = 0 164 | for sample in dataset: 165 | counter += 1 166 | if counter >= self.num_samples_per_task: 167 | break 168 | if label == "math": 169 | samples.append( 170 | CategorySample( 171 | question=sample["question"], 172 | label="math", 173 | ) 174 | ) 175 | elif label == "code": 176 | samples.append( 177 | CategorySample( 178 | question=sample.instruction, 179 | label="code", 180 | ) 181 | ) 182 | else: # reasoning 183 | question = sample["question"] + "\n" 184 | question += "Options:\n" 185 | options = [] 186 | for opt in sample["choices"]["text"]: 187 | options.append(opt) 188 | for i, opt in enumerate(options): 189 | question += "{}. {}\n".format(choices[i], opt) 190 | samples.append( 191 | CategorySample( 192 | question=question, 193 | label="reasoning", 194 | ) 195 | ) 196 | 197 | return self.split_samples(samples) 198 | 199 | def get_rewards(self, res): 200 | rewards = [1.0 if x["correct"] else -1.0 for x in res.sample_details] 201 | return rewards 202 | 203 | def get_evaluator(self) -> Tuple: 204 | # Build cls dataset here with training tasks. 205 | res = [] 206 | for samples in [self.train_samples, self.test_samples]: 207 | res.append( 208 | CategoryClassficiationTask( 209 | samples=samples, 210 | context_messages=[ 211 | fishfarm.Message("system", self.system_msg), 212 | ], 213 | ) 214 | ) 215 | 216 | return tuple(res) 217 | 218 | def get_prompt(self, tokenizer, samples, ix, model_id): 219 | chat_template = self.model_to_template[model_id] 220 | context_msg = {"role": "system", "content": self.system_msg} 221 | user_msg = {"role": "user", "content": samples[ix].question} 222 | prompt = tokenizer.apply_chat_template( 223 | conversation=[context_msg, user_msg], 224 | chat_template=chat_template, 225 | tokenize=False, 226 | add_generation_prompt=True, 227 | ) 228 | return prompt 229 | 230 | def get_vllm_model(self, model_id) -> VLLMModel: 231 | """Load a vLLM model.""" 232 | model = vllm.LLM( 233 | model_id, 234 | max_model_len=2048, 235 | gpu_memory_utilization=0.8, 236 | enforce_eager=True, 237 | dtype="bfloat16", 238 | download_dir=get_download_dir(), 239 | ) 240 | chat_template = self.model_to_template[model_id] 241 | # This may change with vLLM versions. 242 | m = model.llm_engine.model_executor.driver_worker.model_runner.model 243 | for _, param in m.named_parameters(): 244 | param.requires_grad = False 245 | vllm_model = VLLMModel( 246 | model, 247 | sampling_params=vllm.SamplingParams( 248 | temperature=0, 249 | top_p=1, 250 | max_tokens=1024, 251 | stop=["Instruction:", "Instruction", "Response:", "Response"], 252 | repetition_penalty=1.0, 253 | ), 254 | chat_template=chat_template, 255 | ) 256 | return vllm_model 257 | -------------------------------------------------------------------------------- /tasks/gsm8k.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import fishfarm 4 | import vllm 5 | from datasets import load_dataset 6 | from fishfarm.models.vllm_model import VLLMModel 7 | from fishfarm.tasks.language_restricted_math import ( 8 | LanguageRestrictedMathTask, MathSample, extract_answer_number) 9 | 10 | from .base import Task, get_download_dir 11 | 12 | 13 | class Gsm8kTask(Task): 14 | def __init__( 15 | self, 16 | ): 17 | self.model_to_template = { 18 | "meta-llama/Meta-Llama-3-8B-Instruct": ( 19 | "{% set loop_messages = messages %}" 20 | "{% for message in loop_messages %}" 21 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>" 22 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}" 23 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}" 24 | "{% endif %}" 25 | "{{ content }}" 26 | "{% endfor %}" 27 | "{% if add_generation_prompt %}" 28 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" 29 | "{% endif %}" 30 | ), 31 | "mistralai/Mistral-7B-Instruct-v0.3": None, 32 | } 33 | self.system_msg = ( 34 | "Below is an instruction that describes a task." 35 | " Write a response that appropriately completes the request.\n\n" 36 | ) 37 | 38 | self.target_metric_train = "acc" 39 | self.target_metric_valid = self.target_metric_train 40 | self.target_metric_test = self.target_metric_train 41 | self.target_metric_transfer = self.target_metric_train 42 | self.has_transfer_split = False 43 | self.has_training_split = True 44 | 45 | def get_train_data( 46 | self, 47 | ): 48 | train_data = load_dataset("gsm8k", "main", split="train") 49 | train_size = len(train_data) 50 | train_ix = range(0, train_size, 2) 51 | valid_ix = range(1, train_size, 2) 52 | return train_data, train_ix, valid_ix 53 | 54 | def get_rewards(self, res): 55 | rewards = [1.0 if x["correct"] else -1.0 for x in res.sample_details] 56 | return rewards 57 | 58 | def get_evaluator(self) -> Tuple: 59 | res = [] 60 | for split in ["train", "test"]: 61 | dataset = load_dataset("gsm8k", "main", split=split) 62 | samples = [] 63 | for sample in dataset: 64 | answer = sample["answer"] 65 | answer = extract_answer_number(answer) 66 | answer = int(answer) if answer is not None else None 67 | samples.append( 68 | MathSample( 69 | problem=sample["question"], 70 | answer=answer, 71 | ) 72 | ) 73 | res.append( 74 | LanguageRestrictedMathTask( 75 | samples=samples, 76 | context_messages=[ 77 | fishfarm.Message("system", self.system_msg), 78 | ], 79 | languages=[], 80 | ) 81 | ) 82 | return tuple(res) 83 | 84 | def get_prompt(self, tokenizer, samples, ix, model_id): 85 | chat_template = self.model_to_template[model_id] 86 | context_msg = {"role": "system", "content": self.system_msg} 87 | user_msg = {"role": "user", "content": samples["question"][ix]} 88 | prompt = tokenizer.apply_chat_template( 89 | conversation=[context_msg, user_msg], 90 | chat_template=chat_template, 91 | tokenize=False, 92 | add_generation_prompt=True, 93 | ) 94 | return prompt 95 | 96 | def get_vllm_model(self, model_id) -> VLLMModel: 97 | """Load a vLLM model.""" 98 | model = vllm.LLM( 99 | model_id, 100 | max_model_len=1024, 101 | gpu_memory_utilization=0.8, 102 | enforce_eager=True, 103 | dtype="bfloat16", 104 | download_dir=get_download_dir(), 105 | ) 106 | chat_template = self.model_to_template[model_id] 107 | # This may change with vLLM versions. 108 | m = model.llm_engine.model_executor.driver_worker.model_runner.model 109 | for _, param in m.named_parameters(): 110 | param.requires_grad = False 111 | vllm_model = VLLMModel( 112 | model, 113 | sampling_params=vllm.SamplingParams( 114 | temperature=0, 115 | top_p=1, 116 | max_tokens=512, 117 | stop=["Instruction:", "Instruction", "Response:", "Response"], 118 | repetition_penalty=1.0, 119 | ), 120 | chat_template=chat_template, 121 | ) 122 | return vllm_model 123 | -------------------------------------------------------------------------------- /tasks/math.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import fishfarm 4 | import vllm 5 | from datasets import load_dataset 6 | from fishfarm.models.vllm_model import VLLMModel 7 | from fishfarm.tasks.competation_math import (LatexFormatMathTask, MathSample, 8 | last_boxed_only_string, 9 | remove_boxed) 10 | 11 | from .base import Task, get_download_dir 12 | 13 | 14 | class MathTask(Task): 15 | def __init__(self): 16 | self.model_to_template = { 17 | "meta-llama/Meta-Llama-3-8B-Instruct": ( 18 | "{% set loop_messages = messages %}" 19 | "{% for message in loop_messages %}" 20 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>" 21 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}" 22 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}" 23 | "{% endif %}" 24 | "{{ content }}" 25 | "{% endfor %}" 26 | "{% if add_generation_prompt %}" 27 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" 28 | "{% endif %}" 29 | ), 30 | "mistralai/Mistral-7B-Instruct-v0.3": None, 31 | } 32 | self.system_msg = ( 33 | "Solve the question below by reasoning step by step," 34 | "and put the final answer within \\boxed{}." 35 | ) 36 | 37 | self.target_metric_train = "acc" 38 | self.target_metric_valid = self.target_metric_train 39 | self.target_metric_test = self.target_metric_train 40 | self.target_metric_transfer = self.target_metric_train 41 | self.has_transfer_split = False 42 | self.has_training_split = False 43 | 44 | def get_train_data( 45 | self, 46 | ): 47 | return None, None, None 48 | 49 | def get_rewards(self, res): 50 | rewards = [1.0 if x["correct"] else -1.0 for x in res.sample_details] 51 | return rewards 52 | 53 | def get_evaluator( 54 | self, 55 | ) -> Tuple: 56 | res = [None] 57 | dataset = load_dataset("hendrycks/competition_math", "main", split="test") 58 | 59 | samples = [] 60 | for sample in dataset: 61 | answer = remove_boxed(last_boxed_only_string((sample["solution"]))) 62 | samples.append( 63 | MathSample( 64 | problem=sample["problem"], answer=answer, type=sample["type"] 65 | ) 66 | ) 67 | 68 | test_eval = LatexFormatMathTask( 69 | samples=samples, 70 | context_messages=[ 71 | fishfarm.Message("system", self.system_msg), 72 | ], 73 | ) 74 | res.append(test_eval) 75 | return tuple(res) 76 | 77 | def get_prompt(self, tokenizer, samples, ix, model_id): 78 | chat_template = self.model_to_template[model_id] 79 | context_msg = {"role": "system", "content": self.system_msg} 80 | user_msg = {"role": "user", "content": samples[ix].problem} 81 | prompt = tokenizer.apply_chat_template( 82 | conversation=[context_msg, user_msg], 83 | chat_template=chat_template, 84 | tokenize=False, 85 | add_generation_prompt=True, 86 | ) 87 | return prompt 88 | 89 | def get_vllm_model(self, model_id) -> VLLMModel: 90 | """Load a vLLM model.""" 91 | model = vllm.LLM( 92 | model_id, 93 | max_model_len=2048, 94 | gpu_memory_utilization=0.8, 95 | enforce_eager=True, 96 | dtype="bfloat16", 97 | download_dir=get_download_dir(), 98 | ) 99 | chat_template = self.model_to_template[model_id] 100 | # This may change with vLLM versions. 101 | m = model.llm_engine.model_executor.driver_worker.model_runner.model 102 | for _, param in m.named_parameters(): 103 | param.requires_grad = False 104 | vllm_model = VLLMModel( 105 | model, 106 | sampling_params=vllm.SamplingParams( 107 | temperature=0, 108 | top_p=1, 109 | max_tokens=1024, 110 | stop=["Instruction:", "Instruction", "Response:", "Response"], 111 | repetition_penalty=1.0, 112 | ), 113 | chat_template=chat_template, 114 | ) 115 | return vllm_model 116 | -------------------------------------------------------------------------------- /tasks/mbpp2.py: -------------------------------------------------------------------------------- 1 | import fishfarm 2 | import vllm 3 | from fishfarm.models.vllm_model import VLLMModel 4 | from fishfarm.tasks.evalplus import EvalplusTask, load_dataset 5 | 6 | from .base import CODE_PROMPT, Task, get_download_dir 7 | 8 | 9 | class Mbpp2Task(Task): 10 | def __init__( 11 | self, 12 | ): 13 | self.model_to_template = { 14 | "meta-llama/Meta-Llama-3-8B-Instruct": CODE_PROMPT, 15 | "mistralai/Mistral-7B-Instruct-v0.3": CODE_PROMPT, 16 | } 17 | self.system_msg = ( 18 | "You are an exceptionally intelligent coding assistant that " 19 | " consistently delivers accurate and reliable responses to user " 20 | "instructions." 21 | ) 22 | 23 | self.target_metric_train = "mbpp_base_pass@1" 24 | self.target_metric_valid = self.target_metric_train 25 | self.target_metric_test = self.target_metric_train 26 | self.target_metric_transfer = "humaneval_base_pass@1" 27 | self.has_transfer_split = True 28 | self.has_training_split = True 29 | 30 | def get_train_data( 31 | self, 32 | ): 33 | train_eval, *test_evals = self.get_evaluator() 34 | train_data = train_eval.samples 35 | train_size = len(train_data) 36 | total_ix = list(range(train_size)) 37 | import random 38 | 39 | random.seed(16) # fix random seed for reproducibility 40 | random.shuffle(total_ix) 41 | train_ix = total_ix[:200] 42 | valid_ix = total_ix[200:] 43 | return train_data, train_ix, valid_ix 44 | 45 | def get_rewards(self, res): 46 | rewards = [1.0 if x["base_correct"] == 1 else -1.0 for x in res.sample_details] 47 | return rewards 48 | 49 | def get_evaluator( 50 | self, 51 | ): 52 | res = [] 53 | samples = load_dataset(source_dataset="mbpp") 54 | res.append( 55 | EvalplusTask( 56 | samples[:300], 57 | context_messages=[ 58 | fishfarm.Message("system", self.system_msg), 59 | ], 60 | source_dataset="mbpp", 61 | ) 62 | ) 63 | res.append( 64 | EvalplusTask( 65 | samples[300:], 66 | context_messages=[ 67 | fishfarm.Message("system", self.system_msg), 68 | ], 69 | source_dataset="mbpp", 70 | ) 71 | ) 72 | samples = load_dataset(source_dataset="humaneval") 73 | res.append( 74 | EvalplusTask( 75 | samples, 76 | context_messages=[ 77 | fishfarm.Message("system", self.system_msg), 78 | ], 79 | source_dataset="humaneval", 80 | ) 81 | ) 82 | return tuple(res) 83 | 84 | def get_prompt(self, tokenizer, samples, ix, model_id): 85 | chat_template = self.model_to_template[model_id] 86 | context_msg = {"role": "system", "content": self.system_msg} 87 | user_msg = {"role": "user", "content": samples[ix].instruction} 88 | assistant_msg = {"role": "assistant", "content": samples[ix].response_prefix} 89 | return tokenizer.apply_chat_template( 90 | conversation=[context_msg, user_msg, assistant_msg], 91 | chat_template=chat_template, 92 | tokenize=False, 93 | add_generation_prompt=True, 94 | ) 95 | 96 | def get_vllm_model(self, model_id) -> VLLMModel: 97 | """Load a vLLM model.""" 98 | model = vllm.LLM( 99 | model_id, 100 | max_model_len=1024, 101 | gpu_memory_utilization=0.8, 102 | enforce_eager=True, 103 | dtype="bfloat16", 104 | download_dir=get_download_dir(), 105 | ) 106 | chat_template = self.model_to_template[model_id] 107 | # This may change with vLLM versions. 108 | m = model.llm_engine.model_executor.driver_worker.model_runner.model 109 | for _, param in m.named_parameters(): 110 | param.requires_grad = False 111 | vllm_model = VLLMModel( 112 | model, 113 | sampling_params=vllm.SamplingParams( 114 | temperature=0, 115 | top_p=1, 116 | max_tokens=512, 117 | stop=["Instruction:", "Instruction", "Response:", "Response"], 118 | repetition_penalty=1.0, 119 | ), 120 | chat_template=chat_template, 121 | ) 122 | return vllm_model 123 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from copy import deepcopy 3 | from typing import Dict, Optional 4 | 5 | import fishfarm 6 | import torch 7 | import torch.utils 8 | import vllm 9 | 10 | 11 | def load_hf_params_to_vllm(param: Dict, llm: vllm.LLM) -> None: 12 | """Load weights from HF transformer model to vLLM model.""" 13 | 14 | model = llm.llm_engine.model_executor.driver_worker.model_runner.model 15 | num_layers = model.config.num_hidden_layers 16 | 17 | # Load embeddings layer weights. 18 | model_param = model.get_parameter("model.embed_tokens.weight") 19 | model_param.copy_( 20 | param["model.embed_tokens.weight"][: model_param.shape[0]] 21 | .to(model_param.dtype) 22 | .to(model_param.device) 23 | ) 24 | model_param = model.get_parameter("lm_head.weight") 25 | model_param.copy_( 26 | param["lm_head.weight"][: model_param.shape[0]] 27 | .to(model_param.dtype) 28 | .to(model_param.device) 29 | ) 30 | 31 | # Load the final layernorm weights. 32 | model_param = model.get_parameter("model.norm.weight") 33 | model_param.copy_( 34 | param["model.norm.weight"].to(model_param.dtype).to(model_param.device) 35 | ) 36 | 37 | for i in range(num_layers): 38 | # Load qkv_proj weights. 39 | model_param = model.get_parameter(f"model.layers.{i}.self_attn.qkv_proj.weight") 40 | model_param.copy_( 41 | torch.cat( 42 | [ 43 | param[f"model.layers.{i}.self_attn.q_proj.weight"], 44 | param[f"model.layers.{i}.self_attn.k_proj.weight"], 45 | param[f"model.layers.{i}.self_attn.v_proj.weight"], 46 | ], 47 | dim=0, 48 | ) 49 | .to(model_param.dtype) 50 | .to(model_param.device) 51 | ) 52 | # Load gate_up_proj weights. 53 | model_param = model.get_parameter(f"model.layers.{i}.mlp.gate_up_proj.weight") 54 | model_param.copy_( 55 | torch.cat( 56 | [ 57 | param[f"model.layers.{i}.mlp.gate_proj.weight"], 58 | param[f"model.layers.{i}.mlp.up_proj.weight"], 59 | ], 60 | dim=0, 61 | ) 62 | .to(model_param.dtype) 63 | .to(model_param.device) 64 | ) 65 | # Load o_proj and down_proj weights. 66 | model_param = model.get_parameter(f"model.layers.{i}.self_attn.o_proj.weight") 67 | model_param.copy_( 68 | param[f"model.layers.{i}.self_attn.o_proj.weight"] 69 | .to(model_param.dtype) 70 | .to(model_param.device) 71 | ) 72 | model_param = model.get_parameter(f"model.layers.{i}.mlp.down_proj.weight") 73 | model_param.copy_( 74 | param[f"model.layers.{i}.mlp.down_proj.weight"] 75 | .to(model_param.dtype) 76 | .to(model_param.device) 77 | ) 78 | # Load layer_norm weights. 79 | model_param = model.get_parameter(f"model.layers.{i}.input_layernorm.weight") 80 | model_param.copy_( 81 | param[f"model.layers.{i}.input_layernorm.weight"] 82 | .to(model_param.dtype) 83 | .to(model_param.device) 84 | ) 85 | model_param = model.get_parameter( 86 | f"model.layers.{i}.post_attention_layernorm.weight" 87 | ) 88 | model_param.copy_( 89 | param[f"model.layers.{i}.post_attention_layernorm.weight"] 90 | .to(model_param.dtype) 91 | .to(model_param.device) 92 | ) 93 | 94 | 95 | def eval_model(vllm_model, evaluator, ix=None): 96 | result = evaluator.evaluate(vllm_model, sample_ids=ix) 97 | return result 98 | 99 | 100 | def compose_new_params( 101 | policy, 102 | param_name, 103 | decomposed_params, 104 | learnable_params, 105 | ): 106 | """Compose new parameters from decomposed parameters.""" 107 | mm = policy.get_mask(learnable_params[param_name]) 108 | return ( 109 | decomposed_params[f"{param_name}.U"] 110 | @ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm) 111 | @ decomposed_params[f"{param_name}.V"].T 112 | ) * ( 113 | decomposed_params[f"{param_name}.S"].sum() 114 | / (decomposed_params[f"{param_name}.S"] * mm).sum() 115 | ) 116 | 117 | 118 | @torch.no_grad() 119 | def forward(policy, model, base_params, decomposed_params, learnable_params): 120 | """Forward pass.""" 121 | new_params = {} 122 | for k in base_params: 123 | if "mlp" in k: 124 | new_params[k] = compose_new_params( 125 | policy, k, decomposed_params, learnable_params 126 | ) 127 | model.get_parameter(k).copy_(new_params[k]) 128 | else: 129 | new_params[k] = base_params[k] 130 | return new_params 131 | 132 | 133 | @torch.no_grad() 134 | def load_base_params( 135 | model, 136 | base_params, 137 | ): 138 | for k in base_params: 139 | if "mlp" in k: 140 | model.get_parameter(k).copy_(base_params[k].cuda()) 141 | 142 | 143 | def backward( 144 | policy, 145 | model, 146 | base_params, 147 | decomposed_params, 148 | learnable_params, 149 | ): 150 | """Backward pass.""" 151 | keys_to_backprop = [k for k in base_params if "mlp" in k] 152 | last_key = keys_to_backprop[-1] 153 | for k in keys_to_backprop[:-1]: 154 | compose_new_params(policy, k, decomposed_params, learnable_params).backward( 155 | model.get_parameter(k).grad, retain_graph=True 156 | ) 157 | # release graph 158 | compose_new_params(policy, last_key, decomposed_params, learnable_params).backward( 159 | model.get_parameter(last_key).grad, retain_graph=False 160 | ) 161 | 162 | 163 | def classify_samples(vllm_model, test_eval): 164 | """Classify samples.""" 165 | 166 | CLASSIFICATION_PROMPT = """ 167 | # Analyze the given question and classify it into one of four categories: 'code', 'math', 'reasoning' or 'other'. Follow these guidelines: 168 | 169 | 1. Code: Questions asking for programming solutions, functions, algorithms. Often includes specific programming terms, language syntax, or data structures. 170 | 2. Math: Questions involving mathematical calculations, formulas, statistics. Often includes numbers, equations, or mathematical operations. 171 | 3. Reasoning: Questions requiring logical thinking, application of scientific knowledge, or critical analysis of information. Often presents statements that need evaluation based on general understanding. 172 | 4. Other: Questions not clearly fit into above categories. 173 | 174 | Instructions: 175 | - Consider the primary focus, skills, and knowledge required to answer the question. 176 | - If a question spans multiple categories, choose the most dominant one. 177 | - Provide your final classification within \\boxed{} notation. Example: \\boxed{reasoning} 178 | 179 | Format your response as follows: 180 | Classification: \\boxed{category} 181 | """ 182 | 183 | def extract_classification(text: str) -> Optional[str]: 184 | """ 185 | Extract the classification from the model's output using regex. 186 | """ 187 | match = re.search(r"\\boxed{([^}]*)}", text) 188 | return match.group(1) if match else None 189 | 190 | # Identify the key in the samples that contains the problem text 191 | problem_key = None 192 | for key in ("problem", "question", "instruction"): 193 | if ( 194 | hasattr(test_eval.samples[0], key) 195 | and getattr(test_eval.samples[0], key) is not None 196 | ): 197 | problem_key = key 198 | break 199 | assert problem_key is not None, "Could not find problem text in the samples" 200 | 201 | # Prepare classification requests 202 | classification_requests = [ 203 | fishfarm.models.GenerationRequest( 204 | messages=[ 205 | fishfarm.Message("system", CLASSIFICATION_PROMPT), 206 | fishfarm.Message("user", getattr(sample, problem_key)), 207 | ] 208 | ) 209 | for sample in test_eval.samples 210 | ] 211 | 212 | # Generate classifications using the model 213 | model_outputs = vllm_model.generate(classification_requests) 214 | 215 | # Process results and update samples 216 | classified_samples = [] 217 | for sample, result in zip(test_eval.samples, model_outputs): 218 | prediction = extract_classification(result.generation) 219 | if prediction not in ["code", "math", "reasoning"]: 220 | prediction = "other" 221 | sample.expert_label = prediction 222 | classified_samples.append(sample) 223 | 224 | return classified_samples 225 | 226 | 227 | def eval_model_experts_prompt_based( 228 | vllm_model, 229 | evaluator, 230 | experts_path_dict, 231 | policy, 232 | model, 233 | base_params, 234 | decomposed_params, 235 | task_metric, 236 | ): 237 | """Evaluate the model using expert models and prompt-based classification.""" 238 | results_by_expert: Dict[str, Dict] = {} 239 | 240 | # Classify all test samples 241 | classified_samples = classify_samples(vllm_model, evaluator) 242 | 243 | # Evaluate samples for each expert model 244 | for expert_label, expert_model_path in experts_path_dict.items(): 245 | # Filter samples for current expert 246 | expert_samples = [ 247 | sample 248 | for sample in classified_samples 249 | if sample.expert_label == expert_label 250 | ] 251 | if not expert_samples: 252 | continue 253 | 254 | # Update test evaluation with filtered samples 255 | evaluator.samples = expert_samples 256 | 257 | # Load and apply expert model parameters if available 258 | if expert_model_path: 259 | policy.load_state_dict(torch.load(expert_model_path)) 260 | expert_params = policy.get_learnable_params() 261 | updated_params = forward( 262 | policy=policy, 263 | model=model, 264 | base_params=base_params, 265 | decomposed_params=decomposed_params, 266 | learnable_params=expert_params, 267 | ) 268 | load_hf_params_to_vllm(updated_params, vllm_model.llm) 269 | 270 | # Evaluate current expert model 271 | evaluation_results = eval_model(vllm_model, evaluator) 272 | 273 | # Store results for current expert 274 | results_by_expert[expert_label] = { 275 | "num_samples": len(expert_samples), 276 | "test_acc": evaluation_results.aggregate_metrics[task_metric], 277 | } 278 | 279 | # Compute the overall accuracy. 280 | data_dict = deepcopy(results_by_expert) 281 | data_dict["final_test_acc"] = 0.0 282 | for label in results_by_expert.keys(): 283 | data_dict["final_test_acc"] += ( 284 | results_by_expert[label]["test_acc"] 285 | * results_by_expert[label]["num_samples"] 286 | ) 287 | data_dict["final_test_acc"] /= len(classified_samples) 288 | 289 | return data_dict 290 | --------------------------------------------------------------------------------