├── .gitnore ├── __init__.py ├── core ├── __init__.py └── state_manager.py ├── trainers ├── __init__.py ├── gpu_manager.py └── trl_trainer.py ├── managers ├── __init__.py ├── round_controller.py └── question_manager.py ├── collectors ├── __init__.py ├── data_normalizer.py └── trajectory_collector.py ├── processors ├── __init__.py ├── question_enhancer.py ├── reward_calculator.py └── solver_data_processor.py ├── requirements.txt ├── .env.example ├── config └── .env.example ├── pyproject.toml ├── README.md └── tools └── utils └── status_checker.py /.gitnore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core Module 3 | 4 | This module contains core components for state management and system operations. 5 | """ 6 | 7 | from .state_manager import StateManager 8 | 9 | __all__ = [ 10 | 'StateManager' 11 | ] -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training Module 3 | 4 | This module contains components for model training and GPU management. 5 | """ 6 | 7 | from .trl_trainer import TRLTrainer 8 | from .gpu_manager import GPUManager 9 | 10 | __all__ = [ 11 | 'TRLTrainer', 12 | 'GPUManager' 13 | ] -------------------------------------------------------------------------------- /managers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Management Module 3 | 4 | This module contains components for round and question management. 5 | """ 6 | 7 | from .round_controller import RoundController 8 | from .question_manager import QuestionManager 9 | 10 | __all__ = [ 11 | 'RoundController', 12 | 'QuestionManager' 13 | ] -------------------------------------------------------------------------------- /collectors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Collection Module 3 | 4 | This module contains components for collecting training data and trajectories. 5 | """ 6 | 7 | from .trajectory_collector import TrajectoryCollector 8 | from .data_normalizer import DataNormalizer 9 | 10 | __all__ = [ 11 | 'TrajectoryCollector', 12 | 'DataNormalizer' 13 | ] -------------------------------------------------------------------------------- /processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Processing Module 3 | 4 | This module contains components for processing and analyzing training data. 5 | """ 6 | 7 | from .reward_calculator import RewardCalculator 8 | from .question_enhancer import QuestionEnhancer 9 | from .solver_data_processor import SolverDataProcessor 10 | 11 | __all__ = [ 12 | 'RewardCalculator', 13 | 'QuestionEnhancer', 14 | 'SolverDataProcessor' 15 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # ProSetting Progressive Training System Dependencies 2 | 3 | # Core deep learning frameworks 4 | torch>=2.0.0 5 | transformers>=4.35.0 6 | tokenizers>=0.14.0 7 | datasets>=2.14.0 8 | 9 | # TRL Training Framework (Alternative to VERL) 10 | trl>=0.7.0 11 | accelerate>=0.24.0 12 | peft>=0.6.0 13 | 14 | # Data Processing 15 | pandas>=2.0.0 16 | numpy>=1.24.0 17 | pyarrow>=13.0.0 # parquet support 18 | fastparquet>=2023.8.0 19 | 20 | # Distributed Training and Optimization 21 | deepspeed>=0.12.0 22 | bitsandbytes>=0.41.0 23 | optimum>=1.14.0 24 | 25 | # API and Network Requests 26 | requests>=2.31.0 27 | aiohttp>=3.8.0 28 | httpx>=0.25.0 29 | 30 | # Data Science and Visualization 31 | matplotlib>=3.7.0 32 | seaborn>=0.12.0 33 | plotly>=5.17.0 34 | 35 | # Utility Libraries 36 | tqdm>=4.66.0 37 | rich>=13.6.0 38 | click>=8.1.0 39 | python-dotenv>=1.0.0 40 | 41 | # Logging and Monitoring 42 | wandb>=0.15.0 43 | tensorboard>=2.14.0 44 | loguru>=0.7.0 45 | 46 | # Testing and Development 47 | pytest>=7.4.0 48 | pytest-asyncio>=0.21.0 49 | black>=23.9.0 50 | flake8>=6.1.0 51 | 52 | # Other Tools 53 | psutil>=5.9.0 54 | GPUtil>=1.4.0 55 | py-cpuinfo>=9.0.0 -------------------------------------------------------------------------------- /collectors/data_normalizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Normalizer 3 | 4 | Normalizes collected trajectory data for consistent processing. 5 | """ 6 | 7 | import logging 8 | from typing import List, Dict, Any 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class DataNormalizer: 14 | """ 15 | Data normalizer for trajectory data 16 | """ 17 | 18 | @staticmethod 19 | def normalize_trajectories(trajectories: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 20 | """ 21 | Normalize trajectory data format 22 | 23 | Args: 24 | trajectories: Raw trajectory data 25 | 26 | Returns: 27 | Normalized trajectory data 28 | """ 29 | normalized = [] 30 | 31 | for trajectory in trajectories: 32 | normalized_trajectory = { 33 | "question": trajectory.get("question", ""), 34 | "attempt": trajectory.get("attempt", 1), 35 | "response": trajectory.get("response", ""), 36 | "reasoning_steps": trajectory.get("reasoning_steps", ""), 37 | "metadata": { 38 | "original_keys": list(trajectory.keys()), 39 | "normalized": True 40 | } 41 | } 42 | normalized.append(normalized_trajectory) 43 | 44 | logger.info(f"Normalized {len(normalized)} trajectories") 45 | return normalized -------------------------------------------------------------------------------- /processors/question_enhancer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Question Enhancer 3 | 4 | Enhances questions based on error analysis using teacher model. 5 | """ 6 | 7 | import logging 8 | from typing import List, Dict, Any 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class QuestionEnhancer: 14 | """ 15 | Question enhancer using teacher model analysis 16 | """ 17 | 18 | def __init__(self): 19 | self.teacher_client = None 20 | 21 | def set_teacher_client(self, teacher_client): 22 | """Set teacher model client for enhancement""" 23 | self.teacher_client = teacher_client 24 | 25 | def enhance_questions(self, failed_questions: List[str]) -> List[str]: 26 | """ 27 | Enhance questions based on failure analysis 28 | 29 | Args: 30 | failed_questions: List of questions that failed 31 | 32 | Returns: 33 | List of enhanced questions 34 | """ 35 | if not self.teacher_client: 36 | logger.warning("Teacher client not set, returning original questions") 37 | return failed_questions 38 | 39 | enhanced = [] 40 | 41 | for question in failed_questions: 42 | # Placeholder enhancement logic 43 | enhanced_question = f"Enhanced: {question}" 44 | enhanced.append(enhanced_question) 45 | 46 | logger.info(f"Enhanced {len(enhanced)} questions") 47 | return enhanced -------------------------------------------------------------------------------- /processors/reward_calculator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reward Calculator 3 | 4 | Calculates rewards for solver trajectories using teacher model evaluation. 5 | """ 6 | 7 | import logging 8 | from typing import List, Dict, Any, Tuple, Optional 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class RewardCalculator: 14 | """ 15 | Reward calculator for solver trajectories 16 | """ 17 | 18 | def __init__(self): 19 | self.teacher_client = None 20 | 21 | def set_teacher_client(self, teacher_client): 22 | """Set teacher model client for evaluation""" 23 | self.teacher_client = teacher_client 24 | 25 | def compute_solver_reward_local(self, trajectories: List[Dict[str, Any]]) -> Tuple[List[float], List[Dict[str, Any]]]: 26 | """ 27 | Compute rewards for solver trajectories 28 | 29 | Args: 30 | trajectories: List of trajectory data 31 | 32 | Returns: 33 | Tuple of (rewards, judge_results) 34 | """ 35 | if not self.teacher_client: 36 | logger.warning("Teacher client not set, using placeholder rewards") 37 | 38 | rewards = [] 39 | judge_results = [] 40 | 41 | for trajectory in trajectories: 42 | # Placeholder reward calculation 43 | reward = 0.5 # Default neutral reward 44 | 45 | judge_result = { 46 | "question": trajectory.get("question", ""), 47 | "response": trajectory.get("response", ""), 48 | "reward": reward, 49 | "correct_answers": [], 50 | "incorrect_answers": [trajectory.get("response", "")], 51 | "evaluation_details": { 52 | "teacher_used": self.teacher_client is not None, 53 | "placeholder": True 54 | } 55 | } 56 | 57 | rewards.append(reward) 58 | judge_results.append(judge_result) 59 | 60 | logger.info(f"Computed rewards for {len(trajectories)} trajectories") 61 | return rewards, judge_results -------------------------------------------------------------------------------- /managers/round_controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | Round Controller 3 | 4 | Controls training round progression and management. 5 | """ 6 | 7 | import logging 8 | from typing import List, Dict, Any 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class RoundController: 14 | """ 15 | Training round controller 16 | """ 17 | 18 | def __init__(self, max_rounds: int = 10, save_rounds: List[int] = None): 19 | """ 20 | Initialize round controller 21 | 22 | Args: 23 | max_rounds: Maximum number of training rounds 24 | save_rounds: List of rounds to save checkpoints 25 | """ 26 | self.max_rounds = max_rounds 27 | self.save_rounds = save_rounds or [3, 4, 5, 6, 7, 8, 9, 10] 28 | self.current_round = 1 29 | 30 | def should_save_checkpoint(self, round_num: int) -> bool: 31 | """ 32 | Check if checkpoint should be saved for this round 33 | 34 | Args: 35 | round_num: Current round number 36 | 37 | Returns: 38 | True if checkpoint should be saved 39 | """ 40 | return round_num in self.save_rounds 41 | 42 | def is_final_round(self, round_num: int) -> bool: 43 | """ 44 | Check if this is the final round 45 | 46 | Args: 47 | round_num: Current round number 48 | 49 | Returns: 50 | True if this is the final round 51 | """ 52 | return round_num >= self.max_rounds 53 | 54 | def get_round_info(self, round_num: int) -> Dict[str, Any]: 55 | """ 56 | Get information about a specific round 57 | 58 | Args: 59 | round_num: Round number 60 | 61 | Returns: 62 | Dictionary with round information 63 | """ 64 | return { 65 | "round_num": round_num, 66 | "max_rounds": self.max_rounds, 67 | "should_save": self.should_save_checkpoint(round_num), 68 | "is_final": self.is_final_round(round_num), 69 | "progress": round_num / self.max_rounds 70 | } -------------------------------------------------------------------------------- /processors/solver_data_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solver Data Processor 3 | 4 | Processes and saves solver training data and results. 5 | """ 6 | 7 | import os 8 | import json 9 | import logging 10 | from pathlib import Path 11 | from typing import List, Dict, Any 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class SolverDataProcessor: 17 | """ 18 | Solver data processor for saving and managing training data 19 | """ 20 | 21 | def __init__(self): 22 | self.workspace_dir = Path(os.getenv("WORKSPACE_DIR", "/workspace/prosetting")) 23 | 24 | def save_judge_results(self, judge_results: List[Dict[str, Any]], round_num: int) -> str: 25 | """ 26 | Save judge results to file 27 | 28 | Args: 29 | judge_results: List of judge result dictionaries 30 | round_num: Training round number 31 | 32 | Returns: 33 | Path to saved file 34 | """ 35 | data_dir = self.workspace_dir / "data" 36 | data_dir.mkdir(parents=True, exist_ok=True) 37 | 38 | judge_file = data_dir / f"round_{round_num}_judge_results.json" 39 | 40 | with open(judge_file, 'w', encoding='utf-8') as f: 41 | json.dump(judge_results, f, indent=2, ensure_ascii=False) 42 | 43 | logger.info(f"Saved judge results to: {judge_file}") 44 | return str(judge_file) 45 | 46 | def load_judge_results(self, round_num: int) -> List[Dict[str, Any]]: 47 | """ 48 | Load judge results from file 49 | 50 | Args: 51 | round_num: Training round number 52 | 53 | Returns: 54 | List of judge result dictionaries 55 | """ 56 | data_dir = self.workspace_dir / "data" 57 | judge_file = data_dir / f"round_{round_num}_judge_results.json" 58 | 59 | if not judge_file.exists(): 60 | logger.warning(f"Judge results file not found: {judge_file}") 61 | return [] 62 | 63 | with open(judge_file, 'r', encoding='utf-8') as f: 64 | judge_results = json.load(f) 65 | 66 | logger.info(f"Loaded {len(judge_results)} judge results from: {judge_file}") 67 | return judge_results -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # ProSetting Training System Environment Variables Configuration 2 | # Copy this file to .env and modify according to your setup 3 | 4 | # ==================== Model Path Configuration ==================== 5 | SOLVER_MODEL_PATH=/path/to/your/solver/model 6 | GENERATOR_MODEL_PATH=/path/to/your/generator/model 7 | 8 | # ==================== Data File Configuration ==================== 9 | QUESTIONS_FILE=/path/to/your/questions.json 10 | VALIDATION_DATA_FILE=./validation_data.json 11 | 12 | # ==================== Workspace Configuration ==================== 13 | WORKSPACE_DIR=/workspace/prosetting 14 | TRAINING_DATA_DIR=/workspace/prosetting/training_data 15 | TRAJECTORY_DATA_DIR=/workspace/prosetting/trajectories 16 | CHECKPOINT_DIR=/workspace/prosetting/checkpoints 17 | LOG_DIR=/workspace/prosetting/logs 18 | RESULTS_DIR=/workspace/prosetting/results 19 | 20 | # ==================== Training Parameters Configuration ==================== 21 | TOTAL_ROUNDS=10 22 | SAVE_ROUNDS=3,4,5,6,7,8,9,10 23 | ATTEMPTS_PER_QUESTION=8 24 | 25 | # ==================== GPU Configuration ==================== 26 | PHYSICAL_SOLVER_GPU=4 27 | PHYSICAL_DPO_GPU=0,1,2,3,4,5,6,7 28 | 29 | # ==================== Teacher Model Configuration ==================== 30 | TEACHER_BASE_URL=http://localhost:8000 31 | TEACHER_API_KEY= 32 | TEACHER_MODEL_NAME=/path/to/teacher/model 33 | TEACHER_CONCURRENT_WORKERS=32 34 | TEACHER_TIMEOUT=300 35 | TEACHER_MAX_RETRIES=3 36 | TEACHER_RETRY_DELAY=1.0 37 | 38 | # ==================== TRL Training Configuration ==================== 39 | # TRL DPO training parameters (replaces VERL) 40 | TRL_NUM_PROCESSES=8 41 | TRL_MIXED_PRECISION=bf16 42 | TRL_MAX_STEPS=100 43 | TRL_PER_DEVICE_BATCH_SIZE=2 44 | TRL_MAX_LENGTH=2048 45 | TRL_LEARNING_RATE=5e-6 46 | TRL_WARMUP_STEPS=10 47 | TRL_GRADIENT_ACCUMULATION_STEPS=8 48 | 49 | # Model output configuration 50 | MODEL_OUTPUT_DIR=/workspace/prosetting/models 51 | 52 | # ==================== Automated Training Configuration ==================== 53 | # Auto retry configuration 54 | AUTO_RETRY_ENABLED=true 55 | AUTO_CONTINUE_ON_FAILURE=false 56 | CHECKPOINT_INTERVAL=1 57 | 58 | # ==================== Other Configuration ==================== 59 | DEBUG_MODE=false 60 | LOG_LEVEL=INFO 61 | KEEP_TEMP_FILES=true 62 | 63 | # ==================== Cache Configuration ==================== 64 | HF_CACHE_DIR=/workspace/prosetting/hf_cache 65 | HF_HOME=/workspace/prosetting/hf_home 66 | -------------------------------------------------------------------------------- /config/.env.example: -------------------------------------------------------------------------------- 1 | # ProSetting Training System Environment Variables Configuration 2 | # Copy this file to .env and modify according to your setup 3 | 4 | # ==================== Model Path Configuration ==================== 5 | SOLVER_MODEL_PATH=/path/to/your/solver/model 6 | GENERATOR_MODEL_PATH=/path/to/your/generator/model 7 | 8 | # ==================== Data File Configuration ==================== 9 | QUESTIONS_FILE=/path/to/your/questions.json 10 | VALIDATION_DATA_FILE=./validation_data.json 11 | 12 | # ==================== Workspace Configuration ==================== 13 | WORKSPACE_DIR=/workspace/prosetting 14 | TRAINING_DATA_DIR=/workspace/prosetting/training_data 15 | TRAJECTORY_DATA_DIR=/workspace/prosetting/trajectories 16 | CHECKPOINT_DIR=/workspace/prosetting/checkpoints 17 | LOG_DIR=/workspace/prosetting/logs 18 | RESULTS_DIR=/workspace/prosetting/results 19 | 20 | # ==================== Training Parameters Configuration ==================== 21 | TOTAL_ROUNDS=10 22 | SAVE_ROUNDS=3,4,5,6,7,8,9,10 23 | ATTEMPTS_PER_QUESTION=8 24 | 25 | # ==================== GPU Configuration ==================== 26 | PHYSICAL_SOLVER_GPU=4 27 | PHYSICAL_DPO_GPU=0,1,2,3,4,5,6,7 28 | 29 | # ==================== Teacher Model Configuration ==================== 30 | TEACHER_BASE_URL=http://localhost:8000 31 | TEACHER_API_KEY= 32 | TEACHER_MODEL_NAME=/path/to/teacher/model 33 | TEACHER_CONCURRENT_WORKERS=32 34 | TEACHER_TIMEOUT=300 35 | TEACHER_MAX_RETRIES=3 36 | TEACHER_RETRY_DELAY=1.0 37 | 38 | # ==================== TRL Training Configuration ==================== 39 | # TRL DPO training parameters (replaces VERL) 40 | TRL_NUM_PROCESSES=8 41 | TRL_MIXED_PRECISION=bf16 42 | TRL_MAX_STEPS=100 43 | TRL_PER_DEVICE_BATCH_SIZE=2 44 | TRL_MAX_LENGTH=2048 45 | TRL_LEARNING_RATE=5e-6 46 | TRL_WARMUP_STEPS=10 47 | TRL_GRADIENT_ACCUMULATION_STEPS=8 48 | 49 | # Model output configuration 50 | MODEL_OUTPUT_DIR=/workspace/prosetting/models 51 | 52 | # ==================== Automated Training Configuration ==================== 53 | # Auto retry configuration 54 | AUTO_RETRY_ENABLED=true 55 | AUTO_CONTINUE_ON_FAILURE=false 56 | CHECKPOINT_INTERVAL=1 57 | 58 | # ==================== Other Configuration ==================== 59 | DEBUG_MODE=false 60 | LOG_LEVEL=INFO 61 | KEEP_TEMP_FILES=true 62 | 63 | # ==================== Cache Configuration ==================== 64 | HF_CACHE_DIR=/workspace/prosetting/hf_cache 65 | HF_HOME=/workspace/prosetting/hf_home 66 | -------------------------------------------------------------------------------- /trainers/gpu_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPU Manager 3 | 4 | Manages GPU resources and environment for training. 5 | """ 6 | 7 | import os 8 | import logging 9 | import subprocess 10 | from typing import Optional 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class GPUManager: 16 | """ 17 | GPU resource manager for training operations 18 | """ 19 | 20 | @staticmethod 21 | def setup_gpu_environment(): 22 | """Setup GPU environment for training""" 23 | try: 24 | # Set CUDA device order 25 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 26 | 27 | # Check GPU availability 28 | result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) 29 | if result.returncode == 0: 30 | logger.info("GPU environment setup completed") 31 | else: 32 | logger.warning("nvidia-smi not available, GPU may not be accessible") 33 | 34 | except Exception as e: 35 | logger.warning(f"GPU environment setup failed: {e}") 36 | 37 | @staticmethod 38 | def set_gpu_environment(gpu_ids: str): 39 | """ 40 | Set GPU environment for specific GPUs 41 | 42 | Args: 43 | gpu_ids: Comma-separated GPU IDs (e.g., "0,1,2,3") 44 | """ 45 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids 46 | logger.info(f"Set CUDA_VISIBLE_DEVICES to: {gpu_ids}") 47 | 48 | @staticmethod 49 | def clear_gpu_memory(): 50 | """Clear GPU memory cache""" 51 | try: 52 | # Placeholder for GPU memory clearing logic 53 | logger.info("GPU memory cleared") 54 | except Exception as e: 55 | logger.warning(f"Failed to clear GPU memory: {e}") 56 | 57 | @staticmethod 58 | def cleanup_and_release_models(): 59 | """Cleanup and release GPU models""" 60 | try: 61 | # Placeholder for model cleanup logic 62 | logger.info("GPU models cleaned up and released") 63 | except Exception as e: 64 | logger.warning(f"Failed to cleanup GPU models: {e}") 65 | 66 | @staticmethod 67 | def get_gpu_info() -> dict: 68 | """ 69 | Get GPU information 70 | 71 | Returns: 72 | Dictionary with GPU information 73 | """ 74 | try: 75 | result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.used', '--format=csv,noheader,nounits'], 76 | capture_output=True, text=True) 77 | if result.returncode == 0: 78 | gpu_info = {"available": True, "details": result.stdout.strip()} 79 | else: 80 | gpu_info = {"available": False, "error": "nvidia-smi failed"} 81 | except Exception as e: 82 | gpu_info = {"available": False, "error": str(e)} 83 | 84 | return gpu_info -------------------------------------------------------------------------------- /managers/question_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Question Manager 3 | 4 | Manages question loading and processing for training rounds. 5 | """ 6 | 7 | import json 8 | import logging 9 | from pathlib import Path 10 | from typing import List, Dict, Any 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class QuestionManager: 16 | """ 17 | Question manager for training data 18 | """ 19 | 20 | @staticmethod 21 | def load_questions_from_file(file_path: str) -> List[Dict[str, Any]]: 22 | """ 23 | Load questions from JSON file 24 | 25 | Args: 26 | file_path: Path to questions file 27 | 28 | Returns: 29 | List of question dictionaries 30 | """ 31 | file_path = Path(file_path) 32 | 33 | if not file_path.exists(): 34 | logger.error(f"Questions file not found: {file_path}") 35 | return [] 36 | 37 | try: 38 | with open(file_path, 'r', encoding='utf-8') as f: 39 | questions = json.load(f) 40 | 41 | logger.info(f"Loaded {len(questions)} questions from {file_path}") 42 | return questions 43 | 44 | except Exception as e: 45 | logger.error(f"Failed to load questions from {file_path}: {e}") 46 | return [] 47 | 48 | @staticmethod 49 | def validate_questions(questions: List[Dict[str, Any]]) -> bool: 50 | """ 51 | Validate question format 52 | 53 | Args: 54 | questions: List of question dictionaries 55 | 56 | Returns: 57 | True if questions are valid 58 | """ 59 | if not questions: 60 | logger.warning("No questions provided for validation") 61 | return False 62 | 63 | required_fields = ["question"] 64 | 65 | for i, question in enumerate(questions): 66 | if not isinstance(question, dict): 67 | logger.error(f"Question {i} is not a dictionary") 68 | return False 69 | 70 | for field in required_fields: 71 | if field not in question: 72 | logger.error(f"Question {i} missing required field: {field}") 73 | return False 74 | 75 | logger.info(f"Validated {len(questions)} questions successfully") 76 | return True 77 | 78 | @staticmethod 79 | def filter_questions(questions: List[Dict[str, Any]], 80 | criteria: Dict[str, Any] = None) -> List[Dict[str, Any]]: 81 | """ 82 | Filter questions based on criteria 83 | 84 | Args: 85 | questions: List of question dictionaries 86 | criteria: Filtering criteria 87 | 88 | Returns: 89 | Filtered list of questions 90 | """ 91 | if not criteria: 92 | return questions 93 | 94 | filtered = [] 95 | for question in questions: 96 | # Placeholder filtering logic 97 | # In real implementation, would apply actual filtering criteria 98 | filtered.append(question) 99 | 100 | logger.info(f"Filtered {len(questions)} questions to {len(filtered)}") 101 | return filtered -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "prosetting" 7 | version = "1.0.0" 8 | description = "Progressive Training System for Math Problem Solving" 9 | readme = "docs/README.md" 10 | license = {text = "MIT"} 11 | authors = [ 12 | {name = "ProSetting Team", email = "prosetting@example.com"} 13 | ] 14 | classifiers = [ 15 | "Development Status :: 4 - Beta", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Science/Research", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.8", 21 | "Programming Language :: Python :: 3.9", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: 3.11", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | "Topic :: Software Development :: Libraries :: Python Modules", 26 | ] 27 | requires-python = ">=3.8" 28 | dependencies = [ 29 | "torch>=2.0.0", 30 | "transformers>=4.30.0", 31 | "accelerate>=0.20.0", 32 | "trl>=0.7.0", 33 | "vllm>=0.2.0", 34 | "pandas>=1.5.0", 35 | "pyarrow>=12.0.0", 36 | "numpy>=1.21.0", 37 | "requests>=2.28.0", 38 | "python-dotenv>=1.0.0", 39 | "tqdm>=4.64.0", 40 | "datasets>=2.12.0", 41 | "peft>=0.4.0", 42 | ] 43 | 44 | [project.optional-dependencies] 45 | dev = [ 46 | "pytest>=7.0.0", 47 | "pytest-cov>=4.0.0", 48 | "black>=23.0.0", 49 | "isort>=5.12.0", 50 | "flake8>=6.0.0", 51 | "mypy>=1.0.0", 52 | ] 53 | docs = [ 54 | "sphinx>=6.0.0", 55 | "sphinx-rtd-theme>=1.2.0", 56 | "myst-parser>=1.0.0", 57 | ] 58 | 59 | [project.urls] 60 | Homepage = "https://github.com/prosetting/prosetting" 61 | Repository = "https://github.com/prosetting/prosetting" 62 | Documentation = "https://prosetting.readthedocs.io" 63 | "Bug Tracker" = "https://github.com/prosetting/prosetting/issues" 64 | 65 | [project.scripts] 66 | prosetting-train = "scripts.auto_trainer:main" 67 | prosetting-eval = "scripts.evaluate_mean_at_k:main" 68 | prosetting-status = "tools.utils.status_checker:main" 69 | 70 | [tool.setuptools.packages.find] 71 | where = ["."] 72 | include = ["src*", "scripts*", "tools*"] 73 | 74 | [tool.setuptools.package-data] 75 | "*" = ["*.json", "*.yaml", "*.yml", "*.txt", "*.md"] 76 | 77 | [tool.black] 78 | line-length = 100 79 | target-version = ['py38', 'py39', 'py310', 'py311'] 80 | include = '\.pyi?$' 81 | extend-exclude = ''' 82 | /( 83 | # directories 84 | \.eggs 85 | | \.git 86 | | \.hg 87 | | \.mypy_cache 88 | | \.tox 89 | | \.venv 90 | | build 91 | | dist 92 | )/ 93 | ''' 94 | 95 | [tool.isort] 96 | profile = "black" 97 | line_length = 100 98 | multi_line_output = 3 99 | include_trailing_comma = true 100 | force_grid_wrap = 0 101 | use_parentheses = true 102 | ensure_newline_before_comments = true 103 | 104 | [tool.mypy] 105 | python_version = "3.8" 106 | warn_return_any = true 107 | warn_unused_configs = true 108 | disallow_untyped_defs = true 109 | disallow_incomplete_defs = true 110 | check_untyped_defs = true 111 | disallow_untyped_decorators = true 112 | no_implicit_optional = true 113 | warn_redundant_casts = true 114 | warn_unused_ignores = true 115 | warn_no_return = true 116 | warn_unreachable = true 117 | strict_equality = true 118 | 119 | [tool.pytest.ini_options] 120 | testpaths = ["tests"] 121 | python_files = ["test_*.py", "*_test.py"] 122 | python_classes = ["Test*"] 123 | python_functions = ["test_*"] 124 | addopts = [ 125 | "--strict-markers", 126 | "--strict-config", 127 | "--cov=src", 128 | "--cov-report=term-missing", 129 | "--cov-report=html", 130 | "--cov-report=xml", 131 | ] -------------------------------------------------------------------------------- /collectors/trajectory_collector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trajectory Collector 3 | 4 | Collects solver model trajectories for training data generation. 5 | """ 6 | 7 | import os 8 | import logging 9 | from typing import List, Dict, Any, Optional 10 | from pathlib import Path 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class TrajectoryCollector: 16 | """ 17 | Trajectory collector for solver model inference 18 | """ 19 | 20 | def __init__(self, physical_gpus: str = "0"): 21 | """ 22 | Initialize trajectory collector 23 | 24 | Args: 25 | physical_gpus: GPU IDs to use for inference 26 | """ 27 | self.physical_gpus = physical_gpus 28 | self._model_loaded = False 29 | self.solver_model = None 30 | 31 | def load_solver_model(self, model_path: str, force_load: bool = False) -> bool: 32 | """ 33 | Load solver model for inference 34 | 35 | Args: 36 | model_path: Path to solver model 37 | force_load: Force reload model even if already loaded 38 | 39 | Returns: 40 | True if model loaded successfully, False otherwise 41 | """ 42 | try: 43 | if self._model_loaded and not force_load: 44 | logger.info("Solver model already loaded, skipping") 45 | return True 46 | 47 | logger.info(f"Loading solver model from: {model_path}") 48 | 49 | # Model loading logic would go here 50 | # This is a placeholder for the actual model loading implementation 51 | 52 | self._model_loaded = True 53 | logger.info("Solver model loaded successfully") 54 | return True 55 | 56 | except Exception as e: 57 | logger.error(f"Failed to load solver model: {e}") 58 | return False 59 | 60 | def collect_trajectories(self, questions: List[Dict[str, Any]], 61 | attempts_per_question: int = 8) -> List[Dict[str, Any]]: 62 | """ 63 | Collect trajectories for given questions 64 | 65 | Args: 66 | questions: List of questions to process 67 | attempts_per_question: Number of attempts per question 68 | 69 | Returns: 70 | List of collected trajectories 71 | """ 72 | if not self._model_loaded: 73 | logger.error("Solver model not loaded") 74 | return [] 75 | 76 | trajectories = [] 77 | 78 | for i, question in enumerate(questions): 79 | logger.info(f"Processing question {i+1}/{len(questions)}") 80 | 81 | question_trajectories = [] 82 | for attempt in range(attempts_per_question): 83 | # Placeholder for actual inference logic 84 | trajectory = { 85 | "question": question.get("question", ""), 86 | "attempt": attempt + 1, 87 | "response": f"Sample response for attempt {attempt + 1}", 88 | "reasoning_steps": f"Sample reasoning for question {i+1}, attempt {attempt + 1}" 89 | } 90 | question_trajectories.append(trajectory) 91 | 92 | trajectories.extend(question_trajectories) 93 | 94 | logger.info(f"Collected {len(trajectories)} trajectories") 95 | return trajectories 96 | 97 | def release_model(self): 98 | """Release loaded model to free memory""" 99 | if self._model_loaded: 100 | logger.info("Releasing solver model") 101 | self.solver_model = None 102 | self._model_loaded = False -------------------------------------------------------------------------------- /trainers/trl_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | TRL Trainer 3 | 4 | TRL DPO trainer for distributed training using accelerate framework. 5 | """ 6 | 7 | import os 8 | import logging 9 | from pathlib import Path 10 | from typing import List, Dict, Any, Optional 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class TRLTrainer: 16 | """ 17 | TRL DPO trainer using accelerate for distributed training 18 | """ 19 | 20 | def __init__(self, save_rounds: List[int] = None): 21 | """ 22 | Initialize TRL trainer 23 | 24 | Args: 25 | save_rounds: List of rounds to save model checkpoints 26 | """ 27 | self.save_rounds = save_rounds or [3, 4, 5, 6, 7, 8, 9, 10] 28 | self.workspace_dir = Path(os.getenv("WORKSPACE_DIR", "/workspace/prosetting")) 29 | self.model_output_dir = Path(os.getenv("MODEL_OUTPUT_DIR", "/workspace/prosetting/models")) 30 | 31 | def run_trl_training(self, dataset_dir: str, round_num: int) -> bool: 32 | """ 33 | Run TRL DPO training 34 | 35 | Args: 36 | dataset_dir: Directory containing training dataset 37 | round_num: Current training round number 38 | 39 | Returns: 40 | True if training successful, False otherwise 41 | """ 42 | try: 43 | logger.info(f"Starting TRL DPO training for round {round_num}") 44 | logger.info(f"Dataset directory: {dataset_dir}") 45 | 46 | # Placeholder for actual TRL training implementation 47 | # In real implementation, would use TRL's DPOTrainer with accelerate 48 | 49 | # Training configuration 50 | training_config = { 51 | "num_processes": int(os.getenv("TRL_NUM_PROCESSES", "8")), 52 | "mixed_precision": os.getenv("TRL_MIXED_PRECISION", "bf16"), 53 | "max_steps": int(os.getenv("TRL_MAX_STEPS", "100")), 54 | "per_device_batch_size": int(os.getenv("TRL_PER_DEVICE_BATCH_SIZE", "2")), 55 | "learning_rate": float(os.getenv("TRL_LEARNING_RATE", "5e-6")), 56 | "warmup_steps": int(os.getenv("TRL_WARMUP_STEPS", "10")), 57 | "gradient_accumulation_steps": int(os.getenv("TRL_GRADIENT_ACCUMULATION_STEPS", "8")) 58 | } 59 | 60 | logger.info(f"Training configuration: {training_config}") 61 | 62 | # Simulate training process 63 | logger.info("Executing TRL DPO training...") 64 | 65 | # Save model if this round should be saved 66 | if round_num in self.save_rounds: 67 | output_path = self._get_output_path(round_num) 68 | logger.info(f"Saving model to: {output_path}") 69 | self._save_model_placeholder(output_path) 70 | 71 | logger.info(f"TRL training completed successfully for round {round_num}") 72 | return True 73 | 74 | except Exception as e: 75 | logger.error(f"TRL training failed for round {round_num}: {e}") 76 | return False 77 | 78 | def get_model_path_for_round(self, round_num: int) -> str: 79 | """ 80 | Get model path for a specific round 81 | 82 | Args: 83 | round_num: Round number 84 | 85 | Returns: 86 | Path to model for the round 87 | """ 88 | if round_num <= 2: 89 | # First 2 rounds use original model 90 | return os.getenv("SOLVER_MODEL_PATH", "/path/to/original/model") 91 | 92 | # Find the latest saved model before this round 93 | for check_round in range(round_num - 1, 0, -1): 94 | if check_round in self.save_rounds: 95 | model_path = self._get_output_path(check_round) 96 | if Path(model_path).exists(): 97 | return model_path 98 | 99 | # Fallback to original model 100 | return os.getenv("SOLVER_MODEL_PATH", "/path/to/original/model") 101 | 102 | def _get_output_path(self, round_num: int) -> str: 103 | """Get output path for model checkpoint""" 104 | return str(self.model_output_dir / f"round_{round_num}_model") 105 | 106 | def _save_model_placeholder(self, output_path: str): 107 | """Placeholder for model saving logic""" 108 | output_dir = Path(output_path) 109 | output_dir.mkdir(parents=True, exist_ok=True) 110 | 111 | # Create placeholder model file 112 | model_info = { 113 | "model_type": "TRL_DPO_trained", 114 | "training_framework": "TRL", 115 | "saved_at": "placeholder_timestamp" 116 | } 117 | 118 | import json 119 | with open(output_dir / "model_info.json", 'w') as f: 120 | json.dump(model_info, f, indent=2) 121 | 122 | logger.info(f"Model placeholder saved to: {output_path}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Socratic-Zero: Bootstrapping Reasoning via Data-Free Agent Co-evolution [](http://arxiv.org/abs/2509.24726)[](https://www.python.org/downloads/release/python-31013/)[](#license) 4 | 5 | 6 | ## Overview 7 | 8 | Socratic-Zero is a fully autonomous framework that generates high-quality training data for mathematical reasoning from minimal seed examples through the co-evolution of three agents: the *Solver*, the *Teacher*, and the *Generator*. Starting from only 100 seed questions, our approach achieves significant improvements without relying on massive external datasets. 9 |
18 |
23 |
24 |
25 | The Socratic-Zero Framework Pipeline
26 |
32 |