├── .gitignore ├── LICENSE ├── README.md ├── evaluator.py ├── llms.py ├── main.py ├── plots ├── eval_score.png └── train_score.png ├── plotter.py ├── requirements.txt ├── rldatasets.py ├── run.sh ├── training_score.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Brendan Hogan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # DeepSeek R1 Implementation 3 | 4 | ## Motivation 5 | I wanted to recreate DeepSeek R1's results at a smaller scale, focusing on understanding the core mechanics by implementing everything from scratch. So this is a repo that trains Qwen1.5B on the [grade school math dataset](https://github.com/openai/grade-school-math). 6 | 7 | This implementation heavily borrows from [Will Brown's work](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) ([@willccbb](https://x.com/willccbb)), but restructures the code into a format optimized for learning and experimentation. 8 | 9 | The key difference in my implementation is computing the GRPO loss function directly rather than using external RL libraries, and reformatting into a multi script repo. 10 | 11 | I hope this might help other people understand things better, and maybe provide an easier way to try out smaller scale ideas etc. 12 | 13 | ## Installation 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | Required environment variables: 19 | ``` 20 | export HUGGINGFACE_TOKEN="your-token-here" 21 | huggingface-cli login 22 | ``` 23 | 24 | ## Implementation Details 25 | 26 | The system consists of several key modules: 27 | 28 | ### main.py 29 | Contains the core training loop implementing GRPO (Generalized Reward-Powered Optimization). Handles model training, evaluation, and metric tracking. 30 | 31 | ### llms.py 32 | Manages model loading and configuration, currently supporting LLaMA + Qwen models through Hugging Face's transformers library. Designed to be easily extensible to other model architectures. 33 | 34 | ### rldatasets.py 35 | Handles dataset loading and preprocessing, currently focused on GSM8K math problems. Implements custom data loaders for both training and evaluation. 36 | 37 | ### evaluator.py 38 | Contains evaluation metrics and reward functions, closely following DeepSeek's original implementation. 39 | 40 | ## Results 41 | Training was conducted on a single H100 GPU. After ~400 training steps: 42 | 43 | ![Training Results](plots/train_score.png) 44 | 45 | And results on the validation set - this shows a clearer sign of learning: 46 | ![Eval Results](plots/eval_score.png) 47 | 48 | ## Future Directions 49 | I'm really pleased to see how well the key mechanics work even in this simplified implementation. Building on this, I am very excited about several directions: 50 | 51 | 1. Adding self-play capabilities where agents compete and learn from each other using relative rewards. This would create a more dynamic training environment where the reward signal comes from agent interactions rather than fixed metrics. 52 | 53 | 2. Implementing soft reward structures, particularly for complex reasoning tasks. I've writing a framework for AI debate that I'm excited to try out. 54 | 55 | 3. Expanding into vision-language models (VLMs) to improve world modeling capabilities. I have an idea about using R1-style training to enhance how VLMs build and maintain internal world models that I'm really excited to explore. (Really excited about this idea - if anyone else is interested I would love to talk.) 56 | 57 | 4. I'd like to do all this experimentation in this framework, so I need to make things faster, and support multi-gpu training. 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Abstract base class and implementations for reward computation in RL training. 3 | 4 | """ 5 | 6 | import re 7 | import torch 8 | from abc import ABC, abstractmethod 9 | from typing import List, Dict, Tuple, Any 10 | 11 | class RewardEvaluator(ABC): 12 | """ 13 | Abstract base class for reward computation in RL training. 14 | 15 | This class defines the interface for reward evaluators that can be used 16 | to score model completions during RL training. Implement this class to 17 | create custom reward functions for different tasks. 18 | 19 | The main methods that need to be implemented are: 20 | - compute_rewards: Computes rewards for a batch of completions 21 | - get_reward_breakdown: Converts raw reward scores to a labeled dictionary 22 | """ 23 | 24 | @abstractmethod 25 | def compute_rewards( 26 | self, 27 | prompts: List[List[Dict[str, str]]], 28 | completions: List[List[Dict[str, str]]], 29 | answer: Any, 30 | device: str 31 | ) -> Tuple[torch.Tensor, Dict[str, float]]: 32 | """ 33 | Compute rewards for a batch of completions. 34 | 35 | Args: 36 | prompts: List of prompt messages in chat format 37 | [{"role": "user", "content": "..."}, ...] 38 | completions: List of completion messages in chat format 39 | [{"role": "assistant", "content": "..."}, ...] 40 | answer: Ground truth answer(s) for the prompts 41 | device: Device to place tensors on ("cpu" or "cuda") 42 | 43 | Returns: 44 | rewards_per_func: Tensor of shape (num_completions, num_reward_functions) 45 | containing individual reward function scores 46 | metrics: Dictionary of aggregated metrics including mean rewards 47 | per function and total reward 48 | """ 49 | pass 50 | 51 | @abstractmethod 52 | def get_reward_breakdown(self, reward_scores: torch.Tensor) -> Dict[str, float]: 53 | """ 54 | Convert raw reward scores tensor to a labeled dictionary. 55 | 56 | Args: 57 | reward_scores: Tensor of raw scores from compute_rewards 58 | 59 | Returns: 60 | Dictionary mapping reward function names to their scores 61 | """ 62 | pass 63 | 64 | 65 | def get_evaluator(name: str) -> RewardEvaluator: 66 | """ 67 | Get the appropriate reward evaluator for a given task. 68 | 69 | Args: 70 | name: Name of the task/dataset to get evaluator for 71 | 72 | Returns: 73 | RewardEvaluator instance for the specified task 74 | 75 | Raises: 76 | NotImplementedError: If evaluator for given task is not implemented 77 | """ 78 | if name.lower() == "gsm8k": 79 | return GSM8kEvaluator() 80 | else: 81 | raise NotImplementedError(f"No evaluator implemented for {name}") 82 | 83 | 84 | 85 | class GSM8kEvaluator(RewardEvaluator): 86 | """ 87 | Reward evaluator for the GSM8K math problem dataset. 88 | 89 | Implements reward functions for: 90 | - Answer correctness 91 | - Integer format validation 92 | - XML formatting (strict and soft) 93 | - XML tag counting 94 | """ 95 | 96 | def __init__(self): 97 | self.num_reward_functions = 5 98 | 99 | def _extract_xml_answer(self, text: str) -> str: 100 | """Extract answer from XML tags.""" 101 | answer = text.split("")[-1] 102 | answer = answer.split("")[0] 103 | return answer.strip() 104 | 105 | def _correctness_reward(self, prompts, completions, answer) -> List[float]: 106 | """Reward for correct answer.""" 107 | responses = [completion[0]['content'] for completion in completions] 108 | extracted = [self._extract_xml_answer(r) for r in responses] 109 | return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)] 110 | 111 | def _int_format_reward(self, completions) -> List[float]: 112 | """Reward for integer format.""" 113 | responses = [completion[0]['content'] for completion in completions] 114 | extracted = [self._extract_xml_answer(r) for r in responses] 115 | return [0.5 if r.isdigit() else 0.0 for r in extracted] 116 | 117 | def _strict_format_reward(self, completions) -> List[float]: 118 | """Reward for strict XML format.""" 119 | pattern = r"^\n.*?\n\n\n.*?\n\n$" 120 | responses = [completion[0]["content"] for completion in completions] 121 | matches = [bool(re.match(pattern, r)) for r in responses] 122 | return [0.5 if m else 0.0 for m in matches] 123 | 124 | def _soft_format_reward(self, completions) -> List[float]: 125 | """Reward for relaxed XML format.""" 126 | pattern = r".*?\s*.*?" 127 | responses = [completion[0]["content"] for completion in completions] 128 | matches = [bool(re.match(pattern, r)) for r in responses] 129 | return [0.5 if m else 0.0 for m in matches] 130 | 131 | def _xml_count_reward(self, completions) -> List[float]: 132 | """Reward for XML tag counting.""" 133 | def count_xml(text: str) -> float: 134 | count = 0.0 135 | if text.count("\n") == 1: count += 0.125 136 | if text.count("\n\n") == 1: count += 0.125 137 | if text.count("\n\n") == 1: 138 | count += 0.125 139 | count -= len(text.split("\n\n")[-1])*0.001 140 | if text.count("\n") == 1: 141 | count += 0.125 142 | count -= (len(text.split("\n")[-1]) - 1)*0.001 143 | return count 144 | 145 | responses = [completion[0]["content"] for completion in completions] 146 | return [count_xml(r) for r in responses] 147 | 148 | def compute_rewards( 149 | self, 150 | prompts: List[List[Dict[str, str]]], 151 | completions: List[List[Dict[str, str]]], 152 | answer: Any, 153 | device: str 154 | ) -> Tuple[torch.Tensor, Dict[str, float]]: 155 | """Compute all rewards for the given completions.""" 156 | 157 | num_completions = len(completions) 158 | rewards_per_func = torch.zeros(num_completions, self.num_reward_functions, device=device) 159 | 160 | # Compute all reward functions 161 | all_scores = [ 162 | self._correctness_reward(prompts, completions, answer), 163 | self._int_format_reward(completions), 164 | self._strict_format_reward(completions), 165 | self._soft_format_reward(completions), 166 | self._xml_count_reward(completions) 167 | ] 168 | 169 | # Fill rewards tensor 170 | for i, scores in enumerate(all_scores): 171 | rewards_per_func[:, i] = torch.tensor(scores, dtype=torch.float32, device=device) 172 | 173 | # Compute metrics 174 | reward_per_func = rewards_per_func.mean(0) 175 | 176 | # Calculate accuracy (perfect correctness score) 177 | correctness_scores = rewards_per_func[:, 0] # First reward function is correctness 178 | num_perfect = (correctness_scores == 2.0).sum().item() 179 | accuracy = num_perfect / num_completions 180 | 181 | metrics = { 182 | "rewards/correctness_reward_func": reward_per_func[0].item(), 183 | "rewards/int_reward_func": reward_per_func[1].item(), 184 | "rewards/strict_format_reward_func": reward_per_func[2].item(), 185 | "rewards/soft_format_reward_func": reward_per_func[3].item(), 186 | "rewards/xmlcount_reward_func": reward_per_func[4].item(), 187 | "reward": rewards_per_func.sum(dim=1).mean().item(), 188 | "accuracy": accuracy 189 | } 190 | 191 | return rewards_per_func, metrics 192 | 193 | def get_reward_breakdown(self, reward_scores: torch.Tensor) -> Dict[str, float]: 194 | """Convert reward scores tensor to labeled dictionary.""" 195 | return { 196 | 'correctness': reward_scores[0].item(), 197 | 'integer_format': reward_scores[1].item(), 198 | 'strict_format': reward_scores[2].item(), 199 | 'soft_format': reward_scores[3].item(), 200 | 'xml_count': reward_scores[4].item() 201 | } 202 | -------------------------------------------------------------------------------- /llms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for loading LLMs and their tokenizers from huggingface. 3 | 4 | """ 5 | import torch 6 | from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase 7 | 8 | 9 | def get_llm_tokenizer(model_name: str, device: str) -> tuple[PreTrainedModel, PreTrainedTokenizerBase]: 10 | """ 11 | Load and configure a language model and its tokenizer. 12 | 13 | Args: 14 | model_name: Name or path of the pretrained model to load 15 | device: Device to load the model on ('cpu' or 'cuda') 16 | 17 | Returns: 18 | tuple containing: 19 | - The loaded language model 20 | - The configured tokenizer for that model 21 | """ 22 | model = AutoModelForCausalLM.from_pretrained( 23 | model_name, 24 | torch_dtype=torch.bfloat16, 25 | attn_implementation="flash_attention_2", 26 | device_map=None, 27 | ).to(device) 28 | 29 | tokenizer = AutoTokenizer.from_pretrained(model_name) 30 | tokenizer.pad_token = tokenizer.eos_token 31 | model.config.pad_token_id = tokenizer.pad_token_id 32 | model.config.use_cache = False 33 | 34 | return model, tokenizer 35 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of GRPO, DeepSeek style training without external libraries 3 | """ 4 | import os 5 | import json 6 | import torch 7 | import argparse 8 | from tqdm import tqdm 9 | from collections import defaultdict 10 | from transformers import PreTrainedModel, PreTrainedTokenizerBase, GenerationConfig 11 | 12 | import llms 13 | import utils 14 | import evaluator 15 | import rldatasets 16 | 17 | def eval_on_test_set( 18 | model: PreTrainedModel, 19 | tokenizer: PreTrainedTokenizerBase, 20 | test_loader: rldatasets.DataLoader, 21 | eval_class: evaluator.RewardEvaluator, 22 | device: str, 23 | args: argparse.Namespace, 24 | round_num: int 25 | ) -> tuple[dict[str, float], float]: 26 | """ 27 | Evaluate model performance on test set. 28 | 29 | Args: 30 | model: The model to evaluate 31 | tokenizer: Tokenizer for the model 32 | test_loader: DataLoader for test set 33 | eval_class: Evaluator for computing rewards 34 | device: Device to run on 35 | args: Training arguments 36 | round_num: Current training round number 37 | 38 | Returns: 39 | total_scores: Dictionary of average metrics 40 | accuracy: Accuracy on test set 41 | """ 42 | print("Running evaluation on test set...") 43 | 44 | # Track metrics across all test examples 45 | total_scores = defaultdict(float) 46 | num_examples = 0 47 | total_accuracy = 0.0 48 | 49 | # Create log file for this evaluation round 50 | log_file = os.path.join(args.output_dir, f'eval_metrics_{round_num}.txt') 51 | test_loader.reset() 52 | 53 | with open(log_file, 'w') as f: 54 | # Run through test set 55 | for question, answer in tqdm(test_loader, desc="Evaluating on test set"): 56 | # Generate completions using same function as training 57 | _, _, _, _, completions_text, _ = generate_completions( 58 | model, tokenizer, question, device, args 59 | ) 60 | 61 | # Score completions using evaluator 62 | mock_prompts = [[{'content': question}]] * len(completions_text) 63 | mock_completions = [[{'content': completion}] for completion in completions_text] 64 | # Make answer array same length as completions 65 | answers = [answer] * len(completions_text) 66 | rewards_per_func, metrics = eval_class.compute_rewards( 67 | prompts=mock_prompts, 68 | completions=mock_completions, 69 | answer=answers, 70 | device=device 71 | ) 72 | 73 | # Track accuracy and accumulate metrics 74 | total_accuracy += metrics['accuracy'] 75 | 76 | for k, v in metrics.items(): 77 | total_scores[k] += v 78 | num_examples += 1 79 | 80 | # Log this example 81 | f.write("\n" + "="*50 + "\n") 82 | f.write(f"Q# {num_examples}\n") 83 | f.write(f"Question: {question}\n") 84 | f.write(f"Response: {completions_text[0]}\n") # Log first completion 85 | f.write(f"Ground Truth: {answer}\n") 86 | f.write("Metrics:\n") 87 | for metric, value in metrics.items(): 88 | f.write(f"{metric}: {value}\n") 89 | f.write(f"Total Score: {rewards_per_func.sum().item()}\n") 90 | 91 | 92 | # Calculate averages 93 | avg_scores = {k: v/num_examples for k,v in total_scores.items()} 94 | accuracy = total_accuracy / num_examples * 100 95 | 96 | # Save metrics 97 | metrics_path = os.path.join(args.output_dir, f'eval_metrics_{round_num}.json') 98 | with open(metrics_path, 'w') as f: 99 | json.dump({**avg_scores, 'accuracy': accuracy}, f, indent=4) 100 | 101 | if args.verbose: 102 | print("\nEvaluation Results:") 103 | print("-" * 20) 104 | print(f"Accuracy: {accuracy:.2f}%") 105 | for metric, value in avg_scores.items(): 106 | print(f"{metric:15s}: {value:.4f}") 107 | print("-" * 20) 108 | 109 | return avg_scores, accuracy 110 | 111 | def generate_completions( 112 | model: PreTrainedModel, 113 | tokenizer: PreTrainedTokenizerBase, 114 | question: str, 115 | device: str, 116 | args: argparse.Namespace 117 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], str]: 118 | """ 119 | Generate multiple completion sequences for a given prompt using a language model. 120 | 121 | Args: 122 | model: The language model to use for generation 123 | tokenizer: Tokenizer corresponding to the model 124 | question: The input question/prompt to generate completions for 125 | device: Device to run generation on ('cpu' or 'cuda') 126 | args: Namespace containing generation parameters 127 | 128 | Returns: 129 | prompt_completion_ids: Tensor containing the full sequence of prompt + completion token IDs 130 | prompt_ids: Tensor containing just the prompt token IDs 131 | completion_ids: Tensor containing just the completion token IDs 132 | attention_mask: Attention mask tensor for the full sequence 133 | completions_text: List of decoded completion texts 134 | prompt_text: The full formatted prompt text 135 | """ 136 | # 1. Prepare prompting 137 | prompt = [ 138 | {'role': 'system', 'content': train_loader.system_prompt}, 139 | {'role': 'user', 'content': question} 140 | ] 141 | prompt_text = tokenizer.apply_chat_template(prompt, tokenize=False) 142 | prompt_inputs = tokenizer(prompt_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False) 143 | prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] 144 | 145 | # Truncate prompt to max length and repeat for number of generations 146 | prompt_ids = prompt_ids[:, -args.max_prompt_length:] 147 | prompt_mask = prompt_mask[:, -args.max_prompt_length:] 148 | 149 | # Repeat for number of chains/generations 150 | prompt_ids = prompt_ids.repeat(args.num_chains, 1) 151 | prompt_mask = prompt_mask.repeat(args.num_chains, 1) 152 | 153 | # Move tensors to device 154 | prompt_ids = prompt_ids.to(device) 155 | prompt_mask = prompt_mask.to(device) 156 | 157 | # Set up generation config 158 | generation_config = GenerationConfig( 159 | max_new_tokens=args.max_completion_length, 160 | do_sample=True, 161 | temperature=args.temperature, 162 | pad_token_id=tokenizer.pad_token_id 163 | ) 164 | 165 | # Generate completions 166 | prompt_completion_ids = model.generate( 167 | prompt_ids, 168 | attention_mask=prompt_mask, 169 | generation_config=generation_config 170 | ) 171 | 172 | # Extract completion ids 173 | prompt_length = prompt_ids.size(1) 174 | prompt_ids = prompt_completion_ids[:, :prompt_length] 175 | completion_ids = prompt_completion_ids[:, prompt_length:] 176 | 177 | # Do masking 178 | is_eos = completion_ids == tokenizer.eos_token_id 179 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) 180 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] 181 | sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) 182 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() 183 | 184 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) 185 | 186 | # Decode completions 187 | completions_text = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) 188 | 189 | return prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completions_text, prompt_text 190 | 191 | def score_completions( 192 | completions_text: list[str], 193 | question: str, 194 | answer: str, 195 | eval_class: evaluator.RewardEvaluator, 196 | device: str, 197 | args: argparse.Namespace 198 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float], dict]: 199 | """ 200 | Score model completions and compute advantages for training. 201 | 202 | Args: 203 | completions_text: List of generated completion strings 204 | question: Original input question/prompt 205 | answer: Ground truth answer 206 | eval_class: Evaluator class for computing rewards 207 | device: Device to place tensors on 208 | args: Training arguments 209 | 210 | Returns: 211 | rewards: Raw reward scores for each completion 212 | advantages: Computed advantages for policy gradient 213 | rewards_per_func: Rewards broken down by individual reward functions 214 | metrics: Dictionary of aggregated metrics 215 | log_data: Dictionary containing detailed generation and scoring data 216 | """ 217 | # Build log data dictionary 218 | log_data = { 219 | 'prompt': { 220 | 'text': question, 221 | 'answer': answer 222 | }, 223 | 'generations': [] 224 | } 225 | 226 | # Format inputs as expected by evaluator 227 | mock_prompts = [[{'content': question}]] * len(completions_text) 228 | mock_completions = [[{'content': completion}] for completion in completions_text] 229 | answers = [answer] * len(completions_text) 230 | 231 | # Get rewards and metrics from evaluator 232 | rewards_per_func, metrics = eval_class.compute_rewards( 233 | prompts=mock_prompts, 234 | completions=mock_completions, 235 | answer=answers, 236 | device=device 237 | ) 238 | rewards = rewards_per_func.sum(dim=1) 239 | 240 | # Store generation data 241 | for i, (completion, reward_scores) in enumerate(zip(completions_text, rewards_per_func)): 242 | generation_data = { 243 | 'response': completion, 244 | 'scores': { 245 | **eval_class.get_reward_breakdown(reward_scores), 246 | 'total_reward': rewards[i].item() 247 | } 248 | } 249 | log_data['generations'].append(generation_data) 250 | 251 | # Compute advantages 252 | mean_grouped_rewards = rewards.view(-1, args.num_chains).mean(dim=1) 253 | std_grouped_rewards = rewards.view(-1, args.num_chains).std(dim=1) 254 | 255 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(args.num_chains, dim=0) 256 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(args.num_chains, dim=0) 257 | 258 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) 259 | metrics["reward_std"] = std_grouped_rewards.mean().item() 260 | 261 | # Store summary statistics 262 | log_data['summary_stats'] = { 263 | 'mean_rewards_per_group': mean_grouped_rewards.tolist(), 264 | 'std_rewards_per_group': std_grouped_rewards.tolist(), 265 | 'advantages': advantages.tolist() 266 | } 267 | 268 | return rewards, advantages, rewards_per_func, metrics, log_data 269 | 270 | def compute_loss( 271 | model: PreTrainedModel, 272 | base_model: PreTrainedModel, 273 | prompt_completion_ids: torch.Tensor, 274 | prompt_ids: torch.Tensor, 275 | completion_ids: torch.Tensor, 276 | attention_mask: torch.Tensor, 277 | completion_mask: torch.Tensor, 278 | advantages: torch.Tensor, 279 | args: argparse.Namespace 280 | ) -> tuple[torch.Tensor, dict[str, float]]: 281 | """ 282 | Compute the GRPO loss between current and base model. 283 | 284 | Args: 285 | model: The current model being trained 286 | base_model: The reference model to compare against 287 | prompt_completion_ids: Combined prompt and completion token IDs 288 | prompt_ids: Token IDs for just the prompt 289 | completion_ids: Token IDs for just the completion 290 | attention_mask: Attention mask for the full sequence 291 | completion_mask: Mask indicating which tokens are from the completion 292 | advantages: Advantage values for each sequence 293 | args: Training arguments 294 | 295 | Returns: 296 | loss: The computed GRPO loss 297 | metrics: Dictionary containing additional metrics like KL divergence 298 | """ 299 | 300 | # Only need the generated tokens' logits 301 | logits_to_keep = completion_ids.size(1) 302 | 303 | # Get reference model logits 304 | with torch.inference_mode(): 305 | ref_per_token_logps = utils.get_per_token_logps(base_model, prompt_completion_ids, attention_mask, logits_to_keep) 306 | 307 | # Get training model logits 308 | input_ids = torch.cat([prompt_ids, completion_ids], dim=1) 309 | per_token_logps = utils.get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) 310 | 311 | # Compute KL divergence 312 | per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 313 | 314 | # Compute loss with advantages 315 | per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) 316 | per_token_loss = -(per_token_loss - args.kl_weight_beta * per_token_kl) 317 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 318 | 319 | # Additional metrics 320 | metrics = {} 321 | response_length = completion_mask.sum(1).float().mean().item() 322 | metrics["response_length"] = response_length 323 | mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 324 | metrics["kl"] = mean_kl.item() 325 | 326 | return loss, metrics 327 | 328 | def grpo_loss( 329 | model: PreTrainedModel, 330 | base_model: PreTrainedModel, 331 | tokenizer: PreTrainedTokenizerBase, 332 | question: str, 333 | answer: str, 334 | eval_class: evaluator.RewardEvaluator, 335 | device: str, 336 | round_num: int, 337 | training_log_dir: str, 338 | args: argparse.Namespace 339 | ) -> tuple[torch.Tensor, dict[str, float], float]: 340 | """ 341 | Compute GRPO loss between the current model and base model. 342 | 343 | Args: 344 | model: The current model being trained 345 | base_model: The reference model to compare against 346 | tokenizer: Tokenizer for the models 347 | question: Input question/prompt 348 | answer: Ground truth answer 349 | eval_class: Evaluator for computing rewards 350 | device: Device to run on ('cpu' or 'cuda') 351 | round_num: Current training round number 352 | training_log_dir: Directory to save training logs 353 | args: Training arguments 354 | 355 | Returns: 356 | loss: The computed GRPO loss 357 | metrics: Dictionary containing training metrics 358 | reward: The total reward for this batch 359 | """ 360 | # Generate completions 361 | prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completions_text, prompt_text = generate_completions( 362 | model, tokenizer, question, device, args 363 | ) 364 | 365 | # Score completions 366 | rewards, advantages, rewards_per_func, metrics, log_data = score_completions( 367 | completions_text, question, answer, eval_class, device, args 368 | ) 369 | 370 | # Write log data 371 | log_file = os.path.join(training_log_dir, f'{round_num}_generations.txt') 372 | utils.write_generation_log(log_data, log_file) 373 | 374 | # Compute loss 375 | completion_mask = attention_mask[:, prompt_ids.size(1):] 376 | loss, loss_metrics = compute_loss( 377 | model, base_model, prompt_completion_ids, prompt_ids, completion_ids, 378 | attention_mask, completion_mask, advantages, args 379 | ) 380 | 381 | # Combine metrics 382 | metrics.update(loss_metrics) 383 | 384 | return loss, metrics 385 | 386 | 387 | def parse_args(): 388 | parser = argparse.ArgumentParser(description="GRPO training arguments") 389 | 390 | # Model configuration 391 | parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct", help="Name/path of base model") 392 | parser.add_argument("--dataset_name", type=str, default="gsm8k", help="Dataset to use for training") 393 | parser.add_argument("--evaluator", type=str, default="gsm8k", help="Evaluator to use for scoring") 394 | 395 | # Output and logging 396 | parser.add_argument("--output_dir", type=str, default="output", help="Directory to save outputs") 397 | parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") 398 | parser.add_argument("--save_steps", type=int, default=100, help="Save model every N steps") 399 | parser.add_argument("--eval_iterations", type=int, default=20, help="Number of iterations for evaluation") 400 | 401 | # Optimization hyperparameters 402 | parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate") 403 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="Adam beta1") 404 | parser.add_argument("--adam_beta2", type=float, default=0.99, help="Adam beta2") 405 | parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") 406 | parser.add_argument("--max_grad_norm", type=float, default=0.1, help="Max gradient norm for clipping") 407 | parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of gradient accumulation steps") 408 | parser.add_argument("--warmup_percent", type=float, default=0.18, help="Percentage of total steps for warmup") 409 | parser.add_argument("--update_ref_model", action="store_true", help="Whether to update reference model") 410 | parser.add_argument("--update_ref_model_freq", type=int, default=200, help="How often to update reference model") 411 | parser.add_argument("--ref_model_mixup_alpha", type=float, default=0.1, help="Alpha parameter for reference model mixup") 412 | 413 | 414 | # Generation parameters 415 | parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature") 416 | parser.add_argument("--num_chains", type=int, default=16, help="Number of parallel generation chains") 417 | parser.add_argument("--max_prompt_length", type=int, default=256, help="Maximum prompt length") 418 | parser.add_argument("--max_completion_length", type=int, default=786, help="Maximum completion length") 419 | 420 | # Training parameters 421 | parser.add_argument("--num_train_iters", type=int, default=1000, help="Number of training iterations") 422 | parser.add_argument("--kl_weight_beta", type=float, default=0.04, help="KL penalty weight") 423 | parser.add_argument("--seed", type=int, default=7111994, help="Random seed") 424 | 425 | args = parser.parse_args() 426 | return args 427 | 428 | if __name__ == "__main__": 429 | 430 | # Get all args 431 | args = parse_args() 432 | 433 | # Seed everything 434 | utils.seed_everything(args.seed) 435 | 436 | # Set device and enable bf16 437 | device = "cuda" if torch.cuda.is_available() else "cpu" 438 | torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True 439 | torch.set_float32_matmul_precision('high') 440 | 441 | ############################### 442 | ## Main Experiment settings ## 443 | ############################### 444 | 445 | ## Set which model to train 446 | model, tokenizer = llms.get_llm_tokenizer(args.model_name, device) 447 | base_model, _ = llms.get_llm_tokenizer(args.model_name, device) 448 | 449 | ## Set which data set 450 | train_loader, test_loader = rldatasets.get_dataloaders(args.dataset_name) 451 | 452 | ## Set which evaluation criteria to use 453 | eval_class = evaluator.get_evaluator(args.evaluator) 454 | 455 | ############################### 456 | 457 | 458 | # Setup logging 459 | os.makedirs(args.output_dir, exist_ok=True) 460 | args_dict = vars(args) 461 | args_path = os.path.join(args.output_dir, 'args.json') 462 | with open(args_path, 'w') as f: 463 | json.dump(args_dict, f, indent=4) 464 | eval_log_dir = os.path.join(args.output_dir, 'eval_logs') 465 | os.makedirs(eval_log_dir, exist_ok=True) 466 | train_log_dir = os.path.join(args.output_dir, 'training_logs') 467 | os.makedirs(train_log_dir, exist_ok=True) 468 | 469 | 470 | # Setup optimizer for trainer agent with GRPO config settings 471 | optimizer = torch.optim.AdamW( 472 | model.parameters(), 473 | lr=args.learning_rate, 474 | betas=(args.adam_beta1, args.adam_beta2), 475 | weight_decay=args.weight_decay, 476 | eps=1e-8 477 | ) 478 | 479 | # Add linear warmup learning rate scheduler 480 | warmup_steps = int(args.warmup_percent * args.num_train_iters) 481 | def get_lr(step): 482 | if step < warmup_steps: 483 | return (step / warmup_steps) 484 | return 1.0 485 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=get_lr) 486 | 487 | 488 | # Begin training 489 | accumulated_loss = 0 490 | optimizer.zero_grad() 491 | train_metrics_total = {} 492 | for round_num in tqdm(range(args.num_train_iters), desc="Training Progress"): 493 | 494 | # Evaluate on test set every so often 495 | if round_num % args.eval_iterations == 0: 496 | eval_metrics, eval_accuracy = eval_on_test_set( 497 | model=model, 498 | tokenizer=tokenizer, 499 | test_loader=test_loader, 500 | eval_class=eval_class, 501 | device=device, 502 | args=args, 503 | round_num=round_num 504 | ) 505 | 506 | # Save metrics to eval log dir 507 | metrics_path = os.path.join(eval_log_dir, f'metrics_{round_num}.json') 508 | with open(metrics_path, 'w') as f: 509 | json.dump({ 510 | 'metrics': eval_metrics, 511 | 'accuracy': eval_accuracy 512 | }, f, indent=4) 513 | 514 | # Slowly update ref model 515 | if args.update_ref_model and (round_num+1) % args.update_ref_model_freq == 0: 516 | with torch.no_grad(): 517 | for param, ref_param in zip(model.parameters(), base_model.parameters()): 518 | ref_param.data = args.ref_model_mixup_alpha * param.data + (1 - args.ref_model_mixup_alpha) * ref_param.data 519 | 520 | # Get next question 521 | question, answer = next(train_loader) 522 | 523 | # Do GRPO - generate chains, score, compute advantage, compute loss 524 | total_loss, train_metrics = grpo_loss(model, base_model, tokenizer, question, answer, eval_class, device, round_num, train_log_dir, args) 525 | 526 | # Gradient accumulation 527 | total_loss = total_loss # / args.gradient_accumulation_steps 528 | total_loss.backward() 529 | accumulated_loss += total_loss.item() 530 | scheduler.step() 531 | 532 | # Step optimizer 533 | if (round_num + 1) % args.gradient_accumulation_steps == 0: 534 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 535 | optimizer.step() 536 | optimizer.zero_grad() 537 | 538 | # Logs 539 | train_metrics["learning_rate"] = scheduler.get_last_lr()[0] 540 | train_metrics["loss"] = total_loss.item() * args.gradient_accumulation_steps 541 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')).item() 542 | train_metrics["grad_norm"] = grad_norm 543 | train_metrics_total[round_num] = train_metrics 544 | with open(os.path.join(train_log_dir, "train_logs.json"), "w") as f: 545 | json.dump(train_metrics_total, f, indent=4) 546 | 547 | -------------------------------------------------------------------------------- /plots/eval_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brendanhogan/DeepSeekRL-Extended/09e312a07f634cbb6ee4aaf2d9f3372dca9e7b9b/plots/eval_score.png -------------------------------------------------------------------------------- /plots/train_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brendanhogan/DeepSeekRL-Extended/09e312a07f634cbb6ee4aaf2d9f3372dca9e7b9b/plots/train_score.png -------------------------------------------------------------------------------- /plotter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib.style as style 7 | from matplotlib.backends.backend_pdf import PdfPages 8 | 9 | def moving_average(data, window_size=5): 10 | """Calculate moving average with given window size""" 11 | weights = np.ones(window_size) / window_size 12 | return np.convolve(data, weights, mode='valid') 13 | 14 | def plot_metrics(output_dir): 15 | """ 16 | Plot training metrics from training_logs directory. 17 | Creates PDF with separate plots for each metric over training steps. 18 | Uses a modern, professional style with custom color palette. 19 | """ 20 | # Load training logs 21 | train_logs_path = os.path.join(output_dir, 'training_logs', 'train_logs.json') 22 | with open(train_logs_path, 'r') as f: 23 | train_logs = json.load(f) 24 | 25 | # Load evaluation logs 26 | eval_logs = {} 27 | eval_logs_dir = os.path.join(output_dir, 'eval_logs') 28 | for filename in os.listdir(eval_logs_dir): 29 | if filename.startswith('metrics_') and filename.endswith('.json'): 30 | step = int(filename.split('_')[1].split('.')[0]) 31 | with open(os.path.join(eval_logs_dir, filename), 'r') as f: 32 | eval_logs[step] = json.load(f) 33 | 34 | # Set style and color palette 35 | plt.style.use('bmh') # Using 'bmh' style which is a modern, clean style 36 | colors = ['#2ecc71', '#e74c3c', '#3498db', '#f1c40f', '#9b59b6', '#1abc9c', '#e67e22', '#34495e'] 37 | 38 | # Create PDF to save all plots 39 | pdf_path = os.path.join(output_dir, 'training_plots.pdf') 40 | with PdfPages(pdf_path) as pdf: 41 | 42 | # Plot reward metrics 43 | reward_metrics = [ 44 | 'rewards/correctness_reward_func', 45 | 'rewards/int_reward_func', 46 | 'rewards/strict_format_reward_func', 47 | 'rewards/soft_format_reward_func', 48 | 'rewards/xmlcount_reward_func', 49 | 'reward' 50 | ] 51 | 52 | for metric, color in zip(reward_metrics, colors): 53 | plt.figure(figsize=(12,7)) 54 | steps = [int(x) for x in train_logs.keys()] 55 | values = [metrics[metric] for metrics in train_logs.values()] 56 | 57 | # Plot raw data with low alpha 58 | plt.plot(steps, values, color=color, alpha=0.3, linewidth=1.5, label='Raw data') 59 | 60 | # Calculate and plot moving average if we have enough data points 61 | if len(values) > 5: 62 | ma_values = moving_average(values) 63 | ma_steps = steps[len(steps)-len(ma_values):] 64 | plt.plot(ma_steps, ma_values, color=color, linewidth=2.5, label='Moving average') 65 | 66 | plt.xlabel('Training Steps', fontsize=12) 67 | plt.ylabel(f'{metric.split("/")[-1].replace("_", " ").title()}', fontsize=12) 68 | plt.title(f'{metric.split("/")[-1].replace("_", " ").title()}', fontsize=14, pad=20) 69 | plt.grid(True, alpha=0.3) 70 | plt.legend() 71 | pdf.savefig(bbox_inches='tight') 72 | plt.close() 73 | 74 | # Plot learning rate 75 | plt.figure(figsize=(12,7)) 76 | steps = [int(x) for x in train_logs.keys()] 77 | lr_values = [metrics['learning_rate'] for metrics in train_logs.values()] 78 | 79 | plt.plot(steps, lr_values, color='#e74c3c', linewidth=2.0, label='Learning Rate') 80 | 81 | plt.xlabel('Training Steps', fontsize=12) 82 | plt.ylabel('Learning Rate', fontsize=12) 83 | plt.title('Learning Rate Schedule', fontsize=14, pad=20) 84 | plt.grid(True, alpha=0.3) 85 | plt.legend() 86 | pdf.savefig(bbox_inches='tight') 87 | plt.close() 88 | 89 | # Plot reward standard deviation 90 | plt.figure(figsize=(12,7)) 91 | reward_std = [metrics['reward_std'] for metrics in train_logs.values()] 92 | 93 | plt.plot(steps, reward_std, color='#3498db', alpha=0.3, linewidth=1.5, label='Reward Std (Raw)') 94 | if len(reward_std) > 5: 95 | ma_std = moving_average(reward_std) 96 | ma_steps = steps[len(steps)-len(ma_std):] 97 | plt.plot(ma_steps, ma_std, color='#3498db', linewidth=2.5, label='Reward Std (MA)') 98 | 99 | plt.xlabel('Training Steps', fontsize=12) 100 | plt.ylabel('Standard Deviation', fontsize=12) 101 | plt.title('Reward Standard Deviation', fontsize=14, pad=20) 102 | plt.grid(True, alpha=0.3) 103 | plt.legend() 104 | pdf.savefig(bbox_inches='tight') 105 | plt.close() 106 | 107 | # Plot loss 108 | plt.figure(figsize=(12,7)) 109 | loss_values = [metrics['loss'] for metrics in train_logs.values()] 110 | 111 | plt.plot(steps, loss_values, color='#e67e22', alpha=0.3, linewidth=1.5, label='Loss (Raw)') 112 | if len(loss_values) > 5: 113 | ma_loss = moving_average(loss_values) 114 | ma_steps = steps[len(steps)-len(ma_loss):] 115 | plt.plot(ma_steps, ma_loss, color='#e67e22', linewidth=2.5, label='Loss (MA)') 116 | 117 | plt.xlabel('Training Steps', fontsize=12) 118 | plt.ylabel('Loss', fontsize=12) 119 | plt.title('Training Loss', fontsize=14, pad=20) 120 | plt.grid(True, alpha=0.3) 121 | plt.legend() 122 | pdf.savefig(bbox_inches='tight') 123 | plt.close() 124 | 125 | # Plot KL divergence 126 | plt.figure(figsize=(12,7)) 127 | kl_values = [metrics['kl'] for metrics in train_logs.values()] 128 | 129 | plt.plot(steps, kl_values, color='#9b59b6', alpha=0.3, linewidth=1.5, label='KL Divergence (Raw)') 130 | if len(kl_values) > 5: 131 | ma_kl = moving_average(kl_values) 132 | ma_steps = steps[len(steps)-len(ma_kl):] 133 | plt.plot(ma_steps, ma_kl, color='#9b59b6', linewidth=2.5, label='KL Divergence (MA)') 134 | 135 | plt.xlabel('Training Steps', fontsize=12) 136 | plt.ylabel('KL Divergence', fontsize=12) 137 | plt.title('KL Divergence', fontsize=14, pad=20) 138 | plt.grid(True, alpha=0.3) 139 | plt.legend() 140 | pdf.savefig(bbox_inches='tight') 141 | plt.close() 142 | 143 | # Plot evaluation metrics 144 | if eval_logs: 145 | eval_steps = sorted(eval_logs.keys()) 146 | 147 | # Plot accuracy 148 | plt.figure(figsize=(12,7)) 149 | accuracy_values = [eval_logs[step]['accuracy'] for step in eval_steps] 150 | plt.plot(eval_steps, accuracy_values, color='#2ecc71', linewidth=2.0, label='Accuracy') 151 | plt.xlabel('Training Steps', fontsize=12) 152 | plt.ylabel('Accuracy (%)', fontsize=12) 153 | plt.title('Evaluation Accuracy', fontsize=14, pad=20) 154 | plt.grid(True, alpha=0.3) 155 | plt.legend() 156 | pdf.savefig(bbox_inches='tight') 157 | plt.close() 158 | 159 | # Plot evaluation reward metrics 160 | eval_metrics = [key for key in eval_logs[eval_steps[0]]['metrics'].keys()] 161 | for metric, color in zip(eval_metrics, colors): 162 | plt.figure(figsize=(12,7)) 163 | metric_values = [eval_logs[step]['metrics'][metric] for step in eval_steps] 164 | plt.plot(eval_steps, metric_values, color=color, linewidth=2.0, label=metric) 165 | plt.xlabel('Training Steps', fontsize=12) 166 | plt.ylabel(metric.replace('_', ' ').title(), fontsize=12) 167 | plt.title(f'Evaluation {metric.replace("_", " ").title()}', fontsize=14, pad=20) 168 | plt.grid(True, alpha=0.3) 169 | plt.legend() 170 | pdf.savefig(bbox_inches='tight') 171 | plt.close() 172 | 173 | if __name__ == "__main__": 174 | parser = argparse.ArgumentParser(description='Plot training metrics from logs directory') 175 | parser.add_argument('--log_dir', type=str, help='Directory containing training logs') 176 | args = parser.parse_args() 177 | plot_metrics(args.log_dir) 178 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==1.3.0 3 | aiohappyeyeballs==2.4.4 4 | aiohttp==3.11.11 5 | aiosignal==1.3.2 6 | annotated-types==0.7.0 7 | anyio==4.8.0 8 | appdirs==1.4.4 9 | argcomplete==1.8.1 10 | astunparse==1.6.3 11 | async-timeout==5.0.1 12 | attrs==21.2.0 13 | Automat==20.2.0 14 | Babel==2.8.0 15 | backcall==0.2.0 16 | bcrypt==3.2.0 17 | beautifulsoup4==4.10.0 18 | beniget==0.4.1 19 | bleach==4.1.0 20 | blinker==1.4 21 | bottle==0.12.19 22 | Brotli==1.0.9 23 | certifi==2020.6.20 24 | cffi==1.15.0 25 | chardet==4.0.0 26 | charset-normalizer==3.4.1 27 | click==8.0.3 28 | cloud-init==24.4 29 | colorama==0.4.4 30 | command-not-found==0.3 31 | commonmark==0.9.1 32 | configobj==5.0.6 33 | constantly==15.1.0 34 | cryptography==3.4.8 35 | ctop==1.0.0 36 | cycler==0.11.0 37 | datasets==3.2.0 38 | dbus-python==1.2.18 39 | decorator==4.4.2 40 | defusedxml==0.7.1 41 | dill==0.3.8 42 | distlib==0.3.4 43 | distro==1.7.0 44 | distro-info==1.1+ubuntu0.2 45 | docker==5.0.3 46 | docker-pycreds==0.4.0 47 | einops==0.8.0 48 | entrypoints==0.4 49 | exceptiongroup==1.2.2 50 | filelock==3.6.0 51 | flake8==4.0.1 52 | flash_attn==2.7.4.post1 53 | flatbuffers===1.12.1-git20200711.33e2d80-dfsg1-0.6 54 | fonttools==4.29.1 55 | fpdf==1.7.2 56 | frozenlist==1.5.0 57 | fs==2.4.12 58 | fsspec==2024.3.1 59 | future==0.18.2 60 | gast==0.5.2 61 | gitdb==4.0.12 62 | GitPython==3.1.44 63 | Glances==3.2.4.2 64 | google-pasta==0.2.0 65 | grpcio==1.30.2 66 | h11==0.14.0 67 | h5py==3.6.0 68 | h5py.-debian-h5py-serial==3.6.0 69 | html5lib==1.1 70 | httpcore==1.0.7 71 | httplib2==0.20.2 72 | httpx==0.28.1 73 | huggingface-hub==0.28.1 74 | hyperlink==21.0.0 75 | icdiff==2.0.4 76 | idna==3.3 77 | importlib-metadata==4.6.4 78 | incremental==21.3.0 79 | influxdb==5.3.1 80 | iotop==0.6 81 | ipykernel==6.7.0 82 | ipython==7.31.1 83 | ipython_genutils==0.2.0 84 | jax==0.4.30 85 | jaxlib==0.4.30 86 | jedi==0.18.0 87 | jeepney==0.7.1 88 | Jinja2==3.1.5 89 | jiter==0.8.2 90 | joblib==0.17.0 91 | jsonpatch==1.32 92 | jsonpointer==2.0 93 | jsonschema==3.2.0 94 | jupyter-client==7.1.2 95 | jupyter-core==4.9.1 96 | kaptan==0.5.12 97 | keras==3.6.0 98 | keyring==23.5.0 99 | kiwisolver==1.3.2 100 | launchpadlib==1.10.16 101 | lazr.restfulclient==0.14.4 102 | lazr.uri==1.0.6 103 | libtmux==0.10.1 104 | livereload==2.6.3 105 | lxml==4.8.0 106 | lz4==3.1.3+dfsg 107 | Markdown==3.3.6 108 | MarkupSafe==2.0.1 109 | matplotlib==3.5.1 110 | matplotlib-inline==0.1.3 111 | mccabe==0.6.1 112 | mkdocs==1.1.2 113 | ml-dtypes==0.5.0 114 | more-itertools==8.10.0 115 | mpmath==0.0.0 116 | msgpack==1.0.3 117 | multidict==6.1.0 118 | multiprocess==0.70.16 119 | namex==0.0.8 120 | nest-asyncio==1.5.4 121 | netifaces==0.11.0 122 | networkx==2.4 123 | numpy==1.21.5 124 | nvidia-ml-py==12.555.43 125 | oauthlib==3.2.0 126 | olefile==0.46 127 | openai==1.61.0 128 | opt-einsum==3.3.0 129 | optree==0.13.1 130 | packaging==21.3 131 | pandas==1.3.5 132 | parso==0.8.1 133 | peft==0.14.0 134 | pexpect==4.8.0 135 | pickleshare==0.7.5 136 | Pillow==9.0.1 137 | pipx==1.0.0 138 | platformdirs==2.5.1 139 | ply==3.11 140 | prompt-toolkit==3.0.28 141 | propcache==0.2.1 142 | protobuf==4.21.12 143 | psutil==5.9.0 144 | ptyprocess==0.7.0 145 | py==1.10.0 146 | pyarrow==19.0.0 147 | pyasn1==0.4.8 148 | pyasn1-modules==0.2.1 149 | pycodestyle==2.8.0 150 | pycparser==2.21 151 | pycryptodomex==3.11.0 152 | pydantic==2.10.6 153 | pydantic_core==2.27.2 154 | pyflakes==2.4.0 155 | Pygments==2.11.2 156 | PyGObject==3.42.1 157 | PyHamcrest==2.0.2 158 | pyinotify==0.9.6 159 | PyJWT==2.3.0 160 | pyOpenSSL==21.0.0 161 | pyparsing==2.4.7 162 | pyrsistent==0.18.1 163 | pyserial==3.5 164 | pysmi==0.3.2 165 | pysnmp==4.4.12 166 | pystache==0.6.0 167 | python-apt==2.4.0+ubuntu4 168 | python-dateutil==2.8.1 169 | python-magic==0.4.24 170 | pythran==0.10.0 171 | pytz==2022.1 172 | PyYAML==5.4.1 173 | pyzmq==22.3.0 174 | regex==2024.11.6 175 | requests==2.32.3 176 | rich==11.2.0 177 | safetensors==0.5.2 178 | scikit-learn==0.23.2 179 | scipy==1.8.0 180 | SecretStorage==3.3.1 181 | sentry-sdk==2.20.0 182 | service-identity==18.1.0 183 | setproctitle==1.3.4 184 | six==1.16.0 185 | smmap==5.0.2 186 | sniffio==1.3.1 187 | sos==4.7.2 188 | soupsieve==2.3.1 189 | ssh-import-id==5.11 190 | sympy==1.12 191 | tensorboard==2.18.0 192 | tensorflow==2.18.0 193 | termcolor==1.1.0 194 | tf_keras==2.18.0 195 | threadpoolctl==3.1.0 196 | tmuxp==1.9.2 197 | tokenizers==0.21.0 198 | torch==2.5.1 199 | torchvision==0.20.1 200 | tornado==6.1 201 | tqdm==4.67.1 202 | traitlets==5.1.1 203 | transformers==4.48.2 204 | triton==3.1.0 205 | trl==0.14.0 206 | Twisted==22.1.0 207 | typing_extensions==4.12.2 208 | ufoLib2==0.13.1 209 | ufw==0.36.1 210 | unattended-upgrades==0.1 211 | unicodedata2==14.0.0 212 | urllib3==2.3.0 213 | userpath==1.8.0 214 | virtualenv==20.13.0+ds 215 | wadllib==1.3.6 216 | wandb==0.19.5 217 | wcwidth==0.2.5 218 | webencodings==0.5.1 219 | websocket-client==1.2.3 220 | Werkzeug==2.0.2 221 | wrapt==1.13.3 222 | xxhash==3.5.0 223 | yarl==1.18.3 224 | zipp==1.0.0 225 | zope.interface==5.4.0 226 | -------------------------------------------------------------------------------- /rldatasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hold all data sets 3 | 4 | """ 5 | 6 | import random 7 | import numpy as np 8 | from tqdm import tqdm 9 | from datasets import load_dataset, Dataset 10 | from abc import ABC, abstractmethod 11 | from typing import Tuple, Any 12 | 13 | 14 | 15 | class DataLoader(ABC): 16 | """ 17 | Abstract base class for data loaders. 18 | 19 | This class defines the interface that all dataset loaders should implement. 20 | Specific dataset loaders should inherit from this class and implement the 21 | required methods. 22 | 23 | Attributes: 24 | random (bool): If True, returns items randomly; if False, returns sequentially 25 | current_index (int): Current position for sequential access 26 | """ 27 | 28 | def __init__(self, random: bool = False) -> None: 29 | self.random = random 30 | self.current_index = 0 31 | 32 | @abstractmethod 33 | def __len__(self) -> int: 34 | """Return the total number of items in the dataset.""" 35 | pass 36 | 37 | @abstractmethod 38 | def __iter__(self) -> 'DataLoader': 39 | """Return self as iterator.""" 40 | return self 41 | 42 | @abstractmethod 43 | def __next__(self) -> Any: 44 | """Return the next item(s) in the dataset.""" 45 | pass 46 | 47 | 48 | def extract_hash_answer(text: str) -> str | None: 49 | if "####" not in text: 50 | return None 51 | return text.split("####")[1].strip() 52 | 53 | 54 | 55 | SYSTEM_PROMPT = """ 56 | Respond in the following format: 57 | 58 | ... 59 | 60 | 61 | ... 62 | 63 | """ 64 | 65 | 66 | 67 | class GSM8KLoader(DataLoader): 68 | """ 69 | A loader class that provides iteration over GSM8K math problems. 70 | 71 | This class implements both sequential and random access to math problems through 72 | standard Python iterator protocols. It can be used to iterate over problems either 73 | in order or randomly, making it suitable for both training and evaluation. 74 | 75 | Attributes: 76 | questions (List[str]): List of math question strings 77 | answers (List[str]): List of corresponding answer strings 78 | random (bool): If True, returns problems randomly; if False, returns sequentially 79 | current_index (int): Current position in the lists for sequential access 80 | """ 81 | 82 | def __init__(self, questions: list[str], answers: list[str], random: bool = False) -> None: 83 | super().__init__(random) 84 | self.questions = questions 85 | self.answers = answers 86 | self.pre_prompt = """You will be given a question that involves reasoning. You should reason carefully about the question, then provide your answer. 87 | It is very important that you put your reasoning process inside tags and your final answer inside tags, like this: 88 | 89 | 90 | 91 | Your step-by-step reasoning process here 92 | 93 | 94 | Your final answer here 95 | 96 | 97 | All of your returned text should either be in the or tags - no text outside! Start each answer by immediately starting with . 98 | It is is extremely important you answer in this way - do not put any information or text outside of these tags! 99 | 100 | Question: """ 101 | self.system_prompt = SYSTEM_PROMPT 102 | 103 | def __len__(self) -> int: 104 | return len(self.questions) 105 | 106 | def __iter__(self) -> 'GSM8KLoader': 107 | return self 108 | 109 | def __next__(self) -> tuple[str, str]: 110 | if self.current_index >= len(self.questions): 111 | raise StopIteration 112 | 113 | if self.random: 114 | idx = random.randint(0, len(self.questions) - 1) 115 | else: 116 | idx = self.current_index 117 | self.current_index += 1 118 | 119 | return self.questions[idx], self.answers[idx] 120 | 121 | def reset(self): 122 | self.current_index = 0 123 | 124 | 125 | def build_gsm8k_dataloaders() -> Tuple[GSM8KLoader, GSM8KLoader]: 126 | data = load_dataset('openai/gsm8k', 'main')["train"] 127 | 128 | questions = [] 129 | parsed_answers = [] 130 | for i in tqdm(range(len(data)), desc="Processing"): 131 | # Try to get answer - if is None dont use this sample 132 | ans = extract_hash_answer(data[i]['answer']) 133 | if ans is None: 134 | continue 135 | else: 136 | questions.append(data[i]['question']) 137 | parsed_answers.append(ans) 138 | 139 | # Randomly split into train/test sets 140 | total_samples = len(questions) 141 | test_size = int(total_samples * 0.01) # 10% for test set 142 | 143 | # Generate random indices for test set 144 | test_indices = random.sample(range(total_samples), test_size) 145 | test_indices_set = set(test_indices) 146 | 147 | # Convert to numpy arrays for easier indexing 148 | questions = np.array(questions) 149 | parsed_answers = np.array(parsed_answers) 150 | 151 | # Create boolean mask for test indices 152 | test_mask = np.zeros(total_samples, dtype=bool) 153 | test_mask[list(test_indices_set)] = True 154 | 155 | # Split using boolean indexing 156 | test_questions = questions[test_mask] 157 | test_answers = parsed_answers[test_mask] 158 | train_questions = questions[~test_mask] 159 | train_answers = parsed_answers[~test_mask] 160 | 161 | # Setup data loaders 162 | trainloader = GSM8KLoader(train_questions.tolist(), train_answers.tolist()) 163 | testloader = GSM8KLoader(test_questions.tolist(), test_answers.tolist()) 164 | 165 | return trainloader, testloader 166 | 167 | 168 | def get_dataloaders(dataset_name: str) -> Tuple[DataLoader, DataLoader]: 169 | """ 170 | Factory function to get train and test data loaders for a specified dataset. 171 | 172 | Args: 173 | dataset_name (str): Name of the dataset to load ('gsm8k' currently supported) 174 | 175 | Returns: 176 | Tuple[DataLoader, DataLoader]: Train and test data loaders 177 | 178 | Raises: 179 | ValueError: If dataset_name is not supported 180 | """ 181 | if dataset_name.lower() == 'gsm8k': 182 | return build_gsm8k_dataloaders() 183 | else: 184 | raise ValueError(f"Dataset {dataset_name} not supported. Currently only 'gsm8k' is available.") 185 | 186 | 187 | if __name__ == "__main__": 188 | trainloader, testloader = get_dataloaders('gsm8k') -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # python main.py --output_dir "final1" --verbose 2 | python plotter.py --log_dir "final1" -------------------------------------------------------------------------------- /training_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brendanhogan/DeepSeekRL-Extended/09e312a07f634cbb6ee4aaf2d9f3372dca9e7b9b/training_score.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from typing import Any, Dict, Optional 7 | 8 | import re 9 | 10 | #################### 11 | ## MISC FUNCTIONS ## 12 | #################### 13 | 14 | def clean_spaces_preserve_newlines(text): 15 | # Replace multiple spaces with a single space, but preserve newlines 16 | lines = text.split("\n") # Split by newlines 17 | cleaned_lines = [" ".join(re.split(r"\s+", line)).strip() for line in lines] # Remove extra spaces in each line 18 | return "\n".join(cleaned_lines) # Join the lines back with newlines 19 | 20 | 21 | 22 | def seed_everything(seed: int) -> None: 23 | """ 24 | Set random seed for reproducibility across multiple libraries. 25 | 26 | This function sets consistent random seeds for Python's random module, 27 | NumPy, PyTorch (both CPU and CUDA), and configures CUDNN for deterministic 28 | operation. This ensures reproducible results across multiple runs. 29 | 30 | Args: 31 | seed: The random seed to use for all random number generators 32 | """ 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | 38 | # Additional settings for reproducibility 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | 42 | 43 | 44 | def write_generation_log(log_data: Dict[str, Any], log_file: str) -> None: 45 | """ 46 | Write generation log data to a text file. 47 | 48 | Args: 49 | log_data: Dictionary containing prompt and generation data 50 | log_file: Path to output log file 51 | """ 52 | with open(log_file, 'w') as f: 53 | # Write prompt section 54 | f.write("###### ORIGINAL PROMPT #####\n\n") 55 | f.write(log_data['prompt']['text'] + "\n\n") 56 | f.write("#### ANS ####\n\n") 57 | f.write(str(log_data['prompt']['answer']) + "\n") 58 | 59 | # Write each generation 60 | for i, gen in enumerate(log_data['generations'], 1): 61 | f.write(f"#### GENERATION {i} RESPONSE ####\n\n") 62 | f.write(gen['response'] + "\n\n") 63 | f.write(f"#### GENERATION {i} SCORES ####\n") 64 | 65 | # Write individual scores 66 | f.write(f"Correctness: {gen['scores']['correctness']}\n") 67 | f.write(f"Integer format: {gen['scores']['integer_format']}\n") 68 | f.write(f"Strict format: {gen['scores']['strict_format']}\n") 69 | f.write(f"Soft format: {gen['scores']['soft_format']}\n") 70 | f.write(f"XML count: {gen['scores']['xml_count']}\n") 71 | f.write(f"Total reward: {gen['scores']['total_reward']}\n\n") 72 | 73 | 74 | #################################################################################### 75 | ## Copied Directly from TRL -> generate log probs per token ######## 76 | ## https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py ######## 77 | #################################################################################### 78 | 79 | def selective_log_softmax(logits, index): 80 | """ 81 | A memory-efficient implementation of the common `log_softmax -> gather` operation. 82 | 83 | This function is equivalent to the following naive implementation: 84 | ```python 85 | logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) 86 | ``` 87 | 88 | Args: 89 | logits (`torch.Tensor`): 90 | Logits tensor of shape `(..., num_classes)`. 91 | index (`torch.Tensor`): 92 | Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. 93 | 94 | Returns: 95 | `torch.Tensor`: 96 | Gathered log probabilities with the same shape as `index`. 97 | """ 98 | if logits.dtype in [torch.float32, torch.float64]: 99 | selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) 100 | # loop to reduce peak mem consumption 101 | logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) 102 | per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) 103 | else: 104 | # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach 105 | per_token_logps = [] 106 | for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption 107 | row_logps = F.log_softmax(row_logits, dim=-1) 108 | row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) 109 | per_token_logps.append(row_per_token_logps) 110 | per_token_logps = torch.stack(per_token_logps) 111 | return per_token_logps 112 | 113 | def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep): 114 | # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded 115 | logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits 116 | logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 117 | 118 | input_ids = input_ids[:, -logits_to_keep:] 119 | # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. 120 | # See https://github.com/huggingface/trl/issues/2770 121 | logits = logits[:, -logits_to_keep:] 122 | return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens 123 | --------------------------------------------------------------------------------