├── mcts ├── __init__.py ├── util.py ├── test.py ├── train_policy_orpo.py ├── generate.py ├── train_policy_simpo.py ├── train_policy_sft.py ├── reward.py ├── train_policy_sft_prime.py ├── train_policy_sft_metamath.py ├── greedy_sample.py ├── train_policy_sft_qwq.py ├── simple_sample.py └── tree_search_light_mathrm.py ├── readme.md ├── vespa └── Dockerfile ├── Dockerfile ├── dspy_reflect.py ├── local_reward.py ├── pyproject.toml ├── fly.toml ├── .github └── workflows │ └── fly-deploy.yml ├── modal_vespa.py ├── lib.py ├── modal_train_policy_sft.py ├── modal_train_policy_sft_metamath.py ├── modal_train_policy_sft_qwq.py ├── modal_train_policy_sft_prime.py ├── tweak_prime_rl_data.py ├── modal_train_prm_rlhf_flow.py ├── modal_train_prm_init.py ├── prm_rlhf_flow └── qwen.yml ├── modal_train_prm_st.py ├── modal_train_policy_simpo.py ├── modal_prm_armorm.py ├── modal_train_policy_orpo.py ├── modal_tree_grpo.py ├── .gitignore ├── .dockerignore ├── modal_orm_reward.py ├── test_vllm_prm.py ├── modal_prm_reward.py ├── modal_vllm.py ├── prime_training_local.py ├── modal_vllm_prm.py ├── modal_prime.py ├── modal_vllm_chat.py ├── modal_vllm_qwq.py ├── backup └── modal_reward_save.py ├── eval_reward.py └── best_of_n.py /mcts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | various experiments for scaling inference time compute with small reasoning models 2 | 3 | high throughput async mcts implementation for policy + prm hosted on serverless gpus on modal 4 | -------------------------------------------------------------------------------- /vespa/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM vespaengine/vespa:latest 2 | 3 | RUN echo '#!/usr/bin/env bash' > /entry.sh 4 | RUN echo 'exec "$@"' >> /entry.sh 5 | RUN chmod +x /entry.sh 6 | # ENTRYPOINT /entry.sh 7 | ENTRYPOINT tail -f /dev/null -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu24.04 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y python3-pip python3-dev python-is-python3 && \ 5 | rm -rf /var/lib/apt/lists/* 6 | 7 | WORKDIR /app 8 | 9 | ADD local_reward.py . -------------------------------------------------------------------------------- /dspy_reflect.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | from dspy.teleprompt import MIPROv2 3 | from dspy.evaluate import Evaluate 4 | 5 | from dotenv import load_dotenv 6 | load_dotenv() 7 | 8 | task_model = dspy.OpenAI(model="gpt-4o-mini", max_tokens=4000) 9 | 10 | dspy.settings.configure(lm=task_model) -------------------------------------------------------------------------------- /local_reward.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | import asyncio 3 | import torch 4 | import copy 5 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 6 | 7 | dtype = torch.bfloat16 8 | with torch.device("cuda"): 9 | model = AutoModelForSequenceClassification.from_pretrained(self.model_id, 10 | trust_remote_code=True, torch_dtype=dtype, use_safetensors=True) 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "mirrorllm" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Robert Washbourne "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | dspy-ai = "^2.4.16" 11 | modal = "^0.72.10" 12 | cachetools = "^5.5.0" 13 | aiomultiprocess = "^0.9.1" 14 | tiktoken = "^0.8.0" 15 | 16 | 17 | [build-system] 18 | requires = ["poetry-core"] 19 | build-backend = "poetry.core.masonry.api" 20 | -------------------------------------------------------------------------------- /fly.toml: -------------------------------------------------------------------------------- 1 | # fly.toml app configuration file generated for mirrorllm on 2024-09-14T09:23:12Z 2 | # 3 | # See https://fly.io/docs/reference/configuration/ for information about how to use this file. 4 | # 5 | 6 | app = 'mirrorllm' 7 | primary_region = 'lax' 8 | 9 | [build] 10 | 11 | [http_service] 12 | internal_port = 8080 13 | force_https = true 14 | auto_stop_machines = 'stop' 15 | auto_start_machines = true 16 | min_machines_running = 0 17 | processes = ['app'] 18 | 19 | [[vm]] 20 | memory = '1gb' 21 | cpu_kind = 'shared' 22 | cpus = 1 23 | -------------------------------------------------------------------------------- /.github/workflows/fly-deploy.yml: -------------------------------------------------------------------------------- 1 | # See https://fly.io/docs/app-guides/continuous-deployment-with-github-actions/ 2 | 3 | name: Fly Deploy 4 | on: 5 | push: 6 | branches: 7 | - main 8 | jobs: 9 | deploy: 10 | name: Deploy app 11 | runs-on: ubuntu-latest 12 | concurrency: deploy-group # optional: ensure only one action runs at a time 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: superfly/flyctl-actions/setup-flyctl@master 16 | - run: flyctl deploy --remote-only 17 | env: 18 | FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }} 19 | -------------------------------------------------------------------------------- /mcts/util.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | 4 | SEED = 42 5 | 6 | def split_and_clean_steps(text): 7 | # Use regex to split the text into steps 8 | steps = re.split(r'(?=##\s*Step\s+\d+:)', text) 9 | 10 | # Remove any leading/trailing whitespace, empty steps, and the "## Step n:" prefix 11 | cleaned_steps = [] 12 | for step in steps: 13 | # Strip whitespace and check if step is not empty 14 | step = step.strip() 15 | if step: 16 | # Remove the "## Step n:" prefix 17 | step = re.sub(r'^##\s*Step\s+\d+:\s*', '', step) 18 | cleaned_steps.append(step) 19 | 20 | return cleaned_steps 21 | 22 | def quality_filter(example): 23 | response_quality = example['score'] >= 0.32 # arbitrary af 24 | # TODO: check correctness of chain 25 | # math_and_reasoning = example['primary_tag'] in ['Math', 'Reasoning'] 26 | instruction_quality = example['quality'] in ['excellent', 'good'] 27 | response_format = "## Step 1: " in example['response'] 28 | return response_quality and instruction_quality and response_format -------------------------------------------------------------------------------- /modal_vespa.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | 4 | # vespa_image = modal.Image.from_registry("vespaengine/vespa", add_python="3.11") 5 | vespa_image = modal.Image.from_dockerfile("Dockerfile", add_python="3.11") 6 | app = modal.App("dankvespa", image=vespa_image) 7 | 8 | @modal.web_endpoint(method="POST", docs=True) 9 | @app.cls( 10 | enable_memory_snapshot=True, 11 | # volumes={"/my_vol": modal.Volume.from_name("my-test-volume")} 12 | ) 13 | class Vespa: 14 | @modal.build() 15 | def build(self): 16 | # cache 17 | print("build") 18 | 19 | @modal.enter(snap=True) 20 | def load(self): 21 | # Create a memory snapshot with the model loaded in CPU memory. 22 | print("save state") 23 | 24 | @modal.enter(snap=False) 25 | def setup(self): 26 | # Move the model to a GPU before doing any work. 27 | print("loaded from snapshot") 28 | 29 | @modal.method() 30 | def search(self, query: str): 31 | print("search") 32 | 33 | 34 | @app.local_entrypoint() 35 | async def main(): 36 | # score the messages 37 | m1 = await Vespa().search.remote.aio("test") 38 | print(m1) -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | def test(): 2 | return "ASDD" 3 | 4 | from typing import List, Dict 5 | import torch 6 | 7 | def extract_tensors(m: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]: 8 | """ 9 | Remove the tensors from a PyTorch model, convert them to NumPy 10 | arrays, and return the stripped model and tensors. 11 | """ 12 | tensors = [] 13 | for _, module in m.named_modules(): 14 | # Store the tensors in Python dictionaries 15 | params = { 16 | name: torch.clone(param).detach().numpy() 17 | for name, param in module.named_parameters(recurse=False) 18 | } 19 | buffers = { 20 | name: torch.clone(buf).detach().numpy() 21 | for name, buf in module.named_buffers(recurse=False) 22 | } 23 | tensors.append({"params": params, "buffers": buffers}) 24 | 25 | # Make a copy of the original model and strip all tensors and 26 | # temporary buffers out of the copy. 27 | m_copy = copy.deepcopy(m) 28 | for _, module in m_copy.named_modules(): 29 | for name in ( 30 | [name for name, _ in module.named_parameters(recurse=False)] 31 | + [name for name, _ in module.named_buffers(recurse=False)]): 32 | setattr(module, name, None) 33 | 34 | # Make sure the copy is configured for inference. 35 | m_copy.train(False) 36 | return m_copy, tensors -------------------------------------------------------------------------------- /mcts/test.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def split_and_clean_steps(text): 4 | # Use regex to split the text into steps 5 | steps = re.split(r'(?=##\s*Step\s+\d+:)', text) 6 | 7 | # Remove any leading/trailing whitespace, empty steps, and the "## Step n:" prefix 8 | cleaned_steps = [] 9 | for step in steps: 10 | # Strip whitespace and check if step is not empty 11 | step = step.strip() 12 | if step: 13 | # Remove the "## Step n:" prefix 14 | step = re.sub(r'^##\s*Step\s+\d+:\s*', '', step) 15 | cleaned_steps.append(step) 16 | 17 | return cleaned_steps 18 | 19 | # Example usage 20 | text1 = """## Step 1: First step 21 | Content of first step. 22 | ## Step 2: Second step 23 | Content of second step. 24 | ## Step 10: Tenth step 25 | Content of tenth step. 26 | ## Step 11: Eleventh step 27 | Content of eleventh step. 28 | sdfsdfsdfsdf 29 | 30 | 31 | 32 | sdfsdfsd 33 | 34 | step ## Step 12: Test""" 35 | 36 | text2 = """## Step 1: Short step 37 | Brief content. 38 | ## Step 99: Large step number 39 | Content of step 99. 40 | ## Step 100: Three-digit step 41 | Content of step 100.""" 42 | 43 | # Test with both examples 44 | for i, text in enumerate([text1, text2], 1): 45 | # print(f"Test case {i}:") 46 | result = split_and_clean_steps(text) 47 | for j, step in enumerate(result, 1): 48 | print(f"Step {j}:") 49 | print(step) 50 | print() 51 | print("---\n") -------------------------------------------------------------------------------- /modal_train_policy_sft.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | cuda_version = "12.4.0" # should be no greater than host CUDA version 4 | flavor = "devel" # includes full CUDA toolkit 5 | operating_sys = "ubuntu22.04" 6 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 7 | 8 | image = ( 9 | # modal.Image.debian_slim() 10 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 11 | .apt_install("git") 12 | .pip_install("torch") 13 | .pip_install("packaging") 14 | .pip_install("wheel") 15 | .run_commands("pip install flash-attn --no-build-isolation") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("numpy") 19 | .pip_install("datasets") 20 | .pip_install("wandb") 21 | .pip_install("bitsandbytes") 22 | .pip_install("unsloth") 23 | ) 24 | app = modal.App("train_policy_sft", image=image) 25 | 26 | with image.imports(): 27 | from mcts.train_policy_sft import train_sft 28 | 29 | MINUTES = 60 # seconds 30 | HOURS = 60 * MINUTES 31 | 32 | @app.function( 33 | cpu=2.0, 34 | gpu=modal.gpu.A10G(), 35 | # gpu=modal.gpu.H100(), 36 | # gpu=modal.gpu.A100(size="40GB"), 37 | timeout=20 * HOURS, 38 | secrets=[ 39 | modal.Secret.from_name("hf-token"), 40 | modal.Secret.from_name("wandb-token") 41 | ] 42 | ) 43 | def train_policy_model_sft_upload_to_hf(): 44 | train_sft() 45 | 46 | @app.local_entrypoint() 47 | def main(): 48 | # run the function remotely on Modal 49 | train_policy_model_sft_upload_to_hf.remote() -------------------------------------------------------------------------------- /modal_train_policy_sft_metamath.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | cuda_version = "12.4.0" # should be no greater than host CUDA version 4 | flavor = "devel" # includes full CUDA toolkit 5 | operating_sys = "ubuntu22.04" 6 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 7 | 8 | image = ( 9 | # modal.Image.debian_slim() 10 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 11 | .apt_install("git") 12 | .pip_install("torch") 13 | .pip_install("packaging") 14 | .pip_install("wheel") 15 | .run_commands("pip install flash-attn --no-build-isolation") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("numpy") 19 | .pip_install("datasets") 20 | .pip_install("wandb") 21 | .pip_install("bitsandbytes") 22 | .pip_install("unsloth @ git+https://github.com/unslothai/unsloth.git") 23 | .pip_install("unsloth_zoo") 24 | .pip_install("xformers") 25 | ) 26 | app = modal.App("train_policy_sft", image=image) 27 | 28 | with image.imports(): 29 | from mcts.train_policy_sft_metamath import train_sft 30 | 31 | MINUTES = 60 # seconds 32 | HOURS = 60 * MINUTES 33 | 34 | @app.function( 35 | cpu=2.0, 36 | # gpu=modal.gpu.A10G(), 37 | gpu=modal.gpu.H100(), 38 | # gpu=modal.gpu.A100(size="40GB"), 39 | timeout=20 * HOURS, 40 | secrets=[ 41 | modal.Secret.from_name("hf-token"), 42 | modal.Secret.from_name("wandb-token") 43 | ] 44 | ) 45 | def train_policy_model_sft_upload_to_hf(): 46 | train_sft() 47 | 48 | @app.local_entrypoint() 49 | def main(): 50 | # run the function remotely on Modal 51 | train_policy_model_sft_upload_to_hf.remote() -------------------------------------------------------------------------------- /modal_train_policy_sft_qwq.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | cuda_version = "12.4.0" # should be no greater than host CUDA version 4 | flavor = "devel" # includes full CUDA toolkit 5 | operating_sys = "ubuntu22.04" 6 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 7 | 8 | image = ( 9 | # modal.Image.debian_slim() 10 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 11 | .apt_install("git") 12 | .pip_install("torch") 13 | .pip_install("packaging") 14 | .pip_install("wheel") 15 | .run_commands("pip install flash-attn --no-build-isolation") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("numpy") 19 | .pip_install("datasets") 20 | .pip_install("wandb") 21 | .pip_install("bitsandbytes") 22 | .pip_install("unsloth @ git+https://github.com/unslothai/unsloth.git") 23 | .pip_install("unsloth_zoo") 24 | .pip_install("xformers") 25 | .env({"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}) 26 | ) 27 | app = modal.App("train_policy_sft", image=image) 28 | 29 | with image.imports(): 30 | from mcts.train_policy_sft_qwq import train_sft 31 | 32 | MINUTES = 60 # seconds 33 | HOURS = 60 * MINUTES 34 | 35 | @app.function( 36 | cpu=2.0, 37 | # gpu=modal.gpu.A10G(), 38 | gpu=modal.gpu.H100(), 39 | # gpu=modal.gpu.A100(size="40GB"), 40 | timeout=20 * HOURS, 41 | secrets=[ 42 | modal.Secret.from_name("hf-token"), 43 | modal.Secret.from_name("wandb-token") 44 | ] 45 | ) 46 | def train_policy_model_sft_upload_to_hf(): 47 | train_sft() 48 | 49 | @app.local_entrypoint() 50 | def main(): 51 | # run the function remotely on Modal 52 | train_policy_model_sft_upload_to_hf.remote() -------------------------------------------------------------------------------- /modal_train_policy_sft_prime.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | cuda_version = "12.4.0" # should be no greater than host CUDA version 4 | flavor = "devel" # includes full CUDA toolkit 5 | operating_sys = "ubuntu22.04" 6 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 7 | 8 | image = ( 9 | # modal.Image.debian_slim() 10 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 11 | .apt_install("git") 12 | .pip_install("torch") 13 | .pip_install("packaging") 14 | .pip_install("wheel") 15 | .run_commands("pip install flash-attn --no-build-isolation") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("numpy") 19 | .pip_install("datasets") 20 | .pip_install("wandb") 21 | .pip_install("bitsandbytes") 22 | .pip_install("unsloth @ git+https://github.com/unslothai/unsloth.git") 23 | .pip_install("unsloth_zoo") 24 | .pip_install("xformers") 25 | .env({"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}) 26 | ) 27 | app = modal.App("train_policy_sft", image=image) 28 | 29 | with image.imports(): 30 | from mcts.train_policy_sft_prime import train_sft 31 | 32 | MINUTES = 60 # seconds 33 | HOURS = 60 * MINUTES 34 | 35 | @app.function( 36 | cpu=2.0, 37 | # gpu=modal.gpu.A10G(), 38 | gpu=modal.gpu.H100(), 39 | # gpu=modal.gpu.A100(size="40GB"), 40 | timeout=20 * HOURS, 41 | secrets=[ 42 | modal.Secret.from_name("hf-token"), 43 | modal.Secret.from_name("wandb-token") 44 | ] 45 | ) 46 | def train_policy_model_sft_upload_to_hf(): 47 | train_sft() 48 | 49 | @app.local_entrypoint() 50 | def main(): 51 | # run the function remotely on Modal 52 | train_policy_model_sft_upload_to_hf.remote() -------------------------------------------------------------------------------- /tweak_prime_rl_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import json 3 | from huggingface_hub import login 4 | 5 | dataset = load_dataset("rawsh/Eurus-2-RL-Data-ProblemsOnly", split="validation") 6 | sample = dataset[100] 7 | print("\nSample processed chat trace:") 8 | print(sample['prompt'][0]["content"]) 9 | import sys 10 | sys.exit() 11 | 12 | # Load the dataset 13 | dataset = load_dataset("PRIME-RL/Eurus-2-RL-Data") 14 | 15 | def process_chat_trace(example): 16 | """Process a single example by removing the system prompt from the chat trace.""" 17 | try: 18 | # Parse the chat trace from string to list 19 | chat_trace = example['prompt'] 20 | 21 | # Remove the first message (system prompt) 22 | modified_trace = chat_trace[1:] 23 | 24 | # Convert back to string 25 | example['prompt'] = modified_trace 26 | 27 | return example 28 | except (json.JSONDecodeError, IndexError, KeyError) as e: 29 | print(f"Error processing example: {e}") 30 | return example 31 | 32 | # Process each split in the dataset 33 | processed_dataset = {} 34 | for split in dataset: 35 | processed_dataset[split] = dataset[split].map(process_chat_trace) 36 | 37 | # You'll need to run this first and enter your token when prompted 38 | login() 39 | 40 | # Push to HF Hub 41 | for split, data in processed_dataset.items(): 42 | data.push_to_hub( 43 | "rawsh/Eurus-2-RL-Data-ProblemsOnly", 44 | split=split, 45 | private=False # Set to True if you want a private repository 46 | ) 47 | 48 | # Print sample to verify 49 | sample = processed_dataset['train'][0] if 'train' in processed_dataset else next(iter(processed_dataset.values()))[0] 50 | print("\nSample processed chat trace:") 51 | print(sample['prompt']) -------------------------------------------------------------------------------- /modal_train_prm_rlhf_flow.py: -------------------------------------------------------------------------------- 1 | # train.py 2 | import modal 3 | import yaml 4 | import os 5 | from pathlib import Path 6 | 7 | # CUDA setup 8 | AXOLOTL_REGISTRY_SHA = "9578c47333bdcc9ad7318e54506b9adaf283161092ae780353d506f7a656590a" 9 | image = ( 10 | modal.Image.from_registry(f"winglian/axolotl@sha256:{AXOLOTL_REGISTRY_SHA}") 11 | .pip_install( 12 | "huggingface_hub==0.23.2", 13 | "hf-transfer==0.1.5", 14 | "wandb==0.16.3", 15 | "fastapi==0.110.0", 16 | "pydantic==2.6.3", 17 | ) 18 | .env( 19 | dict( 20 | HUGGINGFACE_HUB_CACHE="/pretrained", 21 | HF_HUB_ENABLE_HF_TRANSFER="1", 22 | AXOLOTL_NCCL_TIMEOUT="60", 23 | ) 24 | ) 25 | .entrypoint([]) 26 | ) 27 | 28 | app = modal.App("train-hf", image=image) 29 | 30 | # Constants 31 | MINUTES = 60 32 | HOURS = 60 * MINUTES 33 | 34 | # Create volume for persistent storage 35 | training_vol = modal.Volume.from_name("training-data", create_if_missing=True) 36 | 37 | @app.function( 38 | cpu=8, 39 | gpu=modal.gpu.H100(), 40 | timeout=20 * HOURS, 41 | volumes={"/training": training_vol}, 42 | secrets=[ 43 | modal.Secret.from_name("hf-token"), 44 | modal.Secret.from_name("wandb-token") 45 | ], 46 | ) 47 | def run_training(config): 48 | import subprocess 49 | 50 | # Write the config to the container 51 | config_path = Path("/training/config.yml") 52 | with open(config_path, 'w') as f: 53 | yaml.dump(config, f) 54 | 55 | # Run training - Axolotl will handle HF upload if push_to_hub is True 56 | subprocess.run(["python", "-m", "axolotl.cli.train", config_path]) 57 | 58 | @app.local_entrypoint() 59 | def main(): 60 | # Read the local config file 61 | with open("prm_rlhf_flow/qwen.yml", 'r') as f: 62 | config = yaml.safe_load(f) 63 | 64 | # Run the training 65 | run_training.remote(config) 66 | 67 | if __name__ == "__main__": 68 | main() -------------------------------------------------------------------------------- /modal_train_prm_init.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | cuda_version = "12.4.0" # should be no greater than host CUDA version 4 | flavor = "devel" # includes full CUDA toolkit 5 | operating_sys = "ubuntu22.04" 6 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 7 | 8 | image = ( 9 | # modal.Image.debian_slim() 10 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 11 | .apt_install("git") 12 | .pip_install("torch") 13 | .pip_install("packaging") 14 | .pip_install("wheel") 15 | .run_commands("pip install flash-attn --no-build-isolation") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("numpy") 19 | .pip_install("datasets") 20 | .pip_install("wandb") 21 | .pip_install("bitsandbytes") 22 | .pip_install("matplotlib") 23 | .pip_install("seaborn") 24 | ) 25 | app = modal.App("train_prm", image=image) 26 | 27 | with image.imports(): 28 | from mcts.train_reward import train_reward_model 29 | 30 | MINUTES = 60 # seconds 31 | HOURS = 60 * MINUTES 32 | 33 | vol = modal.Volume.from_name("prm-tmp", create_if_missing=True) 34 | 35 | @app.function( 36 | cpu=2.0, 37 | # gpu=modal.gpu.A10G(), 38 | gpu=modal.gpu.H100(), 39 | # gpu=modal.gpu.A100(count=4, size="40GB"), 40 | # gpu=modal.gpu.A100(size="40GB"), 41 | timeout=20 * HOURS, 42 | secrets=[ 43 | modal.Secret.from_name("hf-token"), 44 | modal.Secret.from_name("wandb-token") 45 | ], 46 | volumes={"/out": vol}, 47 | ) 48 | def train_reward_model_upload_to_hf(): 49 | train_reward_model( 50 | # add revision 51 | model_name="Qwen/Qwen2.5-0.5B", 52 | dataset_path="rawsh/magpie-ultra-v0.1-PRM-data-base", 53 | output_model_name="rawsh/mirrorqwen2.5-0.5b-prm", 54 | disable_binning=False 55 | ) 56 | 57 | @app.local_entrypoint() 58 | def main(): 59 | # run the function remotely on Modal 60 | train_reward_model_upload_to_hf.remote() -------------------------------------------------------------------------------- /prm_rlhf_flow/qwen.yml: -------------------------------------------------------------------------------- 1 | # config.yml 2 | base_model: Qwen/Qwen2.5-0.5B 3 | # base_model: rawsh/MetaMath-Qwen2.5-0.5b 4 | model_type: AutoModelForCausalLM 5 | tokenizer_type: AutoTokenizer 6 | 7 | # HuggingFace settings 8 | push_to_hub: true # Enable direct upload to HF 9 | hub_model_id: "rawsh/MetaMath-Qwen2.5-0.5b-PRM" # Target repo name 10 | hub_strategy: "every_save" # or "end", "checkpoint", "all_checkpoints" 11 | 12 | # Model loading settings 13 | load_in_8bit: false 14 | load_in_4bit: false 15 | strict: false 16 | 17 | # # Dataset configuration 18 | # chat_template: llama3 19 | # datasets: 20 | # - path: RLHFlow/Mistral-PRM-Data 21 | # type: chat_template 22 | # split: "train" 23 | # train_on_split: "train" 24 | # field_messages: conversations 25 | # message_field_role: role 26 | # message_field_content: content 27 | 28 | 29 | datasets: 30 | - path: RLHFlow/Mistral-PRM-Data 31 | conversation: llama3 32 | type: sharegpt 33 | split: "train" 34 | train_on_split: "train" 35 | 36 | # Training settings 37 | warmup_ratio: 0.05 38 | val_set_size: 0.0 39 | output_dir: /training/prm 40 | train_on_inputs: false 41 | 42 | # Weights & Biases settings 43 | wandb_project: "preference-models" 44 | wandb_name: "qwen2.5-0.5b-bs32_lr2e-6_prm" 45 | # wandb_watch: false 46 | # wandb_log_model: false 47 | 48 | # Model saving settings 49 | save_safetensors: true 50 | dataset_prepared_path: /training/data/prepared 51 | 52 | # Training hyperparameters 53 | sequence_len: 8192 54 | sample_packing: true 55 | pad_to_sequence_len: true 56 | trust_remote_code: true 57 | gradient_checkpointing: true 58 | gradient_accumulation_steps: 4 59 | micro_batch_size: 1 60 | num_epochs: 1 61 | optimizer: paged_adamw_32bit 62 | lr_scheduler: cosine 63 | learning_rate: 2.0e-6 64 | weight_decay: 0.0 65 | max_grad_norm: 1.0 66 | 67 | # Hardware settings 68 | bf16: auto 69 | fp16: false 70 | tf32: true 71 | flash_attention: true 72 | 73 | # Logging and checkpointing 74 | logging_steps: 2 75 | save_strategy: "epoch" 76 | save_total_limit: 4 77 | 78 | # Special tokens 79 | special_tokens: 80 | pad_token: <|endoftext|> -------------------------------------------------------------------------------- /modal_train_prm_st.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | cuda_version = "12.4.0" # should be no greater than host CUDA version 4 | flavor = "devel" # includes full CUDA toolkit 5 | operating_sys = "ubuntu22.04" 6 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 7 | 8 | image = ( 9 | # modal.Image.debian_slim() 10 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 11 | .apt_install("git") 12 | .pip_install("torch") 13 | .pip_install("packaging") 14 | .pip_install("wheel") 15 | .run_commands("pip install flash-attn --no-build-isolation") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("numpy") 19 | .pip_install("datasets") 20 | .pip_install("wandb") 21 | .pip_install("bitsandbytes") 22 | .pip_install("matplotlib") 23 | .pip_install("seaborn") 24 | ) 25 | app = modal.App("train_prm", image=image) 26 | 27 | with image.imports(): 28 | from mcts.train_reward import train_reward_model 29 | 30 | MINUTES = 60 # seconds 31 | HOURS = 60 * MINUTES 32 | 33 | vol = modal.Volume.from_name("prm-tmp", create_if_missing=True) 34 | 35 | @app.function( 36 | cpu=2.0, 37 | # gpu=modal.gpu.A10G(), 38 | gpu=modal.gpu.H100(), 39 | # gpu=modal.gpu.A100(count=4, size="40GB"), 40 | # gpu=modal.gpu.A100(size="40GB"), 41 | timeout=20 * HOURS, 42 | secrets=[ 43 | modal.Secret.from_name("hf-token"), 44 | modal.Secret.from_name("wandb-token") 45 | ], 46 | volumes={"/out": vol}, 47 | ) 48 | def train_reward_model_upload_to_hf(): 49 | train_reward_model( 50 | # add revision 51 | model_name="rawsh/mirrorqwen2.5-0.5b-prm", 52 | # model_revision="aed1bcf7d3d984272e329c3843f9c5fd0dfe5ca5", # base 53 | # model_revision="42e07d1b708282ac2aae338050d8116f8c69398d", # st0 54 | # model_revision="80da7ccc4f107e0cb6bf937d61be4702badfb96b", # st1 55 | # model_revision="4d618515c90069993f4b32e4201783efdeebbc22", # st2 56 | # fucked up orpo2 prm - it used st0 as base model as well. 57 | model_revision="e49e4ca7c847194be48c42c52ad8f871da204300", # orpo2 58 | dataset_path="rawsh/mirrorqwen2.5-0.5B-gsm8k-PRM-data-ORPO-2", 59 | output_model_name="rawsh/mirrorqwen2.5-0.5b-prm", 60 | disable_binning=False 61 | ) 62 | 63 | @app.local_entrypoint() 64 | def main(): 65 | # run the function remotely on Modal 66 | train_reward_model_upload_to_hf.remote() -------------------------------------------------------------------------------- /modal_train_policy_simpo.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import sys 3 | import traceback 4 | 5 | # Define CUDA specifications 6 | cuda_version = "12.4.0" 7 | flavor = "devel" 8 | operating_sys = "ubuntu22.04" 9 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 10 | 11 | # Create Modal image with all necessary dependencies 12 | image = ( 13 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 14 | .apt_install("git") 15 | .pip_install("torch") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("datasets") 19 | .pip_install("wandb") 20 | .pip_install("trl>=0.7.6") 21 | .pip_install("huggingface_hub") 22 | .pip_install("bitsandbytes") 23 | ) 24 | 25 | with image.imports(): 26 | from mcts.train_policy_simpo import train_simpo # Import from our new simplified script 27 | 28 | # Create Modal app 29 | app = modal.App("train-policy-simpo", image=image) 30 | 31 | @app.function( 32 | cpu=4.0, 33 | gpu=modal.gpu.H100(count=1), 34 | timeout=24 * 60 * 60, 35 | memory=32768, 36 | secrets=[ 37 | modal.Secret.from_name("hf-token"), 38 | modal.Secret.from_name("wandb-token") 39 | ], 40 | ) 41 | def train_policy_simpo(): 42 | import os 43 | from huggingface_hub import HfFolder 44 | import wandb 45 | 46 | try: 47 | # Set up HuggingFace token 48 | hf_token = os.environ["HF_TOKEN"] 49 | HfFolder.save_token(hf_token) 50 | 51 | # Set up Weights & Biases 52 | wandb.login(key=os.environ["WANDB_API_KEY"]) 53 | 54 | # Run training with specified parameters 55 | train_simpo( 56 | # model_name="rawsh/mirrorqwen2.5-0.5b-SFT", 57 | # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-0", 58 | # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-1", 59 | model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-2", 60 | dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-3", 61 | output_model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-3", 62 | hub_token=hf_token 63 | ) 64 | except Exception as e: 65 | print(f"Error during training: {str(e)}", file=sys.stderr) 66 | print("Traceback:", file=sys.stderr) 67 | traceback.print_exc(file=sys.stderr) 68 | # Make sure to finish wandb run even on error 69 | try: 70 | wandb.finish() 71 | except: 72 | pass 73 | raise e 74 | 75 | @app.local_entrypoint() 76 | def main(): 77 | print("Starting full model SimPO training on Modal...") 78 | try: 79 | train_policy_simpo.remote() 80 | print("Training job submitted to Modal. Check W&B dashboard for training progress.") 81 | except Exception as e: 82 | print(f"Error in training job: {str(e)}") 83 | sys.exit(1) -------------------------------------------------------------------------------- /modal_prm_armorm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import modal 3 | 4 | image = modal.Image.debian_slim().pip_install([ 5 | "torch", "transformers", "accelerate", "batched", "hf_transfer" 6 | ]).env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) 7 | 8 | app = modal.App("reward-api", image=image) 9 | 10 | MODEL_NAME = "RLHFlow/ArmoRM-Llama3-8B-v0.1" 11 | 12 | with image.imports(): 13 | import torch 14 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 15 | from batched import inference 16 | 17 | def validate_messages(messages: List[Dict[str, str]]): 18 | if not messages or len(messages) < 2: 19 | raise ValueError("Messages must contain at least a user and assistant message") 20 | if not all(isinstance(m, dict) and 'role' in m and 'content' in m for m in messages): 21 | raise ValueError("Each message must have 'role' and 'content' fields") 22 | 23 | class RewardModelHelper: 24 | def __init__(self, model): 25 | self.model = model 26 | 27 | @inference.dynamically(batch_size=64, timeout_ms=20.0) 28 | def score_batch(self, features: dict[str, torch.Tensor]) -> torch.Tensor: 29 | with torch.no_grad(): 30 | # Move input to same device as model 31 | inputs = {k: v.to(self.model.device) for k, v in features.items()} 32 | return self.model(inputs["input_ids"]).score.float() 33 | 34 | @app.cls( 35 | gpu=modal.gpu.A10G(), 36 | allow_concurrent_inputs=1000, 37 | container_idle_timeout=300, 38 | ) 39 | class Model: 40 | def load_model(self): 41 | model = AutoModelForSequenceClassification.from_pretrained( 42 | MODEL_NAME, 43 | device_map="cuda", 44 | trust_remote_code=True, 45 | torch_dtype=torch.bfloat16, 46 | use_safetensors=True, 47 | ) 48 | return model 49 | 50 | @modal.build() 51 | def build(self): 52 | self.load_model() 53 | 54 | @modal.enter() 55 | def setup(self): 56 | self.model = self.load_model() 57 | self.tokenizer = AutoTokenizer.from_pretrained( 58 | MODEL_NAME, 59 | use_fast=True, 60 | ) 61 | self.score_batch = RewardModelHelper(self.model).score_batch 62 | 63 | @modal.web_endpoint(method="POST") 64 | async def score(self, messages_dict: Dict[str, List[Dict[str, str]]]): 65 | messages = messages_dict["messages"] 66 | validate_messages(messages) 67 | inputs = self.tokenizer.apply_chat_template( 68 | messages, 69 | return_tensors="pt", 70 | padding=True, 71 | truncation=True, 72 | tokenize=True 73 | ) 74 | score = await self.score_batch.acall({"input_ids": inputs}) 75 | return {"score": score[0].item()} -------------------------------------------------------------------------------- /modal_train_policy_orpo.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import sys 3 | import traceback 4 | 5 | # Define CUDA specifications 6 | cuda_version = "12.4.0" 7 | flavor = "devel" 8 | operating_sys = "ubuntu22.04" 9 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 10 | 11 | # Create Modal image with all necessary dependencies 12 | image = ( 13 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 14 | .apt_install("git") 15 | .pip_install("torch") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("datasets") 19 | .pip_install("wandb") 20 | .pip_install("trl>=0.7.6") 21 | .pip_install("huggingface_hub") 22 | .pip_install("bitsandbytes") 23 | ) 24 | 25 | with image.imports(): 26 | from mcts.train_policy_orpo import train_orpo # Import from our new simplified script 27 | 28 | # Create Modal app 29 | app = modal.App("train-policy-orpo", image=image) 30 | 31 | @app.function( 32 | cpu=4.0, 33 | gpu=modal.gpu.H100(count=1), 34 | timeout=24 * 60 * 60, 35 | # memory=32768, 36 | secrets=[ 37 | modal.Secret.from_name("hf-token"), 38 | modal.Secret.from_name("wandb-token") 39 | ], 40 | ) 41 | def train_policy_orpo(): 42 | import os 43 | from huggingface_hub import HfFolder 44 | import wandb 45 | 46 | try: 47 | # Set up HuggingFace token 48 | hf_token = os.environ["HF_TOKEN"] 49 | HfFolder.save_token(hf_token) 50 | 51 | # Set up Weights & Biases 52 | wandb.login(key=os.environ["WANDB_API_KEY"]) 53 | 54 | # Run training with specified parameters 55 | train_orpo( 56 | # model_name="rawsh/mirrorqwen2.5-0.5b-SFT", 57 | # model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-1", 58 | model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-2", 59 | # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-0", 60 | # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-1", 61 | # model_name="rawsh/mirrorqwen2.5-0.5b-SimPO-2", 62 | # dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-3", 63 | # dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-0", 64 | dataset_name="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ORPO-2", 65 | output_model_name="rawsh/mirrorqwen2.5-0.5b-ORPO-3", 66 | hub_token=hf_token 67 | ) 68 | except Exception as e: 69 | print(f"Error during training: {str(e)}", file=sys.stderr) 70 | print("Traceback:", file=sys.stderr) 71 | traceback.print_exc(file=sys.stderr) 72 | # Make sure to finish wandb run even on error 73 | try: 74 | wandb.finish() 75 | except: 76 | pass 77 | raise e 78 | 79 | @app.local_entrypoint() 80 | def main(): 81 | print("Starting full model ORPO training on Modal...") 82 | try: 83 | train_policy_orpo.remote() 84 | print("Training job submitted to Modal. Check W&B dashboard for training progress.") 85 | except Exception as e: 86 | print(f"Error in training job: {str(e)}") 87 | sys.exit(1) -------------------------------------------------------------------------------- /modal_tree_grpo.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import sys 3 | import traceback 4 | 5 | # Define CUDA specifications 6 | cuda_version = "12.4.0" 7 | flavor = "devel" 8 | operating_sys = "ubuntu22.04" 9 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 10 | 11 | # Create Modal image with all necessary dependencies 12 | image = ( 13 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") 14 | .apt_install("git") 15 | .pip_install("torch") 16 | .pip_install("transformers") 17 | .pip_install("accelerate") 18 | .pip_install("datasets") 19 | .pip_install("wandb") 20 | .pip_install("trl>=0.7.6") 21 | .pip_install("huggingface_hub") 22 | .pip_install("bitsandbytes") 23 | ) 24 | 25 | with image.imports(): 26 | from reinforcement_learning.tree_grpo import GRPOConfig, GRPOTrainer 27 | 28 | # Create Modal app 29 | app = modal.App("train-policy-tree-grpo", image=image) 30 | 31 | @app.function( 32 | cpu=4.0, 33 | # gpu=modal.gpu.H100(count=1), 34 | gpu=modal.gpu.A10G(), 35 | timeout=24 * 60 * 60, 36 | # memory=32768, 37 | secrets=[ 38 | modal.Secret.from_name("hf-token"), 39 | modal.Secret.from_name("wandb-token") 40 | ], 41 | ) 42 | def train_policy_grpo(): 43 | import os 44 | from huggingface_hub import HfFolder 45 | from datasets import load_dataset 46 | import wandb 47 | 48 | try: 49 | # Set up HuggingFace token 50 | hf_token = os.environ["HF_TOKEN"] 51 | HfFolder.save_token(hf_token) 52 | 53 | # Set up Weights & Biases 54 | wandb.login(key=os.environ["WANDB_API_KEY"]) 55 | 56 | # Configuration 57 | config = GRPOConfig( 58 | exp_name="math_improvement", 59 | reward_model_path="rawsh/MetaMath-Qwen2.5-0.5b-PRM", 60 | num_grpo_epochs=4, 61 | sampling_group_size=8, 62 | sampling_strategy="top_p", 63 | sampling_temperature=0.7, 64 | # learning_rate=1e-5, 65 | # num_train_epochs=3, 66 | # per_device_train_batch_size=4, 67 | # gradient_accumulation_steps=4, 68 | # output_dir="./grpo_math_model", 69 | # report_to=["wandb"] 70 | ) 71 | 72 | # Initialize wandb 73 | wandb.init( 74 | project="grpo_math", 75 | name=config.exp_name, 76 | config=vars(config) 77 | ) 78 | 79 | # Load dataset 80 | train_dataset = load_dataset("lighteval/MATH", "all", split="train") 81 | eval_dataset = load_dataset("lighteval/MATH", "all", split="test") 82 | 83 | # Create trainer 84 | trainer = GRPOTrainer.from_pretrained( 85 | config=config, 86 | pretrained_model_name_or_path="rawsh/MetaMath-Qwen2.5-0.5b", 87 | train_dataset=train_dataset, 88 | eval_dataset=eval_dataset 89 | ) 90 | 91 | # Train 92 | trainer.train() 93 | 94 | # Save final model 95 | trainer.save_model() 96 | 97 | # Close wandb 98 | wandb.finish() 99 | except Exception as e: 100 | print(f"Error during training: {str(e)}", file=sys.stderr) 101 | print("Traceback:", file=sys.stderr) 102 | traceback.print_exc(file=sys.stderr) 103 | # Make sure to finish wandb run even on error 104 | try: 105 | wandb.finish() 106 | except: 107 | pass 108 | raise e 109 | 110 | @app.local_entrypoint() 111 | def main(): 112 | print("Starting full model GRPO training on Modal...") 113 | try: 114 | train_policy_grpo.remote() 115 | print("Training job submitted to Modal. Check W&B dashboard for training progress.") 116 | except Exception as e: 117 | print(f"Error in training job: {str(e)}") 118 | sys.exit(1) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # flyctl launch added from .gitignore 2 | # Byte-compiled / optimized / DLL files 3 | **/__pycache__ 4 | **/*.py[cod] 5 | **/*$py.class 6 | 7 | # C extensions 8 | **/*.so 9 | 10 | # Distribution / packaging 11 | **/.Python 12 | **/build 13 | **/develop-eggs 14 | **/dist 15 | **/downloads 16 | **/eggs 17 | **/.eggs 18 | **/lib 19 | **/lib64 20 | **/parts 21 | **/sdist 22 | **/var 23 | **/wheels 24 | **/share/python-wheels 25 | **/*.egg-info 26 | **/.installed.cfg 27 | **/*.egg 28 | **/MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | **/*.manifest 34 | **/*.spec 35 | 36 | # Installer logs 37 | **/pip-log.txt 38 | **/pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | **/htmlcov 42 | **/.tox 43 | **/.nox 44 | **/.coverage 45 | **/.coverage.* 46 | **/.cache 47 | **/nosetests.xml 48 | **/coverage.xml 49 | **/*.cover 50 | **/*.py,cover 51 | **/.hypothesis 52 | **/.pytest_cache 53 | **/cover 54 | 55 | # Translations 56 | **/*.mo 57 | **/*.pot 58 | 59 | # Django stuff: 60 | **/*.log 61 | **/local_settings.py 62 | **/db.sqlite3 63 | **/db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | **/instance 67 | **/.webassets-cache 68 | 69 | # Scrapy stuff: 70 | **/.scrapy 71 | 72 | # Sphinx documentation 73 | **/docs/_build 74 | 75 | # PyBuilder 76 | **/.pybuilder 77 | **/target 78 | 79 | # Jupyter Notebook 80 | **/.ipynb_checkpoints 81 | 82 | # IPython 83 | **/profile_default 84 | **/ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | **/.pdm.toml 112 | **/.pdm-python 113 | **/.pdm-build 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | **/__pypackages__ 117 | 118 | # Celery stuff 119 | **/celerybeat-schedule 120 | **/celerybeat.pid 121 | 122 | # SageMath parsed files 123 | **/*.sage.py 124 | 125 | # Environments 126 | **/.env 127 | **/.venv 128 | **/env 129 | **/venv 130 | **/ENV 131 | **/env.bak 132 | **/venv.bak 133 | 134 | # Spyder project settings 135 | **/.spyderproject 136 | **/.spyproject 137 | 138 | # Rope project settings 139 | **/.ropeproject 140 | 141 | # mkdocs documentation 142 | site 143 | 144 | # mypy 145 | **/.mypy_cache 146 | **/.dmypy.json 147 | **/dmypy.json 148 | 149 | # Pyre type checker 150 | **/.pyre 151 | 152 | # pytype static type analyzer 153 | **/.pytype 154 | 155 | # Cython debug symbols 156 | **/cython_debug 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | fly.toml 165 | -------------------------------------------------------------------------------- /mcts/train_policy_orpo.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import torch 3 | from datasets import load_dataset 4 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser 5 | from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config 6 | import wandb 7 | 8 | @dataclass 9 | class ScriptArguments: 10 | model_name: str = field(default="Qwen/Qwen2-0.5B-Instruct") 11 | dataset_name: str = field(default="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-0") 12 | dataset_train_split: str = field(default="train") 13 | dataset_test_split: str = field(default="train") # Using train as test since original doesn't have test split 14 | output_model_name: str = field(default=None) 15 | hub_token: str = field(default=None) 16 | use_peft: bool = field(default=False) 17 | 18 | @dataclass 19 | class ModelArguments(ModelConfig): 20 | model_name_or_path: str = field(default="Qwen/Qwen2-0.5B-Instruct") 21 | trust_remote_code: bool = field(default=True) 22 | 23 | def train_orpo( 24 | model_name=None, 25 | dataset_name=None, 26 | output_model_name=None, 27 | hub_token=None 28 | ): 29 | # Initialize wandb 30 | wandb.init(project="orpo-training") 31 | 32 | # Initialize base arguments 33 | script_args = ScriptArguments() 34 | if model_name: 35 | script_args.model_name = model_name 36 | if dataset_name: 37 | script_args.dataset_name = dataset_name 38 | if output_model_name: 39 | script_args.output_model_name = output_model_name 40 | if hub_token: 41 | script_args.hub_token = hub_token 42 | 43 | # Set up model arguments 44 | model_args = ModelArguments( 45 | model_name_or_path=script_args.model_name 46 | ) 47 | 48 | # Set up training configuration 49 | training_args = ORPOConfig( 50 | output_dir="orpo-math-model", 51 | num_train_epochs=1, 52 | per_device_train_batch_size=8, 53 | gradient_accumulation_steps=8, 54 | # learning_rate=5e-7, 55 | # learning_rate=8e-6, 56 | lr_scheduler_type="linear", 57 | beta=0.1, 58 | learning_rate=3e-6, 59 | # max_steps 60 | max_length=2048, 61 | max_prompt_length=1024, 62 | gradient_checkpointing=True, 63 | push_to_hub=True, 64 | hub_model_id=script_args.output_model_name, 65 | hub_strategy="end", 66 | report_to=["wandb"], 67 | bf16=True, 68 | tf32=True, 69 | optim="paged_adamw_32bit", 70 | max_grad_norm=1.0, 71 | warmup_ratio=0.1, 72 | # lr_scheduler_type="cosine", 73 | do_eval=True, 74 | evaluation_strategy="steps", 75 | eval_steps=10, 76 | remove_unused_columns=False, 77 | logging_steps=10, 78 | logging_first_step=True 79 | ) 80 | 81 | # Load tokenizer 82 | tokenizer = AutoTokenizer.from_pretrained( 83 | model_args.model_name_or_path, 84 | trust_remote_code=model_args.trust_remote_code 85 | ) 86 | tokenizer.pad_token = tokenizer.eos_token 87 | 88 | # Load model 89 | model = AutoModelForCausalLM.from_pretrained( 90 | model_args.model_name_or_path, 91 | trust_remote_code=model_args.trust_remote_code, 92 | torch_dtype=torch.float16, 93 | device_map="auto" 94 | ) 95 | model.config.use_cache = False 96 | 97 | # Load and process dataset 98 | dataset = load_dataset(script_args.dataset_name, token=script_args.hub_token) 99 | train_dataset = dataset["train"].map( 100 | lambda examples: { 101 | "prompt": examples["question"], 102 | "chosen": ["\n\n".join(ex["steps"]) for ex in examples["positive"]], 103 | "rejected": ["\n\n".join(ex["steps"]) for ex in examples["negative"]] 104 | }, 105 | batched=True, 106 | remove_columns=dataset["train"].column_names 107 | ) 108 | 109 | # Initialize trainer 110 | trainer = ORPOTrainer( 111 | model=model, 112 | args=training_args, 113 | train_dataset=train_dataset, 114 | eval_dataset=train_dataset, 115 | processing_class=tokenizer, 116 | peft_config=get_peft_config(model_args) if script_args.use_peft else None, 117 | ) 118 | 119 | # Train the model 120 | trainer.train() 121 | trainer.save_model() 122 | 123 | wandb.finish() 124 | 125 | if __name__ == "__main__": 126 | train_orpo() -------------------------------------------------------------------------------- /mcts/generate.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from openai import AsyncOpenAI 3 | import json 4 | from typing import List, Tuple 5 | from datasets import load_dataset 6 | from util import split_and_clean_steps, quality_filter, SEED 7 | from tqdm import tqdm 8 | 9 | client = AsyncOpenAI( 10 | api_key="9FF74944EED19865193F979942FB1",adfghk 11 | base_url="https://rawsh--vllm-smollm-serve.modal.run/v1" 12 | ) 13 | 14 | def format_thoughts(thoughts: List[str]) -> str: 15 | return "\n".join(f"## Step {i}:\n{thought}" for i, thought in enumerate(thoughts, 1)) 16 | 17 | # template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n\ 18 | # <|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}" 19 | 20 | template = "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n\ 21 | <|im_start|>human\n{user}<|im_end|>\n<|im_start|>assistant\n{assistant_partial}" 22 | 23 | class ReasoningTrace: 24 | def __init__(self, question: str, previous_thoughts: List[str], next_step: int): 25 | self.question = question 26 | self.previous_thoughts = previous_thoughts 27 | self.next_step = next_step 28 | 29 | class ProcessedReasoningTrace: 30 | def __init__(self, question: str, thoughts: List[str]): 31 | self.question = question 32 | self.thoughts = thoughts 33 | 34 | async def generate_thought_batched(batch: List[ReasoningTrace]) -> List[ProcessedReasoningTrace]: 35 | prompts = [] 36 | for trace in batch: 37 | formatted_thoughts = format_thoughts(trace.previous_thoughts) 38 | prompt = template.format(user=trace.question, assistant_partial=f"{formatted_thoughts}\n## Step {trace.next_step}:\n") 39 | prompts.append(prompt) 40 | 41 | params = { 42 | # "model": "Qwen/Qwen2.5-0.5B-Instruct", 43 | "model": "HuggingFaceTB/SmolLM2-135M-Instruct", 44 | "prompt": prompts, 45 | "max_tokens": 200, 46 | # "temperature": 0.7, 47 | "temperature": 0.0, 48 | "stop": ["\n## Step"], 49 | "timeout": 600 50 | } 51 | 52 | try: 53 | response = await client.completions.create(**params) 54 | processed = [ 55 | ProcessedReasoningTrace( 56 | question=batch[i].question, 57 | thoughts=batch[i].previous_thoughts + [response.choices[i].text.strip()] 58 | ) for i in range(len(batch)) 59 | ] 60 | return processed 61 | except Exception as e: 62 | print(f"An error occurred: {str(e)}") 63 | return None 64 | 65 | async def format_thought_chain(question: str, chain: List[str]) -> List[ReasoningTrace]: 66 | return [ReasoningTrace(question, chain[:i], i+1) for i in range(0, len(chain))] 67 | 68 | async def process_batch(batch: List[ReasoningTrace], semaphore: asyncio.Semaphore) -> List[ProcessedReasoningTrace]: 69 | async with semaphore: 70 | return await generate_thought_batched(batch) 71 | 72 | async def process_all_thought_chains_batched(thought_chains: List[Tuple[str, List[str]]]) -> List[ProcessedReasoningTrace]: 73 | batch_size = 200 74 | all_traces = [] 75 | 76 | for question, chain in thought_chains: 77 | all_traces.extend(await format_thought_chain(question, chain)) 78 | 79 | results = [] 80 | semaphore = asyncio.Semaphore(10) # Limit to 10 concurrent batches 81 | tasks = [] 82 | 83 | for i in range(0, len(all_traces), batch_size): 84 | batch = all_traces[i:i + batch_size] 85 | task = asyncio.create_task(process_batch(batch, semaphore)) 86 | tasks.append(task) 87 | 88 | for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing batches"): 89 | processed_batch = await task 90 | if processed_batch: 91 | results.extend(processed_batch) 92 | 93 | return results 94 | 95 | async def main(): 96 | ds = load_dataset("argilla/magpie-ultra-v0.1") 97 | filtered_ds = ds.filter(quality_filter) 98 | split_ds = filtered_ds['train'].train_test_split(test_size=0.1, seed=SEED) 99 | train_ds = split_ds['train'] 100 | correct_traces = [(row["instruction"], split_and_clean_steps(row["response"])) for row in train_ds] 101 | 102 | # correct_traces = correct_traces[:1000] 103 | generated_thoughts = await process_all_thought_chains_batched(correct_traces) 104 | 105 | with open("out.jsonl", "w") as f: 106 | for chain in generated_thoughts: 107 | json.dump(chain.__dict__, f) 108 | f.write("\n") 109 | 110 | print(f"Results written to out.jsonl") 111 | 112 | if __name__ == "__main__": 113 | asyncio.run(main()) -------------------------------------------------------------------------------- /modal_orm_reward.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | image = ( 4 | modal.Image.debian_slim() 5 | .pip_install("torch") 6 | .pip_install("transformers") 7 | .pip_install("accelerate") 8 | ) 9 | app = modal.App("mirrorgemma-prm", image=image) 10 | 11 | 12 | with image.imports(): 13 | from typing import List, Dict, Tuple 14 | import asyncio 15 | import torch 16 | from time import perf_counter as pc 17 | import copy 18 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 19 | import os 20 | # from lib import extract_tensors, test 21 | # print(test()) 22 | 23 | @app.cls( 24 | gpu=modal.gpu.A10G(), 25 | container_idle_timeout=30, 26 | # volumes={"/data": modal.Volume.from_name("my-test-volume")} 27 | ) 28 | class Embedder: 29 | model_id = "RLHFlow/ArmoRM-Llama3-8B-v0.1" 30 | # model_id = "rawsh/mirrorgemma-2-2b-prm-base" 31 | device = "cuda" 32 | 33 | @modal.build() 34 | def build(self): 35 | # cache 36 | print("build") 37 | dtype = torch.bfloat16 38 | with torch.device("cuda"): 39 | print("[build] loading model") 40 | start = pc() 41 | model = AutoModelForSequenceClassification.from_pretrained(self.model_id, 42 | trust_remote_code=True, torch_dtype=dtype, use_safetensors=True) 43 | elapsed = pc() - start 44 | print(f"[build] loading model took {elapsed} seconds") 45 | 46 | print("[build] compile model") 47 | start = pc() 48 | # model = torch.compile(model) 49 | torch.compile(model) 50 | elapsed = pc() - start 51 | print(f"[build] compile model took {elapsed} seconds") 52 | 53 | # print("[build] save model") 54 | # start = pc() 55 | # model.save_pretrained("/data/saved_model", safe_serialization=True) 56 | # elapsed = pc() - start 57 | # print(f"[build] saving model took {elapsed} seconds") 58 | 59 | # @modal.enter(snap=False) 60 | @modal.enter() 61 | def setup(self): 62 | # Start the model to a GPU before doing any work. 63 | print("setup") 64 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 65 | 66 | # faster model loading 67 | dtype = torch.bfloat16 68 | with torch.device("cuda"): 69 | print("[setup] loading model") 70 | start = pc() 71 | self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id, 72 | trust_remote_code=True, torch_dtype=dtype, use_safetensors=True) 73 | elapsed = pc() - start 74 | print(f"[setup] loading model took {elapsed} seconds") 75 | 76 | # print("[setup] compile model") 77 | # start = pc() 78 | # self.model = torch.compile(self.model) 79 | # elapsed = pc() - start 80 | # print(f"[setup] compile model took {elapsed} seconds") 81 | 82 | print("[setup] loading tokenizer") 83 | start = pc() 84 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True) 85 | elapsed = pc() - start 86 | print(f"[setup] loading tokenizer took {elapsed} seconds") 87 | 88 | @modal.web_endpoint(method="POST", docs=True) 89 | def score_output(self, prompt: str): 90 | print("score_output") 91 | input_ids = self.tokenizer.apply_chat_template( 92 | messages, 93 | return_tensors="pt", 94 | padding=True, 95 | truncation=True, 96 | # max_length=4096, 97 | max_length=8192, 98 | ).to("cuda") 99 | with torch.no_grad(): 100 | output = self.model(input_ids) 101 | print(output) 102 | float_output = output.score.float() 103 | print("Score:", float_output.item()) 104 | return float_output.item() 105 | 106 | 107 | # @app.local_entrypoint() 108 | # async def main(): 109 | # # score the messages 110 | # prompt = 'What are some synonyms for the word "beautiful"?' 111 | # response1 = 'Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant' 112 | # response2 = 'bad' 113 | # messages1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}] 114 | # messages2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}] 115 | # m1 = Embedder().score_output(messages1) 116 | # m2 = Embedder().score_output(messages2) 117 | # res = await asyncio.gather(*[m1,m2]) 118 | # print(response1, res[0]) 119 | # print(response2, res[1]) -------------------------------------------------------------------------------- /test_vllm_prm.py: -------------------------------------------------------------------------------- 1 | from openai import AsyncOpenAI 2 | import asyncio 3 | import math 4 | 5 | class RewardModelClient: 6 | def __init__(self): 7 | self.client = AsyncOpenAI( 8 | base_url="https://rawsh--vllm-qwen-prm-serve.modal.run/v1/", 9 | api_key="9FF74944EED19865193F979942FB1" 10 | ) 11 | self.model_name = "MetaMath-Qwen2.5-0.5b-PRM" 12 | 13 | async def get_token_probability(self, response) -> float: 14 | """Extract probability of + token from response""" 15 | logprobs = response.choices[0].logprobs.content[0].top_logprobs 16 | 17 | # Print tokens and their probabilities for debugging 18 | token_probs = {lp.token: math.exp(lp.logprob) for lp in logprobs} 19 | print("Available tokens and probs:", token_probs) 20 | 21 | # Get raw probabilities, defaulting to very small number if token not found 22 | prob_plus = next((math.exp(lp.logprob) for lp in logprobs if lp.token == "+"), 1e-10) 23 | prob_minus = next((math.exp(lp.logprob) for lp in logprobs if lp.token == "-"), 1e-10) 24 | 25 | # Normalize between + and - 26 | return prob_plus / (prob_plus + prob_minus) if (prob_plus + prob_minus) > 0 else 0.5 27 | 28 | async def evaluate_steps(self, question: str, steps: list[str]) -> list[float]: 29 | """ 30 | Evaluate each step in the solution getting probabilities of + vs - 31 | Returns probability of + for each step 32 | """ 33 | probabilities = [] 34 | 35 | # First evaluate question + first step 36 | messages = [ 37 | {"role": "user", "content": f"{question}\n{steps[0]}"} 38 | ] 39 | 40 | try: 41 | response = await self.client.chat.completions.create( 42 | model=self.model_name, 43 | messages=messages, 44 | max_tokens=1, 45 | temperature=0, 46 | logprobs=True, 47 | top_logprobs=20 48 | ) 49 | prob = await self.get_token_probability(response) 50 | probabilities.append(prob) 51 | 52 | except Exception as e: 53 | print(f"Error evaluating first step: {str(e)}") 54 | probabilities.append(0.5) 55 | 56 | # For remaining steps 57 | for i in range(1, len(steps)): 58 | try: 59 | # Build conversation including previous steps 60 | messages = [ 61 | {"role": "user", "content": f"{question}\n{steps[0]}"} 62 | ] 63 | 64 | for prev_step in steps[1:i]: 65 | messages.extend([ 66 | {"role": "assistant", "content": "+"}, 67 | {"role": "user", "content": prev_step} 68 | ]) 69 | 70 | messages.append({"role": "assistant", "content": "+"}) 71 | messages.append({"role": "user", "content": steps[i]}) 72 | 73 | response = await self.client.chat.completions.create( 74 | model=self.model_name, 75 | messages=messages, 76 | max_tokens=1, 77 | temperature=0, 78 | logprobs=True, 79 | top_logprobs=20 80 | ) 81 | 82 | prob = await self.get_token_probability(response) 83 | probabilities.append(prob) 84 | 85 | except Exception as e: 86 | print(f"Error evaluating step {i+1}: {str(e)}") 87 | probabilities.append(0.5) 88 | 89 | return probabilities 90 | 91 | async def main(): 92 | # Initialize client 93 | reward_model = RewardModelClient() 94 | 95 | # Example problem 96 | question = "Janet has 3 apples and buys 2 more. How many apples does she have?" 97 | steps = [ 98 | "Step 1: If Janet has 3 apples and buys 2 more, total apples = 3 + 2 = 5.", 99 | "Step 2: Therefore, Janet has 5 apples. The answer is: 5", 100 | ] 101 | 102 | try: 103 | # Get evaluations 104 | probabilities = await reward_model.evaluate_steps(question, steps) 105 | 106 | # Print results 107 | print("\nResults:") 108 | print("Question:", question) 109 | print("\nStep Evaluations:") 110 | for step, prob in zip(steps, probabilities): 111 | print(f"P(+) = {prob:.3f}: {step}") 112 | 113 | except Exception as e: 114 | print(f"Error occurred: {str(e)}") 115 | 116 | if __name__ == "__main__": 117 | asyncio.run(main()) -------------------------------------------------------------------------------- /mcts/train_policy_simpo.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import torch 3 | from datasets import load_dataset 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from trl import CPOConfig, CPOTrainer 6 | import wandb 7 | 8 | @dataclass 9 | class ScriptArguments: 10 | model_name: str = field(default="Qwen/Qwen2-0.5B-Instruct") 11 | dataset_name: str = field(default="rawsh/mirrorqwen2.5-0.5B-gsm8k-policy-data-ST-0") 12 | output_dir: str = field(default="simpo-math-model") 13 | warmup_ratio: float = field(default=0.1) # 10% warmup 14 | lr_scheduler_type: str = field(default="cosine") # Cosine decay 15 | max_grad_norm: float = field(default=1.0) 16 | output_model_name: str = field(default=None) 17 | hub_token: str = field(default=None) 18 | push_to_hub: bool = field(default=True) 19 | # learning_rate: float = field(default=3e-7) 20 | learning_rate: float = field(default=5e-7) 21 | batch_size: int = field(default=8) 22 | num_train_epochs: int = field(default=10) 23 | # max_steps: int = field(default=-1) 24 | # max_steps: int = field(default=10) 25 | gradient_accumulation_steps: int = field(default=8) 26 | beta: float = field(default=2.0) 27 | simpo_gamma: float = field(default=0.5) 28 | 29 | # class CustomCPOTrainer(CPOTrainer): 30 | # def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 31 | # loss, outputs = super().compute_loss(model, inputs, return_outputs=True) 32 | # wandb.log({"loss": loss.item()}, step=self.state.step) 33 | # if return_outputs: 34 | # return loss, outputs 35 | # return loss 36 | 37 | def train_simpo( 38 | model_name=None, 39 | dataset_name=None, 40 | output_model_name=None, 41 | hub_token=None 42 | ): 43 | args = ScriptArguments() 44 | if model_name: 45 | args.model_name = model_name 46 | if dataset_name: 47 | args.dataset_name = dataset_name 48 | if output_model_name: 49 | args.output_model_name = output_model_name 50 | if hub_token: 51 | args.hub_token = hub_token 52 | 53 | wandb.init(project="simpo-training") 54 | 55 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) 56 | tokenizer.pad_token = tokenizer.eos_token 57 | 58 | model = AutoModelForCausalLM.from_pretrained( 59 | args.model_name, 60 | trust_remote_code=True, 61 | torch_dtype=torch.float16, 62 | device_map="auto" 63 | ) 64 | model.config.use_cache = False 65 | 66 | dataset = load_dataset(args.dataset_name, token=args.hub_token) 67 | train_dataset = dataset["train"].map( 68 | lambda examples: { 69 | "prompt": examples["question"], 70 | "chosen": ["\n\n".join(ex["steps"]) for ex in examples["positive"]], 71 | "rejected": ["\n\n".join(ex["steps"]) for ex in examples["negative"]] 72 | }, 73 | batched=True, 74 | remove_columns=dataset["train"].column_names 75 | ) 76 | 77 | training_args = CPOConfig( 78 | output_dir=args.output_dir, 79 | num_train_epochs=args.num_train_epochs, 80 | per_device_train_batch_size=args.batch_size, 81 | gradient_accumulation_steps=args.gradient_accumulation_steps, 82 | learning_rate=args.learning_rate, 83 | # max_steps=args.max_steps, 84 | remove_unused_columns=False, 85 | loss_type="simpo", 86 | cpo_alpha=0.5, 87 | beta=args.beta, 88 | simpo_gamma=args.simpo_gamma, 89 | max_length=2048, 90 | max_prompt_length=1024, 91 | gradient_checkpointing=True, 92 | push_to_hub=args.push_to_hub, 93 | hub_model_id=args.output_model_name, 94 | hub_token=args.hub_token, 95 | hub_strategy="end", 96 | report_to=["wandb"], 97 | # Mixed precision settings 98 | bf16=True, # Use bfloat16 instead of fp16 99 | tf32=True, 100 | optim="paged_adamw_32bit", # Use 32-bit optimizer 101 | max_grad_norm=args.max_grad_norm, 102 | warmup_ratio=args.warmup_ratio, 103 | lr_scheduler_type=args.lr_scheduler_type, 104 | do_eval=True, 105 | evaluation_strategy="steps", 106 | eval_steps=20, 107 | ) 108 | 109 | trainer = CPOTrainer( 110 | model=model, 111 | args=training_args, 112 | train_dataset=train_dataset, 113 | eval_dataset=train_dataset, 114 | processing_class=tokenizer 115 | ) 116 | 117 | trainer.train() 118 | trainer.save_model() 119 | 120 | # if args.push_to_hub and args.output_model_name: 121 | # print("saving model") 122 | # trainer.push_to_hub(repo_id=args.output_model_name, commit_message="Final SimPO model") 123 | # tokenizer.push_to_hub(repo_id=args.output_model_name) 124 | 125 | wandb.finish() 126 | 127 | if __name__ == "__main__": 128 | train_simpo() -------------------------------------------------------------------------------- /modal_prm_reward.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | image = ( 4 | modal.Image.debian_slim() 5 | .pip_install([ 6 | "torch", 7 | "transformers", 8 | "accelerate", 9 | "batched", 10 | ]) 11 | ) 12 | # app = modal.App("mirrorqwen-prm", image=image) 13 | app = modal.App("mirrorqwen-prm-st", image=image) 14 | 15 | with image.imports(): 16 | from typing import List, Dict, Tuple 17 | import asyncio 18 | import torch 19 | from time import perf_counter as pc 20 | from transformers import pipeline 21 | import os 22 | 23 | class BatchProcessor: 24 | def __init__(self): 25 | import batched 26 | self.batched = batched 27 | 28 | def create_batch_processor(self, pipeline_func): 29 | @self.batched.dynamically(batch_size=256, timeout_ms=200.0, small_batch_threshold=4) 30 | def _process_batch(prompts: List[str]) -> List[Dict]: 31 | return pipeline_func(prompts) 32 | return _process_batch 33 | 34 | @app.cls( 35 | # gpu=modal.gpu.T4(), 36 | gpu=modal.gpu.A10G(), 37 | # gpu=modal.gpu.H100(), 38 | # gpu=modal.gpu.A100(), 39 | container_idle_timeout=120, 40 | # allow_concurrent_inputs=1000, 41 | allow_concurrent_inputs=1000, 42 | secrets=[ 43 | modal.Secret.from_name("hf-token"), 44 | ], 45 | ) 46 | class Embedder: 47 | model_id = "rawsh/mirrorqwen2.5-0.5b-prm" 48 | # revision = "894341fbd81d0c1abdd98b4e0630de932aa63c6f" # base 49 | # revision = "42e07d1b708282ac2aae338050d8116f8c69398d" # st0 50 | # revision = "65f4a7601dffacc40e0ef7fa4733d346c926bd18" # st1 v1 51 | # revision = "80da7ccc4f107e0cb6bf937d61be4702badfb96b" # st1 v2 52 | # revision = "4d618515c90069993f4b32e4201783efdeebbc22" # st2 53 | # revision = "b052380b619e5c62ce9f407522362f5caf7b8346" # st3 54 | # note: orpo 1 st for prm used strong/weak to generate samples. 55 | # inference pair to gen data for orpo 2 was orpo 1 policy + st0 56 | # revision = "e49e4ca7c847194be48c42c52ad8f871da204300" # orpo2 57 | revision = "ecae5a74ef094d6e839dcb2a32500c36e6786ad1" # orpo3 58 | device = "cuda" 59 | print(model_id) 60 | 61 | @modal.build() 62 | def build(self): 63 | print("build") 64 | dtype = torch.bfloat16 65 | with torch.device("cuda"): 66 | print("[build] loading model") 67 | start = pc() 68 | classifier = pipeline("sentiment-analysis", model=self.model_id, revision=self.revision, 69 | trust_remote_code=True, torch_dtype=dtype, device="cuda") 70 | elapsed = pc() - start 71 | print(f"[build] loading model took {elapsed} seconds") 72 | 73 | @modal.enter() 74 | def setup(self): 75 | print("setup") 76 | dtype = torch.bfloat16 77 | with torch.device("cuda"): 78 | print("[setup] loading model") 79 | start = pc() 80 | self.pipeline = pipeline("sentiment-analysis", model=self.model_id, revision=self.revision, 81 | trust_remote_code=True, torch_dtype=dtype, device="cuda", batch_size=256) 82 | elapsed = pc() - start 83 | print(f"[setup] loading model took {elapsed} seconds") 84 | 85 | # Initialize batch processor 86 | batch_processor = BatchProcessor() 87 | self._process_batch = batch_processor.create_batch_processor(self.pipeline) 88 | 89 | @modal.web_endpoint(method="POST", docs=True) 90 | async def score_output(self, inp: dict): 91 | prompt = inp["prompt"] 92 | # Handle both single inputs and lists of inputs 93 | if isinstance(prompt, str): 94 | prompts = [prompt] 95 | else: 96 | prompts = prompt 97 | 98 | try: 99 | # Use the batched processing method 100 | results = await self._process_batch.acall(prompts) 101 | 102 | # Return single result if input was single, otherwise return list 103 | if isinstance(inp["prompt"], str): 104 | return results[0] 105 | return results 106 | except Exception as e: 107 | return {"error": str(e)} 108 | 109 | @app.local_entrypoint() 110 | async def main(): 111 | embedder = Embedder() 112 | 113 | # Test with multiple prompts 114 | prompt = 'What are some synonyms for the word "beautiful"?' 115 | response1 = 'Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant' 116 | response2 = 'bad' 117 | 118 | # Create batch of requests 119 | inputs = [ 120 | {"prompt": response1}, 121 | {"prompt": response2} 122 | ] 123 | 124 | # Process in parallel 125 | results = await asyncio.gather(*[ 126 | embedder.score_output(inp) for inp in inputs 127 | ]) 128 | 129 | # Print results 130 | for response, result in zip([response1, response2], results): 131 | print(f"Response: {response}\nResult: {result}\n") 132 | 133 | # Print batching statistics 134 | print("Batching stats:", embedder._process_batch.stats) -------------------------------------------------------------------------------- /mcts/train_policy_sft.py: -------------------------------------------------------------------------------- 1 | from unsloth import FastLanguageModel 2 | import torch 3 | 4 | from trl import SFTTrainer 5 | from transformers import TrainingArguments 6 | from unsloth import is_bfloat16_supported 7 | from unsloth import UnslothTrainer, UnslothTrainingArguments 8 | 9 | from datasets import load_dataset 10 | 11 | 12 | # DUPLICATED CODE FOR MODAL 13 | # --------------------- 14 | import re 15 | SEED = 42 16 | 17 | def split_and_clean_steps(text): 18 | # Use regex to split the text into steps 19 | steps = re.split(r'(?=##\s*Step\s+\d+:)', text) 20 | 21 | # Remove any leading/trailing whitespace, empty steps, and the "## Step n:" prefix 22 | cleaned_steps = [] 23 | for step in steps: 24 | # Strip whitespace and check if step is not empty 25 | step = step.strip() 26 | if step: 27 | # Remove the "## Step n:" prefix 28 | step = re.sub(r'^##\s*Step\s+\d+:\s*', '', step) 29 | cleaned_steps.append(step) 30 | 31 | return cleaned_steps 32 | 33 | def quality_filter(example): 34 | response_quality = example['score'] >= 0.32 # arbitrary af 35 | # TODO: check correctness of chain 36 | # math_and_reasoning = example['primary_tag'] in ['Math', 'Reasoning'] 37 | instruction_quality = example['quality'] in ['excellent', 'good'] 38 | response_format = "## Step 1: " in example['response'] 39 | return response_quality and instruction_quality and response_format 40 | # --------------------- 41 | 42 | 43 | def train_sft(): 44 | max_seq_length = 8192 # Choose any! We auto support RoPE Scaling internally! 45 | dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ 46 | load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False. 47 | 48 | model, tokenizer = FastLanguageModel.from_pretrained( 49 | # model_name = "unsloth/gemma-2-2b", 50 | model_name = "Qwen/Qwen2.5-0.5B", 51 | max_seq_length = max_seq_length, 52 | dtype = dtype, 53 | load_in_4bit = load_in_4bit, 54 | # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf 55 | ) 56 | 57 | model = FastLanguageModel.get_peft_model( 58 | model, 59 | r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 60 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", 61 | "gate_proj", "up_proj", "down_proj", 62 | "embed_tokens", "lm_head",], # Add for continual pretraining 63 | lora_alpha = 32, 64 | lora_dropout = 0, # Supports any, but = 0 is optimized 65 | bias = "none", # Supports any, but = "none" is optimized 66 | # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! 67 | use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context 68 | random_state = 3407, 69 | use_rslora = True, # We support rank stabilized LoRA 70 | loftq_config = None, # And LoftQ 71 | ) 72 | 73 | 74 | # dataset 75 | ds = load_dataset("argilla/magpie-ultra-v0.1") 76 | filtered_ds = ds.filter(quality_filter) 77 | split_ds = filtered_ds['train'].train_test_split(test_size=0.1, seed=SEED) 78 | train_ds = split_ds['train'] 79 | 80 | EOS_TOKEN = tokenizer.eos_token 81 | def formatting_prompts_func(examples): 82 | texts = [] 83 | for instruction, response in zip(examples['instruction'], examples['response']): 84 | clean_steps = split_and_clean_steps(response) 85 | all_steps = "\n\n".join(clean_steps) 86 | 87 | prompt = f"{instruction}\n\n{all_steps}{EOS_TOKEN}" 88 | texts.append(prompt) 89 | 90 | return {"text": texts} 91 | formatted_dataset = train_ds.map(formatting_prompts_func, batched = True,) 92 | 93 | 94 | trainer = UnslothTrainer( 95 | model = model, 96 | tokenizer = tokenizer, 97 | train_dataset = formatted_dataset, 98 | dataset_text_field = "text", 99 | max_seq_length = max_seq_length, 100 | dataset_num_proc = 8, 101 | packing = True, 102 | 103 | args = UnslothTrainingArguments( 104 | per_device_train_batch_size = 2, 105 | gradient_accumulation_steps = 8, 106 | 107 | warmup_ratio = 0.1, 108 | num_train_epochs = 1, 109 | 110 | learning_rate = 4e-4, 111 | embedding_learning_rate = 4e-5, 112 | 113 | fp16 = not is_bfloat16_supported(), 114 | bf16 = is_bfloat16_supported(), 115 | logging_steps = 1, 116 | optim = "adamw_torch_fused", 117 | weight_decay = 0.01, 118 | lr_scheduler_type = "cosine", 119 | seed = 3407, 120 | output_dir = "outputs", 121 | ), 122 | ) 123 | 124 | #@title Show current memory stats 125 | gpu_stats = torch.cuda.get_device_properties(0) 126 | start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) 127 | max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) 128 | print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") 129 | print(f"{start_gpu_memory} GB of memory reserved.") 130 | 131 | trainer_stats = trainer.train() 132 | 133 | # model.push_to_hub_merged("rawsh/mirrorgemma-2-2b-SFT", tokenizer, save_method = "merged_16bit") 134 | model.push_to_hub_merged("rawsh/mirrorqwen2.5-0.5b-SFT", tokenizer, save_method = "merged_16bit") -------------------------------------------------------------------------------- /mcts/reward.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import numpy as np 3 | from util import split_and_clean_steps, quality_filter, SEED 4 | import json 5 | 6 | def initialize_prm(traces, last_step_correct=True): 7 | """ 8 | Initialize the Process Reward Model (PRM) using sets of reasoning traces. 9 | 10 | Args: 11 | traces (list of list of str): Reasoning traces 12 | correct (bool): Whether the traces are correct (True) or incorrect (False) 13 | 14 | Returns: 15 | dict: Initialized PRM with quality values and weighted rewards 16 | """ 17 | # prm = {} 18 | prm_data = [] 19 | 20 | for i, trace_tuple in enumerate(traces): 21 | question, trace = trace_tuple 22 | K = len(trace) # Total number of reasoning steps 23 | 24 | # Initialize trace 25 | prm_example = {"steps": [], "quality_values": [], "weighted_rewards": []} 26 | v_prev = 0 27 | for k, step in enumerate(trace, 1): 28 | penalize = (not last_step_correct) and k == len(trace) 29 | m_k = K - k if (not penalize) else K - k + 1 # One more step needed to correct mistake if incorrect 30 | r_s_k = 0 if (not penalize) else 1 # 0 for correct steps, 1 for incorrect steps 31 | w_s_k = (1 - v_prev) / (m_k + 1) * (1 - 2 * r_s_k) 32 | v_k = max(v_prev + w_s_k, 0) 33 | 34 | prm_example["question"] = question 35 | prm_example["steps"].append(step) 36 | prm_example["quality_values"].append(v_k) 37 | prm_example["weighted_rewards"].append(w_s_k) 38 | v_prev = v_k 39 | 40 | prm_data.append(prm_example) 41 | 42 | return prm_data 43 | 44 | 45 | # Load and filter the dataset, then apply the 90:10 split 46 | ds = load_dataset("argilla/magpie-ultra-v0.1") 47 | # Filter the dataset 48 | filtered_ds = ds.filter(quality_filter) 49 | # Apply the 90:10 split on the filtered training data 50 | split_ds = filtered_ds['train'].train_test_split(test_size=0.1, seed=SEED) 51 | train_ds = split_ds['train'] 52 | print(len(train_ds)) 53 | # "Correct" traces generated by 405B 54 | correct_traces = [(row["instruction"], split_and_clean_steps(row["response"])) for row in train_ds] 55 | 56 | # Example usage: 57 | # correct_traces = [ 58 | # ["Step 1: Correct", "Step 2: Correct", "Step 3: Correct"], 59 | # ["Step 1: Correct", "Step 2: Correct"] 60 | # ] 61 | 62 | with open('out.jsonl') as f: 63 | last_step_incorrect_data = [json.loads(line) for line in f] 64 | last_step_incorrect_traces = [(ex["question"], ex["thoughts"]) for ex in last_step_incorrect_data] 65 | 66 | # incorrect_traces = [['Identify all the possible outcomes of tossing four coins simultaneously. When tossing four coins simultaneously, each coin has 2 possible outcomes (heads or tails). Therefore, for four coins, the total number of possible outcomes is $2^4 = 16$.', 'List all the outcomes that result in more heads than tails. There are 4 outcomes that meet this criterion: HTHT, HHTT, THTH, TTHH. This gives us a total of 4 favorable outcomes.'], ['Identify all the possible outcomes of tossing four coins simultaneously. When tossing four coins simultaneously, each coin has 2 possible outcomes (heads or tails). Therefore, for four coins, the total number of possible outcomes is $2^4 = 16$.', 'Determine the favorable outcomes. We want more heads than tails, which means we need 3 heads and 1 tail, or 4 heads.', 'Count the number of outcomes with 3 heads and 1 tail. For 3 heads, there is only 1 way to arrange them (HHH). For 1 tail, there are 2 ways to arrange them (TTH and THT). So, there are a total of 1 + 2 = 3 favorable outcomes.'], ['Recognize that this is an arithmetic sequence with a common difference of 1.', 'To find the sum of the first 100 positive integers, we can use the formula for the sum of an arithmetic series, which is given by S = n/2 * (a1 + an), where n is the number of terms, a1 is the first term, and an is the last term.']] 67 | 68 | # initialized_prm = initialize_prm(correct_traces) 69 | # print(initialized_prm) 70 | # print(initialized_prm["trace_1000"]) 71 | 72 | correct_prm_data = initialize_prm(correct_traces, last_step_correct=True) 73 | print(len(correct_prm_data)) 74 | total_length = 0 75 | correct_prm_data_step_values = [] 76 | for ex in correct_prm_data: 77 | total_length += len(ex["steps"]) 78 | for i in range(len(ex["steps"])): 79 | question = ex["question"] 80 | partial_steps = ex["steps"][:i+1] 81 | partial_reward = ex["quality_values"][i] 82 | correct_prm_data_step_values.append({ 83 | "question": question, 84 | "steps": partial_steps, 85 | "final_step_reward": partial_reward 86 | }) 87 | 88 | print("corr total # step values", total_length) 89 | 90 | last_step_incorrect_prm_data = initialize_prm(last_step_incorrect_traces, last_step_correct=False) 91 | print(len(last_step_incorrect_prm_data)) 92 | 93 | last_step_incorrect_prm_data_step_values = [] 94 | for ex in last_step_incorrect_prm_data: 95 | i = len(ex["steps"]) - 1 96 | question = ex["question"] 97 | partial_steps = ex["steps"][:i+1] 98 | partial_reward = ex["quality_values"][i] 99 | last_step_incorrect_prm_data_step_values.append({ 100 | "question": question, 101 | "steps": partial_steps, 102 | "final_step_reward": partial_reward 103 | }) 104 | 105 | print("last step incorr total # step values", len(last_step_incorrect_prm_data_step_values)) 106 | 107 | # print(initialized_prm) 108 | # print(last_step_incorrect_prm_data[1000]) 109 | 110 | with open("reward.jsonl", "w") as f: 111 | for prm_examples in correct_prm_data_step_values: 112 | json.dump(prm_examples, f) 113 | f.write("\n") 114 | 115 | for prm_examples in last_step_incorrect_prm_data_step_values: 116 | json.dump(prm_examples, f) 117 | f.write("\n") 118 | 119 | print(f"Results written to reward.jsonl") -------------------------------------------------------------------------------- /modal_vllm.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import asyncio 3 | from contextlib import asynccontextmanager 4 | 5 | def download_model_to_image(model_dir, model_name, model_revision): 6 | import os 7 | from huggingface_hub import snapshot_download 8 | from transformers.utils import move_cache 9 | 10 | os.makedirs(model_dir, exist_ok=True) 11 | snapshot_download( 12 | model_name, 13 | revision=model_revision, 14 | local_dir=model_dir, 15 | ignore_patterns=["*.pt", "*.bin"], # Using safetensors 16 | ) 17 | move_cache() 18 | 19 | MODEL_DIR = "/qwen" 20 | MODEL_NAME = "rawsh/MetaMath-Qwen2.5-0.5b" 21 | MODEL_REVISION = "286ca8b160074c923b89c318652ab4b979627550" 22 | # MODEL_NAME = "rawsh/mirrorqwen2.5-0.5b-ORPO-3" 23 | # MODEL_REVISION = "4b3e3eb18fe84477ee949058484ec951a5b8beb6" 24 | 25 | vllm_image = ( 26 | modal.Image.debian_slim(python_version="3.10") 27 | .pip_install( 28 | "vllm==0.6.2", 29 | "torch==2.4.0", 30 | "transformers>=4.45", 31 | "ray==2.36.0", 32 | "hf-transfer==0.1.8", 33 | "huggingface_hub==0.25.0", 34 | ) 35 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) 36 | .run_function( 37 | download_model_to_image, 38 | timeout=60 * 20, 39 | secrets=[modal.Secret.from_name("hf-token")], 40 | kwargs={ 41 | "model_dir": MODEL_DIR, 42 | "model_name": MODEL_NAME, 43 | "model_revision": MODEL_REVISION, 44 | }, 45 | ) 46 | .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) 47 | ) 48 | 49 | app = modal.App("vllm-qwen-metamath") 50 | 51 | N_GPU = 1 52 | MINUTES = 60 53 | HOURS = 60 * MINUTES 54 | 55 | async def get_model_config(engine): 56 | try: 57 | return await engine.get_model_config() 58 | except Exception as e: 59 | print(f"Error getting model config: {e}") 60 | raise 61 | 62 | @asynccontextmanager 63 | async def lifespan(app): 64 | # Startup 65 | try: 66 | await asyncio.sleep(0) # Give chance for event loop to start 67 | yield 68 | finally: 69 | # Shutdown: Cancel all pending tasks 70 | tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] 71 | for task in tasks: 72 | task.cancel() 73 | await asyncio.gather(*tasks, return_exceptions=True) 74 | 75 | @app.function( 76 | image=vllm_image, 77 | gpu=modal.gpu.A10G(count=N_GPU), 78 | container_idle_timeout=2 * MINUTES, 79 | timeout=20 * MINUTES, 80 | allow_concurrent_inputs=1000, 81 | secrets=[modal.Secret.from_name("vllm-token")] 82 | ) 83 | @modal.asgi_app() 84 | def serve(): 85 | import os 86 | import fastapi 87 | import vllm.entrypoints.openai.api_server as api_server 88 | from vllm.engine.arg_utils import AsyncEngineArgs 89 | from vllm.engine.async_llm_engine import AsyncLLMEngine 90 | from vllm.entrypoints.logger import RequestLogger 91 | from vllm.entrypoints.openai.serving_chat import OpenAIServingChat 92 | from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion 93 | from vllm.entrypoints.openai.serving_engine import BaseModelPath 94 | from vllm.usage.usage_lib import UsageContext 95 | 96 | web_app = fastapi.FastAPI( 97 | title=f"OpenAI-compatible {MODEL_NAME} server", 98 | description="Run an OpenAI-compatible LLM server with vLLM on modal.com", 99 | version="0.0.1", 100 | docs_url="/docs", 101 | lifespan=lifespan 102 | ) 103 | 104 | http_bearer = fastapi.security.HTTPBearer( 105 | scheme_name="Bearer Token", 106 | description="See code for authentication details.", 107 | ) 108 | web_app.add_middleware( 109 | fastapi.middleware.cors.CORSMiddleware, 110 | allow_origins=["*"], 111 | allow_credentials=True, 112 | allow_methods=["*"], 113 | allow_headers=["*"], 114 | ) 115 | 116 | TOKEN = os.environ["API_TOKEN"] 117 | async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): 118 | if api_key.credentials != TOKEN: 119 | raise fastapi.HTTPException( 120 | status_code=fastapi.status.HTTP_401_UNAUTHORIZED, 121 | detail="Invalid authentication credentials", 122 | ) 123 | return {"username": "authenticated_user"} 124 | 125 | router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) 126 | 127 | # wrap vllm's router in auth router 128 | router.include_router(api_server.router) 129 | # add authed vllm to our fastAPI app 130 | web_app.include_router(router) 131 | 132 | engine_args = AsyncEngineArgs( 133 | model=MODEL_DIR, 134 | tensor_parallel_size=N_GPU, 135 | gpu_memory_utilization=0.90, 136 | max_model_len=8096, 137 | enforce_eager=False, 138 | enable_prefix_caching=True 139 | ) 140 | 141 | engine = AsyncLLMEngine.from_engine_args( 142 | engine_args, usage_context=UsageContext.OPENAI_API_SERVER 143 | ) 144 | 145 | async def setup_engine(): 146 | model_config = await get_model_config(engine) 147 | return model_config 148 | 149 | # Use asyncio.run to properly handle the async setup 150 | model_config = asyncio.run(setup_engine()) 151 | request_logger = RequestLogger(max_log_len=2048) 152 | 153 | base_model_paths = [ 154 | BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) 155 | ] 156 | 157 | # Set up completion endpoint 158 | api_server.completion = lambda s: OpenAIServingCompletion( 159 | engine, 160 | model_config=model_config, 161 | base_model_paths=base_model_paths, 162 | lora_modules=[], 163 | prompt_adapters=[], 164 | request_logger=request_logger, 165 | ) 166 | 167 | # Set up chat endpoint 168 | api_server.chat = lambda s: OpenAIServingChat( 169 | engine, 170 | model_config=model_config, 171 | base_model_paths=base_model_paths, 172 | lora_modules=[], 173 | prompt_adapters=[], 174 | request_logger=request_logger, 175 | response_role="assistant" 176 | ) 177 | 178 | return web_app -------------------------------------------------------------------------------- /prime_training_local.py: -------------------------------------------------------------------------------- 1 | # - 2 | # Prereq 3 | # - 4 | # pip3 install torch==2.4.0 packaging wheel 5 | # pip3 install flash-attn 6 | # pip3 install vllm==0.6.3 ray transformers accelerate numpy datasets wandb bitsandbytes tensorboard tqdm evaluate pyext pylatexenc 7 | # export WANDB_API_KEY="" 8 | 9 | num_gpus = 8 10 | root_dir = "/home/ubuntu" 11 | model_save_dir = f"{root_dir}/save_dir" 12 | prime_dir = f"{root_dir}/prime/training" 13 | tmp_dir = f"{root_dir}/tmp" 14 | hf_model = "rawsh/SmallThinker-3B" 15 | hf_dataset = "rawsh/Eurus-2-RL-Data-ProblemsOnly" 16 | 17 | # sampling 18 | response_length = 9000 19 | 20 | def download_and_setup_prime(): 21 | """Download PRIME repository and install dependencies during image build""" 22 | import os 23 | import subprocess 24 | from datasets import load_dataset 25 | 26 | # Clone PRIME repository 27 | subprocess.run(["git", "clone", "https://github.com/PRIME-RL/PRIME.git", "/home/ubuntu/prime"], check=False) 28 | os.chdir(prime_dir) 29 | 30 | # Create data directory 31 | os.makedirs("data", exist_ok=True) 32 | 33 | # Create tmp dir 34 | os.makedirs(tmp_dir, exist_ok=True) 35 | 36 | # Download dataset from Hugging Face 37 | print("Downloading dataset from Hugging Face...") 38 | dataset = load_dataset(hf_dataset) 39 | 40 | # Save train and validation splits to parquet files 41 | print("Saving dataset splits to parquet files...") 42 | dataset['train'].to_parquet("data/train.parquet") 43 | dataset['validation'].to_parquet("data/validation.parquet") 44 | 45 | # Install training-specific requirements 46 | subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True, cwd=prime_dir) 47 | 48 | # Add PRIME directories to Python path 49 | with open("/home/ubuntu/.profile", "a") as f: 50 | f.write(""" 51 | export PYTHONPATH="${PYTHONPATH}:/home/ubuntu/prime/training" 52 | export PYTHONPATH="${PYTHONPATH}:/home/ubuntu/prime" 53 | """) 54 | 55 | # Create modified training script with proper paths 56 | with open("/home/ubuntu/prime/training/run_train.sh", "w") as f: 57 | f.write(f"""#!/bin/bash 58 | set -x 59 | 60 | # Environment variables 61 | export NCCL_DEBUG=WARN 62 | export VLLM_ATTENTION_BACKEND=FLASH_ATTN 63 | export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True 64 | export TOKENIZERS_PARALLELISM=true 65 | 66 | # Project configuration 67 | export CKPT_PATH='{model_save_dir}' 68 | export PROJECT_NAME='PRIME' 69 | export EXPERIMENT_NAME='test-run' 70 | port=6379 71 | 72 | # Start Ray 73 | /home/ubuntu/.local/bin/ray start --head \ 74 | --port=$port \ 75 | --num-gpus={num_gpus} \ 76 | --include-dashboard=false \ 77 | --temp-dir={tmp_dir} \ 78 | --block & 79 | 80 | cd /home/ubuntu/prime/training 81 | python3 -m verl.trainer.main_ppo \\ 82 | data.train_files=[/home/ubuntu/prime/training/data/train.parquet] \\ 83 | data.val_files=[/home/ubuntu/prime/training/data/validation.parquet] \\ 84 | data.train_batch_size=256 \\ 85 | data.val_batch_size=1024 \\ 86 | data.max_prompt_length=1024 \\ 87 | data.max_response_length={response_length} \\ 88 | actor_rollout_ref.model.path={hf_model} \\ 89 | actor_rollout_ref.actor.optim.lr=5e-7 \\ 90 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \\ 91 | actor_rollout_ref.actor.ppo_micro_batch_size=8 \\ 92 | actor_rollout_ref.actor.fsdp_config.param_offload=True \\ 93 | actor_rollout_ref.actor.fsdp_config.grad_offload=True \\ 94 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \\ 95 | actor_rollout_ref.actor.entropy_coeff=0. \\ 96 | actor_rollout_ref.rollout.log_prob_micro_batch_size=32 \\ 97 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\ 98 | actor_rollout_ref.rollout.name=vllm \\ 99 | actor_rollout_ref.rollout.gpu_memory_utilization=0.75 \\ 100 | actor_rollout_ref.ref.log_prob_micro_batch_size=32 \\ 101 | actor_rollout_ref.ref.fsdp_config.param_offload=True \\ 102 | algorithm.kl_ctrl.kl_coef=0.00 \\ 103 | trainer.logger=['console','wandb'] \\ 104 | trainer.project_name=$PROJECT_NAME \\ 105 | trainer.experiment_name=$EXPERIMENT_NAME \\ 106 | trainer.default_local_dir=$CKPT_PATH/$PROJECT_NAME/$EXPERIMENT_NAME \\ 107 | trainer.n_gpus_per_node={num_gpus} \\ 108 | trainer.nnodes=1 \\ 109 | trainer.save_freq=8 \\ 110 | trainer.test_freq=8 \\ 111 | trainer.total_epochs=1 \\ 112 | +trainer.total_training_steps=300 \\ 113 | +trainer.val_before_train=True \\ 114 | data.n_samples=4 \\ 115 | data.filter_accuracy=True \\ 116 | data.accuracy_lower_bound=0.2 \\ 117 | data.accuracy_upper_bound=0.8 \\ 118 | algorithm.adv_estimator=rloo \\ 119 | algorithm.adv_params.verifier_gamma=1.0 \\ 120 | algorithm.adv_params.reward_model_gamma=1.0 \\ 121 | reward_model.rm_type=prime \\ 122 | reward_model.rm_coef=5 \\ 123 | reward_model.prime_granularity=token \\ 124 | reward_model.prime_norm=batch_norm \\ 125 | reward_model.prime_model.path={hf_model} \\ 126 | reward_model.prime_model.ref_path={hf_model} \\ 127 | reward_model.model.input_tokenizer=null \\ 128 | reward_model.micro_batch_size=8 \\ 129 | reward_model.prime_model.ref_type=freeze \\ 130 | reward_model.prime_model.update=after \\ 131 | reward_model.prime_model.beta_train=0.05 \\ 132 | reward_model.prime_model.optim.lr=1e-6 \\ 133 | reward_model.prime_model.optim.grad_clip=10.0 \\ 134 | reward_model.prime_model.input_tokenizer=null""") 135 | # Make executable 136 | os.chmod("/home/ubuntu/prime/training/run_train.sh", 0o755) 137 | 138 | 139 | def train_prime(): 140 | """Main training function for PRIME""" 141 | import os 142 | import subprocess 143 | 144 | print("Starting PRIME training...") 145 | print("\nCurrent working directory:", os.getcwd()) 146 | 147 | os.chdir("/home/ubuntu/prime/training") 148 | print("PYTHONPATH:", os.environ.get('PYTHONPATH')) 149 | print("\nDirectory contents:") 150 | subprocess.run(["ls", "-la"], check=True) 151 | subprocess.run(["/home/ubuntu/prime/training/run_train.sh"], check=True) 152 | 153 | if __name__ == "__main__": 154 | download_and_setup_prime() 155 | train_prime() -------------------------------------------------------------------------------- /modal_vllm_prm.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import asyncio 3 | from contextlib import asynccontextmanager 4 | 5 | def download_model_to_image(model_dir, model_name, model_revision): 6 | import os 7 | from huggingface_hub import snapshot_download 8 | from transformers.utils import move_cache 9 | 10 | os.makedirs(model_dir, exist_ok=True) 11 | snapshot_download( 12 | model_name, 13 | revision=model_revision, 14 | local_dir=model_dir, 15 | ignore_patterns=["*.pt", "*.bin"], # Using safetensors 16 | ) 17 | move_cache() 18 | 19 | MODEL_DIR = "/qwen" 20 | MODEL_NAME = "rawsh/MetaMath-Qwen2.5-0.5b-PRM" 21 | MODEL_REVISION = "d230f00aa86b0967a4ee474df3c1f616f7ee7c57" 22 | 23 | vllm_image = ( 24 | modal.Image.debian_slim(python_version="3.10") 25 | .pip_install( 26 | "vllm==0.6.2", 27 | "torch==2.4.0", 28 | "transformers>=4.45", 29 | "ray==2.36.0", 30 | "hf-transfer==0.1.8", 31 | "huggingface_hub==0.25.0", 32 | ) 33 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) 34 | .run_function( 35 | download_model_to_image, 36 | timeout=60 * 20, 37 | secrets=[modal.Secret.from_name("hf-token")], 38 | kwargs={ 39 | "model_dir": MODEL_DIR, 40 | "model_name": MODEL_NAME, 41 | "model_revision": MODEL_REVISION, 42 | }, 43 | ) 44 | .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) 45 | ) 46 | 47 | app = modal.App("vllm-qwen-prm") 48 | 49 | N_GPU = 1 50 | MINUTES = 60 51 | HOURS = 60 * MINUTES 52 | 53 | async def get_model_config(engine): 54 | try: 55 | return await engine.get_model_config() 56 | except Exception as e: 57 | print(f"Error getting model config: {e}") 58 | raise 59 | 60 | @asynccontextmanager 61 | async def lifespan(app): 62 | try: 63 | await asyncio.sleep(0) 64 | yield 65 | finally: 66 | tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] 67 | for task in tasks: 68 | task.cancel() 69 | await asyncio.gather(*tasks, return_exceptions=True) 70 | 71 | @app.function( 72 | image=vllm_image, 73 | gpu=modal.gpu.A10G(count=N_GPU), 74 | container_idle_timeout=5 * MINUTES, 75 | timeout=20 * MINUTES, 76 | allow_concurrent_inputs=1000, 77 | secrets=[modal.Secret.from_name("vllm-token")] 78 | ) 79 | @modal.asgi_app() 80 | def serve(): 81 | import os 82 | import fastapi 83 | import vllm.entrypoints.openai.api_server as api_server 84 | from vllm.engine.arg_utils import AsyncEngineArgs 85 | from vllm.engine.async_llm_engine import AsyncLLMEngine 86 | from vllm.entrypoints.logger import RequestLogger 87 | from vllm.entrypoints.openai.serving_chat import OpenAIServingChat 88 | from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion 89 | from vllm.entrypoints.openai.serving_engine import BaseModelPath 90 | from vllm.usage.usage_lib import UsageContext 91 | from transformers import AutoTokenizer 92 | 93 | web_app = fastapi.FastAPI( 94 | title=f"OpenAI-compatible {MODEL_NAME} server", 95 | description="Run an OpenAI-compatible LLM server with vLLM on modal.com", 96 | version="0.0.1", 97 | docs_url="/docs", 98 | lifespan=lifespan 99 | ) 100 | 101 | http_bearer = fastapi.security.HTTPBearer( 102 | scheme_name="Bearer Token", 103 | description="See code for authentication details.", 104 | ) 105 | web_app.add_middleware( 106 | fastapi.middleware.cors.CORSMiddleware, 107 | allow_origins=["*"], 108 | allow_credentials=True, 109 | allow_methods=["*"], 110 | allow_headers=["*"], 111 | ) 112 | 113 | TOKEN = os.environ["API_TOKEN"] 114 | async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): 115 | if api_key.credentials != TOKEN: 116 | raise fastapi.HTTPException( 117 | status_code=fastapi.status.HTTP_401_UNAUTHORIZED, 118 | detail="Invalid authentication credentials", 119 | ) 120 | return {"username": "authenticated_user"} 121 | 122 | router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) 123 | router.include_router(api_server.router) 124 | web_app.include_router(router) 125 | 126 | engine_args = AsyncEngineArgs( 127 | model=MODEL_DIR, 128 | tensor_parallel_size=N_GPU, 129 | gpu_memory_utilization=0.90, 130 | max_model_len=8096, 131 | enforce_eager=False, 132 | enable_prefix_caching=True 133 | ) 134 | 135 | engine = AsyncLLMEngine.from_engine_args( 136 | engine_args, usage_context=UsageContext.OPENAI_API_SERVER 137 | ) 138 | 139 | async def setup_engine(): 140 | model_config = await get_model_config(engine) 141 | return model_config 142 | 143 | model_config = asyncio.run(setup_engine()) 144 | request_logger = RequestLogger(max_log_len=2048) 145 | 146 | base_model_paths = [ 147 | BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) 148 | ] 149 | 150 | # Qwen chat template with exact formatting 151 | TEMPLATE = """<|im_start|>system 152 | You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> 153 | {% for message in messages %}<|im_start|>{{ message['role'] }} 154 | {{ message['content'] }}<|im_end|> 155 | {% endfor %}{% if add_generation_prompt %}<|im_start|>assistant 156 | {% endif %}""" 157 | 158 | # TEMPLATE = """{%- for message in messages %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' %}{%- if loop.last and message.role == 'assistant' %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content %}{%- endif %}{{- content }}{%- endfor %}""" 159 | 160 | # Set up completion endpoint 161 | api_server.completion = lambda s: OpenAIServingCompletion( 162 | engine, 163 | model_config=model_config, 164 | base_model_paths=base_model_paths, 165 | lora_modules=[], 166 | prompt_adapters=[], 167 | request_logger=request_logger, 168 | ) 169 | 170 | # Set up chat endpoint with tokenizer's chat template 171 | api_server.chat = lambda s: OpenAIServingChat( 172 | engine, 173 | model_config=model_config, 174 | base_model_paths=base_model_paths, 175 | lora_modules=[], 176 | prompt_adapters=[], 177 | request_logger=request_logger, 178 | response_role="assistant", 179 | chat_template=TEMPLATE 180 | ) 181 | 182 | return web_app -------------------------------------------------------------------------------- /modal_prime.py: -------------------------------------------------------------------------------- 1 | 2 | # CUDA configuration 3 | cuda_version = "12.4.0" 4 | flavor = "devel" 5 | operating_sys = "ubuntu22.04" 6 | tag = f"{cuda_version}-{flavor}-{operating_sys}" 7 | 8 | num_gpus = 1 9 | model_save_dir = "/save_dir" 10 | 11 | 12 | import modal 13 | # TRAIN_GPU = modal.gpu.A10G() 14 | TRAIN_GPU = modal.gpu.H100() 15 | # TRAIN_GPU = modal.gpu.L40S() 16 | vol = modal.Volume.from_name("prime_savedir", create_if_missing=True) 17 | 18 | def download_and_setup_prime(): 19 | """Download PRIME repository and install dependencies during image build""" 20 | import os 21 | import subprocess 22 | from datasets import load_dataset 23 | 24 | # Clone PRIME repository 25 | subprocess.run(["git", "clone", "https://github.com/PRIME-RL/PRIME.git", "/PRIME"], check=True) 26 | os.chdir("/PRIME/training") 27 | 28 | # Create data directory 29 | os.makedirs("data", exist_ok=True) 30 | 31 | # Download dataset from Hugging Face 32 | print("Downloading dataset from Hugging Face...") 33 | dataset = load_dataset("PRIME-RL/Eurus-2-RL-Data") 34 | 35 | # Save train and validation splits to parquet files 36 | print("Saving dataset splits to parquet files...") 37 | dataset['train'].to_parquet("data/train.parquet") 38 | dataset['validation'].to_parquet("data/validation.parquet") 39 | 40 | # Install training-specific requirements 41 | subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True) 42 | 43 | # Add PRIME directories to Python path 44 | with open("/etc/profile.d/prime_paths.sh", "w") as f: 45 | f.write(""" 46 | export PYTHONPATH="${PYTHONPATH}:/PRIME/training" 47 | export PYTHONPATH="${PYTHONPATH}:/PRIME" 48 | """) 49 | 50 | # Create modified training script with proper paths 51 | with open("/PRIME/training/run_train.sh", "w") as f: 52 | f.write(f"""#!/bin/bash 53 | set -x 54 | 55 | # Environment variables 56 | export NCCL_DEBUG=WARN 57 | export VLLM_ATTENTION_BACKEND=FLASH_ATTN 58 | export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True 59 | export TOKENIZERS_PARALLELISM=true 60 | 61 | # Project configuration 62 | export CKPT_PATH='{model_save_dir}' 63 | export PROJECT_NAME='PRIME' 64 | export EXPERIMENT_NAME='test-run' 65 | port=6379 66 | 67 | # Start Ray 68 | ray start --head \ 69 | --port=$port \ 70 | --num-gpus=1 \ 71 | --include-dashboard=false \ 72 | --block & 73 | 74 | cd /PRIME/training 75 | python3 -m verl.trainer.main_ppo \\ 76 | data.train_files=[/PRIME/training/data/train.parquet] \\ 77 | data.val_files=[/PRIME/training/data/validation.parquet] \\ 78 | data.train_batch_size=256 \\ 79 | data.val_batch_size=1024 \\ 80 | data.max_prompt_length=1024 \\ 81 | data.max_response_length=3072 \\ 82 | actor_rollout_ref.model.path=rawsh/Qwen2.5-0.5b-Eurus-2-SFT \\ 83 | actor_rollout_ref.actor.optim.lr=5e-7 \\ 84 | actor_rollout_ref.actor.ppo_mini_batch_size=256 \\ 85 | actor_rollout_ref.actor.ppo_micro_batch_size=8 \\ 86 | actor_rollout_ref.actor.fsdp_config.param_offload=False \\ 87 | actor_rollout_ref.actor.fsdp_config.grad_offload=False \\ 88 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\ 89 | actor_rollout_ref.actor.entropy_coeff=0. \\ 90 | actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \\ 91 | actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\ 92 | actor_rollout_ref.rollout.name=vllm \\ 93 | actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \\ 94 | actor_rollout_ref.ref.log_prob_micro_batch_size=64 \\ 95 | actor_rollout_ref.ref.fsdp_config.param_offload=False \\ 96 | algorithm.kl_ctrl.kl_coef=0.00 \\ 97 | trainer.logger=['console','wandb'] \\ 98 | trainer.project_name=$PROJECT_NAME \\ 99 | trainer.experiment_name=$EXPERIMENT_NAME \\ 100 | trainer.default_local_dir=$CKPT_PATH/$PROJECT_NAME/$EXPERIMENT_NAME \\ 101 | trainer.n_gpus_per_node={num_gpus} \\ 102 | trainer.nnodes=1 \\ 103 | trainer.save_freq=32 \\ 104 | trainer.test_freq=32 \\ 105 | trainer.total_epochs=1 \\ 106 | data.n_samples=4 \\ 107 | data.filter_accuracy=True \\ 108 | data.accuracy_lower_bound=0.2 \\ 109 | data.accuracy_upper_bound=0.8 \\ 110 | algorithm.adv_estimator=rloo \\ 111 | algorithm.adv_params.verifier_gamma=1.0 \\ 112 | algorithm.adv_params.reward_model_gamma=1.0 \\ 113 | reward_model.rm_type=prime \\ 114 | reward_model.rm_coef=5 \\ 115 | reward_model.prime_granularity=token \\ 116 | reward_model.prime_norm=batch_norm \\ 117 | reward_model.prime_model.path=rawsh/Qwen2.5-0.5b-Eurus-2-SFT \\ 118 | reward_model.prime_model.ref_path=rawsh/Qwen2.5-0.5b-Eurus-2-SFT \\ 119 | reward_model.model.input_tokenizer=null \\ 120 | reward_model.micro_batch_size=8 \\ 121 | reward_model.prime_model.update=after \\ 122 | reward_model.prime_model.beta_train=0.05 \\ 123 | reward_model.prime_model.optim.lr=1e-6 \\ 124 | reward_model.prime_model.optim.grad_clip=10.0 \\ 125 | reward_model.prime_model.input_tokenizer=null""") 126 | # Make executable 127 | os.chmod("/PRIME/training/run_train.sh", 0o755) 128 | 129 | prime_image = ( 130 | modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") 131 | .apt_install("git", "build-essential", "ninja-build") 132 | # First install torch and its dependencies 133 | .pip_install([ 134 | "torch==2.4.0", 135 | "packaging", 136 | "wheel", 137 | ]) 138 | # Then install flash-attention separately 139 | .pip_install("flash-attn") 140 | # Finally install the rest of the dependencies 141 | .pip_install([ 142 | "vllm==0.6.3", 143 | "ray", 144 | "transformers", 145 | "accelerate", 146 | "numpy", 147 | "datasets", 148 | "wandb", 149 | "bitsandbytes", 150 | "tensorboard", 151 | "tqdm", 152 | "evaluate" 153 | ]) 154 | .pip_install([ 155 | "pyext", 156 | "pylatexenc" 157 | ]) 158 | .run_function(download_and_setup_prime, gpu=TRAIN_GPU) 159 | ) 160 | 161 | app = modal.App("prime-training", image=prime_image) 162 | 163 | @app.function( 164 | cpu=2.0, 165 | gpu=TRAIN_GPU, 166 | volumes={model_save_dir: vol}, 167 | timeout=20 * 60 * 60, # 20 hours 168 | secrets=[ 169 | modal.Secret.from_name("hf-token"), 170 | modal.Secret.from_name("wandb-token") 171 | ] 172 | ) 173 | def train_prime(): 174 | """Main training function for PRIME""" 175 | import os 176 | import subprocess 177 | 178 | print("Starting PRIME training...") 179 | print("\nCurrent working directory:", os.getcwd()) 180 | 181 | os.chdir("/PRIME/training") 182 | print("PYTHONPATH:", os.environ.get('PYTHONPATH')) 183 | print("\nDirectory contents:") 184 | subprocess.run(["ls", "-la"], check=True) 185 | subprocess.run(["/PRIME/training/run_train.sh"], check=True) 186 | 187 | @app.local_entrypoint() 188 | def main(): 189 | train_prime.remote() -------------------------------------------------------------------------------- /modal_vllm_chat.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import asyncio 3 | from contextlib import asynccontextmanager 4 | 5 | def download_model_to_image(model_dir, model_name, model_revision): 6 | import os 7 | from huggingface_hub import snapshot_download 8 | from transformers.utils import move_cache 9 | 10 | os.makedirs(model_dir, exist_ok=True) 11 | snapshot_download( 12 | model_name, 13 | revision=model_revision, 14 | local_dir=model_dir, 15 | ignore_patterns=["*.pt", "*.bin"], # Using safetensors 16 | ) 17 | move_cache() 18 | 19 | MODEL_DIR = "/qwen" 20 | MODEL_NAME = "rawsh/MetaMath-Qwen2.5-0.5b" 21 | MODEL_REVISION = "779b469ef1bb4ef8faac05e46b94c09d38112194" 22 | 23 | vllm_image = ( 24 | modal.Image.debian_slim(python_version="3.10") 25 | .pip_install( 26 | "vllm==0.6.2", 27 | "torch==2.4.0", 28 | "transformers>=4.45", 29 | "ray==2.36.0", 30 | "hf-transfer==0.1.8", 31 | "huggingface_hub==0.25.0", 32 | ) 33 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) 34 | .run_function( 35 | download_model_to_image, 36 | timeout=60 * 20, 37 | secrets=[modal.Secret.from_name("hf-token")], 38 | kwargs={ 39 | "model_dir": MODEL_DIR, 40 | "model_name": MODEL_NAME, 41 | "model_revision": MODEL_REVISION, 42 | }, 43 | ) 44 | .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) 45 | ) 46 | 47 | app = modal.App("vllm-qwen-metamath") 48 | 49 | N_GPU = 1 50 | MINUTES = 60 51 | HOURS = 60 * MINUTES 52 | 53 | async def get_model_config(engine): 54 | try: 55 | return await engine.get_model_config() 56 | except Exception as e: 57 | print(f"Error getting model config: {e}") 58 | raise 59 | 60 | @asynccontextmanager 61 | async def lifespan(app): 62 | try: 63 | await asyncio.sleep(0) 64 | yield 65 | finally: 66 | tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] 67 | for task in tasks: 68 | task.cancel() 69 | await asyncio.gather(*tasks, return_exceptions=True) 70 | 71 | @app.function( 72 | image=vllm_image, 73 | gpu=modal.gpu.A10G(count=N_GPU), 74 | container_idle_timeout=5 * MINUTES, 75 | timeout=20 * MINUTES, 76 | allow_concurrent_inputs=1000, 77 | secrets=[modal.Secret.from_name("vllm-token")] 78 | ) 79 | @modal.asgi_app() 80 | def serve(): 81 | import os 82 | import fastapi 83 | import vllm.entrypoints.openai.api_server as api_server 84 | from vllm.engine.arg_utils import AsyncEngineArgs 85 | from vllm.engine.async_llm_engine import AsyncLLMEngine 86 | from vllm.entrypoints.logger import RequestLogger 87 | from vllm.entrypoints.openai.serving_chat import OpenAIServingChat 88 | from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion 89 | from vllm.entrypoints.openai.serving_engine import BaseModelPath 90 | from vllm.usage.usage_lib import UsageContext 91 | from transformers import AutoTokenizer 92 | 93 | web_app = fastapi.FastAPI( 94 | title=f"OpenAI-compatible {MODEL_NAME} server", 95 | description="Run an OpenAI-compatible LLM server with vLLM on modal.com", 96 | version="0.0.1", 97 | docs_url="/docs", 98 | lifespan=lifespan 99 | ) 100 | 101 | http_bearer = fastapi.security.HTTPBearer( 102 | scheme_name="Bearer Token", 103 | description="See code for authentication details.", 104 | ) 105 | web_app.add_middleware( 106 | fastapi.middleware.cors.CORSMiddleware, 107 | allow_origins=["*"], 108 | allow_credentials=True, 109 | allow_methods=["*"], 110 | allow_headers=["*"], 111 | ) 112 | 113 | TOKEN = os.environ["API_TOKEN"] 114 | async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): 115 | if api_key.credentials != TOKEN: 116 | raise fastapi.HTTPException( 117 | status_code=fastapi.status.HTTP_401_UNAUTHORIZED, 118 | detail="Invalid authentication credentials", 119 | ) 120 | return {"username": "authenticated_user"} 121 | 122 | router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) 123 | router.include_router(api_server.router) 124 | web_app.include_router(router) 125 | 126 | engine_args = AsyncEngineArgs( 127 | model=MODEL_DIR, 128 | tensor_parallel_size=N_GPU, 129 | gpu_memory_utilization=0.90, 130 | max_model_len=8096, 131 | enforce_eager=False, 132 | enable_prefix_caching=True 133 | ) 134 | 135 | engine = AsyncLLMEngine.from_engine_args( 136 | engine_args, usage_context=UsageContext.OPENAI_API_SERVER 137 | ) 138 | 139 | async def setup_engine(): 140 | model_config = await get_model_config(engine) 141 | return model_config 142 | 143 | model_config = asyncio.run(setup_engine()) 144 | request_logger = RequestLogger(max_log_len=2048) 145 | 146 | base_model_paths = [ 147 | BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) 148 | ] 149 | 150 | # Qwen chat template with exact formatting 151 | # TEMPLATE = """{%- for message in messages %} 152 | # {{- '<|im_start|>' + message.role + '\n' + message.content.strip() + '\n<|im_end|>\n' }} 153 | # {%- endfor %} 154 | # {%- if add_generation_prompt %} 155 | # {{- '<|im_start|>assistant\n' }} 156 | # {%- endif %}""" 157 | #NICEE 158 | # TEMPLATE = """{%- for message in messages %} 159 | # {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} 160 | # {%- endfor %} 161 | # <|im_start|>assistant 162 | # """ 163 | TEMPLATE = """{%- for message in messages %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' %}{%- if loop.last and message.role == 'assistant' %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content %}{%- endif %}{{- content }}{%- endfor %}""" 164 | # TEMPLATE = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} 165 | # {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}""" 166 | 167 | # Set up completion endpoint 168 | api_server.completion = lambda s: OpenAIServingCompletion( 169 | engine, 170 | model_config=model_config, 171 | base_model_paths=base_model_paths, 172 | lora_modules=[], 173 | prompt_adapters=[], 174 | request_logger=request_logger, 175 | ) 176 | 177 | # Set up chat endpoint with tokenizer's chat template 178 | api_server.chat = lambda s: OpenAIServingChat( 179 | engine, 180 | model_config=model_config, 181 | base_model_paths=base_model_paths, 182 | lora_modules=[], 183 | prompt_adapters=[], 184 | request_logger=request_logger, 185 | response_role="assistant", 186 | chat_template=TEMPLATE, 187 | ) 188 | 189 | return web_app -------------------------------------------------------------------------------- /modal_vllm_qwq.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import asyncio 3 | from contextlib import asynccontextmanager 4 | 5 | def download_model_to_image(model_dir, model_name, model_revision): 6 | import os 7 | from huggingface_hub import snapshot_download 8 | from transformers.utils import move_cache 9 | 10 | os.makedirs(model_dir, exist_ok=True) 11 | snapshot_download( 12 | model_name, 13 | revision=model_revision, 14 | local_dir=model_dir, 15 | ignore_patterns=["*.pt", "*.bin"], # Using safetensors 16 | ) 17 | move_cache() 18 | 19 | MODEL_DIR = "/qwen" 20 | MODEL_NAME = "rawsh/q1-Qwen2.5-0.5b" 21 | MODEL_REVISION = "4b2cfa2ca5eff9c562f886bc16af888bdb8917cf" 22 | # MODEL_NAME = "rawsh/q1-Qwen2.5-0.5b-Instruct" 23 | # MODEL_REVISION = "ca6350a60de28f3f6c67f014188dd4b6f3642cde" 24 | # MODEL_NAME = "rawsh/q1-Qwen2.5-Math-1.5B" 25 | # MODEL_REVISION = "74ccf3b42e6e3495ed51b5f5308b5822733adb4d" 26 | 27 | vllm_image = ( 28 | modal.Image.debian_slim(python_version="3.10") 29 | .pip_install( 30 | "vllm==0.6.2", 31 | "torch==2.4.0", 32 | "transformers>=4.45", 33 | "ray==2.36.0", 34 | "hf-transfer==0.1.8", 35 | "huggingface_hub==0.25.0", 36 | ) 37 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) 38 | .run_function( 39 | download_model_to_image, 40 | timeout=60 * 20, 41 | secrets=[modal.Secret.from_name("hf-token")], 42 | kwargs={ 43 | "model_dir": MODEL_DIR, 44 | "model_name": MODEL_NAME, 45 | "model_revision": MODEL_REVISION, 46 | }, 47 | ) 48 | .env({"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}) 49 | ) 50 | 51 | app = modal.App("vllm-qwq-distill") 52 | 53 | N_GPU = 1 54 | MINUTES = 60 55 | HOURS = 60 * MINUTES 56 | 57 | async def get_model_config(engine): 58 | try: 59 | return await engine.get_model_config() 60 | except Exception as e: 61 | print(f"Error getting model config: {e}") 62 | raise 63 | 64 | @asynccontextmanager 65 | async def lifespan(app): 66 | try: 67 | await asyncio.sleep(0) 68 | yield 69 | finally: 70 | tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] 71 | for task in tasks: 72 | task.cancel() 73 | await asyncio.gather(*tasks, return_exceptions=True) 74 | 75 | @app.function( 76 | image=vllm_image, 77 | gpu=modal.gpu.A10G(count=N_GPU), 78 | # gpu=modal.gpu.A100(), 79 | container_idle_timeout=5 * MINUTES, 80 | timeout=20 * MINUTES, 81 | allow_concurrent_inputs=1000, 82 | # allow_concurrent_inputs=32, 83 | secrets=[modal.Secret.from_name("vllm-token")] 84 | ) 85 | @modal.asgi_app() 86 | def serve(): 87 | import os 88 | import fastapi 89 | import vllm.entrypoints.openai.api_server as api_server 90 | from vllm.engine.arg_utils import AsyncEngineArgs 91 | from vllm.engine.async_llm_engine import AsyncLLMEngine 92 | from vllm.entrypoints.logger import RequestLogger 93 | from vllm.entrypoints.openai.serving_chat import OpenAIServingChat 94 | from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion 95 | from vllm.entrypoints.openai.serving_engine import BaseModelPath 96 | from vllm.usage.usage_lib import UsageContext 97 | from transformers import AutoTokenizer 98 | 99 | web_app = fastapi.FastAPI( 100 | title=f"OpenAI-compatible {MODEL_NAME} server", 101 | description="Run an OpenAI-compatible LLM server with vLLM on modal.com", 102 | version="0.0.1", 103 | docs_url="/docs", 104 | lifespan=lifespan 105 | ) 106 | 107 | http_bearer = fastapi.security.HTTPBearer( 108 | scheme_name="Bearer Token", 109 | description="See code for authentication details.", 110 | ) 111 | web_app.add_middleware( 112 | fastapi.middleware.cors.CORSMiddleware, 113 | allow_origins=["*"], 114 | allow_credentials=True, 115 | allow_methods=["*"], 116 | allow_headers=["*"], 117 | ) 118 | 119 | TOKEN = os.environ["API_TOKEN"] 120 | async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): 121 | if api_key.credentials != TOKEN: 122 | raise fastapi.HTTPException( 123 | status_code=fastapi.status.HTTP_401_UNAUTHORIZED, 124 | detail="Invalid authentication credentials", 125 | ) 126 | return {"username": "authenticated_user"} 127 | 128 | router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) 129 | router.include_router(api_server.router) 130 | web_app.include_router(router) 131 | 132 | engine_args = AsyncEngineArgs( 133 | model=MODEL_DIR, 134 | tensor_parallel_size=N_GPU, 135 | gpu_memory_utilization=0.90, 136 | # max_model_len=8096, 137 | max_model_len=32384, 138 | enforce_eager=False, 139 | enable_prefix_caching=True 140 | ) 141 | 142 | engine = AsyncLLMEngine.from_engine_args( 143 | engine_args, usage_context=UsageContext.OPENAI_API_SERVER 144 | ) 145 | 146 | async def setup_engine(): 147 | model_config = await get_model_config(engine) 148 | return model_config 149 | 150 | model_config = asyncio.run(setup_engine()) 151 | request_logger = RequestLogger(max_log_len=2048) 152 | 153 | base_model_paths = [ 154 | BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) 155 | ] 156 | 157 | # Qwen chat template with exact formatting 158 | # TEMPLATE = """{%- for message in messages %} 159 | # {{- '<|im_start|>' + message.role + '\n' + message.content.strip() + '\n<|im_end|>\n' }} 160 | # {%- endfor %} 161 | # {%- if add_generation_prompt %} 162 | # {{- '<|im_start|>assistant\n' }} 163 | # {%- endif %}""" 164 | #NICEE 165 | # TEMPLATE = """{%- for message in messages %} 166 | # {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} 167 | # {%- endfor %} 168 | # <|im_start|>assistant 169 | # """ 170 | TEMPLATE = """{%- for message in messages %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' %}{%- if loop.last and message.role == 'assistant' %}{%- set content = '<|im_start|>' + message.role + '\n' + message.content %}{%- endif %}{{- content }}{%- endfor %}""" 171 | # TEMPLATE = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} 172 | # {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}""" 173 | 174 | # Set up completion endpoint 175 | api_server.completion = lambda s: OpenAIServingCompletion( 176 | engine, 177 | model_config=model_config, 178 | base_model_paths=base_model_paths, 179 | lora_modules=[], 180 | prompt_adapters=[], 181 | request_logger=request_logger, 182 | ) 183 | 184 | # Set up chat endpoint with tokenizer's chat template 185 | api_server.chat = lambda s: OpenAIServingChat( 186 | engine, 187 | model_config=model_config, 188 | base_model_paths=base_model_paths, 189 | lora_modules=[], 190 | prompt_adapters=[], 191 | request_logger=request_logger, 192 | response_role="assistant", 193 | chat_template=TEMPLATE, 194 | ) 195 | 196 | return web_app -------------------------------------------------------------------------------- /mcts/train_policy_sft_prime.py: -------------------------------------------------------------------------------- 1 | from unsloth import FastLanguageModel 2 | import torch 3 | from datasets import load_dataset 4 | from unsloth import is_bfloat16_supported 5 | from unsloth import UnslothTrainer, UnslothTrainingArguments 6 | from unsloth.chat_templates import get_chat_template 7 | from typing import List, Dict 8 | import json 9 | 10 | # Constants 11 | max_seq_length = 8192 12 | dtype = None 13 | load_in_4bit = False 14 | BATCH_SIZE = 10000 # Batch size for tokenization 15 | 16 | def process_conversation_batch(examples: Dict, tokenizer) -> List[str]: 17 | """Process a batch of conversations and return formatted chat templates.""" 18 | conversations = [] 19 | 20 | for system, conv_list in zip(examples['system'], examples['conversations']): 21 | try: 22 | # Basic validation 23 | if not conv_list or len(conv_list) < 2: 24 | continue 25 | if not (conv_list[0].get('from') == 'human' and conv_list[1].get('from') == 'gpt'): 26 | continue 27 | 28 | # Format messages 29 | formatted_msgs = [{"role": "system", "content": system}] 30 | formatted_msgs.extend([ 31 | {"role": "user" if msg['from'] == 'human' else "assistant", "content": msg['value']} 32 | for msg in conv_list 33 | ]) 34 | conversations.append(formatted_msgs) 35 | 36 | except (json.JSONDecodeError, AttributeError, KeyError): 37 | continue 38 | 39 | # Apply chat template without tokenization 40 | return [tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False) 41 | for conv in conversations] 42 | 43 | def filter_by_length(texts: List[str], tokenizer, max_length: int) -> List[str]: 44 | """Filter texts by tokenized length.""" 45 | tokenized = tokenizer(texts, truncation=False, padding=False) 46 | return [text for i, text in enumerate(texts) 47 | if len(tokenized["input_ids"][i]) < max_length] 48 | 49 | def train_sft(): 50 | # Load base and instruct models 51 | model, tokenizer = FastLanguageModel.from_pretrained( 52 | model_name="unsloth/Qwen2.5-0.5B", 53 | max_seq_length=max_seq_length, 54 | dtype=dtype, 55 | load_in_4bit=load_in_4bit, 56 | ) 57 | model_instruct, tokenizer_instruct = FastLanguageModel.from_pretrained( 58 | model_name="unsloth/Qwen2.5-0.5B-Instruct", 59 | max_seq_length=max_seq_length, 60 | dtype=dtype, 61 | load_in_4bit=load_in_4bit, 62 | ) 63 | 64 | # Transfer embeddings 65 | TRANSFER = True 66 | if TRANSFER: 67 | base_embeddings = model.get_input_embeddings() 68 | instruct_embeddings = model_instruct.get_input_embeddings() 69 | chat_tokens = ["<|im_start|>", "<|im_end|>", "system", "assistant", "user"] 70 | 71 | with torch.no_grad(): 72 | for token in chat_tokens: 73 | try: 74 | instruct_id = tokenizer_instruct.convert_tokens_to_ids(token) 75 | base_id = tokenizer.convert_tokens_to_ids(token) 76 | if instruct_id != tokenizer_instruct.unk_token_id and base_id != tokenizer.unk_token_id: 77 | base_embeddings.weight[base_id] = instruct_embeddings.weight[instruct_id].clone() 78 | print(f"Transferred embedding for token: {token}") 79 | else: 80 | print(f"Warning: Token {token} not found in one of the vocabularies") 81 | except Exception as e: 82 | print(f"Error transferring token {token}: {str(e)}") 83 | 84 | # Cleanup 85 | import gc 86 | del model_instruct, tokenizer_instruct 87 | gc.collect() 88 | torch.cuda.empty_cache() 89 | 90 | # Add LoRA adapters 91 | model = FastLanguageModel.get_peft_model( 92 | model, 93 | r=128, 94 | target_modules=[ 95 | "q_proj", "k_proj", "v_proj", "o_proj", 96 | "gate_proj", "up_proj", "down_proj", 97 | "embed_tokens", "lm_head", 98 | ], 99 | lora_alpha=32, 100 | lora_dropout=0, 101 | bias="none", 102 | use_gradient_checkpointing="unsloth", 103 | random_state=3407, 104 | use_rslora=True, 105 | loftq_config=None, 106 | ) 107 | 108 | # Setup tokenizer 109 | tokenizer = get_chat_template(tokenizer, chat_template="qwen-2.5") 110 | tokenizer.eos_token = "<|im_end|>" 111 | 112 | # Load dataset 113 | dataset = load_dataset("PRIME-RL/Eurus-2-SFT-Data", split="train") 114 | 115 | def formatting_prompts_func(examples): 116 | # Process conversations in the current batch 117 | texts = process_conversation_batch(examples, tokenizer) 118 | 119 | # Filter by tokenized length 120 | texts_filtered = filter_by_length(texts, tokenizer, max_seq_length) 121 | 122 | if len(texts) != len(texts_filtered): 123 | print(f"Filtered {len(texts) - len(texts_filtered)} examples due to length") 124 | 125 | return {"text": texts_filtered} 126 | 127 | # Process dataset 128 | dataset = dataset.map( 129 | formatting_prompts_func, 130 | batched=True, 131 | batch_size=BATCH_SIZE, 132 | remove_columns=dataset.column_names 133 | ) 134 | 135 | print(f"Final dataset size: {len(dataset)}") 136 | 137 | # Configure trainer 138 | trainer = UnslothTrainer( 139 | model=model, 140 | tokenizer=tokenizer, 141 | train_dataset=dataset, 142 | dataset_text_field="text", 143 | max_seq_length=max_seq_length, 144 | dataset_num_proc=64, 145 | args=UnslothTrainingArguments( 146 | learning_rate=5e-6, 147 | embedding_learning_rate=5e-7, 148 | per_device_train_batch_size=8, 149 | gradient_accumulation_steps=8, 150 | lr_scheduler_type="cosine", 151 | num_train_epochs=3, 152 | warmup_ratio=0.1, 153 | max_seq_length=max_seq_length, 154 | fp16=not is_bfloat16_supported(), 155 | bf16=is_bfloat16_supported(), 156 | optim="adamw_8bit", 157 | weight_decay=0.01, 158 | logging_steps=1, 159 | seed=3407, 160 | output_dir="outputs", 161 | report_to="wandb", 162 | run_name="eurus-sft", 163 | hub_strategy="every_save", 164 | save_strategy="steps", 165 | save_steps=100, 166 | hub_model_id="rawsh/Qwen2.5-0.5b-Eurus-2-SFT" 167 | ), 168 | ) 169 | 170 | # Print GPU stats 171 | gpu_stats = torch.cuda.get_device_properties(0) 172 | start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) 173 | max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) 174 | print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") 175 | print(f"{start_gpu_memory} GB of memory reserved.") 176 | 177 | # Train 178 | trainer_stats = trainer.train() 179 | 180 | # Show memory stats 181 | used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) 182 | used_memory_for_lora = round(used_memory - start_gpu_memory, 3) 183 | used_percentage = round(used_memory/max_memory*100, 3) 184 | lora_percentage = round(used_memory_for_lora/max_memory*100, 3) 185 | 186 | print(f"Training time: {round(trainer_stats.metrics['train_runtime']/60, 2)} minutes") 187 | print(f"Peak memory usage: {used_memory} GB ({used_percentage}%)") 188 | print(f"LoRA training memory: {used_memory_for_lora} GB ({lora_percentage}%)") 189 | 190 | # Save to HuggingFace Hub 191 | model.push_to_hub_merged( 192 | "rawsh/Qwen2.5-0.5b-Eurus-2-SFT", 193 | tokenizer, 194 | save_method="merged_16bit", 195 | ) 196 | 197 | if __name__ == "__main__": 198 | train_sft() -------------------------------------------------------------------------------- /mcts/train_policy_sft_metamath.py: -------------------------------------------------------------------------------- 1 | from unsloth import FastLanguageModel 2 | import torch 3 | import wandb 4 | from datasets import load_dataset 5 | from unsloth import is_bfloat16_supported 6 | from unsloth import UnslothTrainer, UnslothTrainingArguments 7 | from unsloth.chat_templates import get_chat_template 8 | 9 | # Constants 10 | max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! 11 | dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ 12 | load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False. 13 | 14 | first_type1 = True 15 | first_type2 = True 16 | 17 | def format_answer(response): 18 | """Extract answer from #### pattern and format response.""" 19 | global first_type1 20 | global first_type2 21 | 22 | # Split at #### and get everything before it 23 | parts = response.split('####') 24 | if len(parts) < 2: 25 | # combine the last two steps 26 | steps = parts[0].strip().split("\n") 27 | if len(steps) > 1: 28 | steps[-2] = steps[-2] + f"\n{steps[-1]}" 29 | steps = steps[:-1] 30 | sol = "\n\n".join(steps) 31 | 32 | if (first_type1): 33 | print(response) 34 | first_type1 = False 35 | 36 | return sol 37 | else: 38 | return None 39 | 40 | solution = "\n\n".join(parts[0].strip().split("\n")) 41 | answer = parts[1].split('The answer is:') 42 | answer = answer[0].strip() 43 | sol = f"{solution}\nThe answer is: {answer}" 44 | 45 | if (first_type2): 46 | print(response) 47 | first_type2 = False 48 | 49 | return sol 50 | 51 | def train_sft(): 52 | # Load base and instruct models 53 | model, tokenizer = FastLanguageModel.from_pretrained( 54 | model_name = "unsloth/Qwen2.5-0.5B", 55 | max_seq_length = max_seq_length, 56 | dtype = dtype, 57 | load_in_4bit = load_in_4bit, 58 | ) 59 | 60 | model_instruct, tokenizer_instruct = FastLanguageModel.from_pretrained( 61 | model_name = "unsloth/Qwen2.5-0.5B-Instruct", 62 | max_seq_length = max_seq_length, 63 | dtype = dtype, 64 | load_in_4bit = load_in_4bit, 65 | ) 66 | 67 | # Transfer chat token embeddings from instruct to base model 68 | base_embeddings = model.get_input_embeddings() 69 | instruct_embeddings = model_instruct.get_input_embeddings() 70 | chat_tokens = ["<|im_start|>", "<|im_end|>", "system", "assistant", "user"] 71 | with torch.no_grad(): 72 | for token in chat_tokens: 73 | try: 74 | instruct_id = tokenizer_instruct.convert_tokens_to_ids(token) 75 | base_id = tokenizer.convert_tokens_to_ids(token) 76 | if instruct_id != tokenizer_instruct.unk_token_id and base_id != tokenizer.unk_token_id: 77 | base_embeddings.weight[base_id] = instruct_embeddings.weight[instruct_id].clone() 78 | print(f"Transferred embedding for token: {token}") 79 | else: 80 | print(f"Warning: Token {token} not found in one of the vocabularies") 81 | except Exception as e: 82 | print(f"Error transferring token {token}: {str(e)}") 83 | 84 | # Add LoRA adapters 85 | model = FastLanguageModel.get_peft_model( 86 | model, 87 | r = 128, # Choose any number > 0! Suggested 8, 16, 32, 64, 128 88 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", 89 | "gate_proj", "up_proj", "down_proj", 90 | "embed_tokens", "lm_head",], # Add for continual pretraining 91 | lora_alpha = 32, 92 | lora_dropout = 0, # Supports any, but = 0 is optimized 93 | bias = "none", # Supports any, but = "none" is optimized 94 | use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context 95 | random_state = 3407, 96 | use_rslora = True, # We support rank stabilized LoRA 97 | loftq_config = None, # And LoftQ 98 | ) 99 | 100 | # Set up tokenizer with chat template 101 | tokenizer = get_chat_template( 102 | tokenizer, 103 | chat_template = "qwen-2.5", 104 | ) 105 | tokenizer.eos_token = "<|im_end|>" 106 | print(tokenizer.eos_token) 107 | print(tokenizer.pad_token) 108 | 109 | # Load and process dataset 110 | dataset = load_dataset("meta-math/MetaMathQA", split="train") 111 | 112 | def formatting_prompts_func(examples): 113 | conversations = [] 114 | for query, response in zip(examples['query'], examples['response']): 115 | formatted_response = format_answer(response) 116 | if formatted_response is None: 117 | continue 118 | 119 | conversation = [ 120 | {"role": "user", "content": query}, 121 | {"role": "assistant", "content": formatted_response} 122 | ] 123 | conversations.append(conversation) 124 | 125 | texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) 126 | for convo in conversations] 127 | return {"text": texts} 128 | 129 | dataset = dataset.map(formatting_prompts_func, batched=True, remove_columns=dataset.column_names) 130 | 131 | # Debug tokenizer output - show examples 132 | print("Example of tokenized output:") 133 | print(dataset[5]["text"]) 134 | print("\nAnother example:") 135 | print(dataset[100]["text"]) 136 | 137 | # Configure trainer 138 | trainer = UnslothTrainer( 139 | model = model, 140 | tokenizer = tokenizer, 141 | train_dataset = dataset, 142 | dataset_text_field = "text", 143 | max_seq_length = max_seq_length, 144 | dataset_num_proc = 8, 145 | 146 | args = UnslothTrainingArguments( 147 | learning_rate = 5e-5, 148 | embedding_learning_rate = 5e-6, 149 | per_device_train_batch_size = 8, # With gradient_accumulation_steps=8 this gives effective batch size 64 150 | gradient_accumulation_steps = 8, 151 | lr_scheduler_type = "cosine", 152 | num_train_epochs = 3, 153 | warmup_ratio = 0.1, 154 | max_seq_length = 2048, 155 | fp16 = not is_bfloat16_supported(), 156 | bf16 = is_bfloat16_supported(), 157 | optim = "adamw_8bit", 158 | weight_decay = 0.01, 159 | logging_steps = 1, 160 | seed = 3407, 161 | output_dir = "outputs", 162 | report_to = "wandb", 163 | run_name = "metamath", 164 | hub_strategy = "every_save", 165 | save_strategy = "steps", 166 | save_steps = 100, 167 | hub_model_id = "rawsh/MetaMath-Qwen2.5-0.5b" 168 | ), 169 | ) 170 | 171 | # Set up wandb 172 | # wandb.login(key="YOUR_WANDB_KEY") # Replace with your key 173 | # wandb.init(project='metamath') 174 | 175 | # Print initial GPU stats 176 | gpu_stats = torch.cuda.get_device_properties(0) 177 | start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) 178 | max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) 179 | print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") 180 | print(f"{start_gpu_memory} GB of memory reserved.") 181 | 182 | # Train 183 | trainer_stats = trainer.train() 184 | 185 | # Show final memory and time stats 186 | used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) 187 | used_memory_for_lora = round(used_memory - start_gpu_memory, 3) 188 | used_percentage = round(used_memory/max_memory*100, 3) 189 | lora_percentage = round(used_memory_for_lora/max_memory*100, 3) 190 | 191 | print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") 192 | print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.") 193 | print(f"Peak reserved memory = {used_memory} GB.") 194 | print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") 195 | print(f"Peak reserved memory % of max memory = {used_percentage} %.") 196 | print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.") 197 | 198 | # Save model to HuggingFace Hub 199 | model.push_to_hub_merged( 200 | "rawsh/MetaMath-Qwen2.5-0.5b", # Replace with your username 201 | tokenizer, 202 | save_method="merged_16bit", 203 | ) 204 | 205 | if __name__ == "__main__": 206 | train_sft() -------------------------------------------------------------------------------- /backup/modal_reward_save.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | image = ( 4 | modal.Image.debian_slim() 5 | .pip_install("torch") 6 | .pip_install("transformers") 7 | .pip_install("accelerate") 8 | ) 9 | app = modal.App("dankreward", image=image) 10 | 11 | 12 | with image.imports(): 13 | from typing import List, Dict, Tuple 14 | import asyncio 15 | import torch 16 | from time import perf_counter as pc 17 | import copy 18 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 19 | # from lib import extract_tensors, test 20 | # print(test()) 21 | 22 | @app.function( 23 | keep_warm=1 24 | ) 25 | @modal.web_endpoint(method="POST") 26 | def upload(model_id): 27 | def extract_tensors(m: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]: 28 | """ 29 | Remove the tensors from a PyTorch model, convert them to NumPy 30 | arrays, and return the stripped model and tensors. 31 | """ 32 | tensors = [] 33 | for _, module in m.named_modules(): 34 | # Store the tensors in Python dictionaries 35 | params = { 36 | name: torch.clone(param).detach().numpy() 37 | for name, param in module.named_parameters(recurse=False) 38 | } 39 | buffers = { 40 | name: torch.clone(buf).detach().numpy() 41 | for name, buf in module.named_buffers(recurse=False) 42 | } 43 | tensors.append({"params": params, "buffers": buffers}) 44 | 45 | # Make a copy of the original model and strip all tensors and 46 | # temporary buffers out of the copy. 47 | m_copy = copy.deepcopy(m) 48 | for _, module in m_copy.named_modules(): 49 | for name in ( 50 | [name for name, _ in module.named_parameters(recurse=False)] 51 | + [name for name, _ in module.named_buffers(recurse=False)]): 52 | setattr(module, name, None) 53 | 54 | # Make sure the copy is configured for inference. 55 | m_copy.train(False) 56 | return m_copy, tensors 57 | 58 | # Create a memory snapshot with the model loaded in CPU memory. 59 | print("save state") 60 | 61 | # faster model loading 62 | dtype = torch.float16 63 | start = pc() 64 | self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id, device_map="cpu", 65 | trust_remote_code=True, torch_dtype=dtype, use_safetensors=True) 66 | elapsed = pc() - start 67 | print(f"loading model on cpu took {elapsed} seconds") 68 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True) 69 | print("extracting tensors") 70 | m_copy, tensors = extract_tensors(self.model) 71 | print("save state") 72 | 73 | # faster model loading 74 | dtype = torch.float16 75 | start = pc() 76 | self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id, device_map="cpu", 77 | trust_remote_code=True, torch_dtype=dtype, use_safetensors=True) 78 | elapsed = pc() - start 79 | print(f"loading model on cpu took {elapsed} seconds") 80 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True) 81 | print("extracting tensors") 82 | m_copy, tensors = extract_tensors(self.model) 83 | for _, module in m_copy.named_modules(): 84 | for name in ( 85 | [name for name, _ in module.named_parameters(recurse=False)] 86 | + [name for name, _ in module.named_buffers(recurse=False)]): 87 | setattr(module, name, None) 88 | 89 | # Make sure the copy is configured for inference. 90 | m_copy.train(False) 91 | return m_copy, tensors 92 | 93 | # Create a memory snapshot with the model loaded in CPU memory. 94 | print("save state") 95 | 96 | # faster model loading 97 | dtype = torch.float16 98 | start = pc() 99 | self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id, device_map="cpu", 100 | trust_remote_code=True, torch_dtype=dtype, use_safetensors=True) 101 | elapsed = pc() - start 102 | print(f"loading model on cpu took {elapsed} seconds") 103 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True) 104 | print("extracting tensors") 105 | m_copy, tensors = extract_tensors(self.model) 106 | return True 107 | 108 | 109 | @app.cls( 110 | gpu=modal.gpu.L4(), 111 | container_idle_timeout=10, 112 | volume=modal.voume("my-test-volume") 113 | ) 114 | class Embedder: 115 | model_id = "RLHFlow/ArmoRM-Llama3-8B-v0.1" 116 | device = "cuda" 117 | 118 | @modal.build() 119 | def build(self): 120 | # cache 121 | print("build") 122 | dtype = torch.bfloat16 123 | with torch.device("cpu"): 124 | model = AutoModelForSequenceClassification.from_pretrained(self.model_id, 125 | trust_remote_code=True, torch_dtype=dtype, use_safetensors=True) 126 | 127 | # @modal.enter(snap=True) 128 | @modal.enter() 129 | def load(self): 130 | # move back 131 | def replace_tensors(m: torch.nn.Module, tensors: List[Dict]): 132 | """ 133 | Restore the tensors that extract_tensors() stripped out of a 134 | PyTorch model. 135 | :param no_parameters_objects: Skip wrapping tensors in 136 | ``torch.nn.Parameters`` objects (~20% speedup, may impact 137 | some models) 138 | """ 139 | with torch.device("cuda"): 140 | modules = [module for _, module in m.named_modules()] 141 | for module, tensor_dict in zip(modules, tensors): 142 | # There are separate APIs to set parameters and buffers. 143 | for name, array in tensor_dict["params"].items(): 144 | module.register_parameter(name, 145 | torch.nn.Parameter(torch.as_tensor(array))) 146 | for name, array in tensor_dict["buffers"].items(): 147 | module.register_buffer(name, torch.as_tensor(array)) 148 | 149 | # Load tensors into the model's graph of Python objects 150 | 151 | # self.model = m_copy 152 | print("moving mock to cuda") 153 | start = pc() 154 | m_copy.to("cuda") 155 | elapsed = pc() - start 156 | print(f"moving mock to cuda took {elapsed} seconds") 157 | 158 | print("replacing tensors") 159 | start = pc() 160 | replace_tensors(m_copy, tensors) 161 | elapsed = pc() - start 162 | print(f"replacing took {elapsed} seconds") 163 | self.model = m_copy 164 | 165 | input_ids = self.tokenizer.apply_chat_template( 166 | [{"role": "user", "content": "test"}, {"role": "assistant", "content": "wow"}], 167 | return_tensors="pt", 168 | padding=True, 169 | truncation=True, 170 | max_length=4096, 171 | ).to("cuda") 172 | with torch.no_grad(): 173 | output = self.model(input_ids) 174 | print(output) 175 | 176 | # @modal.enter(snap=False) 177 | # @modal.enter() 178 | # def setup(self): 179 | # Move the model to a GPU before doing any work. 180 | # print("loaded from snapshot") 181 | # self.model.to("cuda") 182 | 183 | # # faster model loading 184 | # with torch.device("cuda"): 185 | # self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id, 186 | # trust_remote_code=True, torch_dtype=dtype, use_safetensors=True) 187 | # self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True) 188 | 189 | @modal.web_endpoint(method="POST", docs=True) 190 | def score_output(self, messages: List[Dict[str, str]]): 191 | print("batched") 192 | input_ids = self.tokenizer.apply_chat_template( 193 | messages, 194 | return_tensors="pt", 195 | padding=True, 196 | truncation=True, 197 | max_length=4096, 198 | ).to("cuda") 199 | with torch.no_grad(): 200 | output = self.model(input_ids) 201 | print(output) 202 | float_output = output.score.float() 203 | print("Score:", float_output.item()) 204 | return float_output.item() 205 | 206 | 207 | # @app.local_entrypoint() 208 | # async def main(): 209 | # # score the messages 210 | # prompt = 'What are some synonyms for the word "beautiful"?' 211 | # response1 = 'Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant' 212 | # response2 = 'bad' 213 | # messages1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}] 214 | # messages2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}] 215 | # m1 = Embedder().score_output(messages1) 216 | # m2 = Embedder().score_output(messages2) 217 | # res = await asyncio.gather(*[m1,m2]) 218 | # print(response1, res[0]) 219 | # print(response2, res[1]) -------------------------------------------------------------------------------- /eval_reward.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import asyncio 3 | import json 4 | from tqdm.asyncio import tqdm_asyncio 5 | from tqdm import tqdm 6 | from datasets import load_dataset 7 | import random 8 | from datetime import datetime 9 | 10 | MODAL_ENDPOINT = "https://rawsh--reward-api-model-score.modal.run" 11 | MAX_CONCURRENT = 32 12 | BATCH_SIZE = 10 13 | 14 | async def get_score(sem, session, messages, question_id, option_idx, answer, is_correct): 15 | """Get reward model score for a completion.""" 16 | async with sem: 17 | try: 18 | async with session.post( 19 | MODAL_ENDPOINT, 20 | json={"messages": messages}, 21 | headers={"Content-Type": "application/json"}, 22 | timeout=aiohttp.ClientTimeout(total=20) 23 | ) as response: 24 | if response.status != 200: 25 | print(f"Error {response.status}: {await response.text()}") 26 | score = 0 27 | else: 28 | result = await response.json() 29 | score = result.get('score', 0) 30 | 31 | return { 32 | 'question_id': question_id, 33 | 'option_idx': option_idx, 34 | 'answer': answer, 35 | 'score': float(score), 36 | 'is_correct': is_correct 37 | } 38 | except Exception as e: 39 | print(f"Exception in get_score: {str(e)}") 40 | return { 41 | 'question_id': question_id, 42 | 'option_idx': option_idx, 43 | 'answer': answer, 44 | 'score': 0, 45 | 'is_correct': is_correct 46 | } 47 | 48 | async def evaluate_question(sem, session, question_data, question_idx): 49 | """Evaluate a single question with all its options.""" 50 | question = question_data['question'] 51 | choices = question_data['mc1_targets']['choices'] 52 | correct_idx = question_data['mc1_targets']['labels'].index(1) 53 | 54 | print(f"\nEvaluating Question {question_idx}:") 55 | print(f"Question: {question}") 56 | print(f"Correct Answer: {chr(65 + correct_idx)}) {choices[correct_idx]}") 57 | 58 | all_requests = [] 59 | 60 | # Evaluate each option 61 | for idx, choice in enumerate(choices): 62 | formatted_question = f"{question}\n\nOptions:\n" 63 | for i, c in enumerate(choices): 64 | formatted_question += f"{chr(65 + i)}) {c}\n" 65 | 66 | messages = [ 67 | {"role": "user", "content": ("You are an expert at truthful reasoning and you always pick the most accurate answer. " 68 | "Think step by step and output your reasoning followed by your final answer.\n\n" + 69 | formatted_question)}, 70 | {"role": "assistant", "content": f"After careful consideration, I choose option {chr(65 + idx)}. {choice}"} 71 | ] 72 | 73 | all_requests.append(get_score( 74 | sem, 75 | session, 76 | messages, 77 | question_idx, 78 | idx, 79 | choice, 80 | idx == correct_idx 81 | )) 82 | 83 | return await asyncio.gather(*all_requests) 84 | 85 | async def process_batch(sem, session, batch_data, start_idx): 86 | """Process a batch of questions.""" 87 | batch_requests = [ 88 | evaluate_question(sem, session, example, idx) 89 | for idx, example in enumerate(batch_data, start_idx) 90 | ] 91 | return await tqdm_asyncio.gather(*batch_requests) 92 | 93 | async def evaluate_all(session, dataset): 94 | """Evaluate all questions in the dataset using batching.""" 95 | sem = asyncio.Semaphore(MAX_CONCURRENT) 96 | 97 | # Convert dataset to list and take same subset as original code 98 | dataset_list = list(dataset) 99 | random.seed(42) # Same seed as original code 100 | random.shuffle(dataset_list) 101 | dataset_list = dataset_list[:100] # Same subset size as original code 102 | 103 | results = [] 104 | print(f"\nEvaluating {len(dataset_list)} questions...") 105 | 106 | # Process in batches 107 | for i in range(0, len(dataset_list), BATCH_SIZE): 108 | batch_data = dataset_list[i:i + BATCH_SIZE] 109 | print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{(len(dataset_list) + BATCH_SIZE - 1)//BATCH_SIZE}") 110 | 111 | batch_results = await process_batch(sem, session, batch_data, i) 112 | results.extend(batch_results) 113 | 114 | await asyncio.sleep(1) # Small delay between batches 115 | 116 | return results, dataset_list 117 | 118 | async def main(): 119 | try: 120 | # Load TruthfulQA dataset 121 | dataset = load_dataset("truthful_qa", "multiple_choice") 122 | validation_set = dataset["validation"] 123 | print(f"Loaded {len(validation_set)} questions from TruthfulQA validation set") 124 | 125 | # Configure session 126 | connector = aiohttp.TCPConnector(limit=MAX_CONCURRENT, force_close=True) 127 | timeout = aiohttp.ClientTimeout(total=60) 128 | 129 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 130 | 131 | async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: 132 | all_results, dataset_list = await evaluate_all(session, validation_set) 133 | 134 | if all_results: 135 | # Process results by question 136 | results_by_question = {} 137 | for question_results in all_results: 138 | for result in question_results: 139 | qid = result['question_id'] 140 | if qid not in results_by_question: 141 | results_by_question[qid] = [] 142 | results_by_question[qid].append(result) 143 | 144 | # Calculate statistics 145 | total_questions = len(results_by_question) 146 | rank_1_count = 0 147 | total_correct_rank = 0 148 | total_score_diff = 0 149 | total_correct_score = 0 150 | total_best_score = 0 151 | 152 | print("\nDetailed Results:") 153 | for qid, scores in results_by_question.items(): 154 | # Sort by score 155 | scores.sort(key=lambda x: x['score'], reverse=True) 156 | 157 | # Find correct answer details 158 | correct_scores = [s for s in scores if s['is_correct']] 159 | if correct_scores: 160 | correct_score = correct_scores[0] 161 | correct_rank = scores.index(correct_score) + 1 162 | 163 | if correct_rank == 1: 164 | rank_1_count += 1 165 | 166 | total_correct_rank += correct_rank 167 | total_score_diff += scores[0]['score'] - correct_score['score'] 168 | total_correct_score += correct_score['score'] 169 | total_best_score += scores[0]['score'] 170 | 171 | print(f"\nQuestion {qid}:") 172 | print(f"Correct answer rank: {correct_rank} out of {len(scores)}") 173 | print(f"Correct score: {correct_score['score']:.4f}") 174 | print(f"Best score: {scores[0]['score']:.4f}") 175 | print(f"Score difference: {scores[0]['score'] - correct_score['score']:.4f}") 176 | 177 | print("\nSummary Statistics:") 178 | print(f"Average rank of correct answer: {total_correct_rank/total_questions:.2f}") 179 | print(f"Times correct answer ranked first: {rank_1_count}/{total_questions}") 180 | print(f"Average score difference from best: {total_score_diff/total_questions:.4f}") 181 | print(f"Average correct answer score: {total_correct_score/total_questions:.4f}") 182 | print(f"Average best score: {total_best_score/total_questions:.4f}") 183 | 184 | # Save results 185 | output_file = f'truthfulqa_reward_results_{timestamp}.json' 186 | with open(output_file, 'w') as f: 187 | json.dump({ 188 | 'results_by_question': results_by_question, 189 | 'summary': { 190 | 'total_questions': total_questions, 191 | 'rank_1_count': rank_1_count, 192 | 'avg_correct_rank': total_correct_rank/total_questions, 193 | 'avg_score_diff': total_score_diff/total_questions, 194 | 'avg_correct_score': total_correct_score/total_questions, 195 | 'avg_best_score': total_best_score/total_questions 196 | } 197 | }, f, indent=2) 198 | print(f"\nDetailed results saved to {output_file}") 199 | 200 | except Exception as e: 201 | print(f"Error in main: {str(e)}") 202 | raise 203 | finally: 204 | if 'connector' in locals() and hasattr(connector, 'close'): 205 | await connector.close() 206 | 207 | if __name__ == "__main__": 208 | asyncio.run(main()) -------------------------------------------------------------------------------- /mcts/greedy_sample.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from openai import AsyncOpenAI 3 | from datasets import load_dataset 4 | from tqdm import tqdm 5 | import random 6 | import aiohttp 7 | from typing import Optional, Tuple 8 | 9 | # Configuration 10 | POLICY_URL = 'https://rawsh--vllm-qwq-distill-serve.modal.run/v1/' 11 | API_KEY = '9FF74944EED19865193F979942FB1' 12 | MODEL_NAME = 'q1-Qwen2.5-0.5b' 13 | # MODEL_NAME = 'q1-Qwen2.5-0.5b-Instruct' 14 | # MODEL_NAME = 'q1-Qwen2.5-Math-1.5B' 15 | MAX_CONCURRENT_REQUESTS = 128 16 | MAX_RETRIES = 2 17 | # REQUEST_TIMEOUT = 300 # seconds 18 | REQUEST_TIMEOUT = 3000 # seconds 19 | MAX_TOKENS = 10000 # Maximum tokens per response 20 | 21 | class TokenCounter: 22 | def __init__(self): 23 | self.total_tokens = 0 24 | self.total_possible = 0 25 | self.pbar = None 26 | 27 | def init_progress(self, total_problems): 28 | max_possible = total_problems * MAX_TOKENS 29 | self.pbar = tqdm(total=max_possible, desc="Token usage", unit="tokens", position=1, leave=True) 30 | self.total_possible = max_possible 31 | 32 | def update(self, tokens: int): 33 | self.total_tokens += tokens 34 | self.pbar.update(tokens) 35 | return tokens 36 | 37 | def get_stats(self): 38 | return f"{self.total_tokens}/{self.total_possible} tokens ({(self.total_tokens/self.total_possible)*100:.1f}%)" 39 | 40 | async def solve_problem(client, question: str, token_counter: TokenCounter, max_tokens: int = MAX_TOKENS) -> Tuple[Optional[str], int]: 41 | """Generate a solution for a given math problem with retries, timeout, and token tracking.""" 42 | for attempt in range(MAX_RETRIES + 1): 43 | try: 44 | async with asyncio.timeout(REQUEST_TIMEOUT): 45 | response = await client.chat.completions.create( 46 | model=MODEL_NAME, 47 | messages=[ 48 | {"role": "user", "content": question}, 49 | {"role": "assistant", "content": ""} 50 | ], 51 | max_tokens=max_tokens, 52 | stop=["<|endoftext|>", "<|im_end|>"], 53 | # temperature=0.0, 54 | # temperature=0.1, 55 | temperature=1.0, 56 | extra_body={ 57 | # "repetition_penalty": 1.05, 58 | # "repetition_penalty": 1.10, 59 | # "top_p": 0.8, 60 | # "top_k": 20, 61 | "top_p": 1.0, 62 | "top_k": -1, 63 | # "frequency_penalty": 0.05, 64 | # "presence_penalty": 0.05, 65 | # "frequency_penalty": 0.15, 66 | # "presence_penalty": 0.15, 67 | # "min_p": 0.05, 68 | "min_p": 0.00, 69 | }, 70 | stream=True 71 | ) 72 | 73 | full_response = "" 74 | total_tokens = 0 75 | 76 | # Print question at the start 77 | print(f"\033[K\nQ: {question}\nA: ", end="", flush=True) 78 | 79 | async for chunk in response: 80 | if chunk.choices[0].delta.content: 81 | content = chunk.choices[0].delta.content 82 | full_response += content 83 | # Print content as it arrives 84 | # print(content.replace("\n", "\\n"), end="", flush=True) 85 | token_counter.update(1) 86 | total_tokens += 1 87 | 88 | # Print newline after response 89 | print("\n") 90 | return full_response.strip(), total_tokens 91 | 92 | except asyncio.TimeoutError: 93 | if attempt < MAX_RETRIES: 94 | print(f"\nTimeout on attempt {attempt + 1}, retrying...", flush=True) 95 | await asyncio.sleep(1 * (attempt + 1)) 96 | else: 97 | print(f"\nTimeout after {MAX_RETRIES} retries for question: {question[:100]}...", flush=True) 98 | return "", 0 99 | except Exception as e: 100 | if attempt < MAX_RETRIES: 101 | print(f"\nError on attempt {attempt + 1}: {e}, retrying...", flush=True) 102 | await asyncio.sleep(1 * (attempt + 1)) 103 | else: 104 | print(f"\nError after {MAX_RETRIES} retries: {e}", flush=True) 105 | return "", 0 106 | return "", 0 107 | 108 | async def process_problem(client, question, answer, semaphore, pbar, results, correct, token_counter): 109 | """Process a single problem with semaphore control and token tracking.""" 110 | async with semaphore: 111 | solution, tokens_used = await solve_problem(client, question, token_counter) 112 | is_solved = is_correct(solution, answer) 113 | 114 | if is_solved: 115 | correct.value += 1 116 | 117 | result = { 118 | "question": question, 119 | "correct_answer": answer, 120 | "solution": solution, 121 | "is_correct": is_solved, 122 | "tokens_used": tokens_used 123 | } 124 | 125 | accuracy = (correct.value / (len(results) + 1)) * 100 126 | 127 | # Move cursor to bottom, update progress bar, then restore cursor 128 | print(f"\033[KTokens used: {tokens_used}") 129 | print("\033[K" + "-" * 40) 130 | 131 | # Update progress bar 132 | pbar.set_description(f"Solving problems [{accuracy:.1f}% correct] [Tokens: {token_counter.get_stats()}]") 133 | pbar.update(1) 134 | 135 | results.append(result) 136 | 137 | def is_correct(solution, correct_answer): 138 | """Check if the solution contains the correct answer.""" 139 | answer_segment = str(correct_answer).strip() 140 | sol_segment = solution.strip()[-300:] 141 | print("\nGROUND TRUTH", answer_segment, "\nLLM RESPONSE", sol_segment) 142 | print(f"{{{answer_segment}}}") 143 | return f"{{{answer_segment}}}" in sol_segment 144 | 145 | class Counter: 146 | def __init__(self): 147 | self.value = 0 148 | 149 | async def main(): 150 | random.seed(42) 151 | 152 | client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) 153 | 154 | # Clear screen and move cursor to top 155 | print("\033[2J\033[H", end="") 156 | 157 | await warmup_api(client) 158 | 159 | token_counter = TokenCounter() 160 | 161 | def process_aime(example): 162 | example["answer"] = str(int(example["answer"].strip())) 163 | return example 164 | 165 | print("Loading dataset...") 166 | dataset = load_dataset("AI-MO/aimo-validation-aime", split="train") 167 | dataset = dataset.map(process_aime) 168 | problems = [(example["problem"], example["answer"]) for example in dataset] 169 | problems = problems[5:10] 170 | # problems = problems[70:80] 171 | 172 | semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) 173 | results = [] 174 | correct = Counter() 175 | 176 | # Initialize progress bars at the bottom 177 | token_counter.init_progress(len(problems)) 178 | pbar = tqdm(total=len(problems), desc="Solving problems [0.0% correct]", position=0, leave=True) 179 | 180 | async with aiohttp.ClientSession() as session: 181 | tasks = [ 182 | process_problem(client, question, answer, semaphore, pbar, results, correct, token_counter) 183 | for question, answer in problems 184 | ] 185 | await asyncio.gather(*tasks) 186 | 187 | pbar.close() 188 | token_counter.pbar.close() 189 | 190 | # Print final results 191 | final_accuracy = (correct.value / len(problems)) * 100 192 | print(f"\nFinal Results:") 193 | print(f"Total Problems: {len(problems)}") 194 | print(f"Correct Solutions: {correct.value}") 195 | print(f"Final Accuracy: {final_accuracy:.2f}%") 196 | print(f"Token Usage: {token_counter.get_stats()}") 197 | 198 | with open("greedy_results.txt", "w") as f: 199 | f.write(f"Final Accuracy: {final_accuracy:.2f}%\n") 200 | f.write(f"Token Usage: {token_counter.get_stats()}\n\n") 201 | for result in results: 202 | f.write(f"Question: {result['question']}\n") 203 | f.write(f"Correct Answer: {result['correct_answer']}\n") 204 | f.write(f"Solution: {result['solution']}\n") 205 | f.write(f"Correct: {result['is_correct']}\n") 206 | f.write(f"Tokens Used: {result['tokens_used']}\n") 207 | f.write("-" * 80 + "\n") 208 | 209 | async def warmup_api(client): 210 | """Warm up the API with a simple query and retry logic.""" 211 | print("Warming up API...") 212 | for attempt in range(MAX_RETRIES + 1): 213 | try: 214 | async with asyncio.timeout(REQUEST_TIMEOUT): 215 | completion = await client.chat.completions.create( 216 | model=MODEL_NAME, 217 | messages=[ 218 | {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, 219 | {"role": "user", "content": "What is 5+45+4=?"} 220 | ], 221 | stop=["<|im_end|>"], 222 | max_tokens=10, 223 | stream=True 224 | ) 225 | 226 | response = "" 227 | total_tokens = 0 228 | async for chunk in completion: 229 | if chunk.choices[0].delta.content: 230 | response += chunk.choices[0].delta.content 231 | total_tokens += 1 232 | 233 | print("API warmup successful") 234 | return 235 | except (asyncio.TimeoutError, Exception) as e: 236 | if attempt < MAX_RETRIES: 237 | print(f"Warmup attempt {attempt + 1} failed: {e}, retrying...") 238 | await asyncio.sleep(1 * (attempt + 1)) 239 | else: 240 | print(f"Warning: API warmup failed after {MAX_RETRIES} retries: {e}") 241 | 242 | if __name__ == "__main__": 243 | asyncio.run(main()) -------------------------------------------------------------------------------- /mcts/train_policy_sft_qwq.py: -------------------------------------------------------------------------------- 1 | from unsloth import FastLanguageModel 2 | import torch 3 | import wandb 4 | from datasets import load_dataset 5 | from unsloth import is_bfloat16_supported 6 | from unsloth import UnslothTrainer, UnslothTrainingArguments 7 | from unsloth.chat_templates import get_chat_template 8 | 9 | # Constants 10 | max_seq_length = 32768 # Choose any! We auto support RoPE Scaling internally! 11 | dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ 12 | load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False. 13 | 14 | def extract_boxed_solution(text): 15 | """ 16 | Extract solution from \boxed{} notation and clean it. 17 | Returns None if no boxed solution is found. 18 | """ 19 | import re 20 | try: 21 | # Find content inside \boxed{} 22 | match = re.search(r'\\boxed{([^}]+)}', text) 23 | if match: 24 | # Extract and clean the solution 25 | solution = match.group(1) 26 | # Remove spaces, newlines and extra whitespace 27 | solution = re.sub(r'\s+', '', solution) 28 | return solution 29 | return None 30 | except Exception as e: 31 | print(f"Error extracting boxed solution: {str(e)}") 32 | return None 33 | 34 | def filter_dataset(example): 35 | """ 36 | Filter dataset based on solution matching between boxed answer and QwQ response. 37 | Returns True if example should be kept, False if it should be filtered out. 38 | 39 | Args: 40 | example: Dictionary containing dataset fields (problem, solution, qwq) 41 | 42 | Returns: 43 | bool: True if example meets criteria, False otherwise 44 | """ 45 | try: 46 | # Extract solution from the solution column 47 | if 'solution' not in example: 48 | return False 49 | 50 | boxed_solution = extract_boxed_solution(example['solution']) 51 | if not boxed_solution: 52 | return False 53 | 54 | # Clean the QwQ response for comparison 55 | qwq_clean = ''.join(example['qwq'].split()) 56 | 57 | # Check if the boxed solution appears in the QwQ response 58 | if boxed_solution not in qwq_clean: 59 | return False 60 | 61 | # Additional basic quality checks 62 | if len(example['qwq']) < 50: # Minimum response length 63 | return False 64 | 65 | # Additional basic quality checks 66 | if len(example['qwq']) > 10000: # Max response length 67 | return False 68 | 69 | return True 70 | 71 | except Exception as e: 72 | print(f"Error in filter validation: {str(e)}") 73 | return False 74 | 75 | def train_sft(): 76 | # Load base and instruct models 77 | model, tokenizer = FastLanguageModel.from_pretrained( 78 | model_name = "unsloth/Qwen2.5-0.5B", 79 | # model_name = "unsloth/Qwen2.5-Math-1.5B", 80 | max_seq_length = max_seq_length, 81 | dtype = dtype, 82 | load_in_4bit = load_in_4bit, 83 | ) 84 | model_instruct, tokenizer_instruct = FastLanguageModel.from_pretrained( 85 | model_name = "unsloth/Qwen2.5-0.5B-Instruct", 86 | # model_name = "unsloth/Qwen2.5-Math-1.5B-Instruct", 87 | max_seq_length = max_seq_length, 88 | dtype = dtype, 89 | load_in_4bit = load_in_4bit, 90 | ) 91 | # model, tokenizer = FastLanguageModel.from_pretrained( 92 | # model_name = "unsloth/Qwen2.5-0.5B-Instruct", 93 | # # model_name = "unsloth/Qwen2.5-Math-1.5B-Instruct", 94 | # max_seq_length = max_seq_length, 95 | # dtype = dtype, 96 | # load_in_4bit = load_in_4bit, 97 | # ) 98 | 99 | # TRANSFER = False 100 | TRANSFER = True 101 | if TRANSFER: 102 | # Transfer chat token embeddings from instruct to base model 103 | base_embeddings = model.get_input_embeddings() 104 | instruct_embeddings = model_instruct.get_input_embeddings() 105 | chat_tokens = ["<|im_start|>", "<|im_end|>", "system", "assistant", "user"] 106 | with torch.no_grad(): 107 | for token in chat_tokens: 108 | try: 109 | instruct_id = tokenizer_instruct.convert_tokens_to_ids(token) 110 | base_id = tokenizer.convert_tokens_to_ids(token) 111 | if instruct_id != tokenizer_instruct.unk_token_id and base_id != tokenizer.unk_token_id: 112 | base_embeddings.weight[base_id] = instruct_embeddings.weight[instruct_id].clone() 113 | print(f"Transferred embedding for token: {token}") 114 | else: 115 | print(f"Warning: Token {token} not found in one of the vocabularies") 116 | except Exception as e: 117 | print(f"Error transferring token {token}: {str(e)}") 118 | 119 | # Add LoRA adapters 120 | model = FastLanguageModel.get_peft_model( 121 | model, 122 | r = 128, # Choose any number > 0! Suggested 8, 16, 32, 64, 128 123 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", 124 | "gate_proj", "up_proj", "down_proj", 125 | "embed_tokens", "lm_head",], # Add for continual pretraining 126 | lora_alpha = 32, 127 | lora_dropout = 0, # Supports any, but = 0 is optimized 128 | bias = "none", # Supports any, but = "none" is optimized 129 | use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context 130 | random_state = 3407, 131 | use_rslora = True, # We support rank stabilized LoRA 132 | loftq_config = None, # And LoftQ 133 | ) 134 | 135 | # Set up tokenizer with chat template 136 | tokenizer = get_chat_template( 137 | tokenizer, 138 | chat_template = "qwen-2.5", 139 | ) 140 | tokenizer.eos_token = "<|im_end|>" 141 | print(tokenizer.eos_token) 142 | print(tokenizer.pad_token) 143 | 144 | # Load and process dataset 145 | dataset = load_dataset("qingy2024/QwQ-LongCoT-Verified-130K", "verified", split="train") 146 | 147 | # Apply filtering 148 | print("filtering") 149 | filtered_dataset = dataset.filter(filter_dataset) 150 | print(f"Original dataset size: {len(dataset)}") 151 | print(f"Filtered dataset size: {len(filtered_dataset)}") 152 | 153 | # Print some examples of filtered data 154 | print("\nExample of filtered data:") 155 | for idx in range(min(3, len(filtered_dataset))): 156 | print(f"\nExample {idx + 1}:") 157 | print("Problem:", filtered_dataset[idx]['problem'][:200], "...") 158 | print("Response:", filtered_dataset[idx]['qwq'][:200], "...") 159 | 160 | 161 | def formatting_prompts_func(examples): 162 | conversations = [] 163 | for query, response in zip(examples['problem'], examples['qwq']): 164 | # break 165 | conversation = [ 166 | {"role": "user", "content": query}, 167 | {"role": "assistant", "content": response} 168 | ] 169 | conversations.append(conversation) 170 | 171 | texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) 172 | for convo in conversations] 173 | return {"text": texts} 174 | 175 | dataset = filtered_dataset.map(formatting_prompts_func, batched=True, remove_columns=dataset.column_names) 176 | print(len(dataset)) 177 | 178 | # Debug tokenizer output - show examples 179 | print("Example of tokenized output:") 180 | print(dataset[5]["text"]) 181 | print("\nAnother example:") 182 | print(dataset[100]["text"]) 183 | 184 | # Configure trainer 185 | trainer = UnslothTrainer( 186 | model = model, 187 | tokenizer = tokenizer, 188 | train_dataset = dataset, 189 | dataset_text_field = "text", 190 | max_seq_length = max_seq_length, 191 | dataset_num_proc = 8, 192 | 193 | args = UnslothTrainingArguments( 194 | # learning_rate = 5e-5, 195 | # embedding_learning_rate = 5e-6, 196 | # learning_rate = 3e-5, 197 | # embedding_learning_rate = 3e-6, 198 | learning_rate = 3e-6, 199 | embedding_learning_rate = 3e-7, 200 | # per_device_train_batch_size = 8, # With gradient_accumulation_steps=8 this gives effective batch size 64 201 | per_device_train_batch_size = 4, 202 | gradient_accumulation_steps = 8, 203 | lr_scheduler_type = "cosine", 204 | num_train_epochs = 3, 205 | # num_train_epochs = 2, 206 | # num_train_epochs = 1, 207 | warmup_ratio = 0.1, 208 | max_seq_length = 2048, 209 | fp16 = not is_bfloat16_supported(), 210 | bf16 = is_bfloat16_supported(), 211 | optim = "adamw_8bit", 212 | weight_decay = 0.01, 213 | logging_steps = 1, 214 | seed = 3407, 215 | output_dir = "outputs", 216 | report_to = "wandb", 217 | # run_name = "qwqdistill1.5", 218 | run_name = "qwqdistill", 219 | hub_strategy = "every_save", 220 | save_strategy = "steps", 221 | save_steps = 100, 222 | hub_model_id = "rawsh/q1-Qwen2.5-0.5b" 223 | # hub_model_id = "rawsh/q1-Qwen2.5-0.5b-Instruct" 224 | # hub_model_id = "rawsh/q1-Qwen2.5-Math-1.5B" 225 | ), 226 | ) 227 | 228 | # Set up wandb 229 | # wandb.login(key="YOUR_WANDB_KEY") # Replace with your key 230 | # wandb.init(project='metamath') 231 | 232 | # Print initial GPU stats 233 | gpu_stats = torch.cuda.get_device_properties(0) 234 | start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) 235 | max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) 236 | print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") 237 | print(f"{start_gpu_memory} GB of memory reserved.") 238 | 239 | # Train 240 | trainer_stats = trainer.train() 241 | 242 | # Show final memory and time stats 243 | used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) 244 | used_memory_for_lora = round(used_memory - start_gpu_memory, 3) 245 | used_percentage = round(used_memory/max_memory*100, 3) 246 | lora_percentage = round(used_memory_for_lora/max_memory*100, 3) 247 | 248 | print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") 249 | print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.") 250 | print(f"Peak reserved memory = {used_memory} GB.") 251 | print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") 252 | print(f"Peak reserved memory % of max memory = {used_percentage} %.") 253 | print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.") 254 | 255 | # Save model to HuggingFace Hub 256 | model.push_to_hub_merged( 257 | "rawsh/q1-Qwen2.5-0.5b", # Replace with your username 258 | # "rawsh/q1-Qwen2.5-0.5b-Instruct", 259 | # "rawsh/q1-Qwen2.5-Math-1.5B", 260 | tokenizer, 261 | save_method="merged_16bit", 262 | ) 263 | 264 | if __name__ == "__main__": 265 | train_sft() -------------------------------------------------------------------------------- /mcts/simple_sample.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import aiohttp 3 | from openai import AsyncOpenAI 4 | import random 5 | from datasets import load_dataset 6 | from tqdm.asyncio import tqdm 7 | from typing import List, Tuple, Dict 8 | import json 9 | from asyncio import Semaphore 10 | from collections import Counter 11 | from functools import wraps 12 | from collections import OrderedDict 13 | import math # Added for math.exp in evaluate_step 14 | 15 | # Configuration 16 | POLICY_URL = 'https://rawsh--vllm-qwen-metamath-serve.modal.run/v1/' 17 | PRM_URL = 'https://rawsh--vllm-qwen-prm-serve.modal.run/v1/' 18 | API_KEY = '9FF74944EED19865193F979942FB1' 19 | BATCH_SIZE = 100 # Reduced batch size since we're doing multiple requests per question 20 | MAX_RETRIES = 5 21 | TIMEOUT = 20 22 | MAX_CONCURRENT = 100 23 | SAMPLES_PER_QUESTION = 1 # Default to single sample mode, override with CLI arg 24 | 25 | # Cache decorator for PRM scores 26 | def async_lru_cache(maxsize=10000): 27 | cache = OrderedDict() 28 | def decorator(func): 29 | @wraps(func) 30 | async def wrapper(*args, **kwargs): 31 | key = str(args) + str(kwargs) 32 | if key not in cache: 33 | if len(cache) >= maxsize: 34 | cache.popitem(last=False) 35 | cache[key] = await func(*args, **kwargs) 36 | return cache[key] 37 | return wrapper 38 | return decorator 39 | 40 | class BatchProgress: 41 | def __init__(self, total_questions: int, samples_per_question: int): 42 | self.total = total_questions 43 | self.samples = samples_per_question 44 | self.correct_any = 0 45 | self.correct_best = 0 46 | self.correct_sc = 0 47 | self.processed = 0 48 | self.pbar = tqdm(total=total_questions, desc=self.get_description()) 49 | 50 | def get_description(self) -> str: 51 | if self.processed == 0: 52 | return "Starting..." 53 | 54 | any_acc = (self.correct_any / self.processed) * 100 55 | if self.samples > 1: 56 | best_acc = (self.correct_best / self.processed) * 100 57 | sc_acc = (self.correct_sc / self.processed) * 100 58 | return f"Processed: {self.processed}/{self.total} | Any: {any_acc:.1f}% | Best: {best_acc:.1f}% | SC: {sc_acc:.1f}%" 59 | else: 60 | return f"Processed: {self.processed}/{self.total} | Accuracy: {any_acc:.1f}%" 61 | 62 | def update(self, any_correct: bool, best_correct: bool = None, sc_correct: bool = None): 63 | self.processed += 1 64 | if any_correct: 65 | self.correct_any += 1 66 | if best_correct is not None: 67 | if best_correct: 68 | self.correct_best += 1 69 | if sc_correct is not None: 70 | if sc_correct: 71 | self.correct_sc += 1 72 | self.pbar.update(1) 73 | self.pbar.set_description(self.get_description()) 74 | 75 | def close(self): 76 | self.pbar.close() 77 | if self.processed > 0: 78 | any_acc = (self.correct_any / self.processed) * 100 79 | print(f"\nFinal Results:") 80 | print(f"Total Questions: {self.processed}") 81 | print(f"Single Sample Accuracy: {any_acc:.2f}%") 82 | 83 | if self.samples > 1: 84 | best_acc = (self.correct_best / self.processed) * 100 85 | sc_acc = (self.correct_sc / self.processed) * 100 86 | print(f"Best-of-{self.samples} Accuracy: {best_acc:.2f}%") 87 | print(f"Self-Consistency Accuracy: {sc_acc:.2f}%") 88 | 89 | async def retry_with_exponential_backoff(func, *args, **kwargs): 90 | for attempt in range(MAX_RETRIES): 91 | try: 92 | return await asyncio.wait_for(func(*args, **kwargs), timeout=TIMEOUT) 93 | except Exception as e: 94 | if attempt == MAX_RETRIES - 1: 95 | raise 96 | delay = min(1.5 ** attempt + random.random(), 10) 97 | await asyncio.sleep(delay) 98 | 99 | @async_lru_cache(maxsize=10000) 100 | async def get_prm_score(completion: str, session: aiohttp.ClientSession) -> float: 101 | """Get the PRM score for a completion.""" 102 | async with session.post(PRM_URL, json={"prompt": completion}) as response: 103 | result = await response.json() 104 | return float(result['score']) 105 | 106 | async def generate_completion( 107 | question: str, 108 | client: AsyncOpenAI, 109 | semaphore: Semaphore 110 | ) -> str: 111 | """Generate a single completion using chat-based API.""" 112 | async with semaphore: 113 | messages = [ 114 | {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, 115 | {"role": "user", "content": question} 116 | ] 117 | response = await client.chat.completions.create( 118 | timeout=TIMEOUT, 119 | model="MetaMath-Qwen2.5-0.5b", # Ensure this is the correct model name 120 | messages=messages, 121 | max_tokens=1500, 122 | # temperature=1.2, 123 | temperature=0.0, 124 | stop=["<|endoftext|>", "<|im_end|>"], 125 | extra_body={ 126 | "repetition_penalty": 1.05, 127 | "top_p": 0.8, 128 | "top_k": 20, 129 | "frequency_penalty": 0.05, 130 | "presence_penalty": 0.05, 131 | } 132 | ) 133 | # print(response) 134 | return response.choices[0].message.content.strip() 135 | 136 | async def evaluate_question( 137 | question: str, 138 | answer: str, 139 | client: AsyncOpenAI, 140 | session: aiohttp.ClientSession, 141 | semaphore: Semaphore, 142 | samples_per_question: int 143 | ) -> Dict: 144 | """Evaluate a question with single or multiple samples.""" 145 | try: 146 | # Generate completions 147 | completions = [] 148 | for _ in range(samples_per_question): 149 | completion = await retry_with_exponential_backoff( 150 | generate_completion, question, client, semaphore 151 | ) 152 | completions.append(completion) 153 | 154 | # For single sample mode, return simpler result 155 | if samples_per_question == 1: 156 | # is_correct = fr"\boxed{{{answer}}}" in completions[0] 157 | completion = completions[0].split("\n\n")[-1] 158 | is_correct = answer in completion 159 | print(completions[0].split("\n\n")[-1]) 160 | print(f"ANSWER: {completion} {answer} ({is_correct})") 161 | return { 162 | "question": question, 163 | "expected_answer": answer, 164 | "completion": completions[0], 165 | "correct": is_correct 166 | } 167 | 168 | # For multi-sample mode, evaluate with PRM 169 | scores = [] 170 | for completion in completions: 171 | score = await retry_with_exponential_backoff( 172 | get_prm_score, completion, session 173 | ) 174 | scores.append(score) 175 | 176 | # Evaluate correctness and extract answers 177 | is_correct = [] 178 | extracted_answers = [] 179 | for completion in completions: 180 | # correct = fr"\boxed{{{answer}}}" in completion 181 | extracted = completion.split("\n\n")[-1].split("The answer is: ")[-1].strip() 182 | is_correct.append(answer in extracted) 183 | 184 | # Extract answer for self-consistency 185 | # if r"\boxed{" in completion: 186 | # extracted = completion.split(r"\boxed{")[1].split("}")[0] 187 | print(extracted) 188 | extracted_answers.append(extracted) 189 | 190 | # Find best completion by PRM score 191 | best_idx = max(range(len(scores)), key=lambda i: scores[i]) 192 | 193 | # Calculate self-consistency 194 | answer_counts = Counter(extracted_answers) 195 | most_common_answer = answer_counts.most_common(1)[0][0] if answer_counts else None 196 | is_sc_correct = most_common_answer == answer if most_common_answer else False 197 | 198 | return { 199 | "question": question, 200 | "expected_answer": answer, 201 | "completions": [ 202 | { 203 | "text": compl, 204 | "score": score, 205 | "correct": corr 206 | } 207 | for compl, score, corr in zip(completions, scores, is_correct) 208 | ], 209 | "best_completion": { 210 | "text": completions[best_idx], 211 | "score": scores[best_idx], 212 | "correct": is_correct[best_idx] 213 | }, 214 | "statistics": { 215 | "any_correct": any(is_correct), 216 | "best_correct": is_correct[best_idx], 217 | "self_consistency_correct": is_sc_correct, 218 | "unique_answers": len(answer_counts), 219 | "most_common_answer": most_common_answer, 220 | "most_common_count": answer_counts.most_common(1)[0][1] if answer_counts else 0 221 | } 222 | } 223 | 224 | except Exception as e: 225 | return { 226 | "question": question, 227 | "expected_answer": answer, 228 | "error": str(e) 229 | } 230 | 231 | async def process_batch( 232 | batch: List[Tuple[str, str]], 233 | client: AsyncOpenAI, 234 | session: aiohttp.ClientSession, 235 | progress: BatchProgress, 236 | semaphore: Semaphore, 237 | samples_per_question: int 238 | ) -> List[dict]: 239 | """Process a batch of questions concurrently.""" 240 | tasks = [] 241 | for question, answer in batch: 242 | tasks.append( 243 | evaluate_question( 244 | question, answer, client, session, semaphore, samples_per_question 245 | ) 246 | ) 247 | 248 | results = await asyncio.gather(*tasks) 249 | 250 | # Update progress based on mode 251 | for result in results: 252 | if "error" not in result: 253 | if samples_per_question == 1: 254 | progress.update(result["correct"]) 255 | else: 256 | progress.update( 257 | result["statistics"]["any_correct"], 258 | result["statistics"]["best_correct"], 259 | result["statistics"]["self_consistency_correct"] 260 | ) 261 | 262 | return results 263 | 264 | # greedy: 57% 265 | 266 | async def main(): 267 | import argparse 268 | parser = argparse.ArgumentParser() 269 | parser.add_argument("--samples", type=int, default=1, 270 | help="Number of samples per question (default: 1)") 271 | parser.add_argument("--num-questions", type=int, default=100, 272 | help="Number of questions to evaluate (default: 200)") 273 | args = parser.parse_args() 274 | 275 | # Set random seed for reproducibility 276 | random.seed(0) 277 | 278 | # Load and preprocess dataset 279 | gsm8k = load_dataset("openai/gsm8k", "main", split="test").shuffle(seed=42) 280 | questions = [(ex["question"], ex["answer"].split("\n#### ")[-1].strip()) 281 | for ex in gsm8k] 282 | questions = random.sample(questions, args.num_questions) 283 | 284 | # Initialize API client and semaphore 285 | client = AsyncOpenAI(base_url=POLICY_URL, api_key=API_KEY) 286 | semaphore = Semaphore(MAX_CONCURRENT) 287 | 288 | # Initialize progress tracker 289 | progress = BatchProgress(len(questions), args.samples) 290 | 291 | # Process in batches 292 | all_results = [] 293 | 294 | # Create session only if needed (multi-sample mode) 295 | if args.samples > 1: 296 | async with aiohttp.ClientSession() as session: 297 | for i in range(0, len(questions), BATCH_SIZE): 298 | batch = questions[i:i + BATCH_SIZE] 299 | results = await process_batch( 300 | batch, client, session, progress, semaphore, args.samples 301 | ) 302 | all_results.extend(results) 303 | else: 304 | # Use a dummy session since PRM is not needed in single-sample mode 305 | async with aiohttp.ClientSession() as session: 306 | for i in range(0, len(questions), BATCH_SIZE): 307 | batch = questions[i:i + BATCH_SIZE] 308 | results = await process_batch( 309 | batch, client, session, progress, semaphore, args.samples 310 | ) 311 | all_results.extend(results) 312 | 313 | # Save results 314 | suffix = f"{args.samples}samples" if args.samples > 1 else "single" 315 | filename = f"sampling_results_{suffix}.jsonl" 316 | with open(filename, "w") as f: 317 | for result in all_results: 318 | f.write(json.dumps(result) + "\n") 319 | 320 | progress.close() 321 | 322 | if __name__ == "__main__": 323 | asyncio.run(main()) 324 | -------------------------------------------------------------------------------- /mcts/tree_search_light_mathrm.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import math 3 | import aiohttp 4 | from openai import AsyncOpenAI 5 | from datasets import load_dataset 6 | import random 7 | from typing import List, Tuple, Dict, Set, Optional 8 | 9 | # Configuration 10 | class Config: 11 | POLICY_MODEL_NAME = 'MetaMath-Qwen2.5-0.5b' 12 | POLICY_URL = 'https://rawsh--vllm-qwen-metamath-serve.modal.run/v1/' 13 | PRM_URL = 'https://rawsh--vllm-qwen-prm-serve.modal.run/v1/' 14 | PRM_MODEL_NAME = 'MetaMath-Qwen2.5-0.5b-PRM' 15 | API_KEY = '9FF74944EED19865193F979942FB1' 16 | UCT_CONSTANT = 1.41 17 | TIMEOUT = 60 18 | 19 | class Node: 20 | def __init__(self, state: str, parent: Optional['Node'] = None): 21 | self.state = state 22 | self.parent = parent 23 | self.children: Dict[str, Node] = {} 24 | self.visits = 0 25 | self.total_value = 0.0 26 | self.prm_value: Optional[float] = None 27 | 28 | class MCTSWorker: 29 | def __init__(self, root_state: str, correct_answer: str): 30 | self.root = Node(root_state) 31 | self.correct_answer = correct_answer 32 | self.terminal_nodes: Set[Node] = set() 33 | 34 | async def select_next(self) -> Node: 35 | """Returns leaf node that should be expanded next.""" 36 | print("\nSELECT_NEXT -----") 37 | node = self.root 38 | depth = 0 39 | path = [node] 40 | 41 | while node.children: 42 | # If node is not fully expanded, return it 43 | if len(node.children) < 3: # Assuming 3 possible actions 44 | print(f"Found unexpanded node at depth {depth}") 45 | print(f"Current children: {list(node.children.keys())}") 46 | return node 47 | node = self._best_uct_child(node) 48 | path.append(node) 49 | depth += 1 50 | 51 | print(f"Reached leaf node at depth {depth}") 52 | print(f"Path taken: {' -> '.join(str(n.visits) for n in path)}") 53 | return node 54 | 55 | def _best_uct_child(self, node: Node) -> Node: 56 | best_value = float('-inf') 57 | best_child = None 58 | parent_visits = node.visits 59 | ln_parent = math.log(parent_visits + 1) 60 | 61 | print("\nUCT calculation for children:") 62 | for action, child in node.children.items(): 63 | if child.visits == 0: 64 | print(f"Found unvisited child for action: {action}") 65 | return child 66 | 67 | exploit = child.total_value / child.visits 68 | explore = Config.UCT_CONSTANT * math.sqrt(ln_parent / child.visits) 69 | uct_value = exploit + explore 70 | 71 | print(f"Action: {action[:20]}...") 72 | print(f"Visits: {child.visits}, Total value: {child.total_value:.3f}") 73 | print(f"UCT = {exploit:.3f} (exploit) + {explore:.3f} (explore) = {uct_value:.3f}") 74 | 75 | if uct_value > best_value: 76 | best_value = uct_value 77 | best_child = child 78 | 79 | print(f"Selected best child with UCT value: {best_value:.3f}") 80 | return best_child 81 | 82 | class MCTSManager: 83 | def __init__(self): 84 | self.policy_client = AsyncOpenAI( 85 | base_url=Config.POLICY_URL, 86 | api_key=Config.API_KEY 87 | ) 88 | self.prm_client = AsyncOpenAI( 89 | base_url=Config.PRM_URL, 90 | api_key=Config.API_KEY 91 | ) 92 | 93 | async def get_next_action(self, state: str) -> Tuple[str, bool]: 94 | """Get next action from policy model.""" 95 | print("\nGET_NEXT_ACTION -----") 96 | print(f"Input state:\n{state}") 97 | steps = state.split("\n\n") 98 | question = steps[0] 99 | answer = "\n\n".join(steps[1:]) if len(steps) > 1 else None 100 | 101 | messages = [ 102 | {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, 103 | {"role": "user", "content": question} 104 | ] 105 | 106 | if answer: 107 | messages.append({"role": "assistant", "content": answer + "\n\n"}) 108 | else: 109 | messages.append({"role": "assistant", "content": ""}) 110 | 111 | try: 112 | print(f"Calling policy model with {len(messages)} messages") 113 | response = await self.policy_client.chat.completions.create( 114 | timeout=Config.TIMEOUT, 115 | model=Config.POLICY_MODEL_NAME, 116 | messages=messages, 117 | max_tokens=150, 118 | stop=["<|endoftext|>", "<|im_end|>", "\n\n"], 119 | temperature=1.0, 120 | extra_body={ 121 | "repetition_penalty": 1.05, 122 | "top_p": 0.8, 123 | "top_k": 20, 124 | "frequency_penalty": 0.05, 125 | "presence_penalty": 0.05, 126 | } 127 | ) 128 | content = response.choices[0].message.content.strip() 129 | is_term = (response.choices[0].finish_reason == 'stop' and 130 | response.choices[0].stop_reason != '\n\n') 131 | 132 | print(f"Generated content: {content}") 133 | print(f"Is terminal: {is_term}") 134 | if is_term: 135 | print(f"TERMINAL STATE REACHED") 136 | print(f"Q: {question}") 137 | print(f"A: {content}") 138 | return content, is_term 139 | 140 | except Exception as e: 141 | print(f"Error getting next action: {e}") 142 | print(f"Full state that caused error: {state}") 143 | return "", True 144 | 145 | async def evaluate_state(self, state: str, session: aiohttp.ClientSession) -> float: 146 | """Evaluate state using PRM model.""" 147 | print("\nEVALUATE_STATE -----") 148 | print(f"Input state:\n{state}") 149 | 150 | steps = state.split("\n\n") 151 | if len(steps) < 2: 152 | print("Warning: state has less than 2 steps, returning 0.0") 153 | return 0.0 154 | 155 | question = steps[0] 156 | curr_step = steps[-1] 157 | 158 | messages = [] 159 | if len(steps) == 2: 160 | messages = [{"role": "user", "content": f"{question} Step 1: {curr_step}"}] 161 | else: 162 | messages = [{"role": "user", "content": f"{question} Step 1: {steps[1]}"}] 163 | for i, step in enumerate(steps[2:-1], start=2): 164 | messages.extend([ 165 | {"role": "assistant", "content": "+"}, 166 | {"role": "user", "content": f"Step {i}: {step}"} 167 | ]) 168 | messages.extend([ 169 | {"role": "assistant", "content": "+"}, 170 | {"role": "user", "content": f"Step {len(steps)-1}: {curr_step}"} 171 | ]) 172 | 173 | try: 174 | print(f"Calling PRM model with {len(messages)} messages") 175 | print(f"Messages: {messages}") 176 | 177 | response = await self.prm_client.chat.completions.create( 178 | timeout=Config.TIMEOUT, 179 | model=Config.PRM_MODEL_NAME, 180 | messages=messages, 181 | max_tokens=1, 182 | temperature=0.0, 183 | logprobs=True, 184 | top_logprobs=20, 185 | extra_body={ 186 | "repetition_penalty": 1.05, 187 | "top_p": 0.8, 188 | "top_k": 20, 189 | "frequency_penalty": 0.05, 190 | "presence_penalty": 0.05, 191 | "add_generation_prompt": True, 192 | } 193 | ) 194 | logprobs = response.choices[0].logprobs.content[0].top_logprobs 195 | print(f"Got logprobs: {logprobs}") 196 | 197 | prob_plus = next( 198 | (math.exp(lp.logprob) for lp in logprobs if lp.token == "+"), 199 | 1e-10 200 | ) 201 | print(f"Probability for '+' token: {prob_plus}") 202 | return prob_plus 203 | 204 | except Exception as e: 205 | print(f"Error evaluating state: {e}") 206 | print(f"Full state that caused error: {state}") 207 | return 0.0 208 | 209 | async def find_best_leaf(self, node: Node, session: aiohttp.ClientSession) -> Node: 210 | """Find the best leaf node using PRM values.""" 211 | leaf_nodes: List[Node] = [] 212 | self._collect_leaf_nodes(node, leaf_nodes) 213 | 214 | # Evaluate all leaves in parallel 215 | await asyncio.gather(*[ 216 | self._evaluate_and_update_node(leaf, session) 217 | for leaf in leaf_nodes 218 | ]) 219 | 220 | return max(leaf_nodes, key=lambda leaf: leaf.prm_value or 0.0) 221 | 222 | def _collect_leaf_nodes(self, node: Node, leaf_nodes: List[Node]) -> None: 223 | """Collect all leaf nodes in the tree.""" 224 | if not node.children: 225 | leaf_nodes.append(node) 226 | else: 227 | for child in node.children.values(): 228 | self._collect_leaf_nodes(child, leaf_nodes) 229 | 230 | async def _evaluate_and_update_node( 231 | self, 232 | node: Node, 233 | session: aiohttp.ClientSession 234 | ) -> None: 235 | """Evaluate a node and update its PRM value.""" 236 | node.prm_value = await self.evaluate_state(node.state, session) 237 | 238 | async def run_batch_mcts( 239 | initial_states: List[Tuple[str, str]], 240 | num_trees_per_q: int = 1, 241 | iters_per_tree: int = 50 242 | ) -> List[Dict]: 243 | """Run batch MCTS on multiple initial states.""" 244 | mcts_manager = MCTSManager() 245 | 246 | # Create workers for each question 247 | all_workers = [ 248 | ([MCTSWorker(state, answer) for _ in range(num_trees_per_q)], answer) 249 | for state, answer in initial_states 250 | ] 251 | 252 | async with aiohttp.ClientSession() as session: 253 | # Run iterations 254 | for iteration in range(iters_per_tree): 255 | # Phase 1: Select leaves from all trees 256 | all_leaves = [] 257 | for workers, _ in all_workers: 258 | leaves = await asyncio.gather(*[w.select_next() for w in workers]) 259 | all_leaves.extend(leaves) 260 | 261 | # Phase 2: Generate actions for all leaves 262 | actions = await asyncio.gather(*[ 263 | mcts_manager.get_next_action(leaf.state) for leaf in all_leaves 264 | ]) 265 | 266 | # Phase 3: Apply actions & evaluate 267 | new_states = [ 268 | (leaf, action, f"{leaf.state}\n\n{action}", is_term) 269 | for leaf, (action, is_term) in zip(all_leaves, actions) 270 | ] 271 | 272 | values = await asyncio.gather(*[ 273 | mcts_manager.evaluate_state(state, session) 274 | for _, _, state, _ in new_states 275 | ]) 276 | 277 | # Phase 4: Update all trees 278 | for (leaf, action, new_state, is_term), value in zip(new_states, values): 279 | if not action: # Skip empty actions 280 | continue 281 | 282 | child = Node(new_state, parent=leaf) 283 | leaf.children[action] = child 284 | 285 | if is_term: 286 | for workers, _ in all_workers: 287 | for w in workers: 288 | if leaf in w.root.children.values(): 289 | w.terminal_nodes.add(child) 290 | 291 | # Backpropagate 292 | node = child 293 | while node: 294 | node.visits += 1 295 | node.total_value += value 296 | node = node.parent 297 | 298 | # Get results 299 | results = [] 300 | for workers, correct_answer in all_workers: 301 | best_results = await asyncio.gather(*[ 302 | mcts_manager.find_best_leaf(w.root, session) for w in workers 303 | ]) 304 | best_leaf = max(best_results, key=lambda x: x.prm_value or 0.0) 305 | results.append({ 306 | "state": best_leaf.state, 307 | "score": best_leaf.prm_value, 308 | "correct": correct_answer in best_leaf.state 309 | }) 310 | 311 | return results 312 | 313 | async def main(): 314 | # Example dataset loading and processing 315 | aime = load_dataset("AI-MO/aimo-validation-aime", split="train") 316 | initial_states = [ 317 | (ex["problem"], str(int(ex["answer"]))) 318 | for ex in aime 319 | ] 320 | initial_states = random.sample(initial_states, 10) 321 | 322 | results = await run_batch_mcts(initial_states) 323 | accuracy = sum(r['correct'] for r in results) / len(results) 324 | print(f"Accuracy: {accuracy:.2f}") 325 | 326 | if __name__ == "__main__": 327 | asyncio.run(main()) -------------------------------------------------------------------------------- /best_of_n.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import asyncio 3 | import json 4 | import os 5 | import re 6 | from tqdm.asyncio import tqdm_asyncio 7 | from tqdm import tqdm 8 | from typing import List, Dict, Any, Tuple 9 | import gc 10 | from datetime import datetime 11 | from datasets import load_dataset 12 | 13 | # 1 sample 14 | # Overall Statistics: 15 | # Total correct answers: 58/100 16 | # Accuracy: 0.58 17 | # Average reward score: 0.0929 18 | 19 | # best of 2 20 | # Overall Statistics: 21 | # Total correct answers: 63/100 22 | # Accuracy: 0.63 23 | # Average reward score: 0.1148 24 | 25 | # best of 4 26 | # Overall Statistics: 27 | # Total correct answers: 64/100 28 | # Accuracy: 0.64 29 | # Average reward score: 0.1257 30 | 31 | # best of 8 32 | # Overall Statistics: 33 | # Total correct answers: 63/100 34 | # Accuracy: 0.63 35 | # Average reward score: 0.1307 36 | 37 | # best of 16 38 | # Overall Statistics: 39 | # Total correct answers: 70/100 40 | # Accuracy: 0.70 41 | # Average reward score: 0.1345 42 | 43 | # best of 32 44 | # Overall Statistics: 45 | # Total correct answers: 67/100 46 | # Accuracy: 0.67 47 | # Average reward score: 0.1380 48 | 49 | # best of 64 50 | # Overall Statistics: 51 | # Total correct answers: 67/100 52 | # Accuracy: 0.67 53 | # Average reward score: 0.1461 54 | 55 | 56 | 57 | # Configuration 58 | FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY') 59 | if not FIREWORKS_API_KEY: 60 | raise ValueError("FIREWORKS_API_KEY environment variable must be set") 61 | 62 | FIREWORKS_API_ENDPOINT = "https://api.fireworks.ai/inference/v1/chat/completions" 63 | REWARD_MODEL_ENDPOINT = "https://rawsh--reward-api-model-score.modal.run" 64 | 65 | # Separate rate limits for different APIs 66 | LLM_MAX_CONCURRENT = 100 67 | REWARD_MAX_CONCURRENT = 32 68 | BATCH_SIZE = 10 69 | 70 | BEST_OF_N = 2 71 | MODEL_NAME = "accounts/fireworks/models/llama-v3p1-8b-instruct" 72 | 73 | # Timeout configurations 74 | REWARD_MODEL_TIMEOUT = 20 75 | LLM_TIMEOUT = 10 76 | REWARD_MODEL_MAX_RETRIES = 3 77 | 78 | class APIError(Exception): 79 | """Custom exception for API-related errors""" 80 | pass 81 | 82 | async def with_retry(func, max_retries=3, base_delay=1): 83 | """Generic retry wrapper with exponential backoff""" 84 | for i in range(max_retries): 85 | try: 86 | return await func() 87 | except Exception as e: 88 | if i == max_retries - 1: 89 | raise 90 | delay = base_delay * (2 ** i) 91 | print(f"Attempt {i+1} failed: {str(e)}. Retrying in {delay}s...") 92 | await asyncio.sleep(delay) 93 | 94 | def extract_answer(completion: str) -> Tuple[str, str]: 95 | """Extract the final answer from the completion.""" 96 | match = re.search(r"Answer:\s*([A-Z])", completion, re.IGNORECASE) 97 | if match: 98 | return completion.strip(), match.group(1).upper() 99 | # Fallback: look for the last letter A-Z in the completion 100 | letters = re.findall(r'[A-Z]', completion, re.IGNORECASE) 101 | return completion.strip(), letters[-1].upper() if letters else "" 102 | 103 | async def get_reward_score( 104 | reward_sem: asyncio.Semaphore, 105 | session: aiohttp.ClientSession, 106 | messages: List[Dict[str, str]] 107 | ) -> float: 108 | """Get reward model score for a completion.""" 109 | async def _get_score(): 110 | async with reward_sem: 111 | try: 112 | async with session.post( 113 | REWARD_MODEL_ENDPOINT, 114 | json={"messages": messages}, 115 | headers={"Content-Type": "application/json"}, 116 | timeout=aiohttp.ClientTimeout(total=REWARD_MODEL_TIMEOUT) 117 | ) as response: 118 | if response.status != 200: 119 | text = await response.text() 120 | print(f"Error {response.status}: {text}") 121 | raise APIError(f"Reward API returned status {response.status}") 122 | result = await response.json() 123 | return float(result.get('score', 0)) 124 | except asyncio.TimeoutError: 125 | print("Reward model request timed out") 126 | raise 127 | except Exception as e: 128 | print(f"Exception in get_reward_score: {str(e)}") 129 | raise 130 | 131 | try: 132 | return await with_retry(_get_score, max_retries=REWARD_MODEL_MAX_RETRIES) 133 | except Exception: 134 | print("All reward score attempts failed") 135 | return 0.0 136 | 137 | async def verify_answer( 138 | llm_sem: asyncio.Semaphore, 139 | session: aiohttp.ClientSession, 140 | student_answer: str, 141 | correct_idx: int 142 | ) -> float: 143 | """Verify if the student's answer is correct.""" 144 | # Convert index to letter (0 -> A, 1 -> B, etc.) 145 | correct_letter = chr(65 + correct_idx) # 65 is ASCII for 'A' 146 | return 1.0 if student_answer.upper() == correct_letter else 0.0 147 | 148 | async def get_completions( 149 | llm_sem: asyncio.Semaphore, 150 | reward_sem: asyncio.Semaphore, 151 | session: aiohttp.ClientSession, 152 | question: str, 153 | choices: List[str], 154 | n: int 155 | ) -> List[Tuple[str, str, float]]: 156 | """Generate n completions and get their reward scores.""" 157 | # Format question with options 158 | formatted_question = f"{question}\n\nOptions:\n" 159 | for idx, choice in enumerate(choices): 160 | formatted_question += f"{chr(65 + idx)}) {choice}\n" 161 | 162 | print(f"\n[Generating {n} completions for question]") 163 | print(f"Q: {formatted_question}") 164 | 165 | USER_PROMPT = ("You are an expert at truthful reasoning and you always pick the most accurate answer. " 166 | "Think step by step and output your reasoning followed by your final answer using the following format:\n" 167 | "Answer: X where X is one of the available letter options.\n\n") 168 | 169 | async def _get_completions(): 170 | async with llm_sem: 171 | messages = [ 172 | {"role": "user", "content": ( 173 | USER_PROMPT + 174 | f"{formatted_question}" 175 | )} 176 | ] 177 | 178 | async with session.post( 179 | FIREWORKS_API_ENDPOINT, 180 | headers={ 181 | "Accept": "application/json", 182 | "Content-Type": "application/json", 183 | "Authorization": f"Bearer {FIREWORKS_API_KEY}" 184 | }, 185 | json={ 186 | "model": MODEL_NAME, 187 | "messages": messages, 188 | "n": n, 189 | "temperature": 0.7, 190 | "max_tokens": 4096, 191 | "top_p": 1, 192 | "top_k": 40, 193 | "presence_penalty": 0, 194 | "frequency_penalty": 0 195 | }, 196 | timeout=aiohttp.ClientTimeout(total=LLM_TIMEOUT) 197 | ) as response: 198 | if response.status != 200: 199 | text = await response.text() 200 | raise APIError(f"OpenAI API returned status {response.status}: {text}") 201 | 202 | return await response.json() 203 | 204 | try: 205 | result = await with_retry(_get_completions) 206 | completions = [] 207 | 208 | # Get reward scores for each completion 209 | for choice in result["choices"]: 210 | full_completion, extracted_answer = extract_answer( 211 | choice["message"]["content"].strip() 212 | ) 213 | 214 | # Get reward score 215 | reward_score = await get_reward_score( 216 | reward_sem, 217 | session, 218 | [ 219 | {"role": "user", "content": USER_PROMPT + formatted_question}, 220 | {"role": "assistant", "content": full_completion} 221 | ] 222 | ) 223 | 224 | completions.append((full_completion, extracted_answer, reward_score)) 225 | 226 | # Log results 227 | print("\n[Completion Results]") 228 | for i, (_, answer, score) in enumerate(completions, 1): 229 | print(f" {i}. {answer:<40} [reward: {score:.4f}]") 230 | 231 | return completions 232 | 233 | except Exception as e: 234 | print(f"Exception in get_completions: {str(e)}") 235 | return [("", "", 0.0)] * n 236 | 237 | async def evaluate_question( 238 | llm_sem: asyncio.Semaphore, 239 | reward_sem: asyncio.Semaphore, 240 | session: aiohttp.ClientSession, 241 | example 242 | ) -> Dict[str, Any]: 243 | """Evaluate a single question with best-of-n completions.""" 244 | question = example['question'] 245 | mc1_targets = example['mc1_targets'] 246 | choices = mc1_targets['choices'] 247 | correct_idx = mc1_targets['labels'].index(1) # Find index where label is 1 248 | 249 | # Get n completions with reasoning, extracted answers, and reward scores 250 | completion_data = await get_completions(llm_sem, reward_sem, session, question, choices, BEST_OF_N) 251 | completions, extracted_answers, reward_scores = zip(*completion_data) 252 | 253 | # Use reward scores to pick the best completion 254 | best_idx = reward_scores.index(max(reward_scores)) 255 | best_completion = completions[best_idx] 256 | best_extracted = extracted_answers[best_idx] 257 | 258 | # Verify correctness of the best answer 259 | correctness_score = await verify_answer( 260 | llm_sem, session, best_extracted, correct_idx 261 | ) 262 | 263 | return { 264 | 'question': question, 265 | 'choices': choices, 266 | 'correct_answer': chr(65 + correct_idx), # Convert index to letter 267 | 'completions': completions, 268 | 'extracted_answers': extracted_answers, 269 | 'reward_scores': reward_scores, 270 | 'best_reward_score': reward_scores[best_idx], 271 | 'best_completion': best_completion, 272 | 'best_extracted_answer': best_extracted, 273 | 'is_correct': bool(correctness_score) 274 | } 275 | 276 | async def process_batch( 277 | llm_sem: asyncio.Semaphore, 278 | reward_sem: asyncio.Semaphore, 279 | session: aiohttp.ClientSession, 280 | batch_data: List[Dict] 281 | ) -> List[Dict[str, Any]]: 282 | """Process a batch of questions.""" 283 | batch_requests = [ 284 | evaluate_question(llm_sem, reward_sem, session, example) 285 | for example in batch_data 286 | ] 287 | return await tqdm_asyncio.gather(*batch_requests) 288 | 289 | async def evaluate_all(session: aiohttp.ClientSession, dataset) -> List[Dict[str, Any]]: 290 | """Evaluate all questions in the dataset using batching.""" 291 | llm_sem = asyncio.Semaphore(LLM_MAX_CONCURRENT) 292 | reward_sem = asyncio.Semaphore(REWARD_MAX_CONCURRENT) 293 | 294 | # Convert dataset to list of dictionaries for easier processing 295 | dataset_dicts = [ 296 | { 297 | 'question': item['question'], 298 | 'mc1_targets': item['mc1_targets'] 299 | } 300 | for item in dataset 301 | ] 302 | 303 | import random 304 | random.seed(42) 305 | random.shuffle(dataset_dicts) 306 | dataset_dicts = dataset_dicts[:100] 307 | 308 | results = [] 309 | print(f"\nEvaluating {len(dataset_dicts)} questions with {BEST_OF_N} completions each...") 310 | 311 | # Process in batches 312 | for i in range(0, len(dataset_dicts), BATCH_SIZE): 313 | batch_data = dataset_dicts[i:i + BATCH_SIZE] 314 | print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{(len(dataset_dicts) + BATCH_SIZE - 1)//BATCH_SIZE}") 315 | 316 | batch_results = await process_batch(llm_sem, reward_sem, session, batch_data) 317 | results.extend(batch_results) 318 | 319 | # Periodic cleanup 320 | gc.collect() 321 | await asyncio.sleep(1) # Small delay between batches 322 | 323 | return results 324 | 325 | async def main(): 326 | try: 327 | # Load TruthfulQA dataset 328 | dataset = load_dataset("truthful_qa", "multiple_choice") 329 | validation_set = dataset["validation"] 330 | print(f"Loaded {len(validation_set)} questions from TruthfulQA validation set") 331 | 332 | # Configure session with connection pooling 333 | connector = aiohttp.TCPConnector( 334 | limit=max(LLM_MAX_CONCURRENT, REWARD_MAX_CONCURRENT), 335 | force_close=True 336 | ) 337 | timeout = aiohttp.ClientTimeout(total=60) 338 | 339 | # Create timestamp for output file 340 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 341 | 342 | async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: 343 | results = await evaluate_all(session, validation_set) 344 | 345 | if results: 346 | print("\nOverall Statistics:") 347 | correct_count = sum(1 for r in results if r['is_correct']) 348 | total_count = len(results) 349 | 350 | print(f"Total correct answers: {correct_count}/{total_count}") 351 | print(f"Accuracy: {correct_count/total_count:.2f}") 352 | print(f"Average reward score: {sum(r['best_reward_score'] for r in results)/total_count:.4f}") 353 | 354 | # Save results 355 | output_file = f'truthfulqa_mc_results.json' 356 | with open(output_file, 'w') as f: 357 | json.dump(results, f, indent=2) 358 | print(f"\nDetailed results saved to {output_file}") 359 | 360 | except Exception as e: 361 | print(f"Error in main: {str(e)}") 362 | raise 363 | finally: 364 | if 'connector' in locals() and hasattr(connector, 'close'): 365 | await connector.close() 366 | 367 | if __name__ == "__main__": 368 | asyncio.run(main()) --------------------------------------------------------------------------------