├── img ├── method.png └── results.png ├── src ├── train_question_generator │ ├── scripts │ │ ├── zero3.yaml │ │ ├── run_qwen2math_qft.sh │ │ ├── run_dsmath_qft.sh │ │ ├── run_dscode_qft.sh │ │ └── run_qwen2code_qft.sh │ ├── question_optim │ │ ├── optim_difficulty.py │ │ ├── optim_solvability.py │ │ ├── openai_gen.py │ │ └── prompts.py │ ├── qft_train │ │ └── train.py │ └── qpo_train │ │ ├── train.py │ │ └── dpo.py └── data_generation │ ├── run.sh │ ├── question_filtering │ ├── generate_difficulty_score.py │ ├── language_solvability_check.py │ └── vllm_gen.py │ ├── reward_filtering │ └── rm_score.py │ ├── vllm_gen.py │ └── gen.py ├── docker └── Dockerfile ├── requirements.txt ├── .gitignore ├── README.md └── LICENSE /img/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yyDing1/ScaleQuest/HEAD/img/method.png -------------------------------------------------------------------------------- /img/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yyDing1/ScaleQuest/HEAD/img/results.png -------------------------------------------------------------------------------- /src/train_question_generator/scripts/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | offload_optimizer_device: none 5 | offload_param_device: none 6 | zero3_init_flag: true 7 | zero3_save_16bit_model: true 8 | zero_stage: 3 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: bf16 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.04-py3 2 | 3 | RUN sed -i "s|archive.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g" /etc/apt/sources.list \ 4 | && pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple \ 5 | && ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ 6 | && echo "Asia/Shanghai" > /etc/timezone 7 | 8 | RUN apt update \ 9 | && apt install build-essential openssh-server net-tools psmisc tzdata tmux lsof libaio-dev -y \ 10 | && service ssh start 11 | 12 | RUN ssh-keygen -t ed25519 -f ~/.ssh/id_ed25519 -P "" \ 13 | && cat ~/.ssh/id_ed25519.pub >> ~/.ssh/authorized_keys \ 14 | && echo "StrictHostKeyChecking no" > ~/.ssh/config \ 15 | && chmod 600 ~/.ssh/{config,authorized_keys} 16 | 17 | COPY requirements.txt /tmp/requirements.txt 18 | 19 | RUN pip install -r /tmp/requirements.txt \ 20 | && pip install flash-attn --no-build-isolation \ 21 | && pip install -U nvitop 22 | 23 | ENTRYPOINT ["/usr/sbin/sshd", "-D"] 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.2 2 | transformers<4.43 3 | accelerate 4 | deepspeed==0.14.4 5 | numpy>=1.24.2,<2.0.0 # deepspeed 1.3.0 not compatible with numpy 2.0.0 yet 6 | vllm 7 | # math_eval 8 | sympy==1.12 9 | antlr4-python3-runtime==4.11.1 # ! The version needs to be compatible with sympy. 10 | word2number 11 | Pebble 12 | timeout-decorator 13 | git+https://github.com/ZubinGou/latex2sympy 14 | 15 | # For flash-attn build 16 | einops>=0.6.1 17 | packaging>=23.0 18 | ninja>=1.11.1 19 | pebble 20 | ipython 21 | ray 22 | 23 | bitsandbytes==0.42.0 24 | black>=24.4.2 25 | datasets>=2.18.0 26 | evaluate==0.4.0 27 | flake8>=6.0.0 28 | hf-doc-builder>=0.4.0 29 | hf_transfer>=0.1.4 30 | huggingface-hub>=0.19.2,<1.0 31 | isort>=5.12.0 32 | parameterized>=0.9.0 33 | peft>=0.9.0 34 | protobuf<=3.20.2 # Needed to avoid conflicts with `transformers` 35 | pytest 36 | safetensors>=0.3.3 37 | sentencepiece>=0.1.99 38 | scipy 39 | tensorboard 40 | trl>=0.9.6 41 | jinja2>=3.0.0 42 | tqdm>=4.64.1 43 | matplotlib 44 | seaborn 45 | orjson 46 | -------------------------------------------------------------------------------- /src/train_question_generator/scripts/run_qwen2math_qft.sh: -------------------------------------------------------------------------------- 1 | # Step 1: QFT 2 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 3 | --config_file ./zero3.yaml \ 4 | --main_process_port 29500 \ 5 | train.py \ 6 | --model_path Qwen/Qwen2-Math-7B-Instruct \ 7 | --dataset_path /path/to/gsm8k_math_15k \ 8 | --prompt_type qwen2-math \ 9 | --num_train_epochs 1 \ 10 | --gradient_checkpointing false \ 11 | --max_length 256 \ 12 | --output_dir models/Qwen2-Math-7B-QFT \ 13 | --per_device_train_batch_size 1 \ 14 | --per_device_eval_batch_size 1 \ 15 | --gradient_accumulation_steps 4 \ 16 | 17 | # Step 2: QPO 18 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 19 | --config_file ./zero3.yaml \ 20 | --main_process_port 29051 \ 21 | train.py \ 22 | --model_path /path/to/models/Qwen2-Math-7B-QFT \ 23 | --ref_model /path/to/models/Qwen2-Math-7B-QFT \ 24 | --dataset_path /path/to/qpo_data \ 25 | --prompt_type qwen2-math \ 26 | --run_name qwen2-math-qgen-sft-dpo \ 27 | --learning_rate 5e-7 \ 28 | --lr_scheduler_type cosine \ 29 | --loss_type sigmoid \ 30 | --warmup_steps 20 \ 31 | --num_train_epochs 1 \ 32 | --gradient_checkpointing true \ 33 | --max_length 1024 \ 34 | --output_dir models/Qwen2-Math-7B-QGen \ 35 | --per_device_train_batch_size 8 \ 36 | --per_device_eval_batch_size 8 \ 37 | --gradient_accumulation_steps 2 \ -------------------------------------------------------------------------------- /src/train_question_generator/scripts/run_dsmath_qft.sh: -------------------------------------------------------------------------------- 1 | # Step 1: QFT 2 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 3 | --config_file ./zero3.yaml \ 4 | --main_process_port 29600 \ 5 | qft_train/train.py \ 6 | --model_path deepseek-ai/deepseek-math-7b-rl \ 7 | --dataset_path /path/to/gsm8k_math_15k \ 8 | --prompt_type deepseek-math \ 9 | --num_train_epochs 1 \ 10 | --gradient_checkpointing false \ 11 | --max_length 256 \ 12 | --output_dir models/Deepseek-Math-7B-QFT \ 13 | --per_device_train_batch_size 1 \ 14 | --per_device_eval_batch_size 1 \ 15 | --gradient_accumulation_steps 4 \ 16 | 17 | # Step 2: QPO 18 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 19 | --config_file ./zero3.yaml \ 20 | --main_process_port 29601 \ 21 | train.py \ 22 | --model_path models/Deepseek-Math-7B-QFT \ 23 | --ref_model models/Deepseek-Math-7B-QFT \ 24 | --dataset_path /path/to/qpo_data \ 25 | --prompt_type deepseek-math \ 26 | --run_name deepseek-math-qgen-sft-dpo \ 27 | --learning_rate 5e-7 \ 28 | --lr_scheduler_type cosine \ 29 | --loss_type sigmoid \ 30 | --warmup_steps 20 \ 31 | --num_train_epochs 1 \ 32 | --gradient_checkpointing true \ 33 | --max_length 1024 \ 34 | --output_dir models/Deepseek-Math-7B-QGen \ 35 | --per_device_train_batch_size 8 \ 36 | --per_device_eval_batch_size 8 \ 37 | --gradient_accumulation_steps 2 \ -------------------------------------------------------------------------------- /src/train_question_generator/scripts/run_dscode_qft.sh: -------------------------------------------------------------------------------- 1 | # Step 1: QFT 2 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 3 | --config_file ./zero3.yaml \ 4 | --main_process_port 29600 \ 5 | qft_train/train.py \ 6 | --model_path deepseek-ai/deepseek-coder-6.7b-instruct \ 7 | --dataset_path /path/to/CodeFeedback-Filtered-Instruction \ 8 | --prompt_type deepseek-code \ 9 | --num_train_epochs 1 \ 10 | --gradient_checkpointing false \ 11 | --max_length 256 \ 12 | --output_dir models/Deepseek-Coder-7B-QFT \ 13 | --per_device_train_batch_size 1 \ 14 | --per_device_eval_batch_size 1 \ 15 | --gradient_accumulation_steps 4 \ 16 | 17 | Step 2: QPO 18 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 19 | --config_file ./zero3.yaml \ 20 | --main_process_port 29601 \ 21 | train.py \ 22 | --model_path models/Deepseek-Coder-7B-QFT \ 23 | --ref_model models/Deepseek-Coder-7B-QFT \ 24 | --dataset_path /path/to/qpo_data \ 25 | --prompt_type deepseek-code \ 26 | --run_name deepseek-code-qgen-sft-dpo \ 27 | --learning_rate 5e-7 \ 28 | --lr_scheduler_type cosine \ 29 | --loss_type sigmoid \ 30 | --warmup_steps 20 \ 31 | --num_train_epochs 1 \ 32 | --gradient_checkpointing true \ 33 | --max_length 1024 \ 34 | --output_dir models/Deepseek-Coder-7B-QGen \ 35 | --per_device_train_batch_size 8 \ 36 | --per_device_eval_batch_size 8 \ 37 | --gradient_accumulation_steps 2 \ -------------------------------------------------------------------------------- /src/train_question_generator/scripts/run_qwen2code_qft.sh: -------------------------------------------------------------------------------- 1 | # Step 1: QFT 2 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 3 | --config_file scripts/zero3.yaml \ 4 | --main_process_port 29500 \ 5 | qft_train/train.py \ 6 | --model_path Qwen/Qwen2.5-Coder-7B-Instruct \ 7 | --dataset_path /path/to/CodeFeedback-Filtered-Instruction \ 8 | --prompt_type qwen2.5-code \ 9 | --num_train_epochs 1 \ 10 | --gradient_checkpointing false \ 11 | --max_length 256 \ 12 | --max_training_samples 20000 \ 13 | --output_dir models/Qwen2.5-Coder-7B-QFT \ 14 | --per_device_train_batch_size 1 \ 15 | --per_device_eval_batch_size 1 \ 16 | --gradient_accumulation_steps 4 \ 17 | 18 | # Step 2: QPO 19 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 20 | --config_file ./zero3.yaml \ 21 | --main_process_port 29051 \ 22 | train.py \ 23 | --model_path models/Qwen2.5-Coder-7B-QFT \ 24 | --ref_model models/Qwen2.5-Coder-7B-QFT \ 25 | --dataset_path /path/to/qpo_data \ 26 | --prompt_type qwen2.5-code \ 27 | --run_name qwen2-code-qgen-sft-dpo \ 28 | --learning_rate 5e-7 \ 29 | --lr_scheduler_type cosine \ 30 | --loss_type sigmoid \ 31 | --warmup_steps 20 \ 32 | --num_train_epochs 1 \ 33 | --gradient_checkpointing true \ 34 | --max_length 1024 \ 35 | --output_dir models/Qwen2-Coder-7B-QGen \ 36 | --per_device_train_batch_size 8 \ 37 | --per_device_eval_batch_size 8 \ 38 | --gradient_accumulation_steps 2 \ -------------------------------------------------------------------------------- /src/data_generation/run.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | # Query Gen 3 | qry_num=100 4 | qry_prompt_type="qwen2-math-sft" 5 | qry_model_path="/path/to/Qwen2-Math-7B-QGen" 6 | qry_temp=1.0 7 | qry_top_p=0.99 8 | 9 | # Response Gen 10 | res_num_per_query=5 11 | res_prompt_type="qwen2-math" 12 | res_model_path="Qwen/Qwen2-Math-7B-Instruct" 13 | res_temp=0.7 14 | res_top_p=0.95 15 | 16 | output_folder="generation/" 17 | 18 | # for query generation 19 | python gen.py \ 20 | --qry_gen \ 21 | --qry_num $qry_num \ 22 | --qry_prompt_type $qry_prompt_type \ 23 | --qry_model_path $qry_model_path \ 24 | --qry_temperature $qry_temp \ 25 | --qry_top_p $qry_top_p \ 26 | --res_num_per_query $res_num_per_query \ 27 | --res_prompt_type $res_prompt_type \ 28 | --res_model_path $res_model_path \ 29 | --res_temperature $res_temp \ 30 | --res_top_p $res_top_p \ 31 | --output_folder $output_folder \ 32 | --swap_space 32 33 | 34 | # for response generation 35 | python gen.py \ 36 | --qry_num $qry_num \ 37 | --qry_prompt_type $qry_prompt_type \ 38 | --qry_model_path $qry_model_path \ 39 | --qry_temperature $qry_temp \ 40 | --qry_top_p $qry_top_p \ 41 | --res_gen \ 42 | --res_num_per_query $res_num_per_query \ 43 | --res_prompt_type $res_prompt_type \ 44 | --res_model_path $res_model_path \ 45 | --res_temperature $res_temp \ 46 | --res_top_p $res_top_p \ 47 | --output_folder $output_folder \ 48 | --swap_space 32 49 | -------------------------------------------------------------------------------- /src/data_generation/question_filtering/generate_difficulty_score.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | import torch 4 | from datasets import load_dataset, Dataset 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer, HfArgumentParser, pipeline 7 | import numpy as np 8 | import pandas as pd 9 | from torch.utils.data import DataLoader, Dataset as TorchDataset 10 | 11 | 12 | def process_data_on_device(device, sub_dataset, score_model_path, score_tokenizer): 13 | score_pipe = pipeline( 14 | "sentiment-analysis", 15 | model=score_model_path, 16 | device=device, 17 | tokenizer=score_tokenizer, 18 | model_kwargs={"torch_dtype": torch.bfloat16}, 19 | truncation=True 20 | ) 21 | 22 | pipe_kwargs = { 23 | "return_all_scores": True, 24 | "function_to_apply": "none", 25 | "batch_size": 1, 26 | } 27 | 28 | def get_reward(test_texts): 29 | pipe_outputs = score_pipe(test_texts, **pipe_kwargs) 30 | rewards = [output[0]["score"] for output in pipe_outputs] 31 | return rewards 32 | 33 | all_data = [] 34 | for line_data in tqdm(sub_dataset): 35 | score = get_reward(line_data["query"]) 36 | line_data["model_score"] = score 37 | line_data["score_model"] = score_model_path 38 | all_data.append(line_data) 39 | 40 | return all_data 41 | 42 | from multiprocessing import Pool 43 | 44 | def process_wrapper(args): 45 | return process_data_on_device(*args) 46 | 47 | def generate_score( 48 | dataset, 49 | model_path="/path/to/difficulty_score_model", 50 | tokenizer_path="/path/to/difficulty_score_model" 51 | ): 52 | num_gpus = torch.cuda.device_count() 53 | 54 | sub_datasets = [dataset.shard(num_shards=num_gpus, index=i) for i in range(num_gpus)] 55 | 56 | score_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 57 | 58 | with Pool(num_gpus) as p: 59 | results = p.map(process_wrapper, [(i, sub_datasets[i], model_path, score_tokenizer) for i in range(num_gpus)]) 60 | 61 | all_data = [item for sublist in results for item in sublist] 62 | final_dataset = Dataset.from_list(all_data) 63 | return final_dataset 64 | 65 | -------------------------------------------------------------------------------- /src/train_question_generator/question_optim/optim_difficulty.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from datasets import load_dataset 4 | import time 5 | 6 | from prompts import createDifficultyPrompt 7 | from openai_gen import run_openai_inference 8 | 9 | 10 | datasets = load_dataset("/path/to/qft_generated_questions", split="train") 11 | 12 | ds_len = len(datasets) 13 | eval_config = { 14 | "model": "gpt-4o-mini", 15 | "max_tokens": 4096, 16 | "num_generations": 1, 17 | "temperature": 0.0, 18 | "top_p": 1.0, 19 | "openai_timeout": 45, 20 | } 21 | 22 | def process_batch(batch): 23 | requests = [] 24 | for line_data in batch: 25 | instruction = createDifficultyPrompt(line_data["query"]) 26 | requests.append({ 27 | "query": line_data["query"], 28 | "query_gen_model": line_data["qry_gen_model"], 29 | "messages": [ 30 | {"role": "system", "content": "You are a helpful assistant."}, 31 | {"role": "user", "content": instruction}, 32 | ] 33 | }) 34 | 35 | results = run_openai_inference(requests, **eval_config) 36 | match_parten_list = [ 37 | "**FINAL QUESTION**:", 38 | "**FINAL QUESTION:**", 39 | "FINAL QUESTION:", 40 | "### FINAL QUESTION:", 41 | "### FINAL QUESTION", 42 | ] 43 | 44 | for result in results: 45 | result["rewritten"] = "" 46 | for match_partern in match_parten_list: 47 | if match_partern in result["generation"]: 48 | result["rewritten"] = result["generation"].split(match_partern)[-1].strip() 49 | break 50 | 51 | return results 52 | 53 | 54 | def save_results(results, file_path): 55 | try: 56 | with open(file_path, 'r') as f: 57 | existing_data = json.load(f) 58 | except FileNotFoundError: 59 | existing_data = [] 60 | 61 | existing_data.extend(results) 62 | 63 | with open(file_path, 'w') as f: 64 | json.dump(existing_data, f, indent=4) 65 | 66 | 67 | batch_size = ds_len 68 | output_file = 'output/gpt-4o-mini_optimize_qwen2-math-qgen_difficulty.json' 69 | 70 | for i in range(0, ds_len, batch_size): 71 | batch = datasets.select(range(i, i+batch_size)) 72 | results = process_batch(batch) 73 | 74 | save_results(results, output_file) 75 | print(f"Processed and saved batch {i // batch_size + 1} / {ds_len // batch_size}") 76 | 77 | time.sleep(30) 78 | 79 | print("All batches processed and saved.") 80 | -------------------------------------------------------------------------------- /src/train_question_generator/question_optim/optim_solvability.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from datasets import load_dataset 4 | import time 5 | 6 | from prompts import createSolvabilityPrompt 7 | from openai_gen import run_openai_inference 8 | 9 | 10 | datasets = load_dataset("/path/to/qft_generated_questions", split="train") 11 | 12 | ds_len = len(datasets) 13 | eval_config = { 14 | "model": "gpt-4o-mini", 15 | "max_tokens": 4096, 16 | "num_generations": 1, 17 | "temperature": 0.0, 18 | "top_p": 1.0, 19 | "openai_timeout": 45, 20 | } 21 | 22 | def process_batch(batch): 23 | requests = [] 24 | for line_data in batch: 25 | instruction = createSolvabilityPrompt(line_data["query"]) 26 | requests.append({ 27 | "query": line_data["query"], 28 | "query_gen_model": line_data["qry_gen_model"], 29 | "messages": [ 30 | {"role": "system", "content": "You are a helpful assistant."}, 31 | {"role": "user", "content": instruction}, 32 | ] 33 | }) 34 | 35 | results = run_openai_inference(requests, **eval_config) 36 | match_parten_list = [ 37 | "#Finally Rewritten Problem#:", 38 | "**Finally Rewritten Problem**:", 39 | "**Finally Rewritten Problem:**", 40 | "Finally Rewritten Problem:", 41 | "### Finally Rewritten Problem:", 42 | "### Finally Rewritten Problem", 43 | ] 44 | 45 | for result in results: 46 | result["rewritten"] = "" 47 | for match_partern in match_parten_list: 48 | if match_partern in result["generation"]: 49 | result["rewritten"] = result["generation"].split(match_partern)[-1].strip() 50 | break 51 | 52 | return results 53 | 54 | 55 | def save_results(results, file_path): 56 | try: 57 | with open(file_path, 'r') as f: 58 | existing_data = json.load(f) 59 | except FileNotFoundError: 60 | existing_data = [] 61 | 62 | existing_data.extend(results) 63 | 64 | with open(file_path, 'w') as f: 65 | json.dump(existing_data, f, indent=4) 66 | 67 | 68 | batch_size = ds_len 69 | output_file = 'output/gpt-4o-mini_optimize_qwen2-math-qgen_solvability.json' 70 | 71 | for i in range(0, ds_len, batch_size): 72 | batch = datasets.select(range(i, i+batch_size)) 73 | results = process_batch(batch) 74 | 75 | save_results(results, output_file) 76 | print(f"Processed and saved batch {i // batch_size + 1} / {ds_len // batch_size}") 77 | 78 | time.sleep(30) 79 | 80 | print("All batches processed and saved.") 81 | -------------------------------------------------------------------------------- /src/data_generation/reward_filtering/rm_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.multiprocessing as mp 4 | from typing import Dict, List 5 | from datasets import Dataset, load_dataset 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer, AutoModel 8 | 9 | 10 | def process_data_on_device(dataset, score_model_path, score_tokenizer, gpu_id): 11 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 12 | 13 | if "internlm2-7b-reward" in score_model_path.lower(): 14 | model = AutoModel.from_pretrained( 15 | score_model_path, 16 | device_map="cuda", 17 | torch_dtype=torch.float16, 18 | trust_remote_code=True, 19 | ) 20 | else: 21 | raise NotImplementedError(f"Model {score_model_path} not supported") 22 | 23 | all_data = [] 24 | for line_data in tqdm(dataset, desc=f"GPU {gpu_id}"): 25 | chat_messages = [ 26 | {"role": "user", "content": line_data["query"]}, 27 | {"role": "assistant", "content": line_data["response"]}, 28 | ] 29 | score = model.get_score(score_tokenizer, chat_messages) 30 | line_data["score"] = score 31 | line_data["reward_model"] = score_model_path 32 | all_data.append(line_data) 33 | 34 | return all_data 35 | 36 | 37 | 38 | def generate_score_parallel(dataset, model_path, tokenizer_path, num_gpus=8): 39 | score_tokenizer = AutoTokenizer.from_pretrained( 40 | tokenizer_path, trust_remote_code=True 41 | ) 42 | 43 | dataset_splits = [ 44 | dataset.shard(num_shards=num_gpus, index=i) for i in range(num_gpus) 45 | ] 46 | 47 | mp.set_start_method("spawn", force=True) 48 | pool = mp.Pool(processes=num_gpus) 49 | 50 | results = pool.starmap( 51 | process_data_on_device, 52 | [ 53 | (split, model_path, score_tokenizer, i) 54 | for i, split in enumerate(dataset_splits) 55 | ], 56 | ) 57 | 58 | all_results = [item for sublist in results for item in sublist] 59 | 60 | final_dataset = Dataset.from_list(all_results) 61 | return final_dataset 62 | 63 | 64 | if __name__ == "__main__": 65 | data_dir = "/path/to/data_generation/generation/qwen2-math_resgen600000x5_temp0.7_topp0.95" 66 | rm_path = "internlm/internlm2-7b-reward" 67 | 68 | save_dir = f"/path/to/data_generation/generation/qwen2-math_resgen600000x5_temp0.7_topp0.95_rm_score/output.jsonl" 69 | ds = load_dataset(data_dir, split="train") 70 | ds = generate_score_parallel( 71 | ds, 72 | model_path=rm_path, 73 | tokenizer_path=rm_path, 74 | num_gpus=os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1, 75 | ) 76 | ds.to_json(save_dir) 77 | -------------------------------------------------------------------------------- /src/data_generation/question_filtering/language_solvability_check.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import re 3 | from tqdm import tqdm 4 | import ray.data 5 | from transformers import AutoTokenizer 6 | from vllm_gen import run_vllm_inference_distributed 7 | from generate_difficulty_score import generate_score 8 | 9 | 10 | # Step 1: Filter Language 11 | data_path = "/path/to/data_generation/generation/qwen2-math-sft-dpo_querygen1000000_temp1.0_topp0.99" 12 | dataset = ray.data.read_json(data_path) 13 | 14 | def filter_query(line): 15 | ord_index = list(map(ord, line["query"])) 16 | if ord_index and max(ord_index) <= 127 and line["query"] != "": 17 | return True 18 | else: 19 | return False 20 | 21 | dataset = dataset.filter(filter_query, concurrency=8) 22 | 23 | 24 | # Step 2: Filter Solvability 25 | model_path = "Qwen/Qwen2-Math-7B-Instruct" 26 | stop_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] 27 | instruction = """ 28 | Please act as a professional math teacher. 29 | Your goal is to determine if the given problem is a valuable math problem. You need to consider two aspects: 30 | 1. The given problem is a math problem. 31 | 2. The given math problem can be solved based on the conditions provided in the problem (You can first try to solve it and then judge its solvability). 32 | 33 | Please reason step by step and conclude with either 'Yes' or 'No'. 34 | 35 | Given Problem: {problem} 36 | """.strip() 37 | 38 | tokenizer = AutoTokenizer.from_pretrained(model_path) 39 | def construct_solvability_check_prompt(line): 40 | messages = [ 41 | {"role": "system", "content": "You are a helpful assistant."}, 42 | {"role": "user", "content": instruction.format(problem=line["query"])} 43 | ] 44 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 45 | return {**line, "prompt_for_solvability_check": prompt} 46 | 47 | dataset = dataset.map(construct_solvability_check_prompt) 48 | dataset = run_vllm_inference_distributed( 49 | ds=dataset, 50 | model_path=model_path, 51 | tokenizer_path=model_path, 52 | prompt_key="prompt_for_solvability_check", 53 | generation_key="generation_for_solvability_check", 54 | max_tokens=2048, 55 | max_model_len=4096, 56 | num_generations=1, 57 | temperature=0.0, 58 | top_p=1.0, 59 | stop_tokens=stop_tokens, 60 | tensor_parallel_size=1, 61 | swap_space=32, 62 | ) 63 | 64 | def filter_answer(line): 65 | return "yes" in line["generation_for_solvability_check"][0].lower() 66 | dataset = dataset.filter(filter_answer) 67 | 68 | 69 | # Step 3: Generate Difficulty Score 70 | dataset = generate_score(dataset) 71 | 72 | dataset.write_json("final_data/question_filtering_data") 73 | -------------------------------------------------------------------------------- /src/train_question_generator/question_optim/openai_gen.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import os 4 | from time import sleep 5 | import numpy as np 6 | 7 | import openai 8 | from openai import AsyncOpenAI 9 | import asyncio 10 | from tqdm.asyncio import tqdm 11 | 12 | 13 | class AsyncOpenAIPredictor: 14 | def __init__( 15 | self, 16 | model, 17 | max_tokens, 18 | num_generations=1, 19 | temperature=0.0, 20 | top_p=1.0, 21 | openai_timeout=45, 22 | ): 23 | self.client = AsyncOpenAI( 24 | api_key=os.getenv("OPENAI_KEY"), 25 | base_url="https://api.openai.com/v1/", 26 | ) 27 | self.client_kwargs: dict[str | str] = { 28 | "model": model, 29 | "temperature": temperature, 30 | "max_tokens": max_tokens, 31 | "top_p": top_p, 32 | "n": num_generations, 33 | "timeout": openai_timeout, 34 | # "stop": args.stop, --> stop is only used for base models currently 35 | } 36 | 37 | async def __call__(self, item: Dict[str, np.ndarray]) -> Dict[str, list]: 38 | assert isinstance(item["messages"], list) 39 | 40 | max_retries = 5 41 | for attempt in range(max_retries): 42 | try: 43 | response = await self.client.chat.completions.create( 44 | messages=item["messages"], **self.client_kwargs 45 | ) 46 | return {**item, "generation": response.choices[0].message.content} 47 | except ( 48 | openai.APIError, 49 | openai.RateLimitError, 50 | openai.InternalServerError, 51 | openai.OpenAIError, 52 | openai.APIStatusError, 53 | openai.APITimeoutError, 54 | openai.InternalServerError, 55 | openai.APIConnectionError, 56 | ) as e: 57 | print(f"[Attempt {attempt + 1}] Exception: {repr(e)}") 58 | print(f"[Attempt {attempt + 1}] Sleeping for 30 seconds...") 59 | await asyncio.sleep(30) 60 | except Exception as e: 61 | print(f"Failed to run the model for {item['messages']}!") 62 | print("Exception: ", repr(e)) 63 | return None 64 | 65 | 66 | async def run_openai_inference_async(requests, **kwargs): 67 | predictor = AsyncOpenAIPredictor(**kwargs) 68 | 69 | results = [] 70 | tasks = [predictor(request) for request in requests] 71 | 72 | for task in tqdm( 73 | asyncio.as_completed(tasks), total=len(tasks), desc="Processing items" 74 | ): 75 | result = await task 76 | if result: 77 | results.append(result) 78 | 79 | return results 80 | 81 | 82 | def run_openai_inference(requests, **kwargs): 83 | requests = asyncio.run(run_openai_inference_async(requests, **kwargs)) 84 | return requests 85 | 86 | 87 | def test(): 88 | import ray 89 | import ray.data 90 | ds = ray.data.from_items([ 91 | {"messages": [{"content": "hello", "role": "user"}]}, 92 | {"messages": [{"content": "hi", "role": "user"}]}, 93 | ]) 94 | ds = run_openai_inference(ds, model="gpt-4o-mini", max_tokens=1024) 95 | x = ds.take_all() 96 | 97 | 98 | if __name__ == "__main__": 99 | test() 100 | -------------------------------------------------------------------------------- /src/train_question_generator/question_optim/prompts.py: -------------------------------------------------------------------------------- 1 | solvability_optimization_prompt = """ 2 | You are an Math Problem Rewriter that rewrites the given #Problem# into a more complex version. 3 | Please follow the steps below to rewrite the given "#Problem#" into a more complex version. 4 | 5 | Step 1: Please read the "#Problem#" carefully and list all the possible methods to make this problem more complex (to make it a bit harder for well-known AI assistants such as ChatGPT and GPT4 to handle). Note that the problem itself might be erroneous, and you need to first correct the errors within it. 6 | 7 | Step 2: Please create a comprehensive plan based on the #Methods List# generated in Step 1 to make the #Problem# more complex. The plan should include several methods from the #Methods List#. 8 | 9 | Step 3: Please execute the plan step by step and provide the #Rewritten Problem#. #Rewritten Problem# can only add 10 to 20 words into the "#Problem#". 10 | 11 | Step 4: Please carefully review the #Rewritten Problem# and identify any unreasonable parts. Ensure that the #Rewritten Problem# is only a more complex version of the #Problem#. Just provide the #Finally Rewritten Problem# without any explanation and step-by-step reasoning guidance. 12 | 13 | Please reply strictly in the following format: 14 | Step 1 #Methods List#: 15 | Step 2 #Plan#: 16 | Step 3 #Rewritten Problem#: 17 | Step 4 #Finally Rewritten Problem#: 18 | 19 | #Problem#: 20 | {problem} 21 | """.strip() 22 | 23 | difficulty_optimization_prompt = """ 24 | Please act as a professional math teacher. 25 | Your goal is to create high quality math word problems to help students learn math. 26 | You will be given a math question. Please optimize the Given Question and following instructions. 27 | To achieve the goal, please follow the steps: 28 | # Please check that the given question is a math question and write detailed solution to the Given Question. 29 | # Based on the problem-solving process, double check the question is solvable. 30 | # If you feel that the given question is not a meaningful math question, rewrite one that makes sense to you. Otherwise, modify the Given question according to your checking comment to ensure it is solvable and of high quality. 31 | # If the question can be solved with just a few simple thinking processes, you can rewrite it to explicitly request multiple-step reasoning. 32 | 33 | You have five principles to do this: 34 | # Ensure the optimized question only asks for one thing, be reasonable and solvable, be based on the Given Question (if possible), and can be answered with only a number (float or integer). For example, DO NOT ask, 'what is the amount of A, B and C?'. 35 | # Ensure the optimized question is in line with common sense of life. For example, the amount someone has or pays must be a positive number, and the number of people must be an integer. 36 | # Ensure your student can answer the optimized question without the given question. If you want to use some numbers, conditions or background in the given question, please restate them to ensure no information is omitted in your optimized question. 37 | # Please DO NOT include solution in your question. 38 | 39 | Given Question: {problem} 40 | Your output should be in the following format: 41 | CREATED QUESTION: 42 | VERIFICATION AND MODIFICATION: 43 | FINAL QUESTION: 44 | """.strip() 45 | 46 | 47 | def createSolvabilityPrompt(problem): 48 | prompt = solvability_optimization_prompt.format( 49 | problem=problem 50 | ) 51 | return prompt 52 | 53 | 54 | def createDifficultyPrompt(problem): 55 | prompt = difficulty_optimization_prompt.format( 56 | problem=problem 57 | ) 58 | return prompt 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | src/data_generation/generation 162 | src/train_question_generator/models 163 | -------------------------------------------------------------------------------- /src/data_generation/vllm_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example shows how to use Ray Data for running offline batch inference 3 | distributively on a multi-nodes cluster. 4 | 5 | Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html 6 | """ 7 | 8 | from typing import Any, Dict, List 9 | 10 | import os 11 | import time 12 | import numpy as np 13 | import ray 14 | import ray.data 15 | import torch 16 | from packaging.version import Version 17 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 18 | 19 | from vllm import LLM, SamplingParams 20 | 21 | assert Version(ray.__version__) >= Version( 22 | "2.22.0" 23 | ), "Ray version must be at least 2.22.0" 24 | 25 | 26 | # Create a class to do batch inference. 27 | class LLMPredictor: 28 | def __init__( 29 | self, 30 | model_path, 31 | tokenizer_path, 32 | prompt_key="prompt", 33 | generation_key="generation", 34 | max_tokens=2048, 35 | max_model_len=4096, 36 | num_generations=1, 37 | temperature=0.0, 38 | top_p=1.0, 39 | stop_tokens=None, 40 | stop_token_ids=None, 41 | tensor_parallel_size=1, 42 | enable_prefix_caching=False, 43 | swap_space=16, 44 | ): 45 | seed = int(time.time() * 1e6) % int(1e9) 46 | # Create an LLM. 47 | self.prompt_key = prompt_key 48 | self.generation_key = generation_key 49 | self.llm = LLM( 50 | model=model_path, 51 | tokenizer=tokenizer_path, 52 | tensor_parallel_size=tensor_parallel_size, 53 | max_model_len=max_model_len, 54 | enable_prefix_caching=enable_prefix_caching, 55 | trust_remote_code=True, 56 | swap_space=swap_space, 57 | gpu_memory_utilization=0.95, 58 | seed=seed, 59 | ) 60 | self.sampling_params = SamplingParams( 61 | n=num_generations, 62 | max_tokens=max_tokens, 63 | temperature=temperature, 64 | top_p=top_p, 65 | stop=stop_tokens, 66 | stop_token_ids=stop_token_ids, 67 | ) 68 | 69 | def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: 70 | # Generate texts from the prompts. 71 | # The output is a list of RequestOutput objects that contain the prompt, 72 | # generated text, and other information. 73 | outputs = self.llm.generate(batch[self.prompt_key], self.sampling_params) 74 | generated_text: List[str] = [] 75 | for output in outputs: 76 | generated_text.append([o.text for o in output.outputs]) 77 | return {**batch, self.generation_key: generated_text} 78 | 79 | 80 | def run_vllm_inference_distributed( 81 | ds, 82 | **kwargs, 83 | ): 84 | tensor_parallel_size = kwargs.get("tensor_parallel_size", 1) 85 | 86 | # Guarentee the compute resources is available 87 | if torch.cuda.device_count() < tensor_parallel_size: 88 | raise MemoryError( 89 | "Insufficient GPUs: tensor_parallel_size ({}) < available gpus ({})".format( 90 | tensor_parallel_size, torch.cuda.device_count() 91 | ) 92 | ) 93 | 94 | # Set number of instances. Each instance will use tensor_parallel_size GPUs. 95 | num_instances = torch.cuda.device_count() // tensor_parallel_size 96 | print("Launch {} instances for vllm inference.".format(num_instances)) 97 | 98 | # For tensor_parallel_size > 1, we need to create placement groups for vLLM 99 | # to use. Every actor has to have its own placement group. 100 | def scheduling_strategy_fn(): 101 | # One bundle per tensor parallel worker 102 | pg = ray.util.placement_group( 103 | [{"GPU": 1, "CPU": 1}] * tensor_parallel_size, strategy="STRICT_PACK" 104 | ) 105 | return dict( 106 | scheduling_strategy=PlacementGroupSchedulingStrategy( 107 | pg, placement_group_capture_child_tasks=True 108 | ) 109 | ) 110 | 111 | resources_kwarg: Dict[str, Any] = {} 112 | if tensor_parallel_size == 1: 113 | # For tensor_parallel_size == 1, we simply set num_gpus=1. 114 | resources_kwarg["num_gpus"] = 1 115 | else: 116 | # Otherwise, we have to set num_gpus=0 and provide 117 | # a function that will create a placement group for 118 | # each instance. 119 | resources_kwarg["num_gpus"] = 0 120 | resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn 121 | 122 | batch_size = min(ds.count() // num_instances + 1, 10000) 123 | # Apply batch inference for all input data. 124 | ds = ds.map_batches( 125 | LLMPredictor, 126 | # Set the concurrency to the number of LLM instances. 127 | concurrency=num_instances, 128 | # Specify the batch size for inference. 129 | batch_size=batch_size, 130 | fn_constructor_kwargs=kwargs, 131 | **resources_kwarg, 132 | ) 133 | 134 | return ds 135 | -------------------------------------------------------------------------------- /src/data_generation/question_filtering/vllm_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example shows how to use Ray Data for running offline batch inference 3 | distributively on a multi-nodes cluster. 4 | 5 | Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html 6 | """ 7 | 8 | from typing import Any, Dict, List 9 | 10 | import os 11 | import time 12 | import numpy as np 13 | import ray 14 | import ray.data 15 | import torch 16 | from packaging.version import Version 17 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 18 | 19 | from vllm import LLM, SamplingParams 20 | 21 | assert Version(ray.__version__) >= Version( 22 | "2.22.0" 23 | ), "Ray version must be at least 2.22.0" 24 | 25 | 26 | # Create a class to do batch inference. 27 | class LLMPredictor: 28 | def __init__( 29 | self, 30 | model_path, 31 | tokenizer_path, 32 | prompt_key="prompt", 33 | generation_key="generation", 34 | max_tokens=2048, 35 | max_model_len=4096, 36 | num_generations=1, 37 | temperature=0.0, 38 | top_p=1.0, 39 | stop_tokens=None, 40 | stop_token_ids=None, 41 | tensor_parallel_size=1, 42 | enable_prefix_caching=False, 43 | swap_space=16, 44 | ): 45 | seed = int(time.time() * 1e6) % int(1e9) 46 | # Create an LLM. 47 | self.prompt_key = prompt_key 48 | self.generation_key = generation_key 49 | self.llm = LLM( 50 | model=model_path, 51 | tokenizer=tokenizer_path, 52 | tensor_parallel_size=tensor_parallel_size, 53 | max_model_len=max_model_len, 54 | enable_prefix_caching=enable_prefix_caching, 55 | trust_remote_code=True, 56 | swap_space=swap_space, 57 | gpu_memory_utilization=0.95, 58 | seed=seed, 59 | ) 60 | self.sampling_params = SamplingParams( 61 | n=num_generations, 62 | max_tokens=max_tokens, 63 | temperature=temperature, 64 | top_p=top_p, 65 | stop=stop_tokens, 66 | stop_token_ids=stop_token_ids, 67 | ) 68 | 69 | def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: 70 | # Generate texts from the prompts. 71 | # The output is a list of RequestOutput objects that contain the prompt, 72 | # generated text, and other information. 73 | outputs = self.llm.generate(batch[self.prompt_key], self.sampling_params) 74 | generated_text: List[str] = [] 75 | for output in outputs: 76 | generated_text.append([o.text for o in output.outputs]) 77 | return {**batch, self.generation_key: generated_text} 78 | 79 | 80 | def run_vllm_inference_distributed( 81 | ds, 82 | **kwargs, 83 | ): 84 | tensor_parallel_size = kwargs.get("tensor_parallel_size", 1) 85 | 86 | # Guarentee the compute resources is available 87 | if torch.cuda.device_count() < tensor_parallel_size: 88 | raise MemoryError( 89 | "Insufficient GPUs: tensor_parallel_size ({}) < available gpus ({})".format( 90 | tensor_parallel_size, torch.cuda.device_count() 91 | ) 92 | ) 93 | 94 | # Set number of instances. Each instance will use tensor_parallel_size GPUs. 95 | num_instances = torch.cuda.device_count() // tensor_parallel_size 96 | print("Launch {} instances for vllm inference.".format(num_instances)) 97 | 98 | # For tensor_parallel_size > 1, we need to create placement groups for vLLM 99 | # to use. Every actor has to have its own placement group. 100 | def scheduling_strategy_fn(): 101 | # One bundle per tensor parallel worker 102 | pg = ray.util.placement_group( 103 | [{"GPU": 1, "CPU": 1}] * tensor_parallel_size, strategy="STRICT_PACK" 104 | ) 105 | return dict( 106 | scheduling_strategy=PlacementGroupSchedulingStrategy( 107 | pg, placement_group_capture_child_tasks=True 108 | ) 109 | ) 110 | 111 | resources_kwarg: Dict[str, Any] = {} 112 | if tensor_parallel_size == 1: 113 | # For tensor_parallel_size == 1, we simply set num_gpus=1. 114 | resources_kwarg["num_gpus"] = 1 115 | else: 116 | # Otherwise, we have to set num_gpus=0 and provide 117 | # a function that will create a placement group for 118 | # each instance. 119 | resources_kwarg["num_gpus"] = 0 120 | resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn 121 | 122 | batch_size = min(ds.count() // num_instances + 1, 10000) 123 | # Apply batch inference for all input data. 124 | ds = ds.map_batches( 125 | LLMPredictor, 126 | # Set the concurrency to the number of LLM instances. 127 | concurrency=num_instances, 128 | # Specify the batch size for inference. 129 | batch_size=batch_size, 130 | fn_constructor_kwargs=kwargs, 131 | **resources_kwarg, 132 | ) 133 | 134 | return ds 135 | -------------------------------------------------------------------------------- /src/train_question_generator/qft_train/train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | from datasets import load_dataset 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | AutoTokenizer, 10 | HfArgumentParser, 11 | TrainingArguments, 12 | ) 13 | from trl import SFTTrainer 14 | 15 | 16 | # Define and parse arguments. 17 | @dataclass 18 | class ScriptArguments: 19 | """ 20 | These arguments vary depending on how many GPUs you have, \ 21 | what their capacity and features are, and what size model you want to train. 22 | """ 23 | 24 | per_device_train_batch_size: Optional[int] = field(default=1) 25 | per_device_eval_batch_size: Optional[int] = field(default=1) 26 | gradient_accumulation_steps: Optional[int] = field(default=4) 27 | learning_rate: Optional[float] = field(default=2e-5) 28 | weight_decay: Optional[float] = field(default=0.0) 29 | model_path: Optional[str] = field( 30 | default="meta-llama/Meta-Llama-3-8B", 31 | metadata={ 32 | "help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc." 33 | }, 34 | ) 35 | dataset_path: Optional[str] = field( 36 | default="openai/gsm8k", 37 | metadata={ 38 | "help": "", 39 | }, 40 | ) 41 | prompt_type: Optional[str] = field( 42 | default="qwen2-math", 43 | ) 44 | bf16: Optional[bool] = field( 45 | default=True, 46 | metadata={ 47 | "help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." 48 | }, 49 | ) 50 | tf32: Optional[bool] = field( 51 | default=None, 52 | ) 53 | num_train_epochs: Optional[float] = field( 54 | default=1, 55 | metadata={"help": "The number of training epochs for the reward model."}, 56 | ) 57 | gradient_checkpointing: Optional[bool] = field( 58 | default=True, 59 | metadata={"help": "Enables gradient checkpointing."}, 60 | ) 61 | optim: Optional[str] = field( 62 | # default="adamw_hf", 63 | default="paged_adamw_32bit", 64 | # default="adamw_torch_fused", 65 | metadata={"help": "The optimizer to use."}, 66 | ) 67 | lr_scheduler_type: Optional[str] = field( 68 | default="cosine", 69 | metadata={"help": "The lr scheduler"}, 70 | ) 71 | 72 | max_training_samples: Optional[int] = field( 73 | default=-1, metadata={"help": "the maximum sample size"} 74 | ) 75 | 76 | max_length: Optional[int] = field(default=4096) 77 | output_dir: Optional[str] = field(default="./models/sft_model_llama3") 78 | 79 | 80 | parser = HfArgumentParser(ScriptArguments) 81 | script_args = parser.parse_args_into_dataclasses()[0] 82 | 83 | 84 | training_args = TrainingArguments( 85 | output_dir=script_args.output_dir, 86 | learning_rate=script_args.learning_rate, 87 | per_device_train_batch_size=script_args.per_device_train_batch_size, 88 | per_device_eval_batch_size=script_args.per_device_eval_batch_size, 89 | num_train_epochs=script_args.num_train_epochs, 90 | weight_decay=script_args.weight_decay, 91 | save_strategy="epoch", 92 | eval_strategy="epoch", 93 | gradient_accumulation_steps=script_args.gradient_accumulation_steps, 94 | gradient_checkpointing=script_args.gradient_checkpointing, 95 | remove_unused_columns=True, 96 | bf16=script_args.bf16, 97 | tf32=script_args.tf32, 98 | logging_strategy="steps", 99 | logging_steps=1, 100 | optim=script_args.optim, 101 | lr_scheduler_type=script_args.lr_scheduler_type, 102 | warmup_ratio=0.03, 103 | report_to="tensorboard", 104 | ) 105 | 106 | 107 | model = AutoModelForCausalLM.from_pretrained( 108 | script_args.model_path, 109 | torch_dtype=torch.bfloat16, 110 | use_flash_attention_2=True, 111 | trust_remote_code=True, 112 | ).to("cuda") 113 | tokenizer = AutoTokenizer.from_pretrained( 114 | script_args.model_path, trust_remote_code=True 115 | ) 116 | tokenizer.pad_token = tokenizer.eos_token 117 | print("We set the pad token as the eos token by default....") 118 | # tokenizer.truncation_side = "left" 119 | tokenizer.model_max_length = script_args.max_length 120 | # tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 121 | 122 | 123 | dataset = load_dataset(script_args.dataset_path) 124 | 125 | 126 | def formatting_prompts_func(example): 127 | if script_args.prompt_type == "qwen2-math": 128 | text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{example['query'].strip()}<|im_end|>" 129 | elif script_args.prompt_type == "deepseek-math": 130 | text = f"User: {example['query'].strip()}\n\n<|end▁of▁sentence|>" 131 | elif script_args.prompt_type == "deepseek-code": 132 | text = f"You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n" \ 133 | f"### Instruction:\n{example['query'].strip()}\n### Response:\n" 134 | elif script_args.prompt_type == "qwen2.5-code": 135 | text = f"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{example['query'].strip()}<|im_end|>" 136 | else: 137 | raise NotImplementedError( 138 | f"Prompt type {script_args.prompt_type} not implemented." 139 | ) 140 | 141 | return {"text": text} 142 | 143 | 144 | dataset = dataset.map(formatting_prompts_func, batched=False) 145 | train_dataset = dataset["train"] 146 | eval_dataset = dataset["test"] if "test" in dataset else None 147 | if script_args.max_training_samples > 0: 148 | train_dataset = train_dataset.shuffle(seed=42).select(range(script_args.max_training_samples)) 149 | 150 | 151 | # formatting_prompts_func 152 | 153 | trainer = SFTTrainer( 154 | model=model, 155 | tokenizer=tokenizer, 156 | train_dataset=train_dataset, 157 | eval_dataset=eval_dataset, 158 | args=training_args, 159 | # formatting_func=, 160 | dataset_text_field="text", 161 | max_seq_length=script_args.max_length, 162 | packing=True, 163 | ) 164 | 165 | trainer.train() 166 | print("Saving last checkpoint of the model") 167 | 168 | trainer.save_model(script_args.output_dir) 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Unleashing Reasoning Capability of LLMs
via Scalable Question Synthesis from Scratch

3 | 4 |

5 | GitHub license 6 | Pretrained Models 7 | Blog 8 | Paper 9 | 10 |

11 | 12 | We introduce ScaleQuest, a scalable, cost-effective, and novel data synthesis method that utilizes small-size open-source models to generate questions from scratch without the need for seed data with complex augmentation constraints. 13 | 14 | ![](img/results.png) 15 | 16 | We release two question generator models and four problem-solving models. 17 | 18 | | Model | Type | MATH | Olympiad Bench | 🤗 HuggingFace
Download Link | 19 | | - | :-: | :-: | :-: | :-: | 20 | | ScaleQuest-DeepSeekMath-7B-QGen | question generator | - | - | [link](https://huggingface.co/dyyyyyyyy/ScaleQuest-DeepSeekMath-7B-QGen) 21 | | ScaleQuest-Qwen2-Math-7B-QGen | question generator | - | - | [link](https://huggingface.co/dyyyyyyyy/ScaleQuest-Qwen2-Math-7B-QGen) 22 | | Mistral-7B-ScaleQuest | problem solver | 62.9 | 26.8 | [link](https://huggingface.co/dyyyyyyyy/Mistral-7B-ScaleQuest) | 23 | | Llama3-8B-ScaleQuest | problem solver | 64.4 | 25.3 | [link](https://huggingface.co/dyyyyyyyy/Llama3-8B-ScaleQuest) | 24 | | DeepSeekMath-7B-ScaleQuest | problem solver | 66.6 | 29.9 | [link](https://huggingface.co/dyyyyyyyy/DeepSeekMath-7B-ScaleQuest) | 25 | | Qwen2-Math-7B-ScaleQuest | problem solver | 73.4 | 38.5 | [link](https://huggingface.co/dyyyyyyyy/Qwen2-Math-7B-ScaleQuest) | 26 | 27 | This repository contains our complete data synthesis method, including: 28 | 29 | ## Step 0: Requirements 30 | 31 | You should install the dependencies: 32 | 33 | ```bash 34 | conda create -n scalequest python=3.11 35 | conda activate scalequest 36 | pip install -r requirements.txt 37 | pip install flash-attn --no-build-isolation 38 | ``` 39 | 40 | 41 | ## Demo Usage 42 | 43 | Below is an question generator exmaple using `ScaleQuest-Qwen2-Math-7B-QGen` 44 | ```python 45 | from vllm import LLM, SamplingParams 46 | 47 | model_name = "dyyyyyyyy/ScaleQuest-Qwen2-Math-7B-QGen" 48 | pre_query_template = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n" 49 | stop_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] 50 | llm = LLM( 51 | model=model_name, 52 | tokenizer=model_name, 53 | tensor_parallel_size=1, 54 | max_model_len=4096, 55 | enable_prefix_caching=True, 56 | trust_remote_code=True, 57 | swap_space=16, 58 | gpu_memory_utilization=0.95, 59 | ) 60 | sampling_params = SamplingParams( 61 | n=4, 62 | max_tokens=1024, 63 | temperature=1.0, 64 | top_p=0.99, 65 | stop=stop_tokens, 66 | ) 67 | outputs = llm.generate(pre_query_template, sampling_params) 68 | # Print the outputs. 69 | for output in outputs: 70 | prompt = output.prompt 71 | for idx, generated_output in enumerate(output.outputs): 72 | generated_text = generated_output.text 73 | print(f"Sample {idx + 1}:") 74 | print(f"Prompt: {prompt!r}") 75 | print(f"Generated text: {generated_text!r}") 76 | print("-" * 50) 77 | ``` 78 | 79 | Below is an problem solver example using `Qwen2-Math-7B-ScaleQuest` 80 | 81 | ```python 82 | import torch 83 | from transformers import AutoModelForCausalLM, AutoTokenizer 84 | 85 | model_name = "dyyyyyyyy/Qwen2-Math-7B-ScaleQuest" 86 | 87 | model = AutoModelForCausalLM.from_pretrained( 88 | model_name, 89 | torch_dtype=torch.bfloat16, 90 | device_map="auto" 91 | ) 92 | tokenizer = AutoTokenizer.from_pretrained(model_name) 93 | 94 | question = "Find the value of $x$ that satisfies the equation $4x+5 = 6x+7$." 95 | 96 | sys_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 97 | query_prompt = "<|im_start|>user" + "\n" 98 | # {query} 99 | prompt_after_query = "\n" + "Please reason step by step, and put your final answer within \\boxed{}.<|im_end|>" + "\n" 100 | resp_prompt = "<|im_start|>assistant" + "\n" 101 | prompt_before_resp = "" 102 | # {resp} 103 | delim = "<|im_end|>" + "\n" 104 | 105 | prefix_prompt = f"{query_prompt}{question}{prompt_after_query}{resp_prompt}{prompt_before_resp}".rstrip(" ") 106 | full_prompt = sys_prompt + delim.join([prefix_prompt]) 107 | 108 | # print(full_prompt) 109 | 110 | inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) 111 | outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False) 112 | print(tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)) 113 | ``` 114 | 115 | ## Step 1: Train Query Generators 116 | 117 | 1. Training a question generator through question fine-tuning (code in the `src/train_question_generator/qft_train` folder). 118 | 2. Constructing preference data (code in the `src/train_question_generator/question_optim` folder) and performing question preference optimization (code in the `src/train_question_generator/qpo_train` folder). 119 | 120 | You can run QFT and QPO by the following command: 121 | 122 | ```bash 123 | cd src/train_question_generator && bash scripts/run_dsmath_qft.sh 124 | cd src/train_question_generator && bash scripts/run_qwen2math_qft.sh 125 | ``` 126 | 127 | ## Step 2: Question Synthesis 128 | 129 | 1. Using the trained question generator to synthesize questions (code in the `src/data_generation` folder). 130 | 2. Applying a filtering process to the generated questions (code in the `src/data_generation/question_filtering` folder). 131 | 132 | ```bash 133 | cd src/data_generation && bash scripts/run.sh 134 | ``` 135 | 136 | ## Step 3: Response Synthesis 137 | 138 | 1. Generating responses (code in the `src/data_generation` folder) 139 | 2. applying a reward filtering strategy (code in the `src/data_generation/reward_filtering` folder). 140 | 141 | ```bash 142 | cd src/data_generation && bash scripts/run.sh 143 | ``` 144 | 145 | ## Step 4: Instruction-Tuning & Evaluation 146 | 147 | We use [DART-Math](https://github.com/hkust-nlp/dart-math) framework for instruction tuning and evaluation. 148 | 149 | ## Citation 150 | 151 | ``` 152 | @article{ding2024unleashing, 153 | title={Unleashing Reasoning Capability of LLMs via Scalable Question Synthesis from Scratch}, 154 | author={Ding, Yuyang and Shi, Xinyu and Liang, Xiaobo and Li, Juntao and Zhu, Qiaoming and Zhang, Min}, 155 | journal={arXiv preprint arXiv:2410.18693}, 156 | year={2024} 157 | } 158 | ``` 159 | -------------------------------------------------------------------------------- /src/data_generation/gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import argparse 5 | import json 6 | import time 7 | import random 8 | import numpy as np 9 | from tqdm import tqdm 10 | import ray 11 | import ray.data 12 | from vllm_gen import run_vllm_inference_distributed 13 | 14 | 15 | def get_args(): 16 | # Experiment Settings 17 | parser = argparse.ArgumentParser(description="Instruction Generation Manager.") 18 | 19 | # Query Generation Parameters 20 | parser.add_argument("--qry_gen", action="store_true") 21 | parser.add_argument( 22 | "--qry_num", 23 | type=int, 24 | default=1000, 25 | help="Total number of prompts to generate. If specified, repeat will be ignored.", 26 | ) 27 | parser.add_argument("--qry_prompt_type", type=str, default="qwen2-math") 28 | parser.add_argument( 29 | "--qry_model_path", type=str, default="Qwen/Qwen2-Math-1.5B-Instruct" 30 | ) 31 | parser.add_argument( 32 | "--qry_model_tp", 33 | type=int, 34 | default=1, 35 | help="Number of GPUs to use for tensor parallelism. Only used for Llama 70B models.", 36 | ) 37 | # parser.add_argument("--qry_model_dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"]) 38 | parser.add_argument("--qry_temperature", type=float, default=1.0) 39 | parser.add_argument("--qry_top_p", type=float, default=1.0) 40 | parser.add_argument("--qry_max_tokens", type=int, default=1024) 41 | 42 | # Response Generation Parameters 43 | parser.add_argument("--res_gen", action="store_true") 44 | parser.add_argument( 45 | "--res_num_per_query", 46 | type=int, 47 | default=5, 48 | help="Number of samples to generate for one time.", 49 | ) 50 | parser.add_argument("--res_prompt_type", type=str, default="qwen2-math") 51 | parser.add_argument( 52 | "--res_model_path", type=str, default="Qwen/Qwen2-Math-1.5B-Instruct" 53 | ) 54 | parser.add_argument( 55 | "--res_model_tp", 56 | type=int, 57 | default=1, 58 | help="Number of GPUs to use for tensor parallelism. Only used for Llama 70B models.", 59 | ) 60 | # parser.add_argument("--res_model_dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"]) 61 | parser.add_argument("--res_temperature", type=float, default=0.0) 62 | parser.add_argument("--res_top_p", type=float, default=1.0) 63 | parser.add_argument("--res_max_tokens", type=int, default=2048) 64 | 65 | # System Settings 66 | parser.add_argument("--max_model_len", type=int, default=4096) 67 | parser.add_argument("--swap_space", type=float, default=4) 68 | parser.add_argument("--output_folder", type=str, default="./data") 69 | parser.add_argument("--seed", type=int, default=None, help="Random seed.") 70 | 71 | return parser.parse_args() 72 | 73 | 74 | args = get_args() 75 | print(f"Instruction Generation Manager. Arguments: {args}") # For logging 76 | 77 | if args.qry_gen: 78 | if "qwen2" in args.qry_prompt_type: 79 | pre_query_template = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n" 80 | stop_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] 81 | elif "deepseek" in args.qry_prompt_type: 82 | pre_query_template = "<|begin▁of▁sentence|>User: " 83 | stop_tokens = ["<|begin▁of▁sentence|>", "<|end▁of▁sentence|>"] 84 | elif "qwen2.5-code" in args.qry_prompt_type: 85 | pre_query_template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" 86 | stop_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] 87 | else: 88 | raise NotImplementedError( 89 | f"Query prompt type {args.qry_prompt_type} is not implemented" 90 | ) 91 | 92 | # Generate Query 93 | 94 | dataset = ray.data.from_items( 95 | [ 96 | { 97 | "query_id": query_idx, 98 | "qry_gen_model": args.qry_model_path, 99 | "prompt_query_gen": pre_query_template, 100 | "query_gen_temp": args.qry_temperature, 101 | "qry_top_p": args.qry_top_p, 102 | "qry_max_tokens": args.qry_max_tokens, 103 | } 104 | for query_idx in range(args.qry_num) 105 | ] 106 | ) 107 | dataset = run_vllm_inference_distributed( 108 | ds=dataset, 109 | model_path=args.qry_model_path, 110 | tokenizer_path=args.qry_model_path, 111 | prompt_key="prompt_query_gen", 112 | generation_key="generation_query_list", 113 | max_tokens=args.qry_max_tokens, 114 | max_model_len=args.max_model_len, 115 | num_generations=1, 116 | temperature=args.qry_temperature, 117 | top_p=args.qry_top_p, 118 | stop_tokens=stop_tokens, 119 | tensor_parallel_size=args.qry_model_tp, 120 | enable_prefix_caching=True, 121 | swap_space=args.swap_space, 122 | ) 123 | 124 | def flatten_batch_and_strip(line_data): 125 | generation_query_list = line_data.pop("generation_query_list") 126 | expanded_rows = [ 127 | {**line_data, "query": generation_query.strip()} 128 | for generation_query in generation_query_list 129 | ] 130 | return expanded_rows 131 | 132 | dataset = dataset.flat_map(flatten_batch_and_strip, concurrency=4) 133 | qry_gen_output_path = os.path.join( 134 | args.output_folder, 135 | f"{args.qry_prompt_type}_querygen{args.qry_num}_temp{args.qry_temperature}_topp{args.qry_top_p}", 136 | ) 137 | dataset.write_json(qry_gen_output_path) 138 | 139 | # Generate Response 140 | if args.res_gen: 141 | time.sleep(30) # wait for the GPU resources to be released 142 | qry_gen_output_path = os.path.join( 143 | args.output_folder, 144 | f"{args.qry_prompt_type}_querygen{args.qry_num}_temp{args.qry_temperature}_topp{args.qry_top_p}", 145 | ) 146 | dataset = ray.data.read_json(qry_gen_output_path) 147 | 148 | if "qwen2-math" in args.res_prompt_type: 149 | res_generation_template = ( 150 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 151 | "<|im_start|>user\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 152 | "<|im_start|>assistant\n" 153 | ) 154 | stop_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] 155 | elif "deepseek-math" in args.res_prompt_type: 156 | res_generation_template = ( 157 | "<|begin▁of▁sentence|>User: {input}\nPlease reason step by step, " 158 | "and put your final answer within \\boxed{{}}.\n\nAssistant:" 159 | ) 160 | stop_tokens = ["<|begin▁of▁sentence|>", "<|end▁of▁sentence|>"] 161 | elif "qwen2.5-code" in args.res_prompt_type: 162 | res_generation_template = ( 163 | "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" 164 | "<|im_start|>user\n{input}<|im_end|>\n" 165 | "<|im_start|>assistant\n" 166 | ) 167 | stop_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] 168 | else: 169 | raise NotImplementedError( 170 | f"Response prompt type {args.res_prompt_type} is not implemented" 171 | ) 172 | 173 | def preprocess_response_template(line_data): 174 | prompt_res_gen = res_generation_template.format(input=line_data["query"]) 175 | line_data.update( 176 | { 177 | "res_gen_model": args.res_model_path, 178 | "prompt_res_gen": prompt_res_gen, 179 | "res_gen_temp": args.res_temperature, 180 | "res_top_p": args.res_top_p, 181 | "res_max_tokens": args.res_max_tokens, 182 | } 183 | ) 184 | expanded_rows = [ 185 | {**line_data, "sample_idx": sample_idx} 186 | for sample_idx in range(args.res_num_per_query) 187 | ] 188 | return expanded_rows 189 | 190 | dataset = dataset.flat_map(preprocess_response_template, concurrency=4) 191 | dataset = run_vllm_inference_distributed( 192 | ds=dataset, 193 | model_path=args.res_model_path, 194 | tokenizer_path=args.res_model_path, 195 | prompt_key="prompt_res_gen", 196 | generation_key="response", 197 | max_tokens=args.res_max_tokens, 198 | max_model_len=args.max_model_len, 199 | num_generations=1, 200 | temperature=args.res_temperature, 201 | top_p=args.res_top_p, 202 | stop_tokens=stop_tokens, 203 | tensor_parallel_size=args.res_model_tp, 204 | swap_space=args.swap_space, 205 | ) 206 | 207 | def strip_data(line_data): 208 | line_data["response"] = line_data["response"][0].strip() 209 | return line_data 210 | 211 | def filter_data(line_data): 212 | response = line_data["response"] 213 | has_answer = "boxed" in response or "he answer is" in response or "final answer is" in response 214 | return has_answer 215 | 216 | dataset = dataset.map(strip_data, concurrency=4) 217 | if "math" in args.res_prompt_type: 218 | dataset = dataset.filter(filter_data, concurrency=4) 219 | 220 | res_gen_output_path = os.path.join( 221 | args.output_folder, 222 | f"{args.res_prompt_type}_resgen{args.qry_num}x{args.res_num_per_query}_temp{args.res_temperature}_topp{args.res_top_p}", 223 | ) 224 | dataset.write_json(res_gen_output_path) 225 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/train_question_generator/qpo_train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | from datasets import Dataset, load_dataset 8 | from dpo import PreferenceTrainer, PreferenceDataCollatorWithPadding 9 | from transformers import ( 10 | AutoModelForCausalLM, 11 | AutoTokenizer, 12 | HfArgumentParser, 13 | ) 14 | from trl import DPOConfig, DPOTrainer 15 | 16 | 17 | @dataclass 18 | class ScriptArguments: 19 | """ 20 | The arguments for the DPO training script. 21 | """ 22 | 23 | # data parameters, i.e., the KL penalty in the paper 24 | beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) 25 | 26 | # training parameters 27 | model_path: Optional[str] = field( 28 | default="HuggingFaceH4/mistral-7b-sft-beta", 29 | metadata={"help": "the location of the model name or path"}, 30 | ) 31 | ref_model: Optional[str] = field( 32 | default="", 33 | metadata={"help": "the location of the SFT model name or path"}, 34 | ) 35 | dataset_path: Optional[str] = field( 36 | default="./data/uf_split0_responses_K8_reward.json", 37 | metadata={"help": "the location of the dataset name or path"}, 38 | ) 39 | prompt_type: Optional[str] = field( 40 | default="qwen2-math", 41 | ) 42 | dpo_type: Optional[str] = field( 43 | default="query", 44 | ) 45 | learning_rate: Optional[float] = field(default=5e-7, metadata={"help": "optimizer learning rate"}) 46 | lr_scheduler_type: Optional[str] = field( 47 | default="constant_with_warmup", metadata={"help": "the lr scheduler type"} 48 | ) 49 | warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) 50 | weight_decay: Optional[float] = field(default=0.01, metadata={"help": "the weight decay"}) 51 | optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) 52 | 53 | per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"}) 54 | per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) 55 | gradient_accumulation_steps: Optional[int] = field( 56 | default=16, metadata={"help": "the number of gradient accumulation steps"} 57 | ) 58 | gradient_checkpointing: Optional[bool] = field( 59 | default=True, metadata={"help": "whether to use gradient checkpointing"} 60 | ) 61 | 62 | eos_padding: Optional[bool] = field(default=True, metadata={"help": "whether to pad with eos token"}) 63 | lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) 64 | lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) 65 | lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) 66 | 67 | margin_scale: Optional[float] = field(default=1.0, metadata={"help": "the margin scale"}) 68 | 69 | max_prompt_length: Optional[int] = field(default=1000, metadata={"help": "the maximum prompt length"}) 70 | max_length: Optional[int] = field(default=2048, metadata={"help": "the maximum sequence length"}) 71 | max_steps: Optional[int] = field(default=20, metadata={"help": "max number of training steps"}) 72 | num_train_epochs: Optional[int] = field(default=2, metadata={"help": "max number of training epochs"}) 73 | logging_steps: Optional[int] = field(default=2, metadata={"help": "the logging frequency"}) 74 | save_strategy: Optional[str] = field(default="epoch", metadata={"help": "the saving strategy"}) 75 | save_steps: Optional[int] = field(default=50000, metadata={"help": "the saving frequency"}) 76 | eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) 77 | run_name: Optional[str] = field(default="dpo_soft", metadata={"help": "the run name"}) 78 | loss_type: Optional[str] = field(default="sigmoid", metadata={"help": "the loss type"}) 79 | output_dir: Optional[str] = field(default="./dpo_soft", metadata={"help": "the output directory"}) 80 | log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) 81 | 82 | # instrumentation 83 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) 84 | 85 | max_training_samples: Optional[int] = field(default=-1, metadata={"help": "the maximum sample size"}) 86 | 87 | choose_type: Optional[str] = field(default="max_random", metadata={"help": "the choose type"}) 88 | 89 | report_to: Optional[str] = field( 90 | default="tensorboard", 91 | metadata={ 92 | "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' 93 | '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' 94 | 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' 95 | }, 96 | ) 97 | # debug argument for distributed training 98 | ignore_bias_buffers: Optional[bool] = field( 99 | default=False, 100 | metadata={ 101 | "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" 102 | "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" 103 | }, 104 | ) 105 | eot_token: Optional[str] = field(default="", metadata={"help": "the end of text token"}) 106 | mask_prompt: Optional[bool] = field(default=False, metadata={"help": "mask prompt"}) 107 | len_penalty: Optional[float] = field(default=0, metadata={"help": "the length penalty"}) 108 | 109 | 110 | def prepare_data( 111 | dataset, 112 | sanity_check: bool = False, 113 | cache_dir: str = None, 114 | num_proc=24, 115 | margin_scale=1, 116 | choose_type="random", 117 | eot_token="", 118 | length_penalty=0, 119 | ) -> Dataset: 120 | """Prepare the dataset for DPO training by rejection sampling. 121 | We implement different strategies to select pairs, including 122 | max_min: best v.s. worst 123 | max_random: best v.s. random from the remaining; 124 | max_max: best v.s. second best 125 | max_min_p: best v.s. worst but we additionally add a length penalty in the reward value 126 | """ 127 | 128 | pos = [] 129 | neg = [] 130 | prompts = [] 131 | 132 | margin = [] 133 | for sample in dataset: 134 | if choose_type == "random": 135 | idx0 = 0 136 | idx1 = 1 137 | elif choose_type == "max_random": 138 | idx0 = np.argmax(sample["rewards"]) 139 | if idx0 == 0: 140 | idx1 = 1 141 | else: 142 | idx1 = 0 143 | elif choose_type == "max_min": 144 | idx0 = np.argmax(sample["rewards"]) 145 | idx1 = np.argmin(sample["rewards"]) 146 | elif choose_type == "max_max": 147 | sorted_indices = np.argsort(sample["rewards"]) 148 | idx0 = sorted_indices[-1] 149 | idx1 = sorted_indices[-2] 150 | elif choose_type == "max_min_p": 151 | r = [ 152 | sample["rewards"][i] - length_penalty * len(sample["responses"][i]) 153 | for i in range(len(sample["rewards"])) 154 | ] 155 | idx0 = np.argmax(r) 156 | idx1 = np.argmin(r) 157 | else: 158 | raise NotImplementedError 159 | 160 | if type(idx0) == np.ndarray or type(idx0) == list: 161 | assert len(idx0) == len(idx1) 162 | for i in range(len(idx0)): 163 | prompts.append(sample["prompt"]) 164 | pos.append(sample["responses"][idx0[i]] + eot_token) 165 | neg.append(sample["responses"][idx1[i]] + eot_token) 166 | margin.append((sample["rewards"][idx0[i]] - sample["rewards"][idx1[i]]) * margin_scale) 167 | else: 168 | if sample["rewards"][idx0] > sample["rewards"][idx1]: 169 | prompts.append(sample["prompt"]) 170 | pos.append(sample["responses"][idx0] + eot_token) 171 | neg.append(sample["responses"][idx1] + eot_token) 172 | margin.append((sample["rewards"][idx0] - sample["rewards"][idx1]) * margin_scale) 173 | elif sample["rewards"][idx0] < sample["rewards"][idx1]: 174 | prompts.append(sample["prompt"]) 175 | pos.append(sample["responses"][idx1] + eot_token) 176 | neg.append(sample["responses"][idx0] + eot_token) 177 | margin.append((-sample["rewards"][idx0] + sample["rewards"][idx1]) * margin_scale) 178 | dataset = Dataset.from_dict({"prompt": prompts, "chosen": pos, "rejected": neg, "margin": margin}) 179 | 180 | if sanity_check: 181 | dataset = dataset.select(range(min(len(dataset), 100))) 182 | 183 | return dataset 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = HfArgumentParser(ScriptArguments) 188 | script_args = parser.parse_args_into_dataclasses()[0] 189 | 190 | # 1. load a pretrained model 191 | model = AutoModelForCausalLM.from_pretrained( 192 | script_args.model_path, 193 | use_flash_attention_2=True, 194 | torch_dtype=torch.float16, 195 | ) 196 | model.config.use_cache = False 197 | 198 | if script_args.ignore_bias_buffers: 199 | # torch distributed hack 200 | model._ddp_params_and_buffers_to_ignore = [ 201 | name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool 202 | ] 203 | 204 | if script_args.ref_model: 205 | ref_name = script_args.ref_model 206 | else: 207 | ref_name = script_args.model_path 208 | 209 | model_ref = AutoModelForCausalLM.from_pretrained( 210 | ref_name, 211 | torch_dtype=torch.bfloat16, 212 | use_flash_attention_2=True, 213 | ) 214 | tokenizer = AutoTokenizer.from_pretrained(script_args.model_path) 215 | 216 | if script_args.eos_padding: 217 | tokenizer.pad_token = tokenizer.eos_token 218 | else: 219 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 220 | model.config.vocab_size += 1 221 | model_ref.config.vocab_size += 1 222 | model.config.pad_token_id = tokenizer.pad_token_id 223 | model_ref.config.pad_token_id = tokenizer.pad_token_id 224 | model.resize_token_embeddings(len(tokenizer)) 225 | model_ref.resize_token_embeddings(len(tokenizer)) 226 | 227 | def tokenize(sample): 228 | tokenized_pos = tokenizer(sample["prompt"].replace("", "") + "\n" + sample["chosen"]) 229 | tokenized_neg = tokenizer(sample["prompt"].replace("", "") + "\n" + sample["rejected"]) 230 | prompt_id = tokenizer(sample["prompt"]) 231 | sample["tprompdt_ids"] = prompt_id["input_ids"] 232 | sample["tchosen_input_ids"] = tokenized_pos["input_ids"] 233 | sample["trejected_input_ids"] = tokenized_neg["input_ids"] 234 | return sample 235 | 236 | def process_dataset(sample): 237 | if script_args.dpo_type == "query": 238 | if "deepseek" in script_args.prompt_type: 239 | prompt = "<|begin▁of▁sentence|>User: " 240 | elif "qwen2" in script_args.prompt_type: 241 | prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n" 242 | elif script_args.dpo_type == "response": 243 | if "deepseek" in script_args.prompt_type: 244 | prompt = "<|begin▁of▁sentence|>User: " + sample["query"] + "\nPlease reason step by step, and put your final answer within \\boxed{}.\n\nAssistant:" 245 | elif "qwen2" in script_args.prompt_type: 246 | prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n" + sample["query"] + "\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>assistant\n" 247 | return {"prompt": prompt} 248 | 249 | # 2. Load the Stack-exchange paired dataset 250 | dataset = load_dataset(script_args.dataset_path, split="train").train_test_split(test_size=100) 251 | train_dataset = dataset["train"].map(process_dataset) 252 | # train_dataset = train_dataset.map(process_dataset) 253 | # train_dataset = prepare_data( 254 | # dataset=train_dataset, 255 | # margin_scale=script_args.margin_scale, 256 | # sanity_check=script_args.sanity_check, 257 | # choose_type=script_args.choose_type, 258 | # eot_token=script_args.eot_token, 259 | # length_penalty=script_args.len_penalty, 260 | # ) 261 | 262 | if script_args.max_training_samples > 0: 263 | train_dataset = train_dataset.select(range(script_args.max_training_samples)) 264 | 265 | # 3. Load evaluation dataset 266 | eval_dataset = dataset["test"].map(process_dataset) 267 | # eval_dataset = prepare_data( 268 | # dataset=eval_dataset, 269 | # sanity_check=True, 270 | # margin_scale=script_args.margin_scale, 271 | # eot_token=script_args.eot_token, 272 | # ) 273 | 274 | # 4. initialize training arguments: 275 | 276 | training_args = DPOConfig( 277 | per_device_train_batch_size=script_args.per_device_train_batch_size, 278 | per_device_eval_batch_size=script_args.per_device_eval_batch_size, 279 | # max_steps=script_args.max_steps, 280 | num_train_epochs=script_args.num_train_epochs, 281 | save_strategy=script_args.save_strategy, 282 | logging_steps=script_args.logging_steps, 283 | save_steps=script_args.save_steps, 284 | gradient_accumulation_steps=script_args.gradient_accumulation_steps, 285 | gradient_checkpointing=script_args.gradient_checkpointing, 286 | learning_rate=script_args.learning_rate, 287 | evaluation_strategy="steps", 288 | eval_steps=script_args.eval_steps, 289 | output_dir=script_args.output_dir, 290 | report_to=script_args.report_to, 291 | lr_scheduler_type=script_args.lr_scheduler_type, 292 | warmup_steps=script_args.warmup_steps, 293 | # optim=script_args.optimizer_type, 294 | bf16=True, 295 | remove_unused_columns=False, 296 | run_name=script_args.run_name, 297 | ) 298 | print(training_args) 299 | 300 | # 5. initialize the DPO trainer 301 | dpo_trainer = PreferenceTrainer( 302 | model, 303 | model_ref, 304 | args=training_args, 305 | beta=script_args.beta, 306 | train_dataset=train_dataset, 307 | eval_dataset=eval_dataset, 308 | tokenizer=tokenizer, 309 | loss_type=script_args.loss_type, 310 | max_prompt_length=script_args.max_prompt_length, 311 | max_length=script_args.max_length, 312 | mask_prompt=script_args.mask_prompt, 313 | len_penalty=script_args.len_penalty, 314 | ) 315 | # data_collator = PreferenceDataCollatorWithPadding( 316 | # tokenizer, 317 | # max_length=script_args.max_length, 318 | # max_prompt_length=script_args.max_prompt_length, 319 | # is_encoder_decoder=False, 320 | # ) 321 | # dpo_trainer = DPOTrainer( 322 | # model, 323 | # model_ref, 324 | # args=training_args, 325 | # beta=script_args.beta, 326 | # train_dataset=train_dataset, 327 | # eval_dataset=eval_dataset, 328 | # tokenizer=tokenizer, 329 | # max_length=script_args.max_length, 330 | # max_prompt_length=script_args.max_prompt_length, 331 | # loss_type=script_args.loss_type, 332 | # data_collator=data_collator, 333 | # ) 334 | print("begin to train") 335 | 336 | # 6. train 337 | dpo_trainer.train() 338 | dpo_trainer.save_model(script_args.output_dir) 339 | 340 | # 7. save 341 | output_dir = os.path.join(script_args.output_dir, "final_checkpoint") 342 | dpo_trainer.model.save_pretrained(output_dir) 343 | -------------------------------------------------------------------------------- /src/train_question_generator/qpo_train/dpo.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from datasets import Dataset 7 | # from peft import AutoPeftModelForCausalLM, LoraConfig 8 | from torch import nn 9 | from torch.nn.utils.rnn import pad_sequence 10 | from transformers import ( 11 | DataCollator, 12 | PreTrainedModel, 13 | PreTrainedTokenizerBase, 14 | TrainerCallback, 15 | TrainingArguments, 16 | ) 17 | from transformers.trainer_callback import TrainerCallback 18 | from transformers.trainer_utils import EvalLoopOutput 19 | from trl import DPOTrainer 20 | 21 | # Define and parse arguments. 22 | 23 | @dataclass 24 | class PreferenceDataCollatorWithPadding: 25 | tokenizer: PreTrainedTokenizerBase 26 | model: Optional[PreTrainedModel] = None 27 | padding: Union[bool, str] = True 28 | max_length: Optional[int] = None 29 | max_prompt_length: Optional[int] = None 30 | label_pad_token_id: int = -100 31 | padding_value: int = 0 32 | truncation_mode: str = "keep_end" 33 | is_encoder_decoder: Optional[bool] = False 34 | max_target_length: Optional[int] = None 35 | mask_prompt: Optional[bool] = False 36 | 37 | def tokenize_batch_element( 38 | self, 39 | prompt: str, 40 | chosen: str, 41 | rejected: str, 42 | ) -> Dict: 43 | """Tokenize a single batch element. 44 | 45 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation 46 | in case the prompt + chosen or prompt + rejected responses is/are too long. First 47 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected. 48 | 49 | We also create the labels for the chosen/rejected responses, which are of length equal to 50 | the sum of the length of the prompt and the chosen/rejected response, with 51 | label_pad_token_id for the prompt tokens. 52 | """ 53 | batch = {} 54 | 55 | if not self.is_encoder_decoder: 56 | chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) 57 | rejected_tokens = self.tokenizer(rejected, add_special_tokens=False) 58 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) 59 | 60 | eos_token_id = self.tokenizer.eos_token_id 61 | # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0) 62 | eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id] 63 | # attention mask these indices to eos_token_id 64 | if self.mask_prompt: 65 | new_attention_mask = [0 for i, p in enumerate(prompt_tokens["attention_mask"])] 66 | else: 67 | new_attention_mask = [ 68 | 0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"]) 69 | ] 70 | prompt_tokens["attention_mask"] = new_attention_mask 71 | 72 | # do the same for chosen and rejected 73 | eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id] 74 | new_attention_mask_c = [ 75 | 0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"]) 76 | ] 77 | chosen_tokens["attention_mask"] = new_attention_mask_c 78 | 79 | eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id] 80 | new_attention_mask_r = [ 81 | 0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"]) 82 | ] 83 | rejected_tokens["attention_mask"] = new_attention_mask_r 84 | 85 | # add EOS token to end of prompt 86 | 87 | chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) 88 | chosen_tokens["attention_mask"].append(1) 89 | 90 | rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) 91 | rejected_tokens["attention_mask"].append(1) 92 | 93 | longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) 94 | 95 | # if combined sequence is too long, truncate the prompt 96 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: 97 | if self.truncation_mode == "keep_start": 98 | prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} 99 | elif self.truncation_mode == "keep_end": 100 | prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} 101 | else: 102 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") 103 | 104 | # if that's still too long, truncate the response 105 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: 106 | chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} 107 | rejected_tokens = { 108 | k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items() 109 | } 110 | 111 | # Create labels 112 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} 113 | rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} 114 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] 115 | chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( 116 | prompt_tokens["input_ids"] 117 | ) 118 | rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] 119 | rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( 120 | prompt_tokens["input_ids"] 121 | ) 122 | 123 | for k, toks in { 124 | "chosen": chosen_sequence_tokens, 125 | "rejected": rejected_sequence_tokens, 126 | "prompt": prompt_tokens, 127 | }.items(): 128 | for type_key, tokens in toks.items(): 129 | if type_key == "token_type_ids": 130 | continue 131 | batch[f"{k}_{type_key}"] = tokens 132 | 133 | else: 134 | raise NotImplementedError 135 | 136 | batch["prompt"] = prompt 137 | batch["chosen"] = prompt + chosen 138 | batch["rejected"] = prompt + rejected 139 | batch["chosen_response_only"] = chosen 140 | batch["rejected_response_only"] = rejected 141 | 142 | return batch 143 | 144 | def collate(self, batch): 145 | # first, pad everything to the same length 146 | padded_batch = {} 147 | for k in batch[0].keys(): 148 | if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): 149 | if self.is_encoder_decoder: 150 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 151 | 152 | if (k.startswith("prompt")) and (k.endswith("input_ids")): 153 | padding_value = self.tokenizer.pad_token_id 154 | elif k.endswith("_attention_mask"): 155 | padding_value = 0 156 | elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): 157 | padding_value = self.label_pad_token_id 158 | else: 159 | raise ValueError(f"Unexpected key in batch '{k}'") 160 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 161 | else: 162 | # adapted from https://stackoverflow.com/questions/73256206 163 | if "prompt" in k: 164 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] 165 | else: 166 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 167 | if k.endswith("_input_ids"): 168 | padding_value = self.tokenizer.pad_token_id 169 | elif k.endswith("_labels"): 170 | padding_value = self.label_pad_token_id 171 | elif k.endswith("_attention_mask"): 172 | padding_value = self.padding_value 173 | else: 174 | raise ValueError(f"Unexpected key in batch '{k}'") 175 | 176 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 177 | # for the prompt, flip back so padding is on left side 178 | if "prompt" in k: 179 | padded_batch[k] = padded_batch[k].flip(dims=[1]) 180 | else: 181 | padded_batch[k] = [ex[k] for ex in batch] 182 | 183 | return padded_batch 184 | 185 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 186 | tokenized_batch = [] 187 | 188 | for feature in features: 189 | prompt = feature["prompt"] 190 | chosen = feature["chosen"] 191 | rejected = feature["rejected"] 192 | 193 | batch_element = self.tokenize_batch_element(prompt, chosen, rejected) 194 | # batch_element["margin"] = feature["margin"] 195 | tokenized_batch.append(batch_element) 196 | 197 | # return collated batch 198 | return self.collate(tokenized_batch) 199 | 200 | 201 | class PreferenceTrainer(DPOTrainer): 202 | def __init__( 203 | self, 204 | model: Union[PreTrainedModel, nn.Module] = None, 205 | ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, 206 | beta: float = 0.1, 207 | loss_type: Literal["sigmoid", "hinge", "cross_entropy", "kl", "rev_kl", "raft"] = "rev_kl", 208 | args: TrainingArguments = None, 209 | data_collator: Optional[DataCollator] = None, 210 | label_pad_token_id: int = -100, 211 | padding_value: int = 0, 212 | truncation_mode: str = "keep_end", 213 | train_dataset: Optional[Dataset] = None, 214 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 215 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 216 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 217 | callbacks: Optional[List[TrainerCallback]] = None, 218 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 219 | None, 220 | None, 221 | ), 222 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 223 | max_length: Optional[int] = None, 224 | max_prompt_length: Optional[int] = None, 225 | max_target_length: Optional[int] = None, 226 | peft_config: Optional[Dict] = None, 227 | is_encoder_decoder: Optional[bool] = None, 228 | disable_dropout: bool = True, 229 | generate_during_eval: bool = False, 230 | compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, 231 | mask_prompt: Optional[bool] = False, 232 | len_penalty: float = 0, 233 | ): 234 | 235 | if data_collator is None: 236 | data_collator = PreferenceDataCollatorWithPadding( 237 | tokenizer, 238 | max_length=max_length, 239 | max_prompt_length=max_prompt_length, 240 | label_pad_token_id=label_pad_token_id, 241 | padding_value=padding_value, 242 | truncation_mode=truncation_mode, 243 | is_encoder_decoder=False, 244 | max_target_length=max_target_length, 245 | mask_prompt=mask_prompt, 246 | ) 247 | super().__init__( 248 | model=model, 249 | ref_model=ref_model, 250 | beta=beta, 251 | loss_type=loss_type, 252 | args=args, 253 | data_collator=data_collator, 254 | label_pad_token_id=label_pad_token_id, 255 | padding_value=padding_value, 256 | truncation_mode=truncation_mode, 257 | train_dataset=train_dataset, 258 | eval_dataset=eval_dataset, 259 | tokenizer=tokenizer, 260 | model_init=model_init, 261 | callbacks=callbacks, 262 | optimizers=optimizers, 263 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 264 | max_length=max_length, 265 | max_prompt_length=max_prompt_length, 266 | max_target_length=max_target_length, 267 | peft_config=peft_config, 268 | is_encoder_decoder=is_encoder_decoder, 269 | disable_dropout=disable_dropout, 270 | generate_during_eval=generate_during_eval, 271 | compute_metrics=compute_metrics, 272 | ) 273 | self.use_dpo_data_collator = True 274 | self.len_penalty = len_penalty 275 | 276 | def dpo_loss( 277 | self, 278 | policy_chosen_logps: torch.FloatTensor, 279 | policy_rejected_logps: torch.FloatTensor, 280 | reference_chosen_logps: torch.FloatTensor, 281 | reference_rejected_logps: torch.FloatTensor, 282 | reference_free: bool = False, 283 | margin: Optional[torch.FloatTensor] = None, 284 | len_penalty: float = 0, 285 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 286 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 287 | 288 | Args: 289 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 290 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 291 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 292 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 293 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 294 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. 295 | 296 | Returns: 297 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 298 | The losses tensor contains the DPO loss for each example in the batch. 299 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 300 | """ 301 | pi_logratios = policy_chosen_logps - policy_rejected_logps 302 | ref_logratios = reference_chosen_logps - reference_rejected_logps + len_penalty 303 | 304 | if reference_free: 305 | ref_logratios = 0 306 | 307 | if self.loss_type == "sigmoid": 308 | logits = pi_logratios - ref_logratios 309 | losses = -F.logsigmoid(self.beta * logits) 310 | elif self.loss_type == "hinge": 311 | logits = pi_logratios - ref_logratios 312 | losses = torch.relu(1 - self.beta * logits) 313 | elif self.loss_type == "cross_entropy": 314 | logits = policy_chosen_logps - reference_chosen_logps 315 | losses = -F.logsigmoid(self.beta * logits) 316 | elif self.loss_type == "raft": 317 | losses = -policy_chosen_logps # F.logsigmoid(self.beta * logits) 318 | elif self.loss_type == "ipo": 319 | logits = pi_logratios - ref_logratios 320 | # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. 321 | losses = (logits - 1 / (2 * self.beta)) ** 2 322 | elif self.loss_type == "kl": 323 | logits = pi_logratios - ref_logratios 324 | p = F.sigmoid(self.beta * logits) 325 | p = torch.minimum(p, torch.ones_like(p) * 0.999) 326 | p_gt = torch.exp(margin) / (1 + torch.exp(margin) + 1e-3) 327 | losses = p * (torch.log(p) - torch.log(p_gt)) + (1 - p) * (torch.log(1 - p) - torch.log(1 - p_gt)) 328 | elif self.loss_type == "tv": 329 | logits = pi_logratios - ref_logratios 330 | p = F.sigmoid(self.beta * logits) 331 | p_gt = torch.exp(margin) / (1 + torch.exp(margin)) 332 | losses = torch.abs(p - p_gt) 333 | elif self.loss_type == "hellinger": 334 | logits = pi_logratios - ref_logratios 335 | p = F.sigmoid(self.beta * logits) 336 | p = torch.minimum(p, torch.ones_like(p) * 0.999) 337 | p_gt = torch.exp(margin) / (1 + torch.exp(margin)) 338 | losses = 0.5 * ((p**0.5 - p_gt**0.5) ** 2 + ((1 - p) ** 0.5 - (1 - p_gt) ** 0.5) ** 2) 339 | elif self.loss_type == "rev_kl": 340 | logits = pi_logratios - ref_logratios 341 | logp = F.logsigmoid(self.beta * logits) 342 | logp_neg = F.logsigmoid(-self.beta * logits) 343 | p_gt = F.sigmoid(margin) 344 | losses = -p_gt * (logp) - (1 - p_gt) * logp_neg 345 | else: 346 | raise ValueError(f"Unknown loss type: {self.loss_type}.") 347 | 348 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() 349 | rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() 350 | 351 | return losses, chosen_rewards, rejected_rewards 352 | 353 | def get_batch_loss_metrics( 354 | self, 355 | model, 356 | batch: Dict[str, Union[List, torch.LongTensor]], 357 | train_eval: Literal["train", "eval"] = "train", 358 | ): 359 | return self.get_batch_metrics(model, batch, train_eval) 360 | 361 | def get_batch_metrics( 362 | self, 363 | model, 364 | batch: Dict[str, Union[List, torch.LongTensor]], 365 | train_eval: Literal["train", "eval"] = "train", 366 | ): 367 | """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" 368 | metrics = {} 369 | ( 370 | policy_chosen_logps, 371 | policy_rejected_logps, 372 | policy_chosen_logits, 373 | policy_rejected_logits, 374 | ) = self.concatenated_forward(model, batch)[:4] 375 | with torch.no_grad(): 376 | if self.ref_model is None: 377 | with self.accelerator.unwrap_model(self.model).disable_adapter(): 378 | ( 379 | reference_chosen_logps, 380 | reference_rejected_logps, 381 | _, 382 | _, 383 | _, 384 | ) = self.concatenated_forward(self.model, batch) 385 | else: 386 | ( 387 | reference_chosen_logps, 388 | reference_rejected_logps, 389 | _, 390 | _, 391 | _, 392 | ) = self.concatenated_forward(self.ref_model, batch) 393 | if self.len_penalty > 0: 394 | chosen_len = batch["chosen_input_ids"].shape[1] * self.len_penalty 395 | rejected_len = batch["rejected_input_ids"].shape[1] * self.len_penalty 396 | len_penalty = chosen_len - rejected_len 397 | else: 398 | chosen_len = 1 399 | rejected_len = 1 400 | len_penalty = 0 401 | 402 | # margin = torch.tensor(batch["margin"], dtype=policy_chosen_logps.dtype).to(self.accelerator.device) 403 | margin = None 404 | losses, chosen_rewards, rejected_rewards = self.dpo_loss( 405 | policy_chosen_logps, 406 | policy_rejected_logps, 407 | reference_chosen_logps, 408 | reference_rejected_logps, 409 | margin=margin, 410 | len_penalty=len_penalty, 411 | ) 412 | reward_accuracies = (chosen_rewards > rejected_rewards).float() 413 | 414 | prefix = "eval_" if train_eval == "eval" else "" 415 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() 416 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() 417 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() 418 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() 419 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() 420 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() 421 | metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() 422 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() 423 | 424 | return losses.mean(), metrics --------------------------------------------------------------------------------