├── .gitignore
├── LICENSE
├── README.md
├── assets
├── cover.gif
└── cover.mp4
├── base_model
├── __init__.py
├── base.py
├── llama3instruct.py
└── mistral03instruct.py
├── cfgs
├── base_model
│ ├── llama3i8b.yaml
│ └── mistral03i7b.yaml
├── config.yaml
├── mode
│ ├── eval.yaml
│ └── training.yaml
├── optimization
│ ├── cem.yaml
│ ├── reinforce.yaml
│ └── rsm.yaml
├── policy
│ ├── default.yaml
│ └── wcomb.yaml
└── task
│ ├── ablation_tasks
│ ├── few_shot_arc_challenge_20.yaml
│ ├── few_shot_arc_challenge_3.yaml
│ └── few_shot_arc_challenge_5.yaml
│ ├── ai2_arc.yaml
│ ├── cls.yaml
│ ├── few_shot_arc_challenge.yaml
│ ├── few_shot_humaneval.yaml
│ ├── few_shot_math.yaml
│ ├── gsm8k.yaml
│ ├── math.yaml
│ └── mbpp2.yaml
├── evaluation
└── fishfarm
│ ├── fishfarm
│ ├── __init__.py
│ ├── chat_templates.py
│ ├── imports.py
│ ├── logging.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── tokenization_utils.py
│ │ └── vllm_model.py
│ ├── tasks
│ │ ├── __init__.py
│ │ ├── ai2_arc.py
│ │ ├── base.py
│ │ ├── competation_math.py
│ │ ├── evalplus
│ │ │ ├── __init__.py
│ │ │ ├── data.py
│ │ │ ├── evaluation.py
│ │ │ ├── generation.py
│ │ │ ├── sanitization.py
│ │ │ └── task.py
│ │ └── language_restricted_math.py
│ └── version.py
│ ├── pyproject.toml
│ └── tox.ini
├── logging_utils.py
├── optim_modules.py
├── policy
├── __init__.py
├── base.py
└── weighted_combination.py
├── requirements.txt
├── scripts
├── eval_few_shot.sh
├── eval_prompt_based.sh
└── train_task_expert.sh
├── svd_reinforce_hydra.py
├── tasks
├── __init__.py
├── arc.py
├── base.py
├── cls.py
├── gsm8k.py
├── math.py
└── mbpp2.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | playground*
2 | wandb
3 | results
4 | results_eval
5 | saved_models
6 | outputs
7 | reference_code
8 | messy_scripts
9 | *_decomposed_params.pt
10 | **/__pycache__/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Transformer2: Self-adaptive LLMs 🐙
3 |
4 |
5 | 📚 [Paper] |
6 | 📄 [Blog]
7 |
8 |
9 | Self-adaptive large language models (LLMs) aim to solve the challenges posed by traditional fine-tuning methods, which are often computationally intensive and static in their ability to handle diverse tasks.
10 |
11 | We are excited to introduce Transformer², a novel self-adaptation framework that adapts LLMs for unseen tasks in real-time by selectively adjusting only the singular components of their weight matrices.
12 | During inference, Transformer² employs a two-pass mechanism: first, a dispatch system identifies the task properties, and then task-specific "expert" vectors, trained using reinforcement learning, are dynamically mixed to obtain targeted behavior for the incoming prompt.
13 |
14 |
15 | 
16 |
17 |
18 |
19 | ## Installation
20 |
21 | ### 1. Clone the Repo
22 | ```
23 | git clone https://github.com/SakanaAI/self-adaptive-llms
24 | cd self-adaptive-llms
25 | ```
26 |
27 | ### 2. Install Libraries
28 | ```bash
29 | conda create -n t2 python=3.11 -y
30 | conda activate t2
31 | pip install --upgrade pip
32 | pip install -r requirements.txt
33 | ```
34 |
35 | ### 3. Install Tasks Evaluator
36 | ```bash
37 | cd evaluation/fishfarm
38 | pip install -e .
39 | ```
40 |
41 | ## Usage
42 | We provide example scripts for both training and evaluation.
43 |
44 | Please change the argument in the provided script to choose among models and tasks
45 |
46 | ### Training
47 |
48 | ```bash
49 | bash scripts/train_task_expert.sh
50 | ```
51 |
52 | ### Evaluation
53 |
54 | #### Prompt-based evaluation
55 | Classification experts can be loaded by specifying the CLS_EXPERT_PATH in the script.
56 | ```bash
57 | bash scripts/eval_prompt_based.sh
58 | ```
59 |
60 | #### Few-shots evaluation
61 | ```bash
62 | bash scripts/eval_few_shot.sh
63 | ```
64 |
65 | ## Citation
66 | If you find **Transformer^2** useful for your research, please cite using this BibTeX:
67 | ```
68 | @misc{sun2025transformersquaredselfadaptivellms,
69 | title={Transformer-Squared: Self-adaptive LLMs},
70 | author={Qi Sun and Edoardo Cetin and Yujin Tang},
71 | year={2025},
72 | eprint={2501.06252},
73 | archivePrefix={arXiv},
74 | primaryClass={cs.LG},
75 | url={https://arxiv.org/abs/2501.06252},
76 | }
77 | ```
78 |
--------------------------------------------------------------------------------
/assets/cover.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SakanaAI/self-adaptive-llms/03a41aed1cfc57276e72ad5a42845a04b356db1e/assets/cover.gif
--------------------------------------------------------------------------------
/assets/cover.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SakanaAI/self-adaptive-llms/03a41aed1cfc57276e72ad5a42845a04b356db1e/assets/cover.mp4
--------------------------------------------------------------------------------
/base_model/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseModel
2 | from .llama3instruct import Llama3Instruct8B
3 | from .mistral03instruct import MistralV03Instruct7B
4 |
--------------------------------------------------------------------------------
/base_model/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class BaseModel(ABC):
5 | def __init__(self):
6 | pass
7 |
8 | @abstractmethod
9 | def get_model_id(self):
10 | raise NotImplementedError
11 |
12 | @abstractmethod
13 | def get_model_name(self):
14 | raise NotImplementedError
15 |
16 | @abstractmethod
17 | def get_param_file(self, param_folder_path=""):
18 | raise NotImplementedError
19 |
--------------------------------------------------------------------------------
/base_model/llama3instruct.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .base import BaseModel
4 |
5 |
6 | class Llama3Instruct8B(BaseModel):
7 | def __init__(self):
8 | self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
9 | self.dec_param_file_n = "llama3_decomposed_params.pt"
10 |
11 | def get_model_id(self):
12 | return self.model_id
13 |
14 | def get_model_name(self):
15 | return self.model_id.split("/")[1]
16 |
17 | def get_param_file(self, param_folder_path=""):
18 | return os.path.join(param_folder_path, self.dec_param_file_n)
19 |
--------------------------------------------------------------------------------
/base_model/mistral03instruct.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .base import BaseModel
4 |
5 |
6 | class MistralV03Instruct7B(BaseModel):
7 | def __init__(self):
8 | self.model_id = "mistralai/Mistral-7B-Instruct-v0.3"
9 | self.dec_param_file_n = "mistral_decomposed_params.pt"
10 |
11 | def get_model_id(self):
12 | return self.model_id
13 |
14 | def get_model_name(self):
15 | return self.model_id.split("/")[1]
16 |
17 | def get_param_file(self, param_folder_path=""):
18 | return os.path.join(param_folder_path, self.dec_param_file_n)
19 |
--------------------------------------------------------------------------------
/cfgs/base_model/llama3i8b.yaml:
--------------------------------------------------------------------------------
1 | base_model:
2 | _target_: base_model.Llama3Instruct8B
3 |
4 |
5 | base_model_name: llama3i8b
6 |
7 | # reference_params_results:
8 | # - 'saved_models/llama3i8b/gsm8k/learnable_params.pt'
9 | # - 'saved_models/llama3i8b/mbpp/learnable_params.pt'
10 | # - 'saved_models/llama3i8b/ai2arc/learnable_params.pt'
11 |
12 | reference_params_results:
13 | - "ckpts/learnable_params/llama3_8b_instruct_gsm8k_svd_pg_mlp.pt"
14 | - "ckpts/learnable_params/llama3_8b_instruct_mbpp_pro_svd_pg_mlp.pt"
15 | - "ckpts/learnable_params/llama3_8b_instruct_gsm8k_svd_pg_mlp.pt"
--------------------------------------------------------------------------------
/cfgs/base_model/mistral03i7b.yaml:
--------------------------------------------------------------------------------
1 | base_model:
2 | _target_: base_model.MistralV03Instruct7B
3 |
4 |
5 | base_model_name: mistral03i7b
6 |
7 | reference_params_results:
8 | - 'saved_models/mistral03i7b/gsm8k/policy_params.pt'
9 | - 'saved_models/mistral03i7b/mbpp/policy_params.pt'
10 | - 'saved_models/mistral03i7b/ai2arc/policy_params.pt'
--------------------------------------------------------------------------------
/cfgs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - policy@_global_: default
4 | - task@_global_: gsm8k
5 | - base_model@_global_: llama3i8b
6 | - optimization@_global_: reinforce
7 | - mode@_global_: training
8 |
9 | num_iters: 2000
10 | test_interval: 10
11 | lr: 2e-3
12 | batch_size: 256
13 | seed: 42
14 | init_val: 0.1
15 | test_only: false
16 | model_dir: null
17 | save_legacy_params: false
18 | use_lora: false
19 | prompt_based_eval: false
20 | experts_path_dict: null
21 |
22 | run_name: null
23 |
24 | load_ckpt: null
25 | exp_suffix: 'st'
26 |
27 | exp_name: ${base_model_name}/${optim_name}-${exp_suffix}
28 |
29 | wandb_log: true # enabled by default
30 | wandb_project: shakeoff
31 | wandb_group_name: ${exp_name}
32 | extract_svd: false
33 |
34 | out_dir: results
35 |
36 | hydra:
37 | run:
38 | dir: ${out_dir}/
--------------------------------------------------------------------------------
/cfgs/mode/eval.yaml:
--------------------------------------------------------------------------------
1 | exp_name: eval_${base_model_name}/temp-lr${lr}-mGN${max_grad_norm}-klC${kl_ref_coeff}-r${rw_strategy}-${exp_suffix}-r
2 |
3 | test_only: true
4 | load_ckpt: null
5 | use_lora: false
6 |
7 | prompt_based_eval: false
8 | experts_path_dict:
9 | code: null
10 | math: null
11 | reasoning: null
12 | other: null
13 |
14 | wandb_project: T^2_eval
15 | wandb_group_name: ${exp_name}
16 | out_dir: results_eval
--------------------------------------------------------------------------------
/cfgs/mode/training.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SakanaAI/self-adaptive-llms/03a41aed1cfc57276e72ad5a42845a04b356db1e/cfgs/mode/training.yaml
--------------------------------------------------------------------------------
/cfgs/optimization/cem.yaml:
--------------------------------------------------------------------------------
1 |
2 | optimization_algorithm:
3 | _target_: optim_modules.CEM
4 | elite_ratio: ${elite_ratio}
5 | pop_size: ${pop_size}
6 | min_trainable_param: ${min_trainable_param}
7 | max_trainable_param: ${max_trainable_param}
8 | optim_ema: ${optim_ema}
9 | re_eval_best: ${re_eval_best}
10 | use_loglikelihood_for_ties: ${use_loglikelihood_for_ties}
11 |
12 |
13 | pop_size: 32
14 | elite_ratio: 0.2
15 | min_trainable_param: 0
16 | max_trainable_param: 1
17 | optim_ema: 0
18 | re_eval_best: True
19 | use_loglikelihood_for_ties: true
20 | optim_name: CEM-pop${pop_size}e${elite_ratio}-[${min_trainable_param}-${max_trainable_param}]-tieswLL${use_loglikelihood_for_ties}
--------------------------------------------------------------------------------
/cfgs/optimization/reinforce.yaml:
--------------------------------------------------------------------------------
1 |
2 | optimization_algorithm:
3 | _target_: optim_modules.Reinforce
4 | # policy: ${policy}
5 | # gpu: ${gpu}
6 | max_grad_norm: ${max_grad_norm}
7 | lr: ${lr}
8 | rw_norm: ${rw_norm}
9 | rw_clip: ${rw_clip}
10 | kl_ref_coeff: ${kl_ref_coeff}
11 |
12 |
13 | # policy:
14 | # gpu:
15 | max_grad_norm: 1e-3
16 | lr: 2e-3
17 | rw_norm: 0
18 | rw_clip: null
19 | kl_ref_coeff: 0
20 | rw_strategy: rN${rw_norm}C${rw_clip}
21 | optim_name: RL-lr${lr}-mGN${max_grad_norm}-klC${kl_ref_coeff}-r${rw_strategy}
--------------------------------------------------------------------------------
/cfgs/optimization/rsm.yaml:
--------------------------------------------------------------------------------
1 |
2 | optimization_algorithm:
3 | _target_: optim_modules.RandomShooting
4 | # policy: ${policy}
5 | # gpu: ${gpu}
6 | pop_size: ${pop_size}
7 | min_trainable_param: ${min_trainable_param}
8 | max_trainable_param: ${max_trainable_param}
9 | optim_ema: ${optim_ema}
10 | re_eval_best: ${re_eval_best}
11 | use_loglikelihood_for_ties: ${use_loglikelihood_for_ties}
12 |
13 |
14 | pop_size: 32
15 | min_trainable_param: 0
16 | max_trainable_param: 1
17 | optim_ema: 0
18 | re_eval_best: True
19 | use_loglikelihood_for_ties: false
20 | optim_name: RSML-pop${pop_size}-[${min_trainable_param}-${max_trainable_param}]-tieswLL${use_loglikelihood_for_ties}
--------------------------------------------------------------------------------
/cfgs/policy/default.yaml:
--------------------------------------------------------------------------------
1 | shakeoff_policy:
2 | _target_: policy.Policy
3 | init_val: ${init_val}
4 | mode: ${policy_mode}
5 | max_mult: ${max_mult}
6 |
7 | policy_mode: 1
8 | max_mult: 1
9 | policy_name: ${policy_mode}_mm${max_mult}
10 |
--------------------------------------------------------------------------------
/cfgs/policy/wcomb.yaml:
--------------------------------------------------------------------------------
1 |
2 |
3 | shakeoff_policy:
4 | _target_: policy.WeightedCombination
5 | base_policy_cfg: null
6 | params_paths: ${reference_params_results}
7 | norm_coeffs: ${norm_coeffs}
8 | per_layer: ${per_layer}
9 | init_values: ${init_values}
10 |
11 | norm_coeffs: true
12 | per_layer: false
13 | init_values: null
14 |
15 | policy_name: Wcomb_n${norm_coeffs}_p${per_layer}
16 |
--------------------------------------------------------------------------------
/cfgs/task/ablation_tasks/few_shot_arc_challenge_20.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.FewShotTask
3 | wrapped_task:
4 | _target_: tasks.AI2ArcTask
5 | wrapped_split: ${wrapped_split}
6 | shots: ${task_shots}
7 | seed: ${task_loader_seed}
8 |
9 |
10 | wrapped_split: transfer
11 | task_shots: 20
12 | task_loader_seed: 38
13 |
14 | task_name: arc_chal_${task_shots}shots
15 |
16 |
--------------------------------------------------------------------------------
/cfgs/task/ablation_tasks/few_shot_arc_challenge_3.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.FewShotTask
3 | wrapped_task:
4 | _target_: tasks.AI2ArcTask
5 | wrapped_split: ${wrapped_split}
6 | shots: ${task_shots}
7 | seed: ${task_loader_seed}
8 |
9 |
10 | wrapped_split: transfer
11 | task_shots: 3
12 | task_loader_seed: 38
13 |
14 | task_name: arc_chal_${task_shots}shots
15 |
16 |
--------------------------------------------------------------------------------
/cfgs/task/ablation_tasks/few_shot_arc_challenge_5.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.FewShotTask
3 | wrapped_task:
4 | _target_: tasks.AI2ArcTask
5 | wrapped_split: ${wrapped_split}
6 | shots: ${task_shots}
7 | seed: ${task_loader_seed}
8 |
9 |
10 | wrapped_split: transfer
11 | task_shots: 5
12 | task_loader_seed: 38
13 |
14 | task_name: arc_chal_${task_shots}shots
15 |
16 |
--------------------------------------------------------------------------------
/cfgs/task/ai2_arc.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.AI2ArcTask
3 |
4 |
5 | task_name: ai2_arc
6 |
7 |
--------------------------------------------------------------------------------
/cfgs/task/cls.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.ClsTask
3 |
4 |
5 | task_name: Cls
6 |
7 |
--------------------------------------------------------------------------------
/cfgs/task/few_shot_arc_challenge.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.FewShotTask
3 | wrapped_task:
4 | _target_: tasks.AI2ArcTask
5 | wrapped_split: ${wrapped_split}
6 | shots: ${task_shots}
7 | seed: ${task_loader_seed}
8 |
9 |
10 | wrapped_split: transfer
11 | task_shots: 10
12 | task_loader_seed: 38
13 |
14 | task_name: arc_chal_${task_shots}shots
15 |
16 |
--------------------------------------------------------------------------------
/cfgs/task/few_shot_humaneval.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.FewShotTask
3 | wrapped_task:
4 | _target_: tasks.Mbpp2Task2
5 | wrapped_split: ${wrapped_split}
6 | shots: ${task_shots}
7 | seed: ${task_loader_seed}
8 |
9 |
10 | wrapped_split: transfer
11 | task_shots: 10
12 | task_loader_seed: 16
13 |
14 | task_name: humaneval_${task_shots}shots
15 |
--------------------------------------------------------------------------------
/cfgs/task/few_shot_math.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.FewShotTask
3 | wrapped_task:
4 | _target_: tasks.MathTask
5 | wrapped_split: ${wrapped_split}
6 | shots: ${task_shots}
7 | seed: ${task_loader_seed}
8 |
9 |
10 | wrapped_split: test
11 | task_shots: 10
12 | task_loader_seed: 27
13 |
14 | task_name: math_${task_shots}shots
15 |
16 |
--------------------------------------------------------------------------------
/cfgs/task/gsm8k.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.Gsm8kTask
3 |
4 |
5 | task_name: gsm8k
6 |
7 |
--------------------------------------------------------------------------------
/cfgs/task/math.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.MathTask
3 |
4 |
5 | task_name: math
6 |
7 |
--------------------------------------------------------------------------------
/cfgs/task/mbpp2.yaml:
--------------------------------------------------------------------------------
1 | task_loader:
2 | _target_: tasks.Mbpp2Task
3 |
4 |
5 | task_name: mbpp
6 |
7 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/__init__.py:
--------------------------------------------------------------------------------
1 | from . import chat_templates, models, tasks
2 | from .models import Message, Model, Role
3 | from .tasks import Task, TaskResult
4 |
5 | __all__ = [
6 | "chat_templates",
7 | "tasks",
8 | "models",
9 | "Task",
10 | "TaskResult",
11 | "Model",
12 | "Message",
13 | "Role",
14 | ]
15 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/chat_templates.py:
--------------------------------------------------------------------------------
1 | LLAMA3 = (
2 | "{% set loop_messages = messages %}"
3 | "{% for message in loop_messages %}"
4 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>"
5 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
6 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}"
7 | "{% endif %}"
8 | "{{ content }}"
9 | "{% endfor %}"
10 | "{% if add_generation_prompt %}"
11 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
12 | "{% endif %}"
13 | )
14 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/imports.py:
--------------------------------------------------------------------------------
1 | from types import TracebackType
2 | from typing import Optional, Tuple, Type
3 |
4 |
5 | class _DeferredImportExceptionContextManager:
6 | """Context manager to defer exceptions from imports.
7 |
8 | Catches :exc:`ImportError` and :exc:`SyntaxError`.
9 | If any exception is caught, this class raises an :exc:`ImportError` when being checked.
10 |
11 | """
12 |
13 | def __init__(self) -> None:
14 | self._deferred: Optional[Tuple[Exception, str]] = None
15 |
16 | def __enter__(self) -> "_DeferredImportExceptionContextManager":
17 | """Enter the context manager.
18 |
19 | Returns:
20 | Itself.
21 |
22 | """
23 | return self
24 |
25 | def __exit__(
26 | self,
27 | exc_type: Optional[Type[Exception]],
28 | exc_value: Optional[Exception],
29 | traceback: Optional[TracebackType],
30 | ) -> Optional[bool]:
31 | """Exit the context manager.
32 |
33 | Args:
34 | exc_type:
35 | Raised exception type. :obj:`None` if nothing is raised.
36 | exc_value:
37 | Raised exception object. :obj:`None` if nothing is raised.
38 | traceback:
39 | Associated traceback. :obj:`None` if nothing is raised.
40 |
41 | Returns:
42 | :obj:`None` if nothing is deferred, otherwise :obj:`True`.
43 | :obj:`True` will suppress any exceptions avoiding them from propagating.
44 |
45 | """
46 | if isinstance(exc_value, (ImportError, SyntaxError)):
47 | if isinstance(exc_value, ImportError):
48 | message = (
49 | "Tried to import '{}' but failed. Please make sure that the package is "
50 | "installed correctly to use this feature. Actual error: {}."
51 | ).format(exc_value.name, exc_value)
52 | elif isinstance(exc_value, SyntaxError):
53 | message = (
54 | "Tried to import a package but failed due to a syntax error in {}. Please "
55 | "make sure that the Python version is correct to use this feature. Actual "
56 | "error: {}."
57 | ).format(exc_value.filename, exc_value)
58 | else:
59 | assert False
60 |
61 | self._deferred = (exc_value, message)
62 | return True
63 | return None
64 |
65 | def is_successful(self) -> bool:
66 | """Return whether the context manager has caught any exceptions.
67 |
68 | Returns:
69 | :obj:`True` if no exceptions are caught, :obj:`False` otherwise.
70 |
71 | """
72 | return self._deferred is None
73 |
74 | def check(self) -> None:
75 | """Check whether the context manager has caught any exceptions.
76 |
77 | Raises:
78 | :exc:`ImportError`:
79 | If any exception was caught from the caught exception.
80 |
81 | """
82 | if self._deferred is not None:
83 | exc_value, message = self._deferred
84 | raise ImportError(message) from exc_value
85 |
86 |
87 | def try_import() -> _DeferredImportExceptionContextManager:
88 | """Create a context manager that can wrap imports of optional packages to defer exceptions.
89 |
90 | Returns:
91 | Deferred import context manager.
92 |
93 | """
94 | return _DeferredImportExceptionContextManager()
95 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/logging.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied from Optuna repo:
3 | https://github.com/optuna/optuna/blob/2595653638506e1b7e025a966a220984a59ab936/optuna/logging.py
4 | Removed some comments for less verbosity.
5 |
6 | In general, `logger.info` is preferred over `print` since it contains module name and timestamp;
7 | We recommend the use of logger object for the fishfarm developers.
8 |
9 | Inside fishfarm, we can call `get_logger(__name__)` from each python file.
10 | Then the root logger format and level are applied to that logger object.
11 | """
12 |
13 | from __future__ import annotations
14 |
15 | import logging
16 | import os
17 | import sys
18 | import threading
19 | from logging import CRITICAL, DEBUG, ERROR, FATAL, INFO, WARN, WARNING
20 |
21 | import colorlog
22 |
23 | __all__ = [
24 | "CRITICAL",
25 | "DEBUG",
26 | "ERROR",
27 | "FATAL",
28 | "INFO",
29 | "WARN",
30 | "WARNING",
31 | ]
32 |
33 | _lock: threading.Lock = threading.Lock()
34 | _default_handler: logging.Handler | None = None
35 |
36 |
37 | def create_default_formatter() -> logging.Formatter:
38 | """Create a default formatter of log messages.
39 |
40 | This function is not supposed to be directly accessed by library users.
41 | """
42 | header = "[%(levelname)1.1s %(asctime)s %(name)s]"
43 | message = "%(message)s"
44 | if _color_supported():
45 | return colorlog.ColoredFormatter(
46 | f"%(log_color)s{header}%(reset)s {message}",
47 | )
48 | return logging.Formatter(f"{header} {message}")
49 |
50 |
51 | def _color_supported() -> bool:
52 | """Detection of color support."""
53 | # NO_COLOR environment variable:
54 | if os.environ.get("NO_COLOR", None):
55 | return False
56 |
57 | if not hasattr(sys.stderr, "isatty") or not sys.stderr.isatty():
58 | return False
59 | else:
60 | return True
61 |
62 |
63 | def _get_library_name() -> str:
64 | return __name__.split(".")[0]
65 |
66 |
67 | def _get_library_root_logger() -> logging.Logger:
68 | return logging.getLogger(_get_library_name())
69 |
70 |
71 | def _configure_library_root_logger() -> None:
72 | global _default_handler
73 |
74 | with _lock:
75 | if _default_handler:
76 | # This library has already configured the library root logger.
77 | return
78 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
79 | _default_handler.setFormatter(create_default_formatter())
80 |
81 | # Apply our default configuration to the library root logger.
82 | library_root_logger: logging.Logger = _get_library_root_logger()
83 | library_root_logger.addHandler(_default_handler)
84 | library_root_logger.setLevel(logging.INFO)
85 | library_root_logger.propagate = False
86 |
87 |
88 | def _reset_library_root_logger() -> None:
89 | global _default_handler
90 |
91 | with _lock:
92 | if not _default_handler:
93 | return
94 |
95 | library_root_logger: logging.Logger = _get_library_root_logger()
96 | library_root_logger.removeHandler(_default_handler)
97 | library_root_logger.setLevel(logging.NOTSET)
98 | _default_handler = None
99 |
100 |
101 | def get_logger(name: str) -> logging.Logger:
102 | """Return a logger with the specified name.
103 | name's prefix should be `fishfarm.` (just like __name__ variable),
104 | otherwise root logger settings will be not reflected.
105 | This function is not supposed to be directly accessed by library users.
106 | """
107 |
108 | _configure_library_root_logger()
109 | return logging.getLogger(name)
110 |
111 |
112 | def get_verbosity() -> int:
113 | """Return the current level for the fishfarm's root logger.
114 |
115 | Returns:
116 | Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``.
117 |
118 | .. note::
119 | fishfarm has following logging levels:
120 |
121 | - ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL``
122 | - ``fishfarm.logging.ERROR``
123 | - ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN``
124 | - ``fishfarm.logging.INFO``
125 | - ``fishfarm.logging.DEBUG``
126 | """
127 |
128 | _configure_library_root_logger()
129 | return _get_library_root_logger().getEffectiveLevel()
130 |
131 |
132 | def set_verbosity(verbosity: int) -> None:
133 | """Set the level for the fishfarm's root logger.
134 |
135 | Args:
136 | verbosity:
137 | Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``.
138 |
139 | .. note::
140 | fishfarm has following logging levels:
141 |
142 | - ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL``
143 | - ``fishfarm.logging.ERROR``
144 | - ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN``
145 | - ``fishfarm.logging.INFO``
146 | - ``fishfarm.logging.DEBUG``
147 | """
148 |
149 | _configure_library_root_logger()
150 | _get_library_root_logger().setLevel(verbosity)
151 |
152 |
153 | def disable_default_handler() -> None:
154 | """Disable the default handler of the fishfarm's root logger."""
155 |
156 | _configure_library_root_logger()
157 |
158 | assert _default_handler is not None
159 | _get_library_root_logger().removeHandler(_default_handler)
160 |
161 |
162 | def enable_default_handler() -> None:
163 | """Enable the default handler of the fishfarm's root logger."""
164 |
165 | _configure_library_root_logger()
166 |
167 | assert _default_handler is not None
168 | _get_library_root_logger().addHandler(_default_handler)
169 |
170 |
171 | def disable_propagation() -> None:
172 | """Disable propagation of the library log outputs.
173 |
174 | Note that log propagation is disabled by default. You only need to use this function
175 | to stop log propagation when you use :func:`~fishfarm.logging.enable_propagation()`.
176 | """
177 |
178 | _configure_library_root_logger()
179 | _get_library_root_logger().propagate = False
180 |
181 |
182 | def enable_propagation() -> None:
183 | """Enable propagation of the library log outputs.
184 |
185 | Please disable the fishfarm's default handler to prevent double logging if the root logger has
186 | been configured.
187 | """
188 |
189 | _configure_library_root_logger()
190 | _get_library_root_logger().propagate = True
191 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import (GenerationRequest, GenerationResult, Message, Model,
2 | NLLRequest, NLLResult, Role)
3 |
4 | __all__ = [
5 | "GenerationRequest",
6 | "GenerationResult",
7 | "NLLRequest",
8 | "NLLResult",
9 | "Model",
10 | "Role",
11 | "Message",
12 | ]
13 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/models/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import Iterable, Literal, Optional, Sequence
5 |
6 | Role = Literal["system", "user", "assistant", "assistant_prefill"]
7 |
8 |
9 | @dataclass
10 | class Message:
11 |
12 | role: Role
13 | content: str
14 |
15 |
16 | @dataclass
17 | class GenerationRequest:
18 |
19 | messages: list[Message]
20 |
21 | max_tokens: Optional[int] = None
22 | stop: Sequence[str] = ()
23 |
24 |
25 | @dataclass
26 | class GenerationResult:
27 |
28 | request: GenerationRequest
29 | generation: str
30 |
31 |
32 | @dataclass
33 | class NLLRequest:
34 |
35 | messages: list[Message]
36 |
37 |
38 | @dataclass
39 | class NLLResult:
40 |
41 | request: NLLRequest
42 | sum_nll: float
43 | num_considered_tokens: int
44 |
45 |
46 | class Model:
47 |
48 | def generate(
49 | self, requests: Sequence[GenerationRequest]
50 | ) -> Iterable[GenerationResult]:
51 | raise NotImplementedError()
52 |
53 | def nll(self, requests: Sequence[NLLRequest]) -> Iterable[NLLResult]:
54 | raise NotImplementedError()
55 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/models/tokenization_utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Optional
3 |
4 | from transformers import PreTrainedTokenizerBase
5 |
6 | from .base import Message
7 |
8 |
9 | class MaskedTokens:
10 |
11 | text: str
12 | token_ids: list[int]
13 | mask: list[bool]
14 |
15 | def __init__(self) -> None:
16 | self.text = ""
17 | self.token_ids = []
18 | self.mask = []
19 |
20 | def extend(
21 | self,
22 | messages: list[Message],
23 | mask_value: bool,
24 | tokenizer: PreTrainedTokenizerBase,
25 | chat_template: Optional[str],
26 | add_generation_prompt: bool,
27 | ) -> None:
28 | if len(messages) == 0:
29 | # `tokenizer.apply_chat_template` does not accept an empty list.
30 | raise ValueError("At least one message is required.")
31 |
32 | all_text: str = tokenizer.apply_chat_template(
33 | conversation=[dataclasses.asdict(message) for message in messages],
34 | chat_template=chat_template,
35 | tokenize=False,
36 | add_generation_prompt=add_generation_prompt,
37 | )
38 | assert all_text.startswith(self.text)
39 | new_text = all_text[len(self.text) :]
40 | new_token_ids: list[int] = tokenizer.encode(new_text, add_special_tokens=False)
41 |
42 | self.token_ids.extend(new_token_ids)
43 | self.mask.extend([mask_value] * len(new_token_ids))
44 | self.text = all_text
45 |
46 |
47 | def tokenize_messages(
48 | messages: list[Message],
49 | tokenizer: PreTrainedTokenizerBase,
50 | chat_template: Optional[str],
51 | ) -> MaskedTokens:
52 | masked_tokens = MaskedTokens()
53 |
54 | for i, message in enumerate(messages):
55 | if message.role != "assistant":
56 | continue
57 |
58 | masked_tokens.extend(messages[:i], False, tokenizer, chat_template, True)
59 | masked_tokens.extend(messages[: i + 1], True, tokenizer, chat_template, False)
60 |
61 | masked_tokens.extend(messages, False, tokenizer, chat_template, True)
62 | return masked_tokens
63 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/models/vllm_model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import dataclasses
3 | from typing import Any, Iterable, Optional, Sequence
4 |
5 | from fishfarm.models.base import NLLRequest, NLLResult
6 | from transformers import PreTrainedTokenizerBase
7 |
8 | from ..imports import try_import
9 | from .base import GenerationRequest, GenerationResult, Message, Model
10 | from .tokenization_utils import tokenize_messages
11 |
12 | with try_import() as _imports:
13 | import vllm
14 |
15 | _imports.check()
16 |
17 |
18 | class VLLMModel(Model):
19 |
20 | def __init__(
21 | self,
22 | llm: vllm.LLM,
23 | sampling_params: vllm.SamplingParams,
24 | chat_template: Optional[str],
25 | ) -> None:
26 | self.llm = llm
27 | self.chat_template = chat_template
28 | self.sampling_params = sampling_params
29 |
30 | def get_tokenizer(self) -> PreTrainedTokenizerBase:
31 | tokenizer = self.llm.get_tokenizer()
32 |
33 | if not hasattr(tokenizer, "apply_chat_template"):
34 | if hasattr(tokenizer, "tokenizer"):
35 | tokenizer = tokenizer.tokenizer
36 | else:
37 | raise ValueError(
38 | "The tokenizer does not have the 'apply_chat_template' method. "
39 | "This is likely because of the versions of vLLM or transformers."
40 | )
41 |
42 | return tokenizer
43 |
44 | def _into_prompt(self, messages: Sequence[Message]) -> str:
45 | tokenizer = self.get_tokenizer()
46 | prefill_text = ""
47 | n_assistant_prefill = sum([m.role == "assistant_prefill" for m in messages])
48 | if n_assistant_prefill > 1:
49 | raise ValueError(
50 | f"There must be at most one assistant_prefill role, but got {n_assistant_prefill}",
51 | )
52 | if n_assistant_prefill:
53 | assert (
54 | messages[-1].role == "assistant_prefill"
55 | ), "assistant_prefill role must be the last message"
56 | prefill_text = messages[-1].content
57 | messages = messages[:-1]
58 | prompt: str = tokenizer.apply_chat_template(
59 | conversation=[dataclasses.asdict(message) for message in messages],
60 | chat_template=self.chat_template,
61 | tokenize=False,
62 | add_generation_prompt=True,
63 | )
64 | prompt += prefill_text
65 | return prompt
66 |
67 | def _predict_log_probs(self, token_ids_list: list[list[int]]) -> list[list[float]]:
68 | sampling_params = copy.copy(self.sampling_params)
69 | sampling_params.prompt_logprobs = 1
70 | sampling_params.max_tokens = 1
71 |
72 | completions = self.llm.generate(
73 | prompt_token_ids=token_ids_list,
74 | sampling_params=sampling_params,
75 | )
76 |
77 | log_probs_list = []
78 | for token_ids, completion in zip(token_ids_list, completions):
79 | log_probs = []
80 | assert completion.prompt_logprobs is not None
81 | assert token_ids == completion.prompt_token_ids
82 | assert len(token_ids) == len(completion.prompt_logprobs)
83 | for token_id, logprob_dict in zip(token_ids, completion.prompt_logprobs):
84 | if logprob_dict is None:
85 | log_probs.append(0.0)
86 | else:
87 | logprob_entry: Any = logprob_dict[token_id]
88 |
89 | if isinstance(logprob_entry, float):
90 | log_probs.append(logprob_entry)
91 | else:
92 | log_probs.append(logprob_entry.logprob)
93 |
94 | log_probs_list.append(log_probs)
95 |
96 | return log_probs_list
97 |
98 | def generate(
99 | self, requests: Sequence[GenerationRequest]
100 | ) -> Iterable[GenerationResult]:
101 |
102 | prompts = [self._into_prompt(request.messages) for request in requests]
103 | completions = self.llm.generate(
104 | prompts=prompts,
105 | sampling_params=self.sampling_params,
106 | )
107 |
108 | for request, completion in zip(requests, completions):
109 | yield GenerationResult(
110 | request=request, generation=completion.outputs[0].text
111 | )
112 |
113 | def nll(self, requests: Sequence[NLLRequest]) -> Iterable[NLLResult]:
114 | masked_tokens_list = [
115 | tokenize_messages(
116 | request.messages, self.get_tokenizer(), self.chat_template
117 | )
118 | for request in requests
119 | ]
120 | log_probs_list = self._predict_log_probs(
121 | [masked_tokens.token_ids for masked_tokens in masked_tokens_list]
122 | )
123 |
124 | results = []
125 | for log_probs, masked_tokens, request in zip(
126 | log_probs_list, masked_tokens_list, requests
127 | ):
128 | assert len(log_probs) == len(masked_tokens.mask)
129 |
130 | sum_nll = 0.0
131 | num_considered_tokens = 0
132 | for log_prob, mask_value in zip(log_probs, masked_tokens.mask):
133 | if mask_value:
134 | sum_nll += -log_prob
135 | num_considered_tokens += 1
136 |
137 | results.append(
138 | NLLResult(
139 | request=request,
140 | sum_nll=sum_nll,
141 | num_considered_tokens=num_considered_tokens,
142 | )
143 | )
144 |
145 | return results
146 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from . import base
2 | from .base import Task, TaskResult
3 |
4 | __all__ = [
5 | "base",
6 | "TaskResult",
7 | "Task",
8 | ]
9 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/ai2_arc.py:
--------------------------------------------------------------------------------
1 | import random
2 | import re
3 | from dataclasses import dataclass
4 | from typing import Iterable, Optional, Sequence
5 |
6 | from ..models import GenerationRequest, Message, Model
7 | from .base import Task, TaskResult
8 |
9 |
10 | def extract_answer(text: str) -> Optional[str]:
11 | pattern = r"answer is \(?([A-J])\)?"
12 | match = re.search(pattern, text)
13 | if match:
14 | return match.group(1)
15 | else:
16 | return extract_again(text)
17 |
18 |
19 | def extract_again(text: str) -> Optional[str]:
20 | match = re.search(r".*[aA]nswer:\s*([A-J])", text)
21 | if match:
22 | return match.group(1)
23 | else:
24 | return extract_final(text)
25 |
26 |
27 | def extract_final(text: str) -> Optional[str]:
28 | pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)"
29 | match = re.search(pattern, text, re.DOTALL)
30 | if match:
31 | return match.group(0)
32 | else:
33 | return None
34 |
35 |
36 | def is_correct(pred: Optional[str], answer: str, options: list[str]) -> bool:
37 | if not pred:
38 | random.seed(42)
39 | x = random.randint(0, len(options) - 1)
40 | if ["A", "B", "C", "D", "E"][x] == answer:
41 | return True
42 | else:
43 | return False
44 | elif pred == answer:
45 | return True
46 | else:
47 | return False
48 |
49 |
50 | @dataclass
51 | class Ai2ArcSample:
52 |
53 | question: str
54 | question_id: str
55 | options: list[str]
56 | answer: str
57 |
58 |
59 | def mean(iterable: Iterable[float]) -> float:
60 | total, count = 0.0, 0
61 | for x in iterable:
62 | total += x
63 | count += 1
64 | return total / count
65 |
66 |
67 | class Ai2ArcTask(Task):
68 | def __init__(
69 | self,
70 | samples: Sequence[Ai2ArcSample],
71 | context_messages: Sequence[Message] = (),
72 | ):
73 | self.samples = list(samples)
74 | self.context_messages = context_messages
75 |
76 | @property
77 | def num_samples(self) -> int:
78 | return len(self.samples)
79 |
80 | def evaluate(
81 | self,
82 | model: Model,
83 | sample_ids: Optional[Sequence[int]] = None,
84 | ) -> TaskResult:
85 | if sample_ids is None:
86 | sample_ids = range(len(self.samples))
87 | samples = [self.samples[sample_id] for sample_id in sample_ids]
88 |
89 | requests = []
90 | for sample in samples:
91 | messages = list(self.context_messages)
92 | messages.append(Message(role="user", content=sample.question))
93 | requests.append(GenerationRequest(messages=messages))
94 |
95 | sample_details = []
96 | for sample, result in zip(samples, model.generate(requests)):
97 | output = result.generation
98 | prediction = extract_answer(result.generation)
99 |
100 | sample_details.append(
101 | dict(
102 | problem=sample.question,
103 | output=output,
104 | answer=sample.answer,
105 | prediction=prediction,
106 | correct=is_correct(prediction, sample.answer, sample.options),
107 | )
108 | )
109 |
110 | aggregate_metrics = {
111 | "acc": mean(
112 | float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0
113 | for sd in sample_details
114 | )
115 | }
116 | return TaskResult(
117 | aggregate_metrics=aggregate_metrics, sample_details=sample_details
118 | )
119 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from dataclasses import dataclass
3 | from typing import Any, Optional, Sequence
4 |
5 | from ..models import Model
6 |
7 |
8 | @dataclass
9 | class TaskResult:
10 |
11 | aggregate_metrics: dict[str, float]
12 | sample_details: list[dict[str, Any]]
13 |
14 |
15 | class Task(abc.ABC):
16 |
17 | @property
18 | @abc.abstractmethod
19 | def num_samples(self) -> int:
20 | raise NotImplementedError()
21 |
22 | @abc.abstractmethod
23 | def evaluate(
24 | self,
25 | model: Model,
26 | sample_ids: Optional[Sequence[int]] = None,
27 | ) -> TaskResult:
28 | raise NotImplementedError()
29 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/competation_math.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from math import isclose
3 | from typing import Any, Iterable, Optional, Sequence, Union
4 |
5 | from sympy import N, simplify
6 | from sympy.parsing.latex import parse_latex
7 | from sympy.parsing.sympy_parser import parse_expr
8 |
9 | from ..models import GenerationRequest, Message, Model
10 | from .base import Task, TaskResult
11 |
12 |
13 | def _fix_fracs(string: str) -> str:
14 | substrs = string.split("\\frac")
15 | new_str = substrs[0]
16 | if len(substrs) > 1:
17 | substrs = substrs[1:]
18 | for substr in substrs:
19 | new_str += "\\frac"
20 | if substr[0] == "{":
21 | new_str += substr
22 | else:
23 | try:
24 | assert len(substr) >= 2
25 | except AssertionError:
26 | return string
27 | a = substr[0]
28 | b = substr[1]
29 | if b != "{":
30 | if len(substr) > 2:
31 | post_substr = substr[2:]
32 | new_str += "{" + a + "}{" + b + "}" + post_substr
33 | else:
34 | new_str += "{" + a + "}{" + b + "}"
35 | else:
36 | if len(substr) > 2:
37 | post_substr = substr[2:]
38 | new_str += "{" + a + "}" + b + post_substr
39 | else:
40 | new_str += "{" + a + "}" + b
41 | string = new_str
42 | return string
43 |
44 |
45 | def _fix_a_slash_b(string: str) -> str:
46 | if len(string.split("/")) != 2:
47 | return string
48 | a: str = string.split("/")[0]
49 | b: str = string.split("/")[1]
50 | try:
51 | a_int: int = int(a)
52 | b_int: int = int(b)
53 | assert string == "{}/{}".format(a_int, b_int)
54 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
55 | return new_string
56 | except (AssertionError, ValueError):
57 | return string
58 |
59 |
60 | def _remove_right_units(string: str) -> str:
61 | if "\\text{ " in string:
62 | splits = string.split("\\text{ ")
63 | assert len(splits) == 2
64 | return splits[0]
65 | else:
66 | return string
67 |
68 |
69 | def _fix_sqrt(string: str) -> str:
70 | if "\\sqrt" not in string:
71 | return string
72 | splits = string.split("\\sqrt")
73 | new_string = splits[0]
74 | for split in splits[1:]:
75 | if split[0] != "{":
76 | a = split[0]
77 | new_substr = "\\sqrt{" + a + "}" + split[1:]
78 | else:
79 | new_substr = "\\sqrt" + split
80 | new_string += new_substr
81 | return new_string
82 |
83 |
84 | def _strip_string(string: str) -> str:
85 | string = string.replace("\n", "")
86 |
87 | string = string.replace("\\!", "")
88 |
89 | string = string.replace("\\\\", "\\")
90 |
91 | string = string.replace("tfrac", "frac")
92 | string = string.replace("dfrac", "frac")
93 |
94 | string = string.replace("\\left", "")
95 | string = string.replace("\\right", "")
96 |
97 | string = string.replace("^{\\circ}", "")
98 | string = string.replace("^\\circ", "")
99 |
100 | string = string.replace("\\$", "")
101 |
102 | string = _remove_right_units(string)
103 |
104 | string = string.replace(r"\\%", "")
105 | string = string.replace(r"\%", "")
106 |
107 | string = string.replace(" .", " 0.")
108 | string = string.replace("{.", "{0.")
109 | if len(string) == 0:
110 | return string
111 | if string[0] == ".":
112 | string = "0" + string
113 |
114 | if len(string.split("=")) == 2:
115 | if len(string.split("=")[0]) <= 2:
116 | string = string.split("=")[1]
117 |
118 | string = _fix_sqrt(string)
119 |
120 | string = string.replace(" ", "")
121 |
122 | string = _fix_fracs(string)
123 |
124 | if string == "0.5":
125 | string = "\\frac{1}{2}"
126 |
127 | string = _fix_a_slash_b(string)
128 |
129 | return string
130 |
131 |
132 | def is_digit(s: Union[bool, float, str]) -> bool:
133 | try:
134 | float(str(s).replace(",", ""))
135 | return True
136 | except ValueError:
137 | return False
138 |
139 |
140 | def symbolic_equal(a: str, b: str) -> bool:
141 | def _parse(s: str) -> Any:
142 | for f in [parse_latex, parse_expr]:
143 | try:
144 | return f(s)
145 | except Exception:
146 | pass
147 | return s
148 |
149 | a = _parse(a)
150 | b = _parse(b)
151 |
152 | try:
153 | if simplify(a - b) == 0:
154 | return True
155 | except Exception:
156 | pass
157 |
158 | try:
159 | if isclose(N(a), N(b), rel_tol=1e-3):
160 | return True
161 | except Exception:
162 | pass
163 | return False
164 |
165 |
166 | def math_equal(
167 | prediction: Union[bool, float, str],
168 | reference: Union[float, str],
169 | include_percentage: bool = True,
170 | is_close: bool = True,
171 | ) -> bool:
172 | """
173 | Exact match of math if and only if:
174 | 1. numerical equal: both can convert to float and are equal
175 | 2. symbolic equal: both can convert to sympy expression and are equal
176 | """
177 | try:
178 | if is_digit(prediction) and is_digit(reference):
179 | prediction = float(str(prediction).replace(",", ""))
180 | reference = float(str(reference).replace(",", ""))
181 | if include_percentage:
182 | gt_result = [reference / 100, reference, reference * 100]
183 | else:
184 | gt_result = [reference]
185 | for item in gt_result:
186 | try:
187 | if is_close:
188 | if isclose(item, prediction, rel_tol=1e-4):
189 | return True
190 | else:
191 | if item == prediction:
192 | return True
193 | except Exception:
194 | continue
195 | return False
196 | except Exception:
197 | pass
198 |
199 | if not prediction and prediction not in [0, False]:
200 | return False
201 |
202 | reference = str(reference).strip()
203 | prediction = str(prediction).strip()
204 |
205 | pred_str, ref_str = prediction, reference
206 | if (
207 | prediction.startswith("[")
208 | and prediction.endswith("]")
209 | and not reference.startswith("(")
210 | ) or (
211 | prediction.startswith("(")
212 | and prediction.endswith(")")
213 | and not reference.startswith("[")
214 | ):
215 | pred_str = pred_str.strip("[]()")
216 | ref_str = ref_str.strip("[]()")
217 | for s in ["{", "}", "(", ")"]:
218 | ref_str = ref_str.replace(s, "")
219 | pred_str = pred_str.replace(s, "")
220 | if pred_str == ref_str:
221 | return True
222 |
223 | if (
224 | (prediction.startswith("[") and prediction.endswith("]"))
225 | and (reference.startswith("[") and reference.endswith("]"))
226 | or (prediction.startswith("(") and prediction.endswith(")"))
227 | and (reference.startswith("(") and reference.endswith(")"))
228 | ):
229 | pred_parts = prediction[1:-1].split(",")
230 | ref_parts = reference[1:-1].split(",")
231 | if len(pred_parts) == len(ref_parts):
232 | if all(
233 | [
234 | math_equal(
235 | pred_parts[i], ref_parts[i], include_percentage, is_close
236 | )
237 | for i in range(len(pred_parts))
238 | ]
239 | ):
240 | return True
241 |
242 | if symbolic_equal(prediction, reference):
243 | return True
244 |
245 | return False
246 |
247 |
248 | def is_equiv(str1: Optional[str], str2: Optional[str]) -> bool:
249 | if str1 is None and str2 is None:
250 | return True
251 | if str1 is None or str2 is None:
252 | return False
253 |
254 | try:
255 | ss1 = _strip_string(str1)
256 | ss2 = _strip_string(str2)
257 | return math_equal(ss1, ss2) or ss1 == ss2
258 | except (AssertionError, TypeError, ValueError):
259 | return math_equal(str1, str2) or str1 == str2
260 |
261 |
262 | def last_boxed_only_string(string: str) -> Optional[str]:
263 | idx = string.rfind("\\boxed")
264 | if idx < 0:
265 | idx = string.rfind("\\fbox")
266 | if idx < 0:
267 | return None
268 |
269 | i = idx
270 | right_brace_idx: Optional[int] = None
271 |
272 | num_left_braces_open = 0
273 | while i < len(string):
274 | if string[i] == "{":
275 | num_left_braces_open += 1
276 | if string[i] == "}":
277 | num_left_braces_open -= 1
278 | if num_left_braces_open == 0:
279 | right_brace_idx = i
280 | break
281 | i += 1
282 |
283 | if right_brace_idx is None:
284 | retval = None
285 | else:
286 | assert right_brace_idx is not None
287 | retval = string[idx : right_brace_idx + 1]
288 |
289 | return retval
290 |
291 |
292 | def remove_boxed(s: Optional[str]) -> Optional[str]:
293 | left = "\\boxed{"
294 | if s is None:
295 | return None
296 | else:
297 | try:
298 | assert s[: len(left)] == left
299 | assert s[-1] == "}"
300 | return s[len(left) : -1]
301 | except (AssertionError, TypeError, ValueError):
302 | return None
303 |
304 |
305 | @dataclass
306 | class MathSample:
307 |
308 | problem: str
309 | answer: Optional[str] = None
310 | type: Optional[str] = None
311 |
312 |
313 | def mean(iterable: Iterable[float]) -> float:
314 | total, count = 0.0, 0
315 | for x in iterable:
316 | total += x
317 | count += 1
318 | return total / count
319 |
320 |
321 | def extract_ans(completion: str) -> Optional[str]:
322 |
323 | split_ans = completion.split("The answer is: ")
324 | if len(split_ans) > 1:
325 | ans = split_ans[-1]
326 | extract_ans_temp = ans.split(".\n")[0]
327 | extract_ans_temp = extract_ans_temp.strip()
328 | if len(extract_ans_temp) > 0 and extract_ans_temp[-1] == ".":
329 | extract_ans = extract_ans_temp[0:-1]
330 | else:
331 | extract_ans = extract_ans_temp
332 | extract_ans = extract_ans.strip()
333 | return extract_ans
334 | else:
335 | return remove_boxed(last_boxed_only_string(completion))
336 |
337 |
338 | class LatexFormatMathTask(Task):
339 | def __init__(
340 | self,
341 | samples: Sequence[MathSample],
342 | context_messages: Sequence[Message] = (),
343 | ):
344 | self.samples = list(samples)
345 | self.context_messages = context_messages
346 |
347 | @property
348 | def num_samples(self) -> int:
349 | return len(self.samples)
350 |
351 | def evaluate(
352 | self,
353 | model: Model,
354 | sample_ids: Optional[Sequence[int]] = None,
355 | ) -> TaskResult:
356 | if sample_ids is None:
357 | sample_ids = range(len(self.samples))
358 | samples = [self.samples[sample_id] for sample_id in sample_ids]
359 |
360 | requests = []
361 | for sample in samples:
362 | messages = list(self.context_messages)
363 | messages.append(Message(role="user", content=sample.problem))
364 | requests.append(GenerationRequest(messages=messages))
365 |
366 | sample_details = []
367 | for sample, result in zip(samples, model.generate(requests)):
368 | output = result.generation
369 | prediction = extract_ans(output)
370 |
371 | sample_details.append(
372 | dict(
373 | problem=sample.problem,
374 | output=output,
375 | answer=sample.answer,
376 | type=sample.type,
377 | prediction=prediction,
378 | correct=is_equiv(sample.answer, prediction),
379 | )
380 | )
381 |
382 | aggregate_metrics = {
383 | "acc": mean(
384 | float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0
385 | for sd in sample_details
386 | )
387 | }
388 |
389 | return TaskResult(
390 | aggregate_metrics=aggregate_metrics, sample_details=sample_details
391 | )
392 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/evalplus/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import load_dataset
2 | from .task import EvalplusTask
3 |
4 | __all__ = ["EvalplusTask", "load_dataset"]
5 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/evalplus/data.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from evalplus.data import get_human_eval_plus, get_mbpp_plus
4 |
5 |
6 | @dataclass
7 | class TextToCodeProblem:
8 | id: str
9 | instruction: str
10 | response_prefix: str
11 |
12 |
13 | def get_mbpp_raw_problems() -> list[dict]:
14 | problems = get_mbpp_plus()
15 | return list(problems.values())
16 |
17 |
18 | def get_humaneval_raw_problems() -> list[dict]:
19 | problems = get_human_eval_plus()
20 | return list(problems.values())
21 |
22 |
23 | def read_mbpp_plus(
24 | plus_path: str, err_incomplete: bool = True, mini: bool = False
25 | ) -> dict[str, dict]:
26 | from evalplus.data.mbpp import (completeness_check,
27 | mbpp_deserialize_inputs, stream_jsonl)
28 |
29 | plus = {task["task_id"]: task for task in stream_jsonl(plus_path)}
30 | for task_id, task in plus.items():
31 | task["base_input"] = mbpp_deserialize_inputs(task_id, task["base_input"])
32 | task["plus_input"] = mbpp_deserialize_inputs(task_id, task["plus_input"])
33 |
34 | if err_incomplete:
35 | completeness_check("MBPP+", plus)
36 | return plus
37 |
38 |
39 | def map_mbpp_problem(p: dict) -> TextToCodeProblem:
40 | id = p["task_id"]
41 | prompt = p["prompt"]
42 | start_index = prompt.index('"""')
43 | end_index = prompt.rindex('"""')
44 | prompt = prompt[start_index + 3 : end_index]
45 | assert_index = prompt.index("assert")
46 | instruction = prompt[:assert_index].strip()
47 | if not instruction.endswith("."):
48 | instruction += "."
49 | assertion = prompt[assert_index:].strip()
50 | instruction = f"""{instruction} Your code should satisfy the following assertion:
51 | ```python
52 | {assertion}
53 | ```"""
54 | response_prefix = """```python"""
55 | return TextToCodeProblem(
56 | id=str(id), instruction=instruction, response_prefix=response_prefix
57 | )
58 |
59 |
60 | def map_humaneval_problem(p: dict) -> TextToCodeProblem:
61 | id = p["task_id"]
62 | prompt = p["prompt"]
63 | prompt = prompt.strip()
64 | instruction = f"""Write a solution to the following problem:
65 | ```python
66 | {prompt}
67 | ```"""
68 | response_prefix = f"""```python
69 | {prompt}"""
70 | return TextToCodeProblem(
71 | id=id, instruction=instruction, response_prefix=response_prefix
72 | )
73 |
74 |
75 | def load_dataset(source_dataset: str) -> list[TextToCodeProblem]:
76 | if source_dataset not in ("humaneval", "mbpp"):
77 | raise ValueError(f"Unknown source_dataset: {source_dataset}")
78 |
79 | raw_problem_fn = {
80 | "humaneval": get_humaneval_raw_problems,
81 | "mbpp": get_mbpp_raw_problems,
82 | }[source_dataset]
83 |
84 | if source_dataset.startswith("humaneval"):
85 | map_problem_fn = map_humaneval_problem
86 | elif source_dataset.startswith("mbpp"):
87 | map_problem_fn = map_mbpp_problem
88 | else:
89 | raise ValueError(f"Unknown source_dataset: {source_dataset}")
90 |
91 | raw_problems = raw_problem_fn()
92 | problems = list(map(map_problem_fn, raw_problems))
93 |
94 | return problems
95 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/evalplus/evaluation.py:
--------------------------------------------------------------------------------
1 | import json
2 | import multiprocessing
3 | import os
4 | import threading
5 | import time
6 | from collections import Counter, defaultdict
7 | from concurrent.futures import ProcessPoolExecutor, as_completed
8 | from datetime import datetime
9 | from typing import Any
10 | from warnings import warn
11 |
12 | import numpy as np
13 | from evalplus.data import (get_human_eval_plus, get_human_eval_plus_hash,
14 | get_mbpp_plus, get_mbpp_plus_hash, load_solutions)
15 | from evalplus.eval import SUCCESS, estimate_pass_at_k, untrusted_check
16 | from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS
17 | from evalplus.evaluate import Result, get_groundtruth
18 | from termcolor import cprint
19 | from tqdm.auto import tqdm
20 |
21 | from ...logging import get_logger
22 |
23 | logger = get_logger(__name__)
24 |
25 |
26 | def check_correctness(
27 | dataset: str,
28 | completion_id: int,
29 | problem: dict[str, Any],
30 | solution: str,
31 | expected_output: dict[str, list],
32 | base_only: bool = False,
33 | fast_check: bool = False,
34 | identifier: str = "HumanEval/0_0",
35 | min_time_limit: float = 0.1,
36 | gt_time_limit_factor: float = 2.0,
37 | ) -> dict[str, Result]:
38 | ret = {
39 | "completion_id": completion_id,
40 | "task_id": problem["task_id"],
41 | "_identifier": identifier,
42 | "solution": solution,
43 | }
44 | ret["base"] = untrusted_check(
45 | dataset,
46 | solution,
47 | problem["base_input"],
48 | problem["entry_point"],
49 | expected=expected_output["base"],
50 | atol=problem["atol"],
51 | ref_time=expected_output["base_time"],
52 | fast_check=fast_check,
53 | min_time_limit=min_time_limit,
54 | gt_time_limit_factor=gt_time_limit_factor,
55 | )
56 |
57 | if not base_only:
58 | ret["plus"] = untrusted_check(
59 | dataset,
60 | solution,
61 | problem["plus_input"],
62 | problem["entry_point"],
63 | expected=expected_output["plus"],
64 | atol=problem["atol"],
65 | ref_time=expected_output["plus_time"],
66 | fast_check=fast_check,
67 | min_time_limit=min_time_limit,
68 | gt_time_limit_factor=gt_time_limit_factor,
69 | )
70 | return ret
71 |
72 |
73 | def evaluate(
74 | source_dataset: str,
75 | output_path: str,
76 | base_only: bool = False,
77 | parallel: int = 0,
78 | i_just_wanna_run: bool = False,
79 | test_details: bool = False,
80 | min_time_limit: float = 0.2,
81 | gt_time_limit_factor: float = 4.0,
82 | mini: bool = False,
83 | ) -> tuple[Any, list[dict[str, Any]]]:
84 | if parallel == 0:
85 | n_workers = max(1, multiprocessing.cpu_count() // 2)
86 | else:
87 | n_workers = parallel
88 |
89 | if os.path.isdir(output_path):
90 | result_path = os.path.join(output_path, "eval_results.json")
91 | else:
92 | assert output_path.endswith(".jsonl")
93 | result_path = output_path.replace(".jsonl", "_eval_results.json")
94 |
95 | if source_dataset == "humaneval":
96 | problems = get_human_eval_plus(mini=mini)
97 | dataset_hash = get_human_eval_plus_hash()
98 | expected_output = get_groundtruth(problems, dataset_hash, [])
99 | elif source_dataset == "mbpp":
100 | problems = get_mbpp_plus(mini=mini)
101 | dataset_hash = get_mbpp_plus_hash()
102 | expected_output = get_groundtruth(
103 | problems,
104 | dataset_hash,
105 | MBPP_OUTPUT_NOT_NONE_TASKS,
106 | )
107 |
108 | results = {
109 | "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
110 | "hash": dataset_hash,
111 | "eval": {},
112 | }
113 |
114 | with ProcessPoolExecutor(max_workers=n_workers) as executor:
115 | futures = []
116 | completion_id: Counter[str] = Counter()
117 | n_samples = 0
118 | eval_results = defaultdict(list)
119 | remainings = set()
120 | sample_details = []
121 |
122 | logger.info("Reading samples...")
123 | for sample in tqdm(load_solutions(output_path)):
124 | task_id = sample["task_id"]
125 | explanation = sample.get("explanation", "")
126 | solution = (
127 | sample["solution"]
128 | if "solution" in sample
129 | else problems[task_id]["prompt"] + sample["completion"]
130 | )
131 | remainings.add(sample["_identifier"])
132 |
133 | args = (
134 | source_dataset,
135 | completion_id[task_id],
136 | problems[task_id],
137 | solution,
138 | expected_output[task_id],
139 | base_only,
140 | not test_details,
141 | sample["_identifier"],
142 | min_time_limit,
143 | gt_time_limit_factor,
144 | )
145 |
146 | futures.append(executor.submit(check_correctness, *args))
147 | completion_id[task_id] += 1
148 | n_samples += 1
149 |
150 | sample_details.append(
151 | dict(
152 | task_id=task_id,
153 | solution=solution,
154 | explanation=explanation,
155 | problems=problems[task_id],
156 | expected_output=expected_output[task_id],
157 | )
158 | )
159 |
160 | assert n_samples == len(remainings), "Missing problems in unfinished"
161 | if len(completion_id) != len(problems):
162 | logger.warning("Warning: Missing problems in samples")
163 |
164 | def stucking_checker() -> None:
165 | while remainings:
166 | last_size = len(remainings)
167 | time.sleep(20)
168 | if last_size != len(remainings) or len(remainings) == 0:
169 | continue
170 | warn("No samples had finished testing in the last 20s")
171 | warn(f"{len(remainings)} samples to be tested: {remainings}")
172 |
173 | threading.Thread(target=stucking_checker).start()
174 |
175 | for future in tqdm(as_completed(futures), total=n_samples):
176 | result = future.result()
177 | remainings.remove(result["_identifier"])
178 | eval_results[result["task_id"]].append(result)
179 |
180 | for task_id, task_results in eval_results.items():
181 | task_results.sort(key=lambda x: x["completion_id"])
182 | results["eval"][task_id] = {
183 | "nfiles": len(task_results),
184 | "base": [x["base"] for x in task_results],
185 | "plus": ([x["plus"] for x in task_results] if not base_only else []),
186 | }
187 |
188 | if os.path.isfile(result_path) and i_just_wanna_run:
189 | decision = ""
190 | while decision.lower() not in ["y", "n"]:
191 | logger.info(
192 | f"{result_path} already exists. Press [Y/N] to overwrite or exit..."
193 | )
194 | decision = input()
195 |
196 | if decision.lower() == "y":
197 | new_path = result_path + ".bak"
198 | while os.path.isfile(new_path):
199 | new_path += ".bak"
200 | os.rename(result_path, new_path)
201 | logger.info(f"Backup {result_path} to {new_path}")
202 |
203 | if not os.path.isfile(result_path):
204 | with open(result_path, "w") as f:
205 | json.dump(results, f)
206 |
207 | total = np.array([r["nfiles"] for r in results["eval"].values()])
208 | base_correct = []
209 | new_correct = []
210 |
211 | for key, res in results["eval"].items():
212 | elements = [element for element in sample_details if element["task_id"] == key]
213 | assert (
214 | len(elements) == 1
215 | ), f"Expected an element with task_id {key}, found {len(elements)}"
216 | element = elements[0]
217 |
218 | bc = sum([r[0] == SUCCESS for r in res["base"]])
219 | base_correct.append(bc)
220 | element["base_correct"] = bc
221 | if res["plus"]:
222 | new_bc = sum(
223 | [
224 | res["plus"][i][0] == res["base"][i][0] == SUCCESS
225 | for i in range(len(res["plus"]))
226 | ]
227 | )
228 | new_correct.append(new_bc)
229 | element["plus_correct"] = new_bc
230 |
231 | base_correct_array = np.array(base_correct)
232 |
233 | pass_at_k = {
234 | f"pass@{k}": estimate_pass_at_k(total, base_correct_array, k).mean()
235 | for k in [1, 10, 100]
236 | if total.min() >= k
237 | }
238 |
239 | result = {f"{source_dataset}_base_{key}": value for key, value in pass_at_k.items()}
240 | cprint(f"{source_dataset} (base tests)", "red")
241 | for k, v in pass_at_k.items():
242 | cprint(f"{k}:\t{v:.3f}", "red")
243 |
244 | if new_correct:
245 | cprint(f"{source_dataset}+ (base + extra tests)", "green")
246 | pass_at_k = {
247 | f"pass@{k}": estimate_pass_at_k(total, np.array(new_correct), k).mean()
248 | for k in [1, 10, 100]
249 | if (total >= k).all()
250 | }
251 | result.update(
252 | {f"{source_dataset}_plus_{key}": value for key, value in pass_at_k.items()}
253 | )
254 | for k, v in pass_at_k.items():
255 | cprint(f"{k}:\t{v:.3f}", "green")
256 |
257 | return result, sample_details
258 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/evalplus/generation.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from pathlib import Path
3 | from typing import Iterable, List, Sequence, TypeVar
4 |
5 | from evalplus.data import write_jsonl
6 | from tqdm.auto import tqdm
7 |
8 | from ...models import GenerationRequest, Message, Model
9 | from .data import TextToCodeProblem
10 |
11 | _T = TypeVar("_T")
12 |
13 |
14 | def chunked(seq: Sequence[_T], n: int) -> Iterable[Sequence[_T]]:
15 | """Yield successive n-sized chunks from seq."""
16 | return (seq[i : i + n] for i in range(0, len(seq), n))
17 |
18 |
19 | def generate(
20 | model: Model,
21 | problems: list[TextToCodeProblem],
22 | context_messages: Sequence[Message],
23 | output_path: str,
24 | n_batches: int = 1,
25 | n_problems_per_batch: int = 1_000_000_000,
26 | n_samples_per_problem: int = 1,
27 | ) -> List[str]:
28 | problems_chunked = list(chunked(list(problems), n_problems_per_batch))
29 | iter = itertools.product(problems_chunked, range(n_batches))
30 | n_total = len(problems_chunked) * n_batches
31 |
32 | Path(output_path).write_text("")
33 | for problems, batch_idx in tqdm(iter, total=n_total):
34 | task_ids = [problem.id for problem in problems]
35 | all_task_ids = task_ids * n_samples_per_problem
36 |
37 | requests = []
38 | for problem in problems:
39 | messages = list(context_messages)
40 | messages.append(Message(role="user", content=problem.instruction))
41 | messages.append(
42 | Message(role="assistant_prefill", content=problem.response_prefix)
43 | )
44 | requests.append(GenerationRequest(messages=messages))
45 | completes = model.generate(requests)
46 | completions = [c.generation for c in completes]
47 |
48 | assert len(problems) <= n_problems_per_batch
49 | assert len(completions) == len(problems) * n_samples_per_problem
50 |
51 | samples = []
52 | for task_id, completion in zip(all_task_ids, completions):
53 | completion_body = completion[
54 | : (
55 | index
56 | if (index := completion.find("```")) != -1
57 | else len(completion)
58 | )
59 | ]
60 | explanation = completion[
61 | (
62 | index
63 | if (index := completion.find("```") + 3) != -1
64 | else len(completion)
65 | ) :
66 | ].strip()
67 |
68 | samples.append(
69 | dict(
70 | task_id=task_id,
71 | completion=completion_body,
72 | explanation=explanation,
73 | )
74 | )
75 |
76 | write_jsonl(output_path, samples, append=True)
77 | return completions
78 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/evalplus/sanitization.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | import pathlib
4 | import re
5 | import traceback
6 | from typing import Optional
7 |
8 | from evalplus.data import (get_human_eval_plus, get_mbpp_plus, load_solutions,
9 | write_directory, write_jsonl)
10 | from tqdm.auto import tqdm
11 |
12 | from ...logging import get_logger
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | def syntax_check(code: str, verbose: bool = False) -> bool:
18 | try:
19 | ast.parse(code)
20 | return True
21 | except (SyntaxError, MemoryError):
22 | if verbose:
23 | traceback.print_exc()
24 | return False
25 |
26 |
27 | def remove_unindented_lines(
28 | code: str, protect_before: str, execeptions: list[str], trim_tails: list[str]
29 | ) -> str:
30 | lines = code.splitlines()
31 | cut_idx = []
32 | cut_enabled = False
33 | for i, line in enumerate(lines):
34 | if not cut_enabled and line.startswith(protect_before):
35 | cut_enabled = True
36 | continue
37 | if line.strip() == "":
38 | continue
39 | if any(line.startswith(e) for e in execeptions):
40 | continue
41 |
42 | lspace = len(line) - len(line.lstrip())
43 | if lspace == 0:
44 | cut_idx.append(i)
45 |
46 | if any(line.rstrip().startswith(t) for t in trim_tails):
47 | cut_idx.extend(list(range(i, len(lines))))
48 | break
49 |
50 | return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx])
51 |
52 |
53 | def to_four_space_indents(old_code: str) -> str:
54 | new_code = ""
55 | for line in old_code.splitlines():
56 | lspace = len(line) - len(line.lstrip())
57 | if lspace == 3:
58 | new_code += " "
59 | new_code += line + "\n"
60 | return new_code
61 |
62 |
63 | def sanitize_code(
64 | old_code: str,
65 | entry_point: str,
66 | rm_prefix_lines: Optional[str] = None,
67 | eofs: list = [],
68 | ) -> str:
69 | new_code = old_code
70 | if rm_prefix_lines is not None:
71 | new_code = "\n".join(
72 | [
73 | line
74 | for line in old_code.splitlines()
75 | if not line.startswith(rm_prefix_lines)
76 | ]
77 | )
78 |
79 | new_code = "\n" + new_code
80 | def_left = "def " + entry_point
81 |
82 | new_code = new_code.replace("\n```python\n", "\n```\n")
83 | for chunk in new_code.split("\n```\n"):
84 | if def_left in chunk:
85 | new_code = chunk
86 | break
87 |
88 | chunks = [chunk for chunk in re.split(rf"{def_left}\s*\(", new_code)]
89 | bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]]
90 | def_left = def_left + "("
91 | new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else ""
92 | new_code = to_four_space_indents(new_code)
93 |
94 | for eof in eofs or []:
95 | new_code = new_code.split(eof)[0]
96 |
97 | new_code = remove_unindented_lines(
98 | new_code,
99 | protect_before=def_left,
100 | execeptions=["def ", "import ", "from "],
101 | trim_tails=['"""', "if", "print"],
102 | )
103 | new_code = chunks[0] + new_code
104 |
105 | parts = new_code.split("\ndef ")
106 | includes = [parts[0]]
107 | for fn in new_code.split("\ndef ")[1:]:
108 | if (
109 | fn.strip().startswith(entry_point + " ")
110 | or fn.strip().startswith(entry_point + "(")
111 | or syntax_check("\ndef " + fn)
112 | ):
113 | includes.append(fn)
114 | new_code = "\ndef ".join(includes)
115 | return new_code.strip()
116 |
117 |
118 | def sanitize(
119 | source_dataset: str,
120 | input_path: str,
121 | eofs: list = [],
122 | inplace: bool = False,
123 | rm_prefix_lines: Optional[str] = None,
124 | debug_task: Optional[str] = None,
125 | ) -> str:
126 | entry_point = {}
127 |
128 | if source_dataset == "humaneval":
129 | dataset = get_human_eval_plus()
130 | elif source_dataset == "mbpp":
131 | dataset = get_mbpp_plus()
132 |
133 | for task_id, problem in dataset.items():
134 | entry_point[task_id] = problem["entry_point"]
135 |
136 | is_folder = os.path.isdir(input_path)
137 | target_path = pathlib.Path(input_path)
138 | if not inplace:
139 | if is_folder:
140 | new_name = target_path.name + "-sanitized"
141 | else:
142 | new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl")
143 | target_path = target_path.parent / new_name
144 | output_path = str(target_path)
145 |
146 | nsan = 0
147 | ntotal = 0
148 |
149 | new_solutions = []
150 |
151 | for solution in tqdm(load_solutions(input_path)):
152 | task_id = solution["task_id"]
153 | dbg_identifier = solution["_identifier"]
154 | if debug_task is not None and task_id != debug_task:
155 | continue
156 |
157 | ntotal += 1
158 | if "solution" in solution:
159 | old_code = solution["solution"]
160 | else:
161 | assert "completion" in solution
162 | old_code = dataset[task_id]["prompt"] + "\n" + solution["completion"]
163 |
164 | old_code = old_code.strip()
165 |
166 | new_code = sanitize_code(
167 | old_code=old_code,
168 | entry_point=entry_point[task_id],
169 | rm_prefix_lines=rm_prefix_lines,
170 | eofs=eofs,
171 | ).strip()
172 |
173 | if new_code != old_code:
174 | msg = "Sanitized: " + dbg_identifier
175 | if is_folder:
176 | msg += " -> " + dbg_identifier.replace(input_path, output_path)
177 | logger.info(msg)
178 | nsan += 1
179 |
180 | new_solutions.append(
181 | {
182 | "task_id": task_id,
183 | "solution": new_code,
184 | "explanation": solution["explanation"],
185 | }
186 | )
187 |
188 | if is_folder:
189 | write_directory(output_path, new_solutions)
190 | else:
191 | write_jsonl(output_path, new_solutions)
192 |
193 | logger.info(f"Sanitized {nsan} out of {ntotal} files.")
194 |
195 | return output_path
196 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/evalplus/task.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from typing import Literal, Optional, Sequence
3 |
4 | from ...models import Message, Model
5 | from ..base import Task, TaskResult
6 | from . import evaluation, generation, sanitization
7 | from .data import TextToCodeProblem
8 |
9 |
10 | class EvalplusTask(Task):
11 |
12 | def __init__(
13 | self,
14 | samples: Sequence[TextToCodeProblem],
15 | context_messages: Sequence[Message] = (),
16 | source_dataset: Literal["humaneval", "mbpp"] = "humaneval",
17 | ):
18 | self.samples = list(samples)
19 | self.context_messages = context_messages
20 | self.source_dataset = source_dataset
21 | if source_dataset not in ("humaneval", "mbpp"):
22 | raise ValueError(f"Unknown source_dataset: {source_dataset}")
23 |
24 | @property
25 | def num_samples(self) -> int:
26 | return len(self.samples)
27 |
28 | def evaluate(
29 | self,
30 | model: Model,
31 | sample_ids: Optional[Sequence[int]] = None,
32 | ) -> TaskResult:
33 | if sample_ids is None:
34 | sample_ids = range(len(self.samples))
35 | samples = [self.samples[sample_id] for sample_id in sample_ids]
36 |
37 | with tempfile.TemporaryDirectory() as save_dir:
38 | output_path = f"{save_dir}/outputs.jsonl"
39 |
40 | completions = generation.generate(
41 | model, samples, self.context_messages, output_path
42 | )
43 |
44 | if self.source_dataset == "mbpp":
45 | output_path = sanitization.sanitize(self.source_dataset, output_path)
46 |
47 | result, sample_details = evaluation.evaluate(
48 | self.source_dataset, output_path
49 | )
50 |
51 | for i, completion in enumerate(completions):
52 | sample_details[i]["output"] = completion
53 |
54 | return TaskResult(aggregate_metrics=result, sample_details=sample_details)
55 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/tasks/language_restricted_math.py:
--------------------------------------------------------------------------------
1 | import re
2 | from dataclasses import dataclass
3 | from typing import Iterable, Optional, Sequence
4 |
5 | import huggingface_hub
6 |
7 | from ..imports import try_import
8 | from ..models import GenerationRequest, Message, Model
9 | from .base import Task, TaskResult
10 |
11 | with try_import() as _imports:
12 | import fasttext
13 |
14 | _imports.check()
15 |
16 |
17 | @dataclass
18 | class MathSample:
19 |
20 | problem: str
21 | answer: int
22 |
23 |
24 | def mean(iterable: Iterable[float]) -> float:
25 | total, count = 0.0, 0
26 | for x in iterable:
27 | total += x
28 | count += 1
29 | return total / count
30 |
31 |
32 | def extract_answer_number(completion: str) -> Optional[float]:
33 | matches = re.findall(r"\d*\.?\d+", completion)
34 | if not matches:
35 | return None
36 | text = matches[-1]
37 | return float(text.replace(",", ""))
38 |
39 |
40 | class LanguageRestrictedMathTask(Task):
41 | def __init__(
42 | self,
43 | samples: Sequence[MathSample],
44 | context_messages: Sequence[Message] = (),
45 | languages: Sequence[str] = ("ja", "en"),
46 | ):
47 | self.samples = list(samples)
48 | self.languages = languages
49 | self.context_messages = context_messages
50 | if len(self.languages) != 0:
51 | lid176ftz_path = huggingface_hub.hf_hub_download(
52 | "julien-c/fasttext-language-id", "lid.176.ftz"
53 | )
54 | self.lid_model = fasttext.load_model(lid176ftz_path)
55 |
56 | @property
57 | def num_samples(self) -> int:
58 | return len(self.samples)
59 |
60 | def evaluate(
61 | self,
62 | model: Model,
63 | sample_ids: Optional[Sequence[int]] = None,
64 | ) -> TaskResult:
65 | if sample_ids is None:
66 | sample_ids = range(len(self.samples))
67 | samples = [self.samples[sample_id] for sample_id in sample_ids]
68 |
69 | requests = []
70 | for sample in samples:
71 | messages = list(self.context_messages)
72 | messages.append(Message(role="user", content=sample.problem))
73 | requests.append(GenerationRequest(messages=messages))
74 |
75 | sample_details = []
76 | for sample, result in zip(samples, model.generate(requests)):
77 | output = result.generation
78 | prediction = extract_answer_number(result.generation)
79 | if len(self.languages) != 0:
80 | lid_probs = dict(
81 | zip(*self.lid_model.predict(output.replace("\n", ""), k=-1))
82 | )
83 |
84 | sample_details.append(
85 | dict(
86 | problem=sample.problem,
87 | output=output,
88 | answer=sample.answer,
89 | prediction=prediction,
90 | correct=sample.answer == prediction,
91 | **{
92 | f"lang_{lang}": lid_probs.get(f"__label__{lang}", 0.0)
93 | for lang in self.languages
94 | },
95 | )
96 | )
97 |
98 | aggregate_metrics = {"acc": mean(sd["correct"] for sd in sample_details)}
99 | for lang in self.languages:
100 | aggregate_metrics[f"acc_{lang}"] = mean(
101 | (sd["correct"] and sd[f"lang_{lang}"] > 0.5) for sd in sample_details
102 | )
103 |
104 | return TaskResult(
105 | aggregate_metrics=aggregate_metrics, sample_details=sample_details
106 | )
107 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/fishfarm/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0dev"
2 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "fishfarm"
3 | description = ""
4 | readme = "README.md"
5 | license = {file = "LICENSE"}
6 | authors = [
7 | {name = "Takuya Akiba"},
8 | {email = "takiba@sakana.ai"}
9 | ]
10 | classifiers = [
11 | "Development Status :: 2 - Pre-Alpha",
12 | "Intended Audience :: Science/Research",
13 | "Intended Audience :: Developers",
14 | "License :: OSI Approved :: MIT License",
15 | "Programming Language :: Python :: 3",
16 | "Programming Language :: Python :: 3.8",
17 | "Programming Language :: Python :: 3.9",
18 | "Programming Language :: Python :: 3.10",
19 | "Programming Language :: Python :: 3.11",
20 | "Programming Language :: Python :: 3 :: Only",
21 | "Topic :: Scientific/Engineering",
22 | "Topic :: Scientific/Engineering :: Mathematics",
23 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
24 | "Topic :: Software Development",
25 | "Topic :: Software Development :: Libraries",
26 | "Topic :: Software Development :: Libraries :: Python Modules",
27 | ]
28 | requires-python = ">=3.10"
29 | dependencies = [
30 | "huggingface_hub",
31 | "transformers",
32 | "pydantic",
33 | "colorlog"
34 | ]
35 | dynamic = ["version"]
36 |
37 | [project.optional-dependencies]
38 | development = [
39 | "black",
40 | "blackdoc",
41 | "flake8",
42 | "isort",
43 | "mypy",
44 | "pytest",
45 | "pytest-mock",
46 | "types-PyYAML",
47 | ]
48 |
49 | full = [
50 | "vllm",
51 | "langchain",
52 | "langchain-openai",
53 | "fasttext-wheel",
54 | "datasets",
55 | "mysql-connector-python==8.0.32",
56 | "docker==6.1.2",
57 | "evalplus @ git+https://github.com/evalplus/evalplus@1895d2f6aa8895044a7cf69defc24bd57695e885",
58 | "rouge-score"
59 | ]
60 |
61 | [project.urls]
62 | repository = "https://github.com/SakanaAI/fishfarm"
63 |
64 | [tool.setuptools.packages.find]
65 | include = ["fishfarm*"]
66 |
67 | [tool.setuptools.dynamic]
68 | version = {attr = "fishfarm.version.__version__"}
69 |
70 | [tool.black]
71 | line-length = 99
72 | target-version = ['py310']
73 | exclude = '''
74 | /(
75 | \.eggs
76 | | \.git
77 | | \.hg
78 | | \.mypy_cache
79 | | \.venv
80 | | venv
81 | | _build
82 | | buck-out
83 | | build
84 | | dist
85 | | docs
86 | | data
87 | )/
88 | '''
89 |
90 | [tool.isort]
91 | profile = 'black'
92 | src_paths = ['fishfarm', 'tests']
93 | line_length = 99
94 | lines_after_imports = 2
95 |
96 | [tool.mypy]
97 | python_version = "3.10"
98 | strict = true
99 | ignore_missing_imports = true
100 | warn_unused_configs = true
101 | disallow_untyped_defs = true
102 | warn_redundant_casts = true
103 | warn_unused_ignores = true
104 | warn_unreachable = true
105 | disallow_any_generics = false
106 | exclude = ".venv|venv|build|docs|tutorial|data"
107 |
108 | [tool.pytest]
109 | mock_use_standalone_module = true
110 |
--------------------------------------------------------------------------------
/evaluation/fishfarm/tox.ini:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 99
3 | statistics = True
4 | exclude = .venv,venv,build,notebooks,.asv,data
5 | ignore =
6 | E203,
7 | W503,
8 | E704
--------------------------------------------------------------------------------
/logging_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def get_mean_std_max_min_dict(array, prefix):
5 | res = {}
6 | res[prefix + "/mean"] = np.mean(array)
7 | res[prefix + "/std"] = np.std(array)
8 | res[prefix + "/min"] = np.amin(array)
9 | res[prefix + "/max"] = np.amax(array)
10 | return res
11 |
12 |
13 | class Metrics:
14 | """Object keeping running average/latest of relevant metrics to log."""
15 |
16 | def __init__(self, *args):
17 | self.metrics = {arg: 0 for arg in args}
18 | self.latest_metrics = {arg: 0 for arg in args}
19 | self.samples = {arg: 1e-8 for arg in args}
20 | self.logged_metrics = [arg for arg in args]
21 |
22 | def reset(self):
23 | for arg in self.metrics:
24 | self.metrics[arg] = 0
25 | self.samples[arg] = 1e-8
26 |
27 | def add(self, *args):
28 | for arg in args:
29 | if arg not in self.metrics:
30 | self.logged_metrics.append(arg)
31 | self.metrics[arg] = 0
32 | self.latest_metrics[arg] = 0
33 | self.samples[arg] = 1e-8
34 |
35 | def update(self, **kwargs):
36 | for arg, val in kwargs.items():
37 | if arg not in self.metrics:
38 | self.logged_metrics += arg
39 | self.metrics[arg] = 0
40 | self.latest_metrics[arg] = 0
41 | self.samples[arg] = 1e-8
42 | self.metrics[arg] += val
43 | self.samples[arg] += 1
44 |
45 | def set(self, **kwargs):
46 | for arg, val in kwargs.items():
47 | if arg not in self.metrics:
48 | self.logged_metrics += arg
49 | self.metrics[arg] = val
50 | self.samples[arg] = 1
51 | self.metrics[arg] = val
52 | self.samples[arg] = 1
53 |
54 | def get(self):
55 | for arg, metric_agg in self.metrics.items():
56 | samples = self.samples[arg]
57 | if samples >= 1:
58 | self.latest_metrics[arg] = metric_agg / samples
59 | return self.latest_metrics
60 |
--------------------------------------------------------------------------------
/optim_modules.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from logging_utils import get_mean_std_max_min_dict
9 | from utils import (backward, eval_model, forward, load_base_params,
10 | load_hf_params_to_vllm)
11 |
12 |
13 | class OptimizationAlgorithm(abc.ABC):
14 | def __init__(self, **kwargs):
15 | nn.Module.__init__(self=self)
16 |
17 | @abc.abstractmethod
18 | def step_optimization(
19 | self,
20 | model_id,
21 | model,
22 | tokenizer,
23 | policy,
24 | task_loader,
25 | batch_ix,
26 | train_data,
27 | train_eval,
28 | base_params,
29 | decomposed_params,
30 | original_model_params,
31 | metrics_to_log,
32 | vllm_model=None,
33 | **kwargs,
34 | ):
35 | raise NotADirectoryError
36 |
37 | @abc.abstractmethod
38 | def update(self, policy):
39 | raise NotImplementedError
40 |
41 | def log_optim(self, metrics_to_log):
42 | pass
43 |
44 |
45 | class Reinforce(OptimizationAlgorithm, nn.Module):
46 | def __init__(
47 | self, policy, gpu, max_grad_norm, lr, rw_norm, rw_clip, kl_ref_coeff, **kwargs
48 | ):
49 | nn.Module.__init__(self=self)
50 | self.gpu = gpu
51 | self.kl_ref_coeff = kl_ref_coeff
52 | self.use_kl_loss = kl_ref_coeff > 0.0
53 | self.max_grad_norm = float(max_grad_norm)
54 | self.lr = lr
55 | self.rw_norm = rw_norm
56 | self.rw_clip = rw_clip
57 | self.optimizer = torch.optim.Adam(policy.trainable_params, lr=lr)
58 |
59 | def compute_ref_logprobs(
60 | self,
61 | model,
62 | tokenizer,
63 | prompts,
64 | res,
65 | ):
66 | ref_log_probs_list = []
67 | print("Computing reference log probs...")
68 | for j, prompt in enumerate(prompts):
69 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(self.gpu)
70 | prompt_length = input_ids.shape[-1]
71 | output_ids = tokenizer(
72 | prompt + res.sample_details[j]["output"],
73 | return_tensors="pt",
74 | ).input_ids.to(self.gpu)
75 | outputs = model(output_ids)
76 | logits = outputs.logits[:, prompt_length - 1 : -1]
77 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
78 | ref_log_probs_list.append(log_probs.detach().cpu())
79 | return ref_log_probs_list
80 |
81 | def get_rewards(self, task_loader, res):
82 | rw_norm = self.rw_norm
83 | rw_clip = self.rw_clip
84 | rewards = task_loader.get_rewards(res=res)
85 |
86 | if rw_norm:
87 | rewards = np.array(rewards)
88 | mean_rw = np.mean(rewards)
89 | std_rw = np.clip(np.std(rewards), a_min=1e-7, a_max=None)
90 | rewards = (rewards - mean_rw) / std_rw
91 | if rw_clip is not None:
92 | if rw_clip > 0:
93 | rewards = np.array(rewards)
94 | rewards = np.clip(rewards, a_min=-rw_clip, a_max=rw_clip)
95 | return rewards
96 |
97 | def step_optimization(
98 | self,
99 | model_id,
100 | model,
101 | tokenizer,
102 | policy,
103 | task_loader,
104 | batch_ix,
105 | train_data,
106 | train_eval,
107 | base_params,
108 | decomposed_params,
109 | original_model_params,
110 | metrics_to_log,
111 | vllm_model=None,
112 | **kwargs,
113 | ):
114 | use_kl_loss = self.use_kl_loss
115 | kl_ref_coeff = self.kl_ref_coeff
116 |
117 | gpu = self.gpu
118 |
119 | prompts = [
120 | task_loader.get_prompt(
121 | tokenizer,
122 | train_data,
123 | i,
124 | model_id=model_id,
125 | )
126 | for i in batch_ix
127 | ]
128 |
129 | clipped_batch_size = len(prompts)
130 |
131 | learnable_params = policy.get_learnable_params()
132 | new_params = forward(
133 | policy, model, base_params, decomposed_params, learnable_params
134 | )
135 |
136 | print("Loading weights and getting completions with VLLM")
137 | load_hf_params_to_vllm(new_params, vllm_model.llm)
138 | res = eval_model(vllm_model, train_eval, batch_ix)
139 | rewards = self.get_rewards(task_loader=task_loader, res=res)
140 |
141 | rw_stats = get_mean_std_max_min_dict(array=rewards, prefix="rewards")
142 | metrics_to_log.update(**rw_stats)
143 |
144 | if use_kl_loss:
145 | with torch.no_grad():
146 | load_base_params(model=model, base_params=original_model_params)
147 | ref_log_probs_list = self.compute_ref_logprobs(
148 | model=model,
149 | tokenizer=tokenizer,
150 | prompts=prompts,
151 | res=res,
152 | )
153 | new_params = forward(
154 | policy, model, base_params, decomposed_params, learnable_params
155 | )
156 |
157 | print("Computing the policy gradient...")
158 | for j, prompt in enumerate(prompts):
159 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(gpu)
160 | prompt_length = input_ids.shape[-1]
161 | output_ids = tokenizer(
162 | prompt + res.sample_details[j]["output"],
163 | return_tensors="pt",
164 | ).input_ids.to(gpu)
165 | generated_ids = output_ids[:, prompt_length:]
166 |
167 | outputs = model(output_ids)
168 | logits = outputs.logits[:, prompt_length - 1 : -1]
169 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
170 | selected_log_probs = log_probs.gather(
171 | 2, generated_ids.unsqueeze(-1)
172 | ).squeeze(-1)
173 | log_likelihood = selected_log_probs.sum(axis=-1)
174 |
175 | pg = -log_likelihood * rewards[j]
176 | loss = pg
177 |
178 | if use_kl_loss:
179 | ref_log_probs = ref_log_probs_list[j].to(gpu)
180 | kl_div = F.kl_div(
181 | input=log_probs,
182 | target=ref_log_probs,
183 | log_target=True,
184 | reduction="sum",
185 | )
186 | loss = loss + kl_ref_coeff * kl_div
187 | scaled_loss = loss / clipped_batch_size
188 | scaled_loss.backward()
189 | log_dict = {
190 | "pg": pg.item(),
191 | "loss": loss.item(),
192 | }
193 | if use_kl_loss:
194 | log_dict["kl_div"] = kl_div.item()
195 | metrics_to_log.update(**log_dict)
196 | backward(policy, model, base_params, decomposed_params, learnable_params)
197 |
198 | def update(self, policy):
199 | max_grad_norm = self.max_grad_norm
200 | torch.nn.utils.clip_grad_norm_(policy.trainable_params, max_grad_norm)
201 | self.optimizer.step()
202 | self.optimizer.zero_grad()
203 |
204 | def log_optim(self, metrics_to_log):
205 | metrics_dict = metrics_to_log.get()
206 | pg = metrics_dict["pg"]
207 | print(f"PG={pg}")
208 | if self.use_kl_loss:
209 | kl_div = metrics_dict["kl_div"]
210 | print(f"kl_div={kl_div}")
211 |
212 |
213 | class RandomShooting(OptimizationAlgorithm, nn.Module):
214 | def __init__(
215 | self,
216 | policy,
217 | gpu,
218 | pop_size,
219 | min_trainable_param,
220 | max_trainable_param,
221 | optim_ema=0,
222 | re_eval_best=True,
223 | use_loglikelihood_for_ties=False,
224 | **kwargs,
225 | ):
226 |
227 | nn.Module.__init__(self=self)
228 | self.gpu = gpu
229 | trainable_params = policy.trainable_params
230 | self.pop_size = pop_size
231 | self.min_trainable_param = min_trainable_param
232 | self.max_trainable_param = max_trainable_param
233 | self.range_trainable_param = max_trainable_param - min_trainable_param
234 | assert optim_ema >= 0 and optim_ema < 1
235 | self.optim_ema = optim_ema
236 | self.re_eval_best = re_eval_best
237 | self.use_loglikelihood_for_ties = use_loglikelihood_for_ties
238 |
239 | self.trainable_params_shapes = [p.shape for p in trainable_params]
240 | self.trainable_params_nums = [torch.numel(p) for p in trainable_params]
241 | self.trainable_params_dtype = trainable_params[0].dtype
242 | self.total_trainable_params = sum(self.trainable_params_nums)
243 | self.best_idx = 0
244 |
245 | initial_values = (
246 | torch.rand(size=[pop_size, self.total_trainable_params])
247 | * self.range_trainable_param
248 | ) + self.min_trainable_param
249 | init_values_flat = [
250 | torch.flatten(torch.detach_copy(p.data)) for p in trainable_params
251 | ]
252 | init_soln = torch.concat(init_values_flat, dim=0)
253 | if self.re_eval_best:
254 | initial_values[0] = torch.clone(init_soln)
255 |
256 | self.pop_params = nn.Parameter(
257 | initial_values,
258 | requires_grad=False,
259 | ).cpu()
260 | self.best_soln = nn.Parameter(init_soln, requires_grad=False).cpu()
261 |
262 | def compute_logprobs(
263 | self,
264 | model,
265 | tokenizer,
266 | prompts,
267 | generated_outputs,
268 | ):
269 | selected_log_probs_list = []
270 | for j, prompt in enumerate(prompts):
271 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(self.gpu)
272 | prompt_length = input_ids.shape[-1]
273 | output_ids = tokenizer(
274 | prompt + generated_outputs[j],
275 | return_tensors="pt",
276 | ).input_ids.to(self.gpu)
277 | generated_ids = output_ids[:, prompt_length:]
278 |
279 | outputs = model(output_ids)
280 | logits = outputs.logits[:, prompt_length - 1 : -1]
281 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
282 | selected_log_probs = log_probs.gather(
283 | 2, generated_ids.unsqueeze(-1)
284 | ).squeeze(-1)
285 | selected_log_probs_list.append(selected_log_probs.detach().cpu())
286 | return selected_log_probs_list
287 |
288 | @torch.no_grad
289 | def sample_new_params(
290 | self,
291 | ):
292 | pop_values = (
293 | torch.rand(size=[self.pop_size, self.total_trainable_params])
294 | * self.range_trainable_param
295 | ) + self.min_trainable_param
296 | if self.re_eval_best:
297 | pop_values[0] = torch.detach_copy(self.best_soln)
298 |
299 | self.pop_params.data.copy_(pop_values)
300 |
301 | def split_and_convert(self, flat_params):
302 | split_flat_params = torch.split_with_sizes(
303 | flat_params, split_sizes=self.trainable_params_nums
304 | )
305 | split_params = [
306 | torch.reshape(p, shape=s).to(dtype=self.trainable_params_dtype).to(self.gpu)
307 | for p, s in zip(split_flat_params, self.trainable_params_shapes)
308 | ]
309 | return split_params
310 |
311 | def get_params_for_pop_member(self, pop_idx):
312 | return self.split_and_convert(self.pop_params[pop_idx])
313 |
314 | @torch.no_grad
315 | def step_optimization(
316 | self,
317 | model_id,
318 | model,
319 | tokenizer,
320 | policy,
321 | task_loader,
322 | batch_ix,
323 | train_data,
324 | train_eval,
325 | base_params,
326 | decomposed_params,
327 | metrics_to_log,
328 | vllm_model=None,
329 | **kwargs,
330 | ):
331 | self.sample_new_params()
332 | perf_per_pop = []
333 | avg_log_likelihoods_per_pop = []
334 | for pop_idx in range(self.pop_size):
335 | pop_idx_params = self.split_and_convert(
336 | flat_params=self.pop_params[pop_idx]
337 | )
338 | policy.set_trainable_params_values(new_values=pop_idx_params)
339 | learnable_params = policy.get_learnable_params()
340 | new_params = forward(
341 | policy, model, base_params, decomposed_params, learnable_params
342 | )
343 |
344 | print("Loading weights and getting completions with VLLM")
345 | load_hf_params_to_vllm(new_params, vllm_model.llm)
346 | res = eval_model(vllm_model, train_eval, batch_ix)
347 | if self.use_loglikelihood_for_ties:
348 | print("Storing log likelihhods")
349 | rewards = task_loader.get_rewards(res=res)
350 | correct = [int(r > 0) for r in rewards]
351 | correct_batch_ix = [i for i, c in zip(batch_ix, correct) if c]
352 | if len(correct_batch_ix) > 0:
353 | avg_log_likelihoods = []
354 | correct_prompts = [
355 | task_loader.get_prompt(
356 | tokenizer,
357 | train_data,
358 | i,
359 | model_id=model_id,
360 | )
361 | for i in correct_batch_ix
362 | ]
363 | correct_outputs = [
364 | res.sample_details[j]["output"]
365 | for j, c in enumerate(correct)
366 | if c
367 | ]
368 | selected_log_probs_list = self.compute_logprobs(
369 | model=model,
370 | tokenizer=tokenizer,
371 | prompts=correct_prompts,
372 | generated_outputs=correct_outputs,
373 | )
374 | for selected_log_probs in selected_log_probs_list:
375 | avg_log_likelihood = selected_log_probs.mean(axis=-1)
376 | avg_log_likelihoods.append(avg_log_likelihood.item())
377 | avg_log_likelihoods_per_pop.append(np.mean(avg_log_likelihoods))
378 | else:
379 | avg_log_likelihoods_per_pop.append(0.0)
380 |
381 | perf = res.aggregate_metrics[task_loader.target_metric_train]
382 | perf_per_pop.append(perf)
383 |
384 | perf_stats = get_mean_std_max_min_dict(array=perf_per_pop, prefix="pop_perf")
385 | metrics_to_log.update(**perf_stats)
386 |
387 | if self.use_loglikelihood_for_ties:
388 | perf_per_pop_array = np.array(perf_per_pop)
389 | loglikelihood_array = np.array(avg_log_likelihoods_per_pop)
390 | max_perf = perf_per_pop_array == np.max(perf_per_pop_array)
391 | max_perf_idxs = np.flatnonzero(max_perf)
392 | max_perf_logprobs = loglikelihood_array[max_perf_idxs]
393 | print("SC CHECK")
394 | print(perf_per_pop)
395 | print(loglikelihood_array)
396 | print(max_perf_idxs)
397 | best_logprob_idx = np.argmax(max_perf_logprobs)
398 | best_member_idx = max_perf_idxs[best_logprob_idx]
399 | print(best_logprob_idx)
400 | print(best_member_idx)
401 | logprobs_stats = get_mean_std_max_min_dict(
402 | array=max_perf_logprobs, prefix="logprobs_correct"
403 | )
404 | metrics_to_log.update(**logprobs_stats)
405 | else:
406 | best_member_idx = np.argmax(perf_per_pop)
407 | self.best_idx = best_member_idx
408 | best_params = self.pop_params[best_member_idx].cpu()
409 | self.best_soln.data.copy_(
410 | best_params * (1 - self.optim_ema) + self.optim_ema * self.best_soln.cpu()
411 | )
412 |
413 | def update(self, policy):
414 | policy.set_trainable_params_values(
415 | new_values=self.split_and_convert(self.best_soln)
416 | )
417 |
418 | def log_optim(self, metrics_to_log):
419 | pass
420 |
421 |
422 | class CEM(RandomShooting):
423 | def __init__(
424 | self,
425 | policy,
426 | gpu,
427 | elite_ratio,
428 | pop_size,
429 | min_trainable_param,
430 | max_trainable_param,
431 | optim_ema=0,
432 | re_eval_best=True,
433 | use_loglikelihood_for_ties=False,
434 | **kwargs,
435 | ):
436 |
437 | RandomShooting.__init__(
438 | self=self,
439 | policy=policy,
440 | gpu=gpu,
441 | pop_size=pop_size,
442 | min_trainable_param=min_trainable_param,
443 | max_trainable_param=max_trainable_param,
444 | optim_ema=optim_ema,
445 | re_eval_best=re_eval_best,
446 | use_loglikelihood_for_ties=use_loglikelihood_for_ties,
447 | **kwargs,
448 | )
449 |
450 | self.elite_ratio = elite_ratio
451 | self.num_elites = int(elite_ratio * pop_size)
452 | self.dist_mean = nn.Parameter(
453 | torch.detach_copy(self.best_soln), requires_grad=False
454 | ).cpu()
455 | init_stdev = (
456 | torch.ones([self.total_trainable_params]) * self.range_trainable_param / 2
457 | )
458 | self.dist_std = nn.Parameter(init_stdev, requires_grad=False).cpu()
459 |
460 | @torch.no_grad
461 | def sample_new_params(
462 | self,
463 | ):
464 | pop_values = (
465 | torch.randn(size=[self.pop_size, self.total_trainable_params])
466 | * self.dist_std
467 | ) + self.dist_mean
468 | pop_values = torch.clamp(
469 | pop_values,
470 | min=self.min_trainable_param,
471 | max=self.max_trainable_param,
472 | )
473 | if self.re_eval_best:
474 | pop_values[0] = torch.detach_copy(self.best_soln)
475 |
476 | self.pop_params.data.copy_(pop_values)
477 |
478 | @torch.no_grad
479 | def step_optimization(
480 | self,
481 | model_id,
482 | model,
483 | tokenizer,
484 | policy,
485 | task_loader,
486 | batch_ix,
487 | train_data,
488 | train_eval,
489 | base_params,
490 | decomposed_params,
491 | metrics_to_log,
492 | vllm_model=None,
493 | **kwargs,
494 | ):
495 | self.sample_new_params()
496 | perf_per_pop = []
497 | avg_log_likelihoods_per_pop = []
498 | for pop_idx in range(self.pop_size):
499 | pop_idx_params = self.split_and_convert(
500 | flat_params=self.pop_params[pop_idx]
501 | )
502 | policy.set_trainable_params_values(new_values=pop_idx_params)
503 | learnable_params = policy.get_learnable_params()
504 | new_params = forward(
505 | policy, model, base_params, decomposed_params, learnable_params
506 | )
507 |
508 | print("Loading weights and getting completions with VLLM")
509 | load_hf_params_to_vllm(new_params, vllm_model.llm)
510 | res = eval_model(vllm_model, train_eval, batch_ix)
511 | if self.use_loglikelihood_for_ties:
512 | print("Storing log likelihhods")
513 | rewards = task_loader.get_rewards(res=res)
514 | correct = [int(r > 0) for r in rewards]
515 | correct_batch_ix = [i for i, c in zip(batch_ix, correct) if c]
516 | if len(correct_batch_ix) > 0:
517 | avg_log_likelihoods = []
518 | correct_prompts = [
519 | task_loader.get_prompt(
520 | tokenizer,
521 | train_data,
522 | i,
523 | model_id=model_id,
524 | )
525 | for i in correct_batch_ix
526 | ]
527 | correct_outputs = [
528 | res.sample_details[j]["output"]
529 | for j, c in enumerate(correct)
530 | if c
531 | ]
532 | print("lalala, I am hitting the selected_log_probs!")
533 | selected_log_probs_list = self.compute_logprobs(
534 | model=model,
535 | tokenizer=tokenizer,
536 | prompts=correct_prompts,
537 | generated_outputs=correct_outputs,
538 | )
539 | for selected_log_probs in selected_log_probs_list:
540 | avg_log_likelihood = selected_log_probs.mean(axis=-1)
541 | avg_log_likelihoods.append(avg_log_likelihood.item())
542 | avg_log_likelihoods_per_pop.append(np.mean(avg_log_likelihoods))
543 | else:
544 | avg_log_likelihoods_per_pop.append(0.0)
545 |
546 | perf = res.aggregate_metrics[task_loader.target_metric_train]
547 | perf_per_pop.append(perf)
548 |
549 | perf_stats = get_mean_std_max_min_dict(array=perf_per_pop, prefix="pop_perf")
550 | metrics_to_log.update(**perf_stats)
551 |
552 | if self.use_loglikelihood_for_ties:
553 | perf_per_pop_array = np.array(perf_per_pop)
554 | loglikelihood_array = np.array(avg_log_likelihoods_per_pop)
555 | max_perf = perf_per_pop_array == np.max(perf_per_pop_array)
556 | max_perf_idxs = np.flatnonzero(max_perf)
557 | max_perf_logprobs = loglikelihood_array[max_perf_idxs]
558 | best_logprob_idx = np.argmax(max_perf_logprobs)
559 | best_member_idx = max_perf_idxs[best_logprob_idx]
560 | logprobs_stats = get_mean_std_max_min_dict(
561 | array=max_perf_logprobs, prefix="logprobs_correct"
562 | )
563 | metrics_to_log.update(**logprobs_stats)
564 | else:
565 | best_member_idx = np.argmax(perf_per_pop)
566 | elite_idxs = np.argpartition(perf_per_pop, -self.num_elites)[-self.num_elites :]
567 |
568 | elite_params = self.pop_params[elite_idxs]
569 | elite_mean = torch.mean(elite_params, dim=0)
570 | elite_std = torch.std(elite_params, dim=0)
571 | self.best_idx = best_member_idx
572 | best_params = self.pop_params[best_member_idx].cpu()
573 | self.best_soln.data.copy_(best_params)
574 | self.dist_mean.copy_(
575 | elite_mean.cpu() * (1 - self.optim_ema)
576 | + self.optim_ema * self.dist_mean.cpu()
577 | )
578 | self.dist_std.copy_(
579 | elite_std.cpu() * (1 - self.optim_ema)
580 | + self.optim_ema * self.dist_std.cpu()
581 | )
582 |
583 | cem_mean_stats = get_mean_std_max_min_dict(
584 | array=self.dist_mean.detach().cpu().numpy(),
585 | prefix="cem_mean",
586 | )
587 | metrics_to_log.update(**cem_mean_stats)
588 |
589 | cem_std_stats = get_mean_std_max_min_dict(
590 | array=self.dist_std.detach().cpu().numpy(),
591 | prefix="cem_std",
592 | )
593 | metrics_to_log.update(**cem_std_stats)
594 |
--------------------------------------------------------------------------------
/policy/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Policy
2 | from .weighted_combination import WeightedCombination
3 |
--------------------------------------------------------------------------------
/policy/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def get_soft_mask(n, fraction):
6 | indices = torch.linspace(0, n - 1, n, dtype=torch.bfloat16) + 1
7 | scaled_indices = indices.to(fraction.device) - fraction * n
8 | result = torch.clamp(scaled_indices, 0, 1)
9 | return 1.0 - result
10 |
11 |
12 | class Policy(nn.Module):
13 | def __init__(self, base_params, gpu, init_val, max_mult=1, **kwargs):
14 | # Create learnable parameters.
15 | super().__init__()
16 | self.learnable_params = {}
17 | self.num_params = 0
18 | self.max_mult = max_mult
19 | for k, v in base_params.items():
20 | # each param initialized with small gaussian noise
21 | if "mlp" in k:
22 | self.learnable_params[k] = torch.nn.Parameter(
23 | data=(
24 | torch.randn(
25 | min(v.shape),
26 | device=gpu,
27 | dtype=torch.bfloat16,
28 | )
29 | * 0.01
30 | + init_val
31 | ),
32 | requires_grad=True,
33 | )
34 | self.num_params += self.learnable_params[k].numel()
35 | print(f"#params={self.num_params}")
36 | self.learnable_params_list = list(self.learnable_params.values())
37 | self.trainable_params = self.learnable_params_list
38 | self.learnable_params_module_list = nn.ParameterList(self.learnable_params_list)
39 |
40 | def get_learnable_params(self, detach=False):
41 | return self.learnable_params
42 |
43 | def set_trainable_params_values(self, new_values):
44 | with torch.no_grad():
45 | for p, v in zip(self.trainable_params, new_values):
46 | p.data.copy_(v)
47 |
48 | def get_mask(self, p):
49 | return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult
50 |
51 | def record_state(self, metrics_to_log):
52 | pass
53 |
--------------------------------------------------------------------------------
/policy/weighted_combination.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union
2 |
3 | import hydra
4 | import torch
5 | import torch.nn as nn
6 | from omegaconf import DictConfig
7 |
8 | from .base import Policy
9 |
10 |
11 | class WeightedCombination(Policy):
12 | def __init__(
13 | self,
14 | base_params,
15 | decomposed_params,
16 | base_policy_cfg: Optional[Union[DictConfig, int]],
17 | params_paths: List[str],
18 | gpu,
19 | norm_coeffs,
20 | per_layer,
21 | init_values: Optional[List[float]] = None,
22 | **kwargs,
23 | ):
24 | # Create learnable parameters.
25 | nn.Module.__init__(self=self)
26 | weights_dict_list: List[Dict[str, torch.Tensor]] = []
27 | if base_policy_cfg is None:
28 | base_policy = Policy(base_params=base_params, gpu=gpu, init_val=0)
29 | elif isinstance(base_policy_cfg, DictConfig):
30 | base_policy: Policy = hydra.utils.instantiate(
31 | base_policy_cfg,
32 | base_params=base_params,
33 | decomposed_params=decomposed_params,
34 | gpu=gpu,
35 | )
36 | else:
37 | raise NotImplementedError
38 |
39 | with torch.no_grad():
40 | for i, load_ckpt in enumerate(params_paths):
41 | print(f"Loading checkpoint {i} at {load_ckpt}...")
42 | if "learnable_params" in load_ckpt:
43 | learnable_params = torch.load(load_ckpt)
44 | else:
45 | state_dict = torch.load(load_ckpt, weights_only=True)
46 | base_policy.load_state_dict(state_dict=state_dict)
47 | learnable_params = base_policy.get_learnable_params()
48 | weights_dict_list.append(
49 | {k: torch.detach_copy(p) for k, p in learnable_params.items()}
50 | )
51 |
52 | self.num_weights_dict = len(weights_dict_list)
53 |
54 | self.num_params_per_weight_dict = 0
55 | for _ in weights_dict_list[0]:
56 | self.num_params_per_weight_dict += 1
57 |
58 | self.num_params = self.num_weights_dict * self.num_params_per_weight_dict
59 | if init_values is None:
60 | init_values = torch.Tensor(
61 | [1 / self.num_weights_dict for _ in range(self.num_weights_dict)]
62 | )
63 | else:
64 | assert len(init_values) == self.num_weights_dict
65 | init_values = torch.Tensor(init_values)
66 | self.learned_params_per_weight_dict = 1
67 | if per_layer:
68 | self.learned_params_per_weight_dict = self.num_params_per_weight_dict
69 | init_values = torch.stack(
70 | [init_values for _ in range(self.learned_params_per_weight_dict)], dim=1
71 | )
72 | if norm_coeffs:
73 | # Normalize across different weight idxs (for all layers)
74 | init_values = init_values / torch.sum(init_values, axis=0)
75 |
76 | # Num weight idxs x learned params_per_weight_idx
77 | self.adaptive_weights = torch.nn.Parameter(
78 | data=init_values,
79 | requires_grad=True,
80 | )
81 |
82 | self.parameter_keys = []
83 | self.original_params = {}
84 | for k, v in weights_dict_list[0].items():
85 | self.parameter_keys.append(k)
86 | self.original_params[k] = []
87 | for i, weight_dict in enumerate(weights_dict_list):
88 | weight_tensor = self.get_mask(p=weight_dict[k])
89 | new_key = k.replace(".", "_")
90 | self.register_buffer(
91 | f"weights_{i}_k_{new_key}",
92 | tensor=weight_tensor,
93 | )
94 | self.original_params[k].append(weight_tensor.to(device=gpu))
95 |
96 | self.norm = norm_coeffs
97 | self.per_layer = per_layer
98 | self.trainable_params = [self.adaptive_weights]
99 |
100 | def get_weight_to_combine(self, k, weights_dict_idx):
101 | new_key = k.replace(".", "_")
102 | return getattr(self, f"weights_{weights_dict_idx}_k_{new_key}")
103 |
104 | def get_coeff_per_layer(self):
105 | if self.norm:
106 | adaptive_weights = self.adaptive_weights / self.adaptive_weights.sum(0)
107 | else:
108 | adaptive_weights = self.adaptive_weights
109 | weights_per_layer = adaptive_weights.expand(
110 | [
111 | self.num_weights_dict,
112 | self.num_params_per_weight_dict,
113 | ]
114 | )
115 | return weights_per_layer
116 |
117 | def get_learnable_params(self):
118 | adaptive_coeff_per_layer = self.get_coeff_per_layer()
119 | output_params = {}
120 | for i, (k, vs) in enumerate(self.original_params.items()):
121 | cs_coeff = adaptive_coeff_per_layer[:, i]
122 | out = vs[0] * cs_coeff[0]
123 | for j, other_v in enumerate(vs[1:]):
124 | v_idx = j + 1
125 | out = out + other_v * cs_coeff[v_idx]
126 | output_params[k] = out
127 | return output_params
128 |
129 | def get_mask(self, p):
130 | return p
131 |
132 | def record_state(self, metrics_to_log):
133 | avg_weights = self.adaptive_weights.mean(-1).detach().cpu().numpy()
134 | dict_to_log = {
135 | f"adaptive_weight/mean_across_params_w{i}": w
136 | for i, w in enumerate(avg_weights.tolist())
137 | }
138 | metrics_to_log.update(**dict_to_log)
139 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.27.2
2 | datasets==2.14.6
3 | einops==0.7.0
4 | evaluate==0.4.1
5 | fasttext==0.9.2
6 | pandas==2.1.2
7 | Pillow==10.1.0
8 | torch==2.3.1
9 | torchvision==0.18.1
10 | tornado==6.4
11 | tqdm==4.66.1
12 | traitlets==5.13.0
13 | transformers==4.43.1
14 | trl==0.8.6
15 | vllm==0.5.3.post1
16 | hydra-core==1.3.2
17 | evalplus @ git+https://github.com/evalplus/evalplus@1895d2f6aa8895044a7cf69defc24bd57695e885
18 | peft
19 | fire
20 | matplotlib
21 | wandb
22 |
23 |
--------------------------------------------------------------------------------
/scripts/eval_few_shot.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 |
3 | # Prerequisites:
4 | # - Ensure reference_params_results path is set in cfgs/base_model/*.yaml
5 |
6 | # Task Selection
7 | TASK="few_shot_math" # Available options: few_shot_arc_challenge, few_shot_humaneval
8 |
9 | # Evaluation Setting
10 | PER_LAYER=true
11 | NORM_COEFFS=false
12 |
13 | # Start evaluation!
14 | CUDA_VISIBLE_DEVICES=0,1 python svd_reinforce_hydra.py \
15 | base_model@_global_=llama3i8b \
16 | optimization@_global_=cem \
17 | policy@_global_=wcomb \
18 | use_loglikelihood_for_ties=true \
19 | per_layer=$PER_LAYER \
20 | task@_global_=$TASK \
21 | norm_coeffs=$NORM_COEFFS \
22 | wandb_log=true \
23 | num_iters=50
--------------------------------------------------------------------------------
/scripts/eval_prompt_based.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Task Selection
4 | TASK="mbpp2" # Available options: mbpp2, math, ai2_arc
5 |
6 | # First Stage Inference: Classification Expert
7 | # Set to 'None' if not using cls expert
8 | CLS_EXPERT_PATH="None"
9 |
10 | # Second Stage: Expert Models
11 | # Replace these paths with your actual model paths
12 | CODE_EXPERT_PATH="your_path_to_code_expert"
13 | MATH_EXPERT_PATH="your_path_to_math_expert"
14 | REASONING_EXPERT_PATH="your_path_to_reasoning_expert"
15 |
16 | # Start evaluation!
17 | CUDA_VISIBLE_DEVICES=0,1 python svd_reinforce_hydra.py \
18 | base_model@_global_=llama3i8b \
19 | task@_global_=$TASK \
20 | mode@_global_=eval \
21 | prompt_based_eval=True \
22 | experts_path_dict.code=$CODE_EXPERT_PATH \
23 | experts_path_dict.math=$MATH_EXPERT_PATH \
24 | experts_path_dict.reasoning=$REASONING_EXPERT_PATH \
25 | load_ckpt=$CLS_EXPERT_PATH
--------------------------------------------------------------------------------
/scripts/train_task_expert.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 |
3 | # Task Selection
4 | TASK="mbpp2" # Available options: mbpp2, gsm8k, ai2_arc, cls
5 |
6 | # Training Setting
7 | NUM_ITERS=200
8 |
9 | # This script needs 2 gpus
10 | CUDA_VISIBLE_DEVICES=0,1 python svd_reinforce_hydra.py \
11 | base_model@_global_=llama3i8b \
12 | task@_global_=$TASK \
13 | mode@_global_=training \
14 | num_iters=$NUM_ITERS
--------------------------------------------------------------------------------
/svd_reinforce_hydra.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import os
4 | from datetime import datetime
5 | from typing import Dict
6 |
7 | import hydra
8 | import numpy as np
9 | import torch
10 | from omegaconf import OmegaConf
11 | from transformers import AutoModelForCausalLM, AutoTokenizer
12 |
13 | from base_model import BaseModel
14 | from logging_utils import Metrics, get_mean_std_max_min_dict
15 | from optim_modules import OptimizationAlgorithm
16 | from policy import Policy
17 | from tasks import Task
18 | from utils import (eval_model, eval_model_experts_prompt_based, forward,
19 | load_hf_params_to_vllm)
20 |
21 |
22 | def wandb_init(cfg, run_name: str, group_name: str, log_dir: str):
23 | import wandb
24 |
25 | config_dict = OmegaConf.to_container(
26 | cfg,
27 | resolve=True,
28 | throw_on_missing=False,
29 | )
30 | config_dict["log_dir"] = log_dir
31 | config_dict["wandb_run_name"] = run_name
32 | config_dict["wandb_group_name"] = group_name
33 |
34 | # wandb has a 128-size character limit on the group name
35 | wandb.init(
36 | project=cfg.wandb_project,
37 | group=group_name[:127],
38 | name=run_name[:127],
39 | config=config_dict,
40 | )
41 | return wandb
42 |
43 |
44 | @hydra.main(version_base=None, config_path="cfgs", config_name="config")
45 | def main(cfg):
46 | """Main function."""
47 |
48 | num_iters = cfg.num_iters
49 | test_interval = cfg.test_interval
50 |
51 | batch_size = cfg.batch_size
52 | seed = cfg.seed
53 | policy_name = cfg.policy_name
54 | test_only = cfg.test_only
55 | save_legacy_params = cfg.save_legacy_params
56 | exp_name = cfg.exp_name
57 | run_name = cfg.run_name
58 |
59 | task_name = cfg.task_name
60 |
61 | load_ckpt = cfg.load_ckpt
62 | use_lora = cfg.use_lora
63 | prompt_based_eval = cfg.prompt_based_eval
64 | experts_path_dict = cfg.experts_path_dict
65 |
66 | resuming_from_ckpt = False
67 | if load_ckpt is not None:
68 | if load_ckpt == "scratch" or load_ckpt == "base":
69 | resuming_from_ckpt = False
70 | else:
71 | resuming_from_ckpt = True
72 |
73 | # Create task
74 | task_loader: Task = hydra.utils.instantiate(cfg.task_loader)
75 |
76 | base_model: BaseModel = hydra.utils.instantiate(cfg.base_model)
77 |
78 | model_id = base_model.get_model_id()
79 | decomposed_param_file = base_model.get_param_file(param_folder_path="")
80 |
81 | extract_svd = cfg.extract_svd or (not os.path.exists(decomposed_param_file))
82 |
83 | has_training_split = task_loader.has_training_split
84 | has_transfer_split = task_loader.has_transfer_split
85 |
86 | if not has_training_split:
87 | assert test_only, "Cannot train on a task with no training split"
88 |
89 | if exp_name is None:
90 | exp_name = "temp"
91 |
92 | metrics_to_log = Metrics()
93 |
94 | # Create log dir.
95 | if run_name is None:
96 | now = datetime.now()
97 | run_name = now.strftime("%Y%m%d-%H%M%S")
98 | if test_only and (not resuming_from_ckpt):
99 | log_dir = f"{cfg.out_dir}/{task_name}/{cfg.base_model_name}_base"
100 | group_name = cfg.base_model_name
101 | else:
102 | log_dir = f"{cfg.out_dir}/{task_name}/{policy_name}/{exp_name}/{run_name}"
103 | group_name = cfg.wandb_group_name
104 | os.makedirs(log_dir, exist_ok=True)
105 |
106 | vllm_model = task_loader.get_vllm_model(model_id=model_id)
107 |
108 | train_eval, *test_evals = task_loader.get_evaluator()
109 | if task_loader.has_transfer_split:
110 | test_eval, transfer_eval = test_evals
111 | else:
112 | test_eval = test_evals[0]
113 |
114 | train_data, train_ix, valid_ix = task_loader.get_train_data()
115 | gpu = torch.device("cuda:1")
116 | np_random = np.random.RandomState(seed)
117 |
118 | # cpu + float32 for initial SVD decomposition
119 | if extract_svd:
120 | model = AutoModelForCausalLM.from_pretrained(
121 | model_id, device_map="cpu", torch_dtype=torch.float32
122 | )
123 | else:
124 | # Load model and tokenizer.
125 | model = AutoModelForCausalLM.from_pretrained(
126 | model_id, device_map="cuda:1", torch_dtype=torch.bfloat16
127 | )
128 | tokenizer = AutoTokenizer.from_pretrained(model_id)
129 | base_params = model.state_dict()
130 |
131 | original_model_params = {
132 | k: v.clone().detach().cpu() for k, v in base_params.items() if "mlp" in k
133 | }
134 |
135 | # Load decomposed parameters.
136 | if not os.path.exists(decomposed_param_file):
137 | print("Decomposed params not found. Decomposing...")
138 | decomposed_params = {}
139 | for k, v in base_params.items():
140 | if "norm" not in k:
141 | print(k)
142 | U, S, V = torch.svd(v)
143 | decomposed_params[f"{k}.U"] = U
144 | decomposed_params[f"{k}.S"] = S
145 | decomposed_params[f"{k}.V"] = V
146 | torch.save(decomposed_params, decomposed_param_file)
147 | print("successfully decomposed model - returning")
148 | return
149 | elif extract_svd:
150 | print(f"ERROR: SVD file already exists at {decomposed_param_file}")
151 | else:
152 | print("Decomposed params found. Loading...")
153 | assert not extract_svd
154 | decomposed_params = torch.load(decomposed_param_file)
155 | for k, v in decomposed_params.items():
156 | decomposed_params[k] = v.to(torch.bfloat16).to(gpu)
157 |
158 | if cfg.wandb_log:
159 | wandb = wandb_init(
160 | cfg=cfg, group_name=group_name, run_name=run_name, log_dir=log_dir
161 | )
162 |
163 | policy: Policy = hydra.utils.instantiate(
164 | cfg.shakeoff_policy,
165 | base_params=base_params,
166 | decomposed_params=decomposed_params,
167 | gpu=gpu,
168 | )
169 |
170 | optimization_algorithm: OptimizationAlgorithm = hydra.utils.instantiate(
171 | cfg.optimization_algorithm,
172 | policy=policy,
173 | gpu=gpu,
174 | )
175 |
176 | if resuming_from_ckpt and os.path.exists(load_ckpt):
177 | print(f"Starting from checkpoint at: {load_ckpt}")
178 | # load the lora weight
179 | if use_lora:
180 | assert os.path.isdir(load_ckpt), "ckpt for lora must be dir to lora adapter"
181 | from peft import PeftModel
182 |
183 | lora_model = PeftModel.from_pretrained(model, load_ckpt)
184 | merged_model = lora_model.merge_and_unload()
185 | new_params = merged_model.state_dict()
186 | # load svd expert
187 | elif "learnable_params" in load_ckpt:
188 | learnable_params = torch.load(load_ckpt)
189 | for k, v in learnable_params.items():
190 | learnable_params[k] = v.to(gpu)
191 | assert test_only
192 | new_params = forward(
193 | policy, model, base_params, decomposed_params, learnable_params
194 | )
195 | else:
196 | state_dict = torch.load(load_ckpt, weights_only=True)
197 | policy.load_state_dict(state_dict=state_dict)
198 | if test_only:
199 | learnable_params = policy.get_learnable_params()
200 | new_params = forward(
201 | policy, model, base_params, decomposed_params, learnable_params
202 | )
203 | load_hf_params_to_vllm(new_params, vllm_model.llm)
204 | else:
205 | print(f"Starting from the base model as load_ckpt=={load_ckpt}")
206 |
207 | model.eval()
208 |
209 | # Prompt-based and cls dispatcher evaluation.
210 | if test_only and prompt_based_eval:
211 | test_data_dict = eval_model_experts_prompt_based(
212 | vllm_model,
213 | test_eval,
214 | experts_path_dict,
215 | policy,
216 | model,
217 | base_params,
218 | decomposed_params,
219 | task_loader.target_metric_test,
220 | )
221 | test_data_dict["type"] = "test"
222 | # Log the results.
223 | if cfg.wandb_log:
224 | wandb.log(test_data_dict)
225 | with open(f"{log_dir}/eval_results.json", "w") as f:
226 | json.dump(test_data_dict, f, indent=4)
227 | print(f"Test evaluation results: {test_data_dict}")
228 |
229 | # Eval the transfer set if available
230 | if has_transfer_split:
231 | transfer_data_dict = eval_model_experts_prompt_based(
232 | vllm_model,
233 | transfer_eval,
234 | experts_path_dict,
235 | policy,
236 | model,
237 | base_params,
238 | decomposed_params,
239 | task_loader.target_metric_transfer,
240 | )
241 | transfer_data_dict["type"] = "transfer"
242 | # Log the results.
243 | if cfg.wandb_log:
244 | wandb.log(transfer_data_dict)
245 | with open(f"{log_dir}/eval_results.json", "w") as f:
246 | json.dump(transfer_data_dict, f, indent=4)
247 | print(f"Transfer evaluation results: {transfer_data_dict}")
248 |
249 | return
250 |
251 | # Non-adaptive evaluation on train, val, test set.
252 | if test_only and not prompt_based_eval:
253 | data_dict = {}
254 | details_dict = {}
255 | if has_training_split:
256 | train_res = eval_model(vllm_model, train_eval, train_ix)
257 | valid_res = eval_model(vllm_model, train_eval, valid_ix)
258 | data_dict["train_acc"] = train_res.aggregate_metrics[
259 | task_loader.target_metric_train
260 | ]
261 | data_dict["valid_acc"] = valid_res.aggregate_metrics[
262 | task_loader.target_metric_valid
263 | ]
264 | details_dict["train"] = train_res.sample_details
265 | details_dict["valid"] = valid_res.sample_details
266 | test_res = eval_model(vllm_model, test_eval)
267 | data_dict["test_acc"] = test_res.aggregate_metrics[
268 | task_loader.target_metric_test
269 | ]
270 | details_dict["test"] = test_res.sample_details
271 | if has_transfer_split:
272 | transfer_res = eval_model(vllm_model, transfer_eval)
273 | data_dict["transfer_acc"] = transfer_res.aggregate_metrics[
274 | task_loader.target_metric_transfer
275 | ]
276 | details_dict["transfer"] = transfer_res.sample_details
277 | if cfg.wandb_log:
278 | wandb.log(data_dict)
279 | with open(f"{log_dir}/eval_results.json", "w") as f:
280 | json.dump(data_dict, f, indent=4)
281 | print(f"Evaluation results: {data_dict}")
282 | return
283 |
284 | learnable_params = policy.get_learnable_params()
285 | for k in learnable_params:
286 | model.get_parameter(k).requires_grad_(True)
287 |
288 | # Training loop.
289 | if batch_size is None:
290 | clipped_batch_size = len(list(train_ix))
291 | else:
292 | clipped_batch_size = min(batch_size, len(list(train_ix)))
293 | best_val_acc = 0.0
294 | test_at_best = 0.0
295 | transfer_at_best = 0.0
296 | for i in range(num_iters):
297 |
298 | batch_ix = np_random.choice(train_ix, size=clipped_batch_size, replace=False)
299 |
300 | optimization_algorithm.step_optimization(
301 | model_id=model_id,
302 | model=model,
303 | tokenizer=tokenizer,
304 | policy=policy,
305 | task_loader=task_loader,
306 | batch_ix=batch_ix,
307 | train_data=train_data,
308 | train_eval=train_eval,
309 | base_params=base_params,
310 | decomposed_params=decomposed_params,
311 | original_model_params=original_model_params,
312 | metrics_to_log=metrics_to_log,
313 | vllm_model=vllm_model,
314 | )
315 |
316 | with torch.no_grad():
317 | lists_to_log = {}
318 | grads = [p.grad for p in policy.trainable_params]
319 | if grads[0] is not None:
320 | grad_mean = [g.mean().item() for g in grads]
321 | grad_mags = [torch.linalg.vector_norm(g).item() for g in grads]
322 | lists_to_log["grad_mean"] = grad_mean
323 | lists_to_log["grad_mags"] = grad_mags
324 |
325 | param_mags = [
326 | torch.linalg.vector_norm(p).item() for p in policy.trainable_params
327 | ]
328 | lists_to_log["policy_param_mag"] = param_mags
329 |
330 | generated_params_list = list(learnable_params.values())
331 |
332 | generated_param_mean = [p.mean().item() for p in generated_params_list]
333 | generated_param_mags = [
334 | torch.linalg.vector_norm(p).item() for p in generated_params_list
335 | ]
336 | lists_to_log["generated_param_mean"] = generated_param_mean
337 | lists_to_log["generated_param_mags"] = generated_param_mags
338 |
339 | list_stats = {}
340 | for k, v in lists_to_log.items():
341 | list_stats.update(get_mean_std_max_min_dict(array=v, prefix=k))
342 | metrics_to_log.update(**list_stats)
343 |
344 | optimization_algorithm.update(policy=policy)
345 |
346 | # Make sure old params are deleted and garbage-collected
347 | gc.collect()
348 | torch.cuda.empty_cache()
349 | model.zero_grad()
350 |
351 | # More accurate logging.
352 | value_mean = list_stats.get("generated_param_mean/mean", None)
353 | grad_mean_mag = list_stats.get("grad_mags/mean", None)
354 | print(
355 | f"Iter {i}: "
356 | + f"param_mean={value_mean}, "
357 | + f"grad_mean_mag={grad_mean_mag}"
358 | )
359 | optimization_algorithm.log_optim(metrics_to_log=metrics_to_log)
360 |
361 | # Test and save.
362 | if i % test_interval == 0:
363 | learnable_params = policy.get_learnable_params()
364 | forward(policy, model, base_params, decomposed_params, learnable_params)
365 | load_hf_params_to_vllm(model.state_dict(), vllm_model.llm)
366 |
367 | train_res = eval_model(vllm_model, train_eval, train_ix)
368 | valid_res = eval_model(vllm_model, train_eval, valid_ix)
369 | test_res = eval_model(vllm_model, test_eval)
370 | if has_transfer_split:
371 | transfer_res = eval_model(vllm_model, transfer_eval)
372 | if (
373 | valid_res.aggregate_metrics[task_loader.target_metric_valid]
374 | > best_val_acc
375 | ):
376 | best_val_acc = valid_res.aggregate_metrics[
377 | task_loader.target_metric_valid
378 | ]
379 | test_at_best = test_res.aggregate_metrics[
380 | task_loader.target_metric_test
381 | ]
382 | if has_transfer_split:
383 | transfer_at_best = transfer_res.aggregate_metrics[
384 | task_loader.target_metric_transfer
385 | ]
386 | print("best_val_acc updated")
387 | path = f"{log_dir}/policy_params.pt"
388 | torch.save(policy.state_dict(), path)
389 | if save_legacy_params:
390 | torch.save(learnable_params, f"{log_dir}/learnable_params.pt")
391 |
392 | path = f"{log_dir}/policy_params_latest.pt"
393 | torch.save(policy.state_dict(), path)
394 | if save_legacy_params:
395 | torch.save(learnable_params, f"{log_dir}/learnable_params_latest.pt")
396 |
397 | policy.record_state(metrics_to_log=metrics_to_log)
398 | data_dict = {
399 | "iter": i,
400 | "best_val_acc": best_val_acc,
401 | "test_at_best_val": test_at_best,
402 | "train_acc": train_res.aggregate_metrics[
403 | task_loader.target_metric_train
404 | ],
405 | "valid_acc": valid_res.aggregate_metrics[
406 | task_loader.target_metric_valid
407 | ],
408 | "test_acc": test_res.aggregate_metrics[task_loader.target_metric_test],
409 | **metrics_to_log.get(),
410 | }
411 | if has_transfer_split:
412 | data_dict["transfer_acc"] = transfer_res.aggregate_metrics[
413 | task_loader.target_metric_transfer
414 | ]
415 | data_dict["transfer_at_best_val"] = transfer_at_best
416 | if cfg.wandb_log:
417 | wandb.log(data_dict)
418 | with open(f"{log_dir}/reinforce_log.json", "a") as f:
419 | json_data = json.dumps(data_dict, indent=4)
420 | f.write(json_data)
421 | f.write("\n")
422 | metrics_to_log.reset()
423 |
424 |
425 | if __name__ == "__main__":
426 | main()
427 |
--------------------------------------------------------------------------------
/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .arc import AI2ArcTask
2 | from .base import FewShotTask, Task
3 | from .cls import ClsTask
4 | from .gsm8k import Gsm8kTask
5 | from .math import MathTask
6 | from .mbpp2 import Mbpp2Task
7 |
--------------------------------------------------------------------------------
/tasks/arc.py:
--------------------------------------------------------------------------------
1 | import fishfarm
2 | import vllm
3 | from datasets import load_dataset
4 | from fishfarm.models.vllm_model import VLLMModel
5 | from fishfarm.tasks.ai2_arc import Ai2ArcSample, Ai2ArcTask
6 |
7 | from .base import LLAMA3_COT, Task, get_download_dir
8 |
9 | choices = ["A", "B", "C", "D", "E"]
10 |
11 |
12 | class AI2ArcTask(Task):
13 | def __init__(
14 | self,
15 | ):
16 | self.model_to_template = {
17 | "meta-llama/Meta-Llama-3-8B-Instruct": LLAMA3_COT,
18 | "mistralai/Mistral-7B-Instruct-v0.3": None,
19 | }
20 | self.system_msg = (
21 | "The following are multiple choice questions (with answers). "
22 | "Think step by step and then finish your answer "
23 | 'with "the answer is (X)" where X is the correct letter choice.'
24 | )
25 | self.target_metric_train = "acc"
26 | self.target_metric_valid = self.target_metric_train
27 | self.target_metric_test = self.target_metric_train
28 | self.target_metric_transfer = self.target_metric_train
29 | self.has_transfer_split = True
30 | self.has_training_split = True
31 |
32 | def get_train_data(
33 | self,
34 | ):
35 | train_eval, *test_evals = self.get_evaluator()
36 | train_data = train_eval.samples
37 | train_size = len(train_data)
38 | train_ix = range(0, train_size, 2)
39 | valid_ix = range(1, train_size, 2)
40 | return train_data, train_ix, valid_ix
41 |
42 | def get_rewards(self, res):
43 | rewards = [1.0 if x["correct"] else -1.0 for x in res.sample_details]
44 | return rewards
45 |
46 | def get_evaluator(
47 | self,
48 | ):
49 | res = []
50 | for split in ["train", "test"]:
51 | dataset = load_dataset("allenai/ai2_arc", "ARC-Easy", split=split)
52 | samples = []
53 | for sample in dataset:
54 | options = []
55 | for opt in sample["choices"]["text"]:
56 | options.append(opt)
57 | # add options to the question
58 | question = sample["question"] + "\n"
59 | question += "Options:\n"
60 | for i, opt in enumerate(options):
61 | question += "{}. {}\n".format(choices[i], opt)
62 | samples.append(
63 | Ai2ArcSample(
64 | question=question,
65 | answer=sample["answerKey"],
66 | options=options,
67 | question_id=sample["id"],
68 | )
69 | )
70 | res.append(
71 | Ai2ArcTask(
72 | samples=samples,
73 | context_messages=[
74 | fishfarm.Message("system", self.system_msg),
75 | ],
76 | )
77 | )
78 | dataset = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
79 | samples = []
80 | for sample in dataset:
81 | options = []
82 | for opt in sample["choices"]["text"]:
83 | options.append(opt)
84 | # add options to the question
85 | question = sample["question"] + "\n"
86 | question += "Options:\n"
87 | for i, opt in enumerate(options):
88 | question += "{}. {}\n".format(choices[i], opt)
89 | samples.append(
90 | Ai2ArcSample(
91 | question=question,
92 | answer=sample["answerKey"],
93 | options=options,
94 | question_id=sample["id"],
95 | )
96 | )
97 | res.append(
98 | Ai2ArcTask(
99 | samples=samples,
100 | context_messages=[
101 | fishfarm.Message("system", self.system_msg),
102 | ],
103 | )
104 | )
105 | return tuple(res)
106 |
107 | def get_prompt(self, tokenizer, samples, ix, model_id):
108 | chat_template = self.model_to_template[model_id]
109 | context_msg = {"role": "system", "content": self.system_msg}
110 | user_msg = {"role": "user", "content": samples[ix].question}
111 | prompt = tokenizer.apply_chat_template(
112 | conversation=[context_msg, user_msg],
113 | chat_template=chat_template,
114 | tokenize=False,
115 | add_generation_prompt=True,
116 | )
117 | return prompt
118 |
119 | def get_vllm_model(self, model_id) -> VLLMModel:
120 | """Load a vLLM model."""
121 | model = vllm.LLM(
122 | model_id,
123 | max_model_len=1024,
124 | gpu_memory_utilization=0.8,
125 | enforce_eager=True,
126 | dtype="bfloat16",
127 | download_dir=get_download_dir(),
128 | )
129 | chat_template = self.model_to_template[model_id]
130 | # This may change with vLLM versions.
131 | m = model.llm_engine.model_executor.driver_worker.model_runner.model
132 | for _, param in m.named_parameters():
133 | param.requires_grad = False
134 | vllm_model = VLLMModel(
135 | model,
136 | sampling_params=vllm.SamplingParams(
137 | temperature=0,
138 | top_p=1,
139 | max_tokens=512,
140 | stop=["Instruction:", "Instruction", "Response:", "Response"],
141 | repetition_penalty=1.0,
142 | ),
143 | chat_template=chat_template,
144 | )
145 | return vllm_model
146 |
--------------------------------------------------------------------------------
/tasks/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import ABC, abstractmethod
3 | from typing import Union
4 |
5 | import hydra
6 | from omegaconf import DictConfig
7 |
8 | LLAMA3_COT = (
9 | "{% set loop_messages = messages %}"
10 | "{% for message in loop_messages %}"
11 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>"
12 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
13 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}"
14 | "{% endif %}"
15 | "{{ content }}"
16 | "{% endfor %}"
17 | "{% if add_generation_prompt %}"
18 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + 'Let\\'s think step by step' }}"
19 | "{% endif %}"
20 | )
21 |
22 |
23 | CODE_PROMPT = r"""
24 | {% if messages[0]['role'] == 'system' %}
25 | {% set loop_messages = messages[1:] %}
26 | {% set system_message = messages[0]['content'].strip() + '\n\n' %}
27 | {% else %}
28 | {% set loop_messages = messages %}
29 | {% set system_message = '' %}
30 | {% endif %}
31 |
32 | {{ system_message }}
33 | {% for message in loop_messages %}
34 | {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
35 | {{ raise_exception(
36 | 'Conversation roles must alternate user/assistant/user/assistant/...')}}
37 | {% endif %}
38 |
39 | {% if message['role'] == 'user' %}
40 | {{ '@@ Instruction:\n' + message['content'].strip() + '\n\n' }}
41 | {% elif message['role'] == 'assistant' %}
42 | {{ '@@ Response:\n' + message['content'].strip() }}
43 | {% endif %}
44 |
45 | {% if loop.last and message['role'] == 'user' and add_generation_prompt %}
46 | {{ '@@ Response:' }}
47 | {% endif %}
48 | {% endfor %}
49 | """.replace(
50 | " ", ""
51 | ).replace(
52 | "\n", ""
53 | )
54 |
55 |
56 | def get_download_dir():
57 | if "HF_HOME" in os.environ:
58 | return os.environ["HF_HOME"] + "/models"
59 | else:
60 | return os.path.expanduser("~") + "/.cache/huggingface/models"
61 |
62 |
63 | class Task(ABC):
64 | def __init__(
65 | self,
66 | ):
67 | self.model_to_template = {}
68 | self.system_msg = ()
69 | self.target_metric_train = None
70 | self.target_metric_valid = self.target_metric_train
71 | self.target_metric_test = self.target_metric_train
72 | self.target_metric_transfer = None
73 | self.has_transfer_split = True
74 | self.has_training_split = True
75 |
76 | @abstractmethod
77 | def get_train_data(
78 | self,
79 | ):
80 | raise NotImplementedError
81 |
82 | @abstractmethod
83 | def get_rewards(self, res):
84 | raise NotImplementedError
85 |
86 | @abstractmethod
87 | def get_evaluator(
88 | self,
89 | ):
90 | raise NotImplementedError
91 |
92 | @abstractmethod
93 | def get_prompt(self, tokenizer, samples, ix, model_id):
94 | raise NotImplementedError
95 |
96 | @abstractmethod
97 | def get_vllm_model(self, model_id):
98 | raise NotImplementedError
99 |
100 |
101 | class FewShotTask(Task):
102 | def __init__(
103 | self,
104 | wrapped_task: Union[Task, DictConfig],
105 | wrapped_split: str = "test",
106 | shots=5,
107 | seed=16,
108 | ):
109 | if isinstance(wrapped_task, Task):
110 | self.wrapped_task: Task = wrapped_task
111 | else:
112 | self.wrapped_task: Task = hydra.utils.instantiate(wrapped_task)
113 |
114 | self.wrapped_split = wrapped_split
115 | self.shots = shots
116 | self.seed = seed
117 | self.model_to_template = wrapped_task.model_to_template
118 | self.system_msg = wrapped_task.system_msg
119 | if wrapped_split == "train":
120 | self.target_metric_train = wrapped_task.target_metric_train
121 | self.target_metric_valid = wrapped_task.target_metric_train
122 | self.target_metric_test = wrapped_task.target_metric_train
123 | assert wrapped_task.has_training_split
124 | elif wrapped_split == "test":
125 | self.target_metric_train = wrapped_task.target_metric_test
126 | self.target_metric_valid = wrapped_task.target_metric_test
127 | self.target_metric_test = wrapped_task.target_metric_test
128 | elif wrapped_split == "transfer":
129 | self.target_metric_train = wrapped_task.target_metric_transfer
130 | self.target_metric_valid = wrapped_task.target_metric_transfer
131 | self.target_metric_test = wrapped_task.target_metric_transfer
132 | assert wrapped_task.has_transfer_split
133 | else:
134 | raise NotImplementedError
135 | self.target_metric_transfer = wrapped_task.target_metric_transfer
136 | self.has_transfer_split = False
137 | self.has_training_split = True
138 |
139 | def get_train_data(
140 | self,
141 | ):
142 | train_eval, *test_evals = self.get_evaluator()
143 | train_data = train_eval.samples
144 | train_size = len(train_data)
145 | total_ix = list(range(train_size))
146 | import random
147 |
148 | random.seed(self.seed) # fix random seed for reproducibility
149 | random.shuffle(total_ix)
150 | train_ix = total_ix[: self.shots]
151 | valid_ix = total_ix[self.shots :]
152 | return train_data, train_ix, valid_ix
153 |
154 | def get_rewards(self, res):
155 | return self.wrapped_task.get_rewards(res=res)
156 |
157 | def get_evaluator(
158 | self,
159 | ):
160 | evaluators = self.wrapped_task.get_evaluator()
161 | if self.wrapped_split == "train":
162 | target_eval = evaluators[0]
163 | elif self.wrapped_split == "test":
164 | target_eval = evaluators[1]
165 | elif self.wrapped_split == "transfer":
166 | target_eval = evaluators[2]
167 | return target_eval, target_eval
168 |
169 | def get_prompt(self, tokenizer, samples, ix, model_id):
170 | return self.wrapped_task.get_prompt(
171 | tokenizer=tokenizer,
172 | samples=samples,
173 | ix=ix,
174 | model_id=model_id,
175 | )
176 |
177 | def get_vllm_model(self, model_id):
178 | return self.wrapped_task.get_vllm_model(
179 | model_id=model_id,
180 | )
181 |
--------------------------------------------------------------------------------
/tasks/cls.py:
--------------------------------------------------------------------------------
1 | import re
2 | from dataclasses import dataclass
3 | from typing import Iterable, Tuple
4 |
5 | import datasets
6 | import fishfarm
7 | import vllm
8 | from fishfarm.models.vllm_model import VLLMModel
9 | from fishfarm.tasks.base import TaskResult
10 | from fishfarm.tasks.evalplus import load_dataset
11 |
12 | from .base import Task, get_download_dir
13 |
14 |
15 | def mean(iterable: Iterable[float]) -> float:
16 | total, count = 0.0, 0
17 | for x in iterable:
18 | total += x
19 | count += 1
20 | return total / count
21 |
22 |
23 | def extract_ans(text):
24 | """Fetch the string within \\boxed{}."""
25 | match = re.search(r"\\boxed{([^}]*)}", text)
26 | if match:
27 | return match.group(1) # Return the content inside the \boxed{}
28 | else:
29 | return None # Return None if no match is found
30 |
31 |
32 | @dataclass
33 | class CategorySample:
34 | question: str
35 | label: str
36 |
37 |
38 | class CategoryClassficiationTask(fishfarm.tasks.base.Task):
39 | def __init__(
40 | self,
41 | samples,
42 | context_messages,
43 | ):
44 | self.samples = list(samples)
45 | self.context_messages = context_messages
46 |
47 | @property
48 | def num_samples(self) -> int:
49 | return len(self.samples)
50 |
51 | def evaluate(
52 | self,
53 | model,
54 | sample_ids,
55 | ):
56 | if sample_ids is None:
57 | sample_ids = range(len(self.samples))
58 | samples = [self.samples[sample_id] for sample_id in sample_ids]
59 |
60 | requests = []
61 | for sample in samples:
62 | messages = list(self.context_messages)
63 | messages.append(fishfarm.Message(role="user", content=sample.question))
64 | requests.append(fishfarm.models.GenerationRequest(messages=messages))
65 |
66 | sample_details = []
67 | for sample, result in zip(samples, model.generate(requests)):
68 | output = result.generation
69 | prediction = extract_ans(output)
70 |
71 | sample_details.append(
72 | dict(
73 | question=sample.question,
74 | label=sample.label,
75 | output=output,
76 | prediction=prediction,
77 | correct=sample.label == prediction,
78 | )
79 | )
80 |
81 | aggregate_metrics = {
82 | "acc": mean(
83 | float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0
84 | for sd in sample_details
85 | )
86 | }
87 | return TaskResult(
88 | aggregate_metrics=aggregate_metrics, sample_details=sample_details
89 | )
90 |
91 |
92 | class ClsTask(Task):
93 | def __init__(self):
94 | self.model_to_template = {
95 | "meta-llama/Meta-Llama-3-8B-Instruct": (
96 | "{% set loop_messages = messages %}"
97 | "{% for message in loop_messages %}"
98 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>"
99 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
100 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}"
101 | "{% endif %}"
102 | "{{ content }}"
103 | "{% endfor %}"
104 | "{% if add_generation_prompt %}"
105 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
106 | "{% endif %}"
107 | ),
108 | "mistralai/Mistral-7B-Instruct-v0.3": None,
109 | }
110 | self.system_msg = """
111 | # Analyze the given question and classify it into one of four categories: 'code', 'math', 'reasoning' or 'other'. Follow these guidelines:
112 |
113 | 1. Code: Questions asking for programming solutions, functions, algorithms. Often includes specific programming terms, language syntax, or data structures.
114 | 2. Math: Questions involving mathematical calculations, formulas, statistics. Often includes numbers, equations, or mathematical operations.
115 | 3. Reasoning: Questions requiring logical thinking, application of scientific knowledge, or critical analysis of information. Often presents statements that need evaluation based on general understanding.
116 | 4. Other: Questions not clearly fit into above categories.
117 |
118 | Instructions:
119 | - Consider the primary focus, skills, and knowledge required to answer the question.
120 | - If a question spans multiple categories, choose the most dominant one.
121 | - Provide your final classification within \\boxed{} notation. Example: \\boxed{reasoning}
122 |
123 | Format your response as follows:
124 | Classification: \\boxed{category}
125 | """
126 | self.target_metric_train = "acc"
127 | self.target_metric_valid = self.target_metric_train
128 | self.target_metric_test = self.target_metric_train
129 | self.target_metric_transfer = self.target_metric_train
130 | self.has_transfer_split = False
131 | self.has_training_split = True
132 | self.num_samples_per_task = 400 # Hard code 400 samples per task
133 | self.task_datasets = [
134 | datasets.load_dataset("gsm8k", "main", split="test"),
135 | load_dataset(source_dataset="mbpp"),
136 | datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test"),
137 | ]
138 | self.train_samples, self.test_samples = self.build_samples()
139 |
140 | def split_samples(self, samples):
141 | """Split samples into train and test sets with rate 4 : 1."""
142 | train_samples = []
143 | test_samples = []
144 | for i, sample in enumerate(samples):
145 | if i % 5 < 4:
146 | train_samples.append(sample)
147 | else:
148 | test_samples.append(sample)
149 | return train_samples, test_samples
150 |
151 | def get_train_data(self=400):
152 | train_ix = range(0, len(self.train_samples), 2)
153 | valid_ix = range(1, len(self.train_samples), 2)
154 |
155 | return self.train_samples, train_ix, valid_ix
156 |
157 | def build_samples(self):
158 | task_labels = ["math", "code", "reasoning"]
159 |
160 | samples = []
161 | choices = ["A", "B", "C", "D", "E"]
162 | for dataset, label in zip(self.task_datasets, task_labels):
163 | counter = 0
164 | for sample in dataset:
165 | counter += 1
166 | if counter >= self.num_samples_per_task:
167 | break
168 | if label == "math":
169 | samples.append(
170 | CategorySample(
171 | question=sample["question"],
172 | label="math",
173 | )
174 | )
175 | elif label == "code":
176 | samples.append(
177 | CategorySample(
178 | question=sample.instruction,
179 | label="code",
180 | )
181 | )
182 | else: # reasoning
183 | question = sample["question"] + "\n"
184 | question += "Options:\n"
185 | options = []
186 | for opt in sample["choices"]["text"]:
187 | options.append(opt)
188 | for i, opt in enumerate(options):
189 | question += "{}. {}\n".format(choices[i], opt)
190 | samples.append(
191 | CategorySample(
192 | question=question,
193 | label="reasoning",
194 | )
195 | )
196 |
197 | return self.split_samples(samples)
198 |
199 | def get_rewards(self, res):
200 | rewards = [1.0 if x["correct"] else -1.0 for x in res.sample_details]
201 | return rewards
202 |
203 | def get_evaluator(self) -> Tuple:
204 | # Build cls dataset here with training tasks.
205 | res = []
206 | for samples in [self.train_samples, self.test_samples]:
207 | res.append(
208 | CategoryClassficiationTask(
209 | samples=samples,
210 | context_messages=[
211 | fishfarm.Message("system", self.system_msg),
212 | ],
213 | )
214 | )
215 |
216 | return tuple(res)
217 |
218 | def get_prompt(self, tokenizer, samples, ix, model_id):
219 | chat_template = self.model_to_template[model_id]
220 | context_msg = {"role": "system", "content": self.system_msg}
221 | user_msg = {"role": "user", "content": samples[ix].question}
222 | prompt = tokenizer.apply_chat_template(
223 | conversation=[context_msg, user_msg],
224 | chat_template=chat_template,
225 | tokenize=False,
226 | add_generation_prompt=True,
227 | )
228 | return prompt
229 |
230 | def get_vllm_model(self, model_id) -> VLLMModel:
231 | """Load a vLLM model."""
232 | model = vllm.LLM(
233 | model_id,
234 | max_model_len=2048,
235 | gpu_memory_utilization=0.8,
236 | enforce_eager=True,
237 | dtype="bfloat16",
238 | download_dir=get_download_dir(),
239 | )
240 | chat_template = self.model_to_template[model_id]
241 | # This may change with vLLM versions.
242 | m = model.llm_engine.model_executor.driver_worker.model_runner.model
243 | for _, param in m.named_parameters():
244 | param.requires_grad = False
245 | vllm_model = VLLMModel(
246 | model,
247 | sampling_params=vllm.SamplingParams(
248 | temperature=0,
249 | top_p=1,
250 | max_tokens=1024,
251 | stop=["Instruction:", "Instruction", "Response:", "Response"],
252 | repetition_penalty=1.0,
253 | ),
254 | chat_template=chat_template,
255 | )
256 | return vllm_model
257 |
--------------------------------------------------------------------------------
/tasks/gsm8k.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import fishfarm
4 | import vllm
5 | from datasets import load_dataset
6 | from fishfarm.models.vllm_model import VLLMModel
7 | from fishfarm.tasks.language_restricted_math import (
8 | LanguageRestrictedMathTask, MathSample, extract_answer_number)
9 |
10 | from .base import Task, get_download_dir
11 |
12 |
13 | class Gsm8kTask(Task):
14 | def __init__(
15 | self,
16 | ):
17 | self.model_to_template = {
18 | "meta-llama/Meta-Llama-3-8B-Instruct": (
19 | "{% set loop_messages = messages %}"
20 | "{% for message in loop_messages %}"
21 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>"
22 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
23 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}"
24 | "{% endif %}"
25 | "{{ content }}"
26 | "{% endfor %}"
27 | "{% if add_generation_prompt %}"
28 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
29 | "{% endif %}"
30 | ),
31 | "mistralai/Mistral-7B-Instruct-v0.3": None,
32 | }
33 | self.system_msg = (
34 | "Below is an instruction that describes a task."
35 | " Write a response that appropriately completes the request.\n\n"
36 | )
37 |
38 | self.target_metric_train = "acc"
39 | self.target_metric_valid = self.target_metric_train
40 | self.target_metric_test = self.target_metric_train
41 | self.target_metric_transfer = self.target_metric_train
42 | self.has_transfer_split = False
43 | self.has_training_split = True
44 |
45 | def get_train_data(
46 | self,
47 | ):
48 | train_data = load_dataset("gsm8k", "main", split="train")
49 | train_size = len(train_data)
50 | train_ix = range(0, train_size, 2)
51 | valid_ix = range(1, train_size, 2)
52 | return train_data, train_ix, valid_ix
53 |
54 | def get_rewards(self, res):
55 | rewards = [1.0 if x["correct"] else -1.0 for x in res.sample_details]
56 | return rewards
57 |
58 | def get_evaluator(self) -> Tuple:
59 | res = []
60 | for split in ["train", "test"]:
61 | dataset = load_dataset("gsm8k", "main", split=split)
62 | samples = []
63 | for sample in dataset:
64 | answer = sample["answer"]
65 | answer = extract_answer_number(answer)
66 | answer = int(answer) if answer is not None else None
67 | samples.append(
68 | MathSample(
69 | problem=sample["question"],
70 | answer=answer,
71 | )
72 | )
73 | res.append(
74 | LanguageRestrictedMathTask(
75 | samples=samples,
76 | context_messages=[
77 | fishfarm.Message("system", self.system_msg),
78 | ],
79 | languages=[],
80 | )
81 | )
82 | return tuple(res)
83 |
84 | def get_prompt(self, tokenizer, samples, ix, model_id):
85 | chat_template = self.model_to_template[model_id]
86 | context_msg = {"role": "system", "content": self.system_msg}
87 | user_msg = {"role": "user", "content": samples["question"][ix]}
88 | prompt = tokenizer.apply_chat_template(
89 | conversation=[context_msg, user_msg],
90 | chat_template=chat_template,
91 | tokenize=False,
92 | add_generation_prompt=True,
93 | )
94 | return prompt
95 |
96 | def get_vllm_model(self, model_id) -> VLLMModel:
97 | """Load a vLLM model."""
98 | model = vllm.LLM(
99 | model_id,
100 | max_model_len=1024,
101 | gpu_memory_utilization=0.8,
102 | enforce_eager=True,
103 | dtype="bfloat16",
104 | download_dir=get_download_dir(),
105 | )
106 | chat_template = self.model_to_template[model_id]
107 | # This may change with vLLM versions.
108 | m = model.llm_engine.model_executor.driver_worker.model_runner.model
109 | for _, param in m.named_parameters():
110 | param.requires_grad = False
111 | vllm_model = VLLMModel(
112 | model,
113 | sampling_params=vllm.SamplingParams(
114 | temperature=0,
115 | top_p=1,
116 | max_tokens=512,
117 | stop=["Instruction:", "Instruction", "Response:", "Response"],
118 | repetition_penalty=1.0,
119 | ),
120 | chat_template=chat_template,
121 | )
122 | return vllm_model
123 |
--------------------------------------------------------------------------------
/tasks/math.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import fishfarm
4 | import vllm
5 | from datasets import load_dataset
6 | from fishfarm.models.vllm_model import VLLMModel
7 | from fishfarm.tasks.competation_math import (LatexFormatMathTask, MathSample,
8 | last_boxed_only_string,
9 | remove_boxed)
10 |
11 | from .base import Task, get_download_dir
12 |
13 |
14 | class MathTask(Task):
15 | def __init__(self):
16 | self.model_to_template = {
17 | "meta-llama/Meta-Llama-3-8B-Instruct": (
18 | "{% set loop_messages = messages %}"
19 | "{% for message in loop_messages %}"
20 | "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>"
21 | "\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
22 | "{% if loop.index0 == 0 %}{% set content = bos_token + content %}"
23 | "{% endif %}"
24 | "{{ content }}"
25 | "{% endfor %}"
26 | "{% if add_generation_prompt %}"
27 | "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
28 | "{% endif %}"
29 | ),
30 | "mistralai/Mistral-7B-Instruct-v0.3": None,
31 | }
32 | self.system_msg = (
33 | "Solve the question below by reasoning step by step,"
34 | "and put the final answer within \\boxed{}."
35 | )
36 |
37 | self.target_metric_train = "acc"
38 | self.target_metric_valid = self.target_metric_train
39 | self.target_metric_test = self.target_metric_train
40 | self.target_metric_transfer = self.target_metric_train
41 | self.has_transfer_split = False
42 | self.has_training_split = False
43 |
44 | def get_train_data(
45 | self,
46 | ):
47 | return None, None, None
48 |
49 | def get_rewards(self, res):
50 | rewards = [1.0 if x["correct"] else -1.0 for x in res.sample_details]
51 | return rewards
52 |
53 | def get_evaluator(
54 | self,
55 | ) -> Tuple:
56 | res = [None]
57 | dataset = load_dataset("hendrycks/competition_math", "main", split="test")
58 |
59 | samples = []
60 | for sample in dataset:
61 | answer = remove_boxed(last_boxed_only_string((sample["solution"])))
62 | samples.append(
63 | MathSample(
64 | problem=sample["problem"], answer=answer, type=sample["type"]
65 | )
66 | )
67 |
68 | test_eval = LatexFormatMathTask(
69 | samples=samples,
70 | context_messages=[
71 | fishfarm.Message("system", self.system_msg),
72 | ],
73 | )
74 | res.append(test_eval)
75 | return tuple(res)
76 |
77 | def get_prompt(self, tokenizer, samples, ix, model_id):
78 | chat_template = self.model_to_template[model_id]
79 | context_msg = {"role": "system", "content": self.system_msg}
80 | user_msg = {"role": "user", "content": samples[ix].problem}
81 | prompt = tokenizer.apply_chat_template(
82 | conversation=[context_msg, user_msg],
83 | chat_template=chat_template,
84 | tokenize=False,
85 | add_generation_prompt=True,
86 | )
87 | return prompt
88 |
89 | def get_vllm_model(self, model_id) -> VLLMModel:
90 | """Load a vLLM model."""
91 | model = vllm.LLM(
92 | model_id,
93 | max_model_len=2048,
94 | gpu_memory_utilization=0.8,
95 | enforce_eager=True,
96 | dtype="bfloat16",
97 | download_dir=get_download_dir(),
98 | )
99 | chat_template = self.model_to_template[model_id]
100 | # This may change with vLLM versions.
101 | m = model.llm_engine.model_executor.driver_worker.model_runner.model
102 | for _, param in m.named_parameters():
103 | param.requires_grad = False
104 | vllm_model = VLLMModel(
105 | model,
106 | sampling_params=vllm.SamplingParams(
107 | temperature=0,
108 | top_p=1,
109 | max_tokens=1024,
110 | stop=["Instruction:", "Instruction", "Response:", "Response"],
111 | repetition_penalty=1.0,
112 | ),
113 | chat_template=chat_template,
114 | )
115 | return vllm_model
116 |
--------------------------------------------------------------------------------
/tasks/mbpp2.py:
--------------------------------------------------------------------------------
1 | import fishfarm
2 | import vllm
3 | from fishfarm.models.vllm_model import VLLMModel
4 | from fishfarm.tasks.evalplus import EvalplusTask, load_dataset
5 |
6 | from .base import CODE_PROMPT, Task, get_download_dir
7 |
8 |
9 | class Mbpp2Task(Task):
10 | def __init__(
11 | self,
12 | ):
13 | self.model_to_template = {
14 | "meta-llama/Meta-Llama-3-8B-Instruct": CODE_PROMPT,
15 | "mistralai/Mistral-7B-Instruct-v0.3": CODE_PROMPT,
16 | }
17 | self.system_msg = (
18 | "You are an exceptionally intelligent coding assistant that "
19 | " consistently delivers accurate and reliable responses to user "
20 | "instructions."
21 | )
22 |
23 | self.target_metric_train = "mbpp_base_pass@1"
24 | self.target_metric_valid = self.target_metric_train
25 | self.target_metric_test = self.target_metric_train
26 | self.target_metric_transfer = "humaneval_base_pass@1"
27 | self.has_transfer_split = True
28 | self.has_training_split = True
29 |
30 | def get_train_data(
31 | self,
32 | ):
33 | train_eval, *test_evals = self.get_evaluator()
34 | train_data = train_eval.samples
35 | train_size = len(train_data)
36 | total_ix = list(range(train_size))
37 | import random
38 |
39 | random.seed(16) # fix random seed for reproducibility
40 | random.shuffle(total_ix)
41 | train_ix = total_ix[:200]
42 | valid_ix = total_ix[200:]
43 | return train_data, train_ix, valid_ix
44 |
45 | def get_rewards(self, res):
46 | rewards = [1.0 if x["base_correct"] == 1 else -1.0 for x in res.sample_details]
47 | return rewards
48 |
49 | def get_evaluator(
50 | self,
51 | ):
52 | res = []
53 | samples = load_dataset(source_dataset="mbpp")
54 | res.append(
55 | EvalplusTask(
56 | samples[:300],
57 | context_messages=[
58 | fishfarm.Message("system", self.system_msg),
59 | ],
60 | source_dataset="mbpp",
61 | )
62 | )
63 | res.append(
64 | EvalplusTask(
65 | samples[300:],
66 | context_messages=[
67 | fishfarm.Message("system", self.system_msg),
68 | ],
69 | source_dataset="mbpp",
70 | )
71 | )
72 | samples = load_dataset(source_dataset="humaneval")
73 | res.append(
74 | EvalplusTask(
75 | samples,
76 | context_messages=[
77 | fishfarm.Message("system", self.system_msg),
78 | ],
79 | source_dataset="humaneval",
80 | )
81 | )
82 | return tuple(res)
83 |
84 | def get_prompt(self, tokenizer, samples, ix, model_id):
85 | chat_template = self.model_to_template[model_id]
86 | context_msg = {"role": "system", "content": self.system_msg}
87 | user_msg = {"role": "user", "content": samples[ix].instruction}
88 | assistant_msg = {"role": "assistant", "content": samples[ix].response_prefix}
89 | return tokenizer.apply_chat_template(
90 | conversation=[context_msg, user_msg, assistant_msg],
91 | chat_template=chat_template,
92 | tokenize=False,
93 | add_generation_prompt=True,
94 | )
95 |
96 | def get_vllm_model(self, model_id) -> VLLMModel:
97 | """Load a vLLM model."""
98 | model = vllm.LLM(
99 | model_id,
100 | max_model_len=1024,
101 | gpu_memory_utilization=0.8,
102 | enforce_eager=True,
103 | dtype="bfloat16",
104 | download_dir=get_download_dir(),
105 | )
106 | chat_template = self.model_to_template[model_id]
107 | # This may change with vLLM versions.
108 | m = model.llm_engine.model_executor.driver_worker.model_runner.model
109 | for _, param in m.named_parameters():
110 | param.requires_grad = False
111 | vllm_model = VLLMModel(
112 | model,
113 | sampling_params=vllm.SamplingParams(
114 | temperature=0,
115 | top_p=1,
116 | max_tokens=512,
117 | stop=["Instruction:", "Instruction", "Response:", "Response"],
118 | repetition_penalty=1.0,
119 | ),
120 | chat_template=chat_template,
121 | )
122 | return vllm_model
123 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from copy import deepcopy
3 | from typing import Dict, Optional
4 |
5 | import fishfarm
6 | import torch
7 | import torch.utils
8 | import vllm
9 |
10 |
11 | def load_hf_params_to_vllm(param: Dict, llm: vllm.LLM) -> None:
12 | """Load weights from HF transformer model to vLLM model."""
13 |
14 | model = llm.llm_engine.model_executor.driver_worker.model_runner.model
15 | num_layers = model.config.num_hidden_layers
16 |
17 | # Load embeddings layer weights.
18 | model_param = model.get_parameter("model.embed_tokens.weight")
19 | model_param.copy_(
20 | param["model.embed_tokens.weight"][: model_param.shape[0]]
21 | .to(model_param.dtype)
22 | .to(model_param.device)
23 | )
24 | model_param = model.get_parameter("lm_head.weight")
25 | model_param.copy_(
26 | param["lm_head.weight"][: model_param.shape[0]]
27 | .to(model_param.dtype)
28 | .to(model_param.device)
29 | )
30 |
31 | # Load the final layernorm weights.
32 | model_param = model.get_parameter("model.norm.weight")
33 | model_param.copy_(
34 | param["model.norm.weight"].to(model_param.dtype).to(model_param.device)
35 | )
36 |
37 | for i in range(num_layers):
38 | # Load qkv_proj weights.
39 | model_param = model.get_parameter(f"model.layers.{i}.self_attn.qkv_proj.weight")
40 | model_param.copy_(
41 | torch.cat(
42 | [
43 | param[f"model.layers.{i}.self_attn.q_proj.weight"],
44 | param[f"model.layers.{i}.self_attn.k_proj.weight"],
45 | param[f"model.layers.{i}.self_attn.v_proj.weight"],
46 | ],
47 | dim=0,
48 | )
49 | .to(model_param.dtype)
50 | .to(model_param.device)
51 | )
52 | # Load gate_up_proj weights.
53 | model_param = model.get_parameter(f"model.layers.{i}.mlp.gate_up_proj.weight")
54 | model_param.copy_(
55 | torch.cat(
56 | [
57 | param[f"model.layers.{i}.mlp.gate_proj.weight"],
58 | param[f"model.layers.{i}.mlp.up_proj.weight"],
59 | ],
60 | dim=0,
61 | )
62 | .to(model_param.dtype)
63 | .to(model_param.device)
64 | )
65 | # Load o_proj and down_proj weights.
66 | model_param = model.get_parameter(f"model.layers.{i}.self_attn.o_proj.weight")
67 | model_param.copy_(
68 | param[f"model.layers.{i}.self_attn.o_proj.weight"]
69 | .to(model_param.dtype)
70 | .to(model_param.device)
71 | )
72 | model_param = model.get_parameter(f"model.layers.{i}.mlp.down_proj.weight")
73 | model_param.copy_(
74 | param[f"model.layers.{i}.mlp.down_proj.weight"]
75 | .to(model_param.dtype)
76 | .to(model_param.device)
77 | )
78 | # Load layer_norm weights.
79 | model_param = model.get_parameter(f"model.layers.{i}.input_layernorm.weight")
80 | model_param.copy_(
81 | param[f"model.layers.{i}.input_layernorm.weight"]
82 | .to(model_param.dtype)
83 | .to(model_param.device)
84 | )
85 | model_param = model.get_parameter(
86 | f"model.layers.{i}.post_attention_layernorm.weight"
87 | )
88 | model_param.copy_(
89 | param[f"model.layers.{i}.post_attention_layernorm.weight"]
90 | .to(model_param.dtype)
91 | .to(model_param.device)
92 | )
93 |
94 |
95 | def eval_model(vllm_model, evaluator, ix=None):
96 | result = evaluator.evaluate(vllm_model, sample_ids=ix)
97 | return result
98 |
99 |
100 | def compose_new_params(
101 | policy,
102 | param_name,
103 | decomposed_params,
104 | learnable_params,
105 | ):
106 | """Compose new parameters from decomposed parameters."""
107 | mm = policy.get_mask(learnable_params[param_name])
108 | return (
109 | decomposed_params[f"{param_name}.U"]
110 | @ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm)
111 | @ decomposed_params[f"{param_name}.V"].T
112 | ) * (
113 | decomposed_params[f"{param_name}.S"].sum()
114 | / (decomposed_params[f"{param_name}.S"] * mm).sum()
115 | )
116 |
117 |
118 | @torch.no_grad()
119 | def forward(policy, model, base_params, decomposed_params, learnable_params):
120 | """Forward pass."""
121 | new_params = {}
122 | for k in base_params:
123 | if "mlp" in k:
124 | new_params[k] = compose_new_params(
125 | policy, k, decomposed_params, learnable_params
126 | )
127 | model.get_parameter(k).copy_(new_params[k])
128 | else:
129 | new_params[k] = base_params[k]
130 | return new_params
131 |
132 |
133 | @torch.no_grad()
134 | def load_base_params(
135 | model,
136 | base_params,
137 | ):
138 | for k in base_params:
139 | if "mlp" in k:
140 | model.get_parameter(k).copy_(base_params[k].cuda())
141 |
142 |
143 | def backward(
144 | policy,
145 | model,
146 | base_params,
147 | decomposed_params,
148 | learnable_params,
149 | ):
150 | """Backward pass."""
151 | keys_to_backprop = [k for k in base_params if "mlp" in k]
152 | last_key = keys_to_backprop[-1]
153 | for k in keys_to_backprop[:-1]:
154 | compose_new_params(policy, k, decomposed_params, learnable_params).backward(
155 | model.get_parameter(k).grad, retain_graph=True
156 | )
157 | # release graph
158 | compose_new_params(policy, last_key, decomposed_params, learnable_params).backward(
159 | model.get_parameter(last_key).grad, retain_graph=False
160 | )
161 |
162 |
163 | def classify_samples(vllm_model, test_eval):
164 | """Classify samples."""
165 |
166 | CLASSIFICATION_PROMPT = """
167 | # Analyze the given question and classify it into one of four categories: 'code', 'math', 'reasoning' or 'other'. Follow these guidelines:
168 |
169 | 1. Code: Questions asking for programming solutions, functions, algorithms. Often includes specific programming terms, language syntax, or data structures.
170 | 2. Math: Questions involving mathematical calculations, formulas, statistics. Often includes numbers, equations, or mathematical operations.
171 | 3. Reasoning: Questions requiring logical thinking, application of scientific knowledge, or critical analysis of information. Often presents statements that need evaluation based on general understanding.
172 | 4. Other: Questions not clearly fit into above categories.
173 |
174 | Instructions:
175 | - Consider the primary focus, skills, and knowledge required to answer the question.
176 | - If a question spans multiple categories, choose the most dominant one.
177 | - Provide your final classification within \\boxed{} notation. Example: \\boxed{reasoning}
178 |
179 | Format your response as follows:
180 | Classification: \\boxed{category}
181 | """
182 |
183 | def extract_classification(text: str) -> Optional[str]:
184 | """
185 | Extract the classification from the model's output using regex.
186 | """
187 | match = re.search(r"\\boxed{([^}]*)}", text)
188 | return match.group(1) if match else None
189 |
190 | # Identify the key in the samples that contains the problem text
191 | problem_key = None
192 | for key in ("problem", "question", "instruction"):
193 | if (
194 | hasattr(test_eval.samples[0], key)
195 | and getattr(test_eval.samples[0], key) is not None
196 | ):
197 | problem_key = key
198 | break
199 | assert problem_key is not None, "Could not find problem text in the samples"
200 |
201 | # Prepare classification requests
202 | classification_requests = [
203 | fishfarm.models.GenerationRequest(
204 | messages=[
205 | fishfarm.Message("system", CLASSIFICATION_PROMPT),
206 | fishfarm.Message("user", getattr(sample, problem_key)),
207 | ]
208 | )
209 | for sample in test_eval.samples
210 | ]
211 |
212 | # Generate classifications using the model
213 | model_outputs = vllm_model.generate(classification_requests)
214 |
215 | # Process results and update samples
216 | classified_samples = []
217 | for sample, result in zip(test_eval.samples, model_outputs):
218 | prediction = extract_classification(result.generation)
219 | if prediction not in ["code", "math", "reasoning"]:
220 | prediction = "other"
221 | sample.expert_label = prediction
222 | classified_samples.append(sample)
223 |
224 | return classified_samples
225 |
226 |
227 | def eval_model_experts_prompt_based(
228 | vllm_model,
229 | evaluator,
230 | experts_path_dict,
231 | policy,
232 | model,
233 | base_params,
234 | decomposed_params,
235 | task_metric,
236 | ):
237 | """Evaluate the model using expert models and prompt-based classification."""
238 | results_by_expert: Dict[str, Dict] = {}
239 |
240 | # Classify all test samples
241 | classified_samples = classify_samples(vllm_model, evaluator)
242 |
243 | # Evaluate samples for each expert model
244 | for expert_label, expert_model_path in experts_path_dict.items():
245 | # Filter samples for current expert
246 | expert_samples = [
247 | sample
248 | for sample in classified_samples
249 | if sample.expert_label == expert_label
250 | ]
251 | if not expert_samples:
252 | continue
253 |
254 | # Update test evaluation with filtered samples
255 | evaluator.samples = expert_samples
256 |
257 | # Load and apply expert model parameters if available
258 | if expert_model_path:
259 | policy.load_state_dict(torch.load(expert_model_path))
260 | expert_params = policy.get_learnable_params()
261 | updated_params = forward(
262 | policy=policy,
263 | model=model,
264 | base_params=base_params,
265 | decomposed_params=decomposed_params,
266 | learnable_params=expert_params,
267 | )
268 | load_hf_params_to_vllm(updated_params, vllm_model.llm)
269 |
270 | # Evaluate current expert model
271 | evaluation_results = eval_model(vllm_model, evaluator)
272 |
273 | # Store results for current expert
274 | results_by_expert[expert_label] = {
275 | "num_samples": len(expert_samples),
276 | "test_acc": evaluation_results.aggregate_metrics[task_metric],
277 | }
278 |
279 | # Compute the overall accuracy.
280 | data_dict = deepcopy(results_by_expert)
281 | data_dict["final_test_acc"] = 0.0
282 | for label in results_by_expert.keys():
283 | data_dict["final_test_acc"] += (
284 | results_by_expert[label]["test_acc"]
285 | * results_by_expert[label]["num_samples"]
286 | )
287 | data_dict["final_test_acc"] /= len(classified_samples)
288 |
289 | return data_dict
290 |
--------------------------------------------------------------------------------