├── LICENSE ├── README.md ├── config └── default.json ├── docs └── api.md ├── examples ├── azure_bot_integration.py ├── gradio_integration.py ├── langchain_integration.py ├── llama_cpp_integration.py ├── notebooks │ ├── calculus_examples.ipynb │ ├── geometry_examples.ipynb │ ├── integration_examples.ipynb │ ├── linear_algebra_examples.ipynb │ ├── model_comparison.ipynb │ └── statistics_examples.ipynb ├── rasa_integration.py └── streamlit_integration.py ├── rStar-Math Paper.pdf ├── requirements.txt ├── src ├── __init__.py ├── api │ └── main.py ├── core │ ├── mcts.py │ └── ppm.py ├── dashboard │ └── app.py ├── models │ ├── gemini_model.py │ ├── groq_model.py │ ├── mistral_model.py │ └── model_interface.py └── utils │ └── helpers.py ├── tests ├── test_api.py ├── test_core.py ├── test_integration.py ├── test_models.py └── test_new_models.py └── tools └── benchmark.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 AI in PM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rStar-Math Demonstrator 2 | 3 | An AI Agent that demonstrates the principles and performance of the rStar-Math framework, with capabilities to generate integration code for other chatbots and AI agents. 4 | 5 | The development of this GitHub Repository was inspired by the "rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking" paper. To read the full paper, visit https://arxiv.org/pdf/2501.04519 6 | 7 | ## Features 8 | 9 | - **Core Components** 10 | - Monte Carlo Tree Search (MCTS) for step-by-step reasoning 11 | - Process Preference Model (PPM) for evaluating solution quality 12 | - Flexible model interface supporting multiple LLMs 13 | 14 | - **Model Support** 15 | - OpenAI (GPT-4, GPT-3.5) 16 | - Anthropic (Claude) 17 | - Mistral AI 18 | - Groq 19 | - Google Gemini 20 | - Local models via llama.cpp 21 | 22 | - **Integration Templates** 23 | - Rasa chatbot framework 24 | - LangChain 25 | - Azure Bot Framework 26 | - Streamlit 27 | - Gradio 28 | 29 | - **Example Notebooks** 30 | - Calculus with visualizations 31 | - Geometry and proofs 32 | - Linear algebra operations 33 | - Statistics and probability 34 | - Model comparison studies 35 | 36 | - **Development Tools** 37 | - Comprehensive test suite 38 | - Performance benchmarking 39 | - Visualization components 40 | - API documentation 41 | 42 | ## Installation 43 | 44 | ### Option 1: Install from PyPI 45 | 46 | ```bash 47 | pip install rstar-math 48 | ``` 49 | 50 | ### Option 2: Install from Source 51 | 52 | 1. Clone the repository: 53 | ```bash 54 | git clone https://github.com/yourusername/rStar-Math.git 55 | cd rStar-Math 56 | ``` 57 | 58 | 2. Create a virtual environment: 59 | ```bash 60 | python -m venv venv 61 | 62 | # Windows 63 | venv\Scripts\activate 64 | 65 | # Unix/MacOS 66 | source venv/bin/activate 67 | ``` 68 | 69 | 3. Install dependencies: 70 | ```bash 71 | pip install -r requirements.txt 72 | ``` 73 | 74 | 4. Install in development mode: 75 | ```bash 76 | pip install -e . 77 | ``` 78 | 79 | ### Setting up API Keys 80 | 81 | Create a `.env` file in the project root: 82 | ```bash 83 | OPENAI_API_KEY=your_openai_key 84 | ANTHROPIC_API_KEY=your_anthropic_key 85 | MISTRAL_API_KEY=your_mistral_key 86 | GROQ_API_KEY=your_groq_key 87 | GEMINI_API_KEY=your_gemini_key 88 | ``` 89 | 90 | ## Running the Project 91 | 92 | ### 1. Run Interactive Demos 93 | 94 | #### Gradio Interface 95 | ```bash 96 | python examples/gradio_integration.py 97 | ``` 98 | 99 | #### Streamlit Dashboard 100 | ```bash 101 | streamlit run examples/streamlit_integration.py 102 | ``` 103 | 104 | ### 2. Run Example Notebooks 105 | 106 | ```bash 107 | # Start Jupyter server 108 | jupyter lab 109 | 110 | # Navigate to examples/notebooks/ 111 | # Open any of: 112 | # - calculus_examples.ipynb 113 | # - geometry_examples.ipynb 114 | # - linear_algebra_examples.ipynb 115 | # - statistics_examples.ipynb 116 | ``` 117 | 118 | ### 3. Run Tests 119 | 120 | ```bash 121 | # Run all tests 122 | pytest tests/ 123 | 124 | # Run specific test suite 125 | pytest tests/test_new_models.py 126 | 127 | # Run with coverage report 128 | pytest --cov=src tests/ 129 | ``` 130 | 131 | ### 4. Run Benchmarks 132 | 133 | ```bash 134 | # Run full benchmark suite 135 | python tools/benchmark.py 136 | 137 | # View results in browser 138 | python -m http.server 8000 139 | # Open http://localhost:8000/benchmark_results/ 140 | ``` 141 | 142 | ### 5. Framework Integrations 143 | 144 | #### Rasa Integration 145 | ```bash 146 | # In your Rasa project 147 | pip install rstar-math 148 | cp examples/rasa_integration.py actions/ 149 | ``` 150 | 151 | #### LangChain Integration 152 | ```python 153 | from examples.langchain_integration import RStarMathChain 154 | chain = RStarMathChain() 155 | ``` 156 | 157 | #### Azure Bot Integration 158 | ```bash 159 | # In your Azure Bot project 160 | pip install rstar-math 161 | cp examples/azure_bot_integration.py bot/ 162 | ``` 163 | 164 | ### 6. Local Model Setup 165 | 166 | 1. Download a compatible model: 167 | ```bash 168 | # Example: Download LLaMA model 169 | wget https://huggingface.co/models/llama-7b/resolve/main/model.bin -O models/llama-7b.bin 170 | ``` 171 | 172 | 2. Run with local model: 173 | ```python 174 | from examples.llama_cpp_integration import LlamaCppModel 175 | model = LlamaCppModel("models/llama-7b.bin") 176 | ``` 177 | 178 | ## Quick Start 179 | 180 | ```python 181 | from rstar_math.core import MCTS, PPM 182 | from rstar_math.models import ModelFactory 183 | 184 | # Initialize components 185 | mcts = MCTS.from_config_file('config/default.json') 186 | ppm = ProcessPreferenceModel.from_config_file('config/default.json') 187 | model = ModelFactory.create_model('openai', 'YOUR_API_KEY', 'config/default.json') 188 | 189 | # Solve a problem 190 | problem = "What is the derivative of f(x) = x^2 + 3x?" 191 | action, trajectory = mcts.search(problem) 192 | 193 | # Print solution steps with confidence scores 194 | for step in trajectory: 195 | confidence = ppm.evaluate_step(step['state'], model) 196 | print(f"Step: {step['state']}") 197 | print(f"Confidence: {confidence:.2f}\n") 198 | ``` 199 | 200 | ## Example Applications 201 | 202 | ### 1. Interactive Web Interface 203 | 204 | ```python 205 | from examples.gradio_integration import RStarMathGradio 206 | 207 | # Launch Gradio interface 208 | demo = RStarMathGradio() 209 | demo.launch() 210 | ``` 211 | 212 | ### 2. Chatbot Integration 213 | 214 | ```python 215 | from examples.rasa_integration import RStarMathAction 216 | 217 | # Use in Rasa custom action 218 | action = RStarMathAction() 219 | await action.run(dispatcher, tracker, domain) 220 | ``` 221 | 222 | ### 3. Local Model Inference 223 | 224 | ```python 225 | from examples.llama_cpp_integration import LlamaCppModel 226 | 227 | # Initialize local model 228 | model = LlamaCppModel("path/to/model.bin") 229 | response = model.generate_response("What is 2 + 2?") 230 | ``` 231 | 232 | ## Documentation 233 | 234 | - [API Reference](docs/api.md) 235 | - [Model Integration Guide](docs/model_integration.md) 236 | - [Example Notebooks](examples/notebooks/) 237 | - [Benchmark Results](docs/benchmarks.md) 238 | 239 | ## Benchmarking 240 | 241 | Run performance benchmarks: 242 | 243 | ```bash 244 | python tools/benchmark.py 245 | ``` 246 | 247 | This will generate: 248 | - Execution time comparisons 249 | - Memory usage analysis 250 | - Token count statistics 251 | - Confidence score trends 252 | 253 | ## Contributing 254 | 255 | 1. Fork the repository 256 | 2. Create your feature branch 257 | 3. Run tests: `pytest tests/` 258 | 4. Submit a pull request 259 | 260 | ## License 261 | 262 | MIT License - see LICENSE file for details 263 | 264 | ## Citation 265 | 266 | If you use rStar-Math in your research, please cite: 267 | 268 | ```bibtex 269 | @article{rstar2024, 270 | title={rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking}, 271 | author={Original Authors}, 272 | journal={arXiv preprint}, 273 | year={2024} 274 | } 275 | ``` 276 | 277 | ## Acknowledgments 278 | 279 | This project is inspired by the paper "rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking" (https://arxiv.org/pdf/2501.04519). 280 | 281 | ## Project Structure 282 | 283 | ``` 284 | rStar-Math/ 285 | ├── src/ # Source code 286 | │ ├── core/ # Core components (MCTS, PPM) 287 | │ ├── models/ # Model implementations 288 | │ └── utils/ # Utility functions 289 | ├── tests/ # Test suites 290 | ├── examples/ # Example integrations 291 | │ ├── notebooks/ # Jupyter notebooks 292 | │ └── frameworks/ # Framework integrations 293 | ├── docs/ # Documentation 294 | ├── tools/ # Development tools 295 | └── config/ # Configuration files 296 | ``` 297 | 298 | ## Development Workflow 299 | 300 | 1. Create a new feature branch: 301 | ```bash 302 | git checkout -b feature/your-feature-name 303 | ``` 304 | 305 | 2. Make changes and run tests: 306 | ```bash 307 | # Format code 308 | black src/ tests/ 309 | 310 | # Run linter 311 | flake8 src/ tests/ 312 | 313 | # Run tests 314 | pytest tests/ 315 | ``` 316 | 317 | 3. Submit a pull request: 318 | ```bash 319 | git add . 320 | git commit -m "feat: your feature description" 321 | git push origin feature/your-feature-name 322 | ``` 323 | 324 | ## Troubleshooting 325 | 326 | ### Common Issues 327 | 328 | 1. API Key Issues: 329 | ```bash 330 | # Check if keys are loaded 331 | python -c "import os; print(os.getenv('OPENAI_API_KEY'))" 332 | ``` 333 | 334 | 2. Model Loading Issues: 335 | ```bash 336 | # Verify model files 337 | ls models/ 338 | ``` 339 | 340 | 3. CUDA Issues: 341 | ```bash 342 | # Check CUDA availability 343 | python -c "import torch; print(torch.cuda.is_available())" 344 | ``` 345 | 346 | For more issues, check the [troubleshooting guide](docs/troubleshooting.md). 347 | -------------------------------------------------------------------------------- /config/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "mcts": { 3 | "exploration_weight": 1.0, 4 | "max_simulations": 1000, 5 | "max_depth": 10 6 | }, 7 | "ppm": { 8 | "input_dim": 768, 9 | "hidden_dim": 256, 10 | "learning_rate": 0.001, 11 | "batch_size": 32 12 | }, 13 | "models": { 14 | "openai": { 15 | "model": "gpt-4", 16 | "temperature": 0.7, 17 | "max_tokens": 1000 18 | }, 19 | "anthropic": { 20 | "model": "claude-2", 21 | "temperature": 0.7, 22 | "max_tokens": 1000 23 | }, 24 | "mistral": { 25 | "model": "mistral-large", 26 | "temperature": 0.7, 27 | "max_tokens": 1000 28 | }, 29 | "groq": { 30 | "model": "mixtral-8x7b", 31 | "temperature": 0.7, 32 | "max_tokens": 1000 33 | }, 34 | "gemini": { 35 | "model": "gemini-pro", 36 | "temperature": 0.7, 37 | "max_tokens": 1000 38 | }, 39 | "cohere": { 40 | "model": "command", 41 | "temperature": 0.7, 42 | "max_tokens": 1000 43 | } 44 | }, 45 | "api": { 46 | "host": "0.0.0.0", 47 | "port": 8000, 48 | "debug": true 49 | }, 50 | "dashboard": { 51 | "host": "0.0.0.0", 52 | "port": 8050, 53 | "debug": true 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # rStar-Math Demonstrator API Documentation 2 | 3 | ## Overview 4 | 5 | The rStar-Math Demonstrator API provides endpoints for solving mathematical problems using various Language Models (LLMs) enhanced with the rStar-Math framework. The API supports model comparison, code generation for integration, and detailed performance analysis. 6 | 7 | ## Base URL 8 | 9 | ``` 10 | http://localhost:8000 11 | ``` 12 | 13 | ## Authentication 14 | 15 | All API endpoints require an API key to be passed in the request headers: 16 | 17 | ``` 18 | Authorization: Bearer YOUR_API_KEY 19 | ``` 20 | 21 | ## Endpoints 22 | 23 | ### 1. Solve Problem 24 | 25 | Solve a mathematical problem using rStar-Math enhanced reasoning. 26 | 27 | **Endpoint:** `/solve` 28 | **Method:** `POST` 29 | 30 | **Request Body:** 31 | ```json 32 | { 33 | "problem_text": "What is 2 + 2?", 34 | "model_name": "gpt-4", 35 | "use_rstar": true, 36 | "mcts_simulations": 1000, 37 | "temperature": 0.7 38 | } 39 | ``` 40 | 41 | **Parameters:** 42 | - `problem_text` (string, required): The mathematical problem to solve 43 | - `model_name` (string, required): Name of the LLM to use (e.g., "gpt-4", "claude-2") 44 | - `use_rstar` (boolean, optional): Whether to use rStar-Math enhancement (default: true) 45 | - `mcts_simulations` (integer, optional): Number of MCTS simulations (default: 1000) 46 | - `temperature` (float, optional): Model temperature (default: 0.7) 47 | 48 | **Response:** 49 | ```json 50 | { 51 | "solution_steps": [ 52 | "First, we identify this is an addition problem", 53 | "Then, we add 2 and 2 together", 54 | "Therefore, 2 + 2 = 4" 55 | ], 56 | "confidence_score": 0.95, 57 | "reasoning_path": [ 58 | { 59 | "state": "Initial problem analysis", 60 | "action": "Identify operation", 61 | "value": 0.8, 62 | "visits": 10 63 | } 64 | ], 65 | "execution_time": 1.23 66 | } 67 | ``` 68 | 69 | ### 2. Compare Models 70 | 71 | Compare problem-solving performance across different LLMs. 72 | 73 | **Endpoint:** `/compare-models` 74 | **Method:** `POST` 75 | 76 | **Request Body:** 77 | ```json 78 | { 79 | "problem_text": "What is 2 + 2?", 80 | "model_name": "all", 81 | "use_rstar": true, 82 | "mcts_simulations": 1000, 83 | "temperature": 0.7 84 | } 85 | ``` 86 | 87 | **Parameters:** 88 | - Same as `/solve` endpoint 89 | - `model_name` can be "all" to compare all available models 90 | 91 | **Response:** 92 | ```json 93 | { 94 | "openai": { 95 | "solution": "4", 96 | "score": 0.95, 97 | "execution_time": 1.23 98 | }, 99 | "anthropic": { 100 | "solution": "4", 101 | "score": 0.92, 102 | "execution_time": 1.45 103 | } 104 | } 105 | ``` 106 | 107 | ### 3. Generate Integration Code 108 | 109 | Generate code for integrating rStar-Math into other frameworks. 110 | 111 | **Endpoint:** `/generate-integration-code` 112 | **Method:** `POST` 113 | 114 | **Request Body:** 115 | ```json 116 | { 117 | "framework": "rasa", 118 | "config": { 119 | "model": "gpt-4", 120 | "use_rstar": true 121 | } 122 | } 123 | ``` 124 | 125 | **Parameters:** 126 | - `framework` (string, required): Target framework ("rasa", "azure", "langchain") 127 | - `config` (object, required): Framework-specific configuration 128 | 129 | **Response:** 130 | ```json 131 | { 132 | "code": "# Generated integration code...", 133 | "dependencies": [ 134 | "rstar-math>=1.0.0", 135 | "rasa>=3.0.0" 136 | ], 137 | "instructions": "Setup and usage instructions..." 138 | } 139 | ``` 140 | 141 | ## Error Handling 142 | 143 | The API uses standard HTTP status codes: 144 | 145 | - 200: Success 146 | - 400: Bad Request 147 | - 401: Unauthorized 148 | - 422: Validation Error 149 | - 500: Internal Server Error 150 | 151 | Error responses include a detail message: 152 | 153 | ```json 154 | { 155 | "detail": "Error message describing what went wrong" 156 | } 157 | ``` 158 | 159 | ## Rate Limiting 160 | 161 | - 100 requests per minute per API key 162 | - 1000 requests per day per API key 163 | 164 | ## Examples 165 | 166 | ### Python Example 167 | 168 | ```python 169 | import requests 170 | 171 | api_key = "YOUR_API_KEY" 172 | headers = {"Authorization": f"Bearer {api_key}"} 173 | 174 | # Solve a problem 175 | response = requests.post( 176 | "http://localhost:8000/solve", 177 | headers=headers, 178 | json={ 179 | "problem_text": "What is 2 + 2?", 180 | "model_name": "gpt-4", 181 | "use_rstar": true 182 | } 183 | ) 184 | 185 | print(response.json()) 186 | 187 | # Compare models 188 | response = requests.post( 189 | "http://localhost:8000/compare-models", 190 | headers=headers, 191 | json={ 192 | "problem_text": "What is 2 + 2?", 193 | "model_name": "all" 194 | } 195 | ) 196 | 197 | print(response.json()) 198 | ``` 199 | 200 | ## Support 201 | 202 | For issues, feature requests, or questions, please: 203 | 1. Check the [GitHub Issues](https://github.com/your-repo/rstar-math/issues) 204 | 2. Create a new issue if needed 205 | 3. Contact support at support@rstar-math.com 206 | -------------------------------------------------------------------------------- /examples/azure_bot_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example integration with Azure Bot Framework 3 | """ 4 | from botbuilder.core import TurnContext, ActivityHandler 5 | from botbuilder.schema import Activity 6 | import asyncio 7 | from typing import List, Dict, Any 8 | 9 | from src.core.mcts import MCTS 10 | from src.core.ppm import ProcessPreferenceModel 11 | from src.models.model_interface import ModelFactory 12 | 13 | class RStarMathBot(ActivityHandler): 14 | def __init__(self, api_key: str = "YOUR_API_KEY", config_path: str = "config/default.json"): 15 | super().__init__() 16 | 17 | # Initialize rStar-Math components 18 | self.mcts = MCTS.from_config_file(config_path) 19 | self.ppm = ProcessPreferenceModel.from_config_file(config_path) 20 | self.model = ModelFactory.create_model( 21 | "openai", 22 | api_key, 23 | config_path 24 | ) 25 | 26 | async def on_message_activity(self, turn_context: TurnContext): 27 | """Handle incoming messages.""" 28 | # Get the math problem from user message 29 | problem = turn_context.activity.text 30 | 31 | try: 32 | # Solve using rStar-Math 33 | action, trajectory = self.mcts.search(problem) 34 | 35 | # Format solution steps 36 | steps = [] 37 | for step in trajectory: 38 | step_text = step["state"] 39 | step_score = self.ppm.evaluate_step(step_text, self.model) 40 | steps.append(f"{step_text} (confidence: {step_score:.2f})") 41 | 42 | # Send response 43 | await turn_context.send_activity( 44 | f"Here's how I solved it:\n" + "\n".join(steps) 45 | ) 46 | 47 | except Exception as e: 48 | await turn_context.send_activity( 49 | f"I encountered an error while solving the problem: {str(e)}" 50 | ) 51 | 52 | async def on_members_added_activity( 53 | self, 54 | members_added: List[ChannelAccount], 55 | turn_context: TurnContext 56 | ): 57 | """Welcome new users.""" 58 | for member in members_added: 59 | if member.id != turn_context.activity.recipient.id: 60 | await turn_context.send_activity( 61 | "Welcome! I'm an AI math tutor powered by rStar-Math. " 62 | "Send me any math problem, and I'll solve it step by step!" 63 | ) 64 | 65 | # Example usage: 66 | """ 67 | from botbuilder.core import BotFrameworkAdapter, BotFrameworkAdapterSettings 68 | 69 | # Initialize bot 70 | SETTINGS = BotFrameworkAdapterSettings("APP_ID", "APP_PASSWORD") 71 | ADAPTER = BotFrameworkAdapter(SETTINGS) 72 | BOT = RStarMathBot(api_key="YOUR_API_KEY") 73 | 74 | # Error handler 75 | async def on_error(context: TurnContext, error: Exception): 76 | await context.send_activity("Sorry, I encountered an error!") 77 | 78 | ADAPTER.on_turn_error = on_error 79 | 80 | # Message handler 81 | async def message_handler(req, res): 82 | async def callback(turn_context): 83 | await BOT.on_turn(turn_context) 84 | 85 | await ADAPTER.process_activity(req, res, callback) 86 | """ 87 | -------------------------------------------------------------------------------- /examples/gradio_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gradio integration for rStar-Math 3 | """ 4 | import os 5 | import gradio as gr 6 | import numpy as np 7 | import pandas as pd 8 | import plotly.express as px 9 | from src.core.mcts import MCTS 10 | from src.core.ppm import ProcessPreferenceModel 11 | from src.models.model_interface import ModelFactory 12 | 13 | class RStarMathGradio: 14 | def __init__(self): 15 | """Initialize rStar-Math components.""" 16 | self.mcts = MCTS.from_config_file('config/default.json') 17 | self.ppm = ProcessPreferenceModel.from_config_file('config/default.json') 18 | 19 | # Initialize models 20 | self.models = {} 21 | model_keys = { 22 | 'openai': 'OPENAI_API_KEY', 23 | 'anthropic': 'ANTHROPIC_API_KEY', 24 | 'mistral': 'MISTRAL_API_KEY', 25 | 'groq': 'GROQ_API_KEY', 26 | 'gemini': 'GEMINI_API_KEY' 27 | } 28 | 29 | for name, key in model_keys.items(): 30 | api_key = os.getenv(key) 31 | if api_key: 32 | self.models[name] = ModelFactory.create_model( 33 | name, api_key, 'config/default.json' 34 | ) 35 | 36 | def solve_problem(self, 37 | problem: str, 38 | model_name: str, 39 | use_rstar: bool = True, 40 | show_visualization: bool = True): 41 | """Solve math problem and return solution with visualizations.""" 42 | model = self.models[model_name] 43 | 44 | if use_rstar: 45 | action, trajectory = self.mcts.search(problem) 46 | solution_steps = [] 47 | confidence_scores = [] 48 | 49 | for step in trajectory: 50 | confidence = self.ppm.evaluate_step(step['state'], model) 51 | solution_steps.append(step['state']) 52 | confidence_scores.append(confidence) 53 | else: 54 | solution = model.generate_response(problem) 55 | confidence = model.evaluate_reasoning(problem, [solution]) 56 | solution_steps = [solution] 57 | confidence_scores = [confidence] 58 | 59 | # Format output 60 | output = "Solution Steps:\\n" 61 | for i, (step, conf) in enumerate(zip(solution_steps, confidence_scores), 1): 62 | output += f"Step {i}: {step}\\nConfidence: {conf:.2f}\\n\\n" 63 | 64 | # Create visualization 65 | if show_visualization and len(confidence_scores) > 1: 66 | df = pd.DataFrame({ 67 | 'Step': range(1, len(confidence_scores) + 1), 68 | 'Confidence': confidence_scores 69 | }) 70 | fig = px.line(df, x='Step', y='Confidence', 71 | title='Solution Confidence Trend') 72 | 73 | return output, fig 74 | 75 | return output, None 76 | 77 | def create_examples(): 78 | """Create example problems for the demo.""" 79 | return [ 80 | ["What is 15 × 27?", "openai", True, True], 81 | ["Solve for x: 2x + 5 = 13", "openai", True, True], 82 | ["Find the derivative of f(x) = x² + 3x", "openai", True, True], 83 | ["Find the area of a circle with radius 5", "openai", True, True] 84 | ] 85 | 86 | def main(): 87 | """Create and launch Gradio interface.""" 88 | rstar = RStarMathGradio() 89 | 90 | # Create interface 91 | with gr.Blocks(title="rStar-Math Demonstrator") as demo: 92 | gr.Markdown("# rStar-Math Problem Solver") 93 | 94 | with gr.Row(): 95 | with gr.Column(): 96 | problem_input = gr.Textbox( 97 | label="Enter your math problem", 98 | placeholder="e.g., What is 2 + 2?" 99 | ) 100 | model_select = gr.Dropdown( 101 | choices=list(rstar.models.keys()), 102 | value=list(rstar.models.keys())[0], 103 | label="Select Model" 104 | ) 105 | use_rstar = gr.Checkbox( 106 | label="Use rStar-Math Enhancement", 107 | value=True 108 | ) 109 | show_viz = gr.Checkbox( 110 | label="Show Visualization", 111 | value=True 112 | ) 113 | solve_btn = gr.Button("Solve") 114 | 115 | with gr.Column(): 116 | solution_output = gr.Textbox( 117 | label="Solution", 118 | lines=10 119 | ) 120 | plot_output = gr.Plot(label="Confidence Trend") 121 | 122 | # Add examples 123 | gr.Examples( 124 | examples=create_examples(), 125 | inputs=[problem_input, model_select, use_rstar, show_viz] 126 | ) 127 | 128 | # Connect components 129 | solve_btn.click( 130 | fn=rstar.solve_problem, 131 | inputs=[problem_input, model_select, use_rstar, show_viz], 132 | outputs=[solution_output, plot_output] 133 | ) 134 | 135 | # Add documentation 136 | with gr.Accordion("About"): 137 | gr.Markdown(""" 138 | ## rStar-Math Demonstrator 139 | 140 | This interface demonstrates the capabilities of rStar-Math, an AI framework 141 | that enhances mathematical reasoning using Monte Carlo Tree Search and 142 | Process Preference Models. 143 | 144 | ### Features: 145 | - Multiple LLM support (OpenAI, Anthropic, Mistral, etc.) 146 | - Step-by-step solution breakdown 147 | - Confidence scoring for each step 148 | - Visual confidence tracking 149 | 150 | ### Usage: 151 | 1. Enter your math problem 152 | 2. Select a model 153 | 3. Choose whether to use rStar-Math enhancement 154 | 4. Click "Solve" to see the solution 155 | """) 156 | 157 | # Launch interface 158 | demo.launch(share=False) 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /examples/langchain_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example integration with LangChain framework 3 | """ 4 | from typing import Any, Dict, List 5 | from langchain.chains import LLMChain 6 | from langchain.prompts import PromptTemplate 7 | from langchain.llms.base import BaseLLM 8 | 9 | from src.core.mcts import MCTS 10 | from src.core.ppm import ProcessPreferenceModel 11 | from src.models.model_interface import ModelFactory 12 | 13 | class RStarMathChain(LLMChain): 14 | def __init__( 15 | self, 16 | llm: BaseLLM, 17 | prompt: PromptTemplate = None, 18 | api_key: str = "YOUR_API_KEY", 19 | config_path: str = "config/default.json" 20 | ): 21 | if prompt is None: 22 | prompt = PromptTemplate( 23 | input_variables=["problem"], 24 | template="Solve this math problem: {problem}" 25 | ) 26 | 27 | super().__init__(llm=llm, prompt=prompt) 28 | 29 | # Initialize rStar-Math components 30 | self.mcts = MCTS.from_config_file(config_path) 31 | self.ppm = ProcessPreferenceModel.from_config_file(config_path) 32 | self.model = ModelFactory.create_model( 33 | "openai", 34 | api_key, 35 | config_path 36 | ) 37 | 38 | def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: 39 | """Process the input using rStar-Math enhanced reasoning.""" 40 | problem = inputs["problem"] 41 | 42 | # Get direct solution from LLM (System 1) 43 | direct_solution = super()._call(inputs) 44 | 45 | # Get enhanced solution using rStar-Math (System 2) 46 | action, trajectory = self.mcts.search(problem) 47 | 48 | # Evaluate and compare solutions 49 | direct_score = self.model.evaluate_reasoning( 50 | problem, 51 | [direct_solution["text"]] 52 | ) 53 | 54 | enhanced_steps = [] 55 | total_score = 0.0 56 | for step in trajectory: 57 | step_text = step["state"] 58 | step_score = self.ppm.evaluate_step(step_text, self.model) 59 | enhanced_steps.append(f"{step_text} (confidence: {step_score:.2f})") 60 | total_score += step_score 61 | 62 | enhanced_score = total_score / len(trajectory) if trajectory else 0.0 63 | 64 | return { 65 | "direct_solution": direct_solution["text"], 66 | "direct_score": direct_score, 67 | "enhanced_solution": "\n".join(enhanced_steps), 68 | "enhanced_score": enhanced_score, 69 | "improvement": enhanced_score - direct_score 70 | } 71 | 72 | # Example usage: 73 | """ 74 | from langchain.llms import OpenAI 75 | 76 | llm = OpenAI(temperature=0.7) 77 | chain = RStarMathChain(llm=llm, api_key="YOUR_API_KEY") 78 | 79 | result = chain.run("What is 2 + 2?") 80 | print(f"Direct solution (score: {result['direct_score']:.2f}):") 81 | print(result['direct_solution']) 82 | print("\nEnhanced solution (score: {result['enhanced_score']:.2f}):") 83 | print(result['enhanced_solution']) 84 | print(f"\nImprovement: {result['improvement']:.2f}") 85 | """ 86 | -------------------------------------------------------------------------------- /examples/llama_cpp_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration with llama.cpp for local model inference 3 | """ 4 | from typing import List, Dict, Any, Optional 5 | import json 6 | import os 7 | from llama_cpp import Llama 8 | from src.core.mcts import MCTS 9 | from src.core.ppm import ProcessPreferenceModel 10 | from src.models.model_interface import LLMInterface, ModelConfig 11 | 12 | class LlamaCppModel(LLMInterface): 13 | def __init__(self, model_path: str, config: Optional[ModelConfig] = None): 14 | """Initialize Llama model.""" 15 | self.model_path = model_path 16 | self.config = config or ModelConfig( 17 | model="llama2", 18 | temperature=0.7, 19 | max_tokens=100 20 | ) 21 | self.llm = Llama( 22 | model_path=model_path, 23 | n_ctx=2048, 24 | n_threads=os.cpu_count() 25 | ) 26 | 27 | @classmethod 28 | def from_config_file(cls, config_path: str, model_path: str) -> 'LlamaCppModel': 29 | """Create Llama model instance from config file.""" 30 | with open(config_path, 'r') as f: 31 | config_data = json.load(f) 32 | config = ModelConfig(**config_data['models']['llama']) 33 | return cls(model_path, config) 34 | 35 | def generate_response(self, 36 | prompt: str, 37 | temperature: Optional[float] = None, 38 | max_tokens: Optional[int] = None) -> str: 39 | """Generate response using local Llama model.""" 40 | response = self.llm( 41 | prompt, 42 | max_tokens=max_tokens or self.config.max_tokens, 43 | temperature=temperature or self.config.temperature, 44 | echo=False 45 | ) 46 | return response['choices'][0]['text'].strip() 47 | 48 | def evaluate_reasoning(self, 49 | problem: str, 50 | solution_steps: List[str]) -> float: 51 | """Evaluate reasoning steps using Llama.""" 52 | prompt = f""" 53 | Problem: {problem} 54 | Solution Steps: 55 | {chr(10).join(f'{i+1}. {step}' for i, step in enumerate(solution_steps))} 56 | 57 | Rate the quality of these solution steps from 0 to 1, where: 58 | 0 = completely incorrect or invalid reasoning 59 | 1 = perfect, clear, and mathematically sound reasoning 60 | 61 | Provide only the numerical rating. 62 | """ 63 | 64 | response = self.generate_response(prompt) 65 | try: 66 | rating = float(response.strip()) 67 | return max(0.0, min(1.0, rating)) 68 | except ValueError: 69 | return 0.0 70 | 71 | def embed_text(self, text: str) -> List[float]: 72 | """Generate embeddings using Llama.""" 73 | embeddings = self.llm.embed(text) 74 | return embeddings.tolist() 75 | 76 | def main(): 77 | """Example usage of Llama integration.""" 78 | # Initialize components 79 | model_path = "path/to/llama/model.bin" # Update with actual path 80 | model = LlamaCppModel(model_path) 81 | mcts = MCTS.from_config_file('config/default.json') 82 | ppm = ProcessPreferenceModel.from_config_file('config/default.json') 83 | 84 | # Test problem 85 | problem = "What is the derivative of x^2?" 86 | 87 | print(f"Problem: {problem}\n") 88 | 89 | # Generate solution with rStar-Math 90 | action, trajectory = mcts.search(problem) 91 | 92 | print("Solution Steps:") 93 | for step in trajectory: 94 | confidence = ppm.evaluate_step(step['state'], model) 95 | print(f"- {step['state']}") 96 | print(f" Confidence: {confidence:.2f}\n") 97 | 98 | # Direct solution comparison 99 | print("Direct Solution:") 100 | direct_solution = model.generate_response(problem) 101 | direct_confidence = model.evaluate_reasoning(problem, [direct_solution]) 102 | print(direct_solution) 103 | print(f"Confidence: {direct_confidence:.2f}") 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /examples/notebooks/calculus_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# rStar-Math: Advanced Calculus Examples\n", 8 | "\n", 9 | "This notebook demonstrates how rStar-Math handles complex calculus problems." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "source": [ 17 | "import os\n", 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from src.core.mcts import MCTS\n", 21 | "from src.core.ppm import ProcessPreferenceModel\n", 22 | "from src.models.model_interface import ModelFactory" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Setup Components" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "source": [ 37 | "mcts = MCTS.from_config_file('config/default.json')\n", 38 | "ppm = ProcessPreferenceModel.from_config_file('config/default.json')\n", 39 | "model = ModelFactory.create_model('openai', os.getenv('OPENAI_API_KEY'), 'config/default.json')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## 1. Derivatives and Integration" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "source": [ 54 | "calculus_problems = [\n", 55 | " \"Find the derivative of f(x) = sin(x)cos(x)\",\n", 56 | " \"Integrate ∫x²e^x dx\",\n", 57 | " \"Find the second derivative of f(x) = ln(x)\",\n", 58 | " \"Solve the differential equation dy/dx = x + y\"\n", 59 | "]\n", 60 | "\n", 61 | "for problem in calculus_problems:\n", 62 | " print(f\"Problem: {problem}\\n\")\n", 63 | " action, trajectory = mcts.search(problem)\n", 64 | " \n", 65 | " print(\"Solution Steps:\")\n", 66 | " for step in trajectory:\n", 67 | " confidence = ppm.evaluate_step(step['state'], model)\n", 68 | " print(f\"- {step['state']}\")\n", 69 | " print(f\" Confidence: {confidence:.2f}\\n\")\n", 70 | " print(\"-\" * 50 + \"\\n\")" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "## 2. Limits and Series" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "source": [ 85 | "limit_problems = [\n", 86 | " \"Find the limit of (1 + 1/n)^n as n approaches infinity\",\n", 87 | " \"Find the sum of the infinite series 1 + 1/2 + 1/4 + 1/8 + ...\",\n", 88 | " \"Determine if the series Σ(1/n) converges\"\n", 89 | "]\n", 90 | "\n", 91 | "for problem in limit_problems:\n", 92 | " print(f\"Problem: {problem}\\n\")\n", 93 | " action, trajectory = mcts.search(problem)\n", 94 | " \n", 95 | " print(\"Solution Steps:\")\n", 96 | " for step in trajectory:\n", 97 | " confidence = ppm.evaluate_step(step['state'], model)\n", 98 | " print(f\"- {step['state']}\")\n", 99 | " print(f\" Confidence: {confidence:.2f}\\n\")\n", 100 | " print(\"-\" * 50 + \"\\n\")" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "## 3. Visualization of Solutions\n", 108 | "\n", 109 | "Let's visualize some of the calculus concepts." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "source": [ 117 | "def plot_function_and_derivative(f_str: str):\n", 118 | " \"\"\"Plot a function and its derivative.\"\"\"\n", 119 | " problem = f\"Find the derivative of f(x) = {f_str}\"\n", 120 | " action, trajectory = mcts.search(problem)\n", 121 | " \n", 122 | " # Extract derivative from solution\n", 123 | " derivative_str = trajectory[-1]['state']\n", 124 | " \n", 125 | " # Create plot\n", 126 | " x = np.linspace(-5, 5, 100)\n", 127 | " \n", 128 | " # Original function\n", 129 | " f = eval(f\"lambda x: {f_str}\")\n", 130 | " plt.plot(x, f(x), label=f'f(x) = {f_str}')\n", 131 | " \n", 132 | " # Derivative\n", 133 | " try:\n", 134 | " f_prime = eval(f\"lambda x: {derivative_str}\")\n", 135 | " plt.plot(x, f_prime(x), '--', label=f\"f'(x) = {derivative_str}\")\n", 136 | " except:\n", 137 | " print(\"Could not plot derivative\")\n", 138 | " \n", 139 | " plt.grid(True)\n", 140 | " plt.legend()\n", 141 | " plt.title(f\"Function and its Derivative\")\n", 142 | " plt.show()\n", 143 | "\n", 144 | "# Example plots\n", 145 | "functions = [\n", 146 | " \"x**2\",\n", 147 | " \"np.sin(x)\",\n", 148 | " \"np.exp(x)\"\n", 149 | "]\n", 150 | "\n", 151 | "for f_str in functions:\n", 152 | " plot_function_and_derivative(f_str)" 153 | ] 154 | } 155 | ], 156 | "metadata": { 157 | "kernelspec": { 158 | "display_name": "Python 3", 159 | "language": "python", 160 | "name": "python3" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 4 165 | } 166 | -------------------------------------------------------------------------------- /examples/notebooks/geometry_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# rStar-Math: Geometry and Proofs\n", 8 | "\n", 9 | "This notebook demonstrates how rStar-Math handles geometric problems and mathematical proofs." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "source": [ 17 | "import os\n", 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from matplotlib.patches import Circle, Rectangle, Polygon\n", 21 | "from src.core.mcts import MCTS\n", 22 | "from src.core.ppm import ProcessPreferenceModel\n", 23 | "from src.models.model_interface import ModelFactory" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Setup Components" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "source": [ 38 | "mcts = MCTS.from_config_file('config/default.json')\n", 39 | "ppm = ProcessPreferenceModel.from_config_file('config/default.json')\n", 40 | "model = ModelFactory.create_model('openai', os.getenv('OPENAI_API_KEY'), 'config/default.json')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## 1. Basic Geometry Problems" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "source": [ 55 | "geometry_problems = [\n", 56 | " \"Find the area of a circle with radius 5\",\n", 57 | " \"Calculate the volume of a sphere with radius 3\",\n", 58 | " \"In a right triangle with legs 3 and 4, find the hypotenuse\",\n", 59 | " \"Find the area of a regular hexagon with side length 2\"\n", 60 | "]\n", 61 | "\n", 62 | "for problem in geometry_problems:\n", 63 | " print(f\"Problem: {problem}\\n\")\n", 64 | " action, trajectory = mcts.search(problem)\n", 65 | " \n", 66 | " print(\"Solution Steps:\")\n", 67 | " for step in trajectory:\n", 68 | " confidence = ppm.evaluate_step(step['state'], model)\n", 69 | " print(f\"- {step['state']}\")\n", 70 | " print(f\" Confidence: {confidence:.2f}\\n\")\n", 71 | " print(\"-\" * 50 + \"\\n\")" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "## 2. Geometric Proofs" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "source": [ 86 | "proof_problems = [\n", 87 | " \"Prove that the sum of angles in a triangle is 180 degrees\",\n", 88 | " \"Prove the Pythagorean theorem\",\n", 89 | " \"Prove that the diagonals of a rectangle are equal\"\n", 90 | "]\n", 91 | "\n", 92 | "for problem in proof_problems:\n", 93 | " print(f\"Problem: {problem}\\n\")\n", 94 | " action, trajectory = mcts.search(problem)\n", 95 | " \n", 96 | " print(\"Proof Steps:\")\n", 97 | " for step in trajectory:\n", 98 | " confidence = ppm.evaluate_step(step['state'], model)\n", 99 | " print(f\"- {step['state']}\")\n", 100 | " print(f\" Confidence: {confidence:.2f}\\n\")\n", 101 | " print(\"-\" * 50 + \"\\n\")" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## 3. Visualization of Geometric Concepts" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "source": [ 116 | "def plot_right_triangle():\n", 117 | " \"\"\"Visualize Pythagorean theorem.\"\"\"\n", 118 | " fig, ax = plt.subplots(figsize=(8, 8))\n", 119 | " \n", 120 | " # Draw triangle\n", 121 | " ax.plot([0, 3, 0, 0], [0, 0, 4, 0], 'b-', linewidth=2)\n", 122 | " \n", 123 | " # Add squares on each side\n", 124 | " ax.add_patch(Rectangle((0, 0), 3, 3, fill=False, color='red'))\n", 125 | " ax.add_patch(Rectangle((3, 0), 4, 4, fill=False, color='green', angle=90))\n", 126 | " ax.add_patch(Polygon([[0, 0], [0, 4], [-5, 4], [-5, 0]], fill=False, color='blue'))\n", 127 | " \n", 128 | " # Add labels\n", 129 | " ax.text(1.5, -0.5, '3', ha='center')\n", 130 | " ax.text(-0.5, 2, '4', va='center')\n", 131 | " ax.text(1, 2, '5', ha='center', va='center', rotation=-45)\n", 132 | " \n", 133 | " ax.set_aspect('equal')\n", 134 | " ax.grid(True)\n", 135 | " plt.title(\"Pythagorean Theorem: a² + b² = c²\")\n", 136 | " plt.show()\n", 137 | "\n", 138 | "def plot_circle_properties():\n", 139 | " \"\"\"Visualize circle properties.\"\"\"\n", 140 | " fig, ax = plt.subplots(figsize=(8, 8))\n", 141 | " \n", 142 | " # Draw circle\n", 143 | " circle = Circle((0, 0), 5, fill=False)\n", 144 | " ax.add_patch(circle)\n", 145 | " \n", 146 | " # Draw radius\n", 147 | " ax.plot([0, 5], [0, 0], 'r-', label='Radius')\n", 148 | " \n", 149 | " # Draw diameter\n", 150 | " ax.plot([-5, 5], [0, 0], 'g--', label='Diameter')\n", 151 | " \n", 152 | " ax.set_aspect('equal')\n", 153 | " ax.grid(True)\n", 154 | " plt.legend()\n", 155 | " plt.title(\"Circle Properties\")\n", 156 | " plt.show()\n", 157 | "\n", 158 | "# Generate visualizations\n", 159 | "plot_right_triangle()\n", 160 | "plot_circle_properties()" 161 | ] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "language": "python", 168 | "name": "python3" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 4 173 | } 174 | -------------------------------------------------------------------------------- /examples/notebooks/integration_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# rStar-Math Integration Examples\n", 8 | "\n", 9 | "This notebook demonstrates how to integrate rStar-Math with various frameworks and platforms." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "source": [ 17 | "import os\n", 18 | "from typing import Dict, List\n", 19 | "from src.core.mcts import MCTS\n", 20 | "from src.core.ppm import ProcessPreferenceModel\n", 21 | "from src.models.model_interface import ModelFactory" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## 1. Basic Integration\n", 29 | "\n", 30 | "First, let's see how to use rStar-Math directly." 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "source": [ 38 | "# Initialize components\n", 39 | "mcts = MCTS.from_config_file('config/default.json')\n", 40 | "ppm = ProcessPreferenceModel.from_config_file('config/default.json')\n", 41 | "model = ModelFactory.create_model(\n", 42 | " 'openai',\n", 43 | " os.getenv('OPENAI_API_KEY'),\n", 44 | " 'config/default.json'\n", 45 | ")\n", 46 | "\n", 47 | "# Solve a problem\n", 48 | "problem = \"What is the derivative of f(x) = x^2 + 3x?\"\n", 49 | "action, trajectory = mcts.search(problem)\n", 50 | "\n", 51 | "# Print solution steps with confidence scores\n", 52 | "for step in trajectory:\n", 53 | " confidence = ppm.evaluate_step(step['state'], model)\n", 54 | " print(f\"Step: {step['state']}\")\n", 55 | " print(f\"Confidence: {confidence:.2f}\\n\")" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "## 2. LangChain Integration\n", 63 | "\n", 64 | "Here's how to use rStar-Math with LangChain." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "source": [ 72 | "from langchain.llms import OpenAI\n", 73 | "from examples.langchain_integration import RStarMathChain\n", 74 | "\n", 75 | "# Initialize LangChain components\n", 76 | "llm = OpenAI(temperature=0.7)\n", 77 | "chain = RStarMathChain(\n", 78 | " llm=llm,\n", 79 | " api_key=os.getenv('OPENAI_API_KEY')\n", 80 | ")\n", 81 | "\n", 82 | "# Solve problem\n", 83 | "result = chain.run(\"What is 2 + 2?\")\n", 84 | "\n", 85 | "print(\"Direct Solution:\")\n", 86 | "print(result['direct_solution'])\n", 87 | "print(f\"Score: {result['direct_score']:.2f}\\n\")\n", 88 | "\n", 89 | "print(\"Enhanced Solution:\")\n", 90 | "print(result['enhanced_solution'])\n", 91 | "print(f\"Score: {result['enhanced_score']:.2f}\\n\")\n", 92 | "\n", 93 | "print(f\"Improvement: {result['improvement']:.2%}\")" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "## 3. Rasa Integration\n", 101 | "\n", 102 | "Example of using rStar-Math in a Rasa chatbot." 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "source": [ 110 | "from rasa_sdk import Action\n", 111 | "from examples.rasa_integration import RStarMathAction\n", 112 | "\n", 113 | "# Create mock Rasa components for demonstration\n", 114 | "class MockDispatcher:\n", 115 | " def utter_message(self, text: str):\n", 116 | " print(f\"Bot: {text}\")\n", 117 | "\n", 118 | "class MockTracker:\n", 119 | " def __init__(self, text: str):\n", 120 | " self.latest_message = {\"text\": text}\n", 121 | "\n", 122 | "# Initialize action\n", 123 | "action = RStarMathAction()\n", 124 | "\n", 125 | "# Simulate conversation\n", 126 | "async def simulate_conversation():\n", 127 | " dispatcher = MockDispatcher()\n", 128 | " tracker = MockTracker(\"What is the derivative of x^2?\")\n", 129 | " \n", 130 | " print(f\"User: {tracker.latest_message['text']}\")\n", 131 | " await action.run(dispatcher, tracker, {})\n", 132 | "\n", 133 | "# Run simulation\n", 134 | "import asyncio\n", 135 | "asyncio.run(simulate_conversation())" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "## 4. Custom Integration\n", 143 | "\n", 144 | "Example of creating a custom integration with rStar-Math." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "source": [ 152 | "class MathTutor:\n", 153 | " def __init__(self, api_key: str):\n", 154 | " self.mcts = MCTS.from_config_file('config/default.json')\n", 155 | " self.ppm = ProcessPreferenceModel.from_config_file('config/default.json')\n", 156 | " self.model = ModelFactory.create_model('openai', api_key, 'config/default.json')\n", 157 | " \n", 158 | " def solve_with_explanation(self, problem: str) -> Dict:\n", 159 | " # Get solution using rStar-Math\n", 160 | " action, trajectory = self.mcts.search(problem)\n", 161 | " \n", 162 | " # Format steps with confidence scores\n", 163 | " steps = []\n", 164 | " total_confidence = 0.0\n", 165 | " \n", 166 | " for step in trajectory:\n", 167 | " confidence = self.ppm.evaluate_step(step['state'], self.model)\n", 168 | " steps.append({\n", 169 | " 'explanation': step['state'],\n", 170 | " 'confidence': confidence\n", 171 | " })\n", 172 | " total_confidence += confidence\n", 173 | " \n", 174 | " return {\n", 175 | " 'steps': steps,\n", 176 | " 'average_confidence': total_confidence / len(steps) if steps else 0.0\n", 177 | " }\n", 178 | "\n", 179 | "# Test the custom integration\n", 180 | "tutor = MathTutor(os.getenv('OPENAI_API_KEY'))\n", 181 | "result = tutor.solve_with_explanation(\"Solve for x: 2x + 3 = 7\")\n", 182 | "\n", 183 | "print(\"Solution Steps:\")\n", 184 | "for i, step in enumerate(result['steps'], 1):\n", 185 | " print(f\"\\nStep {i}:\")\n", 186 | " print(f\"Explanation: {step['explanation']}\")\n", 187 | " print(f\"Confidence: {step['confidence']:.2f}\")\n", 188 | "\n", 189 | "print(f\"\\nOverall Confidence: {result['average_confidence']:.2f}\")" 190 | ] 191 | } 192 | ], 193 | "metadata": { 194 | "kernelspec": { 195 | "display_name": "Python 3", 196 | "language": "python", 197 | "name": "python3" 198 | }, 199 | "language_info": { 200 | "codemirror_mode": { 201 | "name": "ipython", 202 | "version": 3 203 | }, 204 | "file_extension": ".py", 205 | "mimetype": "text/x-python", 206 | "name": "python", 207 | "nbconvert_exporter": "python", 208 | "pygments_lexer": "ipython3", 209 | "version": "3.8.0" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 4 214 | } 215 | -------------------------------------------------------------------------------- /examples/notebooks/linear_algebra_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# rStar-Math: Linear Algebra Examples\n", 8 | "\n", 9 | "This notebook demonstrates how rStar-Math handles linear algebra problems with visualizations." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "source": [ 17 | "import os\n", 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from mpl_toolkits.mplot3d import Axes3D\n", 21 | "from src.core.mcts import MCTS\n", 22 | "from src.core.ppm import ProcessPreferenceModel\n", 23 | "from src.models.model_interface import ModelFactory" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "source": [ 31 | "# Initialize components\n", 32 | "mcts = MCTS.from_config_file('config/default.json')\n", 33 | "ppm = ProcessPreferenceModel.from_config_file('config/default.json')\n", 34 | "model = ModelFactory.create_model('openai', os.getenv('OPENAI_API_KEY'), 'config/default.json')" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## 1. Matrix Operations" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "source": [ 49 | "matrix_problems = [\n", 50 | " \"Find the determinant of [[1, 2], [3, 4]]\",\n", 51 | " \"Solve the system of equations: 2x + y = 5, x - y = 1\",\n", 52 | " \"Find the eigenvalues of [[2, 1], [1, 2]]\",\n", 53 | " \"Calculate the inverse of [[1, 2], [3, 4]]\"\n", 54 | "]\n", 55 | "\n", 56 | "def visualize_matrix(matrix_str: str):\n", 57 | " \"\"\"Visualize matrix as a heatmap.\"\"\"\n", 58 | " matrix = np.array(eval(matrix_str))\n", 59 | " plt.figure(figsize=(8, 6))\n", 60 | " plt.imshow(matrix, cmap='viridis')\n", 61 | " plt.colorbar()\n", 62 | " for i in range(matrix.shape[0]):\n", 63 | " for j in range(matrix.shape[1]):\n", 64 | " plt.text(j, i, f'{matrix[i,j]:.2f}', ha='center', va='center')\n", 65 | " plt.title('Matrix Visualization')\n", 66 | " plt.show()\n", 67 | "\n", 68 | "for problem in matrix_problems:\n", 69 | " print(f\"Problem: {problem}\\n\")\n", 70 | " action, trajectory = mcts.search(problem)\n", 71 | " \n", 72 | " print(\"Solution Steps:\")\n", 73 | " for step in trajectory:\n", 74 | " confidence = ppm.evaluate_step(step['state'], model)\n", 75 | " print(f\"- {step['state']}\")\n", 76 | " print(f\" Confidence: {confidence:.2f}\\n\")\n", 77 | " \n", 78 | " # Visualize matrix if present in problem\n", 79 | " if '[[' in problem:\n", 80 | " matrix_str = problem[problem.find('[['):problem.find(']]')+2]\n", 81 | " visualize_matrix(matrix_str)\n", 82 | " print(\"-\" * 50 + \"\\n\")" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## 2. Vector Spaces and Transformations" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "source": [ 97 | "def plot_vector_transformation(matrix: np.ndarray):\n", 98 | " \"\"\"Visualize linear transformation.\"\"\"\n", 99 | " fig = plt.figure(figsize=(12, 5))\n", 100 | " \n", 101 | " # Original vectors\n", 102 | " ax1 = fig.add_subplot(121)\n", 103 | " vectors = np.array([[1, 0], [0, 1]])\n", 104 | " ax1.quiver([0, 0], [0, 0], vectors[:, 0], vectors[:, 1],\n", 105 | " angles='xy', scale_units='xy', scale=1)\n", 106 | " ax1.set_xlim(-2, 2)\n", 107 | " ax1.set_ylim(-2, 2)\n", 108 | " ax1.grid(True)\n", 109 | " ax1.set_title('Original Vectors')\n", 110 | " \n", 111 | " # Transformed vectors\n", 112 | " ax2 = fig.add_subplot(122)\n", 113 | " transformed = np.dot(vectors, matrix)\n", 114 | " ax2.quiver([0, 0], [0, 0], transformed[:, 0], transformed[:, 1],\n", 115 | " angles='xy', scale_units='xy', scale=1)\n", 116 | " ax2.set_xlim(-2, 2)\n", 117 | " ax2.set_ylim(-2, 2)\n", 118 | " ax2.grid(True)\n", 119 | " ax2.set_title('Transformed Vectors')\n", 120 | " \n", 121 | " plt.show()\n", 122 | "\n", 123 | "# Example transformations\n", 124 | "transformations = [\n", 125 | " np.array([[2, 0], [0, 2]]), # Scaling\n", 126 | " np.array([[0, -1], [1, 0]]), # Rotation\n", 127 | " np.array([[1, 1], [0, 1]]) # Shear\n", 128 | "]\n", 129 | "\n", 130 | "for matrix in transformations:\n", 131 | " print(f\"Transformation Matrix:\\n{matrix}\\n\")\n", 132 | " plot_vector_transformation(matrix)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## 3. Eigenvalues and Eigenvectors" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "source": [ 147 | "def plot_eigenvectors(matrix: np.ndarray):\n", 148 | " \"\"\"Visualize eigenvectors and their transformations.\"\"\"\n", 149 | " eigenvals, eigenvecs = np.linalg.eig(matrix)\n", 150 | " \n", 151 | " plt.figure(figsize=(8, 8))\n", 152 | " \n", 153 | " # Plot original vectors\n", 154 | " for i, vec in enumerate(eigenvecs.T):\n", 155 | " plt.quiver(0, 0, vec[0], vec[1], angles='xy', scale_units='xy',\n", 156 | " scale=1, color='blue', label=f'Eigenvector {i+1}')\n", 157 | " \n", 158 | " # Plot transformed vectors\n", 159 | " transformed = np.dot(matrix, eigenvecs)\n", 160 | " for i, vec in enumerate(transformed.T):\n", 161 | " plt.quiver(0, 0, vec[0], vec[1], angles='xy', scale_units='xy',\n", 162 | " scale=1, color='red', label=f'Transformed {i+1}')\n", 163 | " \n", 164 | " plt.xlim(-2, 2)\n", 165 | " plt.ylim(-2, 2)\n", 166 | " plt.grid(True)\n", 167 | " plt.legend()\n", 168 | " plt.title('Eigenvectors and Their Transformations')\n", 169 | " plt.show()\n", 170 | " \n", 171 | " print(\"Eigenvalues:\")\n", 172 | " for i, val in enumerate(eigenvals):\n", 173 | " print(f\"λ{i+1} = {val:.2f}\")\n", 174 | "\n", 175 | "# Example matrices\n", 176 | "matrices = [\n", 177 | " np.array([[2, 1], [1, 2]]), # Symmetric matrix\n", 178 | " np.array([[0, -1], [1, 0]]), # Rotation matrix\n", 179 | " np.array([[3, 1], [0, 2]]) # Upper triangular matrix\n", 180 | "]\n", 181 | "\n", 182 | "for matrix in matrices:\n", 183 | " print(f\"\\nMatrix:\\n{matrix}\")\n", 184 | " plot_eigenvectors(matrix)" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "kernelspec": { 190 | "display_name": "Python 3", 191 | "language": "python", 192 | "name": "python3" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 4 197 | } 198 | -------------------------------------------------------------------------------- /examples/notebooks/model_comparison.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# rStar-Math Model Comparison\n", 8 | "\n", 9 | "This notebook demonstrates the performance comparison between different Language Models (LLMs) with and without rStar-Math enhancement." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "source": [ 17 | "import os\n", 18 | "import json\n", 19 | "from typing import Dict, List\n", 20 | "import pandas as pd\n", 21 | "import plotly.express as px\n", 22 | "from src.core.mcts import MCTS\n", 23 | "from src.core.ppm import ProcessPreferenceModel\n", 24 | "from src.models.model_interface import ModelFactory" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Setup Models\n", 32 | "\n", 33 | "First, let's set up our LLMs with API keys." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "source": [ 41 | "# Load API keys from environment variables\n", 42 | "api_keys = {\n", 43 | " 'openai': os.getenv('OPENAI_API_KEY'),\n", 44 | " 'anthropic': os.getenv('ANTHROPIC_API_KEY'),\n", 45 | " 'mistral': os.getenv('MISTRAL_API_KEY'),\n", 46 | " 'groq': os.getenv('GROQ_API_KEY'),\n", 47 | " 'gemini': os.getenv('GEMINI_API_KEY')\n", 48 | "}\n", 49 | "\n", 50 | "# Initialize models\n", 51 | "models = {}\n", 52 | "for model_name, api_key in api_keys.items():\n", 53 | " if api_key:\n", 54 | " models[model_name] = ModelFactory.create_model(\n", 55 | " model_name,\n", 56 | " api_key,\n", 57 | " 'config/default.json'\n", 58 | " )" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "## Initialize rStar-Math Components" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "source": [ 73 | "mcts = MCTS.from_config_file('config/default.json')\n", 74 | "ppm = ProcessPreferenceModel.from_config_file('config/default.json')" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## Test Problems\n", 82 | "\n", 83 | "Let's define some test problems of varying difficulty." 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "source": [ 91 | "test_problems = [\n", 92 | " \"What is 2 + 2?\", # Simple arithmetic\n", 93 | " \"Solve for x: 2x + 3 = 7\", # Basic algebra\n", 94 | " \"Find the derivative of f(x) = x^2 + 3x\", # Calculus\n", 95 | " \"In a group of 30 people, 40% are men. How many women are there?\", # Word problem\n", 96 | " \"Prove that the square root of 2 is irrational\" # Mathematical proof\n", 97 | "]" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "## Compare Model Performance" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "source": [ 112 | "def solve_problem(model, problem: str, use_rstar: bool = False) -> Dict:\n", 113 | " \"\"\"Solve a problem with or without rStar-Math.\"\"\"\n", 114 | " if use_rstar:\n", 115 | " action, trajectory = mcts.search(problem)\n", 116 | " solution_steps = [step['state'] for step in trajectory]\n", 117 | " score = sum(ppm.evaluate_step(step, model) for step in solution_steps) / len(solution_steps)\n", 118 | " else:\n", 119 | " solution = model.generate_response(problem)\n", 120 | " solution_steps = [solution]\n", 121 | " score = model.evaluate_reasoning(problem, solution_steps)\n", 122 | " \n", 123 | " return {\n", 124 | " 'solution': '\\n'.join(solution_steps),\n", 125 | " 'score': score\n", 126 | " }\n", 127 | "\n", 128 | "# Collect results\n", 129 | "results = []\n", 130 | "for problem in test_problems:\n", 131 | " for model_name, model in models.items():\n", 132 | " # Without rStar-Math\n", 133 | " direct_result = solve_problem(model, problem, use_rstar=False)\n", 134 | " results.append({\n", 135 | " 'problem': problem,\n", 136 | " 'model': model_name,\n", 137 | " 'method': 'Direct',\n", 138 | " 'score': direct_result['score']\n", 139 | " })\n", 140 | " \n", 141 | " # With rStar-Math\n", 142 | " rstar_result = solve_problem(model, problem, use_rstar=True)\n", 143 | " results.append({\n", 144 | " 'problem': problem,\n", 145 | " 'model': model_name,\n", 146 | " 'method': 'rStar-Math',\n", 147 | " 'score': rstar_result['score']\n", 148 | " })\n", 149 | "\n", 150 | "# Create DataFrame\n", 151 | "df = pd.DataFrame(results)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "## Visualize Results" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "source": [ 166 | "# Overall comparison\n", 167 | "fig = px.box(df, x='model', y='score', color='method',\n", 168 | " title='Model Performance Comparison')\n", 169 | "fig.show()\n", 170 | "\n", 171 | "# Problem-specific comparison\n", 172 | "fig = px.bar(df, x='model', y='score', color='method',\n", 173 | " facet_row='problem', barmode='group',\n", 174 | " title='Model Performance by Problem Type')\n", 175 | "fig.show()\n", 176 | "\n", 177 | "# Calculate improvement statistics\n", 178 | "improvements = df[df['method'] == 'rStar-Math']['score'].mean() - \\\n", 179 | " df[df['method'] == 'Direct']['score'].mean()\n", 180 | "print(f\"Average improvement with rStar-Math: {improvements:.2%}\")" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "## Analyze Solution Steps\n", 188 | "\n", 189 | "Let's look at detailed solution steps for a specific problem." 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "source": [ 197 | "def analyze_solution(model, problem: str) -> None:\n", 198 | " \"\"\"Compare direct vs rStar-Math solutions.\"\"\"\n", 199 | " print(f\"Problem: {problem}\\n\")\n", 200 | " \n", 201 | " # Direct solution\n", 202 | " print(\"Direct Solution:\")\n", 203 | " direct_result = solve_problem(model, problem, use_rstar=False)\n", 204 | " print(direct_result['solution'])\n", 205 | " print(f\"Confidence Score: {direct_result['score']:.2f}\\n\")\n", 206 | " \n", 207 | " # rStar-Math solution\n", 208 | " print(\"rStar-Math Enhanced Solution:\")\n", 209 | " rstar_result = solve_problem(model, problem, use_rstar=True)\n", 210 | " print(rstar_result['solution'])\n", 211 | " print(f\"Confidence Score: {rstar_result['score']:.2f}\")\n", 212 | "\n", 213 | "# Analyze a complex problem\n", 214 | "complex_problem = test_problems[-1] # Proof problem\n", 215 | "model = models['openai'] # Use GPT-4 for demonstration\n", 216 | "analyze_solution(model, complex_problem)" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.8.0" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 4 241 | } 242 | -------------------------------------------------------------------------------- /examples/notebooks/statistics_examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# rStar-Math: Statistics and Probability Examples\n", 8 | "\n", 9 | "This notebook demonstrates statistical analysis and probability problems with visualizations." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "source": [ 17 | "import os\n", 18 | "import numpy as np\n", 19 | "import pandas as pd\n", 20 | "import seaborn as sns\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "from scipy import stats\n", 23 | "from src.core.mcts import MCTS\n", 24 | "from src.core.ppm import ProcessPreferenceModel\n", 25 | "from src.models.model_interface import ModelFactory" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "source": [ 33 | "# Initialize components\n", 34 | "mcts = MCTS.from_config_file('config/default.json')\n", 35 | "ppm = ProcessPreferenceModel.from_config_file('config/default.json')\n", 36 | "model = ModelFactory.create_model('openai', os.getenv('OPENAI_API_KEY'), 'config/default.json')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## 1. Descriptive Statistics" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "source": [ 51 | "stats_problems = [\n", 52 | " \"Find the mean, median, and mode of [2, 3, 3, 4, 4, 4, 5, 5, 6]\",\n", 53 | " \"Calculate the standard deviation of [10, 12, 15, 18, 20]\",\n", 54 | " \"Find the quartiles and IQR of [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\"\n", 55 | "]\n", 56 | "\n", 57 | "def visualize_distribution(data: list):\n", 58 | " \"\"\"Visualize data distribution.\"\"\"\n", 59 | " plt.figure(figsize=(12, 4))\n", 60 | " \n", 61 | " # Histogram\n", 62 | " plt.subplot(131)\n", 63 | " plt.hist(data, bins='auto', alpha=0.7)\n", 64 | " plt.title('Histogram')\n", 65 | " \n", 66 | " # Box plot\n", 67 | " plt.subplot(132)\n", 68 | " plt.boxplot(data)\n", 69 | " plt.title('Box Plot')\n", 70 | " \n", 71 | " # KDE plot\n", 72 | " plt.subplot(133)\n", 73 | " sns.kdeplot(data=data)\n", 74 | " plt.title('Density Plot')\n", 75 | " \n", 76 | " plt.tight_layout()\n", 77 | " plt.show()\n", 78 | "\n", 79 | "for problem in stats_problems:\n", 80 | " print(f\"Problem: {problem}\\n\")\n", 81 | " action, trajectory = mcts.search(problem)\n", 82 | " \n", 83 | " print(\"Solution Steps:\")\n", 84 | " for step in trajectory:\n", 85 | " confidence = ppm.evaluate_step(step['state'], model)\n", 86 | " print(f\"- {step['state']}\")\n", 87 | " print(f\" Confidence: {confidence:.2f}\\n\")\n", 88 | " \n", 89 | " # Visualize data if present\n", 90 | " if '[' in problem:\n", 91 | " data = eval(problem[problem.find('['):problem.find(']')+1])\n", 92 | " visualize_distribution(data)\n", 93 | " print(\"-\" * 50 + \"\\n\")" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "## 2. Probability Distributions" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "source": [ 108 | "def plot_probability_distribution(dist_type: str, params: dict):\n", 109 | " \"\"\"Plot various probability distributions.\"\"\"\n", 110 | " plt.figure(figsize=(10, 6))\n", 111 | " \n", 112 | " if dist_type == 'normal':\n", 113 | " x = np.linspace(params['mu'] - 4*params['sigma'],\n", 114 | " params['mu'] + 4*params['sigma'], 100)\n", 115 | " y = stats.norm.pdf(x, params['mu'], params['sigma'])\n", 116 | " plt.plot(x, y, label=f'μ={params[\"mu\"]}, σ={params[\"sigma\"]}')\n", 117 | " plt.title('Normal Distribution')\n", 118 | " \n", 119 | " elif dist_type == 'binomial':\n", 120 | " x = np.arange(0, params['n']+1)\n", 121 | " y = stats.binom.pmf(x, params['n'], params['p'])\n", 122 | " plt.bar(x, y, alpha=0.8, label=f'n={params[\"n\"]}, p={params[\"p\"]}')\n", 123 | " plt.title('Binomial Distribution')\n", 124 | " \n", 125 | " elif dist_type == 'poisson':\n", 126 | " x = np.arange(0, params['lambda']*3)\n", 127 | " y = stats.poisson.pmf(x, params['lambda'])\n", 128 | " plt.bar(x, y, alpha=0.8, label=f'λ={params[\"lambda\"]}')\n", 129 | " plt.title('Poisson Distribution')\n", 130 | " \n", 131 | " plt.grid(True, alpha=0.3)\n", 132 | " plt.legend()\n", 133 | " plt.show()\n", 134 | "\n", 135 | "# Example distributions\n", 136 | "distributions = [\n", 137 | " ('normal', {'mu': 0, 'sigma': 1}),\n", 138 | " ('normal', {'mu': 0, 'sigma': 2}),\n", 139 | " ('binomial', {'n': 20, 'p': 0.3}),\n", 140 | " ('poisson', {'lambda': 3})\n", 141 | "]\n", 142 | "\n", 143 | "for dist_type, params in distributions:\n", 144 | " plot_probability_distribution(dist_type, params)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "## 3. Hypothesis Testing" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "source": [ 159 | "def visualize_hypothesis_test(sample1: np.ndarray, sample2: np.ndarray, test_type: str):\n", 160 | " \"\"\"Visualize hypothesis test results.\"\"\"\n", 161 | " plt.figure(figsize=(12, 5))\n", 162 | " \n", 163 | " # Data distribution\n", 164 | " plt.subplot(121)\n", 165 | " plt.boxplot([sample1, sample2], labels=['Sample 1', 'Sample 2'])\n", 166 | " plt.title('Sample Distributions')\n", 167 | " \n", 168 | " # Test results\n", 169 | " if test_type == 't-test':\n", 170 | " t_stat, p_val = stats.ttest_ind(sample1, sample2)\n", 171 | " test_name = \"Student's t-test\"\n", 172 | " elif test_type == 'wilcoxon':\n", 173 | " t_stat, p_val = stats.wilcoxon(sample1, sample2)\n", 174 | " test_name = \"Wilcoxon signed-rank test\"\n", 175 | " \n", 176 | " plt.subplot(122)\n", 177 | " plt.text(0.5, 0.5,\n", 178 | " f\"Test: {test_name}\\n\" +\n", 179 | " f\"Statistic: {t_stat:.4f}\\n\" +\n", 180 | " f\"p-value: {p_val:.4f}\\n\" +\n", 181 | " f\"Significant: {p_val < 0.05}\",\n", 182 | " ha='center', va='center')\n", 183 | " plt.axis('off')\n", 184 | " \n", 185 | " plt.tight_layout()\n", 186 | " plt.show()\n", 187 | "\n", 188 | "# Example hypothesis tests\n", 189 | "np.random.seed(42)\n", 190 | "sample1 = np.random.normal(0, 1, 100)\n", 191 | "sample2 = np.random.normal(0.5, 1, 100)\n", 192 | "\n", 193 | "print(\"Comparing two samples with different means:\")\n", 194 | "visualize_hypothesis_test(sample1, sample2, 't-test')\n", 195 | "\n", 196 | "print(\"\\nNon-parametric test:\")\n", 197 | "visualize_hypothesis_test(sample1, sample2, 'wilcoxon')" 198 | ] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "Python 3", 204 | "language": "python", 205 | "name": "python3" 206 | } 207 | }, 208 | "nbformat": 4, 209 | "nbformat_minor": 4 210 | } 211 | -------------------------------------------------------------------------------- /examples/rasa_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example integration with Rasa chatbot framework 3 | """ 4 | from typing import Any, Text, Dict, List 5 | from rasa_sdk import Action, Tracker 6 | from rasa_sdk.executor import CollectingDispatcher 7 | 8 | from src.core.mcts import MCTS 9 | from src.core.ppm import ProcessPreferenceModel 10 | from src.models.model_interface import ModelFactory 11 | 12 | class RStarMathAction(Action): 13 | def __init__(self): 14 | super().__init__() 15 | self.mcts = MCTS.from_config_file("config/default.json") 16 | self.ppm = ProcessPreferenceModel.from_config_file("config/default.json") 17 | self.model = ModelFactory.create_model( 18 | "openai", 19 | "YOUR_API_KEY", 20 | "config/default.json" 21 | ) 22 | 23 | def name(self) -> Text: 24 | return "action_solve_math" 25 | 26 | async def run( 27 | self, 28 | dispatcher: CollectingDispatcher, 29 | tracker: Tracker, 30 | domain: Dict[Text, Any] 31 | ) -> List[Dict[Text, Any]]: 32 | # Get the math problem from user message 33 | problem = tracker.latest_message.get("text") 34 | 35 | try: 36 | # Solve using rStar-Math 37 | action, trajectory = self.mcts.search(problem) 38 | 39 | # Format solution steps 40 | steps = [] 41 | for step in trajectory: 42 | step_text = step["state"] 43 | step_score = self.ppm.evaluate_step(step_text, self.model) 44 | steps.append(f"{step_text} (confidence: {step_score:.2f})") 45 | 46 | # Send response 47 | dispatcher.utter_message( 48 | text=f"Here's how I solved it:\n" + "\n".join(steps) 49 | ) 50 | 51 | except Exception as e: 52 | dispatcher.utter_message( 53 | text=f"I encountered an error while solving the problem: {str(e)}" 54 | ) 55 | 56 | return [] 57 | -------------------------------------------------------------------------------- /examples/streamlit_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Streamlit integration for rStar-Math 3 | """ 4 | import os 5 | import streamlit as st 6 | import plotly.express as px 7 | import pandas as pd 8 | from src.core.mcts import MCTS 9 | from src.core.ppm import ProcessPreferenceModel 10 | from src.models.model_interface import ModelFactory 11 | 12 | def initialize_components(): 13 | """Initialize rStar-Math components.""" 14 | st.session_state.mcts = MCTS.from_config_file('config/default.json') 15 | st.session_state.ppm = ProcessPreferenceModel.from_config_file('config/default.json') 16 | 17 | # Initialize models 18 | api_keys = { 19 | 'openai': os.getenv('OPENAI_API_KEY'), 20 | 'anthropic': os.getenv('ANTHROPIC_API_KEY'), 21 | 'mistral': os.getenv('MISTRAL_API_KEY'), 22 | 'groq': os.getenv('GROQ_API_KEY'), 23 | 'gemini': os.getenv('GEMINI_API_KEY') 24 | } 25 | 26 | st.session_state.models = {} 27 | for name, key in api_keys.items(): 28 | if key: 29 | st.session_state.models[name] = ModelFactory.create_model( 30 | name, key, 'config/default.json' 31 | ) 32 | 33 | def solve_problem(problem: str, model_name: str, use_rstar: bool = True): 34 | """Solve math problem with selected model.""" 35 | model = st.session_state.models[model_name] 36 | 37 | if use_rstar: 38 | action, trajectory = st.session_state.mcts.search(problem) 39 | solution_steps = [] 40 | confidence_scores = [] 41 | 42 | for step in trajectory: 43 | confidence = st.session_state.ppm.evaluate_step(step['state'], model) 44 | solution_steps.append(step['state']) 45 | confidence_scores.append(confidence) 46 | 47 | return solution_steps, confidence_scores 48 | else: 49 | solution = model.generate_response(problem) 50 | confidence = model.evaluate_reasoning(problem, [solution]) 51 | return [solution], [confidence] 52 | 53 | def main(): 54 | st.title("rStar-Math Demonstrator") 55 | 56 | # Initialize components if not done 57 | if 'mcts' not in st.session_state: 58 | initialize_components() 59 | 60 | # Sidebar settings 61 | st.sidebar.header("Settings") 62 | selected_model = st.sidebar.selectbox( 63 | "Select Model", 64 | options=list(st.session_state.models.keys()) 65 | ) 66 | 67 | use_rstar = st.sidebar.checkbox("Use rStar-Math", value=True) 68 | 69 | # Main interface 70 | st.header("Math Problem Solver") 71 | 72 | # Input section 73 | problem = st.text_area("Enter your math problem:") 74 | 75 | if st.button("Solve"): 76 | if problem: 77 | with st.spinner("Solving..."): 78 | solution_steps, confidence_scores = solve_problem( 79 | problem, selected_model, use_rstar 80 | ) 81 | 82 | # Display solution 83 | st.header("Solution") 84 | for i, (step, confidence) in enumerate(zip(solution_steps, confidence_scores), 1): 85 | st.markdown(f"**Step {i}:** {step}") 86 | st.progress(confidence) 87 | st.markdown(f"Confidence: {confidence:.2f}") 88 | 89 | # Plot confidence trend 90 | if len(confidence_scores) > 1: 91 | df = pd.DataFrame({ 92 | 'Step': range(1, len(confidence_scores) + 1), 93 | 'Confidence': confidence_scores 94 | }) 95 | fig = px.line(df, x='Step', y='Confidence', 96 | title='Solution Confidence Trend') 97 | st.plotly_chart(fig) 98 | 99 | # Overall statistics 100 | st.header("Solution Statistics") 101 | col1, col2, col3 = st.columns(3) 102 | col1.metric("Steps", len(solution_steps)) 103 | col2.metric("Average Confidence", f"{sum(confidence_scores)/len(confidence_scores):.2f}") 104 | col3.metric("Min Confidence", f"{min(confidence_scores):.2f}") 105 | else: 106 | st.error("Please enter a math problem") 107 | 108 | # Example problems 109 | st.header("Example Problems") 110 | examples = { 111 | "Basic Arithmetic": "What is 15 × 27?", 112 | "Algebra": "Solve for x: 2x + 5 = 13", 113 | "Calculus": "Find the derivative of f(x) = x² + 3x", 114 | "Geometry": "Find the area of a circle with radius 5", 115 | "Word Problem": "If a train travels 60 mph for 2.5 hours, how far does it go?" 116 | } 117 | 118 | if st.button("Try an Example"): 119 | example = st.selectbox("Select an example:", list(examples.keys())) 120 | st.text_area("Problem", examples[example], key="example_problem") 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /rStar-Math Paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-in-pm/rStar-Math/0a20d31f970b9f90e671e0e7af18b4d738915e8d/rStar-Math Paper.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.24.0 2 | pandas>=2.0.0 3 | torch>=2.0.0 4 | transformers>=4.30.0 5 | openai>=1.0.0 6 | anthropic>=0.3.0 7 | langchain>=0.1.0 8 | fastapi>=0.100.0 9 | uvicorn>=0.22.0 10 | python-dotenv>=1.0.0 11 | matplotlib>=3.7.0 12 | plotly>=5.14.0 13 | dash>=2.10.0 14 | pytest>=7.4.0 15 | black>=23.3.0 16 | isort>=5.12.0 17 | mypy>=1.3.0 18 | ruff>=0.0.270 19 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | rStar-Math Demonstrator 3 | A framework for enhanced mathematical reasoning using various LLMs 4 | """ 5 | 6 | __version__ = "0.1.0" 7 | -------------------------------------------------------------------------------- /src/api/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | FastAPI backend for rStar-Math Demonstrator 3 | """ 4 | from fastapi import FastAPI, HTTPException 5 | from pydantic import BaseModel 6 | from typing import List, Dict, Any, Optional 7 | import uvicorn 8 | 9 | from src.core.mcts import MCTS 10 | from src.core.ppm import ProcessPreferenceModel 11 | from src.models.model_interface import OpenAIModel, AnthropicModel 12 | 13 | app = FastAPI(title="rStar-Math Demonstrator API") 14 | 15 | class MathProblem(BaseModel): 16 | problem_text: str 17 | model_name: str 18 | use_rstar: bool = True 19 | mcts_simulations: int = 1000 20 | temperature: float = 0.7 21 | 22 | class SolutionResponse(BaseModel): 23 | solution_steps: List[str] 24 | confidence_score: float 25 | reasoning_path: List[Dict[str, Any]] 26 | execution_time: float 27 | 28 | @app.post("/solve", response_model=SolutionResponse) 29 | async def solve_problem(problem: MathProblem) -> SolutionResponse: 30 | """ 31 | Solve a mathematical problem using specified LLM with or without rStar-Math enhancement. 32 | """ 33 | try: 34 | # Implementation details to be added 35 | pass 36 | except Exception as e: 37 | raise HTTPException(status_code=500, detail=str(e)) 38 | 39 | @app.post("/compare-models") 40 | async def compare_models(problem: MathProblem) -> Dict[str, Any]: 41 | """ 42 | Compare solutions across different LLMs. 43 | """ 44 | try: 45 | # Implementation details to be added 46 | pass 47 | except Exception as e: 48 | raise HTTPException(status_code=500, detail=str(e)) 49 | 50 | @app.post("/generate-integration-code") 51 | async def generate_integration_code( 52 | framework: str, 53 | config: Dict[str, Any] 54 | ) -> Dict[str, str]: 55 | """ 56 | Generate integration code for specified framework. 57 | """ 58 | try: 59 | # Implementation details to be added 60 | pass 61 | except Exception as e: 62 | raise HTTPException(status_code=500, detail=str(e)) 63 | 64 | if __name__ == "__main__": 65 | uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) 66 | -------------------------------------------------------------------------------- /src/core/mcts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monte Carlo Tree Search (MCTS) implementation for mathematical reasoning 3 | """ 4 | from typing import List, Dict, Any, Optional, Tuple 5 | import numpy as np 6 | from dataclasses import dataclass 7 | import json 8 | from pathlib import Path 9 | 10 | @dataclass 11 | class MCTSConfig: 12 | exploration_weight: float = 1.0 13 | max_simulations: int = 1000 14 | max_depth: int = 10 15 | 16 | class MCTSNode: 17 | def __init__(self, state: str, parent: Optional['MCTSNode'] = None, action: Optional[str] = None): 18 | self.state = state 19 | self.parent = parent 20 | self.action = action 21 | self.children: List[MCTSNode] = [] 22 | self.visits = 0 23 | self.value = 0.0 24 | self.untried_actions: List[str] = [] 25 | 26 | def add_child(self, action: str, state: str) -> 'MCTSNode': 27 | """Add a child node with the given action and state.""" 28 | child = MCTSNode(state=state, parent=self, action=action) 29 | self.children.append(child) 30 | if action in self.untried_actions: 31 | self.untried_actions.remove(action) 32 | return child 33 | 34 | def update(self, reward: float) -> None: 35 | """Update node statistics with new reward.""" 36 | self.visits += 1 37 | self.value += (reward - self.value) / self.visits 38 | 39 | def get_ucb_score(self, exploration_weight: float) -> float: 40 | """Calculate UCB1 score for this node.""" 41 | if self.visits == 0: 42 | return float('inf') 43 | exploitation = self.value / self.visits 44 | exploration = exploration_weight * np.sqrt(2 * np.log(self.parent.visits) / self.visits) 45 | return exploitation + exploration 46 | 47 | def is_terminal(self) -> bool: 48 | """Check if this node represents a terminal state.""" 49 | # Implementation depends on problem domain 50 | return False 51 | 52 | def get_possible_actions(self) -> List[str]: 53 | """Get list of possible actions from this state.""" 54 | # Implementation depends on problem domain 55 | return [] 56 | 57 | class MCTS: 58 | def __init__(self, config: Optional[MCTSConfig] = None): 59 | self.config = config or MCTSConfig() 60 | 61 | @classmethod 62 | def from_config_file(cls, config_path: str) -> 'MCTS': 63 | """Create MCTS instance from config file.""" 64 | with open(config_path, 'r') as f: 65 | config_data = json.load(f) 66 | config = MCTSConfig(**config_data['mcts']) 67 | return cls(config) 68 | 69 | def select_action(self, node: MCTSNode) -> Tuple[MCTSNode, str]: 70 | """Select the best child node using UCB1.""" 71 | if not node.children: 72 | return node, "" 73 | 74 | ucb_scores = [ 75 | child.get_ucb_score(self.config.exploration_weight) 76 | for child in node.children 77 | ] 78 | selected_child = node.children[np.argmax(ucb_scores)] 79 | return selected_child, selected_child.action 80 | 81 | def expand(self, node: MCTSNode) -> Tuple[MCTSNode, str]: 82 | """Expand the current node with a new child.""" 83 | if not node.untried_actions: 84 | node.untried_actions = node.get_possible_actions() 85 | 86 | if not node.untried_actions: 87 | return node, "" 88 | 89 | action = np.random.choice(node.untried_actions) 90 | new_state = self.apply_action(node.state, action) 91 | child = node.add_child(action, new_state) 92 | return child, action 93 | 94 | def simulate(self, state: str, depth: int = 0) -> float: 95 | """Run a simulation from the current state.""" 96 | if depth >= self.config.max_depth or self.is_terminal_state(state): 97 | return self.evaluate_state(state) 98 | 99 | actions = self.get_possible_actions(state) 100 | if not actions: 101 | return self.evaluate_state(state) 102 | 103 | action = np.random.choice(actions) 104 | new_state = self.apply_action(state, action) 105 | return self.simulate(new_state, depth + 1) 106 | 107 | def backpropagate(self, node: MCTSNode, reward: float) -> None: 108 | """Update the values up the tree.""" 109 | while node is not None: 110 | node.update(reward) 111 | node = node.parent 112 | 113 | def search(self, root_state: str) -> Tuple[str, List[Dict[str, Any]]]: 114 | """Perform MCTS search to find the best action sequence.""" 115 | root = MCTSNode(state=root_state) 116 | trajectory = [] 117 | 118 | for _ in range(self.config.max_simulations): 119 | node = root 120 | 121 | # Selection 122 | while node.children and not node.untried_actions: 123 | node, action = self.select_action(node) 124 | trajectory.append({ 125 | "state": node.state, 126 | "action": action, 127 | "value": node.value, 128 | "visits": node.visits 129 | }) 130 | 131 | # Expansion 132 | if not node.is_terminal(): 133 | node, action = self.expand(node) 134 | trajectory.append({ 135 | "state": node.state, 136 | "action": action, 137 | "value": node.value, 138 | "visits": node.visits 139 | }) 140 | 141 | # Simulation 142 | reward = self.simulate(node.state) 143 | 144 | # Backpropagation 145 | self.backpropagate(node, reward) 146 | 147 | # Return best action sequence 148 | best_child = max(root.children, key=lambda c: c.visits) 149 | return best_child.action, trajectory 150 | 151 | def apply_action(self, state: str, action: str) -> str: 152 | """Apply an action to a state to get the next state.""" 153 | # Implementation depends on problem domain 154 | pass 155 | 156 | def evaluate_state(self, state: str) -> float: 157 | """Evaluate the value of a terminal state.""" 158 | # Implementation depends on problem domain 159 | pass 160 | 161 | def is_terminal_state(self, state: str) -> bool: 162 | """Check if a state is terminal.""" 163 | # Implementation depends on problem domain 164 | pass 165 | 166 | def get_possible_actions(self, state: str) -> List[str]: 167 | """Get possible actions for a state.""" 168 | # Implementation depends on problem domain 169 | pass 170 | -------------------------------------------------------------------------------- /src/core/ppm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Process Preference Model (PPM) for evaluating reasoning steps 3 | """ 4 | from typing import List, Dict, Any, Optional, Tuple 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from dataclasses import dataclass 9 | import json 10 | from pathlib import Path 11 | 12 | @dataclass 13 | class PPMConfig: 14 | input_dim: int = 768 15 | hidden_dim: int = 256 16 | learning_rate: float = 0.001 17 | batch_size: int = 32 18 | 19 | class StepEncoder(nn.Module): 20 | """Encodes reasoning steps into fixed-size vectors.""" 21 | def __init__(self, input_dim: int, hidden_dim: int): 22 | super().__init__() 23 | self.encoder = nn.Sequential( 24 | nn.Linear(input_dim, hidden_dim), 25 | nn.LayerNorm(hidden_dim), 26 | nn.ReLU(), 27 | nn.Linear(hidden_dim, hidden_dim), 28 | nn.LayerNorm(hidden_dim), 29 | nn.ReLU() 30 | ) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | return self.encoder(x) 34 | 35 | class ProcessPreferenceModel(nn.Module): 36 | def __init__(self, config: Optional[PPMConfig] = None): 37 | super().__init__() 38 | self.config = config or PPMConfig() 39 | 40 | # Step encoder 41 | self.step_encoder = StepEncoder( 42 | self.config.input_dim, 43 | self.config.hidden_dim 44 | ) 45 | 46 | # Value head 47 | self.value_head = nn.Sequential( 48 | nn.Linear(self.config.hidden_dim, self.config.hidden_dim // 2), 49 | nn.LayerNorm(self.config.hidden_dim // 2), 50 | nn.ReLU(), 51 | nn.Linear(self.config.hidden_dim // 2, 1) 52 | ) 53 | 54 | # Initialize optimizer 55 | self.optimizer = torch.optim.Adam( 56 | self.parameters(), 57 | lr=self.config.learning_rate 58 | ) 59 | 60 | @classmethod 61 | def from_config_file(cls, config_path: str) -> 'ProcessPreferenceModel': 62 | """Create PPM instance from config file.""" 63 | with open(config_path, 'r') as f: 64 | config_data = json.load(f) 65 | config = PPMConfig(**config_data['ppm']) 66 | return cls(config) 67 | 68 | def forward(self, x: torch.Tensor) -> torch.Tensor: 69 | """Forward pass through the model.""" 70 | encoded = self.step_encoder(x) 71 | value = self.value_head(encoded) 72 | return value 73 | 74 | def evaluate_step(self, step: str, embedder: Any) -> float: 75 | """Evaluate a single reasoning step.""" 76 | self.eval() 77 | with torch.no_grad(): 78 | # Convert step to embedding using provided embedder 79 | embedding = embedder.encode(step) 80 | embedding_tensor = torch.FloatTensor(embedding).unsqueeze(0) 81 | 82 | # Get value prediction 83 | value = self(embedding_tensor) 84 | return value.item() 85 | 86 | def train_step(self, 87 | preferred_steps: List[str], 88 | non_preferred_steps: List[str], 89 | embedder: Any) -> float: 90 | """Train the model on a batch of preferred vs non-preferred steps.""" 91 | self.train() 92 | 93 | # Convert steps to embeddings 94 | preferred_embeddings = torch.FloatTensor([ 95 | embedder.encode(step) for step in preferred_steps 96 | ]) 97 | non_preferred_embeddings = torch.FloatTensor([ 98 | embedder.encode(step) for step in non_preferred_steps 99 | ]) 100 | 101 | # Get value predictions 102 | preferred_values = self(preferred_embeddings) 103 | non_preferred_values = self(non_preferred_embeddings) 104 | 105 | # Compute preference loss (preferred steps should have higher values) 106 | loss = F.relu(non_preferred_values - preferred_values + 1.0).mean() 107 | 108 | # Backpropagation 109 | self.optimizer.zero_grad() 110 | loss.backward() 111 | self.optimizer.step() 112 | 113 | return loss.item() 114 | 115 | def save_model(self, path: str) -> None: 116 | """Save model state to file.""" 117 | torch.save({ 118 | 'model_state_dict': self.state_dict(), 119 | 'optimizer_state_dict': self.optimizer.state_dict(), 120 | 'config': self.config 121 | }, path) 122 | 123 | def load_model(self, path: str) -> None: 124 | """Load model state from file.""" 125 | checkpoint = torch.load(path) 126 | self.load_state_dict(checkpoint['model_state_dict']) 127 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 128 | self.config = checkpoint['config'] 129 | 130 | class PPMTrainer: 131 | def __init__(self, model: ProcessPreferenceModel): 132 | self.model = model 133 | 134 | def train(self, 135 | training_data: List[Dict[str, Any]], 136 | embedder: Any, 137 | num_epochs: int = 100, 138 | validation_data: Optional[List[Dict[str, Any]]] = None) -> Dict[str, List[float]]: 139 | """Train the PPM model on a dataset of preferred reasoning trajectories.""" 140 | train_losses = [] 141 | val_losses = [] 142 | 143 | for epoch in range(num_epochs): 144 | # Training 145 | epoch_loss = 0.0 146 | for batch in self._create_batches(training_data, self.model.config.batch_size): 147 | loss = self.model.train_step( 148 | batch['preferred'], 149 | batch['non_preferred'], 150 | embedder 151 | ) 152 | epoch_loss += loss 153 | train_losses.append(epoch_loss / len(training_data)) 154 | 155 | # Validation 156 | if validation_data: 157 | val_loss = self._validate(validation_data, embedder) 158 | val_losses.append(val_loss) 159 | 160 | return { 161 | 'train_losses': train_losses, 162 | 'val_losses': val_losses if validation_data else [] 163 | } 164 | 165 | def _create_batches(self, 166 | data: List[Dict[str, Any]], 167 | batch_size: int) -> List[Dict[str, List[str]]]: 168 | """Create batches from training data.""" 169 | batches = [] 170 | for i in range(0, len(data), batch_size): 171 | batch_data = data[i:i + batch_size] 172 | batch = { 173 | 'preferred': [d['preferred'] for d in batch_data], 174 | 'non_preferred': [d['non_preferred'] for d in batch_data] 175 | } 176 | batches.append(batch) 177 | return batches 178 | 179 | def _validate(self, 180 | validation_data: List[Dict[str, Any]], 181 | embedder: Any) -> float: 182 | """Compute validation loss.""" 183 | self.model.eval() 184 | total_loss = 0.0 185 | 186 | with torch.no_grad(): 187 | for batch in self._create_batches(validation_data, self.model.config.batch_size): 188 | preferred_embeddings = torch.FloatTensor([ 189 | embedder.encode(step) for step in batch['preferred'] 190 | ]) 191 | non_preferred_embeddings = torch.FloatTensor([ 192 | embedder.encode(step) for step in batch['non_preferred'] 193 | ]) 194 | 195 | preferred_values = self.model(preferred_embeddings) 196 | non_preferred_values = self.model(non_preferred_embeddings) 197 | 198 | loss = F.relu(non_preferred_values - preferred_values + 1.0).mean() 199 | total_loss += loss.item() 200 | 201 | return total_loss / len(validation_data) 202 | -------------------------------------------------------------------------------- /src/dashboard/app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dash application for visualizing rStar-Math performance 3 | """ 4 | import dash 5 | from dash import dcc, html 6 | from dash.dependencies import Input, Output, State 7 | import plotly.graph_objs as go 8 | import plotly.express as px 9 | from typing import List, Dict, Any 10 | 11 | app = dash.Dash(__name__) 12 | 13 | app.layout = html.Div([ 14 | html.H1("rStar-Math Demonstrator Dashboard"), 15 | 16 | # Problem Input Section 17 | html.Div([ 18 | html.H2("Problem Input"), 19 | dcc.Textarea( 20 | id='problem-input', 21 | placeholder='Enter your mathematical problem here...', 22 | style={'width': '100%', 'height': 100} 23 | ), 24 | dcc.Dropdown( 25 | id='model-selector', 26 | options=[ 27 | {'label': 'OpenAI GPT-4', 'value': 'gpt4'}, 28 | {'label': 'Anthropic Claude', 'value': 'claude'}, 29 | {'label': 'Mistral', 'value': 'mistral'}, 30 | {'label': 'Groq', 'value': 'groq'}, 31 | {'label': 'Gemini', 'value': 'gemini'}, 32 | {'label': 'Cohere', 'value': 'cohere'}, 33 | {'label': 'Emergence', 'value': 'emergence'} 34 | ], 35 | multi=True, 36 | placeholder="Select models to compare" 37 | ), 38 | html.Button('Solve', id='solve-button', n_clicks=0) 39 | ]), 40 | 41 | # Results Section 42 | html.Div([ 43 | html.H2("Results"), 44 | dcc.Graph(id='performance-comparison'), 45 | html.Div(id='solution-steps') 46 | ]), 47 | 48 | # Code Generation Section 49 | html.Div([ 50 | html.H2("Integration Code Generator"), 51 | dcc.Dropdown( 52 | id='framework-selector', 53 | options=[ 54 | {'label': 'Rasa', 'value': 'rasa'}, 55 | {'label': 'Azure Bot Framework', 'value': 'azure'}, 56 | {'label': 'LangChain', 'value': 'langchain'} 57 | ], 58 | placeholder="Select framework for integration" 59 | ), 60 | html.Button('Generate Code', id='generate-code-button', n_clicks=0), 61 | dcc.Markdown(id='generated-code') 62 | ]) 63 | ]) 64 | 65 | @app.callback( 66 | [Output('performance-comparison', 'figure'), 67 | Output('solution-steps', 'children')], 68 | [Input('solve-button', 'n_clicks')], 69 | [State('problem-input', 'value'), 70 | State('model-selector', 'value')] 71 | ) 72 | def update_results(n_clicks: int, 73 | problem: str, 74 | selected_models: List[str]) -> tuple: 75 | """Update dashboard with problem solution results.""" 76 | if n_clicks == 0: 77 | return {}, [] 78 | 79 | # Implementation details to be added 80 | pass 81 | 82 | @app.callback( 83 | Output('generated-code', 'children'), 84 | [Input('generate-code-button', 'n_clicks')], 85 | [State('framework-selector', 'value')] 86 | ) 87 | def generate_code(n_clicks: int, framework: str) -> str: 88 | """Generate integration code for selected framework.""" 89 | if n_clicks == 0: 90 | return "" 91 | 92 | # Implementation details to be added 93 | pass 94 | 95 | if __name__ == '__main__': 96 | app.run_server(debug=True) 97 | -------------------------------------------------------------------------------- /src/models/gemini_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Google Gemini model implementation 3 | """ 4 | from typing import List, Dict, Any, Optional 5 | import google.generativeai as genai 6 | from .model_interface import LLMInterface, ModelConfig 7 | 8 | class GeminiModel(LLMInterface): 9 | def __init__(self, api_key: str, config: Optional[ModelConfig] = None): 10 | self.api_key = api_key 11 | genai.configure(api_key=api_key) 12 | self.config = config or ModelConfig(model="gemini-pro") 13 | self.model = genai.GenerativeModel(self.config.model) 14 | 15 | @classmethod 16 | def from_config_file(cls, config_path: str, api_key: str) -> 'GeminiModel': 17 | """Create Gemini model instance from config file.""" 18 | with open(config_path, 'r') as f: 19 | config_data = json.load(f) 20 | config = ModelConfig(**config_data['models']['gemini']) 21 | return cls(api_key, config) 22 | 23 | def generate_response(self, 24 | prompt: str, 25 | temperature: Optional[float] = None, 26 | max_tokens: Optional[int] = None) -> str: 27 | """Generate a response using Gemini API.""" 28 | response = self.model.generate_content( 29 | prompt, 30 | generation_config=genai.types.GenerationConfig( 31 | temperature=temperature or self.config.temperature, 32 | max_output_tokens=max_tokens or self.config.max_tokens 33 | ) 34 | ) 35 | 36 | return response.text 37 | 38 | def evaluate_reasoning(self, 39 | problem: str, 40 | solution_steps: List[str]) -> float: 41 | """Evaluate reasoning steps using Gemini.""" 42 | prompt = f""" 43 | Problem: {problem} 44 | Solution Steps: 45 | {chr(10).join(f'{i+1}. {step}' for i, step in enumerate(solution_steps))} 46 | 47 | Rate the quality of these solution steps from 0 to 1, where: 48 | 0 = completely incorrect or invalid reasoning 49 | 1 = perfect, clear, and mathematically sound reasoning 50 | 51 | Provide only the numerical rating. 52 | """ 53 | 54 | response = self.generate_response(prompt) 55 | try: 56 | rating = float(response.strip()) 57 | return max(0.0, min(1.0, rating)) 58 | except ValueError: 59 | return 0.0 60 | 61 | def embed_text(self, text: str) -> List[float]: 62 | """Generate embeddings using Gemini API.""" 63 | model = genai.GenerativeModel('embedding-001') 64 | result = model.embed_content(text=text) 65 | return result.embedding 66 | -------------------------------------------------------------------------------- /src/models/groq_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Groq model implementation 3 | """ 4 | from typing import List, Dict, Any, Optional 5 | import groq 6 | from .model_interface import LLMInterface, ModelConfig 7 | 8 | class GroqModel(LLMInterface): 9 | def __init__(self, api_key: str, config: Optional[ModelConfig] = None): 10 | self.api_key = api_key 11 | self.client = groq.Client(api_key=api_key) 12 | self.config = config or ModelConfig(model="mixtral-8x7b-32768") 13 | 14 | @classmethod 15 | def from_config_file(cls, config_path: str, api_key: str) -> 'GroqModel': 16 | """Create Groq model instance from config file.""" 17 | with open(config_path, 'r') as f: 18 | config_data = json.load(f) 19 | config = ModelConfig(**config_data['models']['groq']) 20 | return cls(api_key, config) 21 | 22 | def generate_response(self, 23 | prompt: str, 24 | temperature: Optional[float] = None, 25 | max_tokens: Optional[int] = None) -> str: 26 | """Generate a response using Groq API.""" 27 | response = self.client.chat.completions.create( 28 | model=self.config.model, 29 | messages=[{"role": "user", "content": prompt}], 30 | temperature=temperature or self.config.temperature, 31 | max_tokens=max_tokens or self.config.max_tokens 32 | ) 33 | 34 | return response.choices[0].message.content 35 | 36 | def evaluate_reasoning(self, 37 | problem: str, 38 | solution_steps: List[str]) -> float: 39 | """Evaluate reasoning steps using Groq.""" 40 | prompt = f""" 41 | Problem: {problem} 42 | Solution Steps: 43 | {chr(10).join(f'{i+1}. {step}' for i, step in enumerate(solution_steps))} 44 | 45 | Rate the quality of these solution steps from 0 to 1, where: 46 | 0 = completely incorrect or invalid reasoning 47 | 1 = perfect, clear, and mathematically sound reasoning 48 | 49 | Provide only the numerical rating. 50 | """ 51 | 52 | response = self.generate_response(prompt) 53 | try: 54 | rating = float(response.strip()) 55 | return max(0.0, min(1.0, rating)) 56 | except ValueError: 57 | return 0.0 58 | 59 | def embed_text(self, text: str) -> List[float]: 60 | """Generate embeddings using Groq API.""" 61 | # Note: Groq doesn't provide embedding API yet 62 | # For now, we'll use OpenAI's embedding API as a fallback 63 | from .model_interface import OpenAIModel 64 | openai_model = OpenAIModel(self.api_key) 65 | return openai_model.embed_text(text) 66 | -------------------------------------------------------------------------------- /src/models/mistral_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mistral AI model implementation 3 | """ 4 | from typing import List, Dict, Any, Optional 5 | import mistralai 6 | from mistralai.client import MistralClient 7 | from mistralai.models.chat_completion import ChatMessage 8 | from .model_interface import LLMInterface, ModelConfig 9 | 10 | class MistralModel(LLMInterface): 11 | def __init__(self, api_key: str, config: Optional[ModelConfig] = None): 12 | self.api_key = api_key 13 | self.client = MistralClient(api_key=api_key) 14 | self.config = config or ModelConfig(model="mistral-large-latest") 15 | 16 | @classmethod 17 | def from_config_file(cls, config_path: str, api_key: str) -> 'MistralModel': 18 | """Create Mistral model instance from config file.""" 19 | with open(config_path, 'r') as f: 20 | config_data = json.load(f) 21 | config = ModelConfig(**config_data['models']['mistral']) 22 | return cls(api_key, config) 23 | 24 | def generate_response(self, 25 | prompt: str, 26 | temperature: Optional[float] = None, 27 | max_tokens: Optional[int] = None) -> str: 28 | """Generate a response using Mistral API.""" 29 | messages = [ 30 | ChatMessage(role="user", content=prompt) 31 | ] 32 | 33 | response = self.client.chat( 34 | model=self.config.model, 35 | messages=messages, 36 | temperature=temperature or self.config.temperature, 37 | max_tokens=max_tokens or self.config.max_tokens 38 | ) 39 | 40 | return response.choices[0].message.content 41 | 42 | def evaluate_reasoning(self, 43 | problem: str, 44 | solution_steps: List[str]) -> float: 45 | """Evaluate reasoning steps using Mistral.""" 46 | prompt = f""" 47 | Problem: {problem} 48 | Solution Steps: 49 | {chr(10).join(f'{i+1}. {step}' for i, step in enumerate(solution_steps))} 50 | 51 | Rate the quality of these solution steps from 0 to 1, where: 52 | 0 = completely incorrect or invalid reasoning 53 | 1 = perfect, clear, and mathematically sound reasoning 54 | 55 | Provide only the numerical rating. 56 | """ 57 | 58 | response = self.generate_response(prompt) 59 | try: 60 | rating = float(response.strip()) 61 | return max(0.0, min(1.0, rating)) 62 | except ValueError: 63 | return 0.0 64 | 65 | def embed_text(self, text: str) -> List[float]: 66 | """Generate embeddings using Mistral API.""" 67 | response = self.client.embeddings( 68 | model="mistral-embed", 69 | input=text 70 | ) 71 | return response.data[0].embedding 72 | -------------------------------------------------------------------------------- /src/models/model_interface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interface for different LLM implementations 3 | """ 4 | from abc import ABC, abstractmethod 5 | from typing import List, Dict, Any, Optional 6 | import json 7 | from pathlib import Path 8 | import openai 9 | import anthropic 10 | from dataclasses import dataclass 11 | 12 | @dataclass 13 | class ModelConfig: 14 | model: str 15 | temperature: float = 0.7 16 | max_tokens: int = 1000 17 | 18 | class LLMInterface(ABC): 19 | @abstractmethod 20 | def generate_response(self, 21 | prompt: str, 22 | temperature: float = 0.7, 23 | max_tokens: int = 1000) -> str: 24 | """Generate a response from the LLM.""" 25 | pass 26 | 27 | @abstractmethod 28 | def evaluate_reasoning(self, 29 | problem: str, 30 | solution_steps: List[str]) -> float: 31 | """Evaluate the quality of reasoning steps.""" 32 | pass 33 | 34 | @abstractmethod 35 | def embed_text(self, text: str) -> List[float]: 36 | """Generate embeddings for text.""" 37 | pass 38 | 39 | class OpenAIModel(LLMInterface): 40 | def __init__(self, api_key: str, config: Optional[ModelConfig] = None): 41 | self.api_key = api_key 42 | openai.api_key = api_key 43 | self.config = config or ModelConfig(model="gpt-4") 44 | 45 | @classmethod 46 | def from_config_file(cls, config_path: str, api_key: str) -> 'OpenAIModel': 47 | """Create OpenAI model instance from config file.""" 48 | with open(config_path, 'r') as f: 49 | config_data = json.load(f) 50 | config = ModelConfig(**config_data['models']['openai']) 51 | return cls(api_key, config) 52 | 53 | def generate_response(self, 54 | prompt: str, 55 | temperature: Optional[float] = None, 56 | max_tokens: Optional[int] = None) -> str: 57 | """Generate a response using OpenAI API.""" 58 | response = openai.ChatCompletion.create( 59 | model=self.config.model, 60 | messages=[{"role": "user", "content": prompt}], 61 | temperature=temperature or self.config.temperature, 62 | max_tokens=max_tokens or self.config.max_tokens 63 | ) 64 | return response.choices[0].message.content 65 | 66 | def evaluate_reasoning(self, 67 | problem: str, 68 | solution_steps: List[str]) -> float: 69 | """Evaluate reasoning steps using OpenAI.""" 70 | prompt = f""" 71 | Problem: {problem} 72 | Solution Steps: 73 | {chr(10).join(f'{i+1}. {step}' for i, step in enumerate(solution_steps))} 74 | 75 | Rate the quality of these solution steps from 0 to 1, where: 76 | 0 = completely incorrect or invalid reasoning 77 | 1 = perfect, clear, and mathematically sound reasoning 78 | 79 | Provide only the numerical rating. 80 | """ 81 | 82 | response = self.generate_response(prompt) 83 | try: 84 | rating = float(response.strip()) 85 | return max(0.0, min(1.0, rating)) 86 | except ValueError: 87 | return 0.0 88 | 89 | def embed_text(self, text: str) -> List[float]: 90 | """Generate embeddings using OpenAI API.""" 91 | response = openai.Embedding.create( 92 | model="text-embedding-ada-002", 93 | input=text 94 | ) 95 | return response.data[0].embedding 96 | 97 | class AnthropicModel(LLMInterface): 98 | def __init__(self, api_key: str, config: Optional[ModelConfig] = None): 99 | self.api_key = api_key 100 | self.client = anthropic.Client(api_key=api_key) 101 | self.config = config or ModelConfig(model="claude-2") 102 | 103 | @classmethod 104 | def from_config_file(cls, config_path: str, api_key: str) -> 'AnthropicModel': 105 | """Create Anthropic model instance from config file.""" 106 | with open(config_path, 'r') as f: 107 | config_data = json.load(f) 108 | config = ModelConfig(**config_data['models']['anthropic']) 109 | return cls(api_key, config) 110 | 111 | def generate_response(self, 112 | prompt: str, 113 | temperature: Optional[float] = None, 114 | max_tokens: Optional[int] = None) -> str: 115 | """Generate a response using Anthropic API.""" 116 | response = self.client.messages.create( 117 | model=self.config.model, 118 | messages=[{"role": "user", "content": prompt}], 119 | temperature=temperature or self.config.temperature, 120 | max_tokens=max_tokens or self.config.max_tokens 121 | ) 122 | return response.content 123 | 124 | def evaluate_reasoning(self, 125 | problem: str, 126 | solution_steps: List[str]) -> float: 127 | """Evaluate reasoning steps using Anthropic.""" 128 | prompt = f""" 129 | Problem: {problem} 130 | Solution Steps: 131 | {chr(10).join(f'{i+1}. {step}' for i, step in enumerate(solution_steps))} 132 | 133 | Rate the quality of these solution steps from 0 to 1, where: 134 | 0 = completely incorrect or invalid reasoning 135 | 1 = perfect, clear, and mathematically sound reasoning 136 | 137 | Provide only the numerical rating. 138 | """ 139 | 140 | response = self.generate_response(prompt) 141 | try: 142 | rating = float(response.strip()) 143 | return max(0.0, min(1.0, rating)) 144 | except ValueError: 145 | return 0.0 146 | 147 | def embed_text(self, text: str) -> List[float]: 148 | """Generate embeddings using Anthropic API.""" 149 | # Note: Anthropic doesn't provide embedding API yet 150 | # For now, we'll use OpenAI's embedding API as a fallback 151 | openai_model = OpenAIModel(self.api_key) 152 | return openai_model.embed_text(text) 153 | 154 | # Additional model implementations (Mistral, Groq, Gemini, Cohere, Emergence) 155 | # would follow similar pattern 156 | 157 | class ModelFactory: 158 | """Factory for creating LLM instances.""" 159 | 160 | @staticmethod 161 | def create_model(model_type: str, 162 | api_key: str, 163 | config_path: Optional[str] = None) -> LLMInterface: 164 | """Create a model instance based on type.""" 165 | if config_path: 166 | if model_type.lower() == "openai": 167 | return OpenAIModel.from_config_file(config_path, api_key) 168 | elif model_type.lower() == "anthropic": 169 | return AnthropicModel.from_config_file(config_path, api_key) 170 | # Add other model types here 171 | else: 172 | if model_type.lower() == "openai": 173 | return OpenAIModel(api_key) 174 | elif model_type.lower() == "anthropic": 175 | return AnthropicModel(api_key) 176 | # Add other model types here 177 | 178 | raise ValueError(f"Unknown model type: {model_type}") 179 | -------------------------------------------------------------------------------- /src/utils/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for rStar-Math Demonstrator 3 | """ 4 | from typing import List, Dict, Any, Optional 5 | import json 6 | import time 7 | from pathlib import Path 8 | 9 | def load_config(config_path: str) -> Dict[str, Any]: 10 | """Load configuration from JSON file.""" 11 | with open(config_path, 'r') as f: 12 | return json.load(f) 13 | 14 | def format_solution_steps(steps: List[str]) -> str: 15 | """Format solution steps for display.""" 16 | return "\n".join(f"{i+1}. {step}" for i, step in enumerate(steps)) 17 | 18 | def calculate_metrics(predictions: List[float], 19 | targets: List[float]) -> Dict[str, float]: 20 | """Calculate performance metrics.""" 21 | if not predictions or not targets: 22 | return {} 23 | 24 | accuracy = sum(1 for p, t in zip(predictions, targets) if abs(p - t) < 1e-6) / len(predictions) 25 | mse = sum((p - t) ** 2 for p, t in zip(predictions, targets)) / len(predictions) 26 | 27 | return { 28 | "accuracy": accuracy, 29 | "mse": mse 30 | } 31 | 32 | def generate_framework_template(framework: str, 33 | config: Dict[str, Any]) -> str: 34 | """Generate integration code template for specified framework.""" 35 | templates = { 36 | "rasa": """ 37 | from rasa_sdk import Action 38 | from src.core.mcts import MCTS 39 | from src.core.ppm import ProcessPreferenceModel 40 | 41 | class RStarMathAction(Action): 42 | def name(self) -> str: 43 | return "action_solve_math" 44 | 45 | def run(self, dispatcher, tracker, domain): 46 | # Implementation details 47 | pass 48 | """, 49 | "azure": """ 50 | from botbuilder.core import TurnContext 51 | from src.core.mcts import MCTS 52 | from src.core.ppm import ProcessPreferenceModel 53 | 54 | class RStarMathBot: 55 | async def on_message_activity(self, turn_context: TurnContext): 56 | # Implementation details 57 | pass 58 | """, 59 | "langchain": """ 60 | from langchain import LLMChain 61 | from src.core.mcts import MCTS 62 | from src.core.ppm import ProcessPreferenceModel 63 | 64 | class RStarMathChain(LLMChain): 65 | def __init__(self): 66 | # Implementation details 67 | pass 68 | 69 | def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: 70 | # Implementation details 71 | pass 72 | """ 73 | } 74 | 75 | return templates.get(framework, "Framework template not found") 76 | 77 | def log_execution(func_name: str, 78 | start_time: float, 79 | end_time: float, 80 | success: bool, 81 | error: Optional[str] = None) -> None: 82 | """Log execution details for monitoring.""" 83 | execution_time = end_time - start_time 84 | log_entry = { 85 | "function": func_name, 86 | "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), 87 | "execution_time": execution_time, 88 | "success": success, 89 | "error": error 90 | } 91 | 92 | # Implementation details for logging to be added 93 | pass 94 | -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for API endpoints 3 | """ 4 | import pytest 5 | from fastapi.testclient import TestClient 6 | from unittest.mock import Mock, patch 7 | 8 | from src.api.main import app 9 | from src.core.mcts import MCTS 10 | from src.core.ppm import ProcessPreferenceModel 11 | from src.models.model_interface import OpenAIModel 12 | 13 | client = TestClient(app) 14 | 15 | @pytest.fixture 16 | def mock_components(): 17 | with patch('src.core.mcts.MCTS') as mock_mcts, \ 18 | patch('src.core.ppm.ProcessPreferenceModel') as mock_ppm, \ 19 | patch('src.models.model_interface.OpenAIModel') as mock_model: 20 | 21 | # Configure MCTS mock 22 | mock_mcts_instance = Mock() 23 | mock_mcts_instance.search.return_value = ("test_action", [{"state": "test", "value": 1.0}]) 24 | mock_mcts.return_value = mock_mcts_instance 25 | 26 | # Configure PPM mock 27 | mock_ppm_instance = Mock() 28 | mock_ppm_instance.evaluate_step.return_value = 0.8 29 | mock_ppm.return_value = mock_ppm_instance 30 | 31 | # Configure Model mock 32 | mock_model_instance = Mock() 33 | mock_model_instance.generate_response.return_value = "4" 34 | mock_model_instance.evaluate_reasoning.return_value = 0.9 35 | mock_model.return_value = mock_model_instance 36 | 37 | yield mock_mcts_instance, mock_ppm_instance, mock_model_instance 38 | 39 | def test_solve_endpoint(mock_components): 40 | mcts_mock, ppm_mock, model_mock = mock_components 41 | 42 | response = client.post( 43 | "/solve", 44 | json={ 45 | "problem_text": "What is 2 + 2?", 46 | "model_name": "gpt-4", 47 | "use_rstar": True, 48 | "mcts_simulations": 100, 49 | "temperature": 0.7 50 | } 51 | ) 52 | 53 | assert response.status_code == 200 54 | data = response.json() 55 | assert "solution_steps" in data 56 | assert "confidence_score" in data 57 | assert "reasoning_path" in data 58 | assert "execution_time" in data 59 | 60 | def test_compare_models_endpoint(mock_components): 61 | mcts_mock, ppm_mock, model_mock = mock_components 62 | 63 | response = client.post( 64 | "/compare-models", 65 | json={ 66 | "problem_text": "What is 2 + 2?", 67 | "model_name": "all", 68 | "use_rstar": True, 69 | "mcts_simulations": 100, 70 | "temperature": 0.7 71 | } 72 | ) 73 | 74 | assert response.status_code == 200 75 | data = response.json() 76 | assert isinstance(data, dict) 77 | assert "openai" in data 78 | assert "anthropic" in data 79 | 80 | for model_results in data.values(): 81 | assert "solution" in model_results 82 | assert "score" in model_results 83 | assert "execution_time" in model_results 84 | 85 | def test_generate_integration_code_endpoint(mock_components): 86 | response = client.post( 87 | "/generate-integration-code", 88 | json={ 89 | "framework": "rasa", 90 | "config": { 91 | "model": "gpt-4", 92 | "use_rstar": True 93 | } 94 | } 95 | ) 96 | 97 | assert response.status_code == 200 98 | data = response.json() 99 | assert isinstance(data, dict) 100 | assert "code" in data 101 | assert "rstar_math" in data["code"].lower() 102 | 103 | def test_invalid_input(): 104 | response = client.post( 105 | "/solve", 106 | json={ 107 | "invalid_field": "test" 108 | } 109 | ) 110 | assert response.status_code == 422 111 | 112 | def test_error_handling(mock_components): 113 | mcts_mock, ppm_mock, model_mock = mock_components 114 | model_mock.generate_response.side_effect = Exception("API Error") 115 | 116 | response = client.post( 117 | "/solve", 118 | json={ 119 | "problem_text": "What is 2 + 2?", 120 | "model_name": "gpt-4", 121 | "use_rstar": True 122 | } 123 | ) 124 | 125 | assert response.status_code == 500 126 | data = response.json() 127 | assert "detail" in data 128 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for core rStar-Math functionality 3 | """ 4 | import pytest 5 | from src.core.mcts import MCTS, MCTSNode 6 | from src.core.ppm import ProcessPreferenceModel 7 | 8 | def test_mcts_node_creation(): 9 | node = MCTSNode(state="initial") 10 | assert node.state == "initial" 11 | assert node.parent is None 12 | assert len(node.children) == 0 13 | assert node.visits == 0 14 | assert node.value == 0.0 15 | 16 | def test_mcts_selection(): 17 | mcts = MCTS(exploration_weight=1.0) 18 | root = MCTSNode(state="root") 19 | child = root.add_child(action="test", state="child") 20 | 21 | # Update values 22 | root.visits = 10 23 | child.visits = 5 24 | child.value = 2.5 25 | 26 | selected = mcts.select_action(root) 27 | assert selected == child 28 | 29 | def test_process_preference_model(): 30 | model = ProcessPreferenceModel(input_dim=10) 31 | # Add more specific tests based on implementation 32 | pass 33 | 34 | def test_math_problem_solving(): 35 | """Integration test for solving a simple math problem.""" 36 | problem = "What is 2 + 2?" 37 | mcts = MCTS() 38 | ppm = ProcessPreferenceModel(input_dim=10) 39 | 40 | # Test implementation to be added 41 | pass 42 | 43 | if __name__ == "__main__": 44 | pytest.main([__file__]) 45 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration tests for rStar-Math components 3 | """ 4 | import pytest 5 | from unittest.mock import Mock, patch 6 | import torch 7 | import numpy as np 8 | 9 | from src.core.mcts import MCTS, MCTSConfig 10 | from src.core.ppm import ProcessPreferenceModel, PPMConfig 11 | from src.models.model_interface import OpenAIModel 12 | 13 | class TestProblemSolver: 14 | @pytest.fixture 15 | def setup_components(self): 16 | mcts_config = MCTSConfig( 17 | exploration_weight=1.0, 18 | max_simulations=100, 19 | max_depth=5 20 | ) 21 | ppm_config = PPMConfig( 22 | input_dim=768, 23 | hidden_dim=256, 24 | learning_rate=0.001, 25 | batch_size=16 26 | ) 27 | 28 | mcts = MCTS(mcts_config) 29 | ppm = ProcessPreferenceModel(ppm_config) 30 | 31 | with patch('openai.ChatCompletion.create') as mock_create: 32 | mock_create.return_value = Mock( 33 | choices=[Mock(message=Mock(content="Test response"))] 34 | ) 35 | model = OpenAIModel("test-key") 36 | 37 | return mcts, ppm, model 38 | 39 | def test_solve_simple_problem(self, setup_components): 40 | mcts, ppm, model = setup_components 41 | problem = "What is 2 + 2?" 42 | 43 | # Mock necessary methods 44 | mcts.get_possible_actions = Mock(return_value=["Add the numbers"]) 45 | mcts.apply_action = Mock(return_value="4") 46 | mcts.evaluate_state = Mock(return_value=1.0) 47 | 48 | action, trajectory = mcts.search(problem) 49 | assert action is not None 50 | assert isinstance(trajectory, list) 51 | 52 | def test_evaluate_solution(self, setup_components): 53 | mcts, ppm, model = setup_components 54 | problem = "What is 2 + 2?" 55 | solution_steps = [ 56 | "First, we identify this is an addition problem", 57 | "Then, we add 2 and 2 together", 58 | "Therefore, 2 + 2 = 4" 59 | ] 60 | 61 | # Mock embedding generation 62 | embeddings = torch.randn(len(solution_steps), 768) 63 | model.embed_text = Mock(side_effect=[e.numpy().tolist() for e in embeddings]) 64 | 65 | # Evaluate solution using PPM 66 | for step in solution_steps: 67 | score = ppm.evaluate_step(step, model) 68 | assert isinstance(score, float) 69 | assert 0 <= score <= 1 70 | 71 | def test_self_evolution(self, setup_components): 72 | mcts, ppm, model = setup_components 73 | problem = "What is 2 + 2?" 74 | 75 | # Generate initial solution 76 | mcts.get_possible_actions = Mock(return_value=["Add the numbers"]) 77 | mcts.apply_action = Mock(return_value="4") 78 | mcts.evaluate_state = Mock(return_value=1.0) 79 | 80 | initial_action, initial_trajectory = mcts.search(problem) 81 | 82 | # Evolve solution 83 | evolved_action, evolved_trajectory = mcts.search(problem) 84 | 85 | assert len(evolved_trajectory) > 0 86 | 87 | # Compare trajectories 88 | if len(initial_trajectory) > 0 and len(evolved_trajectory) > 0: 89 | initial_value = initial_trajectory[-1]["value"] 90 | evolved_value = evolved_trajectory[-1]["value"] 91 | assert evolved_value >= initial_value 92 | 93 | class TestModelComparison: 94 | @pytest.fixture 95 | def setup_comparison(self): 96 | problem = "What is 2 + 2?" 97 | solution_steps = [ 98 | "First, we identify this is an addition problem", 99 | "Then, we add 2 and 2 together", 100 | "Therefore, 2 + 2 = 4" 101 | ] 102 | 103 | with patch('openai.ChatCompletion.create') as mock_create: 104 | mock_create.return_value = Mock( 105 | choices=[Mock(message=Mock(content="0.8"))] 106 | ) 107 | openai_model = OpenAIModel("test-key") 108 | 109 | return problem, solution_steps, openai_model 110 | 111 | def test_model_comparison(self, setup_comparison): 112 | problem, solution_steps, model = setup_comparison 113 | 114 | # Test direct reasoning (System 1) 115 | direct_response = model.generate_response(problem) 116 | assert isinstance(direct_response, str) 117 | 118 | # Test rStar-Math enhanced reasoning (System 2) 119 | mcts = MCTS() 120 | ppm = ProcessPreferenceModel(input_dim=768) 121 | 122 | mcts.get_possible_actions = Mock(return_value=solution_steps) 123 | mcts.apply_action = Mock(return_value="4") 124 | mcts.evaluate_state = Mock(return_value=1.0) 125 | 126 | enhanced_action, enhanced_trajectory = mcts.search(problem) 127 | assert enhanced_action is not None 128 | assert len(enhanced_trajectory) > 0 129 | 130 | # Compare solutions 131 | direct_score = model.evaluate_reasoning(problem, [direct_response]) 132 | enhanced_score = model.evaluate_reasoning(problem, solution_steps) 133 | 134 | assert isinstance(direct_score, float) 135 | assert isinstance(enhanced_score, float) 136 | assert 0 <= direct_score <= 1 137 | assert 0 <= enhanced_score <= 1 138 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for LLM model implementations 3 | """ 4 | import pytest 5 | from unittest.mock import Mock, patch 6 | from src.models.model_interface import ( 7 | ModelConfig, 8 | OpenAIModel, 9 | AnthropicModel, 10 | ModelFactory 11 | ) 12 | 13 | @pytest.fixture 14 | def mock_openai_response(): 15 | return Mock( 16 | choices=[ 17 | Mock( 18 | message=Mock( 19 | content="Test response" 20 | ) 21 | ) 22 | ], 23 | data=[ 24 | Mock( 25 | embedding=[0.1, 0.2, 0.3] 26 | ) 27 | ] 28 | ) 29 | 30 | @pytest.fixture 31 | def mock_anthropic_response(): 32 | return Mock(content="Test response") 33 | 34 | def test_model_config(): 35 | config = ModelConfig(model="test-model", temperature=0.5, max_tokens=100) 36 | assert config.model == "test-model" 37 | assert config.temperature == 0.5 38 | assert config.max_tokens == 100 39 | 40 | @patch('openai.ChatCompletion.create') 41 | def test_openai_generate_response(mock_create, mock_openai_response): 42 | mock_create.return_value = mock_openai_response 43 | model = OpenAIModel("test-key") 44 | 45 | response = model.generate_response("test prompt") 46 | assert response == "Test response" 47 | 48 | mock_create.assert_called_once() 49 | call_args = mock_create.call_args[1] 50 | assert call_args['messages'][0]['content'] == "test prompt" 51 | 52 | @patch('openai.Embedding.create') 53 | def test_openai_embed_text(mock_create, mock_openai_response): 54 | mock_create.return_value = mock_openai_response 55 | model = OpenAIModel("test-key") 56 | 57 | embedding = model.embed_text("test text") 58 | assert embedding == [0.1, 0.2, 0.3] 59 | 60 | mock_create.assert_called_once_with( 61 | model="text-embedding-ada-002", 62 | input="test text" 63 | ) 64 | 65 | @patch('anthropic.Client') 66 | def test_anthropic_generate_response(mock_client, mock_anthropic_response): 67 | mock_instance = Mock() 68 | mock_instance.messages.create.return_value = mock_anthropic_response 69 | mock_client.return_value = mock_instance 70 | 71 | model = AnthropicModel("test-key") 72 | response = model.generate_response("test prompt") 73 | 74 | assert response == "Test response" 75 | mock_instance.messages.create.assert_called_once() 76 | 77 | def test_model_factory(): 78 | with pytest.raises(ValueError): 79 | ModelFactory.create_model("unknown-model", "test-key") 80 | 81 | model = ModelFactory.create_model("openai", "test-key") 82 | assert isinstance(model, OpenAIModel) 83 | 84 | model = ModelFactory.create_model("anthropic", "test-key") 85 | assert isinstance(model, AnthropicModel) 86 | 87 | def test_evaluate_reasoning(): 88 | test_problem = "What is 2 + 2?" 89 | test_steps = ["First, we identify that this is an addition problem", 90 | "Then, we add 2 and 2 together", 91 | "Therefore, 2 + 2 = 4"] 92 | 93 | with patch('openai.ChatCompletion.create') as mock_create: 94 | mock_create.return_value = Mock( 95 | choices=[Mock(message=Mock(content="0.8"))] 96 | ) 97 | 98 | model = OpenAIModel("test-key") 99 | score = model.evaluate_reasoning(test_problem, test_steps) 100 | 101 | assert 0 <= score <= 1 102 | mock_create.assert_called_once() 103 | -------------------------------------------------------------------------------- /tests/test_new_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for new model implementations (Mistral, Groq, Gemini) 3 | """ 4 | import pytest 5 | from unittest.mock import Mock, patch 6 | from src.models.mistral_model import MistralModel 7 | from src.models.groq_model import GroqModel 8 | from src.models.gemini_model import GeminiModel 9 | from src.models.model_interface import ModelConfig 10 | 11 | @pytest.fixture 12 | def mock_config(): 13 | return ModelConfig( 14 | model="test-model", 15 | temperature=0.7, 16 | max_tokens=100 17 | ) 18 | 19 | @pytest.fixture 20 | def mock_api_response(): 21 | return Mock( 22 | choices=[Mock(message=Mock(content="Test response"))], 23 | data=[Mock(embedding=[0.1, 0.2, 0.3])] 24 | ) 25 | 26 | class TestMistralModel: 27 | @pytest.fixture 28 | def model(self): 29 | return MistralModel("test-key", mock_config()) 30 | 31 | @patch('mistralai.client.MistralClient') 32 | def test_generate_response(self, mock_client, model): 33 | mock_client.return_value.chat.return_value = mock_api_response 34 | response = model.generate_response("Test prompt") 35 | assert isinstance(response, str) 36 | assert len(response) > 0 37 | 38 | @patch('mistralai.client.MistralClient') 39 | def test_evaluate_reasoning(self, mock_client, model): 40 | mock_client.return_value.chat.return_value = mock_api_response 41 | score = model.evaluate_reasoning( 42 | "What is 2+2?", 43 | ["First step", "Second step"] 44 | ) 45 | assert isinstance(score, float) 46 | assert 0 <= score <= 1 47 | 48 | @patch('mistralai.client.MistralClient') 49 | def test_embed_text(self, mock_client, model): 50 | mock_client.return_value.embeddings.return_value = mock_api_response 51 | embedding = model.embed_text("Test text") 52 | assert isinstance(embedding, list) 53 | assert all(isinstance(x, float) for x in embedding) 54 | 55 | class TestGroqModel: 56 | @pytest.fixture 57 | def model(self): 58 | return GroqModel("test-key", mock_config()) 59 | 60 | @patch('groq.Client') 61 | def test_generate_response(self, mock_client, model): 62 | mock_client.return_value.chat.completions.create.return_value = mock_api_response 63 | response = model.generate_response("Test prompt") 64 | assert isinstance(response, str) 65 | assert len(response) > 0 66 | 67 | @patch('groq.Client') 68 | def test_evaluate_reasoning(self, mock_client, model): 69 | mock_client.return_value.chat.completions.create.return_value = mock_api_response 70 | score = model.evaluate_reasoning( 71 | "What is 2+2?", 72 | ["First step", "Second step"] 73 | ) 74 | assert isinstance(score, float) 75 | assert 0 <= score <= 1 76 | 77 | @patch('groq.Client') 78 | def test_embed_text(self, mock_client, model): 79 | # Test fallback to OpenAI embeddings 80 | with patch('src.models.model_interface.OpenAIModel') as mock_openai: 81 | mock_openai.return_value.embed_text.return_value = [0.1, 0.2, 0.3] 82 | embedding = model.embed_text("Test text") 83 | assert isinstance(embedding, list) 84 | assert all(isinstance(x, float) for x in embedding) 85 | 86 | class TestGeminiModel: 87 | @pytest.fixture 88 | def model(self): 89 | return GeminiModel("test-key", mock_config()) 90 | 91 | @patch('google.generativeai.GenerativeModel') 92 | def test_generate_response(self, mock_model, model): 93 | mock_model.return_value.generate_content.return_value = Mock(text="Test response") 94 | response = model.generate_response("Test prompt") 95 | assert isinstance(response, str) 96 | assert len(response) > 0 97 | 98 | @patch('google.generativeai.GenerativeModel') 99 | def test_evaluate_reasoning(self, mock_model, model): 100 | mock_model.return_value.generate_content.return_value = Mock(text="0.8") 101 | score = model.evaluate_reasoning( 102 | "What is 2+2?", 103 | ["First step", "Second step"] 104 | ) 105 | assert isinstance(score, float) 106 | assert 0 <= score <= 1 107 | 108 | @patch('google.generativeai.GenerativeModel') 109 | def test_embed_text(self, mock_model, model): 110 | mock_model.return_value.embed_content.return_value = Mock( 111 | embedding=[0.1, 0.2, 0.3] 112 | ) 113 | embedding = model.embed_text("Test text") 114 | assert isinstance(embedding, list) 115 | assert all(isinstance(x, float) for x in embedding) 116 | 117 | # Integration tests 118 | @pytest.mark.integration 119 | class TestModelIntegration: 120 | @pytest.mark.parametrize("model_class", [ 121 | MistralModel, 122 | GroqModel, 123 | GeminiModel 124 | ]) 125 | def test_model_chain(self, model_class): 126 | """Test complete chain of operations.""" 127 | model = model_class("test-key", mock_config()) 128 | 129 | # Generate response 130 | response = model.generate_response("What is 2+2?") 131 | assert isinstance(response, str) 132 | 133 | # Evaluate reasoning 134 | score = model.evaluate_reasoning( 135 | "What is 2+2?", 136 | ["First, identify this is addition", "Then, add 2 and 2", "Therefore, 2+2=4"] 137 | ) 138 | assert isinstance(score, float) 139 | 140 | # Generate embeddings 141 | embedding = model.embed_text("Test text") 142 | assert isinstance(embedding, list) 143 | 144 | @pytest.mark.parametrize("model_class", [ 145 | MistralModel, 146 | GroqModel, 147 | GeminiModel 148 | ]) 149 | def test_error_handling(self, model_class): 150 | """Test error handling in models.""" 151 | model = model_class("invalid-key", mock_config()) 152 | 153 | # Test invalid API key 154 | with pytest.raises(Exception): 155 | model.generate_response("Test prompt") 156 | 157 | # Test invalid prompt 158 | with pytest.raises(Exception): 159 | model.generate_response("") 160 | 161 | # Test invalid reasoning steps 162 | score = model.evaluate_reasoning("Test", []) 163 | assert score == 0.0 164 | -------------------------------------------------------------------------------- /tools/benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Performance benchmarking tools for rStar-Math 3 | """ 4 | import time 5 | import json 6 | import numpy as np 7 | import pandas as pd 8 | import plotly.express as px 9 | from typing import Dict, List, Any, Optional 10 | from dataclasses import dataclass 11 | from src.core.mcts import MCTS 12 | from src.core.ppm import ProcessPreferenceModel 13 | from src.models.model_interface import ModelFactory 14 | 15 | @dataclass 16 | class BenchmarkResult: 17 | """Container for benchmark results.""" 18 | model_name: str 19 | problem_type: str 20 | use_rstar: bool 21 | execution_time: float 22 | confidence_score: float 23 | token_count: int 24 | memory_usage: float 25 | 26 | class RStarMathBenchmark: 27 | def __init__(self, config_path: str = 'config/default.json'): 28 | """Initialize benchmark components.""" 29 | self.mcts = MCTS.from_config_file(config_path) 30 | self.ppm = ProcessPreferenceModel.from_config_file(config_path) 31 | self.results: List[BenchmarkResult] = [] 32 | 33 | def load_test_problems(self, problem_set: str) -> List[Dict[str, str]]: 34 | """Load test problems from JSON file.""" 35 | with open(f'tests/problems/{problem_set}.json', 'r') as f: 36 | return json.load(f) 37 | 38 | def measure_execution(self, func, *args, **kwargs) -> Dict[str, float]: 39 | """Measure execution time and memory usage.""" 40 | import psutil 41 | import os 42 | 43 | process = psutil.Process(os.getpid()) 44 | start_memory = process.memory_info().rss / 1024 / 1024 # MB 45 | 46 | start_time = time.time() 47 | result = func(*args, **kwargs) 48 | execution_time = time.time() - start_time 49 | 50 | end_memory = process.memory_info().rss / 1024 / 1024 51 | memory_used = end_memory - start_memory 52 | 53 | return { 54 | 'result': result, 55 | 'execution_time': execution_time, 56 | 'memory_usage': memory_used 57 | } 58 | 59 | def count_tokens(self, text: str) -> int: 60 | """Estimate token count.""" 61 | # Simple estimation: words * 1.3 62 | return int(len(text.split()) * 1.3) 63 | 64 | def run_benchmark(self, 65 | model: Any, 66 | problem: Dict[str, str], 67 | use_rstar: bool = True) -> BenchmarkResult: 68 | """Run benchmark for a single problem.""" 69 | if use_rstar: 70 | # Measure MCTS search 71 | search_result = self.measure_execution( 72 | self.mcts.search, 73 | problem['text'] 74 | ) 75 | action, trajectory = search_result['result'] 76 | 77 | # Measure confidence scoring 78 | confidence_scores = [] 79 | total_tokens = 0 80 | 81 | for step in trajectory: 82 | eval_result = self.measure_execution( 83 | self.ppm.evaluate_step, 84 | step['state'], 85 | model 86 | ) 87 | confidence_scores.append(eval_result['result']) 88 | total_tokens += self.count_tokens(step['state']) 89 | 90 | avg_confidence = np.mean(confidence_scores) 91 | execution_time = search_result['execution_time'] 92 | memory_usage = search_result['memory_usage'] 93 | 94 | else: 95 | # Direct model solution 96 | generate_result = self.measure_execution( 97 | model.generate_response, 98 | problem['text'] 99 | ) 100 | solution = generate_result['result'] 101 | 102 | eval_result = self.measure_execution( 103 | model.evaluate_reasoning, 104 | problem['text'], 105 | [solution] 106 | ) 107 | 108 | avg_confidence = eval_result['result'] 109 | execution_time = generate_result['execution_time'] 110 | memory_usage = generate_result['memory_usage'] 111 | total_tokens = self.count_tokens(solution) 112 | 113 | return BenchmarkResult( 114 | model_name=model.__class__.__name__, 115 | problem_type=problem['type'], 116 | use_rstar=use_rstar, 117 | execution_time=execution_time, 118 | confidence_score=avg_confidence, 119 | token_count=total_tokens, 120 | memory_usage=memory_usage 121 | ) 122 | 123 | def run_full_benchmark(self, 124 | problem_sets: List[str], 125 | models: Dict[str, Any], 126 | iterations: int = 3): 127 | """Run complete benchmark suite.""" 128 | for problem_set in problem_sets: 129 | problems = self.load_test_problems(problem_set) 130 | 131 | for model_name, model in models.items(): 132 | for problem in problems: 133 | for use_rstar in [True, False]: 134 | for _ in range(iterations): 135 | result = self.run_benchmark(model, problem, use_rstar) 136 | self.results.append(result) 137 | 138 | def generate_report(self, output_path: str = 'benchmark_results'): 139 | """Generate benchmark report with visualizations.""" 140 | # Convert results to DataFrame 141 | df = pd.DataFrame([vars(r) for r in self.results]) 142 | 143 | # Create visualizations 144 | plots = [] 145 | 146 | # Execution time comparison 147 | fig1 = px.box(df, x='model_name', y='execution_time', 148 | color='use_rstar', facet_col='problem_type', 149 | title='Execution Time by Model and Problem Type') 150 | plots.append(('execution_time.html', fig1)) 151 | 152 | # Confidence score comparison 153 | fig2 = px.box(df, x='model_name', y='confidence_score', 154 | color='use_rstar', facet_col='problem_type', 155 | title='Confidence Scores by Model and Problem Type') 156 | plots.append(('confidence_scores.html', fig2)) 157 | 158 | # Token usage 159 | fig3 = px.bar(df.groupby(['model_name', 'use_rstar'])['token_count'].mean().reset_index(), 160 | x='model_name', y='token_count', color='use_rstar', 161 | title='Average Token Usage by Model') 162 | plots.append(('token_usage.html', fig3)) 163 | 164 | # Memory usage 165 | fig4 = px.line(df.groupby(['model_name', 'use_rstar'])['memory_usage'].mean().reset_index(), 166 | x='model_name', y='memory_usage', color='use_rstar', 167 | title='Memory Usage by Model') 168 | plots.append(('memory_usage.html', fig4)) 169 | 170 | # Save plots 171 | os.makedirs(output_path, exist_ok=True) 172 | for filename, fig in plots: 173 | fig.write_html(f'{output_path}/{filename}') 174 | 175 | # Generate summary statistics 176 | summary = df.groupby(['model_name', 'use_rstar']).agg({ 177 | 'execution_time': ['mean', 'std'], 178 | 'confidence_score': ['mean', 'std'], 179 | 'token_count': 'mean', 180 | 'memory_usage': 'mean' 181 | }).round(3) 182 | 183 | summary.to_csv(f'{output_path}/summary_stats.csv') 184 | 185 | # Generate markdown report 186 | with open(f'{output_path}/report.md', 'w') as f: 187 | f.write("# rStar-Math Benchmark Report\n\n") 188 | f.write(f"Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") 189 | 190 | f.write("## Summary Statistics\n") 191 | f.write(summary.to_markdown()) 192 | 193 | f.write("\n\n## Visualizations\n") 194 | for filename, _ in plots: 195 | f.write(f"- [{filename}]({filename})\n") 196 | 197 | def main(): 198 | """Run benchmark suite.""" 199 | # Initialize models 200 | models = { 201 | 'openai': ModelFactory.create_model( 202 | 'openai', 203 | os.getenv('OPENAI_API_KEY'), 204 | 'config/default.json' 205 | ), 206 | 'anthropic': ModelFactory.create_model( 207 | 'anthropic', 208 | os.getenv('ANTHROPIC_API_KEY'), 209 | 'config/default.json' 210 | ) 211 | } 212 | 213 | # Initialize benchmark 214 | benchmark = RStarMathBenchmark() 215 | 216 | # Run benchmarks 217 | problem_sets = ['arithmetic', 'algebra', 'calculus'] 218 | benchmark.run_full_benchmark(problem_sets, models) 219 | 220 | # Generate report 221 | benchmark.generate_report() 222 | 223 | if __name__ == "__main__": 224 | main() 225 | --------------------------------------------------------------------------------