├── asset └── Tiny_po.png ├── set.sh ├── LICENSE ├── README.md ├── test.py ├── gsm8k_data.py ├── train.py ├── GRPO.ipynb └── grpo.py /asset/Tiny_po.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangyuan-ksgk/Tiny-GRPO/HEAD/asset/Tiny_po.png -------------------------------------------------------------------------------- /set.sh: -------------------------------------------------------------------------------- 1 | pip install tf-keras 2 | pip install flash-attn 3 | pip install wandb 4 | pip install 'accelerate>=0.26.0' 5 | pip install transformers 6 | pip install datasets -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Fangyuan Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ![Tiny_po](https://github.com/user-attachments/assets/82a6488e-9434-4192-a97c-0d4af4823f8d) 4 | 5 |
6 | 7 | Minimal implementation of **Group Relative Policy Optimization (GRPO)** (DeepSeek) from scratch. No complicated file structure—just a **simple, hackable implementation** with few scripts for better understanding of the algorithm. 8 | 9 | Inspired by the implementation by [@aburkov](https://github.com/aburkov). This implementation optimizes **memory usage** during training by: 10 | - Using **chunk-wise softmax operations** 11 | - Leveraging **mixed precision training** 12 | Together, these techniques reduce memory usage by **50%**, enabling GRPO to run on singel GPU while achieving strong results on **math datasets**. 13 | 14 | set up environment 15 | ```bash 16 | bash set.sh 17 | ``` 18 | 19 | train GRPO on gsm8k dataset (Qwen-2.5-Instruct-1.5B) 20 | ```bash 21 | python train.py 22 | ``` 23 | 24 | test model output 25 | ```bash 26 | python test.py 27 | ``` 28 | 29 | 🤝 Contributing 30 | Feel free to submit issues, PRs, or suggestions to improve the implementation! 31 | 32 | ⚡ Acknowledgments 33 | Inspired by @aburkov's work in The LM Book (https://github.com/aburkov/theLMbook). 34 | 35 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import torch 3 | import os 4 | from gsm8k_data import SYSTEM_PROMPT, build_prompt, extract_answer_from_model_output 5 | 6 | def main(): 7 | """ 8 | Main function to load the fine-tuned model and test it on example math problems. 9 | 10 | Explanation: 11 | 1. Determines the device (GPU if available, otherwise CPU). 12 | 2. Loads the fine-tuned model and tokenizer from the saved path. 13 | 3. Tests the model on predefined math problems. 14 | 4. Formats the prompt using the same SYSTEM_PROMPT and build_prompt function as training. 15 | 5. Generates and displays responses for each test prompt. 16 | """ 17 | # Determine the device: use GPU if available, else fallback to CPU. 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | print(f"Using device: {device}") 20 | 21 | # Load the saved model and tokenizer 22 | saved_model_path = "grpo_finetuned_model" 23 | 24 | 25 | # Load the model 26 | loaded_model = AutoModelForCausalLM.from_pretrained( 27 | saved_model_path, 28 | torch_dtype=torch.bfloat16, 29 | device_map="auto" 30 | ) 31 | 32 | loaded_tokenizer = AutoTokenizer.from_pretrained(saved_model_path) 33 | loaded_tokenizer.pad_token = loaded_tokenizer.eos_token 34 | 35 | # Define test prompts 36 | prompts_to_test = [ 37 | "How much is 1+1?", 38 | "I have 3 apples, my friend eats one and I give 2 to my sister, how many apples do I have now?", 39 | "Solve the equation 6x + 4 = 40" 40 | ] 41 | 42 | # Test each prompt 43 | for prompt in prompts_to_test: 44 | # Prepare the prompt using the same format as during training 45 | test_messages = [ 46 | {"role": "system", "content": SYSTEM_PROMPT}, 47 | {"role": "user", "content": prompt} 48 | ] 49 | test_prompt = build_prompt(test_messages) 50 | 51 | # Tokenize the prompt and generate a response 52 | test_input_ids = loaded_tokenizer.encode(test_prompt, return_tensors="pt").to(device) 53 | 54 | # Generate response with similar parameters to those used in training 55 | with torch.no_grad(): 56 | test_output_ids = loaded_model.generate( 57 | test_input_ids, 58 | max_new_tokens=400, 59 | temperature=0.7, 60 | num_return_sequences=1, 61 | pad_token_id=loaded_tokenizer.pad_token_id, 62 | eos_token_id=loaded_tokenizer.eos_token_id, 63 | do_sample=True, 64 | early_stopping=False 65 | ) 66 | 67 | test_response = loaded_tokenizer.decode(test_output_ids[0], skip_special_tokens=True) 68 | 69 | # Print the test prompt and the model's response 70 | print("\nTest Prompt:") 71 | print(test_prompt) 72 | print("\nModel Response:") 73 | print(test_response) 74 | 75 | # Extract and display the answer part for easier evaluation 76 | try: 77 | extracted_answer = extract_answer_from_model_output(test_response) 78 | print("\nExtracted Answer:") 79 | print(extracted_answer) 80 | print("-" * 50) 81 | except Exception as e: 82 | print(f"\nFailed to extract answer: {e}") 83 | print("-" * 50) 84 | 85 | if __name__ == "__main__": 86 | main() -------------------------------------------------------------------------------- /gsm8k_data.py: -------------------------------------------------------------------------------- 1 | # util for data preparation (GSM8K) dataset 2 | 3 | import re 4 | from datasets import load_dataset 5 | 6 | SYSTEM_PROMPT = """ 7 | Respond in the following format: 8 | 9 | ... 10 | 11 | 12 | ... 13 | 14 | """ 15 | 16 | def extract_gsm8k_answer(text): 17 | """ 18 | Extracts the value from the last ... tag in the text. 19 | - very strict without much diversity in the extractor functional here 20 | """ 21 | # Split on and take everything after the last occurrence 22 | parts = text.split("") 23 | if len(parts) < 2: # No tag found 24 | return None 25 | last_part = parts[-1] 26 | 27 | # Extract content up to 28 | if "" not in last_part: 29 | return None 30 | answer = last_part.split("")[0].strip() 31 | return None if answer == "..." else answer 32 | 33 | 34 | def extract_gsm8k_answer_from_dataset(text): 35 | """ 36 | Extracts the answer from the GSM8K dataset examples. 37 | - specific prompt ask for direct answer following '####' symbol 38 | """ 39 | if "####" not in text: 40 | return None 41 | return text.split("####")[1].strip() 42 | 43 | 44 | def build_prompt(messages): 45 | """simple change line combination of all response without any identifier whatsoever""" 46 | return "\n".join([msg["content"].strip() for msg in messages]) 47 | 48 | 49 | def prepare_dataset(split="train"): 50 | """Load and prepare the GSM8K dataset for training with string prompts.""" 51 | data = load_dataset('openai/gsm8k', 'main')[split] 52 | formatted_data = [] 53 | for example in data: 54 | # Convert list of messages to a single string prompt. 55 | prompt_str = build_prompt([ 56 | {"role": "system", "content": SYSTEM_PROMPT}, 57 | {"role": "user", "content": example["question"]} 58 | ]) 59 | formatted_example = { 60 | "prompt": prompt_str, # Now a string rather than a list. 61 | "answer": extract_gsm8k_answer_from_dataset(example["answer"]) 62 | } 63 | formatted_data.append(formatted_example) 64 | return formatted_data 65 | 66 | 67 | def extract_last_number(text): 68 | text = text.replace('$', '').replace('%', '') 69 | pattern = r'(?:^|\s|=)\s*(-?\d*\.?\d+)\s*$' 70 | match = re.search(pattern, text) 71 | return float(match.group(1)) if match else None 72 | 73 | 74 | def extract_single_number(text): 75 | numbers = re.findall(r'-?\d*\.?\d+', text) 76 | return float(numbers[0]) if len(numbers) == 1 else None 77 | 78 | 79 | def gsm8k_metric(predicted: str, expected: str) -> tuple[bool, float]: 80 | if predicted == expected: # Exact match 81 | is_correct = True 82 | reward = 2.0 83 | else: 84 | # Try single number matching 85 | pred_num = extract_single_number(str(predicted)) 86 | exp_num = extract_single_number(str(expected)) 87 | if pred_num is not None and exp_num is not None and pred_num == exp_num: 88 | is_correct = True 89 | reward = 1.5 90 | else: 91 | # the way I view this, it's just a metric-based evaluation functional 92 | # Try last number matching 93 | pred_num = extract_last_number(str(predicted)) 94 | exp_num = extract_last_number(str(expected)) 95 | is_correct = (pred_num is not None and exp_num is not None and 96 | pred_num == exp_num) 97 | reward = 0.0 98 | return is_correct, reward 99 | 100 | 101 | def functional_reward_fn(completions, answers): 102 | responses = [completion[0]['content'] for completion in completions] 103 | extracted = [extract_gsm8k_answer(response) for response in responses] 104 | rewards = [] 105 | for pred, exp in zip(extracted, answers): 106 | is_correct, reward = gsm8k_metric(pred, exp) 107 | rewards.append(reward) 108 | # count number of words in the response (not token, not character, words) (interesting...) 109 | completion_lengths = [len(response.split()) for response in responses] 110 | return rewards 111 | 112 | 113 | def structural_reward_fn(completions): 114 | responses = [completion[0]['content'] for completion in completions] 115 | rewards = [] 116 | format_scores = [] 117 | for response in responses: 118 | score = 0.0 119 | if "" in response: score += 0.2 120 | if "" in response: score += 0.2 121 | if "" in response: score += 0.2 122 | if "" in response: score += 0.2 123 | rewards.append(score) 124 | format_scores.append(score) 125 | return rewards 126 | 127 | 128 | # trial and error here, for better weightage between the reward components ... 129 | def reward_fn(completions, answers): 130 | functional_rewards = functional_reward_fn(completions, answers) 131 | structural_rewards = structural_reward_fn(completions) 132 | 133 | combined_rewards = [] 134 | for f_score, s_score in zip(functional_rewards, structural_rewards): 135 | # Correctness score range: 0.0 to 2.0 136 | # Format score range: 0.0 to 0.8 137 | # Total range: 0.0 to 2.8 138 | combined_rewards.append(f_score + s_score) 139 | 140 | return combined_rewards -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gsm8k_data import extract_gsm8k_answer, gsm8k_metric 3 | import os 4 | from contextlib import nullcontext 5 | 6 | 7 | def setup_training_environment(device_name, dtype = "bfloat16"): 8 | """set up mixed precision (for memory optimization)""" 9 | 10 | # Set up random seed 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | torch.backends.cudnn.allow_tf32 = True 13 | 14 | # Set up mixed precision 15 | device_type = 'cuda' if 'cuda' in device_name else 'cpu' 16 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 17 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 18 | 19 | return { 20 | 'device': device_name, 21 | 'ctx': ctx, 22 | 'device_type': device_type, 23 | } 24 | 25 | env = setup_training_environment("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | def evaluate_model(model, tokenizer, eval_examples, device, env, max_completion_length): 28 | model.eval() 29 | correct = 0 30 | total = len(eval_examples) 31 | print("\n" + "="*50) 32 | print("EVALUATION ON", total, "EXAMPLES") 33 | print("="*50) 34 | 35 | for example in eval_examples: 36 | # Get the prompt and expected answer 37 | full_prompt = example["prompt"] 38 | expected = example["answer"] 39 | 40 | # Tokenize and generate response 41 | inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(device) 42 | with torch.no_grad(): 43 | with env['ctx']: 44 | outputs = model.generate( 45 | inputs, 46 | max_new_tokens=max_completion_length, 47 | temperature=0.7, 48 | num_return_sequences=1, 49 | pad_token_id=tokenizer.pad_token_id, 50 | eos_token_id=tokenizer.eos_token_id, 51 | forced_eos_token_id=tokenizer.eos_token_id, 52 | early_stopping=False, 53 | ) # 'forward generation --> RL on reward' | can we do the same with pre-training ? 54 | response = tokenizer.decode(outputs[0], skip_special_tokens=True) 55 | 56 | try: 57 | # 1. extract functional 58 | predicted = extract_gsm8k_answer(response) 59 | # 2. metric-based reward fn 60 | is_correct, _ = gsm8k_metric(predicted, expected) 61 | 62 | # Update counter for correct answers 63 | if is_correct: 64 | correct += 1 65 | 66 | # Print evaluation details 67 | print("\nPrompt:") 68 | print(full_prompt) 69 | print("\nExpected Answer:") 70 | print(expected) 71 | print("\nExtracted Answer:") 72 | print(predicted) 73 | print("\nFull Generated Response:") 74 | print(response) 75 | print("\nCorrect:", "✓" if is_correct else "✗") 76 | print("-"*50) 77 | 78 | except Exception as e: 79 | print("\nFailed to parse model output for prompt:") 80 | print(full_prompt) 81 | print("Error:", e) 82 | print("-"*50) 83 | 84 | # Calculate and print final accuracy 85 | accuracy = (correct / total) * 100 86 | print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{total})") 87 | print("="*50) 88 | 89 | # Return model to training mode 90 | model.train() 91 | return accuracy 92 | 93 | 94 | from transformers import AutoModelForCausalLM, AutoTokenizer 95 | from gsm8k_data import prepare_dataset, reward_fn 96 | import random 97 | from grpo import train_with_grpo, optimize_model_memory 98 | import os 99 | import wandb 100 | 101 | 102 | def main(args): 103 | device = torch.device("cuda:0" if torch.cuda.is_available() else "mps") 104 | print(f"Using primary device: {device}") 105 | 106 | model_name = args.model_name 107 | output_dir = args.output_dir 108 | 109 | print("Downloading model...") 110 | model = AutoModelForCausalLM.from_pretrained( 111 | model_name, 112 | torch_dtype=torch.bfloat16, 113 | device_map="auto" 114 | ) 115 | print("Model downloaded") 116 | 117 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") 118 | tokenizer.pad_token = tokenizer.eos_token 119 | model.config.pad_token_id = tokenizer.eos_token_id 120 | model.config.eos_token_id = tokenizer.eos_token_id 121 | 122 | all_data = prepare_dataset("train") 123 | random.shuffle(all_data) 124 | size_of_eval_data = 30 # change to a smaller value to save time or to a larger number for a more reliable estimate 125 | eval_data = all_data[:size_of_eval_data] 126 | train_data = all_data[size_of_eval_data:] 127 | 128 | model = optimize_model_memory(model) 129 | 130 | print("\nInitial model evaluation before finetuning:") 131 | pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device, env, args.max_completion_length) 132 | print(f"Pre-GRPO Accuracy: {pre_grpo_accuracy:.2f}%") 133 | 134 | wandb.init(project=os.environ["WANDB_PROJECT"], reinit=True) 135 | print("Weights & Biases initialized.") 136 | 137 | model = train_with_grpo( 138 | model=model, 139 | tokenizer=tokenizer, 140 | train_data=train_data, 141 | reward_function=reward_fn, 142 | num_iterations=args.num_iterations, 143 | num_steps=args.num_steps, 144 | batch_size=args.batch_size, 145 | num_generations=args.num_generations, 146 | max_completion_length=args.max_completion_length, 147 | beta=args.beta, 148 | learning_rate=args.learning_rate, 149 | mu=args.mu, 150 | epsilon=args.epsilon, 151 | env=env 152 | ) 153 | 154 | wandb.finish() 155 | print("Training completed and wandb run finished.") 156 | 157 | print("\nFinal model evaluation after GRPO RL fine-tuning:") 158 | post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device, env, args.max_completion_length) 159 | print(f"Post-GRPO Accuracy: {post_grpo_accuracy:.2f}%") 160 | 161 | print("\nSaving GRPO fine-tuned model...") 162 | model.save_pretrained("grpo_finetuned_model") 163 | tokenizer.save_pretrained("grpo_finetuned_model") 164 | 165 | 166 | if __name__ == "__main__": 167 | import argparse 168 | 169 | parser = argparse.ArgumentParser(description="Train with GRPO on GSM8K") 170 | parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct", 171 | help="Model name or path") 172 | parser.add_argument("--output_dir", type=str, default="./output", 173 | help="Directory to save the model") 174 | parser.add_argument("--num_iterations", type=int, default=2, 175 | help="Number of iterations") 176 | parser.add_argument("--num_steps", type=int, default=400, 177 | help="Number of steps per iteration") 178 | parser.add_argument("--batch_size", type=int, default=6, 179 | help="Batch size") 180 | parser.add_argument("--num_generations", type=int, default=8, 181 | help="Number of generations per example") 182 | parser.add_argument("--max_completion_length", type=int, default=512, 183 | help="Maximum completion length") 184 | parser.add_argument("--beta", type=float, default=0.1, 185 | help="Beta parameter for GRPO") 186 | parser.add_argument("--learning_rate", type=float, default=5e-6, 187 | help="Learning rate") 188 | parser.add_argument("--mu", type=float, default=6, 189 | help="Mu parameter for GRPO") 190 | parser.add_argument("--epsilon", type=float, default=0.1, 191 | help="Epsilon parameter for GRPO") 192 | 193 | args = parser.parse_args() 194 | main(args) -------------------------------------------------------------------------------- /GRPO.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "ec5d0e27-1bf4-4526-b0f1-8239c05525e7", 7 | "metadata": { 8 | "collapsed": true, 9 | "jupyter": { 10 | "outputs_hidden": true 11 | } 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "import torch, os\n", 16 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n", 17 | "from datasets import load_dataset\n", 18 | "from grpo import set_random_seed\n", 19 | "\n", 20 | "# Call the function to set random seed for reproducibility\n", 21 | "set_random_seed(42)\n", 22 | "\n", 23 | "# Set environment variables for Weights & Biases (wandb) logging\n", 24 | "os.environ[\"WANDB_API_KEY\"] = \"YOUR WANDB API KEY\"\n", 25 | "os.environ[\"WANDB_PROJECT\"] = \"YOUR WANDB PROJECT NAME\"" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "id": "3e35173a-b617-4dde-ba8c-d9ad20ea51ea", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "from gsm8k_data import gsm8k_metric, extract_gsm8k_answer\n", 36 | "\n", 37 | "def evaluate_model(model, tokenizer, eval_examples, device):\n", 38 | "\n", 39 | " model.eval()\n", 40 | " correct = 0\n", 41 | " total = len(eval_examples)\n", 42 | " print(\"\\n\" + \"=\"*50)\n", 43 | " print(\"EVALUATION ON\", total, \"EXAMPLES\")\n", 44 | " print(\"=\"*50)\n", 45 | "\n", 46 | " for example in eval_examples:\n", 47 | " # Get the prompt and expected answer\n", 48 | " full_prompt = example[\"prompt\"]\n", 49 | " expected = example[\"answer\"]\n", 50 | "\n", 51 | " # Tokenize and generate response\n", 52 | " inputs = tokenizer.encode(full_prompt, return_tensors=\"pt\").to(device)\n", 53 | " with torch.no_grad():\n", 54 | " outputs = model.generate(\n", 55 | " inputs,\n", 56 | " max_new_tokens=512,\n", 57 | " temperature=0.7,\n", 58 | " num_return_sequences=1,\n", 59 | " pad_token_id=tokenizer.pad_token_id,\n", 60 | " eos_token_id=tokenizer.eos_token_id,\n", 61 | " forced_eos_token_id=tokenizer.eos_token_id,\n", 62 | " early_stopping=False,\n", 63 | " ) # 'forward generation --> RL on reward' | can we do the same with pre-training ?\n", 64 | " response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", 65 | "\n", 66 | " try: \n", 67 | " # 1. extract functional \n", 68 | " predicted = extract_gsm8k_answer(response)\n", 69 | " # 2. metric-based reward fn\n", 70 | " is_correct = gsm8k_metric(predicted, expected)\n", 71 | "\n", 72 | " # Update counter for correct answers\n", 73 | " if is_correct:\n", 74 | " correct += 1\n", 75 | "\n", 76 | " # Print evaluation details\n", 77 | " print(\"\\nPrompt:\")\n", 78 | " print(full_prompt)\n", 79 | " print(\"\\nExpected Answer:\")\n", 80 | " print(expected)\n", 81 | " print(\"\\nExtracted Answer:\")\n", 82 | " print(predicted)\n", 83 | " print(\"\\nFull Generated Response:\")\n", 84 | " print(response)\n", 85 | " print(\"\\nCorrect:\", \"✓\" if is_correct else \"✗\")\n", 86 | " print(\"-\"*50)\n", 87 | "\n", 88 | " except Exception as e:\n", 89 | " print(\"\\nFailed to parse model output for prompt:\")\n", 90 | " print(full_prompt)\n", 91 | " print(\"Error:\", e)\n", 92 | " print(\"-\"*50)\n", 93 | "\n", 94 | " # Calculate and print final accuracy\n", 95 | " accuracy = (correct / total) * 100\n", 96 | " print(f\"\\nAccuracy: {accuracy:.2f}% ({correct}/{total})\")\n", 97 | " print(\"=\"*50)\n", 98 | "\n", 99 | " # Return model to training mode\n", 100 | " model.train()\n", 101 | " return accuracy" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "id": "41d10965-282c-42c6-a005-487c1ba6a8d8", 108 | "metadata": { 109 | "collapsed": true, 110 | "jupyter": { 111 | "outputs_hidden": true 112 | } 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "# Main execution\n", 117 | "from gsm8k_data import prepare_dataset, reward_fn\n", 118 | "from grpo import *\n", 119 | "\n", 120 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 121 | "print(f\"Using primary device: {device}\")\n", 122 | "\n", 123 | "model_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n", 124 | "output_dir = \"math_solver_model\"\n", 125 | "\n", 126 | "print(\"Downloading model...\")\n", 127 | "model = AutoModelForCausalLM.from_pretrained(\n", 128 | " model_name,\n", 129 | " torch_dtype=torch.bfloat16,\n", 130 | " device_map=\"auto\"\n", 131 | ")\n", 132 | "print(\"Model downloaded\")\n", 133 | "\n", 134 | "tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=\"left\")\n", 135 | "tokenizer.pad_token = tokenizer.eos_token\n", 136 | "model.config.pad_token_id = tokenizer.eos_token_id\n", 137 | "model.config.eos_token_id = tokenizer.eos_token_id\n", 138 | "\n", 139 | "num_gpus = torch.cuda.device_count()\n", 140 | "print(f\"Detected {num_gpus} GPUs\")\n", 141 | "device_ids = list(range(num_gpus)) if num_gpus > 1 else None\n", 142 | "\n", 143 | "all_data = prepare_dataset(\"train\")\n", 144 | "random.shuffle(all_data)\n", 145 | "size_of_eval_data = 30 # change to a smaller value to save time or to a larger number for a more reliable estimate\n", 146 | "eval_data = all_data[:size_of_eval_data]\n", 147 | "train_data = all_data[size_of_eval_data:]\n", 148 | "\n", 149 | "print(\"\\nInitial model evaluation before finetuning:\")\n", 150 | "pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)\n", 151 | "print(f\"Pre-GRPO Accuracy: {pre_grpo_accuracy:.2f}%\")\n", 152 | "\n", 153 | "model = optimize_model_memory(model)\n", 154 | "\n", 155 | "print(\"\\nStarting RL fine-tuning using GRPO...\")\n", 156 | "# This config was tested on a 8xA100 node, where each A100 is has 80GB of VRAM\n", 157 | "training_config = {\n", 158 | " 'num_iterations': 1,\n", 159 | " 'num_steps': 500,\n", 160 | " 'batch_size': 7, # reduce if you have fewer GPUs\n", 161 | " 'num_generations': 12, # reduce if you have GPUs with less VRAM\n", 162 | " 'max_completion_length': 400, # reduce if you have GPUs with less VRAM\n", 163 | " 'beta': 0.04,\n", 164 | " 'learning_rate': 5e-6,\n", 165 | " 'mu': 1,\n", 166 | " 'epsilon': 0.1\n", 167 | "}\n", 168 | "\n", 169 | "# Initialize Weights & Biases\n", 170 | "wandb.init(project=os.environ[\"WANDB_PROJECT\"], reinit=True)\n", 171 | "print(\"Weights & Biases initialized.\")\n", 172 | "\n", 173 | "model = train_with_grpo(\n", 174 | " model=model,\n", 175 | " tokenizer=tokenizer,\n", 176 | " train_data=train_data,\n", 177 | " reward_function=reward_fn,\n", 178 | " device_ids=device_ids,\n", 179 | " **training_config\n", 180 | ")\n", 181 | "\n", 182 | "wandb.finish()\n", 183 | "print(\"Training completed and wandb run finished.\")\n", 184 | "\n", 185 | "print(\"\\nFinal model evaluation after GRPO RL fine-tuning:\")\n", 186 | "post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)\n", 187 | "print(f\"Post-GRPO Accuracy: {post_grpo_accuracy:.2f}%\")\n", 188 | "\n", 189 | "print(\"\\nSaving GRPO fine-tuned model...\")\n", 190 | "model.save_pretrained(\"grpo_finetuned_model\")\n", 191 | "tokenizer.save_pretrained(\"grpo_finetuned_model\")" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "kernelspec": { 197 | "display_name": "base", 198 | "language": "python", 199 | "name": "python3" 200 | }, 201 | "language_info": { 202 | "codemirror_mode": { 203 | "name": "ipython", 204 | "version": 3 205 | }, 206 | "file_extension": ".py", 207 | "mimetype": "text/x-python", 208 | "name": "python", 209 | "nbconvert_exporter": "python", 210 | "pygments_lexer": "ipython3", 211 | "version": "3.11.5" 212 | } 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 5 216 | } 217 | -------------------------------------------------------------------------------- /grpo.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import wandb 3 | import torch 4 | import random 5 | import numpy as np 6 | import torch.nn as nn 7 | from contextlib import nullcontext 8 | 9 | def setup_training_environment(device_name, dtype = "bfloat16"): 10 | """set up mixed precision (for memory optimization)""" 11 | 12 | # Set up random seed 13 | torch.backends.cuda.matmul.allow_tf32 = True 14 | torch.backends.cudnn.allow_tf32 = True 15 | 16 | # Set up mixed precision 17 | device_type = 'cuda' if 'cuda' in device_name else 'cpu' 18 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 19 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 20 | 21 | return { 22 | 'device': device_name, 23 | 'ctx': ctx, 24 | 'device_type': device_type, 25 | } 26 | 27 | def get_memory_usage(): 28 | """Get current GPU memory usage in a human-readable format.""" 29 | if torch.cuda.is_available(): 30 | allocated = torch.cuda.memory_allocated() / (1024 ** 3) # Convert to GB 31 | reserved = torch.cuda.memory_reserved() / (1024 ** 3) # Convert to GB 32 | return f"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" 33 | return "CUDA not available" 34 | 35 | def set_random_seed(seed: int = 42): 36 | """Set the random seed for reproducibility across Python, NumPy, and PyTorch.""" 37 | # Set the seed for Python's built-in random module 38 | random.seed(seed) 39 | # Set the seed for NumPy 40 | np.random.seed(seed) 41 | # Set the seed for PyTorch 42 | torch.manual_seed(seed) 43 | if torch.cuda.is_available(): 44 | torch.cuda.manual_seed_all(seed) 45 | # Ensure deterministic behavior in cuDNN (may impact performance) 46 | torch.backends.cudnn.deterministic = True 47 | torch.backends.cudnn.benchmark = False 48 | 49 | def selective_log_softmax(logits, input_ids, chunk_size=64): 50 | """Process in chunks to reduce peak memory""" 51 | device = logits.device 52 | batch_size, seq_len, vocab_size = logits.shape 53 | log_probs = torch.zeros(batch_size, seq_len, device=device) 54 | 55 | for i in range(0, seq_len, chunk_size): 56 | end_idx = min(i + chunk_size, seq_len) 57 | chunk_logits = logits[:, i:end_idx, :] 58 | chunk_ids = input_ids[:, i:end_idx] 59 | chunk_log_probs = nn.functional.log_softmax(chunk_logits, dim=-1) 60 | # print(" - chunkwise softmax computation GPU memory: ", get_memory_usage()) 61 | log_probs[:, i:end_idx] = chunk_log_probs.gather( 62 | dim=-1, index=chunk_ids.unsqueeze(-1)).squeeze(-1) 63 | del chunk_logits, chunk_log_probs 64 | torch.cuda.empty_cache() 65 | return log_probs 66 | 67 | def compute_log_probs(model, input_ids, attention_mask, logits_to_keep, env, chunk_size=64): 68 | with env['ctx']: 69 | logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[:, :-1, :] 70 | input_ids = input_ids[:, -logits_to_keep:] 71 | logits = logits[:, -logits_to_keep:, :] 72 | return selective_log_softmax(logits, input_ids, chunk_size) 73 | 74 | def create_completion_mask(completion_ids, eos_token_id): 75 | 76 | # ----- TBD: replace this with a less hacky solution ----- 77 | is_eos = completion_ids == eos_token_id 78 | # shape: (batch_size,) value: max length of completion 79 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) 80 | mask_exists = is_eos.any(dim=1) # operate on sequence with eos_token 81 | eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists] # first eos token index | hacky: argmax returns first index of largest value 82 | # ----- TBD: replace above with a less hacky solution ----- 83 | 84 | # create indices of tokens, then build non-end mask by comparing with eos_token index (first appearance indicate termination) 85 | sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) # (batch_size, max_length) 86 | return (sequence_indices < eos_idx.unsqueeze(1)).int() 87 | 88 | def generate_completions(model, tokenizer, prompts, device, num_generations=4, max_completion_length=32, env=None): 89 | """Generate multiple completions for each prompt, record completion mask (end-of-sequence)""" 90 | inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") 91 | prompt_ids = inputs["input_ids"].to(device) 92 | prompt_mask = inputs["attention_mask"].to(device) 93 | prompt_length = prompt_ids.size(1) # Question: sequences within batch should have different length? This leads to wrong completion mask? 94 | prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0) 95 | prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0) 96 | with env['ctx']: 97 | outputs = model.generate( 98 | prompt_ids, 99 | attention_mask=prompt_mask, 100 | max_new_tokens=max_completion_length, 101 | do_sample=True, 102 | temperature=1.0, 103 | pad_token_id=tokenizer.pad_token_id, 104 | eos_token_id=tokenizer.eos_token_id, 105 | early_stopping=False 106 | ) # Important: same length outputs from this generate function 107 | completion_ids = outputs[:, prompt_length:] 108 | completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id) # completion mask excludes eos_token and suffix tokens 109 | return prompt_ids, prompt_mask, completion_ids, completion_mask 110 | 111 | 112 | def generate_rollout_data(model, ref_model, tokenizer, batch_samples, device, num_generations, max_completion_length, env, chunk=64): 113 | """Generate responses and calculate log-probabilities of each response under two model""" 114 | prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples] 115 | answers = [sample["answer"] if isinstance(sample, dict) else sample[1] for sample in batch_samples] 116 | with torch.no_grad(): 117 | prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions( 118 | model, tokenizer, prompts, device, num_generations, max_completion_length, env 119 | ) 120 | input_ids = torch.cat([prompt_ids, completion_ids], dim=1) 121 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) 122 | logits_to_keep = completion_ids.size(1) # same question: whether dim=1 is same across in-batch sequences 123 | # AR perplexity based RL (Issue #1) --> can we use 'skippy loss' 124 | # hidden-space RL (Idea #1) --> rollout on latent space? 125 | 126 | old_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep, env, chunk) 127 | ref_log_probs = compute_log_probs(ref_model, input_ids, attention_mask, logits_to_keep, env, chunk) # ref is base model ? (also used in KL regularization I recall) 128 | 129 | # one-turn completion & reward assignment (Issue #2) 130 | formatted_completions = [[{'content': tokenizer.decode(ids, skip_special_tokens=True)}] for ids in completion_ids] 131 | repeated_prompts = [p for p in prompts for _ in range(num_generations)] 132 | repeated_answers = [a for a in answers for _ in range(num_generations)] 133 | return { 134 | "input_ids": input_ids, 135 | "attention_mask": attention_mask, 136 | "completion_mask": completion_mask, 137 | "old_log_probs": old_log_probs, 138 | "ref_log_probs": ref_log_probs, 139 | "formatted_completions": formatted_completions, 140 | "repeated_prompts": repeated_prompts, 141 | "repeated_answers": repeated_answers, 142 | "logits_to_keep": logits_to_keep, 143 | "batch_size": len(prompts), 144 | "num_generations": num_generations 145 | } 146 | 147 | # Issue #4. do we need to keep all 3 model? prev, curr, base? 148 | def grpo_loss(model, ref_model, rollout_data, tokenizer, reward_function, 149 | device, beta=0.01, epsilon=0.2, env=None, chunk=64): 150 | """ 151 | GRPO loss function: 152 | - group normalized reward 153 | - conservative advantage clipping 154 | - kl regularization 155 | """ 156 | input_ids = rollout_data["input_ids"] 157 | attention_mask = rollout_data["attention_mask"] 158 | completion_mask = rollout_data["completion_mask"] 159 | logits_to_keep = rollout_data["logits_to_keep"] 160 | old_log_probs = rollout_data["old_log_probs"] 161 | ref_log_probs = rollout_data["ref_log_probs"] 162 | 163 | new_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep, env, chunk) 164 | 165 | # ratio: between new and old 166 | ratio = torch.exp(new_log_probs - old_log_probs) 167 | rewards = torch.tensor( 168 | reward_function(completions=rollout_data["formatted_completions"], answers=rollout_data["repeated_answers"]), 169 | dtype=torch.float32, 170 | device=device 171 | ) 172 | 173 | 174 | batch_size = rollout_data["batch_size"] 175 | num_generations = rollout_data["num_generations"] 176 | 177 | # group refers to 'num_genereations' over each prompt, literally 'best-of-N', relative advantage is calculated here 178 | rewards = rewards.view(batch_size, num_generations) 179 | avg_reward = rewards.mean().item() 180 | mean_rewards = rewards.mean(dim=1).repeat_interleave(num_generations) 181 | std_rewards = rewards.std(dim=1).repeat_interleave(num_generations) 182 | advantages = ((rewards.view(-1) - mean_rewards) / (std_rewards + 1e-4)).unsqueeze(1) 183 | surrogate_loss = torch.min(ratio * advantages, torch.clamp(ratio, 1-epsilon, 1+epsilon) * advantages) 184 | kl = torch.exp(ref_log_probs - new_log_probs) - (ref_log_probs - new_log_probs) - 1 185 | per_token_loss = surrogate_loss - beta * kl 186 | loss = - ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 187 | del kl, surrogate_loss, per_token_loss 188 | 189 | return loss, avg_reward 190 | 191 | 192 | def train_with_grpo(model, tokenizer, train_data, 193 | num_iterations=1, num_steps=500, 194 | batch_size=4, num_generations=4, max_completion_length=128, 195 | beta=0.1, learning_rate=5e-6, mu=3, epsilon=0.2, reward_function=None, 196 | env=None, gradient_accumulation_steps=1): 197 | 198 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 199 | 200 | for iteration in range(num_iterations): 201 | print(f"\nIteration {iteration+1}/{num_iterations}") 202 | 203 | ref_model = copy.deepcopy(model) 204 | ref_model.eval() 205 | for param in ref_model.parameters(): 206 | param.requires_grad = False 207 | print("Reference model created") 208 | 209 | # re-initialize optimizer 210 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 211 | model.train() 212 | 213 | for step in range(num_steps): 214 | print(f"\nStep {step+1}/{num_steps}") 215 | batch_samples = random.sample(train_data, batch_size) 216 | 217 | with torch.no_grad(): 218 | rollout_data = generate_rollout_data( 219 | model, 220 | ref_model, 221 | tokenizer, 222 | batch_samples, 223 | device, 224 | num_generations, 225 | max_completion_length, 226 | env 227 | ) 228 | print("\n\n------------------------ \n Example response: \n", rollout_data['formatted_completions'][0][0]['content']) 229 | # Clear cache after generating rollouts 230 | if torch.cuda.is_available(): 231 | torch.cuda.empty_cache() 232 | 233 | for grpo_iter in range(mu): 234 | print(f"GRPO inner loop {grpo_iter+1}/{mu}") 235 | loss, avg_reward = grpo_loss( 236 | model, 237 | ref_model, 238 | rollout_data, 239 | tokenizer, 240 | reward_function, 241 | device=device, 242 | beta=beta, 243 | epsilon=epsilon, 244 | env=env 245 | ) 246 | optimizer.zero_grad() 247 | loss.backward() 248 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) 249 | optimizer.step() 250 | 251 | # Clear cache after each iteration 252 | if torch.cuda.is_available(): 253 | torch.cuda.empty_cache() 254 | 255 | # log to wandb 256 | wandb.log({ 257 | "loss": loss.item(), 258 | "average_reward": avg_reward, 259 | "step": step + 1, 260 | "grpo_iter": grpo_iter + 1, 261 | }) 262 | 263 | 264 | 265 | print(f"Iteration {iteration+1}/{num_iterations}, Step {step+1}/{num_steps}, " 266 | f"GRPO iter {grpo_iter+1}/{mu}, loss: {loss.item():.4f}, reward: {avg_reward}") 267 | 268 | del loss # explicitly delete loss to free memory 269 | 270 | return model 271 | 272 | 273 | def optimize_model_memory(model): 274 | model.config.use_cache = False 275 | model = torch.compile(model) 276 | model.gradient_checkpointing_enable() 277 | return model --------------------------------------------------------------------------------