├── .env.example ├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── data └── puzzles.json ├── lib ├── __init__.py ├── chat_completions.py ├── grpo.py ├── inference_early_stop.py ├── models.py ├── pack.py ├── recipe.py ├── stream.py ├── tasks.py ├── temporal_clue.py ├── tokenize.py ├── tqdm.py ├── tune.py ├── types.py ├── utils.py └── vllm.py ├── pyproject.toml ├── train.ipynb ├── train.py └── uv.lock /.env.example: -------------------------------------------------------------------------------- 1 | # Your Weights & Biases API key for experiment tracking 2 | WANDB_API_KEY=your_wandb_api_key_here 3 | 4 | # The name of your Weights & Biases project 5 | WANDB_PROJECT=your_project_name -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .venv 3 | __pycache__/ 4 | /logs 5 | /models 6 | /wandb 7 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 OpenPipe inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deductive Reasoning 2 | 3 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/674a1d102c0f27a385772cfe/JauBmEQM0FpOdShBMSfst.png) 4 | 5 | Train your own frontier-level deductive reasoning model with reinforcement learning. 6 | 7 | ## Overview 8 | 9 | This repository contains the training recipe for creating frontier-level deductive reasoning models using reinforcement learning. Our research demonstrates how smaller, open-weight language models can be trained to perform complex logical deduction tasks at frontier-level performance, matching or exceeding proprietary models at a fraction of the cost. 10 | 11 | We used the Temporal Clue puzzle dataset to train Qwen 14B and 32B models, improving their deductive reasoning capabilities significantly through Group Relative Policy Optimization (GRPO). Our trained models approach the performance of leading proprietary models like Claude 3.7 Sonnet while maintaining cost efficiency. 12 | 13 | ## Resources 14 | 15 | - **Training Recipe**: This repository (recipe for RL training) 16 | - **Training Dataset**: [Temporal Clue Puzzles](https://github.com/bradhilton/temporal-clue) 17 | - **RL Experiments**: [OpenPipe RL Experiments](https://github.com/openpipe/rl-experiments) 18 | - **Model Weights**: 19 | - [Deductive Reasoning Qwen 14B](https://huggingface.co/OpenPipe/Deductive-Reasoning-Qwen-14B) 20 | - [Deductive Reasoning Qwen 32B](https://huggingface.co/OpenPipe/Deductive-Reasoning-Qwen-32B) 21 | - **Blog Post**: [Using GRPO to Beat o1, o3-mini and R1 at "Temporal Clue"](https://openpipe.ai/blog/using-grpo-to-beat-o1-o3-mini-and-r1-on-temporal-clue) 22 | 23 | ## Getting Started 24 | 25 | Follow these steps to run the training recipe: 26 | 27 | ### Prerequisites 28 | 29 | - Sufficient NVIDIA GPUs for your chosen model: 30 | - Qwen 14B requires at least 2 GPUs 31 | - Qwen 32B requires at least 4 GPUs 32 | - [uv](https://github.com/astral-sh/uv) package manager 33 | - [Weights & Biases](https://wandb.ai) account 34 | 35 | ### Installation 36 | 37 | 1. Clone this repository: 38 | 39 | ```bash 40 | git clone https://github.com/bradhilton/deductive-reasoning.git 41 | cd deductive-reasoning 42 | ``` 43 | 44 | 2. Install dependencies using uv: 45 | 46 | ```bash 47 | uv sync 48 | ``` 49 | 50 | 3. Reinstall torchtune due to an executable naming conflict with Ray Tune 🙈 51 | 52 | ```bash 53 | uv remove torchtune 54 | uv add torchtune 55 | ``` 56 | 57 | 4. (Optional) Configure environment variables: 58 | 59 | ```bash 60 | cp .env.example .env 61 | ``` 62 | 63 | Edit the `.env` file to add your Weights & Biases API key and project name: 64 | 65 | ``` 66 | WANDB_API_KEY=your_wandb_api_key_here 67 | WANDB_PROJECT=your_project_name 68 | ``` 69 | 70 | ### Running the Training 71 | 72 | 1. Open the `train.ipynb` notebook or `train.py` script and configure the training parameters: 73 | 74 | - Set a unique `run_name` for your experiment 75 | - Choose the model (e.g., `models.qwen_14b()` or `models.qwen_32b()`) 76 | - Adjust other parameters as needed (learning rate, number of iterations, etc.) 77 | 78 | 2. Run the training: 79 | 80 | - If using the notebook: Execute all cells in `train.ipynb` 81 | - If using the script: Run `uv run train.py` 82 | 83 | 3. Monitor training progress in Weights & Biases. 84 | 85 | The training process will save the latest and/or best checkpoints in your output directory, allowing you to resume training if interrupted. 86 | 87 | ## Methodology 88 | 89 | Our training approach used reinforcement learning to incrementally improve models' deductive reasoning capabilities: 90 | 91 | 1. **Environment**: Temporal Clue puzzles (inspired by the board game Clue/Cluedo) with verifiable solutions 92 | 2. **Algorithm**: Group Relative Policy Optimization (GRPO) without KL divergence penalty 93 | 3. **Training Loop**: 94 | - Generate model responses to puzzle tasks 95 | - Grade responses and estimate advantages for each group of completions 96 | - Fine-tune the model using clipped policy gradients 97 | - Repeat with new puzzles until peak performance 98 | 99 | We used the torchtune library for efficient training and vLLM for inference, with the following key parameters: 100 | 101 | - Models: Qwen 2.5 Instruct 14B & 32B 102 | - Tasks per Iteration: 32 103 | - Samples per Task per Iteration: 50 104 | - Learning Rate: 6e-6 105 | 106 | ## Results 107 | 108 | Our training produced impressive performance gains, demonstrating that open-weight models can achieve frontier-level reasoning capabilities. 109 | 110 | ![image](https://github.com/user-attachments/assets/c405846e-3f19-4b0e-a4ac-02f16c015c54) 111 | 112 | We dramatically improved the cost-accuracy tradeoff compared to proprietary models: 113 | 114 | ![image](https://github.com/user-attachments/assets/5889e53e-7d11-4742-900d-5386aadc1983) 115 | 116 | Notably, we discovered that meaningful performance improvements (10-15%) can be achieved with as few as 16 training examples, making this approach accessible even with limited data. 117 | 118 | ## License 119 | 120 | This training recipe is freely available under the MIT license. 121 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | 3 | load_dotenv() 4 | -------------------------------------------------------------------------------- /lib/chat_completions.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime 3 | from openai import AsyncOpenAI 4 | from openai.types.chat.chat_completion import ChatCompletion 5 | from openai.types.chat.chat_completion_chunk import ChatCompletionChunk 6 | import os 7 | from typing import Callable, Unpack 8 | 9 | 10 | from .stream import consume_chat_completion_stream 11 | from .types import CreateParams 12 | from .utils import timeout 13 | 14 | MAX_INT = 2**31 - 1 15 | unlimited_semaphore = asyncio.Semaphore(MAX_INT) 16 | 17 | 18 | async def get_chat_completion( 19 | client: AsyncOpenAI, 20 | log_dir: str | None = None, 21 | log_results: bool = True, 22 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None = None, 23 | semaphore: asyncio.Semaphore | None = None, 24 | **create_params: Unpack[CreateParams], 25 | ) -> ChatCompletion: 26 | """ 27 | Given a client and arguments to openai.chat.completions.create, this function will return a chat completion with some additional features: 28 | - Logging of results to a file 29 | - Streaming of results to the callback function 30 | - Support for capping concurrent requests with a semaphore 31 | 32 | Args: 33 | client (AsyncOpenAI): An AsyncOpenAI client 34 | log_dir (str | None): The directory to log the results of the chat completion 35 | log_results (bool): Whether to log the results of the chat completion 36 | on_chunk (Callable[[ChatCompletionChunk, ChatCompletion], None]): A callback function that will be called with each chunk of the chat completion 37 | semaphore (asyncio.Semaphore): A semaphore to limit the number of concurrent requests 38 | 39 | Returns: 40 | ChatCompletion: A chat completion 41 | """ 42 | async with semaphore or unlimited_semaphore: 43 | on_chunk = _create_on_chunk_callback( 44 | create_params, log_dir, log_results, on_chunk 45 | ) 46 | if on_chunk: 47 | return await consume_chat_completion_stream( 48 | await client.chat.completions.create(**create_params, stream=True), 49 | on_chunk=on_chunk, 50 | ) 51 | else: 52 | return await client.chat.completions.create(**create_params) 53 | 54 | 55 | def _create_on_chunk_callback( 56 | create_params: CreateParams, 57 | log_dir: str | None, 58 | log_results: bool, 59 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None = None, 60 | ) -> Callable[[ChatCompletionChunk, ChatCompletion], None] | None: 61 | """Create a callback function for handling streaming chunks. 62 | 63 | This function sets up logging and wraps the user's callback function 64 | if provided, or returns None if no logging or callbacks are needed. 65 | 66 | Args: 67 | create_params: Parameters for the completion (used for conversation history) 68 | log_dir: Optional custom directory for logging 69 | log_results: Whether to log results 70 | on_chunk: Optional user callback for streaming 71 | 72 | Returns: 73 | A callback function or None if no callback is needed 74 | """ 75 | if not (log_results or on_chunk): 76 | return None 77 | 78 | # Set up logging 79 | log_dir = log_dir or "./logs/chat-completions" 80 | os.makedirs(log_dir, exist_ok=True) 81 | log_file = os.path.join(log_dir, f"{datetime.now().isoformat()}.log") 82 | 83 | # Write conversation history to the log file 84 | if log_results: 85 | with open(log_file, "w") as f: 86 | f.write( 87 | "".join( 88 | f"{message['role'].capitalize()}:\n{message.get('content', '')}\n\n" 89 | for message in create_params["messages"] 90 | ) 91 | + "Assistant:\n" 92 | ) 93 | 94 | # Create a callback function that handles both user callbacks and logging 95 | def callback(chunk: ChatCompletionChunk, completion: ChatCompletion) -> None: 96 | # Call user's callback if provided 97 | if on_chunk: 98 | on_chunk(chunk, completion) 99 | 100 | # Log chunk content if enabled 101 | if log_results and chunk.choices: 102 | try: 103 | with timeout(): 104 | with open(log_file, "a") as f: 105 | f.write(chunk.choices[0].delta.content or "") 106 | except TimeoutError: 107 | pass # Skip writing this chunk if it times out 108 | 109 | return callback 110 | -------------------------------------------------------------------------------- /lib/grpo.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, fields 2 | import torch 3 | from typing import Iterable, Optional, Union 4 | 5 | ignore_labels_cache: dict[ 6 | tuple[torch.Size, Union[int, float], torch.dtype, torch.device], torch.Tensor 7 | ] = {} 8 | 9 | 10 | def shift_tensor( 11 | labels: torch.Tensor, ignore_label: Optional[Union[int, float]] = None 12 | ) -> torch.Tensor: 13 | if ignore_label is None: 14 | ignore_label = ( 15 | -100 16 | if labels.dtype in (torch.int32, torch.int64, torch.int16, torch.int8) 17 | else float("nan") 18 | ) 19 | 20 | # Create a tensor of ignore labels every time if we are compiling, otherwise cache it 21 | if torch.compiler.is_compiling(): 22 | ignore_labels = torch.full( 23 | (labels.shape[0], 1), ignore_label, dtype=labels.dtype, device=labels.device 24 | ) 25 | else: 26 | key = (labels.shape[-1:], ignore_label, labels.dtype, labels.device) 27 | if key not in ignore_labels_cache: 28 | ignore_labels_cache[key] = torch.full( 29 | (labels.shape[0], 1), 30 | ignore_label, 31 | dtype=labels.dtype, 32 | device=labels.device, 33 | ) 34 | ignore_labels = ignore_labels_cache[key] 35 | 36 | # Shift labels to compute loss 37 | return torch.cat((labels[..., 1:], ignore_labels), dim=1) 38 | 39 | 40 | @dataclass 41 | class GRPOResult: 42 | num_tokens: torch.Tensor = field(default_factory=lambda: torch.tensor(0)) 43 | policy_loss: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0)) 44 | entropy: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0)) 45 | kl_div: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0)) 46 | entropy_weight: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0)) 47 | kl_weight: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0)) 48 | 49 | def named_tensors(self) -> Iterable[tuple[str, torch.Tensor]]: 50 | for field in fields(self): 51 | yield field.name, getattr(self, field.name) 52 | 53 | def per_token(self) -> "GRPOResult": 54 | return GRPOResult( 55 | **{name: tensor / self.num_tokens for name, tensor in self.named_tensors()} 56 | ) 57 | 58 | def tensors(self) -> Iterable[torch.Tensor]: 59 | return (tensor for _, tensor in self.named_tensors()) 60 | 61 | def to(self, target: Union[torch.device, torch.dtype]) -> "GRPOResult": 62 | return GRPOResult( 63 | **{name: tensor.to(target) for name, tensor in self.named_tensors()} 64 | ) 65 | 66 | def __iadd__(self, other: "GRPOResult") -> "GRPOResult": 67 | for tensor, other_tensor in zip(self.tensors(), other.tensors()): 68 | tensor += other_tensor.to(tensor.device) 69 | return self 70 | 71 | @property 72 | def total_loss(self) -> torch.Tensor: 73 | return ( 74 | self.policy_loss 75 | - self.entropy * self.entropy_weight / self.num_tokens 76 | + torch.nan_to_num(self.kl_div, 0.0) * self.kl_weight / self.num_tokens 77 | ) 78 | 79 | 80 | class GRPO(torch.nn.Module): 81 | def __init__( 82 | self, clip_epsilon: float = 0.2, entropy_coef: float = 0.0, kl_coef: float = 0.0 83 | ) -> None: 84 | """ 85 | Initialize the GRPO loss. 86 | 87 | Args: 88 | clip_epsilon (float): The epsilon value for clipping the policy ratio. 89 | entropy_coef (float): The coefficient for the entropy bonus. 90 | kl_coef (float): The coefficient for the KL divergence penalty. 91 | """ 92 | super().__init__() 93 | self.clip_epsilon = clip_epsilon 94 | self.entropy_coef = entropy_coef 95 | self.kl_coef = kl_coef 96 | 97 | def forward( 98 | self, 99 | logits: Union[torch.Tensor, list[torch.Tensor]], 100 | tokens: torch.Tensor, 101 | advantages: torch.Tensor, 102 | logprobs: torch.Tensor, 103 | reference_logprobs: torch.Tensor | None, 104 | mask: torch.Tensor, 105 | weights: torch.Tensor, 106 | bos_id: int, 107 | ) -> GRPOResult: 108 | """ 109 | Computes the GRPO loss for sequence data, supporting both regular and chunked inputs. 110 | 111 | Args: 112 | logits (Union[Tensor, List[Tensor]]): 113 | Either a single tensor of shape (batch_size, sequence_length, vocab_size) 114 | or a list of chunked tensors, each of shape 115 | (batch_size, sequence_length/num_chunks, vocab_size). 116 | tokens (Tensor): 117 | Shape: (batch_size, sequence_length) 118 | Token indices. 119 | advantages (Tensor): 120 | Shape: (batch_size, sequence_length) 121 | Token advantages. 122 | logprobs (Tensor): 123 | Shape: (batch_size, sequence_length) 124 | Token log probabilities. 125 | reference_logprobs (Tensor | None): 126 | Shape: (batch_size, sequence_length) 127 | Reference token log probabilities. 128 | mask (Tensor): 129 | Shape: (batch_size, sequence_length) 130 | Boolean mask specifying positions where loss should be computed. 131 | weights (Tensor): 132 | Shape: (batch_size, sequence_length) 133 | Weights for each token. 134 | bos_id (int): 135 | Index of the beginning of sequence token in the vocabulary. 136 | 137 | Returns: 138 | GRPOResult: The combined loss results across all chunks. 139 | """ 140 | if isinstance(logits, list): 141 | result = GRPOResult().to(logits[0].device) 142 | num_chunks = len(logits) 143 | for chunked_args in zip( 144 | logits, 145 | tokens.chunk(num_chunks, dim=1), 146 | advantages.chunk(num_chunks, dim=1), 147 | logprobs.chunk(num_chunks, dim=1), 148 | ( 149 | reference_logprobs.chunk(num_chunks, dim=1) 150 | if reference_logprobs is not None 151 | else [None] * num_chunks 152 | ), 153 | mask.chunk(num_chunks, dim=1), 154 | weights.chunk(num_chunks, dim=1), 155 | ): 156 | result += self._forward_chunk(*chunked_args, bos_id=bos_id) 157 | return result 158 | 159 | return self._forward_chunk( 160 | logits, 161 | tokens, 162 | advantages, 163 | logprobs, 164 | reference_logprobs, 165 | mask, 166 | weights, 167 | bos_id, 168 | ) 169 | 170 | def _forward_chunk( 171 | self, 172 | logits: torch.Tensor, 173 | tokens: torch.Tensor, 174 | advantages: torch.Tensor, 175 | logprobs: torch.Tensor, 176 | reference_logprobs: torch.Tensor | None, 177 | mask: torch.Tensor, 178 | weights: torch.Tensor, 179 | bos_id: int, 180 | ) -> GRPOResult: 181 | """ 182 | Processes a single chunk of the GRPO loss computation. 183 | """ 184 | # Flatten logits tensor to shape (batch_size * sequence_length, vocab_size) 185 | logits = logits.view(-1, logits.size(-1)) 186 | tokens = shift_tensor(tokens, bos_id).view( 187 | -1 188 | ) # (batch_size * sequence_length,) 189 | advantages = shift_tensor(advantages, 0).view( 190 | -1 191 | ) # (batch_size * sequence_length,) 192 | logprobs = shift_tensor(logprobs, 0).view(-1) # (batch_size * sequence_length,) 193 | if reference_logprobs is not None: 194 | reference_logprobs = shift_tensor(reference_logprobs, 0).view(-1) 195 | mask = shift_tensor(mask, False).view(-1) # (batch_size * sequence_length,) 196 | weights = shift_tensor(weights, 0).view(-1) # (batch_size * sequence_length,) 197 | num_tokens = mask.sum() 198 | dist = torch.distributions.Categorical(logits=logits) 199 | entropy = dist.entropy()[mask] 200 | new_logprobs = dist.log_prob(tokens)[mask] 201 | logprobs = logprobs[mask] 202 | logprobs = torch.where(torch.isnan(logprobs), new_logprobs, logprobs) 203 | if reference_logprobs is not None: 204 | reference_logprobs = reference_logprobs[mask] 205 | advantages = advantages[mask] 206 | diff = new_logprobs - logprobs 207 | prob_ratio = torch.exp(diff) 208 | policy_loss = -torch.min( 209 | prob_ratio * advantages, 210 | torch.clip(prob_ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) 211 | * advantages, 212 | ) 213 | if reference_logprobs is not None: 214 | kl_div = torch.nn.functional.kl_div( 215 | new_logprobs, 216 | reference_logprobs, 217 | reduction="none", 218 | log_target=True, 219 | ) 220 | else: 221 | kl_div = torch.tensor(torch.nan, device=logits.device) 222 | weights = weights[mask] 223 | return GRPOResult( 224 | num_tokens=num_tokens, 225 | policy_loss=policy_loss.mul(weights).sum(), 226 | entropy=entropy.mul(weights).sum(), 227 | kl_div=kl_div.mul(weights).sum(), 228 | entropy_weight=self.entropy_coef * num_tokens, 229 | kl_weight=self.kl_coef * num_tokens, 230 | ) 231 | -------------------------------------------------------------------------------- /lib/inference_early_stop.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import math 3 | from openai.types.chat.chat_completion import ChatCompletion 4 | from openai.types.chat.chat_completion_chunk import ChatCompletionChunk 5 | 6 | 7 | @dataclass 8 | class InferenceEarlyStop: 9 | """ 10 | Utility for stopping inference early if token log probabilities are too low. 11 | 12 | Args: 13 | alpha: The smoothing factor for the exponential weighted moving average. 14 | threshold: The log probability threshold to stop inference below. 15 | log_early_stops: Whether to log early stops. 16 | log_last_n_characters: The number of characters to log from the end of the stopped completion. 17 | """ 18 | 19 | alpha: float = 0.992 20 | threshold: float = -3 21 | log_early_stops: bool = False 22 | log_last_n_characters: int = 64 23 | ewm_logprobs: dict[str, float] = field(default_factory=dict) 24 | 25 | def __call__(self, chunk: ChatCompletionChunk, completion: ChatCompletion) -> None: 26 | # TODO: handle multiple choices and refusal logprobs 27 | if ( 28 | not chunk.choices 29 | or not chunk.choices[0].logprobs 30 | or not chunk.choices[0].logprobs.content 31 | ): 32 | return 33 | for token_logprob in chunk.choices[0].logprobs.content: 34 | if token_logprob.logprob is None or math.isnan(token_logprob.logprob): 35 | raise StopIteration() 36 | ewm_logprob = ( 37 | self.alpha * self.ewm_logprobs.get(completion.id, 0) 38 | + (1 - self.alpha) * token_logprob.logprob 39 | ) 40 | if ewm_logprob < self.threshold: 41 | if self.log_early_stops: 42 | print( 43 | f"Early stopping - ewm_logprob: {ewm_logprob} completion_tokens: {len(completion.choices[0].logprobs.content)}" # type: ignore 44 | ) 45 | if self.log_last_n_characters: 46 | print( 47 | f"Last {self.log_last_n_characters} characters: {completion.choices[0].message.content[-self.log_last_n_characters :]}" # type: ignore 48 | ) 49 | setattr(completion.choices[0], "early_stop", True) 50 | raise StopIteration() 51 | self.ewm_logprobs[completion.id] = ewm_logprob 52 | -------------------------------------------------------------------------------- /lib/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | from torchtune.models.qwen2_5 import ( 4 | qwen2_5_7b_base, 5 | qwen2_5_14b_base, 6 | qwen2_5_14b_instruct, 7 | qwen2_5_32b_base, 8 | qwen2_5_32b_instruct, 9 | qwen2_5_72b_instruct, 10 | ) 11 | from torchtune.models.llama3_1 import llama3_1_8b, llama3_1_70b 12 | from torchtune.modules import TransformerDecoder 13 | from typing import Callable 14 | 15 | 16 | @dataclass 17 | class Model: 18 | """Basic language model configuration""" 19 | 20 | base_model: str 21 | min_gpus: int 22 | tune_model_type: str 23 | tune_model: Callable[[], TransformerDecoder] 24 | tune_num_output_chunks: int 25 | 26 | def __post_init__(self) -> None: 27 | assert ( 28 | torch.cuda.device_count() >= self.min_gpus 29 | ), f"{self.base_model} requires at least {self.min_gpus} GPUs" 30 | 31 | 32 | def distilled_qwen_7b() -> Model: 33 | """deepseek-ai/DeepSeek-R1-Distill-Qwen-7B model config.""" 34 | return Model( 35 | base_model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 36 | min_gpus=1, 37 | tune_model_type="QWEN2", 38 | tune_model=qwen2_5_7b_base, 39 | tune_num_output_chunks=8, 40 | ) 41 | 42 | 43 | def theta_8b() -> Model: 44 | """NousResearch/Hermes-2-Theta-Llama-3-8B model config.""" 45 | return Model( 46 | base_model="NousResearch/Hermes-2-Theta-Llama-3-8B", 47 | min_gpus=1, 48 | tune_model_type="LLAMA3", 49 | tune_model=llama3_1_8b, 50 | tune_num_output_chunks=8, 51 | ) 52 | 53 | 54 | def qwen_14b() -> Model: 55 | """Qwen/Qwen2.5-14B-Instruct model config.""" 56 | return Model( 57 | base_model="Qwen/Qwen2.5-14B-Instruct", 58 | min_gpus=2, 59 | tune_model_type="QWEN2", 60 | tune_model=qwen2_5_14b_instruct, 61 | tune_num_output_chunks=2, 62 | ) 63 | 64 | 65 | def distilled_qwen_14b() -> Model: 66 | """deepseek-ai/DeepSeek-R1-Distill-Qwen-14B model config.""" 67 | return Model( 68 | base_model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", 69 | min_gpus=2, 70 | tune_model_type="QWEN2", 71 | tune_model=qwen2_5_14b_base, 72 | tune_num_output_chunks=2, 73 | ) 74 | 75 | 76 | def qwen_32b() -> Model: 77 | """Qwen/Qwen2.5-32B-Instruct model config.""" 78 | return Model( 79 | base_model="Qwen/Qwen2.5-32B-Instruct", 80 | min_gpus=4, 81 | tune_model_type="QWEN2", 82 | tune_model=qwen2_5_32b_instruct, 83 | tune_num_output_chunks=2, 84 | ) 85 | 86 | 87 | def distilled_qwen_32b() -> Model: 88 | """deepseek-ai/DeepSeek-R1-Distill-Qwen-32B model config.""" 89 | return Model( 90 | base_model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", 91 | min_gpus=4, 92 | tune_model_type="QWEN2", 93 | tune_model=qwen2_5_32b_base, 94 | tune_num_output_chunks=2, 95 | ) 96 | 97 | 98 | def llama_70b() -> Model: 99 | """unsloth/Llama-3.3-70B-Instruct model config.""" 100 | return Model( 101 | base_model="unsloth/Llama-3.3-70B-Instruct", 102 | min_gpus=8, 103 | tune_model_type="LLAMA3", 104 | tune_model=llama3_1_70b, 105 | tune_num_output_chunks=2, 106 | ) 107 | 108 | 109 | def distilled_llama_70b() -> Model: 110 | """deepseek-ai/DeepSeek-R1-Distill-Llama-70B model config.""" 111 | return Model( 112 | base_model="deepseek-ai/DeepSeek-R1-Distill-Llama-70B", 113 | min_gpus=8, 114 | tune_model_type="LLAMA3", 115 | tune_model=llama3_1_70b, 116 | tune_num_output_chunks=8, 117 | ) 118 | 119 | 120 | def qwen_72b() -> Model: 121 | """Qwen/Qwen2.5-72B-Instruct model config.""" 122 | return Model( 123 | base_model="Qwen/Qwen2.5-72B-Instruct", 124 | min_gpus=8, 125 | tune_model_type="QWEN2", 126 | tune_model=qwen2_5_72b_instruct, 127 | tune_num_output_chunks=2, 128 | ) 129 | -------------------------------------------------------------------------------- /lib/pack.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | import random 4 | import seaborn as sns 5 | import torch 6 | from torch.utils.data import Dataset 7 | from typing import TypedDict, Unpack 8 | 9 | from .tokenize import TokenizedResult 10 | 11 | 12 | class PackedTensors(TypedDict): 13 | tokens: torch.Tensor 14 | group_ids: torch.Tensor 15 | parent_ids: torch.Tensor 16 | input_pos: torch.Tensor 17 | assistant_mask: torch.Tensor 18 | logprobs: torch.Tensor 19 | advantages: torch.Tensor 20 | weights: torch.Tensor 21 | 22 | 23 | class DiskPackedTensors(TypedDict): 24 | dir: str 25 | num_sequences: int 26 | sequence_length: int 27 | 28 | 29 | class PackedDataset(Dataset[PackedTensors]): 30 | def __init__(self, **kwargs: Unpack[DiskPackedTensors]) -> None: 31 | self.tensors = packed_tensors_from_dir(**kwargs) 32 | 33 | def __len__(self) -> int: 34 | return self.tensors["tokens"].shape[0] 35 | 36 | def __getitem__(self, index: int) -> PackedTensors: 37 | return {key: tensor[index] for key, tensor in self.tensors.items()} # type: ignore 38 | 39 | 40 | def packed_tensors_from_tokenized_results( 41 | tokenized_results: list[TokenizedResult], 42 | seq_len: int, 43 | pad_token_id: int = -100, 44 | truncate_long_results: bool = True, 45 | ) -> PackedTensors: 46 | token_ids: list[list[int]] = [[]] 47 | group_ids: list[list[int]] = [[]] 48 | parent_ids: list[list[int]] = [[]] 49 | input_pos: list[list[int]] = [[]] 50 | assistant_mask: list[list[int]] = [[]] 51 | logprobs: list[list[float]] = [[]] 52 | advantages: list[list[float]] = [[]] 53 | weights: list[list[float]] = [[]] 54 | 55 | for result in tokenized_results: 56 | if len(result.token_ids) > seq_len and not truncate_long_results: 57 | print("Result is too long, skipping") 58 | continue 59 | result_without_prompt = result.without_prompt() 60 | if sum(result.assistant_mask) == 0: 61 | print("Result has no unique completion tokens, skipping") 62 | continue 63 | if ( 64 | len(token_ids[-1]) 65 | + ( 66 | len(result_without_prompt.token_ids) 67 | if result.prompt_id in group_ids[-1] 68 | else len(result.token_ids) 69 | ) 70 | > seq_len 71 | ): 72 | token_ids.append([]) 73 | group_ids.append([]) 74 | parent_ids.append([]) 75 | input_pos.append([]) 76 | assistant_mask.append([]) 77 | logprobs.append([]) 78 | advantages.append([]) 79 | weights.append([]) 80 | group_id = random.randint(-(2**63), 2**63 - 1) 81 | if result.prompt_id in group_ids[-1]: 82 | result = result_without_prompt 83 | token_ids[-1].extend(result.token_ids) 84 | group_ids[-1].extend( 85 | [result.prompt_id] * result.prompt_length 86 | + [group_id] * (len(result.token_ids) - result.prompt_length) 87 | ) 88 | parent_ids[-1].extend([result.prompt_id] * len(result.token_ids)) 89 | input_pos[-1].extend(result.input_pos) 90 | assistant_mask[-1].extend(result.assistant_mask) 91 | offset = len(logprobs[-1]) 92 | logprobs[-1].extend([float("nan")] * len(result.token_ids)) 93 | if result.token_logprobs: 94 | assistant_indices = [ 95 | i for i, mask in enumerate(result.assistant_mask) if mask 96 | ] 97 | assert len(assistant_indices) <= len(result.token_logprobs) 98 | for idx, token_logprob in zip( 99 | assistant_indices, 100 | result.token_logprobs[ 101 | len(result.token_logprobs) - len(assistant_indices) : 102 | ], 103 | ): 104 | logprobs[-1][idx + offset] = token_logprob.logprob or float("nan") 105 | advantages[-1].extend([result.advantage] * len(result.token_ids)) 106 | # prevent the model unlearning when to stop 107 | # advantages[-1][-1] = max(0, advantages[-1][-1]) 108 | advantages[-1][-1] = 0 109 | weights[-1].extend([1 / sum(result.assistant_mask)] * len(result.token_ids)) 110 | if truncate_long_results: 111 | token_ids[-1] = token_ids[-1][:seq_len] 112 | group_ids[-1] = group_ids[-1][:seq_len] 113 | parent_ids[-1] = parent_ids[-1][:seq_len] 114 | input_pos[-1] = input_pos[-1][:seq_len] 115 | assistant_mask[-1] = assistant_mask[-1][:seq_len] 116 | logprobs[-1] = logprobs[-1][:seq_len] 117 | advantages[-1] = advantages[-1][:seq_len] 118 | weights[-1] = weights[-1][:seq_len] 119 | 120 | def pad(values: list[list], pad_value) -> list[list]: 121 | max_len = seq_len 122 | for value in values: 123 | value.extend([pad_value] * (max_len - len(value))) 124 | return values 125 | 126 | assistant_mask_tensor = torch.tensor(pad(assistant_mask, 0), dtype=torch.bool) 127 | weights_tensor = torch.tensor(pad(weights, 0.0)) 128 | weights_tensor = torch.where( 129 | assistant_mask_tensor, weights_tensor, torch.zeros_like(weights_tensor) 130 | ) 131 | weights_tensor[assistant_mask_tensor] /= weights_tensor[ 132 | assistant_mask_tensor 133 | ].mean() 134 | 135 | return { 136 | "tokens": torch.tensor(pad(token_ids, pad_token_id)), 137 | "group_ids": torch.tensor(pad(group_ids, -1)), 138 | "parent_ids": torch.tensor(pad(parent_ids, -1)), 139 | "input_pos": torch.tensor(pad(input_pos, 0)), 140 | "assistant_mask": assistant_mask_tensor, 141 | "logprobs": torch.tensor(pad(logprobs, float("nan"))), 142 | "advantages": torch.tensor(pad(advantages, 0.0)), 143 | "weights": weights_tensor, 144 | } 145 | 146 | 147 | def packed_tensors_from_dir(**kwargs: Unpack[DiskPackedTensors]) -> PackedTensors: 148 | os.makedirs(kwargs["dir"], exist_ok=True) 149 | return { 150 | key: torch.from_file( 151 | f"{kwargs["dir"]}/{key}.pt", 152 | shared=True, 153 | size=kwargs["num_sequences"] 154 | * kwargs["sequence_length"] 155 | * (kwargs["sequence_length"] if key == "mask" else 1), 156 | dtype=dtype, 157 | ) 158 | .view(kwargs["num_sequences"], kwargs["sequence_length"], -1) 159 | .squeeze() 160 | for key, dtype in { 161 | "tokens": torch.long, 162 | "group_ids": torch.long, 163 | "parent_ids": torch.long, 164 | "input_pos": torch.long, 165 | "assistant_mask": torch.bool, 166 | "logprobs": torch.float32, 167 | "advantages": torch.float32, 168 | "weights": torch.float32, 169 | }.items() 170 | } # type: ignore 171 | 172 | 173 | def packed_tensors_to_dir(tensors: PackedTensors, dir: str) -> DiskPackedTensors: 174 | os.makedirs(dir, exist_ok=True) 175 | disk_packed_tensors: DiskPackedTensors = { 176 | "dir": dir, 177 | "num_sequences": tensors["tokens"].shape[0], 178 | "sequence_length": tensors["tokens"].shape[1], 179 | } 180 | for key, tensor in packed_tensors_from_dir(**disk_packed_tensors).items(): 181 | tensor.copy_(tensors[key]) # type: ignore 182 | return disk_packed_tensors 183 | 184 | 185 | def plot_packed_tensors(packed_tensors: PackedTensors) -> None: 186 | plt.figure(figsize=(15, 24)) 187 | 188 | for tensor, label, title, subplot_idx in ( 189 | (packed_tensors["tokens"], "Token IDs", "Token IDs", 1), 190 | (packed_tensors["logprobs"], "Log Probabilities", "Token Log Probs", 2), 191 | (packed_tensors["group_ids"], "Group IDs", "Token Groups", 3), 192 | (packed_tensors["parent_ids"], "Parent IDs", "Parent IDs", 4), 193 | (packed_tensors["input_pos"], "Position", "Input Position", 5), 194 | (packed_tensors["assistant_mask"], "Assistant Mask", "Assistant Mask", 6), 195 | (packed_tensors["advantages"], "Advantages", "Token Advantages", 7), 196 | (packed_tensors["weights"], "Weights", "Token Weights", 8), 197 | ): 198 | plt.subplot(4, 2, subplot_idx) 199 | sns.heatmap( 200 | tensor.numpy(), cmap="viridis", cbar_kws={"label": label}, xticklabels=False # type: ignore 201 | ) 202 | plt.title(title) 203 | plt.xlabel("Sequence Position") 204 | plt.ylabel("Batch") 205 | 206 | plt.tight_layout() 207 | plt.show() 208 | -------------------------------------------------------------------------------- /lib/recipe.py: -------------------------------------------------------------------------------- 1 | # This is probably most similar to the recipe found in 2 | # https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py 3 | 4 | from functools import partial 5 | import os 6 | from omegaconf import DictConfig, ListConfig 7 | import sys 8 | import time 9 | import torch 10 | import torch.distributed 11 | from torch.optim.optimizer import Optimizer 12 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 13 | from torchtune import config, modules, training, utils 14 | from torchtune.modules import TransformerDecoder 15 | from torchtune.recipe_interfaces import FTRecipeInterface 16 | from torchtune.training import DummyProfiler, PROFILER_KEY 17 | from torchtune.training.activations import apply_selective_activation_checkpointing 18 | from torchtune.training.checkpointing import Checkpointer 19 | from torchtune.training.metric_logging import MetricLoggerInterface 20 | from tqdm import tqdm 21 | from typing import ( 22 | Any, 23 | cast, 24 | Callable, 25 | Dict, 26 | Generic, 27 | Iterator, 28 | List, 29 | Mapping, 30 | Optional, 31 | overload, 32 | ParamSpec, 33 | Tuple, 34 | TypeVar, 35 | Union, 36 | ) 37 | from warnings import warn 38 | 39 | from .pack import PackedTensors 40 | from .grpo import GRPO, GRPOResult, shift_tensor 41 | 42 | log = utils.get_logger("DEBUG") 43 | 44 | T = TypeVar("T", covariant=True) 45 | P = ParamSpec("P") 46 | 47 | 48 | class ComponentConfig(DictConfig, Generic[T]): 49 | @overload 50 | def __init__( 51 | self, 52 | _component_: Callable[P, T], 53 | *args: P.args, 54 | **kwargs: P.kwargs, 55 | ) -> None: ... 56 | 57 | @overload 58 | def __init__(self, _component_: str, *args: Any, **kwargs: Any) -> None: ... 59 | 60 | def __init__( 61 | self, _component_: Union[Callable, str], *args: Any, **kwargs: Any 62 | ) -> None: 63 | super().__init__({}, flags={"allow_objects": True}) 64 | self._component_ = _component_ 65 | if args: 66 | raise ValueError( 67 | "Positional arguments are not supported in ComponentConfig" 68 | ) 69 | self.update(kwargs) 70 | 71 | def dict_config(self) -> DictConfig: 72 | return DictConfig( 73 | { 74 | "_component_": ( 75 | self._component_ 76 | if isinstance(self._component_, str) 77 | else f"{self._component_.__module__}.{self._component_.__name__}" 78 | ), 79 | **{k: v for k, v in self.items() if k != "_component_"}, 80 | } 81 | ) 82 | 83 | 84 | def instantiate_component(cfg: ComponentConfig[T], *args: Any, **kwargs: Any) -> T: 85 | if isinstance(cfg._component_, str): 86 | return config.instantiate(cfg, *args, **kwargs) 87 | _kwargs = { 88 | str(k): list(v) if isinstance(v, ListConfig) else v 89 | for k, v in cfg.items() 90 | if k != "_component_" 91 | } 92 | _kwargs.update(kwargs) 93 | return cfg._component_(*args, **_kwargs) 94 | 95 | 96 | PLACEHOLDER: Any = None 97 | 98 | 99 | class TuneRecipeConfig(DictConfig): 100 | def __init__( 101 | self, 102 | *, 103 | device: Optional[Union[str, torch.device]] = "cuda", 104 | dtype: Optional[Union[str, torch.dtype]] = "bf16", 105 | optimizer: ComponentConfig[Optimizer] = ComponentConfig( 106 | "torch.optim.AdamW", lr=2e-5, fused=True 107 | ), 108 | resume_from_checkpoint: bool = False, 109 | gradient_accumulation_steps: int = 1, 110 | checkpointer: ComponentConfig[Checkpointer] = PLACEHOLDER, 111 | seed: Optional[int] = None, 112 | epochs: int = 1, 113 | max_steps_per_epoch: Optional[int] = None, 114 | metric_logger: ComponentConfig[MetricLoggerInterface] = PLACEHOLDER, 115 | model: ComponentConfig[TransformerDecoder] = PLACEHOLDER, 116 | loss: ComponentConfig[GRPO] = ComponentConfig(GRPO), 117 | dataset: ComponentConfig[Dataset[PackedTensors]] = PLACEHOLDER, 118 | shuffle: bool = False, 119 | batch_size: int = 1, 120 | fsdp_cpu_offload: Optional[bool] = None, 121 | log_every_n_steps: Optional[int] = None, 122 | log_peak_memory_stats: Optional[bool] = None, 123 | log_grad_magnitude: Optional[bool] = None, 124 | optimizer_in_bwd: Optional[bool] = None, 125 | clip_grad_norm: Optional[Union[str, float]] = None, 126 | enable_activation_checkpointing: Optional[bool] = None, 127 | enable_activation_offloading: Optional[bool] = None, 128 | save_intermediate_checkpoints: Optional[bool] = None, 129 | reference_checkpointer: Optional[ComponentConfig[Checkpointer]] = None, 130 | compile: Optional[bool] = None, 131 | custom_sharded_layers: Optional[List[str]] = None, 132 | fsdp_reshard_after_forward: Optional[bool] = None, 133 | ac_mode: Optional[str] = None, 134 | ac_option: Optional[int] = None, 135 | num_output_chunks: Optional[int] = None, 136 | profiler: Optional[ComponentConfig] = None, 137 | ) -> None: 138 | super().__init__({}) 139 | self.device = device 140 | self.dtype = dtype 141 | self.optimizer = optimizer 142 | self.resume_from_checkpoint = resume_from_checkpoint 143 | self.gradient_accumulation_steps = gradient_accumulation_steps 144 | self.checkpointer = checkpointer 145 | self.seed = seed 146 | self.epochs = epochs 147 | self.max_steps_per_epoch = max_steps_per_epoch 148 | self.metric_logger = metric_logger 149 | self.model = model 150 | self.loss = loss 151 | self.dataset = dataset 152 | self.shuffle = shuffle 153 | self.batch_size = batch_size 154 | if fsdp_cpu_offload is not None: 155 | self.fsdp_cpu_offload = fsdp_cpu_offload 156 | if log_every_n_steps is not None: 157 | self.log_every_n_steps = log_every_n_steps 158 | if log_peak_memory_stats is not None: 159 | self.log_peak_memory_stats = log_peak_memory_stats 160 | if log_grad_magnitude is not None: 161 | self.log_grad_magnitude = log_grad_magnitude 162 | if optimizer_in_bwd is not None: 163 | self.optimizer_in_bwd = optimizer_in_bwd 164 | if clip_grad_norm is not None: 165 | self.clip_grad_norm = clip_grad_norm 166 | if enable_activation_checkpointing is not None: 167 | self.enable_activation_checkpointing = enable_activation_checkpointing 168 | if enable_activation_offloading is not None: 169 | self.enable_activation_offloading = enable_activation_offloading 170 | if save_intermediate_checkpoints is not None: 171 | self.save_intermediate_checkpoints = save_intermediate_checkpoints 172 | if reference_checkpointer is not None: 173 | self.reference_checkpointer = reference_checkpointer 174 | if compile is not None: 175 | self.compile = compile 176 | if custom_sharded_layers is not None: 177 | self.custom_sharded_layers = custom_sharded_layers 178 | if fsdp_reshard_after_forward is not None: 179 | self.fsdp_reshard_after_forward = fsdp_reshard_after_forward 180 | if ac_mode is not None: 181 | self.ac_mode = ac_mode 182 | if ac_option is not None: 183 | self.ac_option = ac_option 184 | if num_output_chunks is not None: 185 | self.num_output_chunks = num_output_chunks 186 | if profiler is not None: 187 | self.profiler = profiler 188 | 189 | def dict_config(self) -> DictConfig: 190 | config = DictConfig({}) 191 | for k, v in self.items(): 192 | if isinstance(v, DictConfig) and "_component_" in v: 193 | v = v.copy() 194 | v["_component_"] = ( 195 | v["_component_"] 196 | if isinstance(v["_component_"], str) 197 | else f"{v['_component_'].__module__}.{v['_component_'].__name__}" 198 | ) 199 | config[k] = v 200 | elif isinstance(v, torch.device): 201 | config[k] = str(v) 202 | elif isinstance(v, torch.dtype): 203 | config[k] = str(v) 204 | else: 205 | config[k] = v 206 | return config 207 | 208 | 209 | class TypedDataLoader(DataLoader[T]): 210 | def __iter__(self) -> Iterator[T]: 211 | return super().__iter__() 212 | 213 | 214 | class TuneRecipe(FTRecipeInterface): 215 | """ 216 | Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports 217 | distributed training and can be run on a single node (1 to 8 GPUs). 218 | 219 | Features: 220 | - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states 221 | is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is 222 | done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config 223 | ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). 224 | DDP is currently not supported. Training on CPU is not supported. 225 | 226 | - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` 227 | flag. Activation checkpointing helps reduce the memory footprint since we no longer keep 228 | activations in memory and instead recompute them during the backward pass. This is especially 229 | helpful for larger batch sizes when you're memory constrained. But these savings in memory 230 | come at the cost of training performance. In most cases training can slow-down quite a bit as 231 | a result of this activation recomputation. 232 | 233 | - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` 234 | flag. Activation offloading is a technique similar to activations checkpointing that helps 235 | reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations 236 | checkpointing drops the activation in the forward to recompute it later in the backward, 237 | activations offloading will drop the activation in the forward to the CPU and bring it 238 | back during the backward pass. As always, there is a tradeoff--these savings in memory can 239 | come at the cost of training performance and CPU resources. To recover some runtime cost, 240 | we've added an option to enable offloading on a different stream to permit overlapping with 241 | the computation. This option is currently only available on PyTorch 2.5 or later and will 242 | be enabled by default if an acceptable torch version is found. Activation offloading can be 243 | used in conjunction with activation checkpointing. 244 | 245 | - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` 246 | flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In 247 | most cases this should halve the memory footprint of full precision (fp32) training, without 248 | loss in model quality (will depend on the model, training data and other settings). For 249 | GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 250 | precision are currently not supported. 251 | 252 | - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is 253 | controlled using the ``gradient_accumulation_steps`` flag. 254 | 255 | Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. 256 | 257 | For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a 258 | total batch size of 64. 259 | 260 | Gradient accumulation is especially useful when you are memory constrained. In this case, 261 | accumulating gradients might give you better training speed than enabling activation 262 | checkpointing. 263 | 264 | - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of 265 | training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are 266 | only saved at the end of a given epoch and used in case of resuming training. 267 | 268 | Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is 269 | currently not supported. 270 | 271 | For more details on the checkpointer, please take a look at 272 | our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). 273 | 274 | - Logging. Terminal, Disk, WandB and TensorBoard are all supported. 275 | 276 | - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, 277 | ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set 278 | ``clip_grad_norm='inf'``. 279 | 280 | For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config 281 | has example commands for how to kick-off training. 282 | 283 | Args: 284 | cfg (DictConfig): OmegaConf object parsed from yaml file 285 | 286 | Raises: 287 | ValueError: If ``dtype`` is set to fp16. 288 | RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. 289 | RuntimeError: If ``left_pad_sequence`` is set as the data collator. 290 | RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. 291 | RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. 292 | """ 293 | 294 | def __init__(self, cfg: TuneRecipeConfig) -> None: 295 | self._device = ( 296 | cfg.device 297 | if isinstance(cfg.device, torch.device) 298 | else utils.get_device(device=cfg.device) 299 | ) 300 | self._dtype = ( 301 | cfg.dtype 302 | if isinstance(cfg.dtype, torch.dtype) 303 | else training.get_dtype(cfg.dtype, device=self._device) 304 | ) 305 | 306 | if self._dtype == torch.float16: 307 | raise ValueError( 308 | "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." 309 | ) 310 | 311 | if ( 312 | cfg.get("fsdp_cpu_offload", False) 313 | and cfg.optimizer.get("fused", False) 314 | and not utils.torch_version_ge("2.4.0") 315 | ): 316 | raise RuntimeError( 317 | "Using fused optimizer on CPU is only supported in PyTorch nightly." 318 | ) 319 | 320 | # logging attributes 321 | self._log_every_n_steps: int = cfg.get("log_every_n_steps", 1) 322 | self._log_peak_memory_stats: bool = cfg.get("log_peak_memory_stats", False) 323 | self._log_grad_magnitude: bool = cfg.get("log_grad_magnitude", False) 324 | 325 | if self._log_peak_memory_stats and self._device.type != "cuda": 326 | log.info( 327 | "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." 328 | ) 329 | self._log_peak_memory_stats = False 330 | 331 | # _is_rank_zero is used primarily for logging. In the future, the logger 332 | # should directly take care of this 333 | _, rank = training.get_world_size_and_rank() 334 | self._is_rank_zero = rank == 0 335 | 336 | # Training cfg 337 | self._resume_from_checkpoint = cfg.resume_from_checkpoint 338 | self._gradient_accumulation_steps = cfg.gradient_accumulation_steps 339 | self._optimizer_in_bwd: bool = cfg.get("optimizer_in_bwd", False) 340 | self._clip_grad_norm: Optional[Union[str, float]] = cfg.get( 341 | "clip_grad_norm", None 342 | ) 343 | 344 | # Optimizer in backward is not compatible with gradient accumulation or gradient clipping 345 | if self._optimizer_in_bwd: 346 | if self._clip_grad_norm is not None: 347 | raise RuntimeError( 348 | "Gradient clipping is not supported with optimizer in bwd." 349 | "Please set clip_grad_norm=None, or optimizer_in_bwd=False." 350 | ) 351 | if self._gradient_accumulation_steps > 1: 352 | raise RuntimeError( 353 | "Gradient accumulation is not supported with optimizer in bwd." 354 | "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." 355 | ) 356 | 357 | # activation checkpointing/offloading 358 | self._enable_activation_checkpointing: bool = cfg.get( 359 | "enable_activation_checkpointing", False 360 | ) 361 | self._enable_activation_offloading: bool = cfg.get( 362 | "enable_activation_offloading", False 363 | ) 364 | if self._enable_activation_offloading: 365 | if self._device.type != "cuda": 366 | raise RuntimeError( 367 | "enable_activation_offloading should only be True when training on CUDA" 368 | ) 369 | if not self._enable_activation_checkpointing: 370 | raise RuntimeError( 371 | "enable_activation_offloading should only be True when enable_activation_checkpointing is True" 372 | ) 373 | elif ( 374 | self._enable_activation_checkpointing 375 | and cfg.checkpointer.model_type # TODO: `model_type` type is not defined 376 | != "LLAMA3_VISION" 377 | ): 378 | utils.log_rank_zero( 379 | log, 380 | "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " 381 | "Enabling activation offloading should reduce memory further.", 382 | ) 383 | 384 | # These are public properties which are updated by the checkpoint loader 385 | # when ``resume_from_checkpoint`` is `True` or validated in tests 386 | self.seed = training.set_seed(seed=cfg.seed) 387 | self.epochs_run = 0 388 | self.total_epochs = cfg.epochs 389 | self.max_steps_per_epoch = cfg.max_steps_per_epoch 390 | self.global_step = 0 391 | self._save_intermediate_checkpoints = cfg.get( 392 | "save_intermediate_checkpoints", False 393 | ) 394 | 395 | def load_checkpoint( 396 | self, cfg_checkpointer: ComponentConfig[Checkpointer] 397 | ) -> Dict[str, Any]: 398 | """ 399 | Extract the checkpoint state from file and validate. If resume_from_checkpoint 400 | is True, this also includes the recipe state. 401 | """ 402 | self._checkpointer = instantiate_component( 403 | cfg_checkpointer, 404 | resume_from_checkpoint=self._resume_from_checkpoint, 405 | ) 406 | checkpoint_dict = self._checkpointer.load_checkpoint() 407 | 408 | if self._resume_from_checkpoint: 409 | self._update_recipe_state(checkpoint_dict) 410 | return checkpoint_dict 411 | 412 | def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: 413 | """ 414 | Updates the recipe state from checkpoint. 415 | """ 416 | try: 417 | self.epochs_run = ckpt_dict[training.EPOCHS_KEY] 418 | 419 | # on mismatch, warn the user and prevent the override 420 | if self.seed != ckpt_dict[training.SEED_KEY]: 421 | warn( 422 | message=( 423 | "Config value for seed does not match the checkpoint value, " 424 | f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" 425 | ) 426 | ) 427 | self.seed = ckpt_dict[training.SEED_KEY] 428 | if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: 429 | warn( 430 | message=( 431 | "Config value for max_steps_per_epoch does not match the checkpoint value, " 432 | f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" 433 | ) 434 | ) 435 | self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] 436 | 437 | # on mismatch, warn the user but allow the override 438 | if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: 439 | warn( 440 | message=( 441 | "Config value for total_epochs does not match the checkpoint value, " 442 | f"using the config value: {self.total_epochs}" 443 | ) 444 | ) 445 | 446 | except KeyError as e: 447 | raise KeyError( 448 | "Checkpoint does not contain the required keys needed for updating recipe state. " 449 | "Are you sure you passed in the right recipe checkpoint?" 450 | ) from e 451 | 452 | def setup(self, cfg: TuneRecipeConfig) -> None: 453 | """ 454 | Setup the recipe. This includes training state (if resume_from_checkpoint is True), 455 | model, tokenizer, loss, optimizer, sampler, and dataloader. 456 | """ 457 | if self._is_rank_zero: 458 | self._metric_logger = instantiate_component(cfg.metric_logger) 459 | 460 | # log config with parameter override 461 | self._metric_logger.log_config(cfg) 462 | 463 | checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) 464 | if reference_checkpointer_cfg := cfg.get("reference_checkpointer", None): 465 | self.reference_model_state_dict = instantiate_component( 466 | reference_checkpointer_cfg 467 | ).load_checkpoint()[training.MODEL_KEY] 468 | else: 469 | self.reference_model_state_dict = None 470 | 471 | self._compile: bool = cfg.get("compile", False) 472 | if self._compile: 473 | torch.empty(1, device=self._device, requires_grad=True).backward() 474 | self._model = self._setup_model( 475 | cfg_model=cfg.model, 476 | enable_activation_checkpointing=self._enable_activation_checkpointing, 477 | enable_activation_offloading=self._enable_activation_offloading, 478 | custom_sharded_layers=cfg.get("custom_sharded_layers", None), 479 | fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), 480 | reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), 481 | model_state_dict=checkpoint_dict[training.MODEL_KEY], 482 | reference_model_state_dict=self.reference_model_state_dict, 483 | ac_mode=cfg.get("ac_mode", None), 484 | ac_option=cfg.get("ac_option", None), 485 | ) 486 | self._model.output_hidden_states = [len(self._model.layers) - 1] 487 | 488 | if self.reference_model_state_dict: 489 | # pin reference model state 490 | for value in self.reference_model_state_dict.values(): 491 | if not isinstance(value, torch.distributed._tensor.DTensor): # type: ignore 492 | value.pin_memory() 493 | 494 | self._optimizer = self._setup_optimizer( 495 | cfg_optimizer=cfg.optimizer, 496 | optimizer_in_bwd=self._optimizer_in_bwd, 497 | opt_state_dict=( 498 | checkpoint_dict[training.OPT_KEY] 499 | if self._resume_from_checkpoint 500 | else None 501 | ), 502 | ) 503 | 504 | # initialize loss 505 | self._loss_fn = instantiate_component(cfg.loss) 506 | 507 | if self._compile: 508 | if self._is_rank_zero: 509 | log.info("Compiling loss with torch.compile...") 510 | self._loss_fn._forward_chunk = torch.compile( 511 | self._loss_fn._forward_chunk, 512 | backend=os.environ.get("TORCH_COMPILE_BACKEND", "inductor"), 513 | ) 514 | 515 | if cfg.get("num_output_chunks", None) is not None: 516 | # set num_output_chunks for model 517 | self._model.set_num_output_chunks(cfg.num_output_chunks) 518 | 519 | if self._is_rank_zero: 520 | log.info("Loss is initialized.") 521 | 522 | # sampler and dataloader depend on the tokenizer and loss_fn and should be 523 | # setup after both of these are initialized 524 | self._sampler, self._dataloader = self._setup_data( 525 | cfg_dataset=cfg.dataset, 526 | shuffle=cfg.shuffle, 527 | batch_size=cfg.batch_size, 528 | ) 529 | 530 | # Finally update the recipe state which can only be correctly set after all of the 531 | # other components have been initialized and updated. 532 | # 533 | # Number of training steps in each epoch depends on the number of batches produced 534 | # by the dataloader, the max_steps_per_epoch param set by the user and the 535 | # gradient_accumulation_steps param. This value is used for logging and tracking 536 | # training state. The computation should happen after the dataloader has been setup 537 | self._steps_per_epoch = ( 538 | len(self._dataloader) // self._gradient_accumulation_steps 539 | ) 540 | if ( 541 | self.max_steps_per_epoch is not None 542 | and self.max_steps_per_epoch < self._steps_per_epoch 543 | ): 544 | self._steps_per_epoch = self.max_steps_per_epoch 545 | self.global_step = self.epochs_run * self._steps_per_epoch 546 | 547 | # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) 548 | # if cfg is missing profiler key or if `cfg.profiler.enabled = False` 549 | self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) 550 | 551 | def _setup_profiler( 552 | self, cfg_profiler: Optional[DictConfig] = None 553 | ) -> Union[torch.profiler.profile, DummyProfiler]: 554 | """ 555 | Parses the `profiler` section of top-level `cfg` and sets up profiler 556 | 557 | Args: 558 | cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to 559 | `recipe.main`). Default None. 560 | 561 | Returns: 562 | profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods 563 | for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such 564 | that the instrumented training loop does not need to be changed profiling is disabled. 565 | 566 | The profiler config can be provided in configs under the `profiler` key with the following layout: 567 | 568 | .. code-block:: yaml 569 | profiler: 570 | enabled: bool 571 | 572 | #Output directory of trace artifacts 573 | output_dir: str 574 | 575 | #`torch.profiler.ProfilerActivity` types to trace 576 | cpu: bool 577 | cuda: bool 578 | 579 | #Trace options 580 | profile_memory: bool 581 | with_stack: bool 582 | record_shapes: bool 583 | with_flops: bool 584 | 585 | # `torch.profiler.schedule` options: 586 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat 587 | wait_steps: int 588 | warmup_steps: int 589 | active_steps: int 590 | num_cycles: int 591 | """ 592 | # Missing profiler section in config, assume disabled 593 | if cfg_profiler is None: 594 | cfg_profiler = DictConfig({"enabled": False}) 595 | 596 | # Check that component is included and set correctly 597 | if cfg_profiler.get("_component_", None) is None: 598 | cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" 599 | else: 600 | assert ( 601 | cfg_profiler.get("_component_") 602 | == "torchtune.training.setup_torch_profiler" 603 | ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" 604 | 605 | profiler, profiler_cfg = config.instantiate(cfg_profiler) 606 | 607 | if self._is_rank_zero: 608 | log.info(f" Profiler config after instantiation: {profiler_cfg}") 609 | 610 | self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) 611 | if profiler_cfg["enabled"]: 612 | self.profiler_wait_steps = profiler_cfg["wait_steps"] 613 | self.profiler_warmup_steps = profiler_cfg["warmup_steps"] 614 | self.profiler_active_steps = profiler_cfg["active_steps"] 615 | 616 | return profiler 617 | 618 | def _setup_model( 619 | self, 620 | cfg_model: ComponentConfig[TransformerDecoder], 621 | enable_activation_checkpointing: bool, 622 | enable_activation_offloading: bool, 623 | fsdp_cpu_offload: bool, 624 | reshard_after_forward: bool, 625 | model_state_dict: Dict[str, Any], 626 | reference_model_state_dict: Optional[Dict[str, Any]] = None, 627 | custom_sharded_layers: Optional[List[str]] = None, 628 | ac_mode: Optional[str] = None, 629 | ac_option: Optional[int] = None, 630 | ) -> TransformerDecoder: 631 | """ 632 | Model initialization has some important considerations: 633 | a. To minimize GPU peak memory, we initialize the model on meta device with 634 | the right dtype 635 | b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since 636 | full state dicts are loaded with ``torch.load(mmap=True)`` 637 | """ 638 | 639 | if self._is_rank_zero: 640 | log.info( 641 | "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." 642 | ) 643 | init_start = time.perf_counter() 644 | else: 645 | init_start = 0.0 646 | 647 | with ( 648 | training.set_default_dtype(self._dtype), 649 | torch.device("meta") if training.is_distributed() else self._device, 650 | ): 651 | model = instantiate_component(cfg_model) 652 | 653 | if self._compile: 654 | training.compile_model(model, verbose=self._is_rank_zero) 655 | 656 | # We currently have two versions of activation checkpointing in this recipe 657 | # for testing and BC purposes. ``enable_activation_checkpointing`` controls 658 | # the older version of AC and this behavior is unchanged 659 | # ac_mode and ac_option together control selective AC. This is only enabled 660 | # when these are set AND ``enable_activation_checkpointing`` is set to False 661 | # We'll clean this up as soon as testing of AC is complete 662 | if (not enable_activation_checkpointing) and (ac_mode is not None): 663 | apply_selective_activation_checkpointing( 664 | model, 665 | ac_mode, 666 | ac_option, 667 | ) 668 | 669 | # original activation checkpointing (full) - flip the condition above 670 | if enable_activation_checkpointing and ac_mode is None: 671 | training.set_activation_checkpointing( 672 | model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} 673 | ) 674 | 675 | if training.is_distributed(): 676 | # For FSDP sharding 677 | fsdp_shard_conditions = [ 678 | partial( 679 | training.get_shard_conditions, 680 | names_to_match=custom_sharded_layers, 681 | ) 682 | ] 683 | training.shard_model( 684 | model=model, 685 | shard_conditions=fsdp_shard_conditions, 686 | cpu_offload=fsdp_cpu_offload, 687 | reshard_after_forward=reshard_after_forward, 688 | ) 689 | 690 | with training.set_default_dtype(self._dtype), self._device: 691 | for m in model.modules(): 692 | # RoPE is not covered in state dict 693 | if hasattr(m, "rope_init"): 694 | m.rope_init() # type: ignore 695 | 696 | # This method will convert the full model state dict into a sharded state 697 | # dict and load into the model 698 | training.load_from_full_model_state_dict( 699 | model, 700 | model_state_dict, 701 | self._device, 702 | self._is_rank_zero, 703 | strict=True, 704 | cpu_offload=fsdp_cpu_offload, 705 | ) 706 | 707 | if reference_model_state_dict: 708 | # Temporarily patch model.load_state_dict to capture the sharded parameter tensors 709 | # for this rank when loading the reference model. This allows us to maintain a 710 | # reference copy of the sharded parameters that matches the FSDP sharding pattern, 711 | # which is needed for weight swapping during training. 712 | load_state_dict = model.load_state_dict 713 | 714 | def patch( 715 | state_dict: Mapping[str, Any], 716 | strict: bool = True, 717 | assign: bool = False, 718 | ) -> Any: 719 | reference_model_state_dict.clear() 720 | reference_model_state_dict.update(state_dict) 721 | 722 | model.load_state_dict = patch 723 | 724 | training.load_from_full_model_state_dict( 725 | model, 726 | reference_model_state_dict, 727 | self._device, 728 | self._is_rank_zero, 729 | strict=True, 730 | cpu_offload=fsdp_cpu_offload, 731 | ) 732 | 733 | model.load_state_dict = load_state_dict 734 | else: 735 | model.load_state_dict(model_state_dict) 736 | 737 | # Validate model was loaded in with the expected dtype. 738 | training.validate_expected_param_dtype( 739 | model.named_parameters(), dtype=self._dtype 740 | ) 741 | 742 | # activation offloading 743 | self.activations_handling_ctx = training.get_act_offloading_ctx_manager( 744 | model, enable_activation_offloading 745 | ) 746 | 747 | # Ensure no params and buffers are on meta device 748 | training.validate_no_params_on_meta_device(model) 749 | 750 | if self._is_rank_zero: 751 | log.info( 752 | f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" 753 | ) 754 | memory_stats = training.get_memory_stats(device=self._device) 755 | training.log_memory_stats(memory_stats) 756 | 757 | if training.is_distributed(): 758 | # synchronize before training begins 759 | torch.distributed.barrier() 760 | 761 | return model 762 | 763 | def _setup_optimizer( 764 | self, 765 | cfg_optimizer: ComponentConfig[Optimizer], 766 | optimizer_in_bwd: bool = False, 767 | opt_state_dict: Optional[Dict[str, Any]] = None, 768 | ) -> Optional[Optimizer]: 769 | if optimizer_in_bwd: 770 | # Maintain a dict of optims for every parameter. 771 | optim_dict = { 772 | param: instantiate_component(cfg_optimizer, params=[param]) 773 | for param in self._model.parameters() 774 | } 775 | 776 | # Register optimizer step hooks on the model to run optimizer in backward. 777 | training.register_optim_in_bwd_hooks( 778 | model=self._model, optim_dict=optim_dict 779 | ) 780 | # Create a wrapper for checkpoint save/load of optimizer states when running in backward. 781 | self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( 782 | model=self._model, optim_dict=optim_dict 783 | ) 784 | # Load optimizer states for each param. If optimizer states are being restored in an optimizer in 785 | # backward run, these need to have been saved with the same setting. Cannot restore from runs that 786 | # did not use optimizer in backward. 787 | if opt_state_dict is not None: 788 | for param in opt_state_dict.keys(): 789 | try: 790 | training.load_from_full_optimizer_state_dict( 791 | self._optim_ckpt_wrapper.state_dict()[param], 792 | opt_state_dict[param], 793 | self._device, 794 | ) 795 | except BaseException as e: 796 | raise RuntimeError( 797 | "Failed loading in-backward optimizer checkpoints." 798 | "Please make sure run being restored from was using in-backward optimizer." 799 | ) from e 800 | if self._is_rank_zero: 801 | log.info("In-backward optimizers are set up.") 802 | return None 803 | else: 804 | optimizer = instantiate_component( 805 | cfg_optimizer, params=self._model.parameters() 806 | ) 807 | if opt_state_dict: 808 | training.load_from_full_optimizer_state_dict( 809 | optimizer, 810 | opt_state_dict, 811 | self._device, 812 | ) 813 | 814 | if self._is_rank_zero: 815 | log.info("Optimizer is initialized.") 816 | return optimizer 817 | 818 | def _setup_data( 819 | self, 820 | cfg_dataset: ComponentConfig[Dataset[PackedTensors]], 821 | shuffle: bool, 822 | batch_size: int, 823 | ) -> Tuple[DistributedSampler, TypedDataLoader[PackedTensors]]: 824 | """ 825 | All data related setup happens here. Currently this recipe only supports the 826 | DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, 827 | iterable datasets and streaming datasets are not supported. 828 | """ 829 | world_size, rank = training.get_world_size_and_rank() 830 | 831 | ds = instantiate_component(cfg_dataset) 832 | 833 | sampler = DistributedSampler( 834 | ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=self.seed or 0 835 | ) 836 | dataloader = TypedDataLoader( 837 | dataset=ds, 838 | batch_size=batch_size, 839 | sampler=sampler, 840 | # dropping last avoids shape issues with compile + flex attention 841 | drop_last=True, 842 | ) 843 | 844 | if self._is_rank_zero: 845 | log.info("Dataset and Sampler are initialized.") 846 | 847 | return sampler, dataloader 848 | 849 | def save_checkpoint( 850 | self, 851 | epoch: int, 852 | ) -> None: 853 | """ 854 | Checkpoint the state of the recipe. The constructed checkpoint state dict 855 | contains the following information: 856 | - Model weights with key training.MODEL_KEY 857 | - Relevant recipe state if training is not complete 858 | 859 | Checkpointer will save the model weights and recipe state in 860 | different checkpoint files. To correctly resume training from an intermediate checkpoint, 861 | the model weights and recipe state must be provided. 862 | """ 863 | # final dict passed onto the checkpointer 864 | checkpoint_dict = {} 865 | 866 | intermediate_checkpoint = epoch + 1 < self.total_epochs 867 | 868 | if intermediate_checkpoint and not self._save_intermediate_checkpoints: 869 | return 870 | 871 | if self._is_rank_zero: 872 | log.info( 873 | "Saving checkpoint. This may take some time. Retrieving full model state dict..." 874 | ) 875 | start = time.perf_counter() 876 | else: 877 | start = 0.0 878 | 879 | # To prevent GPU memory from spiking during checkpoint save, 880 | # we consolidate the full model and optim state dicts on CPU for rank 0 881 | model_state_dict = ( 882 | training.gather_cpu_state_dict( 883 | self._model.state_dict(), 884 | self._is_rank_zero, 885 | device=self._device, 886 | ) 887 | if training.is_distributed() 888 | else self._model.state_dict() 889 | ) 890 | 891 | if self._is_rank_zero: 892 | log.info( 893 | f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" 894 | ) 895 | 896 | if intermediate_checkpoint: 897 | start = time.perf_counter() 898 | if self._is_rank_zero: 899 | log.info("Getting optimizer state dict...") 900 | if self._optimizer: 901 | opt_state_dict = ( 902 | training.get_full_optimizer_state_dict( 903 | self._optimizer, 904 | self._is_rank_zero, 905 | device=self._device, 906 | ) 907 | if training.is_distributed() 908 | else self._optimizer.state_dict() 909 | ) 910 | else: 911 | opt_state_dict = {} 912 | for param, opt in self._optim_ckpt_wrapper.optim_map.items(): 913 | opt_state_dict[param] = ( 914 | training.get_full_optimizer_state_dict( 915 | opt, self._is_rank_zero, device=self._device 916 | ) 917 | if training.is_distributed() 918 | else opt.state_dict() 919 | ) 920 | if self._is_rank_zero: 921 | log.info( 922 | f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" 923 | ) 924 | else: 925 | opt_state_dict = None 926 | 927 | # Now that we have the model and opt state dict, create the actual checkpoint dict 928 | # to be sent to the checkpointer and ultimately written to file 929 | 930 | if self._is_rank_zero: 931 | start = time.perf_counter() 932 | checkpoint_dict.update({training.MODEL_KEY: model_state_dict}) 933 | 934 | # if training is in-progress, checkpoint the optimizer state and recipe state 935 | # as well. 936 | if intermediate_checkpoint: 937 | checkpoint_dict.update( 938 | { 939 | training.OPT_KEY: opt_state_dict, 940 | training.SEED_KEY: self.seed, 941 | training.EPOCHS_KEY: self.epochs_run, 942 | training.TOTAL_EPOCHS_KEY: self.total_epochs, 943 | training.MAX_STEPS_KEY: self.max_steps_per_epoch, 944 | } 945 | ) 946 | 947 | self._checkpointer.save_checkpoint( 948 | checkpoint_dict, 949 | epoch=epoch, 950 | intermediate_checkpoint=intermediate_checkpoint, 951 | ) 952 | log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") 953 | 954 | if training.is_distributed(): 955 | torch.distributed.barrier() 956 | 957 | def train(self) -> None: 958 | """ 959 | The core training loop. 960 | """ 961 | # clean up before training begins 962 | training.cleanup_before_training() 963 | torch.autograd.set_detect_anomaly(True) 964 | 965 | world_size, rank = training.get_world_size_and_rank() 966 | 967 | # zero out the gradients before starting training 968 | if self._optimizer: 969 | self._optimizer.zero_grad() 970 | else: 971 | for opt in self._optim_ckpt_wrapper.optim_map.values(): 972 | opt.zero_grad() 973 | 974 | # Initialize tokens count and running loss (for grad accumulation) 975 | t0 = time.perf_counter() 976 | running_result = GRPOResult().to(self._device) 977 | 978 | self._profiler.start() 979 | # self.epochs_run should be non-zero when we're resuming from a checkpoint 980 | for curr_epoch in range(self.epochs_run, self.total_epochs): 981 | # Update the sampler to ensure data is correctly shuffled across epochs 982 | # in case shuffle is True 983 | self._sampler.set_epoch(curr_epoch) 984 | 985 | pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) 986 | grad_norm: torch.Tensor | None = None 987 | for idx, batch in enumerate(self._dataloader): 988 | if ( 989 | self.max_steps_per_epoch is not None 990 | and (idx // self._gradient_accumulation_steps) 991 | == self.max_steps_per_epoch 992 | ): 993 | break 994 | 995 | # Start tracking CUDA memory for active steps for just the first epoch 996 | if ( 997 | self._is_rank_zero 998 | and curr_epoch == 0 999 | and self.profiler_profile_memory 1000 | and idx == self.profiler_wait_steps + self.profiler_warmup_steps 1001 | ): 1002 | torch.cuda.memory._record_memory_history() 1003 | 1004 | utils.batch_to_device(batch, self._device) # type: ignore - `batch_to_device` expects a `dict`, not a `TypedDict`, but this should be fine 1005 | 1006 | # Assume the first token in the batch is the bos token 1007 | bos_id = int(batch["tokens"].view(-1)[0].item()) 1008 | 1009 | # Create grouped causal mask 1010 | batch_size, seq_len = batch["tokens"].size() 1011 | causal_mask = ( 1012 | torch.tril( 1013 | torch.ones( 1014 | seq_len, seq_len, dtype=torch.bool, device=self._device 1015 | ) 1016 | ) 1017 | .unsqueeze(0) 1018 | .expand(batch_size, seq_len, seq_len) 1019 | ) 1020 | group_mask = batch["group_ids"].unsqueeze(2) == batch[ 1021 | "group_ids" 1022 | ].unsqueeze(1) 1023 | parent_mask = batch["parent_ids"].unsqueeze(2) == batch[ 1024 | "group_ids" 1025 | ].unsqueeze(1) 1026 | mask = causal_mask & (group_mask | parent_mask) 1027 | 1028 | if self.reference_model_state_dict: 1029 | # Save current weights and load reference weights 1030 | model_state_dict = self._swap_state(self.reference_model_state_dict) 1031 | 1032 | # Run reference model forward pass without affecting autograd 1033 | with torch.no_grad(), self.activations_handling_ctx: 1034 | hidden_states, logits = self._model( 1035 | tokens=batch["tokens"], 1036 | mask=mask, 1037 | input_pos=batch["input_pos"], 1038 | ) 1039 | del hidden_states 1040 | if isinstance(logits, list): 1041 | reference_logprobs = torch.cat( 1042 | [ 1043 | torch.distributions.Categorical( 1044 | logits=logits_chunk 1045 | ).log_prob( 1046 | shift_tensor(tokens, ignore_label=bos_id) 1047 | ) 1048 | for logits_chunk, tokens in zip( 1049 | logits, 1050 | batch["tokens"].chunk(len(logits), dim=1), 1051 | ) 1052 | ], 1053 | dim=-1, 1054 | ) 1055 | else: 1056 | reference_logprobs = cast( 1057 | torch.Tensor, 1058 | torch.distributions.Categorical(logits=logits).log_prob( 1059 | shift_tensor(batch["tokens"], ignore_label=bos_id) 1060 | ), 1061 | ) 1062 | del logits 1063 | 1064 | # Restore original weights 1065 | self._swap_state(model_state_dict) 1066 | else: 1067 | reference_logprobs = None 1068 | 1069 | with self.activations_handling_ctx: 1070 | hidden_states, logits = self._model( 1071 | tokens=batch["tokens"], 1072 | mask=mask, 1073 | input_pos=batch["input_pos"], 1074 | ) 1075 | del mask, batch["input_pos"], hidden_states # type: ignore 1076 | 1077 | # Compute loss 1078 | current_result = self._loss_fn.forward( 1079 | logits=logits, 1080 | tokens=batch["tokens"], 1081 | advantages=batch["advantages"], 1082 | logprobs=batch["logprobs"], 1083 | reference_logprobs=reference_logprobs, 1084 | mask=batch["assistant_mask"], 1085 | weights=batch["weights"], 1086 | bos_id=bos_id, 1087 | ) 1088 | del logits, batch 1089 | 1090 | running_result += current_result 1091 | 1092 | # For optimizer in backward, we need to normalize before calling backward 1093 | # This case and gradient accumulation are mutually exclusive 1094 | if self._optimizer_in_bwd: 1095 | if training.is_distributed(): 1096 | for tensor in running_result.tensors(): 1097 | torch.distributed.all_reduce(tensor) 1098 | current_loss = current_result.total_loss / current_result.num_tokens 1099 | else: 1100 | current_loss = current_result.total_loss 1101 | 1102 | current_loss.backward() 1103 | del current_loss 1104 | 1105 | # Step with optimizer 1106 | if (idx + 1) % self._gradient_accumulation_steps == 0: 1107 | if self._optimizer: 1108 | if training.is_distributed(): 1109 | for tensor in running_result.tensors(): 1110 | torch.distributed.all_reduce(tensor) 1111 | # Manually scale the gradients from unnormalized loss by total # of tokens 1112 | training.scale_grads(self._model, 1 / running_result.num_tokens) 1113 | 1114 | # Calculate gradient magnitude (L2 norm of all gradients) 1115 | grad_magnitude = None 1116 | if self._log_grad_magnitude: 1117 | grad_magnitude = torch.norm( 1118 | torch.stack( 1119 | [ 1120 | torch.norm(p.grad.detach()) 1121 | for p in self._model.parameters() 1122 | if p.grad is not None 1123 | ] 1124 | ) 1125 | ) 1126 | 1127 | if self._clip_grad_norm is not None: 1128 | grad_norm = torch.nn.utils.clip_grad_norm_( 1129 | self._model.parameters(), 1130 | max_norm=float(self._clip_grad_norm), 1131 | ) 1132 | self._optimizer.step() 1133 | self._optimizer.zero_grad(set_to_none=True) 1134 | 1135 | # Update the number of steps when the weights are updated 1136 | self.global_step += 1 1137 | 1138 | per_token_result = running_result.per_token() 1139 | loss_to_log = per_token_result.total_loss.item() 1140 | policy_loss_to_log = per_token_result.policy_loss.item() 1141 | entropy_to_log = per_token_result.entropy.item() 1142 | kl_div_to_log = per_token_result.kl_div.item() 1143 | pbar.update(1) 1144 | pbar.set_description( 1145 | f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log:.4f}" 1146 | ) 1147 | postfix = { 1148 | "loss": loss_to_log, 1149 | "policy": policy_loss_to_log, 1150 | "entropy": entropy_to_log, 1151 | "kl_div": kl_div_to_log, 1152 | } 1153 | if self._log_grad_magnitude and grad_magnitude is not None: 1154 | postfix["grad_magnitude"] = grad_magnitude.item() 1155 | pbar.set_postfix(postfix) 1156 | 1157 | # Log per-step metrics 1158 | if ( 1159 | self.global_step % self._log_every_n_steps == 0 1160 | and self._is_rank_zero 1161 | ): 1162 | time_per_step = time.perf_counter() - t0 1163 | log_dict = { 1164 | "loss": loss_to_log, 1165 | "policy": policy_loss_to_log, 1166 | "entropy": entropy_to_log, 1167 | "kl_div": kl_div_to_log, 1168 | # "lr": get_lr(self._optimizer or self._optim_ckpt_wrapper), 1169 | "tokens_per_second_per_gpu": running_result.num_tokens 1170 | / (time_per_step * world_size), 1171 | } 1172 | if self._log_grad_magnitude and grad_magnitude is not None: 1173 | log_dict["grad_magnitude"] = grad_magnitude.item() 1174 | if self._log_peak_memory_stats: 1175 | log_dict.update( 1176 | training.get_memory_stats(device=self._device) 1177 | ) 1178 | if self._clip_grad_norm is not None: 1179 | log_dict.update({"grad_norm": grad_norm}) 1180 | self._metric_logger.log_dict( 1181 | log_dict, 1182 | step=self.global_step, 1183 | ) 1184 | 1185 | # Reset running stats for the next step 1186 | del running_result 1187 | running_result = GRPOResult().to(self._device) 1188 | t0 = time.perf_counter() 1189 | 1190 | # Stop tracking CUDA memory now that active steps are complete 1191 | if ( 1192 | self._is_rank_zero 1193 | and curr_epoch == 0 1194 | and self.profiler_profile_memory 1195 | and idx 1196 | == self.profiler_wait_steps 1197 | + self.profiler_warmup_steps 1198 | + self.profiler_active_steps 1199 | ): 1200 | torch.cuda.memory._record_memory_history( 1201 | # Pylance infers the type of `enabled` as `str` though the function accepts `Literal[None, "state", "all"]` 1202 | enabled=None # type: ignore 1203 | ) 1204 | 1205 | # Step profiler 1206 | # Note that this is called within gradient accumulation block, hence 1207 | # will include multiple forward / backward passes if gradient accumulation > 1 1208 | self._profiler.step() 1209 | 1210 | self.epochs_run += 1 1211 | self.save_checkpoint(epoch=curr_epoch) 1212 | 1213 | self._profiler.stop() 1214 | 1215 | def cleanup(self) -> None: 1216 | if self._is_rank_zero: 1217 | self._metric_logger.close() 1218 | if training.is_distributed(): 1219 | torch.distributed.destroy_process_group() 1220 | training.cleanup_before_training() 1221 | 1222 | def _swap_state( 1223 | self, state_dict: Dict[str, Any], assign: bool = False 1224 | ) -> Dict[str, Any]: 1225 | """ 1226 | Swaps the current model state with the provided state dict. 1227 | Manages GPU memory by moving states to CPU/device in the right order. 1228 | 1229 | Args: 1230 | state_dict: Dictionary of state to load into model 1231 | 1232 | Returns: 1233 | Original model state dict (moved to CPU) 1234 | """ 1235 | # Save current model state and move to CPU 1236 | current_state = { 1237 | k: v.to("cpu", non_blocking=True) 1238 | for k, v in self._model.state_dict().items() 1239 | } 1240 | 1241 | # Move input state to device and load 1242 | device_state = { 1243 | k: v.to(self._device, non_blocking=True) for k, v in state_dict.items() 1244 | } 1245 | self._model.load_state_dict(device_state, assign=assign) 1246 | 1247 | return current_state 1248 | 1249 | 1250 | def recipe_main(cfg: TuneRecipeConfig) -> None: 1251 | """ 1252 | Entry point for the recipe. 1253 | 1254 | Configurable parameters are read in the following order: 1255 | - Parameters specified in config (see available configs through ``tune ls``) 1256 | - Overwritten by arguments from the command-line 1257 | """ 1258 | if not training.is_distributed(): 1259 | log.debug( 1260 | "Training is not distributed. If you want to train on multiple GPUs and are using the tune CLI, specify --nnodes 1 and --nproc_per_node [num_gpus]" 1261 | ) 1262 | elif not torch.distributed.is_initialized(): 1263 | torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo") 1264 | 1265 | if cfg.get("fsdp_cpu_offload", False): 1266 | # Utilize all available CPU cores for intra-op parallelism. This provides ~2x 1267 | # speed up when benchmarking fused AdamW on CPU 1268 | training.set_torch_num_threads() 1269 | 1270 | config.log_config( 1271 | recipe_name="FullFinetuneRecipe", 1272 | cfg=cfg.dict_config() if isinstance(cfg, TuneRecipeConfig) else cfg, 1273 | ) 1274 | 1275 | recipe = TuneRecipe(cfg=cfg) 1276 | recipe.setup(cfg=cfg) 1277 | recipe.train() 1278 | recipe.cleanup() 1279 | 1280 | 1281 | if __name__ == "__main__": 1282 | sys.exit(config.parse(recipe_main)()) # type: ignore 1283 | -------------------------------------------------------------------------------- /lib/stream.py: -------------------------------------------------------------------------------- 1 | # https://gist.github.com/bradhilton/ec2450881abb0d0ef6a1bea7f8ca824f 2 | from openai import AsyncStream 3 | from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs 4 | from openai.types.chat.chat_completion_chunk import ChatCompletionChunk 5 | from openai.types.chat.chat_completion_message import ( 6 | ChatCompletionMessage, 7 | FunctionCall, 8 | ) 9 | from openai.types.chat.chat_completion_message_tool_call import ( 10 | ChatCompletionMessageToolCall, 11 | Function, 12 | ) 13 | from typing import Any, Callable 14 | 15 | 16 | async def consume_chat_completion_stream( 17 | stream: AsyncStream[ChatCompletionChunk], 18 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], Any] | None = None, 19 | ) -> ChatCompletion: 20 | """Consume a chat completion stream and build a complete ChatCompletion object. 21 | 22 | This function processes a stream of ChatCompletionChunks, constructing a complete 23 | ChatCompletion object as if it was returned from a non-streaming API call. 24 | Works with any OpenAI-compatible API implementation. 25 | 26 | Args: 27 | stream: An AsyncStream of ChatCompletionChunk objects. 28 | on_chunk: Optional callback that receives each chunk and the current state of the 29 | ChatCompletion. If the callback raises StopIteration, the stream will close early. 30 | 31 | Returns: 32 | A complete ChatCompletion object built from the streamed chunks. 33 | 34 | Raises: 35 | AssertionError: If no chat completion object could be created. 36 | """ 37 | chat_completion: ChatCompletion | None = None 38 | async for chunk in stream: 39 | if chat_completion is None: 40 | chat_completion = ChatCompletion( 41 | id=chunk.id, 42 | choices=[ 43 | Choice( 44 | finish_reason="stop", 45 | index=choice.index, 46 | logprobs=(ChoiceLogprobs() if choice.logprobs else None), 47 | message=ChatCompletionMessage(role="assistant"), 48 | ) 49 | for choice in chunk.choices 50 | ], 51 | created=chunk.created, 52 | model=chunk.model, 53 | object="chat.completion", 54 | ) 55 | for choice, chunk_choice in zip(chat_completion.choices, chunk.choices): 56 | choice.finish_reason = chunk_choice.finish_reason or "stop" 57 | if chunk_choice.logprobs: 58 | if choice.logprobs is None: 59 | choice.logprobs = ChoiceLogprobs() 60 | if chunk_choice.logprobs.content: 61 | if choice.logprobs.content is None: 62 | choice.logprobs.content = [] 63 | choice.logprobs.content.extend(chunk_choice.logprobs.content) 64 | if chunk_choice.logprobs.refusal: 65 | if choice.logprobs.refusal is None: 66 | choice.logprobs.refusal = [] 67 | choice.logprobs.refusal.extend(chunk_choice.logprobs.refusal) 68 | if chunk_choice.delta.content: 69 | if choice.message.content is None: 70 | choice.message.content = "" 71 | choice.message.content += chunk_choice.delta.content 72 | if chunk_choice.delta.refusal: 73 | if choice.message.refusal is None: 74 | choice.message.refusal = "" 75 | choice.message.refusal += chunk_choice.delta.refusal 76 | if chunk_choice.delta.function_call: 77 | if choice.message.function_call is None: 78 | choice.message.function_call = FunctionCall(arguments="", name="") 79 | choice.message.function_call.name += ( 80 | chunk_choice.delta.function_call.name or "" 81 | ) 82 | choice.message.function_call.arguments += ( 83 | chunk_choice.delta.function_call.arguments or "" 84 | ) 85 | if chunk_choice.delta.tool_calls: 86 | if choice.message.tool_calls is None: 87 | choice.message.tool_calls = [] 88 | for tool_call in chunk_choice.delta.tool_calls: 89 | if not tool_call.index in range(len(choice.message.tool_calls)): 90 | choice.message.tool_calls.append( 91 | ChatCompletionMessageToolCall( 92 | id="", 93 | function=Function(arguments="", name=""), 94 | type="function", 95 | ) 96 | ) 97 | choice.message.tool_calls[tool_call.index].id += tool_call.id or "" 98 | choice.message.tool_calls[tool_call.index].function.name += ( 99 | tool_call.function.name or "" if tool_call.function else "" 100 | ) 101 | choice.message.tool_calls[tool_call.index].function.arguments += ( 102 | tool_call.function.arguments or "" if tool_call.function else "" 103 | ) 104 | if getattr(chunk_choice.delta, "reasoning", None): 105 | if not hasattr(choice.message, "reasoning"): 106 | setattr(choice.message, "reasoning", "") 107 | setattr( 108 | choice.message, 109 | "reasoning", 110 | getattr(choice.message, "reasoning") 111 | + getattr(chunk_choice.delta, "reasoning"), 112 | ) 113 | chat_completion.service_tier = chunk.service_tier 114 | chat_completion.system_fingerprint = chunk.system_fingerprint 115 | chat_completion.usage = chunk.usage 116 | if on_chunk: 117 | try: 118 | on_chunk(chunk, chat_completion) 119 | except StopIteration: 120 | await stream.close() 121 | break 122 | assert chat_completion is not None 123 | return chat_completion 124 | -------------------------------------------------------------------------------- /lib/tasks.py: -------------------------------------------------------------------------------- 1 | from aioitertools.helpers import maybe_await 2 | import asyncio 3 | from dataclasses import dataclass, field 4 | import numpy as np 5 | from openai import AsyncOpenAI 6 | from openai.types.chat.chat_completion import ChatCompletion, Choice 7 | from openai.types.chat.chat_completion_chunk import ChatCompletionChunk 8 | from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam 9 | from openai.types.completion_usage import CompletionUsage 10 | import random 11 | from typing import Awaitable, Callable, TypeVar 12 | 13 | from .chat_completions import get_chat_completion 14 | from .types import ChatCompletionParams 15 | from .tqdm import tqdm 16 | 17 | # A grader function returns a floating-point reward, 18 | # and optionally a dictionary of metrics as the second return value 19 | Grade = float | tuple[float, dict[str, float]] 20 | Grader = Callable[[Choice], Grade | Awaitable[Grade]] 21 | 22 | 23 | @dataclass 24 | class Task: 25 | """ 26 | A minimal task definition. 27 | 28 | Args: 29 | messages (list[ChatCompletionMessageParam]): OpenAI API compatible chat messages for prompting the model 30 | grader (Grader): A grader function to score the model's responses 31 | """ 32 | 33 | messages: list[ChatCompletionMessageParam] 34 | grader: Grader 35 | 36 | 37 | @dataclass 38 | class TaskResult: 39 | """ 40 | A single task result. 41 | 42 | Args: 43 | task (Task): The task that was graded 44 | chat_completions (list[ChatCompletion]): The chat completions generated for the task 45 | rewards (dict[tuple[str, int], float]): Rewards for each chat completion and choice index 46 | metrics (dict[tuple[str, int], dict[str, float]]): Metrics for each chat completion and choice index 47 | advantages (dict[tuple[str, int], float]): GRPO advantages for each chat completion and choice index 48 | exceptions (list[Exception]): Exceptions that occurred while getting the task result 49 | """ 50 | 51 | task: Task 52 | chat_completions: list[ChatCompletion] 53 | rewards: dict[tuple[str, int], float] 54 | metrics: dict[tuple[str, int], dict[str, float]] 55 | advantages: dict[tuple[str, int], float] 56 | exceptions: list[Exception] 57 | 58 | 59 | @dataclass 60 | class TaskResultStats: 61 | """ 62 | Statistics for task results. 63 | 64 | Args: 65 | pbar (tqdm.tqdm): The progress bar 66 | prices (tuple[float, float] | None): Prices for input/output tokens 67 | completion_tokens (int): Total completion tokens 68 | exceptions (list[Exception]): Exceptions that occurred while getting the task results 69 | grades (int): Number of grades 70 | new_completion_ids (set[str]): Set of new completion IDs 71 | new_completion_tokens (int): Total new completion tokens 72 | new_prompt_tokens (int): Total new prompt tokens 73 | prompt_tokens (int): Total prompt tokens 74 | token_logprobs (int): Total token log probabilities 75 | total_metrics (dict[str, float]): Total metrics 76 | total_reward (float): Total reward 77 | usages (int): Total usages 78 | """ 79 | 80 | pbar: tqdm.tqdm 81 | prices: tuple[float, float] | None 82 | completion_tokens: int = 0 83 | exceptions: list[Exception] = field(default_factory=list) 84 | grades: int = 0 85 | new_completion_ids: set[str] = field(default_factory=set) 86 | new_completion_tokens: int = 0 87 | new_prompt_tokens: int = 0 88 | prompt_tokens: int = 0 89 | token_logprobs: int = 0 90 | total_metrics: dict[str, float] = field(default_factory=dict) 91 | total_reward: float = 0 92 | usages: int = 0 93 | 94 | def __del__(self) -> None: 95 | self.pbar.close() 96 | 97 | def update( 98 | self, 99 | *, 100 | id: str | None, 101 | chunk: ChatCompletionChunk | None = None, 102 | usage: CompletionUsage | None = None, 103 | reward: float | None = None, 104 | metrics: dict[str, float] | None = None, 105 | exception: Exception | None = None, 106 | ) -> None: 107 | if chunk: 108 | if id is not None: 109 | self.new_completion_ids.add(id) 110 | self.token_logprobs += sum( 111 | len(choice.logprobs.content or choice.logprobs.refusal or []) 112 | for choice in chunk.choices 113 | if choice.logprobs 114 | ) 115 | elif usage: 116 | self.completion_tokens += usage.completion_tokens 117 | self.prompt_tokens += usage.prompt_tokens 118 | self.usages += 1 119 | if id in self.new_completion_ids: 120 | self.new_completion_tokens += usage.completion_tokens 121 | self.new_prompt_tokens += usage.prompt_tokens 122 | elif reward is not None: 123 | self.grades += 1 124 | self.total_reward += reward 125 | self.pbar.update() 126 | if metrics: 127 | for key, value in metrics.items(): 128 | if key not in self.total_metrics: 129 | self.total_metrics[key] = 0 130 | self.total_metrics[key] += value 131 | elif exception: 132 | self.exceptions.append(exception) 133 | postfix = { 134 | "completion_tokens": round(self.completion_tokens / max(self.usages, 1)), 135 | "prompt_tokens": round(self.prompt_tokens / max(self.usages, 1)), 136 | "reward": self.total_reward / max(self.grades, 1), 137 | } 138 | for key, value in self.total_metrics.items(): 139 | postfix[key] = value / max(self.grades, 1) 140 | if self.prices: 141 | postfix["spend"] = ( 142 | f"${( 143 | self.new_prompt_tokens * self.prices[0] 144 | + (self.token_logprobs or self.new_completion_tokens) * self.prices[1] 145 | ) / 1_000_000:.2f}" 146 | ) 147 | if self.token_logprobs: 148 | postfix["token_logprobs"] = self.token_logprobs 149 | if self.exceptions: 150 | postfix["exceptions"] = len(self.exceptions) 151 | self.pbar.set_postfix(postfix) 152 | 153 | 154 | T = TypeVar("T") 155 | 156 | 157 | class TaskResults(list[T]): 158 | stats: TaskResultStats 159 | 160 | 161 | async def get_task_results( 162 | tasks: list[Task], 163 | client: AsyncOpenAI, 164 | model: str, 165 | clear_pbar: bool = False, 166 | print_pbar: bool = True, 167 | log_dir: str | None = None, 168 | log_results: bool | float | int = True, 169 | log_token_logprobs: bool = True, 170 | n: int = 1, 171 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None = None, 172 | params: ChatCompletionParams | None = None, 173 | pbar_desc: str | None = None, 174 | pbar_position: int | None = None, 175 | prices: tuple[float, float] | None = None, 176 | semaphore: asyncio.Semaphore | None = None, 177 | transform: Callable[[TaskResult], T | Awaitable[T]] = lambda x: x, 178 | ) -> TaskResults[T]: 179 | """ 180 | Returns results for tasks using an AsyncOpenAI client for a given model. Includes support for caching, rate limiting, and logging. Results may be optionally transformed. 181 | 182 | Args: 183 | tasks (list[Task]): List of Task objects, each containing messages to send to the LLM and a grader function 184 | client (AsyncOpenAI): Any valid AsyncOpenAI client that supports creating chat completions (may be pointed at API providers other than OpenAI or a local inference engine like vLLM) 185 | model (str): Name of the chat completion model to use 186 | clear_pbar (bool): Whether to clear the progress bar after completion 187 | print_pbar (bool): Whether to print the progress bar summary after completion 188 | log_dir (str | None): Directory to save completion logs to. If None, will use the default chat completion log directory 189 | log_results (bool | float | int): Controls which task results to log. Can be a boolean, float (fraction), or int (count) 190 | log_token_logprobs (bool): Whether to stream token log probabilities count to the progress bar 191 | n (int): Number of chat completions to sample per task 192 | on_chunk (Callable[[ChatCompletionChunk, ChatCompletion], None] | None): Optional callback function for processing completion chunks 193 | params (ChatCompletionParams | None): Additional parameters to pass to the chat completion API 194 | pbar_desc (str | None): Description to display on the progress bar 195 | pbar_position (int | None): Position of the progress bar 196 | prices (tuple[float, float] | None): Tuple of (input_price, output_price) per million tokens, for cost tracking 197 | semaphore (asyncio.Semaphore | None): Optional semaphore for limiting concurrent API calls 198 | transform (Callable[[TaskResult], T | Awaitable[T]]): Function to transform TaskResult objects before returning 199 | 200 | Returns: 201 | TaskResults[T]: Processed results and statistics 202 | 203 | Process: Runs model inference → evaluates with graders → computes rewards/advantages → 204 | tracks metrics → transforms results 205 | """ 206 | num_completions = len(tasks) * n 207 | pbar = tqdm.tqdm(total=num_completions, desc=pbar_desc, position=pbar_position) 208 | stats = TaskResultStats(pbar=pbar, prices=prices) 209 | results = TaskResults( 210 | await asyncio.gather( 211 | *( 212 | _get_task_result( 213 | task=task, 214 | client=client, 215 | model=model, 216 | log_results=log_results, 217 | n=n, 218 | log_dir=log_dir, 219 | on_chunk=_create_on_chunk_callback( 220 | log_token_logprobs, on_chunk, stats 221 | ), 222 | semaphore=semaphore, 223 | params=params, 224 | stats=stats, 225 | transform=transform, 226 | ) 227 | for task, log_results in zip( 228 | tasks, _get_log_results_flags(log_results, len(tasks)) 229 | ) 230 | ) 231 | ) 232 | ) 233 | results.stats = stats 234 | pbar.close() 235 | if getattr(pbar, "container", None) and clear_pbar: 236 | pbar.container.close() 237 | if getattr(pbar, "container", None) and print_pbar: 238 | print(pbar.container.__repr__(pretty=True)) 239 | return results 240 | 241 | 242 | def _create_on_chunk_callback( 243 | log_token_logprobs: bool, 244 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None, 245 | stats: TaskResultStats, 246 | ) -> Callable[[ChatCompletionChunk, ChatCompletion], None] | None: 247 | "Create a callback function that logs token logprobs and/or calls the user's on_chunk callback if provided." 248 | if not log_token_logprobs and not on_chunk: 249 | return None 250 | 251 | def _on_chunk(chunk: ChatCompletionChunk, completion: ChatCompletion) -> None: 252 | if log_token_logprobs: 253 | stats.update(id=chunk.id, chunk=chunk) 254 | if on_chunk: 255 | on_chunk(chunk, completion) 256 | 257 | return _on_chunk 258 | 259 | 260 | def _get_log_results_flags(log_results: bool | float | int, n: int) -> list[bool]: 261 | "Return a list of flags indicating whether to log results for each task." 262 | if isinstance(log_results, int) and log_results >= 1: 263 | result = [True] * log_results + [False] * max(n - log_results, 0) 264 | elif isinstance(log_results, float): 265 | result = [True] * int(log_results * n) + [False] * (n - int(log_results * n)) 266 | else: 267 | result = [bool(log_results)] * n 268 | random.shuffle(result) 269 | return result 270 | 271 | 272 | async def _get_task_result( 273 | task: Task, 274 | client: AsyncOpenAI, 275 | model: str, 276 | log_results: bool, 277 | n: int, 278 | log_dir: str | None, 279 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None, 280 | semaphore: asyncio.Semaphore | None, 281 | params: ChatCompletionParams | None, 282 | stats: TaskResultStats, 283 | transform: Callable[[TaskResult], T | Awaitable[T]], 284 | ) -> T: 285 | """ 286 | Processes a single task by generating chat completions, grading responses, and calculating rewards. 287 | 288 | This is a helper function called by get_task_results for each task in the list. It: 289 | 1. Makes n API calls to generate completions for the task 290 | 2. Collects all completions and passes them to the task's grader function 291 | 3. Tracks usage statistics and handles exceptions 292 | 4. Calculates GRPO advantages based on reward distribution 293 | 5. Applies the transform function to the TaskResult before returning 294 | 295 | See get_task_results for parameter descriptions. 296 | """ 297 | # always request logprobs, unless explicitly disabled 298 | _params = (params or {}).copy() 299 | if "logprobs" not in _params: 300 | _params["logprobs"] = True 301 | elif _params["logprobs"] is None: 302 | del _params["logprobs"] 303 | chat_completions: list[ChatCompletion] = [] 304 | rewards: dict[tuple[str, int], float] = {} 305 | metrics: dict[tuple[str, int], dict[str, float]] = {} 306 | exceptions: list[Exception] = [] 307 | for chat_completion_future in asyncio.as_completed( 308 | get_chat_completion( 309 | client, 310 | log_dir=log_dir, 311 | log_results=log_results and i == 0, 312 | on_chunk=on_chunk, 313 | semaphore=semaphore, 314 | messages=task.messages, 315 | model=model, 316 | **_params, # type: ignore 317 | ) 318 | for i in range(n) 319 | ): 320 | try: 321 | chat_completion = await chat_completion_future 322 | chat_completions.append(chat_completion) 323 | stats.update(id=chat_completion.id, usage=chat_completion.usage) 324 | 325 | async def _grade(choice: Choice, grader: Grader) -> tuple[int, Grade]: 326 | return choice.index, await maybe_await(grader(choice)) 327 | 328 | for grade_future in asyncio.as_completed( 329 | _grade(choice, task.grader) for choice in chat_completion.choices 330 | ): 331 | try: 332 | choice_index, grade = await grade_future 333 | reward, _metrics = ( 334 | grade if isinstance(grade, tuple) else (grade, {}) 335 | ) 336 | stats.update(id=chat_completion.id, reward=reward, metrics=_metrics) 337 | rewards[chat_completion.id, choice_index] = reward 338 | metrics[chat_completion.id, choice_index] = _metrics 339 | except Exception as e: 340 | exceptions.append(e) 341 | stats.update(id=chat_completion.id, exception=e) 342 | continue 343 | except Exception as e: 344 | exceptions.append(e) 345 | stats.update(id=None, exception=e) 346 | continue 347 | if rewards: 348 | reward_mean = np.mean(list(rewards.values())) 349 | reward_std = np.std(list(rewards.values())) 350 | # calculate GRPO advantages 351 | advantages = { 352 | key: float((reward - reward_mean) / (reward_std + 1e-6)) 353 | for key, reward in rewards.items() 354 | } 355 | else: 356 | advantages = {key: 0.0 for key in rewards.keys()} 357 | return await maybe_await( 358 | transform( 359 | TaskResult( 360 | task=task, 361 | chat_completions=chat_completions, 362 | rewards=rewards, 363 | metrics=metrics, 364 | advantages=advantages, 365 | exceptions=exceptions, 366 | ) 367 | ) 368 | ) 369 | -------------------------------------------------------------------------------- /lib/temporal_clue.py: -------------------------------------------------------------------------------- 1 | import json 2 | from openai.types.chat.chat_completion import Choice 3 | import re 4 | from typing import Iterable, TypedDict 5 | 6 | from .tasks import Task 7 | 8 | 9 | class TemporalCluePuzzle(TypedDict): 10 | num_clues: int 11 | prompt: str 12 | solution: dict[str, str] 13 | 14 | 15 | def get_temporal_clue_puzzles() -> list[TemporalCluePuzzle]: 16 | return json.load(open("./data/puzzles.json")) 17 | 18 | 19 | def get_temporal_clue_tasks() -> Iterable[Task]: 20 | for puzzle in get_temporal_clue_puzzles(): 21 | 22 | def grader(choice: Choice, puzzle: TemporalCluePuzzle = puzzle) -> float: 23 | content = choice.message.content 24 | assert isinstance(content, str) 25 | num_correct = 0 26 | for key, value in puzzle["solution"].items(): 27 | if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content): 28 | match = matches[-1] 29 | if match.strip().lower() == value.lower(): 30 | num_correct += 1 31 | return num_correct / len(puzzle["solution"]) 32 | 33 | yield Task( 34 | messages=[ 35 | { 36 | "role": "user", 37 | "content": puzzle["prompt"], 38 | } 39 | ], 40 | grader=grader, 41 | ) 42 | -------------------------------------------------------------------------------- /lib/tokenize.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from itertools import takewhile 3 | from openai.types.chat.chat_completion import Choice 4 | from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob 5 | import random 6 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 7 | from typing import cast 8 | 9 | 10 | from .tasks import TaskResult 11 | 12 | 13 | @dataclass 14 | class TokenizedResult: 15 | conversation: list 16 | reward: float 17 | advantage: float 18 | chat_template: str 19 | chat: str 20 | tokens: list[str] 21 | token_ids: list[int] 22 | input_pos: list[int] 23 | assistant_mask: list[int] 24 | token_logprobs: list[ChatCompletionTokenLogprob] | None 25 | prompt_id: int = 0 26 | prompt_length: int = 0 27 | 28 | def without_prompt(self) -> "TokenizedResult": 29 | assistant_mask = self.assistant_mask[self.prompt_length :] 30 | return TokenizedResult( 31 | conversation=self.conversation, 32 | advantage=self.advantage, 33 | reward=self.reward, 34 | chat_template=self.chat_template, 35 | chat=self.chat, 36 | tokens=self.tokens[self.prompt_length :], 37 | token_ids=self.token_ids[self.prompt_length :], 38 | input_pos=self.input_pos[self.prompt_length :], 39 | assistant_mask=assistant_mask, 40 | token_logprobs=( 41 | self.token_logprobs[len(self.token_logprobs) - sum(assistant_mask) :] 42 | if self.token_logprobs is not None 43 | else None 44 | ), 45 | prompt_id=self.prompt_id, 46 | prompt_length=0, 47 | ) 48 | 49 | 50 | class TaskResultTokenizer: 51 | def __init__( 52 | self, 53 | pretrained_tokenizer_or_model_name_or_path: ( 54 | PreTrainedTokenizer | PreTrainedTokenizerFast | str 55 | ), 56 | ) -> None: 57 | self.tokenizer = ( 58 | AutoTokenizer.from_pretrained(pretrained_tokenizer_or_model_name_or_path) 59 | if isinstance(pretrained_tokenizer_or_model_name_or_path, str) 60 | else pretrained_tokenizer_or_model_name_or_path 61 | ) 62 | 63 | def __call__(self, task_result: TaskResult) -> list[TokenizedResult]: 64 | chat_completions = task_result.chat_completions.copy() 65 | random.shuffle(chat_completions) 66 | tokenized_results = [ 67 | self._tokenized_result( 68 | task_result, 69 | choice, 70 | task_result.rewards.get((chat_completion.id, choice.index), 0), 71 | task_result.advantages.get((chat_completion.id, choice.index), 0), 72 | ) 73 | for chat_completion in chat_completions 74 | for choice in chat_completion.choices 75 | ] 76 | prompt_id = random.randint(-(2**63), 2**63 - 1) 77 | prompt_length = len( 78 | list( 79 | takewhile( 80 | lambda x: len(set(x)) == 1, 81 | zip(*(r.token_ids for r in tokenized_results)), 82 | ) 83 | ) 84 | ) 85 | for result in tokenized_results: 86 | result.prompt_id = prompt_id 87 | result.prompt_length = prompt_length 88 | # zero out assistant prompt tokens 89 | result.assistant_mask[:prompt_length] = [0] * prompt_length 90 | return tokenized_results 91 | 92 | def _tokenized_result( 93 | self, task_result: TaskResult, choice: Choice, reward: float, advantage: float 94 | ) -> TokenizedResult: 95 | conversation: list = task_result.task.messages + [ 96 | { 97 | "role": "assistant", 98 | "content": choice.message.content, 99 | } 100 | ] 101 | chat_template = update_chat_template(self.tokenizer.get_chat_template()) 102 | chat = cast( 103 | str, 104 | self.tokenizer.apply_chat_template( 105 | conversation, chat_template=chat_template, tokenize=False 106 | ), 107 | ) 108 | tokenized_result = cast( 109 | dict[str, list[int]], 110 | self.tokenizer.apply_chat_template( 111 | conversation, 112 | chat_template=chat_template, 113 | return_dict=True, 114 | return_assistant_tokens_mask=True, 115 | ), 116 | ) 117 | if ( 118 | choice.logprobs 119 | and choice.logprobs.content 120 | and choice.logprobs.content[0].token.startswith("token_id:") 121 | ): 122 | start = tokenized_result["assistant_masks"].index(1) 123 | try: 124 | end = start + tokenized_result["assistant_masks"][start:].index(0) 125 | except ValueError: 126 | end = len(tokenized_result["assistant_masks"]) 127 | tokenized_result["input_ids"][start:end] = [ 128 | int(token_logprob.token.split(":")[1]) 129 | for token_logprob in choice.logprobs.content 130 | ] 131 | tokenized_result["assistant_masks"][start:end] = [ 132 | 1 for _ in choice.logprobs.content 133 | ] 134 | token_logprobs = choice.logprobs.content 135 | else: 136 | token_logprobs = None 137 | tokens = [ 138 | self.tokenizer.decode(token_id) 139 | for token_id in tokenized_result["input_ids"] 140 | ] 141 | if token_logprobs is None: 142 | token_logprobs = self.get_token_logprobs( 143 | choice, 144 | [ 145 | token 146 | for token, mask in zip(tokens, tokenized_result["assistant_masks"]) 147 | if mask 148 | ], 149 | ) 150 | return TokenizedResult( 151 | conversation=conversation, 152 | reward=reward, 153 | advantage=advantage, 154 | chat_template=chat_template, 155 | chat=chat, 156 | tokens=tokens, 157 | token_ids=tokenized_result["input_ids"], 158 | input_pos=list(range(len(tokens))), 159 | assistant_mask=tokenized_result["assistant_masks"], 160 | token_logprobs=token_logprobs, 161 | ) 162 | 163 | def get_token_logprobs( 164 | self, 165 | choice: Choice, 166 | assistant_tokens: list[str], 167 | ) -> list[ChatCompletionTokenLogprob] | None: 168 | if not choice.logprobs: 169 | return None 170 | if not choice.logprobs.content: 171 | return None 172 | result_token_logprobs = choice.logprobs.content.copy() 173 | if "".join(assistant_tokens) != "".join( 174 | token_logprob.token for token_logprob in result_token_logprobs 175 | ) and len(assistant_tokens) != len(result_token_logprobs): 176 | print("Assistant tokens are not equal, skipping token logprobs") 177 | return None 178 | elif assistant_tokens == [ 179 | token_logprob.token for token_logprob in result_token_logprobs 180 | ]: 181 | return result_token_logprobs 182 | else: 183 | completion = "" 184 | result_completion = "" 185 | token_logprobs = [] 186 | try: 187 | while True: 188 | if completion == result_completion: 189 | token = assistant_tokens.pop(0) 190 | result_token_logprob = result_token_logprobs.pop(0) 191 | result_token = result_token_logprob.token 192 | if token == result_token: 193 | token_logprobs.append(result_token_logprob) 194 | else: 195 | token_logprobs.append( 196 | ChatCompletionTokenLogprob( 197 | token=token, 198 | logprob=float("nan"), 199 | top_logprobs=[], 200 | ) 201 | ) 202 | completion += token 203 | result_completion += result_token 204 | elif len(completion) < len(result_completion): 205 | token = assistant_tokens.pop(0) 206 | token_logprobs.append( 207 | ChatCompletionTokenLogprob( 208 | token=token, 209 | logprob=float("nan"), 210 | top_logprobs=[], 211 | ) 212 | ) 213 | completion += token 214 | elif len(completion) > len(result_completion): 215 | result_completion += result_token_logprobs.pop(0).token 216 | else: 217 | print("Warning: Completions are not equal") 218 | print(f"Completion: {completion}") 219 | print(f"Result completion: {result_completion}") 220 | token_logprobs = None 221 | break 222 | except IndexError: 223 | pass 224 | return token_logprobs 225 | 226 | 227 | def update_chat_template(chat_template: str) -> str: 228 | return ( 229 | chat_template 230 | # Remove template logic that strips reasoning content from the chat messages 231 | .replace( 232 | "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}", 233 | "", 234 | ) 235 | # Add generation tags for assistant token masking 236 | .replace( 237 | "{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}", 238 | "{{'<|Assistant|>'}}{% generation %}{{ content }}{% endgeneration %}{{'<|end▁of▁sentence|>'}}", 239 | ) 240 | # Add generation tags for assistant token masking (for Hermes 2 Theta) 241 | .replace( 242 | "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}", 243 | "{{'<|im_start|>' + message['role'] + '\n'}}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{% else %}{{ message['content'] }}{% endif %}{{'<|im_end|>' + '\n'}}", 244 | ) 245 | # Add generation tags for assistant token masking (for Qwen 2.5 Instruct) 246 | .replace( 247 | """ 248 | {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} 249 | {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} 250 | """.strip(), 251 | """ 252 | {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} 253 | {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} 254 | {%- elif message.role == "assistant" and not message.tool_calls %} 255 | {{- '<|im_start|>' + message.role + '\\n' }}{% generation %}{{ message.content }}{% endgeneration %}{{ '<|im_end|>' + '\\n' }} 256 | """.strip(), 257 | ).replace( 258 | """ 259 | {%- elif message.role == "assistant" %} 260 | {{- '<|im_start|>' + message.role }} 261 | {%- if message.content %} 262 | {{- '\\n' + message.content }} 263 | {%- endif %}""".strip(), 264 | """ 265 | {%- elif message.role == "assistant" %} 266 | {{- '<|im_start|>' + message.role }} 267 | {%- if message.content %} 268 | {{- '\\n' }}{% generation %}{{ message.content }}{% endgeneration %} 269 | {%- endif %}""".strip(), 270 | ) 271 | # Add generation tags for assistant token masking (for Llama 3.3 70B) 272 | .replace( 273 | "{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}", 274 | "{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}{%- if message['role'] == 'assistant' %}{% generation %}{{ message['content'] | trim + '<|eot_id|>' }}{% endgeneration %}{% else %}{{ message['content'] | trim + '<|eot_id|>' }}{% endif %}", 275 | ) 276 | ) 277 | -------------------------------------------------------------------------------- /lib/tqdm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | try: 4 | sys.modules["IPython"].get_ipython 5 | from tqdm import notebook as tqdm 6 | except: 7 | from tqdm import std as tqdm 8 | -------------------------------------------------------------------------------- /lib/tune.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import glob 3 | from lib.pack import PackedDataset, PackedTensors, packed_tensors_to_dir 4 | from lib.recipe import ComponentConfig, recipe_main, TuneRecipeConfig 5 | from omegaconf import OmegaConf 6 | import os 7 | import re 8 | import shutil 9 | import sys 10 | import torch 11 | from torchtune.modules import TransformerDecoder 12 | from torchtune.training import cleanup_before_training, FullModelHFCheckpointer 13 | from torchtune.training.metric_logging import DiskLogger 14 | import tqdm 15 | from typing import Any, Callable, Literal, IO 16 | 17 | 18 | Verbosity = Literal[0, 1, 2] 19 | 20 | 21 | def clear_iteration_dirs(output_dir: str, excluding: list[int]) -> None: 22 | for dir in os.listdir(output_dir): 23 | if ( 24 | os.path.isdir(os.path.join(output_dir, dir)) 25 | and dir.isdigit() 26 | and int(dir) not in excluding 27 | ): 28 | iteration_dir = os.path.join(output_dir, dir) 29 | shutil.rmtree(iteration_dir) 30 | print(f"Deleted iteration directory {iteration_dir}") 31 | 32 | 33 | def get_iteration(output_dir: str) -> int: 34 | os.makedirs(output_dir, exist_ok=True) 35 | return max( 36 | ( 37 | int(subdir) 38 | for subdir in os.listdir(output_dir) 39 | if os.path.isdir(os.path.join(output_dir, subdir)) and subdir.isdigit() 40 | ), 41 | default=0, 42 | ) 43 | 44 | 45 | def get_last_iteration_dir(output_dir: str) -> str | None: 46 | last_iteration_dir = os.path.join(output_dir, f"{get_iteration(output_dir):04d}") 47 | return last_iteration_dir if os.path.exists(last_iteration_dir) else None 48 | 49 | 50 | def last_tune_log(output_dir: str) -> list[dict[str, float]]: 51 | sorted_logs = sorted(glob.glob(f"{output_dir}/logs/*")) 52 | contents = open(sorted_logs[-1]).read() 53 | lines = contents.strip().splitlines() 54 | parsed_logs = [] 55 | for line in lines: 56 | step_part, metrics_part = line.split(" | ") 57 | step = int(step_part.split()[1]) 58 | metrics = {} 59 | for metric in metrics_part.split(): 60 | key, value = metric.split(":") 61 | metrics[key] = float(value) 62 | parsed_logs.append({"step": step, **metrics}) 63 | return parsed_logs 64 | 65 | 66 | async def tune( 67 | base_model: str, 68 | output_dir: str, 69 | packed_tensors: PackedTensors, 70 | model: Callable[[], TransformerDecoder], 71 | model_type: str, 72 | config: TuneRecipeConfig = TuneRecipeConfig(), 73 | in_process: bool = False, 74 | verbosity: Verbosity = 2, 75 | ) -> str: 76 | if os.path.isdir(base_model): 77 | base_checkpoint_dir = base_model 78 | else: 79 | process = await asyncio.create_subprocess_shell( 80 | f"HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download {base_model}", 81 | stdout=asyncio.subprocess.PIPE, 82 | stderr=asyncio.subprocess.PIPE, 83 | ) 84 | 85 | base_stdout = [] 86 | 87 | async def read_stream(stream, is_stderr=False): 88 | while True: 89 | line = await stream.readline() 90 | if not line: 91 | break 92 | text = line.decode().rstrip() 93 | if is_stderr: 94 | print(text) 95 | else: 96 | base_stdout.append(text) 97 | 98 | await asyncio.gather( 99 | read_stream(process.stdout, is_stderr=False), 100 | read_stream(process.stderr, is_stderr=True), 101 | ) 102 | await process.wait() 103 | base_checkpoint_dir = "\n".join(base_stdout).strip() 104 | 105 | config.checkpointer = _get_checkpointer_config( 106 | checkpoint_dir=max( 107 | (d for d in glob.glob(f"{output_dir}/*") if d.split("/")[-1].isdigit()), 108 | key=lambda x: int(x.split("/")[-1]), 109 | default=base_checkpoint_dir, 110 | ), 111 | output_dir=output_dir, 112 | tune_model_type=model_type, 113 | ) 114 | if config.loss.kl_coef > 0: 115 | print("Using reference checkpointer") 116 | config.reference_checkpointer = _get_checkpointer_config( 117 | checkpoint_dir=base_checkpoint_dir, 118 | output_dir=output_dir, 119 | tune_model_type=model_type, 120 | ) 121 | if config.metric_logger is None: 122 | config.metric_logger = ComponentConfig(DiskLogger, log_dir=f"{output_dir}/logs") 123 | config.model = ComponentConfig(model) 124 | disk_packed_tensors = packed_tensors_to_dir(packed_tensors, f"{output_dir}/tensors") 125 | config.dataset = ComponentConfig( 126 | PackedDataset, 127 | **disk_packed_tensors, 128 | ) 129 | config.seed = 42 130 | dict_config = config.dict_config() 131 | OmegaConf.save(dict_config, f"{output_dir}/config.yaml") 132 | if in_process: 133 | cleanup_before_training() 134 | recipe_main(config) 135 | else: 136 | await _tune_run( 137 | config_path=f"{output_dir}/config.yaml", 138 | total=disk_packed_tensors["num_sequences"], 139 | verbosity=verbosity, 140 | torchrun_kwargs={"nproc_per_node": torch.cuda.device_count()}, 141 | # tune_run_env={"CUDA_LAUNCH_BLOCKING": "1"}, 142 | ) 143 | epoch_dirs = lambda: glob.glob(f"{output_dir}/epoch_*") 144 | epoch_dir = max( 145 | epoch_dirs(), 146 | key=lambda x: int(x.split("_")[-1]), 147 | default=None, 148 | ) 149 | assert ( 150 | epoch_dir is not None 151 | ), f"No epoch directory found in output directory {output_dir}" 152 | iteration_dir = f"{output_dir}/{get_iteration(output_dir) + 1:04d}" 153 | os.rename(epoch_dir, iteration_dir) 154 | for epoch_dir in epoch_dirs(): 155 | os.rmdir(epoch_dir) 156 | return iteration_dir 157 | 158 | 159 | def _get_checkpointer_config( 160 | checkpoint_dir: str, 161 | output_dir: str, 162 | tune_model_type: str, 163 | checkpoint_files: list[str] | None = None, 164 | output_subdir: str = "", 165 | ) -> ComponentConfig[FullModelHFCheckpointer]: 166 | return ComponentConfig( 167 | FullModelHFCheckpointer, 168 | checkpoint_dir=checkpoint_dir, 169 | checkpoint_files=checkpoint_files 170 | or [ 171 | os.path.basename(file) 172 | for ext in ["safetensors", "pt", "ckpt", "bin", "pth"] 173 | for file in glob.glob(f"{checkpoint_dir}/*.{ext}") 174 | ], 175 | recipe_checkpoint=None, 176 | output_dir=output_dir + output_subdir, 177 | model_type=tune_model_type, 178 | ) 179 | 180 | 181 | async def _tune_run( 182 | config_path: str, 183 | total: int, 184 | verbosity: Verbosity = 2, 185 | torchrun_kwargs: dict[str, Any] | None = None, 186 | tune_run_env: dict[str, str] | None = None, 187 | ) -> None: 188 | args = [ 189 | "tune", 190 | "run", 191 | *[ 192 | f"--{key.replace('_', '-')}{f'={value}' if value is not True else ''}" 193 | for key, value in (torchrun_kwargs or {}).items() 194 | ], 195 | "lib.recipe.TuneRecipe", 196 | "--config", 197 | config_path, 198 | ] 199 | if verbosity > 0: 200 | print(f"$ {' '.join(args)}") 201 | process = await asyncio.create_subprocess_exec( 202 | *args, 203 | stdout=asyncio.subprocess.PIPE, 204 | stderr=asyncio.subprocess.PIPE, 205 | env={**os.environ, **(tune_run_env or {})}, 206 | ) 207 | if verbosity == 1: 208 | pbar = tqdm.tqdm(total=total) 209 | else: 210 | pbar = None 211 | 212 | async def log_output(stream: asyncio.StreamReader, io: IO[str]) -> None: 213 | output = "" 214 | while True: 215 | try: 216 | chunk = await stream.read(4096) 217 | if not chunk: 218 | break 219 | output += chunk.decode() 220 | if verbosity > 1: 221 | io.write(output) 222 | io.flush() 223 | output = "" 224 | elif verbosity == 1: 225 | output = output.split("\n")[-1] 226 | if pbar: 227 | pbar_start = re.compile(r"(\d+)\|(\d+)\|Loss: ([\d.]+):") 228 | if match := pbar_start.search(output): 229 | epoch, step, loss = match.groups() 230 | pbar.update(int(step) - pbar.n) 231 | pbar.set_description(f"{epoch}|{step}|Loss: {loss}") 232 | metrics = { 233 | key: value 234 | for key, value in re.findall(r"(\w+)=([\d.-]+)", output) 235 | } 236 | if metrics: 237 | pbar.set_postfix(**metrics) 238 | output = "" 239 | else: 240 | pbar_regex = re.compile( 241 | r"\[(?:\d+:)?\d+:\d+<(?:\d+:)?\d+:\d+.*\]" 242 | ) 243 | if pbar_regex.search(output): 244 | io.write(output) 245 | io.flush() 246 | output = "" 247 | except Exception: 248 | break 249 | 250 | tasks = [] 251 | if process.stdout: 252 | tasks.append(asyncio.create_task(log_output(process.stdout, sys.stdout))) 253 | if process.stderr: 254 | tasks.append(asyncio.create_task(log_output(process.stderr, sys.stderr))) 255 | try: 256 | _ = await asyncio.gather(*tasks) 257 | except asyncio.CancelledError: 258 | process.kill() 259 | if pbar: 260 | pbar.close() 261 | -------------------------------------------------------------------------------- /lib/types.py: -------------------------------------------------------------------------------- 1 | from openai._types import Body, Headers, Query 2 | from openai.types.chat.completion_create_params import CompletionCreateParamsBase 3 | from typing import Never 4 | 5 | 6 | class CreateParams(CompletionCreateParamsBase, total=False): 7 | """Parameters for chat completion creation with additional fields.""" 8 | 9 | extra_headers: Headers 10 | extra_query: Query 11 | extra_body: Body 12 | 13 | 14 | class ChatCompletionParams(CreateParams, total=False): 15 | """Parameters for chat completion with restricted fields.""" 16 | 17 | messages: Never # type: ignore 18 | model: Never # type: ignore 19 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import signal 3 | from typing import Generator 4 | 5 | 6 | @contextmanager 7 | def timeout(seconds: int = 1) -> Generator[None, None, None]: 8 | def timeout_handler(signum: object, frame: object) -> None: 9 | raise TimeoutError() 10 | 11 | original_handler = signal.signal(signal.SIGALRM, timeout_handler) 12 | try: 13 | signal.alarm(seconds) 14 | yield 15 | finally: 16 | signal.alarm(0) 17 | signal.signal(signal.SIGALRM, original_handler) 18 | -------------------------------------------------------------------------------- /lib/vllm.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | import httpx 4 | from openai import AsyncOpenAI, DefaultAsyncHttpxClient 5 | import os 6 | import socket 7 | import subprocess 8 | import sys 9 | import re 10 | from typing import Any, IO, Optional 11 | 12 | 13 | @dataclass 14 | class vLLM: 15 | client: AsyncOpenAI 16 | max_concurrent_tokens: int 17 | model: str 18 | process: asyncio.subprocess.Process 19 | 20 | 21 | async def start_vllm( 22 | model: str, 23 | env: Optional[dict[str, str]] = None, 24 | log_file: str = "./logs/vllm.log", 25 | max_concurrent_requests: int = 128, 26 | named_arguments: dict[str, Any] = {}, 27 | timeout: float = 120.0, 28 | verbosity: int = 2, 29 | ) -> vLLM: 30 | kill_vllm_workers() 31 | if os.path.exists(os.path.abspath(model)): 32 | named_arguments.setdefault("served_model_name", model) 33 | model = os.path.abspath(model) 34 | port = named_arguments.get("port") or 8000 35 | while True: 36 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 37 | try: 38 | sock.bind((named_arguments.get("host") or "0.0.0.0", port)) 39 | break 40 | except socket.error: 41 | if "port" in named_arguments and named_arguments["port"] == port: 42 | raise RuntimeError(f"Port {port} is already in use") 43 | port += 1 44 | finally: 45 | sock.close() 46 | named_arguments["port"] = port 47 | args = [ 48 | "vllm", 49 | "serve", 50 | model, 51 | *[ 52 | f"--{key.replace('_', '-')}{f'={value}' if value is not True else ''}" 53 | for key, value in named_arguments.items() 54 | ], 55 | "--api-key=default", 56 | ] 57 | # os.system("lsof -ti :8000 | xargs kill -9 2>/dev/null || true") 58 | process = await asyncio.create_subprocess_exec( 59 | *args, 60 | stdout=asyncio.subprocess.PIPE, 61 | stderr=asyncio.subprocess.PIPE, 62 | env={ 63 | **os.environ, 64 | **(env or {}), 65 | }, 66 | ) 67 | if verbosity > 0: 68 | print(f"$ {' '.join(args)}") 69 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 70 | log = open(log_file, "w") 71 | logging = verbosity > 1 72 | max_concurrent_tokens: Optional[int] = None 73 | 74 | async def log_output(stream: asyncio.StreamReader, io: IO[str]) -> None: 75 | while True: 76 | line = await stream.readline() 77 | if not line: 78 | break 79 | decoded_line = line.decode() 80 | if logging: 81 | io.write(decoded_line) 82 | io.flush() 83 | log.write(decoded_line) 84 | log.flush() 85 | nonlocal max_concurrent_tokens 86 | if not max_concurrent_tokens: 87 | match = re.search( 88 | r"Maximum concurrency for (\d+) tokens per request: ([\d.]+)x", 89 | decoded_line, 90 | ) 91 | if match: 92 | max_concurrent_tokens = int( 93 | int(match.group(1)) * float(match.group(2)) 94 | ) 95 | log.close() 96 | 97 | if process.stdout: 98 | asyncio.create_task(log_output(process.stdout, sys.stdout)) 99 | if process.stderr: 100 | asyncio.create_task(log_output(process.stderr, sys.stderr)) 101 | client = AsyncOpenAI( 102 | api_key="default", 103 | base_url=f"http://{named_arguments.get('host', '0.0.0.0')}:{named_arguments['port']}/v1", 104 | max_retries=6, 105 | http_client=DefaultAsyncHttpxClient( 106 | limits=httpx.Limits( 107 | max_connections=max_concurrent_requests, 108 | max_keepalive_connections=max_concurrent_requests, 109 | ), 110 | timeout=httpx.Timeout(timeout=1_200, connect=10.0), 111 | ), 112 | ) 113 | start = asyncio.get_event_loop().time() 114 | while True: 115 | try: 116 | await client.chat.completions.create( 117 | messages=[{"role": "user", "content": "Hello"}], 118 | model=named_arguments.get("served_model_name", model), 119 | max_tokens=1, 120 | ) 121 | break 122 | except Exception: 123 | if asyncio.get_event_loop().time() - start > timeout: 124 | process.terminate() 125 | kill_vllm_workers() 126 | raise TimeoutError("vLLM server did not start in time") 127 | continue 128 | if logging: 129 | print(f"vLLM server started succesfully. Logs can be found at {log_file}") 130 | logging = False 131 | if max_concurrent_tokens is None: 132 | process.terminate() 133 | kill_vllm_workers() 134 | raise RuntimeError( 135 | "Max concurrent requests for the maximum model length not logged" 136 | ) 137 | return vLLM( 138 | client, 139 | max_concurrent_tokens, 140 | named_arguments.get("served_model_name", model), 141 | process, 142 | ) 143 | 144 | 145 | def kill_vllm_workers() -> None: 146 | result = subprocess.run(["ps", "aux"], capture_output=True, text=True) 147 | pids = [ 148 | line.split()[1] 149 | for line in result.stdout.splitlines() 150 | if "from multiprocessing.spawn import spawn_main; spawn_main(tracker_fd=" 151 | in line 152 | ] 153 | for pid in pids: 154 | subprocess.run(["sudo", "kill", "-9", pid], check=True) 155 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "deductive-reasoning" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "aioitertools>=0.12.0", 9 | "matplotlib>=3.10.1", 10 | "polars>=1.24.0", 11 | "seaborn>=0.13.2", 12 | "torch>=2.5.1", 13 | "torchao>=0.8.0", 14 | "torchtune>=0.5.0", 15 | "vllm>=0.7.0", 16 | "wandb>=0.19.8", 17 | ] 18 | 19 | [tool.uv] 20 | dev-dependencies = [ 21 | "ipykernel>=6.29.5", 22 | "ipywidgets>=8.1.5", 23 | "nbconvert>=7.16.6", 24 | ] 25 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "data": { 20 | "text/html": [ 21 | "\n" 30 | ], 31 | "text/plain": [ 32 | "" 33 | ] 34 | }, 35 | "metadata": {}, 36 | "output_type": "display_data" 37 | } 38 | ], 39 | "source": [ 40 | "%%html\n", 41 | "" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "import error: No module named 'triton'\n" 62 | ] 63 | }, 64 | { 65 | "name": "stderr", 66 | "output_type": "stream", 67 | "text": [ 68 | "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", 69 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mbradhilton\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" 70 | ] 71 | }, 72 | { 73 | "data": { 74 | "text/html": [ 75 | "Tracking run with wandb version 0.19.8" 76 | ], 77 | "text/plain": [ 78 | "" 79 | ] 80 | }, 81 | "metadata": {}, 82 | "output_type": "display_data" 83 | }, 84 | { 85 | "data": { 86 | "text/html": [ 87 | "Run data is saved locally in /Users/brad/github/bradhilton/deductive-reasoning/wandb/run-20250306_075102-" 88 | ], 89 | "text/plain": [ 90 | "" 91 | ] 92 | }, 93 | "metadata": {}, 94 | "output_type": "display_data" 95 | }, 96 | { 97 | "data": { 98 | "text/html": [ 99 | "Syncing run to Weights & Biases (docs)
" 100 | ], 101 | "text/plain": [ 102 | "" 103 | ] 104 | }, 105 | "metadata": {}, 106 | "output_type": "display_data" 107 | }, 108 | { 109 | "data": { 110 | "text/html": [ 111 | " View project at https://wandb.ai/bradhilton/grpo-tests" 112 | ], 113 | "text/plain": [ 114 | "" 115 | ] 116 | }, 117 | "metadata": {}, 118 | "output_type": "display_data" 119 | }, 120 | { 121 | "data": { 122 | "text/html": [ 123 | " View run at https://wandb.ai/bradhilton/grpo-tests/runs/%3CYOUR-RUN-NAME%3E" 124 | ], 125 | "text/plain": [ 126 | "" 127 | ] 128 | }, 129 | "metadata": {}, 130 | "output_type": "display_data" 131 | }, 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "(64, 64, 2860)" 136 | ] 137 | }, 138 | "execution_count": 3, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "import asyncio\n", 145 | "from itertools import cycle, islice\n", 146 | "from lib import models\n", 147 | "from lib.grpo import GRPO\n", 148 | "from lib.inference_early_stop import InferenceEarlyStop\n", 149 | "from lib.pack import packed_tensors_from_tokenized_results, plot_packed_tensors\n", 150 | "from lib.recipe import ComponentConfig, TuneRecipeConfig\n", 151 | "from lib.tasks import ChatCompletionParams, get_task_results\n", 152 | "from lib.temporal_clue import get_temporal_clue_tasks\n", 153 | "from lib.tokenize import TaskResultTokenizer\n", 154 | "from lib.tune import (\n", 155 | " clear_iteration_dirs,\n", 156 | " get_iteration,\n", 157 | " get_last_iteration_dir,\n", 158 | " last_tune_log,\n", 159 | " tune,\n", 160 | " Verbosity,\n", 161 | ")\n", 162 | "from lib.vllm import start_vllm, kill_vllm_workers\n", 163 | "import polars as pl\n", 164 | "import random\n", 165 | "import torch\n", 166 | "from transformers import AutoTokenizer\n", 167 | "import wandb\n", 168 | "\n", 169 | "run_name = \"\"\n", 170 | "assert run_name != \"\", \"Don't forget to choose a run name\"\n", 171 | "run = wandb.init(name=run_name, id=run_name, resume=\"allow\")\n", 172 | "\n", 173 | "tasks = list(get_temporal_clue_tasks())\n", 174 | "val_tasks = tasks[:64]\n", 175 | "test_tasks = tasks[64:128]\n", 176 | "train_tasks = tasks[128:]\n", 177 | "random.seed(42)\n", 178 | "random.shuffle(train_tasks)\n", 179 | "len(val_tasks), len(test_tasks), len(train_tasks)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "# GRPO params\n", 189 | "wandb.config[\"clip_epsilon\"] = clip_epsilon = 0.2\n", 190 | "wandb.config[\"entropy_coef\"] = entropy_coef = 0.0\n", 191 | "wandb.config[\"kl_coef\"] = kl_coef = 0.0\n", 192 | "wandb.config[\"tanh\"] = tanh = False\n", 193 | "\n", 194 | "# Model params\n", 195 | "model = models.qwen_32b()\n", 196 | "wandb.config[\"model\"] = model.base_model\n", 197 | "tokenizer = AutoTokenizer.from_pretrained(model.base_model)\n", 198 | "wandb.config[\"seq_len\"] = seq_len = 16384\n", 199 | "\n", 200 | "# Optimizer params\n", 201 | "wandb.config[\"lr\"] = lr = 6e-6\n", 202 | "wandb.config[\"betas\"] = betas = (0.9, 0.99)\n", 203 | "wandb.config[\"weight_decay\"] = weight_decay = 0.1\n", 204 | "\n", 205 | "# Training params\n", 206 | "num_iterations = 1_000\n", 207 | "wandb.config[\"samples_per_task\"] = samples_per_task = 50\n", 208 | "wandb.config[\"tasks_per_iter\"] = tasks_per_iter = 32\n", 209 | "wandb.config[\"stride\"] = stride = 32\n", 210 | "output_dir = f\"./models/{run_name}\"\n", 211 | "\n", 212 | "# Inference params\n", 213 | "expected_tokens = 1000 # Initial expected completion tokens per task sample\n", 214 | "inference_early_stop = InferenceEarlyStop(alpha=0.992, threshold=-3.0)\n", 215 | "\n", 216 | "# Logging params\n", 217 | "plot_tensors = True\n", 218 | "verbosity: Verbosity = 2" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "# Start from the latest iteration if it exists, otherwise start from the base model\n", 228 | "model_name = get_last_iteration_dir(output_dir) or model.base_model\n", 229 | "\n", 230 | "# Loop from the current iteration to the target number of iterations\n", 231 | "for i in range(get_iteration(output_dir), num_iterations):\n", 232 | " # Start vLLM server\n", 233 | " vllm = await start_vllm(\n", 234 | " model_name,\n", 235 | " max_concurrent_requests=4096,\n", 236 | " env={\"VLLM_ALLOW_LONG_MAX_MODEL_LEN\": \"1\"},\n", 237 | " named_arguments=dict(\n", 238 | " block_size=32,\n", 239 | " disable_log_requests=True,\n", 240 | " enable_prefix_caching=True,\n", 241 | " enforce_eager=True,\n", 242 | " gpu_memory_utilization=0.95,\n", 243 | " max_model_len=16384,\n", 244 | " max_num_seqs=4096,\n", 245 | " max_num_batched_tokens=16384,\n", 246 | " num_scheduler_steps=16,\n", 247 | " preemption_mode=\"swap\",\n", 248 | " return_tokens_as_token_ids=True,\n", 249 | " swap_space=80,\n", 250 | " tensor_parallel_size=torch.cuda.device_count(),\n", 251 | " ),\n", 252 | " timeout=360 + 15 * torch.cuda.device_count(),\n", 253 | " verbosity=verbosity,\n", 254 | " )\n", 255 | "\n", 256 | " # Create semaphore for rate limiting\n", 257 | " semaphore = asyncio.Semaphore(\n", 258 | " int(\n", 259 | " 1.3\n", 260 | " * (torch.cuda.device_count() / model.min_gpus)\n", 261 | " * (vllm.max_concurrent_tokens / expected_tokens)\n", 262 | " )\n", 263 | " )\n", 264 | "\n", 265 | " # Get results for logging validation performance and for tuning with train results\n", 266 | " offset = i * stride\n", 267 | " val_results, train_results = await asyncio.gather(\n", 268 | " get_task_results(\n", 269 | " tasks=val_tasks,\n", 270 | " client=vllm.client,\n", 271 | " model=vllm.model,\n", 272 | " log_results=8,\n", 273 | " n=2,\n", 274 | " on_chunk=inference_early_stop,\n", 275 | " params=ChatCompletionParams(\n", 276 | " stream_options={\n", 277 | " \"include_usage\": True,\n", 278 | " },\n", 279 | " max_completion_tokens=8192,\n", 280 | " ),\n", 281 | " pbar_desc=\"val\",\n", 282 | " semaphore=semaphore,\n", 283 | " ),\n", 284 | " get_task_results(\n", 285 | " tasks=list(islice(cycle(train_tasks), offset, offset + tasks_per_iter)),\n", 286 | " client=vllm.client,\n", 287 | " model=vllm.model,\n", 288 | " log_results=False,\n", 289 | " n=samples_per_task,\n", 290 | " on_chunk=inference_early_stop,\n", 291 | " params=ChatCompletionParams(\n", 292 | " stream_options={\n", 293 | " \"include_usage\": True,\n", 294 | " },\n", 295 | " max_completion_tokens=8192,\n", 296 | " ),\n", 297 | " pbar_desc=\"train\",\n", 298 | " semaphore=semaphore,\n", 299 | " transform=TaskResultTokenizer(tokenizer),\n", 300 | " ),\n", 301 | " )\n", 302 | "\n", 303 | " # Stop vLLM workers\n", 304 | " vllm.process.terminate()\n", 305 | " kill_vllm_workers()\n", 306 | "\n", 307 | " # Log results to Weights & Biases\n", 308 | " val_stats = val_results.stats\n", 309 | " assert val_stats.grades > 0\n", 310 | " assert val_stats.usages > 0\n", 311 | " wandb_data = {\n", 312 | " \"iteration\": i,\n", 313 | " \"exceptions\": val_stats.exceptions + train_results.stats.exceptions,\n", 314 | " \"reward\": val_stats.total_reward / val_stats.grades,\n", 315 | " \"tokens\": round(val_stats.completion_tokens / val_stats.usages),\n", 316 | " }\n", 317 | " for metric in val_stats.total_metrics:\n", 318 | " wandb_data[metric] = val_stats.total_metrics[metric] / val_stats.grades\n", 319 | " try:\n", 320 | " wandb_data.update(\n", 321 | " pl.DataFrame(last_tune_log(output_dir)).drop(\"step\").mean().to_dicts()[0]\n", 322 | " )\n", 323 | " except Exception:\n", 324 | " pass\n", 325 | " wandb.log(wandb_data)\n", 326 | "\n", 327 | " # Update expected tokens\n", 328 | " expected_tokens = wandb_data[\"tokens\"]\n", 329 | "\n", 330 | " # Clean up output directory to save space\n", 331 | " try:\n", 332 | " best_iteration = (\n", 333 | " wandb.Api()\n", 334 | " .run(f\"{run.entity}/{run.project}/{run.id}\")\n", 335 | " .history()\n", 336 | " .sort_values(by=\"reward\")[\"iteration\"]\n", 337 | " .iloc[-1]\n", 338 | " )\n", 339 | " # Clear all but the best and current iterations\n", 340 | " clear_iteration_dirs(output_dir, excluding=[best_iteration, i])\n", 341 | " except Exception:\n", 342 | " pass\n", 343 | "\n", 344 | " # Pack the tokenized results into tensors\n", 345 | " tokenized_results = [\n", 346 | " result\n", 347 | " for results in train_results\n", 348 | " for result in results\n", 349 | " if result.advantage != 0\n", 350 | " ]\n", 351 | " packed_tensors = packed_tensors_from_tokenized_results(\n", 352 | " tokenized_results,\n", 353 | " seq_len=seq_len,\n", 354 | " pad_token_id=tokenizer.pad_token_id, # type: ignore\n", 355 | " )\n", 356 | " if plot_tensors:\n", 357 | " plot_packed_tensors(packed_tensors)\n", 358 | " elif verbosity > 0:\n", 359 | " print(f\"Packed tensors into {packed_tensors[\"tokens\"].size()} shape\")\n", 360 | "\n", 361 | " # Tune the model\n", 362 | " model_name = await tune(\n", 363 | " base_model=model.base_model if kl_coef > 0 else model_name,\n", 364 | " output_dir=output_dir,\n", 365 | " packed_tensors=packed_tensors,\n", 366 | " model=model.tune_model,\n", 367 | " model_type=model.tune_model_type,\n", 368 | " config=TuneRecipeConfig(\n", 369 | " optimizer=ComponentConfig(\n", 370 | " \"torch.optim.AdamW\",\n", 371 | " lr=lr,\n", 372 | " betas=betas,\n", 373 | " weight_decay=weight_decay,\n", 374 | " fused=True,\n", 375 | " ),\n", 376 | " loss=ComponentConfig(\n", 377 | " GRPO,\n", 378 | " clip_epsilon=clip_epsilon,\n", 379 | " entropy_coef=entropy_coef,\n", 380 | " kl_coef=kl_coef,\n", 381 | " ),\n", 382 | " shuffle=True,\n", 383 | " batch_size=32768 // seq_len,\n", 384 | " fsdp_cpu_offload=True,\n", 385 | " enable_activation_checkpointing=True,\n", 386 | " enable_activation_offloading=True,\n", 387 | " custom_sharded_layers=[\"tok_embeddings\", \"output\"],\n", 388 | " num_output_chunks=model.tune_num_output_chunks,\n", 389 | " compile=True,\n", 390 | " ),\n", 391 | " verbosity=verbosity,\n", 392 | " )" 393 | ] 394 | } 395 | ], 396 | "metadata": { 397 | "kernelspec": { 398 | "display_name": ".venv", 399 | "language": "python", 400 | "name": "python3" 401 | }, 402 | "language_info": { 403 | "codemirror_mode": { 404 | "name": "ipython", 405 | "version": 3 406 | }, 407 | "file_extension": ".py", 408 | "mimetype": "text/x-python", 409 | "name": "python", 410 | "nbconvert_exporter": "python", 411 | "pygments_lexer": "ipython3", 412 | "version": "3.12.5" 413 | } 414 | }, 415 | "nbformat": 4, 416 | "nbformat_minor": 2 417 | } 418 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from itertools import cycle, islice 3 | from lib import models 4 | from lib.grpo import GRPO 5 | from lib.inference_early_stop import InferenceEarlyStop 6 | from lib.pack import packed_tensors_from_tokenized_results 7 | from lib.recipe import ComponentConfig, TuneRecipeConfig 8 | from lib.tasks import ChatCompletionParams, get_task_results 9 | from lib.temporal_clue import get_temporal_clue_tasks 10 | from lib.tokenize import TaskResultTokenizer 11 | from lib.tune import ( 12 | clear_iteration_dirs, 13 | get_iteration, 14 | get_last_iteration_dir, 15 | last_tune_log, 16 | tune, 17 | Verbosity, 18 | ) 19 | from lib.vllm import start_vllm, kill_vllm_workers 20 | import polars as pl 21 | import random 22 | import torch 23 | from transformers import AutoTokenizer 24 | import wandb 25 | 26 | run_name = "" 27 | assert run_name != "", "Don't forget to choose a run name" 28 | run = wandb.init(name=run_name, id=run_name, resume="allow") 29 | 30 | # Get tasks 31 | tasks = list(get_temporal_clue_tasks()) 32 | val_tasks = tasks[:64] 33 | test_tasks = tasks[64:128] 34 | train_tasks = tasks[128:] 35 | random.seed(42) 36 | random.shuffle(train_tasks) 37 | 38 | # GRPO params 39 | wandb.config["clip_epsilon"] = clip_epsilon = 0.2 40 | wandb.config["entropy_coef"] = entropy_coef = 0.0 41 | wandb.config["kl_coef"] = kl_coef = 0.0 42 | wandb.config["tanh"] = tanh = False 43 | 44 | # Model params 45 | model = models.qwen_32b() 46 | wandb.config["model"] = model.base_model 47 | tokenizer = AutoTokenizer.from_pretrained(model.base_model) 48 | wandb.config["seq_len"] = seq_len = 16384 49 | 50 | # Optimizer params 51 | wandb.config["lr"] = lr = 6e-6 52 | wandb.config["betas"] = betas = (0.9, 0.99) 53 | wandb.config["weight_decay"] = weight_decay = 0.1 54 | 55 | # Training params 56 | num_iterations = 1_000 57 | wandb.config["samples_per_task"] = samples_per_task = 50 58 | wandb.config["tasks_per_iter"] = tasks_per_iter = 32 59 | wandb.config["stride"] = stride = 32 60 | output_dir = f"./models/{run_name}" 61 | 62 | # Inference params 63 | expected_tokens = 1000 # Initial expected completion tokens per task sample 64 | inference_early_stop = InferenceEarlyStop(alpha=0.992, threshold=-3.0) 65 | 66 | # Logging params 67 | verbosity: Verbosity = 2 68 | 69 | # Start from the latest iteration if it exists, otherwise start from the base model 70 | model_name = get_last_iteration_dir(output_dir) or model.base_model 71 | 72 | 73 | async def train() -> None: 74 | global expected_tokens, model_name 75 | # Loop from the current iteration to the target number of iterations 76 | for i in range(get_iteration(output_dir), num_iterations): 77 | # Start vLLM server 78 | vllm = await start_vllm( 79 | model_name, 80 | max_concurrent_requests=4096, 81 | env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"}, 82 | named_arguments=dict( 83 | block_size=32, 84 | disable_log_requests=True, 85 | enable_prefix_caching=True, 86 | enforce_eager=True, 87 | gpu_memory_utilization=0.95, 88 | max_model_len=16384, 89 | max_num_seqs=4096, 90 | max_num_batched_tokens=16384, 91 | num_scheduler_steps=16, 92 | preemption_mode="swap", 93 | return_tokens_as_token_ids=True, 94 | swap_space=80, 95 | tensor_parallel_size=torch.cuda.device_count(), 96 | ), 97 | timeout=360 + 15 * torch.cuda.device_count(), 98 | verbosity=verbosity, 99 | ) 100 | 101 | # Create semaphore for rate limiting 102 | semaphore = asyncio.Semaphore( 103 | int( 104 | 1.3 105 | * (torch.cuda.device_count() / model.min_gpus) 106 | * (vllm.max_concurrent_tokens / expected_tokens) 107 | ) 108 | ) 109 | 110 | # Get results for logging validation performance and for tuning with train results 111 | offset = i * stride 112 | val_results, train_results = await asyncio.gather( 113 | get_task_results( 114 | tasks=val_tasks, 115 | client=vllm.client, 116 | model=vllm.model, 117 | log_results=8, 118 | n=2, 119 | on_chunk=inference_early_stop, 120 | params=ChatCompletionParams( 121 | stream_options={ 122 | "include_usage": True, 123 | }, 124 | max_completion_tokens=8192, 125 | ), 126 | pbar_desc="val", 127 | pbar_position=0, 128 | semaphore=semaphore, 129 | ), 130 | get_task_results( 131 | tasks=list(islice(cycle(train_tasks), offset, offset + tasks_per_iter)), 132 | client=vllm.client, 133 | model=vllm.model, 134 | log_results=False, 135 | n=samples_per_task, 136 | on_chunk=inference_early_stop, 137 | params=ChatCompletionParams( 138 | stream_options={ 139 | "include_usage": True, 140 | }, 141 | max_completion_tokens=8192, 142 | ), 143 | pbar_desc="train", 144 | pbar_position=1, 145 | semaphore=semaphore, 146 | transform=TaskResultTokenizer(tokenizer), 147 | ), 148 | ) 149 | 150 | # Stop vLLM workers 151 | vllm.process.terminate() 152 | kill_vllm_workers() 153 | 154 | # Log results to Weights & Biases 155 | val_stats = val_results.stats 156 | assert val_stats.grades > 0 157 | assert val_stats.usages > 0 158 | wandb_data = { 159 | "iteration": i, 160 | "exceptions": val_stats.exceptions + train_results.stats.exceptions, 161 | "reward": val_stats.total_reward / val_stats.grades, 162 | "tokens": round(val_stats.completion_tokens / val_stats.usages), 163 | } 164 | for metric in val_stats.total_metrics: 165 | wandb_data[metric] = val_stats.total_metrics[metric] / val_stats.grades 166 | try: 167 | wandb_data.update( 168 | pl.DataFrame(last_tune_log(output_dir)) 169 | .drop("step") 170 | .mean() 171 | .to_dicts()[0] 172 | ) 173 | except Exception: 174 | pass 175 | wandb.log(wandb_data) 176 | 177 | # Update expected tokens 178 | expected_tokens = wandb_data["tokens"] 179 | 180 | # Clean up output directory to save space 181 | try: 182 | best_iteration = ( 183 | wandb.Api() 184 | .run(f"{run.entity}/{run.project}/{run.id}") 185 | .history() 186 | .sort_values(by="reward")["iteration"] 187 | .iloc[-1] 188 | ) 189 | # Clear all but the best and current iterations 190 | clear_iteration_dirs(output_dir, excluding=[best_iteration, i]) 191 | except Exception: 192 | pass 193 | 194 | # Pack the tokenized results into tensors 195 | tokenized_results = [ 196 | result 197 | for results in train_results 198 | for result in results 199 | if result.advantage != 0 200 | ] 201 | packed_tensors = packed_tensors_from_tokenized_results( 202 | tokenized_results, 203 | seq_len=seq_len, 204 | pad_token_id=tokenizer.pad_token_id, # type: ignore 205 | ) 206 | if verbosity > 0: 207 | print(f"Packed tensors into {packed_tensors["tokens"].size()} shape") 208 | 209 | # Tune the model 210 | model_name = await tune( 211 | base_model=model.base_model if kl_coef > 0 else model_name, 212 | output_dir=output_dir, 213 | packed_tensors=packed_tensors, 214 | model=model.tune_model, 215 | model_type=model.tune_model_type, 216 | config=TuneRecipeConfig( 217 | optimizer=ComponentConfig( 218 | "torch.optim.AdamW", 219 | lr=lr, 220 | betas=betas, 221 | weight_decay=weight_decay, 222 | fused=True, 223 | ), 224 | loss=ComponentConfig( 225 | GRPO, 226 | clip_epsilon=clip_epsilon, 227 | entropy_coef=entropy_coef, 228 | kl_coef=kl_coef, 229 | ), 230 | shuffle=True, 231 | batch_size=32768 // seq_len, 232 | fsdp_cpu_offload=True, 233 | enable_activation_checkpointing=True, 234 | enable_activation_offloading=True, 235 | custom_sharded_layers=["tok_embeddings", "output"], 236 | num_output_chunks=model.tune_num_output_chunks, 237 | compile=True, 238 | ), 239 | verbosity=verbosity, 240 | ) 241 | 242 | 243 | asyncio.run(train()) 244 | --------------------------------------------------------------------------------