├── README.md ├── assets └── general_fig.png ├── eval.py ├── launch_experiment.py ├── pred.py ├── qalign ├── __init__.py ├── rm │ ├── __init__.py │ ├── create_dataset.py │ ├── data │ │ ├── dataset_info.json │ │ ├── llama-factory_gsm8k_llama-3.1-8b-instruct_128_1_train.json │ │ ├── llama-factory_gsm8k_llama-3.1-8b-instruct_test.json │ │ ├── llama-factory_gsm8k_test.json │ │ ├── raw_gsm8k_llama-3.1-8b-instruct_128_1_train.json │ │ ├── raw_gsm8k_llama-3.1-8b-instruct_test.json │ │ └── raw_gsm8k_test.json │ ├── train_configs │ │ ├── deepspeed │ │ │ ├── ds_z0_config.json │ │ │ ├── ds_z2_config.json │ │ │ ├── ds_z2_offload_config.json │ │ │ ├── ds_z3_config.json │ │ │ └── ds_z3_offload_config.json │ │ ├── explain.txt │ │ ├── gsm8k │ │ │ ├── full │ │ │ │ ├── gemma2_full_dpo.yaml │ │ │ │ ├── gemma2_full_dpo05.yaml │ │ │ │ ├── gemma2_full_dpo10.yaml │ │ │ │ ├── gemma2_full_dpo_sft05.yaml │ │ │ │ ├── gemma2_full_dpo_sft10.yaml │ │ │ │ ├── gemma2_full_reward.yaml │ │ │ │ ├── llama3_1b1b_full_reward.yaml │ │ │ │ ├── llama3_1b8b_full_reward.yaml │ │ │ │ ├── llama3_3b1b_full_reward.yaml │ │ │ │ ├── llama3_3b8b_full_reward.yaml │ │ │ │ ├── llama3_3b_full_reward.yaml │ │ │ │ ├── llama3_8b1b_full_reward.yaml │ │ │ │ ├── llama3_8b8b_full_dpo.yaml │ │ │ │ ├── llama3_8b8b_full_reward.yaml │ │ │ │ ├── olmo_full_reward.yaml │ │ │ │ ├── train.sh │ │ │ │ └── tulu_8b8b_full_reward.yaml │ │ │ └── lora │ │ │ │ ├── gemma2_lora_dpo.yaml │ │ │ │ ├── gemma2_lora_reward.yaml │ │ │ │ ├── llama2_lora_reward.yaml │ │ │ │ ├── llama2_lora_sft.yaml │ │ │ │ └── olmo_lora_reward.yaml │ │ └── math │ │ │ └── full │ │ │ ├── llama3_1b8b_full_reward_math.yaml │ │ │ ├── llama3_8b8b_full_reward_math.yaml │ │ │ ├── llama3_8b8b_full_reward_math_cot.yaml │ │ │ └── tulu_8b8b_full_reward_math.yaml │ └── valid_check.py └── utils │ ├── data.py │ ├── eval.py │ ├── examples.py │ ├── experiment.py │ ├── ifeval │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── eval.cpython-38.pyc │ │ ├── instructions.cpython-38.pyc │ │ ├── instructions_util.cpython-38.pyc │ │ └── registry.cpython-38.pyc │ ├── eval.py │ ├── instructions.py │ ├── instructions_util.py │ └── registry.py │ ├── math.py │ ├── mbr.py │ └── pred.py ├── resume_experiment.py ├── resume_experiment_remote.py └── scripts ├── create_all_general_experiments.sh ├── create_all_task_experiments.sh ├── run_eval.sh ├── run_local_experiments.sh ├── run_pred.sh ├── run_remote_experiments.sh └── run_rm_eval.sh /README.md: -------------------------------------------------------------------------------- 1 | # Sample, Don't Search: Rethinking Test-Time Alignment for Language Models 2 | 3 | Gonçalo Faria, Noah A. Smith 4 | 5 | **Paper**: https://arxiv.org/abs/2504.03790 6 | 7 | **TL;DR:** QAlign is a new test-time alignment approach that improves language model performance by using Markov chain Monte Carlo methods. 8 | 9 | ### Abstract: 10 | Increasing test-time computation has emerged as a promising direction for improving language model performance, particularly in scenarios where model finetuning is impractical or impossible due to computational constraints or private model weights. However, existing test-time search methods using a reward model (RM) often degrade in quality as compute scales, due to the over-optimization of what are inherently imperfect reward proxies. We introduce QAlign, a new test-time alignment approach. As we scale test-time compute, QAlign converges to sampling from the optimal aligned distribution for each individual prompt. By adopting recent advances in Markov chain Monte Carlo for text generation, our method enables better-aligned outputs without modifying the underlying model or even requiring logit access. We demonstrate the effectiveness of QAlign on mathematical reasoning benchmarks (GSM8K and GSM-Symbolic) using a task-specific RM, showing consistent improvements over existing test-time compute methods like best-of-n and majority voting. Furthermore, when applied with more realistic RMs trained on the Tulu 3 preference dataset, QAlign outperforms direct preference optimization (DPO), best-of-n, majority voting, and weighted majority voting on a diverse range of datasets (GSM8K, MATH500, IFEval, MMLU-Redux, and TruthfulQA). A practical solution to aligning language models at test time using additional computation without degradation, our approach expands the limits of the capability that can be obtained from off-the-shelf language models without further training. 11 | 12 | 13 | 14 | ![General Alignment Experiments](assets/general_fig.png) 15 |

Average error rate across multiple evaluation datasets (GSM8K, MATH500, MMLU-Redux, TruthfulQA, and IFEval) as a function of the floating point operations (FLOPS) in log scale. 16 | We compare QAlign method with Tülu3-8B-SFT against four baselines: majority vote (MV) Tülu3-8B-DPO, and applied to Tülu3-8B-SFT the methods best-of-n (BoN), MV, and weighted MV (WMV). All experiments use temperature 1.0 with reasoning included in model outputs. Note that Tülu3-8B-DPO model is the result of doing preference finetuning on the Tülu3-8B-SFT with 271k preference pairs. The costs associated with this process are not accounted for in this plot.

17 | 18 | 19 | ----- 20 | ##
Dependencies
21 | 22 | This project relies strongly on the following external libraries: 23 | - [deepspin/quest-decoding](https://github.com/deep-spin/quest-decoding) 24 | - [goncalorafaria/expkit](https://github.com/goncalorafaria/expkit-core) 25 | - [goncalorafaria/literegistry](https://github.com/goncalorafaria/literegistry) 26 | 27 | ```bash 28 | pip install quest-decoding 29 | pip install expkit-core 30 | pip install literegistry 31 | ``` 32 | 33 | Install the required packages: 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ----- 39 | ##
Reproducing the work
40 | 41 | Replicating the work: 42 | 43 | ### Experiment Setup 44 | 1. **Create Configuration Files** 45 | ```bash 46 | # Create configs for general experiments 47 | scripts/create_all_general_experiments.sh 48 | 49 | # Create configs for task-specific experiments 50 | scripts/create_all_task_experiments.sh 51 | ``` 52 | 53 | ### Running Experiments 54 | 2. **Execute Experiments** 55 | ```bash 56 | # Run experiments locally 57 | scripts/run_local_experiments.sh 58 | 59 | # Run experiments on remote server 60 | scripts/run_remote_experiments.sh 61 | ``` 62 | 63 | ### Evaluation & Analysis 64 | 3. **Evaluate Results** 65 | ```bash 66 | # Compare responses against ground truth answers 67 | scripts/run_eval_experiment.sh 68 | 69 | # Evaluate reward model for ancestral predictions (remote by default) 70 | scripts/run_rm_eval.sh 71 | ``` 72 | 73 | 4. **Generate Final Predictions** 74 | ```bash 75 | # Run WMV, BON, and MV final prediction methods 76 | scripts/run_pred.sh 77 | ``` 78 | 79 | 80 | ----- 81 | 82 | ##
Quick Start
83 | 84 | This guide will help you get started running QAlign. 85 | 86 | ## Basic Usage 87 | 88 | ```python 89 | import os 90 | from quest.core import Quest 91 | from quest.reward.model import ContextualRewardModel, ValueHead 92 | from quest.qalign import QAlign 93 | from quest.model.vllm import VLLM 94 | 95 | # Model configuration 96 | model_path = "allenai/Llama-3.1-Tulu-3-8B-SFT" 97 | model_args = { 98 | "model_path": model_path, 99 | "download_dir": os.environ.get("HF_HOME", "/tmp/"), 100 | "stop_tokens": ["", "<|im_end|>"], 101 | "temperature": 0.7, 102 | "gpu_memory_utilization": 0.9, 103 | "dtype": "bfloat16", 104 | "max_new_tokens": 512, 105 | "max_prompt_length": 4096, 106 | "tensor_parallel_size": 1, # Number of GPUs 107 | "enable_prefix_caching": True, 108 | "enforce_eager": True, 109 | } 110 | 111 | # Initialize the model 112 | model = VLLM(**model_args) 113 | 114 | # Initialize the reward model 115 | reward = ContextualRewardModel( 116 | model_path="allenai/Llama-3.1-Tulu-3-8B-RM", 117 | device=1, ## second gpu 118 | device_count=1, 119 | ) 120 | 121 | # Prepare your data 122 | data_batch = [ 123 | { 124 | "prompt": "<|user|>\nJanet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\n<|assistant|>\n" 125 | }, 126 | # Add more examples as needed 127 | ] 128 | 129 | # Create markov chain 130 | chain = QAlign( 131 | input_data=data_batch, 132 | model=model, 133 | reward=reward, 134 | beta=1.0, # Controls exploration vs exploitation 135 | ) 136 | 137 | # Run 138 | chain_outputs = chain.run( 139 | steps=10, # Number of steps 140 | use_tqdm=True, # Show progress bar 141 | ) 142 | 143 | # Print the accepted outputs 144 | print(f"Original prompt: {chain_outputs[0]['input']['prompt']}") 145 | for output in chain_outputs[0]["outputs"]: 146 | if output["accept"]: 147 | print(f"Response: {output['text']}") 148 | print("-" * 50) 149 | 150 | ``` 151 | 152 | ----- 153 | 154 | ##
Contact
155 | 156 | For bugs and feature requests please visit [GitHub Issues](https://github.com/goncalorafaria/qalign/issues). For business inquiries or 157 | professional support requests please send an [e-mail](mailto:goncalofaria.research@gmail.com). 158 | 159 | ----- 160 | 161 | ##
Citation
162 | 163 | ```` 164 | @misc{faria2025sampledontsearchrethinking, 165 | title={Sample, Don't Search: Rethinking Test-Time Alignment for Language Models}, 166 | author={Gonçalo Faria and Noah A. Smith}, 167 | year={2025}, 168 | eprint={2504.03790}, 169 | archivePrefix={arXiv}, 170 | primaryClass={cs.CL}, 171 | url={https://arxiv.org/abs/2504.03790}, 172 | } 173 | ```` 174 | 175 | -------------------------------------------------------------------------------- /assets/general_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/assets/general_fig.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | ## quest 4 | from quest.reward.model import ValueHead 5 | from quest.reward.remote import RemoteReward, RemoteReward 6 | from quest import ( 7 | RewardModel, 8 | ContextualRewardModel, 9 | ) 10 | 11 | ## expkit 12 | from expkit.setup import ExpSetup 13 | from expkit.storage import DiskStorage 14 | 15 | ## qalign 16 | from qalign.utils.eval import * 17 | 18 | ## literegistry 19 | from literegistry import RegistryClient, FileSystemKVStore 20 | 21 | 22 | def main( 23 | base_dir="remote-outputs-llama/", 24 | reward_model_path="lastnumber", 25 | model_path="allenai/tulu-2-7b", 26 | batch_size=248, 27 | context=True, 28 | query_args={}, 29 | device_count=8, 30 | value_head=True, 31 | remote=False, 32 | n=None, 33 | ): 34 | 35 | print("Query Args:", query_args) 36 | setup = ExpSetup(storage=DiskStorage(base_dir=base_dir, mode="rw")) 37 | # print("Exp:", setup.query(query_args)) 38 | 39 | setup = setup.query(query_args).filter(lambda x: x.has_data()) 40 | 41 | print("That match the query:\n", setup) 42 | 43 | if len(setup.experiments) == 0: 44 | raise FileNotFoundError("The experiment has no data!") 45 | 46 | if reward_model_path == "likelihood": 47 | 48 | ps_eval = LikelihoodEval(model_path=model_path) 49 | 50 | elif reward_model_path == "lastnumber": 51 | ps_eval = ExactLastNumberEval() 52 | 53 | elif reward_model_path == "lastmath": 54 | ps_eval = ExactMATHEval() 55 | 56 | elif reward_model_path == "lastoption": 57 | ps_eval = ExactQAEval() 58 | 59 | elif reward_model_path == "ifeval": 60 | ps_eval = IFEval() 61 | 62 | elif remote: 63 | 64 | if value_head: 65 | reward_type = "value" 66 | elif context: 67 | reward_type = "contextual" 68 | else: 69 | reward_type = "reward" 70 | 71 | registry = RegistryClient( 72 | store=FileSystemKVStore("/gscratch/ark/graf/registry"), 73 | max_history=3600, 74 | cache_ttl=60, 75 | service_type="model_path", 76 | ) 77 | 78 | reward = RemoteReward( 79 | model_path=reward_model_path, 80 | registry=registry, 81 | reward_type=reward_type, 82 | # batch_size=batch_size, 83 | # max_parallel_requests=32, 84 | batch_size=32, 85 | max_parallel_requests=64, 86 | ) 87 | ps_eval = RewardEval( 88 | reward=reward, 89 | n=n, 90 | chunk_size=256, 91 | ) 92 | elif value_head: 93 | 94 | reward = ValueHead( 95 | model_path=reward_model_path, 96 | batch_size=batch_size, 97 | device_count=device_count, 98 | ) 99 | 100 | ps_eval = RewardEval(reward=reward, n=n) 101 | 102 | else: 103 | if context: 104 | reward = ContextualRewardModel( 105 | model_path=reward_model_path, 106 | batch_size=batch_size, 107 | device_count=device_count, 108 | ) 109 | 110 | else: 111 | reward = RewardModel( 112 | model_path=reward_model_path, 113 | batch_size=batch_size, 114 | device_count=device_count, 115 | ) 116 | 117 | ps_eval = RewardEval(reward=reward, n=n) 118 | 119 | # setup = setup.filter(lambda x: not x.has_eval(ps_eval.eval_name)) 120 | # setup = setup.filter(lambda x: x.get("n") == len(x.instances())) 121 | print("That haven't done the eval:", setup) 122 | 123 | def func(experiment): 124 | 125 | try: 126 | return ( 127 | ps_eval(experiment) 128 | # if not experiment.has_eval(ps_eval.eval_name) 129 | # else experiment 130 | ) 131 | except FileNotFoundError: 132 | return experiment 133 | 134 | except Exception as e: 135 | raise e 136 | # return experiment 137 | 138 | setup = setup.map(func) 139 | 140 | # new_setup.save() 141 | 142 | 143 | if __name__ == "__main__": 144 | 145 | import fire 146 | 147 | fire.Fire(main) 148 | -------------------------------------------------------------------------------- /launch_experiment.py: -------------------------------------------------------------------------------- 1 | from qalign.utils.experiment import create_experiment 2 | 3 | 4 | def main( 5 | variant="quest-rlhf", 6 | beta: float = 1.0, 7 | steps: int = 4096, 8 | temperature: float = 1.0, 9 | n: int = 64, 10 | model_path: str = "meta-llama/Llama-3.1-8B-Instruct", 11 | reward_model_path: str = "/gscratch/ark/graf/quest-rlhf/qflow/rm/artifacts/llama3/8b8b/mathcotfix/full/reward", # "/gscratch/ark/graf/LLaMA-Factory/saves/llama3/8b/full/reward/", 12 | reward_model_batch_size: int = 4, 13 | save_path: str = "remote-outputs-llama/", 14 | gpu_memory_utilization: float = 0.8, 15 | max_new_tokens: int = 800, 16 | max_prompt_length: int = 1200, 17 | dataset_path="HuggingFaceH4/MATH-500", # "apple/GSM-Symbolic-p1", 18 | batch_size=64, 19 | stop_tokens=[], 20 | prompt_template="{prompt}", 21 | num_chains: int = 1, 22 | use_few_shot: bool = False, 23 | format: str = "chat", 24 | split: str = "test", 25 | remote: bool = True, 26 | reward_type="value", 27 | ): 28 | 29 | additional_meta = { 30 | "remote": remote, 31 | } 32 | 33 | if variant == "quest-rlhf": 34 | 35 | # Create experiment with additional meta data 36 | additional_meta = { 37 | "beta": beta, 38 | "reward_model_path": reward_model_path, 39 | "reward_type": reward_type, 40 | # "i": start_index, 41 | "num_chains": num_chains, 42 | **additional_meta, 43 | } 44 | 45 | experiment = create_experiment( 46 | save_path=save_path, 47 | variant=variant, 48 | model_path=model_path, 49 | dataset_path=dataset_path, 50 | n=n, 51 | temperature=temperature, 52 | steps=steps, 53 | max_new_tokens=max_new_tokens, 54 | max_prompt_length=max_prompt_length, 55 | stop_tokens=stop_tokens, 56 | prompt_template=prompt_template, 57 | additional_meta=additional_meta, 58 | batch_size=batch_size, 59 | use_few_shot=use_few_shot, 60 | format=format, 61 | split=split, 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | import fire 67 | 68 | fire.Fire(main) 69 | -------------------------------------------------------------------------------- /pred.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from multiprocessing import Pool 3 | 4 | 5 | ## expkit 6 | from expkit.setup import ExpSetup 7 | from expkit.storage import DiskStorage 8 | 9 | ## qalign 10 | from qalign.utils.pred import * 11 | from qalign.utils.eval import * 12 | 13 | 14 | def main( 15 | base_dir="remote-outputs/", # "llama3.2-outputs/", # "tqa-outputs/", 16 | strategy="voting", 17 | key="crm:allenai-Llama-3", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-tulu-8b8b-gsm8k-full-reward-", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-tulu-8b8b-math-full-reward-", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-tulu-8b8b-gsm8k-full-reward-", # "crm:allenai-Llama-3", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-llama3-8b8b-mathcot-full-reward", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-llama3-8b8b-math-full-reward-", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-llama3-8b8b-math-full-reward-", # "vh:-gscratch-ark-graf-LLaMA-Factory-saves-llama3-8b8b-full-reward-", # "vh:-gscratch-ark-graf-LLaMA-Factory-saves-llama3-3b-full-reward-", 18 | query_args={ 19 | # "steps": 4096, 20 | # "split": "test", 21 | # "split": "validation", 22 | "temperature": 1.0, 23 | # "beta": 0.5, 24 | # "model_path": "meta-llama/Llama-3.1-8B-Instruct", 25 | "dataset": "HuggingFaceH4/MATH-500", # "HuggingFaceH4/MATH-500", # "openai/gsm8k", # HuggingFaceH4/MATH-500 26 | # "dataset": "lighteval/MATH", 27 | # "model_path": "allenai/Llama-3.1-Tulu-3-8B-SFT", 28 | # "reward_model_path": "/gscratch/ark/graf/LLaMA-Factory/saves/llama3/8b1b/full/reward/", 29 | # "reward_model_path": "allenai/Llama-3.1-Tulu-3-8B-RM", 30 | # "variant": "ancestral", 31 | # "n": 128, 32 | }, 33 | beta: float = 1.0, 34 | c: float = 1.0, 35 | extract="lastmath", 36 | trials=512, 37 | r=32, 38 | gaps=64, 39 | exp_rate=False, 40 | n=None, 41 | ): 42 | 43 | setup = ( 44 | ExpSetup(storage=DiskStorage(base_dir=base_dir, mode="rw")) 45 | .query(query_args) 46 | .filter(lambda x: x.has_data()) 47 | ) 48 | 49 | print(setup) 50 | # setup.experiments[0].evals() 51 | 52 | print( 53 | len(setup.experiments), 54 | ) 55 | if len(setup.experiments) == 0: 56 | raise FileNotFoundError("The configuration has no data!") 57 | 58 | # strategy, reward_model_path = strategy.split("-") 59 | 60 | with Pool() as p: 61 | pick = get_strategy( 62 | strategy, 63 | key=key, 64 | p=p, 65 | beta=beta, 66 | c=c, 67 | extract=extract, 68 | trials=trials, 69 | r=r, 70 | gaps=gaps, 71 | exp_rate=exp_rate, 72 | n=n, 73 | ) # msft , nvdia 74 | 75 | # setup = setup.filter(lambda x: not x.has_eval(pick.eval_name)) 76 | 77 | def func(experiment): 78 | 79 | try: 80 | # if len(experiment.instances()) == experiment.meta["n"]: 81 | 82 | # print(len(experiment.instances())) 83 | return ( 84 | pick(experiment) 85 | # if not experiment.has_eval(pick.eval_name) 86 | # else experiment 87 | ) 88 | # else: 89 | # return experiment 90 | 91 | except FileNotFoundError: 92 | print(experiment.name) 93 | return experiment 94 | 95 | except Exception as e: 96 | print(experiment.name) 97 | print(e) 98 | raise e 99 | return experiment 100 | 101 | setup = setup.map(func) 102 | 103 | # new_setup.save() 104 | 105 | 106 | if __name__ == "__main__": 107 | 108 | import fire 109 | 110 | fire.Fire(main) 111 | -------------------------------------------------------------------------------- /qalign/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/__init__.py -------------------------------------------------------------------------------- /qalign/rm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/rm/__init__.py -------------------------------------------------------------------------------- /qalign/rm/create_dataset.py: -------------------------------------------------------------------------------- 1 | from expkit import ( 2 | ExpSetup, 3 | ) 4 | import logging 5 | from expkit.storage import CachedRO, ZipStorage, DiskStorage 6 | import os 7 | import shutil 8 | import numpy as np 9 | import re 10 | import json 11 | 12 | from transformers import AutoTokenizer 13 | 14 | from qflow.utils.data import get_data_iterable, general_process_data 15 | from qflow.utils.math import get_last_number 16 | from datasets import load_dataset 17 | 18 | from qflow.utils.data import FlexiblePromptTemplate 19 | 20 | 21 | def sample_pairs(positive, negatives, n=1): 22 | nmin = min(len(positive), len(negatives), n) 23 | 24 | if nmin > 0: 25 | np.random.shuffle(positive) 26 | np.random.shuffle(negatives) 27 | 28 | if len(positive) < n: 29 | positive = np.random.choice( 30 | positive, 31 | size=n, 32 | replace=True, 33 | ) 34 | if len(negatives) < n: 35 | negatives = np.random.choice( 36 | negatives, 37 | size=n, 38 | replace=True, 39 | ) 40 | 41 | return [ 42 | { 43 | "chosen": c, 44 | "rejected": r, 45 | } 46 | for c, r in zip( 47 | positive[:n], 48 | negatives[:n], 49 | ) 50 | ] 51 | else: 52 | return [] 53 | 54 | 55 | def fake_synth_data(i, po): 56 | 57 | number = get_last_number(i["input"]["answer"]) 58 | 59 | alt_results = [ 60 | np.round( 61 | np.random.rand() * float(number), 62 | 1, 63 | ) 64 | for _ in range(len(po)) 65 | ] 66 | 67 | if "." not in number: 68 | alt_results = np.ceil(alt_results).astype(int) 69 | 70 | no = [ 71 | answer.replace( 72 | number, 73 | str(pseudo_result), 74 | ) 75 | for answer, pseudo_result in zip(po, alt_results) 76 | ] 77 | 78 | return no 79 | 80 | 81 | def linearize(setup, oracle_key="lastnumber"): 82 | dataset = [] 83 | evals = [] 84 | 85 | if len(setup) > 1: 86 | 87 | setup = setup.unique("i").sort("i") 88 | 89 | for e in setup: 90 | try: 91 | print("Processing", e.meta.get("i", 0)) 92 | print("elements:", len(e.instances())) 93 | print("evals:", len(e.get_eval(oracle_key))) 94 | 95 | dataset.extend(e.instances()) 96 | evals.extend(e.get_eval(oracle_key)) 97 | 98 | except Exception as e: 99 | logging.error("An error occurred during linearization:", e) 100 | pass 101 | 102 | return {"data": dataset, "evals": evals} 103 | 104 | 105 | def remove_tags(text, tags): 106 | ptext = re.sub("|".join(map(re.escape, tags)), "", text) 107 | 108 | return ptext 109 | 110 | 111 | def create_positive_and_negative_pairs(dataset_stack, stop_tokens): 112 | dataset = [] 113 | answers = [] 114 | 115 | for i, gt in zip(dataset_stack["data"], dataset_stack["evals"]): 116 | 117 | values = np.array(gt["scores"]) 118 | pind = np.where(values == 1)[0] # extract index of correct 119 | nind = np.where(values == 0)[0] # extract index of incorrect 120 | 121 | po = [ 122 | remove_tags(i["outputs"][j]["text"], stop_tokens) for j in pind 123 | ] # get text without special tokens 124 | no = [ 125 | remove_tags(i["outputs"][j]["text"], stop_tokens) for j in nind 126 | ] # get text without special tokens 127 | 128 | all_neg, all_po = ( 129 | len(po) == 0, 130 | len(no) == 0, 131 | ) 132 | 133 | if all_po: ## if all positive 134 | no = fake_synth_data(i, po) # create fake wrongs .. hopefully this is rare. 135 | 136 | dataset.append((po + [i["input"]["answer"]], no)) 137 | 138 | return dataset 139 | 140 | 141 | def sample_pairs_dataset_chat( 142 | split_dataset, 143 | epochs=1, 144 | dataset_path="openai/gsm8k", 145 | model_path="allenai/OLMo-7B-0724-Instruct-hf", 146 | prompt_template="{prompt}", 147 | ): 148 | dataset = [] 149 | 150 | for (po, no), i in zip( 151 | split_dataset, 152 | get_data_iterable( 153 | model_path=model_path, 154 | dataset_path=dataset_path, 155 | split="train", 156 | prompt_template=prompt_template, 157 | ), 158 | ): 159 | 160 | chat_template = i["chat_template_prompt"] 161 | 162 | po, no = list(set(po)), list(set(no)) 163 | for pair in sample_pairs( 164 | po, 165 | no, 166 | n=epochs, 167 | ): 168 | 169 | dataset.append( 170 | { 171 | k: chat_template 172 | + [ 173 | { 174 | "role": "assistant", 175 | "content": v, 176 | } 177 | ] 178 | for k, v in pair.items() 179 | } 180 | ) 181 | 182 | return dataset 183 | 184 | 185 | def sample_pairs_dataset_prompt( 186 | split_dataset, 187 | epochs=1, 188 | dataset_path="openai/gsm8k", 189 | model_path="allenai/OLMo-7B-0724-Instruct-hf", 190 | use_few_shot=False, 191 | ): 192 | dataset = [] 193 | 194 | for (po, no), i in zip( 195 | split_dataset, 196 | get_data_iterable( 197 | model_path=model_path, 198 | dataset_path=dataset_path, 199 | split="train", 200 | format="prompt", 201 | use_few_shot=use_few_shot, 202 | ), 203 | ): 204 | 205 | chat_template = [ 206 | { 207 | "role": "user", 208 | "content": i["prompt"], 209 | } 210 | ] 211 | 212 | po, no = list(set(po)), list(set(no)) 213 | for pair in sample_pairs( 214 | po, 215 | no, 216 | n=epochs, 217 | ): 218 | 219 | dataset.append( 220 | { 221 | k: chat_template 222 | + [ 223 | { 224 | "role": "assistant", 225 | "content": v, 226 | } 227 | ] 228 | for k, v in pair.items() 229 | } 230 | ) 231 | 232 | return dataset 233 | 234 | 235 | def create_rm_data_test_pairs( 236 | model_path, 237 | dataset_path="openai/gsm8k", 238 | ): # this fake synth data .. is not representative ... 239 | dataset = [] 240 | # templates = [] 241 | 242 | for i in get_data_iterable(model_path=model_path, dataset_path=dataset_path): 243 | 244 | chat_template = i["chat_template_prompt"] 245 | 246 | # templates.append(chat_template) 247 | number = get_last_number(i["answer"]) 248 | 249 | alt_number = np.round( 250 | np.random.rand() * float(number), 251 | 1, 252 | ) 253 | 254 | if "." not in number: 255 | alt_number = np.ceil(alt_number).astype(int) 256 | 257 | wrong_answer = i["answer"].replace( 258 | number, 259 | str(alt_number), 260 | ) 261 | 262 | answer = i["answer"] 263 | 264 | dataset.append( 265 | { 266 | "chosen": chat_template 267 | + [ 268 | { 269 | "role": "assistant", 270 | "content": answer, 271 | } 272 | ], 273 | "rejected": chat_template 274 | + [ 275 | { 276 | "role": "assistant", 277 | "content": wrong_answer, 278 | } 279 | ], 280 | } 281 | ) 282 | 283 | return dataset 284 | 285 | 286 | def convert_llamafactory_format(x): 287 | 288 | return { 289 | "instruction": x["chosen"][0]["content"], 290 | "chosen": x["chosen"][-1]["content"], 291 | "rejected": x["rejected"][-1]["content"], 292 | } 293 | 294 | 295 | # /gscratch/ark/graf/rmlearn/data 296 | 297 | 298 | def get_sharded_setup( 299 | storage, 300 | model_path, 301 | temperature=1.0, 302 | steps=64, 303 | dataset="openai/gsm8k", 304 | oracle_key="lastnumber", 305 | prompt_template="{prompt}", 306 | ): 307 | setup = ( 308 | ExpSetup(storage).query( 309 | { 310 | "steps": steps, 311 | "temperature": temperature, 312 | # "prompt_template": "Solve the following grade school math problem step-by-step: {prompt}", 313 | "model_path": model_path, 314 | "split": "train", 315 | "dataset": dataset, 316 | # "prompt_template": prompt_template, 317 | # "n": n, 318 | } 319 | ) 320 | # .filter(lambda x: "i" in x.meta) 321 | # .filter(lambda x: x.islocked()) 322 | ) 323 | 324 | if len(setup) == 0: 325 | raise ValueError( 326 | f"No data found for model {model_path} and dataset {dataset} with steps {steps} and temperature {temperature}" 327 | ) 328 | 329 | setup = setup.filter(lambda x: x.has_eval(oracle_key)) 330 | setup = setup.filter(lambda x: x.get("n") == len(x.instances())) 331 | 332 | if len(setup) == 0: 333 | raise ValueError(f"No eval data with key:{oracle_key} found.") 334 | 335 | setup = setup.sort("n", reverse=False) 336 | 337 | dataset_stack = linearize(setup, oracle_key=oracle_key) 338 | 339 | return dataset_stack 340 | 341 | 342 | def file_name(dataset, model_path, steps=128, epochs=1, format="raw", split="train"): 343 | d = dataset.split("/")[-1].lower() 344 | m = model_path.split("/")[-1].lower() 345 | 346 | if split == "test": 347 | return f"{format}_{d}_{m}_{split}.json" 348 | else: 349 | return f"{format}_{d}_{m}_{steps}_{epochs}_{split}.json" 350 | 351 | 352 | def generate_rm_dataset( 353 | model_path="meta-llama/Llama-3.1-8B-Instruct", 354 | base_dir="outputs/llama3gsm8ktrain/", 355 | dataset="openai/gsm8k", 356 | save_path="/gscratch/ark/graf/quest-rlhf/qflow/rm/data/", 357 | epochs=[1, 2, 4, 8], 358 | steps=128, 359 | temperature=1.0, 360 | stop_tokens=None, 361 | oracle_key="lastnumber", 362 | prompt_template="{prompt}", 363 | prefix="", 364 | ): 365 | 366 | tokenizer = AutoTokenizer.from_pretrained(model_path) 367 | 368 | if stop_tokens is None: 369 | stop_tokens = [tokenizer.eos_token] 370 | else: 371 | stop_tokens = stop_tokens + [tokenizer.eos_token] 372 | 373 | try: 374 | storage = ZipStorage( 375 | base_dir=base_dir, 376 | mode="r", 377 | ) 378 | except ValueError as e: 379 | storage = DiskStorage( 380 | base_dir=base_dir, 381 | mode="r", 382 | ) 383 | 384 | storage = CachedRO(storage) 385 | 386 | dataset_stack = get_sharded_setup( 387 | storage, 388 | model_path=model_path, 389 | temperature=temperature, 390 | steps=steps, 391 | dataset=dataset, 392 | oracle_key=oracle_key, 393 | prompt_template=prompt_template, 394 | ) 395 | 396 | all_pairs = create_positive_and_negative_pairs( 397 | dataset_stack, 398 | stop_tokens=stop_tokens, 399 | ) 400 | 401 | for ep in epochs: 402 | sampled_pairs = sample_pairs_dataset_chat( 403 | all_pairs, 404 | epochs=ep, 405 | model_path=model_path, 406 | dataset_path=dataset, 407 | prompt_template=prompt_template, 408 | # use_few_shot=True, 409 | ) 410 | 411 | sampled_pairs_factory = [convert_llamafactory_format(x) for x in sampled_pairs] 412 | 413 | with open( 414 | save_path 415 | + prefix 416 | + file_name( 417 | dataset, 418 | model_path, 419 | steps, 420 | ep, 421 | format="raw", 422 | ), 423 | "w", 424 | ) as fn: 425 | json.dump(sampled_pairs, fn) 426 | 427 | with open( 428 | save_path 429 | + prefix 430 | + file_name( 431 | dataset, 432 | model_path, 433 | steps, 434 | ep, 435 | format="llama-factory", 436 | ), 437 | "w", 438 | ) as f: 439 | json.dump(sampled_pairs_factory, f) 440 | 441 | print(f"Saved {prefix+file_name(dataset, model_path, steps, ep, format='raw')}") 442 | print( 443 | f"Saved {prefix+file_name(dataset, model_path, steps, ep, format='llama-factory')}" 444 | ) 445 | 446 | 447 | def generate_rm_test_dataset( 448 | model_path="meta-llama/Llama-3.2-3B-Instruct", 449 | base_dir="outputs/llama3gsm8ktrain/", 450 | dataset="openai/gsm8k", 451 | save_path="/gscratch/ark/graf/quest-rlhf/qflow/rm/data/", 452 | ): 453 | 454 | test_dataset = create_rm_data_test_pairs( 455 | model_path=model_path, dataset_path=dataset 456 | ) 457 | 458 | with open( 459 | save_path + file_name(dataset, model_path, format="raw", split="test"), 460 | "w", 461 | ) as fn: 462 | json.dump(test_dataset, fn) 463 | 464 | test_data = [convert_llamafactory_format(x) for x in test_dataset] 465 | 466 | with open( 467 | save_path 468 | + file_name(dataset, model_path, format="llama-factory", split="test"), 469 | "w", 470 | ) as fn: 471 | json.dump(test_data, fn) 472 | 473 | 474 | def main( 475 | model_path="meta-llama/Llama-3.1-8B-Instruct", # "allenai/Llama-3.1-Tulu-3-8B-SFT", # "allenai/Llama-3.1-Tulu-3-8B-SFT", # "meta-llama/Llama-3.1-8B-Instruct", 476 | dataset="lighteval/MATH", # "lighteval/MATH", # "openai/gsm8k", 477 | save_path="/gscratch/ark/graf/quest-rlhf/qflow/rm/data/", 478 | base_dir="remote-outputs-llama/", # "outputs/llama3gsm8ktrain/", 479 | steps=128, 480 | prompt_template="Solve the following math problem step-by-step: {prompt}\n\nPresent the answer in LaTex format: \\boxed{Your answer}", # "Solve the following math problem step-by-step: {prompt}", 481 | prefix="cotnewextractmath14325", 482 | ): # 90d44980-8ad9-48f2-8c34-99747991f571 483 | 484 | dataset_oracles = { 485 | "openai/gsm8k": "lastnumber", 486 | "lighteval/MATH": "lastmath", 487 | } 488 | 489 | generate_rm_dataset( 490 | model_path=model_path, 491 | base_dir=base_dir, 492 | dataset=dataset, 493 | save_path=save_path, 494 | epochs=[1, 2, 4, 8], 495 | steps=steps, 496 | temperature=1.0, 497 | oracle_key=dataset_oracles[dataset], 498 | prompt_template=FlexiblePromptTemplate(prompt_template), 499 | prefix=prefix, 500 | ) 501 | 502 | """generate_rm_test_dataset( 503 | model_path=model_path, 504 | base_dir=base_dir, 505 | dataset=dataset, 506 | save_path=save_path, 507 | )""" # this has to improved. We need to generate from the base model for the test set. 508 | 509 | info = {} 510 | for fn in os.listdir(save_path): 511 | 512 | if "llama-factory" not in fn: 513 | continue 514 | 515 | key = fn.replace(".json", "") 516 | info[key] = { 517 | "file_name": fn, 518 | "ranking": True, 519 | "columns": { 520 | "prompt": "instruction", 521 | "chosen": "chosen", 522 | "rejected": "rejected", 523 | }, 524 | } 525 | 526 | with open(save_path + "dataset_info.json", "w") as f: 527 | json.dump(info, f) 528 | 529 | 530 | if __name__ == "__main__": 531 | main() 532 | -------------------------------------------------------------------------------- /qalign/rm/data/dataset_info.json: -------------------------------------------------------------------------------- 1 | {"llama-factory_gsm8k_llama-3.1-8b-instruct_128_4_train": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_128_2_train": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_1_train": {"file_name": "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_2_train": {"file_name": "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train": {"file_name": "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_4_train": {"file_name": "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train": {"file_name": "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train": {"file_name": "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-8b-instruct_128_1_train": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_8_train": {"file_name": "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_128_1_train": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b_128_4_train": {"file_name": "llama-factory_math_llama-3.1-8b_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b_128_8_train": {"file_name": "llama-factory_math_llama-3.1-8b_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_4_train": {"file_name": "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b-instruct_128_4_train": {"file_name": "llama-factory_math_llama-3.1-8b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.2-1b-instruct_128_1_train": {"file_name": "llama-factory_math_llama-3.2-1b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_test": {"file_name": "llama-factory_gsm8k_test.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.2-1b-instruct_128_8_train": {"file_name": "llama-factory_math_llama-3.2-1b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b_128_1_train": {"file_name": "llama-factory_math_llama-3.1-8b_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b-instruct_128_1_train": {"file_name": "llama-factory_math_llama-3.1-8b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_128_1_train": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_128_2_train": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_128_8_train": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.2-1b-instruct_128_4_train": {"file_name": "llama-factory_math_llama-3.2-1b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_128_8_train": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_128_4_train": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_2_train": {"file_name": "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b-instruct_128_2_train": {"file_name": "llama-factory_math_llama-3.1-8b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_4_train": {"file_name": "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-8b-instruct_128_8_train": {"file_name": "cotllama-factory_math_llama-3.1-8b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.2-1b-instruct_128_2_train": {"file_name": "llama-factory_math_llama-3.2-1b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_8_train": {"file_name": "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_2_train": {"file_name": "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-8b-instruct_128_4_train": {"file_name": "cotllama-factory_math_llama-3.1-8b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-8b-instruct_test": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_test.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b_128_2_train": {"file_name": "llama-factory_math_llama-3.1-8b_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train": {"file_name": "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-8b-instruct_128_2_train": {"file_name": "cotllama-factory_math_llama-3.1-8b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_2_train": {"file_name": "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-8b-instruct_128_1_train": {"file_name": "cotllama-factory_math_llama-3.1-8b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-8b-instruct_128_8_train": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_test": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_test.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_4_train": {"file_name": "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_2_train": {"file_name": "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_test": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_test.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_8_train": {"file_name": "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_8_train": {"file_name": "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_128_4_train": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_4_train": {"file_name": "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b-instruct_128_8_train": {"file_name": "llama-factory_math_llama-3.1-8b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_8_train": {"file_name": "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-8b-instruct_128_2_train": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}} -------------------------------------------------------------------------------- /qalign/rm/train_configs/deepspeed/ds_z0_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 0, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/deepspeed/ds_z2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/deepspeed/ds_z2_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "offload_optimizer": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "allgather_partitions": true, 25 | "allgather_bucket_size": 5e8, 26 | "overlap_comm": true, 27 | "reduce_scatter": true, 28 | "reduce_bucket_size": 5e8, 29 | "contiguous_gradients": true, 30 | "round_robin_gradients": true 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/deepspeed/ds_z3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": false 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 3, 15 | "overlap_comm": true, 16 | "contiguous_gradients": true, 17 | "sub_group_size": 1e9, 18 | "reduce_bucket_size": "auto", 19 | "stage3_prefetch_bucket_size": "auto", 20 | "stage3_param_persistence_threshold": "auto", 21 | "stage3_max_live_parameters": 1e9, 22 | "stage3_max_reuse_distance": 1e9, 23 | "stage3_gather_16bit_weights_on_model_save": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/deepspeed/ds_z3_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "offload_optimizer": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "offload_param": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "overlap_comm": true, 29 | "contiguous_gradients": true, 30 | "sub_group_size": 1e9, 31 | "reduce_bucket_size": "auto", 32 | "stage3_prefetch_bucket_size": "auto", 33 | "stage3_param_persistence_threshold": "auto", 34 | "stage3_max_live_parameters": 1e9, 35 | "stage3_max_reuse_distance": 1e9, 36 | "stage3_gather_16bit_weights_on_model_save": true 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/explain.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | llamafactory-cli train quest/full/llama3_8b1b_full_reward.yaml 4 | llamafactory-cli train qflow/rm/quest/full/llama3_8b1b_full_reward.yaml 5 | 6 | llamafactory-cli train qflow/rm/train_configs/math/full/llama3_8b8b_full_reward_math_cot.yaml -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: google/gemma-2-2b-it 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: full 8 | pref_beta: 5.0 9 | #pref_ftx: 1.0 10 | flash_attn: "disabled" 11 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 12 | deepspeed: examples/deepspeed/ds_z3_config.json 13 | 14 | ## dataset 15 | dataset: gsm8k_gemma2_264_4 16 | template: gemma 17 | cutoff_len: 1024 18 | max_samples: 1000000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: saves/gemma-2-2b-it/full/dpo 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 1 32 | gradient_accumulation_steps: 16 33 | learning_rate: 5.0e-7 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | warmup_ratio: 0.03 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | val_size: 0.1 42 | per_device_eval_batch_size: 4 43 | eval_strategy: steps 44 | eval_steps: 5 45 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo05.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: google/gemma-2-2b-it 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: full 8 | pref_beta: 0.05 9 | pref_ftx: 0.0 10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 11 | deepspeed: examples/deepspeed/ds_z3_config.json 12 | 13 | ## dataset 14 | dataset: gsm8k_gemma2_264_4 15 | template: gemma 16 | cutoff_len: 1024 17 | max_samples: 1000000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/gemma-2-2b-it/full/dpo05 23 | logging_steps: 1 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 32 32 | learning_rate: 5.0e-6 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: linear 35 | warmup_ratio: 0.03 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.1 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 5 44 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo10.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: google/gemma-2-2b-it 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: full 8 | pref_beta: 0.10 9 | pref_ftx: 0.0 10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 11 | deepspeed: examples/deepspeed/ds_z3_config.json 12 | 13 | ## dataset 14 | dataset: gsm8k_gemma2_264_4 15 | template: gemma 16 | cutoff_len: 1024 17 | max_samples: 1000000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/gemma-2-2b-it/full/dpo10 23 | logging_steps: 1 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 32 32 | learning_rate: 5.0e-6 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: linear 35 | warmup_ratio: 0.03 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.1 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 5 44 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo_sft05.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: google/gemma-2-2b-it 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: full 8 | pref_beta: 0.05 9 | pref_ftx: 1.0 10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 11 | deepspeed: examples/deepspeed/ds_z3_config.json 12 | 13 | ## dataset 14 | dataset: gsm8k_gemma2_264_4 15 | template: gemma 16 | cutoff_len: 1024 17 | max_samples: 1000000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/gemma-2-2b-it/full/dposft05 23 | logging_steps: 1 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 32 32 | learning_rate: 5.0e-6 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: linear 35 | warmup_ratio: 0.03 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.1 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 5 44 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo_sft10.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: google/gemma-2-2b-it 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: full 8 | pref_beta: 1.0 9 | pref_ftx: 1.0 10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 11 | deepspeed: examples/deepspeed/ds_z3_config.json 12 | flash_attn: "disabled" 13 | 14 | ## dataset 15 | dataset: gsm8k_gemma2_264_4 16 | template: gemma 17 | cutoff_len: 1024 18 | max_samples: 1000000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: saves/gemma-2-2b-it/full/dposft10 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 2 32 | gradient_accumulation_steps: 32 33 | learning_rate: 5.0e-6 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | warmup_ratio: 0.03 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | val_size: 0.1 42 | per_device_eval_batch_size: 4 43 | eval_strategy: steps 44 | eval_steps: 5 45 | 46 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/gemma2_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: google/gemma-2-2b-it 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | 11 | 12 | ### dataset 13 | dataset: gsm8k_gemma2_264 14 | template: gemma 15 | cutoff_len: 1024 16 | max_samples: 1000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/gemma-2-2b-it/full/reward 22 | logging_steps: 10 23 | save_steps: 500 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | report_to: wandb 27 | 28 | ### train 29 | per_device_train_batch_size: 1 30 | gradient_accumulation_steps: 32 31 | learning_rate: 1.0e-5 32 | num_train_epochs: 1.0 33 | lr_scheduler_type: linear 34 | weight_decay: 0.0 35 | warmup_ratio: 0.03 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.1 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 2 44 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/llama3_1b1b_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.2-1B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | 13 | ### dataset 14 | dataset: gsm8k_llama3.2-1B_128_1ep 15 | eval_dataset: gsm8k_llama3.2-1B_test 16 | template: llama3 17 | cutoff_len: 1024 18 | max_samples: 300000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: saves/llama3/1b1b/full/reward 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 2 32 | gradient_accumulation_steps: 8 33 | learning_rate: 1.0e-5 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | weight_decay: 0.0 37 | warmup_ratio: 0.03 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | #val_size: 0.05 43 | per_device_eval_batch_size: 4 44 | eval_strategy: steps 45 | eval_steps: 5 46 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/llama3_1b8b_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset: gsm8k_llama3.2-1B_128_1ep 14 | eval_dataset: gsm8k_llama3.2-1B_test 15 | template: llama3 16 | cutoff_len: 1024 17 | max_samples: 300000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/llama3/1b8b/full/reward 23 | logging_steps: 1 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 1 31 | gradient_accumulation_steps: 16 32 | learning_rate: 1.0e-5 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: linear 35 | weight_decay: 0.0 36 | warmup_ratio: 0.03 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | #val_size: 0.05 42 | per_device_eval_batch_size: 4 43 | eval_strategy: steps 44 | eval_steps: 5 45 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/llama3_3b1b_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.2-1B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | 13 | ### dataset 14 | dataset: gsm8k_llama3.2-3B_128_1ep 15 | eval_dataset: gsm8k_llama3.2-3B_test 16 | template: llama3 17 | cutoff_len: 1024 18 | max_samples: 300000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: saves/llama3/3b1b/full/reward 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 2 32 | gradient_accumulation_steps: 8 33 | learning_rate: 1.0e-5 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | weight_decay: 0.0 37 | warmup_ratio: 0.03 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | #val_size: 0.05 43 | per_device_eval_batch_size: 4 44 | eval_strategy: steps 45 | eval_steps: 5 46 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/llama3_3b8b_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset: gsm8k_llama3.2-3B_128_1ep 14 | eval_dataset: gsm8k_llama3.2-3B_test 15 | template: llama3 16 | cutoff_len: 1024 17 | max_samples: 300000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/llama3/3b8b/full/reward 23 | logging_steps: 1 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 1 31 | gradient_accumulation_steps: 32 32 | learning_rate: 1.0e-5 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: linear 35 | weight_decay: 0.0 36 | warmup_ratio: 0.03 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | #val_size: 0.05 42 | per_device_eval_batch_size: 4 43 | eval_strategy: steps 44 | eval_steps: 5 45 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/llama3_3b_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.2-3B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset: gsm8k_llama3.2-3B_128_1ep 14 | eval_dataset: gsm8k_llama3.2-3B_test 15 | template: llama3 16 | cutoff_len: 1024 17 | max_samples: 300000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/llama3/3b/full/reward 23 | logging_steps: 1 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 8 32 | learning_rate: 1.0e-5 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: linear 35 | weight_decay: 0.0 36 | warmup_ratio: 0.03 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | #val_size: 0.05 42 | per_device_eval_batch_size: 4 43 | eval_strategy: steps 44 | eval_steps: 5 45 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/llama3_8b1b_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.2-1B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | 13 | ### dataset 14 | dataset: gsm8k_llama3.1-8B_128_1ep 15 | eval_dataset: gsm8k_llama3.1-8B_test 16 | template: llama3 17 | cutoff_len: 1024 18 | max_samples: 300000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: saves/llama3/8b1b/full/reward 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 2 32 | gradient_accumulation_steps: 8 33 | learning_rate: 1.0e-5 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | weight_decay: 0.0 37 | warmup_ratio: 0.03 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | #val_size: 0.05 43 | per_device_eval_batch_size: 4 44 | eval_strategy: steps 45 | eval_steps: 5 46 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/llama3_8b8b_full_dpo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | 6 | ### method 7 | stage: dpo 8 | do_train: true 9 | finetuning_type: full 10 | pref_beta: 5.0 11 | #pref_ftx: 1.0 12 | flash_attn: "disabled" 13 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 14 | deepspeed: examples/deepspeed/ds_z3_config.json 15 | 16 | ### dataset 17 | dataset: gsm8k_llama3.1-8B_128_1ep 18 | eval_dataset: gsm8k_llama3.1-8B_test 19 | template: llama3 20 | cutoff_len: 1024 21 | max_samples: 300000 22 | overwrite_cache: true 23 | preprocessing_num_workers: 16 24 | 25 | ### output 26 | output_dir: saves/llama3/8b8b/full/dpoe6 27 | logging_steps: 1 28 | save_steps: 500 29 | plot_loss: true 30 | overwrite_output_dir: true 31 | report_to: wandb 32 | 33 | ### train 34 | per_device_train_batch_size: 1 35 | gradient_accumulation_steps: 16 36 | learning_rate: 1.0e-6 37 | num_train_epochs: 1.0 38 | lr_scheduler_type: linear 39 | weight_decay: 0.0 40 | warmup_ratio: 0.03 41 | bf16: true 42 | ddp_timeout: 180000000 43 | 44 | ### eval 45 | #val_size: 0.05 46 | per_device_eval_batch_size: 4 47 | eval_strategy: steps 48 | eval_steps: 5 49 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/llama3_8b8b_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset: gsm8k_llama3.1-8B_128_1ep 14 | eval_dataset: gsm8k_llama3.1-8B_test 15 | template: llama3 16 | cutoff_len: 1024 17 | max_samples: 300000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/llama3/8b8b/full/reward 23 | logging_steps: 1 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 1 31 | gradient_accumulation_steps: 16 32 | learning_rate: 1.0e-5 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: linear 35 | weight_decay: 0.0 36 | warmup_ratio: 0.03 37 | bf16: true 38 | ddp_timeout: 180000000 39 | 40 | ### eval 41 | #val_size: 0.05 42 | per_device_eval_batch_size: 4 43 | eval_strategy: steps 44 | eval_steps: 5 45 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/olmo_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: allenai/OLMo-7B-0724-Instruct-hf 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | 11 | 12 | ### dataset 13 | dataset: gsm8k_olmo_128_4ep 14 | template: olmo 15 | cutoff_len: 1024 16 | max_samples: 300000 17 | overwrite_cache: true 18 | preprocessing_num_workers: 16 19 | 20 | ### output 21 | output_dir: saves/olmo/full/reward4e-5 22 | logging_steps: 1 23 | save_steps: 500 24 | plot_loss: true 25 | overwrite_output_dir: true 26 | report_to: wandb 27 | 28 | ### train 29 | per_device_train_batch_size: 2 30 | gradient_accumulation_steps: 32 31 | learning_rate: 1.0e-5 32 | num_train_epochs: 1.0 33 | lr_scheduler_type: linear 34 | weight_decay: 0.0 35 | warmup_ratio: 0.03 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.05 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 5 44 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/train.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # llamafactory-cli train quest/full/llama3_1b_full_reward.yaml 5 | # llamafactory-cli train quest/full/llama3_3b_full_reward.yaml 6 | 7 | 8 | llamafactory-cli train quest/full/llama3_8b1b_full_reward.yaml && llamafactory-cli train quest/full/llama3_3b1b_full_reward.yaml 9 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/full/tulu_8b8b_full_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: allenai/Llama-3.1-Tulu-3-8B-SFT 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset_dir: qflow/rm/data/ 14 | dataset: llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train 15 | #eval_dataset: gsm8k_llama3.1-8B_test 16 | template: llama3 17 | cutoff_len: 1024 18 | max_samples: 300000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: qflow/rm/artifacts/tulu/8b8b/gsm8k/full/reward 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 1 32 | gradient_accumulation_steps: 32 33 | learning_rate: 1.0e-5 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | weight_decay: 0.0 37 | warmup_ratio: 0.03 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | val_size: 0.05 43 | per_device_eval_batch_size: 2 44 | eval_strategy: steps 45 | eval_steps: 5 46 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/lora/gemma2_lora_dpo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: google/gemma-2-2b-it 3 | 4 | ### method 5 | stage: dpo 6 | do_train: true 7 | finetuning_type: lora 8 | pref_beta: 0.1 9 | #pref_ftx: 0.5 10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 11 | lora_rank: 8 12 | deepspeed: examples/deepspeed/ds_z3_config.json 13 | 14 | # finetuning_type: full 15 | 16 | 17 | ### dataset 18 | dataset: gsm8k_gemma2_264_4 19 | template: gemma 20 | cutoff_len: 1024 21 | max_samples: 1000000 22 | overwrite_cache: true 23 | preprocessing_num_workers: 16 24 | 25 | ### output 26 | output_dir: saves/gemma-2-2b-it/lora/dpo 27 | logging_steps: 1 28 | save_steps: 500 29 | plot_loss: true 30 | overwrite_output_dir: true 31 | report_to: wandb 32 | 33 | ### train 34 | per_device_train_batch_size: 2 35 | gradient_accumulation_steps: 32 36 | learning_rate: 5.0e-6 37 | num_train_epochs: 1.0 38 | lr_scheduler_type: linear 39 | warmup_ratio: 0.03 40 | bf16: true 41 | ddp_timeout: 180000000 42 | 43 | ### eval 44 | val_size: 0.1 45 | per_device_eval_batch_size: 4 46 | eval_strategy: steps 47 | eval_steps: 5 48 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/lora/gemma2_lora_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: google/gemma-2-2b-it 3 | 4 | ### method 5 | stage: rm 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | lora_rank: 8 11 | 12 | 13 | ### dataset 14 | dataset: gsm8k_gemma2_264 15 | template: gemma 16 | cutoff_len: 1024 17 | max_samples: 1000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/gemma-2-2b-it/lora/reward 23 | logging_steps: 10 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 8 32 | learning_rate: 1.0e-4 33 | num_train_epochs: 3.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.1 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 10 44 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/lora/llama2_lora_reward.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-2-7b-hf 3 | 4 | ### method 5 | stage: rm 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | lora_rank: 16 11 | 12 | 13 | ### dataset 14 | dataset: gsm8k_olmo_128_4ep 15 | template: llama2 16 | cutoff_len: 1024 17 | max_samples: 1000000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/llama2/lora/reward 23 | logging_steps: 10 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 32 32 | learning_rate: 1.0e-4 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: linear 35 | warmup_ratio: 0.03 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.05 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 10 44 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/lora/llama2_lora_sft.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-2-7b-hf 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: lora 8 | lora_target: all 9 | deepspeed: examples/deepspeed/ds_z3_config.json 10 | lora_rank: 16 11 | 12 | 13 | ### dataset 14 | dataset: gsm8k_gemma2_264 15 | template: llama2 16 | cutoff_len: 1024 17 | max_samples: 1000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: saves/llama2/lora/reward 23 | logging_steps: 10 24 | save_steps: 500 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | report_to: wandb 28 | 29 | ### train 30 | per_device_train_batch_size: 2 31 | gradient_accumulation_steps: 8 32 | learning_rate: 1.0e-4 33 | num_train_epochs: 3.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ### eval 40 | val_size: 0.1 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 10 44 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/gsm8k/lora/olmo_lora_reward.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: allenai/OLMo-7B-0724-Instruct-hf 2 | ### method 3 | stage: rm 4 | do_train: true 5 | finetuning_type: lora 6 | lora_target: all 7 | deepspeed: examples/deepspeed/ds_z3_config.json 8 | lora_rank: 32 9 | ### model 10 | 11 | dataset: gsm8k_olmo_128_4ep 12 | template: olmo 13 | cutoff_len: 1024 14 | max_samples: 300000 15 | overwrite_cache: true 16 | preprocessing_num_workers: 16 17 | 18 | ### output 19 | output_dir: saves/olmo/lora/reward4e-5 20 | logging_steps: 1 21 | save_steps: 500 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | report_to: wandb 25 | 26 | ### train 27 | per_device_train_batch_size: 2 28 | gradient_accumulation_steps: 32 29 | learning_rate: 1.0e-5 30 | num_train_epochs: 1.0 31 | lr_scheduler_type: linear 32 | weight_decay: 0.0 33 | warmup_ratio: 0.03 34 | bf16: true 35 | ddp_timeout: 180000000 36 | 37 | ### eval 38 | val_size: 0.05 39 | per_device_eval_batch_size: 4 40 | eval_strategy: steps 41 | eval_steps: 5 42 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/math/full/llama3_1b8b_full_reward_math.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset_dir: qflow/rm/data/ 14 | dataset: llama-factory_math_llama-3.2-1b-instruct_128_1_train 15 | #eval_dataset: gsm8k_llama3.1-8B_test 16 | template: llama3 17 | cutoff_len: 1024 18 | max_samples: 300000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: qflow/rm/artifacts/llama3/1b8b/math/full/reward 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 1 32 | gradient_accumulation_steps: 32 33 | learning_rate: 1.0e-5 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | weight_decay: 0.0 37 | warmup_ratio: 0.03 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | val_size: 0.05 43 | per_device_eval_batch_size: 2 44 | eval_strategy: steps 45 | eval_steps: 5 46 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/math/full/llama3_8b8b_full_reward_math.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset_dir: qflow/rm/data/ 14 | dataset: llama-factory_math_llama-3.1-8b-instruct_128_1_train 15 | #eval_dataset: gsm8k_llama3.1-8B_test 16 | template: llama3 17 | cutoff_len: 1024 18 | max_samples: 300000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: qflow/rm/artifacts/llama3/8b8b/math/full/reward 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 1 32 | gradient_accumulation_steps: 16 33 | learning_rate: 1.0e-5 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | weight_decay: 0.0 37 | warmup_ratio: 0.03 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | val_size: 0.05 43 | per_device_eval_batch_size: 2 44 | eval_strategy: steps 45 | eval_steps: 5 46 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/math/full/llama3_8b8b_full_reward_math_cot.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset_dir: qflow/rm/data/ 14 | dataset: cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_1_train 15 | 16 | #eval_dataset: gsm8k_llama3.1-8B_test 17 | template: llama3 18 | cutoff_len: 1024 19 | max_samples: 300000 20 | overwrite_cache: true 21 | preprocessing_num_workers: 16 22 | 23 | ### output 24 | output_dir: qflow/rm/artifacts/llama3/8b8b/mathcotnewextractmath14325/full/reward 25 | logging_steps: 1 26 | save_steps: 500 27 | plot_loss: true 28 | overwrite_output_dir: true 29 | report_to: wandb 30 | 31 | ### train 32 | per_device_train_batch_size: 1 33 | gradient_accumulation_steps: 32 34 | learning_rate: 1.0e-5 35 | num_train_epochs: 1.0 36 | lr_scheduler_type: linear 37 | weight_decay: 0.0 38 | warmup_ratio: 0.03 39 | bf16: true 40 | ddp_timeout: 180000000 41 | 42 | ### eval 43 | val_size: 0.05 44 | per_device_eval_batch_size: 2 45 | eval_strategy: steps 46 | eval_steps: 5 47 | -------------------------------------------------------------------------------- /qalign/rm/train_configs/math/full/tulu_8b8b_full_reward_math.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: allenai/Llama-3.1-Tulu-3-8B-SFT 3 | flash_attn: fa2 4 | #attn_implementation: eager 5 | ### method 6 | stage: rm 7 | do_train: true 8 | finetuning_type: full 9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json 10 | save_safetensors: False 11 | 12 | ### dataset 13 | dataset_dir: qflow/rm/data/ 14 | dataset: llama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train 15 | #eval_dataset: gsm8k_llama3.1-8B_test 16 | template: llama3 17 | cutoff_len: 1024 18 | max_samples: 300000 19 | overwrite_cache: true 20 | preprocessing_num_workers: 16 21 | 22 | ### output 23 | output_dir: qflow/rm/artifacts/tulu/8b8b/math/full/reward 24 | logging_steps: 1 25 | save_steps: 500 26 | plot_loss: true 27 | overwrite_output_dir: true 28 | report_to: wandb 29 | 30 | ### train 31 | per_device_train_batch_size: 1 32 | gradient_accumulation_steps: 32 33 | learning_rate: 1.0e-5 34 | num_train_epochs: 1.0 35 | lr_scheduler_type: linear 36 | weight_decay: 0.0 37 | warmup_ratio: 0.03 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | val_size: 0.05 43 | per_device_eval_batch_size: 2 44 | eval_strategy: steps 45 | eval_steps: 5 46 | -------------------------------------------------------------------------------- /qalign/rm/valid_check.py: -------------------------------------------------------------------------------- 1 | import json 2 | from qflow.utils.math import get_last_math 3 | 4 | file_path = "qflow/rm/data/cotllama-factory_math_llama-3.1-8b-instruct_128_1_train.json" 5 | 6 | with open(file_path, "r") as f: 7 | data = json.load(f) 8 | 9 | 10 | count = 0 11 | for entry in data: 12 | right = entry["chosen"] 13 | wrong = entry["rejected"] 14 | 15 | right_answer = get_last_math(right) 16 | wrong_last_line = wrong.split("\n")[-1] 17 | 18 | count += int(right_answer in wrong_last_line) 19 | 20 | if right_answer in wrong_last_line: 21 | print(f"right: [{right_answer}], wrong: >>>\n{'*'*20}\n {wrong_last_line}") 22 | # import pdb 23 | 24 | # pdb.set_trace() 25 | 26 | print(count, 7500) 27 | import pdb 28 | 29 | pdb.set_trace() 30 | -------------------------------------------------------------------------------- /qalign/utils/data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Value 2 | 3 | from langchain.prompts import PromptTemplate 4 | 5 | from functools import partial 6 | 7 | from transformers import AutoTokenizer 8 | 9 | import numpy as np 10 | from datasets import Dataset 11 | 12 | DEFAULT_USER_PROMPT = PromptTemplate.from_template("Question: {content}\n") 13 | DEFAULT_AI_PROMPT = PromptTemplate.from_template("Answer: {content}\n") 14 | 15 | ## qalign 16 | from qalign.utils.examples import MATH_EXAMPLARS, GSM8K_EXEMPLARS 17 | 18 | ## quest 19 | from quest.model.base import ( 20 | DEFAULT_TEMPLATE, 21 | ) 22 | 23 | 24 | class FlexiblePromptTemplate: 25 | def __init__(self, template): 26 | self.template = template 27 | 28 | def format(self, **kwargs): 29 | result = self.template 30 | # Only replace variables that exist in the template 31 | for key, value in kwargs.items(): 32 | placeholder = "{" + key + "}" 33 | if placeholder in result: 34 | result = result.replace(placeholder, str(value)) 35 | return result 36 | 37 | 38 | def parsehh_to_template( 39 | text: str, 40 | br="\n\n", 41 | available_roles={"Human", "Assistant"}, 42 | transfer_roles={ 43 | "Human": "user", 44 | "Assistant": "assistant", 45 | }, 46 | ): 47 | 48 | breaks = text.split(br) 49 | 50 | chlist = [] 51 | for ch in breaks: 52 | try: 53 | subbreaks = ch.split(":") 54 | role = subbreaks[0] 55 | if role in available_roles: 56 | 57 | chlist.append( 58 | { 59 | "role": transfer_roles[role], # .lower(), 60 | "content": subbreaks[1].strip(), 61 | } 62 | ) 63 | else: 64 | 65 | if len(chlist) > 0: 66 | chlist[-1]["content"] += br + ch 67 | 68 | except: 69 | pass 70 | 71 | return chlist 72 | 73 | 74 | def general_process_data_chat( 75 | chat_template_prompt, 76 | tokenizer, 77 | **extra_data, 78 | ): 79 | 80 | return { 81 | "prompt": tokenizer.apply_chat_template( 82 | chat_template_prompt, # [ {"role": "user", "content": ... }, {"role": "assistant", "content": ... }, ... ], 83 | tokenize=False, 84 | add_generation_prompt=True, 85 | ), 86 | "chat_template_prompt": chat_template_prompt, 87 | **extra_data, 88 | } 89 | 90 | 91 | def general_process_data_prompt( 92 | chat_template_prompt, 93 | tokenizer, 94 | user_prompt=DEFAULT_USER_PROMPT, 95 | ai_prompt=DEFAULT_AI_PROMPT, 96 | **extra_data, 97 | ): 98 | concatenated_prompt = tokenizer.bos_token 99 | 100 | for message in chat_template_prompt: 101 | role = message["role"] 102 | content = message["content"].strip() 103 | 104 | if role == "user": 105 | concatenated_prompt += user_prompt.format(content=content) 106 | elif role == "assistant": 107 | concatenated_prompt += ai_prompt.format(content=content) 108 | 109 | concatenated_prompt += "\n" # Add extra newline between Q&A pairs 110 | 111 | return { 112 | "prompt": concatenated_prompt, 113 | "chat_template_prompt": chat_template_prompt, 114 | **extra_data, 115 | } 116 | 117 | 118 | def general_process_data(chat_template_prompt, format="chat", **extra_data): 119 | 120 | if format == "chat": 121 | return general_process_data_chat(chat_template_prompt, **extra_data) 122 | else: 123 | return general_process_data_prompt(chat_template_prompt, **extra_data) 124 | 125 | 126 | def processhh_data(entry, tokenizer, format="chat", prompt_template=None, **kwargs): 127 | 128 | br = "\n\n" 129 | available_roles = {"Human", "Assistant"} 130 | # breaks = entry["chosen"].split(br) 131 | 132 | chat_template = parsehh_to_template( 133 | entry["chosen"], 134 | br=br, 135 | available_roles=available_roles, 136 | ) 137 | chat_template_reject = parsehh_to_template( 138 | entry["rejected"], 139 | br=br, 140 | available_roles=available_roles, 141 | ) 142 | 143 | return general_process_data( 144 | chat_template_prompt=chat_template[:-1], 145 | tokenizer=tokenizer, 146 | format=format, 147 | answer=chat_template[-1]["content"], 148 | bad_answer=chat_template_reject[-1]["content"], 149 | ) 150 | 151 | 152 | def processgsm_data( 153 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs 154 | ): 155 | if use_few_shot: 156 | gsm_messages = [ 157 | [ 158 | { 159 | "role": "user", 160 | "content": prompt_template.format(prompt=sample["question"]), 161 | }, 162 | {"role": "assistant", "content": sample["cot_answer"]}, 163 | ] 164 | for sample in GSM8K_EXEMPLARS 165 | ] 166 | # flatten 167 | gsm_messages = [item for sublist in gsm_messages for item in sublist] 168 | else: 169 | gsm_messages = [] 170 | 171 | chat_template = [ 172 | { 173 | "role": "user", 174 | "content": prompt_template.format(prompt=entry["question"]), 175 | }, 176 | { 177 | "role": "assistant", 178 | "content": entry["answer"], 179 | }, 180 | ] 181 | 182 | return general_process_data_chat( 183 | chat_template_prompt=chat_template[:-1], 184 | format=format, 185 | tokenizer=tokenizer, 186 | answer=chat_template[-1]["content"], 187 | ) 188 | 189 | 190 | def processmath_data( 191 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs 192 | ): 193 | 194 | if use_few_shot: 195 | math_messages = [ 196 | [ 197 | { 198 | "role": "user", 199 | "content": prompt_template.format(prompt=sample["question"]), 200 | }, 201 | {"role": "assistant", "content": sample["cot_answer"]}, 202 | ] 203 | for sample in MATH_EXAMPLARS 204 | ] 205 | # flatten 206 | math_messages = [item for sublist in math_messages for item in sublist] 207 | else: 208 | math_messages = [] 209 | 210 | chat_template = math_messages + [ 211 | { 212 | "role": "user", 213 | "content": prompt_template.format(prompt=entry["problem"]), 214 | }, 215 | { 216 | "role": "assistant", 217 | "content": entry["solution"], 218 | }, 219 | ] 220 | 221 | # Get the final answer for evaluation 222 | # answer = get_last_math(chat_template[-1]["content"]) 223 | # if answer is None: 224 | # return None 225 | 226 | return general_process_data( 227 | chat_template_prompt=chat_template[:-1], 228 | format=format, 229 | tokenizer=tokenizer, 230 | answer=entry["solution"], # answer, 231 | ) 232 | 233 | 234 | def processalpaca_data( 235 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs 236 | ): 237 | # import pdb; pdb.set_trace() 238 | 239 | chat_template = [ 240 | { 241 | "role": "user", 242 | "content": prompt_template.format(prompt=entry["instruction"]), 243 | }, 244 | { 245 | "role": "assistant", 246 | "content": entry["output"], 247 | }, 248 | ] 249 | 250 | return general_process_data( 251 | chat_template_prompt=chat_template[:-1], 252 | format=format, 253 | tokenizer=tokenizer, 254 | answer=chat_template[-1]["content"], 255 | ) 256 | 257 | 258 | def processifeval_data( 259 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs 260 | ): 261 | # import pdb; pdb.set_trace() 262 | 263 | chat_template = [ 264 | { 265 | "role": "user", 266 | "content": prompt_template.format(prompt=entry["prompt"]), 267 | } 268 | ] 269 | 270 | return general_process_data( 271 | chat_template_prompt=chat_template, 272 | format=format, 273 | tokenizer=tokenizer, 274 | answer="no answer", 275 | ) 276 | 277 | 278 | def processmmlu_data( 279 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs 280 | ): 281 | 282 | choice_options = ["A", "B", "C", "D"] 283 | 284 | l = entry["subject"].split("_") 285 | subject = "" 286 | for x in l: 287 | subject += " " + x 288 | 289 | instruction = f"Choose the correct answer to the following multiple-choice question about {subject}.\n\n" 290 | 291 | prompt = entry["question"] 292 | choices = entry["choices"] 293 | answer = int(entry["answer"]) 294 | 295 | instruction += "Question: {}\n\n".format(prompt) 296 | 297 | for j in range(len(choices)): 298 | instruction += "{}). {}\n".format(choice_options[j], choices[j]) 299 | 300 | instruction += "\nProvide your reasoning about the answer and finish your answer with the letter corresponding to the correct option (e.g., A, B, C, or D).\n\n" 301 | 302 | prompt_ans = "\nAnswer:" 303 | 304 | # if include_answer: 305 | prompt_ans += " {}\n\n".format(choice_options[answer]) 306 | 307 | entry = {"prompt": instruction, "answer": prompt_ans} 308 | 309 | chat_template = [ 310 | { 311 | "role": "user", 312 | "content": prompt_template.format(prompt=instruction), 313 | }, 314 | { 315 | "role": "assistant", 316 | "content": prompt_ans, 317 | }, 318 | ] 319 | 320 | # Get the final answer for evaluation 321 | # answer = get_last_math(chat_template[-1]["content"]) 322 | # if answer is None: 323 | # return None 324 | 325 | return general_process_data( 326 | chat_template_prompt=chat_template[:-1], 327 | format=format, 328 | tokenizer=tokenizer, 329 | answer=prompt_ans, # answer, 330 | ) 331 | 332 | 333 | def processtruthfulqa_data( 334 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs 335 | ): 336 | choices_key = "mc1_targets" 337 | 338 | choice_options = [chr(i) for i in range(ord("A"), ord("Z") + 1)] 339 | 340 | instruction = ( 341 | "Choose the correct answer to the following multiple-choice question.\n\n" 342 | ) 343 | 344 | prompt = entry["question"] 345 | choices = entry[choices_key]["choices"] 346 | answer = np.argmax(entry[choices_key]["labels"]) 347 | 348 | # answer = int(entry["answer"]) 349 | 350 | instruction += "Question: {}\n\n".format(prompt) 351 | 352 | for j in range(len(choices)): 353 | instruction += "{}). {}\n".format(choice_options[j], choices[j]) 354 | 355 | instruction += "\nProvide your reasoning about the answer and finish your answer with the letter corresponding to the correct option (e.g., A, B, C, or D).\n\n" 356 | 357 | prompt_ans = "\nAnswer:" 358 | 359 | # if include_answer: 360 | prompt_ans += " {}\n\n".format(choice_options[answer]) 361 | 362 | entry = {"prompt": instruction, "answer": prompt_ans} 363 | 364 | chat_template = [ 365 | { 366 | "role": "user", 367 | "content": prompt_template.format(prompt=instruction), 368 | }, 369 | { 370 | "role": "assistant", 371 | "content": prompt_ans, 372 | }, 373 | ] 374 | 375 | return general_process_data( 376 | chat_template_prompt=chat_template[:-1], 377 | format=format, 378 | tokenizer=tokenizer, 379 | answer=prompt_ans, # answer, 380 | ) 381 | 382 | 383 | def process_all_redux_data(): 384 | 385 | subjects = [ 386 | "anatomy", 387 | "business_ethics", 388 | "clinical_knowledge", 389 | "college_chemistry", 390 | "college_computer_science", 391 | "college_mathematics", 392 | "college_medicine", 393 | "college_physics", 394 | "econometrics", 395 | "electrical_engineering", 396 | "formal_logic", 397 | "global_facts", 398 | "high_school_chemistry", 399 | "high_school_mathematics", 400 | "high_school_physics", 401 | "high_school_statistics", 402 | "human_aging", 403 | "logical_fallacies", 404 | "machine_learning", 405 | "miscellaneous", 406 | "philosophy", 407 | "professional_accounting", 408 | "public_relations", 409 | "virology", 410 | "conceptual_physics", 411 | "high_school_us_history", 412 | "astronomy", 413 | "high_school_geography", 414 | "high_school_macroeconomics", 415 | "professional_law", 416 | ] 417 | 418 | ds = [] 419 | for subject in subjects: 420 | 421 | dsi = load_dataset("edinburgh-dawg/mmlu-redux", subject, split="test") 422 | ds.extend([{"subject": subject, **x} for x in dsi]) 423 | 424 | dataset = Dataset.from_list(ds) 425 | 426 | return dataset 427 | 428 | 429 | SUPPORTED_DATASETS = { 430 | "openai/gsm8k": ({"name": "socratic"}, processgsm_data, "openai/gsm8k"), 431 | "apple/GSM-Symbolic-p1": ({"name": "p1"}, processgsm_data, "apple/GSM-Symbolic"), 432 | "apple/GSM-Symbolic-p2": ({"name": "p2"}, processgsm_data, "apple/GSM-Symbolic"), 433 | "Anthropic/hh-rlhf": ({}, processhh_data, "Anthropic/hh-rlhf"), 434 | "lighteval/MATH": ({}, processmath_data, "lighteval/MATH"), 435 | "cais/mmlu": ({"name": "all"}, processmmlu_data, "cais/mmlu"), 436 | "truthfulqa/truthful_qa": ( 437 | {"name": "multiple_choice"}, 438 | processtruthfulqa_data, 439 | "truthfulqa/truthful_qa", 440 | ), 441 | "HuggingFaceH4/MATH-500": ( 442 | {}, 443 | processmath_data, 444 | "HuggingFaceH4/MATH-500", 445 | ), 446 | "tatsu-lab/alpaca_eval": ( 447 | {}, 448 | processalpaca_data, 449 | "tatsu-lab/alpaca_eval", 450 | ), 451 | "google/IFEval": ( 452 | {}, 453 | processifeval_data, 454 | "google/IFEval", 455 | ), 456 | "edinburgh-dawg/mmlu-redux": ( 457 | {}, 458 | processmmlu_data, 459 | "edinburgh-dawg/mmlu-redux", 460 | process_all_redux_data, 461 | ), 462 | # eval 463 | # "tatsu-lab/alpaca_eval", "alpaca_eval")["eval"] 464 | # processifeval_data 465 | } 466 | 467 | # edinburgh-dawg/mmlu-redux 468 | 469 | 470 | # ds = load_dataset("cais/mmlu", "abstract_algebra", split="test") 471 | 472 | 473 | def get_data_iterable( 474 | model_path: str, 475 | dataset_path: str, 476 | split: str = "test", 477 | n=None, 478 | prompt_template=DEFAULT_TEMPLATE, 479 | use_few_shot=False, 480 | format="chat", 481 | **kwargs, 482 | ): 483 | 484 | tokenizer = AutoTokenizer.from_pretrained(model_path) 485 | 486 | if dataset_path in SUPPORTED_DATASETS: 487 | 488 | resources = SUPPORTED_DATASETS[dataset_path] 489 | 490 | if len(resources) == 3: 491 | kwargs, pfunc, dataset_path = SUPPORTED_DATASETS[dataset_path] 492 | 493 | ds = load_dataset( 494 | dataset_path, 495 | split=split, 496 | # desc=None, 497 | **kwargs, 498 | ) 499 | else: 500 | kwargs, pfunc, dataset_path, load_func = SUPPORTED_DATASETS[dataset_path] 501 | ds = load_func() 502 | 503 | if "answer" in ds.column_names: 504 | ds = ds.cast_column("answer", Value("string")) 505 | 506 | # limit the size of the dataset 507 | 508 | if n is not None: 509 | ds = ds.select(list(range(n))) 510 | 511 | ds = ds.map( 512 | partial( 513 | pfunc, 514 | tokenizer=tokenizer, 515 | prompt_template=prompt_template, 516 | use_few_shot=use_few_shot, 517 | format=format, 518 | ), 519 | load_from_cache_file=False, 520 | # desc=None, # Disable tqdm progress bar 521 | ) 522 | 523 | else: 524 | ds = load_dataset(dataset_path, split=split, trust_remote_code=True, **kwargs) 525 | 526 | if n is not None: 527 | ds = list(ds)[:n] 528 | 529 | return list(ds) 530 | -------------------------------------------------------------------------------- /qalign/utils/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import multiprocessing as mp 4 | 5 | 6 | from typing import * 7 | 8 | 9 | from tqdm import tqdm 10 | 11 | import numpy as np 12 | 13 | ## qalign 14 | from qalign.utils.ifeval.eval import test_instruction_following_loose 15 | from qalign.utils.math import get_last_math, get_last_option, get_last_number 16 | 17 | ## quest 18 | from quest.reward.model import ValueHead 19 | from quest.reward.remote import RemoteReward 20 | from quest.utils.list import ( 21 | flatten_list, 22 | unflatten_list, 23 | get_unique_mapping, 24 | invert_unique_mapping, 25 | chunked, 26 | ) 27 | from quest import ( 28 | Reward, 29 | RewardModel, 30 | ContextualRewardModel, 31 | ) 32 | 33 | ## expkit 34 | from expkit import ( 35 | ExpSetup, 36 | Exp, 37 | Evalutor, 38 | ) 39 | from expkit.ops import proj 40 | 41 | 42 | def compress_predictions(instances, n=None): 43 | prompts, tokens, vocabs, counts = ( 44 | [], 45 | [], 46 | [], 47 | [], 48 | ) 49 | 50 | # instances = experiment.instances() 51 | if n is not None: 52 | instances = instances[:n] 53 | 54 | for i in instances: 55 | 56 | tksi, vcbi = get_unique_mapping( 57 | list( 58 | map( 59 | proj("text"), 60 | i["outputs"], 61 | ) 62 | ) 63 | ) 64 | 65 | promptsi = [i["input"]["prompt"]] * len(vcbi) 66 | 67 | prompts.append(promptsi) 68 | vocabs.append(vcbi) 69 | tokens.append(tksi) 70 | counts.append(len(vcbi)) 71 | 72 | return prompts, tokens, vocabs, counts 73 | 74 | 75 | class ExactEval(Evalutor): 76 | 77 | def __init__(self, process_fn, name="exact"): 78 | super().__init__(name) 79 | 80 | self.process_fn = process_fn 81 | 82 | def eval(self, experiment: Exp): 83 | 84 | results = [] 85 | 86 | instances = experiment.instances() 87 | 88 | if not isinstance(instances, list): 89 | instances = [instances] 90 | 91 | for i in tqdm(instances): 92 | 93 | preds = [self.process_fn(x["text"]) for x in i["outputs"]] 94 | 95 | target = self.process_fn(i["input"]["answer"]) 96 | 97 | scores = list( 98 | map( 99 | lambda p: int(p == target), 100 | preds, 101 | ) 102 | ) 103 | 104 | results.append({"scores": scores}) 105 | 106 | return results 107 | 108 | 109 | class IFEval(Evalutor): 110 | 111 | def __init__(self, name="ifeval"): 112 | super().__init__(name) 113 | 114 | def create_eval_pairs(self, experiment): 115 | """Create all input-output pairs for evaluation.""" 116 | instances = experiment.instances() 117 | if not isinstance(instances, list): 118 | instances = [instances] 119 | 120 | eval_pairs = [] 121 | for idx, instance in enumerate(instances): 122 | for output in instance["outputs"]: 123 | eval_pairs.append( 124 | { 125 | "instance_idx": idx, 126 | "input": instance["input"], 127 | "output": { 128 | "text": output["text"].replace("<|end_of_text|>", "") 129 | }, 130 | } 131 | ) 132 | return eval_pairs, len(instances) 133 | 134 | def process_single_pair(self, pair): 135 | """Process a single input-output pair.""" 136 | 137 | return { 138 | "instance_idx": pair["instance_idx"], 139 | "score": test_instruction_following_loose(pair["input"], pair["output"])[ 140 | "follow_all_instructions" 141 | ], 142 | } 143 | 144 | def eval(self, experiment: Exp): 145 | # Step 1: Create all evaluation pairs 146 | eval_pairs, num_instances = self.create_eval_pairs(experiment) 147 | 148 | # Step 2: Process pairs in parallel 149 | with mp.Pool() as pool: 150 | results_flat = list( 151 | tqdm( # pool.i 152 | pool.map(self.process_single_pair, eval_pairs), 153 | total=len(eval_pairs), 154 | desc="Processing pairs", 155 | ) 156 | ) 157 | 158 | # Step 3: Reorganize results into original structure 159 | results = [] 160 | for i in range(num_instances): 161 | instance_scores = [ 162 | r["score"] for r in results_flat if r["instance_idx"] == i 163 | ] 164 | results.append({"scores": instance_scores}) 165 | 166 | return results 167 | 168 | 169 | class ExactLastNumberEval(ExactEval): 170 | 171 | def __init__(self): 172 | super().__init__( 173 | process_fn=get_last_number, 174 | name="lastnumber", 175 | ) 176 | 177 | 178 | class ExactMATHEval(ExactEval): 179 | 180 | def __init__(self): 181 | super().__init__( 182 | process_fn=get_last_math, 183 | name="lastmath", 184 | ) 185 | 186 | 187 | class ExactQAEval(ExactEval): 188 | 189 | def __init__(self): 190 | super().__init__( 191 | process_fn=get_last_option, 192 | name="lastoption", 193 | ) 194 | 195 | 196 | class RewardEval(Evalutor): 197 | 198 | def __init__(self, reward: Reward, n=None, chunk_size=128): 199 | super().__init__(reward.get_name()) 200 | self.reward = reward 201 | self.n = n 202 | self.chunk_size = chunk_size 203 | 204 | def eval(self, experiment: Exp): 205 | 206 | all_duplicated_scores = [] 207 | 208 | for instance_chunk in chunked( 209 | experiment.instances(lazy_iterable=True), self.chunk_size 210 | ): 211 | # Process current chunk of instances 212 | prompts, tokens, vocabs, counts = compress_predictions( 213 | instance_chunk, n=self.n 214 | ) 215 | 216 | # Set context for this chunk only 217 | if isinstance( 218 | self.reward, 219 | (ContextualRewardModel, ValueHead, RemoteReward), 220 | ): 221 | self.reward.set_context(context=flatten_list(prompts)) 222 | 223 | # Evaluate just this chunk's candidates 224 | chunk_flat_vocabs = flatten_list(vocabs) 225 | chunk_scores = self.reward.evaluate( 226 | candidates=chunk_flat_vocabs, 227 | use_tqdm=True, # Disable per-chunk progress bars 228 | ) 229 | 230 | # Process scores for this chunk only 231 | chunk_unflattened = unflatten_list(chunk_scores, counts) 232 | all_duplicated_scores.extend( 233 | [ 234 | {"scores": invert_unique_mapping(tks, rsi)} 235 | for tks, rsi in zip(tokens, chunk_unflattened) 236 | ] 237 | ) 238 | 239 | del instance_chunk 240 | del prompts 241 | del tokens 242 | del vocabs 243 | del counts 244 | 245 | return all_duplicated_scores 246 | -------------------------------------------------------------------------------- /qalign/utils/examples.py: -------------------------------------------------------------------------------- 1 | # exemplars we will use to prompt the model 2 | # These examplars are from the Table 20 of CoT paper (https://arxiv.org/pdf/2201.11903.pdf). 3 | GSM8K_EXEMPLARS = [ 4 | { 5 | "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?", 6 | "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.", 7 | "short_answer": "6", 8 | }, 9 | { 10 | "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?", 11 | "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. So the answer is 5.", 12 | "short_answer": "5", 13 | }, 14 | { 15 | "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?", 16 | "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. So the answer is 39.", 17 | "short_answer": "39", 18 | }, 19 | { 20 | "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?", 21 | "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. So the answer is 8.", 22 | "short_answer": "8", 23 | }, 24 | { 25 | "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?", 26 | "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. So the answer is 9.", 27 | "short_answer": "9", 28 | }, 29 | { 30 | "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?", 31 | "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. So the answer is 29.", 32 | "short_answer": "29", 33 | }, 34 | { 35 | "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?", 36 | "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. So the answer is 33.", 37 | "short_answer": "33", 38 | }, 39 | { 40 | "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?", 41 | "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. So the answer is 8.", 42 | "short_answer": "8", 43 | }, 44 | ] 45 | 46 | # These examplars are from the DeepSeekMath GitHub repository (https://github.com/deepseek-ai/DeepSeek-Math/tree/main/evaluation/few_shot_prompts) 47 | MATH_EXAMPLARS = [ 48 | { 49 | "question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", 50 | "cot_answer": "The expressions inside each square root must be non-negative.\nTherefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", 51 | "short_answer": "[2,5)", 52 | }, 53 | { 54 | "question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", 55 | "cot_answer": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$", 56 | "short_answer": "24", 57 | }, 58 | { 59 | "question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", 60 | "cot_answer": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}\n30n&=480\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}", 61 | "short_answer": "16", 62 | }, 63 | { 64 | "question": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.", 65 | "cot_answer": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$", 66 | "short_answer": "-\\frac{2}{3}", 67 | }, 68 | ] 69 | -------------------------------------------------------------------------------- /qalign/utils/experiment.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any, List 2 | from datetime import datetime 3 | from tqdm import tqdm 4 | import os 5 | import copy 6 | 7 | ## expkit 8 | from expkit.exp import Exp 9 | from expkit.storage import DiskStorage 10 | 11 | ## quest 12 | from quest.reward.model import ContextualRewardModel, ValueHead 13 | from quest.reward.remote import RemoteReward 14 | from quest.proposal import RLHFSuffixProposal 15 | from quest.core import Quest 16 | 17 | ## qalign 18 | from qalign.utils.data import FlexiblePromptTemplate 19 | from qalign.utils.data import get_data_iterable 20 | from quest.utils.list import chunked 21 | 22 | 23 | ## literegistry 24 | from literegistry import RegistryClient, FileSystemKVStore 25 | 26 | 27 | def create_experiment( 28 | save_path: str, 29 | variant: str, 30 | model_path: str, 31 | dataset_path: str, 32 | n: int, 33 | temperature: float, 34 | steps: int, 35 | max_new_tokens: int = 100, 36 | max_prompt_length: int = 600, 37 | batch_size: int = 64, 38 | split: str = "test", 39 | prompt_template: str = "{prompt}", 40 | stop_tokens: Optional[List[str]] = None, 41 | additional_meta: Optional[Dict[str, Any]] = None, 42 | format: str = "chat", # either "chat" or "prompt" 43 | use_few_shot: bool = False, 44 | ) -> Exp: 45 | """ 46 | Creates a standardized experiment with common metadata. 47 | """ 48 | if stop_tokens is None: 49 | stop_tokens = [] 50 | 51 | meta = { 52 | "steps": steps, 53 | "temperature": temperature, 54 | "n": n, 55 | "model_path": model_path, 56 | "variant": variant, 57 | "stop_tokens": stop_tokens, 58 | "max_new_tokens": max_new_tokens, 59 | "max_prompt_length": max_prompt_length, 60 | "at": datetime.now().isoformat(), 61 | "dataset": dataset_path, 62 | "split": split, 63 | "prompt_template": prompt_template, 64 | "batch_size": batch_size, 65 | "format": format, 66 | "use_few_shot": use_few_shot, 67 | } 68 | 69 | if additional_meta: 70 | meta.update(additional_meta) 71 | 72 | return Exp( 73 | storage=DiskStorage(save_path, "rw"), 74 | meta=meta, 75 | ) 76 | 77 | 78 | def create_extension_experiment(storage, experiment, new_steps=1024): 79 | samples = [ 80 | { 81 | "input": ( 82 | data["input"]["input"] if "input" in data["input"] else data["input"] 83 | ), 84 | "completion": data["outputs"][-1]["text"], 85 | "reward": float(data["outputs"][-1]["reward"]), 86 | } 87 | for data in experiment.instances(lazy_iterable=True) 88 | ] 89 | 90 | meta = copy.deepcopy(experiment.meta) 91 | 92 | meta["steps"] = new_steps 93 | meta["bootstrap"] = samples 94 | meta["link"] = experiment.get_name() 95 | 96 | new_exp = Exp( 97 | storage=storage, 98 | name=experiment.get_name() + "-extension", 99 | meta=meta, 100 | ) 101 | 102 | return new_exp 103 | 104 | 105 | def create_vllm_model( 106 | model_path: str, 107 | temperature: float, 108 | max_new_tokens: int = 100, 109 | max_prompt_length: int = 600, 110 | stop_tokens: Optional[List[str]] = None, 111 | device_count: int = 1, 112 | gpu_memory_utilization: float = 0.8, 113 | # prompt_template: Optional[str] = None, 114 | enforce_eager: bool = False, 115 | remote: bool = False, 116 | ): 117 | """ 118 | Creates a standardized VLLM model instance with common configurations. 119 | """ 120 | if stop_tokens is None: 121 | stop_tokens = [""] 122 | 123 | extra_args = {} 124 | if "bnb" in model_path: 125 | extra_args.update( 126 | { 127 | "trust_remote_code": True, 128 | "quantization": "bitsandbytes", 129 | "load_format": "bitsandbytes", 130 | } 131 | ) 132 | 133 | model_args = { 134 | "model_path": model_path, 135 | "download_dir": os.environ.get("HF_HOME", "/tmp/"), 136 | "stop_tokens": stop_tokens, 137 | "temperature": temperature, 138 | "gpu_memory_utilization": gpu_memory_utilization, 139 | "dtype": "bfloat16", 140 | "max_new_tokens": max_new_tokens, 141 | "max_prompt_length": max_prompt_length, 142 | "tensor_parallel_size": device_count, 143 | "enable_prefix_caching": True, 144 | "enforce_eager": enforce_eager, 145 | **extra_args, 146 | } 147 | 148 | if remote: 149 | from quest.model.remote import RemoteVLLM 150 | 151 | registry = RegistryClient( 152 | store=FileSystemKVStore("/gscratch/ark/graf/registry"), 153 | max_history=3600, 154 | cache_ttl=60, 155 | service_type="model_path", 156 | ) 157 | 158 | return RemoteVLLM(registry=registry, **model_args) 159 | else: 160 | from quest.model.vllm import VLLM 161 | 162 | return VLLM(**model_args) 163 | 164 | 165 | def process_batch_outputs( 166 | chain_outputs: Any, batch_size: int 167 | ) -> List[List[Dict[str, Any]]]: 168 | """ 169 | Processes batch outputs from a Quest chain into a standardized format. 170 | """ 171 | outputs = [] 172 | for i in range(batch_size): 173 | outputs.append( 174 | [ 175 | { 176 | "t": s["t"], 177 | **{k: v[i] for k, v in s.items() if k != "t"}, 178 | } 179 | for s in chain_outputs.state_path 180 | ] 181 | ) 182 | return outputs 183 | 184 | 185 | def get_batched_data( 186 | model, 187 | dataset_path: str, 188 | split: str, 189 | n: int, 190 | batch_size: int, 191 | prompt_template: str, 192 | start_index: int = 0, 193 | num_chains: int = 1, 194 | completed: int = 0, 195 | format="chat", 196 | use_few_shot=False, 197 | ) -> List[Any]: 198 | """ 199 | Gets batched data from a dataset using standard configurations. 200 | """ 201 | data_iterable = get_data_iterable( 202 | model_path=model.model_path, 203 | dataset_path=dataset_path, 204 | split=split, 205 | n=start_index + n, 206 | tokenizer=model.tokenizer, 207 | prompt_template=FlexiblePromptTemplate(prompt_template), 208 | format=format, 209 | use_few_shot=use_few_shot, 210 | ) 211 | 212 | if start_index > 0: 213 | data_iterable = data_iterable[start_index : start_index + n] 214 | 215 | if num_chains > 1: 216 | data_iterable = [x for x in data_iterable for _ in range(num_chains)] 217 | 218 | data_iterable = data_iterable[completed:] 219 | 220 | batches = [] 221 | for i in range(0, len(data_iterable), batch_size): 222 | batches.append(data_iterable[i : i + batch_size]) 223 | 224 | return batches 225 | 226 | 227 | def calculate_reward_scores( 228 | experiment: Exp, 229 | reward_key: Optional[str] = None, 230 | ) -> List[Dict[str, List[float]]]: 231 | """ 232 | Calculates reward scores for experiment instances. 233 | """ 234 | beta = experiment.meta.get("beta", 1.0) 235 | 236 | if not reward_key and "reward_model_path" in experiment.meta: 237 | reward_key = "crm:" + experiment.meta["reward_model_path"].split(".")[ 238 | 0 239 | ].replace("/", "-") 240 | 241 | return [ 242 | {"scores": [float(o["reward"]) * beta for o in i["outputs"]]} 243 | for i in experiment.instances(lazy_iterable=True) 244 | ] 245 | 246 | 247 | def run_ancestral(experiment, model, steps, data_batches): 248 | 249 | steps = experiment.meta["steps"] 250 | 251 | # Process each batch 252 | for data_batch in tqdm(data_batches): 253 | completions_txt = model.ancestral(data_batch, n=steps) 254 | outputs = [ 255 | [{"text": state_t} for state_t in instance_txt] 256 | for instance_txt in completions_txt 257 | ] 258 | 259 | experiment.add_instances( 260 | inputs=data_batch, 261 | outputs=outputs, 262 | ) 263 | 264 | 265 | def run_quest( 266 | experiment, 267 | model, 268 | steps, 269 | data_batches, 270 | reward_model_batch_size=64, 271 | reward_device=1, 272 | reward_device_count=1, 273 | remote=False, 274 | ): 275 | 276 | reward_type = experiment.meta["reward_type"] 277 | 278 | if remote: 279 | registry = RegistryClient( 280 | store=FileSystemKVStore("/gscratch/ark/graf/registry"), 281 | max_history=3600, 282 | cache_ttl=60, 283 | service_type="model_path", 284 | ) 285 | 286 | reward = RemoteReward( 287 | registry=registry, 288 | model_path=experiment.meta["reward_model_path"], 289 | reward_type=reward_type, 290 | # batch_size=reward_model_batch_size, 291 | ) 292 | 293 | else: 294 | if reward_type == "contextual": 295 | reward = ContextualRewardModel( 296 | model_path=experiment.meta["reward_model_path"], 297 | # batch_size=reward_model_batch_size, 298 | device=reward_device, 299 | device_count=reward_device_count, 300 | ) 301 | elif reward_type == "value": 302 | reward = ValueHead( 303 | model_path=experiment.meta["reward_model_path"], 304 | # batch_size=reward_model_batch_size, 305 | device=reward_device, 306 | device_count=reward_device_count, 307 | ) # sentiment model. 308 | else: 309 | raise ValueError(f"Unknown reward type: {reward_type}") 310 | 311 | # Process each batch 312 | for data_batch in data_batches: 313 | context = [model.get_prompt(**data) for data in data_batch] 314 | reward.set_context(context) 315 | 316 | chain = Quest( 317 | input_data=data_batch, 318 | proposal=RLHFSuffixProposal( 319 | model=model, 320 | reward=reward, 321 | ), 322 | beta=experiment.meta["beta"], 323 | ) 324 | 325 | chain_outputs = chain.run( 326 | steps=steps, 327 | use_tqdm=True, 328 | ) 329 | 330 | outputs = process_batch_outputs(chain_outputs, len(data_batch)) 331 | experiment.add_instances( 332 | inputs=data_batch, 333 | outputs=outputs, 334 | ) 335 | 336 | # Calculate and add reward scores 337 | scores = calculate_reward_scores(experiment) 338 | experiment.add_eval(reward.get_name(), scores) 339 | 340 | return experiment 341 | 342 | 343 | def run_quest_bootstrap( 344 | experiment, 345 | model, 346 | steps, 347 | reward_device=1, 348 | reward_device_count=1, 349 | remote=False, 350 | ): 351 | 352 | reward_type = experiment.meta["reward_type"] 353 | 354 | if remote: 355 | registry = RegistryClient( 356 | store=FileSystemKVStore("/gscratch/ark/graf/registry"), 357 | max_history=3600, 358 | cache_ttl=60, 359 | service_type="model_path", 360 | ) 361 | 362 | reward = RemoteReward( 363 | registry=registry, 364 | model_path=experiment.meta["reward_model_path"], 365 | reward_type=reward_type, 366 | # batch_size=reward_model_batch_size, 367 | ) 368 | 369 | else: 370 | if reward_type == "contextual": 371 | reward = ContextualRewardModel( 372 | model_path=experiment.meta["reward_model_path"], 373 | # batch_size=reward_model_batch_size, 374 | device=reward_device, 375 | device_count=reward_device_count, 376 | ) 377 | elif reward_type == "value": 378 | reward = ValueHead( 379 | model_path=experiment.meta["reward_model_path"], 380 | # batch_size=reward_model_batch_size, 381 | device=reward_device, 382 | device_count=reward_device_count, 383 | ) # sentiment model. 384 | else: 385 | raise ValueError(f"Unknown reward type: {reward_type}") 386 | 387 | for data_batch in chunked( 388 | experiment.meta["bootstrap"], experiment.get("batch_size") 389 | ): 390 | context = [data["input"]["prompt"] for data in data_batch] 391 | 392 | reward.set_context(context) 393 | 394 | chain = Quest( 395 | input_data=[data["input"] for data in data_batch], 396 | proposal=RLHFSuffixProposal( 397 | model=model, 398 | reward=reward, 399 | ), 400 | beta=experiment.meta["beta"], 401 | ) 402 | 403 | chain_outputs = chain.run( 404 | steps=steps, 405 | use_tqdm=True, 406 | warm_start=[ 407 | {"completion": data["completion"], "reward": data["reward"]} 408 | for data in data_batch 409 | ], 410 | ) 411 | 412 | outputs = process_batch_outputs(chain_outputs, len(data_batch)) 413 | 414 | experiment.add_instances( 415 | inputs=data_batch, 416 | outputs=outputs, 417 | ) 418 | 419 | # Calculate and add reward scores 420 | scores = calculate_reward_scores(experiment) 421 | experiment.add_eval(reward.get_name(), scores) 422 | 423 | return experiment 424 | 425 | 426 | # Create model 427 | def run_experiment( 428 | experiment, 429 | gpu_memory_utilization=0.95, 430 | device_count=1, 431 | reward_model_batch_size=64, 432 | reward_device=1, 433 | reward_device_count=1, 434 | remote=False, 435 | ): 436 | 437 | # Create model 438 | model = create_vllm_model( 439 | model_path=experiment.meta["model_path"], 440 | temperature=experiment.meta["temperature"], 441 | max_new_tokens=experiment.meta["max_new_tokens"], 442 | max_prompt_length=experiment.meta["max_prompt_length"], 443 | device_count=device_count, 444 | gpu_memory_utilization=gpu_memory_utilization, 445 | stop_tokens=experiment.meta["stop_tokens"], 446 | remote=remote, 447 | ) 448 | 449 | completed = len(experiment.instances()) 450 | 451 | # Get batched data with start index 452 | data_batches = get_batched_data( 453 | model=model, 454 | dataset_path=experiment.meta["dataset"], 455 | split=experiment.meta["split"], 456 | n=experiment.meta["n"], 457 | batch_size=experiment.meta.get("batch_size", 64), 458 | prompt_template=experiment.meta["prompt_template"], 459 | start_index=experiment.meta.get("i", 0), 460 | num_chains=experiment.meta.get("num_chains", 1), 461 | completed=completed, 462 | format=experiment.meta.get("format", "chat"), 463 | use_few_shot=experiment.meta.get("use_few_shot", False), 464 | ) 465 | 466 | if experiment.meta["variant"] == "ancestral": 467 | run_ancestral( 468 | experiment=experiment, 469 | model=model, 470 | steps=experiment.meta["steps"], 471 | data_batches=data_batches, 472 | ) 473 | elif experiment.meta["variant"] == "quest-rlhf": 474 | run_quest( 475 | experiment=experiment, 476 | model=model, 477 | steps=experiment.meta["steps"], 478 | data_batches=data_batches, 479 | reward_model_batch_size=reward_model_batch_size, 480 | reward_device=reward_device, 481 | reward_device_count=reward_device_count, 482 | remote=True, # remote, 483 | ) 484 | 485 | 486 | # Create model 487 | def run_experiment_remote( 488 | experiment, 489 | ): 490 | 491 | # Create model 492 | model = create_vllm_model( 493 | model_path=experiment.meta["model_path"], 494 | temperature=experiment.meta["temperature"], 495 | max_new_tokens=experiment.meta["max_new_tokens"], 496 | max_prompt_length=experiment.meta["max_prompt_length"], 497 | device_count=1, 498 | gpu_memory_utilization=0.95, 499 | stop_tokens=experiment.meta["stop_tokens"], 500 | remote=True, 501 | ) 502 | 503 | completed = len(experiment.instances()) 504 | 505 | # Get batched data with start index 506 | if "bootstrap" in experiment.meta: 507 | 508 | if "bootstrap" in experiment.meta: 509 | run_quest_bootstrap( 510 | experiment=experiment, 511 | model=model, 512 | steps=experiment.meta["steps"], 513 | reward_device=0, 514 | reward_device_count=1, 515 | remote=True, 516 | ) 517 | 518 | else: 519 | 520 | data_batches = get_batched_data( 521 | model=model, 522 | dataset_path=experiment.meta["dataset"], 523 | split=experiment.meta["split"], 524 | n=experiment.meta["n"], 525 | batch_size=experiment.meta.get("batch_size", 64), 526 | prompt_template=experiment.meta["prompt_template"], 527 | start_index=experiment.meta.get("i", 0), 528 | num_chains=experiment.meta.get("num_chains", 1), 529 | completed=completed, 530 | format=experiment.meta.get("format", "chat"), 531 | use_few_shot=experiment.meta.get("use_few_shot", False), 532 | ) 533 | 534 | if experiment.meta["variant"] == "ancestral": 535 | run_ancestral( 536 | experiment=experiment, 537 | model=model, 538 | steps=experiment.meta["steps"], 539 | data_batches=data_batches, 540 | ) 541 | elif experiment.meta["variant"] == "quest-rlhf": 542 | 543 | run_quest( 544 | experiment=experiment, 545 | model=model, 546 | steps=experiment.meta["steps"], 547 | data_batches=data_batches, 548 | # reward_model_batch_size=experiment.meta.get("batch_size", 64), 549 | reward_device=0, 550 | reward_device_count=1, 551 | remote=True, 552 | ) 553 | -------------------------------------------------------------------------------- /qalign/utils/ifeval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__init__.py -------------------------------------------------------------------------------- /qalign/utils/ifeval/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /qalign/utils/ifeval/__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /qalign/utils/ifeval/__pycache__/instructions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/instructions.cpython-38.pyc -------------------------------------------------------------------------------- /qalign/utils/ifeval/__pycache__/instructions_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/instructions_util.cpython-38.pyc -------------------------------------------------------------------------------- /qalign/utils/ifeval/__pycache__/registry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/registry.cpython-38.pyc -------------------------------------------------------------------------------- /qalign/utils/ifeval/eval.py: -------------------------------------------------------------------------------- 1 | from expkit import ExpSetup 2 | from expkit.storage import DiskStorage, CachedRODiskStorage 3 | from expkit import ExpSetup, Exp, ExpSetup 4 | from tqdm import tqdm 5 | import qflow.utils.ifeval.registry as instructions_registry 6 | 7 | 8 | def test_instruction_following_loose( 9 | inp, 10 | oup, 11 | ): 12 | """Tests response for an upper bound for following instructions.""" 13 | response = oup["text"] 14 | r = response.split("\n") 15 | response_remove_first = "\n".join(r[1:]).strip() 16 | response_remove_last = "\n".join(r[:-1]).strip() 17 | response_remove_both = "\n".join(r[1:-1]).strip() 18 | revised_response = response.replace("*", "") 19 | revised_response_remove_first = response_remove_first.replace("*", "") 20 | revised_response_remove_last = response_remove_last.replace("*", "") 21 | revised_response_remove_both = response_remove_both.replace("*", "") 22 | 23 | all_responses = [ 24 | response, 25 | revised_response, 26 | response_remove_first, 27 | response_remove_last, 28 | response_remove_both, 29 | revised_response_remove_first, 30 | revised_response_remove_last, 31 | revised_response_remove_both, 32 | ] 33 | instruction_list = inp["instruction_id_list"] 34 | is_following_list = [] 35 | 36 | for index, instruction_id in enumerate(instruction_list): 37 | instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] 38 | instruction = instruction_cls(instruction_id) 39 | 40 | kwargs = {k: v for k, v in inp["kwargs"][index].items() if v} 41 | 42 | instruction.build_description(**kwargs) 43 | args = instruction.get_instruction_args() 44 | 45 | if args and "prompt" in args: 46 | datasetprompt = inp["chat_template_prompt"][0]["content"] 47 | instruction.build_description(prompt=datasetprompt) 48 | 49 | is_following = False 50 | for r in all_responses: 51 | if r.strip() and instruction.check_following(r): 52 | is_following = True 53 | break 54 | 55 | is_following_list.append(is_following) 56 | 57 | return { 58 | "instruction_id_list": inp["instruction_id_list"], 59 | "follow_all_instructions": int(all(is_following_list)), 60 | "follow_instruction_list": is_following_list, 61 | } 62 | 63 | 64 | """ 65 | storage = CachedRODiskStorage( 66 | base_dir="remote-outputs/", 67 | # mode="rw" 68 | ) 69 | 70 | 71 | setup = ExpSetup( 72 | storage=storage, 73 | ).query({"dataset": "google/IFEval"}) 74 | 75 | exp = setup[0] 76 | 77 | 78 | instances = exp.instances() 79 | results = {"scores": []} 80 | for i in tqdm(instances): 81 | 82 | scoresi = [ 83 | test_instruction_following_loose( 84 | i["input"], 85 | oi, 86 | )["follow_all_instructions"] 87 | for oi in i["outputs"] 88 | ] 89 | 90 | results["scores"].append(scoresi) 91 | 92 | # instruction_id_list = data["input"]["instruction_id_list"] 93 | # prompt = data["input"]["prompt"] 94 | import pdb 95 | 96 | pdb.set_trace() 97 | """ 98 | -------------------------------------------------------------------------------- /qalign/utils/ifeval/registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Registry of all instructions.""" 17 | from qflow.utils.ifeval import instructions 18 | 19 | _KEYWORD = "keywords:" 20 | 21 | _LANGUAGE = "language:" 22 | 23 | _LENGTH = "length_constraints:" 24 | 25 | _CONTENT = "detectable_content:" 26 | 27 | _FORMAT = "detectable_format:" 28 | 29 | _MULTITURN = "multi-turn:" 30 | 31 | _COMBINATION = "combination:" 32 | 33 | _STARTEND = "startend:" 34 | 35 | _CHANGE_CASES = "change_case:" 36 | 37 | _PUNCTUATION = "punctuation:" 38 | 39 | INSTRUCTION_DICT = { 40 | _KEYWORD + "existence": instructions.KeywordChecker, 41 | _KEYWORD + "frequency": instructions.KeywordFrequencyChecker, 42 | # TODO(jeffreyzhou): make a proper set of sentences to choose from 43 | # _KEYWORD + "key_sentences": instructions.KeySentenceChecker, 44 | _KEYWORD + "forbidden_words": instructions.ForbiddenWords, 45 | _KEYWORD + "letter_frequency": instructions.LetterFrequencyChecker, 46 | _LANGUAGE + "response_language": instructions.ResponseLanguageChecker, 47 | _LENGTH + "number_sentences": instructions.NumberOfSentences, 48 | _LENGTH + "number_paragraphs": instructions.ParagraphChecker, 49 | _LENGTH + "number_words": instructions.NumberOfWords, 50 | _LENGTH + "nth_paragraph_first_word": instructions.ParagraphFirstWordCheck, 51 | _CONTENT + "number_placeholders": instructions.PlaceholderChecker, 52 | _CONTENT + "postscript": instructions.PostscriptChecker, 53 | _FORMAT + "number_bullet_lists": instructions.BulletListChecker, 54 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace 55 | # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph, 56 | _FORMAT + "constrained_response": instructions.ConstrainedResponseChecker, 57 | _FORMAT + "number_highlighted_sections": (instructions.HighlightSectionChecker), 58 | _FORMAT + "multiple_sections": instructions.SectionChecker, 59 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. 60 | # _FORMAT + "rephrase": instructions.RephraseChecker, 61 | _FORMAT + "json_format": instructions.JsonFormat, 62 | _FORMAT + "title": instructions.TitleChecker, 63 | # TODO(tianjianlu): Re-enable with specific prompts. 64 | # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker, 65 | _COMBINATION + "two_responses": instructions.TwoResponsesChecker, 66 | _COMBINATION + "repeat_prompt": instructions.RepeatPromptThenAnswer, 67 | _STARTEND + "end_checker": instructions.EndChecker, 68 | _CHANGE_CASES + "capital_word_frequency": instructions.CapitalWordFrequencyChecker, 69 | _CHANGE_CASES + "english_capital": instructions.CapitalLettersEnglishChecker, 70 | _CHANGE_CASES + "english_lowercase": instructions.LowercaseLettersEnglishChecker, 71 | _PUNCTUATION + "no_comma": instructions.CommaChecker, 72 | _STARTEND + "quotation": instructions.QuotationChecker, 73 | } 74 | 75 | INSTRUCTION_CONFLICTS = { 76 | _KEYWORD + "existence": {_KEYWORD + "existence"}, 77 | _KEYWORD + "frequency": {_KEYWORD + "frequency"}, 78 | # TODO(jeffreyzhou): make a proper set of sentences to choose from 79 | # _KEYWORD + "key_sentences": instructions.KeySentenceChecker, 80 | _KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"}, 81 | _KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"}, 82 | _LANGUAGE 83 | + "response_language": { 84 | _LANGUAGE + "response_language", 85 | _FORMAT + "multiple_sections", 86 | _KEYWORD + "existence", 87 | _KEYWORD + "frequency", 88 | _KEYWORD + "forbidden_words", 89 | _STARTEND + "end_checker", 90 | _CHANGE_CASES + "english_capital", 91 | _CHANGE_CASES + "english_lowercase", 92 | }, 93 | _LENGTH + "number_sentences": {_LENGTH + "number_sentences"}, 94 | _LENGTH 95 | + "number_paragraphs": { 96 | _LENGTH + "number_paragraphs", 97 | _LENGTH + "nth_paragraph_first_word", 98 | _LENGTH + "number_sentences", 99 | _LENGTH + "nth_paragraph_first_word", 100 | }, 101 | _LENGTH + "number_words": {_LENGTH + "number_words"}, 102 | _LENGTH 103 | + "nth_paragraph_first_word": { 104 | _LENGTH + "nth_paragraph_first_word", 105 | _LENGTH + "number_paragraphs", 106 | }, 107 | _CONTENT + "number_placeholders": {_CONTENT + "number_placeholders"}, 108 | _CONTENT + "postscript": {_CONTENT + "postscript"}, 109 | _FORMAT + "number_bullet_lists": {_FORMAT + "number_bullet_lists"}, 110 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace 111 | # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph, 112 | _FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()), 113 | _FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"}, 114 | _FORMAT 115 | + "multiple_sections": { 116 | _FORMAT + "multiple_sections", 117 | _LANGUAGE + "response_language", 118 | _FORMAT + "number_highlighted_sections", 119 | }, 120 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. 121 | # _FORMAT + "rephrase": instructions.RephraseChecker, 122 | _FORMAT 123 | + "json_format": set(INSTRUCTION_DICT.keys()).difference( 124 | {_KEYWORD + "forbidden_words", _KEYWORD + "existence"} 125 | ), 126 | _FORMAT + "title": {_FORMAT + "title"}, 127 | # TODO(tianjianlu): Re-enable with specific prompts. 128 | # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker, 129 | _COMBINATION 130 | + "two_responses": set(INSTRUCTION_DICT.keys()).difference( 131 | { 132 | _KEYWORD + "forbidden_words", 133 | _KEYWORD + "existence", 134 | _LANGUAGE + "response_language", 135 | _FORMAT + "title", 136 | _PUNCTUATION + "no_comma", 137 | } 138 | ), 139 | _COMBINATION 140 | + "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference( 141 | {_KEYWORD + "existence", _FORMAT + "title", _PUNCTUATION + "no_comma"} 142 | ), 143 | _STARTEND + "end_checker": {_STARTEND + "end_checker"}, 144 | _CHANGE_CASES 145 | + "capital_word_frequency": { 146 | _CHANGE_CASES + "capital_word_frequency", 147 | _CHANGE_CASES + "english_lowercase", 148 | _CHANGE_CASES + "english_capital", 149 | }, 150 | _CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"}, 151 | _CHANGE_CASES 152 | + "english_lowercase": { 153 | _CHANGE_CASES + "english_lowercase", 154 | _CHANGE_CASES + "english_capital", 155 | }, 156 | _PUNCTUATION + "no_comma": {_PUNCTUATION + "no_comma"}, 157 | _STARTEND + "quotation": {_STARTEND + "quotation", _FORMAT + "title"}, 158 | } 159 | 160 | 161 | def conflict_make(conflicts): 162 | """Makes sure if A conflicts with B, B will conflict with A. 163 | 164 | Args: 165 | conflicts: Dictionary of potential conflicts where key is instruction id 166 | and value is set of instruction ids that it conflicts with. 167 | 168 | Returns: 169 | Revised version of the dictionary. All instructions conflict with 170 | themselves. If A conflicts with B, B will conflict with A. 171 | """ 172 | for key in conflicts: 173 | for k in conflicts[key]: 174 | conflicts[k].add(key) 175 | conflicts[key].add(key) 176 | return conflicts 177 | -------------------------------------------------------------------------------- /qalign/utils/math.py: -------------------------------------------------------------------------------- 1 | import re 2 | import re 3 | import signal 4 | import logging 5 | from typing import Optional 6 | 7 | import sympy 8 | from sympy.parsing.latex import parse_latex 9 | 10 | 11 | ## from open instruct 12 | 13 | 14 | def generate_axis(limit, k): 15 | sequence = [] 16 | i = 0 17 | 18 | # Generate powers of 2 up to K 19 | while 2**i < k: 20 | sequence.append(2**i) 21 | i += 1 22 | 23 | # Continue with increments of K 24 | value = k 25 | while value <= limit: 26 | sequence.append(value) 27 | value += k 28 | 29 | return sequence 30 | 31 | 32 | def get_last_number(output): 33 | 34 | output = re.sub(r"(\d),(\d)", r"\1\2", output) 35 | numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output) 36 | if numbers: 37 | return numbers[-1] 38 | else: 39 | return "NaN" 40 | 41 | 42 | def get_last_math(output): 43 | 44 | if "\\boxed" in output: 45 | last_box = last_boxed_only_string(output) 46 | 47 | try: 48 | final = remove_boxed(last_box) 49 | final = normalize_final_answer(final) 50 | except Exception as e: 51 | # print("no number in box") 52 | final = get_last_number(output) 53 | 54 | else: 55 | # print("no number in box") 56 | final = get_last_number(output) 57 | 58 | return final 59 | 60 | 61 | import re 62 | import signal 63 | import logging 64 | from typing import Optional 65 | 66 | import sympy 67 | from sympy.parsing.latex import parse_latex 68 | 69 | 70 | eval_logger = logging.getLogger("math_utils") 71 | 72 | 73 | # from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py#L187 74 | def last_boxed_only_string(string: str) -> Optional[str]: 75 | idx = string.rfind("\\boxed") 76 | if "\\boxed " in string: 77 | return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] 78 | if idx < 0: 79 | idx = string.rfind("\\fbox") 80 | if idx < 0: 81 | return "NaN" 82 | 83 | i = idx 84 | right_brace_idx = None 85 | num_left_braces_open = 0 86 | while i < len(string): 87 | if string[i] == "{": 88 | num_left_braces_open += 1 89 | if string[i] == "}": 90 | num_left_braces_open -= 1 91 | if num_left_braces_open == 0: 92 | right_brace_idx = i 93 | break 94 | i += 1 95 | 96 | if right_brace_idx is None: 97 | retval = None 98 | else: 99 | retval = string[idx : right_brace_idx + 1] 100 | 101 | return retval 102 | 103 | 104 | def remove_boxed(s: str): 105 | left = "\\boxed{" 106 | try: 107 | assert s[: len(left)] == left 108 | assert s[-1] == "}" 109 | return s[len(left) : -1] 110 | except: 111 | return "" 112 | 113 | 114 | def get_unnormalized_answer(text: str) -> str: 115 | INVALID_ANSWER = "[invalidanswer]" 116 | end_seq = "I hope it is correct." 117 | text += end_seq 118 | match = re.search( 119 | r"Final Answer: The final answer is(.*?). I hope it is correct.", 120 | text, 121 | ) 122 | if match: 123 | return match.group(1).strip() 124 | else: 125 | return INVALID_ANSWER 126 | 127 | 128 | SUBSTITUTIONS = [ 129 | ("an ", ""), 130 | ("a ", ""), 131 | (".$", "$"), 132 | ("\\$", ""), 133 | (r"\ ", ""), 134 | (" ", ""), 135 | ("\\dfrac", "\\frac"), 136 | ("mbox", "text"), 137 | (",\\text{and}", ","), 138 | ("\\text{and}", ","), 139 | ("\\text{m}", "\\text{}"), 140 | ("π", "\\pi"), 141 | ] 142 | REMOVED_EXPRESSIONS = [ 143 | "square", 144 | "ways", 145 | "integers", 146 | "dollars", 147 | "mph", 148 | "inches", 149 | "\\left", 150 | "\\big", 151 | "\\Big", 152 | "\\Bigg", 153 | "\\bigg", 154 | "\\right", 155 | "ft", 156 | "hours", 157 | "km", 158 | "units", 159 | "\\ldots", 160 | "sue", 161 | "points", 162 | "feet", 163 | "minutes", 164 | "digits", 165 | "cents", 166 | "degrees", 167 | "cm", 168 | "gm", 169 | "pounds", 170 | "meters", 171 | "meals", 172 | "edges", 173 | "students", 174 | "childrentickets", 175 | "multiples", 176 | "\\text{s}", 177 | "\\text{.}", 178 | "\\text{\ns}", 179 | "\\text{}^2", 180 | "\\text{}^3", 181 | "\\text{\n}", 182 | "\\text{}", 183 | r"\mathrm{th}", 184 | r"^\circ", 185 | r"^{\circ}", 186 | r"\;", 187 | r",\!", 188 | "{,}", 189 | '"', 190 | "\\dots", 191 | "\%", 192 | ] 193 | 194 | 195 | def normalize_final_answer(final_answer: str) -> str: 196 | """ 197 | Normalize a final answer to a quantitative reasoning question. 198 | 199 | Copied character for character from appendix D of Lewkowycz et al. (2022) 200 | """ 201 | final_answer = final_answer.split("=")[-1] 202 | 203 | for before, after in SUBSTITUTIONS: 204 | final_answer = final_answer.replace(before, after) 205 | for expr in REMOVED_EXPRESSIONS: 206 | final_answer = final_answer.replace(expr, "") 207 | 208 | # Extract answer that is in LaTeX math, is bold, 209 | # is surrounded by a box, etc. 210 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) 211 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) 212 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) 213 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) 214 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) 215 | 216 | # Normalize shorthand TeX: 217 | # \fracab -> \frac{a}{b} 218 | # \frac{abc}{bef} -> \frac{abc}{bef} 219 | # \fracabc -> \frac{a}{b}c 220 | # \sqrta -> \sqrt{a} 221 | # \sqrtab -> sqrt{a}b 222 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) 223 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) 224 | final_answer = final_answer.replace("$", "") 225 | 226 | # Normalize 100,000 -> 100000 227 | if final_answer.replace(",", "").isdigit(): 228 | final_answer = final_answer.replace(",", "") 229 | 230 | return final_answer 231 | 232 | 233 | class timeout: 234 | def __init__(self, seconds=1, error_message="Timeout"): 235 | self.seconds = seconds 236 | self.error_message = error_message 237 | 238 | def handle_timeout(self, signum, frame): 239 | raise TimeoutError(self.error_message) 240 | 241 | def __enter__(self): 242 | signal.signal(signal.SIGALRM, self.handle_timeout) 243 | signal.alarm(self.seconds) 244 | 245 | def __exit__(self, type, value, traceback): 246 | signal.alarm(0) 247 | 248 | 249 | def is_equiv(x1: str, x2: str) -> bool: 250 | """ 251 | x1 and x2 are normalized latex string 252 | """ 253 | try: 254 | with timeout(seconds=5): 255 | try: 256 | parsed_x1 = parse_latex(x1) 257 | parsed_x2 = parse_latex(x2) 258 | except ( 259 | sympy.parsing.latex.errors.LaTeXParsingError, 260 | sympy.SympifyError, 261 | TypeError, 262 | ): 263 | eval_logger.debug(f"couldn't parse one of {x1} or {x2}") 264 | return False 265 | 266 | try: 267 | diff = parsed_x1 - parsed_x2 268 | except TypeError: 269 | eval_logger.debug(f"couldn't subtract {x1} and {x2}") 270 | return False 271 | 272 | try: 273 | if sympy.simplify(diff) == 0: 274 | return True 275 | else: 276 | return False 277 | except ValueError: 278 | eval_logger.debug( 279 | f"Had some trouble simplifying when comparing {x1} and {x2}" 280 | ) 281 | except TimeoutError: 282 | eval_logger.debug(f"Timed out comparing {x1} and {x2}") 283 | return False 284 | except ImportError as e: 285 | eval_logger.error(e) 286 | raise 287 | except Exception as e: 288 | eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}") 289 | return False 290 | 291 | 292 | def fix_fracs(string): 293 | substrs = string.split("\\frac") 294 | new_str = substrs[0] 295 | if len(substrs) > 1: 296 | substrs = substrs[1:] 297 | for substr in substrs: 298 | new_str += "\\frac" 299 | if substr[0] == "{": 300 | new_str += substr 301 | else: 302 | try: 303 | assert len(substr) >= 2 304 | except AssertionError: 305 | return string 306 | a = substr[0] 307 | b = substr[1] 308 | if b != "{": 309 | if len(substr) > 2: 310 | post_substr = substr[2:] 311 | new_str += "{" + a + "}{" + b + "}" + post_substr 312 | else: 313 | new_str += "{" + a + "}{" + b + "}" 314 | else: 315 | if len(substr) > 2: 316 | post_substr = substr[2:] 317 | new_str += "{" + a + "}" + b + post_substr 318 | else: 319 | new_str += "{" + a + "}" + b 320 | string = new_str 321 | return string 322 | 323 | 324 | def fix_a_slash_b(string): 325 | if len(string.split("/")) != 2: 326 | return string 327 | a = string.split("/")[0] 328 | b = string.split("/")[1] 329 | try: 330 | a = int(a) 331 | b = int(b) 332 | assert string == "{}/{}".format(a, b) 333 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 334 | return new_string 335 | except AssertionError: 336 | return string 337 | 338 | 339 | def remove_right_units(string): 340 | # "\\text{ " only ever occurs (at least in the val set) when describing units 341 | if "\\text{ " in string: 342 | splits = string.split("\\text{ ") 343 | assert len(splits) == 2 344 | return splits[0] 345 | else: 346 | return string 347 | 348 | 349 | def fix_sqrt(string): 350 | if "\\sqrt" not in string: 351 | return string 352 | splits = string.split("\\sqrt") 353 | new_string = splits[0] 354 | for split in splits[1:]: 355 | if split[0] != "{": 356 | a = split[0] 357 | new_substr = "\\sqrt{" + a + "}" + split[1:] 358 | else: 359 | new_substr = "\\sqrt" + split 360 | new_string += new_substr 361 | return new_string 362 | 363 | 364 | def strip_string(string): 365 | # linebreaks 366 | string = string.replace("\n", "") 367 | 368 | # remove inverse spaces 369 | string = string.replace("\\!", "") 370 | 371 | # replace \\ with \ 372 | string = string.replace("\\\\", "\\") 373 | 374 | # replace tfrac and dfrac with frac 375 | string = string.replace("tfrac", "frac") 376 | string = string.replace("dfrac", "frac") 377 | 378 | # remove \left and \right 379 | string = string.replace("\\left", "") 380 | string = string.replace("\\right", "") 381 | 382 | # Remove circ (degrees) 383 | string = string.replace("^{\\circ}", "") 384 | string = string.replace("^\\circ", "") 385 | 386 | # remove dollar signs 387 | string = string.replace("\\$", "") 388 | 389 | # remove units (on the right) 390 | string = remove_right_units(string) 391 | 392 | # remove percentage 393 | string = string.replace("\\%", "") 394 | string = string.replace("\%", "") # noqa: W605 395 | 396 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 397 | string = string.replace(" .", " 0.") 398 | string = string.replace("{.", "{0.") 399 | # if empty, return empty string 400 | if len(string) == 0: 401 | return string 402 | if string[0] == ".": 403 | string = "0" + string 404 | 405 | # to consider: get rid of e.g. "k = " or "q = " at beginning 406 | if len(string.split("=")) == 2: 407 | if len(string.split("=")[0]) <= 2: 408 | string = string.split("=")[1] 409 | 410 | # fix sqrt3 --> sqrt{3} 411 | string = fix_sqrt(string) 412 | 413 | # remove spaces 414 | string = string.replace(" ", "") 415 | 416 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 417 | string = fix_fracs(string) 418 | 419 | # manually change 0.5 --> \frac{1}{2} 420 | if string == "0.5": 421 | string = "\\frac{1}{2}" 422 | 423 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 424 | string = fix_a_slash_b(string) 425 | 426 | return string 427 | 428 | 429 | def hendrycks_is_equiv(str1, str2, verbose=False): 430 | if str1 is None and str2 is None: 431 | print("WARNING: Both None") 432 | return True 433 | if str1 is None or str2 is None: 434 | return False 435 | 436 | try: 437 | ss1 = strip_string(str1) 438 | ss2 = strip_string(str2) 439 | if verbose: 440 | print(ss1, ss2) 441 | return ss1 == ss2 442 | except Exception: 443 | return str1 == str2 444 | 445 | 446 | def get_last_option(text): 447 | pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)" 448 | match = re.search(pattern, text, re.DOTALL) 449 | if match: 450 | return match.group(0) 451 | else: 452 | return None 453 | 454 | 455 | def verify_ifeval_sample(model_output, constraint): 456 | model_output = model_output.split("<|assistant|>\n")[-1].strip() 457 | # TODO: just pass in final answer. this should be fine for other evals too. 458 | answer = model_output.split("<|assistant|>\n")[-1].strip() 459 | if isinstance(constraint, str): 460 | constraint = json.loads(constraint) 461 | if "func_name" not in constraint: 462 | print("WARNING: constraint missing func_name") 463 | print(constraint) 464 | return False 465 | # first, parse out the constraint string. 466 | func_name = constraint.pop("func_name") 467 | # get the function 468 | func = IF_FUNCTIONS_MAP[func_name] 469 | # now, run the function 470 | # pop out any none args 471 | non_none_args = {k: v for k, v in constraint.items() if v is not None} 472 | # sometimes we have extra args, sometimes not. 473 | if len(constraint) == 0: 474 | return func(model_output) 475 | return func(answer, **non_none_args) 476 | -------------------------------------------------------------------------------- /qalign/utils/mbr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | 4 | from typing import * 5 | 6 | from multiprocessing import Pool 7 | import numpy as np 8 | import os 9 | 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer 12 | 13 | from scipy.sparse import csr_matrix 14 | 15 | 16 | ## qalign 17 | from qalign.utils.math import generate_axis 18 | from quest.utils.list import ( 19 | chunked, 20 | ) 21 | 22 | 23 | class UCS: 24 | 25 | def __init__( 26 | self, 27 | process_num=None, 28 | vocab_size=1000, 29 | ): 30 | 31 | self.process_num = os.cpu_count() if process_num is None else process_num 32 | self.vocab_size = vocab_size 33 | self.func = ucs_sparse 34 | 35 | def _process_chunk(self, chunk): 36 | """ 37 | Process a chunk of data - this is needed for multiprocessing. 38 | 39 | Args: 40 | chunk: A subset of samples to process 41 | 42 | Returns: 43 | UCS matrix for the chunk 44 | """ 45 | return self.func(chunk, vocab_size=self.vocab_size) 46 | 47 | def get_score(self, samples, compute_in_parallel=True): 48 | 49 | if compute_in_parallel: 50 | with Pool(self.process_num) as executor: 51 | results = list( 52 | tqdm( 53 | executor.map( 54 | self._process_chunk, 55 | samples, 56 | ), 57 | total=len(samples), 58 | desc="Processing chunks", 59 | ) 60 | ) 61 | 62 | mat = np.array(results) 63 | else: 64 | mat = np.array( 65 | [ 66 | self.func(generations, vocab_size=self.vocab_size) 67 | for generations in samples 68 | ] 69 | ) 70 | 71 | return mat 72 | 73 | 74 | class FastROUGE(UCS): 75 | def __init__( 76 | self, 77 | process_num=None, 78 | vocab_size=1000, 79 | ): 80 | 81 | self.process_num = os.cpu_count() if process_num is None else process_num 82 | self.vocab_size = vocab_size 83 | self.func = rouge1_sparse 84 | 85 | 86 | def join_instances_accepts(i, is_quest=True): 87 | 88 | if is_quest: 89 | accepted_outputs = [] 90 | accepted_index = [] 91 | last_accepted = None 92 | last_accept_index = None 93 | for j, output in enumerate(i["outputs"]): 94 | if output["accept"]: 95 | accepted_outputs.append(output) 96 | accepted_index.append(j) 97 | last_accepted = output 98 | last_accept_index = j 99 | elif last_accepted is not None: 100 | accepted_outputs.append(last_accepted) 101 | accepted_index.append(last_accept_index) 102 | 103 | else: 104 | accepted_outputs = i["outputs"] 105 | accepted_index = list(range(len(i["outputs"]))) 106 | 107 | # print(len(outputs)) 108 | return {"input": i["input"], "outputs": accepted_outputs, "index": accepted_index} 109 | 110 | 111 | def get_mats(completions, compute_in_parallel=True, metric="bleu", vocab_size=1000): 112 | 113 | if "ucs" in metric: 114 | s = UCS(vocab_size=vocab_size) 115 | elif "rouge" in metric: 116 | s = FastROUGE(vocab_size=vocab_size) # PairwiseRouge(metric=metric) 117 | else: 118 | raise ValueError("metric must be bleu or rouge") 119 | 120 | mats = s.get_score(completions, compute_in_parallel=compute_in_parallel) 121 | 122 | return mats 123 | 124 | 125 | def ucs_sparse(generations, vocab_size): 126 | """ 127 | Compute Unigram Consistency Score (UCS) matrix for lists of integers using 128 | a vectorized approach for better efficiency. 129 | 130 | Args: 131 | int_lists: List of integer lists where each list represents a sequence of elements 132 | vocab_size: Size of the vocabulary |V| to use in the UCS calculation 133 | 134 | Returns: 135 | UCS matrix where UCS[i, j] represents the consistency score between 136 | lists i and j 137 | """ 138 | n = len(generations) 139 | 140 | # Check if provided vocab_size is valid 141 | if vocab_size <= 0: 142 | raise ValueError("Vocabulary size must be greater than 0") 143 | 144 | # Create a binary matrix of shape (n, vocab_size) where binary_matrix[i, j] = 1 145 | # if element j is in list i, 0 otherwise 146 | 147 | # Method 1: Using sparse matrices (more efficient for large vocab_size) 148 | rows = [] 149 | cols = [] 150 | data = [] 151 | 152 | for i, int_list in enumerate(generations): 153 | for elem in set(int_list): # Use set to count each unique element once 154 | if 0 <= elem < vocab_size: # Ensure the element is in range 155 | rows.append(i) 156 | cols.append(elem) 157 | data.append(1) 158 | 159 | # Create sparse matrix 160 | binary_matrix = csr_matrix((data, (rows, cols)), shape=(n, vocab_size)) 161 | 162 | # Compute UCS matrix using matrix multiplication 163 | # (binary_matrix @ binary_matrix.T) gives the dot product of each pair of rows 164 | dot_products = binary_matrix @ binary_matrix.T 165 | 166 | # Convert to dense array and divide by vocab_size 167 | ucs_matrix = dot_products.toarray() / vocab_size 168 | 169 | return ucs_matrix 170 | 171 | 172 | def rouge1_sparse(generations, vocab_size): 173 | """ 174 | Compute Unigram Consistency Score (UCS) matrix for lists of integers using 175 | a vectorized approach for better efficiency. 176 | 177 | Args: 178 | int_lists: List of integer lists where each list represents a sequence of elements 179 | vocab_size: Size of the vocabulary |V| to use in the UCS calculation 180 | 181 | Returns: 182 | UCS matrix where UCS[i, j] represents the consistency score between 183 | lists i and j 184 | """ 185 | n = len(generations) 186 | 187 | # Check if provided vocab_size is valid 188 | if vocab_size <= 0: 189 | raise ValueError("Vocabulary size must be greater than 0") 190 | 191 | # Create a binary matrix of shape (n, vocab_size) where binary_matrix[i, j] = 1 192 | # if element j is in list i, 0 otherwise 193 | 194 | # Method 1: Using sparse matrices (more efficient for large vocab_size) 195 | rows = [] 196 | cols = [] 197 | data = [] 198 | # Store total word counts for each generation (for normalization) 199 | word_counts = np.zeros(n) 200 | 201 | for i, int_list in enumerate(generations): 202 | # Count total words in this generation (including duplicates) 203 | word_counts[i] = len(int_list) 204 | 205 | # Count frequency of each word in this generation 206 | # word_freq = {} 207 | # for word in int_list: 208 | # if 0 <= word < vocab_size: # Ensure word is in valid range 209 | # word_freq[word] = word_freq.get(word, 0) + 1 210 | 211 | word_freq = Counter(int_list) 212 | 213 | # Add to sparse matrix data 214 | for word, freq in word_freq.items(): 215 | if 0 <= word < vocab_size: 216 | rows.append(i) 217 | cols.append(word) 218 | data.append(freq) # S 219 | 220 | count_matrix = csr_matrix((data, (rows, cols)), shape=(n, vocab_size)) 221 | 222 | overlap_matrix = np.zeros((n, n), dtype=np.float32) 223 | 224 | # Convert to binary matrix first (1 where count > 0) 225 | binary_matrix = count_matrix.copy() 226 | binary_matrix.data = np.ones_like(binary_matrix.data) 227 | 228 | # Get common token indicators using matrix multiplication 229 | common_tokens = binary_matrix @ binary_matrix.T 230 | 231 | # For each pair of rows, calculate the overlap 232 | for i in range(n): 233 | row_i = count_matrix.getrow(i) 234 | 235 | for j in range(i, n): 236 | # Only process if there are any common tokens 237 | if common_tokens[i, j] > 0: 238 | row_j = count_matrix.getrow(j) 239 | 240 | # Get common indices 241 | i_indices = row_i.indices 242 | i_data = row_i.data 243 | j_indices = row_j.indices 244 | j_data = row_j.data 245 | 246 | # Find the intersection of indices 247 | # This is more efficient than multiplying sparse matrices 248 | i_dict = dict(zip(i_indices, i_data)) 249 | j_dict = dict(zip(j_indices, j_data)) 250 | 251 | common_indices = set(i_indices).intersection(set(j_indices)) 252 | 253 | # Calculate the sum of minimums 254 | overlap = sum(min(i_dict[idx], j_dict[idx]) for idx in common_indices) 255 | 256 | overlap_matrix[i, j] = overlap 257 | if i != j: 258 | overlap_matrix[j, i] = overlap 259 | 260 | # Calculate ROUGE-1 recall (overlap / words in reference) 261 | # We can use broadcasting to divide by word_counts 262 | recall_matrix = np.zeros((n, n)) 263 | nonzero_counts = word_counts > 0 264 | if np.any(nonzero_counts): 265 | recall_matrix[nonzero_counts, :] = ( 266 | overlap_matrix[nonzero_counts, :] / word_counts[nonzero_counts, np.newaxis] 267 | ) 268 | 269 | # Calculate ROUGE-1 precision (overlap / words in candidate) 270 | precision_matrix = np.zeros((n, n)) 271 | if np.any(nonzero_counts): 272 | precision_matrix[:, nonzero_counts] = ( 273 | overlap_matrix[:, nonzero_counts] / word_counts[np.newaxis, nonzero_counts] 274 | ) 275 | 276 | # Calculate ROUGE-1 F1 score 277 | f1_matrix = np.zeros((n, n)) 278 | nonzero = (precision_matrix + recall_matrix) > 0 279 | f1_matrix[nonzero] = ( 280 | 2 281 | * (precision_matrix[nonzero] * recall_matrix[nonzero]) 282 | / (precision_matrix[nonzero] + recall_matrix[nonzero]) 283 | ) 284 | 285 | return { 286 | "recall": recall_matrix, 287 | "precision": precision_matrix, 288 | "f1": f1_matrix, 289 | "overlap": overlap_matrix, 290 | "word_counts": word_counts, 291 | }["f1"] 292 | 293 | 294 | def mbr_mat_progression( 295 | exp, 296 | compute_in_parallel=True, 297 | k=1, 298 | n=None, 299 | max_steps=None, 300 | metric="bleu", 301 | ): 302 | 303 | if max_steps is None: 304 | max_steps = exp.get("steps") 305 | 306 | # total_instances = exp.instances(lazy_iterable=True) 307 | 308 | # if n is not None: 309 | # total_instances = total_instances[:n] 310 | 311 | # if "quest" not in exp.get("variant"): 312 | tokenizer = AutoTokenizer.from_pretrained(exp.get("model_path")) 313 | 314 | mats = [] 315 | 316 | repeat_inds = [] 317 | 318 | for instances in tqdm(chunked(exp.instances(lazy_iterable=True), 32)): 319 | # instances = [instance] 320 | 321 | instances = [ 322 | join_instances_accepts(i, is_quest="quest" in exp.get("variant")) 323 | for i in instances 324 | ] 325 | 326 | repeat_inds.extend([i["index"] for i in instances]) 327 | 328 | if "quest" not in exp.get("variant"): 329 | 330 | texts = [ 331 | o["text"] 332 | for instance in instances 333 | for o in instance["outputs"][:max_steps] # Already uniform length 334 | ] 335 | 336 | # Batch tokenize everything at once 337 | batch_ids = tokenizer(texts)["input_ids"] 338 | 339 | # Reshape into [num_instances, max_steps, ...] using fixed chunk size 340 | completions = [ 341 | batch_ids[i : i + max_steps] 342 | for i in range(0, len(batch_ids), max_steps) 343 | ] 344 | 345 | else: 346 | completions = [ 347 | [o["completion"] for o in instance["outputs"][:max_steps]] 348 | for instance in instances 349 | ] 350 | 351 | mat = get_mats( 352 | completions, 353 | compute_in_parallel=compute_in_parallel, 354 | metric=metric, 355 | vocab_size=tokenizer.vocab_size, 356 | ) 357 | 358 | mats.append(mat) 359 | 360 | del instances 361 | del completions 362 | 363 | mat = np.concatenate(mats, axis=0) 364 | 365 | return mat, repeat_inds 366 | 367 | 368 | def weighted_mbr_pick_progression( 369 | exp, 370 | reward_key, 371 | compute_in_parallel=True, 372 | k=1, 373 | n=None, 374 | max_steps=None, 375 | metric="bleu", 376 | ): 377 | 378 | if max_steps is None: 379 | max_steps = exp.get("steps") 380 | 381 | mat, repeat_inds = mbr_mat_progression( 382 | exp, 383 | compute_in_parallel=compute_in_parallel, 384 | k=k, 385 | n=n, 386 | max_steps=max_steps, 387 | metric=metric, 388 | ) 389 | 390 | rewards = exp.get_eval(reward_key) 391 | 392 | repeat_rewards = np.array( 393 | [ 394 | [r["scores"][i] for i in inds][:max_steps] 395 | for inds, r in zip(repeat_inds, rewards) 396 | ] 397 | ) 398 | 399 | repeat_rewards = np.expand_dims(repeat_rewards, axis=1) 400 | 401 | mat *= repeat_rewards 402 | 403 | axis = generate_axis(max_steps + 1, k) 404 | 405 | return pick_mat(mat, axis, repeat_inds) 406 | 407 | 408 | def mbr_pick_progression( 409 | exp, 410 | compute_in_parallel=True, 411 | k=1, 412 | n=None, 413 | max_steps=None, 414 | metric="bleu", 415 | ): 416 | 417 | if max_steps is None: 418 | max_steps = exp.get("steps") 419 | 420 | mat, repeat_inds = mbr_mat_progression( 421 | exp, 422 | compute_in_parallel=compute_in_parallel, 423 | k=k, 424 | n=n, 425 | max_steps=max_steps, 426 | metric=metric, 427 | ) 428 | 429 | axis = generate_axis(max_steps + 1, k) 430 | 431 | # 1,2,4,8,16,32,64,96,128,160,192,224,256 432 | 433 | return pick_mat(mat, axis, repeat_inds) 434 | 435 | 436 | def pick_mat( 437 | mat, 438 | axis, 439 | repeat_inds, 440 | ): 441 | 442 | values = [] 443 | 444 | preds = {} 445 | for ni in axis: 446 | pick_batch = np.argmax(mat[:, :ni, :ni].mean(axis=-1), axis=-1) 447 | # acc = [] 448 | 449 | preds[ni] = [] 450 | for i, pick in enumerate( 451 | pick_batch, 452 | ): 453 | # true_answer = extract_func(instance["input"]["answer"]) 454 | pred_index = repeat_inds[i][pick] 455 | preds[ni].append(pred_index) 456 | # acc.append(correct) 457 | 458 | return {"preds": preds, "axis": axis} 459 | -------------------------------------------------------------------------------- /qalign/utils/pred.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import re 3 | from functools import partial 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | import numpy as np 7 | 8 | 9 | from typing import * 10 | 11 | from transformers import AutoTokenizer 12 | 13 | ## quest 14 | from quest.utils.list import chunked 15 | 16 | ## qalign 17 | from qalign.utils.mbr import mbr_pick_progression, weighted_mbr_pick_progression 18 | from qalign.utils.math import ( 19 | get_last_number, 20 | get_last_math, 21 | get_last_option, 22 | generate_axis, 23 | ) 24 | 25 | ## expkit 26 | from expkit import ( 27 | Exp, 28 | Evalutor, 29 | ) 30 | 31 | 32 | def get_strategy( 33 | strategy: str, 34 | key: str, 35 | p=None, 36 | beta=1.0, 37 | c=1.0, 38 | gaps=64, 39 | exp_rate=False, 40 | **kwargs, 41 | ) -> Evalutor: 42 | 43 | if strategy == "voting": 44 | pick = VotingPick(gaps=gaps, exp_rate=exp_rate, multiprocessing=p, **kwargs) 45 | 46 | elif strategy == "bon": 47 | pick = BonPick( 48 | key=key, gaps=gaps, exp_rate=exp_rate, multiprocessing=p, **kwargs 49 | ) 50 | 51 | # setup = setup.filter(lambda x: x.has_eval(key)) 52 | 53 | elif strategy == "weighted-voting": 54 | pick = WeightedVotingPick( 55 | key=key, 56 | gaps=gaps, 57 | exp_rate=exp_rate, 58 | multiprocessing=p, 59 | beta=beta, 60 | **kwargs, 61 | ) 62 | # setup = setup.filter(lambda x: x.has_eval(key)) 63 | elif strategy == "last": 64 | pick = LastPick(gaps=gaps, exp_rate=exp_rate, multiprocessing=p, **kwargs) 65 | elif strategy == "mbr": 66 | pick = MBRPick(gaps=gaps, **kwargs) 67 | elif strategy == "wmbr": 68 | pick = WMBRPick(gaps=gaps, key=key, exp_rate=exp_rate, **kwargs) 69 | elif strategy == "bonresample": 70 | pick = BonResamplePick( 71 | key=key, gaps=gaps, exp_rate=exp_rate, multiprocessing=p, **kwargs 72 | ) 73 | else: 74 | raise ValueError("Invalid strategy") 75 | 76 | return pick 77 | 78 | 79 | def get_param_count(exp): 80 | # example : model_path = "meta-llama/Llama-3.1-8B-Instruct" 81 | # or model_path = "allenai/olmo-2-2b-it" 82 | # I need to look for B or b and then get the number before it 83 | 84 | # get the number before B or b 85 | # get the number before B or b 86 | 87 | model_path = exp.meta["model_path"] 88 | match = re.search(r"(\d+)[Bb]", model_path) 89 | if match: 90 | return int(match.group(1)) 91 | else: 92 | raise ValueError("No match found") 93 | 94 | 95 | def get_flops_by_scores(exp, key="", strategy="voting"): 96 | 97 | with Pool() as p: 98 | eval_result = get_strategy( 99 | strategy=strategy, 100 | p=p, 101 | exp_rate=True, 102 | gaps=2, 103 | key=key, 104 | ).eval(exp) 105 | 106 | axis = eval_result[0]["axis"] 107 | 108 | tokenizer = AutoTokenizer.from_pretrained(exp.meta["model_path"]) 109 | token_counts = [] 110 | for instance in exp.instances(): 111 | 112 | if exp.meta["variant"] == "ancestral": 113 | 114 | ids = tokenizer([o["text"] for o in instance["outputs"]], padding=False) 115 | 116 | token_counts.append(list(map(len, ids["input_ids"]))) 117 | else: 118 | 119 | token_counts.append( 120 | [len(o["completion"][o["index"] :]) for o in instance["outputs"]] 121 | ) 122 | # all_criteria.append(x) 123 | token_counts = np.array(token_counts) 124 | 125 | # sample a few points from axis ... 126 | 127 | D_test = [] 128 | for i in axis: 129 | D_test.append(np.sum(token_counts[:, :i], axis=1).mean(0)) 130 | 131 | ## billion in scientific notation 132 | ## 133 | IFLOPS = [ 134 | 2 * 2 * float(get_param_count(exp)) * v for v in D_test 135 | ] # one extra 2 for reward. 136 | return { 137 | "IFLOPS": IFLOPS, 138 | "D_test": D_test, 139 | "scores": eval_result[0]["scores"], 140 | } 141 | 142 | 143 | def get_flops_prev(exp, axis): 144 | 145 | tokenizer = AutoTokenizer.from_pretrained(exp.meta["model_path"]) 146 | token_counts = [] 147 | for instance in exp.instances(): 148 | 149 | if exp.meta["variant"] == "ancestral": 150 | 151 | ids = tokenizer([o["text"] for o in instance["outputs"]], padding=False) 152 | 153 | token_counts.append(list(map(len, ids["input_ids"]))) 154 | else: 155 | 156 | token_counts.append( 157 | [len(o["completion"][o["index"] :]) for o in instance["outputs"]] 158 | ) 159 | # all_criteria.append(x) 160 | token_counts = np.array(token_counts) 161 | 162 | # sample a few points from axis ... 163 | 164 | D_test = [] 165 | for i in axis: 166 | D_test.append(np.sum(token_counts[:, :i], axis=1).mean(0)) 167 | 168 | ## billion in scientific notation 169 | ## 170 | IFLOPS = [ 171 | 2 * float(get_param_count(exp)) * v for v in D_test 172 | ] # one extra 2 for reward. 173 | return { 174 | "IFLOPS": IFLOPS, 175 | "D_test": D_test, 176 | } 177 | 178 | 179 | def real_avg_token_counts(exp, sweep, k=64): 180 | tokenizer = AutoTokenizer.from_pretrained(exp.meta["model_path"]) 181 | token_counts = [] 182 | 183 | instances = exp.instances(lazy_iterable=True) 184 | 185 | for i, instance in enumerate( 186 | tqdm(instances, desc="Processing Instances", total=k + 1) 187 | ): 188 | 189 | if exp.meta["variant"] == "ancestral": 190 | 191 | ids = tokenizer([o["text"] for o in instance["outputs"]], padding=False) 192 | 193 | token_counts.append(list(map(len, ids["input_ids"]))) 194 | else: 195 | 196 | token_counts.append( 197 | [len(o["completion"][o["index"] :]) for o in instance["outputs"]] 198 | ) 199 | 200 | if i > k: 201 | break 202 | 203 | token_counts = np.array(token_counts) 204 | 205 | avg_token_counts = np.mean(token_counts, axis=0) 206 | 207 | if exp.meta["variant"] == "ancestral": 208 | if ("bon" in sweep) or ("weighted" in sweep) or ("wmbr" in sweep): 209 | avg_token_counts *= 2 210 | 211 | else: # questrlhf 212 | avg_token_counts *= 2 213 | 214 | return avg_token_counts 215 | 216 | 217 | def fake_avg_token_counts(exp, sweep, d=500): 218 | avg_token_counts = np.array([d] * exp.get("steps")) # * n 219 | 220 | if "DPO" in exp.meta["model_path"]: 221 | avg_token_counts = avg_token_counts * 1.5 222 | 223 | else: 224 | 225 | # sample a few points from axis ... 226 | if exp.meta["variant"] == "ancestral": 227 | if ("bon" in sweep) or ("weighted" in sweep) or ("wmbr" in sweep): 228 | avg_token_counts *= 2 229 | 230 | else: # questrlhf 231 | avg_token_counts[0] *= 2 232 | # avg_token_counts[1:] = avg_token_counts[1:] / 233 | 234 | return avg_token_counts 235 | 236 | 237 | def get_flops(exp, axis, sweep, n=None, fake=True): 238 | 239 | # tokenizer = AutoTokenizer.from_pretrained(exp.meta["model_path"]) 240 | 241 | if n is None: 242 | n = exp.meta["n"] 243 | 244 | if fake: 245 | avg_token_counts = fake_avg_token_counts(exp, sweep=sweep) 246 | else: 247 | avg_token_counts = real_avg_token_counts(exp, sweep=sweep) 248 | 249 | D_test = [] 250 | for i in axis: 251 | D_test.append(np.sum(avg_token_counts[:i])) 252 | 253 | ## billion in scientific notation 254 | ## 255 | IFLOPS = [ 256 | 2 * float(get_param_count(exp)) * v for v in D_test 257 | ] # one extra 2 for reward. 258 | return { 259 | "IFLOPS": IFLOPS, 260 | "D_test": D_test, 261 | } 262 | 263 | 264 | def softmax(x, temp=1.0): 265 | x = np.array(x) 266 | e_x = np.exp(x / temp) 267 | return e_x / e_x.sum() 268 | 269 | 270 | def log_softmax(x, temp=1.0): 271 | x = np.array(x) 272 | x = x / temp 273 | x_max = np.max(x) 274 | log_sum_exp = np.log(np.sum(np.exp(x - x_max))) + x_max 275 | return x - log_sum_exp 276 | 277 | 278 | def is_weighted_mode_correct( 279 | instance_and_score, n=512, beta=1.0, extract_func=get_last_number 280 | ): 281 | instance, scores = instance_and_score 282 | 283 | r = scores[:n] 284 | p = softmax(r, temp=beta) 285 | answer = extract_func(instance["input"]["answer"]) 286 | outputs = instance["outputs"][:n] 287 | 288 | responses = [ 289 | (extract_func(output["text"]), float(pi)) for output, pi in zip(outputs, p) 290 | ] 291 | 292 | response_dict = {} 293 | for response, prob in responses: 294 | if response in response_dict: 295 | response_dict[response] += prob 296 | else: 297 | response_dict[response] = prob 298 | 299 | if answer in response_dict: 300 | correct_estimate = response_dict[answer] 301 | else: 302 | correct_estimate = 0 303 | 304 | max_response = max(response_dict, key=response_dict.get) 305 | 306 | return int(max_response == answer), correct_estimate 307 | 308 | 309 | def get_last(instance, n=512, extract_func=get_last_number): 310 | 311 | outputs = instance["outputs"][:n] 312 | responses = [extract_func(output["text"]) for output in outputs] 313 | answer = extract_func(instance["input"]["answer"]) 314 | 315 | return int(responses[-1] == answer), 0.0 316 | 317 | 318 | def is_mode_correct(instance, n=512, burnin=0, extract_func=get_last_number): 319 | 320 | outputs = instance["outputs"][burnin:n] 321 | 322 | answer = extract_func(instance["input"]["answer"]) 323 | 324 | responses = [extract_func(output["text"]) for output in outputs] 325 | 326 | # compute the most common element in accepted 327 | 328 | c = Counter(responses) 329 | 330 | if answer in c: 331 | correct_estimate = c[answer] / len(responses) 332 | 333 | else: 334 | correct_estimate = 0 335 | 336 | if len(c) == 0: 337 | most_common = "NaN" 338 | else: 339 | most_common = c.most_common(1)[0][0] 340 | 341 | return int(most_common == answer), correct_estimate 342 | 343 | 344 | def is_max_correct(instance_and_score, n=512): 345 | scores, gt = instance_and_score 346 | max_index = np.argmax(scores["scores"][:n]) 347 | 348 | return gt["scores"][max_index] 349 | 350 | 351 | # max_response = extract_func(instance["outputs"][max_index]["text"]) 352 | # # compute the most common element in accepted 353 | # return int(max_response == extract_func(instance["input"]["answer"])) 354 | 355 | 356 | def is_resample_correct( 357 | instance_and_score, n=512, extract_func=get_last_number, r=128, trials=16 358 | ): 359 | instance, scores = instance_and_score 360 | outputs = instance["outputs"][:n] 361 | scores = np.array(scores[:n]) 362 | answer = extract_func(instance["input"]["answer"]) 363 | 364 | def shuffle(x): 365 | np.random.shuffle(x) 366 | return x 367 | 368 | # sample 16 times a sublist of scores of size "r" 369 | 370 | inds = [shuffle(np.arange(n))[:r].tolist() for _ in range(trials)] 371 | 372 | pick_ind = [indsi[np.argmax(scores[indsi])] for indsi in inds] 373 | responses = [extract_func(outputs[i]["text"]) for i in pick_ind] 374 | 375 | c = Counter(responses) 376 | 377 | if answer in c: 378 | correct_estimate = c[answer] / len(responses) 379 | 380 | else: 381 | correct_estimate = 0 382 | 383 | if len(c) == 0: 384 | most_common = "NaN" 385 | else: 386 | most_common = c.most_common(1)[0][0] 387 | 388 | return int(most_common == answer) 389 | 390 | 391 | def join_instances_and_repeat_accepts(instances, is_quest=True): 392 | return _join_instances_and_repeat_accepts(instances, is_quest=is_quest)[0] 393 | 394 | 395 | def join_instances_and_repeat_accepts_with_scores( 396 | instances, 397 | evals_scores, 398 | is_quest=True, 399 | ): 400 | return _join_instances_and_repeat_accepts( 401 | instances, is_quest=is_quest, evals_scores=evals_scores 402 | ) 403 | 404 | 405 | def _join_instances_and_repeat_accepts(instances, is_quest=True, evals_scores=None): 406 | inputs = instances[0]["input"] 407 | outputs = [] 408 | scores = [] 409 | 410 | if is_quest: 411 | outputs_cat = [] 412 | scores_cat = [] 413 | for i, es in zip(instances, evals_scores or [None] * len(instances)): 414 | accepted_outputs = [] 415 | accepted_scores = [] 416 | last_accepted = None 417 | for output, s in zip( 418 | i["outputs"], es["scores"] if es else [None] * len(i["outputs"]) 419 | ): 420 | if output["accept"]: 421 | accepted_outputs.append(output) 422 | if es: 423 | accepted_scores.append(s) 424 | last_accepted = (output, s) 425 | elif last_accepted is not None: 426 | accepted_outputs.append(last_accepted[0]) 427 | if es: 428 | accepted_scores.append(last_accepted[1]) 429 | 430 | outputs_cat.append(accepted_outputs) 431 | if es: 432 | scores_cat.append(accepted_scores) 433 | 434 | for i in range(len(outputs_cat[0])): 435 | for o in outputs_cat: 436 | outputs.append(o[i]) 437 | if evals_scores: 438 | for s in scores_cat: 439 | scores.append(s[i]) 440 | else: 441 | outputs = [j for i in instances for j in i["outputs"]] 442 | if evals_scores: 443 | scores = [s for es in evals_scores for s in es["scores"]] 444 | 445 | # print(len(outputs)) 446 | return {"input": inputs, "outputs": outputs}, scores 447 | 448 | 449 | class VotingPick(Evalutor): 450 | 451 | def __init__( 452 | self, 453 | gaps=2, 454 | exp_rate=True, 455 | max_num_chains=None, 456 | multiprocessing=None, 457 | burnin=0, 458 | extract="lastnumber", 459 | n=None, 460 | **kwargs, 461 | ): 462 | 463 | super().__init__( 464 | extract + "-voting" if n is None else extract + "-voting-" + str(n) 465 | ) 466 | self.gaps = gaps 467 | self.pmap = multiprocessing.map if multiprocessing else map 468 | self.exp_rate = exp_rate 469 | self.max_num_chains = max_num_chains 470 | self.burnin = burnin 471 | self.n = n 472 | self.chunk_size = 64 473 | 474 | if extract == "lastnumber": 475 | self.extract = get_last_number 476 | elif extract == "lastmath": 477 | self.extract = get_last_math 478 | 479 | elif extract == "lastoption": 480 | self.extract = get_last_option 481 | else: 482 | ValueError("Invalid extract function") 483 | 484 | def eval(self, exp: Exp): 485 | 486 | if self.gaps > 0: 487 | if not self.exp_rate: 488 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps) 489 | else: 490 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"]) 491 | 492 | else: 493 | x_axis = [exp.meta["steps"]] 494 | 495 | # num_chains = exp.meta.get("num_chains", 1) 496 | # max_num_chains = self.max_num_chains if self.max_num_chains else num_chains 497 | 498 | # instances are in groups of num_chains .. i.e. 499 | # [0,0,0,1,1,1,2,2,2,3,3,3,4,4,4] 500 | # I want to group into : 501 | # [[0,0,0],[1,1,1],[2,2,2],[3,3,3],[4,4,4]] 502 | # so that I can compute the mean of the estimates 503 | # for each group 504 | 505 | accs_full = {xi: [] for xi in x_axis} 506 | 507 | c = 0 508 | for instance_chunk in tqdm( 509 | chunked(exp.instances(lazy_iterable=True), self.chunk_size) 510 | ): 511 | 512 | c += len(instance_chunk) 513 | 514 | if c > self.n: 515 | break 516 | 517 | is_quest = "quest" in exp.meta["variant"] 518 | 519 | # instances = total_instances 520 | 521 | instances = [ 522 | join_instances_and_repeat_accepts([i], is_quest=is_quest) 523 | for i in instance_chunk 524 | ] 525 | # print(x_axis) 526 | # print(len(instances)) 527 | 528 | for i in x_axis: 529 | 530 | results = self.pmap( 531 | partial( 532 | is_mode_correct, 533 | n=i * 1, 534 | burnin=0, 535 | extract_func=self.extract, 536 | ), 537 | instances, 538 | ) 539 | 540 | accs, estimates = zip(*results) 541 | 542 | accs_full[i].extend(accs) 543 | 544 | # scores.append({"axis": x_axis, "scores": results}) 545 | # x_results.append(float(np.mean(list((accs))))) 546 | # n_estimates.append((estimates)) 547 | # accs_full.append(accs) 548 | 549 | # final_estimate = n_estimates[-1] 550 | # estimate = [float(np.mean(np.abs(i - final_estimate))) for i in n_estimates] 551 | 552 | results = [ 553 | { 554 | "axis": x_axis, 555 | "scores": [float(np.mean(accs_full[i])) for i in x_axis], 556 | "accs": [accs_full[i] for i in x_axis], 557 | } 558 | ] 559 | return results 560 | 561 | 562 | class LastPick(Evalutor): 563 | 564 | def __init__( 565 | self, 566 | gaps=2, 567 | exp_rate=True, 568 | max_num_chains=None, 569 | multiprocessing=None, 570 | extract="lastnumber", 571 | n=None, 572 | **kwargs, 573 | ): 574 | super().__init__("last-pick" if n is None else "last-pick-" + str(n)) 575 | self.gaps = gaps 576 | self.pmap = multiprocessing.map if multiprocessing else map 577 | self.exp_rate = exp_rate 578 | self.max_num_chains = max_num_chains 579 | self.n = n 580 | self.extract_key = extract 581 | 582 | if extract == "lastnumber": 583 | self.extract = get_last_number 584 | elif extract == "lastmath": 585 | self.extract = get_last_math 586 | elif extract == "lastoption": 587 | self.extract = get_last_option 588 | else: 589 | ValueError("Invalid extract function") 590 | 591 | def eval(self, exp: Exp): 592 | 593 | if self.gaps > 0: 594 | if not self.exp_rate: 595 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps) 596 | else: 597 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"]) 598 | 599 | else: 600 | x_axis = [exp.meta["steps"]] 601 | 602 | num_chains = exp.meta.get("num_chains", 1) 603 | max_num_chains = self.max_num_chains if self.max_num_chains else num_chains 604 | 605 | # instances are in groups of num_chains .. i.e. 606 | # [0,0,0,1,1,1,2,2,2,3,3,3,4,4,4] 607 | # I want to group into : 608 | # [[0,0,0],[1,1,1],[2,2,2],[3,3,3],[4,4,4]] 609 | # so that I can compute the mean of the estimates 610 | # for each group 611 | 612 | is_quest = "quest" in exp.meta["variant"] 613 | 614 | instances = [ 615 | join_instances_and_repeat_accepts( 616 | instances[i : i + num_chains][:max_num_chains], is_quest=is_quest 617 | ) 618 | for i in range(0, len(instances), num_chains) 619 | ] 620 | # print(x_axis) 621 | # print(len(instances)) 622 | 623 | x_results = [] 624 | n_estimates = [] 625 | accs_full = [] 626 | for i in tqdm(x_axis): 627 | is_quest = "quest" in exp.meta["variant"] 628 | 629 | results = self.pmap( 630 | partial( 631 | get_last, 632 | n=i * max_num_chains, 633 | extract_func=self.extract, 634 | ), 635 | instances, 636 | ) 637 | 638 | accs, estimates = zip(*results) 639 | 640 | # scores.append({"axis": x_axis, "scores": results}) 641 | x_results.append(float(np.mean(list((accs))))) 642 | n_estimates.append((estimates)) 643 | accs_full.append(accs) 644 | 645 | # final_estimate = n_estimates[-1] 646 | # estimate = [float(np.mean(np.abs(i - final_estimate))) for i in n_estimates] 647 | 648 | return [ 649 | { 650 | "axis": x_axis, 651 | "scores": x_results, 652 | "estimates": n_estimates, 653 | "accs": accs_full, 654 | } 655 | ] 656 | 657 | 658 | class BonPick(Evalutor): 659 | 660 | def __init__( 661 | self, 662 | key, 663 | gaps=2, 664 | exp_rate=True, 665 | multiprocessing=None, 666 | extract="lastnumber", 667 | n=None, 668 | **kwargs, 669 | ): 670 | super().__init__(key + "-bon" if n is None else key + "-bon-" + str(n)) 671 | self.reward_key = key 672 | self.gaps = gaps 673 | self.pmap = multiprocessing.map if multiprocessing else map 674 | self.exp_rate = exp_rate 675 | self.n = n 676 | self.extract_key = extract 677 | 678 | def eval(self, exp: Exp): 679 | 680 | assert exp.has_eval( 681 | self.reward_key 682 | ), f"Experiment does not have {self.reward_key} eval" 683 | 684 | if self.gaps > 0: 685 | if not self.exp_rate: 686 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps) 687 | else: 688 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"]) 689 | 690 | else: 691 | x_axis = [exp.meta["steps"]] 692 | 693 | # instances = exp.instances() 694 | 695 | if self.n is not None: 696 | n = self.n 697 | else: 698 | n = exp.meta["n"] 699 | # instances = instances[: self.n] 700 | 701 | x_results = [] 702 | for i in tqdm(x_axis): 703 | # spawn pool of workers 704 | 705 | results = list( 706 | self.pmap( 707 | partial( 708 | is_max_correct, 709 | n=i, 710 | ), 711 | zip( 712 | exp.get_eval(self.reward_key)[:n], 713 | exp.get_eval(self.extract_key)[:n], 714 | ), 715 | ) 716 | ) 717 | 718 | # scores.append({"axis": x_axis, "scores": results}) 719 | x_results.append(float(np.mean(list((results))))) 720 | 721 | return [{"axis": x_axis, "scores": x_results}] 722 | 723 | 724 | class BonResamplePick(Evalutor): 725 | 726 | def __init__( 727 | self, 728 | key, 729 | gaps=2, 730 | exp_rate=True, 731 | multiprocessing=None, 732 | r=128, 733 | trials=32, 734 | extract="lastnumber", 735 | n=None, 736 | **kwargs, 737 | ): 738 | super().__init__( 739 | key + "-bonresample-" + str(r) + "-" + str(trials) 740 | if n is None 741 | else key + "-bonresample-" + str(r) + "-" + str(trials) + "-" + str(n) 742 | ) 743 | self.reward_key = key 744 | self.gaps = gaps 745 | self.pmap = multiprocessing.map if multiprocessing else map 746 | self.exp_rate = exp_rate 747 | self.r = r 748 | self.trials = trials 749 | self.n = n 750 | 751 | if extract == "lastnumber": 752 | self.extract = get_last_number 753 | elif extract == "lastmath": 754 | self.extract = get_last_math 755 | elif extract == "lastoption": 756 | self.extract = get_last_option 757 | else: 758 | ValueError("Invalid extract function") 759 | 760 | def eval(self, exp: Exp): 761 | 762 | assert exp.has_eval( 763 | self.reward_key 764 | ), f"Experiment does not have {self.reward_key} eval" 765 | 766 | if self.gaps > 0: 767 | if not self.exp_rate: 768 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps) 769 | else: 770 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"]) 771 | 772 | else: 773 | x_axis = [exp.meta["steps"]] 774 | 775 | is_quest = "quest" in exp.meta["variant"] 776 | num_chains = exp.meta.get("num_chains", 1) 777 | instances = exp.instances() 778 | if self.n is not None: 779 | instances = instances[: self.n] 780 | evals_scores = exp.get_eval(self.reward_key) 781 | instances_and_scores = [ 782 | join_instances_and_repeat_accepts_with_scores( 783 | instances[i : i + num_chains], 784 | evals_scores=evals_scores[i : i + num_chains], 785 | is_quest=is_quest, 786 | ) 787 | for i in range(0, len(instances), num_chains) 788 | ] 789 | 790 | x_results = [] 791 | for i in tqdm(x_axis): 792 | # spawn pool of workers 793 | 794 | results = list( 795 | self.pmap( 796 | partial( 797 | is_resample_correct, 798 | n=i, 799 | extract_func=self.extract, 800 | r=self.r, 801 | trials=self.trials, 802 | ), 803 | instances_and_scores, 804 | ) 805 | ) 806 | 807 | # scores.append({"axis": x_axis, "scores": results}) 808 | x_results.append(float(np.mean(list((results))))) 809 | 810 | return [{"axis": x_axis, "scores": x_results}] 811 | 812 | 813 | class WeightedVotingPick(Evalutor): 814 | 815 | def __init__( 816 | self, 817 | key, 818 | gaps=2, 819 | exp_rate=False, 820 | multiprocessing=False, 821 | beta=1.0, 822 | extract="lastnumber", 823 | n=None, 824 | **kwargs, 825 | ): 826 | super().__init__( 827 | key + "-weighted-voting-" + str(beta).replace(".", ";") 828 | if n is None 829 | else key + "-weighted-voting-" + str(beta).replace(".", ";") + "-" + str(n) 830 | ) 831 | self.reward_key = key 832 | self.gaps = gaps 833 | self.exp_rate = exp_rate 834 | self.pmap = multiprocessing.map if multiprocessing else map 835 | self.beta = beta 836 | self.n = n 837 | self.chunk_size = 64 838 | 839 | if extract == "lastnumber": 840 | self.extract = get_last_number 841 | elif extract == "lastmath": 842 | self.extract = get_last_math 843 | elif extract == "lastoption": 844 | self.extract = get_last_option 845 | else: 846 | ValueError("Invalid extract function") 847 | 848 | def eval(self, exp: Exp): 849 | 850 | assert exp.has_eval( 851 | self.reward_key 852 | ), f"Experiment does not have {self.reward_key} eval" 853 | 854 | if self.gaps > 0: 855 | if not self.exp_rate: 856 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps) 857 | else: 858 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"]) 859 | 860 | else: 861 | x_axis = [exp.meta["steps"]] 862 | 863 | accs_full = {xi: [] for xi in x_axis} 864 | 865 | evals_scores = exp.get_eval(self.reward_key) 866 | it_idx = 0 867 | for instance_chunk in tqdm( 868 | chunked(exp.instances(lazy_iterable=True), self.chunk_size) 869 | ): 870 | 871 | evals_scores_chunk = evals_scores[it_idx : it_idx + len(instance_chunk)] 872 | it_idx += len(instance_chunk) 873 | 874 | is_quest = "quest" in exp.meta["variant"] 875 | 876 | instances_and_scores = [ 877 | join_instances_and_repeat_accepts_with_scores( 878 | [ich], evals_scores=[ech], is_quest=is_quest 879 | ) 880 | for ich, ech in zip(instance_chunk, evals_scores_chunk) 881 | ] 882 | 883 | for i in x_axis: 884 | 885 | results = self.pmap( 886 | partial( 887 | is_weighted_mode_correct, # is_weighted_mode_correct 888 | n=i * 1, 889 | extract_func=self.extract, 890 | beta=self.beta, 891 | ), 892 | instances_and_scores, 893 | ) 894 | 895 | accs, estimates = zip(*results) 896 | 897 | accs_full[i].extend(accs) 898 | 899 | results = [ 900 | { 901 | "axis": x_axis, 902 | "scores": [float(np.mean(accs_full[i])) for i in x_axis], 903 | "accs": [accs_full[i] for i in x_axis], 904 | } 905 | ] 906 | return results 907 | 908 | 909 | class MBRPick(Evalutor): 910 | 911 | def __init__( 912 | self, 913 | gaps=2, 914 | extract="lastnumber", 915 | n=None, 916 | metric="rouge", 917 | **kwargs, 918 | ): 919 | super().__init__(f"{metric}-mbr" if n is None else f"{metric}-mbr-" + str(n)) 920 | self.gaps = gaps 921 | 922 | self.metric = metric 923 | self.n = n 924 | print(f"Filename:", self.eval_name) 925 | 926 | self.extract_key = extract 927 | 928 | def eval(self, exp: Exp): 929 | 930 | if not self.n: 931 | n = exp.meta["n"] 932 | else: 933 | n = self.n 934 | 935 | pairwise_pick_results = mbr_pick_progression( 936 | exp, 937 | compute_in_parallel=True, 938 | k=self.gaps, 939 | n=n, 940 | metric=self.metric, 941 | ) 942 | 943 | gt = exp.get_eval(self.extract_key)[:n] 944 | 945 | scores = {} 946 | for ax, preds in pairwise_pick_results["preds"].items(): 947 | pred_scores = [gt[j]["scores"][pred] for j, pred in enumerate(preds)] 948 | scores[ax] = float(np.mean(pred_scores)) 949 | 950 | output = { 951 | "axis": pairwise_pick_results["axis"], 952 | "scores": [scores[ax] for ax in pairwise_pick_results["axis"]], 953 | } 954 | 955 | print(output) 956 | 957 | return [output] 958 | 959 | 960 | class WMBRPick(Evalutor): 961 | 962 | def __init__( 963 | self, 964 | key, 965 | gaps=2, 966 | extract="lastnumber", 967 | n=None, 968 | metric="rouge", 969 | **kwargs, 970 | ): 971 | super().__init__(f"{metric}-wmbr" if n is None else f"{metric}-wmbr-" + str(n)) 972 | self.gaps = gaps 973 | self.metric = metric 974 | self.n = n 975 | self.reward_key = key 976 | print(f"Filename:", self.eval_name) 977 | 978 | self.extract_key = extract 979 | 980 | def eval(self, exp: Exp): 981 | 982 | if not self.n: 983 | n = exp.meta["n"] 984 | else: 985 | n = self.n 986 | 987 | pairwise_pick_results = weighted_mbr_pick_progression( 988 | exp, 989 | reward_key=self.reward_key, 990 | compute_in_parallel=True, 991 | k=self.gaps, 992 | n=n, 993 | metric=self.metric, 994 | ) 995 | 996 | gt = exp.get_eval(self.extract_key)[:n] 997 | 998 | scores = {} 999 | for ax, preds in pairwise_pick_results["preds"].items(): 1000 | pred_scores = [gt[j]["scores"][pred] for j, pred in enumerate(preds)] 1001 | scores[ax] = float(np.mean(pred_scores)) 1002 | 1003 | output = { 1004 | "axis": pairwise_pick_results["axis"], 1005 | "scores": [scores[ax] for ax in pairwise_pick_results["axis"]], 1006 | } 1007 | 1008 | print(output) 1009 | 1010 | return [output] 1011 | -------------------------------------------------------------------------------- /resume_experiment.py: -------------------------------------------------------------------------------- 1 | ## qalign 2 | from qalign.utils.experiment import run_experiment 3 | 4 | ## expkit 5 | from expkit.storage import DiskStorage 6 | from expkit import Exp 7 | 8 | 9 | def main( 10 | experiment_name="", 11 | save_path: str = "llama3.2-outputs/", 12 | reward_model_batch_size: int = 128, 13 | gpu_memory_utilization: float = 0.95, 14 | reward_device=1, 15 | device_count: int = 1, 16 | remote=False, 17 | ): 18 | 19 | storage = DiskStorage(save_path, "rw") 20 | 21 | if storage.exists(experiment_name): 22 | experiment = Exp.load(storage=storage, name=experiment_name) 23 | 24 | run_experiment( 25 | experiment=experiment, 26 | gpu_memory_utilization=gpu_memory_utilization, 27 | device_count=device_count, 28 | reward_model_batch_size=reward_model_batch_size, 29 | reward_device=reward_device, 30 | remote=remote, 31 | ) 32 | 33 | else: 34 | raise ValueError(f"Experiment {experiment_name} does not exist.") 35 | 36 | 37 | if __name__ == "__main__": 38 | import fire 39 | 40 | fire.Fire(main) 41 | -------------------------------------------------------------------------------- /resume_experiment_remote.py: -------------------------------------------------------------------------------- 1 | ## qalign 2 | from qalign.utils.experiment import run_experiment_remote 3 | 4 | ## expkit 5 | from expkit.storage import DiskStorage 6 | from expkit import Exp 7 | 8 | 9 | # 280bd4d8-9e11-48d4-812b-fd2c5666684d 10 | def main( 11 | experiment_name="", 12 | save_path: str = "llama3.2-outputs/", 13 | ): 14 | 15 | storage = DiskStorage(save_path, "rw") 16 | 17 | if storage.exists(experiment_name): 18 | experiment = Exp.load(storage=storage, name=experiment_name) 19 | 20 | run_experiment_remote( 21 | experiment=experiment, 22 | ) 23 | 24 | else: 25 | raise ValueError(f"Experiment {experiment_name} does not exist.") 26 | 27 | 28 | if __name__ == "__main__": 29 | import fire 30 | 31 | fire.Fire(main) 32 | -------------------------------------------------------------------------------- /scripts/create_all_general_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | MODELPATH=allenai/Llama-3.1-Tulu-3-8B-SFT 5 | DIR="outputs/" 6 | 7 | TEMP=1.0 8 | N=128 9 | 10 | #### QUEST 11 | RM_PATH=allenai/Llama-3.1-Tulu-3-8B-RM 12 | BETA=0.5 13 | STEPS=1024 14 | 15 | DATASET="HuggingFaceH4/MATH-500" 16 | PROMPT="Solve the following math problem step-by-step: {prompt}\n\nPresent the answer in LaTex format: \\boxed{Your answer}" 17 | 18 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type contextual --split "test" --num_chains 1 --batch_size 32 --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT" 19 | 20 | DATASET="openai/gsm8k" 21 | PROMPT="Solve the following grade school math problem step-by-step: {prompt}" 22 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type contextual --split "test" --num_chains 1 --batch_size 32 --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT" 23 | 24 | DATASET="google/IFEval" 25 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type contextual --split "train" --num_chains 1 --batch_size 32 --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 26 | 27 | DATASET="edinburgh-dawg/mmlu-redux" 28 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type contextual --split "test" --num_chains 1 --batch_size 32 --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 29 | 30 | 31 | ## ANCESTRAL 32 | DATASET="HuggingFaceH4/MATH-500" 33 | PROMPT="Solve the following math problem step-by-step: {prompt}\n\nPresent the answer in LaTex format: \\boxed{Your answer}" 34 | 35 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT" 36 | 37 | DATASET="openai/gsm8k" 38 | PROMPT="Solve the following grade school math problem step-by-step: {prompt}" 39 | 40 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT" 41 | 42 | DATASET="google/IFEval" 43 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "train" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 44 | 45 | DATASET="edinburgh-dawg/mmlu-redux" 46 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 47 | 48 | ## DPO ANCESTRAL 49 | MODELPATH=allenai/Llama-3.1-Tulu-3-8B-DPO 50 | 51 | DATASET="HuggingFaceH4/MATH-500" 52 | PROMPT="Solve the following math problem step-by-step: {prompt}\n\nPresent the answer in LaTex format: \\boxed{Your answer}" 53 | 54 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT" 55 | 56 | DATASET="openai/gsm8k" 57 | PROMPT="Solve the following grade school math problem step-by-step: {prompt}" 58 | 59 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT" 60 | 61 | 62 | DATASET="google/IFEval" 63 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "train" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 64 | 65 | DATASET="edinburgh-dawg/mmlu-redux" 66 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 67 | 68 | 69 | # Loop through all files in the directory 70 | for file in "$DIR/*; do 71 | # Skip if it's a directory 72 | if [ -f "$file" ]; then 73 | # Get just the filename without path 74 | fileid=$(basename "$file") 75 | 76 | echo "Processing: $fileid" 77 | 78 | # Launch Python script with the filename as experiment_name 79 | python resume_experiment.py --experiment_name "$fileid" --save_path "$DIR" 80 | 81 | echo "" 82 | fi 83 | done 84 | 85 | echo "All experiments have been processed." -------------------------------------------------------------------------------- /scripts/create_all_task_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | MODELPATH=meta-llama/Llama-3.1-8B-Instruct 5 | DIR="outputs-task/" 6 | RM_PATH="/gscratch/ark/graf/LLaMA-Factory/saves/llama3/8b/full/reward/" 7 | TEMP=1.0 8 | N=128 9 | STEPS=4096 10 | BATCH_SIZE=128 11 | 12 | ## ANCESTRAL 13 | DATASET="openai/gsm8k" 14 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 15 | 16 | ## QUEST 17 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type value --split "test" --num_chains 1 --batch_size $BATCH_SIZE --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 18 | 19 | DATASET="apple/GSM-Symbolic-p1" 20 | ## ANCESTRAL 21 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 22 | 23 | ## QUEST 24 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type value --split "test" --num_chains 1 --batch_size $BATCH_SIZE --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET 25 | -------------------------------------------------------------------------------- /scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | directories=("outputs/" "outputs-task/") 5 | 6 | echo "Comparing responses to GT for all datasets" 7 | 8 | DIR="outputs/" 9 | 10 | RM="lastnumber" 11 | query_args='{"dataset":"openai/gsm8k"}' 12 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args 13 | 14 | RM="lastmath" 15 | query_args='{"dataset":"HuggingFaceH4/MATH-500"}' 16 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args 17 | 18 | 19 | RM="ifeval" 20 | query_args='{"dataset":"google/IFEval"}' 21 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args 22 | 23 | 24 | RM="lastoption" 25 | 26 | query_args='{"dataset":"edinburgh-dawg/mmlu-redux"}' 27 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args 28 | 29 | query_args='{"dataset":"truthfulqa/truthful_qa"}' 30 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args 31 | 32 | 33 | DIR="outputs-task/" 34 | 35 | RM="lastnumber" 36 | query_args='{}' 37 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args -------------------------------------------------------------------------------- /scripts/run_local_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | DIR="outputs-task/" 5 | 6 | # Loop through all files in the directory 7 | for file in "$DIR/*; do 8 | # Skip if it's a directory 9 | if [ -f "$file" ]; then 10 | # Get just the filename without path 11 | fileid=$(basename "$file") 12 | 13 | echo "Processing: $fileid" 14 | 15 | # Launch Python script with the filename as experiment_name 16 | python resume_experiment.py --experiment_name "$fileid" --save_path "$DIR" 17 | 18 | echo "" 19 | fi 20 | done 21 | 22 | echo "All experiments have been processed." -------------------------------------------------------------------------------- /scripts/run_pred.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | echo "Running MV WMV and BON and for all datasets" 5 | 6 | 7 | ## general exps 8 | DIR="outputs/" 9 | 10 | RM="lastnumber" 11 | query_args='{"dataset":"openai/gsm8k"}' 12 | python pred.py --base_dir "$DIR" --extract $RM --strategy "voting" --query_args $query_args 13 | 14 | RM="lastmath" 15 | 16 | query_args='{"dataset":"HuggingFaceH4/MATH-500"}' 17 | python pred.py --base_dir "$DIR" --extract $RM --strategy "voting" --query_args $query_args 18 | 19 | 20 | RM="ifeval" 21 | 22 | query_args='{"dataset":"google/IFEval"}' 23 | python pred.py --base_dir "$DIR" --extract $RM --strategy "mbr" --query_args $query_args 24 | 25 | 26 | RM="lastoption" 27 | 28 | query_args='{"dataset":"edinburgh-dawg/mmlu-redux"}' 29 | python pred.py --base_dir "$DIR" --extract $RM --strategy "voting" --query_args $query_args 30 | 31 | query_args='{"dataset":"truthfulqa/truthful_qa"}' 32 | python pred.py --base_dir "$DIR" --extract $RM --strategy "voting" --query_args $query_args 33 | 34 | 35 | 36 | ## bon and WMV on general exps 37 | 38 | RM="crm:allenai-Llama-3" 39 | 40 | query_args='{"dataset":"google/IFEval","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}' 41 | 42 | python pred.py --base_dir $DIR --strategy "bon" --extract "ifeval" --query_args $query_args --key $RM 43 | python pred.py --base_dir $DIR --strategy "wmbr" --extract "ifeval" --query_args $query_args --key $RM 44 | 45 | query_args='{"dataset":"HuggingFaceH4/MATH-500","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}' 46 | 47 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastmath" --query_args $query_args --key $RM 48 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastmath" --query_args $query_args --key $RM 49 | 50 | query_args='{"dataset":"openai/gsm8k","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}' 51 | 52 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastnumber" --query_args $query_args --key $RM 53 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastnumber" --query_args $query_args --key $RM 54 | 55 | query_args='{"dataset":"truthfulqa/truthful_qa","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}' 56 | 57 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastoption" --query_args $query_args --key $RM 58 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastoption" --query_args $query_args --key $RM 59 | 60 | query_args='{"dataset":"edinburgh-dawg/mmlu-redux","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}' 61 | 62 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastoption" --query_args $query_args --key $RM 63 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastoption" --query_args $query_args --key $RM 64 | 65 | 66 | ### task specific exps 67 | RM="vh:-gscratch-ark-graf-LLaMA-Factory-saves-llama3-8b-full-reward-" 68 | 69 | DIR="outputs-task/" 70 | 71 | query_args='{}' 72 | 73 | python pred.py --base_dir "$DIR" --strategy "voting" --extract "lastnumber" --query_args $query_args 74 | 75 | query_args='{"variant":"ancestral"}' 76 | 77 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastnumber" --query_args $query_args --key $RM 78 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastnumber" --query_args $query_args --key $RM 79 | -------------------------------------------------------------------------------- /scripts/run_remote_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIR="outputs-task/" 4 | 5 | # Loop through all files in the directory 6 | for file in "$DIR/*; do 7 | # Skip if it's a directory 8 | if [ -f "$file" ]; then 9 | # Get just the filename without path 10 | fileid=$(basename "$file") 11 | 12 | echo "Processing: $fileid" 13 | 14 | # Launch Python script with the filename as experiment_name 15 | python resume_experiment_remote.py --experiment_name "$fileid" --save_path "$DIR" 16 | 17 | echo "" 18 | fi 19 | done 20 | 21 | echo "All experiments have been processed." -------------------------------------------------------------------------------- /scripts/run_rm_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | RM=allenai/Llama-3.1-Tulu-3-8B-RM 5 | DIR="outputs/" 6 | query_args='{"variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}' 7 | python eval.py --base_dir $DIR --remote True --reward_model_path $RM --value_head False --batch_size 1024 --query_args $query_args 8 | 9 | RM="/gscratch/ark/graf/LLaMA-Factory/saves/llama3/8b/full/reward/" 10 | DIR="outputs-task/" 11 | query_args='{"variant":"ancestral"}' 12 | python eval.py --base_dir $DIR --remote True --reward_model_path $RM --value_head True --batch_size 1024 --query_args $query_args 13 | --------------------------------------------------------------------------------