├── .gitignore
├── LICENSE
├── README.md
├── evaluator.py
├── llms.py
├── main.py
├── plots
├── eval_score.png
└── train_score.png
├── plotter.py
├── requirements.txt
├── rldatasets.py
├── run.sh
├── training_score.png
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Brendan Hogan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # DeepSeek R1 Implementation
3 |
4 | ## Motivation
5 | I wanted to recreate DeepSeek R1's results at a smaller scale, focusing on understanding the core mechanics by implementing everything from scratch. So this is a repo that trains Qwen1.5B on the [grade school math dataset](https://github.com/openai/grade-school-math).
6 |
7 | This implementation heavily borrows from [Will Brown's work](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) ([@willccbb](https://x.com/willccbb)), but restructures the code into a format optimized for learning and experimentation.
8 |
9 | The key difference in my implementation is computing the GRPO loss function directly rather than using external RL libraries, and reformatting into a multi script repo.
10 |
11 | I hope this might help other people understand things better, and maybe provide an easier way to try out smaller scale ideas etc.
12 |
13 | ## Installation
14 | ```
15 | pip install -r requirements.txt
16 | ```
17 |
18 | Required environment variables:
19 | ```
20 | export HUGGINGFACE_TOKEN="your-token-here"
21 | huggingface-cli login
22 | ```
23 |
24 | ## Implementation Details
25 |
26 | The system consists of several key modules:
27 |
28 | ### main.py
29 | Contains the core training loop implementing GRPO (Generalized Reward-Powered Optimization). Handles model training, evaluation, and metric tracking.
30 |
31 | ### llms.py
32 | Manages model loading and configuration, currently supporting LLaMA + Qwen models through Hugging Face's transformers library. Designed to be easily extensible to other model architectures.
33 |
34 | ### rldatasets.py
35 | Handles dataset loading and preprocessing, currently focused on GSM8K math problems. Implements custom data loaders for both training and evaluation.
36 |
37 | ### evaluator.py
38 | Contains evaluation metrics and reward functions, closely following DeepSeek's original implementation.
39 |
40 | ## Results
41 | Training was conducted on a single H100 GPU. After ~400 training steps:
42 |
43 | 
44 |
45 | And results on the validation set - this shows a clearer sign of learning:
46 | 
47 |
48 | ## Future Directions
49 | I'm really pleased to see how well the key mechanics work even in this simplified implementation. Building on this, I am very excited about several directions:
50 |
51 | 1. Adding self-play capabilities where agents compete and learn from each other using relative rewards. This would create a more dynamic training environment where the reward signal comes from agent interactions rather than fixed metrics.
52 |
53 | 2. Implementing soft reward structures, particularly for complex reasoning tasks. I've writing a framework for AI debate that I'm excited to try out.
54 |
55 | 3. Expanding into vision-language models (VLMs) to improve world modeling capabilities. I have an idea about using R1-style training to enhance how VLMs build and maintain internal world models that I'm really excited to explore. (Really excited about this idea - if anyone else is interested I would love to talk.)
56 |
57 | 4. I'd like to do all this experimentation in this framework, so I need to make things faster, and support multi-gpu training.
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/evaluator.py:
--------------------------------------------------------------------------------
1 | """
2 | Abstract base class and implementations for reward computation in RL training.
3 |
4 | """
5 |
6 | import re
7 | import torch
8 | from abc import ABC, abstractmethod
9 | from typing import List, Dict, Tuple, Any
10 |
11 | class RewardEvaluator(ABC):
12 | """
13 | Abstract base class for reward computation in RL training.
14 |
15 | This class defines the interface for reward evaluators that can be used
16 | to score model completions during RL training. Implement this class to
17 | create custom reward functions for different tasks.
18 |
19 | The main methods that need to be implemented are:
20 | - compute_rewards: Computes rewards for a batch of completions
21 | - get_reward_breakdown: Converts raw reward scores to a labeled dictionary
22 | """
23 |
24 | @abstractmethod
25 | def compute_rewards(
26 | self,
27 | prompts: List[List[Dict[str, str]]],
28 | completions: List[List[Dict[str, str]]],
29 | answer: Any,
30 | device: str
31 | ) -> Tuple[torch.Tensor, Dict[str, float]]:
32 | """
33 | Compute rewards for a batch of completions.
34 |
35 | Args:
36 | prompts: List of prompt messages in chat format
37 | [{"role": "user", "content": "..."}, ...]
38 | completions: List of completion messages in chat format
39 | [{"role": "assistant", "content": "..."}, ...]
40 | answer: Ground truth answer(s) for the prompts
41 | device: Device to place tensors on ("cpu" or "cuda")
42 |
43 | Returns:
44 | rewards_per_func: Tensor of shape (num_completions, num_reward_functions)
45 | containing individual reward function scores
46 | metrics: Dictionary of aggregated metrics including mean rewards
47 | per function and total reward
48 | """
49 | pass
50 |
51 | @abstractmethod
52 | def get_reward_breakdown(self, reward_scores: torch.Tensor) -> Dict[str, float]:
53 | """
54 | Convert raw reward scores tensor to a labeled dictionary.
55 |
56 | Args:
57 | reward_scores: Tensor of raw scores from compute_rewards
58 |
59 | Returns:
60 | Dictionary mapping reward function names to their scores
61 | """
62 | pass
63 |
64 |
65 | def get_evaluator(name: str) -> RewardEvaluator:
66 | """
67 | Get the appropriate reward evaluator for a given task.
68 |
69 | Args:
70 | name: Name of the task/dataset to get evaluator for
71 |
72 | Returns:
73 | RewardEvaluator instance for the specified task
74 |
75 | Raises:
76 | NotImplementedError: If evaluator for given task is not implemented
77 | """
78 | if name.lower() == "gsm8k":
79 | return GSM8kEvaluator()
80 | else:
81 | raise NotImplementedError(f"No evaluator implemented for {name}")
82 |
83 |
84 |
85 | class GSM8kEvaluator(RewardEvaluator):
86 | """
87 | Reward evaluator for the GSM8K math problem dataset.
88 |
89 | Implements reward functions for:
90 | - Answer correctness
91 | - Integer format validation
92 | - XML formatting (strict and soft)
93 | - XML tag counting
94 | """
95 |
96 | def __init__(self):
97 | self.num_reward_functions = 5
98 |
99 | def _extract_xml_answer(self, text: str) -> str:
100 | """Extract answer from XML tags."""
101 | answer = text.split("")[-1]
102 | answer = answer.split("")[0]
103 | return answer.strip()
104 |
105 | def _correctness_reward(self, prompts, completions, answer) -> List[float]:
106 | """Reward for correct answer."""
107 | responses = [completion[0]['content'] for completion in completions]
108 | extracted = [self._extract_xml_answer(r) for r in responses]
109 | return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)]
110 |
111 | def _int_format_reward(self, completions) -> List[float]:
112 | """Reward for integer format."""
113 | responses = [completion[0]['content'] for completion in completions]
114 | extracted = [self._extract_xml_answer(r) for r in responses]
115 | return [0.5 if r.isdigit() else 0.0 for r in extracted]
116 |
117 | def _strict_format_reward(self, completions) -> List[float]:
118 | """Reward for strict XML format."""
119 | pattern = r"^\n.*?\n\n\n.*?\n\n$"
120 | responses = [completion[0]["content"] for completion in completions]
121 | matches = [bool(re.match(pattern, r)) for r in responses]
122 | return [0.5 if m else 0.0 for m in matches]
123 |
124 | def _soft_format_reward(self, completions) -> List[float]:
125 | """Reward for relaxed XML format."""
126 | pattern = r".*?\s*.*?"
127 | responses = [completion[0]["content"] for completion in completions]
128 | matches = [bool(re.match(pattern, r)) for r in responses]
129 | return [0.5 if m else 0.0 for m in matches]
130 |
131 | def _xml_count_reward(self, completions) -> List[float]:
132 | """Reward for XML tag counting."""
133 | def count_xml(text: str) -> float:
134 | count = 0.0
135 | if text.count("\n") == 1: count += 0.125
136 | if text.count("\n\n") == 1: count += 0.125
137 | if text.count("\n\n") == 1:
138 | count += 0.125
139 | count -= len(text.split("\n\n")[-1])*0.001
140 | if text.count("\n") == 1:
141 | count += 0.125
142 | count -= (len(text.split("\n")[-1]) - 1)*0.001
143 | return count
144 |
145 | responses = [completion[0]["content"] for completion in completions]
146 | return [count_xml(r) for r in responses]
147 |
148 | def compute_rewards(
149 | self,
150 | prompts: List[List[Dict[str, str]]],
151 | completions: List[List[Dict[str, str]]],
152 | answer: Any,
153 | device: str
154 | ) -> Tuple[torch.Tensor, Dict[str, float]]:
155 | """Compute all rewards for the given completions."""
156 |
157 | num_completions = len(completions)
158 | rewards_per_func = torch.zeros(num_completions, self.num_reward_functions, device=device)
159 |
160 | # Compute all reward functions
161 | all_scores = [
162 | self._correctness_reward(prompts, completions, answer),
163 | self._int_format_reward(completions),
164 | self._strict_format_reward(completions),
165 | self._soft_format_reward(completions),
166 | self._xml_count_reward(completions)
167 | ]
168 |
169 | # Fill rewards tensor
170 | for i, scores in enumerate(all_scores):
171 | rewards_per_func[:, i] = torch.tensor(scores, dtype=torch.float32, device=device)
172 |
173 | # Compute metrics
174 | reward_per_func = rewards_per_func.mean(0)
175 |
176 | # Calculate accuracy (perfect correctness score)
177 | correctness_scores = rewards_per_func[:, 0] # First reward function is correctness
178 | num_perfect = (correctness_scores == 2.0).sum().item()
179 | accuracy = num_perfect / num_completions
180 |
181 | metrics = {
182 | "rewards/correctness_reward_func": reward_per_func[0].item(),
183 | "rewards/int_reward_func": reward_per_func[1].item(),
184 | "rewards/strict_format_reward_func": reward_per_func[2].item(),
185 | "rewards/soft_format_reward_func": reward_per_func[3].item(),
186 | "rewards/xmlcount_reward_func": reward_per_func[4].item(),
187 | "reward": rewards_per_func.sum(dim=1).mean().item(),
188 | "accuracy": accuracy
189 | }
190 |
191 | return rewards_per_func, metrics
192 |
193 | def get_reward_breakdown(self, reward_scores: torch.Tensor) -> Dict[str, float]:
194 | """Convert reward scores tensor to labeled dictionary."""
195 | return {
196 | 'correctness': reward_scores[0].item(),
197 | 'integer_format': reward_scores[1].item(),
198 | 'strict_format': reward_scores[2].item(),
199 | 'soft_format': reward_scores[3].item(),
200 | 'xml_count': reward_scores[4].item()
201 | }
202 |
--------------------------------------------------------------------------------
/llms.py:
--------------------------------------------------------------------------------
1 | """
2 | Module for loading LLMs and their tokenizers from huggingface.
3 |
4 | """
5 | import torch
6 | from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase
7 |
8 |
9 | def get_llm_tokenizer(model_name: str, device: str) -> tuple[PreTrainedModel, PreTrainedTokenizerBase]:
10 | """
11 | Load and configure a language model and its tokenizer.
12 |
13 | Args:
14 | model_name: Name or path of the pretrained model to load
15 | device: Device to load the model on ('cpu' or 'cuda')
16 |
17 | Returns:
18 | tuple containing:
19 | - The loaded language model
20 | - The configured tokenizer for that model
21 | """
22 | model = AutoModelForCausalLM.from_pretrained(
23 | model_name,
24 | torch_dtype=torch.bfloat16,
25 | attn_implementation="flash_attention_2",
26 | device_map=None,
27 | ).to(device)
28 |
29 | tokenizer = AutoTokenizer.from_pretrained(model_name)
30 | tokenizer.pad_token = tokenizer.eos_token
31 | model.config.pad_token_id = tokenizer.pad_token_id
32 | model.config.use_cache = False
33 |
34 | return model, tokenizer
35 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of GRPO, DeepSeek style training without external libraries
3 | """
4 | import os
5 | import json
6 | import torch
7 | import argparse
8 | from tqdm import tqdm
9 | from collections import defaultdict
10 | from transformers import PreTrainedModel, PreTrainedTokenizerBase, GenerationConfig
11 |
12 | import llms
13 | import utils
14 | import evaluator
15 | import rldatasets
16 |
17 | def eval_on_test_set(
18 | model: PreTrainedModel,
19 | tokenizer: PreTrainedTokenizerBase,
20 | test_loader: rldatasets.DataLoader,
21 | eval_class: evaluator.RewardEvaluator,
22 | device: str,
23 | args: argparse.Namespace,
24 | round_num: int
25 | ) -> tuple[dict[str, float], float]:
26 | """
27 | Evaluate model performance on test set.
28 |
29 | Args:
30 | model: The model to evaluate
31 | tokenizer: Tokenizer for the model
32 | test_loader: DataLoader for test set
33 | eval_class: Evaluator for computing rewards
34 | device: Device to run on
35 | args: Training arguments
36 | round_num: Current training round number
37 |
38 | Returns:
39 | total_scores: Dictionary of average metrics
40 | accuracy: Accuracy on test set
41 | """
42 | print("Running evaluation on test set...")
43 |
44 | # Track metrics across all test examples
45 | total_scores = defaultdict(float)
46 | num_examples = 0
47 | total_accuracy = 0.0
48 |
49 | # Create log file for this evaluation round
50 | log_file = os.path.join(args.output_dir, f'eval_metrics_{round_num}.txt')
51 | test_loader.reset()
52 |
53 | with open(log_file, 'w') as f:
54 | # Run through test set
55 | for question, answer in tqdm(test_loader, desc="Evaluating on test set"):
56 | # Generate completions using same function as training
57 | _, _, _, _, completions_text, _ = generate_completions(
58 | model, tokenizer, question, device, args
59 | )
60 |
61 | # Score completions using evaluator
62 | mock_prompts = [[{'content': question}]] * len(completions_text)
63 | mock_completions = [[{'content': completion}] for completion in completions_text]
64 | # Make answer array same length as completions
65 | answers = [answer] * len(completions_text)
66 | rewards_per_func, metrics = eval_class.compute_rewards(
67 | prompts=mock_prompts,
68 | completions=mock_completions,
69 | answer=answers,
70 | device=device
71 | )
72 |
73 | # Track accuracy and accumulate metrics
74 | total_accuracy += metrics['accuracy']
75 |
76 | for k, v in metrics.items():
77 | total_scores[k] += v
78 | num_examples += 1
79 |
80 | # Log this example
81 | f.write("\n" + "="*50 + "\n")
82 | f.write(f"Q# {num_examples}\n")
83 | f.write(f"Question: {question}\n")
84 | f.write(f"Response: {completions_text[0]}\n") # Log first completion
85 | f.write(f"Ground Truth: {answer}\n")
86 | f.write("Metrics:\n")
87 | for metric, value in metrics.items():
88 | f.write(f"{metric}: {value}\n")
89 | f.write(f"Total Score: {rewards_per_func.sum().item()}\n")
90 |
91 |
92 | # Calculate averages
93 | avg_scores = {k: v/num_examples for k,v in total_scores.items()}
94 | accuracy = total_accuracy / num_examples * 100
95 |
96 | # Save metrics
97 | metrics_path = os.path.join(args.output_dir, f'eval_metrics_{round_num}.json')
98 | with open(metrics_path, 'w') as f:
99 | json.dump({**avg_scores, 'accuracy': accuracy}, f, indent=4)
100 |
101 | if args.verbose:
102 | print("\nEvaluation Results:")
103 | print("-" * 20)
104 | print(f"Accuracy: {accuracy:.2f}%")
105 | for metric, value in avg_scores.items():
106 | print(f"{metric:15s}: {value:.4f}")
107 | print("-" * 20)
108 |
109 | return avg_scores, accuracy
110 |
111 | def generate_completions(
112 | model: PreTrainedModel,
113 | tokenizer: PreTrainedTokenizerBase,
114 | question: str,
115 | device: str,
116 | args: argparse.Namespace
117 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], str]:
118 | """
119 | Generate multiple completion sequences for a given prompt using a language model.
120 |
121 | Args:
122 | model: The language model to use for generation
123 | tokenizer: Tokenizer corresponding to the model
124 | question: The input question/prompt to generate completions for
125 | device: Device to run generation on ('cpu' or 'cuda')
126 | args: Namespace containing generation parameters
127 |
128 | Returns:
129 | prompt_completion_ids: Tensor containing the full sequence of prompt + completion token IDs
130 | prompt_ids: Tensor containing just the prompt token IDs
131 | completion_ids: Tensor containing just the completion token IDs
132 | attention_mask: Attention mask tensor for the full sequence
133 | completions_text: List of decoded completion texts
134 | prompt_text: The full formatted prompt text
135 | """
136 | # 1. Prepare prompting
137 | prompt = [
138 | {'role': 'system', 'content': train_loader.system_prompt},
139 | {'role': 'user', 'content': question}
140 | ]
141 | prompt_text = tokenizer.apply_chat_template(prompt, tokenize=False)
142 | prompt_inputs = tokenizer(prompt_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False)
143 | prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
144 |
145 | # Truncate prompt to max length and repeat for number of generations
146 | prompt_ids = prompt_ids[:, -args.max_prompt_length:]
147 | prompt_mask = prompt_mask[:, -args.max_prompt_length:]
148 |
149 | # Repeat for number of chains/generations
150 | prompt_ids = prompt_ids.repeat(args.num_chains, 1)
151 | prompt_mask = prompt_mask.repeat(args.num_chains, 1)
152 |
153 | # Move tensors to device
154 | prompt_ids = prompt_ids.to(device)
155 | prompt_mask = prompt_mask.to(device)
156 |
157 | # Set up generation config
158 | generation_config = GenerationConfig(
159 | max_new_tokens=args.max_completion_length,
160 | do_sample=True,
161 | temperature=args.temperature,
162 | pad_token_id=tokenizer.pad_token_id
163 | )
164 |
165 | # Generate completions
166 | prompt_completion_ids = model.generate(
167 | prompt_ids,
168 | attention_mask=prompt_mask,
169 | generation_config=generation_config
170 | )
171 |
172 | # Extract completion ids
173 | prompt_length = prompt_ids.size(1)
174 | prompt_ids = prompt_completion_ids[:, :prompt_length]
175 | completion_ids = prompt_completion_ids[:, prompt_length:]
176 |
177 | # Do masking
178 | is_eos = completion_ids == tokenizer.eos_token_id
179 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
180 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
181 | sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
182 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
183 |
184 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
185 |
186 | # Decode completions
187 | completions_text = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
188 |
189 | return prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completions_text, prompt_text
190 |
191 | def score_completions(
192 | completions_text: list[str],
193 | question: str,
194 | answer: str,
195 | eval_class: evaluator.RewardEvaluator,
196 | device: str,
197 | args: argparse.Namespace
198 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float], dict]:
199 | """
200 | Score model completions and compute advantages for training.
201 |
202 | Args:
203 | completions_text: List of generated completion strings
204 | question: Original input question/prompt
205 | answer: Ground truth answer
206 | eval_class: Evaluator class for computing rewards
207 | device: Device to place tensors on
208 | args: Training arguments
209 |
210 | Returns:
211 | rewards: Raw reward scores for each completion
212 | advantages: Computed advantages for policy gradient
213 | rewards_per_func: Rewards broken down by individual reward functions
214 | metrics: Dictionary of aggregated metrics
215 | log_data: Dictionary containing detailed generation and scoring data
216 | """
217 | # Build log data dictionary
218 | log_data = {
219 | 'prompt': {
220 | 'text': question,
221 | 'answer': answer
222 | },
223 | 'generations': []
224 | }
225 |
226 | # Format inputs as expected by evaluator
227 | mock_prompts = [[{'content': question}]] * len(completions_text)
228 | mock_completions = [[{'content': completion}] for completion in completions_text]
229 | answers = [answer] * len(completions_text)
230 |
231 | # Get rewards and metrics from evaluator
232 | rewards_per_func, metrics = eval_class.compute_rewards(
233 | prompts=mock_prompts,
234 | completions=mock_completions,
235 | answer=answers,
236 | device=device
237 | )
238 | rewards = rewards_per_func.sum(dim=1)
239 |
240 | # Store generation data
241 | for i, (completion, reward_scores) in enumerate(zip(completions_text, rewards_per_func)):
242 | generation_data = {
243 | 'response': completion,
244 | 'scores': {
245 | **eval_class.get_reward_breakdown(reward_scores),
246 | 'total_reward': rewards[i].item()
247 | }
248 | }
249 | log_data['generations'].append(generation_data)
250 |
251 | # Compute advantages
252 | mean_grouped_rewards = rewards.view(-1, args.num_chains).mean(dim=1)
253 | std_grouped_rewards = rewards.view(-1, args.num_chains).std(dim=1)
254 |
255 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(args.num_chains, dim=0)
256 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(args.num_chains, dim=0)
257 |
258 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
259 | metrics["reward_std"] = std_grouped_rewards.mean().item()
260 |
261 | # Store summary statistics
262 | log_data['summary_stats'] = {
263 | 'mean_rewards_per_group': mean_grouped_rewards.tolist(),
264 | 'std_rewards_per_group': std_grouped_rewards.tolist(),
265 | 'advantages': advantages.tolist()
266 | }
267 |
268 | return rewards, advantages, rewards_per_func, metrics, log_data
269 |
270 | def compute_loss(
271 | model: PreTrainedModel,
272 | base_model: PreTrainedModel,
273 | prompt_completion_ids: torch.Tensor,
274 | prompt_ids: torch.Tensor,
275 | completion_ids: torch.Tensor,
276 | attention_mask: torch.Tensor,
277 | completion_mask: torch.Tensor,
278 | advantages: torch.Tensor,
279 | args: argparse.Namespace
280 | ) -> tuple[torch.Tensor, dict[str, float]]:
281 | """
282 | Compute the GRPO loss between current and base model.
283 |
284 | Args:
285 | model: The current model being trained
286 | base_model: The reference model to compare against
287 | prompt_completion_ids: Combined prompt and completion token IDs
288 | prompt_ids: Token IDs for just the prompt
289 | completion_ids: Token IDs for just the completion
290 | attention_mask: Attention mask for the full sequence
291 | completion_mask: Mask indicating which tokens are from the completion
292 | advantages: Advantage values for each sequence
293 | args: Training arguments
294 |
295 | Returns:
296 | loss: The computed GRPO loss
297 | metrics: Dictionary containing additional metrics like KL divergence
298 | """
299 |
300 | # Only need the generated tokens' logits
301 | logits_to_keep = completion_ids.size(1)
302 |
303 | # Get reference model logits
304 | with torch.inference_mode():
305 | ref_per_token_logps = utils.get_per_token_logps(base_model, prompt_completion_ids, attention_mask, logits_to_keep)
306 |
307 | # Get training model logits
308 | input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
309 | per_token_logps = utils.get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
310 |
311 | # Compute KL divergence
312 | per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
313 |
314 | # Compute loss with advantages
315 | per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
316 | per_token_loss = -(per_token_loss - args.kl_weight_beta * per_token_kl)
317 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
318 |
319 | # Additional metrics
320 | metrics = {}
321 | response_length = completion_mask.sum(1).float().mean().item()
322 | metrics["response_length"] = response_length
323 | mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
324 | metrics["kl"] = mean_kl.item()
325 |
326 | return loss, metrics
327 |
328 | def grpo_loss(
329 | model: PreTrainedModel,
330 | base_model: PreTrainedModel,
331 | tokenizer: PreTrainedTokenizerBase,
332 | question: str,
333 | answer: str,
334 | eval_class: evaluator.RewardEvaluator,
335 | device: str,
336 | round_num: int,
337 | training_log_dir: str,
338 | args: argparse.Namespace
339 | ) -> tuple[torch.Tensor, dict[str, float], float]:
340 | """
341 | Compute GRPO loss between the current model and base model.
342 |
343 | Args:
344 | model: The current model being trained
345 | base_model: The reference model to compare against
346 | tokenizer: Tokenizer for the models
347 | question: Input question/prompt
348 | answer: Ground truth answer
349 | eval_class: Evaluator for computing rewards
350 | device: Device to run on ('cpu' or 'cuda')
351 | round_num: Current training round number
352 | training_log_dir: Directory to save training logs
353 | args: Training arguments
354 |
355 | Returns:
356 | loss: The computed GRPO loss
357 | metrics: Dictionary containing training metrics
358 | reward: The total reward for this batch
359 | """
360 | # Generate completions
361 | prompt_completion_ids, prompt_ids, completion_ids, attention_mask, completions_text, prompt_text = generate_completions(
362 | model, tokenizer, question, device, args
363 | )
364 |
365 | # Score completions
366 | rewards, advantages, rewards_per_func, metrics, log_data = score_completions(
367 | completions_text, question, answer, eval_class, device, args
368 | )
369 |
370 | # Write log data
371 | log_file = os.path.join(training_log_dir, f'{round_num}_generations.txt')
372 | utils.write_generation_log(log_data, log_file)
373 |
374 | # Compute loss
375 | completion_mask = attention_mask[:, prompt_ids.size(1):]
376 | loss, loss_metrics = compute_loss(
377 | model, base_model, prompt_completion_ids, prompt_ids, completion_ids,
378 | attention_mask, completion_mask, advantages, args
379 | )
380 |
381 | # Combine metrics
382 | metrics.update(loss_metrics)
383 |
384 | return loss, metrics
385 |
386 |
387 | def parse_args():
388 | parser = argparse.ArgumentParser(description="GRPO training arguments")
389 |
390 | # Model configuration
391 | parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct", help="Name/path of base model")
392 | parser.add_argument("--dataset_name", type=str, default="gsm8k", help="Dataset to use for training")
393 | parser.add_argument("--evaluator", type=str, default="gsm8k", help="Evaluator to use for scoring")
394 |
395 | # Output and logging
396 | parser.add_argument("--output_dir", type=str, default="output", help="Directory to save outputs")
397 | parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
398 | parser.add_argument("--save_steps", type=int, default=100, help="Save model every N steps")
399 | parser.add_argument("--eval_iterations", type=int, default=20, help="Number of iterations for evaluation")
400 |
401 | # Optimization hyperparameters
402 | parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate")
403 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="Adam beta1")
404 | parser.add_argument("--adam_beta2", type=float, default=0.99, help="Adam beta2")
405 | parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
406 | parser.add_argument("--max_grad_norm", type=float, default=0.1, help="Max gradient norm for clipping")
407 | parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of gradient accumulation steps")
408 | parser.add_argument("--warmup_percent", type=float, default=0.18, help="Percentage of total steps for warmup")
409 | parser.add_argument("--update_ref_model", action="store_true", help="Whether to update reference model")
410 | parser.add_argument("--update_ref_model_freq", type=int, default=200, help="How often to update reference model")
411 | parser.add_argument("--ref_model_mixup_alpha", type=float, default=0.1, help="Alpha parameter for reference model mixup")
412 |
413 |
414 | # Generation parameters
415 | parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature")
416 | parser.add_argument("--num_chains", type=int, default=16, help="Number of parallel generation chains")
417 | parser.add_argument("--max_prompt_length", type=int, default=256, help="Maximum prompt length")
418 | parser.add_argument("--max_completion_length", type=int, default=786, help="Maximum completion length")
419 |
420 | # Training parameters
421 | parser.add_argument("--num_train_iters", type=int, default=1000, help="Number of training iterations")
422 | parser.add_argument("--kl_weight_beta", type=float, default=0.04, help="KL penalty weight")
423 | parser.add_argument("--seed", type=int, default=7111994, help="Random seed")
424 |
425 | args = parser.parse_args()
426 | return args
427 |
428 | if __name__ == "__main__":
429 |
430 | # Get all args
431 | args = parse_args()
432 |
433 | # Seed everything
434 | utils.seed_everything(args.seed)
435 |
436 | # Set device and enable bf16
437 | device = "cuda" if torch.cuda.is_available() else "cpu"
438 | torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
439 | torch.set_float32_matmul_precision('high')
440 |
441 | ###############################
442 | ## Main Experiment settings ##
443 | ###############################
444 |
445 | ## Set which model to train
446 | model, tokenizer = llms.get_llm_tokenizer(args.model_name, device)
447 | base_model, _ = llms.get_llm_tokenizer(args.model_name, device)
448 |
449 | ## Set which data set
450 | train_loader, test_loader = rldatasets.get_dataloaders(args.dataset_name)
451 |
452 | ## Set which evaluation criteria to use
453 | eval_class = evaluator.get_evaluator(args.evaluator)
454 |
455 | ###############################
456 |
457 |
458 | # Setup logging
459 | os.makedirs(args.output_dir, exist_ok=True)
460 | args_dict = vars(args)
461 | args_path = os.path.join(args.output_dir, 'args.json')
462 | with open(args_path, 'w') as f:
463 | json.dump(args_dict, f, indent=4)
464 | eval_log_dir = os.path.join(args.output_dir, 'eval_logs')
465 | os.makedirs(eval_log_dir, exist_ok=True)
466 | train_log_dir = os.path.join(args.output_dir, 'training_logs')
467 | os.makedirs(train_log_dir, exist_ok=True)
468 |
469 |
470 | # Setup optimizer for trainer agent with GRPO config settings
471 | optimizer = torch.optim.AdamW(
472 | model.parameters(),
473 | lr=args.learning_rate,
474 | betas=(args.adam_beta1, args.adam_beta2),
475 | weight_decay=args.weight_decay,
476 | eps=1e-8
477 | )
478 |
479 | # Add linear warmup learning rate scheduler
480 | warmup_steps = int(args.warmup_percent * args.num_train_iters)
481 | def get_lr(step):
482 | if step < warmup_steps:
483 | return (step / warmup_steps)
484 | return 1.0
485 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=get_lr)
486 |
487 |
488 | # Begin training
489 | accumulated_loss = 0
490 | optimizer.zero_grad()
491 | train_metrics_total = {}
492 | for round_num in tqdm(range(args.num_train_iters), desc="Training Progress"):
493 |
494 | # Evaluate on test set every so often
495 | if round_num % args.eval_iterations == 0:
496 | eval_metrics, eval_accuracy = eval_on_test_set(
497 | model=model,
498 | tokenizer=tokenizer,
499 | test_loader=test_loader,
500 | eval_class=eval_class,
501 | device=device,
502 | args=args,
503 | round_num=round_num
504 | )
505 |
506 | # Save metrics to eval log dir
507 | metrics_path = os.path.join(eval_log_dir, f'metrics_{round_num}.json')
508 | with open(metrics_path, 'w') as f:
509 | json.dump({
510 | 'metrics': eval_metrics,
511 | 'accuracy': eval_accuracy
512 | }, f, indent=4)
513 |
514 | # Slowly update ref model
515 | if args.update_ref_model and (round_num+1) % args.update_ref_model_freq == 0:
516 | with torch.no_grad():
517 | for param, ref_param in zip(model.parameters(), base_model.parameters()):
518 | ref_param.data = args.ref_model_mixup_alpha * param.data + (1 - args.ref_model_mixup_alpha) * ref_param.data
519 |
520 | # Get next question
521 | question, answer = next(train_loader)
522 |
523 | # Do GRPO - generate chains, score, compute advantage, compute loss
524 | total_loss, train_metrics = grpo_loss(model, base_model, tokenizer, question, answer, eval_class, device, round_num, train_log_dir, args)
525 |
526 | # Gradient accumulation
527 | total_loss = total_loss # / args.gradient_accumulation_steps
528 | total_loss.backward()
529 | accumulated_loss += total_loss.item()
530 | scheduler.step()
531 |
532 | # Step optimizer
533 | if (round_num + 1) % args.gradient_accumulation_steps == 0:
534 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
535 | optimizer.step()
536 | optimizer.zero_grad()
537 |
538 | # Logs
539 | train_metrics["learning_rate"] = scheduler.get_last_lr()[0]
540 | train_metrics["loss"] = total_loss.item() * args.gradient_accumulation_steps
541 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')).item()
542 | train_metrics["grad_norm"] = grad_norm
543 | train_metrics_total[round_num] = train_metrics
544 | with open(os.path.join(train_log_dir, "train_logs.json"), "w") as f:
545 | json.dump(train_metrics_total, f, indent=4)
546 |
547 |
--------------------------------------------------------------------------------
/plots/eval_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brendanhogan/DeepSeekRL-Extended/09e312a07f634cbb6ee4aaf2d9f3372dca9e7b9b/plots/eval_score.png
--------------------------------------------------------------------------------
/plots/train_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brendanhogan/DeepSeekRL-Extended/09e312a07f634cbb6ee4aaf2d9f3372dca9e7b9b/plots/train_score.png
--------------------------------------------------------------------------------
/plotter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import matplotlib.style as style
7 | from matplotlib.backends.backend_pdf import PdfPages
8 |
9 | def moving_average(data, window_size=5):
10 | """Calculate moving average with given window size"""
11 | weights = np.ones(window_size) / window_size
12 | return np.convolve(data, weights, mode='valid')
13 |
14 | def plot_metrics(output_dir):
15 | """
16 | Plot training metrics from training_logs directory.
17 | Creates PDF with separate plots for each metric over training steps.
18 | Uses a modern, professional style with custom color palette.
19 | """
20 | # Load training logs
21 | train_logs_path = os.path.join(output_dir, 'training_logs', 'train_logs.json')
22 | with open(train_logs_path, 'r') as f:
23 | train_logs = json.load(f)
24 |
25 | # Load evaluation logs
26 | eval_logs = {}
27 | eval_logs_dir = os.path.join(output_dir, 'eval_logs')
28 | for filename in os.listdir(eval_logs_dir):
29 | if filename.startswith('metrics_') and filename.endswith('.json'):
30 | step = int(filename.split('_')[1].split('.')[0])
31 | with open(os.path.join(eval_logs_dir, filename), 'r') as f:
32 | eval_logs[step] = json.load(f)
33 |
34 | # Set style and color palette
35 | plt.style.use('bmh') # Using 'bmh' style which is a modern, clean style
36 | colors = ['#2ecc71', '#e74c3c', '#3498db', '#f1c40f', '#9b59b6', '#1abc9c', '#e67e22', '#34495e']
37 |
38 | # Create PDF to save all plots
39 | pdf_path = os.path.join(output_dir, 'training_plots.pdf')
40 | with PdfPages(pdf_path) as pdf:
41 |
42 | # Plot reward metrics
43 | reward_metrics = [
44 | 'rewards/correctness_reward_func',
45 | 'rewards/int_reward_func',
46 | 'rewards/strict_format_reward_func',
47 | 'rewards/soft_format_reward_func',
48 | 'rewards/xmlcount_reward_func',
49 | 'reward'
50 | ]
51 |
52 | for metric, color in zip(reward_metrics, colors):
53 | plt.figure(figsize=(12,7))
54 | steps = [int(x) for x in train_logs.keys()]
55 | values = [metrics[metric] for metrics in train_logs.values()]
56 |
57 | # Plot raw data with low alpha
58 | plt.plot(steps, values, color=color, alpha=0.3, linewidth=1.5, label='Raw data')
59 |
60 | # Calculate and plot moving average if we have enough data points
61 | if len(values) > 5:
62 | ma_values = moving_average(values)
63 | ma_steps = steps[len(steps)-len(ma_values):]
64 | plt.plot(ma_steps, ma_values, color=color, linewidth=2.5, label='Moving average')
65 |
66 | plt.xlabel('Training Steps', fontsize=12)
67 | plt.ylabel(f'{metric.split("/")[-1].replace("_", " ").title()}', fontsize=12)
68 | plt.title(f'{metric.split("/")[-1].replace("_", " ").title()}', fontsize=14, pad=20)
69 | plt.grid(True, alpha=0.3)
70 | plt.legend()
71 | pdf.savefig(bbox_inches='tight')
72 | plt.close()
73 |
74 | # Plot learning rate
75 | plt.figure(figsize=(12,7))
76 | steps = [int(x) for x in train_logs.keys()]
77 | lr_values = [metrics['learning_rate'] for metrics in train_logs.values()]
78 |
79 | plt.plot(steps, lr_values, color='#e74c3c', linewidth=2.0, label='Learning Rate')
80 |
81 | plt.xlabel('Training Steps', fontsize=12)
82 | plt.ylabel('Learning Rate', fontsize=12)
83 | plt.title('Learning Rate Schedule', fontsize=14, pad=20)
84 | plt.grid(True, alpha=0.3)
85 | plt.legend()
86 | pdf.savefig(bbox_inches='tight')
87 | plt.close()
88 |
89 | # Plot reward standard deviation
90 | plt.figure(figsize=(12,7))
91 | reward_std = [metrics['reward_std'] for metrics in train_logs.values()]
92 |
93 | plt.plot(steps, reward_std, color='#3498db', alpha=0.3, linewidth=1.5, label='Reward Std (Raw)')
94 | if len(reward_std) > 5:
95 | ma_std = moving_average(reward_std)
96 | ma_steps = steps[len(steps)-len(ma_std):]
97 | plt.plot(ma_steps, ma_std, color='#3498db', linewidth=2.5, label='Reward Std (MA)')
98 |
99 | plt.xlabel('Training Steps', fontsize=12)
100 | plt.ylabel('Standard Deviation', fontsize=12)
101 | plt.title('Reward Standard Deviation', fontsize=14, pad=20)
102 | plt.grid(True, alpha=0.3)
103 | plt.legend()
104 | pdf.savefig(bbox_inches='tight')
105 | plt.close()
106 |
107 | # Plot loss
108 | plt.figure(figsize=(12,7))
109 | loss_values = [metrics['loss'] for metrics in train_logs.values()]
110 |
111 | plt.plot(steps, loss_values, color='#e67e22', alpha=0.3, linewidth=1.5, label='Loss (Raw)')
112 | if len(loss_values) > 5:
113 | ma_loss = moving_average(loss_values)
114 | ma_steps = steps[len(steps)-len(ma_loss):]
115 | plt.plot(ma_steps, ma_loss, color='#e67e22', linewidth=2.5, label='Loss (MA)')
116 |
117 | plt.xlabel('Training Steps', fontsize=12)
118 | plt.ylabel('Loss', fontsize=12)
119 | plt.title('Training Loss', fontsize=14, pad=20)
120 | plt.grid(True, alpha=0.3)
121 | plt.legend()
122 | pdf.savefig(bbox_inches='tight')
123 | plt.close()
124 |
125 | # Plot KL divergence
126 | plt.figure(figsize=(12,7))
127 | kl_values = [metrics['kl'] for metrics in train_logs.values()]
128 |
129 | plt.plot(steps, kl_values, color='#9b59b6', alpha=0.3, linewidth=1.5, label='KL Divergence (Raw)')
130 | if len(kl_values) > 5:
131 | ma_kl = moving_average(kl_values)
132 | ma_steps = steps[len(steps)-len(ma_kl):]
133 | plt.plot(ma_steps, ma_kl, color='#9b59b6', linewidth=2.5, label='KL Divergence (MA)')
134 |
135 | plt.xlabel('Training Steps', fontsize=12)
136 | plt.ylabel('KL Divergence', fontsize=12)
137 | plt.title('KL Divergence', fontsize=14, pad=20)
138 | plt.grid(True, alpha=0.3)
139 | plt.legend()
140 | pdf.savefig(bbox_inches='tight')
141 | plt.close()
142 |
143 | # Plot evaluation metrics
144 | if eval_logs:
145 | eval_steps = sorted(eval_logs.keys())
146 |
147 | # Plot accuracy
148 | plt.figure(figsize=(12,7))
149 | accuracy_values = [eval_logs[step]['accuracy'] for step in eval_steps]
150 | plt.plot(eval_steps, accuracy_values, color='#2ecc71', linewidth=2.0, label='Accuracy')
151 | plt.xlabel('Training Steps', fontsize=12)
152 | plt.ylabel('Accuracy (%)', fontsize=12)
153 | plt.title('Evaluation Accuracy', fontsize=14, pad=20)
154 | plt.grid(True, alpha=0.3)
155 | plt.legend()
156 | pdf.savefig(bbox_inches='tight')
157 | plt.close()
158 |
159 | # Plot evaluation reward metrics
160 | eval_metrics = [key for key in eval_logs[eval_steps[0]]['metrics'].keys()]
161 | for metric, color in zip(eval_metrics, colors):
162 | plt.figure(figsize=(12,7))
163 | metric_values = [eval_logs[step]['metrics'][metric] for step in eval_steps]
164 | plt.plot(eval_steps, metric_values, color=color, linewidth=2.0, label=metric)
165 | plt.xlabel('Training Steps', fontsize=12)
166 | plt.ylabel(metric.replace('_', ' ').title(), fontsize=12)
167 | plt.title(f'Evaluation {metric.replace("_", " ").title()}', fontsize=14, pad=20)
168 | plt.grid(True, alpha=0.3)
169 | plt.legend()
170 | pdf.savefig(bbox_inches='tight')
171 | plt.close()
172 |
173 | if __name__ == "__main__":
174 | parser = argparse.ArgumentParser(description='Plot training metrics from logs directory')
175 | parser.add_argument('--log_dir', type=str, help='Directory containing training logs')
176 | args = parser.parse_args()
177 | plot_metrics(args.log_dir)
178 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==1.3.0
3 | aiohappyeyeballs==2.4.4
4 | aiohttp==3.11.11
5 | aiosignal==1.3.2
6 | annotated-types==0.7.0
7 | anyio==4.8.0
8 | appdirs==1.4.4
9 | argcomplete==1.8.1
10 | astunparse==1.6.3
11 | async-timeout==5.0.1
12 | attrs==21.2.0
13 | Automat==20.2.0
14 | Babel==2.8.0
15 | backcall==0.2.0
16 | bcrypt==3.2.0
17 | beautifulsoup4==4.10.0
18 | beniget==0.4.1
19 | bleach==4.1.0
20 | blinker==1.4
21 | bottle==0.12.19
22 | Brotli==1.0.9
23 | certifi==2020.6.20
24 | cffi==1.15.0
25 | chardet==4.0.0
26 | charset-normalizer==3.4.1
27 | click==8.0.3
28 | cloud-init==24.4
29 | colorama==0.4.4
30 | command-not-found==0.3
31 | commonmark==0.9.1
32 | configobj==5.0.6
33 | constantly==15.1.0
34 | cryptography==3.4.8
35 | ctop==1.0.0
36 | cycler==0.11.0
37 | datasets==3.2.0
38 | dbus-python==1.2.18
39 | decorator==4.4.2
40 | defusedxml==0.7.1
41 | dill==0.3.8
42 | distlib==0.3.4
43 | distro==1.7.0
44 | distro-info==1.1+ubuntu0.2
45 | docker==5.0.3
46 | docker-pycreds==0.4.0
47 | einops==0.8.0
48 | entrypoints==0.4
49 | exceptiongroup==1.2.2
50 | filelock==3.6.0
51 | flake8==4.0.1
52 | flash_attn==2.7.4.post1
53 | flatbuffers===1.12.1-git20200711.33e2d80-dfsg1-0.6
54 | fonttools==4.29.1
55 | fpdf==1.7.2
56 | frozenlist==1.5.0
57 | fs==2.4.12
58 | fsspec==2024.3.1
59 | future==0.18.2
60 | gast==0.5.2
61 | gitdb==4.0.12
62 | GitPython==3.1.44
63 | Glances==3.2.4.2
64 | google-pasta==0.2.0
65 | grpcio==1.30.2
66 | h11==0.14.0
67 | h5py==3.6.0
68 | h5py.-debian-h5py-serial==3.6.0
69 | html5lib==1.1
70 | httpcore==1.0.7
71 | httplib2==0.20.2
72 | httpx==0.28.1
73 | huggingface-hub==0.28.1
74 | hyperlink==21.0.0
75 | icdiff==2.0.4
76 | idna==3.3
77 | importlib-metadata==4.6.4
78 | incremental==21.3.0
79 | influxdb==5.3.1
80 | iotop==0.6
81 | ipykernel==6.7.0
82 | ipython==7.31.1
83 | ipython_genutils==0.2.0
84 | jax==0.4.30
85 | jaxlib==0.4.30
86 | jedi==0.18.0
87 | jeepney==0.7.1
88 | Jinja2==3.1.5
89 | jiter==0.8.2
90 | joblib==0.17.0
91 | jsonpatch==1.32
92 | jsonpointer==2.0
93 | jsonschema==3.2.0
94 | jupyter-client==7.1.2
95 | jupyter-core==4.9.1
96 | kaptan==0.5.12
97 | keras==3.6.0
98 | keyring==23.5.0
99 | kiwisolver==1.3.2
100 | launchpadlib==1.10.16
101 | lazr.restfulclient==0.14.4
102 | lazr.uri==1.0.6
103 | libtmux==0.10.1
104 | livereload==2.6.3
105 | lxml==4.8.0
106 | lz4==3.1.3+dfsg
107 | Markdown==3.3.6
108 | MarkupSafe==2.0.1
109 | matplotlib==3.5.1
110 | matplotlib-inline==0.1.3
111 | mccabe==0.6.1
112 | mkdocs==1.1.2
113 | ml-dtypes==0.5.0
114 | more-itertools==8.10.0
115 | mpmath==0.0.0
116 | msgpack==1.0.3
117 | multidict==6.1.0
118 | multiprocess==0.70.16
119 | namex==0.0.8
120 | nest-asyncio==1.5.4
121 | netifaces==0.11.0
122 | networkx==2.4
123 | numpy==1.21.5
124 | nvidia-ml-py==12.555.43
125 | oauthlib==3.2.0
126 | olefile==0.46
127 | openai==1.61.0
128 | opt-einsum==3.3.0
129 | optree==0.13.1
130 | packaging==21.3
131 | pandas==1.3.5
132 | parso==0.8.1
133 | peft==0.14.0
134 | pexpect==4.8.0
135 | pickleshare==0.7.5
136 | Pillow==9.0.1
137 | pipx==1.0.0
138 | platformdirs==2.5.1
139 | ply==3.11
140 | prompt-toolkit==3.0.28
141 | propcache==0.2.1
142 | protobuf==4.21.12
143 | psutil==5.9.0
144 | ptyprocess==0.7.0
145 | py==1.10.0
146 | pyarrow==19.0.0
147 | pyasn1==0.4.8
148 | pyasn1-modules==0.2.1
149 | pycodestyle==2.8.0
150 | pycparser==2.21
151 | pycryptodomex==3.11.0
152 | pydantic==2.10.6
153 | pydantic_core==2.27.2
154 | pyflakes==2.4.0
155 | Pygments==2.11.2
156 | PyGObject==3.42.1
157 | PyHamcrest==2.0.2
158 | pyinotify==0.9.6
159 | PyJWT==2.3.0
160 | pyOpenSSL==21.0.0
161 | pyparsing==2.4.7
162 | pyrsistent==0.18.1
163 | pyserial==3.5
164 | pysmi==0.3.2
165 | pysnmp==4.4.12
166 | pystache==0.6.0
167 | python-apt==2.4.0+ubuntu4
168 | python-dateutil==2.8.1
169 | python-magic==0.4.24
170 | pythran==0.10.0
171 | pytz==2022.1
172 | PyYAML==5.4.1
173 | pyzmq==22.3.0
174 | regex==2024.11.6
175 | requests==2.32.3
176 | rich==11.2.0
177 | safetensors==0.5.2
178 | scikit-learn==0.23.2
179 | scipy==1.8.0
180 | SecretStorage==3.3.1
181 | sentry-sdk==2.20.0
182 | service-identity==18.1.0
183 | setproctitle==1.3.4
184 | six==1.16.0
185 | smmap==5.0.2
186 | sniffio==1.3.1
187 | sos==4.7.2
188 | soupsieve==2.3.1
189 | ssh-import-id==5.11
190 | sympy==1.12
191 | tensorboard==2.18.0
192 | tensorflow==2.18.0
193 | termcolor==1.1.0
194 | tf_keras==2.18.0
195 | threadpoolctl==3.1.0
196 | tmuxp==1.9.2
197 | tokenizers==0.21.0
198 | torch==2.5.1
199 | torchvision==0.20.1
200 | tornado==6.1
201 | tqdm==4.67.1
202 | traitlets==5.1.1
203 | transformers==4.48.2
204 | triton==3.1.0
205 | trl==0.14.0
206 | Twisted==22.1.0
207 | typing_extensions==4.12.2
208 | ufoLib2==0.13.1
209 | ufw==0.36.1
210 | unattended-upgrades==0.1
211 | unicodedata2==14.0.0
212 | urllib3==2.3.0
213 | userpath==1.8.0
214 | virtualenv==20.13.0+ds
215 | wadllib==1.3.6
216 | wandb==0.19.5
217 | wcwidth==0.2.5
218 | webencodings==0.5.1
219 | websocket-client==1.2.3
220 | Werkzeug==2.0.2
221 | wrapt==1.13.3
222 | xxhash==3.5.0
223 | yarl==1.18.3
224 | zipp==1.0.0
225 | zope.interface==5.4.0
226 |
--------------------------------------------------------------------------------
/rldatasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Hold all data sets
3 |
4 | """
5 |
6 | import random
7 | import numpy as np
8 | from tqdm import tqdm
9 | from datasets import load_dataset, Dataset
10 | from abc import ABC, abstractmethod
11 | from typing import Tuple, Any
12 |
13 |
14 |
15 | class DataLoader(ABC):
16 | """
17 | Abstract base class for data loaders.
18 |
19 | This class defines the interface that all dataset loaders should implement.
20 | Specific dataset loaders should inherit from this class and implement the
21 | required methods.
22 |
23 | Attributes:
24 | random (bool): If True, returns items randomly; if False, returns sequentially
25 | current_index (int): Current position for sequential access
26 | """
27 |
28 | def __init__(self, random: bool = False) -> None:
29 | self.random = random
30 | self.current_index = 0
31 |
32 | @abstractmethod
33 | def __len__(self) -> int:
34 | """Return the total number of items in the dataset."""
35 | pass
36 |
37 | @abstractmethod
38 | def __iter__(self) -> 'DataLoader':
39 | """Return self as iterator."""
40 | return self
41 |
42 | @abstractmethod
43 | def __next__(self) -> Any:
44 | """Return the next item(s) in the dataset."""
45 | pass
46 |
47 |
48 | def extract_hash_answer(text: str) -> str | None:
49 | if "####" not in text:
50 | return None
51 | return text.split("####")[1].strip()
52 |
53 |
54 |
55 | SYSTEM_PROMPT = """
56 | Respond in the following format:
57 |
58 | ...
59 |
60 |
61 | ...
62 |
63 | """
64 |
65 |
66 |
67 | class GSM8KLoader(DataLoader):
68 | """
69 | A loader class that provides iteration over GSM8K math problems.
70 |
71 | This class implements both sequential and random access to math problems through
72 | standard Python iterator protocols. It can be used to iterate over problems either
73 | in order or randomly, making it suitable for both training and evaluation.
74 |
75 | Attributes:
76 | questions (List[str]): List of math question strings
77 | answers (List[str]): List of corresponding answer strings
78 | random (bool): If True, returns problems randomly; if False, returns sequentially
79 | current_index (int): Current position in the lists for sequential access
80 | """
81 |
82 | def __init__(self, questions: list[str], answers: list[str], random: bool = False) -> None:
83 | super().__init__(random)
84 | self.questions = questions
85 | self.answers = answers
86 | self.pre_prompt = """You will be given a question that involves reasoning. You should reason carefully about the question, then provide your answer.
87 | It is very important that you put your reasoning process inside tags and your final answer inside tags, like this:
88 |
89 |
90 |
91 | Your step-by-step reasoning process here
92 |
93 |
94 | Your final answer here
95 |
96 |
97 | All of your returned text should either be in the or tags - no text outside! Start each answer by immediately starting with .
98 | It is is extremely important you answer in this way - do not put any information or text outside of these tags!
99 |
100 | Question: """
101 | self.system_prompt = SYSTEM_PROMPT
102 |
103 | def __len__(self) -> int:
104 | return len(self.questions)
105 |
106 | def __iter__(self) -> 'GSM8KLoader':
107 | return self
108 |
109 | def __next__(self) -> tuple[str, str]:
110 | if self.current_index >= len(self.questions):
111 | raise StopIteration
112 |
113 | if self.random:
114 | idx = random.randint(0, len(self.questions) - 1)
115 | else:
116 | idx = self.current_index
117 | self.current_index += 1
118 |
119 | return self.questions[idx], self.answers[idx]
120 |
121 | def reset(self):
122 | self.current_index = 0
123 |
124 |
125 | def build_gsm8k_dataloaders() -> Tuple[GSM8KLoader, GSM8KLoader]:
126 | data = load_dataset('openai/gsm8k', 'main')["train"]
127 |
128 | questions = []
129 | parsed_answers = []
130 | for i in tqdm(range(len(data)), desc="Processing"):
131 | # Try to get answer - if is None dont use this sample
132 | ans = extract_hash_answer(data[i]['answer'])
133 | if ans is None:
134 | continue
135 | else:
136 | questions.append(data[i]['question'])
137 | parsed_answers.append(ans)
138 |
139 | # Randomly split into train/test sets
140 | total_samples = len(questions)
141 | test_size = int(total_samples * 0.01) # 10% for test set
142 |
143 | # Generate random indices for test set
144 | test_indices = random.sample(range(total_samples), test_size)
145 | test_indices_set = set(test_indices)
146 |
147 | # Convert to numpy arrays for easier indexing
148 | questions = np.array(questions)
149 | parsed_answers = np.array(parsed_answers)
150 |
151 | # Create boolean mask for test indices
152 | test_mask = np.zeros(total_samples, dtype=bool)
153 | test_mask[list(test_indices_set)] = True
154 |
155 | # Split using boolean indexing
156 | test_questions = questions[test_mask]
157 | test_answers = parsed_answers[test_mask]
158 | train_questions = questions[~test_mask]
159 | train_answers = parsed_answers[~test_mask]
160 |
161 | # Setup data loaders
162 | trainloader = GSM8KLoader(train_questions.tolist(), train_answers.tolist())
163 | testloader = GSM8KLoader(test_questions.tolist(), test_answers.tolist())
164 |
165 | return trainloader, testloader
166 |
167 |
168 | def get_dataloaders(dataset_name: str) -> Tuple[DataLoader, DataLoader]:
169 | """
170 | Factory function to get train and test data loaders for a specified dataset.
171 |
172 | Args:
173 | dataset_name (str): Name of the dataset to load ('gsm8k' currently supported)
174 |
175 | Returns:
176 | Tuple[DataLoader, DataLoader]: Train and test data loaders
177 |
178 | Raises:
179 | ValueError: If dataset_name is not supported
180 | """
181 | if dataset_name.lower() == 'gsm8k':
182 | return build_gsm8k_dataloaders()
183 | else:
184 | raise ValueError(f"Dataset {dataset_name} not supported. Currently only 'gsm8k' is available.")
185 |
186 |
187 | if __name__ == "__main__":
188 | trainloader, testloader = get_dataloaders('gsm8k')
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # python main.py --output_dir "final1" --verbose
2 | python plotter.py --log_dir "final1"
--------------------------------------------------------------------------------
/training_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brendanhogan/DeepSeekRL-Extended/09e312a07f634cbb6ee4aaf2d9f3372dca9e7b9b/training_score.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from typing import Any, Dict, Optional
7 |
8 | import re
9 |
10 | ####################
11 | ## MISC FUNCTIONS ##
12 | ####################
13 |
14 | def clean_spaces_preserve_newlines(text):
15 | # Replace multiple spaces with a single space, but preserve newlines
16 | lines = text.split("\n") # Split by newlines
17 | cleaned_lines = [" ".join(re.split(r"\s+", line)).strip() for line in lines] # Remove extra spaces in each line
18 | return "\n".join(cleaned_lines) # Join the lines back with newlines
19 |
20 |
21 |
22 | def seed_everything(seed: int) -> None:
23 | """
24 | Set random seed for reproducibility across multiple libraries.
25 |
26 | This function sets consistent random seeds for Python's random module,
27 | NumPy, PyTorch (both CPU and CUDA), and configures CUDNN for deterministic
28 | operation. This ensures reproducible results across multiple runs.
29 |
30 | Args:
31 | seed: The random seed to use for all random number generators
32 | """
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed_all(seed)
37 |
38 | # Additional settings for reproducibility
39 | torch.backends.cudnn.deterministic = True
40 | torch.backends.cudnn.benchmark = False
41 |
42 |
43 |
44 | def write_generation_log(log_data: Dict[str, Any], log_file: str) -> None:
45 | """
46 | Write generation log data to a text file.
47 |
48 | Args:
49 | log_data: Dictionary containing prompt and generation data
50 | log_file: Path to output log file
51 | """
52 | with open(log_file, 'w') as f:
53 | # Write prompt section
54 | f.write("###### ORIGINAL PROMPT #####\n\n")
55 | f.write(log_data['prompt']['text'] + "\n\n")
56 | f.write("#### ANS ####\n\n")
57 | f.write(str(log_data['prompt']['answer']) + "\n")
58 |
59 | # Write each generation
60 | for i, gen in enumerate(log_data['generations'], 1):
61 | f.write(f"#### GENERATION {i} RESPONSE ####\n\n")
62 | f.write(gen['response'] + "\n\n")
63 | f.write(f"#### GENERATION {i} SCORES ####\n")
64 |
65 | # Write individual scores
66 | f.write(f"Correctness: {gen['scores']['correctness']}\n")
67 | f.write(f"Integer format: {gen['scores']['integer_format']}\n")
68 | f.write(f"Strict format: {gen['scores']['strict_format']}\n")
69 | f.write(f"Soft format: {gen['scores']['soft_format']}\n")
70 | f.write(f"XML count: {gen['scores']['xml_count']}\n")
71 | f.write(f"Total reward: {gen['scores']['total_reward']}\n\n")
72 |
73 |
74 | ####################################################################################
75 | ## Copied Directly from TRL -> generate log probs per token ########
76 | ## https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py ########
77 | ####################################################################################
78 |
79 | def selective_log_softmax(logits, index):
80 | """
81 | A memory-efficient implementation of the common `log_softmax -> gather` operation.
82 |
83 | This function is equivalent to the following naive implementation:
84 | ```python
85 | logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
86 | ```
87 |
88 | Args:
89 | logits (`torch.Tensor`):
90 | Logits tensor of shape `(..., num_classes)`.
91 | index (`torch.Tensor`):
92 | Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
93 |
94 | Returns:
95 | `torch.Tensor`:
96 | Gathered log probabilities with the same shape as `index`.
97 | """
98 | if logits.dtype in [torch.float32, torch.float64]:
99 | selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
100 | # loop to reduce peak mem consumption
101 | logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
102 | per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
103 | else:
104 | # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
105 | per_token_logps = []
106 | for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
107 | row_logps = F.log_softmax(row_logits, dim=-1)
108 | row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
109 | per_token_logps.append(row_per_token_logps)
110 | per_token_logps = torch.stack(per_token_logps)
111 | return per_token_logps
112 |
113 | def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep):
114 | # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
115 | logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
116 | logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
117 |
118 | input_ids = input_ids[:, -logits_to_keep:]
119 | # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
120 | # See https://github.com/huggingface/trl/issues/2770
121 | logits = logits[:, -logits_to_keep:]
122 | return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
123 |
--------------------------------------------------------------------------------