├── asset
└── Tiny_po.png
├── set.sh
├── LICENSE
├── README.md
├── test.py
├── gsm8k_data.py
├── train.py
├── GRPO.ipynb
└── grpo.py
/asset/Tiny_po.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fangyuan-ksgk/Tiny-GRPO/HEAD/asset/Tiny_po.png
--------------------------------------------------------------------------------
/set.sh:
--------------------------------------------------------------------------------
1 | pip install tf-keras
2 | pip install flash-attn
3 | pip install wandb
4 | pip install 'accelerate>=0.26.0'
5 | pip install transformers
6 | pip install datasets
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Fangyuan Yu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | 
4 |
5 |
6 |
7 | Minimal implementation of **Group Relative Policy Optimization (GRPO)** (DeepSeek) from scratch. No complicated file structure—just a **simple, hackable implementation** with few scripts for better understanding of the algorithm.
8 |
9 | Inspired by the implementation by [@aburkov](https://github.com/aburkov). This implementation optimizes **memory usage** during training by:
10 | - Using **chunk-wise softmax operations**
11 | - Leveraging **mixed precision training**
12 | Together, these techniques reduce memory usage by **50%**, enabling GRPO to run on singel GPU while achieving strong results on **math datasets**.
13 |
14 | set up environment
15 | ```bash
16 | bash set.sh
17 | ```
18 |
19 | train GRPO on gsm8k dataset (Qwen-2.5-Instruct-1.5B)
20 | ```bash
21 | python train.py
22 | ```
23 |
24 | test model output
25 | ```bash
26 | python test.py
27 | ```
28 |
29 | 🤝 Contributing
30 | Feel free to submit issues, PRs, or suggestions to improve the implementation!
31 |
32 | ⚡ Acknowledgments
33 | Inspired by @aburkov's work in The LM Book (https://github.com/aburkov/theLMbook).
34 |
35 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoModelForCausalLM
2 | import torch
3 | import os
4 | from gsm8k_data import SYSTEM_PROMPT, build_prompt, extract_answer_from_model_output
5 |
6 | def main():
7 | """
8 | Main function to load the fine-tuned model and test it on example math problems.
9 |
10 | Explanation:
11 | 1. Determines the device (GPU if available, otherwise CPU).
12 | 2. Loads the fine-tuned model and tokenizer from the saved path.
13 | 3. Tests the model on predefined math problems.
14 | 4. Formats the prompt using the same SYSTEM_PROMPT and build_prompt function as training.
15 | 5. Generates and displays responses for each test prompt.
16 | """
17 | # Determine the device: use GPU if available, else fallback to CPU.
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 | print(f"Using device: {device}")
20 |
21 | # Load the saved model and tokenizer
22 | saved_model_path = "grpo_finetuned_model"
23 |
24 |
25 | # Load the model
26 | loaded_model = AutoModelForCausalLM.from_pretrained(
27 | saved_model_path,
28 | torch_dtype=torch.bfloat16,
29 | device_map="auto"
30 | )
31 |
32 | loaded_tokenizer = AutoTokenizer.from_pretrained(saved_model_path)
33 | loaded_tokenizer.pad_token = loaded_tokenizer.eos_token
34 |
35 | # Define test prompts
36 | prompts_to_test = [
37 | "How much is 1+1?",
38 | "I have 3 apples, my friend eats one and I give 2 to my sister, how many apples do I have now?",
39 | "Solve the equation 6x + 4 = 40"
40 | ]
41 |
42 | # Test each prompt
43 | for prompt in prompts_to_test:
44 | # Prepare the prompt using the same format as during training
45 | test_messages = [
46 | {"role": "system", "content": SYSTEM_PROMPT},
47 | {"role": "user", "content": prompt}
48 | ]
49 | test_prompt = build_prompt(test_messages)
50 |
51 | # Tokenize the prompt and generate a response
52 | test_input_ids = loaded_tokenizer.encode(test_prompt, return_tensors="pt").to(device)
53 |
54 | # Generate response with similar parameters to those used in training
55 | with torch.no_grad():
56 | test_output_ids = loaded_model.generate(
57 | test_input_ids,
58 | max_new_tokens=400,
59 | temperature=0.7,
60 | num_return_sequences=1,
61 | pad_token_id=loaded_tokenizer.pad_token_id,
62 | eos_token_id=loaded_tokenizer.eos_token_id,
63 | do_sample=True,
64 | early_stopping=False
65 | )
66 |
67 | test_response = loaded_tokenizer.decode(test_output_ids[0], skip_special_tokens=True)
68 |
69 | # Print the test prompt and the model's response
70 | print("\nTest Prompt:")
71 | print(test_prompt)
72 | print("\nModel Response:")
73 | print(test_response)
74 |
75 | # Extract and display the answer part for easier evaluation
76 | try:
77 | extracted_answer = extract_answer_from_model_output(test_response)
78 | print("\nExtracted Answer:")
79 | print(extracted_answer)
80 | print("-" * 50)
81 | except Exception as e:
82 | print(f"\nFailed to extract answer: {e}")
83 | print("-" * 50)
84 |
85 | if __name__ == "__main__":
86 | main()
--------------------------------------------------------------------------------
/gsm8k_data.py:
--------------------------------------------------------------------------------
1 | # util for data preparation (GSM8K) dataset
2 |
3 | import re
4 | from datasets import load_dataset
5 |
6 | SYSTEM_PROMPT = """
7 | Respond in the following format:
8 |
9 | ...
10 |
11 |
12 | ...
13 |
14 | """
15 |
16 | def extract_gsm8k_answer(text):
17 | """
18 | Extracts the value from the last ... tag in the text.
19 | - very strict without much diversity in the extractor functional here
20 | """
21 | # Split on and take everything after the last occurrence
22 | parts = text.split("")
23 | if len(parts) < 2: # No tag found
24 | return None
25 | last_part = parts[-1]
26 |
27 | # Extract content up to
28 | if "" not in last_part:
29 | return None
30 | answer = last_part.split("")[0].strip()
31 | return None if answer == "..." else answer
32 |
33 |
34 | def extract_gsm8k_answer_from_dataset(text):
35 | """
36 | Extracts the answer from the GSM8K dataset examples.
37 | - specific prompt ask for direct answer following '####' symbol
38 | """
39 | if "####" not in text:
40 | return None
41 | return text.split("####")[1].strip()
42 |
43 |
44 | def build_prompt(messages):
45 | """simple change line combination of all response without any identifier whatsoever"""
46 | return "\n".join([msg["content"].strip() for msg in messages])
47 |
48 |
49 | def prepare_dataset(split="train"):
50 | """Load and prepare the GSM8K dataset for training with string prompts."""
51 | data = load_dataset('openai/gsm8k', 'main')[split]
52 | formatted_data = []
53 | for example in data:
54 | # Convert list of messages to a single string prompt.
55 | prompt_str = build_prompt([
56 | {"role": "system", "content": SYSTEM_PROMPT},
57 | {"role": "user", "content": example["question"]}
58 | ])
59 | formatted_example = {
60 | "prompt": prompt_str, # Now a string rather than a list.
61 | "answer": extract_gsm8k_answer_from_dataset(example["answer"])
62 | }
63 | formatted_data.append(formatted_example)
64 | return formatted_data
65 |
66 |
67 | def extract_last_number(text):
68 | text = text.replace('$', '').replace('%', '')
69 | pattern = r'(?:^|\s|=)\s*(-?\d*\.?\d+)\s*$'
70 | match = re.search(pattern, text)
71 | return float(match.group(1)) if match else None
72 |
73 |
74 | def extract_single_number(text):
75 | numbers = re.findall(r'-?\d*\.?\d+', text)
76 | return float(numbers[0]) if len(numbers) == 1 else None
77 |
78 |
79 | def gsm8k_metric(predicted: str, expected: str) -> tuple[bool, float]:
80 | if predicted == expected: # Exact match
81 | is_correct = True
82 | reward = 2.0
83 | else:
84 | # Try single number matching
85 | pred_num = extract_single_number(str(predicted))
86 | exp_num = extract_single_number(str(expected))
87 | if pred_num is not None and exp_num is not None and pred_num == exp_num:
88 | is_correct = True
89 | reward = 1.5
90 | else:
91 | # the way I view this, it's just a metric-based evaluation functional
92 | # Try last number matching
93 | pred_num = extract_last_number(str(predicted))
94 | exp_num = extract_last_number(str(expected))
95 | is_correct = (pred_num is not None and exp_num is not None and
96 | pred_num == exp_num)
97 | reward = 0.0
98 | return is_correct, reward
99 |
100 |
101 | def functional_reward_fn(completions, answers):
102 | responses = [completion[0]['content'] for completion in completions]
103 | extracted = [extract_gsm8k_answer(response) for response in responses]
104 | rewards = []
105 | for pred, exp in zip(extracted, answers):
106 | is_correct, reward = gsm8k_metric(pred, exp)
107 | rewards.append(reward)
108 | # count number of words in the response (not token, not character, words) (interesting...)
109 | completion_lengths = [len(response.split()) for response in responses]
110 | return rewards
111 |
112 |
113 | def structural_reward_fn(completions):
114 | responses = [completion[0]['content'] for completion in completions]
115 | rewards = []
116 | format_scores = []
117 | for response in responses:
118 | score = 0.0
119 | if "" in response: score += 0.2
120 | if "" in response: score += 0.2
121 | if "" in response: score += 0.2
122 | if "" in response: score += 0.2
123 | rewards.append(score)
124 | format_scores.append(score)
125 | return rewards
126 |
127 |
128 | # trial and error here, for better weightage between the reward components ...
129 | def reward_fn(completions, answers):
130 | functional_rewards = functional_reward_fn(completions, answers)
131 | structural_rewards = structural_reward_fn(completions)
132 |
133 | combined_rewards = []
134 | for f_score, s_score in zip(functional_rewards, structural_rewards):
135 | # Correctness score range: 0.0 to 2.0
136 | # Format score range: 0.0 to 0.8
137 | # Total range: 0.0 to 2.8
138 | combined_rewards.append(f_score + s_score)
139 |
140 | return combined_rewards
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from gsm8k_data import extract_gsm8k_answer, gsm8k_metric
3 | import os
4 | from contextlib import nullcontext
5 |
6 |
7 | def setup_training_environment(device_name, dtype = "bfloat16"):
8 | """set up mixed precision (for memory optimization)"""
9 |
10 | # Set up random seed
11 | torch.backends.cuda.matmul.allow_tf32 = True
12 | torch.backends.cudnn.allow_tf32 = True
13 |
14 | # Set up mixed precision
15 | device_type = 'cuda' if 'cuda' in device_name else 'cpu'
16 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
17 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
18 |
19 | return {
20 | 'device': device_name,
21 | 'ctx': ctx,
22 | 'device_type': device_type,
23 | }
24 |
25 | env = setup_training_environment("cuda" if torch.cuda.is_available() else "cpu")
26 |
27 | def evaluate_model(model, tokenizer, eval_examples, device, env, max_completion_length):
28 | model.eval()
29 | correct = 0
30 | total = len(eval_examples)
31 | print("\n" + "="*50)
32 | print("EVALUATION ON", total, "EXAMPLES")
33 | print("="*50)
34 |
35 | for example in eval_examples:
36 | # Get the prompt and expected answer
37 | full_prompt = example["prompt"]
38 | expected = example["answer"]
39 |
40 | # Tokenize and generate response
41 | inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
42 | with torch.no_grad():
43 | with env['ctx']:
44 | outputs = model.generate(
45 | inputs,
46 | max_new_tokens=max_completion_length,
47 | temperature=0.7,
48 | num_return_sequences=1,
49 | pad_token_id=tokenizer.pad_token_id,
50 | eos_token_id=tokenizer.eos_token_id,
51 | forced_eos_token_id=tokenizer.eos_token_id,
52 | early_stopping=False,
53 | ) # 'forward generation --> RL on reward' | can we do the same with pre-training ?
54 | response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55 |
56 | try:
57 | # 1. extract functional
58 | predicted = extract_gsm8k_answer(response)
59 | # 2. metric-based reward fn
60 | is_correct, _ = gsm8k_metric(predicted, expected)
61 |
62 | # Update counter for correct answers
63 | if is_correct:
64 | correct += 1
65 |
66 | # Print evaluation details
67 | print("\nPrompt:")
68 | print(full_prompt)
69 | print("\nExpected Answer:")
70 | print(expected)
71 | print("\nExtracted Answer:")
72 | print(predicted)
73 | print("\nFull Generated Response:")
74 | print(response)
75 | print("\nCorrect:", "✓" if is_correct else "✗")
76 | print("-"*50)
77 |
78 | except Exception as e:
79 | print("\nFailed to parse model output for prompt:")
80 | print(full_prompt)
81 | print("Error:", e)
82 | print("-"*50)
83 |
84 | # Calculate and print final accuracy
85 | accuracy = (correct / total) * 100
86 | print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{total})")
87 | print("="*50)
88 |
89 | # Return model to training mode
90 | model.train()
91 | return accuracy
92 |
93 |
94 | from transformers import AutoModelForCausalLM, AutoTokenizer
95 | from gsm8k_data import prepare_dataset, reward_fn
96 | import random
97 | from grpo import train_with_grpo, optimize_model_memory
98 | import os
99 | import wandb
100 |
101 |
102 | def main(args):
103 | device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
104 | print(f"Using primary device: {device}")
105 |
106 | model_name = args.model_name
107 | output_dir = args.output_dir
108 |
109 | print("Downloading model...")
110 | model = AutoModelForCausalLM.from_pretrained(
111 | model_name,
112 | torch_dtype=torch.bfloat16,
113 | device_map="auto"
114 | )
115 | print("Model downloaded")
116 |
117 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
118 | tokenizer.pad_token = tokenizer.eos_token
119 | model.config.pad_token_id = tokenizer.eos_token_id
120 | model.config.eos_token_id = tokenizer.eos_token_id
121 |
122 | all_data = prepare_dataset("train")
123 | random.shuffle(all_data)
124 | size_of_eval_data = 30 # change to a smaller value to save time or to a larger number for a more reliable estimate
125 | eval_data = all_data[:size_of_eval_data]
126 | train_data = all_data[size_of_eval_data:]
127 |
128 | model = optimize_model_memory(model)
129 |
130 | print("\nInitial model evaluation before finetuning:")
131 | pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device, env, args.max_completion_length)
132 | print(f"Pre-GRPO Accuracy: {pre_grpo_accuracy:.2f}%")
133 |
134 | wandb.init(project=os.environ["WANDB_PROJECT"], reinit=True)
135 | print("Weights & Biases initialized.")
136 |
137 | model = train_with_grpo(
138 | model=model,
139 | tokenizer=tokenizer,
140 | train_data=train_data,
141 | reward_function=reward_fn,
142 | num_iterations=args.num_iterations,
143 | num_steps=args.num_steps,
144 | batch_size=args.batch_size,
145 | num_generations=args.num_generations,
146 | max_completion_length=args.max_completion_length,
147 | beta=args.beta,
148 | learning_rate=args.learning_rate,
149 | mu=args.mu,
150 | epsilon=args.epsilon,
151 | env=env
152 | )
153 |
154 | wandb.finish()
155 | print("Training completed and wandb run finished.")
156 |
157 | print("\nFinal model evaluation after GRPO RL fine-tuning:")
158 | post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device, env, args.max_completion_length)
159 | print(f"Post-GRPO Accuracy: {post_grpo_accuracy:.2f}%")
160 |
161 | print("\nSaving GRPO fine-tuned model...")
162 | model.save_pretrained("grpo_finetuned_model")
163 | tokenizer.save_pretrained("grpo_finetuned_model")
164 |
165 |
166 | if __name__ == "__main__":
167 | import argparse
168 |
169 | parser = argparse.ArgumentParser(description="Train with GRPO on GSM8K")
170 | parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct",
171 | help="Model name or path")
172 | parser.add_argument("--output_dir", type=str, default="./output",
173 | help="Directory to save the model")
174 | parser.add_argument("--num_iterations", type=int, default=2,
175 | help="Number of iterations")
176 | parser.add_argument("--num_steps", type=int, default=400,
177 | help="Number of steps per iteration")
178 | parser.add_argument("--batch_size", type=int, default=6,
179 | help="Batch size")
180 | parser.add_argument("--num_generations", type=int, default=8,
181 | help="Number of generations per example")
182 | parser.add_argument("--max_completion_length", type=int, default=512,
183 | help="Maximum completion length")
184 | parser.add_argument("--beta", type=float, default=0.1,
185 | help="Beta parameter for GRPO")
186 | parser.add_argument("--learning_rate", type=float, default=5e-6,
187 | help="Learning rate")
188 | parser.add_argument("--mu", type=float, default=6,
189 | help="Mu parameter for GRPO")
190 | parser.add_argument("--epsilon", type=float, default=0.1,
191 | help="Epsilon parameter for GRPO")
192 |
193 | args = parser.parse_args()
194 | main(args)
--------------------------------------------------------------------------------
/GRPO.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "id": "ec5d0e27-1bf4-4526-b0f1-8239c05525e7",
7 | "metadata": {
8 | "collapsed": true,
9 | "jupyter": {
10 | "outputs_hidden": true
11 | }
12 | },
13 | "outputs": [],
14 | "source": [
15 | "import torch, os\n",
16 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
17 | "from datasets import load_dataset\n",
18 | "from grpo import set_random_seed\n",
19 | "\n",
20 | "# Call the function to set random seed for reproducibility\n",
21 | "set_random_seed(42)\n",
22 | "\n",
23 | "# Set environment variables for Weights & Biases (wandb) logging\n",
24 | "os.environ[\"WANDB_API_KEY\"] = \"YOUR WANDB API KEY\"\n",
25 | "os.environ[\"WANDB_PROJECT\"] = \"YOUR WANDB PROJECT NAME\""
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 3,
31 | "id": "3e35173a-b617-4dde-ba8c-d9ad20ea51ea",
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "from gsm8k_data import gsm8k_metric, extract_gsm8k_answer\n",
36 | "\n",
37 | "def evaluate_model(model, tokenizer, eval_examples, device):\n",
38 | "\n",
39 | " model.eval()\n",
40 | " correct = 0\n",
41 | " total = len(eval_examples)\n",
42 | " print(\"\\n\" + \"=\"*50)\n",
43 | " print(\"EVALUATION ON\", total, \"EXAMPLES\")\n",
44 | " print(\"=\"*50)\n",
45 | "\n",
46 | " for example in eval_examples:\n",
47 | " # Get the prompt and expected answer\n",
48 | " full_prompt = example[\"prompt\"]\n",
49 | " expected = example[\"answer\"]\n",
50 | "\n",
51 | " # Tokenize and generate response\n",
52 | " inputs = tokenizer.encode(full_prompt, return_tensors=\"pt\").to(device)\n",
53 | " with torch.no_grad():\n",
54 | " outputs = model.generate(\n",
55 | " inputs,\n",
56 | " max_new_tokens=512,\n",
57 | " temperature=0.7,\n",
58 | " num_return_sequences=1,\n",
59 | " pad_token_id=tokenizer.pad_token_id,\n",
60 | " eos_token_id=tokenizer.eos_token_id,\n",
61 | " forced_eos_token_id=tokenizer.eos_token_id,\n",
62 | " early_stopping=False,\n",
63 | " ) # 'forward generation --> RL on reward' | can we do the same with pre-training ?\n",
64 | " response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
65 | "\n",
66 | " try: \n",
67 | " # 1. extract functional \n",
68 | " predicted = extract_gsm8k_answer(response)\n",
69 | " # 2. metric-based reward fn\n",
70 | " is_correct = gsm8k_metric(predicted, expected)\n",
71 | "\n",
72 | " # Update counter for correct answers\n",
73 | " if is_correct:\n",
74 | " correct += 1\n",
75 | "\n",
76 | " # Print evaluation details\n",
77 | " print(\"\\nPrompt:\")\n",
78 | " print(full_prompt)\n",
79 | " print(\"\\nExpected Answer:\")\n",
80 | " print(expected)\n",
81 | " print(\"\\nExtracted Answer:\")\n",
82 | " print(predicted)\n",
83 | " print(\"\\nFull Generated Response:\")\n",
84 | " print(response)\n",
85 | " print(\"\\nCorrect:\", \"✓\" if is_correct else \"✗\")\n",
86 | " print(\"-\"*50)\n",
87 | "\n",
88 | " except Exception as e:\n",
89 | " print(\"\\nFailed to parse model output for prompt:\")\n",
90 | " print(full_prompt)\n",
91 | " print(\"Error:\", e)\n",
92 | " print(\"-\"*50)\n",
93 | "\n",
94 | " # Calculate and print final accuracy\n",
95 | " accuracy = (correct / total) * 100\n",
96 | " print(f\"\\nAccuracy: {accuracy:.2f}% ({correct}/{total})\")\n",
97 | " print(\"=\"*50)\n",
98 | "\n",
99 | " # Return model to training mode\n",
100 | " model.train()\n",
101 | " return accuracy"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "id": "41d10965-282c-42c6-a005-487c1ba6a8d8",
108 | "metadata": {
109 | "collapsed": true,
110 | "jupyter": {
111 | "outputs_hidden": true
112 | }
113 | },
114 | "outputs": [],
115 | "source": [
116 | "# Main execution\n",
117 | "from gsm8k_data import prepare_dataset, reward_fn\n",
118 | "from grpo import *\n",
119 | "\n",
120 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
121 | "print(f\"Using primary device: {device}\")\n",
122 | "\n",
123 | "model_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
124 | "output_dir = \"math_solver_model\"\n",
125 | "\n",
126 | "print(\"Downloading model...\")\n",
127 | "model = AutoModelForCausalLM.from_pretrained(\n",
128 | " model_name,\n",
129 | " torch_dtype=torch.bfloat16,\n",
130 | " device_map=\"auto\"\n",
131 | ")\n",
132 | "print(\"Model downloaded\")\n",
133 | "\n",
134 | "tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=\"left\")\n",
135 | "tokenizer.pad_token = tokenizer.eos_token\n",
136 | "model.config.pad_token_id = tokenizer.eos_token_id\n",
137 | "model.config.eos_token_id = tokenizer.eos_token_id\n",
138 | "\n",
139 | "num_gpus = torch.cuda.device_count()\n",
140 | "print(f\"Detected {num_gpus} GPUs\")\n",
141 | "device_ids = list(range(num_gpus)) if num_gpus > 1 else None\n",
142 | "\n",
143 | "all_data = prepare_dataset(\"train\")\n",
144 | "random.shuffle(all_data)\n",
145 | "size_of_eval_data = 30 # change to a smaller value to save time or to a larger number for a more reliable estimate\n",
146 | "eval_data = all_data[:size_of_eval_data]\n",
147 | "train_data = all_data[size_of_eval_data:]\n",
148 | "\n",
149 | "print(\"\\nInitial model evaluation before finetuning:\")\n",
150 | "pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)\n",
151 | "print(f\"Pre-GRPO Accuracy: {pre_grpo_accuracy:.2f}%\")\n",
152 | "\n",
153 | "model = optimize_model_memory(model)\n",
154 | "\n",
155 | "print(\"\\nStarting RL fine-tuning using GRPO...\")\n",
156 | "# This config was tested on a 8xA100 node, where each A100 is has 80GB of VRAM\n",
157 | "training_config = {\n",
158 | " 'num_iterations': 1,\n",
159 | " 'num_steps': 500,\n",
160 | " 'batch_size': 7, # reduce if you have fewer GPUs\n",
161 | " 'num_generations': 12, # reduce if you have GPUs with less VRAM\n",
162 | " 'max_completion_length': 400, # reduce if you have GPUs with less VRAM\n",
163 | " 'beta': 0.04,\n",
164 | " 'learning_rate': 5e-6,\n",
165 | " 'mu': 1,\n",
166 | " 'epsilon': 0.1\n",
167 | "}\n",
168 | "\n",
169 | "# Initialize Weights & Biases\n",
170 | "wandb.init(project=os.environ[\"WANDB_PROJECT\"], reinit=True)\n",
171 | "print(\"Weights & Biases initialized.\")\n",
172 | "\n",
173 | "model = train_with_grpo(\n",
174 | " model=model,\n",
175 | " tokenizer=tokenizer,\n",
176 | " train_data=train_data,\n",
177 | " reward_function=reward_fn,\n",
178 | " device_ids=device_ids,\n",
179 | " **training_config\n",
180 | ")\n",
181 | "\n",
182 | "wandb.finish()\n",
183 | "print(\"Training completed and wandb run finished.\")\n",
184 | "\n",
185 | "print(\"\\nFinal model evaluation after GRPO RL fine-tuning:\")\n",
186 | "post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)\n",
187 | "print(f\"Post-GRPO Accuracy: {post_grpo_accuracy:.2f}%\")\n",
188 | "\n",
189 | "print(\"\\nSaving GRPO fine-tuned model...\")\n",
190 | "model.save_pretrained(\"grpo_finetuned_model\")\n",
191 | "tokenizer.save_pretrained(\"grpo_finetuned_model\")"
192 | ]
193 | }
194 | ],
195 | "metadata": {
196 | "kernelspec": {
197 | "display_name": "base",
198 | "language": "python",
199 | "name": "python3"
200 | },
201 | "language_info": {
202 | "codemirror_mode": {
203 | "name": "ipython",
204 | "version": 3
205 | },
206 | "file_extension": ".py",
207 | "mimetype": "text/x-python",
208 | "name": "python",
209 | "nbconvert_exporter": "python",
210 | "pygments_lexer": "ipython3",
211 | "version": "3.11.5"
212 | }
213 | },
214 | "nbformat": 4,
215 | "nbformat_minor": 5
216 | }
217 |
--------------------------------------------------------------------------------
/grpo.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import wandb
3 | import torch
4 | import random
5 | import numpy as np
6 | import torch.nn as nn
7 | from contextlib import nullcontext
8 |
9 | def setup_training_environment(device_name, dtype = "bfloat16"):
10 | """set up mixed precision (for memory optimization)"""
11 |
12 | # Set up random seed
13 | torch.backends.cuda.matmul.allow_tf32 = True
14 | torch.backends.cudnn.allow_tf32 = True
15 |
16 | # Set up mixed precision
17 | device_type = 'cuda' if 'cuda' in device_name else 'cpu'
18 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
19 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
20 |
21 | return {
22 | 'device': device_name,
23 | 'ctx': ctx,
24 | 'device_type': device_type,
25 | }
26 |
27 | def get_memory_usage():
28 | """Get current GPU memory usage in a human-readable format."""
29 | if torch.cuda.is_available():
30 | allocated = torch.cuda.memory_allocated() / (1024 ** 3) # Convert to GB
31 | reserved = torch.cuda.memory_reserved() / (1024 ** 3) # Convert to GB
32 | return f"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved"
33 | return "CUDA not available"
34 |
35 | def set_random_seed(seed: int = 42):
36 | """Set the random seed for reproducibility across Python, NumPy, and PyTorch."""
37 | # Set the seed for Python's built-in random module
38 | random.seed(seed)
39 | # Set the seed for NumPy
40 | np.random.seed(seed)
41 | # Set the seed for PyTorch
42 | torch.manual_seed(seed)
43 | if torch.cuda.is_available():
44 | torch.cuda.manual_seed_all(seed)
45 | # Ensure deterministic behavior in cuDNN (may impact performance)
46 | torch.backends.cudnn.deterministic = True
47 | torch.backends.cudnn.benchmark = False
48 |
49 | def selective_log_softmax(logits, input_ids, chunk_size=64):
50 | """Process in chunks to reduce peak memory"""
51 | device = logits.device
52 | batch_size, seq_len, vocab_size = logits.shape
53 | log_probs = torch.zeros(batch_size, seq_len, device=device)
54 |
55 | for i in range(0, seq_len, chunk_size):
56 | end_idx = min(i + chunk_size, seq_len)
57 | chunk_logits = logits[:, i:end_idx, :]
58 | chunk_ids = input_ids[:, i:end_idx]
59 | chunk_log_probs = nn.functional.log_softmax(chunk_logits, dim=-1)
60 | # print(" - chunkwise softmax computation GPU memory: ", get_memory_usage())
61 | log_probs[:, i:end_idx] = chunk_log_probs.gather(
62 | dim=-1, index=chunk_ids.unsqueeze(-1)).squeeze(-1)
63 | del chunk_logits, chunk_log_probs
64 | torch.cuda.empty_cache()
65 | return log_probs
66 |
67 | def compute_log_probs(model, input_ids, attention_mask, logits_to_keep, env, chunk_size=64):
68 | with env['ctx']:
69 | logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[:, :-1, :]
70 | input_ids = input_ids[:, -logits_to_keep:]
71 | logits = logits[:, -logits_to_keep:, :]
72 | return selective_log_softmax(logits, input_ids, chunk_size)
73 |
74 | def create_completion_mask(completion_ids, eos_token_id):
75 |
76 | # ----- TBD: replace this with a less hacky solution -----
77 | is_eos = completion_ids == eos_token_id
78 | # shape: (batch_size,) value: max length of completion
79 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
80 | mask_exists = is_eos.any(dim=1) # operate on sequence with eos_token
81 | eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists] # first eos token index | hacky: argmax returns first index of largest value
82 | # ----- TBD: replace above with a less hacky solution -----
83 |
84 | # create indices of tokens, then build non-end mask by comparing with eos_token index (first appearance indicate termination)
85 | sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) # (batch_size, max_length)
86 | return (sequence_indices < eos_idx.unsqueeze(1)).int()
87 |
88 | def generate_completions(model, tokenizer, prompts, device, num_generations=4, max_completion_length=32, env=None):
89 | """Generate multiple completions for each prompt, record completion mask (end-of-sequence)"""
90 | inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
91 | prompt_ids = inputs["input_ids"].to(device)
92 | prompt_mask = inputs["attention_mask"].to(device)
93 | prompt_length = prompt_ids.size(1) # Question: sequences within batch should have different length? This leads to wrong completion mask?
94 | prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0)
95 | prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0)
96 | with env['ctx']:
97 | outputs = model.generate(
98 | prompt_ids,
99 | attention_mask=prompt_mask,
100 | max_new_tokens=max_completion_length,
101 | do_sample=True,
102 | temperature=1.0,
103 | pad_token_id=tokenizer.pad_token_id,
104 | eos_token_id=tokenizer.eos_token_id,
105 | early_stopping=False
106 | ) # Important: same length outputs from this generate function
107 | completion_ids = outputs[:, prompt_length:]
108 | completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id) # completion mask excludes eos_token and suffix tokens
109 | return prompt_ids, prompt_mask, completion_ids, completion_mask
110 |
111 |
112 | def generate_rollout_data(model, ref_model, tokenizer, batch_samples, device, num_generations, max_completion_length, env, chunk=64):
113 | """Generate responses and calculate log-probabilities of each response under two model"""
114 | prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples]
115 | answers = [sample["answer"] if isinstance(sample, dict) else sample[1] for sample in batch_samples]
116 | with torch.no_grad():
117 | prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
118 | model, tokenizer, prompts, device, num_generations, max_completion_length, env
119 | )
120 | input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
121 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
122 | logits_to_keep = completion_ids.size(1) # same question: whether dim=1 is same across in-batch sequences
123 | # AR perplexity based RL (Issue #1) --> can we use 'skippy loss'
124 | # hidden-space RL (Idea #1) --> rollout on latent space?
125 |
126 | old_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep, env, chunk)
127 | ref_log_probs = compute_log_probs(ref_model, input_ids, attention_mask, logits_to_keep, env, chunk) # ref is base model ? (also used in KL regularization I recall)
128 |
129 | # one-turn completion & reward assignment (Issue #2)
130 | formatted_completions = [[{'content': tokenizer.decode(ids, skip_special_tokens=True)}] for ids in completion_ids]
131 | repeated_prompts = [p for p in prompts for _ in range(num_generations)]
132 | repeated_answers = [a for a in answers for _ in range(num_generations)]
133 | return {
134 | "input_ids": input_ids,
135 | "attention_mask": attention_mask,
136 | "completion_mask": completion_mask,
137 | "old_log_probs": old_log_probs,
138 | "ref_log_probs": ref_log_probs,
139 | "formatted_completions": formatted_completions,
140 | "repeated_prompts": repeated_prompts,
141 | "repeated_answers": repeated_answers,
142 | "logits_to_keep": logits_to_keep,
143 | "batch_size": len(prompts),
144 | "num_generations": num_generations
145 | }
146 |
147 | # Issue #4. do we need to keep all 3 model? prev, curr, base?
148 | def grpo_loss(model, ref_model, rollout_data, tokenizer, reward_function,
149 | device, beta=0.01, epsilon=0.2, env=None, chunk=64):
150 | """
151 | GRPO loss function:
152 | - group normalized reward
153 | - conservative advantage clipping
154 | - kl regularization
155 | """
156 | input_ids = rollout_data["input_ids"]
157 | attention_mask = rollout_data["attention_mask"]
158 | completion_mask = rollout_data["completion_mask"]
159 | logits_to_keep = rollout_data["logits_to_keep"]
160 | old_log_probs = rollout_data["old_log_probs"]
161 | ref_log_probs = rollout_data["ref_log_probs"]
162 |
163 | new_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep, env, chunk)
164 |
165 | # ratio: between new and old
166 | ratio = torch.exp(new_log_probs - old_log_probs)
167 | rewards = torch.tensor(
168 | reward_function(completions=rollout_data["formatted_completions"], answers=rollout_data["repeated_answers"]),
169 | dtype=torch.float32,
170 | device=device
171 | )
172 |
173 |
174 | batch_size = rollout_data["batch_size"]
175 | num_generations = rollout_data["num_generations"]
176 |
177 | # group refers to 'num_genereations' over each prompt, literally 'best-of-N', relative advantage is calculated here
178 | rewards = rewards.view(batch_size, num_generations)
179 | avg_reward = rewards.mean().item()
180 | mean_rewards = rewards.mean(dim=1).repeat_interleave(num_generations)
181 | std_rewards = rewards.std(dim=1).repeat_interleave(num_generations)
182 | advantages = ((rewards.view(-1) - mean_rewards) / (std_rewards + 1e-4)).unsqueeze(1)
183 | surrogate_loss = torch.min(ratio * advantages, torch.clamp(ratio, 1-epsilon, 1+epsilon) * advantages)
184 | kl = torch.exp(ref_log_probs - new_log_probs) - (ref_log_probs - new_log_probs) - 1
185 | per_token_loss = surrogate_loss - beta * kl
186 | loss = - ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
187 | del kl, surrogate_loss, per_token_loss
188 |
189 | return loss, avg_reward
190 |
191 |
192 | def train_with_grpo(model, tokenizer, train_data,
193 | num_iterations=1, num_steps=500,
194 | batch_size=4, num_generations=4, max_completion_length=128,
195 | beta=0.1, learning_rate=5e-6, mu=3, epsilon=0.2, reward_function=None,
196 | env=None, gradient_accumulation_steps=1):
197 |
198 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
199 |
200 | for iteration in range(num_iterations):
201 | print(f"\nIteration {iteration+1}/{num_iterations}")
202 |
203 | ref_model = copy.deepcopy(model)
204 | ref_model.eval()
205 | for param in ref_model.parameters():
206 | param.requires_grad = False
207 | print("Reference model created")
208 |
209 | # re-initialize optimizer
210 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
211 | model.train()
212 |
213 | for step in range(num_steps):
214 | print(f"\nStep {step+1}/{num_steps}")
215 | batch_samples = random.sample(train_data, batch_size)
216 |
217 | with torch.no_grad():
218 | rollout_data = generate_rollout_data(
219 | model,
220 | ref_model,
221 | tokenizer,
222 | batch_samples,
223 | device,
224 | num_generations,
225 | max_completion_length,
226 | env
227 | )
228 | print("\n\n------------------------ \n Example response: \n", rollout_data['formatted_completions'][0][0]['content'])
229 | # Clear cache after generating rollouts
230 | if torch.cuda.is_available():
231 | torch.cuda.empty_cache()
232 |
233 | for grpo_iter in range(mu):
234 | print(f"GRPO inner loop {grpo_iter+1}/{mu}")
235 | loss, avg_reward = grpo_loss(
236 | model,
237 | ref_model,
238 | rollout_data,
239 | tokenizer,
240 | reward_function,
241 | device=device,
242 | beta=beta,
243 | epsilon=epsilon,
244 | env=env
245 | )
246 | optimizer.zero_grad()
247 | loss.backward()
248 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
249 | optimizer.step()
250 |
251 | # Clear cache after each iteration
252 | if torch.cuda.is_available():
253 | torch.cuda.empty_cache()
254 |
255 | # log to wandb
256 | wandb.log({
257 | "loss": loss.item(),
258 | "average_reward": avg_reward,
259 | "step": step + 1,
260 | "grpo_iter": grpo_iter + 1,
261 | })
262 |
263 |
264 |
265 | print(f"Iteration {iteration+1}/{num_iterations}, Step {step+1}/{num_steps}, "
266 | f"GRPO iter {grpo_iter+1}/{mu}, loss: {loss.item():.4f}, reward: {avg_reward}")
267 |
268 | del loss # explicitly delete loss to free memory
269 |
270 | return model
271 |
272 |
273 | def optimize_model_memory(model):
274 | model.config.use_cache = False
275 | model = torch.compile(model)
276 | model.gradient_checkpointing_enable()
277 | return model
--------------------------------------------------------------------------------