├── .env.example ├── .gitignore ├── README.md ├── arguments.py ├── configs └── deepspeed │ ├── ds_z0_config.json │ ├── ds_z2_config.json │ ├── ds_z2_offload_config.json │ ├── ds_z3_config.json │ └── ds_z3_offload_config.json ├── evaluation.py ├── exps └── exp-qwen0.5b-r1-zero-example │ ├── exp_config.yaml │ └── run_train.sh ├── infer_workers.py ├── pyproject.toml ├── requirements.txt ├── res └── imgs │ ├── Qwen2.5-0.5B_R1-Zero_Reproduction.png │ ├── sample_generation_using_ray_server_and_vllm.png │ ├── training_figures.png │ └── weight_synchronization.png ├── rewards.py ├── run_eval.sh ├── run_eval_multi.sh ├── train.py └── utils.py /.env.example: -------------------------------------------------------------------------------- 1 | WANDB_PROJECT="" 2 | WANDB_ENTITY="" 3 | WANDB_API_KEY="" 4 | 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # VSCode configs 2 | .vscode/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # db 62 | *.log 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | .pybuilder/ 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 81 | __pypackages__/ 82 | 83 | # Celery stuff 84 | celerybeat-schedule 85 | celerybeat.pid 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # pytype static type analyzer 118 | .pytype/ 119 | 120 | # Cython debug symbols 121 | cython_debug/ 122 | 123 | # PyCharm 124 | .idea/ 125 | 126 | data/ 127 | ckpts/ 128 | ckpt/ 129 | runs/ 130 | *.wav 131 | *.pt 132 | .aider* 133 | res/ 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple-R1 2 | 3 | A **hackable, simple, and efficient** DeepSeek R1-Zero Reproduction with high speed weight synchronization in a multinode environment. 4 | 5 | ## Features 6 | 7 | - **High-Speed Weight Synchronization between Training Process and Inference Workers** 8 | Unlike traditional RLHF frameworks (e.g., Open-R1), which combine training and inference within a single process—leading to high memory overhead—**Simple-R1 decouples inference from training**. 9 | We achieve **extremely fast weight updates** for vLLM-based inference workers via **direct NCCL communication** among distributed nodes. 10 | - **High-Performance Inference with Ray Serve** 11 | Ray Serve is a high-performance, scalable serving framework that provides load balancing for inference workers. 12 | We use Ray Serve to efficiently sample generated text from vLLM. 13 | - **Hackable**: No Hugging Face Trainer. You can fully customize your training loop. 14 | - **Simple**: Minimal abstraction, minimal files, minimal dependencies. 15 | 16 | ## Architecture 17 | ![sample_generation_using_ray_server_and_vllm](res/imgs/sample_generation_using_ray_server_and_vllm.png) 18 | ![weight_synchronization](res/imgs/weight_synchronization.png) 19 | 20 | 21 | ## TODO 22 | - [X] Implement a basic training loop to reproduce for DeepSeek R1-Zero. 23 | - [X] Implement high-speed weight synchronization using NCCL between training and inference nodes. 24 | - [ ] Improve code readability, enhance documentation, and refactor the code. 25 | - [ ] Test distributed training with multiple training and inference nodes. 26 | - [ ] Test and support large models. 27 | 28 | 29 | 30 | ## R1 Zero Simple Reproduction 31 | 32 | ### GSM8K (Qwen2.5-0.5B) 33 | 34 | ![GSM8K Training Step vs Accuracy Plot](res/imgs/Qwen2.5-0.5B_R1-Zero_Reproduction.png) 35 | 36 | ![Training Figures](res/imgs/training_figures.png) 37 | 38 | | Model | description | GSM8K | 39 | | --- | --- | --- | 40 | | Qwen.2.5-0.5B | Qwen2.5-0.5B baseline | TBU | 41 | | Qwen.2.5-0.5B-r1-zero-oneshot-800step | r1-zero training with Qwen2.5-0.5B with oneshot example | 0.3464 | 42 | 43 | 44 | ## Working Environment 45 | 46 | - Python 3.10 47 | - PyTorch 2.5.1 48 | - CUDA Toolkit 12.1 ~ 12.4 49 | - (Due to vLLM and Ray compatibility issues, CUDA versions must be between 12.1 and 12.4.) 50 | 51 | ## Installation 52 | 53 | ### Install PyTorch 54 | ```bash 55 | pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | ## Training the Model 60 | 61 | ### Single Node Setup 62 | 63 | (TBU) 64 | 65 | ### 1 Training Node, N Inference Nodes Setup 66 | 67 | One node for training, multiple nodes for inference workers. 68 | 69 | #### Setting Up Inference Workers 70 | 71 | ##### Step 1. Launch Ray Master 72 | 73 | ```bash 74 | RAY_MASTER_PORT=1234 75 | RAY_DASHBOARD_PORT=2345 76 | RAY_CLIENT_SERVER_PORT=3456 77 | 78 | RAY_HEAD_NODE_ONLY=1 ray start --head \ 79 | --port=$RAY_MASTER_PORT \ 80 | --num-gpus=0 \ 81 | --num-cpus=0 \ 82 | --min-worker-port=13000 \ 83 | --max-worker-port=14000 \ 84 | --dashboard-port=$RAY_DASHBOARD_PORT \ 85 | --ray-client-server-port=$RAY_CLIENT_SERVER_PORT \ 86 | --resources '{"head": 1}' 87 | ``` 88 | 89 | ##### Step 2. Attach Inference Workers to Ray Master 90 | 91 | The inference worker nodes must connect to the Ray Master. 92 | 93 | ```bash 94 | # On an Inference Worker Node 95 | ray start --address="$RAY_MASTER_ADDRESS:$RAY_MASTER_PORT" --block 96 | ``` 97 | 98 | ##### Step 3. Launch Inference Workers 99 | 100 | ```bash 101 | # On the Master Node 102 | export MODEL_NAME_OR_PATH="Qwen/Qwen2.5-0.5B" 103 | export NUM_INFER_WORKERS=8 104 | serve run infer_workers:inference_app 105 | ``` 106 | 107 | ##### Step 4. Launch Training 108 | 109 | Master node of trainig Node and Master node of Ray must be the same node. 110 | 111 | ```bash 112 | # At Master Node 113 | 114 | # Configure Accelerate for Training 115 | accelerate config 116 | 117 | # Run Training 118 | ./exps/exps/exp-qwen0.5b-r1-zero-example/run_train.sh 119 | ``` 120 | 121 | ### N Training Nodes, K Inference Nodes Setup 122 | 123 | (TBU) 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from dataclasses import dataclass, field 3 | 4 | 5 | @dataclass 6 | class BaseArgs: 7 | exp_name: str = "exp" 8 | logging_methods: list = field(default_factory=list) 9 | wandb_project: str = "default" 10 | wandb_entity: str = "default" 11 | 12 | # Dataset Args 13 | dataset_name_or_path: str = "AI-MO/NuminaMath-TIR" 14 | tokenized_dataset_path: str = "./data/NuminaMath-TIR_tokenized" 15 | overwrite_preprocess: bool = False 16 | 17 | # Preprocessing Args 18 | batch_size_for_preproc: int = 3000 19 | num_proc_for_preproc: int = 16 20 | 21 | # Model Args 22 | model_name_or_path: str = "HuggingFaceTB/SmolLM2-135M" 23 | 24 | # Training Args 25 | max_length: int = 1024 26 | num_train_epochs: int = 3 27 | num_warmup_steps: int = 500 28 | lr_scheduler_type: str = "cosine" 29 | learning_rate: float = 1.e-5 30 | max_grad_value: float = 1.0 31 | train_batch_size_per_proc: int = 2 32 | eval_batch_size_per_proc: int = 2 33 | gradient_accumulation_steps: int = 1 34 | 35 | # Rollout Args 36 | rollout_per_sample: int = 3 37 | rollout_temperature: float = 1.0 38 | rollout_max_tokens: int = 512 39 | kl_coef: float = 0.01 40 | 41 | eval_interval: int = 500 42 | log_interval: int = 100 43 | save_interval: int = 500 44 | save_dir: str ='./ckpts/exp' 45 | 46 | @classmethod 47 | def from_yaml(cls, config_path: str): 48 | with open(config_path, 'rt') as f: 49 | config = yaml.safe_load(f) 50 | return cls(**config) 51 | 52 | def to_dict(self): 53 | return self.__dict__ 54 | 55 | def __str__(self): 56 | return str(self.to_dict()) 57 | -------------------------------------------------------------------------------- /configs/deepspeed/ds_z0_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 0, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /configs/deepspeed/ds_z2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /configs/deepspeed/ds_z2_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "offload_optimizer": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "allgather_partitions": true, 25 | "allgather_bucket_size": 5e8, 26 | "overlap_comm": true, 27 | "reduce_scatter": true, 28 | "reduce_bucket_size": 5e8, 29 | "contiguous_gradients": true, 30 | "round_robin_gradients": true 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /configs/deepspeed/ds_z3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "bf16": { 8 | "enabled": "auto" 9 | }, 10 | "optimizer": { 11 | "type": "AdamW", 12 | "params": { 13 | "lr": "auto", 14 | "betas": "auto", 15 | "eps": "auto", 16 | "weight_decay": 0.0 17 | } 18 | }, 19 | "scheduler": { 20 | "type": "WarmupCosineLR", 21 | "params": { 22 | "last_batch_iteration": -1, 23 | "total_num_steps": "auto", 24 | "warmup_num_steps": "auto", 25 | "warmup_min_ratio": 0.05, 26 | "cos_min_ratio": 0.1 27 | } 28 | }, 29 | "zero_optimization": { 30 | "stage": 3, 31 | "overlap_comm": true, 32 | "contiguous_gradients": true, 33 | "sub_group_size": 1e9, 34 | "reduce_bucket_size": "auto", 35 | "stage3_prefetch_bucket_size": "auto", 36 | "stage3_param_persistence_threshold": "auto", 37 | "stage3_max_live_parameters": 1e9, 38 | "stage3_max_reuse_distance": 1e9, 39 | "stage3_gather_16bit_weights_on_model_save": true, 40 | "memory_efficient_linear": true 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /configs/deepspeed/ds_z3_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "bf16": { 8 | "enabled": "auto" 9 | }, 10 | "optimizer": { 11 | "type": "Adam", 12 | "params": { 13 | "lr": "auto", 14 | "betas": "auto", 15 | "eps": "auto", 16 | "weight_decay": "auto" 17 | } 18 | }, 19 | "scheduler": { 20 | "type": "WarmupDecayLR", 21 | "params": { 22 | "last_batch_iteration": -1, 23 | "total_num_steps": "auto", 24 | "warmup_min_lr": "auto", 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto" 27 | } 28 | }, 29 | "zero_optimization": { 30 | "stage": 3, 31 | "offload_optimizer": { 32 | "device": "cpu", 33 | "pin_memory": true 34 | }, 35 | "overlap_comm": true, 36 | "contiguous_gradients": true, 37 | "sub_group_size": 1e9, 38 | "reduce_bucket_size": "auto", 39 | "stage3_prefetch_bucket_size": "auto", 40 | "stage3_param_persistence_threshold": "auto", 41 | "stage3_max_live_parameters": 1e9, 42 | "stage3_max_reuse_distance": 1e9, 43 | "stage3_gather_16bit_weights_on_model_save": true, 44 | "memory_efficient_linear": true 45 | } 46 | } -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from textwrap import dedent 4 | 5 | import tqdm 6 | from datasets import load_dataset 7 | from vllm import LLM, SamplingParams 8 | from transformers import AutoTokenizer 9 | 10 | from utils import extract_answer, extract_numbers, compare_numbers 11 | 12 | 13 | def prepare_tokenizer(model_name_or_path): 14 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 15 | if tokenizer.pad_token_id is None: 16 | tokenizer.pad_token = tokenizer.eos_token 17 | tokenizer.pad_token_id = tokenizer.eos_token_id 18 | 19 | chat_template = dedent(""" 20 | {{- eos_token }} 21 | {%- for message in messages %} 22 | {{- '' + message['role'] + '\n' + message['content'] + '' + '\n' }} 23 | {%- endfor %} 24 | {%- if add_generation_prompt %} 25 | {{- 'assistant\n' }} 26 | {%- endif %}""").strip() 27 | tokenizer.chat_template = chat_template 28 | return tokenizer 29 | 30 | 31 | def build_gsm8k_input_and_output(tokenizer, question, answer): 32 | system_prompt = ("A conversation between User and Assistant. " 33 | "The user asks a question, and the Assistant solves it. " 34 | "The assistant first thinks about the reasoning process " 35 | "in the mind and then provides the user with the answer. " 36 | "The reasoning process and answer are enclosed within " 37 | " and tags, " 38 | "respectively, i.e., reasoning process here " 39 | "\n answer here ") 40 | fewshot_question_1 = ("Natalia sold clips to 48 of her friends in April, " 41 | "and then she sold half as many clips in May. " 42 | "How many clips did Natalia sell altogether in April and May?") 43 | fewshot_answer_1 = dedent(""" 44 | 45 | Natalia sold 48/2 = 24 clips in May. 46 | Natalia sold 48+24 = 72 clips altogether in April and May. 47 | 48 | 49 | 72 50 | """).strip() 51 | 52 | messages = [{"role": "system", "content": system_prompt}, 53 | {"role": "user", "content": fewshot_question_1}, 54 | {"role": "assistant", "content": fewshot_answer_1}, 55 | {"role": "user", "content": question}] 56 | 57 | input_text = tokenizer.apply_chat_template(messages, 58 | tokenize=False, 59 | add_generation_prompt=True) 60 | 61 | gold = answer.split("####")[-1].strip() 62 | 63 | return input_text, gold 64 | 65 | 66 | def load_gsm8k(split='test'): 67 | dataset = load_dataset("gsm8k", "main", split=split) 68 | return dataset 69 | 70 | 71 | def generate_answer(llm, input_text): 72 | sampling_params = SamplingParams(max_tokens=1024, # 384 73 | temperature=0.0, 74 | stop=[""] 75 | ) 76 | outputs = llm.generate([input_text], sampling_params) 77 | generated_text = outputs[0].outputs[0].text.strip() 78 | return generated_text 79 | 80 | 81 | def evaluate_gsm8k(model_name_or_path, output_path): 82 | llm = LLM(model=model_name_or_path) 83 | dataset = load_gsm8k() 84 | tokenizer = prepare_tokenizer(model_name_or_path=model_name_or_path) 85 | 86 | n_total = 0 87 | n_exact_correct = 0 88 | n_within_tolerance_correct = 0 89 | 90 | pbar = tqdm.tqdm(dataset, total=len(dataset)) 91 | for data in pbar: 92 | n_total += 1 93 | input_text, gold = build_gsm8k_input_and_output(tokenizer, data["question"], data["answer"]) 94 | 95 | pred_raw = generate_answer(llm, input_text) 96 | 97 | print(pred_raw) 98 | 99 | pred_answer_block = extract_answer(pred_raw) 100 | pred_answer_number = extract_numbers(pred_answer_block) 101 | pred = pred_answer_number[0] if pred_answer_number else None 102 | 103 | result = compare_numbers(pred, gold) 104 | if result['exact_match']: 105 | n_exact_correct += 1 106 | if result['within_tolerance']: 107 | n_within_tolerance_correct += 1 108 | 109 | pbar.set_description(f"Exact: {n_exact_correct/n_total:.2f}, Tolerance: {n_within_tolerance_correct/n_total:.2f}") 110 | 111 | exact_accuracy = n_exact_correct / n_total 112 | within_tolerance_accuracy = n_within_tolerance_correct / n_total 113 | 114 | metrics = {"gsm8k_accuracy_exact": exact_accuracy, 115 | "gsm8k_accuracy_within_tolerance": within_tolerance_accuracy, 116 | "output_path": output_path, 117 | "model_name_or_path": model_name_or_path, 118 | } 119 | print(metrics) 120 | 121 | # make dir for ojutput_path file 122 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 123 | with open(output_path, "at") as f: 124 | metrics_jsonline = json.dumps(metrics, ensure_ascii=False) 125 | f.write(metrics_jsonline) 126 | f.write("\n") 127 | 128 | return metrics 129 | 130 | 131 | if __name__ == "__main__": 132 | from argparse import ArgumentParser 133 | parser = ArgumentParser() 134 | parser.add_argument("--model_name_or_path", type=str, default="gpt2") 135 | parser.add_argument("--output_path", type=str, default="eval_outs/evaluation.jsonl") 136 | parser.add_argument("--task", type=str, default="gsm8k") 137 | args = parser.parse_args() 138 | 139 | if args.task == "gsm8k": 140 | evaluate_gsm8k(args.model_name_or_path, args.output_path) 141 | -------------------------------------------------------------------------------- /exps/exp-qwen0.5b-r1-zero-example/exp_config.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Args 2 | exp_name: "exp-qwen0.5b-r1-zero-example" 3 | 4 | logging_methods: ["tensorboard"] 5 | wandb_project: "" 6 | wandb_entity: "" 7 | 8 | # Dataset Args 9 | dataset_name_or_path: "openai/gsm8k" 10 | tokenized_dataset_path: "./data/openai/gsm8k_tokenized_for_qwen2.5_0.5b_oneshot" 11 | overwrite_preprocess: true 12 | 13 | # Preprocessing Args 14 | batch_size_for_preproc: 3000 15 | num_proc_for_preproc: 8 16 | 17 | # Model Args 18 | model_name_or_path: "Qwen/Qwen2.5-0.5B" 19 | 20 | # Training Args 21 | max_length: 340 22 | num_train_epochs: 30 23 | num_warmup_steps: 100 24 | learning_rate: 1.e-7 25 | lr_scheduler_type: "cosine" 26 | max_grad_value: 1.0 27 | train_batch_size_per_proc: 2 28 | eval_batch_size_per_proc: 3 29 | gradient_accumulation_steps: 1 30 | 31 | # GRPO Args 32 | rollout_per_sample: 3 33 | rollout_temperature: 0.5 34 | rollout_max_tokens: 384 35 | kl_coef: 0.01 36 | 37 | # Logging and Saving Args 38 | eval_interval: 32 39 | log_interval: 1 40 | save_interval: 160 41 | save_dir: './ckpts/exp-qwen0.5b-r1-zero-example' 42 | -------------------------------------------------------------------------------- /exps/exp-qwen0.5b-r1-zero-example/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 4 | REPO_ROOT_DIR=$SCRIPT_DIR/../.. 5 | cd $REPO_ROOT_DIR 6 | 7 | YAML_FILE=$(find "$SCRIPT_DIR" -maxdepth 1 -name "*.yaml" | head -n 1) 8 | 9 | if [ -z "$YAML_FILE" ]; then 10 | echo "No YAML file found in the current directory: $(pwd)" 11 | exit 1 12 | fi 13 | 14 | ABS_YAML_PATH=$(realpath "$YAML_FILE") 15 | echo "Using YAML file: $ABS_YAML_PATH" 16 | 17 | # offline 18 | # export WANDB_MODE=offline 19 | # export HF_DATASETS_OFFLINE=1 20 | # export HF_HUB_OFFLINE=1 21 | 22 | # cache 23 | # export HF_HOME="${REPO_ROOT_DIR}/cache" 24 | # export HF_DATASETS_CACHE="${REPO_ROOT_DIR}/cache/datasets" 25 | 26 | # Train Workers 27 | NUM_PROCESSES=8 28 | NUM_MACHINES=1 29 | MAIN_PROCESS_IP=$(hostname -I | awk '{print $1}') 30 | MAIN_PROCESS_PORT=23457 31 | MACHINE_RANK=0 32 | 33 | # Inference Workers 34 | export RAY_MASTER_ADDRESS=$MAIN_PROCESS_IP 35 | export RAY_CLIENT_SERVER_PORT=5000 36 | export RAY_MASTER_PG_PORT=5001 37 | 38 | accelerate launch \ 39 | --num_processes $NUM_PROCESSES \ 40 | --num_machines $NUM_MACHINES \ 41 | --main_process_ip $MAIN_PROCESS_IP \ 42 | --main_process_port $MAIN_PROCESS_PORT \ 43 | --machine_rank $MACHINE_RANK \ 44 | train.py \ 45 | --exp-config-path $ABS_YAML_PATH 46 | 47 | 48 | -------------------------------------------------------------------------------- /infer_workers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from ray import serve 5 | from vllm import LLM, SamplingParams 6 | from vllm.worker.worker import Worker 7 | from starlette.requests import Request 8 | 9 | MODEL_NAME_OR_PATH = os.environ.get("MODEL_NAME_OR_PATH", "Qwen/Qwen2.5-0.5B") 10 | # MODEL_NAME_OR_PATH = os.environ.get("MODEL_NAME_OR_PATH", "Qwen/Qwen2.5-3B") 11 | # MODEL_NAME_OR_PATH = os.environ.get("MODEL_NAME_OR_PATH", "HuggingFaceTB/SmolLM2-360M") 12 | NUM_INFER_WORKERS = os.environ.get("NUM_INFER_WORKERS", 8) 13 | print(f"MODEL_NAME_OR_PATH: {MODEL_NAME_OR_PATH}") 14 | print(f"NUM_INFER_WORKERS: {NUM_INFER_WORKERS}") 15 | 16 | 17 | def stateless_init_process_group(master_address, master_port, rank, world_size, 18 | device): 19 | """ 20 | vLLM provides `StatelessProcessGroup` to create a process group 21 | without considering the global process group in torch.distributed. 22 | It is recommended to create `StatelessProcessGroup`, and then initialize 23 | the data-plane communication (NCCL) between external (train processes) 24 | and vLLM workers. 25 | """ 26 | from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator 27 | from vllm.distributed.utils import StatelessProcessGroup 28 | pg = StatelessProcessGroup.create(host=master_address, 29 | port=master_port, 30 | rank=rank, 31 | world_size=world_size) 32 | pynccl = PyNcclCommunicator(pg, device=device) 33 | return pynccl 34 | 35 | 36 | class WrappedWorker(Worker): 37 | 38 | def init_weight_update_group(self, master_address, master_port, 39 | rank, world_size): 40 | from vllm.distributed.parallel_state import get_world_group 41 | 42 | print(f"{get_world_group().rank=}, {rank=}") 43 | self.model_update_group = stateless_init_process_group( 44 | master_address, 45 | master_port, 46 | rank, 47 | world_size, 48 | self.device, 49 | ) 50 | 51 | def update_weight(self, name, dtype, shape): 52 | weight = torch.empty(shape, dtype=dtype, device="cuda") 53 | self.model_update_group.broadcast(weight, 54 | src=0, 55 | stream=torch.cuda.current_stream()) 56 | 57 | self.model_runner.model.load_weights(weights=[(name, weight)]) 58 | del weight 59 | 60 | 61 | @serve.deployment(num_replicas=NUM_INFER_WORKERS, 62 | ray_actor_options={"num_gpus": 1}) 63 | class InferenceWorker: 64 | def __init__(self): 65 | self.llm = LLM(model=MODEL_NAME_OR_PATH, 66 | enforce_eager=True, 67 | worker_cls=WrappedWorker, 68 | dtype="half" 69 | ) 70 | self.worker_id = os.getpid() 71 | 72 | def init_weight_update_group(self, master_address, master_port, rank, world_size): 73 | self.llm.collective_rpc("init_weight_update_group", 74 | args=(master_address, master_port, rank, world_size)) 75 | return "Weight update group initialized." 76 | 77 | def update_weight(self, name, dtype, shape): 78 | self.llm.collective_rpc("update_weight", 79 | args=(name, dtype, shape)) 80 | return "Weight updated." 81 | 82 | def who_you_are(self, val: int) -> str: 83 | return f"Worker {self.worker_id} processed value: {val}" 84 | 85 | def generate_text(self, prompts, sample_params=None): 86 | if sample_params is None: 87 | sample_params = {'temperature': 0.7, 88 | 'max_tokens': 512, 89 | 'n': 4} 90 | 91 | outputs = self.llm.generate(prompts, 92 | SamplingParams(**sample_params)) 93 | 94 | return {"text": [[out.text for out in output.outputs] 95 | for output in outputs], 96 | "token_ids": [[list(out.token_ids) for out in output.outputs] 97 | for output in outputs] 98 | } 99 | 100 | 101 | async def __call__(self, http_request: Request): 102 | data = await http_request.json() 103 | prompts = data.get("prompts", []) 104 | results = self.generate_text(prompts) 105 | return results 106 | 107 | # Ray Serve App 108 | inference_app = InferenceWorker.bind() 109 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "simple-r1" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "accelerate==1.3.0", 9 | "datasets==3.2.0", 10 | "deepspeed==0.16.3", 11 | "math-verify==0.4.1", 12 | "pudb==2024.1.3", 13 | "python-dotenv==1.0.1", 14 | "ray==2.41.0", 15 | "tensorboard>=2.19.0", 16 | "torch==2.5.1", 17 | "torchaudio==2.5.1", 18 | "torchvision==0.20.1", 19 | "transformers==4.48.2", 20 | "vllm==0.7.0", 21 | "wandb==0.19.6", 22 | ] 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.2 2 | ray==2.41.0 3 | vllm==0.7.0 4 | accelerate==1.3.0 5 | datasets==3.2.0 6 | math-verify==0.4.1 7 | torch==2.5.1 8 | torchaudio==2.5.1 9 | torchvision==0.20.1 10 | deepspeed==0.16.3 11 | wandb==0.19.6 12 | python-dotenv==1.0.1 13 | pudb==2024.1.3 14 | -------------------------------------------------------------------------------- /res/imgs/Qwen2.5-0.5B_R1-Zero_Reproduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goddoe/simple-r1/070724f5af68087bb0101f05995bea4bcc752d73/res/imgs/Qwen2.5-0.5B_R1-Zero_Reproduction.png -------------------------------------------------------------------------------- /res/imgs/sample_generation_using_ray_server_and_vllm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goddoe/simple-r1/070724f5af68087bb0101f05995bea4bcc752d73/res/imgs/sample_generation_using_ray_server_and_vllm.png -------------------------------------------------------------------------------- /res/imgs/training_figures.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goddoe/simple-r1/070724f5af68087bb0101f05995bea4bcc752d73/res/imgs/training_figures.png -------------------------------------------------------------------------------- /res/imgs/weight_synchronization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goddoe/simple-r1/070724f5af68087bb0101f05995bea4bcc752d73/res/imgs/weight_synchronization.png -------------------------------------------------------------------------------- /rewards.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from math_verify import LatexExtractionConfig, parse, verify 4 | from latex2sympy2_extended import NormalizationConfig 5 | 6 | from utils import compare_numbers, extract_answer, extract_numbers 7 | 8 | 9 | def format_reward_func(completion, end_of_turn_token="", **kwargs): 10 | def count_substring_and_calc_reward(substring, completion): 11 | count = completion.count(substring) 12 | if count == 0: 13 | return 0. 14 | elif count == 1: 15 | return 1. 16 | return max((10 - count) * 0.1, -1.) 17 | 18 | reward = 0. 19 | 20 | keywords = ["", "", "", "", end_of_turn_token] 21 | for keyword in keywords: 22 | reward += count_substring_and_calc_reward(keyword, completion) 23 | 24 | # if reward == 0.: 25 | # return -1. 26 | 27 | # for keyword in keywords: 28 | # if completion.count(keyword) != 1: 29 | # return 0. 30 | 31 | if completion.startswith(""): 32 | reward += 1. 33 | 34 | if completion.endswith(end_of_turn_token): 35 | reward += 1. 36 | 37 | pattern = r"^(.*?)\n(.*?)" + end_of_turn_token + r"$" 38 | if re.match(pattern, completion, re.DOTALL): 39 | reward += 3.0 40 | 41 | # possible max value is 10 42 | scale = 1./10 43 | # scale = 1./5 44 | reward = reward * scale 45 | return reward 46 | 47 | 48 | def math_reward_func(completion, solution, **kwargs): 49 | answer_block = extract_answer(completion) 50 | answer_number = extract_numbers(answer_block) 51 | 52 | if answer_number: 53 | result = compare_numbers(answer_number[0], solution, tolerance=1e-5) 54 | if result["within_tolerance"]: 55 | return 1.0 56 | 57 | # Reference : https://github.com/huggingface/open-r1/blob/1fc8d425a995ddf8dbc6f8ef239d8161acdb7fc1/src/open_r1/grpo.py#L53-L82C1 58 | gold_parsed = parse(solution, extraction_mode="first_match", 59 | extraction_config=[LatexExtractionConfig()]) 60 | 61 | if len(gold_parsed) != 0: 62 | # We require the answer to be provided in correct latex (no malformed operators) 63 | answer_parsed = parse( 64 | completion, 65 | extraction_config=[ 66 | LatexExtractionConfig( 67 | normalization_config=NormalizationConfig( 68 | nits=False, 69 | malformed_operators=False, 70 | basic_latex=True, 71 | equations=True, 72 | boxed=True, 73 | units=True, 74 | ), 75 | # Ensures that boxed is tried first 76 | boxed_match_priority=0, 77 | try_extract_without_anchor=False, 78 | ) 79 | ], 80 | extraction_mode="first_match", 81 | ) 82 | # Reward 1 if the content is the same as the ground truth, 0 otherwise 83 | reward = float(verify(answer_parsed, gold_parsed)) 84 | return reward 85 | 86 | reward = 0.0 87 | return reward 88 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python evaluation.py \ 4 | --model_name_or_path "./ckpts/exp-qwen0.5b-r1-zero-example/ckpt_0" \ 5 | --output_path "./eval_outs/evaluation.jsonl" 6 | 7 | -------------------------------------------------------------------------------- /run_eval_multi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_DIR="./ckpts/exp-qwen0.5b-r1-zero-example" 4 | 5 | NUM_GPUS=4 6 | 7 | SAVE_PATH="./eval_outs/evaluation.jsonl" 8 | 9 | COMMANDS=() 10 | 11 | for ckpt in $(find "$BASE_DIR" -maxdepth 1 -type d -name "ckpt_*" | sort); do 12 | COMMANDS+=("python evaluation.py --model_name_or_path \"$ckpt\" --output_path $SAVE_PATH") 13 | done 14 | 15 | TOTAL_JOBS=${#COMMANDS[@]} 16 | 17 | gpu_idx=0 18 | 19 | for (( i=0; i" + message["role"] + "\n" + message["content"] + "" + "\n" }} 76 | {%- endfor %} 77 | {%- if add_generation_prompt %} 78 | {{- "assistant\n" }} 79 | {%- endif %}""").strip() 80 | tokenizer.chat_template = chat_template 81 | 82 | def tokenize_function(examples): 83 | system_prompt = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here \n answer here " 84 | fewshot_question_1 = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?" 85 | fewshot_answer_1 = dedent(""" 86 | 87 | Natalia sold 48/2 = 24 clips in May. 88 | Natalia sold 48+24 = 72 clips altogether in April and May. 89 | 90 | 91 | 72 92 | """).strip() 93 | 94 | gold_answer_list = [] 95 | 96 | new_messages_list = [] 97 | for q, a in zip(examples["question"], examples["answer"]): 98 | new_messages = [ 99 | {"role": "system", "content": system_prompt}, 100 | {"role": "user", "content": fewshot_question_1}, 101 | {"role": "assistant", "content": fewshot_answer_1}, 102 | {"role": "user", "content": q} 103 | ] 104 | new_messages_list.append(new_messages) 105 | gold_answer_list.append(a.split("####")[-1].strip()) 106 | 107 | batch = {} 108 | batch["gold_answer"] = gold_answer_list 109 | batch["solution"] = gold_answer_list 110 | 111 | batch["user_input_ids"] = tokenizer.apply_chat_template( 112 | new_messages_list, 113 | add_generation_prompt=True, 114 | return_tensors="pt", 115 | padding=True, 116 | truncation=True, 117 | max_length=exp_args.max_length 118 | ).tolist() 119 | batch["user_input_text"] = tokenizer.apply_chat_template( 120 | new_messages_list, 121 | tokenize=False, 122 | add_generation_prompt=True 123 | ) 124 | 125 | return batch 126 | 127 | 128 | ############################################################### 129 | # Prepare Dataset 130 | is_cache_exist = os.path.exists(exp_args.tokenized_dataset_path) 131 | if accelerator.is_main_process and (not is_cache_exist or exp_args.overwrite_preprocess): 132 | dataset = load_dataset(exp_args.dataset_name_or_path, "main") 133 | tokenized_datasets = dataset.map( 134 | tokenize_function, 135 | batched=True, 136 | batch_size=exp_args.batch_size_for_preproc, 137 | num_proc=8 138 | ) 139 | tokenized_datasets.save_to_disk(exp_args.tokenized_dataset_path) 140 | accelerator.wait_for_everyone() 141 | 142 | tokenized_datasets = load_from_disk(exp_args.tokenized_dataset_path) 143 | 144 | if dist.is_available() and dist.is_initialized(): 145 | train_sampler = DistributedSampler(tokenized_datasets["train"], shuffle=True) 146 | else: 147 | train_sampler = RandomSampler(tokenized_datasets["train"]) 148 | valid_sampler = SequentialSampler(tokenized_datasets["test"]) 149 | 150 | def collate_fn_all(batch): 151 | keys = [key for key in batch[0].keys()] 152 | data = {key: [] for key in keys} 153 | for item in batch: 154 | for key in keys: 155 | data[key].append(item[key]) 156 | if "user_input_ids" in data: 157 | user_input = tokenizer.pad({"input_ids": data["user_input_ids"]}, 158 | return_tensors="pt", 159 | padding=True, 160 | padding_side="left") 161 | data["user_input_ids"] = user_input.input_ids 162 | return data 163 | 164 | train_dataloader = DataLoader(tokenized_datasets["train"], 165 | sampler=train_sampler, 166 | batch_size=exp_args.train_batch_size_per_proc, 167 | collate_fn=collate_fn_all, 168 | drop_last=True) 169 | valid_dataloader = DataLoader(tokenized_datasets["test"], 170 | sampler=valid_sampler, 171 | batch_size=exp_args.eval_batch_size_per_proc, 172 | collate_fn=collate_fn_all) 173 | 174 | ############################################################### 175 | # Prepare Model 176 | model = AutoModelForCausalLM.from_pretrained(exp_args.model_name_or_path) 177 | 178 | ref_model = None 179 | if exp_args.kl_coef > 0.: 180 | ref_model = AutoModelForCausalLM.from_pretrained(exp_args.model_name_or_path) 181 | ref_model.eval() 182 | 183 | 184 | ############################################################### 185 | # Prepare Optimizer and Scheduler 186 | optimizer_cls = ( 187 | AdamW 188 | if accelerator.state.deepspeed_plugin is None 189 | or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config 190 | else DummyOptim 191 | ) 192 | optimizer = optimizer_cls(model.parameters(), lr=exp_args.learning_rate) 193 | 194 | num_processes = accelerator.num_processes 195 | accelerator.print("Number of processes (GPUs):", num_processes) 196 | 197 | num_training_steps = math.ceil(len(tokenized_datasets["train"]) / (exp_args.train_batch_size_per_proc * num_processes)) * exp_args.num_train_epochs 198 | accelerator.print("Number of training steps:", num_training_steps) 199 | 200 | # Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler 201 | if ( 202 | accelerator.state.deepspeed_plugin is None 203 | or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config 204 | ): 205 | lr_scheduler = get_scheduler( 206 | name=exp_args.lr_scheduler_type, 207 | optimizer=optimizer, 208 | num_warmup_steps=exp_args.num_warmup_steps, 209 | num_training_steps=num_training_steps, 210 | ) 211 | else: 212 | lr_scheduler = DummyScheduler( 213 | optimizer, total_num_steps=num_training_steps, warmup_num_steps=exp_args.num_warmup_steps 214 | ) 215 | 216 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 217 | model, optimizer, train_dataloader, lr_scheduler 218 | ) 219 | 220 | 221 | if exp_args.kl_coef > 0.: 222 | if accelerator.state.deepspeed_plugin is None: 223 | ref_model = accelerator.prepare_model(ref_model, evaluation_mode=True) 224 | else: 225 | ref_model = prepare_deepspeed(ref_model, accelerator) 226 | 227 | ############################################################### 228 | # Prepare Reward function 229 | reward_func_list = [format_reward_func, math_reward_func] 230 | 231 | ############################################################### 232 | # Prepare Inference Workers 233 | ray_master_address = os.environ["RAY_MASTER_ADDRESS"] 234 | ray_client_server_port = int(os.environ["RAY_CLIENT_SERVER_PORT"]) 235 | ray_master_pg_port = int(os.environ["RAY_MASTER_PG_PORT"]) 236 | 237 | if accelerator.is_main_process: 238 | ray.init(address="auto") 239 | else: 240 | ray.init(address=f"ray://{ray_master_address}:{ray_client_server_port}") 241 | 242 | handle = serve.get_deployment_handle("InferenceWorker", 243 | app_name="default") 244 | 245 | print(f"result: {handle.generate_text.remote(['hello']).result()}") 246 | 247 | num_infer_workers = -1 248 | model_update_group = None 249 | # init weight update group 250 | if accelerator.is_main_process: 251 | actor_handle_list = get_all_inference_actors() 252 | num_infer_workers = len(actor_handle_list) 253 | accelerator.print(actor_handle_list) 254 | 255 | worker_weight_init_handle_list = [] 256 | for i, actor_handle in enumerate(actor_handle_list): 257 | worker_weight_init_handle = call_func_using_actor_handle(actor_handle, 258 | "init_weight_update_group", 259 | master_address=ray_master_address, 260 | master_port=ray_master_pg_port, 261 | rank=i+1, 262 | world_size=num_infer_workers + 1) 263 | worker_weight_init_handle_list.append(worker_weight_init_handle) 264 | 265 | model_update_group = stateless_init_process_group( 266 | ray_master_address, 267 | ray_master_pg_port, 268 | rank=0, 269 | world_size=num_infer_workers + 1, 270 | device=torch.device("cuda:0") 271 | ) 272 | 273 | ray.get(worker_weight_init_handle_list) 274 | accelerator.wait_for_everyone() 275 | 276 | global_i = 0 277 | os.makedirs(exp_args.save_dir, exist_ok=True) 278 | model.train() 279 | pbar = tqdm(range(num_training_steps), total=num_training_steps) 280 | 281 | accelerator.print("Start training") 282 | for epoch in range(exp_args.num_train_epochs): 283 | for batch in train_dataloader: 284 | # context = nullcontext() 285 | context = accelerator.accumulate(model) 286 | with context: 287 | ############################################################### 288 | # Rollout 289 | sample_params = {"temperature": exp_args.rollout_temperature, 290 | "max_tokens": exp_args.rollout_max_tokens, 291 | "n": exp_args.rollout_per_sample, 292 | "include_stop_str_in_output": True, 293 | "stop": [stop_token]} 294 | 295 | future_policy_rollout_batch = handle.generate_text.remote( 296 | batch["user_input_text"], 297 | sample_params=sample_params 298 | ) 299 | 300 | policy_rollout_batch = future_policy_rollout_batch.result() 301 | 302 | text_compl_sample_list_batch = policy_rollout_batch["text"] # [batch_size, rollout_per_sample] 303 | reward_list = [] # [batch_size, rollout_per_sample, num_reward_func] 304 | 305 | ############################################################### 306 | # Calc Reward 307 | for j, (text_compl_sample_list, solution) in enumerate(zip(text_compl_sample_list_batch, batch["solution"])): 308 | curr_compl_reward_list = [] # [rollout_per_sample, num_reward_func] 309 | 310 | for k, text_compl_sample in enumerate(text_compl_sample_list): # [rollout_per_sample] 311 | curr_sample_reward_list = [] 312 | for l, reward_func in enumerate(reward_func_list): 313 | reward = reward_func(text_compl_sample, solution=solution) 314 | curr_sample_reward_list.append(reward) 315 | curr_compl_reward_list.append(curr_sample_reward_list) 316 | reward_list.append(curr_compl_reward_list) 317 | 318 | rewards = torch.tensor(reward_list) 319 | total_reward_by_each_compl = torch.sum(rewards, dim=2) # [batch_size, rollout_per_sample] 320 | reward_mean = torch.mean(total_reward_by_each_compl, dim=1) # [batch_size] 321 | reward_std = torch.std(total_reward_by_each_compl, dim=1) # [batch_size] 322 | 323 | ############################################################### 324 | # Calc Advantages 325 | # [batch_size, rollout_per_sample] 326 | advantages = (total_reward_by_each_compl - reward_mean.unsqueeze(1)) / (reward_std.unsqueeze(1) + 1e-4) 327 | advantages = advantages.to(model.device) 328 | 329 | # [batch_size, rollout_per_sample, not fixed length ] 330 | raw_completion_ids_batch = policy_rollout_batch["token_ids"] 331 | 332 | ############################################################### 333 | # Calc KL divergence 334 | 335 | batch_size = len(raw_completion_ids_batch) 336 | rollout_per_sample = len(raw_completion_ids_batch[0]) 337 | 338 | # [batch_size * rollout_per_sample, length] 339 | completion_ids_list = [] 340 | for raw_completion_ids in raw_completion_ids_batch: 341 | completion_ids_list.extend(raw_completion_ids) 342 | 343 | # [batch_size * rollout_per_sample, max_length] 344 | completion_padded= tokenizer.pad({ 345 | "input_ids": completion_ids_list}, 346 | return_tensors="pt", 347 | padding=True, 348 | padding_side="right" 349 | ) 350 | completion_ids = completion_padded.input_ids 351 | 352 | # [batch_size, rollout_per_sample, max_length] 353 | completion_ids = completion_ids.view(batch_size, rollout_per_sample, -1) 354 | completion_ids = completion_ids.to(model.device) 355 | 356 | # [batch_size, max_length] 357 | user_input_ids = batch["user_input_ids"] 358 | 359 | # [batch_size, rollout_per_sample, max_length] 360 | user_input_ids_expanded = user_input_ids.unsqueeze(1).expand(-1, rollout_per_sample, -1) 361 | 362 | # [batch_size, rollout_per_sample, max_length] 363 | prompt_completion_ids = torch.cat([user_input_ids_expanded, 364 | completion_ids], dim=-1) 365 | 366 | logits_to_keep = completion_ids[0].size(1) 367 | 368 | # [batch_size * rollout_per_sample, max_length] 369 | flatten_prompt_completion_ids = prompt_completion_ids.view(batch_size * rollout_per_sample, -1) 370 | 371 | # [batch_size, rollout_per_sample, max_length] 372 | flatten_prompt_completion_attention_mask = (flatten_prompt_completion_ids != pad_token_id).view(batch_size* rollout_per_sample, -1).long() 373 | 374 | 375 | policy_per_token_logps = get_per_token_logps( 376 | model, 377 | flatten_prompt_completion_ids, 378 | flatten_prompt_completion_attention_mask, 379 | logits_to_keep 380 | ) 381 | 382 | per_token_kl = 0. 383 | if exp_args.kl_coef > 0.: 384 | # Calc KLD 385 | with torch.no_grad(): 386 | ref_per_token_logps = get_per_token_logps( 387 | ref_model, 388 | flatten_prompt_completion_ids, 389 | flatten_prompt_completion_attention_mask, 390 | logits_to_keep 391 | ) 392 | 393 | # Compute the KL divergence between the model and the reference model 394 | per_token_kl = torch.exp(ref_per_token_logps - policy_per_token_logps) - (ref_per_token_logps - policy_per_token_logps) - 1 395 | 396 | # x - x.detach() allows for preserving gradients from x 397 | # It is equivalent to updating the old policy model at every step. 398 | # [batch_size * rollout_per_sample, max_length] 399 | per_token_loss = torch.exp(policy_per_token_logps - policy_per_token_logps.detach()) * advantages.view(-1, 1) 400 | 401 | per_token_loss = -(per_token_loss - exp_args.kl_coef * per_token_kl) 402 | 403 | # [batch_size * rollout_per_sample, max_length] 404 | completion_attention_mask = (completion_ids != pad_token_id).view(batch_size* rollout_per_sample, -1).long() 405 | train_loss = ((per_token_loss * completion_attention_mask).sum(dim=1) / completion_attention_mask.sum(dim=1)).mean() 406 | 407 | accelerator.backward(train_loss) 408 | if not is_deepspeed and accelerator.sync_gradients: 409 | accelerator.clip_grad_value_(model.parameters(), exp_args.max_grad_value) 410 | 411 | optimizer.step() 412 | lr_scheduler.step() 413 | optimizer.zero_grad() 414 | 415 | ############################################################### 416 | # Update Policy model 417 | actor_handle_list = get_all_inference_actors(class_name="InferenceWorker", state="ALIVE") 418 | 419 | unwrapped_model = accelerator.unwrap_model(model) 420 | if accelerator.is_main_process: 421 | start_time = time.time() 422 | for name, p in unwrapped_model.named_parameters(): 423 | worker_update_weight_handle_list = [] 424 | for i, actor_handle in enumerate(actor_handle_list): 425 | worker_update_weight_handle = call_func_using_actor_handle(actor_handle, 426 | "update_weight", 427 | name=name, 428 | dtype=p.dtype, 429 | shape=p.shape) 430 | worker_update_weight_handle_list.append(worker_update_weight_handle) 431 | model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) 432 | ray.get(worker_update_weight_handle_list) 433 | 434 | print(f"Time for weight update: {time.time() - start_time}") 435 | accelerator.wait_for_everyone() 436 | print(f"{accelerator.process_index} Train loss:", train_loss.item()) 437 | 438 | ############################################################### 439 | # Logging 440 | if global_i % exp_args.log_interval == 0: 441 | # Collect metrics 442 | global_train_loss = accelerator.reduce(train_loss.detach(), reduction="mean").item() 443 | 444 | # [batch_size * world_size, rollout_per_sample, num_reward_func] 445 | global_rewards = accelerator.gather_for_metrics(rewards.to(model.device)).detach() 446 | global_reward_mean = torch.mean(global_rewards).item() 447 | 448 | if accelerator.is_main_process: 449 | length_list = [] 450 | for text_compl_sample_list in text_compl_sample_list_batch: 451 | length_list.append([len(text_compl_sample) 452 | for text_compl_sample in text_compl_sample_list]) 453 | 454 | length_mean = np.mean(length_list) 455 | length_std = np.std(length_list) 456 | 457 | reward_func_to_reward_map = {} 458 | for i, reward_func in enumerate(reward_func_list): 459 | reward_func_name = reward_func.__name__ 460 | all_rewards = global_rewards[:, :, i] 461 | curr_reward_mean = torch.sum(all_rewards) / torch.numel(all_rewards) 462 | reward_func_to_reward_map[reward_func_name] = curr_reward_mean.item() 463 | 464 | metrics = {"epoch": epoch, 465 | "global_step": global_i, 466 | "reward_mean": global_reward_mean, 467 | "train_loss": global_train_loss, 468 | "lr": lr_scheduler.get_last_lr()[0], 469 | "length_mean": length_mean, 470 | "length_std": length_std, 471 | **reward_func_to_reward_map 472 | } 473 | 474 | print(metrics) 475 | 476 | if is_wandb_logging: 477 | wandb.log(metrics) 478 | 479 | if is_tb_logging: 480 | for k, v in metrics.items(): 481 | tb_writer.add_scalar(f"train/{k}", v, global_i) 482 | 483 | print("="*60) 484 | for item_list in completion_ids: 485 | for item in item_list: 486 | sample_completion = tokenizer.decode(item.cpu().tolist(), 487 | skip_special_tokens=True) 488 | print(sample_completion) 489 | print("-"*30) 490 | 491 | 492 | 493 | if accelerator.is_main_process and global_i % exp_args.eval_interval == 0: 494 | pred_raw_list = [] 495 | pred_list = [] 496 | gold_list = [] 497 | 498 | batch_result_list = [] 499 | 500 | eval_sample = 30 501 | for batch in valid_dataloader: 502 | if len(gold_list) > eval_sample: 503 | break 504 | # inference 505 | gold_list.extend(batch["gold_answer"]) 506 | 507 | sample_params = {"temperature": 0.1, 508 | "max_tokens": exp_args.rollout_max_tokens, 509 | "n": 1, 510 | "include_stop_str_in_output": True, 511 | "stop": [stop_token]} 512 | 513 | future_policy_rollout_batch = handle.generate_text.remote( 514 | batch["user_input_text"], 515 | sample_params=sample_params 516 | ) 517 | batch_result_list.append(future_policy_rollout_batch) 518 | 519 | if len(batch_result_list) >= num_infer_workers: 520 | continue 521 | 522 | for future_policy_rollout_batch in batch_result_list: 523 | policy_rollout_batch = future_policy_rollout_batch.result() 524 | for preds in policy_rollout_batch["text"]: 525 | pred_raw_list.append(preds[0]) 526 | 527 | batch_result_list = [] 528 | 529 | if batch_result_list: 530 | for future_policy_rollout_batch in batch_result_list: 531 | policy_rollout_batch = future_policy_rollout_batch.result() 532 | for preds in policy_rollout_batch["text"]: 533 | pred_raw_list.append(preds[0]) 534 | 535 | gold_list = gold_list[:eval_sample] 536 | pred_raw_list = pred_raw_list[:eval_sample] 537 | 538 | for pred_raw in pred_raw_list: 539 | # extract answer from tag 540 | answer_block = extract_answer(pred_raw) 541 | answer_number = extract_numbers(answer_block) 542 | pred = answer_number[0] if answer_number else None 543 | pred_list.append(pred) 544 | 545 | n_exact_correct = 0 546 | n_within_tolerance_correct = 0 547 | n_total = len(pred_list) 548 | 549 | for pred, gold in zip(pred_list, gold_list): 550 | result = compare_numbers(pred, gold) 551 | if result["exact_match"]: 552 | n_exact_correct += 1 553 | if result["within_tolerance"]: 554 | n_within_tolerance_correct += 1 555 | 556 | # Calc Accuracy 557 | exact_accuracy = n_exact_correct / n_total 558 | within_tolerance_accuracy = n_within_tolerance_correct / n_total 559 | 560 | metrics = { 561 | f"gsm8k_accuracy_exact_{eval_sample}s": exact_accuracy, 562 | f"gsm8k_accuracy_within_tolerance_{eval_sample}s": within_tolerance_accuracy, 563 | } 564 | 565 | if is_wandb_logging: 566 | wandb.log(metrics) 567 | 568 | if is_tb_logging: 569 | for k, v in metrics.items(): 570 | tb_writer.add_scalar(f"valid/{k}", v, global_i) 571 | 572 | accelerator.print( 573 | f"global_step: {global_i}, epoch: {epoch}, " 574 | f"gsm8k_accuracy_exact_{eval_sample}s: {exact_accuracy:0.4f}, " 575 | f"gsm8k_accuracy_within_tolerance_{eval_sample}s: {within_tolerance_accuracy:0.4f}" 576 | ) 577 | 578 | accelerator.wait_for_everyone() 579 | 580 | if global_i % exp_args.save_interval == 0: 581 | if accelerator.is_main_process: 582 | unwrapped_model = accelerator.unwrap_model(model) 583 | unwrapped_model.save_pretrained( 584 | f"{exp_args.save_dir}/ckpt_{global_i}", 585 | is_main_process=accelerator.is_main_process, 586 | save_function=accelerator.save, 587 | state_dict=accelerator.get_state_dict(model) 588 | ) 589 | 590 | tokenizer.save_pretrained(f"{exp_args.save_dir}/ckpt_{global_i}") 591 | torch.save({"epoch": epoch, "global_step": global_i}, f"{exp_args.save_dir}/ckpt_{global_i}/training_state.pt") 592 | 593 | accelerator.wait_for_everyone() 594 | 595 | pbar.update(1) 596 | global_i += 1 597 | 598 | pbar.close() 599 | 600 | 601 | if __name__ == "__main__": 602 | main() 603 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pickle 3 | from copy import deepcopy 4 | 5 | import ray 6 | import yaml 7 | import torch 8 | import torch.nn.functional as F 9 | from ray.util.state import list_actors 10 | from ray.serve._private.common import RequestMetadata 11 | 12 | 13 | def read_config(config_path): 14 | with open(config_path, 'rt') as f: 15 | config = yaml.safe_load(f) 16 | return config 17 | 18 | 19 | def stateless_init_process_group(master_address, master_port, 20 | rank, world_size, device): 21 | """ 22 | vLLM provides `StatelessProcessGroup` to create a process group 23 | without considering the global process group in torch.distributed. 24 | It is recommended to create `StatelessProcessGroup`, and then initialize 25 | the data-plane communication (NCCL) between external (train processes) 26 | and vLLM workers. 27 | """ 28 | from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator 29 | from vllm.distributed.utils import StatelessProcessGroup 30 | pg = StatelessProcessGroup.create(host=master_address, 31 | port=master_port, 32 | rank=rank, 33 | world_size=world_size) 34 | pynccl = PyNcclCommunicator(pg, device=device) 35 | return pynccl 36 | 37 | 38 | def get_all_inference_actors(class_name='InferenceWorker', state='ALIVE' 39 | ) -> list[ray.actor.ActorHandle]: 40 | actor_state_list = [] 41 | 42 | for actor in list_actors(): 43 | if class_name in actor.class_name and actor.state == state: 44 | actor_state_list.append(actor) 45 | 46 | actor_handle_list = [] 47 | for actor_state in actor_state_list: 48 | actor_handle = ray.get_actor(name=actor_state['name'], namespace=actor_state['ray_namespace']) 49 | actor_handle_list.append(actor_handle) 50 | 51 | return actor_handle_list 52 | 53 | 54 | def call_func_using_actor_handle(actor_handle: ray.actor.ActorHandle, 55 | method_name: str, 56 | *method_args, **method_kwargs) -> ray.ObjectRef: 57 | request_metadata = RequestMetadata( 58 | request_id="dummy", 59 | internal_request_id="dummy", 60 | call_method=method_name 61 | ) 62 | serialized_metadata = pickle.dumps(request_metadata) 63 | result = actor_handle.handle_request.remote(serialized_metadata, *method_args, **method_kwargs) 64 | return result 65 | 66 | 67 | 68 | def prepare_deepspeed(model, accelerator): 69 | # Copy From: https://github.com/huggingface/trl/blob/af4ad47035529164799be10f3fe558ee642a9880/trl/models/utils.py#L199-L230 70 | 71 | # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 72 | import deepspeed 73 | deepspeed_plugin = accelerator.state.deepspeed_plugin 74 | config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) 75 | stage = config_kwargs["zero_optimization"]["stage"] 76 | 77 | 78 | if model is not None: 79 | hidden_size = ( 80 | max(model.config.hidden_sizes) 81 | if getattr(model.config, "hidden_sizes", None) 82 | else getattr(model.config, "hidden_size", None) 83 | ) 84 | if hidden_size is not None and stage == 3: 85 | # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache 86 | # @ step 0: expected module 1, but got module 0` 87 | # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 88 | config_kwargs.update( 89 | { 90 | "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, 91 | "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, 92 | "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, 93 | } 94 | ) 95 | 96 | 97 | # If ZeRO-3 is used, we shard both the active and reference model. 98 | # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO 99 | # disabled (stage 0) 100 | if stage != 3: 101 | config_kwargs["zero_optimization"]["stage"] = 0 102 | model, *_ = deepspeed.initialize(model=model, config=config_kwargs) 103 | model.eval() 104 | return model 105 | 106 | 107 | def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep): 108 | """Calculate per-token log probabilities""" 109 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 110 | logits = outputs.logits 111 | 112 | # Shift logits and labels for next token prediction 113 | logits = logits[:, :-1, :].contiguous() 114 | labels = input_ids[:, 1:].contiguous() 115 | 116 | # Get the last logits_to_keep tokens 117 | if logits_to_keep > 0: 118 | logits = logits[:, -logits_to_keep:, :] 119 | labels = labels[:, -logits_to_keep:] 120 | 121 | # Calculate log probabilities 122 | log_probs = F.log_softmax(logits, dim=-1) 123 | 124 | # Gather log probs for actual tokens 125 | labels_expanded = labels.unsqueeze(-1) 126 | per_token_logps = log_probs.gather(dim=-1, index=labels_expanded).squeeze(-1) 127 | 128 | return per_token_logps 129 | 130 | 131 | def create_keyword_mask_from_offsets(tokenizer, input_texts, keywords): 132 | tokenized_inputs = tokenizer(input_texts, padding=True, truncation=True, 133 | return_tensors="pt", return_offsets_mapping=True) 134 | token_ids = tokenized_inputs["input_ids"] 135 | offset_mappings = tokenized_inputs["offset_mapping"] 136 | 137 | batch_size, seq_len = token_ids.shape 138 | 139 | mask = torch.zeros_like(token_ids, dtype=torch.float32) 140 | 141 | keyword_positions = [] 142 | for keyword in keywords: 143 | for text in input_texts: 144 | start_idx = text.find(keyword) 145 | if start_idx != -1: 146 | keyword_positions.append((text, start_idx, start_idx + len(keyword))) 147 | 148 | for b in range(batch_size): 149 | text = input_texts[b] 150 | for _, start_pos, end_pos in keyword_positions: 151 | if text != _: 152 | continue 153 | for i in range(seq_len): 154 | token_start, token_end = offset_mappings[b, i] 155 | if token_start >= start_pos and token_end <= end_pos: 156 | mask[b, i] = 1 157 | 158 | return mask 159 | 160 | 161 | def extract_numbers(text): 162 | if text is None: 163 | return [] 164 | 165 | text = text.replace(",", "") 166 | numbers = re.findall(r"[-+]?\d*\.?\d+", text) 167 | 168 | return [float(num) for num in numbers] if numbers else [] 169 | 170 | 171 | def compare_numbers(pred, gold, tolerance=1e-5): 172 | if not pred or not gold: 173 | return { 174 | "exact_match": False, 175 | "within_tolerance": False, 176 | "pred": pred, 177 | "gold":gold 178 | } 179 | 180 | if isinstance(gold, str): 181 | gold = gold.replace(",", "") 182 | if isinstance(pred, str): 183 | pred = pred.replace(",", "") 184 | 185 | pred = float(pred) 186 | gold = float(gold) 187 | 188 | exact_match = pred == gold 189 | within_tolerance = abs(pred - gold) <= tolerance 190 | 191 | return { 192 | "exact_match": exact_match, 193 | "within_tolerance": within_tolerance, 194 | "pred": pred, 195 | "gold":gold 196 | } 197 | 198 | 199 | def extract_answer(text): 200 | match = re.search(r"(.*?)", text, re.DOTALL) 201 | return match.group(1) if match else None 202 | --------------------------------------------------------------------------------