├── README.md ├── __pycache__ └── train.cpython-310.pyc ├── image.webp ├── o1-paper.pdf ├── o1_model.pth ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # O1 Nano 2 | 3 | ![O1-nano Logo](image.webp) 4 | 5 | ## Overview 6 | 7 | This project implements a simplified version of the O1 model, inspired by OpenAI's research. The O1 model is an advanced language model that integrates chain-of-thought reasoning with reinforcement learning during both training and inference. This implementation, called O1-nano, focuses on arithmetic problem-solving as a demonstration of the model's capabilities. Based of [this](https://youtu.be/sf7Ntg72qCI) video on Youtube by Siraj Raval. 8 | 9 | ## Key Features 10 | 11 | 1. **Chain-of-Thought Reasoning**: The model generates both completion tokens and internal reasoning tokens, simulating a thought process. 12 | 2. **Reinforcement Learning**: Uses Proximal Policy Optimization (PPO) for training. 13 | 3. **Multiple Reasoning Paths**: Explores multiple paths and selects the best one during generation. 14 | 4. **Subtask Generation**: Capable of breaking down complex problems into smaller subtasks. 15 | 5. **Adaptive Reasoning**: Includes mechanisms for revising reasoning during the generation process. 16 | 6. **Large Context Window**: Supports a context window of up to 128,000 tokens. 17 | 7. **Internal Reasoning Tokens**: Implements discardable internal tokens for reasoning. 18 | 19 | ## Files 20 | 21 | - `train.py`: Contains the model architecture, training loop, and utility functions. 22 | - `test.py`: Provides a simple interface for interacting with a trained model. 23 | 24 | ## Model Architecture 25 | 26 | The O1Model class in `train.py` defines the model architecture: 27 | 28 | - Embedding layer 29 | - Positional encoding 30 | - Multiple transformer layers 31 | - Separate decoders for completion and reasoning 32 | - Value head for reinforcement learning 33 | - Subtask generation head 34 | 35 | ## Training Process 36 | 37 | The training process combines supervised learning and reinforcement learning: 38 | 39 | 1. **Data Generation**: Arithmetic problems are generated on-the-fly. 40 | 2. **Supervised Learning**: The model is trained to predict correct solutions and reasoning steps. 41 | 3. **Reinforcement Learning**: PPO is used to optimize the model's policy based on rewards. 42 | 4. **Dynamic Curriculum**: Problem difficulty is adjusted based on the training progress. 43 | 44 | ## Usage 45 | 46 | ### Training 47 | 48 | To train the model, run: 49 | 50 | ```bash 51 | python train.py 52 | ``` 53 | 54 | This will train the model for 500 epochs and save it as `o1_model.pth`. 55 | 56 | ### Testing 57 | 58 | To interact with a trained model, run: 59 | 60 | ```bash 61 | python test.py 62 | ``` 63 | 64 | This will load the trained model and allow you to input arithmetic problems for the model to solve. 65 | 66 | ## Requirements 67 | 68 | - Python 3.7+ 69 | - PyTorch 1.8+ 70 | 71 | ## Model Parameters 72 | 73 | - Embedding dimension: 128 74 | - Number of attention heads: 8 75 | - Number of transformer layers: 4 76 | - Dropout rate: 0.1 77 | 78 | ## Vocabulary 79 | 80 | The model uses a custom vocabulary tailored for arithmetic operations, including special tokens for subtasks and reasoning steps. 81 | 82 | ## Evaluation 83 | 84 | The model is evaluated based on its ability to correctly solve arithmetic problems. The evaluation metrics include average reward and the number of valid samples processed. 85 | 86 | ## Limitations and Future Work 87 | 88 | 1. **Scale**: This implementation is smaller than the actual O1 model described by OpenAI. 89 | 2. **Task Diversity**: Currently focused on arithmetic; could be expanded to more diverse tasks. 90 | 3. **Self-Correction**: The self-correction mechanism could be made more sophisticated. 91 | 4. **Dynamic Curriculum**: The difficulty adjustment could be more adaptive to the model's performance. 92 | 93 | ## Contributing 94 | 95 | Contributions to improve the model or expand its capabilities are welcome. Please submit pull requests or open issues for any bugs or feature requests. 96 | 97 | ## Acknowledgements 98 | 99 | This implementation is inspired by OpenAI's research on the O1 model. It is a simplified version 100 | -------------------------------------------------------------------------------- /__pycache__/train.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/O1-nano/a01c1a83b2c6ed945639d4b4813bd0eb7c0dabd9/__pycache__/train.cpython-310.pyc -------------------------------------------------------------------------------- /image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/O1-nano/a01c1a83b2c6ed945639d4b4813bd0eb7c0dabd9/image.webp -------------------------------------------------------------------------------- /o1-paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/O1-nano/a01c1a83b2c6ed945639d4b4813bd0eb7c0dabd9/o1-paper.pdf -------------------------------------------------------------------------------- /o1_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/O1-nano/a01c1a83b2c6ed945639d4b4813bd0eb7c0dabd9/o1_model.pth -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from train import O1Model, vocab, tokenize, detokenize, vocab_size 3 | 4 | def load_model(model_path): 5 | # Load the state dict 6 | state_dict = torch.load(model_path) 7 | 8 | # Infer model parameters from the state dict 9 | d_model = state_dict['embed.weight'].shape[1] 10 | num_layers = max([int(key.split('.')[1]) for key in state_dict.keys() if key.startswith('transformer_layers.')]) + 1 11 | nhead = state_dict['transformer_layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) 12 | 13 | print(f"Inferred model parameters: d_model={d_model}, num_layers={num_layers}, nhead={nhead}") 14 | 15 | # Create the model with inferred parameters 16 | model = O1Model(vocab_size, d_model, nhead, num_layers) 17 | 18 | # Load the state dict 19 | model.load_state_dict(state_dict, strict=False) 20 | model.eval() 21 | return model 22 | 23 | def chat_with_model(model): 24 | print("Welcome to the O1 Model Arithmetic Solver!") 25 | print("You can ask arithmetic questions like:") 26 | print("- Calculate the sum of 5 and 7") 27 | print("- Calculate the difference between 15 and 8") 28 | print("- Calculate the product of 6 and 4") 29 | print("- Calculate the quotient of 20 and 5") 30 | print("Type 'quit' to exit.") 31 | 32 | while True: 33 | user_input = input("\nEnter your question: ") 34 | if user_input.lower() == 'quit': 35 | break 36 | 37 | input_ids = torch.tensor([tokenize(user_input)]) 38 | completion_tokens, reasoning_tokens, subtasks = model.generate_completion(input_ids, max_new_tokens=50) 39 | 40 | print("\nModel's thought process:") 41 | print("Reasoning:", detokenize(reasoning_tokens)) 42 | print("Subtasks:") 43 | for i, subtask in enumerate(subtasks, 1): 44 | print(f" {i}. {detokenize(subtask)}") 45 | 46 | print("\nModel's response:") 47 | print(detokenize(completion_tokens)) 48 | 49 | if __name__ == "__main__": 50 | model_path = "o1_model.pth" # Make sure this path is correct 51 | try: 52 | model = load_model(model_path) 53 | print(f"Model loaded successfully. Number of layers: {len(model.transformer_layers)}") 54 | chat_with_model(model) 55 | except FileNotFoundError: 56 | print(f"Error: Model file '{model_path}' not found.") 57 | print("Make sure you have trained the model and saved it with the correct filename.") 58 | except Exception as e: 59 | print(f"An error occurred: {e}") 60 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.distributions import Categorical 6 | import math 7 | import random 8 | 9 | # Constants 10 | CONTEXT_WINDOW_SIZE = 128000 11 | MAX_OUTPUT_TOKENS_PREVIEW = 32768 12 | MAX_OUTPUT_TOKENS_MINI = 65536 13 | 14 | # Set random seeds for reproducibility 15 | torch.manual_seed(0) 16 | random.seed(0) 17 | 18 | class PositionalEncoding(nn.Module): 19 | def __init__(self, d_model, max_len=5000): 20 | super(PositionalEncoding, self).__init__() 21 | pe = torch.zeros(max_len, d_model) 22 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 23 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 24 | pe[:, 0::2] = torch.sin(position * div_term) 25 | pe[:, 1::2] = torch.cos(position * div_term) 26 | pe = pe.unsqueeze(0) 27 | self.register_buffer('pe', pe) 28 | 29 | def forward(self, x): 30 | return x + self.pe[:, :x.size(1)] 31 | 32 | class TransformerBlock(nn.Module): 33 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): 34 | super(TransformerBlock, self).__init__() 35 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 36 | self.feed_forward = nn.Sequential( 37 | nn.Linear(d_model, dim_feedforward), 38 | nn.ReLU(), 39 | nn.Linear(dim_feedforward, d_model) 40 | ) 41 | self.norm1 = nn.LayerNorm(d_model) 42 | self.norm2 = nn.LayerNorm(d_model) 43 | self.dropout = nn.Dropout(dropout) 44 | 45 | def forward(self, x): 46 | # Ensure x has the correct shape (batch_size, seq_len, d_model) 47 | if x.dim() == 2: 48 | x = x.unsqueeze(0) # Add batch dimension if missing 49 | elif x.dim() == 4: 50 | x = x.squeeze(2) # Remove extra dimension if present 51 | 52 | attn_output, _ = self.self_attn(x, x, x) 53 | x = x + self.dropout(attn_output) 54 | x = self.norm1(x) 55 | ff_output = self.feed_forward(x) 56 | x = x + self.dropout(ff_output) 57 | x = self.norm2(x) 58 | return x 59 | 60 | class O1Model(nn.Module): 61 | def __init__(self, vocab_size, d_model, nhead, num_layers, is_mini=False): 62 | super(O1Model, self).__init__() 63 | self.vocab_size = vocab_size 64 | self.d_model = d_model 65 | self.embed = nn.Embedding(vocab_size, d_model) 66 | self.pos_encoder = PositionalEncoding(d_model) 67 | self.transformer_layers = nn.ModuleList([TransformerBlock(d_model, nhead) for _ in range(num_layers)]) 68 | self.completion_decoder = nn.Linear(d_model, vocab_size) 69 | self.reasoning_decoder = nn.Linear(d_model, vocab_size) 70 | self.value_head = nn.Linear(d_model, 1) 71 | self.subtask_head = nn.Linear(d_model, 1) 72 | self.is_mini = is_mini 73 | self.max_reasoning_tokens = 1000 74 | 75 | def forward(self, src, reasoning_tokens=None, generate_reasoning=True): 76 | if src.dim() == 1: 77 | src = src.unsqueeze(0) 78 | elif src.dim() == 3: 79 | src = src.squeeze(1) 80 | 81 | if src.size(1) == 0: 82 | print(f"Warning: Empty input tensor in forward pass. Shape: {src.shape}") 83 | batch_size = src.size(0) 84 | return torch.zeros(batch_size, 1, self.vocab_size), torch.zeros(batch_size, 1, self.vocab_size), torch.zeros(batch_size, 1) 85 | 86 | src = self.embed(src) 87 | if reasoning_tokens is not None: 88 | reasoning_embeddings = self.embed(reasoning_tokens) 89 | src = torch.cat([src, reasoning_embeddings], dim=1) 90 | 91 | src = self.pos_encoder(src) 92 | 93 | for layer in self.transformer_layers: 94 | src = layer(src) 95 | 96 | completion_logits = self.completion_decoder(src) 97 | values = self.value_head(src).squeeze(-1) 98 | 99 | if generate_reasoning: 100 | reasoning_logits = self.reasoning_decoder(src) 101 | return completion_logits, reasoning_logits, values 102 | else: 103 | return completion_logits, values 104 | 105 | def generate_completion(self, input_ids, max_new_tokens, num_paths=3): 106 | max_tokens = MAX_OUTPUT_TOKENS_MINI if self.is_mini else MAX_OUTPUT_TOKENS_PREVIEW 107 | max_new_tokens = min(max_new_tokens, max_tokens) 108 | 109 | if input_ids.dim() == 1: 110 | input_ids = input_ids.unsqueeze(0) 111 | elif input_ids.dim() == 3: 112 | input_ids = input_ids.squeeze(1) 113 | 114 | paths = [] 115 | for _ in range(num_paths): 116 | generated = input_ids.clone() 117 | reasoning_tokens = torch.tensor([], dtype=torch.long, device=input_ids.device) 118 | completion_tokens = [] 119 | subtasks = [] 120 | 121 | for _ in range(max_new_tokens): 122 | if generated.size(1) + reasoning_tokens.size(0) >= CONTEXT_WINDOW_SIZE: 123 | break 124 | 125 | completion_logits, reasoning_logits, values = self(generated, reasoning_tokens) 126 | 127 | if completion_logits.numel() == 0: 128 | print(f"Warning: completion_logits is empty. Input shape: {generated.shape}") 129 | break 130 | 131 | next_token_logits = completion_logits[:, -1, :] 132 | next_token = self.sample_token(next_token_logits) 133 | 134 | reasoning_token = self.sample_token(reasoning_logits[:, -1, :]) 135 | reasoning_tokens = torch.cat([reasoning_tokens, reasoning_token.unsqueeze(0)]) 136 | 137 | if reasoning_tokens.size(0) > self.max_reasoning_tokens: 138 | reasoning_tokens = reasoning_tokens[-self.max_reasoning_tokens:] 139 | 140 | last_hidden = self.embed(generated[:, -1]) 141 | subtask_prob = torch.sigmoid(self.subtask_head(last_hidden)) 142 | if subtask_prob > 0.5: 143 | subtask = self.generate_subtask(generated, reasoning_tokens) 144 | subtasks.append(subtask) 145 | generated = torch.cat([generated, torch.tensor([[vocab['']]]).to(generated.device)], dim=1) 146 | else: 147 | generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1) 148 | completion_tokens.append(next_token.item()) 149 | 150 | if self.should_revise_reasoning(): 151 | generated, reasoning_tokens = self.revise_reasoning(generated, reasoning_tokens) 152 | 153 | if next_token.item() == vocab['']: 154 | break 155 | 156 | paths.append((completion_tokens, reasoning_tokens.tolist(), subtasks)) 157 | 158 | if not paths: 159 | print("Warning: No valid paths generated") 160 | return [], [], [] 161 | 162 | rewards = [self.compute_reward(p[0], p[1], p[2]) for p in paths] 163 | best_path = paths[rewards.index(max(rewards))] 164 | 165 | return best_path[0], best_path[1], best_path[2] 166 | 167 | def sample_token(self, logits, temperature=0.7): 168 | probs = F.softmax(logits / temperature, dim=-1) 169 | return torch.multinomial(probs, 1).squeeze(-1) 170 | 171 | def add_reasoning_token(self, token): 172 | self.reasoning_buffer.append(token) 173 | if len(self.reasoning_buffer) > self.max_reasoning_tokens: 174 | self.reasoning_buffer.pop(0) 175 | 176 | def should_revise_reasoning(self): 177 | # Implement logic to decide if reasoning should be revised 178 | return random.random() < 0.1 # 10% chance of revision for demonstration 179 | 180 | def revise_reasoning(self, generated, reasoning_tokens): 181 | # Implement logic to revise reasoning 182 | # For demonstration, we'll just remove the last few tokens from both 183 | return generated[:, :-5], reasoning_tokens[:-5] 184 | 185 | def generate_subtask(self, context, reasoning_tokens): 186 | subtask_tokens = [] 187 | for _ in range(20): # Max subtask length 188 | logits, _, _ = self(context, reasoning_tokens) 189 | next_token = torch.argmax(logits[:, -1, :], dim=-1) 190 | subtask_tokens.append(next_token.item()) 191 | context = torch.cat([context, next_token.unsqueeze(1)], dim=1) 192 | if next_token.item() == vocab['']: 193 | break 194 | return subtask_tokens 195 | 196 | def compute_reward(self, completion_tokens, reasoning_tokens, subtasks): 197 | completion_reward = len(completion_tokens) * 0.1 198 | reasoning_reward = len(set(reasoning_tokens)) * 0.2 199 | subtask_reward = len(subtasks) * 0.5 200 | coherence_reward = self.compute_coherence(completion_tokens) 201 | process_reward = self.compute_process_reward(reasoning_tokens) 202 | return completion_reward + reasoning_reward + subtask_reward + coherence_reward + process_reward 203 | 204 | def compute_coherence(self, tokens): 205 | # Simple coherence check (can be made more sophisticated) 206 | return sum(1 for i in range(len(tokens)-1) if tokens[i] + 1 == tokens[i+1]) * 0.1 207 | 208 | def compute_process_reward(self, reasoning_tokens): 209 | # Implement a more sophisticated process reward 210 | unique_tokens = len(set(reasoning_tokens)) 211 | return unique_tokens * 0.1 # Reward diverse reasoning 212 | 213 | class PPO: 214 | def __init__(self, model, optimizer, clip_epsilon=0.2, value_coef=0.5, entropy_coef=0.01): 215 | self.model = model 216 | self.optimizer = optimizer 217 | self.clip_epsilon = clip_epsilon 218 | self.value_coef = value_coef 219 | self.entropy_coef = entropy_coef 220 | 221 | def compute_advantages(self, rewards, values, gamma=0.99, lambda_=0.95): 222 | advantages = torch.zeros_like(rewards) 223 | last_advantage = 0 224 | 225 | # Make sure to only iterate through the valid range 226 | for t in reversed(range(len(rewards))): 227 | if t + 1 < len(values): 228 | delta = rewards[t] + gamma * values[t + 1] - values[t] 229 | else: 230 | delta = rewards[t] - values[t] 231 | 232 | advantages[t] = delta + gamma * lambda_ * last_advantage 233 | last_advantage = advantages[t] 234 | 235 | returns = advantages + values[:len(advantages)] 236 | return advantages, returns 237 | 238 | def update(self, states, actions, old_log_probs, rewards, old_values): 239 | # Reshape states if necessary 240 | if states.dim() == 2: 241 | batch_size, seq_len = states.shape 242 | states = states.unsqueeze(0) # Add a dimension to make it [1, batch_size, seq_len] 243 | else: 244 | num_steps, batch_size, seq_len = states.shape 245 | 246 | # Flatten other tensors 247 | actions_flat = actions.view(-1) 248 | old_log_probs_flat = old_log_probs.view(-1) 249 | advantages, returns = self.compute_advantages(rewards, old_values) 250 | advantages_flat = advantages.view(-1) 251 | returns_flat = returns.view(-1) 252 | 253 | for _ in range(5): # PPO epochs 254 | logits, _, values = self.model(states.view(-1, seq_len)) 255 | 256 | # Focus on the logits of the last token in the sequence 257 | next_token_logits = logits[:, -1, :] 258 | new_probs = F.softmax(next_token_logits, dim=-1) 259 | dist = Categorical(new_probs) 260 | 261 | # Ensure actions_flat matches the shape of new_probs 262 | actions_flat_truncated = actions_flat[:new_probs.size(0)] 263 | old_log_probs_flat_truncated = old_log_probs_flat[:new_probs.size(0)] 264 | advantages_flat_truncated = advantages_flat[:new_probs.size(0)] 265 | returns_flat_truncated = returns_flat[:new_probs.size(0)] 266 | 267 | # Calculate new log probabilities 268 | new_log_probs = dist.log_prob(actions_flat_truncated) 269 | 270 | # Calculate probability ratio 271 | ratio = torch.exp(new_log_probs - old_log_probs_flat_truncated) 272 | surr1 = ratio * advantages_flat_truncated 273 | surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages_flat_truncated 274 | 275 | # Compute losses 276 | actor_loss = -torch.min(surr1, surr2).mean() 277 | 278 | # Extract the value of the last token in each sequence 279 | values_last = values[:, -1].view(-1) 280 | critic_loss = nn.MSELoss()(values_last, returns_flat_truncated) 281 | 282 | entropy = dist.entropy().mean() 283 | 284 | # Total loss 285 | loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy 286 | 287 | # Backpropagation 288 | self.optimizer.zero_grad() 289 | loss.backward() 290 | self.optimizer.step() 291 | 292 | # Enhanced vocabulary 293 | vocab = { 294 | '': 0, '': 1, '': 2, 'Step:': 3, '+': 4, '-': 5, '*': 6, '/': 7, '=': 8, 295 | '0': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, '9': 18, 296 | 'if': 19, 'then': 20, 'else': 21, 'greater': 22, 'less': 23, 'equal': 24, 297 | 'Calculate': 25, 'the': 26, 'sum': 27, 'of': 28, 'and': 29, 298 | 'difference': 30, 'between': 31, 'product': 32, 'quotient': 33, 299 | 'First,': 34, 'Next,': 35, 'Finally,': 36, 'result': 37, 'is': 38, 300 | '': 39 # New token for subtask generation 301 | } 302 | vocab_size = len(vocab) 303 | inv_vocab = {v: k for k, v in vocab.items()} 304 | 305 | def tokenize(text): 306 | return [vocab.get(token, vocab['']) for token in text.strip().split()] 307 | 308 | def detokenize(indices): 309 | return ' '.join([inv_vocab.get(idx, ' ') for idx in indices]) 310 | 311 | # Update the compute_reward function 312 | def compute_reward(state, target_result): 313 | generated_tokens = state[:, -1].cpu().numpy() 314 | rewards = [] 315 | for tokens in generated_tokens: 316 | try: 317 | generated_text = detokenize(tokens) 318 | if "result is" in generated_text: 319 | result_str = generated_text.split("result is")[-1].strip() 320 | result = int(result_str) if result_str.isdigit() else float(result_str) 321 | if abs(result - target_result) < 1e-6: # Allow for small floating-point differences 322 | rewards.append(1.0) 323 | elif abs(result - target_result) < 5: # Close answer 324 | rewards.append(0.5) 325 | elif abs(result - target_result) < 10: # Somewhat close answer 326 | rewards.append(0.2) 327 | else: 328 | rewards.append(-0.2) 329 | else: 330 | rewards.append(0.0) # Neutral reward for incomplete answers 331 | except: 332 | rewards.append(-0.5) # Penalize malformed outputs 333 | return torch.tensor(rewards) 334 | 335 | # Generate arithmetic problems 336 | def generate_arithmetic_problem(): 337 | operations = ['+', '-', '*', '/'] 338 | op = random.choice(operations) 339 | 340 | while True: 341 | if op in ['+', '-']: 342 | a, b = random.randint(1, 100), random.randint(1, 100) 343 | else: 344 | a, b = random.randint(1, 10), random.randint(1, 10) 345 | 346 | if op == '+': 347 | result = a + b 348 | problem = f"Calculate the sum of {a} and {b}" 349 | elif op == '-': 350 | result = a - b 351 | problem = f"Calculate the difference between {a} and {b}" 352 | elif op == '*': 353 | result = a * b 354 | problem = f"Calculate the product of {a} and {b}" 355 | else: 356 | if b != 0: # Avoid division by zero 357 | result = a // b 358 | problem = f"Calculate the quotient of {a} and {b}" 359 | else: 360 | continue # Try again if b is zero 361 | 362 | if problem and result: 363 | return problem, result 364 | 365 | # Generate reasoning chain 366 | def generate_reasoning_chain(problem, result): 367 | words = problem.split() 368 | operation = words[3] # "sum", "difference", "product", or "quotient" 369 | 370 | if operation == "sum": 371 | a, b = map(int, words[-3::2]) 372 | chain = f"Step: First, we identify the numbers: {a} and {b}. " 373 | chain += f"Next, we add these numbers: {a} + {b}. " 374 | chain += f"Finally, we get the result: The sum is {result}." 375 | elif operation == "difference": 376 | a, b = map(int, words[-3::2]) 377 | chain = f"Step: First, we identify the numbers: {a} and {b}. " 378 | chain += f"Next, we subtract the second number from the first: {a} - {b}. " 379 | chain += f"Finally, we get the result: The difference is {result}." 380 | elif operation == "product": 381 | a, b = map(int, words[-3::2]) 382 | chain = f"Step: First, we identify the numbers: {a} and {b}. " 383 | chain += f"Next, we multiply these numbers: {a} * {b}. " 384 | chain += f"Finally, we get the result: The product is {result}." 385 | else: # quotient 386 | a, b = map(int, words[-3::2]) 387 | chain = f"Step: First, we identify the numbers: {a} and {b}. " 388 | chain += f"Next, we divide the first number by the second: {a} / {b}. " 389 | chain += f"Finally, we get the result: The quotient is {result}." 390 | 391 | return chain 392 | 393 | # Modify collect_trajectories to use arithmetic problems 394 | def collect_trajectories(model, batch_size): 395 | states = [] 396 | actions = [] 397 | rewards = [] 398 | log_probs = [] 399 | values = [] 400 | 401 | max_state_length = 40 402 | 403 | for _ in range(batch_size): 404 | problem, result = generate_arithmetic_problem() 405 | reasoning_chain = generate_reasoning_chain(problem, result) 406 | 407 | input_ids = torch.tensor([tokenize(problem)]) 408 | target_ids = torch.tensor([tokenize(reasoning_chain)]) 409 | 410 | state = input_ids 411 | action_sequence = torch.full((1, max_state_length), vocab[''], dtype=torch.long) 412 | 413 | for t in range(max_state_length): 414 | if state.size(1) > max_state_length: 415 | state = state[:, :max_state_length] 416 | elif state.size(1) < max_state_length: 417 | padding = torch.full((1, max_state_length - state.size(1)), vocab[''], dtype=state.dtype) 418 | state = torch.cat([state, padding], dim=1) 419 | 420 | with torch.no_grad(): 421 | logits, _, value = model(state) 422 | probs = F.softmax(logits[:, -1, :], dim=-1) 423 | dist = Categorical(probs) 424 | action = dist.sample() 425 | log_prob = dist.log_prob(action) 426 | 427 | action_sequence[0, t] = action.item() 428 | log_probs.append(log_prob) 429 | values.append(value[:, -1]) 430 | 431 | state = torch.cat([state[:, :-1], action.unsqueeze(1)], dim=1) 432 | 433 | reward = compute_reward(state, result) 434 | rewards.append(reward) 435 | 436 | if action.item() == vocab['']: 437 | break 438 | 439 | states.append(state) 440 | actions.append(action_sequence) 441 | 442 | states = torch.cat(states, dim=0) 443 | actions = torch.cat(actions, dim=0) 444 | rewards = torch.cat(rewards, dim=0) 445 | log_probs = torch.cat(log_probs, dim=0) 446 | values = torch.cat(values, dim=0) 447 | 448 | return states, actions, rewards, log_probs, values 449 | 450 | # Update the training function 451 | def train_o1_model(model, optimizer, num_epochs, batch_size): 452 | ppo = PPO(model, optimizer) 453 | 454 | for epoch in range(num_epochs): 455 | # Generate a batch of arithmetic problems 456 | states, actions, rewards, old_log_probs, values = collect_trajectories(model, batch_size) 457 | 458 | # Supervised learning step 459 | sl_loss = supervised_finetuning_loss(model, (states, actions)) 460 | optimizer.zero_grad() 461 | sl_loss.backward() 462 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 463 | optimizer.step() 464 | 465 | # Reinforcement learning step 466 | ppo.update(states, actions, old_log_probs, rewards, values) 467 | 468 | # Evaluation and logging 469 | if epoch % 10 == 0: 470 | metrics = evaluate_model(model, batch_size) 471 | log_metrics(metrics, epoch) 472 | 473 | print(f'Epoch {epoch} completed') 474 | 475 | # Dynamic curriculum learning 476 | if epoch % 50 == 0: 477 | adjust_problem_difficulty(epoch) 478 | 479 | def log_metrics(metrics, epoch): 480 | print(f"Epoch {epoch} Metrics: {metrics}") 481 | 482 | def supervised_finetuning_loss(model, batch): 483 | states, actions = batch 484 | logits, _ = model(states, generate_reasoning=False) 485 | 486 | # Reshape logits to [batch_size * sequence_length, vocab_size] 487 | batch_size, seq_length, vocab_size = logits.shape 488 | logits = logits.view(-1, vocab_size) 489 | 490 | # Reshape actions to [batch_size * sequence_length] 491 | target_ids = actions.view(-1) 492 | 493 | # Ensure logits and target_ids have the same length 494 | min_length = min(logits.size(0), target_ids.size(0)) 495 | logits = logits[:min_length] 496 | target_ids = target_ids[:min_length] 497 | 498 | # Compute loss only on non-padded tokens 499 | non_pad_mask = target_ids != vocab[''] 500 | logits = logits[non_pad_mask] 501 | target_ids = target_ids[non_pad_mask] 502 | 503 | loss = F.cross_entropy(logits, target_ids) 504 | return loss 505 | 506 | # Update evaluation function 507 | def evaluate_model(model, batch_size): 508 | model.eval() 509 | total_reward = 0 510 | valid_samples = 0 511 | with torch.no_grad(): 512 | for _ in range(batch_size): 513 | try: 514 | problem, result = generate_arithmetic_problem() 515 | input_ids = torch.tensor([tokenize(problem)]) 516 | if input_ids.numel() == 0: 517 | print(f"Warning: Empty input tensor for problem: {problem}") 518 | continue 519 | completion_tokens, reasoning_tokens, subtasks = model.generate_completion(input_ids, max_new_tokens=50) 520 | if completion_tokens: 521 | reward = compute_reward(torch.tensor([completion_tokens]), result) 522 | total_reward += reward.item() 523 | valid_samples += 1 524 | else: 525 | print(f"Warning: Empty output for problem: {problem}") 526 | except Exception as e: 527 | print(f"Error during evaluation: {e}") 528 | model.train() 529 | avg_reward = total_reward / valid_samples if valid_samples > 0 else 0 530 | return {"average_reward": avg_reward, "valid_samples": valid_samples} 531 | 532 | def adjust_problem_difficulty(epoch): 533 | # Implement dynamic difficulty adjustment based on model performance 534 | global problem_difficulty 535 | if epoch < 100: 536 | problem_difficulty = "easy" 537 | elif epoch < 300: 538 | problem_difficulty = "medium" 539 | else: 540 | problem_difficulty = "hard" 541 | 542 | if __name__ == "__main__": 543 | # Model parameters 544 | d_model = 128 545 | nhead = 8 546 | num_layers = 4 547 | dropout = 0.1 548 | 549 | # Initialize the model 550 | model = O1Model(vocab_size, d_model, nhead, num_layers) 551 | optimizer = optim.Adam(model.parameters(), lr=5e-4) 552 | 553 | # Training parameters 554 | num_epochs = 500 555 | batch_size = 64 556 | 557 | # Train the model 558 | train_o1_model(model, optimizer, num_epochs, batch_size) 559 | 560 | # Save the model 561 | torch.save(model.state_dict(), "o1_model.pth") --------------------------------------------------------------------------------