├── .gitnore ├── __init__.py ├── core ├── __init__.py └── state_manager.py ├── trainers ├── __init__.py ├── gpu_manager.py └── trl_trainer.py ├── managers ├── __init__.py ├── round_controller.py └── question_manager.py ├── collectors ├── __init__.py ├── data_normalizer.py └── trajectory_collector.py ├── processors ├── __init__.py ├── question_enhancer.py ├── reward_calculator.py └── solver_data_processor.py ├── requirements.txt ├── .env.example ├── config └── .env.example ├── pyproject.toml ├── README.md └── tools └── utils └── status_checker.py /.gitnore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core Module 3 | 4 | This module contains core components for state management and system operations. 5 | """ 6 | 7 | from .state_manager import StateManager 8 | 9 | __all__ = [ 10 | 'StateManager' 11 | ] -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training Module 3 | 4 | This module contains components for model training and GPU management. 5 | """ 6 | 7 | from .trl_trainer import TRLTrainer 8 | from .gpu_manager import GPUManager 9 | 10 | __all__ = [ 11 | 'TRLTrainer', 12 | 'GPUManager' 13 | ] -------------------------------------------------------------------------------- /managers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Management Module 3 | 4 | This module contains components for round and question management. 5 | """ 6 | 7 | from .round_controller import RoundController 8 | from .question_manager import QuestionManager 9 | 10 | __all__ = [ 11 | 'RoundController', 12 | 'QuestionManager' 13 | ] -------------------------------------------------------------------------------- /collectors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Collection Module 3 | 4 | This module contains components for collecting training data and trajectories. 5 | """ 6 | 7 | from .trajectory_collector import TrajectoryCollector 8 | from .data_normalizer import DataNormalizer 9 | 10 | __all__ = [ 11 | 'TrajectoryCollector', 12 | 'DataNormalizer' 13 | ] -------------------------------------------------------------------------------- /processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Processing Module 3 | 4 | This module contains components for processing and analyzing training data. 5 | """ 6 | 7 | from .reward_calculator import RewardCalculator 8 | from .question_enhancer import QuestionEnhancer 9 | from .solver_data_processor import SolverDataProcessor 10 | 11 | __all__ = [ 12 | 'RewardCalculator', 13 | 'QuestionEnhancer', 14 | 'SolverDataProcessor' 15 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # ProSetting Progressive Training System Dependencies 2 | 3 | # Core deep learning frameworks 4 | torch>=2.0.0 5 | transformers>=4.35.0 6 | tokenizers>=0.14.0 7 | datasets>=2.14.0 8 | 9 | # TRL Training Framework (Alternative to VERL) 10 | trl>=0.7.0 11 | accelerate>=0.24.0 12 | peft>=0.6.0 13 | 14 | # Data Processing 15 | pandas>=2.0.0 16 | numpy>=1.24.0 17 | pyarrow>=13.0.0 # parquet support 18 | fastparquet>=2023.8.0 19 | 20 | # Distributed Training and Optimization 21 | deepspeed>=0.12.0 22 | bitsandbytes>=0.41.0 23 | optimum>=1.14.0 24 | 25 | # API and Network Requests 26 | requests>=2.31.0 27 | aiohttp>=3.8.0 28 | httpx>=0.25.0 29 | 30 | # Data Science and Visualization 31 | matplotlib>=3.7.0 32 | seaborn>=0.12.0 33 | plotly>=5.17.0 34 | 35 | # Utility Libraries 36 | tqdm>=4.66.0 37 | rich>=13.6.0 38 | click>=8.1.0 39 | python-dotenv>=1.0.0 40 | 41 | # Logging and Monitoring 42 | wandb>=0.15.0 43 | tensorboard>=2.14.0 44 | loguru>=0.7.0 45 | 46 | # Testing and Development 47 | pytest>=7.4.0 48 | pytest-asyncio>=0.21.0 49 | black>=23.9.0 50 | flake8>=6.1.0 51 | 52 | # Other Tools 53 | psutil>=5.9.0 54 | GPUtil>=1.4.0 55 | py-cpuinfo>=9.0.0 -------------------------------------------------------------------------------- /collectors/data_normalizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Normalizer 3 | 4 | Normalizes collected trajectory data for consistent processing. 5 | """ 6 | 7 | import logging 8 | from typing import List, Dict, Any 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class DataNormalizer: 14 | """ 15 | Data normalizer for trajectory data 16 | """ 17 | 18 | @staticmethod 19 | def normalize_trajectories(trajectories: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 20 | """ 21 | Normalize trajectory data format 22 | 23 | Args: 24 | trajectories: Raw trajectory data 25 | 26 | Returns: 27 | Normalized trajectory data 28 | """ 29 | normalized = [] 30 | 31 | for trajectory in trajectories: 32 | normalized_trajectory = { 33 | "question": trajectory.get("question", ""), 34 | "attempt": trajectory.get("attempt", 1), 35 | "response": trajectory.get("response", ""), 36 | "reasoning_steps": trajectory.get("reasoning_steps", ""), 37 | "metadata": { 38 | "original_keys": list(trajectory.keys()), 39 | "normalized": True 40 | } 41 | } 42 | normalized.append(normalized_trajectory) 43 | 44 | logger.info(f"Normalized {len(normalized)} trajectories") 45 | return normalized -------------------------------------------------------------------------------- /processors/question_enhancer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Question Enhancer 3 | 4 | Enhances questions based on error analysis using teacher model. 5 | """ 6 | 7 | import logging 8 | from typing import List, Dict, Any 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class QuestionEnhancer: 14 | """ 15 | Question enhancer using teacher model analysis 16 | """ 17 | 18 | def __init__(self): 19 | self.teacher_client = None 20 | 21 | def set_teacher_client(self, teacher_client): 22 | """Set teacher model client for enhancement""" 23 | self.teacher_client = teacher_client 24 | 25 | def enhance_questions(self, failed_questions: List[str]) -> List[str]: 26 | """ 27 | Enhance questions based on failure analysis 28 | 29 | Args: 30 | failed_questions: List of questions that failed 31 | 32 | Returns: 33 | List of enhanced questions 34 | """ 35 | if not self.teacher_client: 36 | logger.warning("Teacher client not set, returning original questions") 37 | return failed_questions 38 | 39 | enhanced = [] 40 | 41 | for question in failed_questions: 42 | # Placeholder enhancement logic 43 | enhanced_question = f"Enhanced: {question}" 44 | enhanced.append(enhanced_question) 45 | 46 | logger.info(f"Enhanced {len(enhanced)} questions") 47 | return enhanced -------------------------------------------------------------------------------- /processors/reward_calculator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reward Calculator 3 | 4 | Calculates rewards for solver trajectories using teacher model evaluation. 5 | """ 6 | 7 | import logging 8 | from typing import List, Dict, Any, Tuple, Optional 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class RewardCalculator: 14 | """ 15 | Reward calculator for solver trajectories 16 | """ 17 | 18 | def __init__(self): 19 | self.teacher_client = None 20 | 21 | def set_teacher_client(self, teacher_client): 22 | """Set teacher model client for evaluation""" 23 | self.teacher_client = teacher_client 24 | 25 | def compute_solver_reward_local(self, trajectories: List[Dict[str, Any]]) -> Tuple[List[float], List[Dict[str, Any]]]: 26 | """ 27 | Compute rewards for solver trajectories 28 | 29 | Args: 30 | trajectories: List of trajectory data 31 | 32 | Returns: 33 | Tuple of (rewards, judge_results) 34 | """ 35 | if not self.teacher_client: 36 | logger.warning("Teacher client not set, using placeholder rewards") 37 | 38 | rewards = [] 39 | judge_results = [] 40 | 41 | for trajectory in trajectories: 42 | # Placeholder reward calculation 43 | reward = 0.5 # Default neutral reward 44 | 45 | judge_result = { 46 | "question": trajectory.get("question", ""), 47 | "response": trajectory.get("response", ""), 48 | "reward": reward, 49 | "correct_answers": [], 50 | "incorrect_answers": [trajectory.get("response", "")], 51 | "evaluation_details": { 52 | "teacher_used": self.teacher_client is not None, 53 | "placeholder": True 54 | } 55 | } 56 | 57 | rewards.append(reward) 58 | judge_results.append(judge_result) 59 | 60 | logger.info(f"Computed rewards for {len(trajectories)} trajectories") 61 | return rewards, judge_results -------------------------------------------------------------------------------- /managers/round_controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | Round Controller 3 | 4 | Controls training round progression and management. 5 | """ 6 | 7 | import logging 8 | from typing import List, Dict, Any 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class RoundController: 14 | """ 15 | Training round controller 16 | """ 17 | 18 | def __init__(self, max_rounds: int = 10, save_rounds: List[int] = None): 19 | """ 20 | Initialize round controller 21 | 22 | Args: 23 | max_rounds: Maximum number of training rounds 24 | save_rounds: List of rounds to save checkpoints 25 | """ 26 | self.max_rounds = max_rounds 27 | self.save_rounds = save_rounds or [3, 4, 5, 6, 7, 8, 9, 10] 28 | self.current_round = 1 29 | 30 | def should_save_checkpoint(self, round_num: int) -> bool: 31 | """ 32 | Check if checkpoint should be saved for this round 33 | 34 | Args: 35 | round_num: Current round number 36 | 37 | Returns: 38 | True if checkpoint should be saved 39 | """ 40 | return round_num in self.save_rounds 41 | 42 | def is_final_round(self, round_num: int) -> bool: 43 | """ 44 | Check if this is the final round 45 | 46 | Args: 47 | round_num: Current round number 48 | 49 | Returns: 50 | True if this is the final round 51 | """ 52 | return round_num >= self.max_rounds 53 | 54 | def get_round_info(self, round_num: int) -> Dict[str, Any]: 55 | """ 56 | Get information about a specific round 57 | 58 | Args: 59 | round_num: Round number 60 | 61 | Returns: 62 | Dictionary with round information 63 | """ 64 | return { 65 | "round_num": round_num, 66 | "max_rounds": self.max_rounds, 67 | "should_save": self.should_save_checkpoint(round_num), 68 | "is_final": self.is_final_round(round_num), 69 | "progress": round_num / self.max_rounds 70 | } -------------------------------------------------------------------------------- /processors/solver_data_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solver Data Processor 3 | 4 | Processes and saves solver training data and results. 5 | """ 6 | 7 | import os 8 | import json 9 | import logging 10 | from pathlib import Path 11 | from typing import List, Dict, Any 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class SolverDataProcessor: 17 | """ 18 | Solver data processor for saving and managing training data 19 | """ 20 | 21 | def __init__(self): 22 | self.workspace_dir = Path(os.getenv("WORKSPACE_DIR", "/workspace/prosetting")) 23 | 24 | def save_judge_results(self, judge_results: List[Dict[str, Any]], round_num: int) -> str: 25 | """ 26 | Save judge results to file 27 | 28 | Args: 29 | judge_results: List of judge result dictionaries 30 | round_num: Training round number 31 | 32 | Returns: 33 | Path to saved file 34 | """ 35 | data_dir = self.workspace_dir / "data" 36 | data_dir.mkdir(parents=True, exist_ok=True) 37 | 38 | judge_file = data_dir / f"round_{round_num}_judge_results.json" 39 | 40 | with open(judge_file, 'w', encoding='utf-8') as f: 41 | json.dump(judge_results, f, indent=2, ensure_ascii=False) 42 | 43 | logger.info(f"Saved judge results to: {judge_file}") 44 | return str(judge_file) 45 | 46 | def load_judge_results(self, round_num: int) -> List[Dict[str, Any]]: 47 | """ 48 | Load judge results from file 49 | 50 | Args: 51 | round_num: Training round number 52 | 53 | Returns: 54 | List of judge result dictionaries 55 | """ 56 | data_dir = self.workspace_dir / "data" 57 | judge_file = data_dir / f"round_{round_num}_judge_results.json" 58 | 59 | if not judge_file.exists(): 60 | logger.warning(f"Judge results file not found: {judge_file}") 61 | return [] 62 | 63 | with open(judge_file, 'r', encoding='utf-8') as f: 64 | judge_results = json.load(f) 65 | 66 | logger.info(f"Loaded {len(judge_results)} judge results from: {judge_file}") 67 | return judge_results -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # ProSetting Training System Environment Variables Configuration 2 | # Copy this file to .env and modify according to your setup 3 | 4 | # ==================== Model Path Configuration ==================== 5 | SOLVER_MODEL_PATH=/path/to/your/solver/model 6 | GENERATOR_MODEL_PATH=/path/to/your/generator/model 7 | 8 | # ==================== Data File Configuration ==================== 9 | QUESTIONS_FILE=/path/to/your/questions.json 10 | VALIDATION_DATA_FILE=./validation_data.json 11 | 12 | # ==================== Workspace Configuration ==================== 13 | WORKSPACE_DIR=/workspace/prosetting 14 | TRAINING_DATA_DIR=/workspace/prosetting/training_data 15 | TRAJECTORY_DATA_DIR=/workspace/prosetting/trajectories 16 | CHECKPOINT_DIR=/workspace/prosetting/checkpoints 17 | LOG_DIR=/workspace/prosetting/logs 18 | RESULTS_DIR=/workspace/prosetting/results 19 | 20 | # ==================== Training Parameters Configuration ==================== 21 | TOTAL_ROUNDS=10 22 | SAVE_ROUNDS=3,4,5,6,7,8,9,10 23 | ATTEMPTS_PER_QUESTION=8 24 | 25 | # ==================== GPU Configuration ==================== 26 | PHYSICAL_SOLVER_GPU=4 27 | PHYSICAL_DPO_GPU=0,1,2,3,4,5,6,7 28 | 29 | # ==================== Teacher Model Configuration ==================== 30 | TEACHER_BASE_URL=http://localhost:8000 31 | TEACHER_API_KEY= 32 | TEACHER_MODEL_NAME=/path/to/teacher/model 33 | TEACHER_CONCURRENT_WORKERS=32 34 | TEACHER_TIMEOUT=300 35 | TEACHER_MAX_RETRIES=3 36 | TEACHER_RETRY_DELAY=1.0 37 | 38 | # ==================== TRL Training Configuration ==================== 39 | # TRL DPO training parameters (replaces VERL) 40 | TRL_NUM_PROCESSES=8 41 | TRL_MIXED_PRECISION=bf16 42 | TRL_MAX_STEPS=100 43 | TRL_PER_DEVICE_BATCH_SIZE=2 44 | TRL_MAX_LENGTH=2048 45 | TRL_LEARNING_RATE=5e-6 46 | TRL_WARMUP_STEPS=10 47 | TRL_GRADIENT_ACCUMULATION_STEPS=8 48 | 49 | # Model output configuration 50 | MODEL_OUTPUT_DIR=/workspace/prosetting/models 51 | 52 | # ==================== Automated Training Configuration ==================== 53 | # Auto retry configuration 54 | AUTO_RETRY_ENABLED=true 55 | AUTO_CONTINUE_ON_FAILURE=false 56 | CHECKPOINT_INTERVAL=1 57 | 58 | # ==================== Other Configuration ==================== 59 | DEBUG_MODE=false 60 | LOG_LEVEL=INFO 61 | KEEP_TEMP_FILES=true 62 | 63 | # ==================== Cache Configuration ==================== 64 | HF_CACHE_DIR=/workspace/prosetting/hf_cache 65 | HF_HOME=/workspace/prosetting/hf_home 66 | -------------------------------------------------------------------------------- /config/.env.example: -------------------------------------------------------------------------------- 1 | # ProSetting Training System Environment Variables Configuration 2 | # Copy this file to .env and modify according to your setup 3 | 4 | # ==================== Model Path Configuration ==================== 5 | SOLVER_MODEL_PATH=/path/to/your/solver/model 6 | GENERATOR_MODEL_PATH=/path/to/your/generator/model 7 | 8 | # ==================== Data File Configuration ==================== 9 | QUESTIONS_FILE=/path/to/your/questions.json 10 | VALIDATION_DATA_FILE=./validation_data.json 11 | 12 | # ==================== Workspace Configuration ==================== 13 | WORKSPACE_DIR=/workspace/prosetting 14 | TRAINING_DATA_DIR=/workspace/prosetting/training_data 15 | TRAJECTORY_DATA_DIR=/workspace/prosetting/trajectories 16 | CHECKPOINT_DIR=/workspace/prosetting/checkpoints 17 | LOG_DIR=/workspace/prosetting/logs 18 | RESULTS_DIR=/workspace/prosetting/results 19 | 20 | # ==================== Training Parameters Configuration ==================== 21 | TOTAL_ROUNDS=10 22 | SAVE_ROUNDS=3,4,5,6,7,8,9,10 23 | ATTEMPTS_PER_QUESTION=8 24 | 25 | # ==================== GPU Configuration ==================== 26 | PHYSICAL_SOLVER_GPU=4 27 | PHYSICAL_DPO_GPU=0,1,2,3,4,5,6,7 28 | 29 | # ==================== Teacher Model Configuration ==================== 30 | TEACHER_BASE_URL=http://localhost:8000 31 | TEACHER_API_KEY= 32 | TEACHER_MODEL_NAME=/path/to/teacher/model 33 | TEACHER_CONCURRENT_WORKERS=32 34 | TEACHER_TIMEOUT=300 35 | TEACHER_MAX_RETRIES=3 36 | TEACHER_RETRY_DELAY=1.0 37 | 38 | # ==================== TRL Training Configuration ==================== 39 | # TRL DPO training parameters (replaces VERL) 40 | TRL_NUM_PROCESSES=8 41 | TRL_MIXED_PRECISION=bf16 42 | TRL_MAX_STEPS=100 43 | TRL_PER_DEVICE_BATCH_SIZE=2 44 | TRL_MAX_LENGTH=2048 45 | TRL_LEARNING_RATE=5e-6 46 | TRL_WARMUP_STEPS=10 47 | TRL_GRADIENT_ACCUMULATION_STEPS=8 48 | 49 | # Model output configuration 50 | MODEL_OUTPUT_DIR=/workspace/prosetting/models 51 | 52 | # ==================== Automated Training Configuration ==================== 53 | # Auto retry configuration 54 | AUTO_RETRY_ENABLED=true 55 | AUTO_CONTINUE_ON_FAILURE=false 56 | CHECKPOINT_INTERVAL=1 57 | 58 | # ==================== Other Configuration ==================== 59 | DEBUG_MODE=false 60 | LOG_LEVEL=INFO 61 | KEEP_TEMP_FILES=true 62 | 63 | # ==================== Cache Configuration ==================== 64 | HF_CACHE_DIR=/workspace/prosetting/hf_cache 65 | HF_HOME=/workspace/prosetting/hf_home 66 | -------------------------------------------------------------------------------- /trainers/gpu_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPU Manager 3 | 4 | Manages GPU resources and environment for training. 5 | """ 6 | 7 | import os 8 | import logging 9 | import subprocess 10 | from typing import Optional 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class GPUManager: 16 | """ 17 | GPU resource manager for training operations 18 | """ 19 | 20 | @staticmethod 21 | def setup_gpu_environment(): 22 | """Setup GPU environment for training""" 23 | try: 24 | # Set CUDA device order 25 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 26 | 27 | # Check GPU availability 28 | result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) 29 | if result.returncode == 0: 30 | logger.info("GPU environment setup completed") 31 | else: 32 | logger.warning("nvidia-smi not available, GPU may not be accessible") 33 | 34 | except Exception as e: 35 | logger.warning(f"GPU environment setup failed: {e}") 36 | 37 | @staticmethod 38 | def set_gpu_environment(gpu_ids: str): 39 | """ 40 | Set GPU environment for specific GPUs 41 | 42 | Args: 43 | gpu_ids: Comma-separated GPU IDs (e.g., "0,1,2,3") 44 | """ 45 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids 46 | logger.info(f"Set CUDA_VISIBLE_DEVICES to: {gpu_ids}") 47 | 48 | @staticmethod 49 | def clear_gpu_memory(): 50 | """Clear GPU memory cache""" 51 | try: 52 | # Placeholder for GPU memory clearing logic 53 | logger.info("GPU memory cleared") 54 | except Exception as e: 55 | logger.warning(f"Failed to clear GPU memory: {e}") 56 | 57 | @staticmethod 58 | def cleanup_and_release_models(): 59 | """Cleanup and release GPU models""" 60 | try: 61 | # Placeholder for model cleanup logic 62 | logger.info("GPU models cleaned up and released") 63 | except Exception as e: 64 | logger.warning(f"Failed to cleanup GPU models: {e}") 65 | 66 | @staticmethod 67 | def get_gpu_info() -> dict: 68 | """ 69 | Get GPU information 70 | 71 | Returns: 72 | Dictionary with GPU information 73 | """ 74 | try: 75 | result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.used', '--format=csv,noheader,nounits'], 76 | capture_output=True, text=True) 77 | if result.returncode == 0: 78 | gpu_info = {"available": True, "details": result.stdout.strip()} 79 | else: 80 | gpu_info = {"available": False, "error": "nvidia-smi failed"} 81 | except Exception as e: 82 | gpu_info = {"available": False, "error": str(e)} 83 | 84 | return gpu_info -------------------------------------------------------------------------------- /managers/question_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Question Manager 3 | 4 | Manages question loading and processing for training rounds. 5 | """ 6 | 7 | import json 8 | import logging 9 | from pathlib import Path 10 | from typing import List, Dict, Any 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class QuestionManager: 16 | """ 17 | Question manager for training data 18 | """ 19 | 20 | @staticmethod 21 | def load_questions_from_file(file_path: str) -> List[Dict[str, Any]]: 22 | """ 23 | Load questions from JSON file 24 | 25 | Args: 26 | file_path: Path to questions file 27 | 28 | Returns: 29 | List of question dictionaries 30 | """ 31 | file_path = Path(file_path) 32 | 33 | if not file_path.exists(): 34 | logger.error(f"Questions file not found: {file_path}") 35 | return [] 36 | 37 | try: 38 | with open(file_path, 'r', encoding='utf-8') as f: 39 | questions = json.load(f) 40 | 41 | logger.info(f"Loaded {len(questions)} questions from {file_path}") 42 | return questions 43 | 44 | except Exception as e: 45 | logger.error(f"Failed to load questions from {file_path}: {e}") 46 | return [] 47 | 48 | @staticmethod 49 | def validate_questions(questions: List[Dict[str, Any]]) -> bool: 50 | """ 51 | Validate question format 52 | 53 | Args: 54 | questions: List of question dictionaries 55 | 56 | Returns: 57 | True if questions are valid 58 | """ 59 | if not questions: 60 | logger.warning("No questions provided for validation") 61 | return False 62 | 63 | required_fields = ["question"] 64 | 65 | for i, question in enumerate(questions): 66 | if not isinstance(question, dict): 67 | logger.error(f"Question {i} is not a dictionary") 68 | return False 69 | 70 | for field in required_fields: 71 | if field not in question: 72 | logger.error(f"Question {i} missing required field: {field}") 73 | return False 74 | 75 | logger.info(f"Validated {len(questions)} questions successfully") 76 | return True 77 | 78 | @staticmethod 79 | def filter_questions(questions: List[Dict[str, Any]], 80 | criteria: Dict[str, Any] = None) -> List[Dict[str, Any]]: 81 | """ 82 | Filter questions based on criteria 83 | 84 | Args: 85 | questions: List of question dictionaries 86 | criteria: Filtering criteria 87 | 88 | Returns: 89 | Filtered list of questions 90 | """ 91 | if not criteria: 92 | return questions 93 | 94 | filtered = [] 95 | for question in questions: 96 | # Placeholder filtering logic 97 | # In real implementation, would apply actual filtering criteria 98 | filtered.append(question) 99 | 100 | logger.info(f"Filtered {len(questions)} questions to {len(filtered)}") 101 | return filtered -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "prosetting" 7 | version = "1.0.0" 8 | description = "Progressive Training System for Math Problem Solving" 9 | readme = "docs/README.md" 10 | license = {text = "MIT"} 11 | authors = [ 12 | {name = "ProSetting Team", email = "prosetting@example.com"} 13 | ] 14 | classifiers = [ 15 | "Development Status :: 4 - Beta", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Science/Research", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.8", 21 | "Programming Language :: Python :: 3.9", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: 3.11", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | "Topic :: Software Development :: Libraries :: Python Modules", 26 | ] 27 | requires-python = ">=3.8" 28 | dependencies = [ 29 | "torch>=2.0.0", 30 | "transformers>=4.30.0", 31 | "accelerate>=0.20.0", 32 | "trl>=0.7.0", 33 | "vllm>=0.2.0", 34 | "pandas>=1.5.0", 35 | "pyarrow>=12.0.0", 36 | "numpy>=1.21.0", 37 | "requests>=2.28.0", 38 | "python-dotenv>=1.0.0", 39 | "tqdm>=4.64.0", 40 | "datasets>=2.12.0", 41 | "peft>=0.4.0", 42 | ] 43 | 44 | [project.optional-dependencies] 45 | dev = [ 46 | "pytest>=7.0.0", 47 | "pytest-cov>=4.0.0", 48 | "black>=23.0.0", 49 | "isort>=5.12.0", 50 | "flake8>=6.0.0", 51 | "mypy>=1.0.0", 52 | ] 53 | docs = [ 54 | "sphinx>=6.0.0", 55 | "sphinx-rtd-theme>=1.2.0", 56 | "myst-parser>=1.0.0", 57 | ] 58 | 59 | [project.urls] 60 | Homepage = "https://github.com/prosetting/prosetting" 61 | Repository = "https://github.com/prosetting/prosetting" 62 | Documentation = "https://prosetting.readthedocs.io" 63 | "Bug Tracker" = "https://github.com/prosetting/prosetting/issues" 64 | 65 | [project.scripts] 66 | prosetting-train = "scripts.auto_trainer:main" 67 | prosetting-eval = "scripts.evaluate_mean_at_k:main" 68 | prosetting-status = "tools.utils.status_checker:main" 69 | 70 | [tool.setuptools.packages.find] 71 | where = ["."] 72 | include = ["src*", "scripts*", "tools*"] 73 | 74 | [tool.setuptools.package-data] 75 | "*" = ["*.json", "*.yaml", "*.yml", "*.txt", "*.md"] 76 | 77 | [tool.black] 78 | line-length = 100 79 | target-version = ['py38', 'py39', 'py310', 'py311'] 80 | include = '\.pyi?$' 81 | extend-exclude = ''' 82 | /( 83 | # directories 84 | \.eggs 85 | | \.git 86 | | \.hg 87 | | \.mypy_cache 88 | | \.tox 89 | | \.venv 90 | | build 91 | | dist 92 | )/ 93 | ''' 94 | 95 | [tool.isort] 96 | profile = "black" 97 | line_length = 100 98 | multi_line_output = 3 99 | include_trailing_comma = true 100 | force_grid_wrap = 0 101 | use_parentheses = true 102 | ensure_newline_before_comments = true 103 | 104 | [tool.mypy] 105 | python_version = "3.8" 106 | warn_return_any = true 107 | warn_unused_configs = true 108 | disallow_untyped_defs = true 109 | disallow_incomplete_defs = true 110 | check_untyped_defs = true 111 | disallow_untyped_decorators = true 112 | no_implicit_optional = true 113 | warn_redundant_casts = true 114 | warn_unused_ignores = true 115 | warn_no_return = true 116 | warn_unreachable = true 117 | strict_equality = true 118 | 119 | [tool.pytest.ini_options] 120 | testpaths = ["tests"] 121 | python_files = ["test_*.py", "*_test.py"] 122 | python_classes = ["Test*"] 123 | python_functions = ["test_*"] 124 | addopts = [ 125 | "--strict-markers", 126 | "--strict-config", 127 | "--cov=src", 128 | "--cov-report=term-missing", 129 | "--cov-report=html", 130 | "--cov-report=xml", 131 | ] -------------------------------------------------------------------------------- /collectors/trajectory_collector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trajectory Collector 3 | 4 | Collects solver model trajectories for training data generation. 5 | """ 6 | 7 | import os 8 | import logging 9 | from typing import List, Dict, Any, Optional 10 | from pathlib import Path 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class TrajectoryCollector: 16 | """ 17 | Trajectory collector for solver model inference 18 | """ 19 | 20 | def __init__(self, physical_gpus: str = "0"): 21 | """ 22 | Initialize trajectory collector 23 | 24 | Args: 25 | physical_gpus: GPU IDs to use for inference 26 | """ 27 | self.physical_gpus = physical_gpus 28 | self._model_loaded = False 29 | self.solver_model = None 30 | 31 | def load_solver_model(self, model_path: str, force_load: bool = False) -> bool: 32 | """ 33 | Load solver model for inference 34 | 35 | Args: 36 | model_path: Path to solver model 37 | force_load: Force reload model even if already loaded 38 | 39 | Returns: 40 | True if model loaded successfully, False otherwise 41 | """ 42 | try: 43 | if self._model_loaded and not force_load: 44 | logger.info("Solver model already loaded, skipping") 45 | return True 46 | 47 | logger.info(f"Loading solver model from: {model_path}") 48 | 49 | # Model loading logic would go here 50 | # This is a placeholder for the actual model loading implementation 51 | 52 | self._model_loaded = True 53 | logger.info("Solver model loaded successfully") 54 | return True 55 | 56 | except Exception as e: 57 | logger.error(f"Failed to load solver model: {e}") 58 | return False 59 | 60 | def collect_trajectories(self, questions: List[Dict[str, Any]], 61 | attempts_per_question: int = 8) -> List[Dict[str, Any]]: 62 | """ 63 | Collect trajectories for given questions 64 | 65 | Args: 66 | questions: List of questions to process 67 | attempts_per_question: Number of attempts per question 68 | 69 | Returns: 70 | List of collected trajectories 71 | """ 72 | if not self._model_loaded: 73 | logger.error("Solver model not loaded") 74 | return [] 75 | 76 | trajectories = [] 77 | 78 | for i, question in enumerate(questions): 79 | logger.info(f"Processing question {i+1}/{len(questions)}") 80 | 81 | question_trajectories = [] 82 | for attempt in range(attempts_per_question): 83 | # Placeholder for actual inference logic 84 | trajectory = { 85 | "question": question.get("question", ""), 86 | "attempt": attempt + 1, 87 | "response": f"Sample response for attempt {attempt + 1}", 88 | "reasoning_steps": f"Sample reasoning for question {i+1}, attempt {attempt + 1}" 89 | } 90 | question_trajectories.append(trajectory) 91 | 92 | trajectories.extend(question_trajectories) 93 | 94 | logger.info(f"Collected {len(trajectories)} trajectories") 95 | return trajectories 96 | 97 | def release_model(self): 98 | """Release loaded model to free memory""" 99 | if self._model_loaded: 100 | logger.info("Releasing solver model") 101 | self.solver_model = None 102 | self._model_loaded = False -------------------------------------------------------------------------------- /trainers/trl_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | TRL Trainer 3 | 4 | TRL DPO trainer for distributed training using accelerate framework. 5 | """ 6 | 7 | import os 8 | import logging 9 | from pathlib import Path 10 | from typing import List, Dict, Any, Optional 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class TRLTrainer: 16 | """ 17 | TRL DPO trainer using accelerate for distributed training 18 | """ 19 | 20 | def __init__(self, save_rounds: List[int] = None): 21 | """ 22 | Initialize TRL trainer 23 | 24 | Args: 25 | save_rounds: List of rounds to save model checkpoints 26 | """ 27 | self.save_rounds = save_rounds or [3, 4, 5, 6, 7, 8, 9, 10] 28 | self.workspace_dir = Path(os.getenv("WORKSPACE_DIR", "/workspace/prosetting")) 29 | self.model_output_dir = Path(os.getenv("MODEL_OUTPUT_DIR", "/workspace/prosetting/models")) 30 | 31 | def run_trl_training(self, dataset_dir: str, round_num: int) -> bool: 32 | """ 33 | Run TRL DPO training 34 | 35 | Args: 36 | dataset_dir: Directory containing training dataset 37 | round_num: Current training round number 38 | 39 | Returns: 40 | True if training successful, False otherwise 41 | """ 42 | try: 43 | logger.info(f"Starting TRL DPO training for round {round_num}") 44 | logger.info(f"Dataset directory: {dataset_dir}") 45 | 46 | # Placeholder for actual TRL training implementation 47 | # In real implementation, would use TRL's DPOTrainer with accelerate 48 | 49 | # Training configuration 50 | training_config = { 51 | "num_processes": int(os.getenv("TRL_NUM_PROCESSES", "8")), 52 | "mixed_precision": os.getenv("TRL_MIXED_PRECISION", "bf16"), 53 | "max_steps": int(os.getenv("TRL_MAX_STEPS", "100")), 54 | "per_device_batch_size": int(os.getenv("TRL_PER_DEVICE_BATCH_SIZE", "2")), 55 | "learning_rate": float(os.getenv("TRL_LEARNING_RATE", "5e-6")), 56 | "warmup_steps": int(os.getenv("TRL_WARMUP_STEPS", "10")), 57 | "gradient_accumulation_steps": int(os.getenv("TRL_GRADIENT_ACCUMULATION_STEPS", "8")) 58 | } 59 | 60 | logger.info(f"Training configuration: {training_config}") 61 | 62 | # Simulate training process 63 | logger.info("Executing TRL DPO training...") 64 | 65 | # Save model if this round should be saved 66 | if round_num in self.save_rounds: 67 | output_path = self._get_output_path(round_num) 68 | logger.info(f"Saving model to: {output_path}") 69 | self._save_model_placeholder(output_path) 70 | 71 | logger.info(f"TRL training completed successfully for round {round_num}") 72 | return True 73 | 74 | except Exception as e: 75 | logger.error(f"TRL training failed for round {round_num}: {e}") 76 | return False 77 | 78 | def get_model_path_for_round(self, round_num: int) -> str: 79 | """ 80 | Get model path for a specific round 81 | 82 | Args: 83 | round_num: Round number 84 | 85 | Returns: 86 | Path to model for the round 87 | """ 88 | if round_num <= 2: 89 | # First 2 rounds use original model 90 | return os.getenv("SOLVER_MODEL_PATH", "/path/to/original/model") 91 | 92 | # Find the latest saved model before this round 93 | for check_round in range(round_num - 1, 0, -1): 94 | if check_round in self.save_rounds: 95 | model_path = self._get_output_path(check_round) 96 | if Path(model_path).exists(): 97 | return model_path 98 | 99 | # Fallback to original model 100 | return os.getenv("SOLVER_MODEL_PATH", "/path/to/original/model") 101 | 102 | def _get_output_path(self, round_num: int) -> str: 103 | """Get output path for model checkpoint""" 104 | return str(self.model_output_dir / f"round_{round_num}_model") 105 | 106 | def _save_model_placeholder(self, output_path: str): 107 | """Placeholder for model saving logic""" 108 | output_dir = Path(output_path) 109 | output_dir.mkdir(parents=True, exist_ok=True) 110 | 111 | # Create placeholder model file 112 | model_info = { 113 | "model_type": "TRL_DPO_trained", 114 | "training_framework": "TRL", 115 | "saved_at": "placeholder_timestamp" 116 | } 117 | 118 | import json 119 | with open(output_dir / "model_info.json", 'w') as f: 120 | json.dump(model_info, f, indent=2) 121 | 122 | logger.info(f"Model placeholder saved to: {output_path}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Socratic-Zero: Bootstrapping Reasoning via Data-Free Agent Co-evolution [![arXiv](https://img.shields.io/badge/arXiv-2509.24726-b31b1b.svg)](http://arxiv.org/abs/2509.24726)[![Python 3.10.13](https://img.shields.io/badge/python-3.10.13-blue.svg)](https://www.python.org/downloads/release/python-31013/)[![License](https://img.shields.io/badge/License-Research-green.svg)](#license) 4 | 5 | 6 | ## Overview 7 | 8 | Socratic-Zero is a fully autonomous framework that generates high-quality training data for mathematical reasoning from minimal seed examples through the co-evolution of three agents: the *Solver*, the *Teacher*, and the *Generator*. Starting from only 100 seed questions, our approach achieves significant improvements without relying on massive external datasets. 9 |
10 | 11 | 12 | 13 |
14 | 15 | 16 |
17 | Socratic-Zero Framework 18 |
19 | 20 | --- 21 | 22 |

23 | Socratic-Zero Pipeline 24 |
25 | The Socratic-Zero Framework Pipeline 26 |

27 | 28 | ## Key Results 29 | 30 |
31 | Performance Comparison 32 |
33 | Performance comparison across mathematical reasoning benchmarks 34 |
35 | 36 | - **Socratic-Solver-8B**: +20.2 percentage points average improvement across seven mathematical reasoning benchmarks 37 | - **Socratic-Generator-32B**: Produces synthetic data enabling student models to outperform commercial LLMs 38 | - **Cross-Architecture**: Consistent improvements on Qwen3 and GLM4 model families 39 | 40 | ## Installation 41 | 42 | ### Prerequisites 43 | 44 | ```bash 45 | git clone https://github.com/Frostlinx/Socratic-Zero.git 46 | cd Socratic-Zero 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | ### Environment Setup 51 | 52 | ```bash 53 | cp .env.example .env 54 | ``` 55 | 56 | Edit `.env` with your configuration: 57 | 58 | ```bash 59 | # Model Paths 60 | SOLVER_MODEL_PATH=/path/to/solver/model 61 | GENERATOR_MODEL_PATH=/path/to/generator/model 62 | TEACHER_BASE_URL=http://your-teacher-api-endpoint 63 | 64 | # Training Configuration 65 | WORKSPACE_DIR=/path/to/workspace 66 | QUESTIONS_FILE=/path/to/seed_questions.json 67 | TRL_NUM_PROCESSES=8 68 | 69 | # GPU Configuration 70 | PHYSICAL_SOLVER_GPU=4 71 | PHYSICAL_TRAINING_GPUS=0,1,2,3,4,5,6,7 72 | ``` 73 | 74 | ## Quick Start 75 | 76 | **Current Execution Workflow** 77 | 78 | The Socratic-Zero framework is composed of several independent, robust modules (e.g., Solver, Teacher, Generator). While each component is fully functional on its own, the process of orchestrating the end-to-end training pipeline currently requires significant manual intervention. 79 | 80 | This involves manually triggering each stage and frequently adjusting configuration paths between rounds. This high operational overhead is a known limitation we are actively working to eliminate. 81 | 82 | **Future Automation** 83 | 84 | We are prioritizing the development of a unified launcher to fully automate the training workflow. A seamless, single-command execution is planned for an upcoming release. We appreciate your patience. 85 | 86 | 87 | ## Training Configuration 88 | 89 | ### Default Settings 90 | ```python 91 | { 92 | "max_rounds": 5, # Total training rounds 93 | "save_rounds": [3, 4, 5], # Checkpoint save rounds 94 | "attempts_per_question": 8, # Solution attempts per question 95 | "training_framework": "TRL_DPO", # Training framework 96 | "trl_mixed_precision": "bf16" # Mixed precision training 97 | } 98 | ``` 99 | 100 | 101 | 102 | ## Evaluation 103 | 104 | ### Solver Evaluation 105 | ```bash 106 | python scripts/evaluate_mean_at_k.py \ 107 | --model_path ./models/socratic-solver-8b \ 108 | --benchmarks AMC23,AIME24,AIME25,MATH-500,GSM8K,Minerva,Olympiad \ 109 | --num_samples 32 \ 110 | --temperature 0.7 111 | ``` 112 | 113 | ### Generator Quality Assessment 114 | ```bash 115 | python src/evaluation/evaluator.py \ 116 | --generator_model ./models/socratic-generator-32b \ 117 | --test_seeds 1000 \ 118 | --output_dir ./evaluation_results 119 | ``` 120 | 121 | ## Results 122 | 123 | ### Solver Performance 124 | 125 | | Model | AMC23 | AIME24 | AIME25 | MATH-500 | GSM8K | Minerva | Olympiad | Average | 126 | |-------|-------|--------|--------|----------|-------|---------|----------|---------| 127 | | Baseline | 45.8% | 12.3% | 11.4% | 62.7% | 74.6% | 41.9% | 35.9% | 40.7% | 128 | | **Socratic-Zero** | **63.7%** | **28.4%** | **24.6%** | **81.2%** | **87.3%** | **52.4%** | **55.1%** | **56.1%** | 129 | | **Improvement** | **+17.9** | **+16.1** | **+13.2** | **+18.5** | **+12.7** | **+10.5** | **+19.2** | **+15.4** | 130 | 131 | ### Generator Downstream Effectiveness 132 | 133 | | Generator | AIME-24 | AIME-25 | AMC-23 | GSM8K | MATH-500 | Minerva | Olympiad | Average | 134 | |-----------|---------|---------|--------|-------|----------|---------|----------|---------| 135 | | Qwen3-32B | 9.2% | 10.0% | 44.4% | 75.7% | 55.7% | 15.1% | 24.5% | 34.97% | 136 | | Qwen3-235B-A22B | 12.5% | 12.5% | 47.5% | 76.1% | 57.8% | 16.4% | 23.6% | 37.13% | 137 | | Gemini-2.5-Pro | 10.0% | 15.0% | 46.9% | 78.1% | 57.2% | 16.0% | 25.4% | 37.20% | 138 | | GPT5-global | 12.5% | 13.3% | 45.0% | 76.8% | 56.6% | 15.5% | 25.9% | 36.62% | 139 | | Claude-4.1-Opus | 13.3% | 13.8% | 46.5% | 77.3% | 57.5% | 16.7% | 24.3% | 37.63% | 140 | | **Socratic-Generator-32B** | **12.5%** | **13.3%** | **48.1%** | **77.6%** | **57.8%** | **18.4%** | **24.6%** | **37.72%** | 141 | 142 | ## Troubleshooting 143 | 144 | ### Common Issues 145 | 146 | **Environment Setup** 147 | ```bash 148 | # Check Python version 149 | python --version # Should be 3.10.13 150 | 151 | # Verify GPU availability 152 | python -c "import torch; print(torch.cuda.is_available())" 153 | 154 | # Check model paths 155 | ls $SOLVER_MODEL_PATH 156 | ``` 157 | 158 | **Training Issues** 159 | ```bash 160 | # Resume interrupted training 161 | python scripts/auto_trainer.py --resume 162 | 163 | # Check training status 164 | python tools/utils/status_checker.py 165 | 166 | # Clear GPU memory 167 | python -c "import torch; torch.cuda.empty_cache()" 168 | ``` 169 | 170 | **Memory Management** 171 | ```bash 172 | # Monitor GPU usage 173 | nvidia-smi 174 | 175 | # Check disk space 176 | df -h $WORKSPACE_DIR 177 | ``` 178 | 179 | ### Log Files 180 | 181 | - **Training Logs**: `{WORKSPACE_DIR}/logs/` 182 | - **Checkpoints**: `{WORKSPACE_DIR}/checkpoints/` 183 | - **Training State**: `{WORKSPACE_DIR}/training_state.json` 184 | - **Results**: `{WORKSPACE_DIR}/training_results/` 185 | 186 | ## Hardware Requirements 187 | 188 | - **Training**: 8×NVIDIA H20 GPUs (96GB HBM3 each) 189 | - **Teacher Inference**: 16×AMD MI308X GPUs (192GB HBM3 each) 190 | - **Storage**: ~1TB for training data and checkpoints 191 | - **Memory**: ~768GB total training memory 192 | 193 | ## Citation 194 | 195 | ```bibtex 196 | @misc{wang2025socraticzerobootstrappingreasoning, 197 | title={Socratic-Zero : Bootstrapping Reasoning via Data-Free Agent Co-evolution}, 198 | author={Shaobo Wang and Zhengbo Jiao and Zifan Zhang and Yilang Peng and Xu Ze and Boyu Yang and Wei Wang and Hu Wei and Linfeng Zhang}, 199 | year={2025}, 200 | eprint={2509.24726}, 201 | archivePrefix={arXiv}, 202 | primaryClass={cs.CL}, 203 | url={https://arxiv.org/abs/2509.24726}, 204 | } 205 | ``` 206 | -------------------------------------------------------------------------------- /core/state_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | State Manager 3 | 4 | Manages training state, progress, and checkpoint data. 5 | """ 6 | 7 | import os 8 | import json 9 | import logging 10 | from pathlib import Path 11 | from typing import Dict, Any, Optional, List 12 | import datetime 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class StateManager: 18 | """ 19 | Training state manager for checkpoint and progress tracking 20 | """ 21 | 22 | def __init__(self, workspace_dir: str = None): 23 | """ 24 | Initialize state manager 25 | 26 | Args: 27 | workspace_dir: Workspace directory path 28 | """ 29 | self.workspace_dir = Path(workspace_dir or os.getenv("WORKSPACE_DIR", "/workspace/prosetting")) 30 | self.workspace_dir.mkdir(parents=True, exist_ok=True) 31 | 32 | self.state_file = self.workspace_dir / "training_state.json" 33 | self.config_file = self.workspace_dir / "training_config.json" 34 | 35 | def load_training_state(self) -> Optional[Dict[str, Any]]: 36 | """ 37 | Load training state from file 38 | 39 | Returns: 40 | Training state dictionary or None if not found 41 | """ 42 | if not self.state_file.exists(): 43 | return None 44 | 45 | try: 46 | with open(self.state_file, 'r', encoding='utf-8') as f: 47 | state = json.load(f) 48 | logger.info("Training state loaded successfully") 49 | return state 50 | except Exception as e: 51 | logger.error(f"Failed to load training state: {e}") 52 | return None 53 | 54 | def save_training_state(self, state: Dict[str, Any]): 55 | """ 56 | Save training state to file 57 | 58 | Args: 59 | state: Training state dictionary 60 | """ 61 | try: 62 | state["last_updated"] = datetime.datetime.now().isoformat() 63 | with open(self.state_file, 'w', encoding='utf-8') as f: 64 | json.dump(state, f, indent=2, ensure_ascii=False) 65 | logger.info("Training state saved successfully") 66 | except Exception as e: 67 | logger.error(f"Failed to save training state: {e}") 68 | 69 | def load_training_config(self) -> Optional[Dict[str, Any]]: 70 | """ 71 | Load training configuration 72 | 73 | Returns: 74 | Training configuration dictionary or None if not found 75 | """ 76 | if not self.config_file.exists(): 77 | return None 78 | 79 | try: 80 | with open(self.config_file, 'r', encoding='utf-8') as f: 81 | config = json.load(f) 82 | logger.info("Training configuration loaded successfully") 83 | return config 84 | except Exception as e: 85 | logger.error(f"Failed to load training configuration: {e}") 86 | return None 87 | 88 | def save_training_config(self, config: Dict[str, Any]): 89 | """ 90 | Save training configuration 91 | 92 | Args: 93 | config: Training configuration dictionary 94 | """ 95 | try: 96 | with open(self.config_file, 'w', encoding='utf-8') as f: 97 | json.dump(config, f, indent=2, ensure_ascii=False) 98 | logger.info("Training configuration saved successfully") 99 | except Exception as e: 100 | logger.error(f"Failed to save training configuration: {e}") 101 | 102 | def get_current_round(self) -> int: 103 | """ 104 | Get current training round 105 | 106 | Returns: 107 | Current round number (1-based) 108 | """ 109 | state = self.load_training_state() 110 | if state: 111 | return state.get("current_round", 1) 112 | return 1 113 | 114 | def get_completed_rounds(self) -> List[int]: 115 | """ 116 | Get list of completed rounds 117 | 118 | Returns: 119 | List of completed round numbers 120 | """ 121 | state = self.load_training_state() 122 | if state: 123 | return state.get("completed_rounds", []) 124 | return [] 125 | 126 | def mark_round_completed(self, round_num: int, success: bool = True): 127 | """ 128 | Mark a round as completed 129 | 130 | Args: 131 | round_num: Round number to mark as completed 132 | success: Whether the round completed successfully 133 | """ 134 | state = self.load_training_state() or {} 135 | 136 | completed_rounds = state.get("completed_rounds", []) 137 | if round_num not in completed_rounds and success: 138 | completed_rounds.append(round_num) 139 | 140 | state.update({ 141 | "completed_rounds": completed_rounds, 142 | "current_round": round_num + 1, 143 | "last_completed_round": round_num if success else state.get("last_completed_round", 0) 144 | }) 145 | 146 | self.save_training_state(state) 147 | 148 | def get_round_status(self, round_num: int) -> Dict[str, Any]: 149 | """ 150 | Get status of a specific round 151 | 152 | Args: 153 | round_num: Round number 154 | 155 | Returns: 156 | Dictionary with round status information 157 | """ 158 | completed_rounds = self.get_completed_rounds() 159 | current_round = self.get_current_round() 160 | 161 | if round_num in completed_rounds: 162 | status = "completed" 163 | fully_completed = True 164 | elif round_num == current_round: 165 | status = "in_progress" 166 | fully_completed = False 167 | elif round_num < current_round: 168 | status = "completed" 169 | fully_completed = True 170 | else: 171 | status = "pending" 172 | fully_completed = False 173 | 174 | return { 175 | "round_num": round_num, 176 | "status": status, 177 | "fully_completed": fully_completed, 178 | "completed_stages": [], 179 | "next_stage": "data_collection" if status == "pending" else None 180 | } 181 | 182 | def is_stage_completed(self, round_num: int, stage: str) -> bool: 183 | """ 184 | Check if a specific stage is completed for a round 185 | 186 | Args: 187 | round_num: Round number 188 | stage: Stage name 189 | 190 | Returns: 191 | True if stage is completed 192 | """ 193 | # Placeholder implementation 194 | # In real implementation, would check detailed stage progress 195 | return False 196 | 197 | def save_round_progress(self, round_num: int, stage: str, data: Dict[str, Any]): 198 | """ 199 | Save progress for a specific round and stage 200 | 201 | Args: 202 | round_num: Round number 203 | stage: Stage name 204 | data: Progress data 205 | """ 206 | progress_file = self.workspace_dir / f"round_{round_num}_progress.json" 207 | 208 | # Load existing progress 209 | if progress_file.exists(): 210 | with open(progress_file, 'r', encoding='utf-8') as f: 211 | progress = json.load(f) 212 | else: 213 | progress = {"round": round_num, "stages": {}} 214 | 215 | # Update stage progress 216 | progress["stages"][stage] = data 217 | progress["last_updated"] = datetime.datetime.now().isoformat() 218 | 219 | # Save progress 220 | with open(progress_file, 'w', encoding='utf-8') as f: 221 | json.dump(progress, f, indent=2, ensure_ascii=False) 222 | 223 | logger.info(f"Saved progress for round {round_num}, stage {stage}") 224 | 225 | def load_stage_data(self, round_num: int, stage: str) -> Optional[Dict[str, Any]]: 226 | """ 227 | Load data for a specific round and stage 228 | 229 | Args: 230 | round_num: Round number 231 | stage: Stage name 232 | 233 | Returns: 234 | Stage data or None if not found 235 | """ 236 | progress_file = self.workspace_dir / f"round_{round_num}_progress.json" 237 | 238 | if not progress_file.exists(): 239 | return None 240 | 241 | try: 242 | with open(progress_file, 'r', encoding='utf-8') as f: 243 | progress = json.load(f) 244 | return progress.get("stages", {}).get(stage) 245 | except Exception as e: 246 | logger.error(f"Failed to load stage data: {e}") 247 | return None 248 | 249 | def save_round_data(self, round_num: int, data_type: str, data: Dict[str, Any]): 250 | """ 251 | Save round-specific data 252 | 253 | Args: 254 | round_num: Round number 255 | data_type: Type of data 256 | data: Data to save 257 | """ 258 | data_dir = self.workspace_dir / "data" 259 | data_dir.mkdir(parents=True, exist_ok=True) 260 | 261 | data_file = data_dir / f"round_{round_num}_{data_type}.json" 262 | 263 | with open(data_file, 'w', encoding='utf-8') as f: 264 | json.dump(data, f, indent=2, ensure_ascii=False) 265 | 266 | logger.info(f"Saved round {round_num} {data_type} data") 267 | 268 | def load_round_data(self, round_num: int, data_type: str) -> Optional[Dict[str, Any]]: 269 | """ 270 | Load round-specific data 271 | 272 | Args: 273 | round_num: Round number 274 | data_type: Type of data 275 | 276 | Returns: 277 | Round data or None if not found 278 | """ 279 | data_dir = self.workspace_dir / "data" 280 | data_file = data_dir / f"round_{round_num}_{data_type}.json" 281 | 282 | if not data_file.exists(): 283 | return None 284 | 285 | try: 286 | with open(data_file, 'r', encoding='utf-8') as f: 287 | return json.load(f) 288 | except Exception as e: 289 | logger.error(f"Failed to load round data: {e}") 290 | return None 291 | 292 | def get_workspace_summary(self) -> Dict[str, Any]: 293 | """ 294 | Get workspace summary information 295 | 296 | Returns: 297 | Dictionary with workspace summary 298 | """ 299 | data_dir = self.workspace_dir / "data" 300 | 301 | summary = { 302 | "workspace_dir": str(self.workspace_dir), 303 | "state_exists": self.state_file.exists(), 304 | "config_exists": self.config_file.exists(), 305 | "total_data_files": len(list(data_dir.glob("*.json"))) if data_dir.exists() else 0, 306 | "last_updated": datetime.datetime.now().isoformat() 307 | } 308 | 309 | return summary -------------------------------------------------------------------------------- /tools/utils/status_checker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | ProSetting Status Checker Tool 4 | Used to check current system status, configuration and data integrity 5 | """ 6 | 7 | import os 8 | import sys 9 | import json 10 | import logging 11 | from pathlib import Path 12 | from typing import Dict, Any, List, Optional 13 | import datetime 14 | 15 | # Setup project paths 16 | project_root = Path(__file__).parent.parent.parent.absolute() 17 | prosetting_root = project_root / "ProSetting" 18 | 19 | if str(project_root) not in sys.path: 20 | sys.path.insert(0, str(project_root)) 21 | if str(prosetting_root) not in sys.path: 22 | sys.path.insert(0, str(prosetting_root)) 23 | 24 | # Configure logging 25 | logging.basicConfig( 26 | level=logging.INFO, 27 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 28 | ) 29 | logger = logging.getLogger(__name__) 30 | 31 | class ProSettingStatusChecker: 32 | """ProSetting system status checker""" 33 | 34 | def __init__(self, workspace_dir: str = None): 35 | self.workspace_dir = workspace_dir or os.getenv("WORKSPACE_DIR", "/tmp/prosetting_workspace") 36 | self.status_report = {} 37 | 38 | def check_environment(self): 39 | """Check environment configuration""" 40 | print("🔧 ========== Environment Configuration Check ==========") 41 | 42 | env_status = {} 43 | 44 | # Check key environment variables 45 | key_env_vars = [ 46 | "SOLVER_MODEL_PATH", 47 | "QUESTIONS_FILE", 48 | "WORKSPACE_DIR", 49 | "TOTAL_ROUNDS", 50 | "SAVE_ROUNDS", 51 | "TEACHER_BASE_URL", 52 | "TRL_NUM_PROCESSES", 53 | "TRL_MIXED_PRECISION" 54 | ] 55 | 56 | for var in key_env_vars: 57 | value = os.getenv(var) 58 | if value: 59 | env_status[var] = {"value": value, "status": "✅"} 60 | print(f" ✅ {var}: {value}") 61 | else: 62 | env_status[var] = {"value": None, "status": "❌"} 63 | print(f" ❌ {var}: Not set") 64 | 65 | self.status_report["environment"] = env_status 66 | return env_status 67 | 68 | def check_file_paths(self): 69 | """Check file paths""" 70 | print("\n📁 ========== File Path Check ==========") 71 | 72 | file_status = {} 73 | 74 | # Check model path 75 | solver_model_path = os.getenv("SOLVER_MODEL_PATH") 76 | if solver_model_path: 77 | if Path(solver_model_path).exists(): 78 | file_status["solver_model"] = {"path": solver_model_path, "exists": True, "status": "✅"} 79 | print(f" ✅ Solver model: {solver_model_path}") 80 | else: 81 | file_status["solver_model"] = {"path": solver_model_path, "exists": False, "status": "❌"} 82 | print(f" ❌ Solver model not found: {solver_model_path}") 83 | 84 | # Check questions file 85 | questions_file = os.getenv("QUESTIONS_FILE") 86 | if questions_file: 87 | if Path(questions_file).exists(): 88 | try: 89 | with open(questions_file, 'r', encoding='utf-8') as f: 90 | questions_data = json.load(f) 91 | question_count = len(questions_data) 92 | file_status["questions_file"] = { 93 | "path": questions_file, 94 | "exists": True, 95 | "count": question_count, 96 | "status": "✅" 97 | } 98 | print(f" ✅ Questions file: {questions_file} ({question_count} questions)") 99 | except Exception as e: 100 | file_status["questions_file"] = { 101 | "path": questions_file, 102 | "exists": True, 103 | "error": str(e), 104 | "status": "⚠️" 105 | } 106 | print(f" ⚠️ Questions file format error: {e}") 107 | else: 108 | file_status["questions_file"] = {"path": questions_file, "exists": False, "status": "❌"} 109 | print(f" ❌ Questions file not found: {questions_file}") 110 | 111 | # Check workspace 112 | if Path(self.workspace_dir).exists(): 113 | file_status["workspace"] = {"path": self.workspace_dir, "exists": True, "status": "✅"} 114 | print(f" ✅ Workspace: {self.workspace_dir}") 115 | else: 116 | file_status["workspace"] = {"path": self.workspace_dir, "exists": False, "status": "❌"} 117 | print(f" ❌ Workspace not found: {self.workspace_dir}") 118 | 119 | self.status_report["file_paths"] = file_status 120 | return file_status 121 | 122 | def check_training_state(self): 123 | """Check training state""" 124 | print("\n📊 ========== Training Status Check ==========") 125 | 126 | training_status = {} 127 | 128 | try: 129 | from core import StateManager 130 | 131 | state_manager = StateManager(self.workspace_dir) 132 | 133 | # Check training configuration 134 | config = state_manager.load_training_config() 135 | if config: 136 | training_status["config"] = {"exists": True, "data": config, "status": "✅"} 137 | print(f" ✅ Training configuration loaded") 138 | print(f" - Total rounds: {config.get('max_rounds', 'N/A')}") 139 | print(f" - Save rounds: {config.get('save_rounds', 'N/A')}") 140 | print(f" - Training framework: {config.get('training_framework', 'N/A')}") 141 | else: 142 | training_status["config"] = {"exists": False, "status": "❌"} 143 | print(f" ❌ Training configuration not found") 144 | 145 | # Check training state 146 | state = state_manager.load_training_state() 147 | if state: 148 | current_round = state_manager.get_current_round() 149 | completed_rounds = state_manager.get_completed_rounds() 150 | 151 | training_status["state"] = { 152 | "exists": True, 153 | "current_round": current_round, 154 | "completed_rounds": completed_rounds, 155 | "data": state, 156 | "status": "✅" 157 | } 158 | 159 | print(f" ✅ Training state loaded") 160 | print(f" - Current round: {current_round}") 161 | print(f" - Completed rounds: {completed_rounds}") 162 | print(f" - Last updated: {state.get('last_updated', 'N/A')}") 163 | else: 164 | training_status["state"] = {"exists": False, "status": "❌"} 165 | print(f" ❌ Training state not found") 166 | 167 | # Check round detailed status 168 | if config: 169 | max_rounds = config.get("max_rounds", int(os.getenv("TOTAL_ROUNDS", "10"))) 170 | round_details = {} 171 | 172 | for round_num in range(1, min(max_rounds + 1, 6)): # Check at most 5 rounds 173 | round_status = state_manager.get_round_status(round_num) 174 | round_details[f"round_{round_num}"] = round_status 175 | 176 | status_emoji = "✅" if round_status["fully_completed"] else ( 177 | "🔄" if round_status["status"] == "in_progress" else "⏸️" 178 | ) 179 | 180 | print(f" {status_emoji} Round {round_num}: {round_status['status']}") 181 | if round_status["completed_stages"]: 182 | print(f" Completed: {', '.join(round_status['completed_stages'])}") 183 | if round_status["next_stage"]: 184 | print(f" Next step: {round_status['next_stage']}") 185 | 186 | training_status["round_details"] = round_details 187 | 188 | except Exception as e: 189 | training_status["error"] = str(e) 190 | print(f" ❌ Training status check failed: {e}") 191 | 192 | self.status_report["training_state"] = training_status 193 | return training_status 194 | 195 | def check_data_integrity(self): 196 | """Check data integrity""" 197 | print("\n💾 ========== Data Integrity Check ==========") 198 | 199 | data_status = {} 200 | 201 | workspace_path = Path(self.workspace_dir) 202 | 203 | if not workspace_path.exists(): 204 | data_status["workspace_missing"] = True 205 | print(f" ❌ Workspace not found: {self.workspace_dir}") 206 | self.status_report["data_integrity"] = data_status 207 | return data_status 208 | 209 | # Check data directory structure 210 | expected_dirs = ["data", "datasets", "training_data", "checkpoints", "logs", "results"] 211 | 212 | for dir_name in expected_dirs: 213 | dir_path = workspace_path / dir_name 214 | if dir_path.exists(): 215 | file_count = len(list(dir_path.glob("*"))) 216 | data_status[dir_name] = {"exists": True, "file_count": file_count, "status": "✅"} 217 | print(f" ✅ {dir_name} directory: {file_count} files") 218 | else: 219 | data_status[dir_name] = {"exists": False, "status": "⚠️"} 220 | print(f" ⚠️ {dir_name} directory not found") 221 | 222 | # Check round data files 223 | data_dir = workspace_path / "data" 224 | if data_dir.exists(): 225 | round_files = {} 226 | 227 | for round_file in data_dir.glob("round_*_*.json"): 228 | round_files[round_file.name] = { 229 | "path": str(round_file), 230 | "size": round_file.stat().st_size, 231 | "modified": datetime.datetime.fromtimestamp(round_file.stat().st_mtime).isoformat() 232 | } 233 | print(f" 📄 {round_file.name}: {round_file.stat().st_size} bytes") 234 | 235 | data_status["round_files"] = round_files 236 | 237 | # Check TRL datasets 238 | datasets_dir = workspace_path / "datasets" 239 | if datasets_dir.exists(): 240 | dataset_rounds = {} 241 | 242 | for round_dir in datasets_dir.glob("round_*"): 243 | if round_dir.is_dir(): 244 | train_file = round_dir / "train.parquet" 245 | val_file = round_dir / "validation.parquet" 246 | info_file = round_dir / "dataset_info.json" 247 | 248 | dataset_rounds[round_dir.name] = { 249 | "train_exists": train_file.exists(), 250 | "validation_exists": val_file.exists(), 251 | "info_exists": info_file.exists(), 252 | "status": "✅" if all([train_file.exists(), val_file.exists()]) else "⚠️" 253 | } 254 | 255 | status = "✅" if all([train_file.exists(), val_file.exists()]) else "⚠️" 256 | print(f" {status} {round_dir.name}: train={train_file.exists()}, val={val_file.exists()}") 257 | 258 | data_status["datasets"] = dataset_rounds 259 | 260 | self.status_report["data_integrity"] = data_status 261 | return data_status 262 | 263 | def check_module_imports(self): 264 | """Check module imports""" 265 | print("\n📦 ========== Module Import Check ==========") 266 | 267 | import_status = {} 268 | 269 | # Check core modules 270 | core_modules = [ 271 | ("collectors", "Data collection modules"), 272 | ("processors", "Data processing modules"), 273 | ("datasets", "Dataset modules"), 274 | ("trainers", "Trainer modules"), 275 | ("managers", "Manager modules"), 276 | ("core", "Core modules") 277 | ] 278 | 279 | for module_name, description in core_modules: 280 | try: 281 | __import__(module_name) 282 | import_status[module_name] = {"status": "✅", "description": description} 283 | print(f" ✅ {description}: {module_name}") 284 | except ImportError as e: 285 | import_status[module_name] = {"status": "❌", "error": str(e), "description": description} 286 | print(f" ❌ {description}: {e}") 287 | 288 | # Check TRL related modules 289 | trl_modules = [ 290 | ("trl", "TRL training framework"), 291 | ("accelerate", "Accelerate distributed training"), 292 | ("pandas", "Data processing"), 293 | ("pyarrow", "Parquet support") 294 | ] 295 | 296 | for module_name, description in trl_modules: 297 | try: 298 | __import__(module_name) 299 | import_status[module_name] = {"status": "✅", "description": description} 300 | print(f" ✅ {description}: {module_name}") 301 | except ImportError as e: 302 | import_status[module_name] = {"status": "❌", "error": str(e), "description": description} 303 | print(f" ❌ {description}: {e}") 304 | 305 | self.status_report["module_imports"] = import_status 306 | return import_status 307 | 308 | def generate_report(self): 309 | """Generate complete status report""" 310 | print("\n📋 ========== Generate Status Report ==========") 311 | 312 | # Run all checks 313 | self.check_environment() 314 | self.check_file_paths() 315 | self.check_training_state() 316 | self.check_data_integrity() 317 | self.check_module_imports() 318 | 319 | # Generate report summary 320 | report_summary = { 321 | "check_time": datetime.datetime.now().isoformat(), 322 | "workspace_dir": self.workspace_dir, 323 | "overall_status": self._calculate_overall_status(), 324 | "details": self.status_report 325 | } 326 | 327 | # Save report to file 328 | report_file = Path(self.workspace_dir) / "status_report.json" 329 | report_file.parent.mkdir(parents=True, exist_ok=True) 330 | 331 | with open(report_file, 'w', encoding='utf-8') as f: 332 | json.dump(report_summary, f, ensure_ascii=False, indent=2) 333 | 334 | print(f"📄 Status report saved: {report_file}") 335 | 336 | # Print summary 337 | self._print_summary() 338 | 339 | return report_summary 340 | 341 | def _calculate_overall_status(self): 342 | """Calculate overall status""" 343 | total_checks = 0 344 | passed_checks = 0 345 | 346 | for category, details in self.status_report.items(): 347 | if isinstance(details, dict): 348 | for item, status in details.items(): 349 | if isinstance(status, dict) and "status" in status: 350 | total_checks += 1 351 | if status["status"] == "✅": 352 | passed_checks += 1 353 | 354 | if total_checks == 0: 355 | return {"status": "unknown", "score": 0} 356 | 357 | score = passed_checks / total_checks 358 | 359 | if score >= 0.9: 360 | status = "excellent" 361 | elif score >= 0.7: 362 | status = "good" 363 | elif score >= 0.5: 364 | status = "fair" 365 | else: 366 | status = "poor" 367 | 368 | return { 369 | "status": status, 370 | "score": score, 371 | "passed": passed_checks, 372 | "total": total_checks 373 | } 374 | 375 | def _print_summary(self): 376 | """Print status summary""" 377 | overall = self._calculate_overall_status() 378 | 379 | print(f"\n🎯 ========== Status Summary ==========") 380 | print(f"Overall status: {overall['status'].upper()}") 381 | print(f"Check pass rate: {overall['passed']}/{overall['total']} ({overall['score']:.1%})") 382 | 383 | if overall['score'] >= 0.9: 384 | print("🎉 System status excellent, ready for training!") 385 | elif overall['score'] >= 0.7: 386 | print("✅ System status good, recommend checking warning items") 387 | elif overall['score'] >= 0.5: 388 | print("⚠️ System status fair, recommend fixing failed items before running") 389 | else: 390 | print("❌ System status poor, please fix critical issues before use") 391 | 392 | def main(): 393 | """Main function""" 394 | import argparse 395 | 396 | parser = argparse.ArgumentParser(description="ProSetting system status checker tool") 397 | parser.add_argument("--workspace", "-w", help="Specify workspace directory") 398 | parser.add_argument("--quick", "-q", action="store_true", help="Quick check (skip detailed data checks)") 399 | 400 | args = parser.parse_args() 401 | 402 | checker = ProSettingStatusChecker(args.workspace) 403 | 404 | if args.quick: 405 | print("🚀 Quick status check...") 406 | checker.check_environment() 407 | checker.check_file_paths() 408 | checker.check_module_imports() 409 | else: 410 | print("🔍 Complete status check...") 411 | checker.generate_report() 412 | 413 | if __name__ == "__main__": 414 | main() --------------------------------------------------------------------------------