├── .DS_Store ├── data ├── sft_train.jsonl ├── grpo_train.jsonl └── dpo_train.jsonl ├── configs ├── accelerate_deepspeed.yaml ├── accelerate_fsdp.yaml └── ds_zero3.json ├── requirements.txt ├── src ├── rewarding │ └── rules.py ├── utils.py └── algos │ ├── dpo_runner.py │ ├── sft_runner.py │ ├── grpo_runner.py │ └── dapo_runner.py ├── README.md └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zv1131860787/Base-RL/HEAD/.DS_Store -------------------------------------------------------------------------------- /data/sft_train.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt":"写个祝福语。","response":"祝你事事顺心,天天开心!"} 2 | {"prompt":"解释为什么星星会发光(简短)。","response":"恒星通过核聚变释放能量,产生光与热。"} 3 | -------------------------------------------------------------------------------- /data/grpo_train.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt":"Q: 2+3 等于几?请在末尾输出 final_answer: ","ref_answer":"5"} 2 | {"prompt":"Q: 12*7 ?请逐步思考并输出 final_answer: ","ref_answer":"84"} 3 | -------------------------------------------------------------------------------- /data/dpo_train.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt":"说个笑话","chosen":"今天的太阳打卡迟到了,所以天空多睡了一会儿。","rejected":"笑话。"} 2 | {"prompt":"给面试官写一封感谢信","chosen":"感谢您抽空面试我,收获颇丰。期待后续交流。","rejected":"谢谢。"} 3 | -------------------------------------------------------------------------------- /configs/accelerate_deepspeed.yaml: -------------------------------------------------------------------------------- 1 | 2 | compute_environment: LOCAL_MACHINE 3 | distributed_type: DEEPSPEED 4 | deepspeed_config: 5 | deepspeed_config_file: configs/ds_zero3.json 6 | mixed_precision: bf16 7 | gpu_ids: all 8 | num_processes: null 9 | dynamo_backend: NO 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | transformers>=4.44.0 3 | accelerate>=0.33.0 4 | trl>=0.9.6 5 | peft>=0.12.0 6 | datasets>=2.18.0 7 | bitsandbytes; platform_system != "Windows" 8 | einops 9 | wandb 10 | sympy 11 | numpy 12 | tensorboard 13 | torch # install your CUDA build separately 14 | -------------------------------------------------------------------------------- /configs/accelerate_fsdp.yaml: -------------------------------------------------------------------------------- 1 | 2 | compute_environment: LOCAL_MACHINE 3 | distributed_type: FSDP 4 | mixed_precision: bf16 5 | gpu_ids: all 6 | num_processes: null 7 | fsdp_config: 8 | sharding_strategy: FULL_SHARD 9 | offload_params: false 10 | auto_wrap_policy: TRANSFORMER_BASED_WRAP 11 | activation_checkpointing: true 12 | sync_module_states: true 13 | state_dict_type: FULL_STATE_DICT 14 | limit_all_gathers: true 15 | use_orig_params: false 16 | forward_prefetch: true 17 | backward_prefetch: BACKWARD_PRE 18 | param_init_fn: true 19 | cpu_checkpointing: false 20 | pure_bf16: false 21 | xla: false 22 | dynamo_backend: NO 23 | -------------------------------------------------------------------------------- /configs/ds_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "bf16": { 6 | "enabled": true 7 | }, 8 | "zero_optimization": { 9 | "stage": 3, 10 | "overlap_comm": true, 11 | "contiguous_gradients": true, 12 | "reduce_bucket_size": 500000000, 13 | "stage3_prefetch_bucket_size": 500000000, 14 | "stage3_param_persistence_threshold": 1000000, 15 | "offload_param": { 16 | "device": "none" 17 | }, 18 | "offload_optimizer": { 19 | "device": "none" 20 | } 21 | }, 22 | "activation_checkpointing": { 23 | "partition_activations": true, 24 | "cpu_checkpointing": false 25 | } 26 | } -------------------------------------------------------------------------------- /src/rewarding/rules.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | from typing import Any, Dict, List, Optional 4 | 5 | FINAL_ANS_RE = re.compile(r"final_answer\s*:\s*([-\d\.]+)", re.IGNORECASE) 6 | 7 | def extract_final_answer(text: str) -> Optional[str]: 8 | m = FINAL_ANS_RE.search(text) 9 | return m.group(1) if m else None 10 | 11 | def compute_rule_reward(prompt: str, output: str, ref_answer: Optional[str]) -> float: 12 | score = 0.0 13 | if output and "final_answer" in output.lower(): 14 | score += 0.2 15 | L = len((output or "").split()) 16 | if 10 <= L <= 800: 17 | score += 0.1 18 | else: 19 | score -= 0.05 20 | if ref_answer is not None: 21 | pred = extract_final_answer(output or "") 22 | if pred is not None and pred.strip() == str(ref_answer).strip(): 23 | score += 1.0 24 | if any(k in (output or "").lower() for k in ["因此", "所以", "we have", "推理", "hence", "because"]): 25 | score += 0.1 26 | return max(0.0, min(1.0, score)) 27 | 28 | def reward_fn(samples: List[Dict[str, Any]]) -> List[float]: 29 | rewards = [] 30 | for s in samples: 31 | prompt = s.get("prompt", "") 32 | output = s.get("output", "") 33 | ref = s.get("meta", {}).get("ref_answer", None) 34 | r = compute_rule_reward(prompt, output, ref) 35 | rewards.append(r) 36 | return rewards 37 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 4 | from peft import LoraConfig, get_peft_model 5 | 6 | def build_model_and_tokenizer(model_name: str, qlora: bool = True, use_lora: bool = True, 7 | lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, 8 | target_modules=None): 9 | bnb_cfg = None 10 | if qlora: 11 | bnb_cfg = BitsAndBytesConfig( 12 | load_in_4bit=True, bnb_4bit_quant_type="nf4", 13 | bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, 14 | ) 15 | 16 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.pad_token = tokenizer.eos_token 19 | 20 | model = AutoModelForCausalLM.from_pretrained( 21 | model_name, 22 | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, 23 | device_map="auto", 24 | quantization_config=bnb_cfg, 25 | ) 26 | 27 | if use_lora: 28 | if target_modules is None: 29 | target_modules = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] 30 | lora = LoraConfig( 31 | r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, 32 | bias="none", task_type="CAUSAL_LM", target_modules=target_modules, 33 | ) 34 | model = get_peft_model(model, lora) 35 | return model, tokenizer 36 | -------------------------------------------------------------------------------- /src/algos/dpo_runner.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from datasets import Dataset 4 | from trl import DPOTrainer, DPOConfig 5 | from ..utils import build_model_and_tokenizer 6 | 7 | @dataclass 8 | class DPOArgs: 9 | model_name: str 10 | dataset_path: str 11 | output_dir: str = "outputs/dpo" 12 | dpo_beta: float = 0.1 13 | learning_rate: float = 5e-6 14 | per_device_train_batch_size: int = 1 15 | gradient_accumulation_steps: int = 4 16 | max_steps: int = 1000 17 | logging_steps: int = 10 18 | save_steps: int = 200 19 | warmup_ratio: float = 0.05 20 | weight_decay: float = 0.0 21 | 22 | def run_dpo(a: DPOArgs): 23 | import json 24 | rows = [] 25 | with open(a.dataset_path, "r", encoding="utf-8") as f: 26 | for line in f: 27 | j = json.loads(line) 28 | rows.append({"prompt": j["prompt"], "chosen": j["chosen"], "rejected": j["rejected"]}) 29 | ds = Dataset.from_list(rows) 30 | 31 | model, tokenizer = build_model_and_tokenizer(a.model_name, qlora=True, use_lora=True) 32 | 33 | cfg = DPOConfig( 34 | output_dir=a.output_dir, 35 | per_device_train_batch_size=a.per_device_train_batch_size, 36 | gradient_accumulation_steps=a.gradient_accumulation_steps, 37 | learning_rate=a.learning_rate, 38 | warmup_ratio=a.warmup_ratio, 39 | weight_decay=a.weight_decay, 40 | bf16=True, 41 | logging_steps=a.logging_steps, 42 | save_steps=a.save_steps, 43 | max_steps=a.max_steps, 44 | beta=a.dpo_beta, 45 | report_to=["tensorboard"], 46 | logging_dir=f"{a.output_dir}/tb", 47 | ) 48 | 49 | trainer = DPOTrainer(model=model, processing_class=tokenizer, args=cfg, train_dataset=ds) 50 | trainer.train() 51 | trainer.save_model(a.output_dir) 52 | tokenizer.save_pretrained(a.output_dir) 53 | -------------------------------------------------------------------------------- /src/algos/sft_runner.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Dict 4 | from datasets import Dataset 5 | from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling 6 | from ..utils import build_model_and_tokenizer 7 | 8 | @dataclass 9 | class SFTArgs: 10 | model_name: str 11 | dataset_path: str 12 | output_dir: str = "outputs/sft" 13 | sft_max_length: int = 2048 14 | learning_rate: float = 2e-5 15 | num_train_epochs: int = 1 16 | per_device_train_batch_size: int = 1 17 | gradient_accumulation_steps: int = 4 18 | warmup_ratio: float = 0.05 19 | weight_decay: float = 0.0 20 | logging_steps: int = 10 21 | save_steps: int = 200 22 | 23 | def _load_sft(path: str) -> Dataset: 24 | import json 25 | rows = [] 26 | with open(path, "r", encoding="utf-8") as f: 27 | for line in f: 28 | j = json.loads(line) 29 | rows.append(j) 30 | return Dataset.from_list(rows) 31 | 32 | def run_sft(a: SFTArgs): 33 | model, tokenizer = build_model_and_tokenizer(a.model_name, qlora=True, use_lora=True) 34 | ds = _load_sft(a.dataset_path) 35 | 36 | EOS = tokenizer.eos_token or "" 37 | 38 | def to_example(ex: Dict): 39 | text = f"{ex['prompt']}{EOS}{ex['response']}{EOS}" 40 | tok = tokenizer(text, truncation=True, max_length=a.sft_max_length) 41 | input_ids = tok["input_ids"] 42 | # mask loss on prompt 43 | prompt_ids = tokenizer(f"{ex['prompt']}{EOS}", truncation=True, max_length=a.sft_max_length)["input_ids"] 44 | labels = [-100]*len(prompt_ids) + input_ids[len(prompt_ids):] 45 | labels = labels[:len(input_ids)] 46 | return {"input_ids": input_ids, "labels": labels} 47 | 48 | ds = ds.map(to_example, remove_columns=ds.column_names) 49 | ds.set_format(type="torch") 50 | 51 | collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) 52 | args = TrainingArguments( 53 | output_dir=a.output_dir, 54 | per_device_train_batch_size=a.per_device_train_batch_size, 55 | gradient_accumulation_steps=a.gradient_accumulation_steps, 56 | num_train_epochs=a.num_train_epochs, 57 | learning_rate=a.learning_rate, 58 | warmup_ratio=a.warmup_ratio, 59 | weight_decay=a.weight_decay, 60 | logging_steps=a.logging_steps, 61 | save_steps=a.save_steps, 62 | bf16=True, 63 | report_to=["tensorboard"], 64 | logging_dir=f"{a.output_dir}/tb", 65 | ) 66 | trainer = Trainer(model=model, args=args, train_dataset=ds, data_collator=collator, tokenizer=tokenizer) 67 | trainer.train() 68 | trainer.save_model(a.output_dir) 69 | tokenizer.save_pretrained(a.output_dir) 70 | -------------------------------------------------------------------------------- /src/algos/grpo_runner.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from datasets import Dataset 4 | from trl import GRPOConfig, GRPOTrainer 5 | from ..utils import build_model_and_tokenizer 6 | from ..rewarding.rules import reward_fn 7 | 8 | @dataclass 9 | class GRPOArgs: 10 | model_name: str 11 | dataset_path: str 12 | output_dir: str = "outputs/grpo" 13 | max_prompt_len: int = 1024 14 | max_new_tokens: int = 512 15 | temperature: float = 0.7 16 | top_p: float = 0.9 17 | top_k: int = 50 18 | num_generations: int = 4 19 | learning_rate: float = 5e-6 20 | warmup_ratio: float = 0.05 21 | weight_decay: float = 0.0 22 | gradient_accumulation_steps: int = 4 23 | per_device_train_batch_size: int = 1 24 | max_steps: int = 1000 25 | logging_steps: int = 10 26 | save_steps: int = 200 27 | eval_steps: int = 0 28 | kl_coeff: float = 0.02 29 | target_kl: float = 1.0 30 | use_lora: bool = True 31 | qlora: bool = True 32 | use_vllm: bool = False 33 | vllm_gpu_memory_utilization: float = 0.90 34 | 35 | def run_grpo(a: GRPOArgs): 36 | import json 37 | rows = [] 38 | with open(a.dataset_path, "r", encoding="utf-8") as f: 39 | for line in f: 40 | j = json.loads(line) 41 | rows.append({"prompt": j["prompt"], "meta": {"ref_answer": j.get("ref_answer")}}) 42 | ds = Dataset.from_list(rows) 43 | 44 | model, tokenizer = build_model_and_tokenizer(a.model_name, qlora=a.qlora, use_lora=a.use_lora) 45 | 46 | cfg = GRPOConfig( 47 | output_dir=a.output_dir, 48 | per_device_train_batch_size=a.per_device_train_batch_size, 49 | gradient_accumulation_steps=a.gradient_accumulation_steps, 50 | learning_rate=a.learning_rate, 51 | weight_decay=a.weight_decay, 52 | warmup_ratio=a.warmup_ratio, 53 | logging_steps=a.logging_steps, 54 | save_steps=a.save_steps, 55 | max_steps=a.max_steps, 56 | dataloader_num_workers=2, 57 | bf16=True, 58 | remove_unused_columns=False, 59 | max_prompt_length=a.max_prompt_len, 60 | max_new_tokens=a.max_new_tokens, 61 | temperature=a.temperature, 62 | top_p=a.top_p, 63 | top_k=a.top_k, 64 | num_generations=a.num_generations, 65 | kl_coef=a.kl_coeff, 66 | target_kl=a.target_kl, 67 | use_vllm=a.use_vllm, 68 | vllm_gpu_memory_utilization=a.vllm_gpu_memory_utilization, 69 | report_to=["tensorboard"], 70 | logging_dir=f"{a.output_dir}/tb", 71 | ) 72 | 73 | trainer = GRPOTrainer(model=model, processing_class=tokenizer, reward_func=reward_fn, args=cfg, train_dataset=ds) 74 | trainer.train() 75 | trainer.save_model(a.output_dir) 76 | tokenizer.save_pretrained(a.output_dir) 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # RL4LM Suite — SFT / DPO / GRPO / DAPO (FSDP & DeepSpeed, TensorBoard Logs) 3 | 4 | 一个统一的 LLM 训练仓库,支持四种训练方式: 5 | 6 | - **SFT**(监督微调) 7 | - **DPO**(偏好对,离线) 8 | - **GRPO**(在线 RL,基于 TRL 的 GRPOTrainer) 9 | - **DAPO**(在线 RL,**无偏好对**,在 GRPO 思路上的改进:**非对称剪切 + 动态采样**,可选轻量 KL 护栏) 10 | 11 | 本仓库默认**启用 TensorBoard 日志**。所有训练会把事件写到 `--output_dir/tb` 下。 12 | 13 | --- 14 | 15 | ## 安装 16 | 17 | ```bash 18 | pip install -r requirements.txt 19 | # (可选) 交互式配置 accelerate 20 | accelerate config 21 | ``` 22 | 23 | > 你需要安装与显卡匹配的 `torch` CUDA 版本。 24 | 25 | --- 26 | 27 | ## 数据格式 28 | 29 | - **SFT** (`data/sft_train.jsonl`) 30 | ```json 31 | {"prompt": "写个祝福语。", "response": "祝你事事顺心,天天开心!"} 32 | ``` 33 | - **DPO** (`data/dpo_train.jsonl`) 34 | ```json 35 | {"prompt":"...","chosen":"...","rejected":"..."} 36 | ``` 37 | - **GRPO / DAPO** (`data/grpo_train.jsonl`) 38 | ```json 39 | {"prompt":"Q: 2+3 等于?请在末尾输出 final_answer: ", "ref_answer":"5"} 40 | ``` 41 | `ref_answer` 可选,仅当你的奖励函数需要时使用。奖励函数在 `src/rewarding/rules.py`,可自定义。 42 | 43 | --- 44 | 45 | ## 启动(FSDP) 46 | 47 | ```bash 48 | # SFT 49 | accelerate launch --config_file configs/accelerate_fsdp.yaml train.py --algo sft --model_name Qwen2.5-7B-Instruct --dataset_path data/sft_train.jsonl --output_dir outputs/sft_fsdp 50 | 51 | # DPO 52 | accelerate launch --config_file configs/accelerate_fsdp.yaml train.py --algo dpo --model_name Qwen2.5-7B-Instruct --dataset_path data/dpo_train.jsonl --output_dir outputs/dpo_fsdp 53 | 54 | # GRPO 55 | accelerate launch --config_file configs/accelerate_fsdp.yaml train.py --algo grpo --model_name Qwen2.5-7B-Instruct --dataset_path data/grpo_train.jsonl --output_dir outputs/grpo_fsdp 56 | 57 | # DAPO(在线无偏好对) 58 | accelerate launch --config_file configs/accelerate_fsdp.yaml train.py --algo dapo --model_name Qwen2.5-7B-Instruct --dataset_path data/grpo_train.jsonl --output_dir outputs/dapo_fsdp 59 | ``` 60 | 61 | ## 启动(DeepSpeed ZeRO-3) 62 | 63 | ```bash 64 | accelerate launch --config_file configs/accelerate_deepspeed.yaml train.py --algo dpo --model_name Qwen2.5-7B-Instruct --dataset_path data/dpo_train.jsonl --output_dir outputs/dpo_ds 65 | ``` 66 | 67 | > 其他算法仅需切换 `--algo`。 68 | 69 | --- 70 | 71 | ## 查看日志(TensorBoard) 72 | 73 | 所有算法默认写入 `--output_dir/tb/`。例如: 74 | ```bash 75 | tensorboard --logdir outputs/sft_fsdp/tb # SFT 76 | tensorboard --logdir outputs/dpo_fsdp/tb # DPO 77 | tensorboard --logdir outputs/grpo_fsdp/tb # GRPO 78 | tensorboard --logdir outputs/dapo_fsdp/tb # DAPO 79 | ``` 80 | 81 | 打开浏览器访问提示的本地地址即可查看曲线。 82 | 83 | **指标约定**: 84 | - SFT/DPO/GRPO(基于 HF/TRL 的 Trainer):`train/loss`、`lr`、`epoch/step` 等 85 | - DAPO(自定义循环):`dapo/loss`、`dapo/reward_mean`、`dapo/ratio_mean`、`dapo/K`、`dapo/kl` 等 86 | 87 | --- 88 | 89 | ## 关键参数(统一入口) 90 | 91 | - 通用:`--model_name --dataset_path --output_dir --seed --use_lora/--no_lora --qlora/--no_qlora` 92 | - SFT:`--sft_max_length --sft_learning_rate --sft_num_train_epochs --sft_per_device_train_batch_size` 93 | - DPO:`--dpo_beta --max_steps --learning_rate --per_device_train_batch_size --gradient_accumulation_steps` 94 | - GRPO(在线):`--num_generations --max_prompt_len --max_new_tokens --kl_coeff --target_kl --use_vllm` 95 | - **DAPO(在线)**: 96 | - 采样:`--dapo_k_min --dapo_k_max --dapo_sched_interval` 97 | - 剪切:`--dapo_clip_low --dapo_clip_high`(非对称) 98 | - 正则:`--dapo_len_norm --dapo_kl` 99 | - 生成:`--max_prompt_len --max_new_tokens --temperature --top_p --top_k` 100 | 101 | --- 102 | 103 | ## 结构 104 | 105 | ``` 106 | . 107 | ├── configs/ 108 | │ ├── accelerate_fsdp.yaml 109 | │ ├── accelerate_deepspeed.yaml 110 | │ └── ds_zero3.json 111 | ├── data/ 112 | │ ├── sft_train.jsonl 113 | │ ├── dpo_train.jsonl 114 | │ └── grpo_train.jsonl 115 | ├── src/ 116 | │ ├── utils.py 117 | │ ├── rewarding/rules.py 118 | │ └── algos/ 119 | │ ├── sft_runner.py 120 | │ ├── dpo_runner.py 121 | │ ├── grpo_runner.py 122 | │ └── dapo_runner.py 123 | ├── train.py 124 | └── requirements.txt 125 | ``` 126 | 127 | --- 128 | 129 | ## 注意 130 | - **DAPO** 是**在线**方法,不需要 `chosen/rejected` 对;它与 TRL 的 GRPO 相同数据输入,但在**比率剪切与采样策略**上改进。 131 | - 如果你想用 **vLLM** 加速 GRPO 采样,可在环境中安装并加 `--use_vllm`。 132 | - 大模型/长上下文:如需全参 + 超大规模,请优先使用 DeepSpeed ZeRO-3 并考虑 offload。 133 | 134 | 祝顺利! 135 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse, random 3 | import numpy as np 4 | import torch 5 | from src.algos.sft_runner import run_sft, SFTArgs 6 | from src.algos.dpo_runner import run_dpo, DPOArgs 7 | from src.algos.grpo_runner import run_grpo, GRPOArgs 8 | from src.algos.dapo_runner import run_dapo, DAPOArgs 9 | 10 | def set_seed(seed: int): 11 | random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) 12 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--algo", choices=["sft","dpo","grpo","dapo"], required=True) 17 | parser.add_argument("--model_name", type=str, required=True) 18 | parser.add_argument("--dataset_path", type=str, required=True) 19 | parser.add_argument("--output_dir", type=str, default="outputs/run") 20 | parser.add_argument("--seed", type=int, default=42) 21 | # LoRA/QLoRA flags (builder内部默认开启,可用下面开关控制) 22 | parser.add_argument("--use_lora", action="store_true", default=True) 23 | parser.add_argument("--no_lora", dest="use_lora", action="store_false") 24 | parser.add_argument("--qlora", action="store_true", default=True) 25 | parser.add_argument("--no_qlora", dest="qlora", action="store_false") 26 | 27 | # SFT 28 | parser.add_argument("--sft_max_length", type=int, default=2048) 29 | parser.add_argument("--sft_learning_rate", type=float, default=2e-5) 30 | parser.add_argument("--sft_num_train_epochs", type=int, default=1) 31 | parser.add_argument("--sft_per_device_train_batch_size", type=int, default=1) 32 | 33 | # DPO/GRPO/DAPO shared 34 | parser.add_argument("--max_steps", type=int, default=1000) 35 | parser.add_argument("--per_device_train_batch_size", type=int, default=1) 36 | parser.add_argument("--gradient_accumulation_steps", type=int, default=4) 37 | parser.add_argument("--logging_steps", type=int, default=10) 38 | parser.add_argument("--save_steps", type=int, default=200) 39 | parser.add_argument("--warmup_ratio", type=float, default=0.05) 40 | parser.add_argument("--weight_decay", type=float, default=0.0) 41 | parser.add_argument("--learning_rate", type=float, default=5e-6) 42 | 43 | # DPO 44 | parser.add_argument("--dpo_beta", type=float, default=0.1) 45 | 46 | # GRPO/DAPO generation 47 | parser.add_argument("--max_prompt_len", type=int, default=1024) 48 | parser.add_argument("--max_new_tokens", type=int, default=512) 49 | parser.add_argument("--temperature", type=float, default=0.7) 50 | parser.add_argument("--top_p", type=float, default=0.9) 51 | parser.add_argument("--top_k", type=int, default=50) 52 | parser.add_argument("--num_generations", type=int, default=4) 53 | parser.add_argument("--kl_coeff", type=float, default=0.02) 54 | parser.add_argument("--target_kl", type=float, default=1.0) 55 | parser.add_argument("--use_vllm", action="store_true", default=False) 56 | parser.add_argument("--vllm_gpu_memory_utilization", type=float, default=0.9) 57 | 58 | # DAPO extras 59 | parser.add_argument("--dapo_k_min", type=int, default=2) 60 | parser.add_argument("--dapo_k_max", type=int, default=6) 61 | parser.add_argument("--dapo_sched_interval", type=int, default=200) 62 | parser.add_argument("--dapo_clip_low", type=float, default=0.2) 63 | parser.add_argument("--dapo_clip_high", type=float, default=0.2) 64 | parser.add_argument("--dapo_len_norm", action="store_true", default=True) 65 | parser.add_argument("--dapo_kl", type=float, default=0.005) 66 | 67 | args = parser.parse_args() 68 | set_seed(args.seed) 69 | 70 | if args.algo == "sft": 71 | a = SFTArgs( 72 | model_name=args.model_name, 73 | dataset_path=args.dataset_path, 74 | output_dir=args.output_dir, 75 | sft_max_length=args.sft_max_length, 76 | learning_rate=args.sft_learning_rate, 77 | num_train_epochs=args.sft_num_train_epochs, 78 | per_device_train_batch_size=args.sft_per_device_train_batch_size, 79 | ) 80 | run_sft(a) 81 | 82 | elif args.algo == "dpo": 83 | a = DPOArgs( 84 | model_name=args.model_name, 85 | dataset_path=args.dataset_path, 86 | output_dir=args.output_dir, 87 | dpo_beta=args.dpo_beta, 88 | learning_rate=args.learning_rate, 89 | per_device_train_batch_size=args.per_device_train_batch_size, 90 | gradient_accumulation_steps=args.gradient_accumulation_steps, 91 | max_steps=args.max_steps, 92 | logging_steps=args.logging_steps, 93 | save_steps=args.save_steps, 94 | warmup_ratio=args.warmup_ratio, 95 | weight_decay=args.weight_decay, 96 | ) 97 | run_dpo(a) 98 | 99 | elif args.algo == "grpo": 100 | a = GRPOArgs( 101 | model_name=args.model_name, 102 | dataset_path=args.dataset_path, 103 | output_dir=args.output_dir, 104 | max_prompt_len=args.max_prompt_len, 105 | max_new_tokens=args.max_new_tokens, 106 | temperature=args.temperature, 107 | top_p=args.top_p, 108 | top_k=args.top_k, 109 | num_generations=args.num_generations, 110 | learning_rate=args.learning_rate, 111 | warmup_ratio=args.warmup_ratio, 112 | weight_decay=args.weight_decay, 113 | gradient_accumulation_steps=args.gradient_accumulation_steps, 114 | per_device_train_batch_size=args.per_device_train_batch_size, 115 | max_steps=args.max_steps, 116 | logging_steps=args.logging_steps, 117 | save_steps=args.save_steps, 118 | kl_coeff=args.kl_coeff, 119 | target_kl=args.target_kl, 120 | use_lora=args.use_lora, 121 | qlora=args.qlora, 122 | use_vllm=args.use_vllm, 123 | vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization, 124 | ) 125 | run_grpo(a) 126 | 127 | elif args.algo == "dapo": 128 | a = DAPOArgs( 129 | model_name=args.model_name, 130 | dataset_path=args.dataset_path, 131 | output_dir=args.output_dir, 132 | dapo_k_min=args.dapo_k_min, 133 | dapo_k_max=args.dapo_k_max, 134 | dapo_sched_interval=args.dapo_sched_interval, 135 | max_prompt_len=args.max_prompt_len, 136 | max_new_tokens=args.max_new_tokens, 137 | temperature=args.temperature, 138 | top_p=args.top_p, 139 | top_k=args.top_k, 140 | learning_rate=args.learning_rate, 141 | per_device_train_batch_size=args.per_device_train_batch_size, 142 | gradient_accumulation_steps=args.gradient_accumulation_steps, 143 | warmup_ratio=args.warmup_ratio, 144 | max_steps=args.max_steps, 145 | logging_steps=args.logging_steps, 146 | save_steps=args.save_steps, 147 | weight_decay=args.weight_decay, 148 | dapo_clip_low=args.dapo_clip_low, 149 | dapo_clip_high=args.dapo_clip_high, 150 | dapo_len_norm=args.dapo_len_norm, 151 | dapo_kl=args.dapo_kl, 152 | use_lora=args.use_lora, 153 | qlora=args.qlora, 154 | ) 155 | run_dapo(a) 156 | 157 | if __name__ == "__main__": 158 | main() 159 | -------------------------------------------------------------------------------- /src/algos/dapo_runner.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import List, Dict 4 | import os, copy, torch 5 | from torch.utils.data import DataLoader 6 | from torch.utils.tensorboard import SummaryWriter 7 | from datasets import Dataset 8 | from transformers import get_scheduler 9 | from ..utils import build_model_and_tokenizer 10 | from ..rewarding.rules import reward_fn 11 | 12 | @dataclass 13 | class DAPOArgs: 14 | model_name: str 15 | dataset_path: str 16 | output_dir: str = "outputs/dapo" 17 | # generation 18 | dapo_k_min: int = 2 19 | dapo_k_max: int = 6 20 | dapo_sched_interval: int = 200 # steps to grow K 21 | max_prompt_len: int = 1024 22 | max_new_tokens: int = 256 23 | temperature: float = 0.8 24 | top_p: float = 0.9 25 | top_k: int = 50 26 | # optimization 27 | learning_rate: float = 5e-6 28 | per_device_train_batch_size: int = 1 29 | gradient_accumulation_steps: int = 4 30 | warmup_ratio: float = 0.05 31 | max_steps: int = 1000 32 | logging_steps: int = 10 33 | save_steps: int = 200 34 | weight_decay: float = 0.0 35 | # clipping & reg 36 | dapo_clip_low: float = 0.2 # for negative advantages: ratio >= 1 - clip_low 37 | dapo_clip_high: float = 0.2 # for positive advantages: ratio <= 1 + clip_high 38 | dapo_len_norm: bool = True 39 | dapo_kl: float = 0.005 # tiny KL to frozen ref (optional) 40 | # LoRA/QLoRA 41 | use_lora: bool = True 42 | qlora: bool = True 43 | 44 | def _load_prompts(path: str) -> Dataset: 45 | import json 46 | rows = [] 47 | with open(path, "r", encoding="utf-8") as f: 48 | for line in f: 49 | j = json.loads(line) 50 | rows.append({"prompt": j["prompt"], "meta": {"ref_answer": j.get("ref_answer")}}) 51 | return Dataset.from_list(rows) 52 | 53 | def _seq_logprob_and_len(model, tokenizer, prompt: str, resp: str, max_length: int): 54 | model.eval() 55 | device = next(model.parameters()).device 56 | with torch.no_grad(): 57 | p = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length).to(device) 58 | r = tokenizer(resp, return_tensors="pt", truncation=True, max_length=max_length).to(device) 59 | input_ids = torch.cat([p.input_ids, r.input_ids[:,1:]], dim=1) 60 | attn = torch.ones_like(input_ids) 61 | out = model(input_ids=input_ids, attention_mask=attn) 62 | logits = out.logits[:, :-1, :] 63 | labels = input_ids[:, 1:] 64 | resp_start = p.input_ids.shape[1]-1 65 | resp_logits = logits[:, resp_start:, :] 66 | resp_labels = labels[:, resp_start:] 67 | logprobs = torch.log_softmax(resp_logits, dim=-1) 68 | token_lp = logprobs.gather(-1, resp_labels.unsqueeze(-1)).squeeze(-1) 69 | return token_lp.sum(dim=1), resp_labels.shape[1] 70 | 71 | def _generate_k(model, tokenizer, prompt: str, K: int, max_new_tokens: int, temperature: float, top_p: float, top_k: int) -> List[str]: 72 | device = next(model.parameters()).device 73 | p = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(device) 74 | gens = [] 75 | for _ in range(K): 76 | out = model.generate( 77 | **p, 78 | max_new_tokens=max_new_tokens, 79 | do_sample=True, 80 | temperature=temperature, 81 | top_p=top_p, 82 | top_k=top_k, 83 | pad_token_id=tokenizer.eos_token_id, 84 | eos_token_id=tokenizer.eos_token_id, 85 | ) 86 | text = tokenizer.decode(out[0], skip_special_tokens=True) 87 | base = tokenizer.decode(p.input_ids[0], skip_special_tokens=True) 88 | gens.append(text[len(base):].strip()) 89 | return gens 90 | 91 | def run_dapo(a: DAPOArgs): 92 | device = "cuda" if torch.cuda.is_available() else "cpu" 93 | os.makedirs(a.output_dir, exist_ok=True) 94 | writer = SummaryWriter(log_dir=os.path.join(a.output_dir, "tb")) 95 | 96 | policy, tokenizer = build_model_and_tokenizer(a.model_name, qlora=a.qlora, use_lora=a.use_lora) 97 | policy.to(device) 98 | behavior = copy.deepcopy(policy).eval() 99 | for p in behavior.parameters(): p.requires_grad_(False) 100 | ref = copy.deepcopy(policy).eval() 101 | for p in ref.parameters(): p.requires_grad_(False) 102 | 103 | ds = _load_prompts(a.dataset_path) 104 | dl = DataLoader(ds, batch_size=a.per_device_train_batch_size, shuffle=True) 105 | 106 | optim = torch.optim.AdamW([p for p in policy.parameters() if p.requires_grad], lr=a.learning_rate, weight_decay=a.weight_decay) 107 | total_updates = a.max_steps 108 | warmup = int(total_updates * a.warmup_ratio) 109 | sched = get_scheduler("cosine", optimizer=optim, num_warmup_steps=warmup, num_training_steps=total_updates) 110 | 111 | step = 0 112 | policy.train() 113 | while step < a.max_steps: 114 | for batch in dl: 115 | if step >= a.max_steps: break 116 | loss_acc = 0.0 117 | reward_mean = 0.0 118 | ratio_mean = 0.0 119 | K_now = min(a.dapo_k_max, a.dapo_k_min + step // max(1, a.dapo_sched_interval)) 120 | micro = max(1, a.gradient_accumulation_steps) 121 | 122 | for m in range(micro): 123 | if m >= len(batch["prompt"]): break 124 | prompt = batch["prompt"][m] 125 | meta = batch.get("meta", [{}])[m] 126 | 127 | candidates = _generate_k(policy, tokenizer, prompt, K_now, a.max_new_tokens, a.temperature, a.top_p, a.top_k) 128 | samples = [{"prompt": prompt, "output": c, "meta": meta} for c in candidates] 129 | rewards = reward_fn(samples) 130 | r = torch.tensor(rewards, dtype=torch.float32, device=device) 131 | adv = r - r.mean() 132 | reward_mean += r.mean().item() / micro 133 | 134 | lp_new_list, lp_old_list, lp_ref_list, lens = [], [], [], [] 135 | for c in candidates: 136 | lp_new, L = _seq_logprob_and_len(policy, tokenizer, prompt, c, a.max_prompt_len) 137 | lp_old, _ = _seq_logprob_and_len(behavior, tokenizer, prompt, c, a.max_prompt_len) 138 | lp_ref, _ = _seq_logprob_and_len(ref, tokenizer, prompt, c, a.max_prompt_len) 139 | lp_new_list.append(lp_new.squeeze(0)); lp_old_list.append(lp_old.squeeze(0)); lp_ref_list.append(lp_ref.squeeze(0)); lens.append(L) 140 | lp_new = torch.stack(lp_new_list) 141 | lp_old = torch.stack(lp_old_list) 142 | lp_ref = torch.stack(lp_ref_list) 143 | lens_t = torch.tensor(lens, dtype=torch.float32, device=device) 144 | 145 | ratio = torch.exp(lp_new - lp_old) 146 | ratio_mean += ratio.mean().item() / micro 147 | 148 | pos = adv > 0 149 | neg = ~pos 150 | ratio_clipped = ratio.clone() 151 | if a.dapo_clip_high > 0: 152 | ratio_clipped[pos] = torch.minimum(ratio[pos], torch.tensor(1.0 + a.dapo_clip_high, device=device)) 153 | if a.dapo_clip_low > 0: 154 | ratio_clipped[neg] = torch.maximum(ratio[neg], torch.tensor(1.0 - a.dapo_clip_low, device=device)) 155 | 156 | weight = adv / (lens_t if a.dapo_len_norm else 1.0) 157 | obj = ratio_clipped * weight 158 | loss = -obj.mean() 159 | 160 | if a.dapo_kl > 0: 161 | kl = (lp_new - lp_ref).mean() 162 | loss = loss + a.dapo_kl * kl 163 | writer.add_scalar("dapo/kl", kl.item(), step) 164 | 165 | loss = loss / micro 166 | loss.backward() 167 | loss_acc += float(loss.item()) 168 | 169 | optim.step(); sched.step(); optim.zero_grad(set_to_none=True) 170 | step += 1 171 | 172 | if step % a.logging_steps == 0: 173 | writer.add_scalar("dapo/loss", loss_acc, step) 174 | writer.add_scalar("dapo/reward_mean", reward_mean, step) 175 | writer.add_scalar("dapo/ratio_mean", ratio_mean, step) 176 | writer.add_scalar("dapo/K", K_now, step) 177 | 178 | if step % a.save_steps == 0: 179 | policy.save_pretrained(a.output_dir) 180 | 181 | if step % max(50, a.dapo_sched_interval) == 0: 182 | behavior.load_state_dict(policy.state_dict()) 183 | 184 | if step >= a.max_steps: break 185 | 186 | policy.save_pretrained(a.output_dir) 187 | writer.close() 188 | --------------------------------------------------------------------------------