├── .env.example
├── .gitignore
├── .python-version
├── LICENSE
├── README.md
├── data
└── puzzles.json
├── lib
├── __init__.py
├── chat_completions.py
├── grpo.py
├── inference_early_stop.py
├── models.py
├── pack.py
├── recipe.py
├── stream.py
├── tasks.py
├── temporal_clue.py
├── tokenize.py
├── tqdm.py
├── tune.py
├── types.py
├── utils.py
└── vllm.py
├── pyproject.toml
├── train.ipynb
├── train.py
└── uv.lock
/.env.example:
--------------------------------------------------------------------------------
1 | # Your Weights & Biases API key for experiment tracking
2 | WANDB_API_KEY=your_wandb_api_key_here
3 |
4 | # The name of your Weights & Biases project
5 | WANDB_PROJECT=your_project_name
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | .venv
3 | __pycache__/
4 | /logs
5 | /models
6 | /wandb
7 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.12
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 OpenPipe inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deductive Reasoning
2 |
3 | 
4 |
5 | Train your own frontier-level deductive reasoning model with reinforcement learning.
6 |
7 | ## Overview
8 |
9 | This repository contains the training recipe for creating frontier-level deductive reasoning models using reinforcement learning. Our research demonstrates how smaller, open-weight language models can be trained to perform complex logical deduction tasks at frontier-level performance, matching or exceeding proprietary models at a fraction of the cost.
10 |
11 | We used the Temporal Clue puzzle dataset to train Qwen 14B and 32B models, improving their deductive reasoning capabilities significantly through Group Relative Policy Optimization (GRPO). Our trained models approach the performance of leading proprietary models like Claude 3.7 Sonnet while maintaining cost efficiency.
12 |
13 | ## Resources
14 |
15 | - **Training Recipe**: This repository (recipe for RL training)
16 | - **Training Dataset**: [Temporal Clue Puzzles](https://github.com/bradhilton/temporal-clue)
17 | - **RL Experiments**: [OpenPipe RL Experiments](https://github.com/openpipe/rl-experiments)
18 | - **Model Weights**:
19 | - [Deductive Reasoning Qwen 14B](https://huggingface.co/OpenPipe/Deductive-Reasoning-Qwen-14B)
20 | - [Deductive Reasoning Qwen 32B](https://huggingface.co/OpenPipe/Deductive-Reasoning-Qwen-32B)
21 | - **Blog Post**: [Using GRPO to Beat o1, o3-mini and R1 at "Temporal Clue"](https://openpipe.ai/blog/using-grpo-to-beat-o1-o3-mini-and-r1-on-temporal-clue)
22 |
23 | ## Getting Started
24 |
25 | Follow these steps to run the training recipe:
26 |
27 | ### Prerequisites
28 |
29 | - Sufficient NVIDIA GPUs for your chosen model:
30 | - Qwen 14B requires at least 2 GPUs
31 | - Qwen 32B requires at least 4 GPUs
32 | - [uv](https://github.com/astral-sh/uv) package manager
33 | - [Weights & Biases](https://wandb.ai) account
34 |
35 | ### Installation
36 |
37 | 1. Clone this repository:
38 |
39 | ```bash
40 | git clone https://github.com/bradhilton/deductive-reasoning.git
41 | cd deductive-reasoning
42 | ```
43 |
44 | 2. Install dependencies using uv:
45 |
46 | ```bash
47 | uv sync
48 | ```
49 |
50 | 3. Reinstall torchtune due to an executable naming conflict with Ray Tune 🙈
51 |
52 | ```bash
53 | uv remove torchtune
54 | uv add torchtune
55 | ```
56 |
57 | 4. (Optional) Configure environment variables:
58 |
59 | ```bash
60 | cp .env.example .env
61 | ```
62 |
63 | Edit the `.env` file to add your Weights & Biases API key and project name:
64 |
65 | ```
66 | WANDB_API_KEY=your_wandb_api_key_here
67 | WANDB_PROJECT=your_project_name
68 | ```
69 |
70 | ### Running the Training
71 |
72 | 1. Open the `train.ipynb` notebook or `train.py` script and configure the training parameters:
73 |
74 | - Set a unique `run_name` for your experiment
75 | - Choose the model (e.g., `models.qwen_14b()` or `models.qwen_32b()`)
76 | - Adjust other parameters as needed (learning rate, number of iterations, etc.)
77 |
78 | 2. Run the training:
79 |
80 | - If using the notebook: Execute all cells in `train.ipynb`
81 | - If using the script: Run `uv run train.py`
82 |
83 | 3. Monitor training progress in Weights & Biases.
84 |
85 | The training process will save the latest and/or best checkpoints in your output directory, allowing you to resume training if interrupted.
86 |
87 | ## Methodology
88 |
89 | Our training approach used reinforcement learning to incrementally improve models' deductive reasoning capabilities:
90 |
91 | 1. **Environment**: Temporal Clue puzzles (inspired by the board game Clue/Cluedo) with verifiable solutions
92 | 2. **Algorithm**: Group Relative Policy Optimization (GRPO) without KL divergence penalty
93 | 3. **Training Loop**:
94 | - Generate model responses to puzzle tasks
95 | - Grade responses and estimate advantages for each group of completions
96 | - Fine-tune the model using clipped policy gradients
97 | - Repeat with new puzzles until peak performance
98 |
99 | We used the torchtune library for efficient training and vLLM for inference, with the following key parameters:
100 |
101 | - Models: Qwen 2.5 Instruct 14B & 32B
102 | - Tasks per Iteration: 32
103 | - Samples per Task per Iteration: 50
104 | - Learning Rate: 6e-6
105 |
106 | ## Results
107 |
108 | Our training produced impressive performance gains, demonstrating that open-weight models can achieve frontier-level reasoning capabilities.
109 |
110 | 
111 |
112 | We dramatically improved the cost-accuracy tradeoff compared to proprietary models:
113 |
114 | 
115 |
116 | Notably, we discovered that meaningful performance improvements (10-15%) can be achieved with as few as 16 training examples, making this approach accessible even with limited data.
117 |
118 | ## License
119 |
120 | This training recipe is freely available under the MIT license.
121 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | from dotenv import load_dotenv
2 |
3 | load_dotenv()
4 |
--------------------------------------------------------------------------------
/lib/chat_completions.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from datetime import datetime
3 | from openai import AsyncOpenAI
4 | from openai.types.chat.chat_completion import ChatCompletion
5 | from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
6 | import os
7 | from typing import Callable, Unpack
8 |
9 |
10 | from .stream import consume_chat_completion_stream
11 | from .types import CreateParams
12 | from .utils import timeout
13 |
14 | MAX_INT = 2**31 - 1
15 | unlimited_semaphore = asyncio.Semaphore(MAX_INT)
16 |
17 |
18 | async def get_chat_completion(
19 | client: AsyncOpenAI,
20 | log_dir: str | None = None,
21 | log_results: bool = True,
22 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None = None,
23 | semaphore: asyncio.Semaphore | None = None,
24 | **create_params: Unpack[CreateParams],
25 | ) -> ChatCompletion:
26 | """
27 | Given a client and arguments to openai.chat.completions.create, this function will return a chat completion with some additional features:
28 | - Logging of results to a file
29 | - Streaming of results to the callback function
30 | - Support for capping concurrent requests with a semaphore
31 |
32 | Args:
33 | client (AsyncOpenAI): An AsyncOpenAI client
34 | log_dir (str | None): The directory to log the results of the chat completion
35 | log_results (bool): Whether to log the results of the chat completion
36 | on_chunk (Callable[[ChatCompletionChunk, ChatCompletion], None]): A callback function that will be called with each chunk of the chat completion
37 | semaphore (asyncio.Semaphore): A semaphore to limit the number of concurrent requests
38 |
39 | Returns:
40 | ChatCompletion: A chat completion
41 | """
42 | async with semaphore or unlimited_semaphore:
43 | on_chunk = _create_on_chunk_callback(
44 | create_params, log_dir, log_results, on_chunk
45 | )
46 | if on_chunk:
47 | return await consume_chat_completion_stream(
48 | await client.chat.completions.create(**create_params, stream=True),
49 | on_chunk=on_chunk,
50 | )
51 | else:
52 | return await client.chat.completions.create(**create_params)
53 |
54 |
55 | def _create_on_chunk_callback(
56 | create_params: CreateParams,
57 | log_dir: str | None,
58 | log_results: bool,
59 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None = None,
60 | ) -> Callable[[ChatCompletionChunk, ChatCompletion], None] | None:
61 | """Create a callback function for handling streaming chunks.
62 |
63 | This function sets up logging and wraps the user's callback function
64 | if provided, or returns None if no logging or callbacks are needed.
65 |
66 | Args:
67 | create_params: Parameters for the completion (used for conversation history)
68 | log_dir: Optional custom directory for logging
69 | log_results: Whether to log results
70 | on_chunk: Optional user callback for streaming
71 |
72 | Returns:
73 | A callback function or None if no callback is needed
74 | """
75 | if not (log_results or on_chunk):
76 | return None
77 |
78 | # Set up logging
79 | log_dir = log_dir or "./logs/chat-completions"
80 | os.makedirs(log_dir, exist_ok=True)
81 | log_file = os.path.join(log_dir, f"{datetime.now().isoformat()}.log")
82 |
83 | # Write conversation history to the log file
84 | if log_results:
85 | with open(log_file, "w") as f:
86 | f.write(
87 | "".join(
88 | f"{message['role'].capitalize()}:\n{message.get('content', '')}\n\n"
89 | for message in create_params["messages"]
90 | )
91 | + "Assistant:\n"
92 | )
93 |
94 | # Create a callback function that handles both user callbacks and logging
95 | def callback(chunk: ChatCompletionChunk, completion: ChatCompletion) -> None:
96 | # Call user's callback if provided
97 | if on_chunk:
98 | on_chunk(chunk, completion)
99 |
100 | # Log chunk content if enabled
101 | if log_results and chunk.choices:
102 | try:
103 | with timeout():
104 | with open(log_file, "a") as f:
105 | f.write(chunk.choices[0].delta.content or "")
106 | except TimeoutError:
107 | pass # Skip writing this chunk if it times out
108 |
109 | return callback
110 |
--------------------------------------------------------------------------------
/lib/grpo.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field, fields
2 | import torch
3 | from typing import Iterable, Optional, Union
4 |
5 | ignore_labels_cache: dict[
6 | tuple[torch.Size, Union[int, float], torch.dtype, torch.device], torch.Tensor
7 | ] = {}
8 |
9 |
10 | def shift_tensor(
11 | labels: torch.Tensor, ignore_label: Optional[Union[int, float]] = None
12 | ) -> torch.Tensor:
13 | if ignore_label is None:
14 | ignore_label = (
15 | -100
16 | if labels.dtype in (torch.int32, torch.int64, torch.int16, torch.int8)
17 | else float("nan")
18 | )
19 |
20 | # Create a tensor of ignore labels every time if we are compiling, otherwise cache it
21 | if torch.compiler.is_compiling():
22 | ignore_labels = torch.full(
23 | (labels.shape[0], 1), ignore_label, dtype=labels.dtype, device=labels.device
24 | )
25 | else:
26 | key = (labels.shape[-1:], ignore_label, labels.dtype, labels.device)
27 | if key not in ignore_labels_cache:
28 | ignore_labels_cache[key] = torch.full(
29 | (labels.shape[0], 1),
30 | ignore_label,
31 | dtype=labels.dtype,
32 | device=labels.device,
33 | )
34 | ignore_labels = ignore_labels_cache[key]
35 |
36 | # Shift labels to compute loss
37 | return torch.cat((labels[..., 1:], ignore_labels), dim=1)
38 |
39 |
40 | @dataclass
41 | class GRPOResult:
42 | num_tokens: torch.Tensor = field(default_factory=lambda: torch.tensor(0))
43 | policy_loss: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0))
44 | entropy: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0))
45 | kl_div: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0))
46 | entropy_weight: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0))
47 | kl_weight: torch.Tensor = field(default_factory=lambda: torch.tensor(0.0))
48 |
49 | def named_tensors(self) -> Iterable[tuple[str, torch.Tensor]]:
50 | for field in fields(self):
51 | yield field.name, getattr(self, field.name)
52 |
53 | def per_token(self) -> "GRPOResult":
54 | return GRPOResult(
55 | **{name: tensor / self.num_tokens for name, tensor in self.named_tensors()}
56 | )
57 |
58 | def tensors(self) -> Iterable[torch.Tensor]:
59 | return (tensor for _, tensor in self.named_tensors())
60 |
61 | def to(self, target: Union[torch.device, torch.dtype]) -> "GRPOResult":
62 | return GRPOResult(
63 | **{name: tensor.to(target) for name, tensor in self.named_tensors()}
64 | )
65 |
66 | def __iadd__(self, other: "GRPOResult") -> "GRPOResult":
67 | for tensor, other_tensor in zip(self.tensors(), other.tensors()):
68 | tensor += other_tensor.to(tensor.device)
69 | return self
70 |
71 | @property
72 | def total_loss(self) -> torch.Tensor:
73 | return (
74 | self.policy_loss
75 | - self.entropy * self.entropy_weight / self.num_tokens
76 | + torch.nan_to_num(self.kl_div, 0.0) * self.kl_weight / self.num_tokens
77 | )
78 |
79 |
80 | class GRPO(torch.nn.Module):
81 | def __init__(
82 | self, clip_epsilon: float = 0.2, entropy_coef: float = 0.0, kl_coef: float = 0.0
83 | ) -> None:
84 | """
85 | Initialize the GRPO loss.
86 |
87 | Args:
88 | clip_epsilon (float): The epsilon value for clipping the policy ratio.
89 | entropy_coef (float): The coefficient for the entropy bonus.
90 | kl_coef (float): The coefficient for the KL divergence penalty.
91 | """
92 | super().__init__()
93 | self.clip_epsilon = clip_epsilon
94 | self.entropy_coef = entropy_coef
95 | self.kl_coef = kl_coef
96 |
97 | def forward(
98 | self,
99 | logits: Union[torch.Tensor, list[torch.Tensor]],
100 | tokens: torch.Tensor,
101 | advantages: torch.Tensor,
102 | logprobs: torch.Tensor,
103 | reference_logprobs: torch.Tensor | None,
104 | mask: torch.Tensor,
105 | weights: torch.Tensor,
106 | bos_id: int,
107 | ) -> GRPOResult:
108 | """
109 | Computes the GRPO loss for sequence data, supporting both regular and chunked inputs.
110 |
111 | Args:
112 | logits (Union[Tensor, List[Tensor]]):
113 | Either a single tensor of shape (batch_size, sequence_length, vocab_size)
114 | or a list of chunked tensors, each of shape
115 | (batch_size, sequence_length/num_chunks, vocab_size).
116 | tokens (Tensor):
117 | Shape: (batch_size, sequence_length)
118 | Token indices.
119 | advantages (Tensor):
120 | Shape: (batch_size, sequence_length)
121 | Token advantages.
122 | logprobs (Tensor):
123 | Shape: (batch_size, sequence_length)
124 | Token log probabilities.
125 | reference_logprobs (Tensor | None):
126 | Shape: (batch_size, sequence_length)
127 | Reference token log probabilities.
128 | mask (Tensor):
129 | Shape: (batch_size, sequence_length)
130 | Boolean mask specifying positions where loss should be computed.
131 | weights (Tensor):
132 | Shape: (batch_size, sequence_length)
133 | Weights for each token.
134 | bos_id (int):
135 | Index of the beginning of sequence token in the vocabulary.
136 |
137 | Returns:
138 | GRPOResult: The combined loss results across all chunks.
139 | """
140 | if isinstance(logits, list):
141 | result = GRPOResult().to(logits[0].device)
142 | num_chunks = len(logits)
143 | for chunked_args in zip(
144 | logits,
145 | tokens.chunk(num_chunks, dim=1),
146 | advantages.chunk(num_chunks, dim=1),
147 | logprobs.chunk(num_chunks, dim=1),
148 | (
149 | reference_logprobs.chunk(num_chunks, dim=1)
150 | if reference_logprobs is not None
151 | else [None] * num_chunks
152 | ),
153 | mask.chunk(num_chunks, dim=1),
154 | weights.chunk(num_chunks, dim=1),
155 | ):
156 | result += self._forward_chunk(*chunked_args, bos_id=bos_id)
157 | return result
158 |
159 | return self._forward_chunk(
160 | logits,
161 | tokens,
162 | advantages,
163 | logprobs,
164 | reference_logprobs,
165 | mask,
166 | weights,
167 | bos_id,
168 | )
169 |
170 | def _forward_chunk(
171 | self,
172 | logits: torch.Tensor,
173 | tokens: torch.Tensor,
174 | advantages: torch.Tensor,
175 | logprobs: torch.Tensor,
176 | reference_logprobs: torch.Tensor | None,
177 | mask: torch.Tensor,
178 | weights: torch.Tensor,
179 | bos_id: int,
180 | ) -> GRPOResult:
181 | """
182 | Processes a single chunk of the GRPO loss computation.
183 | """
184 | # Flatten logits tensor to shape (batch_size * sequence_length, vocab_size)
185 | logits = logits.view(-1, logits.size(-1))
186 | tokens = shift_tensor(tokens, bos_id).view(
187 | -1
188 | ) # (batch_size * sequence_length,)
189 | advantages = shift_tensor(advantages, 0).view(
190 | -1
191 | ) # (batch_size * sequence_length,)
192 | logprobs = shift_tensor(logprobs, 0).view(-1) # (batch_size * sequence_length,)
193 | if reference_logprobs is not None:
194 | reference_logprobs = shift_tensor(reference_logprobs, 0).view(-1)
195 | mask = shift_tensor(mask, False).view(-1) # (batch_size * sequence_length,)
196 | weights = shift_tensor(weights, 0).view(-1) # (batch_size * sequence_length,)
197 | num_tokens = mask.sum()
198 | dist = torch.distributions.Categorical(logits=logits)
199 | entropy = dist.entropy()[mask]
200 | new_logprobs = dist.log_prob(tokens)[mask]
201 | logprobs = logprobs[mask]
202 | logprobs = torch.where(torch.isnan(logprobs), new_logprobs, logprobs)
203 | if reference_logprobs is not None:
204 | reference_logprobs = reference_logprobs[mask]
205 | advantages = advantages[mask]
206 | diff = new_logprobs - logprobs
207 | prob_ratio = torch.exp(diff)
208 | policy_loss = -torch.min(
209 | prob_ratio * advantages,
210 | torch.clip(prob_ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
211 | * advantages,
212 | )
213 | if reference_logprobs is not None:
214 | kl_div = torch.nn.functional.kl_div(
215 | new_logprobs,
216 | reference_logprobs,
217 | reduction="none",
218 | log_target=True,
219 | )
220 | else:
221 | kl_div = torch.tensor(torch.nan, device=logits.device)
222 | weights = weights[mask]
223 | return GRPOResult(
224 | num_tokens=num_tokens,
225 | policy_loss=policy_loss.mul(weights).sum(),
226 | entropy=entropy.mul(weights).sum(),
227 | kl_div=kl_div.mul(weights).sum(),
228 | entropy_weight=self.entropy_coef * num_tokens,
229 | kl_weight=self.kl_coef * num_tokens,
230 | )
231 |
--------------------------------------------------------------------------------
/lib/inference_early_stop.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | import math
3 | from openai.types.chat.chat_completion import ChatCompletion
4 | from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
5 |
6 |
7 | @dataclass
8 | class InferenceEarlyStop:
9 | """
10 | Utility for stopping inference early if token log probabilities are too low.
11 |
12 | Args:
13 | alpha: The smoothing factor for the exponential weighted moving average.
14 | threshold: The log probability threshold to stop inference below.
15 | log_early_stops: Whether to log early stops.
16 | log_last_n_characters: The number of characters to log from the end of the stopped completion.
17 | """
18 |
19 | alpha: float = 0.992
20 | threshold: float = -3
21 | log_early_stops: bool = False
22 | log_last_n_characters: int = 64
23 | ewm_logprobs: dict[str, float] = field(default_factory=dict)
24 |
25 | def __call__(self, chunk: ChatCompletionChunk, completion: ChatCompletion) -> None:
26 | # TODO: handle multiple choices and refusal logprobs
27 | if (
28 | not chunk.choices
29 | or not chunk.choices[0].logprobs
30 | or not chunk.choices[0].logprobs.content
31 | ):
32 | return
33 | for token_logprob in chunk.choices[0].logprobs.content:
34 | if token_logprob.logprob is None or math.isnan(token_logprob.logprob):
35 | raise StopIteration()
36 | ewm_logprob = (
37 | self.alpha * self.ewm_logprobs.get(completion.id, 0)
38 | + (1 - self.alpha) * token_logprob.logprob
39 | )
40 | if ewm_logprob < self.threshold:
41 | if self.log_early_stops:
42 | print(
43 | f"Early stopping - ewm_logprob: {ewm_logprob} completion_tokens: {len(completion.choices[0].logprobs.content)}" # type: ignore
44 | )
45 | if self.log_last_n_characters:
46 | print(
47 | f"Last {self.log_last_n_characters} characters: {completion.choices[0].message.content[-self.log_last_n_characters :]}" # type: ignore
48 | )
49 | setattr(completion.choices[0], "early_stop", True)
50 | raise StopIteration()
51 | self.ewm_logprobs[completion.id] = ewm_logprob
52 |
--------------------------------------------------------------------------------
/lib/models.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import torch
3 | from torchtune.models.qwen2_5 import (
4 | qwen2_5_7b_base,
5 | qwen2_5_14b_base,
6 | qwen2_5_14b_instruct,
7 | qwen2_5_32b_base,
8 | qwen2_5_32b_instruct,
9 | qwen2_5_72b_instruct,
10 | )
11 | from torchtune.models.llama3_1 import llama3_1_8b, llama3_1_70b
12 | from torchtune.modules import TransformerDecoder
13 | from typing import Callable
14 |
15 |
16 | @dataclass
17 | class Model:
18 | """Basic language model configuration"""
19 |
20 | base_model: str
21 | min_gpus: int
22 | tune_model_type: str
23 | tune_model: Callable[[], TransformerDecoder]
24 | tune_num_output_chunks: int
25 |
26 | def __post_init__(self) -> None:
27 | assert (
28 | torch.cuda.device_count() >= self.min_gpus
29 | ), f"{self.base_model} requires at least {self.min_gpus} GPUs"
30 |
31 |
32 | def distilled_qwen_7b() -> Model:
33 | """deepseek-ai/DeepSeek-R1-Distill-Qwen-7B model config."""
34 | return Model(
35 | base_model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
36 | min_gpus=1,
37 | tune_model_type="QWEN2",
38 | tune_model=qwen2_5_7b_base,
39 | tune_num_output_chunks=8,
40 | )
41 |
42 |
43 | def theta_8b() -> Model:
44 | """NousResearch/Hermes-2-Theta-Llama-3-8B model config."""
45 | return Model(
46 | base_model="NousResearch/Hermes-2-Theta-Llama-3-8B",
47 | min_gpus=1,
48 | tune_model_type="LLAMA3",
49 | tune_model=llama3_1_8b,
50 | tune_num_output_chunks=8,
51 | )
52 |
53 |
54 | def qwen_14b() -> Model:
55 | """Qwen/Qwen2.5-14B-Instruct model config."""
56 | return Model(
57 | base_model="Qwen/Qwen2.5-14B-Instruct",
58 | min_gpus=2,
59 | tune_model_type="QWEN2",
60 | tune_model=qwen2_5_14b_instruct,
61 | tune_num_output_chunks=2,
62 | )
63 |
64 |
65 | def distilled_qwen_14b() -> Model:
66 | """deepseek-ai/DeepSeek-R1-Distill-Qwen-14B model config."""
67 | return Model(
68 | base_model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
69 | min_gpus=2,
70 | tune_model_type="QWEN2",
71 | tune_model=qwen2_5_14b_base,
72 | tune_num_output_chunks=2,
73 | )
74 |
75 |
76 | def qwen_32b() -> Model:
77 | """Qwen/Qwen2.5-32B-Instruct model config."""
78 | return Model(
79 | base_model="Qwen/Qwen2.5-32B-Instruct",
80 | min_gpus=4,
81 | tune_model_type="QWEN2",
82 | tune_model=qwen2_5_32b_instruct,
83 | tune_num_output_chunks=2,
84 | )
85 |
86 |
87 | def distilled_qwen_32b() -> Model:
88 | """deepseek-ai/DeepSeek-R1-Distill-Qwen-32B model config."""
89 | return Model(
90 | base_model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
91 | min_gpus=4,
92 | tune_model_type="QWEN2",
93 | tune_model=qwen2_5_32b_base,
94 | tune_num_output_chunks=2,
95 | )
96 |
97 |
98 | def llama_70b() -> Model:
99 | """unsloth/Llama-3.3-70B-Instruct model config."""
100 | return Model(
101 | base_model="unsloth/Llama-3.3-70B-Instruct",
102 | min_gpus=8,
103 | tune_model_type="LLAMA3",
104 | tune_model=llama3_1_70b,
105 | tune_num_output_chunks=2,
106 | )
107 |
108 |
109 | def distilled_llama_70b() -> Model:
110 | """deepseek-ai/DeepSeek-R1-Distill-Llama-70B model config."""
111 | return Model(
112 | base_model="deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
113 | min_gpus=8,
114 | tune_model_type="LLAMA3",
115 | tune_model=llama3_1_70b,
116 | tune_num_output_chunks=8,
117 | )
118 |
119 |
120 | def qwen_72b() -> Model:
121 | """Qwen/Qwen2.5-72B-Instruct model config."""
122 | return Model(
123 | base_model="Qwen/Qwen2.5-72B-Instruct",
124 | min_gpus=8,
125 | tune_model_type="QWEN2",
126 | tune_model=qwen2_5_72b_instruct,
127 | tune_num_output_chunks=2,
128 | )
129 |
--------------------------------------------------------------------------------
/lib/pack.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import os
3 | import random
4 | import seaborn as sns
5 | import torch
6 | from torch.utils.data import Dataset
7 | from typing import TypedDict, Unpack
8 |
9 | from .tokenize import TokenizedResult
10 |
11 |
12 | class PackedTensors(TypedDict):
13 | tokens: torch.Tensor
14 | group_ids: torch.Tensor
15 | parent_ids: torch.Tensor
16 | input_pos: torch.Tensor
17 | assistant_mask: torch.Tensor
18 | logprobs: torch.Tensor
19 | advantages: torch.Tensor
20 | weights: torch.Tensor
21 |
22 |
23 | class DiskPackedTensors(TypedDict):
24 | dir: str
25 | num_sequences: int
26 | sequence_length: int
27 |
28 |
29 | class PackedDataset(Dataset[PackedTensors]):
30 | def __init__(self, **kwargs: Unpack[DiskPackedTensors]) -> None:
31 | self.tensors = packed_tensors_from_dir(**kwargs)
32 |
33 | def __len__(self) -> int:
34 | return self.tensors["tokens"].shape[0]
35 |
36 | def __getitem__(self, index: int) -> PackedTensors:
37 | return {key: tensor[index] for key, tensor in self.tensors.items()} # type: ignore
38 |
39 |
40 | def packed_tensors_from_tokenized_results(
41 | tokenized_results: list[TokenizedResult],
42 | seq_len: int,
43 | pad_token_id: int = -100,
44 | truncate_long_results: bool = True,
45 | ) -> PackedTensors:
46 | token_ids: list[list[int]] = [[]]
47 | group_ids: list[list[int]] = [[]]
48 | parent_ids: list[list[int]] = [[]]
49 | input_pos: list[list[int]] = [[]]
50 | assistant_mask: list[list[int]] = [[]]
51 | logprobs: list[list[float]] = [[]]
52 | advantages: list[list[float]] = [[]]
53 | weights: list[list[float]] = [[]]
54 |
55 | for result in tokenized_results:
56 | if len(result.token_ids) > seq_len and not truncate_long_results:
57 | print("Result is too long, skipping")
58 | continue
59 | result_without_prompt = result.without_prompt()
60 | if sum(result.assistant_mask) == 0:
61 | print("Result has no unique completion tokens, skipping")
62 | continue
63 | if (
64 | len(token_ids[-1])
65 | + (
66 | len(result_without_prompt.token_ids)
67 | if result.prompt_id in group_ids[-1]
68 | else len(result.token_ids)
69 | )
70 | > seq_len
71 | ):
72 | token_ids.append([])
73 | group_ids.append([])
74 | parent_ids.append([])
75 | input_pos.append([])
76 | assistant_mask.append([])
77 | logprobs.append([])
78 | advantages.append([])
79 | weights.append([])
80 | group_id = random.randint(-(2**63), 2**63 - 1)
81 | if result.prompt_id in group_ids[-1]:
82 | result = result_without_prompt
83 | token_ids[-1].extend(result.token_ids)
84 | group_ids[-1].extend(
85 | [result.prompt_id] * result.prompt_length
86 | + [group_id] * (len(result.token_ids) - result.prompt_length)
87 | )
88 | parent_ids[-1].extend([result.prompt_id] * len(result.token_ids))
89 | input_pos[-1].extend(result.input_pos)
90 | assistant_mask[-1].extend(result.assistant_mask)
91 | offset = len(logprobs[-1])
92 | logprobs[-1].extend([float("nan")] * len(result.token_ids))
93 | if result.token_logprobs:
94 | assistant_indices = [
95 | i for i, mask in enumerate(result.assistant_mask) if mask
96 | ]
97 | assert len(assistant_indices) <= len(result.token_logprobs)
98 | for idx, token_logprob in zip(
99 | assistant_indices,
100 | result.token_logprobs[
101 | len(result.token_logprobs) - len(assistant_indices) :
102 | ],
103 | ):
104 | logprobs[-1][idx + offset] = token_logprob.logprob or float("nan")
105 | advantages[-1].extend([result.advantage] * len(result.token_ids))
106 | # prevent the model unlearning when to stop
107 | # advantages[-1][-1] = max(0, advantages[-1][-1])
108 | advantages[-1][-1] = 0
109 | weights[-1].extend([1 / sum(result.assistant_mask)] * len(result.token_ids))
110 | if truncate_long_results:
111 | token_ids[-1] = token_ids[-1][:seq_len]
112 | group_ids[-1] = group_ids[-1][:seq_len]
113 | parent_ids[-1] = parent_ids[-1][:seq_len]
114 | input_pos[-1] = input_pos[-1][:seq_len]
115 | assistant_mask[-1] = assistant_mask[-1][:seq_len]
116 | logprobs[-1] = logprobs[-1][:seq_len]
117 | advantages[-1] = advantages[-1][:seq_len]
118 | weights[-1] = weights[-1][:seq_len]
119 |
120 | def pad(values: list[list], pad_value) -> list[list]:
121 | max_len = seq_len
122 | for value in values:
123 | value.extend([pad_value] * (max_len - len(value)))
124 | return values
125 |
126 | assistant_mask_tensor = torch.tensor(pad(assistant_mask, 0), dtype=torch.bool)
127 | weights_tensor = torch.tensor(pad(weights, 0.0))
128 | weights_tensor = torch.where(
129 | assistant_mask_tensor, weights_tensor, torch.zeros_like(weights_tensor)
130 | )
131 | weights_tensor[assistant_mask_tensor] /= weights_tensor[
132 | assistant_mask_tensor
133 | ].mean()
134 |
135 | return {
136 | "tokens": torch.tensor(pad(token_ids, pad_token_id)),
137 | "group_ids": torch.tensor(pad(group_ids, -1)),
138 | "parent_ids": torch.tensor(pad(parent_ids, -1)),
139 | "input_pos": torch.tensor(pad(input_pos, 0)),
140 | "assistant_mask": assistant_mask_tensor,
141 | "logprobs": torch.tensor(pad(logprobs, float("nan"))),
142 | "advantages": torch.tensor(pad(advantages, 0.0)),
143 | "weights": weights_tensor,
144 | }
145 |
146 |
147 | def packed_tensors_from_dir(**kwargs: Unpack[DiskPackedTensors]) -> PackedTensors:
148 | os.makedirs(kwargs["dir"], exist_ok=True)
149 | return {
150 | key: torch.from_file(
151 | f"{kwargs["dir"]}/{key}.pt",
152 | shared=True,
153 | size=kwargs["num_sequences"]
154 | * kwargs["sequence_length"]
155 | * (kwargs["sequence_length"] if key == "mask" else 1),
156 | dtype=dtype,
157 | )
158 | .view(kwargs["num_sequences"], kwargs["sequence_length"], -1)
159 | .squeeze()
160 | for key, dtype in {
161 | "tokens": torch.long,
162 | "group_ids": torch.long,
163 | "parent_ids": torch.long,
164 | "input_pos": torch.long,
165 | "assistant_mask": torch.bool,
166 | "logprobs": torch.float32,
167 | "advantages": torch.float32,
168 | "weights": torch.float32,
169 | }.items()
170 | } # type: ignore
171 |
172 |
173 | def packed_tensors_to_dir(tensors: PackedTensors, dir: str) -> DiskPackedTensors:
174 | os.makedirs(dir, exist_ok=True)
175 | disk_packed_tensors: DiskPackedTensors = {
176 | "dir": dir,
177 | "num_sequences": tensors["tokens"].shape[0],
178 | "sequence_length": tensors["tokens"].shape[1],
179 | }
180 | for key, tensor in packed_tensors_from_dir(**disk_packed_tensors).items():
181 | tensor.copy_(tensors[key]) # type: ignore
182 | return disk_packed_tensors
183 |
184 |
185 | def plot_packed_tensors(packed_tensors: PackedTensors) -> None:
186 | plt.figure(figsize=(15, 24))
187 |
188 | for tensor, label, title, subplot_idx in (
189 | (packed_tensors["tokens"], "Token IDs", "Token IDs", 1),
190 | (packed_tensors["logprobs"], "Log Probabilities", "Token Log Probs", 2),
191 | (packed_tensors["group_ids"], "Group IDs", "Token Groups", 3),
192 | (packed_tensors["parent_ids"], "Parent IDs", "Parent IDs", 4),
193 | (packed_tensors["input_pos"], "Position", "Input Position", 5),
194 | (packed_tensors["assistant_mask"], "Assistant Mask", "Assistant Mask", 6),
195 | (packed_tensors["advantages"], "Advantages", "Token Advantages", 7),
196 | (packed_tensors["weights"], "Weights", "Token Weights", 8),
197 | ):
198 | plt.subplot(4, 2, subplot_idx)
199 | sns.heatmap(
200 | tensor.numpy(), cmap="viridis", cbar_kws={"label": label}, xticklabels=False # type: ignore
201 | )
202 | plt.title(title)
203 | plt.xlabel("Sequence Position")
204 | plt.ylabel("Batch")
205 |
206 | plt.tight_layout()
207 | plt.show()
208 |
--------------------------------------------------------------------------------
/lib/recipe.py:
--------------------------------------------------------------------------------
1 | # This is probably most similar to the recipe found in
2 | # https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py
3 |
4 | from functools import partial
5 | import os
6 | from omegaconf import DictConfig, ListConfig
7 | import sys
8 | import time
9 | import torch
10 | import torch.distributed
11 | from torch.optim.optimizer import Optimizer
12 | from torch.utils.data import DataLoader, Dataset, DistributedSampler
13 | from torchtune import config, modules, training, utils
14 | from torchtune.modules import TransformerDecoder
15 | from torchtune.recipe_interfaces import FTRecipeInterface
16 | from torchtune.training import DummyProfiler, PROFILER_KEY
17 | from torchtune.training.activations import apply_selective_activation_checkpointing
18 | from torchtune.training.checkpointing import Checkpointer
19 | from torchtune.training.metric_logging import MetricLoggerInterface
20 | from tqdm import tqdm
21 | from typing import (
22 | Any,
23 | cast,
24 | Callable,
25 | Dict,
26 | Generic,
27 | Iterator,
28 | List,
29 | Mapping,
30 | Optional,
31 | overload,
32 | ParamSpec,
33 | Tuple,
34 | TypeVar,
35 | Union,
36 | )
37 | from warnings import warn
38 |
39 | from .pack import PackedTensors
40 | from .grpo import GRPO, GRPOResult, shift_tensor
41 |
42 | log = utils.get_logger("DEBUG")
43 |
44 | T = TypeVar("T", covariant=True)
45 | P = ParamSpec("P")
46 |
47 |
48 | class ComponentConfig(DictConfig, Generic[T]):
49 | @overload
50 | def __init__(
51 | self,
52 | _component_: Callable[P, T],
53 | *args: P.args,
54 | **kwargs: P.kwargs,
55 | ) -> None: ...
56 |
57 | @overload
58 | def __init__(self, _component_: str, *args: Any, **kwargs: Any) -> None: ...
59 |
60 | def __init__(
61 | self, _component_: Union[Callable, str], *args: Any, **kwargs: Any
62 | ) -> None:
63 | super().__init__({}, flags={"allow_objects": True})
64 | self._component_ = _component_
65 | if args:
66 | raise ValueError(
67 | "Positional arguments are not supported in ComponentConfig"
68 | )
69 | self.update(kwargs)
70 |
71 | def dict_config(self) -> DictConfig:
72 | return DictConfig(
73 | {
74 | "_component_": (
75 | self._component_
76 | if isinstance(self._component_, str)
77 | else f"{self._component_.__module__}.{self._component_.__name__}"
78 | ),
79 | **{k: v for k, v in self.items() if k != "_component_"},
80 | }
81 | )
82 |
83 |
84 | def instantiate_component(cfg: ComponentConfig[T], *args: Any, **kwargs: Any) -> T:
85 | if isinstance(cfg._component_, str):
86 | return config.instantiate(cfg, *args, **kwargs)
87 | _kwargs = {
88 | str(k): list(v) if isinstance(v, ListConfig) else v
89 | for k, v in cfg.items()
90 | if k != "_component_"
91 | }
92 | _kwargs.update(kwargs)
93 | return cfg._component_(*args, **_kwargs)
94 |
95 |
96 | PLACEHOLDER: Any = None
97 |
98 |
99 | class TuneRecipeConfig(DictConfig):
100 | def __init__(
101 | self,
102 | *,
103 | device: Optional[Union[str, torch.device]] = "cuda",
104 | dtype: Optional[Union[str, torch.dtype]] = "bf16",
105 | optimizer: ComponentConfig[Optimizer] = ComponentConfig(
106 | "torch.optim.AdamW", lr=2e-5, fused=True
107 | ),
108 | resume_from_checkpoint: bool = False,
109 | gradient_accumulation_steps: int = 1,
110 | checkpointer: ComponentConfig[Checkpointer] = PLACEHOLDER,
111 | seed: Optional[int] = None,
112 | epochs: int = 1,
113 | max_steps_per_epoch: Optional[int] = None,
114 | metric_logger: ComponentConfig[MetricLoggerInterface] = PLACEHOLDER,
115 | model: ComponentConfig[TransformerDecoder] = PLACEHOLDER,
116 | loss: ComponentConfig[GRPO] = ComponentConfig(GRPO),
117 | dataset: ComponentConfig[Dataset[PackedTensors]] = PLACEHOLDER,
118 | shuffle: bool = False,
119 | batch_size: int = 1,
120 | fsdp_cpu_offload: Optional[bool] = None,
121 | log_every_n_steps: Optional[int] = None,
122 | log_peak_memory_stats: Optional[bool] = None,
123 | log_grad_magnitude: Optional[bool] = None,
124 | optimizer_in_bwd: Optional[bool] = None,
125 | clip_grad_norm: Optional[Union[str, float]] = None,
126 | enable_activation_checkpointing: Optional[bool] = None,
127 | enable_activation_offloading: Optional[bool] = None,
128 | save_intermediate_checkpoints: Optional[bool] = None,
129 | reference_checkpointer: Optional[ComponentConfig[Checkpointer]] = None,
130 | compile: Optional[bool] = None,
131 | custom_sharded_layers: Optional[List[str]] = None,
132 | fsdp_reshard_after_forward: Optional[bool] = None,
133 | ac_mode: Optional[str] = None,
134 | ac_option: Optional[int] = None,
135 | num_output_chunks: Optional[int] = None,
136 | profiler: Optional[ComponentConfig] = None,
137 | ) -> None:
138 | super().__init__({})
139 | self.device = device
140 | self.dtype = dtype
141 | self.optimizer = optimizer
142 | self.resume_from_checkpoint = resume_from_checkpoint
143 | self.gradient_accumulation_steps = gradient_accumulation_steps
144 | self.checkpointer = checkpointer
145 | self.seed = seed
146 | self.epochs = epochs
147 | self.max_steps_per_epoch = max_steps_per_epoch
148 | self.metric_logger = metric_logger
149 | self.model = model
150 | self.loss = loss
151 | self.dataset = dataset
152 | self.shuffle = shuffle
153 | self.batch_size = batch_size
154 | if fsdp_cpu_offload is not None:
155 | self.fsdp_cpu_offload = fsdp_cpu_offload
156 | if log_every_n_steps is not None:
157 | self.log_every_n_steps = log_every_n_steps
158 | if log_peak_memory_stats is not None:
159 | self.log_peak_memory_stats = log_peak_memory_stats
160 | if log_grad_magnitude is not None:
161 | self.log_grad_magnitude = log_grad_magnitude
162 | if optimizer_in_bwd is not None:
163 | self.optimizer_in_bwd = optimizer_in_bwd
164 | if clip_grad_norm is not None:
165 | self.clip_grad_norm = clip_grad_norm
166 | if enable_activation_checkpointing is not None:
167 | self.enable_activation_checkpointing = enable_activation_checkpointing
168 | if enable_activation_offloading is not None:
169 | self.enable_activation_offloading = enable_activation_offloading
170 | if save_intermediate_checkpoints is not None:
171 | self.save_intermediate_checkpoints = save_intermediate_checkpoints
172 | if reference_checkpointer is not None:
173 | self.reference_checkpointer = reference_checkpointer
174 | if compile is not None:
175 | self.compile = compile
176 | if custom_sharded_layers is not None:
177 | self.custom_sharded_layers = custom_sharded_layers
178 | if fsdp_reshard_after_forward is not None:
179 | self.fsdp_reshard_after_forward = fsdp_reshard_after_forward
180 | if ac_mode is not None:
181 | self.ac_mode = ac_mode
182 | if ac_option is not None:
183 | self.ac_option = ac_option
184 | if num_output_chunks is not None:
185 | self.num_output_chunks = num_output_chunks
186 | if profiler is not None:
187 | self.profiler = profiler
188 |
189 | def dict_config(self) -> DictConfig:
190 | config = DictConfig({})
191 | for k, v in self.items():
192 | if isinstance(v, DictConfig) and "_component_" in v:
193 | v = v.copy()
194 | v["_component_"] = (
195 | v["_component_"]
196 | if isinstance(v["_component_"], str)
197 | else f"{v['_component_'].__module__}.{v['_component_'].__name__}"
198 | )
199 | config[k] = v
200 | elif isinstance(v, torch.device):
201 | config[k] = str(v)
202 | elif isinstance(v, torch.dtype):
203 | config[k] = str(v)
204 | else:
205 | config[k] = v
206 | return config
207 |
208 |
209 | class TypedDataLoader(DataLoader[T]):
210 | def __iter__(self) -> Iterator[T]:
211 | return super().__iter__()
212 |
213 |
214 | class TuneRecipe(FTRecipeInterface):
215 | """
216 | Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
217 | distributed training and can be run on a single node (1 to 8 GPUs).
218 |
219 | Features:
220 | - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
221 | is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
222 | done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
223 | ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
224 | DDP is currently not supported. Training on CPU is not supported.
225 |
226 | - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
227 | flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
228 | activations in memory and instead recompute them during the backward pass. This is especially
229 | helpful for larger batch sizes when you're memory constrained. But these savings in memory
230 | come at the cost of training performance. In most cases training can slow-down quite a bit as
231 | a result of this activation recomputation.
232 |
233 | - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
234 | flag. Activation offloading is a technique similar to activations checkpointing that helps
235 | reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
236 | checkpointing drops the activation in the forward to recompute it later in the backward,
237 | activations offloading will drop the activation in the forward to the CPU and bring it
238 | back during the backward pass. As always, there is a tradeoff--these savings in memory can
239 | come at the cost of training performance and CPU resources. To recover some runtime cost,
240 | we've added an option to enable offloading on a different stream to permit overlapping with
241 | the computation. This option is currently only available on PyTorch 2.5 or later and will
242 | be enabled by default if an acceptable torch version is found. Activation offloading can be
243 | used in conjunction with activation checkpointing.
244 |
245 | - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
246 | flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
247 | most cases this should halve the memory footprint of full precision (fp32) training, without
248 | loss in model quality (will depend on the model, training data and other settings). For
249 | GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
250 | precision are currently not supported.
251 |
252 | - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
253 | controlled using the ``gradient_accumulation_steps`` flag.
254 |
255 | Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
256 |
257 | For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
258 | total batch size of 64.
259 |
260 | Gradient accumulation is especially useful when you are memory constrained. In this case,
261 | accumulating gradients might give you better training speed than enabling activation
262 | checkpointing.
263 |
264 | - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
265 | training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
266 | only saved at the end of a given epoch and used in case of resuming training.
267 |
268 | Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
269 | currently not supported.
270 |
271 | For more details on the checkpointer, please take a look at
272 | our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
273 |
274 | - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
275 |
276 | - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
277 | ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
278 | ``clip_grad_norm='inf'``.
279 |
280 | For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
281 | has example commands for how to kick-off training.
282 |
283 | Args:
284 | cfg (DictConfig): OmegaConf object parsed from yaml file
285 |
286 | Raises:
287 | ValueError: If ``dtype`` is set to fp16.
288 | RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
289 | RuntimeError: If ``left_pad_sequence`` is set as the data collator.
290 | RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
291 | RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
292 | """
293 |
294 | def __init__(self, cfg: TuneRecipeConfig) -> None:
295 | self._device = (
296 | cfg.device
297 | if isinstance(cfg.device, torch.device)
298 | else utils.get_device(device=cfg.device)
299 | )
300 | self._dtype = (
301 | cfg.dtype
302 | if isinstance(cfg.dtype, torch.dtype)
303 | else training.get_dtype(cfg.dtype, device=self._device)
304 | )
305 |
306 | if self._dtype == torch.float16:
307 | raise ValueError(
308 | "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
309 | )
310 |
311 | if (
312 | cfg.get("fsdp_cpu_offload", False)
313 | and cfg.optimizer.get("fused", False)
314 | and not utils.torch_version_ge("2.4.0")
315 | ):
316 | raise RuntimeError(
317 | "Using fused optimizer on CPU is only supported in PyTorch nightly."
318 | )
319 |
320 | # logging attributes
321 | self._log_every_n_steps: int = cfg.get("log_every_n_steps", 1)
322 | self._log_peak_memory_stats: bool = cfg.get("log_peak_memory_stats", False)
323 | self._log_grad_magnitude: bool = cfg.get("log_grad_magnitude", False)
324 |
325 | if self._log_peak_memory_stats and self._device.type != "cuda":
326 | log.info(
327 | "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
328 | )
329 | self._log_peak_memory_stats = False
330 |
331 | # _is_rank_zero is used primarily for logging. In the future, the logger
332 | # should directly take care of this
333 | _, rank = training.get_world_size_and_rank()
334 | self._is_rank_zero = rank == 0
335 |
336 | # Training cfg
337 | self._resume_from_checkpoint = cfg.resume_from_checkpoint
338 | self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
339 | self._optimizer_in_bwd: bool = cfg.get("optimizer_in_bwd", False)
340 | self._clip_grad_norm: Optional[Union[str, float]] = cfg.get(
341 | "clip_grad_norm", None
342 | )
343 |
344 | # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
345 | if self._optimizer_in_bwd:
346 | if self._clip_grad_norm is not None:
347 | raise RuntimeError(
348 | "Gradient clipping is not supported with optimizer in bwd."
349 | "Please set clip_grad_norm=None, or optimizer_in_bwd=False."
350 | )
351 | if self._gradient_accumulation_steps > 1:
352 | raise RuntimeError(
353 | "Gradient accumulation is not supported with optimizer in bwd."
354 | "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
355 | )
356 |
357 | # activation checkpointing/offloading
358 | self._enable_activation_checkpointing: bool = cfg.get(
359 | "enable_activation_checkpointing", False
360 | )
361 | self._enable_activation_offloading: bool = cfg.get(
362 | "enable_activation_offloading", False
363 | )
364 | if self._enable_activation_offloading:
365 | if self._device.type != "cuda":
366 | raise RuntimeError(
367 | "enable_activation_offloading should only be True when training on CUDA"
368 | )
369 | if not self._enable_activation_checkpointing:
370 | raise RuntimeError(
371 | "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
372 | )
373 | elif (
374 | self._enable_activation_checkpointing
375 | and cfg.checkpointer.model_type # TODO: `model_type` type is not defined
376 | != "LLAMA3_VISION"
377 | ):
378 | utils.log_rank_zero(
379 | log,
380 | "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
381 | "Enabling activation offloading should reduce memory further.",
382 | )
383 |
384 | # These are public properties which are updated by the checkpoint loader
385 | # when ``resume_from_checkpoint`` is `True` or validated in tests
386 | self.seed = training.set_seed(seed=cfg.seed)
387 | self.epochs_run = 0
388 | self.total_epochs = cfg.epochs
389 | self.max_steps_per_epoch = cfg.max_steps_per_epoch
390 | self.global_step = 0
391 | self._save_intermediate_checkpoints = cfg.get(
392 | "save_intermediate_checkpoints", False
393 | )
394 |
395 | def load_checkpoint(
396 | self, cfg_checkpointer: ComponentConfig[Checkpointer]
397 | ) -> Dict[str, Any]:
398 | """
399 | Extract the checkpoint state from file and validate. If resume_from_checkpoint
400 | is True, this also includes the recipe state.
401 | """
402 | self._checkpointer = instantiate_component(
403 | cfg_checkpointer,
404 | resume_from_checkpoint=self._resume_from_checkpoint,
405 | )
406 | checkpoint_dict = self._checkpointer.load_checkpoint()
407 |
408 | if self._resume_from_checkpoint:
409 | self._update_recipe_state(checkpoint_dict)
410 | return checkpoint_dict
411 |
412 | def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
413 | """
414 | Updates the recipe state from checkpoint.
415 | """
416 | try:
417 | self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
418 |
419 | # on mismatch, warn the user and prevent the override
420 | if self.seed != ckpt_dict[training.SEED_KEY]:
421 | warn(
422 | message=(
423 | "Config value for seed does not match the checkpoint value, "
424 | f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
425 | )
426 | )
427 | self.seed = ckpt_dict[training.SEED_KEY]
428 | if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
429 | warn(
430 | message=(
431 | "Config value for max_steps_per_epoch does not match the checkpoint value, "
432 | f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
433 | )
434 | )
435 | self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
436 |
437 | # on mismatch, warn the user but allow the override
438 | if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
439 | warn(
440 | message=(
441 | "Config value for total_epochs does not match the checkpoint value, "
442 | f"using the config value: {self.total_epochs}"
443 | )
444 | )
445 |
446 | except KeyError as e:
447 | raise KeyError(
448 | "Checkpoint does not contain the required keys needed for updating recipe state. "
449 | "Are you sure you passed in the right recipe checkpoint?"
450 | ) from e
451 |
452 | def setup(self, cfg: TuneRecipeConfig) -> None:
453 | """
454 | Setup the recipe. This includes training state (if resume_from_checkpoint is True),
455 | model, tokenizer, loss, optimizer, sampler, and dataloader.
456 | """
457 | if self._is_rank_zero:
458 | self._metric_logger = instantiate_component(cfg.metric_logger)
459 |
460 | # log config with parameter override
461 | self._metric_logger.log_config(cfg)
462 |
463 | checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
464 | if reference_checkpointer_cfg := cfg.get("reference_checkpointer", None):
465 | self.reference_model_state_dict = instantiate_component(
466 | reference_checkpointer_cfg
467 | ).load_checkpoint()[training.MODEL_KEY]
468 | else:
469 | self.reference_model_state_dict = None
470 |
471 | self._compile: bool = cfg.get("compile", False)
472 | if self._compile:
473 | torch.empty(1, device=self._device, requires_grad=True).backward()
474 | self._model = self._setup_model(
475 | cfg_model=cfg.model,
476 | enable_activation_checkpointing=self._enable_activation_checkpointing,
477 | enable_activation_offloading=self._enable_activation_offloading,
478 | custom_sharded_layers=cfg.get("custom_sharded_layers", None),
479 | fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
480 | reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
481 | model_state_dict=checkpoint_dict[training.MODEL_KEY],
482 | reference_model_state_dict=self.reference_model_state_dict,
483 | ac_mode=cfg.get("ac_mode", None),
484 | ac_option=cfg.get("ac_option", None),
485 | )
486 | self._model.output_hidden_states = [len(self._model.layers) - 1]
487 |
488 | if self.reference_model_state_dict:
489 | # pin reference model state
490 | for value in self.reference_model_state_dict.values():
491 | if not isinstance(value, torch.distributed._tensor.DTensor): # type: ignore
492 | value.pin_memory()
493 |
494 | self._optimizer = self._setup_optimizer(
495 | cfg_optimizer=cfg.optimizer,
496 | optimizer_in_bwd=self._optimizer_in_bwd,
497 | opt_state_dict=(
498 | checkpoint_dict[training.OPT_KEY]
499 | if self._resume_from_checkpoint
500 | else None
501 | ),
502 | )
503 |
504 | # initialize loss
505 | self._loss_fn = instantiate_component(cfg.loss)
506 |
507 | if self._compile:
508 | if self._is_rank_zero:
509 | log.info("Compiling loss with torch.compile...")
510 | self._loss_fn._forward_chunk = torch.compile(
511 | self._loss_fn._forward_chunk,
512 | backend=os.environ.get("TORCH_COMPILE_BACKEND", "inductor"),
513 | )
514 |
515 | if cfg.get("num_output_chunks", None) is not None:
516 | # set num_output_chunks for model
517 | self._model.set_num_output_chunks(cfg.num_output_chunks)
518 |
519 | if self._is_rank_zero:
520 | log.info("Loss is initialized.")
521 |
522 | # sampler and dataloader depend on the tokenizer and loss_fn and should be
523 | # setup after both of these are initialized
524 | self._sampler, self._dataloader = self._setup_data(
525 | cfg_dataset=cfg.dataset,
526 | shuffle=cfg.shuffle,
527 | batch_size=cfg.batch_size,
528 | )
529 |
530 | # Finally update the recipe state which can only be correctly set after all of the
531 | # other components have been initialized and updated.
532 | #
533 | # Number of training steps in each epoch depends on the number of batches produced
534 | # by the dataloader, the max_steps_per_epoch param set by the user and the
535 | # gradient_accumulation_steps param. This value is used for logging and tracking
536 | # training state. The computation should happen after the dataloader has been setup
537 | self._steps_per_epoch = (
538 | len(self._dataloader) // self._gradient_accumulation_steps
539 | )
540 | if (
541 | self.max_steps_per_epoch is not None
542 | and self.max_steps_per_epoch < self._steps_per_epoch
543 | ):
544 | self._steps_per_epoch = self.max_steps_per_epoch
545 | self.global_step = self.epochs_run * self._steps_per_epoch
546 |
547 | # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
548 | # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
549 | self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
550 |
551 | def _setup_profiler(
552 | self, cfg_profiler: Optional[DictConfig] = None
553 | ) -> Union[torch.profiler.profile, DummyProfiler]:
554 | """
555 | Parses the `profiler` section of top-level `cfg` and sets up profiler
556 |
557 | Args:
558 | cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
559 | `recipe.main`). Default None.
560 |
561 | Returns:
562 | profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
563 | for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
564 | that the instrumented training loop does not need to be changed profiling is disabled.
565 |
566 | The profiler config can be provided in configs under the `profiler` key with the following layout:
567 |
568 | .. code-block:: yaml
569 | profiler:
570 | enabled: bool
571 |
572 | #Output directory of trace artifacts
573 | output_dir: str
574 |
575 | #`torch.profiler.ProfilerActivity` types to trace
576 | cpu: bool
577 | cuda: bool
578 |
579 | #Trace options
580 | profile_memory: bool
581 | with_stack: bool
582 | record_shapes: bool
583 | with_flops: bool
584 |
585 | # `torch.profiler.schedule` options:
586 | # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
587 | wait_steps: int
588 | warmup_steps: int
589 | active_steps: int
590 | num_cycles: int
591 | """
592 | # Missing profiler section in config, assume disabled
593 | if cfg_profiler is None:
594 | cfg_profiler = DictConfig({"enabled": False})
595 |
596 | # Check that component is included and set correctly
597 | if cfg_profiler.get("_component_", None) is None:
598 | cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
599 | else:
600 | assert (
601 | cfg_profiler.get("_component_")
602 | == "torchtune.training.setup_torch_profiler"
603 | ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
604 |
605 | profiler, profiler_cfg = config.instantiate(cfg_profiler)
606 |
607 | if self._is_rank_zero:
608 | log.info(f" Profiler config after instantiation: {profiler_cfg}")
609 |
610 | self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
611 | if profiler_cfg["enabled"]:
612 | self.profiler_wait_steps = profiler_cfg["wait_steps"]
613 | self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
614 | self.profiler_active_steps = profiler_cfg["active_steps"]
615 |
616 | return profiler
617 |
618 | def _setup_model(
619 | self,
620 | cfg_model: ComponentConfig[TransformerDecoder],
621 | enable_activation_checkpointing: bool,
622 | enable_activation_offloading: bool,
623 | fsdp_cpu_offload: bool,
624 | reshard_after_forward: bool,
625 | model_state_dict: Dict[str, Any],
626 | reference_model_state_dict: Optional[Dict[str, Any]] = None,
627 | custom_sharded_layers: Optional[List[str]] = None,
628 | ac_mode: Optional[str] = None,
629 | ac_option: Optional[int] = None,
630 | ) -> TransformerDecoder:
631 | """
632 | Model initialization has some important considerations:
633 | a. To minimize GPU peak memory, we initialize the model on meta device with
634 | the right dtype
635 | b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
636 | full state dicts are loaded with ``torch.load(mmap=True)``
637 | """
638 |
639 | if self._is_rank_zero:
640 | log.info(
641 | "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
642 | )
643 | init_start = time.perf_counter()
644 | else:
645 | init_start = 0.0
646 |
647 | with (
648 | training.set_default_dtype(self._dtype),
649 | torch.device("meta") if training.is_distributed() else self._device,
650 | ):
651 | model = instantiate_component(cfg_model)
652 |
653 | if self._compile:
654 | training.compile_model(model, verbose=self._is_rank_zero)
655 |
656 | # We currently have two versions of activation checkpointing in this recipe
657 | # for testing and BC purposes. ``enable_activation_checkpointing`` controls
658 | # the older version of AC and this behavior is unchanged
659 | # ac_mode and ac_option together control selective AC. This is only enabled
660 | # when these are set AND ``enable_activation_checkpointing`` is set to False
661 | # We'll clean this up as soon as testing of AC is complete
662 | if (not enable_activation_checkpointing) and (ac_mode is not None):
663 | apply_selective_activation_checkpointing(
664 | model,
665 | ac_mode,
666 | ac_option,
667 | )
668 |
669 | # original activation checkpointing (full) - flip the condition above
670 | if enable_activation_checkpointing and ac_mode is None:
671 | training.set_activation_checkpointing(
672 | model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
673 | )
674 |
675 | if training.is_distributed():
676 | # For FSDP sharding
677 | fsdp_shard_conditions = [
678 | partial(
679 | training.get_shard_conditions,
680 | names_to_match=custom_sharded_layers,
681 | )
682 | ]
683 | training.shard_model(
684 | model=model,
685 | shard_conditions=fsdp_shard_conditions,
686 | cpu_offload=fsdp_cpu_offload,
687 | reshard_after_forward=reshard_after_forward,
688 | )
689 |
690 | with training.set_default_dtype(self._dtype), self._device:
691 | for m in model.modules():
692 | # RoPE is not covered in state dict
693 | if hasattr(m, "rope_init"):
694 | m.rope_init() # type: ignore
695 |
696 | # This method will convert the full model state dict into a sharded state
697 | # dict and load into the model
698 | training.load_from_full_model_state_dict(
699 | model,
700 | model_state_dict,
701 | self._device,
702 | self._is_rank_zero,
703 | strict=True,
704 | cpu_offload=fsdp_cpu_offload,
705 | )
706 |
707 | if reference_model_state_dict:
708 | # Temporarily patch model.load_state_dict to capture the sharded parameter tensors
709 | # for this rank when loading the reference model. This allows us to maintain a
710 | # reference copy of the sharded parameters that matches the FSDP sharding pattern,
711 | # which is needed for weight swapping during training.
712 | load_state_dict = model.load_state_dict
713 |
714 | def patch(
715 | state_dict: Mapping[str, Any],
716 | strict: bool = True,
717 | assign: bool = False,
718 | ) -> Any:
719 | reference_model_state_dict.clear()
720 | reference_model_state_dict.update(state_dict)
721 |
722 | model.load_state_dict = patch
723 |
724 | training.load_from_full_model_state_dict(
725 | model,
726 | reference_model_state_dict,
727 | self._device,
728 | self._is_rank_zero,
729 | strict=True,
730 | cpu_offload=fsdp_cpu_offload,
731 | )
732 |
733 | model.load_state_dict = load_state_dict
734 | else:
735 | model.load_state_dict(model_state_dict)
736 |
737 | # Validate model was loaded in with the expected dtype.
738 | training.validate_expected_param_dtype(
739 | model.named_parameters(), dtype=self._dtype
740 | )
741 |
742 | # activation offloading
743 | self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
744 | model, enable_activation_offloading
745 | )
746 |
747 | # Ensure no params and buffers are on meta device
748 | training.validate_no_params_on_meta_device(model)
749 |
750 | if self._is_rank_zero:
751 | log.info(
752 | f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
753 | )
754 | memory_stats = training.get_memory_stats(device=self._device)
755 | training.log_memory_stats(memory_stats)
756 |
757 | if training.is_distributed():
758 | # synchronize before training begins
759 | torch.distributed.barrier()
760 |
761 | return model
762 |
763 | def _setup_optimizer(
764 | self,
765 | cfg_optimizer: ComponentConfig[Optimizer],
766 | optimizer_in_bwd: bool = False,
767 | opt_state_dict: Optional[Dict[str, Any]] = None,
768 | ) -> Optional[Optimizer]:
769 | if optimizer_in_bwd:
770 | # Maintain a dict of optims for every parameter.
771 | optim_dict = {
772 | param: instantiate_component(cfg_optimizer, params=[param])
773 | for param in self._model.parameters()
774 | }
775 |
776 | # Register optimizer step hooks on the model to run optimizer in backward.
777 | training.register_optim_in_bwd_hooks(
778 | model=self._model, optim_dict=optim_dict
779 | )
780 | # Create a wrapper for checkpoint save/load of optimizer states when running in backward.
781 | self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
782 | model=self._model, optim_dict=optim_dict
783 | )
784 | # Load optimizer states for each param. If optimizer states are being restored in an optimizer in
785 | # backward run, these need to have been saved with the same setting. Cannot restore from runs that
786 | # did not use optimizer in backward.
787 | if opt_state_dict is not None:
788 | for param in opt_state_dict.keys():
789 | try:
790 | training.load_from_full_optimizer_state_dict(
791 | self._optim_ckpt_wrapper.state_dict()[param],
792 | opt_state_dict[param],
793 | self._device,
794 | )
795 | except BaseException as e:
796 | raise RuntimeError(
797 | "Failed loading in-backward optimizer checkpoints."
798 | "Please make sure run being restored from was using in-backward optimizer."
799 | ) from e
800 | if self._is_rank_zero:
801 | log.info("In-backward optimizers are set up.")
802 | return None
803 | else:
804 | optimizer = instantiate_component(
805 | cfg_optimizer, params=self._model.parameters()
806 | )
807 | if opt_state_dict:
808 | training.load_from_full_optimizer_state_dict(
809 | optimizer,
810 | opt_state_dict,
811 | self._device,
812 | )
813 |
814 | if self._is_rank_zero:
815 | log.info("Optimizer is initialized.")
816 | return optimizer
817 |
818 | def _setup_data(
819 | self,
820 | cfg_dataset: ComponentConfig[Dataset[PackedTensors]],
821 | shuffle: bool,
822 | batch_size: int,
823 | ) -> Tuple[DistributedSampler, TypedDataLoader[PackedTensors]]:
824 | """
825 | All data related setup happens here. Currently this recipe only supports the
826 | DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
827 | iterable datasets and streaming datasets are not supported.
828 | """
829 | world_size, rank = training.get_world_size_and_rank()
830 |
831 | ds = instantiate_component(cfg_dataset)
832 |
833 | sampler = DistributedSampler(
834 | ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=self.seed or 0
835 | )
836 | dataloader = TypedDataLoader(
837 | dataset=ds,
838 | batch_size=batch_size,
839 | sampler=sampler,
840 | # dropping last avoids shape issues with compile + flex attention
841 | drop_last=True,
842 | )
843 |
844 | if self._is_rank_zero:
845 | log.info("Dataset and Sampler are initialized.")
846 |
847 | return sampler, dataloader
848 |
849 | def save_checkpoint(
850 | self,
851 | epoch: int,
852 | ) -> None:
853 | """
854 | Checkpoint the state of the recipe. The constructed checkpoint state dict
855 | contains the following information:
856 | - Model weights with key training.MODEL_KEY
857 | - Relevant recipe state if training is not complete
858 |
859 | Checkpointer will save the model weights and recipe state in
860 | different checkpoint files. To correctly resume training from an intermediate checkpoint,
861 | the model weights and recipe state must be provided.
862 | """
863 | # final dict passed onto the checkpointer
864 | checkpoint_dict = {}
865 |
866 | intermediate_checkpoint = epoch + 1 < self.total_epochs
867 |
868 | if intermediate_checkpoint and not self._save_intermediate_checkpoints:
869 | return
870 |
871 | if self._is_rank_zero:
872 | log.info(
873 | "Saving checkpoint. This may take some time. Retrieving full model state dict..."
874 | )
875 | start = time.perf_counter()
876 | else:
877 | start = 0.0
878 |
879 | # To prevent GPU memory from spiking during checkpoint save,
880 | # we consolidate the full model and optim state dicts on CPU for rank 0
881 | model_state_dict = (
882 | training.gather_cpu_state_dict(
883 | self._model.state_dict(),
884 | self._is_rank_zero,
885 | device=self._device,
886 | )
887 | if training.is_distributed()
888 | else self._model.state_dict()
889 | )
890 |
891 | if self._is_rank_zero:
892 | log.info(
893 | f"Getting full model state dict took {time.perf_counter() - start:.2f} secs"
894 | )
895 |
896 | if intermediate_checkpoint:
897 | start = time.perf_counter()
898 | if self._is_rank_zero:
899 | log.info("Getting optimizer state dict...")
900 | if self._optimizer:
901 | opt_state_dict = (
902 | training.get_full_optimizer_state_dict(
903 | self._optimizer,
904 | self._is_rank_zero,
905 | device=self._device,
906 | )
907 | if training.is_distributed()
908 | else self._optimizer.state_dict()
909 | )
910 | else:
911 | opt_state_dict = {}
912 | for param, opt in self._optim_ckpt_wrapper.optim_map.items():
913 | opt_state_dict[param] = (
914 | training.get_full_optimizer_state_dict(
915 | opt, self._is_rank_zero, device=self._device
916 | )
917 | if training.is_distributed()
918 | else opt.state_dict()
919 | )
920 | if self._is_rank_zero:
921 | log.info(
922 | f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs"
923 | )
924 | else:
925 | opt_state_dict = None
926 |
927 | # Now that we have the model and opt state dict, create the actual checkpoint dict
928 | # to be sent to the checkpointer and ultimately written to file
929 |
930 | if self._is_rank_zero:
931 | start = time.perf_counter()
932 | checkpoint_dict.update({training.MODEL_KEY: model_state_dict})
933 |
934 | # if training is in-progress, checkpoint the optimizer state and recipe state
935 | # as well.
936 | if intermediate_checkpoint:
937 | checkpoint_dict.update(
938 | {
939 | training.OPT_KEY: opt_state_dict,
940 | training.SEED_KEY: self.seed,
941 | training.EPOCHS_KEY: self.epochs_run,
942 | training.TOTAL_EPOCHS_KEY: self.total_epochs,
943 | training.MAX_STEPS_KEY: self.max_steps_per_epoch,
944 | }
945 | )
946 |
947 | self._checkpointer.save_checkpoint(
948 | checkpoint_dict,
949 | epoch=epoch,
950 | intermediate_checkpoint=intermediate_checkpoint,
951 | )
952 | log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")
953 |
954 | if training.is_distributed():
955 | torch.distributed.barrier()
956 |
957 | def train(self) -> None:
958 | """
959 | The core training loop.
960 | """
961 | # clean up before training begins
962 | training.cleanup_before_training()
963 | torch.autograd.set_detect_anomaly(True)
964 |
965 | world_size, rank = training.get_world_size_and_rank()
966 |
967 | # zero out the gradients before starting training
968 | if self._optimizer:
969 | self._optimizer.zero_grad()
970 | else:
971 | for opt in self._optim_ckpt_wrapper.optim_map.values():
972 | opt.zero_grad()
973 |
974 | # Initialize tokens count and running loss (for grad accumulation)
975 | t0 = time.perf_counter()
976 | running_result = GRPOResult().to(self._device)
977 |
978 | self._profiler.start()
979 | # self.epochs_run should be non-zero when we're resuming from a checkpoint
980 | for curr_epoch in range(self.epochs_run, self.total_epochs):
981 | # Update the sampler to ensure data is correctly shuffled across epochs
982 | # in case shuffle is True
983 | self._sampler.set_epoch(curr_epoch)
984 |
985 | pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
986 | grad_norm: torch.Tensor | None = None
987 | for idx, batch in enumerate(self._dataloader):
988 | if (
989 | self.max_steps_per_epoch is not None
990 | and (idx // self._gradient_accumulation_steps)
991 | == self.max_steps_per_epoch
992 | ):
993 | break
994 |
995 | # Start tracking CUDA memory for active steps for just the first epoch
996 | if (
997 | self._is_rank_zero
998 | and curr_epoch == 0
999 | and self.profiler_profile_memory
1000 | and idx == self.profiler_wait_steps + self.profiler_warmup_steps
1001 | ):
1002 | torch.cuda.memory._record_memory_history()
1003 |
1004 | utils.batch_to_device(batch, self._device) # type: ignore - `batch_to_device` expects a `dict`, not a `TypedDict`, but this should be fine
1005 |
1006 | # Assume the first token in the batch is the bos token
1007 | bos_id = int(batch["tokens"].view(-1)[0].item())
1008 |
1009 | # Create grouped causal mask
1010 | batch_size, seq_len = batch["tokens"].size()
1011 | causal_mask = (
1012 | torch.tril(
1013 | torch.ones(
1014 | seq_len, seq_len, dtype=torch.bool, device=self._device
1015 | )
1016 | )
1017 | .unsqueeze(0)
1018 | .expand(batch_size, seq_len, seq_len)
1019 | )
1020 | group_mask = batch["group_ids"].unsqueeze(2) == batch[
1021 | "group_ids"
1022 | ].unsqueeze(1)
1023 | parent_mask = batch["parent_ids"].unsqueeze(2) == batch[
1024 | "group_ids"
1025 | ].unsqueeze(1)
1026 | mask = causal_mask & (group_mask | parent_mask)
1027 |
1028 | if self.reference_model_state_dict:
1029 | # Save current weights and load reference weights
1030 | model_state_dict = self._swap_state(self.reference_model_state_dict)
1031 |
1032 | # Run reference model forward pass without affecting autograd
1033 | with torch.no_grad(), self.activations_handling_ctx:
1034 | hidden_states, logits = self._model(
1035 | tokens=batch["tokens"],
1036 | mask=mask,
1037 | input_pos=batch["input_pos"],
1038 | )
1039 | del hidden_states
1040 | if isinstance(logits, list):
1041 | reference_logprobs = torch.cat(
1042 | [
1043 | torch.distributions.Categorical(
1044 | logits=logits_chunk
1045 | ).log_prob(
1046 | shift_tensor(tokens, ignore_label=bos_id)
1047 | )
1048 | for logits_chunk, tokens in zip(
1049 | logits,
1050 | batch["tokens"].chunk(len(logits), dim=1),
1051 | )
1052 | ],
1053 | dim=-1,
1054 | )
1055 | else:
1056 | reference_logprobs = cast(
1057 | torch.Tensor,
1058 | torch.distributions.Categorical(logits=logits).log_prob(
1059 | shift_tensor(batch["tokens"], ignore_label=bos_id)
1060 | ),
1061 | )
1062 | del logits
1063 |
1064 | # Restore original weights
1065 | self._swap_state(model_state_dict)
1066 | else:
1067 | reference_logprobs = None
1068 |
1069 | with self.activations_handling_ctx:
1070 | hidden_states, logits = self._model(
1071 | tokens=batch["tokens"],
1072 | mask=mask,
1073 | input_pos=batch["input_pos"],
1074 | )
1075 | del mask, batch["input_pos"], hidden_states # type: ignore
1076 |
1077 | # Compute loss
1078 | current_result = self._loss_fn.forward(
1079 | logits=logits,
1080 | tokens=batch["tokens"],
1081 | advantages=batch["advantages"],
1082 | logprobs=batch["logprobs"],
1083 | reference_logprobs=reference_logprobs,
1084 | mask=batch["assistant_mask"],
1085 | weights=batch["weights"],
1086 | bos_id=bos_id,
1087 | )
1088 | del logits, batch
1089 |
1090 | running_result += current_result
1091 |
1092 | # For optimizer in backward, we need to normalize before calling backward
1093 | # This case and gradient accumulation are mutually exclusive
1094 | if self._optimizer_in_bwd:
1095 | if training.is_distributed():
1096 | for tensor in running_result.tensors():
1097 | torch.distributed.all_reduce(tensor)
1098 | current_loss = current_result.total_loss / current_result.num_tokens
1099 | else:
1100 | current_loss = current_result.total_loss
1101 |
1102 | current_loss.backward()
1103 | del current_loss
1104 |
1105 | # Step with optimizer
1106 | if (idx + 1) % self._gradient_accumulation_steps == 0:
1107 | if self._optimizer:
1108 | if training.is_distributed():
1109 | for tensor in running_result.tensors():
1110 | torch.distributed.all_reduce(tensor)
1111 | # Manually scale the gradients from unnormalized loss by total # of tokens
1112 | training.scale_grads(self._model, 1 / running_result.num_tokens)
1113 |
1114 | # Calculate gradient magnitude (L2 norm of all gradients)
1115 | grad_magnitude = None
1116 | if self._log_grad_magnitude:
1117 | grad_magnitude = torch.norm(
1118 | torch.stack(
1119 | [
1120 | torch.norm(p.grad.detach())
1121 | for p in self._model.parameters()
1122 | if p.grad is not None
1123 | ]
1124 | )
1125 | )
1126 |
1127 | if self._clip_grad_norm is not None:
1128 | grad_norm = torch.nn.utils.clip_grad_norm_(
1129 | self._model.parameters(),
1130 | max_norm=float(self._clip_grad_norm),
1131 | )
1132 | self._optimizer.step()
1133 | self._optimizer.zero_grad(set_to_none=True)
1134 |
1135 | # Update the number of steps when the weights are updated
1136 | self.global_step += 1
1137 |
1138 | per_token_result = running_result.per_token()
1139 | loss_to_log = per_token_result.total_loss.item()
1140 | policy_loss_to_log = per_token_result.policy_loss.item()
1141 | entropy_to_log = per_token_result.entropy.item()
1142 | kl_div_to_log = per_token_result.kl_div.item()
1143 | pbar.update(1)
1144 | pbar.set_description(
1145 | f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log:.4f}"
1146 | )
1147 | postfix = {
1148 | "loss": loss_to_log,
1149 | "policy": policy_loss_to_log,
1150 | "entropy": entropy_to_log,
1151 | "kl_div": kl_div_to_log,
1152 | }
1153 | if self._log_grad_magnitude and grad_magnitude is not None:
1154 | postfix["grad_magnitude"] = grad_magnitude.item()
1155 | pbar.set_postfix(postfix)
1156 |
1157 | # Log per-step metrics
1158 | if (
1159 | self.global_step % self._log_every_n_steps == 0
1160 | and self._is_rank_zero
1161 | ):
1162 | time_per_step = time.perf_counter() - t0
1163 | log_dict = {
1164 | "loss": loss_to_log,
1165 | "policy": policy_loss_to_log,
1166 | "entropy": entropy_to_log,
1167 | "kl_div": kl_div_to_log,
1168 | # "lr": get_lr(self._optimizer or self._optim_ckpt_wrapper),
1169 | "tokens_per_second_per_gpu": running_result.num_tokens
1170 | / (time_per_step * world_size),
1171 | }
1172 | if self._log_grad_magnitude and grad_magnitude is not None:
1173 | log_dict["grad_magnitude"] = grad_magnitude.item()
1174 | if self._log_peak_memory_stats:
1175 | log_dict.update(
1176 | training.get_memory_stats(device=self._device)
1177 | )
1178 | if self._clip_grad_norm is not None:
1179 | log_dict.update({"grad_norm": grad_norm})
1180 | self._metric_logger.log_dict(
1181 | log_dict,
1182 | step=self.global_step,
1183 | )
1184 |
1185 | # Reset running stats for the next step
1186 | del running_result
1187 | running_result = GRPOResult().to(self._device)
1188 | t0 = time.perf_counter()
1189 |
1190 | # Stop tracking CUDA memory now that active steps are complete
1191 | if (
1192 | self._is_rank_zero
1193 | and curr_epoch == 0
1194 | and self.profiler_profile_memory
1195 | and idx
1196 | == self.profiler_wait_steps
1197 | + self.profiler_warmup_steps
1198 | + self.profiler_active_steps
1199 | ):
1200 | torch.cuda.memory._record_memory_history(
1201 | # Pylance infers the type of `enabled` as `str` though the function accepts `Literal[None, "state", "all"]`
1202 | enabled=None # type: ignore
1203 | )
1204 |
1205 | # Step profiler
1206 | # Note that this is called within gradient accumulation block, hence
1207 | # will include multiple forward / backward passes if gradient accumulation > 1
1208 | self._profiler.step()
1209 |
1210 | self.epochs_run += 1
1211 | self.save_checkpoint(epoch=curr_epoch)
1212 |
1213 | self._profiler.stop()
1214 |
1215 | def cleanup(self) -> None:
1216 | if self._is_rank_zero:
1217 | self._metric_logger.close()
1218 | if training.is_distributed():
1219 | torch.distributed.destroy_process_group()
1220 | training.cleanup_before_training()
1221 |
1222 | def _swap_state(
1223 | self, state_dict: Dict[str, Any], assign: bool = False
1224 | ) -> Dict[str, Any]:
1225 | """
1226 | Swaps the current model state with the provided state dict.
1227 | Manages GPU memory by moving states to CPU/device in the right order.
1228 |
1229 | Args:
1230 | state_dict: Dictionary of state to load into model
1231 |
1232 | Returns:
1233 | Original model state dict (moved to CPU)
1234 | """
1235 | # Save current model state and move to CPU
1236 | current_state = {
1237 | k: v.to("cpu", non_blocking=True)
1238 | for k, v in self._model.state_dict().items()
1239 | }
1240 |
1241 | # Move input state to device and load
1242 | device_state = {
1243 | k: v.to(self._device, non_blocking=True) for k, v in state_dict.items()
1244 | }
1245 | self._model.load_state_dict(device_state, assign=assign)
1246 |
1247 | return current_state
1248 |
1249 |
1250 | def recipe_main(cfg: TuneRecipeConfig) -> None:
1251 | """
1252 | Entry point for the recipe.
1253 |
1254 | Configurable parameters are read in the following order:
1255 | - Parameters specified in config (see available configs through ``tune ls``)
1256 | - Overwritten by arguments from the command-line
1257 | """
1258 | if not training.is_distributed():
1259 | log.debug(
1260 | "Training is not distributed. If you want to train on multiple GPUs and are using the tune CLI, specify --nnodes 1 and --nproc_per_node [num_gpus]"
1261 | )
1262 | elif not torch.distributed.is_initialized():
1263 | torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo")
1264 |
1265 | if cfg.get("fsdp_cpu_offload", False):
1266 | # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
1267 | # speed up when benchmarking fused AdamW on CPU
1268 | training.set_torch_num_threads()
1269 |
1270 | config.log_config(
1271 | recipe_name="FullFinetuneRecipe",
1272 | cfg=cfg.dict_config() if isinstance(cfg, TuneRecipeConfig) else cfg,
1273 | )
1274 |
1275 | recipe = TuneRecipe(cfg=cfg)
1276 | recipe.setup(cfg=cfg)
1277 | recipe.train()
1278 | recipe.cleanup()
1279 |
1280 |
1281 | if __name__ == "__main__":
1282 | sys.exit(config.parse(recipe_main)()) # type: ignore
1283 |
--------------------------------------------------------------------------------
/lib/stream.py:
--------------------------------------------------------------------------------
1 | # https://gist.github.com/bradhilton/ec2450881abb0d0ef6a1bea7f8ca824f
2 | from openai import AsyncStream
3 | from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs
4 | from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
5 | from openai.types.chat.chat_completion_message import (
6 | ChatCompletionMessage,
7 | FunctionCall,
8 | )
9 | from openai.types.chat.chat_completion_message_tool_call import (
10 | ChatCompletionMessageToolCall,
11 | Function,
12 | )
13 | from typing import Any, Callable
14 |
15 |
16 | async def consume_chat_completion_stream(
17 | stream: AsyncStream[ChatCompletionChunk],
18 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], Any] | None = None,
19 | ) -> ChatCompletion:
20 | """Consume a chat completion stream and build a complete ChatCompletion object.
21 |
22 | This function processes a stream of ChatCompletionChunks, constructing a complete
23 | ChatCompletion object as if it was returned from a non-streaming API call.
24 | Works with any OpenAI-compatible API implementation.
25 |
26 | Args:
27 | stream: An AsyncStream of ChatCompletionChunk objects.
28 | on_chunk: Optional callback that receives each chunk and the current state of the
29 | ChatCompletion. If the callback raises StopIteration, the stream will close early.
30 |
31 | Returns:
32 | A complete ChatCompletion object built from the streamed chunks.
33 |
34 | Raises:
35 | AssertionError: If no chat completion object could be created.
36 | """
37 | chat_completion: ChatCompletion | None = None
38 | async for chunk in stream:
39 | if chat_completion is None:
40 | chat_completion = ChatCompletion(
41 | id=chunk.id,
42 | choices=[
43 | Choice(
44 | finish_reason="stop",
45 | index=choice.index,
46 | logprobs=(ChoiceLogprobs() if choice.logprobs else None),
47 | message=ChatCompletionMessage(role="assistant"),
48 | )
49 | for choice in chunk.choices
50 | ],
51 | created=chunk.created,
52 | model=chunk.model,
53 | object="chat.completion",
54 | )
55 | for choice, chunk_choice in zip(chat_completion.choices, chunk.choices):
56 | choice.finish_reason = chunk_choice.finish_reason or "stop"
57 | if chunk_choice.logprobs:
58 | if choice.logprobs is None:
59 | choice.logprobs = ChoiceLogprobs()
60 | if chunk_choice.logprobs.content:
61 | if choice.logprobs.content is None:
62 | choice.logprobs.content = []
63 | choice.logprobs.content.extend(chunk_choice.logprobs.content)
64 | if chunk_choice.logprobs.refusal:
65 | if choice.logprobs.refusal is None:
66 | choice.logprobs.refusal = []
67 | choice.logprobs.refusal.extend(chunk_choice.logprobs.refusal)
68 | if chunk_choice.delta.content:
69 | if choice.message.content is None:
70 | choice.message.content = ""
71 | choice.message.content += chunk_choice.delta.content
72 | if chunk_choice.delta.refusal:
73 | if choice.message.refusal is None:
74 | choice.message.refusal = ""
75 | choice.message.refusal += chunk_choice.delta.refusal
76 | if chunk_choice.delta.function_call:
77 | if choice.message.function_call is None:
78 | choice.message.function_call = FunctionCall(arguments="", name="")
79 | choice.message.function_call.name += (
80 | chunk_choice.delta.function_call.name or ""
81 | )
82 | choice.message.function_call.arguments += (
83 | chunk_choice.delta.function_call.arguments or ""
84 | )
85 | if chunk_choice.delta.tool_calls:
86 | if choice.message.tool_calls is None:
87 | choice.message.tool_calls = []
88 | for tool_call in chunk_choice.delta.tool_calls:
89 | if not tool_call.index in range(len(choice.message.tool_calls)):
90 | choice.message.tool_calls.append(
91 | ChatCompletionMessageToolCall(
92 | id="",
93 | function=Function(arguments="", name=""),
94 | type="function",
95 | )
96 | )
97 | choice.message.tool_calls[tool_call.index].id += tool_call.id or ""
98 | choice.message.tool_calls[tool_call.index].function.name += (
99 | tool_call.function.name or "" if tool_call.function else ""
100 | )
101 | choice.message.tool_calls[tool_call.index].function.arguments += (
102 | tool_call.function.arguments or "" if tool_call.function else ""
103 | )
104 | if getattr(chunk_choice.delta, "reasoning", None):
105 | if not hasattr(choice.message, "reasoning"):
106 | setattr(choice.message, "reasoning", "")
107 | setattr(
108 | choice.message,
109 | "reasoning",
110 | getattr(choice.message, "reasoning")
111 | + getattr(chunk_choice.delta, "reasoning"),
112 | )
113 | chat_completion.service_tier = chunk.service_tier
114 | chat_completion.system_fingerprint = chunk.system_fingerprint
115 | chat_completion.usage = chunk.usage
116 | if on_chunk:
117 | try:
118 | on_chunk(chunk, chat_completion)
119 | except StopIteration:
120 | await stream.close()
121 | break
122 | assert chat_completion is not None
123 | return chat_completion
124 |
--------------------------------------------------------------------------------
/lib/tasks.py:
--------------------------------------------------------------------------------
1 | from aioitertools.helpers import maybe_await
2 | import asyncio
3 | from dataclasses import dataclass, field
4 | import numpy as np
5 | from openai import AsyncOpenAI
6 | from openai.types.chat.chat_completion import ChatCompletion, Choice
7 | from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
8 | from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
9 | from openai.types.completion_usage import CompletionUsage
10 | import random
11 | from typing import Awaitable, Callable, TypeVar
12 |
13 | from .chat_completions import get_chat_completion
14 | from .types import ChatCompletionParams
15 | from .tqdm import tqdm
16 |
17 | # A grader function returns a floating-point reward,
18 | # and optionally a dictionary of metrics as the second return value
19 | Grade = float | tuple[float, dict[str, float]]
20 | Grader = Callable[[Choice], Grade | Awaitable[Grade]]
21 |
22 |
23 | @dataclass
24 | class Task:
25 | """
26 | A minimal task definition.
27 |
28 | Args:
29 | messages (list[ChatCompletionMessageParam]): OpenAI API compatible chat messages for prompting the model
30 | grader (Grader): A grader function to score the model's responses
31 | """
32 |
33 | messages: list[ChatCompletionMessageParam]
34 | grader: Grader
35 |
36 |
37 | @dataclass
38 | class TaskResult:
39 | """
40 | A single task result.
41 |
42 | Args:
43 | task (Task): The task that was graded
44 | chat_completions (list[ChatCompletion]): The chat completions generated for the task
45 | rewards (dict[tuple[str, int], float]): Rewards for each chat completion and choice index
46 | metrics (dict[tuple[str, int], dict[str, float]]): Metrics for each chat completion and choice index
47 | advantages (dict[tuple[str, int], float]): GRPO advantages for each chat completion and choice index
48 | exceptions (list[Exception]): Exceptions that occurred while getting the task result
49 | """
50 |
51 | task: Task
52 | chat_completions: list[ChatCompletion]
53 | rewards: dict[tuple[str, int], float]
54 | metrics: dict[tuple[str, int], dict[str, float]]
55 | advantages: dict[tuple[str, int], float]
56 | exceptions: list[Exception]
57 |
58 |
59 | @dataclass
60 | class TaskResultStats:
61 | """
62 | Statistics for task results.
63 |
64 | Args:
65 | pbar (tqdm.tqdm): The progress bar
66 | prices (tuple[float, float] | None): Prices for input/output tokens
67 | completion_tokens (int): Total completion tokens
68 | exceptions (list[Exception]): Exceptions that occurred while getting the task results
69 | grades (int): Number of grades
70 | new_completion_ids (set[str]): Set of new completion IDs
71 | new_completion_tokens (int): Total new completion tokens
72 | new_prompt_tokens (int): Total new prompt tokens
73 | prompt_tokens (int): Total prompt tokens
74 | token_logprobs (int): Total token log probabilities
75 | total_metrics (dict[str, float]): Total metrics
76 | total_reward (float): Total reward
77 | usages (int): Total usages
78 | """
79 |
80 | pbar: tqdm.tqdm
81 | prices: tuple[float, float] | None
82 | completion_tokens: int = 0
83 | exceptions: list[Exception] = field(default_factory=list)
84 | grades: int = 0
85 | new_completion_ids: set[str] = field(default_factory=set)
86 | new_completion_tokens: int = 0
87 | new_prompt_tokens: int = 0
88 | prompt_tokens: int = 0
89 | token_logprobs: int = 0
90 | total_metrics: dict[str, float] = field(default_factory=dict)
91 | total_reward: float = 0
92 | usages: int = 0
93 |
94 | def __del__(self) -> None:
95 | self.pbar.close()
96 |
97 | def update(
98 | self,
99 | *,
100 | id: str | None,
101 | chunk: ChatCompletionChunk | None = None,
102 | usage: CompletionUsage | None = None,
103 | reward: float | None = None,
104 | metrics: dict[str, float] | None = None,
105 | exception: Exception | None = None,
106 | ) -> None:
107 | if chunk:
108 | if id is not None:
109 | self.new_completion_ids.add(id)
110 | self.token_logprobs += sum(
111 | len(choice.logprobs.content or choice.logprobs.refusal or [])
112 | for choice in chunk.choices
113 | if choice.logprobs
114 | )
115 | elif usage:
116 | self.completion_tokens += usage.completion_tokens
117 | self.prompt_tokens += usage.prompt_tokens
118 | self.usages += 1
119 | if id in self.new_completion_ids:
120 | self.new_completion_tokens += usage.completion_tokens
121 | self.new_prompt_tokens += usage.prompt_tokens
122 | elif reward is not None:
123 | self.grades += 1
124 | self.total_reward += reward
125 | self.pbar.update()
126 | if metrics:
127 | for key, value in metrics.items():
128 | if key not in self.total_metrics:
129 | self.total_metrics[key] = 0
130 | self.total_metrics[key] += value
131 | elif exception:
132 | self.exceptions.append(exception)
133 | postfix = {
134 | "completion_tokens": round(self.completion_tokens / max(self.usages, 1)),
135 | "prompt_tokens": round(self.prompt_tokens / max(self.usages, 1)),
136 | "reward": self.total_reward / max(self.grades, 1),
137 | }
138 | for key, value in self.total_metrics.items():
139 | postfix[key] = value / max(self.grades, 1)
140 | if self.prices:
141 | postfix["spend"] = (
142 | f"${(
143 | self.new_prompt_tokens * self.prices[0]
144 | + (self.token_logprobs or self.new_completion_tokens) * self.prices[1]
145 | ) / 1_000_000:.2f}"
146 | )
147 | if self.token_logprobs:
148 | postfix["token_logprobs"] = self.token_logprobs
149 | if self.exceptions:
150 | postfix["exceptions"] = len(self.exceptions)
151 | self.pbar.set_postfix(postfix)
152 |
153 |
154 | T = TypeVar("T")
155 |
156 |
157 | class TaskResults(list[T]):
158 | stats: TaskResultStats
159 |
160 |
161 | async def get_task_results(
162 | tasks: list[Task],
163 | client: AsyncOpenAI,
164 | model: str,
165 | clear_pbar: bool = False,
166 | print_pbar: bool = True,
167 | log_dir: str | None = None,
168 | log_results: bool | float | int = True,
169 | log_token_logprobs: bool = True,
170 | n: int = 1,
171 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None = None,
172 | params: ChatCompletionParams | None = None,
173 | pbar_desc: str | None = None,
174 | pbar_position: int | None = None,
175 | prices: tuple[float, float] | None = None,
176 | semaphore: asyncio.Semaphore | None = None,
177 | transform: Callable[[TaskResult], T | Awaitable[T]] = lambda x: x,
178 | ) -> TaskResults[T]:
179 | """
180 | Returns results for tasks using an AsyncOpenAI client for a given model. Includes support for caching, rate limiting, and logging. Results may be optionally transformed.
181 |
182 | Args:
183 | tasks (list[Task]): List of Task objects, each containing messages to send to the LLM and a grader function
184 | client (AsyncOpenAI): Any valid AsyncOpenAI client that supports creating chat completions (may be pointed at API providers other than OpenAI or a local inference engine like vLLM)
185 | model (str): Name of the chat completion model to use
186 | clear_pbar (bool): Whether to clear the progress bar after completion
187 | print_pbar (bool): Whether to print the progress bar summary after completion
188 | log_dir (str | None): Directory to save completion logs to. If None, will use the default chat completion log directory
189 | log_results (bool | float | int): Controls which task results to log. Can be a boolean, float (fraction), or int (count)
190 | log_token_logprobs (bool): Whether to stream token log probabilities count to the progress bar
191 | n (int): Number of chat completions to sample per task
192 | on_chunk (Callable[[ChatCompletionChunk, ChatCompletion], None] | None): Optional callback function for processing completion chunks
193 | params (ChatCompletionParams | None): Additional parameters to pass to the chat completion API
194 | pbar_desc (str | None): Description to display on the progress bar
195 | pbar_position (int | None): Position of the progress bar
196 | prices (tuple[float, float] | None): Tuple of (input_price, output_price) per million tokens, for cost tracking
197 | semaphore (asyncio.Semaphore | None): Optional semaphore for limiting concurrent API calls
198 | transform (Callable[[TaskResult], T | Awaitable[T]]): Function to transform TaskResult objects before returning
199 |
200 | Returns:
201 | TaskResults[T]: Processed results and statistics
202 |
203 | Process: Runs model inference → evaluates with graders → computes rewards/advantages →
204 | tracks metrics → transforms results
205 | """
206 | num_completions = len(tasks) * n
207 | pbar = tqdm.tqdm(total=num_completions, desc=pbar_desc, position=pbar_position)
208 | stats = TaskResultStats(pbar=pbar, prices=prices)
209 | results = TaskResults(
210 | await asyncio.gather(
211 | *(
212 | _get_task_result(
213 | task=task,
214 | client=client,
215 | model=model,
216 | log_results=log_results,
217 | n=n,
218 | log_dir=log_dir,
219 | on_chunk=_create_on_chunk_callback(
220 | log_token_logprobs, on_chunk, stats
221 | ),
222 | semaphore=semaphore,
223 | params=params,
224 | stats=stats,
225 | transform=transform,
226 | )
227 | for task, log_results in zip(
228 | tasks, _get_log_results_flags(log_results, len(tasks))
229 | )
230 | )
231 | )
232 | )
233 | results.stats = stats
234 | pbar.close()
235 | if getattr(pbar, "container", None) and clear_pbar:
236 | pbar.container.close()
237 | if getattr(pbar, "container", None) and print_pbar:
238 | print(pbar.container.__repr__(pretty=True))
239 | return results
240 |
241 |
242 | def _create_on_chunk_callback(
243 | log_token_logprobs: bool,
244 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None,
245 | stats: TaskResultStats,
246 | ) -> Callable[[ChatCompletionChunk, ChatCompletion], None] | None:
247 | "Create a callback function that logs token logprobs and/or calls the user's on_chunk callback if provided."
248 | if not log_token_logprobs and not on_chunk:
249 | return None
250 |
251 | def _on_chunk(chunk: ChatCompletionChunk, completion: ChatCompletion) -> None:
252 | if log_token_logprobs:
253 | stats.update(id=chunk.id, chunk=chunk)
254 | if on_chunk:
255 | on_chunk(chunk, completion)
256 |
257 | return _on_chunk
258 |
259 |
260 | def _get_log_results_flags(log_results: bool | float | int, n: int) -> list[bool]:
261 | "Return a list of flags indicating whether to log results for each task."
262 | if isinstance(log_results, int) and log_results >= 1:
263 | result = [True] * log_results + [False] * max(n - log_results, 0)
264 | elif isinstance(log_results, float):
265 | result = [True] * int(log_results * n) + [False] * (n - int(log_results * n))
266 | else:
267 | result = [bool(log_results)] * n
268 | random.shuffle(result)
269 | return result
270 |
271 |
272 | async def _get_task_result(
273 | task: Task,
274 | client: AsyncOpenAI,
275 | model: str,
276 | log_results: bool,
277 | n: int,
278 | log_dir: str | None,
279 | on_chunk: Callable[[ChatCompletionChunk, ChatCompletion], None] | None,
280 | semaphore: asyncio.Semaphore | None,
281 | params: ChatCompletionParams | None,
282 | stats: TaskResultStats,
283 | transform: Callable[[TaskResult], T | Awaitable[T]],
284 | ) -> T:
285 | """
286 | Processes a single task by generating chat completions, grading responses, and calculating rewards.
287 |
288 | This is a helper function called by get_task_results for each task in the list. It:
289 | 1. Makes n API calls to generate completions for the task
290 | 2. Collects all completions and passes them to the task's grader function
291 | 3. Tracks usage statistics and handles exceptions
292 | 4. Calculates GRPO advantages based on reward distribution
293 | 5. Applies the transform function to the TaskResult before returning
294 |
295 | See get_task_results for parameter descriptions.
296 | """
297 | # always request logprobs, unless explicitly disabled
298 | _params = (params or {}).copy()
299 | if "logprobs" not in _params:
300 | _params["logprobs"] = True
301 | elif _params["logprobs"] is None:
302 | del _params["logprobs"]
303 | chat_completions: list[ChatCompletion] = []
304 | rewards: dict[tuple[str, int], float] = {}
305 | metrics: dict[tuple[str, int], dict[str, float]] = {}
306 | exceptions: list[Exception] = []
307 | for chat_completion_future in asyncio.as_completed(
308 | get_chat_completion(
309 | client,
310 | log_dir=log_dir,
311 | log_results=log_results and i == 0,
312 | on_chunk=on_chunk,
313 | semaphore=semaphore,
314 | messages=task.messages,
315 | model=model,
316 | **_params, # type: ignore
317 | )
318 | for i in range(n)
319 | ):
320 | try:
321 | chat_completion = await chat_completion_future
322 | chat_completions.append(chat_completion)
323 | stats.update(id=chat_completion.id, usage=chat_completion.usage)
324 |
325 | async def _grade(choice: Choice, grader: Grader) -> tuple[int, Grade]:
326 | return choice.index, await maybe_await(grader(choice))
327 |
328 | for grade_future in asyncio.as_completed(
329 | _grade(choice, task.grader) for choice in chat_completion.choices
330 | ):
331 | try:
332 | choice_index, grade = await grade_future
333 | reward, _metrics = (
334 | grade if isinstance(grade, tuple) else (grade, {})
335 | )
336 | stats.update(id=chat_completion.id, reward=reward, metrics=_metrics)
337 | rewards[chat_completion.id, choice_index] = reward
338 | metrics[chat_completion.id, choice_index] = _metrics
339 | except Exception as e:
340 | exceptions.append(e)
341 | stats.update(id=chat_completion.id, exception=e)
342 | continue
343 | except Exception as e:
344 | exceptions.append(e)
345 | stats.update(id=None, exception=e)
346 | continue
347 | if rewards:
348 | reward_mean = np.mean(list(rewards.values()))
349 | reward_std = np.std(list(rewards.values()))
350 | # calculate GRPO advantages
351 | advantages = {
352 | key: float((reward - reward_mean) / (reward_std + 1e-6))
353 | for key, reward in rewards.items()
354 | }
355 | else:
356 | advantages = {key: 0.0 for key in rewards.keys()}
357 | return await maybe_await(
358 | transform(
359 | TaskResult(
360 | task=task,
361 | chat_completions=chat_completions,
362 | rewards=rewards,
363 | metrics=metrics,
364 | advantages=advantages,
365 | exceptions=exceptions,
366 | )
367 | )
368 | )
369 |
--------------------------------------------------------------------------------
/lib/temporal_clue.py:
--------------------------------------------------------------------------------
1 | import json
2 | from openai.types.chat.chat_completion import Choice
3 | import re
4 | from typing import Iterable, TypedDict
5 |
6 | from .tasks import Task
7 |
8 |
9 | class TemporalCluePuzzle(TypedDict):
10 | num_clues: int
11 | prompt: str
12 | solution: dict[str, str]
13 |
14 |
15 | def get_temporal_clue_puzzles() -> list[TemporalCluePuzzle]:
16 | return json.load(open("./data/puzzles.json"))
17 |
18 |
19 | def get_temporal_clue_tasks() -> Iterable[Task]:
20 | for puzzle in get_temporal_clue_puzzles():
21 |
22 | def grader(choice: Choice, puzzle: TemporalCluePuzzle = puzzle) -> float:
23 | content = choice.message.content
24 | assert isinstance(content, str)
25 | num_correct = 0
26 | for key, value in puzzle["solution"].items():
27 | if matches := re.findall(rf"{key}\. ([A-Za-z \.:-]+)", content):
28 | match = matches[-1]
29 | if match.strip().lower() == value.lower():
30 | num_correct += 1
31 | return num_correct / len(puzzle["solution"])
32 |
33 | yield Task(
34 | messages=[
35 | {
36 | "role": "user",
37 | "content": puzzle["prompt"],
38 | }
39 | ],
40 | grader=grader,
41 | )
42 |
--------------------------------------------------------------------------------
/lib/tokenize.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from itertools import takewhile
3 | from openai.types.chat.chat_completion import Choice
4 | from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
5 | import random
6 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
7 | from typing import cast
8 |
9 |
10 | from .tasks import TaskResult
11 |
12 |
13 | @dataclass
14 | class TokenizedResult:
15 | conversation: list
16 | reward: float
17 | advantage: float
18 | chat_template: str
19 | chat: str
20 | tokens: list[str]
21 | token_ids: list[int]
22 | input_pos: list[int]
23 | assistant_mask: list[int]
24 | token_logprobs: list[ChatCompletionTokenLogprob] | None
25 | prompt_id: int = 0
26 | prompt_length: int = 0
27 |
28 | def without_prompt(self) -> "TokenizedResult":
29 | assistant_mask = self.assistant_mask[self.prompt_length :]
30 | return TokenizedResult(
31 | conversation=self.conversation,
32 | advantage=self.advantage,
33 | reward=self.reward,
34 | chat_template=self.chat_template,
35 | chat=self.chat,
36 | tokens=self.tokens[self.prompt_length :],
37 | token_ids=self.token_ids[self.prompt_length :],
38 | input_pos=self.input_pos[self.prompt_length :],
39 | assistant_mask=assistant_mask,
40 | token_logprobs=(
41 | self.token_logprobs[len(self.token_logprobs) - sum(assistant_mask) :]
42 | if self.token_logprobs is not None
43 | else None
44 | ),
45 | prompt_id=self.prompt_id,
46 | prompt_length=0,
47 | )
48 |
49 |
50 | class TaskResultTokenizer:
51 | def __init__(
52 | self,
53 | pretrained_tokenizer_or_model_name_or_path: (
54 | PreTrainedTokenizer | PreTrainedTokenizerFast | str
55 | ),
56 | ) -> None:
57 | self.tokenizer = (
58 | AutoTokenizer.from_pretrained(pretrained_tokenizer_or_model_name_or_path)
59 | if isinstance(pretrained_tokenizer_or_model_name_or_path, str)
60 | else pretrained_tokenizer_or_model_name_or_path
61 | )
62 |
63 | def __call__(self, task_result: TaskResult) -> list[TokenizedResult]:
64 | chat_completions = task_result.chat_completions.copy()
65 | random.shuffle(chat_completions)
66 | tokenized_results = [
67 | self._tokenized_result(
68 | task_result,
69 | choice,
70 | task_result.rewards.get((chat_completion.id, choice.index), 0),
71 | task_result.advantages.get((chat_completion.id, choice.index), 0),
72 | )
73 | for chat_completion in chat_completions
74 | for choice in chat_completion.choices
75 | ]
76 | prompt_id = random.randint(-(2**63), 2**63 - 1)
77 | prompt_length = len(
78 | list(
79 | takewhile(
80 | lambda x: len(set(x)) == 1,
81 | zip(*(r.token_ids for r in tokenized_results)),
82 | )
83 | )
84 | )
85 | for result in tokenized_results:
86 | result.prompt_id = prompt_id
87 | result.prompt_length = prompt_length
88 | # zero out assistant prompt tokens
89 | result.assistant_mask[:prompt_length] = [0] * prompt_length
90 | return tokenized_results
91 |
92 | def _tokenized_result(
93 | self, task_result: TaskResult, choice: Choice, reward: float, advantage: float
94 | ) -> TokenizedResult:
95 | conversation: list = task_result.task.messages + [
96 | {
97 | "role": "assistant",
98 | "content": choice.message.content,
99 | }
100 | ]
101 | chat_template = update_chat_template(self.tokenizer.get_chat_template())
102 | chat = cast(
103 | str,
104 | self.tokenizer.apply_chat_template(
105 | conversation, chat_template=chat_template, tokenize=False
106 | ),
107 | )
108 | tokenized_result = cast(
109 | dict[str, list[int]],
110 | self.tokenizer.apply_chat_template(
111 | conversation,
112 | chat_template=chat_template,
113 | return_dict=True,
114 | return_assistant_tokens_mask=True,
115 | ),
116 | )
117 | if (
118 | choice.logprobs
119 | and choice.logprobs.content
120 | and choice.logprobs.content[0].token.startswith("token_id:")
121 | ):
122 | start = tokenized_result["assistant_masks"].index(1)
123 | try:
124 | end = start + tokenized_result["assistant_masks"][start:].index(0)
125 | except ValueError:
126 | end = len(tokenized_result["assistant_masks"])
127 | tokenized_result["input_ids"][start:end] = [
128 | int(token_logprob.token.split(":")[1])
129 | for token_logprob in choice.logprobs.content
130 | ]
131 | tokenized_result["assistant_masks"][start:end] = [
132 | 1 for _ in choice.logprobs.content
133 | ]
134 | token_logprobs = choice.logprobs.content
135 | else:
136 | token_logprobs = None
137 | tokens = [
138 | self.tokenizer.decode(token_id)
139 | for token_id in tokenized_result["input_ids"]
140 | ]
141 | if token_logprobs is None:
142 | token_logprobs = self.get_token_logprobs(
143 | choice,
144 | [
145 | token
146 | for token, mask in zip(tokens, tokenized_result["assistant_masks"])
147 | if mask
148 | ],
149 | )
150 | return TokenizedResult(
151 | conversation=conversation,
152 | reward=reward,
153 | advantage=advantage,
154 | chat_template=chat_template,
155 | chat=chat,
156 | tokens=tokens,
157 | token_ids=tokenized_result["input_ids"],
158 | input_pos=list(range(len(tokens))),
159 | assistant_mask=tokenized_result["assistant_masks"],
160 | token_logprobs=token_logprobs,
161 | )
162 |
163 | def get_token_logprobs(
164 | self,
165 | choice: Choice,
166 | assistant_tokens: list[str],
167 | ) -> list[ChatCompletionTokenLogprob] | None:
168 | if not choice.logprobs:
169 | return None
170 | if not choice.logprobs.content:
171 | return None
172 | result_token_logprobs = choice.logprobs.content.copy()
173 | if "".join(assistant_tokens) != "".join(
174 | token_logprob.token for token_logprob in result_token_logprobs
175 | ) and len(assistant_tokens) != len(result_token_logprobs):
176 | print("Assistant tokens are not equal, skipping token logprobs")
177 | return None
178 | elif assistant_tokens == [
179 | token_logprob.token for token_logprob in result_token_logprobs
180 | ]:
181 | return result_token_logprobs
182 | else:
183 | completion = ""
184 | result_completion = ""
185 | token_logprobs = []
186 | try:
187 | while True:
188 | if completion == result_completion:
189 | token = assistant_tokens.pop(0)
190 | result_token_logprob = result_token_logprobs.pop(0)
191 | result_token = result_token_logprob.token
192 | if token == result_token:
193 | token_logprobs.append(result_token_logprob)
194 | else:
195 | token_logprobs.append(
196 | ChatCompletionTokenLogprob(
197 | token=token,
198 | logprob=float("nan"),
199 | top_logprobs=[],
200 | )
201 | )
202 | completion += token
203 | result_completion += result_token
204 | elif len(completion) < len(result_completion):
205 | token = assistant_tokens.pop(0)
206 | token_logprobs.append(
207 | ChatCompletionTokenLogprob(
208 | token=token,
209 | logprob=float("nan"),
210 | top_logprobs=[],
211 | )
212 | )
213 | completion += token
214 | elif len(completion) > len(result_completion):
215 | result_completion += result_token_logprobs.pop(0).token
216 | else:
217 | print("Warning: Completions are not equal")
218 | print(f"Completion: {completion}")
219 | print(f"Result completion: {result_completion}")
220 | token_logprobs = None
221 | break
222 | except IndexError:
223 | pass
224 | return token_logprobs
225 |
226 |
227 | def update_chat_template(chat_template: str) -> str:
228 | return (
229 | chat_template
230 | # Remove template logic that strips reasoning content from the chat messages
231 | .replace(
232 | "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}",
233 | "",
234 | )
235 | # Add generation tags for assistant token masking
236 | .replace(
237 | "{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}",
238 | "{{'<|Assistant|>'}}{% generation %}{{ content }}{% endgeneration %}{{'<|end▁of▁sentence|>'}}",
239 | )
240 | # Add generation tags for assistant token masking (for Hermes 2 Theta)
241 | .replace(
242 | "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}",
243 | "{{'<|im_start|>' + message['role'] + '\n'}}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{% else %}{{ message['content'] }}{% endif %}{{'<|im_end|>' + '\n'}}",
244 | )
245 | # Add generation tags for assistant token masking (for Qwen 2.5 Instruct)
246 | .replace(
247 | """
248 | {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
249 | {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}
250 | """.strip(),
251 | """
252 | {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
253 | {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}
254 | {%- elif message.role == "assistant" and not message.tool_calls %}
255 | {{- '<|im_start|>' + message.role + '\\n' }}{% generation %}{{ message.content }}{% endgeneration %}{{ '<|im_end|>' + '\\n' }}
256 | """.strip(),
257 | ).replace(
258 | """
259 | {%- elif message.role == "assistant" %}
260 | {{- '<|im_start|>' + message.role }}
261 | {%- if message.content %}
262 | {{- '\\n' + message.content }}
263 | {%- endif %}""".strip(),
264 | """
265 | {%- elif message.role == "assistant" %}
266 | {{- '<|im_start|>' + message.role }}
267 | {%- if message.content %}
268 | {{- '\\n' }}{% generation %}{{ message.content }}{% endgeneration %}
269 | {%- endif %}""".strip(),
270 | )
271 | # Add generation tags for assistant token masking (for Llama 3.3 70B)
272 | .replace(
273 | "{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}",
274 | "{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}{%- if message['role'] == 'assistant' %}{% generation %}{{ message['content'] | trim + '<|eot_id|>' }}{% endgeneration %}{% else %}{{ message['content'] | trim + '<|eot_id|>' }}{% endif %}",
275 | )
276 | )
277 |
--------------------------------------------------------------------------------
/lib/tqdm.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | try:
4 | sys.modules["IPython"].get_ipython
5 | from tqdm import notebook as tqdm
6 | except:
7 | from tqdm import std as tqdm
8 |
--------------------------------------------------------------------------------
/lib/tune.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import glob
3 | from lib.pack import PackedDataset, PackedTensors, packed_tensors_to_dir
4 | from lib.recipe import ComponentConfig, recipe_main, TuneRecipeConfig
5 | from omegaconf import OmegaConf
6 | import os
7 | import re
8 | import shutil
9 | import sys
10 | import torch
11 | from torchtune.modules import TransformerDecoder
12 | from torchtune.training import cleanup_before_training, FullModelHFCheckpointer
13 | from torchtune.training.metric_logging import DiskLogger
14 | import tqdm
15 | from typing import Any, Callable, Literal, IO
16 |
17 |
18 | Verbosity = Literal[0, 1, 2]
19 |
20 |
21 | def clear_iteration_dirs(output_dir: str, excluding: list[int]) -> None:
22 | for dir in os.listdir(output_dir):
23 | if (
24 | os.path.isdir(os.path.join(output_dir, dir))
25 | and dir.isdigit()
26 | and int(dir) not in excluding
27 | ):
28 | iteration_dir = os.path.join(output_dir, dir)
29 | shutil.rmtree(iteration_dir)
30 | print(f"Deleted iteration directory {iteration_dir}")
31 |
32 |
33 | def get_iteration(output_dir: str) -> int:
34 | os.makedirs(output_dir, exist_ok=True)
35 | return max(
36 | (
37 | int(subdir)
38 | for subdir in os.listdir(output_dir)
39 | if os.path.isdir(os.path.join(output_dir, subdir)) and subdir.isdigit()
40 | ),
41 | default=0,
42 | )
43 |
44 |
45 | def get_last_iteration_dir(output_dir: str) -> str | None:
46 | last_iteration_dir = os.path.join(output_dir, f"{get_iteration(output_dir):04d}")
47 | return last_iteration_dir if os.path.exists(last_iteration_dir) else None
48 |
49 |
50 | def last_tune_log(output_dir: str) -> list[dict[str, float]]:
51 | sorted_logs = sorted(glob.glob(f"{output_dir}/logs/*"))
52 | contents = open(sorted_logs[-1]).read()
53 | lines = contents.strip().splitlines()
54 | parsed_logs = []
55 | for line in lines:
56 | step_part, metrics_part = line.split(" | ")
57 | step = int(step_part.split()[1])
58 | metrics = {}
59 | for metric in metrics_part.split():
60 | key, value = metric.split(":")
61 | metrics[key] = float(value)
62 | parsed_logs.append({"step": step, **metrics})
63 | return parsed_logs
64 |
65 |
66 | async def tune(
67 | base_model: str,
68 | output_dir: str,
69 | packed_tensors: PackedTensors,
70 | model: Callable[[], TransformerDecoder],
71 | model_type: str,
72 | config: TuneRecipeConfig = TuneRecipeConfig(),
73 | in_process: bool = False,
74 | verbosity: Verbosity = 2,
75 | ) -> str:
76 | if os.path.isdir(base_model):
77 | base_checkpoint_dir = base_model
78 | else:
79 | process = await asyncio.create_subprocess_shell(
80 | f"HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download {base_model}",
81 | stdout=asyncio.subprocess.PIPE,
82 | stderr=asyncio.subprocess.PIPE,
83 | )
84 |
85 | base_stdout = []
86 |
87 | async def read_stream(stream, is_stderr=False):
88 | while True:
89 | line = await stream.readline()
90 | if not line:
91 | break
92 | text = line.decode().rstrip()
93 | if is_stderr:
94 | print(text)
95 | else:
96 | base_stdout.append(text)
97 |
98 | await asyncio.gather(
99 | read_stream(process.stdout, is_stderr=False),
100 | read_stream(process.stderr, is_stderr=True),
101 | )
102 | await process.wait()
103 | base_checkpoint_dir = "\n".join(base_stdout).strip()
104 |
105 | config.checkpointer = _get_checkpointer_config(
106 | checkpoint_dir=max(
107 | (d for d in glob.glob(f"{output_dir}/*") if d.split("/")[-1].isdigit()),
108 | key=lambda x: int(x.split("/")[-1]),
109 | default=base_checkpoint_dir,
110 | ),
111 | output_dir=output_dir,
112 | tune_model_type=model_type,
113 | )
114 | if config.loss.kl_coef > 0:
115 | print("Using reference checkpointer")
116 | config.reference_checkpointer = _get_checkpointer_config(
117 | checkpoint_dir=base_checkpoint_dir,
118 | output_dir=output_dir,
119 | tune_model_type=model_type,
120 | )
121 | if config.metric_logger is None:
122 | config.metric_logger = ComponentConfig(DiskLogger, log_dir=f"{output_dir}/logs")
123 | config.model = ComponentConfig(model)
124 | disk_packed_tensors = packed_tensors_to_dir(packed_tensors, f"{output_dir}/tensors")
125 | config.dataset = ComponentConfig(
126 | PackedDataset,
127 | **disk_packed_tensors,
128 | )
129 | config.seed = 42
130 | dict_config = config.dict_config()
131 | OmegaConf.save(dict_config, f"{output_dir}/config.yaml")
132 | if in_process:
133 | cleanup_before_training()
134 | recipe_main(config)
135 | else:
136 | await _tune_run(
137 | config_path=f"{output_dir}/config.yaml",
138 | total=disk_packed_tensors["num_sequences"],
139 | verbosity=verbosity,
140 | torchrun_kwargs={"nproc_per_node": torch.cuda.device_count()},
141 | # tune_run_env={"CUDA_LAUNCH_BLOCKING": "1"},
142 | )
143 | epoch_dirs = lambda: glob.glob(f"{output_dir}/epoch_*")
144 | epoch_dir = max(
145 | epoch_dirs(),
146 | key=lambda x: int(x.split("_")[-1]),
147 | default=None,
148 | )
149 | assert (
150 | epoch_dir is not None
151 | ), f"No epoch directory found in output directory {output_dir}"
152 | iteration_dir = f"{output_dir}/{get_iteration(output_dir) + 1:04d}"
153 | os.rename(epoch_dir, iteration_dir)
154 | for epoch_dir in epoch_dirs():
155 | os.rmdir(epoch_dir)
156 | return iteration_dir
157 |
158 |
159 | def _get_checkpointer_config(
160 | checkpoint_dir: str,
161 | output_dir: str,
162 | tune_model_type: str,
163 | checkpoint_files: list[str] | None = None,
164 | output_subdir: str = "",
165 | ) -> ComponentConfig[FullModelHFCheckpointer]:
166 | return ComponentConfig(
167 | FullModelHFCheckpointer,
168 | checkpoint_dir=checkpoint_dir,
169 | checkpoint_files=checkpoint_files
170 | or [
171 | os.path.basename(file)
172 | for ext in ["safetensors", "pt", "ckpt", "bin", "pth"]
173 | for file in glob.glob(f"{checkpoint_dir}/*.{ext}")
174 | ],
175 | recipe_checkpoint=None,
176 | output_dir=output_dir + output_subdir,
177 | model_type=tune_model_type,
178 | )
179 |
180 |
181 | async def _tune_run(
182 | config_path: str,
183 | total: int,
184 | verbosity: Verbosity = 2,
185 | torchrun_kwargs: dict[str, Any] | None = None,
186 | tune_run_env: dict[str, str] | None = None,
187 | ) -> None:
188 | args = [
189 | "tune",
190 | "run",
191 | *[
192 | f"--{key.replace('_', '-')}{f'={value}' if value is not True else ''}"
193 | for key, value in (torchrun_kwargs or {}).items()
194 | ],
195 | "lib.recipe.TuneRecipe",
196 | "--config",
197 | config_path,
198 | ]
199 | if verbosity > 0:
200 | print(f"$ {' '.join(args)}")
201 | process = await asyncio.create_subprocess_exec(
202 | *args,
203 | stdout=asyncio.subprocess.PIPE,
204 | stderr=asyncio.subprocess.PIPE,
205 | env={**os.environ, **(tune_run_env or {})},
206 | )
207 | if verbosity == 1:
208 | pbar = tqdm.tqdm(total=total)
209 | else:
210 | pbar = None
211 |
212 | async def log_output(stream: asyncio.StreamReader, io: IO[str]) -> None:
213 | output = ""
214 | while True:
215 | try:
216 | chunk = await stream.read(4096)
217 | if not chunk:
218 | break
219 | output += chunk.decode()
220 | if verbosity > 1:
221 | io.write(output)
222 | io.flush()
223 | output = ""
224 | elif verbosity == 1:
225 | output = output.split("\n")[-1]
226 | if pbar:
227 | pbar_start = re.compile(r"(\d+)\|(\d+)\|Loss: ([\d.]+):")
228 | if match := pbar_start.search(output):
229 | epoch, step, loss = match.groups()
230 | pbar.update(int(step) - pbar.n)
231 | pbar.set_description(f"{epoch}|{step}|Loss: {loss}")
232 | metrics = {
233 | key: value
234 | for key, value in re.findall(r"(\w+)=([\d.-]+)", output)
235 | }
236 | if metrics:
237 | pbar.set_postfix(**metrics)
238 | output = ""
239 | else:
240 | pbar_regex = re.compile(
241 | r"\[(?:\d+:)?\d+:\d+<(?:\d+:)?\d+:\d+.*\]"
242 | )
243 | if pbar_regex.search(output):
244 | io.write(output)
245 | io.flush()
246 | output = ""
247 | except Exception:
248 | break
249 |
250 | tasks = []
251 | if process.stdout:
252 | tasks.append(asyncio.create_task(log_output(process.stdout, sys.stdout)))
253 | if process.stderr:
254 | tasks.append(asyncio.create_task(log_output(process.stderr, sys.stderr)))
255 | try:
256 | _ = await asyncio.gather(*tasks)
257 | except asyncio.CancelledError:
258 | process.kill()
259 | if pbar:
260 | pbar.close()
261 |
--------------------------------------------------------------------------------
/lib/types.py:
--------------------------------------------------------------------------------
1 | from openai._types import Body, Headers, Query
2 | from openai.types.chat.completion_create_params import CompletionCreateParamsBase
3 | from typing import Never
4 |
5 |
6 | class CreateParams(CompletionCreateParamsBase, total=False):
7 | """Parameters for chat completion creation with additional fields."""
8 |
9 | extra_headers: Headers
10 | extra_query: Query
11 | extra_body: Body
12 |
13 |
14 | class ChatCompletionParams(CreateParams, total=False):
15 | """Parameters for chat completion with restricted fields."""
16 |
17 | messages: Never # type: ignore
18 | model: Never # type: ignore
19 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import signal
3 | from typing import Generator
4 |
5 |
6 | @contextmanager
7 | def timeout(seconds: int = 1) -> Generator[None, None, None]:
8 | def timeout_handler(signum: object, frame: object) -> None:
9 | raise TimeoutError()
10 |
11 | original_handler = signal.signal(signal.SIGALRM, timeout_handler)
12 | try:
13 | signal.alarm(seconds)
14 | yield
15 | finally:
16 | signal.alarm(0)
17 | signal.signal(signal.SIGALRM, original_handler)
18 |
--------------------------------------------------------------------------------
/lib/vllm.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from dataclasses import dataclass
3 | import httpx
4 | from openai import AsyncOpenAI, DefaultAsyncHttpxClient
5 | import os
6 | import socket
7 | import subprocess
8 | import sys
9 | import re
10 | from typing import Any, IO, Optional
11 |
12 |
13 | @dataclass
14 | class vLLM:
15 | client: AsyncOpenAI
16 | max_concurrent_tokens: int
17 | model: str
18 | process: asyncio.subprocess.Process
19 |
20 |
21 | async def start_vllm(
22 | model: str,
23 | env: Optional[dict[str, str]] = None,
24 | log_file: str = "./logs/vllm.log",
25 | max_concurrent_requests: int = 128,
26 | named_arguments: dict[str, Any] = {},
27 | timeout: float = 120.0,
28 | verbosity: int = 2,
29 | ) -> vLLM:
30 | kill_vllm_workers()
31 | if os.path.exists(os.path.abspath(model)):
32 | named_arguments.setdefault("served_model_name", model)
33 | model = os.path.abspath(model)
34 | port = named_arguments.get("port") or 8000
35 | while True:
36 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
37 | try:
38 | sock.bind((named_arguments.get("host") or "0.0.0.0", port))
39 | break
40 | except socket.error:
41 | if "port" in named_arguments and named_arguments["port"] == port:
42 | raise RuntimeError(f"Port {port} is already in use")
43 | port += 1
44 | finally:
45 | sock.close()
46 | named_arguments["port"] = port
47 | args = [
48 | "vllm",
49 | "serve",
50 | model,
51 | *[
52 | f"--{key.replace('_', '-')}{f'={value}' if value is not True else ''}"
53 | for key, value in named_arguments.items()
54 | ],
55 | "--api-key=default",
56 | ]
57 | # os.system("lsof -ti :8000 | xargs kill -9 2>/dev/null || true")
58 | process = await asyncio.create_subprocess_exec(
59 | *args,
60 | stdout=asyncio.subprocess.PIPE,
61 | stderr=asyncio.subprocess.PIPE,
62 | env={
63 | **os.environ,
64 | **(env or {}),
65 | },
66 | )
67 | if verbosity > 0:
68 | print(f"$ {' '.join(args)}")
69 | os.makedirs(os.path.dirname(log_file), exist_ok=True)
70 | log = open(log_file, "w")
71 | logging = verbosity > 1
72 | max_concurrent_tokens: Optional[int] = None
73 |
74 | async def log_output(stream: asyncio.StreamReader, io: IO[str]) -> None:
75 | while True:
76 | line = await stream.readline()
77 | if not line:
78 | break
79 | decoded_line = line.decode()
80 | if logging:
81 | io.write(decoded_line)
82 | io.flush()
83 | log.write(decoded_line)
84 | log.flush()
85 | nonlocal max_concurrent_tokens
86 | if not max_concurrent_tokens:
87 | match = re.search(
88 | r"Maximum concurrency for (\d+) tokens per request: ([\d.]+)x",
89 | decoded_line,
90 | )
91 | if match:
92 | max_concurrent_tokens = int(
93 | int(match.group(1)) * float(match.group(2))
94 | )
95 | log.close()
96 |
97 | if process.stdout:
98 | asyncio.create_task(log_output(process.stdout, sys.stdout))
99 | if process.stderr:
100 | asyncio.create_task(log_output(process.stderr, sys.stderr))
101 | client = AsyncOpenAI(
102 | api_key="default",
103 | base_url=f"http://{named_arguments.get('host', '0.0.0.0')}:{named_arguments['port']}/v1",
104 | max_retries=6,
105 | http_client=DefaultAsyncHttpxClient(
106 | limits=httpx.Limits(
107 | max_connections=max_concurrent_requests,
108 | max_keepalive_connections=max_concurrent_requests,
109 | ),
110 | timeout=httpx.Timeout(timeout=1_200, connect=10.0),
111 | ),
112 | )
113 | start = asyncio.get_event_loop().time()
114 | while True:
115 | try:
116 | await client.chat.completions.create(
117 | messages=[{"role": "user", "content": "Hello"}],
118 | model=named_arguments.get("served_model_name", model),
119 | max_tokens=1,
120 | )
121 | break
122 | except Exception:
123 | if asyncio.get_event_loop().time() - start > timeout:
124 | process.terminate()
125 | kill_vllm_workers()
126 | raise TimeoutError("vLLM server did not start in time")
127 | continue
128 | if logging:
129 | print(f"vLLM server started succesfully. Logs can be found at {log_file}")
130 | logging = False
131 | if max_concurrent_tokens is None:
132 | process.terminate()
133 | kill_vllm_workers()
134 | raise RuntimeError(
135 | "Max concurrent requests for the maximum model length not logged"
136 | )
137 | return vLLM(
138 | client,
139 | max_concurrent_tokens,
140 | named_arguments.get("served_model_name", model),
141 | process,
142 | )
143 |
144 |
145 | def kill_vllm_workers() -> None:
146 | result = subprocess.run(["ps", "aux"], capture_output=True, text=True)
147 | pids = [
148 | line.split()[1]
149 | for line in result.stdout.splitlines()
150 | if "from multiprocessing.spawn import spawn_main; spawn_main(tracker_fd="
151 | in line
152 | ]
153 | for pid in pids:
154 | subprocess.run(["sudo", "kill", "-9", pid], check=True)
155 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "deductive-reasoning"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | requires-python = ">=3.12"
7 | dependencies = [
8 | "aioitertools>=0.12.0",
9 | "matplotlib>=3.10.1",
10 | "polars>=1.24.0",
11 | "seaborn>=0.13.2",
12 | "torch>=2.5.1",
13 | "torchao>=0.8.0",
14 | "torchtune>=0.5.0",
15 | "vllm>=0.7.0",
16 | "wandb>=0.19.8",
17 | ]
18 |
19 | [tool.uv]
20 | dev-dependencies = [
21 | "ipykernel>=6.29.5",
22 | "ipywidgets>=8.1.5",
23 | "nbconvert>=7.16.6",
24 | ]
25 |
--------------------------------------------------------------------------------
/train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [
18 | {
19 | "data": {
20 | "text/html": [
21 | "\n"
30 | ],
31 | "text/plain": [
32 | ""
33 | ]
34 | },
35 | "metadata": {},
36 | "output_type": "display_data"
37 | }
38 | ],
39 | "source": [
40 | "%%html\n",
41 | ""
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [
57 | {
58 | "name": "stdout",
59 | "output_type": "stream",
60 | "text": [
61 | "import error: No module named 'triton'\n"
62 | ]
63 | },
64 | {
65 | "name": "stderr",
66 | "output_type": "stream",
67 | "text": [
68 | "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
69 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mbradhilton\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
70 | ]
71 | },
72 | {
73 | "data": {
74 | "text/html": [
75 | "Tracking run with wandb version 0.19.8"
76 | ],
77 | "text/plain": [
78 | ""
79 | ]
80 | },
81 | "metadata": {},
82 | "output_type": "display_data"
83 | },
84 | {
85 | "data": {
86 | "text/html": [
87 | "Run data is saved locally in /Users/brad/github/bradhilton/deductive-reasoning/wandb/run-20250306_075102-
"
88 | ],
89 | "text/plain": [
90 | ""
91 | ]
92 | },
93 | "metadata": {},
94 | "output_type": "display_data"
95 | },
96 | {
97 | "data": {
98 | "text/html": [
99 | "Syncing run to Weights & Biases (docs)
"
100 | ],
101 | "text/plain": [
102 | ""
103 | ]
104 | },
105 | "metadata": {},
106 | "output_type": "display_data"
107 | },
108 | {
109 | "data": {
110 | "text/html": [
111 | " View project at https://wandb.ai/bradhilton/grpo-tests"
112 | ],
113 | "text/plain": [
114 | ""
115 | ]
116 | },
117 | "metadata": {},
118 | "output_type": "display_data"
119 | },
120 | {
121 | "data": {
122 | "text/html": [
123 | " View run at https://wandb.ai/bradhilton/grpo-tests/runs/%3CYOUR-RUN-NAME%3E"
124 | ],
125 | "text/plain": [
126 | ""
127 | ]
128 | },
129 | "metadata": {},
130 | "output_type": "display_data"
131 | },
132 | {
133 | "data": {
134 | "text/plain": [
135 | "(64, 64, 2860)"
136 | ]
137 | },
138 | "execution_count": 3,
139 | "metadata": {},
140 | "output_type": "execute_result"
141 | }
142 | ],
143 | "source": [
144 | "import asyncio\n",
145 | "from itertools import cycle, islice\n",
146 | "from lib import models\n",
147 | "from lib.grpo import GRPO\n",
148 | "from lib.inference_early_stop import InferenceEarlyStop\n",
149 | "from lib.pack import packed_tensors_from_tokenized_results, plot_packed_tensors\n",
150 | "from lib.recipe import ComponentConfig, TuneRecipeConfig\n",
151 | "from lib.tasks import ChatCompletionParams, get_task_results\n",
152 | "from lib.temporal_clue import get_temporal_clue_tasks\n",
153 | "from lib.tokenize import TaskResultTokenizer\n",
154 | "from lib.tune import (\n",
155 | " clear_iteration_dirs,\n",
156 | " get_iteration,\n",
157 | " get_last_iteration_dir,\n",
158 | " last_tune_log,\n",
159 | " tune,\n",
160 | " Verbosity,\n",
161 | ")\n",
162 | "from lib.vllm import start_vllm, kill_vllm_workers\n",
163 | "import polars as pl\n",
164 | "import random\n",
165 | "import torch\n",
166 | "from transformers import AutoTokenizer\n",
167 | "import wandb\n",
168 | "\n",
169 | "run_name = \"\"\n",
170 | "assert run_name != \"\", \"Don't forget to choose a run name\"\n",
171 | "run = wandb.init(name=run_name, id=run_name, resume=\"allow\")\n",
172 | "\n",
173 | "tasks = list(get_temporal_clue_tasks())\n",
174 | "val_tasks = tasks[:64]\n",
175 | "test_tasks = tasks[64:128]\n",
176 | "train_tasks = tasks[128:]\n",
177 | "random.seed(42)\n",
178 | "random.shuffle(train_tasks)\n",
179 | "len(val_tasks), len(test_tasks), len(train_tasks)"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "# GRPO params\n",
189 | "wandb.config[\"clip_epsilon\"] = clip_epsilon = 0.2\n",
190 | "wandb.config[\"entropy_coef\"] = entropy_coef = 0.0\n",
191 | "wandb.config[\"kl_coef\"] = kl_coef = 0.0\n",
192 | "wandb.config[\"tanh\"] = tanh = False\n",
193 | "\n",
194 | "# Model params\n",
195 | "model = models.qwen_32b()\n",
196 | "wandb.config[\"model\"] = model.base_model\n",
197 | "tokenizer = AutoTokenizer.from_pretrained(model.base_model)\n",
198 | "wandb.config[\"seq_len\"] = seq_len = 16384\n",
199 | "\n",
200 | "# Optimizer params\n",
201 | "wandb.config[\"lr\"] = lr = 6e-6\n",
202 | "wandb.config[\"betas\"] = betas = (0.9, 0.99)\n",
203 | "wandb.config[\"weight_decay\"] = weight_decay = 0.1\n",
204 | "\n",
205 | "# Training params\n",
206 | "num_iterations = 1_000\n",
207 | "wandb.config[\"samples_per_task\"] = samples_per_task = 50\n",
208 | "wandb.config[\"tasks_per_iter\"] = tasks_per_iter = 32\n",
209 | "wandb.config[\"stride\"] = stride = 32\n",
210 | "output_dir = f\"./models/{run_name}\"\n",
211 | "\n",
212 | "# Inference params\n",
213 | "expected_tokens = 1000 # Initial expected completion tokens per task sample\n",
214 | "inference_early_stop = InferenceEarlyStop(alpha=0.992, threshold=-3.0)\n",
215 | "\n",
216 | "# Logging params\n",
217 | "plot_tensors = True\n",
218 | "verbosity: Verbosity = 2"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {},
225 | "outputs": [],
226 | "source": [
227 | "# Start from the latest iteration if it exists, otherwise start from the base model\n",
228 | "model_name = get_last_iteration_dir(output_dir) or model.base_model\n",
229 | "\n",
230 | "# Loop from the current iteration to the target number of iterations\n",
231 | "for i in range(get_iteration(output_dir), num_iterations):\n",
232 | " # Start vLLM server\n",
233 | " vllm = await start_vllm(\n",
234 | " model_name,\n",
235 | " max_concurrent_requests=4096,\n",
236 | " env={\"VLLM_ALLOW_LONG_MAX_MODEL_LEN\": \"1\"},\n",
237 | " named_arguments=dict(\n",
238 | " block_size=32,\n",
239 | " disable_log_requests=True,\n",
240 | " enable_prefix_caching=True,\n",
241 | " enforce_eager=True,\n",
242 | " gpu_memory_utilization=0.95,\n",
243 | " max_model_len=16384,\n",
244 | " max_num_seqs=4096,\n",
245 | " max_num_batched_tokens=16384,\n",
246 | " num_scheduler_steps=16,\n",
247 | " preemption_mode=\"swap\",\n",
248 | " return_tokens_as_token_ids=True,\n",
249 | " swap_space=80,\n",
250 | " tensor_parallel_size=torch.cuda.device_count(),\n",
251 | " ),\n",
252 | " timeout=360 + 15 * torch.cuda.device_count(),\n",
253 | " verbosity=verbosity,\n",
254 | " )\n",
255 | "\n",
256 | " # Create semaphore for rate limiting\n",
257 | " semaphore = asyncio.Semaphore(\n",
258 | " int(\n",
259 | " 1.3\n",
260 | " * (torch.cuda.device_count() / model.min_gpus)\n",
261 | " * (vllm.max_concurrent_tokens / expected_tokens)\n",
262 | " )\n",
263 | " )\n",
264 | "\n",
265 | " # Get results for logging validation performance and for tuning with train results\n",
266 | " offset = i * stride\n",
267 | " val_results, train_results = await asyncio.gather(\n",
268 | " get_task_results(\n",
269 | " tasks=val_tasks,\n",
270 | " client=vllm.client,\n",
271 | " model=vllm.model,\n",
272 | " log_results=8,\n",
273 | " n=2,\n",
274 | " on_chunk=inference_early_stop,\n",
275 | " params=ChatCompletionParams(\n",
276 | " stream_options={\n",
277 | " \"include_usage\": True,\n",
278 | " },\n",
279 | " max_completion_tokens=8192,\n",
280 | " ),\n",
281 | " pbar_desc=\"val\",\n",
282 | " semaphore=semaphore,\n",
283 | " ),\n",
284 | " get_task_results(\n",
285 | " tasks=list(islice(cycle(train_tasks), offset, offset + tasks_per_iter)),\n",
286 | " client=vllm.client,\n",
287 | " model=vllm.model,\n",
288 | " log_results=False,\n",
289 | " n=samples_per_task,\n",
290 | " on_chunk=inference_early_stop,\n",
291 | " params=ChatCompletionParams(\n",
292 | " stream_options={\n",
293 | " \"include_usage\": True,\n",
294 | " },\n",
295 | " max_completion_tokens=8192,\n",
296 | " ),\n",
297 | " pbar_desc=\"train\",\n",
298 | " semaphore=semaphore,\n",
299 | " transform=TaskResultTokenizer(tokenizer),\n",
300 | " ),\n",
301 | " )\n",
302 | "\n",
303 | " # Stop vLLM workers\n",
304 | " vllm.process.terminate()\n",
305 | " kill_vllm_workers()\n",
306 | "\n",
307 | " # Log results to Weights & Biases\n",
308 | " val_stats = val_results.stats\n",
309 | " assert val_stats.grades > 0\n",
310 | " assert val_stats.usages > 0\n",
311 | " wandb_data = {\n",
312 | " \"iteration\": i,\n",
313 | " \"exceptions\": val_stats.exceptions + train_results.stats.exceptions,\n",
314 | " \"reward\": val_stats.total_reward / val_stats.grades,\n",
315 | " \"tokens\": round(val_stats.completion_tokens / val_stats.usages),\n",
316 | " }\n",
317 | " for metric in val_stats.total_metrics:\n",
318 | " wandb_data[metric] = val_stats.total_metrics[metric] / val_stats.grades\n",
319 | " try:\n",
320 | " wandb_data.update(\n",
321 | " pl.DataFrame(last_tune_log(output_dir)).drop(\"step\").mean().to_dicts()[0]\n",
322 | " )\n",
323 | " except Exception:\n",
324 | " pass\n",
325 | " wandb.log(wandb_data)\n",
326 | "\n",
327 | " # Update expected tokens\n",
328 | " expected_tokens = wandb_data[\"tokens\"]\n",
329 | "\n",
330 | " # Clean up output directory to save space\n",
331 | " try:\n",
332 | " best_iteration = (\n",
333 | " wandb.Api()\n",
334 | " .run(f\"{run.entity}/{run.project}/{run.id}\")\n",
335 | " .history()\n",
336 | " .sort_values(by=\"reward\")[\"iteration\"]\n",
337 | " .iloc[-1]\n",
338 | " )\n",
339 | " # Clear all but the best and current iterations\n",
340 | " clear_iteration_dirs(output_dir, excluding=[best_iteration, i])\n",
341 | " except Exception:\n",
342 | " pass\n",
343 | "\n",
344 | " # Pack the tokenized results into tensors\n",
345 | " tokenized_results = [\n",
346 | " result\n",
347 | " for results in train_results\n",
348 | " for result in results\n",
349 | " if result.advantage != 0\n",
350 | " ]\n",
351 | " packed_tensors = packed_tensors_from_tokenized_results(\n",
352 | " tokenized_results,\n",
353 | " seq_len=seq_len,\n",
354 | " pad_token_id=tokenizer.pad_token_id, # type: ignore\n",
355 | " )\n",
356 | " if plot_tensors:\n",
357 | " plot_packed_tensors(packed_tensors)\n",
358 | " elif verbosity > 0:\n",
359 | " print(f\"Packed tensors into {packed_tensors[\"tokens\"].size()} shape\")\n",
360 | "\n",
361 | " # Tune the model\n",
362 | " model_name = await tune(\n",
363 | " base_model=model.base_model if kl_coef > 0 else model_name,\n",
364 | " output_dir=output_dir,\n",
365 | " packed_tensors=packed_tensors,\n",
366 | " model=model.tune_model,\n",
367 | " model_type=model.tune_model_type,\n",
368 | " config=TuneRecipeConfig(\n",
369 | " optimizer=ComponentConfig(\n",
370 | " \"torch.optim.AdamW\",\n",
371 | " lr=lr,\n",
372 | " betas=betas,\n",
373 | " weight_decay=weight_decay,\n",
374 | " fused=True,\n",
375 | " ),\n",
376 | " loss=ComponentConfig(\n",
377 | " GRPO,\n",
378 | " clip_epsilon=clip_epsilon,\n",
379 | " entropy_coef=entropy_coef,\n",
380 | " kl_coef=kl_coef,\n",
381 | " ),\n",
382 | " shuffle=True,\n",
383 | " batch_size=32768 // seq_len,\n",
384 | " fsdp_cpu_offload=True,\n",
385 | " enable_activation_checkpointing=True,\n",
386 | " enable_activation_offloading=True,\n",
387 | " custom_sharded_layers=[\"tok_embeddings\", \"output\"],\n",
388 | " num_output_chunks=model.tune_num_output_chunks,\n",
389 | " compile=True,\n",
390 | " ),\n",
391 | " verbosity=verbosity,\n",
392 | " )"
393 | ]
394 | }
395 | ],
396 | "metadata": {
397 | "kernelspec": {
398 | "display_name": ".venv",
399 | "language": "python",
400 | "name": "python3"
401 | },
402 | "language_info": {
403 | "codemirror_mode": {
404 | "name": "ipython",
405 | "version": 3
406 | },
407 | "file_extension": ".py",
408 | "mimetype": "text/x-python",
409 | "name": "python",
410 | "nbconvert_exporter": "python",
411 | "pygments_lexer": "ipython3",
412 | "version": "3.12.5"
413 | }
414 | },
415 | "nbformat": 4,
416 | "nbformat_minor": 2
417 | }
418 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from itertools import cycle, islice
3 | from lib import models
4 | from lib.grpo import GRPO
5 | from lib.inference_early_stop import InferenceEarlyStop
6 | from lib.pack import packed_tensors_from_tokenized_results
7 | from lib.recipe import ComponentConfig, TuneRecipeConfig
8 | from lib.tasks import ChatCompletionParams, get_task_results
9 | from lib.temporal_clue import get_temporal_clue_tasks
10 | from lib.tokenize import TaskResultTokenizer
11 | from lib.tune import (
12 | clear_iteration_dirs,
13 | get_iteration,
14 | get_last_iteration_dir,
15 | last_tune_log,
16 | tune,
17 | Verbosity,
18 | )
19 | from lib.vllm import start_vllm, kill_vllm_workers
20 | import polars as pl
21 | import random
22 | import torch
23 | from transformers import AutoTokenizer
24 | import wandb
25 |
26 | run_name = ""
27 | assert run_name != "", "Don't forget to choose a run name"
28 | run = wandb.init(name=run_name, id=run_name, resume="allow")
29 |
30 | # Get tasks
31 | tasks = list(get_temporal_clue_tasks())
32 | val_tasks = tasks[:64]
33 | test_tasks = tasks[64:128]
34 | train_tasks = tasks[128:]
35 | random.seed(42)
36 | random.shuffle(train_tasks)
37 |
38 | # GRPO params
39 | wandb.config["clip_epsilon"] = clip_epsilon = 0.2
40 | wandb.config["entropy_coef"] = entropy_coef = 0.0
41 | wandb.config["kl_coef"] = kl_coef = 0.0
42 | wandb.config["tanh"] = tanh = False
43 |
44 | # Model params
45 | model = models.qwen_32b()
46 | wandb.config["model"] = model.base_model
47 | tokenizer = AutoTokenizer.from_pretrained(model.base_model)
48 | wandb.config["seq_len"] = seq_len = 16384
49 |
50 | # Optimizer params
51 | wandb.config["lr"] = lr = 6e-6
52 | wandb.config["betas"] = betas = (0.9, 0.99)
53 | wandb.config["weight_decay"] = weight_decay = 0.1
54 |
55 | # Training params
56 | num_iterations = 1_000
57 | wandb.config["samples_per_task"] = samples_per_task = 50
58 | wandb.config["tasks_per_iter"] = tasks_per_iter = 32
59 | wandb.config["stride"] = stride = 32
60 | output_dir = f"./models/{run_name}"
61 |
62 | # Inference params
63 | expected_tokens = 1000 # Initial expected completion tokens per task sample
64 | inference_early_stop = InferenceEarlyStop(alpha=0.992, threshold=-3.0)
65 |
66 | # Logging params
67 | verbosity: Verbosity = 2
68 |
69 | # Start from the latest iteration if it exists, otherwise start from the base model
70 | model_name = get_last_iteration_dir(output_dir) or model.base_model
71 |
72 |
73 | async def train() -> None:
74 | global expected_tokens, model_name
75 | # Loop from the current iteration to the target number of iterations
76 | for i in range(get_iteration(output_dir), num_iterations):
77 | # Start vLLM server
78 | vllm = await start_vllm(
79 | model_name,
80 | max_concurrent_requests=4096,
81 | env={"VLLM_ALLOW_LONG_MAX_MODEL_LEN": "1"},
82 | named_arguments=dict(
83 | block_size=32,
84 | disable_log_requests=True,
85 | enable_prefix_caching=True,
86 | enforce_eager=True,
87 | gpu_memory_utilization=0.95,
88 | max_model_len=16384,
89 | max_num_seqs=4096,
90 | max_num_batched_tokens=16384,
91 | num_scheduler_steps=16,
92 | preemption_mode="swap",
93 | return_tokens_as_token_ids=True,
94 | swap_space=80,
95 | tensor_parallel_size=torch.cuda.device_count(),
96 | ),
97 | timeout=360 + 15 * torch.cuda.device_count(),
98 | verbosity=verbosity,
99 | )
100 |
101 | # Create semaphore for rate limiting
102 | semaphore = asyncio.Semaphore(
103 | int(
104 | 1.3
105 | * (torch.cuda.device_count() / model.min_gpus)
106 | * (vllm.max_concurrent_tokens / expected_tokens)
107 | )
108 | )
109 |
110 | # Get results for logging validation performance and for tuning with train results
111 | offset = i * stride
112 | val_results, train_results = await asyncio.gather(
113 | get_task_results(
114 | tasks=val_tasks,
115 | client=vllm.client,
116 | model=vllm.model,
117 | log_results=8,
118 | n=2,
119 | on_chunk=inference_early_stop,
120 | params=ChatCompletionParams(
121 | stream_options={
122 | "include_usage": True,
123 | },
124 | max_completion_tokens=8192,
125 | ),
126 | pbar_desc="val",
127 | pbar_position=0,
128 | semaphore=semaphore,
129 | ),
130 | get_task_results(
131 | tasks=list(islice(cycle(train_tasks), offset, offset + tasks_per_iter)),
132 | client=vllm.client,
133 | model=vllm.model,
134 | log_results=False,
135 | n=samples_per_task,
136 | on_chunk=inference_early_stop,
137 | params=ChatCompletionParams(
138 | stream_options={
139 | "include_usage": True,
140 | },
141 | max_completion_tokens=8192,
142 | ),
143 | pbar_desc="train",
144 | pbar_position=1,
145 | semaphore=semaphore,
146 | transform=TaskResultTokenizer(tokenizer),
147 | ),
148 | )
149 |
150 | # Stop vLLM workers
151 | vllm.process.terminate()
152 | kill_vllm_workers()
153 |
154 | # Log results to Weights & Biases
155 | val_stats = val_results.stats
156 | assert val_stats.grades > 0
157 | assert val_stats.usages > 0
158 | wandb_data = {
159 | "iteration": i,
160 | "exceptions": val_stats.exceptions + train_results.stats.exceptions,
161 | "reward": val_stats.total_reward / val_stats.grades,
162 | "tokens": round(val_stats.completion_tokens / val_stats.usages),
163 | }
164 | for metric in val_stats.total_metrics:
165 | wandb_data[metric] = val_stats.total_metrics[metric] / val_stats.grades
166 | try:
167 | wandb_data.update(
168 | pl.DataFrame(last_tune_log(output_dir))
169 | .drop("step")
170 | .mean()
171 | .to_dicts()[0]
172 | )
173 | except Exception:
174 | pass
175 | wandb.log(wandb_data)
176 |
177 | # Update expected tokens
178 | expected_tokens = wandb_data["tokens"]
179 |
180 | # Clean up output directory to save space
181 | try:
182 | best_iteration = (
183 | wandb.Api()
184 | .run(f"{run.entity}/{run.project}/{run.id}")
185 | .history()
186 | .sort_values(by="reward")["iteration"]
187 | .iloc[-1]
188 | )
189 | # Clear all but the best and current iterations
190 | clear_iteration_dirs(output_dir, excluding=[best_iteration, i])
191 | except Exception:
192 | pass
193 |
194 | # Pack the tokenized results into tensors
195 | tokenized_results = [
196 | result
197 | for results in train_results
198 | for result in results
199 | if result.advantage != 0
200 | ]
201 | packed_tensors = packed_tensors_from_tokenized_results(
202 | tokenized_results,
203 | seq_len=seq_len,
204 | pad_token_id=tokenizer.pad_token_id, # type: ignore
205 | )
206 | if verbosity > 0:
207 | print(f"Packed tensors into {packed_tensors["tokens"].size()} shape")
208 |
209 | # Tune the model
210 | model_name = await tune(
211 | base_model=model.base_model if kl_coef > 0 else model_name,
212 | output_dir=output_dir,
213 | packed_tensors=packed_tensors,
214 | model=model.tune_model,
215 | model_type=model.tune_model_type,
216 | config=TuneRecipeConfig(
217 | optimizer=ComponentConfig(
218 | "torch.optim.AdamW",
219 | lr=lr,
220 | betas=betas,
221 | weight_decay=weight_decay,
222 | fused=True,
223 | ),
224 | loss=ComponentConfig(
225 | GRPO,
226 | clip_epsilon=clip_epsilon,
227 | entropy_coef=entropy_coef,
228 | kl_coef=kl_coef,
229 | ),
230 | shuffle=True,
231 | batch_size=32768 // seq_len,
232 | fsdp_cpu_offload=True,
233 | enable_activation_checkpointing=True,
234 | enable_activation_offloading=True,
235 | custom_sharded_layers=["tok_embeddings", "output"],
236 | num_output_chunks=model.tune_num_output_chunks,
237 | compile=True,
238 | ),
239 | verbosity=verbosity,
240 | )
241 |
242 |
243 | asyncio.run(train())
244 |
--------------------------------------------------------------------------------