├── README.md
├── assets
└── general_fig.png
├── eval.py
├── launch_experiment.py
├── pred.py
├── qalign
├── __init__.py
├── rm
│ ├── __init__.py
│ ├── create_dataset.py
│ ├── data
│ │ ├── dataset_info.json
│ │ ├── llama-factory_gsm8k_llama-3.1-8b-instruct_128_1_train.json
│ │ ├── llama-factory_gsm8k_llama-3.1-8b-instruct_test.json
│ │ ├── llama-factory_gsm8k_test.json
│ │ ├── raw_gsm8k_llama-3.1-8b-instruct_128_1_train.json
│ │ ├── raw_gsm8k_llama-3.1-8b-instruct_test.json
│ │ └── raw_gsm8k_test.json
│ ├── train_configs
│ │ ├── deepspeed
│ │ │ ├── ds_z0_config.json
│ │ │ ├── ds_z2_config.json
│ │ │ ├── ds_z2_offload_config.json
│ │ │ ├── ds_z3_config.json
│ │ │ └── ds_z3_offload_config.json
│ │ ├── explain.txt
│ │ ├── gsm8k
│ │ │ ├── full
│ │ │ │ ├── gemma2_full_dpo.yaml
│ │ │ │ ├── gemma2_full_dpo05.yaml
│ │ │ │ ├── gemma2_full_dpo10.yaml
│ │ │ │ ├── gemma2_full_dpo_sft05.yaml
│ │ │ │ ├── gemma2_full_dpo_sft10.yaml
│ │ │ │ ├── gemma2_full_reward.yaml
│ │ │ │ ├── llama3_1b1b_full_reward.yaml
│ │ │ │ ├── llama3_1b8b_full_reward.yaml
│ │ │ │ ├── llama3_3b1b_full_reward.yaml
│ │ │ │ ├── llama3_3b8b_full_reward.yaml
│ │ │ │ ├── llama3_3b_full_reward.yaml
│ │ │ │ ├── llama3_8b1b_full_reward.yaml
│ │ │ │ ├── llama3_8b8b_full_dpo.yaml
│ │ │ │ ├── llama3_8b8b_full_reward.yaml
│ │ │ │ ├── olmo_full_reward.yaml
│ │ │ │ ├── train.sh
│ │ │ │ └── tulu_8b8b_full_reward.yaml
│ │ │ └── lora
│ │ │ │ ├── gemma2_lora_dpo.yaml
│ │ │ │ ├── gemma2_lora_reward.yaml
│ │ │ │ ├── llama2_lora_reward.yaml
│ │ │ │ ├── llama2_lora_sft.yaml
│ │ │ │ └── olmo_lora_reward.yaml
│ │ └── math
│ │ │ └── full
│ │ │ ├── llama3_1b8b_full_reward_math.yaml
│ │ │ ├── llama3_8b8b_full_reward_math.yaml
│ │ │ ├── llama3_8b8b_full_reward_math_cot.yaml
│ │ │ └── tulu_8b8b_full_reward_math.yaml
│ └── valid_check.py
└── utils
│ ├── data.py
│ ├── eval.py
│ ├── examples.py
│ ├── experiment.py
│ ├── ifeval
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── eval.cpython-38.pyc
│ │ ├── instructions.cpython-38.pyc
│ │ ├── instructions_util.cpython-38.pyc
│ │ └── registry.cpython-38.pyc
│ ├── eval.py
│ ├── instructions.py
│ ├── instructions_util.py
│ └── registry.py
│ ├── math.py
│ ├── mbr.py
│ └── pred.py
├── resume_experiment.py
├── resume_experiment_remote.py
└── scripts
├── create_all_general_experiments.sh
├── create_all_task_experiments.sh
├── run_eval.sh
├── run_local_experiments.sh
├── run_pred.sh
├── run_remote_experiments.sh
└── run_rm_eval.sh
/README.md:
--------------------------------------------------------------------------------
1 | # Sample, Don't Search: Rethinking Test-Time Alignment for Language Models
2 |
3 | Gonçalo Faria, Noah A. Smith
4 |
5 | **Paper**: https://arxiv.org/abs/2504.03790
6 |
7 | **TL;DR:** QAlign is a new test-time alignment approach that improves language model performance by using Markov chain Monte Carlo methods.
8 |
9 | ### Abstract:
10 | Increasing test-time computation has emerged as a promising direction for improving language model performance, particularly in scenarios where model finetuning is impractical or impossible due to computational constraints or private model weights. However, existing test-time search methods using a reward model (RM) often degrade in quality as compute scales, due to the over-optimization of what are inherently imperfect reward proxies. We introduce QAlign, a new test-time alignment approach. As we scale test-time compute, QAlign converges to sampling from the optimal aligned distribution for each individual prompt. By adopting recent advances in Markov chain Monte Carlo for text generation, our method enables better-aligned outputs without modifying the underlying model or even requiring logit access. We demonstrate the effectiveness of QAlign on mathematical reasoning benchmarks (GSM8K and GSM-Symbolic) using a task-specific RM, showing consistent improvements over existing test-time compute methods like best-of-n and majority voting. Furthermore, when applied with more realistic RMs trained on the Tulu 3 preference dataset, QAlign outperforms direct preference optimization (DPO), best-of-n, majority voting, and weighted majority voting on a diverse range of datasets (GSM8K, MATH500, IFEval, MMLU-Redux, and TruthfulQA). A practical solution to aligning language models at test time using additional computation without degradation, our approach expands the limits of the capability that can be obtained from off-the-shelf language models without further training.
11 |
12 |
13 |
14 | 
15 |
Average error rate across multiple evaluation datasets (GSM8K, MATH500, MMLU-Redux, TruthfulQA, and IFEval) as a function of the floating point operations (FLOPS) in log scale.
16 | We compare QAlign method with Tülu3-8B-SFT against four baselines: majority vote (MV) Tülu3-8B-DPO, and applied to Tülu3-8B-SFT the methods best-of-n (BoN), MV, and weighted MV (WMV). All experiments use temperature 1.0 with reasoning included in model outputs. Note that Tülu3-8B-DPO model is the result of doing preference finetuning on the Tülu3-8B-SFT with 271k preference pairs. The costs associated with this process are not accounted for in this plot.
17 |
18 |
19 | -----
20 | ## Dependencies
21 |
22 | This project relies strongly on the following external libraries:
23 | - [deepspin/quest-decoding](https://github.com/deep-spin/quest-decoding)
24 | - [goncalorafaria/expkit](https://github.com/goncalorafaria/expkit-core)
25 | - [goncalorafaria/literegistry](https://github.com/goncalorafaria/literegistry)
26 |
27 | ```bash
28 | pip install quest-decoding
29 | pip install expkit-core
30 | pip install literegistry
31 | ```
32 |
33 | Install the required packages:
34 | ```bash
35 | pip install -r requirements.txt
36 | ```
37 |
38 | -----
39 | ## Reproducing the work
40 |
41 | Replicating the work:
42 |
43 | ### Experiment Setup
44 | 1. **Create Configuration Files**
45 | ```bash
46 | # Create configs for general experiments
47 | scripts/create_all_general_experiments.sh
48 |
49 | # Create configs for task-specific experiments
50 | scripts/create_all_task_experiments.sh
51 | ```
52 |
53 | ### Running Experiments
54 | 2. **Execute Experiments**
55 | ```bash
56 | # Run experiments locally
57 | scripts/run_local_experiments.sh
58 |
59 | # Run experiments on remote server
60 | scripts/run_remote_experiments.sh
61 | ```
62 |
63 | ### Evaluation & Analysis
64 | 3. **Evaluate Results**
65 | ```bash
66 | # Compare responses against ground truth answers
67 | scripts/run_eval_experiment.sh
68 |
69 | # Evaluate reward model for ancestral predictions (remote by default)
70 | scripts/run_rm_eval.sh
71 | ```
72 |
73 | 4. **Generate Final Predictions**
74 | ```bash
75 | # Run WMV, BON, and MV final prediction methods
76 | scripts/run_pred.sh
77 | ```
78 |
79 |
80 | -----
81 |
82 | ## Quick Start
83 |
84 | This guide will help you get started running QAlign.
85 |
86 | ## Basic Usage
87 |
88 | ```python
89 | import os
90 | from quest.core import Quest
91 | from quest.reward.model import ContextualRewardModel, ValueHead
92 | from quest.qalign import QAlign
93 | from quest.model.vllm import VLLM
94 |
95 | # Model configuration
96 | model_path = "allenai/Llama-3.1-Tulu-3-8B-SFT"
97 | model_args = {
98 | "model_path": model_path,
99 | "download_dir": os.environ.get("HF_HOME", "/tmp/"),
100 | "stop_tokens": ["", "<|im_end|>"],
101 | "temperature": 0.7,
102 | "gpu_memory_utilization": 0.9,
103 | "dtype": "bfloat16",
104 | "max_new_tokens": 512,
105 | "max_prompt_length": 4096,
106 | "tensor_parallel_size": 1, # Number of GPUs
107 | "enable_prefix_caching": True,
108 | "enforce_eager": True,
109 | }
110 |
111 | # Initialize the model
112 | model = VLLM(**model_args)
113 |
114 | # Initialize the reward model
115 | reward = ContextualRewardModel(
116 | model_path="allenai/Llama-3.1-Tulu-3-8B-RM",
117 | device=1, ## second gpu
118 | device_count=1,
119 | )
120 |
121 | # Prepare your data
122 | data_batch = [
123 | {
124 | "prompt": "<|user|>\nJanet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\n<|assistant|>\n"
125 | },
126 | # Add more examples as needed
127 | ]
128 |
129 | # Create markov chain
130 | chain = QAlign(
131 | input_data=data_batch,
132 | model=model,
133 | reward=reward,
134 | beta=1.0, # Controls exploration vs exploitation
135 | )
136 |
137 | # Run
138 | chain_outputs = chain.run(
139 | steps=10, # Number of steps
140 | use_tqdm=True, # Show progress bar
141 | )
142 |
143 | # Print the accepted outputs
144 | print(f"Original prompt: {chain_outputs[0]['input']['prompt']}")
145 | for output in chain_outputs[0]["outputs"]:
146 | if output["accept"]:
147 | print(f"Response: {output['text']}")
148 | print("-" * 50)
149 |
150 | ```
151 |
152 | -----
153 |
154 | ## Contact
155 |
156 | For bugs and feature requests please visit [GitHub Issues](https://github.com/goncalorafaria/qalign/issues). For business inquiries or
157 | professional support requests please send an [e-mail](mailto:goncalofaria.research@gmail.com).
158 |
159 | -----
160 |
161 | ## Citation
162 |
163 | ````
164 | @misc{faria2025sampledontsearchrethinking,
165 | title={Sample, Don't Search: Rethinking Test-Time Alignment for Language Models},
166 | author={Gonçalo Faria and Noah A. Smith},
167 | year={2025},
168 | eprint={2504.03790},
169 | archivePrefix={arXiv},
170 | primaryClass={cs.CL},
171 | url={https://arxiv.org/abs/2504.03790},
172 | }
173 | ````
174 |
175 |
--------------------------------------------------------------------------------
/assets/general_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/assets/general_fig.png
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | ## quest
4 | from quest.reward.model import ValueHead
5 | from quest.reward.remote import RemoteReward, RemoteReward
6 | from quest import (
7 | RewardModel,
8 | ContextualRewardModel,
9 | )
10 |
11 | ## expkit
12 | from expkit.setup import ExpSetup
13 | from expkit.storage import DiskStorage
14 |
15 | ## qalign
16 | from qalign.utils.eval import *
17 |
18 | ## literegistry
19 | from literegistry import RegistryClient, FileSystemKVStore
20 |
21 |
22 | def main(
23 | base_dir="remote-outputs-llama/",
24 | reward_model_path="lastnumber",
25 | model_path="allenai/tulu-2-7b",
26 | batch_size=248,
27 | context=True,
28 | query_args={},
29 | device_count=8,
30 | value_head=True,
31 | remote=False,
32 | n=None,
33 | ):
34 |
35 | print("Query Args:", query_args)
36 | setup = ExpSetup(storage=DiskStorage(base_dir=base_dir, mode="rw"))
37 | # print("Exp:", setup.query(query_args))
38 |
39 | setup = setup.query(query_args).filter(lambda x: x.has_data())
40 |
41 | print("That match the query:\n", setup)
42 |
43 | if len(setup.experiments) == 0:
44 | raise FileNotFoundError("The experiment has no data!")
45 |
46 | if reward_model_path == "likelihood":
47 |
48 | ps_eval = LikelihoodEval(model_path=model_path)
49 |
50 | elif reward_model_path == "lastnumber":
51 | ps_eval = ExactLastNumberEval()
52 |
53 | elif reward_model_path == "lastmath":
54 | ps_eval = ExactMATHEval()
55 |
56 | elif reward_model_path == "lastoption":
57 | ps_eval = ExactQAEval()
58 |
59 | elif reward_model_path == "ifeval":
60 | ps_eval = IFEval()
61 |
62 | elif remote:
63 |
64 | if value_head:
65 | reward_type = "value"
66 | elif context:
67 | reward_type = "contextual"
68 | else:
69 | reward_type = "reward"
70 |
71 | registry = RegistryClient(
72 | store=FileSystemKVStore("/gscratch/ark/graf/registry"),
73 | max_history=3600,
74 | cache_ttl=60,
75 | service_type="model_path",
76 | )
77 |
78 | reward = RemoteReward(
79 | model_path=reward_model_path,
80 | registry=registry,
81 | reward_type=reward_type,
82 | # batch_size=batch_size,
83 | # max_parallel_requests=32,
84 | batch_size=32,
85 | max_parallel_requests=64,
86 | )
87 | ps_eval = RewardEval(
88 | reward=reward,
89 | n=n,
90 | chunk_size=256,
91 | )
92 | elif value_head:
93 |
94 | reward = ValueHead(
95 | model_path=reward_model_path,
96 | batch_size=batch_size,
97 | device_count=device_count,
98 | )
99 |
100 | ps_eval = RewardEval(reward=reward, n=n)
101 |
102 | else:
103 | if context:
104 | reward = ContextualRewardModel(
105 | model_path=reward_model_path,
106 | batch_size=batch_size,
107 | device_count=device_count,
108 | )
109 |
110 | else:
111 | reward = RewardModel(
112 | model_path=reward_model_path,
113 | batch_size=batch_size,
114 | device_count=device_count,
115 | )
116 |
117 | ps_eval = RewardEval(reward=reward, n=n)
118 |
119 | # setup = setup.filter(lambda x: not x.has_eval(ps_eval.eval_name))
120 | # setup = setup.filter(lambda x: x.get("n") == len(x.instances()))
121 | print("That haven't done the eval:", setup)
122 |
123 | def func(experiment):
124 |
125 | try:
126 | return (
127 | ps_eval(experiment)
128 | # if not experiment.has_eval(ps_eval.eval_name)
129 | # else experiment
130 | )
131 | except FileNotFoundError:
132 | return experiment
133 |
134 | except Exception as e:
135 | raise e
136 | # return experiment
137 |
138 | setup = setup.map(func)
139 |
140 | # new_setup.save()
141 |
142 |
143 | if __name__ == "__main__":
144 |
145 | import fire
146 |
147 | fire.Fire(main)
148 |
--------------------------------------------------------------------------------
/launch_experiment.py:
--------------------------------------------------------------------------------
1 | from qalign.utils.experiment import create_experiment
2 |
3 |
4 | def main(
5 | variant="quest-rlhf",
6 | beta: float = 1.0,
7 | steps: int = 4096,
8 | temperature: float = 1.0,
9 | n: int = 64,
10 | model_path: str = "meta-llama/Llama-3.1-8B-Instruct",
11 | reward_model_path: str = "/gscratch/ark/graf/quest-rlhf/qflow/rm/artifacts/llama3/8b8b/mathcotfix/full/reward", # "/gscratch/ark/graf/LLaMA-Factory/saves/llama3/8b/full/reward/",
12 | reward_model_batch_size: int = 4,
13 | save_path: str = "remote-outputs-llama/",
14 | gpu_memory_utilization: float = 0.8,
15 | max_new_tokens: int = 800,
16 | max_prompt_length: int = 1200,
17 | dataset_path="HuggingFaceH4/MATH-500", # "apple/GSM-Symbolic-p1",
18 | batch_size=64,
19 | stop_tokens=[],
20 | prompt_template="{prompt}",
21 | num_chains: int = 1,
22 | use_few_shot: bool = False,
23 | format: str = "chat",
24 | split: str = "test",
25 | remote: bool = True,
26 | reward_type="value",
27 | ):
28 |
29 | additional_meta = {
30 | "remote": remote,
31 | }
32 |
33 | if variant == "quest-rlhf":
34 |
35 | # Create experiment with additional meta data
36 | additional_meta = {
37 | "beta": beta,
38 | "reward_model_path": reward_model_path,
39 | "reward_type": reward_type,
40 | # "i": start_index,
41 | "num_chains": num_chains,
42 | **additional_meta,
43 | }
44 |
45 | experiment = create_experiment(
46 | save_path=save_path,
47 | variant=variant,
48 | model_path=model_path,
49 | dataset_path=dataset_path,
50 | n=n,
51 | temperature=temperature,
52 | steps=steps,
53 | max_new_tokens=max_new_tokens,
54 | max_prompt_length=max_prompt_length,
55 | stop_tokens=stop_tokens,
56 | prompt_template=prompt_template,
57 | additional_meta=additional_meta,
58 | batch_size=batch_size,
59 | use_few_shot=use_few_shot,
60 | format=format,
61 | split=split,
62 | )
63 |
64 |
65 | if __name__ == "__main__":
66 | import fire
67 |
68 | fire.Fire(main)
69 |
--------------------------------------------------------------------------------
/pred.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 | from multiprocessing import Pool
3 |
4 |
5 | ## expkit
6 | from expkit.setup import ExpSetup
7 | from expkit.storage import DiskStorage
8 |
9 | ## qalign
10 | from qalign.utils.pred import *
11 | from qalign.utils.eval import *
12 |
13 |
14 | def main(
15 | base_dir="remote-outputs/", # "llama3.2-outputs/", # "tqa-outputs/",
16 | strategy="voting",
17 | key="crm:allenai-Llama-3", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-tulu-8b8b-gsm8k-full-reward-", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-tulu-8b8b-math-full-reward-", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-tulu-8b8b-gsm8k-full-reward-", # "crm:allenai-Llama-3", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-llama3-8b8b-mathcot-full-reward", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-llama3-8b8b-math-full-reward-", # "vh:-gscratch-ark-graf-quest-rlhf-qflow-rm-artifacts-llama3-8b8b-math-full-reward-", # "vh:-gscratch-ark-graf-LLaMA-Factory-saves-llama3-8b8b-full-reward-", # "vh:-gscratch-ark-graf-LLaMA-Factory-saves-llama3-3b-full-reward-",
18 | query_args={
19 | # "steps": 4096,
20 | # "split": "test",
21 | # "split": "validation",
22 | "temperature": 1.0,
23 | # "beta": 0.5,
24 | # "model_path": "meta-llama/Llama-3.1-8B-Instruct",
25 | "dataset": "HuggingFaceH4/MATH-500", # "HuggingFaceH4/MATH-500", # "openai/gsm8k", # HuggingFaceH4/MATH-500
26 | # "dataset": "lighteval/MATH",
27 | # "model_path": "allenai/Llama-3.1-Tulu-3-8B-SFT",
28 | # "reward_model_path": "/gscratch/ark/graf/LLaMA-Factory/saves/llama3/8b1b/full/reward/",
29 | # "reward_model_path": "allenai/Llama-3.1-Tulu-3-8B-RM",
30 | # "variant": "ancestral",
31 | # "n": 128,
32 | },
33 | beta: float = 1.0,
34 | c: float = 1.0,
35 | extract="lastmath",
36 | trials=512,
37 | r=32,
38 | gaps=64,
39 | exp_rate=False,
40 | n=None,
41 | ):
42 |
43 | setup = (
44 | ExpSetup(storage=DiskStorage(base_dir=base_dir, mode="rw"))
45 | .query(query_args)
46 | .filter(lambda x: x.has_data())
47 | )
48 |
49 | print(setup)
50 | # setup.experiments[0].evals()
51 |
52 | print(
53 | len(setup.experiments),
54 | )
55 | if len(setup.experiments) == 0:
56 | raise FileNotFoundError("The configuration has no data!")
57 |
58 | # strategy, reward_model_path = strategy.split("-")
59 |
60 | with Pool() as p:
61 | pick = get_strategy(
62 | strategy,
63 | key=key,
64 | p=p,
65 | beta=beta,
66 | c=c,
67 | extract=extract,
68 | trials=trials,
69 | r=r,
70 | gaps=gaps,
71 | exp_rate=exp_rate,
72 | n=n,
73 | ) # msft , nvdia
74 |
75 | # setup = setup.filter(lambda x: not x.has_eval(pick.eval_name))
76 |
77 | def func(experiment):
78 |
79 | try:
80 | # if len(experiment.instances()) == experiment.meta["n"]:
81 |
82 | # print(len(experiment.instances()))
83 | return (
84 | pick(experiment)
85 | # if not experiment.has_eval(pick.eval_name)
86 | # else experiment
87 | )
88 | # else:
89 | # return experiment
90 |
91 | except FileNotFoundError:
92 | print(experiment.name)
93 | return experiment
94 |
95 | except Exception as e:
96 | print(experiment.name)
97 | print(e)
98 | raise e
99 | return experiment
100 |
101 | setup = setup.map(func)
102 |
103 | # new_setup.save()
104 |
105 |
106 | if __name__ == "__main__":
107 |
108 | import fire
109 |
110 | fire.Fire(main)
111 |
--------------------------------------------------------------------------------
/qalign/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/__init__.py
--------------------------------------------------------------------------------
/qalign/rm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/rm/__init__.py
--------------------------------------------------------------------------------
/qalign/rm/create_dataset.py:
--------------------------------------------------------------------------------
1 | from expkit import (
2 | ExpSetup,
3 | )
4 | import logging
5 | from expkit.storage import CachedRO, ZipStorage, DiskStorage
6 | import os
7 | import shutil
8 | import numpy as np
9 | import re
10 | import json
11 |
12 | from transformers import AutoTokenizer
13 |
14 | from qflow.utils.data import get_data_iterable, general_process_data
15 | from qflow.utils.math import get_last_number
16 | from datasets import load_dataset
17 |
18 | from qflow.utils.data import FlexiblePromptTemplate
19 |
20 |
21 | def sample_pairs(positive, negatives, n=1):
22 | nmin = min(len(positive), len(negatives), n)
23 |
24 | if nmin > 0:
25 | np.random.shuffle(positive)
26 | np.random.shuffle(negatives)
27 |
28 | if len(positive) < n:
29 | positive = np.random.choice(
30 | positive,
31 | size=n,
32 | replace=True,
33 | )
34 | if len(negatives) < n:
35 | negatives = np.random.choice(
36 | negatives,
37 | size=n,
38 | replace=True,
39 | )
40 |
41 | return [
42 | {
43 | "chosen": c,
44 | "rejected": r,
45 | }
46 | for c, r in zip(
47 | positive[:n],
48 | negatives[:n],
49 | )
50 | ]
51 | else:
52 | return []
53 |
54 |
55 | def fake_synth_data(i, po):
56 |
57 | number = get_last_number(i["input"]["answer"])
58 |
59 | alt_results = [
60 | np.round(
61 | np.random.rand() * float(number),
62 | 1,
63 | )
64 | for _ in range(len(po))
65 | ]
66 |
67 | if "." not in number:
68 | alt_results = np.ceil(alt_results).astype(int)
69 |
70 | no = [
71 | answer.replace(
72 | number,
73 | str(pseudo_result),
74 | )
75 | for answer, pseudo_result in zip(po, alt_results)
76 | ]
77 |
78 | return no
79 |
80 |
81 | def linearize(setup, oracle_key="lastnumber"):
82 | dataset = []
83 | evals = []
84 |
85 | if len(setup) > 1:
86 |
87 | setup = setup.unique("i").sort("i")
88 |
89 | for e in setup:
90 | try:
91 | print("Processing", e.meta.get("i", 0))
92 | print("elements:", len(e.instances()))
93 | print("evals:", len(e.get_eval(oracle_key)))
94 |
95 | dataset.extend(e.instances())
96 | evals.extend(e.get_eval(oracle_key))
97 |
98 | except Exception as e:
99 | logging.error("An error occurred during linearization:", e)
100 | pass
101 |
102 | return {"data": dataset, "evals": evals}
103 |
104 |
105 | def remove_tags(text, tags):
106 | ptext = re.sub("|".join(map(re.escape, tags)), "", text)
107 |
108 | return ptext
109 |
110 |
111 | def create_positive_and_negative_pairs(dataset_stack, stop_tokens):
112 | dataset = []
113 | answers = []
114 |
115 | for i, gt in zip(dataset_stack["data"], dataset_stack["evals"]):
116 |
117 | values = np.array(gt["scores"])
118 | pind = np.where(values == 1)[0] # extract index of correct
119 | nind = np.where(values == 0)[0] # extract index of incorrect
120 |
121 | po = [
122 | remove_tags(i["outputs"][j]["text"], stop_tokens) for j in pind
123 | ] # get text without special tokens
124 | no = [
125 | remove_tags(i["outputs"][j]["text"], stop_tokens) for j in nind
126 | ] # get text without special tokens
127 |
128 | all_neg, all_po = (
129 | len(po) == 0,
130 | len(no) == 0,
131 | )
132 |
133 | if all_po: ## if all positive
134 | no = fake_synth_data(i, po) # create fake wrongs .. hopefully this is rare.
135 |
136 | dataset.append((po + [i["input"]["answer"]], no))
137 |
138 | return dataset
139 |
140 |
141 | def sample_pairs_dataset_chat(
142 | split_dataset,
143 | epochs=1,
144 | dataset_path="openai/gsm8k",
145 | model_path="allenai/OLMo-7B-0724-Instruct-hf",
146 | prompt_template="{prompt}",
147 | ):
148 | dataset = []
149 |
150 | for (po, no), i in zip(
151 | split_dataset,
152 | get_data_iterable(
153 | model_path=model_path,
154 | dataset_path=dataset_path,
155 | split="train",
156 | prompt_template=prompt_template,
157 | ),
158 | ):
159 |
160 | chat_template = i["chat_template_prompt"]
161 |
162 | po, no = list(set(po)), list(set(no))
163 | for pair in sample_pairs(
164 | po,
165 | no,
166 | n=epochs,
167 | ):
168 |
169 | dataset.append(
170 | {
171 | k: chat_template
172 | + [
173 | {
174 | "role": "assistant",
175 | "content": v,
176 | }
177 | ]
178 | for k, v in pair.items()
179 | }
180 | )
181 |
182 | return dataset
183 |
184 |
185 | def sample_pairs_dataset_prompt(
186 | split_dataset,
187 | epochs=1,
188 | dataset_path="openai/gsm8k",
189 | model_path="allenai/OLMo-7B-0724-Instruct-hf",
190 | use_few_shot=False,
191 | ):
192 | dataset = []
193 |
194 | for (po, no), i in zip(
195 | split_dataset,
196 | get_data_iterable(
197 | model_path=model_path,
198 | dataset_path=dataset_path,
199 | split="train",
200 | format="prompt",
201 | use_few_shot=use_few_shot,
202 | ),
203 | ):
204 |
205 | chat_template = [
206 | {
207 | "role": "user",
208 | "content": i["prompt"],
209 | }
210 | ]
211 |
212 | po, no = list(set(po)), list(set(no))
213 | for pair in sample_pairs(
214 | po,
215 | no,
216 | n=epochs,
217 | ):
218 |
219 | dataset.append(
220 | {
221 | k: chat_template
222 | + [
223 | {
224 | "role": "assistant",
225 | "content": v,
226 | }
227 | ]
228 | for k, v in pair.items()
229 | }
230 | )
231 |
232 | return dataset
233 |
234 |
235 | def create_rm_data_test_pairs(
236 | model_path,
237 | dataset_path="openai/gsm8k",
238 | ): # this fake synth data .. is not representative ...
239 | dataset = []
240 | # templates = []
241 |
242 | for i in get_data_iterable(model_path=model_path, dataset_path=dataset_path):
243 |
244 | chat_template = i["chat_template_prompt"]
245 |
246 | # templates.append(chat_template)
247 | number = get_last_number(i["answer"])
248 |
249 | alt_number = np.round(
250 | np.random.rand() * float(number),
251 | 1,
252 | )
253 |
254 | if "." not in number:
255 | alt_number = np.ceil(alt_number).astype(int)
256 |
257 | wrong_answer = i["answer"].replace(
258 | number,
259 | str(alt_number),
260 | )
261 |
262 | answer = i["answer"]
263 |
264 | dataset.append(
265 | {
266 | "chosen": chat_template
267 | + [
268 | {
269 | "role": "assistant",
270 | "content": answer,
271 | }
272 | ],
273 | "rejected": chat_template
274 | + [
275 | {
276 | "role": "assistant",
277 | "content": wrong_answer,
278 | }
279 | ],
280 | }
281 | )
282 |
283 | return dataset
284 |
285 |
286 | def convert_llamafactory_format(x):
287 |
288 | return {
289 | "instruction": x["chosen"][0]["content"],
290 | "chosen": x["chosen"][-1]["content"],
291 | "rejected": x["rejected"][-1]["content"],
292 | }
293 |
294 |
295 | # /gscratch/ark/graf/rmlearn/data
296 |
297 |
298 | def get_sharded_setup(
299 | storage,
300 | model_path,
301 | temperature=1.0,
302 | steps=64,
303 | dataset="openai/gsm8k",
304 | oracle_key="lastnumber",
305 | prompt_template="{prompt}",
306 | ):
307 | setup = (
308 | ExpSetup(storage).query(
309 | {
310 | "steps": steps,
311 | "temperature": temperature,
312 | # "prompt_template": "Solve the following grade school math problem step-by-step: {prompt}",
313 | "model_path": model_path,
314 | "split": "train",
315 | "dataset": dataset,
316 | # "prompt_template": prompt_template,
317 | # "n": n,
318 | }
319 | )
320 | # .filter(lambda x: "i" in x.meta)
321 | # .filter(lambda x: x.islocked())
322 | )
323 |
324 | if len(setup) == 0:
325 | raise ValueError(
326 | f"No data found for model {model_path} and dataset {dataset} with steps {steps} and temperature {temperature}"
327 | )
328 |
329 | setup = setup.filter(lambda x: x.has_eval(oracle_key))
330 | setup = setup.filter(lambda x: x.get("n") == len(x.instances()))
331 |
332 | if len(setup) == 0:
333 | raise ValueError(f"No eval data with key:{oracle_key} found.")
334 |
335 | setup = setup.sort("n", reverse=False)
336 |
337 | dataset_stack = linearize(setup, oracle_key=oracle_key)
338 |
339 | return dataset_stack
340 |
341 |
342 | def file_name(dataset, model_path, steps=128, epochs=1, format="raw", split="train"):
343 | d = dataset.split("/")[-1].lower()
344 | m = model_path.split("/")[-1].lower()
345 |
346 | if split == "test":
347 | return f"{format}_{d}_{m}_{split}.json"
348 | else:
349 | return f"{format}_{d}_{m}_{steps}_{epochs}_{split}.json"
350 |
351 |
352 | def generate_rm_dataset(
353 | model_path="meta-llama/Llama-3.1-8B-Instruct",
354 | base_dir="outputs/llama3gsm8ktrain/",
355 | dataset="openai/gsm8k",
356 | save_path="/gscratch/ark/graf/quest-rlhf/qflow/rm/data/",
357 | epochs=[1, 2, 4, 8],
358 | steps=128,
359 | temperature=1.0,
360 | stop_tokens=None,
361 | oracle_key="lastnumber",
362 | prompt_template="{prompt}",
363 | prefix="",
364 | ):
365 |
366 | tokenizer = AutoTokenizer.from_pretrained(model_path)
367 |
368 | if stop_tokens is None:
369 | stop_tokens = [tokenizer.eos_token]
370 | else:
371 | stop_tokens = stop_tokens + [tokenizer.eos_token]
372 |
373 | try:
374 | storage = ZipStorage(
375 | base_dir=base_dir,
376 | mode="r",
377 | )
378 | except ValueError as e:
379 | storage = DiskStorage(
380 | base_dir=base_dir,
381 | mode="r",
382 | )
383 |
384 | storage = CachedRO(storage)
385 |
386 | dataset_stack = get_sharded_setup(
387 | storage,
388 | model_path=model_path,
389 | temperature=temperature,
390 | steps=steps,
391 | dataset=dataset,
392 | oracle_key=oracle_key,
393 | prompt_template=prompt_template,
394 | )
395 |
396 | all_pairs = create_positive_and_negative_pairs(
397 | dataset_stack,
398 | stop_tokens=stop_tokens,
399 | )
400 |
401 | for ep in epochs:
402 | sampled_pairs = sample_pairs_dataset_chat(
403 | all_pairs,
404 | epochs=ep,
405 | model_path=model_path,
406 | dataset_path=dataset,
407 | prompt_template=prompt_template,
408 | # use_few_shot=True,
409 | )
410 |
411 | sampled_pairs_factory = [convert_llamafactory_format(x) for x in sampled_pairs]
412 |
413 | with open(
414 | save_path
415 | + prefix
416 | + file_name(
417 | dataset,
418 | model_path,
419 | steps,
420 | ep,
421 | format="raw",
422 | ),
423 | "w",
424 | ) as fn:
425 | json.dump(sampled_pairs, fn)
426 |
427 | with open(
428 | save_path
429 | + prefix
430 | + file_name(
431 | dataset,
432 | model_path,
433 | steps,
434 | ep,
435 | format="llama-factory",
436 | ),
437 | "w",
438 | ) as f:
439 | json.dump(sampled_pairs_factory, f)
440 |
441 | print(f"Saved {prefix+file_name(dataset, model_path, steps, ep, format='raw')}")
442 | print(
443 | f"Saved {prefix+file_name(dataset, model_path, steps, ep, format='llama-factory')}"
444 | )
445 |
446 |
447 | def generate_rm_test_dataset(
448 | model_path="meta-llama/Llama-3.2-3B-Instruct",
449 | base_dir="outputs/llama3gsm8ktrain/",
450 | dataset="openai/gsm8k",
451 | save_path="/gscratch/ark/graf/quest-rlhf/qflow/rm/data/",
452 | ):
453 |
454 | test_dataset = create_rm_data_test_pairs(
455 | model_path=model_path, dataset_path=dataset
456 | )
457 |
458 | with open(
459 | save_path + file_name(dataset, model_path, format="raw", split="test"),
460 | "w",
461 | ) as fn:
462 | json.dump(test_dataset, fn)
463 |
464 | test_data = [convert_llamafactory_format(x) for x in test_dataset]
465 |
466 | with open(
467 | save_path
468 | + file_name(dataset, model_path, format="llama-factory", split="test"),
469 | "w",
470 | ) as fn:
471 | json.dump(test_data, fn)
472 |
473 |
474 | def main(
475 | model_path="meta-llama/Llama-3.1-8B-Instruct", # "allenai/Llama-3.1-Tulu-3-8B-SFT", # "allenai/Llama-3.1-Tulu-3-8B-SFT", # "meta-llama/Llama-3.1-8B-Instruct",
476 | dataset="lighteval/MATH", # "lighteval/MATH", # "openai/gsm8k",
477 | save_path="/gscratch/ark/graf/quest-rlhf/qflow/rm/data/",
478 | base_dir="remote-outputs-llama/", # "outputs/llama3gsm8ktrain/",
479 | steps=128,
480 | prompt_template="Solve the following math problem step-by-step: {prompt}\n\nPresent the answer in LaTex format: \\boxed{Your answer}", # "Solve the following math problem step-by-step: {prompt}",
481 | prefix="cotnewextractmath14325",
482 | ): # 90d44980-8ad9-48f2-8c34-99747991f571
483 |
484 | dataset_oracles = {
485 | "openai/gsm8k": "lastnumber",
486 | "lighteval/MATH": "lastmath",
487 | }
488 |
489 | generate_rm_dataset(
490 | model_path=model_path,
491 | base_dir=base_dir,
492 | dataset=dataset,
493 | save_path=save_path,
494 | epochs=[1, 2, 4, 8],
495 | steps=steps,
496 | temperature=1.0,
497 | oracle_key=dataset_oracles[dataset],
498 | prompt_template=FlexiblePromptTemplate(prompt_template),
499 | prefix=prefix,
500 | )
501 |
502 | """generate_rm_test_dataset(
503 | model_path=model_path,
504 | base_dir=base_dir,
505 | dataset=dataset,
506 | save_path=save_path,
507 | )""" # this has to improved. We need to generate from the base model for the test set.
508 |
509 | info = {}
510 | for fn in os.listdir(save_path):
511 |
512 | if "llama-factory" not in fn:
513 | continue
514 |
515 | key = fn.replace(".json", "")
516 | info[key] = {
517 | "file_name": fn,
518 | "ranking": True,
519 | "columns": {
520 | "prompt": "instruction",
521 | "chosen": "chosen",
522 | "rejected": "rejected",
523 | },
524 | }
525 |
526 | with open(save_path + "dataset_info.json", "w") as f:
527 | json.dump(info, f)
528 |
529 |
530 | if __name__ == "__main__":
531 | main()
532 |
--------------------------------------------------------------------------------
/qalign/rm/data/dataset_info.json:
--------------------------------------------------------------------------------
1 | {"llama-factory_gsm8k_llama-3.1-8b-instruct_128_4_train": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_128_2_train": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_1_train": {"file_name": "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_2_train": {"file_name": "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train": {"file_name": "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_4_train": {"file_name": "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train": {"file_name": "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train": {"file_name": "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-8b-instruct_128_1_train": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_8_train": {"file_name": "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_128_1_train": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b_128_4_train": {"file_name": "llama-factory_math_llama-3.1-8b_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b_128_8_train": {"file_name": "llama-factory_math_llama-3.1-8b_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_4_train": {"file_name": "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b-instruct_128_4_train": {"file_name": "llama-factory_math_llama-3.1-8b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.2-1b-instruct_128_1_train": {"file_name": "llama-factory_math_llama-3.2-1b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_test": {"file_name": "llama-factory_gsm8k_test.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.2-1b-instruct_128_8_train": {"file_name": "llama-factory_math_llama-3.2-1b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b_128_1_train": {"file_name": "llama-factory_math_llama-3.1-8b_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b-instruct_128_1_train": {"file_name": "llama-factory_math_llama-3.1-8b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_128_1_train": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_128_2_train": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_128_8_train": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.2-1b-instruct_128_4_train": {"file_name": "llama-factory_math_llama-3.2-1b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_128_8_train": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_128_4_train": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_2_train": {"file_name": "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b-instruct_128_2_train": {"file_name": "llama-factory_math_llama-3.1-8b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_4_train": {"file_name": "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-8b-instruct_128_8_train": {"file_name": "cotllama-factory_math_llama-3.1-8b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.2-1b-instruct_128_2_train": {"file_name": "llama-factory_math_llama-3.2-1b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_8_train": {"file_name": "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_2_train": {"file_name": "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-8b-instruct_128_4_train": {"file_name": "cotllama-factory_math_llama-3.1-8b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-8b-instruct_test": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_test.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b_128_2_train": {"file_name": "llama-factory_math_llama-3.1-8b_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train": {"file_name": "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-8b-instruct_128_2_train": {"file_name": "cotllama-factory_math_llama-3.1-8b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_2_train": {"file_name": "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-8b-instruct_128_1_train": {"file_name": "cotllama-factory_math_llama-3.1-8b-instruct_128_1_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-8b-instruct_128_8_train": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-1b-instruct_test": {"file_name": "llama-factory_gsm8k_llama-3.2-1b-instruct_test.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_4_train": {"file_name": "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_2_train": {"file_name": "cotllama-factory_math_llama-3.1-tulu-3-8b-sft_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_test": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_test.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_8_train": {"file_name": "cotllama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_8_train": {"file_name": "llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.2-3b-instruct_128_4_train": {"file_name": "llama-factory_gsm8k_llama-3.2-3b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_4_train": {"file_name": "cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_4_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-8b-instruct_128_8_train": {"file_name": "llama-factory_math_llama-3.1-8b-instruct_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_8_train": {"file_name": "llama-factory_math_llama-3.1-tulu-3-8b-sft_128_8_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}, "llama-factory_gsm8k_llama-3.1-8b-instruct_128_2_train": {"file_name": "llama-factory_gsm8k_llama-3.1-8b-instruct_128_2_train.json", "ranking": true, "columns": {"prompt": "instruction", "chosen": "chosen", "rejected": "rejected"}}}
--------------------------------------------------------------------------------
/qalign/rm/train_configs/deepspeed/ds_z0_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 0,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/deepspeed/ds_z2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/deepspeed/ds_z2_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "offload_optimizer": {
21 | "device": "cpu",
22 | "pin_memory": true
23 | },
24 | "allgather_partitions": true,
25 | "allgather_bucket_size": 5e8,
26 | "overlap_comm": true,
27 | "reduce_scatter": true,
28 | "reduce_bucket_size": 5e8,
29 | "contiguous_gradients": true,
30 | "round_robin_gradients": true
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/deepspeed/ds_z3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": false
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "zero_optimization": {
14 | "stage": 3,
15 | "overlap_comm": true,
16 | "contiguous_gradients": true,
17 | "sub_group_size": 1e9,
18 | "reduce_bucket_size": "auto",
19 | "stage3_prefetch_bucket_size": "auto",
20 | "stage3_param_persistence_threshold": "auto",
21 | "stage3_max_live_parameters": 1e9,
22 | "stage3_max_reuse_distance": 1e9,
23 | "stage3_gather_16bit_weights_on_model_save": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/deepspeed/ds_z3_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 3,
20 | "offload_optimizer": {
21 | "device": "cpu",
22 | "pin_memory": true
23 | },
24 | "offload_param": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "overlap_comm": true,
29 | "contiguous_gradients": true,
30 | "sub_group_size": 1e9,
31 | "reduce_bucket_size": "auto",
32 | "stage3_prefetch_bucket_size": "auto",
33 | "stage3_param_persistence_threshold": "auto",
34 | "stage3_max_live_parameters": 1e9,
35 | "stage3_max_reuse_distance": 1e9,
36 | "stage3_gather_16bit_weights_on_model_save": true
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/explain.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | llamafactory-cli train quest/full/llama3_8b1b_full_reward.yaml
4 | llamafactory-cli train qflow/rm/quest/full/llama3_8b1b_full_reward.yaml
5 |
6 | llamafactory-cli train qflow/rm/train_configs/math/full/llama3_8b8b_full_reward_math_cot.yaml
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: google/gemma-2-2b-it
3 |
4 | ### method
5 | stage: dpo
6 | do_train: true
7 | finetuning_type: full
8 | pref_beta: 5.0
9 | #pref_ftx: 1.0
10 | flash_attn: "disabled"
11 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
12 | deepspeed: examples/deepspeed/ds_z3_config.json
13 |
14 | ## dataset
15 | dataset: gsm8k_gemma2_264_4
16 | template: gemma
17 | cutoff_len: 1024
18 | max_samples: 1000000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: saves/gemma-2-2b-it/full/dpo
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 1
32 | gradient_accumulation_steps: 16
33 | learning_rate: 5.0e-7
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | warmup_ratio: 0.03
37 | bf16: true
38 | ddp_timeout: 180000000
39 |
40 | ### eval
41 | val_size: 0.1
42 | per_device_eval_batch_size: 4
43 | eval_strategy: steps
44 | eval_steps: 5
45 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo05.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: google/gemma-2-2b-it
3 |
4 | ### method
5 | stage: dpo
6 | do_train: true
7 | finetuning_type: full
8 | pref_beta: 0.05
9 | pref_ftx: 0.0
10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
11 | deepspeed: examples/deepspeed/ds_z3_config.json
12 |
13 | ## dataset
14 | dataset: gsm8k_gemma2_264_4
15 | template: gemma
16 | cutoff_len: 1024
17 | max_samples: 1000000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/gemma-2-2b-it/full/dpo05
23 | logging_steps: 1
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 2
31 | gradient_accumulation_steps: 32
32 | learning_rate: 5.0e-6
33 | num_train_epochs: 1.0
34 | lr_scheduler_type: linear
35 | warmup_ratio: 0.03
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | val_size: 0.1
41 | per_device_eval_batch_size: 4
42 | eval_strategy: steps
43 | eval_steps: 5
44 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo10.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: google/gemma-2-2b-it
3 |
4 | ### method
5 | stage: dpo
6 | do_train: true
7 | finetuning_type: full
8 | pref_beta: 0.10
9 | pref_ftx: 0.0
10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
11 | deepspeed: examples/deepspeed/ds_z3_config.json
12 |
13 | ## dataset
14 | dataset: gsm8k_gemma2_264_4
15 | template: gemma
16 | cutoff_len: 1024
17 | max_samples: 1000000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/gemma-2-2b-it/full/dpo10
23 | logging_steps: 1
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 2
31 | gradient_accumulation_steps: 32
32 | learning_rate: 5.0e-6
33 | num_train_epochs: 1.0
34 | lr_scheduler_type: linear
35 | warmup_ratio: 0.03
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | val_size: 0.1
41 | per_device_eval_batch_size: 4
42 | eval_strategy: steps
43 | eval_steps: 5
44 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo_sft05.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: google/gemma-2-2b-it
3 |
4 | ### method
5 | stage: dpo
6 | do_train: true
7 | finetuning_type: full
8 | pref_beta: 0.05
9 | pref_ftx: 1.0
10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
11 | deepspeed: examples/deepspeed/ds_z3_config.json
12 |
13 | ## dataset
14 | dataset: gsm8k_gemma2_264_4
15 | template: gemma
16 | cutoff_len: 1024
17 | max_samples: 1000000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/gemma-2-2b-it/full/dposft05
23 | logging_steps: 1
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 2
31 | gradient_accumulation_steps: 32
32 | learning_rate: 5.0e-6
33 | num_train_epochs: 1.0
34 | lr_scheduler_type: linear
35 | warmup_ratio: 0.03
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | val_size: 0.1
41 | per_device_eval_batch_size: 4
42 | eval_strategy: steps
43 | eval_steps: 5
44 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/gemma2_full_dpo_sft10.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: google/gemma-2-2b-it
3 |
4 | ### method
5 | stage: dpo
6 | do_train: true
7 | finetuning_type: full
8 | pref_beta: 1.0
9 | pref_ftx: 1.0
10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
11 | deepspeed: examples/deepspeed/ds_z3_config.json
12 | flash_attn: "disabled"
13 |
14 | ## dataset
15 | dataset: gsm8k_gemma2_264_4
16 | template: gemma
17 | cutoff_len: 1024
18 | max_samples: 1000000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: saves/gemma-2-2b-it/full/dposft10
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 2
32 | gradient_accumulation_steps: 32
33 | learning_rate: 5.0e-6
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | warmup_ratio: 0.03
37 | bf16: true
38 | ddp_timeout: 180000000
39 |
40 | ### eval
41 | val_size: 0.1
42 | per_device_eval_batch_size: 4
43 | eval_strategy: steps
44 | eval_steps: 5
45 |
46 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/gemma2_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: google/gemma-2-2b-it
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 |
11 |
12 | ### dataset
13 | dataset: gsm8k_gemma2_264
14 | template: gemma
15 | cutoff_len: 1024
16 | max_samples: 1000
17 | overwrite_cache: true
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: saves/gemma-2-2b-it/full/reward
22 | logging_steps: 10
23 | save_steps: 500
24 | plot_loss: true
25 | overwrite_output_dir: true
26 | report_to: wandb
27 |
28 | ### train
29 | per_device_train_batch_size: 1
30 | gradient_accumulation_steps: 32
31 | learning_rate: 1.0e-5
32 | num_train_epochs: 1.0
33 | lr_scheduler_type: linear
34 | weight_decay: 0.0
35 | warmup_ratio: 0.03
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | val_size: 0.1
41 | per_device_eval_batch_size: 4
42 | eval_strategy: steps
43 | eval_steps: 2
44 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/llama3_1b1b_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 |
13 | ### dataset
14 | dataset: gsm8k_llama3.2-1B_128_1ep
15 | eval_dataset: gsm8k_llama3.2-1B_test
16 | template: llama3
17 | cutoff_len: 1024
18 | max_samples: 300000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: saves/llama3/1b1b/full/reward
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 2
32 | gradient_accumulation_steps: 8
33 | learning_rate: 1.0e-5
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | weight_decay: 0.0
37 | warmup_ratio: 0.03
38 | bf16: true
39 | ddp_timeout: 180000000
40 |
41 | ### eval
42 | #val_size: 0.05
43 | per_device_eval_batch_size: 4
44 | eval_strategy: steps
45 | eval_steps: 5
46 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/llama3_1b8b_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset: gsm8k_llama3.2-1B_128_1ep
14 | eval_dataset: gsm8k_llama3.2-1B_test
15 | template: llama3
16 | cutoff_len: 1024
17 | max_samples: 300000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/llama3/1b8b/full/reward
23 | logging_steps: 1
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 1
31 | gradient_accumulation_steps: 16
32 | learning_rate: 1.0e-5
33 | num_train_epochs: 1.0
34 | lr_scheduler_type: linear
35 | weight_decay: 0.0
36 | warmup_ratio: 0.03
37 | bf16: true
38 | ddp_timeout: 180000000
39 |
40 | ### eval
41 | #val_size: 0.05
42 | per_device_eval_batch_size: 4
43 | eval_strategy: steps
44 | eval_steps: 5
45 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/llama3_3b1b_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 |
13 | ### dataset
14 | dataset: gsm8k_llama3.2-3B_128_1ep
15 | eval_dataset: gsm8k_llama3.2-3B_test
16 | template: llama3
17 | cutoff_len: 1024
18 | max_samples: 300000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: saves/llama3/3b1b/full/reward
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 2
32 | gradient_accumulation_steps: 8
33 | learning_rate: 1.0e-5
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | weight_decay: 0.0
37 | warmup_ratio: 0.03
38 | bf16: true
39 | ddp_timeout: 180000000
40 |
41 | ### eval
42 | #val_size: 0.05
43 | per_device_eval_batch_size: 4
44 | eval_strategy: steps
45 | eval_steps: 5
46 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/llama3_3b8b_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset: gsm8k_llama3.2-3B_128_1ep
14 | eval_dataset: gsm8k_llama3.2-3B_test
15 | template: llama3
16 | cutoff_len: 1024
17 | max_samples: 300000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/llama3/3b8b/full/reward
23 | logging_steps: 1
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 1
31 | gradient_accumulation_steps: 32
32 | learning_rate: 1.0e-5
33 | num_train_epochs: 1.0
34 | lr_scheduler_type: linear
35 | weight_decay: 0.0
36 | warmup_ratio: 0.03
37 | bf16: true
38 | ddp_timeout: 180000000
39 |
40 | ### eval
41 | #val_size: 0.05
42 | per_device_eval_batch_size: 4
43 | eval_strategy: steps
44 | eval_steps: 5
45 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/llama3_3b_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.2-3B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset: gsm8k_llama3.2-3B_128_1ep
14 | eval_dataset: gsm8k_llama3.2-3B_test
15 | template: llama3
16 | cutoff_len: 1024
17 | max_samples: 300000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/llama3/3b/full/reward
23 | logging_steps: 1
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 2
31 | gradient_accumulation_steps: 8
32 | learning_rate: 1.0e-5
33 | num_train_epochs: 1.0
34 | lr_scheduler_type: linear
35 | weight_decay: 0.0
36 | warmup_ratio: 0.03
37 | bf16: true
38 | ddp_timeout: 180000000
39 |
40 | ### eval
41 | #val_size: 0.05
42 | per_device_eval_batch_size: 4
43 | eval_strategy: steps
44 | eval_steps: 5
45 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/llama3_8b1b_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 |
13 | ### dataset
14 | dataset: gsm8k_llama3.1-8B_128_1ep
15 | eval_dataset: gsm8k_llama3.1-8B_test
16 | template: llama3
17 | cutoff_len: 1024
18 | max_samples: 300000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: saves/llama3/8b1b/full/reward
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 2
32 | gradient_accumulation_steps: 8
33 | learning_rate: 1.0e-5
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | weight_decay: 0.0
37 | warmup_ratio: 0.03
38 | bf16: true
39 | ddp_timeout: 180000000
40 |
41 | ### eval
42 | #val_size: 0.05
43 | per_device_eval_batch_size: 4
44 | eval_strategy: steps
45 | eval_steps: 5
46 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/llama3_8b8b_full_dpo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 |
6 | ### method
7 | stage: dpo
8 | do_train: true
9 | finetuning_type: full
10 | pref_beta: 5.0
11 | #pref_ftx: 1.0
12 | flash_attn: "disabled"
13 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
14 | deepspeed: examples/deepspeed/ds_z3_config.json
15 |
16 | ### dataset
17 | dataset: gsm8k_llama3.1-8B_128_1ep
18 | eval_dataset: gsm8k_llama3.1-8B_test
19 | template: llama3
20 | cutoff_len: 1024
21 | max_samples: 300000
22 | overwrite_cache: true
23 | preprocessing_num_workers: 16
24 |
25 | ### output
26 | output_dir: saves/llama3/8b8b/full/dpoe6
27 | logging_steps: 1
28 | save_steps: 500
29 | plot_loss: true
30 | overwrite_output_dir: true
31 | report_to: wandb
32 |
33 | ### train
34 | per_device_train_batch_size: 1
35 | gradient_accumulation_steps: 16
36 | learning_rate: 1.0e-6
37 | num_train_epochs: 1.0
38 | lr_scheduler_type: linear
39 | weight_decay: 0.0
40 | warmup_ratio: 0.03
41 | bf16: true
42 | ddp_timeout: 180000000
43 |
44 | ### eval
45 | #val_size: 0.05
46 | per_device_eval_batch_size: 4
47 | eval_strategy: steps
48 | eval_steps: 5
49 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/llama3_8b8b_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset: gsm8k_llama3.1-8B_128_1ep
14 | eval_dataset: gsm8k_llama3.1-8B_test
15 | template: llama3
16 | cutoff_len: 1024
17 | max_samples: 300000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/llama3/8b8b/full/reward
23 | logging_steps: 1
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 1
31 | gradient_accumulation_steps: 16
32 | learning_rate: 1.0e-5
33 | num_train_epochs: 1.0
34 | lr_scheduler_type: linear
35 | weight_decay: 0.0
36 | warmup_ratio: 0.03
37 | bf16: true
38 | ddp_timeout: 180000000
39 |
40 | ### eval
41 | #val_size: 0.05
42 | per_device_eval_batch_size: 4
43 | eval_strategy: steps
44 | eval_steps: 5
45 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/olmo_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: allenai/OLMo-7B-0724-Instruct-hf
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 |
11 |
12 | ### dataset
13 | dataset: gsm8k_olmo_128_4ep
14 | template: olmo
15 | cutoff_len: 1024
16 | max_samples: 300000
17 | overwrite_cache: true
18 | preprocessing_num_workers: 16
19 |
20 | ### output
21 | output_dir: saves/olmo/full/reward4e-5
22 | logging_steps: 1
23 | save_steps: 500
24 | plot_loss: true
25 | overwrite_output_dir: true
26 | report_to: wandb
27 |
28 | ### train
29 | per_device_train_batch_size: 2
30 | gradient_accumulation_steps: 32
31 | learning_rate: 1.0e-5
32 | num_train_epochs: 1.0
33 | lr_scheduler_type: linear
34 | weight_decay: 0.0
35 | warmup_ratio: 0.03
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | val_size: 0.05
41 | per_device_eval_batch_size: 4
42 | eval_strategy: steps
43 | eval_steps: 5
44 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/train.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # llamafactory-cli train quest/full/llama3_1b_full_reward.yaml
5 | # llamafactory-cli train quest/full/llama3_3b_full_reward.yaml
6 |
7 |
8 | llamafactory-cli train quest/full/llama3_8b1b_full_reward.yaml && llamafactory-cli train quest/full/llama3_3b1b_full_reward.yaml
9 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/full/tulu_8b8b_full_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: allenai/Llama-3.1-Tulu-3-8B-SFT
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset_dir: qflow/rm/data/
14 | dataset: llama-factory_gsm8k_llama-3.1-tulu-3-8b-sft_64_1_train
15 | #eval_dataset: gsm8k_llama3.1-8B_test
16 | template: llama3
17 | cutoff_len: 1024
18 | max_samples: 300000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: qflow/rm/artifacts/tulu/8b8b/gsm8k/full/reward
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 1
32 | gradient_accumulation_steps: 32
33 | learning_rate: 1.0e-5
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | weight_decay: 0.0
37 | warmup_ratio: 0.03
38 | bf16: true
39 | ddp_timeout: 180000000
40 |
41 | ### eval
42 | val_size: 0.05
43 | per_device_eval_batch_size: 2
44 | eval_strategy: steps
45 | eval_steps: 5
46 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/lora/gemma2_lora_dpo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: google/gemma-2-2b-it
3 |
4 | ### method
5 | stage: dpo
6 | do_train: true
7 | finetuning_type: lora
8 | pref_beta: 0.1
9 | #pref_ftx: 0.5
10 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
11 | lora_rank: 8
12 | deepspeed: examples/deepspeed/ds_z3_config.json
13 |
14 | # finetuning_type: full
15 |
16 |
17 | ### dataset
18 | dataset: gsm8k_gemma2_264_4
19 | template: gemma
20 | cutoff_len: 1024
21 | max_samples: 1000000
22 | overwrite_cache: true
23 | preprocessing_num_workers: 16
24 |
25 | ### output
26 | output_dir: saves/gemma-2-2b-it/lora/dpo
27 | logging_steps: 1
28 | save_steps: 500
29 | plot_loss: true
30 | overwrite_output_dir: true
31 | report_to: wandb
32 |
33 | ### train
34 | per_device_train_batch_size: 2
35 | gradient_accumulation_steps: 32
36 | learning_rate: 5.0e-6
37 | num_train_epochs: 1.0
38 | lr_scheduler_type: linear
39 | warmup_ratio: 0.03
40 | bf16: true
41 | ddp_timeout: 180000000
42 |
43 | ### eval
44 | val_size: 0.1
45 | per_device_eval_batch_size: 4
46 | eval_strategy: steps
47 | eval_steps: 5
48 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/lora/gemma2_lora_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: google/gemma-2-2b-it
3 |
4 | ### method
5 | stage: rm
6 | do_train: true
7 | finetuning_type: lora
8 | lora_target: all
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | lora_rank: 8
11 |
12 |
13 | ### dataset
14 | dataset: gsm8k_gemma2_264
15 | template: gemma
16 | cutoff_len: 1024
17 | max_samples: 1000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/gemma-2-2b-it/lora/reward
23 | logging_steps: 10
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 2
31 | gradient_accumulation_steps: 8
32 | learning_rate: 1.0e-4
33 | num_train_epochs: 3.0
34 | lr_scheduler_type: cosine
35 | warmup_ratio: 0.1
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | val_size: 0.1
41 | per_device_eval_batch_size: 4
42 | eval_strategy: steps
43 | eval_steps: 10
44 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/lora/llama2_lora_reward.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-2-7b-hf
3 |
4 | ### method
5 | stage: rm
6 | do_train: true
7 | finetuning_type: lora
8 | lora_target: all
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | lora_rank: 16
11 |
12 |
13 | ### dataset
14 | dataset: gsm8k_olmo_128_4ep
15 | template: llama2
16 | cutoff_len: 1024
17 | max_samples: 1000000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/llama2/lora/reward
23 | logging_steps: 10
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 2
31 | gradient_accumulation_steps: 32
32 | learning_rate: 1.0e-4
33 | num_train_epochs: 1.0
34 | lr_scheduler_type: linear
35 | warmup_ratio: 0.03
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | val_size: 0.05
41 | per_device_eval_batch_size: 4
42 | eval_strategy: steps
43 | eval_steps: 10
44 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/lora/llama2_lora_sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-2-7b-hf
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: lora
8 | lora_target: all
9 | deepspeed: examples/deepspeed/ds_z3_config.json
10 | lora_rank: 16
11 |
12 |
13 | ### dataset
14 | dataset: gsm8k_gemma2_264
15 | template: llama2
16 | cutoff_len: 1024
17 | max_samples: 1000
18 | overwrite_cache: true
19 | preprocessing_num_workers: 16
20 |
21 | ### output
22 | output_dir: saves/llama2/lora/reward
23 | logging_steps: 10
24 | save_steps: 500
25 | plot_loss: true
26 | overwrite_output_dir: true
27 | report_to: wandb
28 |
29 | ### train
30 | per_device_train_batch_size: 2
31 | gradient_accumulation_steps: 8
32 | learning_rate: 1.0e-4
33 | num_train_epochs: 3.0
34 | lr_scheduler_type: cosine
35 | warmup_ratio: 0.1
36 | bf16: true
37 | ddp_timeout: 180000000
38 |
39 | ### eval
40 | val_size: 0.1
41 | per_device_eval_batch_size: 4
42 | eval_strategy: steps
43 | eval_steps: 10
44 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/gsm8k/lora/olmo_lora_reward.yaml:
--------------------------------------------------------------------------------
1 | model_name_or_path: allenai/OLMo-7B-0724-Instruct-hf
2 | ### method
3 | stage: rm
4 | do_train: true
5 | finetuning_type: lora
6 | lora_target: all
7 | deepspeed: examples/deepspeed/ds_z3_config.json
8 | lora_rank: 32
9 | ### model
10 |
11 | dataset: gsm8k_olmo_128_4ep
12 | template: olmo
13 | cutoff_len: 1024
14 | max_samples: 300000
15 | overwrite_cache: true
16 | preprocessing_num_workers: 16
17 |
18 | ### output
19 | output_dir: saves/olmo/lora/reward4e-5
20 | logging_steps: 1
21 | save_steps: 500
22 | plot_loss: true
23 | overwrite_output_dir: true
24 | report_to: wandb
25 |
26 | ### train
27 | per_device_train_batch_size: 2
28 | gradient_accumulation_steps: 32
29 | learning_rate: 1.0e-5
30 | num_train_epochs: 1.0
31 | lr_scheduler_type: linear
32 | weight_decay: 0.0
33 | warmup_ratio: 0.03
34 | bf16: true
35 | ddp_timeout: 180000000
36 |
37 | ### eval
38 | val_size: 0.05
39 | per_device_eval_batch_size: 4
40 | eval_strategy: steps
41 | eval_steps: 5
42 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/math/full/llama3_1b8b_full_reward_math.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset_dir: qflow/rm/data/
14 | dataset: llama-factory_math_llama-3.2-1b-instruct_128_1_train
15 | #eval_dataset: gsm8k_llama3.1-8B_test
16 | template: llama3
17 | cutoff_len: 1024
18 | max_samples: 300000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: qflow/rm/artifacts/llama3/1b8b/math/full/reward
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 1
32 | gradient_accumulation_steps: 32
33 | learning_rate: 1.0e-5
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | weight_decay: 0.0
37 | warmup_ratio: 0.03
38 | bf16: true
39 | ddp_timeout: 180000000
40 |
41 | ### eval
42 | val_size: 0.05
43 | per_device_eval_batch_size: 2
44 | eval_strategy: steps
45 | eval_steps: 5
46 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/math/full/llama3_8b8b_full_reward_math.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset_dir: qflow/rm/data/
14 | dataset: llama-factory_math_llama-3.1-8b-instruct_128_1_train
15 | #eval_dataset: gsm8k_llama3.1-8B_test
16 | template: llama3
17 | cutoff_len: 1024
18 | max_samples: 300000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: qflow/rm/artifacts/llama3/8b8b/math/full/reward
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 1
32 | gradient_accumulation_steps: 16
33 | learning_rate: 1.0e-5
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | weight_decay: 0.0
37 | warmup_ratio: 0.03
38 | bf16: true
39 | ddp_timeout: 180000000
40 |
41 | ### eval
42 | val_size: 0.05
43 | per_device_eval_batch_size: 2
44 | eval_strategy: steps
45 | eval_steps: 5
46 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/math/full/llama3_8b8b_full_reward_math_cot.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset_dir: qflow/rm/data/
14 | dataset: cotnewextractmath14325llama-factory_math_llama-3.1-8b-instruct_128_1_train
15 |
16 | #eval_dataset: gsm8k_llama3.1-8B_test
17 | template: llama3
18 | cutoff_len: 1024
19 | max_samples: 300000
20 | overwrite_cache: true
21 | preprocessing_num_workers: 16
22 |
23 | ### output
24 | output_dir: qflow/rm/artifacts/llama3/8b8b/mathcotnewextractmath14325/full/reward
25 | logging_steps: 1
26 | save_steps: 500
27 | plot_loss: true
28 | overwrite_output_dir: true
29 | report_to: wandb
30 |
31 | ### train
32 | per_device_train_batch_size: 1
33 | gradient_accumulation_steps: 32
34 | learning_rate: 1.0e-5
35 | num_train_epochs: 1.0
36 | lr_scheduler_type: linear
37 | weight_decay: 0.0
38 | warmup_ratio: 0.03
39 | bf16: true
40 | ddp_timeout: 180000000
41 |
42 | ### eval
43 | val_size: 0.05
44 | per_device_eval_batch_size: 2
45 | eval_strategy: steps
46 | eval_steps: 5
47 |
--------------------------------------------------------------------------------
/qalign/rm/train_configs/math/full/tulu_8b8b_full_reward_math.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: allenai/Llama-3.1-Tulu-3-8B-SFT
3 | flash_attn: fa2
4 | #attn_implementation: eager
5 | ### method
6 | stage: rm
7 | do_train: true
8 | finetuning_type: full
9 | deepspeed: qflow/rm/train_configs/deepspeed/ds_z3_config.json
10 | save_safetensors: False
11 |
12 | ### dataset
13 | dataset_dir: qflow/rm/data/
14 | dataset: llama-factory_math_llama-3.1-tulu-3-8b-sft_128_1_train
15 | #eval_dataset: gsm8k_llama3.1-8B_test
16 | template: llama3
17 | cutoff_len: 1024
18 | max_samples: 300000
19 | overwrite_cache: true
20 | preprocessing_num_workers: 16
21 |
22 | ### output
23 | output_dir: qflow/rm/artifacts/tulu/8b8b/math/full/reward
24 | logging_steps: 1
25 | save_steps: 500
26 | plot_loss: true
27 | overwrite_output_dir: true
28 | report_to: wandb
29 |
30 | ### train
31 | per_device_train_batch_size: 1
32 | gradient_accumulation_steps: 32
33 | learning_rate: 1.0e-5
34 | num_train_epochs: 1.0
35 | lr_scheduler_type: linear
36 | weight_decay: 0.0
37 | warmup_ratio: 0.03
38 | bf16: true
39 | ddp_timeout: 180000000
40 |
41 | ### eval
42 | val_size: 0.05
43 | per_device_eval_batch_size: 2
44 | eval_strategy: steps
45 | eval_steps: 5
46 |
--------------------------------------------------------------------------------
/qalign/rm/valid_check.py:
--------------------------------------------------------------------------------
1 | import json
2 | from qflow.utils.math import get_last_math
3 |
4 | file_path = "qflow/rm/data/cotllama-factory_math_llama-3.1-8b-instruct_128_1_train.json"
5 |
6 | with open(file_path, "r") as f:
7 | data = json.load(f)
8 |
9 |
10 | count = 0
11 | for entry in data:
12 | right = entry["chosen"]
13 | wrong = entry["rejected"]
14 |
15 | right_answer = get_last_math(right)
16 | wrong_last_line = wrong.split("\n")[-1]
17 |
18 | count += int(right_answer in wrong_last_line)
19 |
20 | if right_answer in wrong_last_line:
21 | print(f"right: [{right_answer}], wrong: >>>\n{'*'*20}\n {wrong_last_line}")
22 | # import pdb
23 |
24 | # pdb.set_trace()
25 |
26 | print(count, 7500)
27 | import pdb
28 |
29 | pdb.set_trace()
30 |
--------------------------------------------------------------------------------
/qalign/utils/data.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, Value
2 |
3 | from langchain.prompts import PromptTemplate
4 |
5 | from functools import partial
6 |
7 | from transformers import AutoTokenizer
8 |
9 | import numpy as np
10 | from datasets import Dataset
11 |
12 | DEFAULT_USER_PROMPT = PromptTemplate.from_template("Question: {content}\n")
13 | DEFAULT_AI_PROMPT = PromptTemplate.from_template("Answer: {content}\n")
14 |
15 | ## qalign
16 | from qalign.utils.examples import MATH_EXAMPLARS, GSM8K_EXEMPLARS
17 |
18 | ## quest
19 | from quest.model.base import (
20 | DEFAULT_TEMPLATE,
21 | )
22 |
23 |
24 | class FlexiblePromptTemplate:
25 | def __init__(self, template):
26 | self.template = template
27 |
28 | def format(self, **kwargs):
29 | result = self.template
30 | # Only replace variables that exist in the template
31 | for key, value in kwargs.items():
32 | placeholder = "{" + key + "}"
33 | if placeholder in result:
34 | result = result.replace(placeholder, str(value))
35 | return result
36 |
37 |
38 | def parsehh_to_template(
39 | text: str,
40 | br="\n\n",
41 | available_roles={"Human", "Assistant"},
42 | transfer_roles={
43 | "Human": "user",
44 | "Assistant": "assistant",
45 | },
46 | ):
47 |
48 | breaks = text.split(br)
49 |
50 | chlist = []
51 | for ch in breaks:
52 | try:
53 | subbreaks = ch.split(":")
54 | role = subbreaks[0]
55 | if role in available_roles:
56 |
57 | chlist.append(
58 | {
59 | "role": transfer_roles[role], # .lower(),
60 | "content": subbreaks[1].strip(),
61 | }
62 | )
63 | else:
64 |
65 | if len(chlist) > 0:
66 | chlist[-1]["content"] += br + ch
67 |
68 | except:
69 | pass
70 |
71 | return chlist
72 |
73 |
74 | def general_process_data_chat(
75 | chat_template_prompt,
76 | tokenizer,
77 | **extra_data,
78 | ):
79 |
80 | return {
81 | "prompt": tokenizer.apply_chat_template(
82 | chat_template_prompt, # [ {"role": "user", "content": ... }, {"role": "assistant", "content": ... }, ... ],
83 | tokenize=False,
84 | add_generation_prompt=True,
85 | ),
86 | "chat_template_prompt": chat_template_prompt,
87 | **extra_data,
88 | }
89 |
90 |
91 | def general_process_data_prompt(
92 | chat_template_prompt,
93 | tokenizer,
94 | user_prompt=DEFAULT_USER_PROMPT,
95 | ai_prompt=DEFAULT_AI_PROMPT,
96 | **extra_data,
97 | ):
98 | concatenated_prompt = tokenizer.bos_token
99 |
100 | for message in chat_template_prompt:
101 | role = message["role"]
102 | content = message["content"].strip()
103 |
104 | if role == "user":
105 | concatenated_prompt += user_prompt.format(content=content)
106 | elif role == "assistant":
107 | concatenated_prompt += ai_prompt.format(content=content)
108 |
109 | concatenated_prompt += "\n" # Add extra newline between Q&A pairs
110 |
111 | return {
112 | "prompt": concatenated_prompt,
113 | "chat_template_prompt": chat_template_prompt,
114 | **extra_data,
115 | }
116 |
117 |
118 | def general_process_data(chat_template_prompt, format="chat", **extra_data):
119 |
120 | if format == "chat":
121 | return general_process_data_chat(chat_template_prompt, **extra_data)
122 | else:
123 | return general_process_data_prompt(chat_template_prompt, **extra_data)
124 |
125 |
126 | def processhh_data(entry, tokenizer, format="chat", prompt_template=None, **kwargs):
127 |
128 | br = "\n\n"
129 | available_roles = {"Human", "Assistant"}
130 | # breaks = entry["chosen"].split(br)
131 |
132 | chat_template = parsehh_to_template(
133 | entry["chosen"],
134 | br=br,
135 | available_roles=available_roles,
136 | )
137 | chat_template_reject = parsehh_to_template(
138 | entry["rejected"],
139 | br=br,
140 | available_roles=available_roles,
141 | )
142 |
143 | return general_process_data(
144 | chat_template_prompt=chat_template[:-1],
145 | tokenizer=tokenizer,
146 | format=format,
147 | answer=chat_template[-1]["content"],
148 | bad_answer=chat_template_reject[-1]["content"],
149 | )
150 |
151 |
152 | def processgsm_data(
153 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs
154 | ):
155 | if use_few_shot:
156 | gsm_messages = [
157 | [
158 | {
159 | "role": "user",
160 | "content": prompt_template.format(prompt=sample["question"]),
161 | },
162 | {"role": "assistant", "content": sample["cot_answer"]},
163 | ]
164 | for sample in GSM8K_EXEMPLARS
165 | ]
166 | # flatten
167 | gsm_messages = [item for sublist in gsm_messages for item in sublist]
168 | else:
169 | gsm_messages = []
170 |
171 | chat_template = [
172 | {
173 | "role": "user",
174 | "content": prompt_template.format(prompt=entry["question"]),
175 | },
176 | {
177 | "role": "assistant",
178 | "content": entry["answer"],
179 | },
180 | ]
181 |
182 | return general_process_data_chat(
183 | chat_template_prompt=chat_template[:-1],
184 | format=format,
185 | tokenizer=tokenizer,
186 | answer=chat_template[-1]["content"],
187 | )
188 |
189 |
190 | def processmath_data(
191 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs
192 | ):
193 |
194 | if use_few_shot:
195 | math_messages = [
196 | [
197 | {
198 | "role": "user",
199 | "content": prompt_template.format(prompt=sample["question"]),
200 | },
201 | {"role": "assistant", "content": sample["cot_answer"]},
202 | ]
203 | for sample in MATH_EXAMPLARS
204 | ]
205 | # flatten
206 | math_messages = [item for sublist in math_messages for item in sublist]
207 | else:
208 | math_messages = []
209 |
210 | chat_template = math_messages + [
211 | {
212 | "role": "user",
213 | "content": prompt_template.format(prompt=entry["problem"]),
214 | },
215 | {
216 | "role": "assistant",
217 | "content": entry["solution"],
218 | },
219 | ]
220 |
221 | # Get the final answer for evaluation
222 | # answer = get_last_math(chat_template[-1]["content"])
223 | # if answer is None:
224 | # return None
225 |
226 | return general_process_data(
227 | chat_template_prompt=chat_template[:-1],
228 | format=format,
229 | tokenizer=tokenizer,
230 | answer=entry["solution"], # answer,
231 | )
232 |
233 |
234 | def processalpaca_data(
235 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs
236 | ):
237 | # import pdb; pdb.set_trace()
238 |
239 | chat_template = [
240 | {
241 | "role": "user",
242 | "content": prompt_template.format(prompt=entry["instruction"]),
243 | },
244 | {
245 | "role": "assistant",
246 | "content": entry["output"],
247 | },
248 | ]
249 |
250 | return general_process_data(
251 | chat_template_prompt=chat_template[:-1],
252 | format=format,
253 | tokenizer=tokenizer,
254 | answer=chat_template[-1]["content"],
255 | )
256 |
257 |
258 | def processifeval_data(
259 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs
260 | ):
261 | # import pdb; pdb.set_trace()
262 |
263 | chat_template = [
264 | {
265 | "role": "user",
266 | "content": prompt_template.format(prompt=entry["prompt"]),
267 | }
268 | ]
269 |
270 | return general_process_data(
271 | chat_template_prompt=chat_template,
272 | format=format,
273 | tokenizer=tokenizer,
274 | answer="no answer",
275 | )
276 |
277 |
278 | def processmmlu_data(
279 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs
280 | ):
281 |
282 | choice_options = ["A", "B", "C", "D"]
283 |
284 | l = entry["subject"].split("_")
285 | subject = ""
286 | for x in l:
287 | subject += " " + x
288 |
289 | instruction = f"Choose the correct answer to the following multiple-choice question about {subject}.\n\n"
290 |
291 | prompt = entry["question"]
292 | choices = entry["choices"]
293 | answer = int(entry["answer"])
294 |
295 | instruction += "Question: {}\n\n".format(prompt)
296 |
297 | for j in range(len(choices)):
298 | instruction += "{}). {}\n".format(choice_options[j], choices[j])
299 |
300 | instruction += "\nProvide your reasoning about the answer and finish your answer with the letter corresponding to the correct option (e.g., A, B, C, or D).\n\n"
301 |
302 | prompt_ans = "\nAnswer:"
303 |
304 | # if include_answer:
305 | prompt_ans += " {}\n\n".format(choice_options[answer])
306 |
307 | entry = {"prompt": instruction, "answer": prompt_ans}
308 |
309 | chat_template = [
310 | {
311 | "role": "user",
312 | "content": prompt_template.format(prompt=instruction),
313 | },
314 | {
315 | "role": "assistant",
316 | "content": prompt_ans,
317 | },
318 | ]
319 |
320 | # Get the final answer for evaluation
321 | # answer = get_last_math(chat_template[-1]["content"])
322 | # if answer is None:
323 | # return None
324 |
325 | return general_process_data(
326 | chat_template_prompt=chat_template[:-1],
327 | format=format,
328 | tokenizer=tokenizer,
329 | answer=prompt_ans, # answer,
330 | )
331 |
332 |
333 | def processtruthfulqa_data(
334 | entry, tokenizer, prompt_template, format="chat", use_few_shot=False, **kwargs
335 | ):
336 | choices_key = "mc1_targets"
337 |
338 | choice_options = [chr(i) for i in range(ord("A"), ord("Z") + 1)]
339 |
340 | instruction = (
341 | "Choose the correct answer to the following multiple-choice question.\n\n"
342 | )
343 |
344 | prompt = entry["question"]
345 | choices = entry[choices_key]["choices"]
346 | answer = np.argmax(entry[choices_key]["labels"])
347 |
348 | # answer = int(entry["answer"])
349 |
350 | instruction += "Question: {}\n\n".format(prompt)
351 |
352 | for j in range(len(choices)):
353 | instruction += "{}). {}\n".format(choice_options[j], choices[j])
354 |
355 | instruction += "\nProvide your reasoning about the answer and finish your answer with the letter corresponding to the correct option (e.g., A, B, C, or D).\n\n"
356 |
357 | prompt_ans = "\nAnswer:"
358 |
359 | # if include_answer:
360 | prompt_ans += " {}\n\n".format(choice_options[answer])
361 |
362 | entry = {"prompt": instruction, "answer": prompt_ans}
363 |
364 | chat_template = [
365 | {
366 | "role": "user",
367 | "content": prompt_template.format(prompt=instruction),
368 | },
369 | {
370 | "role": "assistant",
371 | "content": prompt_ans,
372 | },
373 | ]
374 |
375 | return general_process_data(
376 | chat_template_prompt=chat_template[:-1],
377 | format=format,
378 | tokenizer=tokenizer,
379 | answer=prompt_ans, # answer,
380 | )
381 |
382 |
383 | def process_all_redux_data():
384 |
385 | subjects = [
386 | "anatomy",
387 | "business_ethics",
388 | "clinical_knowledge",
389 | "college_chemistry",
390 | "college_computer_science",
391 | "college_mathematics",
392 | "college_medicine",
393 | "college_physics",
394 | "econometrics",
395 | "electrical_engineering",
396 | "formal_logic",
397 | "global_facts",
398 | "high_school_chemistry",
399 | "high_school_mathematics",
400 | "high_school_physics",
401 | "high_school_statistics",
402 | "human_aging",
403 | "logical_fallacies",
404 | "machine_learning",
405 | "miscellaneous",
406 | "philosophy",
407 | "professional_accounting",
408 | "public_relations",
409 | "virology",
410 | "conceptual_physics",
411 | "high_school_us_history",
412 | "astronomy",
413 | "high_school_geography",
414 | "high_school_macroeconomics",
415 | "professional_law",
416 | ]
417 |
418 | ds = []
419 | for subject in subjects:
420 |
421 | dsi = load_dataset("edinburgh-dawg/mmlu-redux", subject, split="test")
422 | ds.extend([{"subject": subject, **x} for x in dsi])
423 |
424 | dataset = Dataset.from_list(ds)
425 |
426 | return dataset
427 |
428 |
429 | SUPPORTED_DATASETS = {
430 | "openai/gsm8k": ({"name": "socratic"}, processgsm_data, "openai/gsm8k"),
431 | "apple/GSM-Symbolic-p1": ({"name": "p1"}, processgsm_data, "apple/GSM-Symbolic"),
432 | "apple/GSM-Symbolic-p2": ({"name": "p2"}, processgsm_data, "apple/GSM-Symbolic"),
433 | "Anthropic/hh-rlhf": ({}, processhh_data, "Anthropic/hh-rlhf"),
434 | "lighteval/MATH": ({}, processmath_data, "lighteval/MATH"),
435 | "cais/mmlu": ({"name": "all"}, processmmlu_data, "cais/mmlu"),
436 | "truthfulqa/truthful_qa": (
437 | {"name": "multiple_choice"},
438 | processtruthfulqa_data,
439 | "truthfulqa/truthful_qa",
440 | ),
441 | "HuggingFaceH4/MATH-500": (
442 | {},
443 | processmath_data,
444 | "HuggingFaceH4/MATH-500",
445 | ),
446 | "tatsu-lab/alpaca_eval": (
447 | {},
448 | processalpaca_data,
449 | "tatsu-lab/alpaca_eval",
450 | ),
451 | "google/IFEval": (
452 | {},
453 | processifeval_data,
454 | "google/IFEval",
455 | ),
456 | "edinburgh-dawg/mmlu-redux": (
457 | {},
458 | processmmlu_data,
459 | "edinburgh-dawg/mmlu-redux",
460 | process_all_redux_data,
461 | ),
462 | # eval
463 | # "tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
464 | # processifeval_data
465 | }
466 |
467 | # edinburgh-dawg/mmlu-redux
468 |
469 |
470 | # ds = load_dataset("cais/mmlu", "abstract_algebra", split="test")
471 |
472 |
473 | def get_data_iterable(
474 | model_path: str,
475 | dataset_path: str,
476 | split: str = "test",
477 | n=None,
478 | prompt_template=DEFAULT_TEMPLATE,
479 | use_few_shot=False,
480 | format="chat",
481 | **kwargs,
482 | ):
483 |
484 | tokenizer = AutoTokenizer.from_pretrained(model_path)
485 |
486 | if dataset_path in SUPPORTED_DATASETS:
487 |
488 | resources = SUPPORTED_DATASETS[dataset_path]
489 |
490 | if len(resources) == 3:
491 | kwargs, pfunc, dataset_path = SUPPORTED_DATASETS[dataset_path]
492 |
493 | ds = load_dataset(
494 | dataset_path,
495 | split=split,
496 | # desc=None,
497 | **kwargs,
498 | )
499 | else:
500 | kwargs, pfunc, dataset_path, load_func = SUPPORTED_DATASETS[dataset_path]
501 | ds = load_func()
502 |
503 | if "answer" in ds.column_names:
504 | ds = ds.cast_column("answer", Value("string"))
505 |
506 | # limit the size of the dataset
507 |
508 | if n is not None:
509 | ds = ds.select(list(range(n)))
510 |
511 | ds = ds.map(
512 | partial(
513 | pfunc,
514 | tokenizer=tokenizer,
515 | prompt_template=prompt_template,
516 | use_few_shot=use_few_shot,
517 | format=format,
518 | ),
519 | load_from_cache_file=False,
520 | # desc=None, # Disable tqdm progress bar
521 | )
522 |
523 | else:
524 | ds = load_dataset(dataset_path, split=split, trust_remote_code=True, **kwargs)
525 |
526 | if n is not None:
527 | ds = list(ds)[:n]
528 |
529 | return list(ds)
530 |
--------------------------------------------------------------------------------
/qalign/utils/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import multiprocessing as mp
4 |
5 |
6 | from typing import *
7 |
8 |
9 | from tqdm import tqdm
10 |
11 | import numpy as np
12 |
13 | ## qalign
14 | from qalign.utils.ifeval.eval import test_instruction_following_loose
15 | from qalign.utils.math import get_last_math, get_last_option, get_last_number
16 |
17 | ## quest
18 | from quest.reward.model import ValueHead
19 | from quest.reward.remote import RemoteReward
20 | from quest.utils.list import (
21 | flatten_list,
22 | unflatten_list,
23 | get_unique_mapping,
24 | invert_unique_mapping,
25 | chunked,
26 | )
27 | from quest import (
28 | Reward,
29 | RewardModel,
30 | ContextualRewardModel,
31 | )
32 |
33 | ## expkit
34 | from expkit import (
35 | ExpSetup,
36 | Exp,
37 | Evalutor,
38 | )
39 | from expkit.ops import proj
40 |
41 |
42 | def compress_predictions(instances, n=None):
43 | prompts, tokens, vocabs, counts = (
44 | [],
45 | [],
46 | [],
47 | [],
48 | )
49 |
50 | # instances = experiment.instances()
51 | if n is not None:
52 | instances = instances[:n]
53 |
54 | for i in instances:
55 |
56 | tksi, vcbi = get_unique_mapping(
57 | list(
58 | map(
59 | proj("text"),
60 | i["outputs"],
61 | )
62 | )
63 | )
64 |
65 | promptsi = [i["input"]["prompt"]] * len(vcbi)
66 |
67 | prompts.append(promptsi)
68 | vocabs.append(vcbi)
69 | tokens.append(tksi)
70 | counts.append(len(vcbi))
71 |
72 | return prompts, tokens, vocabs, counts
73 |
74 |
75 | class ExactEval(Evalutor):
76 |
77 | def __init__(self, process_fn, name="exact"):
78 | super().__init__(name)
79 |
80 | self.process_fn = process_fn
81 |
82 | def eval(self, experiment: Exp):
83 |
84 | results = []
85 |
86 | instances = experiment.instances()
87 |
88 | if not isinstance(instances, list):
89 | instances = [instances]
90 |
91 | for i in tqdm(instances):
92 |
93 | preds = [self.process_fn(x["text"]) for x in i["outputs"]]
94 |
95 | target = self.process_fn(i["input"]["answer"])
96 |
97 | scores = list(
98 | map(
99 | lambda p: int(p == target),
100 | preds,
101 | )
102 | )
103 |
104 | results.append({"scores": scores})
105 |
106 | return results
107 |
108 |
109 | class IFEval(Evalutor):
110 |
111 | def __init__(self, name="ifeval"):
112 | super().__init__(name)
113 |
114 | def create_eval_pairs(self, experiment):
115 | """Create all input-output pairs for evaluation."""
116 | instances = experiment.instances()
117 | if not isinstance(instances, list):
118 | instances = [instances]
119 |
120 | eval_pairs = []
121 | for idx, instance in enumerate(instances):
122 | for output in instance["outputs"]:
123 | eval_pairs.append(
124 | {
125 | "instance_idx": idx,
126 | "input": instance["input"],
127 | "output": {
128 | "text": output["text"].replace("<|end_of_text|>", "")
129 | },
130 | }
131 | )
132 | return eval_pairs, len(instances)
133 |
134 | def process_single_pair(self, pair):
135 | """Process a single input-output pair."""
136 |
137 | return {
138 | "instance_idx": pair["instance_idx"],
139 | "score": test_instruction_following_loose(pair["input"], pair["output"])[
140 | "follow_all_instructions"
141 | ],
142 | }
143 |
144 | def eval(self, experiment: Exp):
145 | # Step 1: Create all evaluation pairs
146 | eval_pairs, num_instances = self.create_eval_pairs(experiment)
147 |
148 | # Step 2: Process pairs in parallel
149 | with mp.Pool() as pool:
150 | results_flat = list(
151 | tqdm( # pool.i
152 | pool.map(self.process_single_pair, eval_pairs),
153 | total=len(eval_pairs),
154 | desc="Processing pairs",
155 | )
156 | )
157 |
158 | # Step 3: Reorganize results into original structure
159 | results = []
160 | for i in range(num_instances):
161 | instance_scores = [
162 | r["score"] for r in results_flat if r["instance_idx"] == i
163 | ]
164 | results.append({"scores": instance_scores})
165 |
166 | return results
167 |
168 |
169 | class ExactLastNumberEval(ExactEval):
170 |
171 | def __init__(self):
172 | super().__init__(
173 | process_fn=get_last_number,
174 | name="lastnumber",
175 | )
176 |
177 |
178 | class ExactMATHEval(ExactEval):
179 |
180 | def __init__(self):
181 | super().__init__(
182 | process_fn=get_last_math,
183 | name="lastmath",
184 | )
185 |
186 |
187 | class ExactQAEval(ExactEval):
188 |
189 | def __init__(self):
190 | super().__init__(
191 | process_fn=get_last_option,
192 | name="lastoption",
193 | )
194 |
195 |
196 | class RewardEval(Evalutor):
197 |
198 | def __init__(self, reward: Reward, n=None, chunk_size=128):
199 | super().__init__(reward.get_name())
200 | self.reward = reward
201 | self.n = n
202 | self.chunk_size = chunk_size
203 |
204 | def eval(self, experiment: Exp):
205 |
206 | all_duplicated_scores = []
207 |
208 | for instance_chunk in chunked(
209 | experiment.instances(lazy_iterable=True), self.chunk_size
210 | ):
211 | # Process current chunk of instances
212 | prompts, tokens, vocabs, counts = compress_predictions(
213 | instance_chunk, n=self.n
214 | )
215 |
216 | # Set context for this chunk only
217 | if isinstance(
218 | self.reward,
219 | (ContextualRewardModel, ValueHead, RemoteReward),
220 | ):
221 | self.reward.set_context(context=flatten_list(prompts))
222 |
223 | # Evaluate just this chunk's candidates
224 | chunk_flat_vocabs = flatten_list(vocabs)
225 | chunk_scores = self.reward.evaluate(
226 | candidates=chunk_flat_vocabs,
227 | use_tqdm=True, # Disable per-chunk progress bars
228 | )
229 |
230 | # Process scores for this chunk only
231 | chunk_unflattened = unflatten_list(chunk_scores, counts)
232 | all_duplicated_scores.extend(
233 | [
234 | {"scores": invert_unique_mapping(tks, rsi)}
235 | for tks, rsi in zip(tokens, chunk_unflattened)
236 | ]
237 | )
238 |
239 | del instance_chunk
240 | del prompts
241 | del tokens
242 | del vocabs
243 | del counts
244 |
245 | return all_duplicated_scores
246 |
--------------------------------------------------------------------------------
/qalign/utils/examples.py:
--------------------------------------------------------------------------------
1 | # exemplars we will use to prompt the model
2 | # These examplars are from the Table 20 of CoT paper (https://arxiv.org/pdf/2201.11903.pdf).
3 | GSM8K_EXEMPLARS = [
4 | {
5 | "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
6 | "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.",
7 | "short_answer": "6",
8 | },
9 | {
10 | "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
11 | "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. So the answer is 5.",
12 | "short_answer": "5",
13 | },
14 | {
15 | "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
16 | "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. So the answer is 39.",
17 | "short_answer": "39",
18 | },
19 | {
20 | "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
21 | "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. So the answer is 8.",
22 | "short_answer": "8",
23 | },
24 | {
25 | "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
26 | "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. So the answer is 9.",
27 | "short_answer": "9",
28 | },
29 | {
30 | "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
31 | "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. So the answer is 29.",
32 | "short_answer": "29",
33 | },
34 | {
35 | "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
36 | "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. So the answer is 33.",
37 | "short_answer": "33",
38 | },
39 | {
40 | "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
41 | "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. So the answer is 8.",
42 | "short_answer": "8",
43 | },
44 | ]
45 |
46 | # These examplars are from the DeepSeekMath GitHub repository (https://github.com/deepseek-ai/DeepSeek-Math/tree/main/evaluation/few_shot_prompts)
47 | MATH_EXAMPLARS = [
48 | {
49 | "question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}",
50 | "cot_answer": "The expressions inside each square root must be non-negative.\nTherefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.",
51 | "short_answer": "[2,5)",
52 | },
53 | {
54 | "question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$",
55 | "cot_answer": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$",
56 | "short_answer": "24",
57 | },
58 | {
59 | "question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?",
60 | "cot_answer": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}\n30n&=480\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}",
61 | "short_answer": "16",
62 | },
63 | {
64 | "question": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.",
65 | "cot_answer": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$",
66 | "short_answer": "-\\frac{2}{3}",
67 | },
68 | ]
69 |
--------------------------------------------------------------------------------
/qalign/utils/experiment.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict, Any, List
2 | from datetime import datetime
3 | from tqdm import tqdm
4 | import os
5 | import copy
6 |
7 | ## expkit
8 | from expkit.exp import Exp
9 | from expkit.storage import DiskStorage
10 |
11 | ## quest
12 | from quest.reward.model import ContextualRewardModel, ValueHead
13 | from quest.reward.remote import RemoteReward
14 | from quest.proposal import RLHFSuffixProposal
15 | from quest.core import Quest
16 |
17 | ## qalign
18 | from qalign.utils.data import FlexiblePromptTemplate
19 | from qalign.utils.data import get_data_iterable
20 | from quest.utils.list import chunked
21 |
22 |
23 | ## literegistry
24 | from literegistry import RegistryClient, FileSystemKVStore
25 |
26 |
27 | def create_experiment(
28 | save_path: str,
29 | variant: str,
30 | model_path: str,
31 | dataset_path: str,
32 | n: int,
33 | temperature: float,
34 | steps: int,
35 | max_new_tokens: int = 100,
36 | max_prompt_length: int = 600,
37 | batch_size: int = 64,
38 | split: str = "test",
39 | prompt_template: str = "{prompt}",
40 | stop_tokens: Optional[List[str]] = None,
41 | additional_meta: Optional[Dict[str, Any]] = None,
42 | format: str = "chat", # either "chat" or "prompt"
43 | use_few_shot: bool = False,
44 | ) -> Exp:
45 | """
46 | Creates a standardized experiment with common metadata.
47 | """
48 | if stop_tokens is None:
49 | stop_tokens = []
50 |
51 | meta = {
52 | "steps": steps,
53 | "temperature": temperature,
54 | "n": n,
55 | "model_path": model_path,
56 | "variant": variant,
57 | "stop_tokens": stop_tokens,
58 | "max_new_tokens": max_new_tokens,
59 | "max_prompt_length": max_prompt_length,
60 | "at": datetime.now().isoformat(),
61 | "dataset": dataset_path,
62 | "split": split,
63 | "prompt_template": prompt_template,
64 | "batch_size": batch_size,
65 | "format": format,
66 | "use_few_shot": use_few_shot,
67 | }
68 |
69 | if additional_meta:
70 | meta.update(additional_meta)
71 |
72 | return Exp(
73 | storage=DiskStorage(save_path, "rw"),
74 | meta=meta,
75 | )
76 |
77 |
78 | def create_extension_experiment(storage, experiment, new_steps=1024):
79 | samples = [
80 | {
81 | "input": (
82 | data["input"]["input"] if "input" in data["input"] else data["input"]
83 | ),
84 | "completion": data["outputs"][-1]["text"],
85 | "reward": float(data["outputs"][-1]["reward"]),
86 | }
87 | for data in experiment.instances(lazy_iterable=True)
88 | ]
89 |
90 | meta = copy.deepcopy(experiment.meta)
91 |
92 | meta["steps"] = new_steps
93 | meta["bootstrap"] = samples
94 | meta["link"] = experiment.get_name()
95 |
96 | new_exp = Exp(
97 | storage=storage,
98 | name=experiment.get_name() + "-extension",
99 | meta=meta,
100 | )
101 |
102 | return new_exp
103 |
104 |
105 | def create_vllm_model(
106 | model_path: str,
107 | temperature: float,
108 | max_new_tokens: int = 100,
109 | max_prompt_length: int = 600,
110 | stop_tokens: Optional[List[str]] = None,
111 | device_count: int = 1,
112 | gpu_memory_utilization: float = 0.8,
113 | # prompt_template: Optional[str] = None,
114 | enforce_eager: bool = False,
115 | remote: bool = False,
116 | ):
117 | """
118 | Creates a standardized VLLM model instance with common configurations.
119 | """
120 | if stop_tokens is None:
121 | stop_tokens = [""]
122 |
123 | extra_args = {}
124 | if "bnb" in model_path:
125 | extra_args.update(
126 | {
127 | "trust_remote_code": True,
128 | "quantization": "bitsandbytes",
129 | "load_format": "bitsandbytes",
130 | }
131 | )
132 |
133 | model_args = {
134 | "model_path": model_path,
135 | "download_dir": os.environ.get("HF_HOME", "/tmp/"),
136 | "stop_tokens": stop_tokens,
137 | "temperature": temperature,
138 | "gpu_memory_utilization": gpu_memory_utilization,
139 | "dtype": "bfloat16",
140 | "max_new_tokens": max_new_tokens,
141 | "max_prompt_length": max_prompt_length,
142 | "tensor_parallel_size": device_count,
143 | "enable_prefix_caching": True,
144 | "enforce_eager": enforce_eager,
145 | **extra_args,
146 | }
147 |
148 | if remote:
149 | from quest.model.remote import RemoteVLLM
150 |
151 | registry = RegistryClient(
152 | store=FileSystemKVStore("/gscratch/ark/graf/registry"),
153 | max_history=3600,
154 | cache_ttl=60,
155 | service_type="model_path",
156 | )
157 |
158 | return RemoteVLLM(registry=registry, **model_args)
159 | else:
160 | from quest.model.vllm import VLLM
161 |
162 | return VLLM(**model_args)
163 |
164 |
165 | def process_batch_outputs(
166 | chain_outputs: Any, batch_size: int
167 | ) -> List[List[Dict[str, Any]]]:
168 | """
169 | Processes batch outputs from a Quest chain into a standardized format.
170 | """
171 | outputs = []
172 | for i in range(batch_size):
173 | outputs.append(
174 | [
175 | {
176 | "t": s["t"],
177 | **{k: v[i] for k, v in s.items() if k != "t"},
178 | }
179 | for s in chain_outputs.state_path
180 | ]
181 | )
182 | return outputs
183 |
184 |
185 | def get_batched_data(
186 | model,
187 | dataset_path: str,
188 | split: str,
189 | n: int,
190 | batch_size: int,
191 | prompt_template: str,
192 | start_index: int = 0,
193 | num_chains: int = 1,
194 | completed: int = 0,
195 | format="chat",
196 | use_few_shot=False,
197 | ) -> List[Any]:
198 | """
199 | Gets batched data from a dataset using standard configurations.
200 | """
201 | data_iterable = get_data_iterable(
202 | model_path=model.model_path,
203 | dataset_path=dataset_path,
204 | split=split,
205 | n=start_index + n,
206 | tokenizer=model.tokenizer,
207 | prompt_template=FlexiblePromptTemplate(prompt_template),
208 | format=format,
209 | use_few_shot=use_few_shot,
210 | )
211 |
212 | if start_index > 0:
213 | data_iterable = data_iterable[start_index : start_index + n]
214 |
215 | if num_chains > 1:
216 | data_iterable = [x for x in data_iterable for _ in range(num_chains)]
217 |
218 | data_iterable = data_iterable[completed:]
219 |
220 | batches = []
221 | for i in range(0, len(data_iterable), batch_size):
222 | batches.append(data_iterable[i : i + batch_size])
223 |
224 | return batches
225 |
226 |
227 | def calculate_reward_scores(
228 | experiment: Exp,
229 | reward_key: Optional[str] = None,
230 | ) -> List[Dict[str, List[float]]]:
231 | """
232 | Calculates reward scores for experiment instances.
233 | """
234 | beta = experiment.meta.get("beta", 1.0)
235 |
236 | if not reward_key and "reward_model_path" in experiment.meta:
237 | reward_key = "crm:" + experiment.meta["reward_model_path"].split(".")[
238 | 0
239 | ].replace("/", "-")
240 |
241 | return [
242 | {"scores": [float(o["reward"]) * beta for o in i["outputs"]]}
243 | for i in experiment.instances(lazy_iterable=True)
244 | ]
245 |
246 |
247 | def run_ancestral(experiment, model, steps, data_batches):
248 |
249 | steps = experiment.meta["steps"]
250 |
251 | # Process each batch
252 | for data_batch in tqdm(data_batches):
253 | completions_txt = model.ancestral(data_batch, n=steps)
254 | outputs = [
255 | [{"text": state_t} for state_t in instance_txt]
256 | for instance_txt in completions_txt
257 | ]
258 |
259 | experiment.add_instances(
260 | inputs=data_batch,
261 | outputs=outputs,
262 | )
263 |
264 |
265 | def run_quest(
266 | experiment,
267 | model,
268 | steps,
269 | data_batches,
270 | reward_model_batch_size=64,
271 | reward_device=1,
272 | reward_device_count=1,
273 | remote=False,
274 | ):
275 |
276 | reward_type = experiment.meta["reward_type"]
277 |
278 | if remote:
279 | registry = RegistryClient(
280 | store=FileSystemKVStore("/gscratch/ark/graf/registry"),
281 | max_history=3600,
282 | cache_ttl=60,
283 | service_type="model_path",
284 | )
285 |
286 | reward = RemoteReward(
287 | registry=registry,
288 | model_path=experiment.meta["reward_model_path"],
289 | reward_type=reward_type,
290 | # batch_size=reward_model_batch_size,
291 | )
292 |
293 | else:
294 | if reward_type == "contextual":
295 | reward = ContextualRewardModel(
296 | model_path=experiment.meta["reward_model_path"],
297 | # batch_size=reward_model_batch_size,
298 | device=reward_device,
299 | device_count=reward_device_count,
300 | )
301 | elif reward_type == "value":
302 | reward = ValueHead(
303 | model_path=experiment.meta["reward_model_path"],
304 | # batch_size=reward_model_batch_size,
305 | device=reward_device,
306 | device_count=reward_device_count,
307 | ) # sentiment model.
308 | else:
309 | raise ValueError(f"Unknown reward type: {reward_type}")
310 |
311 | # Process each batch
312 | for data_batch in data_batches:
313 | context = [model.get_prompt(**data) for data in data_batch]
314 | reward.set_context(context)
315 |
316 | chain = Quest(
317 | input_data=data_batch,
318 | proposal=RLHFSuffixProposal(
319 | model=model,
320 | reward=reward,
321 | ),
322 | beta=experiment.meta["beta"],
323 | )
324 |
325 | chain_outputs = chain.run(
326 | steps=steps,
327 | use_tqdm=True,
328 | )
329 |
330 | outputs = process_batch_outputs(chain_outputs, len(data_batch))
331 | experiment.add_instances(
332 | inputs=data_batch,
333 | outputs=outputs,
334 | )
335 |
336 | # Calculate and add reward scores
337 | scores = calculate_reward_scores(experiment)
338 | experiment.add_eval(reward.get_name(), scores)
339 |
340 | return experiment
341 |
342 |
343 | def run_quest_bootstrap(
344 | experiment,
345 | model,
346 | steps,
347 | reward_device=1,
348 | reward_device_count=1,
349 | remote=False,
350 | ):
351 |
352 | reward_type = experiment.meta["reward_type"]
353 |
354 | if remote:
355 | registry = RegistryClient(
356 | store=FileSystemKVStore("/gscratch/ark/graf/registry"),
357 | max_history=3600,
358 | cache_ttl=60,
359 | service_type="model_path",
360 | )
361 |
362 | reward = RemoteReward(
363 | registry=registry,
364 | model_path=experiment.meta["reward_model_path"],
365 | reward_type=reward_type,
366 | # batch_size=reward_model_batch_size,
367 | )
368 |
369 | else:
370 | if reward_type == "contextual":
371 | reward = ContextualRewardModel(
372 | model_path=experiment.meta["reward_model_path"],
373 | # batch_size=reward_model_batch_size,
374 | device=reward_device,
375 | device_count=reward_device_count,
376 | )
377 | elif reward_type == "value":
378 | reward = ValueHead(
379 | model_path=experiment.meta["reward_model_path"],
380 | # batch_size=reward_model_batch_size,
381 | device=reward_device,
382 | device_count=reward_device_count,
383 | ) # sentiment model.
384 | else:
385 | raise ValueError(f"Unknown reward type: {reward_type}")
386 |
387 | for data_batch in chunked(
388 | experiment.meta["bootstrap"], experiment.get("batch_size")
389 | ):
390 | context = [data["input"]["prompt"] for data in data_batch]
391 |
392 | reward.set_context(context)
393 |
394 | chain = Quest(
395 | input_data=[data["input"] for data in data_batch],
396 | proposal=RLHFSuffixProposal(
397 | model=model,
398 | reward=reward,
399 | ),
400 | beta=experiment.meta["beta"],
401 | )
402 |
403 | chain_outputs = chain.run(
404 | steps=steps,
405 | use_tqdm=True,
406 | warm_start=[
407 | {"completion": data["completion"], "reward": data["reward"]}
408 | for data in data_batch
409 | ],
410 | )
411 |
412 | outputs = process_batch_outputs(chain_outputs, len(data_batch))
413 |
414 | experiment.add_instances(
415 | inputs=data_batch,
416 | outputs=outputs,
417 | )
418 |
419 | # Calculate and add reward scores
420 | scores = calculate_reward_scores(experiment)
421 | experiment.add_eval(reward.get_name(), scores)
422 |
423 | return experiment
424 |
425 |
426 | # Create model
427 | def run_experiment(
428 | experiment,
429 | gpu_memory_utilization=0.95,
430 | device_count=1,
431 | reward_model_batch_size=64,
432 | reward_device=1,
433 | reward_device_count=1,
434 | remote=False,
435 | ):
436 |
437 | # Create model
438 | model = create_vllm_model(
439 | model_path=experiment.meta["model_path"],
440 | temperature=experiment.meta["temperature"],
441 | max_new_tokens=experiment.meta["max_new_tokens"],
442 | max_prompt_length=experiment.meta["max_prompt_length"],
443 | device_count=device_count,
444 | gpu_memory_utilization=gpu_memory_utilization,
445 | stop_tokens=experiment.meta["stop_tokens"],
446 | remote=remote,
447 | )
448 |
449 | completed = len(experiment.instances())
450 |
451 | # Get batched data with start index
452 | data_batches = get_batched_data(
453 | model=model,
454 | dataset_path=experiment.meta["dataset"],
455 | split=experiment.meta["split"],
456 | n=experiment.meta["n"],
457 | batch_size=experiment.meta.get("batch_size", 64),
458 | prompt_template=experiment.meta["prompt_template"],
459 | start_index=experiment.meta.get("i", 0),
460 | num_chains=experiment.meta.get("num_chains", 1),
461 | completed=completed,
462 | format=experiment.meta.get("format", "chat"),
463 | use_few_shot=experiment.meta.get("use_few_shot", False),
464 | )
465 |
466 | if experiment.meta["variant"] == "ancestral":
467 | run_ancestral(
468 | experiment=experiment,
469 | model=model,
470 | steps=experiment.meta["steps"],
471 | data_batches=data_batches,
472 | )
473 | elif experiment.meta["variant"] == "quest-rlhf":
474 | run_quest(
475 | experiment=experiment,
476 | model=model,
477 | steps=experiment.meta["steps"],
478 | data_batches=data_batches,
479 | reward_model_batch_size=reward_model_batch_size,
480 | reward_device=reward_device,
481 | reward_device_count=reward_device_count,
482 | remote=True, # remote,
483 | )
484 |
485 |
486 | # Create model
487 | def run_experiment_remote(
488 | experiment,
489 | ):
490 |
491 | # Create model
492 | model = create_vllm_model(
493 | model_path=experiment.meta["model_path"],
494 | temperature=experiment.meta["temperature"],
495 | max_new_tokens=experiment.meta["max_new_tokens"],
496 | max_prompt_length=experiment.meta["max_prompt_length"],
497 | device_count=1,
498 | gpu_memory_utilization=0.95,
499 | stop_tokens=experiment.meta["stop_tokens"],
500 | remote=True,
501 | )
502 |
503 | completed = len(experiment.instances())
504 |
505 | # Get batched data with start index
506 | if "bootstrap" in experiment.meta:
507 |
508 | if "bootstrap" in experiment.meta:
509 | run_quest_bootstrap(
510 | experiment=experiment,
511 | model=model,
512 | steps=experiment.meta["steps"],
513 | reward_device=0,
514 | reward_device_count=1,
515 | remote=True,
516 | )
517 |
518 | else:
519 |
520 | data_batches = get_batched_data(
521 | model=model,
522 | dataset_path=experiment.meta["dataset"],
523 | split=experiment.meta["split"],
524 | n=experiment.meta["n"],
525 | batch_size=experiment.meta.get("batch_size", 64),
526 | prompt_template=experiment.meta["prompt_template"],
527 | start_index=experiment.meta.get("i", 0),
528 | num_chains=experiment.meta.get("num_chains", 1),
529 | completed=completed,
530 | format=experiment.meta.get("format", "chat"),
531 | use_few_shot=experiment.meta.get("use_few_shot", False),
532 | )
533 |
534 | if experiment.meta["variant"] == "ancestral":
535 | run_ancestral(
536 | experiment=experiment,
537 | model=model,
538 | steps=experiment.meta["steps"],
539 | data_batches=data_batches,
540 | )
541 | elif experiment.meta["variant"] == "quest-rlhf":
542 |
543 | run_quest(
544 | experiment=experiment,
545 | model=model,
546 | steps=experiment.meta["steps"],
547 | data_batches=data_batches,
548 | # reward_model_batch_size=experiment.meta.get("batch_size", 64),
549 | reward_device=0,
550 | reward_device_count=1,
551 | remote=True,
552 | )
553 |
--------------------------------------------------------------------------------
/qalign/utils/ifeval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__init__.py
--------------------------------------------------------------------------------
/qalign/utils/ifeval/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/qalign/utils/ifeval/__pycache__/eval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/eval.cpython-38.pyc
--------------------------------------------------------------------------------
/qalign/utils/ifeval/__pycache__/instructions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/instructions.cpython-38.pyc
--------------------------------------------------------------------------------
/qalign/utils/ifeval/__pycache__/instructions_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/instructions_util.cpython-38.pyc
--------------------------------------------------------------------------------
/qalign/utils/ifeval/__pycache__/registry.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/goncalorafaria/qalign/39b1df52c913bda4f94e656018f68915fd06299b/qalign/utils/ifeval/__pycache__/registry.cpython-38.pyc
--------------------------------------------------------------------------------
/qalign/utils/ifeval/eval.py:
--------------------------------------------------------------------------------
1 | from expkit import ExpSetup
2 | from expkit.storage import DiskStorage, CachedRODiskStorage
3 | from expkit import ExpSetup, Exp, ExpSetup
4 | from tqdm import tqdm
5 | import qflow.utils.ifeval.registry as instructions_registry
6 |
7 |
8 | def test_instruction_following_loose(
9 | inp,
10 | oup,
11 | ):
12 | """Tests response for an upper bound for following instructions."""
13 | response = oup["text"]
14 | r = response.split("\n")
15 | response_remove_first = "\n".join(r[1:]).strip()
16 | response_remove_last = "\n".join(r[:-1]).strip()
17 | response_remove_both = "\n".join(r[1:-1]).strip()
18 | revised_response = response.replace("*", "")
19 | revised_response_remove_first = response_remove_first.replace("*", "")
20 | revised_response_remove_last = response_remove_last.replace("*", "")
21 | revised_response_remove_both = response_remove_both.replace("*", "")
22 |
23 | all_responses = [
24 | response,
25 | revised_response,
26 | response_remove_first,
27 | response_remove_last,
28 | response_remove_both,
29 | revised_response_remove_first,
30 | revised_response_remove_last,
31 | revised_response_remove_both,
32 | ]
33 | instruction_list = inp["instruction_id_list"]
34 | is_following_list = []
35 |
36 | for index, instruction_id in enumerate(instruction_list):
37 | instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
38 | instruction = instruction_cls(instruction_id)
39 |
40 | kwargs = {k: v for k, v in inp["kwargs"][index].items() if v}
41 |
42 | instruction.build_description(**kwargs)
43 | args = instruction.get_instruction_args()
44 |
45 | if args and "prompt" in args:
46 | datasetprompt = inp["chat_template_prompt"][0]["content"]
47 | instruction.build_description(prompt=datasetprompt)
48 |
49 | is_following = False
50 | for r in all_responses:
51 | if r.strip() and instruction.check_following(r):
52 | is_following = True
53 | break
54 |
55 | is_following_list.append(is_following)
56 |
57 | return {
58 | "instruction_id_list": inp["instruction_id_list"],
59 | "follow_all_instructions": int(all(is_following_list)),
60 | "follow_instruction_list": is_following_list,
61 | }
62 |
63 |
64 | """
65 | storage = CachedRODiskStorage(
66 | base_dir="remote-outputs/",
67 | # mode="rw"
68 | )
69 |
70 |
71 | setup = ExpSetup(
72 | storage=storage,
73 | ).query({"dataset": "google/IFEval"})
74 |
75 | exp = setup[0]
76 |
77 |
78 | instances = exp.instances()
79 | results = {"scores": []}
80 | for i in tqdm(instances):
81 |
82 | scoresi = [
83 | test_instruction_following_loose(
84 | i["input"],
85 | oi,
86 | )["follow_all_instructions"]
87 | for oi in i["outputs"]
88 | ]
89 |
90 | results["scores"].append(scoresi)
91 |
92 | # instruction_id_list = data["input"]["instruction_id_list"]
93 | # prompt = data["input"]["prompt"]
94 | import pdb
95 |
96 | pdb.set_trace()
97 | """
98 |
--------------------------------------------------------------------------------
/qalign/utils/ifeval/registry.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Registry of all instructions."""
17 | from qflow.utils.ifeval import instructions
18 |
19 | _KEYWORD = "keywords:"
20 |
21 | _LANGUAGE = "language:"
22 |
23 | _LENGTH = "length_constraints:"
24 |
25 | _CONTENT = "detectable_content:"
26 |
27 | _FORMAT = "detectable_format:"
28 |
29 | _MULTITURN = "multi-turn:"
30 |
31 | _COMBINATION = "combination:"
32 |
33 | _STARTEND = "startend:"
34 |
35 | _CHANGE_CASES = "change_case:"
36 |
37 | _PUNCTUATION = "punctuation:"
38 |
39 | INSTRUCTION_DICT = {
40 | _KEYWORD + "existence": instructions.KeywordChecker,
41 | _KEYWORD + "frequency": instructions.KeywordFrequencyChecker,
42 | # TODO(jeffreyzhou): make a proper set of sentences to choose from
43 | # _KEYWORD + "key_sentences": instructions.KeySentenceChecker,
44 | _KEYWORD + "forbidden_words": instructions.ForbiddenWords,
45 | _KEYWORD + "letter_frequency": instructions.LetterFrequencyChecker,
46 | _LANGUAGE + "response_language": instructions.ResponseLanguageChecker,
47 | _LENGTH + "number_sentences": instructions.NumberOfSentences,
48 | _LENGTH + "number_paragraphs": instructions.ParagraphChecker,
49 | _LENGTH + "number_words": instructions.NumberOfWords,
50 | _LENGTH + "nth_paragraph_first_word": instructions.ParagraphFirstWordCheck,
51 | _CONTENT + "number_placeholders": instructions.PlaceholderChecker,
52 | _CONTENT + "postscript": instructions.PostscriptChecker,
53 | _FORMAT + "number_bullet_lists": instructions.BulletListChecker,
54 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace
55 | # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph,
56 | _FORMAT + "constrained_response": instructions.ConstrainedResponseChecker,
57 | _FORMAT + "number_highlighted_sections": (instructions.HighlightSectionChecker),
58 | _FORMAT + "multiple_sections": instructions.SectionChecker,
59 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message.
60 | # _FORMAT + "rephrase": instructions.RephraseChecker,
61 | _FORMAT + "json_format": instructions.JsonFormat,
62 | _FORMAT + "title": instructions.TitleChecker,
63 | # TODO(tianjianlu): Re-enable with specific prompts.
64 | # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker,
65 | _COMBINATION + "two_responses": instructions.TwoResponsesChecker,
66 | _COMBINATION + "repeat_prompt": instructions.RepeatPromptThenAnswer,
67 | _STARTEND + "end_checker": instructions.EndChecker,
68 | _CHANGE_CASES + "capital_word_frequency": instructions.CapitalWordFrequencyChecker,
69 | _CHANGE_CASES + "english_capital": instructions.CapitalLettersEnglishChecker,
70 | _CHANGE_CASES + "english_lowercase": instructions.LowercaseLettersEnglishChecker,
71 | _PUNCTUATION + "no_comma": instructions.CommaChecker,
72 | _STARTEND + "quotation": instructions.QuotationChecker,
73 | }
74 |
75 | INSTRUCTION_CONFLICTS = {
76 | _KEYWORD + "existence": {_KEYWORD + "existence"},
77 | _KEYWORD + "frequency": {_KEYWORD + "frequency"},
78 | # TODO(jeffreyzhou): make a proper set of sentences to choose from
79 | # _KEYWORD + "key_sentences": instructions.KeySentenceChecker,
80 | _KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"},
81 | _KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"},
82 | _LANGUAGE
83 | + "response_language": {
84 | _LANGUAGE + "response_language",
85 | _FORMAT + "multiple_sections",
86 | _KEYWORD + "existence",
87 | _KEYWORD + "frequency",
88 | _KEYWORD + "forbidden_words",
89 | _STARTEND + "end_checker",
90 | _CHANGE_CASES + "english_capital",
91 | _CHANGE_CASES + "english_lowercase",
92 | },
93 | _LENGTH + "number_sentences": {_LENGTH + "number_sentences"},
94 | _LENGTH
95 | + "number_paragraphs": {
96 | _LENGTH + "number_paragraphs",
97 | _LENGTH + "nth_paragraph_first_word",
98 | _LENGTH + "number_sentences",
99 | _LENGTH + "nth_paragraph_first_word",
100 | },
101 | _LENGTH + "number_words": {_LENGTH + "number_words"},
102 | _LENGTH
103 | + "nth_paragraph_first_word": {
104 | _LENGTH + "nth_paragraph_first_word",
105 | _LENGTH + "number_paragraphs",
106 | },
107 | _CONTENT + "number_placeholders": {_CONTENT + "number_placeholders"},
108 | _CONTENT + "postscript": {_CONTENT + "postscript"},
109 | _FORMAT + "number_bullet_lists": {_FORMAT + "number_bullet_lists"},
110 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace
111 | # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph,
112 | _FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()),
113 | _FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"},
114 | _FORMAT
115 | + "multiple_sections": {
116 | _FORMAT + "multiple_sections",
117 | _LANGUAGE + "response_language",
118 | _FORMAT + "number_highlighted_sections",
119 | },
120 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message.
121 | # _FORMAT + "rephrase": instructions.RephraseChecker,
122 | _FORMAT
123 | + "json_format": set(INSTRUCTION_DICT.keys()).difference(
124 | {_KEYWORD + "forbidden_words", _KEYWORD + "existence"}
125 | ),
126 | _FORMAT + "title": {_FORMAT + "title"},
127 | # TODO(tianjianlu): Re-enable with specific prompts.
128 | # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker,
129 | _COMBINATION
130 | + "two_responses": set(INSTRUCTION_DICT.keys()).difference(
131 | {
132 | _KEYWORD + "forbidden_words",
133 | _KEYWORD + "existence",
134 | _LANGUAGE + "response_language",
135 | _FORMAT + "title",
136 | _PUNCTUATION + "no_comma",
137 | }
138 | ),
139 | _COMBINATION
140 | + "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference(
141 | {_KEYWORD + "existence", _FORMAT + "title", _PUNCTUATION + "no_comma"}
142 | ),
143 | _STARTEND + "end_checker": {_STARTEND + "end_checker"},
144 | _CHANGE_CASES
145 | + "capital_word_frequency": {
146 | _CHANGE_CASES + "capital_word_frequency",
147 | _CHANGE_CASES + "english_lowercase",
148 | _CHANGE_CASES + "english_capital",
149 | },
150 | _CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"},
151 | _CHANGE_CASES
152 | + "english_lowercase": {
153 | _CHANGE_CASES + "english_lowercase",
154 | _CHANGE_CASES + "english_capital",
155 | },
156 | _PUNCTUATION + "no_comma": {_PUNCTUATION + "no_comma"},
157 | _STARTEND + "quotation": {_STARTEND + "quotation", _FORMAT + "title"},
158 | }
159 |
160 |
161 | def conflict_make(conflicts):
162 | """Makes sure if A conflicts with B, B will conflict with A.
163 |
164 | Args:
165 | conflicts: Dictionary of potential conflicts where key is instruction id
166 | and value is set of instruction ids that it conflicts with.
167 |
168 | Returns:
169 | Revised version of the dictionary. All instructions conflict with
170 | themselves. If A conflicts with B, B will conflict with A.
171 | """
172 | for key in conflicts:
173 | for k in conflicts[key]:
174 | conflicts[k].add(key)
175 | conflicts[key].add(key)
176 | return conflicts
177 |
--------------------------------------------------------------------------------
/qalign/utils/math.py:
--------------------------------------------------------------------------------
1 | import re
2 | import re
3 | import signal
4 | import logging
5 | from typing import Optional
6 |
7 | import sympy
8 | from sympy.parsing.latex import parse_latex
9 |
10 |
11 | ## from open instruct
12 |
13 |
14 | def generate_axis(limit, k):
15 | sequence = []
16 | i = 0
17 |
18 | # Generate powers of 2 up to K
19 | while 2**i < k:
20 | sequence.append(2**i)
21 | i += 1
22 |
23 | # Continue with increments of K
24 | value = k
25 | while value <= limit:
26 | sequence.append(value)
27 | value += k
28 |
29 | return sequence
30 |
31 |
32 | def get_last_number(output):
33 |
34 | output = re.sub(r"(\d),(\d)", r"\1\2", output)
35 | numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
36 | if numbers:
37 | return numbers[-1]
38 | else:
39 | return "NaN"
40 |
41 |
42 | def get_last_math(output):
43 |
44 | if "\\boxed" in output:
45 | last_box = last_boxed_only_string(output)
46 |
47 | try:
48 | final = remove_boxed(last_box)
49 | final = normalize_final_answer(final)
50 | except Exception as e:
51 | # print("no number in box")
52 | final = get_last_number(output)
53 |
54 | else:
55 | # print("no number in box")
56 | final = get_last_number(output)
57 |
58 | return final
59 |
60 |
61 | import re
62 | import signal
63 | import logging
64 | from typing import Optional
65 |
66 | import sympy
67 | from sympy.parsing.latex import parse_latex
68 |
69 |
70 | eval_logger = logging.getLogger("math_utils")
71 |
72 |
73 | # from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py#L187
74 | def last_boxed_only_string(string: str) -> Optional[str]:
75 | idx = string.rfind("\\boxed")
76 | if "\\boxed " in string:
77 | return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
78 | if idx < 0:
79 | idx = string.rfind("\\fbox")
80 | if idx < 0:
81 | return "NaN"
82 |
83 | i = idx
84 | right_brace_idx = None
85 | num_left_braces_open = 0
86 | while i < len(string):
87 | if string[i] == "{":
88 | num_left_braces_open += 1
89 | if string[i] == "}":
90 | num_left_braces_open -= 1
91 | if num_left_braces_open == 0:
92 | right_brace_idx = i
93 | break
94 | i += 1
95 |
96 | if right_brace_idx is None:
97 | retval = None
98 | else:
99 | retval = string[idx : right_brace_idx + 1]
100 |
101 | return retval
102 |
103 |
104 | def remove_boxed(s: str):
105 | left = "\\boxed{"
106 | try:
107 | assert s[: len(left)] == left
108 | assert s[-1] == "}"
109 | return s[len(left) : -1]
110 | except:
111 | return ""
112 |
113 |
114 | def get_unnormalized_answer(text: str) -> str:
115 | INVALID_ANSWER = "[invalidanswer]"
116 | end_seq = "I hope it is correct."
117 | text += end_seq
118 | match = re.search(
119 | r"Final Answer: The final answer is(.*?). I hope it is correct.",
120 | text,
121 | )
122 | if match:
123 | return match.group(1).strip()
124 | else:
125 | return INVALID_ANSWER
126 |
127 |
128 | SUBSTITUTIONS = [
129 | ("an ", ""),
130 | ("a ", ""),
131 | (".$", "$"),
132 | ("\\$", ""),
133 | (r"\ ", ""),
134 | (" ", ""),
135 | ("\\dfrac", "\\frac"),
136 | ("mbox", "text"),
137 | (",\\text{and}", ","),
138 | ("\\text{and}", ","),
139 | ("\\text{m}", "\\text{}"),
140 | ("π", "\\pi"),
141 | ]
142 | REMOVED_EXPRESSIONS = [
143 | "square",
144 | "ways",
145 | "integers",
146 | "dollars",
147 | "mph",
148 | "inches",
149 | "\\left",
150 | "\\big",
151 | "\\Big",
152 | "\\Bigg",
153 | "\\bigg",
154 | "\\right",
155 | "ft",
156 | "hours",
157 | "km",
158 | "units",
159 | "\\ldots",
160 | "sue",
161 | "points",
162 | "feet",
163 | "minutes",
164 | "digits",
165 | "cents",
166 | "degrees",
167 | "cm",
168 | "gm",
169 | "pounds",
170 | "meters",
171 | "meals",
172 | "edges",
173 | "students",
174 | "childrentickets",
175 | "multiples",
176 | "\\text{s}",
177 | "\\text{.}",
178 | "\\text{\ns}",
179 | "\\text{}^2",
180 | "\\text{}^3",
181 | "\\text{\n}",
182 | "\\text{}",
183 | r"\mathrm{th}",
184 | r"^\circ",
185 | r"^{\circ}",
186 | r"\;",
187 | r",\!",
188 | "{,}",
189 | '"',
190 | "\\dots",
191 | "\%",
192 | ]
193 |
194 |
195 | def normalize_final_answer(final_answer: str) -> str:
196 | """
197 | Normalize a final answer to a quantitative reasoning question.
198 |
199 | Copied character for character from appendix D of Lewkowycz et al. (2022)
200 | """
201 | final_answer = final_answer.split("=")[-1]
202 |
203 | for before, after in SUBSTITUTIONS:
204 | final_answer = final_answer.replace(before, after)
205 | for expr in REMOVED_EXPRESSIONS:
206 | final_answer = final_answer.replace(expr, "")
207 |
208 | # Extract answer that is in LaTeX math, is bold,
209 | # is surrounded by a box, etc.
210 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
211 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
212 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
213 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
214 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
215 |
216 | # Normalize shorthand TeX:
217 | # \fracab -> \frac{a}{b}
218 | # \frac{abc}{bef} -> \frac{abc}{bef}
219 | # \fracabc -> \frac{a}{b}c
220 | # \sqrta -> \sqrt{a}
221 | # \sqrtab -> sqrt{a}b
222 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
223 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
224 | final_answer = final_answer.replace("$", "")
225 |
226 | # Normalize 100,000 -> 100000
227 | if final_answer.replace(",", "").isdigit():
228 | final_answer = final_answer.replace(",", "")
229 |
230 | return final_answer
231 |
232 |
233 | class timeout:
234 | def __init__(self, seconds=1, error_message="Timeout"):
235 | self.seconds = seconds
236 | self.error_message = error_message
237 |
238 | def handle_timeout(self, signum, frame):
239 | raise TimeoutError(self.error_message)
240 |
241 | def __enter__(self):
242 | signal.signal(signal.SIGALRM, self.handle_timeout)
243 | signal.alarm(self.seconds)
244 |
245 | def __exit__(self, type, value, traceback):
246 | signal.alarm(0)
247 |
248 |
249 | def is_equiv(x1: str, x2: str) -> bool:
250 | """
251 | x1 and x2 are normalized latex string
252 | """
253 | try:
254 | with timeout(seconds=5):
255 | try:
256 | parsed_x1 = parse_latex(x1)
257 | parsed_x2 = parse_latex(x2)
258 | except (
259 | sympy.parsing.latex.errors.LaTeXParsingError,
260 | sympy.SympifyError,
261 | TypeError,
262 | ):
263 | eval_logger.debug(f"couldn't parse one of {x1} or {x2}")
264 | return False
265 |
266 | try:
267 | diff = parsed_x1 - parsed_x2
268 | except TypeError:
269 | eval_logger.debug(f"couldn't subtract {x1} and {x2}")
270 | return False
271 |
272 | try:
273 | if sympy.simplify(diff) == 0:
274 | return True
275 | else:
276 | return False
277 | except ValueError:
278 | eval_logger.debug(
279 | f"Had some trouble simplifying when comparing {x1} and {x2}"
280 | )
281 | except TimeoutError:
282 | eval_logger.debug(f"Timed out comparing {x1} and {x2}")
283 | return False
284 | except ImportError as e:
285 | eval_logger.error(e)
286 | raise
287 | except Exception as e:
288 | eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}")
289 | return False
290 |
291 |
292 | def fix_fracs(string):
293 | substrs = string.split("\\frac")
294 | new_str = substrs[0]
295 | if len(substrs) > 1:
296 | substrs = substrs[1:]
297 | for substr in substrs:
298 | new_str += "\\frac"
299 | if substr[0] == "{":
300 | new_str += substr
301 | else:
302 | try:
303 | assert len(substr) >= 2
304 | except AssertionError:
305 | return string
306 | a = substr[0]
307 | b = substr[1]
308 | if b != "{":
309 | if len(substr) > 2:
310 | post_substr = substr[2:]
311 | new_str += "{" + a + "}{" + b + "}" + post_substr
312 | else:
313 | new_str += "{" + a + "}{" + b + "}"
314 | else:
315 | if len(substr) > 2:
316 | post_substr = substr[2:]
317 | new_str += "{" + a + "}" + b + post_substr
318 | else:
319 | new_str += "{" + a + "}" + b
320 | string = new_str
321 | return string
322 |
323 |
324 | def fix_a_slash_b(string):
325 | if len(string.split("/")) != 2:
326 | return string
327 | a = string.split("/")[0]
328 | b = string.split("/")[1]
329 | try:
330 | a = int(a)
331 | b = int(b)
332 | assert string == "{}/{}".format(a, b)
333 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
334 | return new_string
335 | except AssertionError:
336 | return string
337 |
338 |
339 | def remove_right_units(string):
340 | # "\\text{ " only ever occurs (at least in the val set) when describing units
341 | if "\\text{ " in string:
342 | splits = string.split("\\text{ ")
343 | assert len(splits) == 2
344 | return splits[0]
345 | else:
346 | return string
347 |
348 |
349 | def fix_sqrt(string):
350 | if "\\sqrt" not in string:
351 | return string
352 | splits = string.split("\\sqrt")
353 | new_string = splits[0]
354 | for split in splits[1:]:
355 | if split[0] != "{":
356 | a = split[0]
357 | new_substr = "\\sqrt{" + a + "}" + split[1:]
358 | else:
359 | new_substr = "\\sqrt" + split
360 | new_string += new_substr
361 | return new_string
362 |
363 |
364 | def strip_string(string):
365 | # linebreaks
366 | string = string.replace("\n", "")
367 |
368 | # remove inverse spaces
369 | string = string.replace("\\!", "")
370 |
371 | # replace \\ with \
372 | string = string.replace("\\\\", "\\")
373 |
374 | # replace tfrac and dfrac with frac
375 | string = string.replace("tfrac", "frac")
376 | string = string.replace("dfrac", "frac")
377 |
378 | # remove \left and \right
379 | string = string.replace("\\left", "")
380 | string = string.replace("\\right", "")
381 |
382 | # Remove circ (degrees)
383 | string = string.replace("^{\\circ}", "")
384 | string = string.replace("^\\circ", "")
385 |
386 | # remove dollar signs
387 | string = string.replace("\\$", "")
388 |
389 | # remove units (on the right)
390 | string = remove_right_units(string)
391 |
392 | # remove percentage
393 | string = string.replace("\\%", "")
394 | string = string.replace("\%", "") # noqa: W605
395 |
396 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
397 | string = string.replace(" .", " 0.")
398 | string = string.replace("{.", "{0.")
399 | # if empty, return empty string
400 | if len(string) == 0:
401 | return string
402 | if string[0] == ".":
403 | string = "0" + string
404 |
405 | # to consider: get rid of e.g. "k = " or "q = " at beginning
406 | if len(string.split("=")) == 2:
407 | if len(string.split("=")[0]) <= 2:
408 | string = string.split("=")[1]
409 |
410 | # fix sqrt3 --> sqrt{3}
411 | string = fix_sqrt(string)
412 |
413 | # remove spaces
414 | string = string.replace(" ", "")
415 |
416 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
417 | string = fix_fracs(string)
418 |
419 | # manually change 0.5 --> \frac{1}{2}
420 | if string == "0.5":
421 | string = "\\frac{1}{2}"
422 |
423 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
424 | string = fix_a_slash_b(string)
425 |
426 | return string
427 |
428 |
429 | def hendrycks_is_equiv(str1, str2, verbose=False):
430 | if str1 is None and str2 is None:
431 | print("WARNING: Both None")
432 | return True
433 | if str1 is None or str2 is None:
434 | return False
435 |
436 | try:
437 | ss1 = strip_string(str1)
438 | ss2 = strip_string(str2)
439 | if verbose:
440 | print(ss1, ss2)
441 | return ss1 == ss2
442 | except Exception:
443 | return str1 == str2
444 |
445 |
446 | def get_last_option(text):
447 | pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)"
448 | match = re.search(pattern, text, re.DOTALL)
449 | if match:
450 | return match.group(0)
451 | else:
452 | return None
453 |
454 |
455 | def verify_ifeval_sample(model_output, constraint):
456 | model_output = model_output.split("<|assistant|>\n")[-1].strip()
457 | # TODO: just pass in final answer. this should be fine for other evals too.
458 | answer = model_output.split("<|assistant|>\n")[-1].strip()
459 | if isinstance(constraint, str):
460 | constraint = json.loads(constraint)
461 | if "func_name" not in constraint:
462 | print("WARNING: constraint missing func_name")
463 | print(constraint)
464 | return False
465 | # first, parse out the constraint string.
466 | func_name = constraint.pop("func_name")
467 | # get the function
468 | func = IF_FUNCTIONS_MAP[func_name]
469 | # now, run the function
470 | # pop out any none args
471 | non_none_args = {k: v for k, v in constraint.items() if v is not None}
472 | # sometimes we have extra args, sometimes not.
473 | if len(constraint) == 0:
474 | return func(model_output)
475 | return func(answer, **non_none_args)
476 |
--------------------------------------------------------------------------------
/qalign/utils/mbr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import Counter
3 |
4 | from typing import *
5 |
6 | from multiprocessing import Pool
7 | import numpy as np
8 | import os
9 |
10 | from tqdm import tqdm
11 | from transformers import AutoTokenizer
12 |
13 | from scipy.sparse import csr_matrix
14 |
15 |
16 | ## qalign
17 | from qalign.utils.math import generate_axis
18 | from quest.utils.list import (
19 | chunked,
20 | )
21 |
22 |
23 | class UCS:
24 |
25 | def __init__(
26 | self,
27 | process_num=None,
28 | vocab_size=1000,
29 | ):
30 |
31 | self.process_num = os.cpu_count() if process_num is None else process_num
32 | self.vocab_size = vocab_size
33 | self.func = ucs_sparse
34 |
35 | def _process_chunk(self, chunk):
36 | """
37 | Process a chunk of data - this is needed for multiprocessing.
38 |
39 | Args:
40 | chunk: A subset of samples to process
41 |
42 | Returns:
43 | UCS matrix for the chunk
44 | """
45 | return self.func(chunk, vocab_size=self.vocab_size)
46 |
47 | def get_score(self, samples, compute_in_parallel=True):
48 |
49 | if compute_in_parallel:
50 | with Pool(self.process_num) as executor:
51 | results = list(
52 | tqdm(
53 | executor.map(
54 | self._process_chunk,
55 | samples,
56 | ),
57 | total=len(samples),
58 | desc="Processing chunks",
59 | )
60 | )
61 |
62 | mat = np.array(results)
63 | else:
64 | mat = np.array(
65 | [
66 | self.func(generations, vocab_size=self.vocab_size)
67 | for generations in samples
68 | ]
69 | )
70 |
71 | return mat
72 |
73 |
74 | class FastROUGE(UCS):
75 | def __init__(
76 | self,
77 | process_num=None,
78 | vocab_size=1000,
79 | ):
80 |
81 | self.process_num = os.cpu_count() if process_num is None else process_num
82 | self.vocab_size = vocab_size
83 | self.func = rouge1_sparse
84 |
85 |
86 | def join_instances_accepts(i, is_quest=True):
87 |
88 | if is_quest:
89 | accepted_outputs = []
90 | accepted_index = []
91 | last_accepted = None
92 | last_accept_index = None
93 | for j, output in enumerate(i["outputs"]):
94 | if output["accept"]:
95 | accepted_outputs.append(output)
96 | accepted_index.append(j)
97 | last_accepted = output
98 | last_accept_index = j
99 | elif last_accepted is not None:
100 | accepted_outputs.append(last_accepted)
101 | accepted_index.append(last_accept_index)
102 |
103 | else:
104 | accepted_outputs = i["outputs"]
105 | accepted_index = list(range(len(i["outputs"])))
106 |
107 | # print(len(outputs))
108 | return {"input": i["input"], "outputs": accepted_outputs, "index": accepted_index}
109 |
110 |
111 | def get_mats(completions, compute_in_parallel=True, metric="bleu", vocab_size=1000):
112 |
113 | if "ucs" in metric:
114 | s = UCS(vocab_size=vocab_size)
115 | elif "rouge" in metric:
116 | s = FastROUGE(vocab_size=vocab_size) # PairwiseRouge(metric=metric)
117 | else:
118 | raise ValueError("metric must be bleu or rouge")
119 |
120 | mats = s.get_score(completions, compute_in_parallel=compute_in_parallel)
121 |
122 | return mats
123 |
124 |
125 | def ucs_sparse(generations, vocab_size):
126 | """
127 | Compute Unigram Consistency Score (UCS) matrix for lists of integers using
128 | a vectorized approach for better efficiency.
129 |
130 | Args:
131 | int_lists: List of integer lists where each list represents a sequence of elements
132 | vocab_size: Size of the vocabulary |V| to use in the UCS calculation
133 |
134 | Returns:
135 | UCS matrix where UCS[i, j] represents the consistency score between
136 | lists i and j
137 | """
138 | n = len(generations)
139 |
140 | # Check if provided vocab_size is valid
141 | if vocab_size <= 0:
142 | raise ValueError("Vocabulary size must be greater than 0")
143 |
144 | # Create a binary matrix of shape (n, vocab_size) where binary_matrix[i, j] = 1
145 | # if element j is in list i, 0 otherwise
146 |
147 | # Method 1: Using sparse matrices (more efficient for large vocab_size)
148 | rows = []
149 | cols = []
150 | data = []
151 |
152 | for i, int_list in enumerate(generations):
153 | for elem in set(int_list): # Use set to count each unique element once
154 | if 0 <= elem < vocab_size: # Ensure the element is in range
155 | rows.append(i)
156 | cols.append(elem)
157 | data.append(1)
158 |
159 | # Create sparse matrix
160 | binary_matrix = csr_matrix((data, (rows, cols)), shape=(n, vocab_size))
161 |
162 | # Compute UCS matrix using matrix multiplication
163 | # (binary_matrix @ binary_matrix.T) gives the dot product of each pair of rows
164 | dot_products = binary_matrix @ binary_matrix.T
165 |
166 | # Convert to dense array and divide by vocab_size
167 | ucs_matrix = dot_products.toarray() / vocab_size
168 |
169 | return ucs_matrix
170 |
171 |
172 | def rouge1_sparse(generations, vocab_size):
173 | """
174 | Compute Unigram Consistency Score (UCS) matrix for lists of integers using
175 | a vectorized approach for better efficiency.
176 |
177 | Args:
178 | int_lists: List of integer lists where each list represents a sequence of elements
179 | vocab_size: Size of the vocabulary |V| to use in the UCS calculation
180 |
181 | Returns:
182 | UCS matrix where UCS[i, j] represents the consistency score between
183 | lists i and j
184 | """
185 | n = len(generations)
186 |
187 | # Check if provided vocab_size is valid
188 | if vocab_size <= 0:
189 | raise ValueError("Vocabulary size must be greater than 0")
190 |
191 | # Create a binary matrix of shape (n, vocab_size) where binary_matrix[i, j] = 1
192 | # if element j is in list i, 0 otherwise
193 |
194 | # Method 1: Using sparse matrices (more efficient for large vocab_size)
195 | rows = []
196 | cols = []
197 | data = []
198 | # Store total word counts for each generation (for normalization)
199 | word_counts = np.zeros(n)
200 |
201 | for i, int_list in enumerate(generations):
202 | # Count total words in this generation (including duplicates)
203 | word_counts[i] = len(int_list)
204 |
205 | # Count frequency of each word in this generation
206 | # word_freq = {}
207 | # for word in int_list:
208 | # if 0 <= word < vocab_size: # Ensure word is in valid range
209 | # word_freq[word] = word_freq.get(word, 0) + 1
210 |
211 | word_freq = Counter(int_list)
212 |
213 | # Add to sparse matrix data
214 | for word, freq in word_freq.items():
215 | if 0 <= word < vocab_size:
216 | rows.append(i)
217 | cols.append(word)
218 | data.append(freq) # S
219 |
220 | count_matrix = csr_matrix((data, (rows, cols)), shape=(n, vocab_size))
221 |
222 | overlap_matrix = np.zeros((n, n), dtype=np.float32)
223 |
224 | # Convert to binary matrix first (1 where count > 0)
225 | binary_matrix = count_matrix.copy()
226 | binary_matrix.data = np.ones_like(binary_matrix.data)
227 |
228 | # Get common token indicators using matrix multiplication
229 | common_tokens = binary_matrix @ binary_matrix.T
230 |
231 | # For each pair of rows, calculate the overlap
232 | for i in range(n):
233 | row_i = count_matrix.getrow(i)
234 |
235 | for j in range(i, n):
236 | # Only process if there are any common tokens
237 | if common_tokens[i, j] > 0:
238 | row_j = count_matrix.getrow(j)
239 |
240 | # Get common indices
241 | i_indices = row_i.indices
242 | i_data = row_i.data
243 | j_indices = row_j.indices
244 | j_data = row_j.data
245 |
246 | # Find the intersection of indices
247 | # This is more efficient than multiplying sparse matrices
248 | i_dict = dict(zip(i_indices, i_data))
249 | j_dict = dict(zip(j_indices, j_data))
250 |
251 | common_indices = set(i_indices).intersection(set(j_indices))
252 |
253 | # Calculate the sum of minimums
254 | overlap = sum(min(i_dict[idx], j_dict[idx]) for idx in common_indices)
255 |
256 | overlap_matrix[i, j] = overlap
257 | if i != j:
258 | overlap_matrix[j, i] = overlap
259 |
260 | # Calculate ROUGE-1 recall (overlap / words in reference)
261 | # We can use broadcasting to divide by word_counts
262 | recall_matrix = np.zeros((n, n))
263 | nonzero_counts = word_counts > 0
264 | if np.any(nonzero_counts):
265 | recall_matrix[nonzero_counts, :] = (
266 | overlap_matrix[nonzero_counts, :] / word_counts[nonzero_counts, np.newaxis]
267 | )
268 |
269 | # Calculate ROUGE-1 precision (overlap / words in candidate)
270 | precision_matrix = np.zeros((n, n))
271 | if np.any(nonzero_counts):
272 | precision_matrix[:, nonzero_counts] = (
273 | overlap_matrix[:, nonzero_counts] / word_counts[np.newaxis, nonzero_counts]
274 | )
275 |
276 | # Calculate ROUGE-1 F1 score
277 | f1_matrix = np.zeros((n, n))
278 | nonzero = (precision_matrix + recall_matrix) > 0
279 | f1_matrix[nonzero] = (
280 | 2
281 | * (precision_matrix[nonzero] * recall_matrix[nonzero])
282 | / (precision_matrix[nonzero] + recall_matrix[nonzero])
283 | )
284 |
285 | return {
286 | "recall": recall_matrix,
287 | "precision": precision_matrix,
288 | "f1": f1_matrix,
289 | "overlap": overlap_matrix,
290 | "word_counts": word_counts,
291 | }["f1"]
292 |
293 |
294 | def mbr_mat_progression(
295 | exp,
296 | compute_in_parallel=True,
297 | k=1,
298 | n=None,
299 | max_steps=None,
300 | metric="bleu",
301 | ):
302 |
303 | if max_steps is None:
304 | max_steps = exp.get("steps")
305 |
306 | # total_instances = exp.instances(lazy_iterable=True)
307 |
308 | # if n is not None:
309 | # total_instances = total_instances[:n]
310 |
311 | # if "quest" not in exp.get("variant"):
312 | tokenizer = AutoTokenizer.from_pretrained(exp.get("model_path"))
313 |
314 | mats = []
315 |
316 | repeat_inds = []
317 |
318 | for instances in tqdm(chunked(exp.instances(lazy_iterable=True), 32)):
319 | # instances = [instance]
320 |
321 | instances = [
322 | join_instances_accepts(i, is_quest="quest" in exp.get("variant"))
323 | for i in instances
324 | ]
325 |
326 | repeat_inds.extend([i["index"] for i in instances])
327 |
328 | if "quest" not in exp.get("variant"):
329 |
330 | texts = [
331 | o["text"]
332 | for instance in instances
333 | for o in instance["outputs"][:max_steps] # Already uniform length
334 | ]
335 |
336 | # Batch tokenize everything at once
337 | batch_ids = tokenizer(texts)["input_ids"]
338 |
339 | # Reshape into [num_instances, max_steps, ...] using fixed chunk size
340 | completions = [
341 | batch_ids[i : i + max_steps]
342 | for i in range(0, len(batch_ids), max_steps)
343 | ]
344 |
345 | else:
346 | completions = [
347 | [o["completion"] for o in instance["outputs"][:max_steps]]
348 | for instance in instances
349 | ]
350 |
351 | mat = get_mats(
352 | completions,
353 | compute_in_parallel=compute_in_parallel,
354 | metric=metric,
355 | vocab_size=tokenizer.vocab_size,
356 | )
357 |
358 | mats.append(mat)
359 |
360 | del instances
361 | del completions
362 |
363 | mat = np.concatenate(mats, axis=0)
364 |
365 | return mat, repeat_inds
366 |
367 |
368 | def weighted_mbr_pick_progression(
369 | exp,
370 | reward_key,
371 | compute_in_parallel=True,
372 | k=1,
373 | n=None,
374 | max_steps=None,
375 | metric="bleu",
376 | ):
377 |
378 | if max_steps is None:
379 | max_steps = exp.get("steps")
380 |
381 | mat, repeat_inds = mbr_mat_progression(
382 | exp,
383 | compute_in_parallel=compute_in_parallel,
384 | k=k,
385 | n=n,
386 | max_steps=max_steps,
387 | metric=metric,
388 | )
389 |
390 | rewards = exp.get_eval(reward_key)
391 |
392 | repeat_rewards = np.array(
393 | [
394 | [r["scores"][i] for i in inds][:max_steps]
395 | for inds, r in zip(repeat_inds, rewards)
396 | ]
397 | )
398 |
399 | repeat_rewards = np.expand_dims(repeat_rewards, axis=1)
400 |
401 | mat *= repeat_rewards
402 |
403 | axis = generate_axis(max_steps + 1, k)
404 |
405 | return pick_mat(mat, axis, repeat_inds)
406 |
407 |
408 | def mbr_pick_progression(
409 | exp,
410 | compute_in_parallel=True,
411 | k=1,
412 | n=None,
413 | max_steps=None,
414 | metric="bleu",
415 | ):
416 |
417 | if max_steps is None:
418 | max_steps = exp.get("steps")
419 |
420 | mat, repeat_inds = mbr_mat_progression(
421 | exp,
422 | compute_in_parallel=compute_in_parallel,
423 | k=k,
424 | n=n,
425 | max_steps=max_steps,
426 | metric=metric,
427 | )
428 |
429 | axis = generate_axis(max_steps + 1, k)
430 |
431 | # 1,2,4,8,16,32,64,96,128,160,192,224,256
432 |
433 | return pick_mat(mat, axis, repeat_inds)
434 |
435 |
436 | def pick_mat(
437 | mat,
438 | axis,
439 | repeat_inds,
440 | ):
441 |
442 | values = []
443 |
444 | preds = {}
445 | for ni in axis:
446 | pick_batch = np.argmax(mat[:, :ni, :ni].mean(axis=-1), axis=-1)
447 | # acc = []
448 |
449 | preds[ni] = []
450 | for i, pick in enumerate(
451 | pick_batch,
452 | ):
453 | # true_answer = extract_func(instance["input"]["answer"])
454 | pred_index = repeat_inds[i][pick]
455 | preds[ni].append(pred_index)
456 | # acc.append(correct)
457 |
458 | return {"preds": preds, "axis": axis}
459 |
--------------------------------------------------------------------------------
/qalign/utils/pred.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import re
3 | from functools import partial
4 | from tqdm import tqdm
5 | from multiprocessing import Pool
6 | import numpy as np
7 |
8 |
9 | from typing import *
10 |
11 | from transformers import AutoTokenizer
12 |
13 | ## quest
14 | from quest.utils.list import chunked
15 |
16 | ## qalign
17 | from qalign.utils.mbr import mbr_pick_progression, weighted_mbr_pick_progression
18 | from qalign.utils.math import (
19 | get_last_number,
20 | get_last_math,
21 | get_last_option,
22 | generate_axis,
23 | )
24 |
25 | ## expkit
26 | from expkit import (
27 | Exp,
28 | Evalutor,
29 | )
30 |
31 |
32 | def get_strategy(
33 | strategy: str,
34 | key: str,
35 | p=None,
36 | beta=1.0,
37 | c=1.0,
38 | gaps=64,
39 | exp_rate=False,
40 | **kwargs,
41 | ) -> Evalutor:
42 |
43 | if strategy == "voting":
44 | pick = VotingPick(gaps=gaps, exp_rate=exp_rate, multiprocessing=p, **kwargs)
45 |
46 | elif strategy == "bon":
47 | pick = BonPick(
48 | key=key, gaps=gaps, exp_rate=exp_rate, multiprocessing=p, **kwargs
49 | )
50 |
51 | # setup = setup.filter(lambda x: x.has_eval(key))
52 |
53 | elif strategy == "weighted-voting":
54 | pick = WeightedVotingPick(
55 | key=key,
56 | gaps=gaps,
57 | exp_rate=exp_rate,
58 | multiprocessing=p,
59 | beta=beta,
60 | **kwargs,
61 | )
62 | # setup = setup.filter(lambda x: x.has_eval(key))
63 | elif strategy == "last":
64 | pick = LastPick(gaps=gaps, exp_rate=exp_rate, multiprocessing=p, **kwargs)
65 | elif strategy == "mbr":
66 | pick = MBRPick(gaps=gaps, **kwargs)
67 | elif strategy == "wmbr":
68 | pick = WMBRPick(gaps=gaps, key=key, exp_rate=exp_rate, **kwargs)
69 | elif strategy == "bonresample":
70 | pick = BonResamplePick(
71 | key=key, gaps=gaps, exp_rate=exp_rate, multiprocessing=p, **kwargs
72 | )
73 | else:
74 | raise ValueError("Invalid strategy")
75 |
76 | return pick
77 |
78 |
79 | def get_param_count(exp):
80 | # example : model_path = "meta-llama/Llama-3.1-8B-Instruct"
81 | # or model_path = "allenai/olmo-2-2b-it"
82 | # I need to look for B or b and then get the number before it
83 |
84 | # get the number before B or b
85 | # get the number before B or b
86 |
87 | model_path = exp.meta["model_path"]
88 | match = re.search(r"(\d+)[Bb]", model_path)
89 | if match:
90 | return int(match.group(1))
91 | else:
92 | raise ValueError("No match found")
93 |
94 |
95 | def get_flops_by_scores(exp, key="", strategy="voting"):
96 |
97 | with Pool() as p:
98 | eval_result = get_strategy(
99 | strategy=strategy,
100 | p=p,
101 | exp_rate=True,
102 | gaps=2,
103 | key=key,
104 | ).eval(exp)
105 |
106 | axis = eval_result[0]["axis"]
107 |
108 | tokenizer = AutoTokenizer.from_pretrained(exp.meta["model_path"])
109 | token_counts = []
110 | for instance in exp.instances():
111 |
112 | if exp.meta["variant"] == "ancestral":
113 |
114 | ids = tokenizer([o["text"] for o in instance["outputs"]], padding=False)
115 |
116 | token_counts.append(list(map(len, ids["input_ids"])))
117 | else:
118 |
119 | token_counts.append(
120 | [len(o["completion"][o["index"] :]) for o in instance["outputs"]]
121 | )
122 | # all_criteria.append(x)
123 | token_counts = np.array(token_counts)
124 |
125 | # sample a few points from axis ...
126 |
127 | D_test = []
128 | for i in axis:
129 | D_test.append(np.sum(token_counts[:, :i], axis=1).mean(0))
130 |
131 | ## billion in scientific notation
132 | ##
133 | IFLOPS = [
134 | 2 * 2 * float(get_param_count(exp)) * v for v in D_test
135 | ] # one extra 2 for reward.
136 | return {
137 | "IFLOPS": IFLOPS,
138 | "D_test": D_test,
139 | "scores": eval_result[0]["scores"],
140 | }
141 |
142 |
143 | def get_flops_prev(exp, axis):
144 |
145 | tokenizer = AutoTokenizer.from_pretrained(exp.meta["model_path"])
146 | token_counts = []
147 | for instance in exp.instances():
148 |
149 | if exp.meta["variant"] == "ancestral":
150 |
151 | ids = tokenizer([o["text"] for o in instance["outputs"]], padding=False)
152 |
153 | token_counts.append(list(map(len, ids["input_ids"])))
154 | else:
155 |
156 | token_counts.append(
157 | [len(o["completion"][o["index"] :]) for o in instance["outputs"]]
158 | )
159 | # all_criteria.append(x)
160 | token_counts = np.array(token_counts)
161 |
162 | # sample a few points from axis ...
163 |
164 | D_test = []
165 | for i in axis:
166 | D_test.append(np.sum(token_counts[:, :i], axis=1).mean(0))
167 |
168 | ## billion in scientific notation
169 | ##
170 | IFLOPS = [
171 | 2 * float(get_param_count(exp)) * v for v in D_test
172 | ] # one extra 2 for reward.
173 | return {
174 | "IFLOPS": IFLOPS,
175 | "D_test": D_test,
176 | }
177 |
178 |
179 | def real_avg_token_counts(exp, sweep, k=64):
180 | tokenizer = AutoTokenizer.from_pretrained(exp.meta["model_path"])
181 | token_counts = []
182 |
183 | instances = exp.instances(lazy_iterable=True)
184 |
185 | for i, instance in enumerate(
186 | tqdm(instances, desc="Processing Instances", total=k + 1)
187 | ):
188 |
189 | if exp.meta["variant"] == "ancestral":
190 |
191 | ids = tokenizer([o["text"] for o in instance["outputs"]], padding=False)
192 |
193 | token_counts.append(list(map(len, ids["input_ids"])))
194 | else:
195 |
196 | token_counts.append(
197 | [len(o["completion"][o["index"] :]) for o in instance["outputs"]]
198 | )
199 |
200 | if i > k:
201 | break
202 |
203 | token_counts = np.array(token_counts)
204 |
205 | avg_token_counts = np.mean(token_counts, axis=0)
206 |
207 | if exp.meta["variant"] == "ancestral":
208 | if ("bon" in sweep) or ("weighted" in sweep) or ("wmbr" in sweep):
209 | avg_token_counts *= 2
210 |
211 | else: # questrlhf
212 | avg_token_counts *= 2
213 |
214 | return avg_token_counts
215 |
216 |
217 | def fake_avg_token_counts(exp, sweep, d=500):
218 | avg_token_counts = np.array([d] * exp.get("steps")) # * n
219 |
220 | if "DPO" in exp.meta["model_path"]:
221 | avg_token_counts = avg_token_counts * 1.5
222 |
223 | else:
224 |
225 | # sample a few points from axis ...
226 | if exp.meta["variant"] == "ancestral":
227 | if ("bon" in sweep) or ("weighted" in sweep) or ("wmbr" in sweep):
228 | avg_token_counts *= 2
229 |
230 | else: # questrlhf
231 | avg_token_counts[0] *= 2
232 | # avg_token_counts[1:] = avg_token_counts[1:] /
233 |
234 | return avg_token_counts
235 |
236 |
237 | def get_flops(exp, axis, sweep, n=None, fake=True):
238 |
239 | # tokenizer = AutoTokenizer.from_pretrained(exp.meta["model_path"])
240 |
241 | if n is None:
242 | n = exp.meta["n"]
243 |
244 | if fake:
245 | avg_token_counts = fake_avg_token_counts(exp, sweep=sweep)
246 | else:
247 | avg_token_counts = real_avg_token_counts(exp, sweep=sweep)
248 |
249 | D_test = []
250 | for i in axis:
251 | D_test.append(np.sum(avg_token_counts[:i]))
252 |
253 | ## billion in scientific notation
254 | ##
255 | IFLOPS = [
256 | 2 * float(get_param_count(exp)) * v for v in D_test
257 | ] # one extra 2 for reward.
258 | return {
259 | "IFLOPS": IFLOPS,
260 | "D_test": D_test,
261 | }
262 |
263 |
264 | def softmax(x, temp=1.0):
265 | x = np.array(x)
266 | e_x = np.exp(x / temp)
267 | return e_x / e_x.sum()
268 |
269 |
270 | def log_softmax(x, temp=1.0):
271 | x = np.array(x)
272 | x = x / temp
273 | x_max = np.max(x)
274 | log_sum_exp = np.log(np.sum(np.exp(x - x_max))) + x_max
275 | return x - log_sum_exp
276 |
277 |
278 | def is_weighted_mode_correct(
279 | instance_and_score, n=512, beta=1.0, extract_func=get_last_number
280 | ):
281 | instance, scores = instance_and_score
282 |
283 | r = scores[:n]
284 | p = softmax(r, temp=beta)
285 | answer = extract_func(instance["input"]["answer"])
286 | outputs = instance["outputs"][:n]
287 |
288 | responses = [
289 | (extract_func(output["text"]), float(pi)) for output, pi in zip(outputs, p)
290 | ]
291 |
292 | response_dict = {}
293 | for response, prob in responses:
294 | if response in response_dict:
295 | response_dict[response] += prob
296 | else:
297 | response_dict[response] = prob
298 |
299 | if answer in response_dict:
300 | correct_estimate = response_dict[answer]
301 | else:
302 | correct_estimate = 0
303 |
304 | max_response = max(response_dict, key=response_dict.get)
305 |
306 | return int(max_response == answer), correct_estimate
307 |
308 |
309 | def get_last(instance, n=512, extract_func=get_last_number):
310 |
311 | outputs = instance["outputs"][:n]
312 | responses = [extract_func(output["text"]) for output in outputs]
313 | answer = extract_func(instance["input"]["answer"])
314 |
315 | return int(responses[-1] == answer), 0.0
316 |
317 |
318 | def is_mode_correct(instance, n=512, burnin=0, extract_func=get_last_number):
319 |
320 | outputs = instance["outputs"][burnin:n]
321 |
322 | answer = extract_func(instance["input"]["answer"])
323 |
324 | responses = [extract_func(output["text"]) for output in outputs]
325 |
326 | # compute the most common element in accepted
327 |
328 | c = Counter(responses)
329 |
330 | if answer in c:
331 | correct_estimate = c[answer] / len(responses)
332 |
333 | else:
334 | correct_estimate = 0
335 |
336 | if len(c) == 0:
337 | most_common = "NaN"
338 | else:
339 | most_common = c.most_common(1)[0][0]
340 |
341 | return int(most_common == answer), correct_estimate
342 |
343 |
344 | def is_max_correct(instance_and_score, n=512):
345 | scores, gt = instance_and_score
346 | max_index = np.argmax(scores["scores"][:n])
347 |
348 | return gt["scores"][max_index]
349 |
350 |
351 | # max_response = extract_func(instance["outputs"][max_index]["text"])
352 | # # compute the most common element in accepted
353 | # return int(max_response == extract_func(instance["input"]["answer"]))
354 |
355 |
356 | def is_resample_correct(
357 | instance_and_score, n=512, extract_func=get_last_number, r=128, trials=16
358 | ):
359 | instance, scores = instance_and_score
360 | outputs = instance["outputs"][:n]
361 | scores = np.array(scores[:n])
362 | answer = extract_func(instance["input"]["answer"])
363 |
364 | def shuffle(x):
365 | np.random.shuffle(x)
366 | return x
367 |
368 | # sample 16 times a sublist of scores of size "r"
369 |
370 | inds = [shuffle(np.arange(n))[:r].tolist() for _ in range(trials)]
371 |
372 | pick_ind = [indsi[np.argmax(scores[indsi])] for indsi in inds]
373 | responses = [extract_func(outputs[i]["text"]) for i in pick_ind]
374 |
375 | c = Counter(responses)
376 |
377 | if answer in c:
378 | correct_estimate = c[answer] / len(responses)
379 |
380 | else:
381 | correct_estimate = 0
382 |
383 | if len(c) == 0:
384 | most_common = "NaN"
385 | else:
386 | most_common = c.most_common(1)[0][0]
387 |
388 | return int(most_common == answer)
389 |
390 |
391 | def join_instances_and_repeat_accepts(instances, is_quest=True):
392 | return _join_instances_and_repeat_accepts(instances, is_quest=is_quest)[0]
393 |
394 |
395 | def join_instances_and_repeat_accepts_with_scores(
396 | instances,
397 | evals_scores,
398 | is_quest=True,
399 | ):
400 | return _join_instances_and_repeat_accepts(
401 | instances, is_quest=is_quest, evals_scores=evals_scores
402 | )
403 |
404 |
405 | def _join_instances_and_repeat_accepts(instances, is_quest=True, evals_scores=None):
406 | inputs = instances[0]["input"]
407 | outputs = []
408 | scores = []
409 |
410 | if is_quest:
411 | outputs_cat = []
412 | scores_cat = []
413 | for i, es in zip(instances, evals_scores or [None] * len(instances)):
414 | accepted_outputs = []
415 | accepted_scores = []
416 | last_accepted = None
417 | for output, s in zip(
418 | i["outputs"], es["scores"] if es else [None] * len(i["outputs"])
419 | ):
420 | if output["accept"]:
421 | accepted_outputs.append(output)
422 | if es:
423 | accepted_scores.append(s)
424 | last_accepted = (output, s)
425 | elif last_accepted is not None:
426 | accepted_outputs.append(last_accepted[0])
427 | if es:
428 | accepted_scores.append(last_accepted[1])
429 |
430 | outputs_cat.append(accepted_outputs)
431 | if es:
432 | scores_cat.append(accepted_scores)
433 |
434 | for i in range(len(outputs_cat[0])):
435 | for o in outputs_cat:
436 | outputs.append(o[i])
437 | if evals_scores:
438 | for s in scores_cat:
439 | scores.append(s[i])
440 | else:
441 | outputs = [j for i in instances for j in i["outputs"]]
442 | if evals_scores:
443 | scores = [s for es in evals_scores for s in es["scores"]]
444 |
445 | # print(len(outputs))
446 | return {"input": inputs, "outputs": outputs}, scores
447 |
448 |
449 | class VotingPick(Evalutor):
450 |
451 | def __init__(
452 | self,
453 | gaps=2,
454 | exp_rate=True,
455 | max_num_chains=None,
456 | multiprocessing=None,
457 | burnin=0,
458 | extract="lastnumber",
459 | n=None,
460 | **kwargs,
461 | ):
462 |
463 | super().__init__(
464 | extract + "-voting" if n is None else extract + "-voting-" + str(n)
465 | )
466 | self.gaps = gaps
467 | self.pmap = multiprocessing.map if multiprocessing else map
468 | self.exp_rate = exp_rate
469 | self.max_num_chains = max_num_chains
470 | self.burnin = burnin
471 | self.n = n
472 | self.chunk_size = 64
473 |
474 | if extract == "lastnumber":
475 | self.extract = get_last_number
476 | elif extract == "lastmath":
477 | self.extract = get_last_math
478 |
479 | elif extract == "lastoption":
480 | self.extract = get_last_option
481 | else:
482 | ValueError("Invalid extract function")
483 |
484 | def eval(self, exp: Exp):
485 |
486 | if self.gaps > 0:
487 | if not self.exp_rate:
488 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps)
489 | else:
490 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"])
491 |
492 | else:
493 | x_axis = [exp.meta["steps"]]
494 |
495 | # num_chains = exp.meta.get("num_chains", 1)
496 | # max_num_chains = self.max_num_chains if self.max_num_chains else num_chains
497 |
498 | # instances are in groups of num_chains .. i.e.
499 | # [0,0,0,1,1,1,2,2,2,3,3,3,4,4,4]
500 | # I want to group into :
501 | # [[0,0,0],[1,1,1],[2,2,2],[3,3,3],[4,4,4]]
502 | # so that I can compute the mean of the estimates
503 | # for each group
504 |
505 | accs_full = {xi: [] for xi in x_axis}
506 |
507 | c = 0
508 | for instance_chunk in tqdm(
509 | chunked(exp.instances(lazy_iterable=True), self.chunk_size)
510 | ):
511 |
512 | c += len(instance_chunk)
513 |
514 | if c > self.n:
515 | break
516 |
517 | is_quest = "quest" in exp.meta["variant"]
518 |
519 | # instances = total_instances
520 |
521 | instances = [
522 | join_instances_and_repeat_accepts([i], is_quest=is_quest)
523 | for i in instance_chunk
524 | ]
525 | # print(x_axis)
526 | # print(len(instances))
527 |
528 | for i in x_axis:
529 |
530 | results = self.pmap(
531 | partial(
532 | is_mode_correct,
533 | n=i * 1,
534 | burnin=0,
535 | extract_func=self.extract,
536 | ),
537 | instances,
538 | )
539 |
540 | accs, estimates = zip(*results)
541 |
542 | accs_full[i].extend(accs)
543 |
544 | # scores.append({"axis": x_axis, "scores": results})
545 | # x_results.append(float(np.mean(list((accs)))))
546 | # n_estimates.append((estimates))
547 | # accs_full.append(accs)
548 |
549 | # final_estimate = n_estimates[-1]
550 | # estimate = [float(np.mean(np.abs(i - final_estimate))) for i in n_estimates]
551 |
552 | results = [
553 | {
554 | "axis": x_axis,
555 | "scores": [float(np.mean(accs_full[i])) for i in x_axis],
556 | "accs": [accs_full[i] for i in x_axis],
557 | }
558 | ]
559 | return results
560 |
561 |
562 | class LastPick(Evalutor):
563 |
564 | def __init__(
565 | self,
566 | gaps=2,
567 | exp_rate=True,
568 | max_num_chains=None,
569 | multiprocessing=None,
570 | extract="lastnumber",
571 | n=None,
572 | **kwargs,
573 | ):
574 | super().__init__("last-pick" if n is None else "last-pick-" + str(n))
575 | self.gaps = gaps
576 | self.pmap = multiprocessing.map if multiprocessing else map
577 | self.exp_rate = exp_rate
578 | self.max_num_chains = max_num_chains
579 | self.n = n
580 | self.extract_key = extract
581 |
582 | if extract == "lastnumber":
583 | self.extract = get_last_number
584 | elif extract == "lastmath":
585 | self.extract = get_last_math
586 | elif extract == "lastoption":
587 | self.extract = get_last_option
588 | else:
589 | ValueError("Invalid extract function")
590 |
591 | def eval(self, exp: Exp):
592 |
593 | if self.gaps > 0:
594 | if not self.exp_rate:
595 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps)
596 | else:
597 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"])
598 |
599 | else:
600 | x_axis = [exp.meta["steps"]]
601 |
602 | num_chains = exp.meta.get("num_chains", 1)
603 | max_num_chains = self.max_num_chains if self.max_num_chains else num_chains
604 |
605 | # instances are in groups of num_chains .. i.e.
606 | # [0,0,0,1,1,1,2,2,2,3,3,3,4,4,4]
607 | # I want to group into :
608 | # [[0,0,0],[1,1,1],[2,2,2],[3,3,3],[4,4,4]]
609 | # so that I can compute the mean of the estimates
610 | # for each group
611 |
612 | is_quest = "quest" in exp.meta["variant"]
613 |
614 | instances = [
615 | join_instances_and_repeat_accepts(
616 | instances[i : i + num_chains][:max_num_chains], is_quest=is_quest
617 | )
618 | for i in range(0, len(instances), num_chains)
619 | ]
620 | # print(x_axis)
621 | # print(len(instances))
622 |
623 | x_results = []
624 | n_estimates = []
625 | accs_full = []
626 | for i in tqdm(x_axis):
627 | is_quest = "quest" in exp.meta["variant"]
628 |
629 | results = self.pmap(
630 | partial(
631 | get_last,
632 | n=i * max_num_chains,
633 | extract_func=self.extract,
634 | ),
635 | instances,
636 | )
637 |
638 | accs, estimates = zip(*results)
639 |
640 | # scores.append({"axis": x_axis, "scores": results})
641 | x_results.append(float(np.mean(list((accs)))))
642 | n_estimates.append((estimates))
643 | accs_full.append(accs)
644 |
645 | # final_estimate = n_estimates[-1]
646 | # estimate = [float(np.mean(np.abs(i - final_estimate))) for i in n_estimates]
647 |
648 | return [
649 | {
650 | "axis": x_axis,
651 | "scores": x_results,
652 | "estimates": n_estimates,
653 | "accs": accs_full,
654 | }
655 | ]
656 |
657 |
658 | class BonPick(Evalutor):
659 |
660 | def __init__(
661 | self,
662 | key,
663 | gaps=2,
664 | exp_rate=True,
665 | multiprocessing=None,
666 | extract="lastnumber",
667 | n=None,
668 | **kwargs,
669 | ):
670 | super().__init__(key + "-bon" if n is None else key + "-bon-" + str(n))
671 | self.reward_key = key
672 | self.gaps = gaps
673 | self.pmap = multiprocessing.map if multiprocessing else map
674 | self.exp_rate = exp_rate
675 | self.n = n
676 | self.extract_key = extract
677 |
678 | def eval(self, exp: Exp):
679 |
680 | assert exp.has_eval(
681 | self.reward_key
682 | ), f"Experiment does not have {self.reward_key} eval"
683 |
684 | if self.gaps > 0:
685 | if not self.exp_rate:
686 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps)
687 | else:
688 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"])
689 |
690 | else:
691 | x_axis = [exp.meta["steps"]]
692 |
693 | # instances = exp.instances()
694 |
695 | if self.n is not None:
696 | n = self.n
697 | else:
698 | n = exp.meta["n"]
699 | # instances = instances[: self.n]
700 |
701 | x_results = []
702 | for i in tqdm(x_axis):
703 | # spawn pool of workers
704 |
705 | results = list(
706 | self.pmap(
707 | partial(
708 | is_max_correct,
709 | n=i,
710 | ),
711 | zip(
712 | exp.get_eval(self.reward_key)[:n],
713 | exp.get_eval(self.extract_key)[:n],
714 | ),
715 | )
716 | )
717 |
718 | # scores.append({"axis": x_axis, "scores": results})
719 | x_results.append(float(np.mean(list((results)))))
720 |
721 | return [{"axis": x_axis, "scores": x_results}]
722 |
723 |
724 | class BonResamplePick(Evalutor):
725 |
726 | def __init__(
727 | self,
728 | key,
729 | gaps=2,
730 | exp_rate=True,
731 | multiprocessing=None,
732 | r=128,
733 | trials=32,
734 | extract="lastnumber",
735 | n=None,
736 | **kwargs,
737 | ):
738 | super().__init__(
739 | key + "-bonresample-" + str(r) + "-" + str(trials)
740 | if n is None
741 | else key + "-bonresample-" + str(r) + "-" + str(trials) + "-" + str(n)
742 | )
743 | self.reward_key = key
744 | self.gaps = gaps
745 | self.pmap = multiprocessing.map if multiprocessing else map
746 | self.exp_rate = exp_rate
747 | self.r = r
748 | self.trials = trials
749 | self.n = n
750 |
751 | if extract == "lastnumber":
752 | self.extract = get_last_number
753 | elif extract == "lastmath":
754 | self.extract = get_last_math
755 | elif extract == "lastoption":
756 | self.extract = get_last_option
757 | else:
758 | ValueError("Invalid extract function")
759 |
760 | def eval(self, exp: Exp):
761 |
762 | assert exp.has_eval(
763 | self.reward_key
764 | ), f"Experiment does not have {self.reward_key} eval"
765 |
766 | if self.gaps > 0:
767 | if not self.exp_rate:
768 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps)
769 | else:
770 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"])
771 |
772 | else:
773 | x_axis = [exp.meta["steps"]]
774 |
775 | is_quest = "quest" in exp.meta["variant"]
776 | num_chains = exp.meta.get("num_chains", 1)
777 | instances = exp.instances()
778 | if self.n is not None:
779 | instances = instances[: self.n]
780 | evals_scores = exp.get_eval(self.reward_key)
781 | instances_and_scores = [
782 | join_instances_and_repeat_accepts_with_scores(
783 | instances[i : i + num_chains],
784 | evals_scores=evals_scores[i : i + num_chains],
785 | is_quest=is_quest,
786 | )
787 | for i in range(0, len(instances), num_chains)
788 | ]
789 |
790 | x_results = []
791 | for i in tqdm(x_axis):
792 | # spawn pool of workers
793 |
794 | results = list(
795 | self.pmap(
796 | partial(
797 | is_resample_correct,
798 | n=i,
799 | extract_func=self.extract,
800 | r=self.r,
801 | trials=self.trials,
802 | ),
803 | instances_and_scores,
804 | )
805 | )
806 |
807 | # scores.append({"axis": x_axis, "scores": results})
808 | x_results.append(float(np.mean(list((results)))))
809 |
810 | return [{"axis": x_axis, "scores": x_results}]
811 |
812 |
813 | class WeightedVotingPick(Evalutor):
814 |
815 | def __init__(
816 | self,
817 | key,
818 | gaps=2,
819 | exp_rate=False,
820 | multiprocessing=False,
821 | beta=1.0,
822 | extract="lastnumber",
823 | n=None,
824 | **kwargs,
825 | ):
826 | super().__init__(
827 | key + "-weighted-voting-" + str(beta).replace(".", ";")
828 | if n is None
829 | else key + "-weighted-voting-" + str(beta).replace(".", ";") + "-" + str(n)
830 | )
831 | self.reward_key = key
832 | self.gaps = gaps
833 | self.exp_rate = exp_rate
834 | self.pmap = multiprocessing.map if multiprocessing else map
835 | self.beta = beta
836 | self.n = n
837 | self.chunk_size = 64
838 |
839 | if extract == "lastnumber":
840 | self.extract = get_last_number
841 | elif extract == "lastmath":
842 | self.extract = get_last_math
843 | elif extract == "lastoption":
844 | self.extract = get_last_option
845 | else:
846 | ValueError("Invalid extract function")
847 |
848 | def eval(self, exp: Exp):
849 |
850 | assert exp.has_eval(
851 | self.reward_key
852 | ), f"Experiment does not have {self.reward_key} eval"
853 |
854 | if self.gaps > 0:
855 | if not self.exp_rate:
856 | x_axis = generate_axis(exp.meta["steps"] + 1, self.gaps)
857 | else:
858 | x_axis = generate_axis(exp.meta["steps"] + 1, exp.meta["steps"])
859 |
860 | else:
861 | x_axis = [exp.meta["steps"]]
862 |
863 | accs_full = {xi: [] for xi in x_axis}
864 |
865 | evals_scores = exp.get_eval(self.reward_key)
866 | it_idx = 0
867 | for instance_chunk in tqdm(
868 | chunked(exp.instances(lazy_iterable=True), self.chunk_size)
869 | ):
870 |
871 | evals_scores_chunk = evals_scores[it_idx : it_idx + len(instance_chunk)]
872 | it_idx += len(instance_chunk)
873 |
874 | is_quest = "quest" in exp.meta["variant"]
875 |
876 | instances_and_scores = [
877 | join_instances_and_repeat_accepts_with_scores(
878 | [ich], evals_scores=[ech], is_quest=is_quest
879 | )
880 | for ich, ech in zip(instance_chunk, evals_scores_chunk)
881 | ]
882 |
883 | for i in x_axis:
884 |
885 | results = self.pmap(
886 | partial(
887 | is_weighted_mode_correct, # is_weighted_mode_correct
888 | n=i * 1,
889 | extract_func=self.extract,
890 | beta=self.beta,
891 | ),
892 | instances_and_scores,
893 | )
894 |
895 | accs, estimates = zip(*results)
896 |
897 | accs_full[i].extend(accs)
898 |
899 | results = [
900 | {
901 | "axis": x_axis,
902 | "scores": [float(np.mean(accs_full[i])) for i in x_axis],
903 | "accs": [accs_full[i] for i in x_axis],
904 | }
905 | ]
906 | return results
907 |
908 |
909 | class MBRPick(Evalutor):
910 |
911 | def __init__(
912 | self,
913 | gaps=2,
914 | extract="lastnumber",
915 | n=None,
916 | metric="rouge",
917 | **kwargs,
918 | ):
919 | super().__init__(f"{metric}-mbr" if n is None else f"{metric}-mbr-" + str(n))
920 | self.gaps = gaps
921 |
922 | self.metric = metric
923 | self.n = n
924 | print(f"Filename:", self.eval_name)
925 |
926 | self.extract_key = extract
927 |
928 | def eval(self, exp: Exp):
929 |
930 | if not self.n:
931 | n = exp.meta["n"]
932 | else:
933 | n = self.n
934 |
935 | pairwise_pick_results = mbr_pick_progression(
936 | exp,
937 | compute_in_parallel=True,
938 | k=self.gaps,
939 | n=n,
940 | metric=self.metric,
941 | )
942 |
943 | gt = exp.get_eval(self.extract_key)[:n]
944 |
945 | scores = {}
946 | for ax, preds in pairwise_pick_results["preds"].items():
947 | pred_scores = [gt[j]["scores"][pred] for j, pred in enumerate(preds)]
948 | scores[ax] = float(np.mean(pred_scores))
949 |
950 | output = {
951 | "axis": pairwise_pick_results["axis"],
952 | "scores": [scores[ax] for ax in pairwise_pick_results["axis"]],
953 | }
954 |
955 | print(output)
956 |
957 | return [output]
958 |
959 |
960 | class WMBRPick(Evalutor):
961 |
962 | def __init__(
963 | self,
964 | key,
965 | gaps=2,
966 | extract="lastnumber",
967 | n=None,
968 | metric="rouge",
969 | **kwargs,
970 | ):
971 | super().__init__(f"{metric}-wmbr" if n is None else f"{metric}-wmbr-" + str(n))
972 | self.gaps = gaps
973 | self.metric = metric
974 | self.n = n
975 | self.reward_key = key
976 | print(f"Filename:", self.eval_name)
977 |
978 | self.extract_key = extract
979 |
980 | def eval(self, exp: Exp):
981 |
982 | if not self.n:
983 | n = exp.meta["n"]
984 | else:
985 | n = self.n
986 |
987 | pairwise_pick_results = weighted_mbr_pick_progression(
988 | exp,
989 | reward_key=self.reward_key,
990 | compute_in_parallel=True,
991 | k=self.gaps,
992 | n=n,
993 | metric=self.metric,
994 | )
995 |
996 | gt = exp.get_eval(self.extract_key)[:n]
997 |
998 | scores = {}
999 | for ax, preds in pairwise_pick_results["preds"].items():
1000 | pred_scores = [gt[j]["scores"][pred] for j, pred in enumerate(preds)]
1001 | scores[ax] = float(np.mean(pred_scores))
1002 |
1003 | output = {
1004 | "axis": pairwise_pick_results["axis"],
1005 | "scores": [scores[ax] for ax in pairwise_pick_results["axis"]],
1006 | }
1007 |
1008 | print(output)
1009 |
1010 | return [output]
1011 |
--------------------------------------------------------------------------------
/resume_experiment.py:
--------------------------------------------------------------------------------
1 | ## qalign
2 | from qalign.utils.experiment import run_experiment
3 |
4 | ## expkit
5 | from expkit.storage import DiskStorage
6 | from expkit import Exp
7 |
8 |
9 | def main(
10 | experiment_name="",
11 | save_path: str = "llama3.2-outputs/",
12 | reward_model_batch_size: int = 128,
13 | gpu_memory_utilization: float = 0.95,
14 | reward_device=1,
15 | device_count: int = 1,
16 | remote=False,
17 | ):
18 |
19 | storage = DiskStorage(save_path, "rw")
20 |
21 | if storage.exists(experiment_name):
22 | experiment = Exp.load(storage=storage, name=experiment_name)
23 |
24 | run_experiment(
25 | experiment=experiment,
26 | gpu_memory_utilization=gpu_memory_utilization,
27 | device_count=device_count,
28 | reward_model_batch_size=reward_model_batch_size,
29 | reward_device=reward_device,
30 | remote=remote,
31 | )
32 |
33 | else:
34 | raise ValueError(f"Experiment {experiment_name} does not exist.")
35 |
36 |
37 | if __name__ == "__main__":
38 | import fire
39 |
40 | fire.Fire(main)
41 |
--------------------------------------------------------------------------------
/resume_experiment_remote.py:
--------------------------------------------------------------------------------
1 | ## qalign
2 | from qalign.utils.experiment import run_experiment_remote
3 |
4 | ## expkit
5 | from expkit.storage import DiskStorage
6 | from expkit import Exp
7 |
8 |
9 | # 280bd4d8-9e11-48d4-812b-fd2c5666684d
10 | def main(
11 | experiment_name="",
12 | save_path: str = "llama3.2-outputs/",
13 | ):
14 |
15 | storage = DiskStorage(save_path, "rw")
16 |
17 | if storage.exists(experiment_name):
18 | experiment = Exp.load(storage=storage, name=experiment_name)
19 |
20 | run_experiment_remote(
21 | experiment=experiment,
22 | )
23 |
24 | else:
25 | raise ValueError(f"Experiment {experiment_name} does not exist.")
26 |
27 |
28 | if __name__ == "__main__":
29 | import fire
30 |
31 | fire.Fire(main)
32 |
--------------------------------------------------------------------------------
/scripts/create_all_general_experiments.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | MODELPATH=allenai/Llama-3.1-Tulu-3-8B-SFT
5 | DIR="outputs/"
6 |
7 | TEMP=1.0
8 | N=128
9 |
10 | #### QUEST
11 | RM_PATH=allenai/Llama-3.1-Tulu-3-8B-RM
12 | BETA=0.5
13 | STEPS=1024
14 |
15 | DATASET="HuggingFaceH4/MATH-500"
16 | PROMPT="Solve the following math problem step-by-step: {prompt}\n\nPresent the answer in LaTex format: \\boxed{Your answer}"
17 |
18 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type contextual --split "test" --num_chains 1 --batch_size 32 --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT"
19 |
20 | DATASET="openai/gsm8k"
21 | PROMPT="Solve the following grade school math problem step-by-step: {prompt}"
22 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type contextual --split "test" --num_chains 1 --batch_size 32 --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT"
23 |
24 | DATASET="google/IFEval"
25 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type contextual --split "train" --num_chains 1 --batch_size 32 --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
26 |
27 | DATASET="edinburgh-dawg/mmlu-redux"
28 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type contextual --split "test" --num_chains 1 --batch_size 32 --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
29 |
30 |
31 | ## ANCESTRAL
32 | DATASET="HuggingFaceH4/MATH-500"
33 | PROMPT="Solve the following math problem step-by-step: {prompt}\n\nPresent the answer in LaTex format: \\boxed{Your answer}"
34 |
35 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT"
36 |
37 | DATASET="openai/gsm8k"
38 | PROMPT="Solve the following grade school math problem step-by-step: {prompt}"
39 |
40 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT"
41 |
42 | DATASET="google/IFEval"
43 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "train" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
44 |
45 | DATASET="edinburgh-dawg/mmlu-redux"
46 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
47 |
48 | ## DPO ANCESTRAL
49 | MODELPATH=allenai/Llama-3.1-Tulu-3-8B-DPO
50 |
51 | DATASET="HuggingFaceH4/MATH-500"
52 | PROMPT="Solve the following math problem step-by-step: {prompt}\n\nPresent the answer in LaTex format: \\boxed{Your answer}"
53 |
54 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT"
55 |
56 | DATASET="openai/gsm8k"
57 | PROMPT="Solve the following grade school math problem step-by-step: {prompt}"
58 |
59 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET --prompt_template "$PROMPT"
60 |
61 |
62 | DATASET="google/IFEval"
63 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "train" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
64 |
65 | DATASET="edinburgh-dawg/mmlu-redux"
66 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
67 |
68 |
69 | # Loop through all files in the directory
70 | for file in "$DIR/*; do
71 | # Skip if it's a directory
72 | if [ -f "$file" ]; then
73 | # Get just the filename without path
74 | fileid=$(basename "$file")
75 |
76 | echo "Processing: $fileid"
77 |
78 | # Launch Python script with the filename as experiment_name
79 | python resume_experiment.py --experiment_name "$fileid" --save_path "$DIR"
80 |
81 | echo ""
82 | fi
83 | done
84 |
85 | echo "All experiments have been processed."
--------------------------------------------------------------------------------
/scripts/create_all_task_experiments.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | MODELPATH=meta-llama/Llama-3.1-8B-Instruct
5 | DIR="outputs-task/"
6 | RM_PATH="/gscratch/ark/graf/LLaMA-Factory/saves/llama3/8b/full/reward/"
7 | TEMP=1.0
8 | N=128
9 | STEPS=4096
10 | BATCH_SIZE=128
11 |
12 | ## ANCESTRAL
13 | DATASET="openai/gsm8k"
14 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
15 |
16 | ## QUEST
17 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type value --split "test" --num_chains 1 --batch_size $BATCH_SIZE --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
18 |
19 | DATASET="apple/GSM-Symbolic-p1"
20 | ## ANCESTRAL
21 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "ancestral" --split "test" --batch_size 1 --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
22 |
23 | ## QUEST
24 | TOKENIZERS_PARALLELISM=false python launch_experiment.py --variant "quest-rlhf" --reward_type value --split "test" --num_chains 1 --batch_size $BATCH_SIZE --beta $BETA --steps $STEPS --temperature $TEMP --n $N --save_path $DIR --gpu_memory_utilization 0.95 --model_path $MODELPATH --reward_model_path $RM_PATH --reward_model_batch_size 32 --max_new_tokens 800 --max_prompt_length 1200 --dataset_path $DATASET
25 |
--------------------------------------------------------------------------------
/scripts/run_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | directories=("outputs/" "outputs-task/")
5 |
6 | echo "Comparing responses to GT for all datasets"
7 |
8 | DIR="outputs/"
9 |
10 | RM="lastnumber"
11 | query_args='{"dataset":"openai/gsm8k"}'
12 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args
13 |
14 | RM="lastmath"
15 | query_args='{"dataset":"HuggingFaceH4/MATH-500"}'
16 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args
17 |
18 |
19 | RM="ifeval"
20 | query_args='{"dataset":"google/IFEval"}'
21 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args
22 |
23 |
24 | RM="lastoption"
25 |
26 | query_args='{"dataset":"edinburgh-dawg/mmlu-redux"}'
27 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args
28 |
29 | query_args='{"dataset":"truthfulqa/truthful_qa"}'
30 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args
31 |
32 |
33 | DIR="outputs-task/"
34 |
35 | RM="lastnumber"
36 | query_args='{}'
37 | python eval.py --base_dir "$DIR" --reward_model_path $RM --query_args $query_args
--------------------------------------------------------------------------------
/scripts/run_local_experiments.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | DIR="outputs-task/"
5 |
6 | # Loop through all files in the directory
7 | for file in "$DIR/*; do
8 | # Skip if it's a directory
9 | if [ -f "$file" ]; then
10 | # Get just the filename without path
11 | fileid=$(basename "$file")
12 |
13 | echo "Processing: $fileid"
14 |
15 | # Launch Python script with the filename as experiment_name
16 | python resume_experiment.py --experiment_name "$fileid" --save_path "$DIR"
17 |
18 | echo ""
19 | fi
20 | done
21 |
22 | echo "All experiments have been processed."
--------------------------------------------------------------------------------
/scripts/run_pred.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | echo "Running MV WMV and BON and for all datasets"
5 |
6 |
7 | ## general exps
8 | DIR="outputs/"
9 |
10 | RM="lastnumber"
11 | query_args='{"dataset":"openai/gsm8k"}'
12 | python pred.py --base_dir "$DIR" --extract $RM --strategy "voting" --query_args $query_args
13 |
14 | RM="lastmath"
15 |
16 | query_args='{"dataset":"HuggingFaceH4/MATH-500"}'
17 | python pred.py --base_dir "$DIR" --extract $RM --strategy "voting" --query_args $query_args
18 |
19 |
20 | RM="ifeval"
21 |
22 | query_args='{"dataset":"google/IFEval"}'
23 | python pred.py --base_dir "$DIR" --extract $RM --strategy "mbr" --query_args $query_args
24 |
25 |
26 | RM="lastoption"
27 |
28 | query_args='{"dataset":"edinburgh-dawg/mmlu-redux"}'
29 | python pred.py --base_dir "$DIR" --extract $RM --strategy "voting" --query_args $query_args
30 |
31 | query_args='{"dataset":"truthfulqa/truthful_qa"}'
32 | python pred.py --base_dir "$DIR" --extract $RM --strategy "voting" --query_args $query_args
33 |
34 |
35 |
36 | ## bon and WMV on general exps
37 |
38 | RM="crm:allenai-Llama-3"
39 |
40 | query_args='{"dataset":"google/IFEval","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}'
41 |
42 | python pred.py --base_dir $DIR --strategy "bon" --extract "ifeval" --query_args $query_args --key $RM
43 | python pred.py --base_dir $DIR --strategy "wmbr" --extract "ifeval" --query_args $query_args --key $RM
44 |
45 | query_args='{"dataset":"HuggingFaceH4/MATH-500","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}'
46 |
47 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastmath" --query_args $query_args --key $RM
48 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastmath" --query_args $query_args --key $RM
49 |
50 | query_args='{"dataset":"openai/gsm8k","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}'
51 |
52 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastnumber" --query_args $query_args --key $RM
53 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastnumber" --query_args $query_args --key $RM
54 |
55 | query_args='{"dataset":"truthfulqa/truthful_qa","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}'
56 |
57 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastoption" --query_args $query_args --key $RM
58 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastoption" --query_args $query_args --key $RM
59 |
60 | query_args='{"dataset":"edinburgh-dawg/mmlu-redux","variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}'
61 |
62 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastoption" --query_args $query_args --key $RM
63 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastoption" --query_args $query_args --key $RM
64 |
65 |
66 | ### task specific exps
67 | RM="vh:-gscratch-ark-graf-LLaMA-Factory-saves-llama3-8b-full-reward-"
68 |
69 | DIR="outputs-task/"
70 |
71 | query_args='{}'
72 |
73 | python pred.py --base_dir "$DIR" --strategy "voting" --extract "lastnumber" --query_args $query_args
74 |
75 | query_args='{"variant":"ancestral"}'
76 |
77 | python pred.py --base_dir $DIR --strategy "bon" --extract "lastnumber" --query_args $query_args --key $RM
78 | python pred.py --base_dir $DIR --strategy "weighted-voting" --extract "lastnumber" --query_args $query_args --key $RM
79 |
--------------------------------------------------------------------------------
/scripts/run_remote_experiments.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DIR="outputs-task/"
4 |
5 | # Loop through all files in the directory
6 | for file in "$DIR/*; do
7 | # Skip if it's a directory
8 | if [ -f "$file" ]; then
9 | # Get just the filename without path
10 | fileid=$(basename "$file")
11 |
12 | echo "Processing: $fileid"
13 |
14 | # Launch Python script with the filename as experiment_name
15 | python resume_experiment_remote.py --experiment_name "$fileid" --save_path "$DIR"
16 |
17 | echo ""
18 | fi
19 | done
20 |
21 | echo "All experiments have been processed."
--------------------------------------------------------------------------------
/scripts/run_rm_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | RM=allenai/Llama-3.1-Tulu-3-8B-RM
5 | DIR="outputs/"
6 | query_args='{"variant":"ancestral","model_path":"allenai/Llama-3.1-Tulu-3-8B-SFT"}'
7 | python eval.py --base_dir $DIR --remote True --reward_model_path $RM --value_head False --batch_size 1024 --query_args $query_args
8 |
9 | RM="/gscratch/ark/graf/LLaMA-Factory/saves/llama3/8b/full/reward/"
10 | DIR="outputs-task/"
11 | query_args='{"variant":"ancestral"}'
12 | python eval.py --base_dir $DIR --remote True --reward_model_path $RM --value_head True --batch_size 1024 --query_args $query_args
13 |
--------------------------------------------------------------------------------