├── data └── README.md ├── assets ├── table1.png ├── table2.png ├── table3.png ├── table4.png ├── figure1.png ├── figure2.png ├── figure3.png └── overall.png ├── evaluate ├── BBH │ └── __init__.py ├── AIME │ └── __init__.py ├── ARCC │ └── __init__.py ├── GPQA │ ├── __init__.py │ └── gpqa.py ├── __init__.py ├── FinQA │ └── __init__.py ├── MedQA │ ├── __init__.py │ └── medqa.py ├── MATH500 │ └── __init__.py ├── MBPP │ ├── __init__.py │ ├── utils.py │ ├── mbpp.py │ └── execution.py ├── MELD │ ├── __init__.py │ └── meld.py ├── MMLUPro │ └── __init__.py ├── KORBench │ └── __init__.py ├── SimpleQA │ ├── __init__.py │ └── prompts.py ├── ArenaHard │ ├── __init__.py │ └── arenahard.py ├── MATHBench │ └── __init__.py ├── TruthfulQA │ ├── __init__.py │ └── truthfulqa.py ├── Winogrande │ └── __init__.py ├── BrainTeaser │ ├── __init__.py │ └── brainteaser.py ├── EmoryNLP │ ├── __init__.py │ └── emorynlp.py ├── K_and_K │ ├── __init__.py │ └── prompt.py ├── HumanEval │ ├── __init__.py │ ├── utils.py │ ├── humaneval.py │ └── execution.py ├── LiveMathBench │ └── __init__.py ├── DailyDialog │ └── __init__.py ├── StudentEval │ ├── __init__.py │ ├── utils.py │ ├── studenteval.py │ └── execution.py ├── LiveCodeBench │ ├── __init__.py │ └── compute_code_generation_metrics.py ├── base_evaluator.py └── factory.py ├── .gitattributes ├── core ├── rank │ ├── ranking_centers_split_k64_m22_7b.npy │ ├── map_m22.json │ ├── ranking_split_k64_m22_7b.json │ └── ranking_embedding_normalizer_gte_qwen2-7b-instruct.joblib ├── inference │ ├── base_generator.py │ ├── __init__.py │ ├── selfconsistency_generator.py │ ├── factory.py │ ├── direct_generator.py │ ├── modelswitch_generator.py │ ├── fastslow_generator.py │ └── slowfast_generator.py ├── routing │ ├── __init__.py │ ├── base_router.py │ ├── straight_router.py │ ├── moa_router.py │ ├── random_router.py │ ├── factory.py │ ├── elo_router.py │ └── rank_router.py ├── ablation │ ├── utils.py │ └── embedding_cache.py └── experts │ └── load_experts.py ├── requirements.txt ├── config ├── experiment_var_template.yaml ├── experts_template.yaml └── config_loader.py ├── scripts └── deploy_template.sh ├── .gitignore ├── main.py ├── app.py └── diversity └── embedding_cache.py /data/README.md: -------------------------------------------------------------------------------- 1 | Data download link: https://huggingface.co/datasets/Estwld/Avengers -------------------------------------------------------------------------------- /assets/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/assets/table1.png -------------------------------------------------------------------------------- /assets/table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/assets/table2.png -------------------------------------------------------------------------------- /assets/table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/assets/table3.png -------------------------------------------------------------------------------- /assets/table4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/assets/table4.png -------------------------------------------------------------------------------- /assets/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/assets/figure1.png -------------------------------------------------------------------------------- /assets/figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/assets/figure2.png -------------------------------------------------------------------------------- /assets/figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/assets/figure3.png -------------------------------------------------------------------------------- /assets/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/assets/overall.png -------------------------------------------------------------------------------- /evaluate/BBH/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.BBH.bbh import BBHEvaluator 2 | 3 | __all__ = ["BBHEvaluator"] -------------------------------------------------------------------------------- /evaluate/AIME/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.AIME.aime import AIMEEvaluator 2 | 3 | __all__ = ["AIMEEvaluator"] -------------------------------------------------------------------------------- /evaluate/ARCC/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.ARCC.arcc import ARCCEvaluator 2 | 3 | __all__ = ["ARCCEvaluator"] -------------------------------------------------------------------------------- /evaluate/GPQA/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.GPQA.gpqa import GPQAEvaluator 2 | 3 | __all__ = ["GPQAEvaluator"] -------------------------------------------------------------------------------- /evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.factory import EvaluatorFactory 2 | 3 | __all__ = ['EvaluatorFactory'] -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.json filter=lfs diff=lfs merge=lfs -text 2 | *.jsonl filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /evaluate/FinQA/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.FinQA.finqa import FinQAEvaluator 2 | 3 | __all__ = ["FinQAEvaluator"] -------------------------------------------------------------------------------- /evaluate/MedQA/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.MedQA.medqa import MedQAEvaluator 2 | 3 | __all__ = ["MedQAEvaluator"] -------------------------------------------------------------------------------- /evaluate/MATH500/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.MATH500.math500 import MATH500Evaluator 2 | 3 | __all__ = ["MATH500Evaluator"] -------------------------------------------------------------------------------- /evaluate/MBPP/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.MBPP.mbpp import MBPPEvaluator 2 | 3 | __all__ = [ 4 | "MBPPEvaluator" 5 | ] -------------------------------------------------------------------------------- /evaluate/MELD/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.MELD.meld import MELDEvaluator 2 | 3 | __all__ = [ 4 | "MELDEvaluator" 5 | ] -------------------------------------------------------------------------------- /evaluate/MMLUPro/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.MMLUPro.mmlupro import MMLUProEvaluator 2 | 3 | __all__ = ["MMLUProEvaluator"] -------------------------------------------------------------------------------- /evaluate/KORBench/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.KORBench.korbench import KORBenchEvaluator 2 | 3 | __all__ = ["KORBenchEvaluator"] -------------------------------------------------------------------------------- /evaluate/SimpleQA/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.SimpleQA.simpleqa import SimpleQAEvaluator 2 | 3 | __all__ = ["SimpleQAEvaluator"] -------------------------------------------------------------------------------- /evaluate/ArenaHard/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.ArenaHard.arenahard import ArenaHardEvaluator 2 | 3 | __all__ = ["ArenaHardEvaluator"] -------------------------------------------------------------------------------- /evaluate/MATHBench/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.MATHBench.mathbench import MathBenchEvaluator 2 | 3 | __all__ = ["MathBenchEvaluator"] -------------------------------------------------------------------------------- /evaluate/TruthfulQA/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.TruthfulQA.truthfulqa import TruthfulQAEvaluator 2 | 3 | __all__ = ["TruthfulQAEvaluator"] -------------------------------------------------------------------------------- /evaluate/Winogrande/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.Winogrande.winogrande import WinograndeEvaluator 2 | 3 | __all__ = ["WinograndeEvaluator"] -------------------------------------------------------------------------------- /evaluate/BrainTeaser/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.BrainTeaser.brainteaser import BrainTeaserEvaluator 2 | 3 | __all__ = ["BrainTeaserEvaluator"] -------------------------------------------------------------------------------- /evaluate/EmoryNLP/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.EmoryNLP.emorynlp import EmoryNLPEvaluator 2 | 3 | __all__ = [ 4 | "EmoryNLPEvaluator" 5 | ] -------------------------------------------------------------------------------- /evaluate/K_and_K/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.K_and_K.k_and_k import KnightsAndKnavesEvaluator 2 | 3 | __all__ = ["KnightsAndKnavesEvaluator"] -------------------------------------------------------------------------------- /evaluate/HumanEval/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.HumanEval.humaneval import HumanEvalEvaluator 2 | 3 | __all__ = [ 4 | "HumanEvalEvaluator" 5 | ] -------------------------------------------------------------------------------- /core/rank/ranking_centers_split_k64_m22_7b.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/core/rank/ranking_centers_split_k64_m22_7b.npy -------------------------------------------------------------------------------- /evaluate/LiveMathBench/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.LiveMathBench.livemathbench import LiveMathBenchEvaluator 2 | 3 | __all__ = ["LiveMathBenchEvaluator"] -------------------------------------------------------------------------------- /evaluate/DailyDialog/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.DailyDialog.dailydialog import DailyDialogEvaluator 2 | 3 | __all__ = [ 4 | "DailyDialogEvaluator" 5 | ] -------------------------------------------------------------------------------- /evaluate/StudentEval/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.StudentEval.studenteval import StudentEvalEvaluator 2 | 3 | __all__ = [ 4 | "StudentEvalEvaluator" 5 | ] -------------------------------------------------------------------------------- /core/rank/map_m22.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8af5f137188e6f43803aece305f39778ece6df8d96d0081fe3f51862d9e63a4d 3 | size 729 4 | -------------------------------------------------------------------------------- /evaluate/LiveCodeBench/__init__.py: -------------------------------------------------------------------------------- 1 | from evaluate.LiveCodeBench.livecodebench import LiveCodeBenchEvaluator 2 | 3 | __all__ = [ 4 | "LiveCodeBenchEvaluator" 5 | ] -------------------------------------------------------------------------------- /core/rank/ranking_split_k64_m22_7b.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bbcd16a04a6fae927fcaa0633b12aed622e9a13bacc7986a8ae59f3b2d52cc5b 3 | size 232918 4 | -------------------------------------------------------------------------------- /core/rank/ranking_embedding_normalizer_gte_qwen2-7b-instruct.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangYiqun018/Avengers/HEAD/core/rank/ranking_embedding_normalizer_gte_qwen2-7b-instruct.joblib -------------------------------------------------------------------------------- /core/inference/base_generator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import List 4 | @dataclass 5 | class GeneratorOutput: 6 | first_output: str 7 | raw_output: List[str] 8 | prompt_tokens: int 9 | completion_tokens: int 10 | 11 | class BaseGenerator(ABC): 12 | @abstractmethod 13 | def generate(self, question: str): 14 | pass -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core dependencies 2 | loguru>=0.6.0 3 | PyYAML>=6.0 4 | openai>=1.0.0 5 | pydantic>=2.0.0 6 | 7 | # HTTP and caching 8 | httpx>=0.24.0 9 | hishel>=0.0.24 10 | 11 | # ML and scientific computing 12 | numpy>=1.24.0 13 | torch>=2.0.0 14 | transformers>=4.30.0 15 | scipy>=1.10.0 16 | datasets>=2.12.0 17 | 18 | # Math and symbolic computing 19 | sympy>=1.12 20 | pylatexenc>=2.10 21 | 22 | # Tokenization 23 | tiktoken>=0.4.0 24 | 25 | # Async and retry 26 | tenacity>=8.2.0 27 | 28 | # Progress bars 29 | tqdm>=4.65.0 -------------------------------------------------------------------------------- /config/experiment_var_template.yaml: -------------------------------------------------------------------------------- 1 | base_config: config/experts.yaml 2 | save_dir: /path/to/save/results 3 | variations: 4 | - name: experiment_name_1 5 | params: 6 | experiments.task: task1 7 | router.type: straight 8 | router.straight_router.model: model1 9 | generator.type: direct 10 | 11 | - name: experiment_name_2 12 | params: 13 | experiments.task: task2 14 | router.type: random 15 | router.random_router.models: 16 | - model1 17 | - model2 18 | generator.type: self_consistency -------------------------------------------------------------------------------- /core/routing/__init__.py: -------------------------------------------------------------------------------- 1 | from core.routing.base_router import BaseRouter, RouterOutput 2 | from core.routing.factory import RouterFactory, RouterType 3 | from core.routing.gpt_router import GPTRouter 4 | from core.routing.straight_router import StraightRouter 5 | from core.routing.elo_router import EloRouter 6 | from core.routing.routerdc_router import RouterDC 7 | from core.routing.random_router import RandomRouter 8 | 9 | __all__ = ['GPTRouter', 'BaseRouter', 'StraightRouter', 'RouterFactory', 'RouterType', 'RouterOutput', 'EloRouter', 'RouterDC', 'RandomRouter'] 10 | -------------------------------------------------------------------------------- /core/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from core.inference.base_generator import BaseGenerator, GeneratorOutput 2 | from core.inference.direct_generator import DirectGenerator 3 | from core.inference.factory import GeneratorFactory, GeneratorType 4 | from core.inference.fastslow_generator import FastSlowGenerator 5 | from core.inference.modelswitch_generator import ModelSwitchGenerator 6 | from core.inference.selfconsistency_generator import SelfConsistencyGenerator 7 | 8 | __all__ = [ 9 | "BaseGenerator", 10 | "DirectGenerator", 11 | "SelfConsistencyGenerator", 12 | "ModelSwitchGenerator", 13 | "FastSlowGenerator", 14 | "GeneratorFactory", 15 | "GeneratorType", 16 | "GeneratorOutput" 17 | ] 18 | -------------------------------------------------------------------------------- /core/routing/base_router.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import List 4 | 5 | from core.experts.load_experts import Expert 6 | 7 | @dataclass 8 | class RouterOutput: 9 | normal_experts: List[Expert] 10 | thinking_experts: List[Expert] 11 | 12 | 13 | class BaseRouter(ABC): 14 | def __init__(self, normal_experts: List[Expert], thinking_experts: List[Expert]): 15 | self.normal_experts = normal_experts 16 | self.thinking_experts = thinking_experts 17 | 18 | @abstractmethod 19 | def route(self, question: str) -> RouterOutput: 20 | return RouterOutput( 21 | normal_experts=[], 22 | thinking_experts=[] 23 | ) -------------------------------------------------------------------------------- /scripts/deploy_template.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 环境变量设置 3 | export VLLM_LOGGING_LEVEL=DEBUG 4 | export VLLM_ATTENTION_BACKEND=XFORMERS 5 | export TORCH_USE_CUDA_DSA=1 6 | # export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True 7 | # export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 8 | 9 | # 模型路径 10 | MODEL=/path/to/your/model 11 | 12 | # 日志目录 13 | ROOT="logs" 14 | mkdir -p vllm_logs/$ROOT 15 | 16 | # 端口设置 17 | PORT=8000 18 | 19 | # GPU设置 20 | GPU_COUNTS=$(nvidia-smi --query-gpu=count --format=csv,noheader | wc -l) 21 | 22 | echo "Starting API server on port $PORT..." 23 | 24 | # 通用参数 25 | COMMON_ARGS="--model $MODEL \ 26 | --trust-remote-code \ 27 | --seed 42 \ 28 | --enforce-eager \ 29 | --max-model-len 4096 \ 30 | --gpu-memory-utilization 0.93 \ 31 | --tensor-parallel-size $GPU_COUNTS \ 32 | --served-model-name your-model-name" 33 | 34 | # 可选:添加特定功能 35 | # --enable-prefix-caching \ 36 | # --enable-chunked-prefill \ 37 | # --enable-auto-tool-choice \ 38 | # --tool-call-parser hermes \ 39 | 40 | # 启动服务器 41 | python -m vllm.entrypoints.openai.api_server \ 42 | $COMMON_ARGS \ 43 | --port $PORT 44 | 45 | echo "API server started on port $PORT" -------------------------------------------------------------------------------- /core/routing/straight_router.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from loguru import logger 4 | 5 | from core.experts.load_experts import Expert 6 | from core.routing.base_router import BaseRouter, RouterOutput 7 | 8 | 9 | class StraightRouter(BaseRouter): 10 | def __init__(self, normal_experts: List[Expert], thinking_experts: List[Expert], router_config: dict): 11 | super().__init__(normal_experts, thinking_experts) 12 | 13 | self.config = router_config['straight_router'] 14 | model_name = self.config['model'] 15 | self.expert = self.find_expert(model_name) 16 | 17 | def find_expert(self, model_name: str) -> Expert: 18 | self.expert = next((expert for expert in self.normal_experts if expert.model_name == model_name), None) 19 | if self.expert is None: 20 | logger.error(f"Expert with model name {model_name} not found") 21 | raise ValueError(f"Expert with model name {model_name} not found") 22 | return self.expert 23 | 24 | def route(self, question: str) -> RouterOutput: 25 | return RouterOutput( 26 | normal_experts=[self.expert], 27 | thinking_experts=self.thinking_experts 28 | ) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | .env 8 | .venv/ 9 | env/ 10 | venv/ 11 | ENV/ 12 | 13 | # IDE 14 | .idea/ 15 | .vscode/ 16 | *.swp 17 | *.swo 18 | 19 | # OS 20 | .DS_Store 21 | Thumbs.db 22 | 23 | # Logs and databases 24 | *.log 25 | nohup.out 26 | *.sqlite3 27 | *.db 28 | 29 | # Jupyter Notebook 30 | .ipynb_checkpoints 31 | *.ipynb 32 | *.ipynbresults/ 33 | 34 | # Project specific 35 | results/ 36 | analysis_results/ 37 | config/temp/ 38 | .cache/ 39 | routerdc/ 40 | train/ 41 | 42 | # Data files 43 | *.txt 44 | *.xlsx 45 | data/LiveCodeBench/*.jsonl 46 | 47 | # Configuration and scripts 48 | *.yaml 49 | *.yml 50 | *.sh 51 | 52 | # Keep template files 53 | !config/*_template.yaml 54 | !scripts/*_template.sh 55 | diversity/result 56 | diversity/*.csv 57 | 58 | # 忽略所有数据集子文件夹 59 | data/AIME/ 60 | data/EmoryNLP/ 61 | data/FinQA/ 62 | data/GPQA/ 63 | data/HumanEval/ 64 | data/K_and_K/ 65 | data/LiveCodeBench/ 66 | data/MATH500/ 67 | data/MBPP/ 68 | data/MMLUPro/ 69 | data/MedQA/ 70 | data/SimpleQA/ 71 | data/arc_c/ 72 | data/bbh/ 73 | data/livemathbench/ 74 | data/mathbench_v1/ 75 | data/winogrande/ 76 | 77 | # 保留README.md 78 | !data/README.md 79 | 80 | core/ablation/legacy 81 | CLAUDE.md 82 | GEMINI.md -------------------------------------------------------------------------------- /core/ablation/utils.py: -------------------------------------------------------------------------------- 1 | EMBED_CONFIG={ 2 | "bge-m3": { 3 | "url": "input your api url", 4 | "api_key": "input your api key", 5 | "model_name": "bge-m3" 6 | }, 7 | "jina-embeddings-v3": { 8 | "url": "input your api url", 9 | "api_key": "input your api key", 10 | "model_name": "jina-embeddings-v3" 11 | }, 12 | "gte-qwen2-7b-instruct": { 13 | "url": "input your api url", 14 | "api_key": "input your api key", 15 | "model_name": "gte-qwen2-7b-instruct" 16 | }, 17 | "gte_Qwen2-1.5B-instruct": { 18 | "url": "input your api url", 19 | "api_key": "input your api key", 20 | "model_name": "gte_Qwen2-1.5B-instruct" 21 | }, 22 | "text-embedding-3-small": { 23 | "url": "input your openai api url", 24 | "api_key": "input your openai api key", 25 | "model_name": "text-embedding-3-small" 26 | }, 27 | "text-embedding-3-large": { 28 | "url": "input your openai api url", 29 | "api_key": "input your openai api key", 30 | "model_name": "text-embedding-3-large" 31 | }, 32 | "qwen3-embedding-0.6b": { 33 | "url": "input your api url", 34 | "api_key": "input your api key", 35 | "model_name": "qwen3-embedding-0.6b" 36 | } 37 | } -------------------------------------------------------------------------------- /config/experts_template.yaml: -------------------------------------------------------------------------------- 1 | experiments: 2 | task: task_name 3 | max_workers: 8 4 | mode: test 5 | 6 | router: 7 | type: straight # options: straight, gpt_router, random_router, rank_router, elo_router, routerdc_router 8 | gpt_router: 9 | model: model_name 10 | max_router: 2 11 | base_url: your_base_url 12 | api_key: your_api_key 13 | straight_router: 14 | model: model_name 15 | random_router: 16 | max_router: 2 17 | rank_router: 18 | centres_path: "path/to/centres.npy" 19 | rankings_path: "path/to/rankings.json" 20 | mapping_path: "path/to/mapping.json" 21 | normalizer_path: "path/to/normalizer.joblib" 22 | available_models: 23 | - "model1" 24 | - "model2" 25 | top_n: 2 26 | top_k: 2 27 | beta: 6.0 28 | default_rank: 999 29 | embedding_model: embedding_model_name 30 | 31 | generator: 32 | type: direct # 可选: direct, self_consistency, model_switch, fast_slow, slow_fast, aggregation 33 | direct: 34 | temperature: 0.2 35 | top_p: 1.0 36 | self_consistency: 37 | samples: 10 38 | temperature: 0.7 39 | top_p: 1.0 40 | 41 | experts: 42 | - name: model1_name 43 | base_url: http://your.api.endpoint/v1 44 | api_key: your_api_key 45 | description: "the description of model1" 46 | 47 | - name: model2_name 48 | base_url: http://your.api.endpoint/v1 49 | api_key: your_api_key 50 | description: "the description of model2" -------------------------------------------------------------------------------- /core/routing/moa_router.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from loguru import logger 4 | 5 | from core.experts.load_experts import Expert 6 | from core.routing.base_router import BaseRouter, RouterOutput 7 | 8 | 9 | class MoARouter(BaseRouter): 10 | def __init__(self, normal_experts: List[Expert], thinking_experts: List[Expert], router_config: dict): 11 | super().__init__(normal_experts, thinking_experts) 12 | self.config = router_config['moa_router'] 13 | self.proposers = self.config.get('proposers') 14 | self.aggregator = self.config.get('aggregator') 15 | 16 | self.experts = self.find_experts() 17 | 18 | def find_experts(self) -> List[Expert]: 19 | experts = [] 20 | for proposer in self.proposers: 21 | for expert in self.normal_experts: 22 | if proposer == expert.model_name: 23 | experts.append(expert) 24 | break 25 | 26 | for expert in self.normal_experts: 27 | if expert.model_name == self.aggregator: 28 | experts.append(expert) 29 | break 30 | 31 | assert len(experts) == len(self.proposers) + 1, f"Number of experts must be equal to {len(self.proposers) + 1}" 32 | return experts 33 | 34 | def route(self, question: str) -> RouterOutput: 35 | return RouterOutput( 36 | normal_experts=self.experts, 37 | thinking_experts=[] 38 | ) -------------------------------------------------------------------------------- /config/config_loader.py: -------------------------------------------------------------------------------- 1 | # config_loader.py 2 | import copy 3 | import os 4 | from pathlib import Path 5 | 6 | import yaml 7 | from loguru import logger 8 | 9 | 10 | def load_config(yaml_path=None): 11 | """ 12 | load config from yaml file 13 | """ 14 | with open(yaml_path, 'r', encoding='utf-8') as f: 15 | config = yaml.safe_load(f) 16 | return config 17 | 18 | def generate_experiment_configs(base_config, experiment_variations): 19 | """ 20 | generate experiment configs from base config and experiment variations 21 | """ 22 | configs = [] 23 | for variation in experiment_variations: 24 | config = copy.deepcopy(base_config) 25 | 26 | for path, value in variation['params'].items(): 27 | # like router.type 28 | parts = path.split('.') 29 | current = config 30 | for part in parts[:-1]: 31 | current = current[part] 32 | current[parts[-1]] = value 33 | 34 | config['experiment_name'] = variation['name'] 35 | configs.append(config) 36 | 37 | return configs 38 | 39 | def save_temp_config(config, output_dir): 40 | os.makedirs(output_dir, exist_ok=True) 41 | config_path = os.path.join(output_dir, f"{config['experiment_name']}.yaml") 42 | 43 | with open(config_path, 'w', encoding='utf-8') as f: 44 | yaml.dump(config, f, default_flow_style=False, allow_unicode=True) 45 | 46 | return config_path 47 | 48 | # test 49 | if __name__ == "__main__": 50 | cfg = load_config() 51 | print(cfg) # print to see -------------------------------------------------------------------------------- /core/routing/random_router.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from loguru import logger 4 | 5 | from core.experts.load_experts import Expert 6 | from core.routing.base_router import BaseRouter, RouterOutput 7 | import random 8 | 9 | class RandomRouter(BaseRouter): 10 | def __init__(self, normal_experts: List[Expert], thinking_experts: List[Expert], router_config: dict): 11 | super().__init__(normal_experts, thinking_experts) 12 | self.config = router_config['random_router'] 13 | self.max_router = self.config['max_router'] 14 | available_models = self.config.get("available_models") 15 | self.candidate_models = [] 16 | for expert in normal_experts: 17 | if expert.model_name in available_models: 18 | self.candidate_models.append(expert) 19 | assert len(self.candidate_models) == len(available_models) 20 | if not isinstance(self.max_router, int) or self.max_router <= 0: 21 | raise ValueError(f"max_router must be a positive integer, got {self.max_router}") 22 | 23 | if self.max_router > len(normal_experts): 24 | logger.warning(f"max_router ({self.max_router}) is greater than the number of normal experts ({len(normal_experts)})") 25 | raise ValueError(f"max_router ({self.max_router}) cannot be greater than the number of normal experts ({len(normal_experts)})") 26 | 27 | def route(self, question: str) -> RouterOutput: 28 | """随机选择指定数量的专家进行路由 29 | 30 | Args: 31 | question: 输入的问题 32 | 33 | Returns: 34 | RouterOutput: 包含随机选择的normal_experts和所有thinking_experts 35 | """ 36 | selected_experts = random.sample(self.candidate_models, self.max_router) 37 | return RouterOutput( 38 | normal_experts=selected_experts, 39 | thinking_experts=self.thinking_experts 40 | ) -------------------------------------------------------------------------------- /evaluate/K_and_K/prompt.py: -------------------------------------------------------------------------------- 1 | system_instruction='''Your task is to solve a logical reasoning problem. You are given set of statements from which you must logically deduce the identity of a set of characters. 2 | 3 | You must infer the identity of each character. First, explain your reasoning. At the end of your answer, you must clearly state the identity of each character by following the format: 4 | 5 | CONCLUSION: 6 | (1) ... 7 | (2) ... 8 | (3) ... 9 | ''' 10 | 11 | 12 | system_instruction_no_reason='''Your task is to solve a logical reasoning problem. You are given set of statements from which you must logically deduce the identity of a set of characters. 13 | 14 | You must infer the identity of each character. At the end of your answer, you must clearly state the identity of each character by following the format: 15 | 16 | CONCLUSION: 17 | (1) ... 18 | (2) ... 19 | (3) ... 20 | ''' 21 | 22 | demonstration_2char_no_reason='''### Question: A very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 2 inhabitants: Jack, and Sophia. Jack tells you that Sophia is not a knave. Sophia says that If Jack is a knight then Sophia is a knight. So who is a knight and who is a knave? 23 | ### Answer: 24 | CONCLUSION: 25 | (1) Jack is a knight 26 | (2) Sophia is a knight 27 | ''' 28 | 29 | 30 | 31 | demonstration_2char='''### Question: A very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 2 inhabitants: Ella, and Penelope. In a statement by Ella: \"Ella is a knight or Penelope is a knight\". According to Penelope, \"Ella is a knave if and only if Penelope is a knight\". So who is a knight and who is a knave? 32 | ### Answer: Let's think step by step, by considering whether each person is lying and if that leads to contradiction. Assume Ella is a knight. Penelope cannot be a knight, because this would contradict the claim of their own. Penelope cannot be a knave, because this would contradict the false claim of their own. We have exhausted all possibilities for Penelope, so let us go back and reconsider Ella. Assume Ella is a knave. Penelope cannot be a knight, because this would contradict the false claim of Ella. Assume Penelope is a knave. This leads to a feasible solution. 33 | CONCLUSION: 34 | (1) Ella is a knave 35 | (2) Penelope is a knave 36 | ''' 37 | 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import subprocess 5 | import sys 6 | 7 | import yaml 8 | from loguru import logger 9 | 10 | from config.config_loader import (generate_experiment_configs, load_config, 11 | save_temp_config) 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser(description="Run multiple experiments sequentially") 16 | parser.add_argument( 17 | "--config", type=str, help="Path to experiment variations config file", default='config/experiment_variations.yaml' 18 | ) 19 | args = parser.parse_args() 20 | 21 | return args 22 | 23 | def run_experiment(config_path, save_dir: str): 24 | """运行单个实验""" 25 | logger.info(f"Running experiment with config: {config_path}, save on: {save_dir}") 26 | result = subprocess.run( 27 | [sys.executable, "app.py", "--config", config_path, "--save_dir", save_dir], 28 | ) 29 | 30 | if result.returncode != 0: 31 | logger.error(f"Experiment failed: {result.stderr}") 32 | return False 33 | 34 | logger.info(f"Experiment completed successfully") 35 | return True 36 | 37 | def main(): 38 | args = get_args() 39 | with open(args.config, 'r', encoding='utf-8') as f: 40 | experiment_config = yaml.safe_load(f) 41 | 42 | base_config_path = experiment_config.get('base_config', 'config/experts.yaml') 43 | save_dir = experiment_config.get('save_dir', 'results') 44 | base_config = load_config(base_config_path) 45 | experiment_variations = experiment_config.get('variations', []) 46 | 47 | # create temp dirs 48 | temp_dir = 'config/temp' 49 | os.makedirs(temp_dir, exist_ok=True) 50 | 51 | experiment_configs = generate_experiment_configs(base_config, experiment_variations) 52 | logger.info(f"Generated {len(experiment_configs)} experiment configs") 53 | 54 | for i, experiment_config in enumerate(experiment_configs): 55 | logger.info(f"Running experiment {i+1}/{len(experiment_configs)}: {experiment_config['experiment_name']}") 56 | config_path = save_temp_config(experiment_config, temp_dir) 57 | success = run_experiment(config_path, save_dir) 58 | 59 | if not success: 60 | logger.error(f"Experiment {experiment_config['experiment_name']} failed") 61 | 62 | logger.info("All experiments completed") 63 | 64 | 65 | if __name__ == "__main__": 66 | main() -------------------------------------------------------------------------------- /core/inference/selfconsistency_generator.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI, NOT_GIVEN 2 | from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type 3 | from loguru import logger 4 | 5 | from core.experts.load_experts import Expert 6 | from core.inference.base_generator import BaseGenerator, GeneratorOutput 7 | 8 | class SelfConsistencyGenerator(BaseGenerator): 9 | def __init__(self, expert: Expert, generator_config: dict): 10 | self.client = expert.client 11 | self.model = expert.model_name 12 | self.config = generator_config 13 | self.samples = self.config.get("samples", 5) 14 | self.temperature = self.config.get("temperature", 0.2) 15 | self.top_p = self.config.get("top_p", 1.0) 16 | self.top_k = self.config.get("top_k", NOT_GIVEN) 17 | # 定义重试日志记录函数 18 | def _log_retry(retry_state): 19 | exception = retry_state.outcome.exception() 20 | if exception: 21 | logger.warning(f"Retrying SelfConsistencyGenerator.generate due to error: {str(exception)}. Attempt {retry_state.attempt_number}/{retry_state.retry_object.stop.max_attempt_number}") 22 | return None 23 | 24 | @retry( 25 | stop=stop_after_attempt(10), # 最多重试10次 26 | wait=wait_exponential(multiplier=1, min=2, max=100), # 指数退避策略:1*2^x 秒,最少2秒,最多100秒 27 | retry=retry_if_exception_type((Exception)), # 捕获所有异常进行重试 28 | before_sleep=_log_retry # 重试前记录日志 29 | ) 30 | def generate(self, question: str) -> GeneratorOutput: 31 | try: 32 | response = self.client.chat.completions.create( 33 | model=self.model, 34 | messages=[{"role": "user", "content": question}], 35 | temperature=self.temperature, 36 | top_p=self.top_p, 37 | n=self.samples, 38 | timeout=1_000, 39 | ) 40 | choices = response.choices 41 | usage = response.usage 42 | raw_output = [choice.message.content for choice in choices] 43 | assert len(raw_output) == self.samples, f"Expected {self.samples} samples, got {len(raw_output)}" 44 | first_output = choices[0].message.content 45 | 46 | return GeneratorOutput( 47 | first_output=first_output, 48 | raw_output=raw_output, 49 | prompt_tokens=usage.prompt_tokens, 50 | completion_tokens=usage.completion_tokens 51 | ) 52 | except Exception as e: 53 | logger.error(f"Error in SelfConsistencyGenerator.generate: {str(e)}") 54 | raise # 重新抛出异常,让重试装饰器捕获 -------------------------------------------------------------------------------- /core/routing/factory.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from loguru import logger 4 | 5 | from core.routing.base_router import BaseRouter 6 | from core.routing.gpt_router import GPTRouter 7 | from core.routing.straight_router import StraightRouter 8 | from core.routing.routerdc_router import RouterDC 9 | from core.routing.random_router import RandomRouter 10 | from core.routing.elo_router import EloRouter 11 | from core.routing.rank_router import RankRouter 12 | from core.routing.symbolic_moe_router import SymbolicMoERouter 13 | from core.routing.moa_router import MoARouter 14 | 15 | class RouterType(Enum): 16 | GPT = "gpt" 17 | STRAIGHT = "straight" 18 | RANDOM = "random" 19 | ROUTERDC = "routerdc" 20 | ELO = "elo" 21 | RANK = "rank" 22 | SYMBOLIC_MOE = "symbolic_moe" 23 | MOA = "moa" 24 | 25 | class RouterFactory: 26 | @staticmethod 27 | def create_router(normal_experts: list, thinking_experts: list, router_config: dict): 28 | router_type = router_config['type'] 29 | 30 | if isinstance(router_type, str): 31 | router_type = RouterType(router_type) 32 | 33 | if router_type == RouterType.GPT: 34 | logger.info(f"Creating GPT router.") 35 | return GPTRouter(normal_experts, thinking_experts, router_config) 36 | elif router_type == RouterType.STRAIGHT: 37 | logger.info(f"Creating Straight router.") 38 | return StraightRouter(normal_experts, thinking_experts, router_config) 39 | elif router_type == RouterType.RANDOM: 40 | logger.info(f"Creating Random router.") 41 | return RandomRouter(normal_experts, thinking_experts, router_config) 42 | elif router_type == RouterType.ROUTERDC: 43 | logger.info(f"Creating RouterDC router.") 44 | return RouterDC(normal_experts, thinking_experts, router_config) 45 | elif router_type == RouterType.ELO: 46 | logger.info(f"Creating Elo router.") 47 | return EloRouter(normal_experts, thinking_experts, router_config) 48 | elif router_type == RouterType.RANK: 49 | logger.info(f"Creating Rank router.") 50 | return RankRouter(normal_experts, thinking_experts, router_config) 51 | elif router_type == RouterType.SYMBOLIC_MOE: 52 | logger.info(f"Creating Symbolic MoE router.") 53 | return SymbolicMoERouter(normal_experts, thinking_experts, router_config) 54 | elif router_type == RouterType.MOA: 55 | logger.info(f"Creating MoA router.") 56 | return MoARouter(normal_experts, thinking_experts, router_config) 57 | else: 58 | logger.error(f"Invalid router type: {router_type}") 59 | raise ValueError(f"Invalid router type: {router_type}") -------------------------------------------------------------------------------- /core/inference/factory.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import List 3 | 4 | from loguru import logger 5 | 6 | from core.inference.aggregation_generator import AggregationGenerator 7 | from core.inference.base_generator import BaseGenerator 8 | from core.inference.direct_generator import DirectGenerator 9 | from core.inference.fastslow_generator import FastSlowGenerator 10 | from core.inference.modelswitch_generator import ModelSwitchGenerator 11 | from core.inference.selfconsistency_generator import SelfConsistencyGenerator 12 | from core.inference.moa_generator import MoAGenerator 13 | from core.inference.slowfast_generator import SlowFastGenerator 14 | from core.routing import RouterOutput 15 | 16 | 17 | class GeneratorType(Enum): 18 | SELF_CONSISTENCY = "self_consistency" 19 | DIRECT = "direct" 20 | MODEL_SWITCH = "model_switch" 21 | FAST_SLOW = "fast_slow" 22 | AGGREGATION = "aggregation" 23 | MoA = "moa" 24 | SLOW_FAST = "slow_fast" 25 | 26 | class GeneratorFactory: 27 | @staticmethod 28 | def create_generator(experts: RouterOutput, generator_config: dict): 29 | generator_type = GeneratorType(generator_config["type"]) 30 | # get normal experts and thinking experts 31 | normal_experts = experts.normal_experts 32 | thinking_experts = experts.thinking_experts 33 | 34 | if generator_type == GeneratorType.SELF_CONSISTENCY: 35 | return SelfConsistencyGenerator(expert=normal_experts[0], generator_config=generator_config["self_consistency"]) 36 | elif generator_type == GeneratorType.DIRECT: 37 | return DirectGenerator(expert=normal_experts[0], generator_config=generator_config["direct"]) 38 | elif generator_type == GeneratorType.MODEL_SWITCH: 39 | return ModelSwitchGenerator(experts=normal_experts, generator_config=generator_config["model_switch"]) 40 | elif generator_type == GeneratorType.AGGREGATION: 41 | return AggregationGenerator(experts=normal_experts, generator_config=generator_config["aggregation"]) 42 | elif generator_type == GeneratorType.MoA: 43 | return MoAGenerator(experts=normal_experts, generator_config=generator_config["moa"]) 44 | elif generator_type == GeneratorType.FAST_SLOW: 45 | assert len(thinking_experts) >= 1, "FastSlowGenerator requires at least **1** thinking expert" 46 | return FastSlowGenerator( 47 | fast_expert=normal_experts[0], 48 | slow_expert=thinking_experts[0], 49 | generator_config=generator_config["fast_slow"] 50 | ) 51 | elif generator_type == GeneratorType.SLOW_FAST: 52 | return SlowFastGenerator( 53 | fast_expert=normal_experts[0], 54 | slow_expert=thinking_experts[0], 55 | generator_config=generator_config["slow_fast"] 56 | ) 57 | else: 58 | raise ValueError(f"Invalid generator type: {generator_type}") -------------------------------------------------------------------------------- /core/inference/direct_generator.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI, NOT_GIVEN 2 | from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type 3 | from loguru import logger 4 | 5 | from core.experts.load_experts import Expert 6 | from core.inference.base_generator import BaseGenerator, GeneratorOutput 7 | 8 | class DirectGenerator(BaseGenerator): 9 | def __init__(self, expert: Expert, generator_config: dict): 10 | self.client = expert.client 11 | self.model = expert.model_name 12 | self.config = generator_config 13 | self.temperature = self.config.get("temperature", 0.2) 14 | self.top_p = self.config.get("top_p", 1.0) 15 | self.top_k = self.config.get("top_k", NOT_GIVEN) 16 | 17 | # 定义重试日志记录函数 18 | def _log_retry(retry_state): 19 | exception = retry_state.outcome.exception() 20 | if exception: 21 | logger.warning(f"Retrying DirectGenerator.generate due to error: {str(exception)}. Attempt {retry_state.attempt_number}/{retry_state.retry_object.stop.max_attempt_number}") 22 | return None 23 | 24 | @retry( 25 | stop=stop_after_attempt(5), # 最多重试5次 26 | wait=wait_exponential(multiplier=1, min=2, max=60), # 指数退避策略:1*2^x 秒,最少2秒,最多60秒 27 | retry=retry_if_exception_type((Exception)), # 捕获所有异常进行重试 28 | before_sleep=_log_retry # 重试前记录日志 29 | ) 30 | def generate_with_retry(self, question: str) -> GeneratorOutput: 31 | if "Distill" in self.model or "EXAOME" in self.model: 32 | question += "Don't make your reasoning and thinking too long.\n" 33 | try: 34 | response = self.client.chat.completions.create( 35 | model=self.model, 36 | messages=[{"role": "user", "content": question}], 37 | temperature=self.temperature, 38 | top_p=self.top_p, 39 | timeout=500, 40 | ) 41 | usage = response.usage 42 | choices = response.choices 43 | assert choices[0].message.content is not None, f"choices[0].message.content is None" 44 | 45 | return GeneratorOutput( 46 | first_output=choices[0].message.content, 47 | raw_output=[choice.message.content for choice in choices], 48 | prompt_tokens=usage.prompt_tokens, 49 | completion_tokens=usage.completion_tokens 50 | ) 51 | except Exception as e: 52 | logger.error(f"Error in DirectGenerator.generate: {str(e)}, model_name: {self.model}") 53 | raise # 重新抛出异常,让重试装饰器捕获 54 | 55 | def generate(self, question: str) -> GeneratorOutput: 56 | try: 57 | return self.generate_with_retry(question=question) 58 | except Exception as e: 59 | logger.error( 60 | f"Error in DirectGenerator.generate after all retries: " 61 | f"{str(e)}, model_name: {self.model}" 62 | ) 63 | return GeneratorOutput( 64 | first_output="failed to generate", 65 | raw_output=["failed to generate"], 66 | prompt_tokens=0, 67 | completion_tokens=0 68 | ) 69 | -------------------------------------------------------------------------------- /evaluate/LiveCodeBench/compute_code_generation_metrics.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import json 3 | from evaluate.LiveCodeBench.testing_util import run_test 4 | from loguru import logger 5 | import numpy as np 6 | import os 7 | 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | 10 | if not multiprocessing.get_start_method(allow_none=True): 11 | multiprocessing.set_start_method('spawn') 12 | 13 | def _temp_run(sample, generation, debug, result, metadata_list, timeout): 14 | res, metadata = run_test(sample, test=generation, debug=debug, timeout=timeout) 15 | result.append(res) 16 | metadata_list.append(metadata) 17 | 18 | def check_correctness(sample, generation, timeout, debug=True): 19 | """Check correctness of code generation with a global timeout. 20 | The global timeout is to catch some extreme/rare cases not handled by the timeouts 21 | inside `run_test`""" 22 | 23 | manager = multiprocessing.Manager() 24 | result = manager.list() 25 | metadata_list = manager.list() 26 | p = multiprocessing.Process( 27 | target=_temp_run, 28 | args=(sample, generation, debug, result, metadata_list, timeout), 29 | ) 30 | p.start() 31 | p.join( 32 | timeout=(timeout + 1) * len(json.loads(sample["input_output"])["inputs"]) + 5 33 | ) 34 | if p.is_alive(): 35 | p.kill() 36 | if not result: 37 | in_outs = json.loads(sample["input_output"]) 38 | # consider that all tests failed 39 | result = [[-1 for i in range(len(in_outs["inputs"]))]] 40 | if debug: 41 | print(f"global timeout") 42 | 43 | return result[0], metadata_list[0] 44 | 45 | 46 | def evaluate_generation(generations, sample, debug: bool = False, timeout:int=6): 47 | res = [] 48 | metadata = [] 49 | 50 | for o_idx, o in enumerate(generations): 51 | curr_res = [-2] 52 | try: 53 | # print(sample, o) 54 | curr_res, curr_metadata = check_correctness(sample, o, timeout, debug) 55 | if debug: 56 | logger.info(f"sample generation {o_idx} passed {curr_res}") 57 | fixed = [] 58 | for e in curr_res: 59 | if isinstance(e, np.ndarray): 60 | e = e.item(0) 61 | if isinstance(e, np.bool_): 62 | e = bool(e) 63 | fixed.append(e) 64 | curr_res = fixed 65 | if not np.all(curr_res): 66 | if debug: 67 | logger.info(f"Results were not True for all test cases {curr_res=}\n") 68 | except Exception as e: 69 | if debug: 70 | logger.warning(f"Compilation failed, test framework exception = {repr(e)}{e}\n") 71 | curr_metadata = { 72 | "error": str(e), 73 | "error_code": -5, 74 | "error_message": "TestRunnerError" 75 | } 76 | finally: 77 | assert isinstance(curr_res, list) 78 | assert isinstance(curr_metadata, dict) 79 | res.append(curr_res) 80 | metadata.append(curr_metadata) 81 | if debug: 82 | for i, r in enumerate(res): 83 | logger.info("Sample\n") 84 | logger.info(sample) 85 | logger.info("\n") 86 | logger.info("Result\n") 87 | logger.info(res[i]) 88 | logger.info("*" * 30 + "\n\n") 89 | return res, metadata 90 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | 6 | from loguru import logger 7 | 8 | from config.config_loader import load_config 9 | from core.experts.load_experts import load_experts 10 | from core.inference import DirectGenerator 11 | from core.routing import GPTRouter, RouterFactory, RouterType, StraightRouter 12 | from evaluate.factory import EvaluatorFactory 13 | 14 | 15 | def show_config(config: dict): 16 | logger.info("="*30) 17 | logger.info("Experiment Config:") 18 | if 'experiment_name' in config: 19 | logger.info(f"Experiment: {config['experiment_name']}") 20 | logger.info(f"Task: {config['experiments']['task']}") 21 | logger.info(f"Max workers: {config['experiments']['max_workers']}") 22 | logger.info(f"Mode: {config['experiments']['mode']}") 23 | logger.info(f"Router: {config['router']['type']}") 24 | logger.info(f"Generator: {config['generator']['type']}") 25 | logger.info(f"Use HTTP Cache: {config['experiments']['use_http_cache']}") 26 | logger.info("="*30) 27 | 28 | 29 | def run_experiment(config: dict, save_dir: str = None): 30 | show_config(config) 31 | 32 | task = config['experiments']['task'] 33 | max_workers = config['experiments']['max_workers'] 34 | mode = config['experiments']['mode'] 35 | use_http_cache = config['experiments']['use_http_cache'] 36 | 37 | if use_http_cache: 38 | logger.warning("Use HTTP Cache, supported by hishel: https://github.com/karpetrosyan/hishel") 39 | logger.warning(f"Cache dir: {config['experiments']['cache_dir']}") 40 | 41 | # 2. load experts, TODO: thinking experts 42 | normal_experts, thinking_experts = load_experts(config) 43 | 44 | # 3. create router 45 | router = RouterFactory.create_router( 46 | normal_experts=normal_experts, thinking_experts=thinking_experts, router_config = config['router'] 47 | ) 48 | # 4. get evaluator 49 | if config['generator']['type'] == "fast_slow" and max_workers > 1: 50 | logger.warning(f"FastSlowGenerator does not recommend multi-threading, kv-cache may cause GPU boom.") 51 | evaluator = EvaluatorFactory(max_workers=max_workers, mode=mode).get_evaluator(task=task) 52 | 53 | # 5. evaluate 54 | results = evaluator.evaluate_loop(router=router, generator_config=config['generator']) 55 | 56 | results['config'] = config 57 | 58 | # 6. save results 59 | # 使用实验名称作为文件名的一部分(如果有) 60 | if save_dir is None: 61 | save_dir = "results" 62 | 63 | if not os.path.exists(save_dir): 64 | os.makedirs(save_dir) 65 | 66 | experiment_name = config.get('experiment_name', '') 67 | generator_type = config.get('generator', {}).get('type', '') 68 | 69 | os.makedirs(f"{save_dir}/{generator_type}", exist_ok=True) 70 | if experiment_name: 71 | filename = f"{save_dir}/{generator_type}/{task}-{experiment_name}-{time.strftime('%Y%m%d-%H%M%S')}.json" 72 | else: 73 | filename = f"{save_dir}/{generator_type}/{task}-{time.strftime('%Y%m%d-%H%M%S')}.json" 74 | 75 | logger.info(f"Save result to {filename}") 76 | with open(filename, "w") as f: 77 | json.dump(results, f, indent=4, ensure_ascii=False) 78 | 79 | return results 80 | 81 | if __name__ == "__main__": 82 | # 添加命令行参数解析 83 | parser = argparse.ArgumentParser(description="Run experiments") 84 | parser.add_argument("--config", type=str, default=None, 85 | help="Path to config file (relative to config directory)") 86 | parser.add_argument("--save_dir", type=str, default="results", 87 | help="Path to save results") 88 | args = parser.parse_args() 89 | 90 | # 1. Load config 91 | config = load_config(args.config) 92 | run_experiment(config, args.save_dir) -------------------------------------------------------------------------------- /core/experts/load_experts.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from loguru import logger 5 | from openai import OpenAI 6 | 7 | from config.config_loader import load_config 8 | import hishel, httpx 9 | import json, hashlib 10 | from typing import Optional 11 | from hishel._utils import normalized_url 12 | 13 | @dataclass 14 | class Expert: 15 | model_name: str 16 | base_url: str 17 | api_key: str 18 | description: str 19 | client: OpenAI 20 | 21 | def param_only_key(request: httpx.Request, body: Optional[bytes] = b"") -> str: 22 | INTERESTED_FIELDS = {"model", "temperature", "top_p", "n", "messages"} 23 | 24 | # 1) extract url path 25 | full_url = normalized_url(request.url) 26 | url_obj = httpx.URL(full_url) 27 | encoded_url = (url_obj.raw_path or b"/") 28 | # 2) extract interested fields 29 | try: 30 | payload = json.loads(body or b"{}") 31 | except json.JSONDecodeError: 32 | payload = {} 33 | filtered = {k: payload.get(k) for k in sorted(INTERESTED_FIELDS)} 34 | encoded_body = json.dumps( 35 | filtered, separators=(",", ":"), sort_keys=True, ensure_ascii=False 36 | ).encode() 37 | 38 | # 3) generate key 39 | key_parts = [request.method, encoded_url, encoded_body] 40 | 41 | try: # use blake2b-128 42 | hasher = hashlib.blake2b(digest_size=16, usedforsecurity=False) 43 | except (TypeError, ValueError, AttributeError): 44 | hasher = hashlib.sha256(usedforsecurity=False) 45 | 46 | for part in key_parts: 47 | hasher.update(part) 48 | return hasher.hexdigest() 49 | 50 | def init_http_cache(cache_dir: str): 51 | storage = hishel.FileStorage(base_path=cache_dir) 52 | base_transport = httpx.HTTPTransport() 53 | controller = hishel.Controller( 54 | cacheable_methods = ["GET", "POST"], 55 | cacheable_status_codes=[200], 56 | allow_stale=True, 57 | force_cache=True, 58 | key_generator=param_only_key 59 | ) 60 | transport = hishel.CacheTransport( 61 | storage=storage, 62 | transport=base_transport, 63 | controller=controller 64 | ) 65 | return transport 66 | 67 | def load_experts(config: dict) -> List[Expert]: 68 | use_http_cache = config['experiments']['use_http_cache'] 69 | if use_http_cache: 70 | cache_dir = config['experiments']['cache_dir'] 71 | transport = init_http_cache(cache_dir) 72 | httpx_client = httpx.Client(transport=transport) 73 | 74 | experts = [] 75 | for model_config in config['experts']: 76 | client = OpenAI( 77 | base_url=model_config['base_url'], 78 | api_key=model_config['api_key'], 79 | http_client=httpx_client if use_http_cache else None 80 | ) 81 | experts.append(Expert( 82 | model_name=model_config['name'], 83 | base_url=model_config['base_url'], 84 | api_key=model_config['api_key'], 85 | description=model_config['description'], 86 | client=client 87 | )) 88 | expert_names = [e.model_name for e in experts] 89 | logger.info( 90 | f"Load {len(experts)} experts: {expert_names}" 91 | ) 92 | thinking_experts = [] 93 | if 'thinking_experts' in config.keys(): 94 | for model_config in config['thinking_experts']: 95 | client = OpenAI( 96 | base_url=model_config['base_url'], 97 | api_key=model_config['api_key'] 98 | ) 99 | thinking_experts.append(Expert( 100 | model_name=model_config['name'], 101 | base_url=model_config['base_url'], 102 | api_key=model_config['api_key'], 103 | description=model_config['description'], 104 | client=client 105 | )) 106 | logger.info(f"Load thinking expert: {[e.model_name for e in thinking_experts]}") 107 | 108 | return experts, thinking_experts 109 | # test 110 | if __name__ == "__main__": 111 | config = load_config() 112 | experts = load_experts(config) 113 | logger.info(experts) -------------------------------------------------------------------------------- /evaluate/base_evaluator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from abc import ABC, abstractmethod 4 | from collections import Counter 5 | from typing import Dict, List 6 | 7 | from loguru import logger 8 | 9 | BOXED_PATTERN = r"\\boxed\{([^}]*)\}" 10 | 11 | class BaseEvaluator(ABC): 12 | def __init__(self, max_workers:int=8, mode: str="test"): 13 | self.prompt_tokens = 0 14 | self.completion_tokens = 0 15 | self.max_workers = max_workers 16 | self.mode = mode 17 | 18 | def load_jsonl(self, file_path: str) -> List[Dict]: 19 | with open(file_path, "r") as f: 20 | data = [json.loads(line) for line in f] 21 | return data 22 | 23 | def update_tokens(self, prompt_tokens: int, completion_tokens: int): 24 | self.prompt_tokens += prompt_tokens 25 | self.completion_tokens += completion_tokens 26 | 27 | def fresh_tokens(self): 28 | self.prompt_tokens = 0 29 | self.completion_tokens = 0 30 | 31 | def extract_boxed_content(self, text: str) -> str: 32 | start_tag = r"\boxed{" 33 | start = text.find(start_tag) 34 | if start == -1: 35 | return "" 36 | 37 | start += len(start_tag) 38 | brace_count = 1 # 已经找到一个 { 39 | result = [] 40 | 41 | for char in text[start:]: 42 | if char == "{": 43 | brace_count += 1 44 | elif char == "}": 45 | brace_count -= 1 46 | if brace_count == 0: 47 | break 48 | result.append(char) 49 | 50 | return ''.join(result).strip() 51 | 52 | def count_prediction_frequency(self, predictions: list): 53 | """ 54 | Count the frequency of each prediction in the list of predictions for math or multi-choice tasks. 55 | """ 56 | prediction_counts = Counter(predictions) 57 | total = len(predictions) 58 | frequency_stats = dict() 59 | for pred, count in prediction_counts.items(): 60 | frequency_stats[pred] = { 61 | "count": count, 62 | "frequency": count / total 63 | } 64 | return frequency_stats 65 | 66 | def calculate_model_counts(self, results: list[dict]): 67 | # process model name 68 | position_model_counts = {} 69 | for result in results: 70 | model_name = result['model_name'] 71 | if not isinstance(model_name, list): 72 | model_name = [model_name] 73 | for idx, model in enumerate(model_name): 74 | position = idx + 1 75 | if position not in position_model_counts: 76 | position_model_counts[position] = Counter() 77 | position_model_counts[position][model] += 1 78 | 79 | # 输出每个位置的模型使用情况 80 | for position, model_counter in position_model_counts.items(): 81 | logger.info(f"Position {position} model counts: {model_counter}") 82 | 83 | return position_model_counts 84 | 85 | def extract_normal_answer(self, text: str, answer_pattern: str) -> str: 86 | """ 87 | Extract the answer from the text using the answer pattern. 88 | Like: 89 | - Answer: 123 -> 123 90 | - Answer:123 -> 123 91 | - Final Answer\n\nxxx -> xxx 92 | if failed, try to parse \\boxed{answer} 93 | """ 94 | if len(text) <= 10 and 'Answer' not in text and 'box' not in text: 95 | return text.lstrip() 96 | 97 | if text is None: 98 | return "" 99 | 100 | # First, try to match using the provided answer_pattern 101 | matches = re.findall(answer_pattern, text) 102 | if matches: 103 | extracted_answer = matches[-1].strip() 104 | if extracted_answer.lower().startswith("answer: "): 105 | extracted_answer = extracted_answer[len("answer:"):].strip().lstrip() 106 | return extracted_answer 107 | 108 | # If no match is found, check for "Final Answer" format 109 | answer_pattern = answer_pattern.replace("Answer\s*:\s", "Final Answer\s\n+\s") 110 | final_answer_match = re.search(answer_pattern, text) 111 | if final_answer_match: 112 | return final_answer_match.group(1).strip() 113 | 114 | # If both patterns fail, try to extract boxed content 115 | return self.extract_boxed_content(text) 116 | 117 | @abstractmethod 118 | def load_data(self, split: str): 119 | pass 120 | 121 | @abstractmethod 122 | def evaluate(self, question: str, answer: str): 123 | pass 124 | 125 | -------------------------------------------------------------------------------- /evaluate/MBPP/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import traceback 3 | from typing import Optional, List, Tuple, Dict, Set 4 | 5 | imports = [ "import math", 6 | "import re", 7 | "import sys", 8 | "import copy", 9 | "import datetime", 10 | "import itertools", 11 | "import collections", 12 | "import heapq", 13 | "import functools", 14 | "import hashlib", 15 | "import numpy", 16 | "import numpy as np", 17 | "import string", 18 | "from typing import *", 19 | "from collections import *" 20 | ] 21 | 22 | def get_definition_name(node: ast.AST) -> Optional[str]: 23 | if isinstance(node, (ast.FunctionDef, ast.ClassDef)): 24 | return node.name 25 | elif isinstance(node, ast.Assign): 26 | targets = node.targets 27 | if targets and isinstance(targets[0], ast.Name): 28 | return targets[0].id 29 | return None 30 | 31 | def refine_text(text: str) -> str: 32 | text = text.replace("\t", " ") 33 | text = text.replace("\r\n", "\n").replace("\r", "\n") 34 | return text.strip() + "\n" 35 | 36 | def syntax_check(code, verbose = False): 37 | try: 38 | ast.parse(code) 39 | return True 40 | except (SyntaxError, MemoryError): 41 | if verbose: 42 | traceback.print_exc() 43 | return False 44 | 45 | 46 | def extract_longest_valid_code(text: str) -> str: 47 | lines = text.splitlines() 48 | 49 | if len(lines) > 100: 50 | lines = lines[:100] 51 | max_valid_lines = 0 52 | max_valid_snippet = "" 53 | 54 | for i in range(len(lines)): 55 | for j in range(i, len(lines)): 56 | current_snippet = "\n".join(lines[i:j+1]) 57 | if syntax_check(current_snippet): 58 | valid_line_count = sum(1 for line in lines[i:j+1] if line.strip()) 59 | if valid_line_count > max_valid_lines: 60 | max_valid_lines = valid_line_count 61 | max_valid_snippet = current_snippet 62 | 63 | return max_valid_snippet 64 | 65 | def has_return_statement(node: ast.AST) -> bool: 66 | return any(isinstance(n, ast.Return) for n in ast.walk(node)) 67 | 68 | def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]: 69 | name2deps = {} 70 | for name, node in nodes: 71 | deps = set() 72 | stack = [node] 73 | while stack: 74 | current = stack.pop() 75 | for child in ast.iter_child_nodes(current): 76 | if isinstance(child, ast.Name): 77 | deps.add(child.id) 78 | elif isinstance(child, ast.Attribute): 79 | deps.add(child.attr) 80 | else: 81 | stack.append(child) 82 | name2deps[name] = deps 83 | return name2deps 84 | 85 | def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]: 86 | visited = set() 87 | to_visit = [entrypoint] 88 | 89 | while to_visit: 90 | current = to_visit.pop(0) 91 | if current not in visited: 92 | visited.add(current) 93 | to_visit.extend(call_graph.get(current, set()) - visited) 94 | 95 | return visited 96 | 97 | def sanitize(text: str, entrypoint: Optional[str] = None) -> str: 98 | 99 | text = refine_text(text) 100 | 101 | # text = python_extract(text) 102 | 103 | code = extract_longest_valid_code(text) 104 | tree = ast.parse(code) 105 | 106 | definitions = {} 107 | 108 | imports = [] 109 | 110 | for node in tree.body: 111 | if isinstance(node, (ast.Import, ast.ImportFrom)): 112 | imports.append(node) 113 | elif isinstance(node, ast.ClassDef): 114 | name = node.name 115 | definitions[name] = ('class', node) 116 | elif isinstance(node, ast.FunctionDef): 117 | name = node.name 118 | if has_return_statement(node): 119 | definitions[name] = ('function', node) 120 | elif isinstance(node, ast.Assign): 121 | name = get_definition_name(node) 122 | if name: 123 | definitions[name] = ('variable', node) 124 | 125 | if entrypoint: 126 | name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()]) 127 | reachable = get_function_dependency(entrypoint, name2deps) 128 | 129 | sanitized_output = [] 130 | 131 | for node in imports: 132 | sanitized_output.append(ast.unparse(node)) 133 | 134 | for name, (_, node) in definitions.items(): 135 | if not entrypoint or name in reachable: 136 | sanitized_output.append(ast.unparse(node)) 137 | 138 | return "\n".join(sanitized_output) -------------------------------------------------------------------------------- /evaluate/HumanEval/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import traceback 3 | from typing import Optional, List, Tuple, Dict, Set 4 | 5 | imports = [ "import math", 6 | "import re", 7 | "import sys", 8 | "import copy", 9 | "import datetime", 10 | "import itertools", 11 | "import collections", 12 | "import heapq", 13 | "import functools", 14 | "import hashlib", 15 | "import numpy", 16 | "import numpy as np", 17 | "import string", 18 | "from typing import *", 19 | "from collections import *" 20 | ] 21 | 22 | def get_definition_name(node: ast.AST) -> Optional[str]: 23 | if isinstance(node, (ast.FunctionDef, ast.ClassDef)): 24 | return node.name 25 | elif isinstance(node, ast.Assign): 26 | targets = node.targets 27 | if targets and isinstance(targets[0], ast.Name): 28 | return targets[0].id 29 | return None 30 | 31 | def refine_text(text: str) -> str: 32 | text = text.replace("\t", " ") 33 | text = text.replace("\r\n", "\n").replace("\r", "\n") 34 | return text.strip() + "\n" 35 | 36 | def syntax_check(code, verbose = False): 37 | try: 38 | ast.parse(code) 39 | return True 40 | except (SyntaxError, MemoryError): 41 | if verbose: 42 | traceback.print_exc() 43 | return False 44 | 45 | 46 | def extract_longest_valid_code(text: str) -> str: 47 | lines = text.splitlines() 48 | 49 | if len(lines) > 100: 50 | lines = lines[:100] 51 | max_valid_lines = 0 52 | max_valid_snippet = "" 53 | 54 | for i in range(len(lines)): 55 | for j in range(i, len(lines)): 56 | current_snippet = "\n".join(lines[i:j+1]) 57 | if syntax_check(current_snippet): 58 | valid_line_count = sum(1 for line in lines[i:j+1] if line.strip()) 59 | if valid_line_count > max_valid_lines: 60 | max_valid_lines = valid_line_count 61 | max_valid_snippet = current_snippet 62 | 63 | return max_valid_snippet 64 | 65 | def has_return_statement(node: ast.AST) -> bool: 66 | return any(isinstance(n, ast.Return) for n in ast.walk(node)) 67 | 68 | def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]: 69 | name2deps = {} 70 | for name, node in nodes: 71 | deps = set() 72 | stack = [node] 73 | while stack: 74 | current = stack.pop() 75 | for child in ast.iter_child_nodes(current): 76 | if isinstance(child, ast.Name): 77 | deps.add(child.id) 78 | elif isinstance(child, ast.Attribute): 79 | deps.add(child.attr) 80 | else: 81 | stack.append(child) 82 | name2deps[name] = deps 83 | return name2deps 84 | 85 | def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]: 86 | visited = set() 87 | to_visit = [entrypoint] 88 | 89 | while to_visit: 90 | current = to_visit.pop(0) 91 | if current not in visited: 92 | visited.add(current) 93 | to_visit.extend(call_graph.get(current, set()) - visited) 94 | 95 | return visited 96 | 97 | def sanitize(text: str, entrypoint: Optional[str] = None) -> str: 98 | 99 | text = refine_text(text) 100 | 101 | # text = python_extract(text) 102 | 103 | code = extract_longest_valid_code(text) 104 | tree = ast.parse(code) 105 | 106 | definitions = {} 107 | 108 | imports = [] 109 | 110 | for node in tree.body: 111 | if isinstance(node, (ast.Import, ast.ImportFrom)): 112 | imports.append(node) 113 | elif isinstance(node, ast.ClassDef): 114 | name = node.name 115 | definitions[name] = ('class', node) 116 | elif isinstance(node, ast.FunctionDef): 117 | name = node.name 118 | if has_return_statement(node): 119 | definitions[name] = ('function', node) 120 | elif isinstance(node, ast.Assign): 121 | name = get_definition_name(node) 122 | if name: 123 | definitions[name] = ('variable', node) 124 | 125 | if entrypoint: 126 | name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()]) 127 | reachable = get_function_dependency(entrypoint, name2deps) 128 | 129 | sanitized_output = [] 130 | 131 | for node in imports: 132 | sanitized_output.append(ast.unparse(node)) 133 | 134 | for name, (_, node) in definitions.items(): 135 | if not entrypoint or name in reachable: 136 | sanitized_output.append(ast.unparse(node)) 137 | 138 | return "\n".join(sanitized_output) -------------------------------------------------------------------------------- /evaluate/StudentEval/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import traceback 3 | from typing import Optional, List, Tuple, Dict, Set 4 | 5 | imports = [ "import math", 6 | "import re", 7 | "import sys", 8 | "import copy", 9 | "import datetime", 10 | "import itertools", 11 | "import collections", 12 | "import heapq", 13 | "import functools", 14 | "import hashlib", 15 | "import numpy", 16 | "import numpy as np", 17 | "import string", 18 | "from typing import *", 19 | "from collections import *" 20 | ] 21 | 22 | def get_definition_name(node: ast.AST) -> Optional[str]: 23 | if isinstance(node, (ast.FunctionDef, ast.ClassDef)): 24 | return node.name 25 | elif isinstance(node, ast.Assign): 26 | targets = node.targets 27 | if targets and isinstance(targets[0], ast.Name): 28 | return targets[0].id 29 | return None 30 | 31 | def refine_text(text: str) -> str: 32 | text = text.replace("\t", " ") 33 | text = text.replace("\r\n", "\n").replace("\r", "\n") 34 | return text.strip() + "\n" 35 | 36 | def syntax_check(code, verbose = False): 37 | try: 38 | ast.parse(code) 39 | return True 40 | except (SyntaxError, MemoryError): 41 | if verbose: 42 | traceback.print_exc() 43 | return False 44 | 45 | 46 | def extract_longest_valid_code(text: str) -> str: 47 | lines = text.splitlines() 48 | 49 | if len(lines) > 100: 50 | lines = lines[:100] 51 | max_valid_lines = 0 52 | max_valid_snippet = "" 53 | 54 | for i in range(len(lines)): 55 | for j in range(i, len(lines)): 56 | current_snippet = "\n".join(lines[i:j+1]) 57 | if syntax_check(current_snippet): 58 | valid_line_count = sum(1 for line in lines[i:j+1] if line.strip()) 59 | if valid_line_count > max_valid_lines: 60 | max_valid_lines = valid_line_count 61 | max_valid_snippet = current_snippet 62 | 63 | return max_valid_snippet 64 | 65 | def has_return_statement(node: ast.AST) -> bool: 66 | return any(isinstance(n, ast.Return) for n in ast.walk(node)) 67 | 68 | def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]: 69 | name2deps = {} 70 | for name, node in nodes: 71 | deps = set() 72 | stack = [node] 73 | while stack: 74 | current = stack.pop() 75 | for child in ast.iter_child_nodes(current): 76 | if isinstance(child, ast.Name): 77 | deps.add(child.id) 78 | elif isinstance(child, ast.Attribute): 79 | deps.add(child.attr) 80 | else: 81 | stack.append(child) 82 | name2deps[name] = deps 83 | return name2deps 84 | 85 | def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]: 86 | visited = set() 87 | to_visit = [entrypoint] 88 | 89 | while to_visit: 90 | current = to_visit.pop(0) 91 | if current not in visited: 92 | visited.add(current) 93 | to_visit.extend(call_graph.get(current, set()) - visited) 94 | 95 | return visited 96 | 97 | def sanitize(text: str, entrypoint: Optional[str] = None) -> str: 98 | 99 | text = refine_text(text) 100 | 101 | # text = python_extract(text) 102 | 103 | code = extract_longest_valid_code(text) 104 | tree = ast.parse(code) 105 | 106 | definitions = {} 107 | 108 | imports = [] 109 | 110 | for node in tree.body: 111 | if isinstance(node, (ast.Import, ast.ImportFrom)): 112 | imports.append(node) 113 | elif isinstance(node, ast.ClassDef): 114 | name = node.name 115 | definitions[name] = ('class', node) 116 | elif isinstance(node, ast.FunctionDef): 117 | name = node.name 118 | if has_return_statement(node): 119 | definitions[name] = ('function', node) 120 | elif isinstance(node, ast.Assign): 121 | name = get_definition_name(node) 122 | if name: 123 | definitions[name] = ('variable', node) 124 | 125 | if entrypoint: 126 | name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()]) 127 | reachable = get_function_dependency(entrypoint, name2deps) 128 | 129 | sanitized_output = [] 130 | 131 | for node in imports: 132 | sanitized_output.append(ast.unparse(node)) 133 | 134 | for name, (_, node) in definitions.items(): 135 | if not entrypoint or name in reachable: 136 | sanitized_output.append(ast.unparse(node)) 137 | 138 | return "\n".join(sanitized_output) -------------------------------------------------------------------------------- /core/inference/modelswitch_generator.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor, as_completed 2 | from typing import List 3 | 4 | from loguru import logger 5 | from openai import NOT_GIVEN, OpenAI 6 | from tenacity import (retry, retry_if_exception_type, stop_after_attempt, 7 | wait_exponential) 8 | 9 | from core.experts.load_experts import Expert 10 | from core.inference.base_generator import BaseGenerator, GeneratorOutput 11 | 12 | 13 | class ModelSwitchGenerator(BaseGenerator): 14 | def __init__(self, experts: List[Expert], generator_config: dict): 15 | self.name = self.__class__.__name__ 16 | assert len(experts) > 1, "ModelSwitchGenerator requires at least two experts." 17 | # get clients and models from experts 18 | self.experts = experts 19 | self.first_client = experts[0].client 20 | self.first_model = experts[0].model_name 21 | self.second_client = experts[1].client 22 | self.second_model = experts[1].model_name 23 | self.model = [self.first_model, self.second_model] 24 | # get config 25 | self.config = generator_config 26 | self.samples = self.config.get("samples", 5) 27 | self.temperature = self.config.get("temperature", 0.7) 28 | self.top_p = self.config.get("top_p", 1.0) 29 | self.top_k = self.config.get("top_k", NOT_GIVEN) 30 | self.consistency_rate_threshold = self.config.get("consistency_rate_threshold", 0.8) 31 | # define final results 32 | self.first_results = None 33 | self.second_results = None 34 | self.final_results = None 35 | 36 | def _log_retry(retry_state): 37 | exception = retry_state.outcome.exception() 38 | if exception: 39 | logger.warning(f"Retrying ModelSwitchGenerator.generate due to error: {str(exception)}. Attempt {retry_state.attempt_number}/{retry_state.retry_object.stop.max_attempt_number}") 40 | return None 41 | 42 | @retry( 43 | stop=stop_after_attempt(10), # 最多重试10次 44 | wait=wait_exponential(multiplier=1, min=2, max=100), # 指数退避策略:1*2^x 秒,最少2秒,最多100秒 45 | retry=retry_if_exception_type((Exception)), # 捕获所有异常进行重试 46 | before_sleep=_log_retry # 重试前记录日志 47 | ) 48 | def _generate(self, client: OpenAI, model: str, question: str) -> GeneratorOutput: 49 | try: 50 | response = client.chat.completions.create( 51 | model=model, 52 | messages=[{"role": "user", "content": question}], 53 | temperature=self.temperature, 54 | top_p=self.top_p, 55 | n=self.samples, 56 | timeout=1000, 57 | ) 58 | choices = response.choices 59 | usage = response.usage 60 | raw_output = [choice.message.content for choice in choices] 61 | assert len(raw_output) == self.samples, f"Expected {self.samples} samples, got {len(raw_output)}" 62 | 63 | return GeneratorOutput( 64 | first_output=choices[0].message.content, 65 | raw_output=raw_output, 66 | prompt_tokens=usage.prompt_tokens, 67 | completion_tokens=usage.completion_tokens 68 | ) 69 | except Exception as e: 70 | logger.error(f"Error in ModelSwitchGenerator._generate: {str(e)}, error model: {model}") 71 | raise # 重新抛出异常,让重试装饰器捕获 72 | 73 | def generate(self, question: str) -> tuple[GeneratorOutput, GeneratorOutput]: 74 | results = dict() 75 | 76 | with ThreadPoolExecutor(max_workers=2) as executor: 77 | future_to_model = { 78 | executor.submit(self._generate, self.first_client, self.first_model, question): self.first_model, 79 | executor.submit(self._generate, self.second_client, self.second_model, question): self.second_model 80 | } 81 | for future in as_completed(future_to_model): 82 | model = future_to_model[future] 83 | try: 84 | results[model] = future.result() 85 | except Exception as e: 86 | logger.error(f"Error in ModelSwitchGenerator.generate: {str(e)}, error model: {model}") 87 | results[model] = None 88 | 89 | self.first_results = results.get(self.first_model, None) 90 | self.second_results = results.get(self.second_model, None) 91 | self.final_results = self.get_second_output() 92 | 93 | return self.first_results, self.final_results 94 | 95 | def get_second_output(self) -> GeneratorOutput: 96 | if self.first_results is None and self.second_results is None: 97 | raise ValueError("first_results and second_results is None") 98 | if self.first_results is None: 99 | return self.second_results 100 | if self.second_results is None: 101 | return self.first_results 102 | 103 | final_results = GeneratorOutput( 104 | first_output = self.second_results.first_output, 105 | raw_output = self.first_results.raw_output + self.second_results.raw_output, 106 | prompt_tokens = self.first_results.prompt_tokens + self.second_results.prompt_tokens, 107 | completion_tokens = self.first_results.completion_tokens + self.second_results.completion_tokens 108 | ) 109 | 110 | return final_results -------------------------------------------------------------------------------- /evaluate/ArenaHard/arenahard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | 16 | disable_progress_bars() 17 | 18 | DATA_DIR = "data/arena-hard" 19 | 20 | PROMPT_FOUR_OPTIONS = """Answer the following question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. 21 | 22 | {question} 23 | 24 | A) {A} 25 | B) {B} 26 | C) {C} 27 | D) {D} 28 | 29 | Let's think step by step.""" 30 | 31 | class ArenaHardEvaluator(BaseEvaluator): 32 | def __init__(self, max_workers: int = 8, mode: str="test"): 33 | super().__init__(max_workers=max_workers, mode=mode) 34 | self.task = "ArenaHard" 35 | self.seed = 42 36 | 37 | def load_data(self, split="v2"): 38 | if split == "v2": 39 | data = self.load_jsonl(os.path.join(DATA_DIR, f"arena-hard-v2.jsonl")) 40 | else: 41 | raise ValueError(f"Invalid split: {split}") 42 | 43 | data = Dataset.from_list(data) 44 | logger.info(data) 45 | 46 | return data 47 | 48 | def format_prompt(self, item: Dict) -> Dict: 49 | pass 50 | 51 | def extract_raw_answer(self, raw_datas: list[str]) -> list[str]: 52 | pass 53 | 54 | def process_output(self, question: str, output: GeneratorOutput): 55 | prediction = output.raw_output[0] 56 | return [ 57 | {"role": "user", "content": question}, 58 | # Add an 'answer' key to meet ArenaHard's specific requirements 59 | {"role": "assistant", "content": {"answer": prediction}} 60 | ] 61 | 62 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 63 | # step 1. get router, get generator 64 | uid = data['uid'] 65 | if uid != "6c69551e80664df5": 66 | router_result = router.route(question=data['prompt']) 67 | generator = GeneratorFactory.create_generator( 68 | experts=router_result, generator_config=generator_config 69 | ) # type: ignore 70 | 71 | output: GeneratorOutput = generator.generate(question=data['prompt']) 72 | model = generator.model 73 | else: 74 | logger.warning(f"Context length exceeded, unable to answer this question, uid: {uid}") 75 | output = GeneratorOutput( 76 | first_output="Context length exceeded, unable to answer this question", 77 | raw_output=["Context length exceeded, unable to answer this question"], 78 | prompt_tokens=0, 79 | completion_tokens=0 80 | ) 81 | model = "None" 82 | 83 | messages = self.process_output(question=data['prompt'], output=output) 84 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 85 | 86 | return dict( 87 | uid=data['uid'], 88 | category=data['category'], 89 | subcategory=data['subcategory'], 90 | language=data['language'], 91 | ans_id=index, 92 | messages=messages, 93 | model="avengers", 94 | model_name=model, 95 | tstamp=time.time() 96 | ) 97 | 98 | def save_records(self, records: list[dict]): 99 | # save records for evluate on ArenaHard official leaderboard 100 | # save to data/arena-hard/avengers-{timestamp}.jsonl 101 | import json 102 | from datetime import datetime 103 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 104 | with open(os.path.join(DATA_DIR, f"avengers-{timestamp}.jsonl"), "w") as f: 105 | for record in records: 106 | f.write(json.dumps(record) + "\n") 107 | 108 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 109 | start_time = time.time() 110 | data = self.load_data(split="v2") 111 | 112 | results = [] 113 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 114 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 115 | futures = [ 116 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 117 | for idx, d in enumerate(data) 118 | ] 119 | for future in as_completed(futures): 120 | result = future.result() 121 | results.append(result) 122 | pbar.update(1) 123 | pbar.close() 124 | 125 | model_counts = self.calculate_model_counts(results=results) 126 | 127 | end_time = time.time() 128 | logger.info(f"Task: {self.task}") 129 | logger.info(f"Time taken: {end_time - start_time} seconds") 130 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 131 | logger.info(f"Completion tokens: {self.completion_tokens}") 132 | 133 | self.save_records(records=results) 134 | return { 135 | "time_taken": end_time - start_time, 136 | "prompt_tokens": self.prompt_tokens, 137 | "completion_tokens": self.completion_tokens, 138 | "model_counts": model_counts, 139 | "records": results, 140 | } -------------------------------------------------------------------------------- /core/inference/fastslow_generator.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor, as_completed 2 | from typing import List 3 | 4 | from loguru import logger 5 | from openai import NOT_GIVEN, OpenAI 6 | from tenacity import (retry, retry_if_exception_type, stop_after_attempt, 7 | wait_exponential) 8 | 9 | from core.experts.load_experts import Expert 10 | from core.inference.base_generator import BaseGenerator, GeneratorOutput 11 | 12 | 13 | class FastSlowGenerator(BaseGenerator): 14 | def __init__(self, fast_expert: Expert, slow_expert: Expert, generator_config: dict): 15 | self.name = self.__class__.__name__ 16 | # get clients and models from fast and slow experts 17 | self.fast_expert = fast_expert 18 | self.slow_expert = slow_expert 19 | self.fast_client = fast_expert.client 20 | self.fast_model = fast_expert.model_name 21 | self.slow_client = slow_expert.client 22 | self.slow_model = slow_expert.model_name 23 | self.model = [self.fast_model] 24 | 25 | # get fast and slow config 26 | self.config = generator_config 27 | self.fast_samples = self.config.get("fast_samples", 10) 28 | self.fast_temperature = self.config.get("fast_temperature", 0.7) 29 | self.fast_top_p = self.config.get("fast_top_p", 1.0) 30 | self.slow_samples = self.config.get("slow_samples", 1) 31 | self.slow_temperature = self.config.get("slow_temperature", 0.7) 32 | self.slow_top_p = self.config.get("slow_top_p", 1.0) 33 | self.consistency_rate_threshold = self.config.get("consistency_rate_threshold", 0.8) 34 | 35 | self.fast_max_retries = 20 36 | self.slow_max_retries = 3 37 | # define final results 38 | self.fast_results = None 39 | self.slow_results = None 40 | self.final_results = None 41 | 42 | def _generate_with_retry(self, client: OpenAI, model: str, question: str, mode: str = "fast") -> GeneratorOutput: 43 | max_retries = self.fast_max_retries if mode == "fast" else self.slow_max_retries 44 | 45 | def _log_retry(retry_state): 46 | exception = retry_state.outcome.exception() 47 | if exception: 48 | logger.warning(f"Retrying FastSlowGenerator.generate due to error: {str(exception)}. Attempt {retry_state.attempt_number}/{retry_state.retry_object.stop.max_attempt_number}") 49 | return None 50 | 51 | @retry( 52 | stop=stop_after_attempt(max_retries), 53 | wait=wait_exponential(multiplier=1, min=2, max=60), 54 | retry=retry_if_exception_type((Exception)), 55 | before_sleep=_log_retry 56 | ) 57 | def _generate_impl(): 58 | temperature = self.fast_temperature if mode == "fast" else self.slow_temperature 59 | top_p = self.fast_top_p if mode == "fast" else self.slow_top_p 60 | samples = self.fast_samples if mode == "fast" else self.slow_samples 61 | timeout = 2_000 if mode == "fast" else 200_000 62 | if mode == "slow": 63 | slow_prompt = "\nDon't make your reasoning and thinking too long.\n" 64 | question_with_prompt = question + slow_prompt 65 | else: 66 | question_with_prompt = question 67 | 68 | try: 69 | response = client.chat.completions.create( 70 | model=model, 71 | messages=[{"role": "user", "content": question_with_prompt}], 72 | temperature=temperature, 73 | top_p=top_p, 74 | n=samples, 75 | timeout=timeout, 76 | ) 77 | choices = response.choices 78 | usage = response.usage 79 | raw_output = [choice.message.content for choice in choices] 80 | assert len(raw_output) == samples, f"Mode={mode}, Expected {samples} samples, got {len(raw_output)}" 81 | 82 | return GeneratorOutput( 83 | first_output=choices[0].message.content, 84 | raw_output=raw_output, 85 | prompt_tokens=usage.prompt_tokens, 86 | completion_tokens=usage.completion_tokens 87 | ) 88 | except Exception as e: 89 | logger.warning(f"Error in FastSlowGenerator._generate: {str(e)}, error model: {model}, mode: {mode}") 90 | raise 91 | 92 | return _generate_impl() 93 | 94 | def generate(self, question: str) -> GeneratorOutput: 95 | self.fast_results = self._generate_with_retry(self.fast_client, self.fast_model, question, "fast") 96 | return self.fast_results 97 | 98 | def slow_generate(self, question: str) -> GeneratorOutput: 99 | try: 100 | self.slow_results = self._generate_with_retry(self.slow_client, self.slow_model, question, "slow") 101 | self.final_results = GeneratorOutput( 102 | first_output = self.slow_results.first_output, 103 | raw_output = self.slow_results.raw_output, 104 | prompt_tokens = self.slow_results.prompt_tokens + self.fast_results.prompt_tokens, 105 | completion_tokens = self.slow_results.completion_tokens + self.fast_results.completion_tokens 106 | ) 107 | self.model = [self.slow_model] 108 | return self.final_results 109 | except Exception as e: 110 | logger.warning(f"Slow model generation failed after all retries: {str(e)}. Falling back to fast model results.") 111 | return self.fast_results -------------------------------------------------------------------------------- /evaluate/MBPP/mbpp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict, List 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | from evaluate.MBPP.execution import check_correctness 16 | from evaluate.MBPP.utils import imports, sanitize 17 | 18 | disable_progress_bars() 19 | 20 | DATA_DIR = "data/MBPP" 21 | 22 | PROMPT = """You are an expert Python programmer, and here is your task: 23 | {question} 24 | 25 | Your code should pass these tests: 26 | {test} 27 | """.strip() 28 | 29 | 30 | class MBPPEvaluator(BaseEvaluator): 31 | def __init__(self, max_workers:int = 8, mode: str="test"): 32 | super().__init__(max_workers=max_workers, mode=mode) 33 | self.task = "MBPP" 34 | self.seed = 42 35 | self.imports = imports 36 | 37 | def load_data(self, split: str): 38 | data = self.load_jsonl(os.path.join(DATA_DIR, f"{split}.json")) 39 | 40 | data = Dataset.from_list(data) 41 | data = data.map(lambda x: self.format_prompt(x)) 42 | 43 | if self.mode == "test": 44 | # split data into train and test 45 | logger.warning(f"Split data into train and test for {self.task}") 46 | split_data = data.train_test_split(test_size=0.3) 47 | train_data = split_data["train"] 48 | data = split_data["test"] 49 | logger.info(f"Calibration data: {len(train_data)}") 50 | logger.info(f"Test data: {len(data)}") 51 | 52 | return data 53 | 54 | def format_prompt(self, item: Dict) -> Dict: 55 | prompt = PROMPT.format( 56 | question=item['text'], 57 | test="\n".join(item['test_list']) 58 | ) 59 | return {"task_prompt": prompt} 60 | 61 | def extract_code_answer(self, text: str, test_list: List[str]) -> str: 62 | extract_code = sanitize(text) 63 | code = "\n".join(self.imports) + "\n" + extract_code + "\n" + "\n".join(test_list) 64 | 65 | return code 66 | 67 | def extract_raw_answer(self, raw_datas: list[str], test_list: List[str]) -> list[str]: 68 | extracted_answer = [] 69 | for data in raw_datas: 70 | answer = self.extract_code_answer(text=data, test_list=test_list) 71 | if answer is None: 72 | answer = "" 73 | extracted_answer.append(answer) 74 | 75 | return extracted_answer 76 | 77 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 78 | # step 1. get router, get generator 79 | router_result = router.route(question=data['task_prompt']) 80 | generator = GeneratorFactory.create_generator( 81 | experts=router_result, generator_config=generator_config 82 | ) # type: ignore 83 | 84 | # step 2. generate & update token usage 85 | output: GeneratorOutput = generator.generate(question=data['task_prompt']) 86 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 87 | 88 | # step 3. extract answer 89 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output, test_list=data['test_list']) 90 | 91 | # step 4. TODO: code majority voting (do not support know.) 92 | prediction = full_prediction[0] 93 | is_correct = check_correctness(task_id = data['task_id'], completion_id=0, solution=prediction, time_out=10)['passed'] 94 | 95 | return dict( 96 | index=index, 97 | query=data['task_prompt'], 98 | origin_query=data['text'], 99 | prediction=prediction, 100 | full_prediction=full_prediction, 101 | raw_output=output.raw_output, 102 | answer=None, 103 | is_correct=is_correct, 104 | model_name=generator.model 105 | ) 106 | 107 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 108 | start_time = time.time() 109 | data = self.load_data(split="test") 110 | counter = 0 111 | results = [] 112 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 113 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 114 | futures = [ 115 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 116 | for idx, d in enumerate(data) 117 | ] 118 | for future in as_completed(futures): 119 | result = future.result() 120 | results.append(result) 121 | if result['is_correct']: 122 | counter += 1 123 | pbar.update(1) 124 | pbar.close() 125 | 126 | models = [result['model_name'] for result in results] 127 | 128 | model_counts = self.calculate_model_counts(results=results) 129 | logger.info(model_counts) 130 | 131 | acc = counter / len(data) 132 | end_time = time.time() 133 | logger.info(f"Task: {self.task}") 134 | logger.info(f"Accuracy: {acc}") 135 | logger.info(f"Time taken: {end_time - start_time} seconds") 136 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 137 | logger.info(f"Completion tokens: {self.completion_tokens}") 138 | 139 | return { 140 | "performance": acc, 141 | "time_taken": end_time - start_time, 142 | "prompt_tokens": self.prompt_tokens, 143 | "completion_tokens": self.completion_tokens, 144 | "model_counts": model_counts, 145 | "records": results, 146 | } 147 | -------------------------------------------------------------------------------- /evaluate/StudentEval/studenteval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict, List 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | from evaluate.StudentEval.execution import check_correctness 16 | from evaluate.StudentEval.utils import imports, sanitize 17 | 18 | disable_progress_bars() 19 | 20 | DATA_DIR = "data/studenteval" 21 | 22 | PROMPT = """You are an expert Python programmer, and here is your task: 23 | {question} 24 | 25 | Your code should pass these tests: 26 | {test} 27 | """.strip() 28 | 29 | 30 | class StudentEvalEvaluator(BaseEvaluator): 31 | def __init__(self, max_workers:int = 8, mode: str="test"): 32 | super().__init__(max_workers=max_workers, mode=mode) 33 | self.task = "StudentEval" 34 | self.seed = 42 35 | self.imports = imports 36 | 37 | def load_data(self, split: str): 38 | data = self.load_jsonl(os.path.join(DATA_DIR, f"{split}.json")) 39 | 40 | data = Dataset.from_list(data) 41 | data = data.map(lambda x: self.format_prompt(x)) 42 | 43 | if self.mode == "test": 44 | # split data into train and test 45 | logger.warning(f"Split data into train and test for {self.task}") 46 | split_data = data.train_test_split(test_size=0.3) 47 | train_data = split_data["train"] 48 | data = split_data["test"] 49 | logger.info(f"Calibration data: {len(train_data)}") 50 | logger.info(f"Test data: {len(data)}") 51 | 52 | return data 53 | 54 | def format_prompt(self, item: Dict) -> Dict: 55 | prompt = PROMPT.format( 56 | question=item['text'], 57 | test="\n".join(item['test_list']) 58 | ) 59 | return {"task_prompt": prompt} 60 | 61 | def extract_code_answer(self, text: str, test_list: List[str]) -> str: 62 | extract_code = sanitize(text) 63 | code = "\n".join(self.imports) + "\n" + extract_code + "\n" + "\n".join(test_list) 64 | 65 | return code 66 | 67 | def extract_raw_answer(self, raw_datas: list[str], test_list: List[str]) -> list[str]: 68 | extracted_answer = [] 69 | for data in raw_datas: 70 | answer = self.extract_code_answer(text=data, test_list=test_list) 71 | if answer is None: 72 | answer = "" 73 | extracted_answer.append(answer) 74 | 75 | return extracted_answer 76 | 77 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 78 | # step 1. get router, get generator 79 | router_result = router.route(question=data['task_prompt']) 80 | generator = GeneratorFactory.create_generator( 81 | experts=router_result, generator_config=generator_config 82 | ) # type: ignore 83 | 84 | # step 2. generate & update token usage 85 | output: GeneratorOutput = generator.generate(question=data['task_prompt']) 86 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 87 | 88 | # step 3. extract answer 89 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output, test_list=data['test_list']) 90 | # step 4. TODO: code majority voting (do not support know.) 91 | prediction = full_prediction[0] 92 | is_correct = check_correctness(task_id = index, completion_id=0, solution=prediction, time_out=10)['passed'] 93 | 94 | return dict( 95 | index=index, 96 | query=data['task_prompt'], 97 | origin_query=data['text'], 98 | prediction=prediction, 99 | full_prediction=full_prediction, 100 | raw_output=output.raw_output, 101 | answer=None, 102 | is_correct=is_correct, 103 | model_name=generator.model 104 | ) 105 | 106 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 107 | start_time = time.time() 108 | data = self.load_data(split="test") 109 | counter = 0 110 | results = [] 111 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 112 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 113 | futures = [ 114 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 115 | for idx, d in enumerate(data) 116 | ] 117 | for future in as_completed(futures): 118 | result = future.result() 119 | results.append(result) 120 | if result['is_correct']: 121 | counter += 1 122 | pbar.update(1) 123 | pbar.close() 124 | 125 | models = [result['model_name'] for result in results] 126 | 127 | model_counts = self.calculate_model_counts(results=results) 128 | logger.info(model_counts) 129 | 130 | acc = counter / len(data) 131 | end_time = time.time() 132 | logger.info(f"Task: {self.task}") 133 | logger.info(f"Accuracy: {acc}") 134 | logger.info(f"Time taken: {end_time - start_time} seconds") 135 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 136 | logger.info(f"Completion tokens: {self.completion_tokens}") 137 | 138 | return { 139 | "performance": acc, 140 | "time_taken": end_time - start_time, 141 | "prompt_tokens": self.prompt_tokens, 142 | "completion_tokens": self.completion_tokens, 143 | "model_counts": model_counts, 144 | "records": results, 145 | } 146 | -------------------------------------------------------------------------------- /evaluate/HumanEval/humaneval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict, List 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | from evaluate.HumanEval.execution import check_correctness 16 | from evaluate.HumanEval.utils import imports, sanitize 17 | 18 | disable_progress_bars() 19 | 20 | DATA_DIR = "data/HumanEval" 21 | 22 | PROMPT = """You are an expert Python programmer, and here is your task: 23 | {question} 24 | 25 | Your code should pass these tests: 26 | {test} 27 | """.strip() 28 | 29 | ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" 30 | 31 | class HumanEvalEvaluator(BaseEvaluator): 32 | def __init__(self, max_workers:int = 8, mode: str="test"): 33 | super().__init__(max_workers=max_workers, mode=mode) 34 | self.task = "HumanEval" 35 | self.seed = 42 36 | self.imports = imports 37 | 38 | def load_data(self, split: str): 39 | data = self.load_jsonl(os.path.join(DATA_DIR, f"HumanEval.jsonl")) 40 | 41 | data = Dataset.from_list(data) 42 | data = data.map(lambda x: self.format_prompt(x)) 43 | 44 | if self.mode == "test": 45 | # split data into train and test 46 | logger.warning(f"Split data into train and test for {self.task}") 47 | split_data = data.train_test_split(test_size=0.3) 48 | train_data = split_data["train"] 49 | data = split_data["test"] 50 | return data 51 | 52 | def format_prompt(self, item: Dict): 53 | # answer key: Answer 54 | prompt = PROMPT.format( 55 | question=item["prompt"], 56 | test=item["test"] 57 | ) 58 | return {"task_prompt": prompt} 59 | 60 | def extract_raw_answer(self, raw_datas: list[str], test: str, entry_point: str) -> list[str]: 61 | extracted_answer = [] 62 | for data in raw_datas: 63 | answer = self.extract_code_answer(text=data, test=test, entry_point=entry_point) 64 | if answer is None: 65 | answer = "" 66 | extracted_answer.append(answer) 67 | 68 | return extracted_answer 69 | 70 | def extract_code_answer(self, text: str, test: str, entry_point: str) -> str: 71 | extract_code = sanitize(text) 72 | code = "\n".join(self.imports) + "\n" + extract_code + "\n" + test + "\n" + f"check({entry_point})" 73 | 74 | return code 75 | 76 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 77 | # step 1. get router, get generator 78 | router_result = router.route(question=data['task_prompt']) 79 | generator = GeneratorFactory.create_generator( 80 | experts=router_result, generator_config=generator_config 81 | ) # type: ignore 82 | 83 | # step 2. generate & update token usage 84 | output: GeneratorOutput = generator.generate(question=data['task_prompt']) 85 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 86 | 87 | # step 3. extract answer 88 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output, test=data['test'], entry_point=data['entry_point']) 89 | 90 | # step 4. TODO: code majority voting 91 | prediction = full_prediction[0] 92 | is_correct = check_correctness(task_id = data['task_id'], completion_id=0, solution=prediction, time_out=10)['passed'] 93 | 94 | return dict( 95 | index=index, 96 | query=data['task_prompt'], 97 | origin_query=data['prompt'], 98 | prediction=prediction, 99 | full_prediction=full_prediction, 100 | raw_output=output.raw_output, 101 | answer=None, 102 | is_correct=is_correct, 103 | model_name=generator.model 104 | ) 105 | 106 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 107 | start_time = time.time() 108 | data = self.load_data(split="test") 109 | 110 | counter = 0 111 | results = [] 112 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 113 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 114 | futures = [ 115 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 116 | for idx, d in enumerate(data) 117 | ] 118 | for future in as_completed(futures): 119 | result = future.result() 120 | results.append(result) 121 | if result['is_correct']: 122 | counter += 1 123 | pbar.update(1) 124 | pbar.close() 125 | 126 | models = [result['model_name'] for result in results] 127 | 128 | model_counts = self.calculate_model_counts(results=results) 129 | logger.info(model_counts) 130 | 131 | acc = counter / len(data) 132 | end_time = time.time() 133 | logger.info(f"Task: {self.task}") 134 | logger.info(f"Accuracy: {acc}") 135 | logger.info(f"Time taken: {end_time - start_time} seconds") 136 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 137 | logger.info(f"Completion tokens: {self.completion_tokens}") 138 | 139 | return { 140 | "performance": acc, 141 | "time_taken": end_time - start_time, 142 | "prompt_tokens": self.prompt_tokens, 143 | "completion_tokens": self.completion_tokens, 144 | "model_counts": model_counts, 145 | "records": results, 146 | } 147 | -------------------------------------------------------------------------------- /evaluate/SimpleQA/prompts.py: -------------------------------------------------------------------------------- 1 | GRADER_TEMPLATE = """Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"]. 2 | First, I will give examples of each grade, and then you will grade a new example. 3 | 4 | 5 | The following are examples of CORRECT predicted answers. 6 | ``` 7 | Question: What are the names of Barack Obama's children? 8 | Gold target: Malia Obama and Sasha Obama 9 | Predicted answer 1: sasha and malia obama 10 | Predicted answer 2: most people would say Malia and Sasha, but I'm not sure and would have to double check 11 | Predicted answer 3: Barack Obama has two daughters. Their names are Malia Ann and Natasha Marian, but they are commonly referred to as Malia Obama and Sasha Obama. Malia was born on July 4, 1998, and Sasha was born on June 10, 2001. 12 | ``` 13 | These predicted answers are all CORRECT because: 14 | - They fully contain the important information in the gold target. 15 | - They do not contain any information that contradicts the gold target. 16 | - Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter. 17 | - Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions. 18 | 19 | 20 | The following are examples of INCORRECT predicted answers. 21 | ``` 22 | Question: What are the names of Barack Obama's children? 23 | Gold target: Malia and Sasha 24 | Predicted answer 1: Malia. 25 | Predicted answer 2: Malia, Sasha, and Susan. 26 | Predicted answer 3: Barack Obama does not have any children. 27 | Predicted answer 4: I think it's either Malia and Sasha. Or it could be Malia and Jackie. Or it could be Joey and Malia. 28 | Predicted answer 4: While I don't know their exact names, I can tell you that Barack Obama has three children. 29 | Predicted answer 5: It's possible you may mean Betsy and Olivia. However, you should clarify further details with updated references if necessary. Is that the correct answer? 30 | Predicted answer 6: It may be the case that Obama's child is named James. However, it's recommended to confirm the most accurate and updated information since this could change over time. This model may not always reflect the most current information. 31 | ``` 32 | These predicted answers are all INCORRECT because: 33 | - A factual statement in the answer contradicts the gold target. Incorrect statements that have some hedging (e.g., "it is possible that", "although i'm not sure, i think") are also considered incorrect. 34 | 35 | 36 | The following are examples of NOT_ATTEMPTED predicted answers. 37 | ``` 38 | Question: What are the names of Barack Obama's children? 39 | Gold target: Malia and Sasha 40 | Predicted answer 1: I don't know. 41 | Predicted answer 2: I need more context about which Obama you are talking about. 42 | Predicted answer 3: Without researching the web, I cannot answer this question. However, I can tell you that Barack Obama has two children. 43 | Predicted answer 4: Barack Obama has two children. I know that one of them is Malia, but I'm not sure about the other one. 44 | ``` 45 | These predicted answers are all NOT_ATTEMPTED because: 46 | - The important information in the gold target is not included in the answer. 47 | - No statements in the answer contradict the gold target. 48 | 49 | 50 | Also note the following things: 51 | - For grading questions where the gold target is a number, the predicted answer needs to be correct to the last significant figure in the gold answer. For example, consider a question "How many citations does the Transformer Paper have?" with gold target "120k". 52 | - Predicted answers "120k", "124k", and 115k" are all CORRECT. 53 | - Predicted answers "100k" and "113k" are INCORRECT. 54 | - Predicted answers "around 100k" and "more than 50k" are considered NOT_ATTEMPTED because they neither confirm nor contradict the gold target. 55 | - The gold target may contain more information than the question. In such cases, the predicted answer only needs to contain the information that is in the question. 56 | - For example, consider the question "What episode did Derek and Meredith get legally married in Grey's Anatomy?" with gold target "Season 7, Episode 20: White Wedding". Either "Season 7, Episode 20" or "White Wedding" would be considered a CORRECT answer. 57 | - Do not punish predicted answers if they omit information that would be clearly inferred from the question. 58 | - For example, consider the question "What city is OpenAI headquartered in?" and the gold target "San Francisco, California". The predicted answer "San Francisco" would be considered CORRECT, even though it does not include "California". 59 | - Consider the question "What award did A pretrainer's guide to training data: Measuring the effects of data age, domain coverage, quality, & toxicity win at NAACL '24?", the gold target is "Outstanding Paper Award". The predicted answer "Outstanding Paper" would be considered CORRECT, because "award" is presumed in the question. 60 | - For the question "What is the height of Jason Wei in meters?", the gold target is "1.73 m". The predicted answer "1.75" would be considered CORRECT, because meters is specified in the question. 61 | - For the question "What is the name of Barack Obama's wife?", the gold target is "Michelle Obama". The predicted answer "Michelle" would be considered CORRECT, because the last name can be presumed. 62 | - Do not punish for typos in people's name if it's clearly the same name. 63 | - For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung". 64 | 65 | 66 | Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. 67 | ``` 68 | Question: {question} 69 | Gold target: {target} 70 | Predicted answer: {predicted_answer} 71 | ``` 72 | 73 | Grade the predicted answer of this new question as one of: 74 | A: CORRECT 75 | B: INCORRECT 76 | C: NOT_ATTEMPTED 77 | 78 | Just return the letters "A", "B", or "C", with no text around it. 79 | """.strip() -------------------------------------------------------------------------------- /evaluate/factory.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from evaluate.AIME import AIMEEvaluator 4 | from evaluate.GPQA import GPQAEvaluator 5 | from evaluate.MATH500 import MATH500Evaluator 6 | from evaluate.MedQA import MedQAEvaluator 7 | from evaluate.MMLUPro import MMLUProEvaluator 8 | from evaluate.EmoryNLP import EmoryNLPEvaluator 9 | from evaluate.HumanEval import HumanEvalEvaluator 10 | from evaluate.K_and_K import KnightsAndKnavesEvaluator 11 | from evaluate.FinQA import FinQAEvaluator 12 | from evaluate.MBPP import MBPPEvaluator 13 | from evaluate.ARCC import ARCCEvaluator 14 | from evaluate.Winogrande import WinograndeEvaluator 15 | from evaluate.BBH import BBHEvaluator 16 | from evaluate.MATHBench import MathBenchEvaluator 17 | from evaluate.LiveMathBench import LiveMathBenchEvaluator 18 | from evaluate.MELD import MELDEvaluator 19 | from evaluate.LiveCodeBench import LiveCodeBenchEvaluator 20 | from evaluate.KORBench import KORBenchEvaluator 21 | from evaluate.ArenaHard import ArenaHardEvaluator 22 | from evaluate.TruthfulQA import TruthfulQAEvaluator 23 | from evaluate.DailyDialog import DailyDialogEvaluator 24 | from evaluate.StudentEval import StudentEvalEvaluator 25 | from evaluate.BrainTeaser import BrainTeaserEvaluator 26 | 27 | class Benchmark(Enum): 28 | # math 29 | AIMETOTAL = 'aime_total' 30 | AIME2024 = 'aime2024' 31 | AIME2025 = 'aime2025' 32 | AIME = 'aime' 33 | MATH500 = 'math500' 34 | LIVEMATHBENCH = 'livemathbench' 35 | # mmlu 36 | MMLUPro = 'mmlupro' 37 | # emotion 38 | EmoryNLP = 'emorynlp' 39 | MELD = 'meld' 40 | # code 41 | HumanEval = 'humaneval' 42 | MBPP = 'mbpp' 43 | # logical 44 | KnightsAndKnaves = 'kandk' 45 | BBH = 'bbh' 46 | KORBench = 'korbench' 47 | # QA 48 | FinQA = 'finqa' 49 | MedQA = 'medqa' 50 | GPQA = 'gpqa' 51 | ARCC = 'arcc' 52 | SimpleQA = 'simpleqa' 53 | # Out of distribution 54 | TruthfulQA = 'truthfulqa' # knowledge 55 | MATHBENCH = 'mathbench' # math 56 | LiveCodeBench = 'livecodebench' # code 57 | Winogrande = 'winogrande' # logic 58 | DailyDialog = 'dailydialog' # Affective Computing 59 | StudentEval = 'studenteval' # code 60 | BrainTeaser = 'brainteaser' # logic 61 | # arenahard 62 | ArenaHard = 'arenahard' 63 | 64 | class EvaluatorFactory: 65 | def __init__(self, max_workers: int=8, mode: str="test"): 66 | self.max_workers = max_workers 67 | assert mode in ["test", "full"], f"Invalid mode: {mode}, mode should be in ['test', 'full']" 68 | self.mode = mode 69 | 70 | def get_evaluator(self, task: str | Benchmark): 71 | if isinstance(task, str): 72 | task = Benchmark(task) 73 | 74 | if not isinstance(task, Benchmark): 75 | raise TypeError(f"Invalid task type: {type(task)}, task: {task}") 76 | 77 | # AIME 78 | if task == Benchmark.AIME: 79 | return AIMEEvaluator(split='hybrid', max_workers=self.max_workers, mode=self.mode) 80 | elif task == Benchmark.AIME2024: 81 | return AIMEEvaluator(split='2024', max_workers=self.max_workers, mode=self.mode) 82 | elif task == Benchmark.AIME2025: 83 | return AIMEEvaluator(split='2025', max_workers=self.max_workers, mode=self.mode) 84 | elif task == Benchmark.AIMETOTAL: 85 | return AIMEEvaluator(split='total', max_workers=self.max_workers, mode=self.mode) 86 | # MATH 87 | elif task == Benchmark.MATH500: 88 | return MATH500Evaluator(max_workers=self.max_workers, mode=self.mode) 89 | # MATHBENCH 90 | elif task == Benchmark.MATHBENCH: 91 | return MathBenchEvaluator(max_workers=self.max_workers, mode=self.mode) 92 | # LIVEMATHBENCH 93 | elif task == Benchmark.LIVEMATHBENCH: 94 | return LiveMathBenchEvaluator(max_workers=self.max_workers, mode=self.mode) 95 | # MMLUPro 96 | elif task == Benchmark.MMLUPro: 97 | return MMLUProEvaluator(split="test", max_workers=self.max_workers, mode=self.mode) 98 | elif task == Benchmark.MedQA: 99 | return MedQAEvaluator(max_workers=self.max_workers, mode=self.mode) 100 | elif task == Benchmark.GPQA: 101 | return GPQAEvaluator(max_workers=self.max_workers, mode=self.mode) 102 | # Affective Computing 103 | elif task == Benchmark.EmoryNLP: 104 | return EmoryNLPEvaluator(max_workers=self.max_workers, mode=self.mode) 105 | elif task == Benchmark.MELD: 106 | return MELDEvaluator(max_workers=self.max_workers, mode=self.mode) 107 | # Code Generation 108 | elif task == Benchmark.HumanEval: 109 | return HumanEvalEvaluator(max_workers=self.max_workers, mode=self.mode) 110 | elif task == Benchmark.MBPP: 111 | return MBPPEvaluator(max_workers=self.max_workers, mode=self.mode) 112 | elif task == Benchmark.LiveCodeBench: 113 | return LiveCodeBenchEvaluator(max_workers=self.max_workers, mode=self.mode) 114 | elif task == Benchmark.StudentEval: 115 | return StudentEvalEvaluator(max_workers=self.max_workers, mode=self.mode) 116 | elif task == Benchmark.KnightsAndKnaves: 117 | return KnightsAndKnavesEvaluator(max_workers=self.max_workers, mode=self.mode) 118 | elif task == Benchmark.BBH: 119 | return BBHEvaluator(max_workers=self.max_workers, mode=self.mode) 120 | elif task == Benchmark.KORBench: 121 | return KORBenchEvaluator(max_workers=self.max_workers, mode=self.mode) 122 | elif task == Benchmark.FinQA: 123 | return FinQAEvaluator(max_workers=self.max_workers, mode=self.mode) 124 | elif task == Benchmark.ARCC: 125 | return ARCCEvaluator(max_workers=self.max_workers, mode=self.mode) 126 | elif task == Benchmark.Winogrande: 127 | return WinograndeEvaluator(max_workers=self.max_workers, mode=self.mode) 128 | elif task == Benchmark.TruthfulQA: 129 | return TruthfulQAEvaluator(max_workers=self.max_workers, mode=self.mode) 130 | elif task == Benchmark.ArenaHard: 131 | return ArenaHardEvaluator(max_workers=self.max_workers, mode=self.mode) 132 | elif task == Benchmark.DailyDialog: 133 | return DailyDialogEvaluator(max_workers=self.max_workers, mode=self.mode) 134 | elif task == Benchmark.BrainTeaser: 135 | return BrainTeaserEvaluator(max_workers=self.max_workers, mode=self.mode) 136 | else: 137 | raise ValueError(f"Invalid task: {task}") 138 | -------------------------------------------------------------------------------- /core/inference/slowfast_generator.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor, as_completed 2 | from typing import List 3 | 4 | from loguru import logger 5 | from openai import NOT_GIVEN, OpenAI 6 | from tenacity import (retry, retry_if_exception_type, stop_after_attempt, 7 | wait_exponential) 8 | 9 | from core.experts.load_experts import Expert 10 | from core.inference.base_generator import BaseGenerator, GeneratorOutput 11 | 12 | 13 | SLOW_PROMPT="""Your goal is to provide a clear and detailed approach (within 2000 tokens) for solving the given question. 14 | 15 | Question: {question} 16 | 17 | Provide your step-by-step reasoning strategy, clearly outlining: 18 | 1. How you interpret the question and its key components. 19 | 2. What background knowledge or concepts are relevant. 20 | 3. How you will structure your solution approach. 21 | 22 | Note: Do NOT solve the question itself; only explain your problem-solving thought process clearly and concisely.""" 23 | 24 | 25 | FAST_PROMPT="""You are an agent tasked with solving the provided question step-by-step. 26 | 27 | Question: {question} 28 | 29 | Reasoning strategy (may be incomplete): {reasoning_steps} 30 | 31 | Now, answer the question step by step.""" 32 | 33 | class SlowFastGenerator(BaseGenerator): 34 | def __init__(self, fast_expert: Expert, slow_expert: Expert, generator_config: dict): 35 | self.name = self.__class__.__name__ 36 | # get clients and models from fast and slow experts 37 | self.fast_expert = fast_expert 38 | self.slow_expert = slow_expert 39 | self.fast_client = fast_expert.client 40 | self.fast_model = fast_expert.model_name 41 | self.slow_client = slow_expert.client 42 | self.slow_model = slow_expert.model_name 43 | self.model = [self.fast_model] 44 | 45 | # get fast and slow config 46 | self.config = generator_config 47 | ## fast config 48 | self.fast_samples = self.config.get("fast_samples", 10) 49 | self.fast_temperature = self.config.get("fast_temperature", 0.2) 50 | self.fast_top_p = self.config.get("fast_top_p", 1.0) 51 | ## slow config 52 | self.slow_samples = self.config.get("slow_samples", 1) 53 | self.slow_temperature = self.config.get("slow_temperature", 0.7) 54 | self.slow_top_p = self.config.get("slow_top_p", 1.0) 55 | 56 | # define final results 57 | self.fast_results = None 58 | self.slow_results = None 59 | self.final_results = None 60 | 61 | def _log_retry(retry_state): 62 | exception = retry_state.outcome.exception() 63 | if exception: 64 | logger.warning(f"Retrying FastSlowGenerator.generate due to error: {str(exception)}. Attempt {retry_state.attempt_number}/{retry_state.retry_object.stop.max_attempt_number}") 65 | return None 66 | 67 | @retry( 68 | stop=stop_after_attempt(20), # 最多重试10次 69 | wait=wait_exponential(multiplier=1, min=10, max=120), # 指数退避策略:1*2^x 秒,最少2秒,最多100秒 70 | retry=retry_if_exception_type((Exception)), # 捕获所有异常进行重试 71 | before_sleep=_log_retry # 重试前记录日志 72 | ) 73 | def _generate(self, client: OpenAI, model: str, question: str, mode: str = "fast", reasoning_steps: str = None) -> GeneratorOutput: 74 | if mode == "fast": 75 | # fast 76 | assert reasoning_steps is not None 77 | temperature = self.fast_temperature 78 | top_p = self.fast_top_p 79 | samples = self.fast_samples 80 | timeout = 1_000 81 | max_tokens = NOT_GIVEN 82 | prompt = FAST_PROMPT.format(question=question, reasoning_steps=reasoning_steps) 83 | else: 84 | # slow 85 | temperature = self.slow_temperature 86 | top_p = self.slow_top_p 87 | samples = 1 88 | timeout = 10_000 89 | max_tokens = 2_500 90 | prompt = SLOW_PROMPT.format(question=question) 91 | 92 | try: 93 | response = client.chat.completions.create( 94 | model=model, 95 | messages=[{"role": "user", "content": prompt}], 96 | temperature=temperature, 97 | top_p=top_p, 98 | n=samples, 99 | max_tokens=max_tokens, 100 | timeout=timeout, 101 | ) 102 | choices = response.choices 103 | usage = response.usage 104 | raw_output = [choice.message.content for choice in choices] 105 | assert len(raw_output) == samples, f"Mode={mode}, Expected {samples} samples, got {len(raw_output)}" 106 | 107 | return GeneratorOutput( 108 | first_output=choices[0].message.content, 109 | raw_output=raw_output, 110 | prompt_tokens=usage.prompt_tokens, 111 | completion_tokens=usage.completion_tokens 112 | ) 113 | except Exception as e: 114 | logger.error(f"Error in FastSlowGenerator._generate: {str(e)}, error model: {model}, mode: {mode}") 115 | raise # 重新抛出异常,让重试装饰器捕获 116 | 117 | def generate(self, question: str) -> GeneratorOutput: 118 | # first use slow model get reasoning steps 119 | slow_output = self._generate(self.slow_client, self.slow_model, question, "slow") 120 | reasoning_steps = slow_output.first_output 121 | slow_prompt_tokens = slow_output.prompt_tokens 122 | slow_completion_tokens = slow_output.completion_tokens 123 | # then use fast model to get final answer 124 | fast_output = self._generate(self.fast_client, self.fast_model, question, "fast", reasoning_steps) 125 | 126 | fast_output.completion_tokens += slow_completion_tokens 127 | fast_output.prompt_tokens += slow_prompt_tokens 128 | fast_output.raw_output += [reasoning_steps] 129 | return fast_output 130 | 131 | def slow_generate(self, question: str) -> GeneratorOutput: 132 | self.slow_results = self._generate(self.slow_client, self.slow_model, question, "slow") 133 | 134 | if self.slow_results is None: 135 | return self.fast_results 136 | self.final_results = GeneratorOutput( 137 | first_output = self.slow_results.first_output, 138 | raw_output = self.slow_results.raw_output, 139 | prompt_tokens = self.slow_results.prompt_tokens + self.fast_results.prompt_tokens, 140 | completion_tokens = self.slow_results.completion_tokens + self.fast_results.completion_tokens 141 | ) 142 | self.model = [self.slow_model] 143 | return self.final_results -------------------------------------------------------------------------------- /core/routing/elo_router.py: -------------------------------------------------------------------------------- 1 | """elo_router.py 2 | 3 | A router that chooses the best experts for a query by combining 4 | • distance between the query embedding and pre‑computed K‑Means cluster centres 5 | • pre‑computed TrueSkill (or Elo) ratings for each model inside each cluster 6 | 7 | Training / scoring is **not** included – all heavy lifting is assumed 8 | already done offline. The router simply loads static artefacts at start‑up. 9 | 10 | Expected artefacts 11 | ----------------- 12 | centres.npy 13 | np.ndarray of shape (K, D) – cluster centres in the same embedding space. 14 | ratings.json 15 | { 16 | "0": { "model_A": {"mu": 25.1, "sigma": 0.71}, ... }, 17 | "1": { ... } 18 | } 19 | 20 | Config example (pass in `router_config["elo_router"]`): 21 | ``` 22 | { 23 | "centres_path": "data/centres.npy", 24 | "ratings_path": "data/ratings.json", 25 | "mapping_path": "data/map.json", 26 | "top_n": 3, 27 | "beta": 3.0, 28 | "k_sigma": 3, 29 | "default_mu": 25.0, 30 | "embedding_model": "text-embedding-3-small" 31 | } 32 | ``` 33 | 34 | """ 35 | from pathlib import Path 36 | import json 37 | import numpy as np 38 | from typing import List, Dict, Any 39 | 40 | from core.experts.load_experts import Expert 41 | from core.routing.base_router import BaseRouter, RouterOutput 42 | from diversity.embedding_cache import EmbeddingCache 43 | from sklearn.preprocessing import Normalizer 44 | import joblib 45 | 46 | class EloRouter(BaseRouter): 47 | """Cluster‑Aware Elo Router.""" 48 | 49 | def __init__( 50 | self, 51 | normal_experts: List[Expert], 52 | thinking_experts: List[Expert], 53 | router_config: Dict[str, Any], 54 | ) -> None: 55 | super().__init__(normal_experts, thinking_experts) 56 | 57 | cfg = router_config["elo_router"] 58 | # 1. load artefacts -------------------------------------------------- 59 | self.centres = np.load(Path(cfg["centres_path"])) # (K, D) 60 | self.ratings = self._load_ratings(cfg["ratings_path"]) # dict[int, dict[str, (mu,sigma)]] 61 | self.mapping = self._load_mapping(cfg["mapping_path"]) # expert id -> expert name , dict[str, str] 62 | self.normalizer = self._load_normalizer(cfg["normalizer_path"]) # Normalizer 63 | 64 | # 2. set hyperparameters -------------------------------------------- 65 | self.top_k: int = cfg.get("top_k", 11) 66 | self.top_n: int = cfg.get("top_n", 3) 67 | self.beta: float = cfg.get("beta", 3.0) # softmax temperature 68 | self.k_sigma: int = cfg.get("k_sigma", 3) 69 | self.default_mu: float = cfg.get("default_mu", 25.0) 70 | self.available_models = cfg.get("available_models") 71 | print(self.available_models) 72 | # 3. Build quick lookup: expert ID -> Expert object 73 | self.expert_map = {} 74 | for e in self.normal_experts: 75 | # Find the expert ID (e.g. M01) that maps to this expert's model name 76 | for expert_id, model_name in self.mapping.items(): 77 | if model_name == e.model_name: 78 | self.expert_map[expert_id] = e 79 | break 80 | self.availabel_models_id = [] 81 | for id, model in self.mapping.items(): 82 | if model in self.available_models: 83 | self.availabel_models_id.append(id) 84 | assert len(self.availabel_models_id) == len(self.available_models), f"Length of available models ({len(self.available_models)}) does not match length of available models id ({len(self.availabel_models_id)})" 85 | 86 | # Embedding helper – uses local cache automatically 87 | self.embedder = EmbeddingCache( 88 | base_url="http://172.30.12.113:8000/v1", 89 | api_key="sk-1234567890", 90 | model_name=cfg.get("embedding_model", "text-embedding-3-small")) 91 | 92 | # ------------------------------------------------------------------ 93 | # public API 94 | # ------------------------------------------------------------------ 95 | def route(self, question: str) -> RouterOutput: 96 | """Select Top‑N experts for *question* based on cluster probability × Elo.""" 97 | q_vec = np.array(self.embedder.get(question)) # (D,) 98 | q_vec = self.normalizer.transform([q_vec])[0] 99 | # 1. distance to each cluster centre 100 | dists = 1 - self.centres @ q_vec 101 | idx = np.argsort(dists)[:self.top_k] 102 | dists = dists[idx] 103 | # 2. softmax to probability (the smaller the dist, the larger P) 104 | 105 | logits = -self.beta * dists 106 | probs = np.exp(logits - logits.max()) 107 | probs /= probs.sum() # (K,) 108 | 109 | # 3. fuse with ratings 110 | scores: Dict[str, float] = {} 111 | for cid, p in zip(idx, probs): 112 | table = self.ratings.get(cid, {}) 113 | for model, (mu, sigma) in table.items(): 114 | if model in self.availabel_models_id: 115 | conservative = mu - self.k_sigma * sigma 116 | scores[model] = scores.get(model, 0.0) + p * conservative 117 | 118 | # 4. fill missing models with default mu 119 | for model in self.expert_map: 120 | scores.setdefault(model, self.default_mu) 121 | 122 | # 5. rank & pick 123 | ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True) 124 | chosen_models = [m for m, _ in ranked if m in self.expert_map][: self.top_n] 125 | chosen_experts = [self.expert_map[m] for m in chosen_models] 126 | 127 | return RouterOutput( 128 | normal_experts=chosen_experts, 129 | thinking_experts=self.thinking_experts, 130 | ) 131 | 132 | # ------------------------------------------------------------------ 133 | # helpers 134 | # ------------------------------------------------------------------ 135 | @staticmethod 136 | def _load_ratings(path: str | Path) -> Dict[int, Dict[str, tuple[float, float]]]: 137 | with open(path, "r", encoding="utf-8") as f: 138 | raw = json.load(f) 139 | parsed: Dict[int, Dict[str, tuple[float, float]]] = {} 140 | for cid_str, models in raw.items(): 141 | cid = int(cid_str) 142 | parsed[cid] = {m: (vals["mu"], vals["sigma"]) for m, vals in models.items()} 143 | return parsed 144 | 145 | @staticmethod 146 | def _load_mapping(path: str | Path) -> Dict[str, str]: 147 | with open(path, "r", encoding="utf-8") as f: 148 | return json.load(f) 149 | 150 | @staticmethod 151 | def _load_normalizer(path: str | Path) -> Normalizer: 152 | return joblib.load(path) -------------------------------------------------------------------------------- /core/routing/rank_router.py: -------------------------------------------------------------------------------- 1 | """elo_router.py 2 | 3 | A router that chooses the best experts for a query by combining 4 | • distance between the query embedding and pre‑computed K‑Means cluster centres 5 | • pre‑computed TrueSkill (or Elo) ratings for each model inside each cluster 6 | 7 | Training / scoring is **not** included – all heavy lifting is assumed 8 | already done offline. The router simply loads static artefacts at start‑up. 9 | 10 | Expected artefacts 11 | ----------------- 12 | centres.npy 13 | np.ndarray of shape (K, D) – cluster centres in the same embedding space. 14 | ratings.json 15 | { 16 | "0": { "model_A": {"mu": 25.1, "sigma": 0.71}, ... }, 17 | "1": { ... } 18 | } 19 | 20 | Config example (pass in `router_config["elo_router"]`): 21 | ``` 22 | { 23 | "centres_path": "data/centres.npy", 24 | "ratings_path": "data/ratings.json", 25 | "mapping_path": "data/map.json", 26 | "top_n": 3, 27 | "beta": 3.0, 28 | "k_sigma": 3, 29 | "default_mu": 25.0, 30 | "embedding_model": "text-embedding-3-small" 31 | } 32 | ``` 33 | 34 | """ 35 | from pathlib import Path 36 | import json 37 | import numpy as np 38 | from typing import List, Dict, Any 39 | 40 | from core.experts.load_experts import Expert 41 | from core.routing.base_router import BaseRouter, RouterOutput 42 | from diversity.embedding_cache import EmbeddingCache 43 | from sklearn.preprocessing import Normalizer 44 | import joblib 45 | 46 | class RankRouter(BaseRouter): 47 | """Cluster‑Aware Rank Router.""" 48 | 49 | def __init__( 50 | self, 51 | normal_experts: List[Expert], 52 | thinking_experts: List[Expert], 53 | router_config: Dict[str, Any], 54 | ) -> None: 55 | super().__init__(normal_experts, thinking_experts) 56 | 57 | cfg = router_config["rank_router"] 58 | # 1. load artefacts -------------------------------------------------- 59 | self.centres = np.load(Path(cfg["centres_path"])) # (K, D) 60 | self.rankings = self._load_rankings(cfg["rankings_path"]) # dict[int, dict[str, (mu,sigma)]] 61 | self.mapping = self._load_mapping(cfg["mapping_path"]) # expert id -> expert name , dict[str, str] 62 | self.normalizer = self._load_normalizer(cfg["normalizer_path"]) # Normalizer 63 | 64 | # 2. set hyperparameters -------------------------------------------- 65 | self.top_k: int = cfg.get("top_k", 11) 66 | self.top_n: int = cfg.get("top_n", 3) 67 | self.beta: float = cfg.get("beta", 3.0) 68 | self.default_rank: int = cfg.get("default_rank", 999) 69 | self.available_models = cfg.get("available_models") 70 | # 3. Build quick lookup: expert ID -> Expert object 71 | self.expert_map = {} 72 | for e in self.normal_experts: 73 | # Find the expert ID (e.g. M01) that maps to this expert's model name 74 | for expert_id, model_name in self.mapping.items(): 75 | if model_name == e.model_name: 76 | self.expert_map[expert_id] = e 77 | break 78 | self.available_models_id = [] 79 | for id, model in self.mapping.items(): 80 | if model in self.available_models: 81 | self.available_models_id.append(id) 82 | assert len(self.available_models_id) == len(self.available_models), f"Length of available models ({len(self.available_models)}) does not match length of available models id ({len(self.available_models_id)})" 83 | 84 | # Embedding helper – uses local cache automatically 85 | self.embedder = EmbeddingCache( 86 | base_url="http://172.30.12.113:8000/v1", 87 | api_key="sk-1234567890", 88 | model_name=cfg.get("embedding_model", "gte-qwen2-7b-instruct")) 89 | 90 | # ------------------------------------------------------------------ 91 | # public API 92 | # ------------------------------------------------------------------ 93 | def route(self, question: str) -> RouterOutput: 94 | """Select Top‑N experts for *question* based on cluster probability × Elo.""" 95 | q_vec = np.array(self.embedder.get(question)) # (D,) 96 | q_vec = self.normalizer.transform([q_vec])[0] 97 | # 1. distance to each cluster centre 98 | dists = 1 - self.centres @ q_vec 99 | idx = np.argsort(dists)[:self.top_k] 100 | dists = dists[idx] 101 | # 2. softmax to probability (the smaller the dist, the larger P) 102 | 103 | logits = -self.beta * dists 104 | probs = np.exp(logits - logits.max()) 105 | probs /= probs.sum() # (K,) 106 | 107 | # 3. fuse with ratings 108 | scores: Dict[str, float] = {} 109 | for cid, p in zip(idx, probs): 110 | try: 111 | if cid not in self.rankings: 112 | continue 113 | cluster_info = self.rankings[cid] 114 | ranking = cluster_info['ranking'] 115 | 116 | for model_id in self.available_models_id: 117 | rank = ranking.index(model_id) 118 | rank_score = 1.0 / (rank + 0.1) 119 | scores[model_id] = scores.get(model_id, 0.0) + p * rank_score 120 | except Exception as e: 121 | continue 122 | 123 | for model_id in self.expert_map: 124 | if model_id not in scores: 125 | scores[model_id] = 1.0 / self.default_rank 126 | 127 | # 5. rank & pick 128 | ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True) 129 | chosen_models = [m for m, _ in ranked if m in self.expert_map][: self.top_n] 130 | chosen_experts = [self.expert_map[m] for m in chosen_models] 131 | 132 | return RouterOutput( 133 | normal_experts=chosen_experts, 134 | thinking_experts=self.thinking_experts, 135 | ) 136 | # ------------------------------------------------------------------ 137 | # helpers 138 | # ------------------------------------------------------------------ 139 | @staticmethod 140 | def _load_rankings(path: str | Path) -> Dict[int, Dict[str, tuple[float, float]]]: 141 | with open(path, "r", encoding="utf-8") as f: 142 | raw = json.load(f) 143 | 144 | parsed: Dict[int, Dict] = {} 145 | for cid_str, cluster_info in raw.items(): 146 | cid = int(cid_str) 147 | parsed[cid] = { 148 | 'total': cluster_info['total'], 149 | 'scores': cluster_info['scores'], 150 | 'ranking': cluster_info['ranking'] 151 | } 152 | 153 | return parsed 154 | 155 | @staticmethod 156 | def _load_mapping(path: str | Path) -> Dict[str, str]: 157 | with open(path, "r", encoding="utf-8") as f: 158 | return json.load(f) 159 | 160 | @staticmethod 161 | def _load_normalizer(path: str | Path) -> Normalizer: 162 | return joblib.load(path) -------------------------------------------------------------------------------- /evaluate/MBPP/execution.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | import signal 5 | import tempfile 6 | import platform 7 | import contextlib 8 | import faulthandler 9 | import multiprocessing 10 | 11 | from typing import Optional, Callable, Dict 12 | 13 | 14 | # WARNING 15 | # This program exists to execute untrusted model-generated code. Although 16 | # it is highly unlikely that model-generated code will do something overtly 17 | # malicious in response to this test suite, model-generated code may act 18 | # destructively due to a lack of model capability or alignment. 19 | # Users are strongly encouraged to sandbox this evaluation suite so that it 20 | # does not perform destructive actions on their host or network. For more 21 | # information on how OpenAI sandboxes its code, see the accompanying paper. 22 | # Once you have read this disclaimer and taken appropriate precautions, 23 | # uncomment the 58 line and proceed at your own risk 24 | 25 | def unsafe_execute(result: list, solution: str, time_out: float): 26 | 27 | with create_tempdir(): 28 | 29 | # These system calls are needed when cleaning up tempdir. 30 | import os 31 | import shutil 32 | rmtree = shutil.rmtree 33 | rmdir = os.rmdir 34 | chdir = os.chdir 35 | 36 | # Disable functionalities that can make destructive changes to the test. 37 | reliability_guard() 38 | 39 | # Construct the check program and run it. 40 | check_program = solution 41 | 42 | try: 43 | exec_globals = {} 44 | with swallow_io(): 45 | with time_limit(time_out): 46 | exec(check_program, exec_globals) 47 | result.append("passed") 48 | except TimeoutException: 49 | result.append("timed out") 50 | except BaseException as e: 51 | result.append(f"failed: {e}") 52 | 53 | # Needed for cleaning up. 54 | shutil.rmtree = rmtree 55 | os.rmdir = rmdir 56 | os.chdir = chdir 57 | 58 | def check_correctness(task_id: int, 59 | completion_id: int, 60 | solution: str, 61 | time_out: float, 62 | ) -> Dict: 63 | """ 64 | Evaluates the functional correctness of a completion by running the test 65 | suite provided in the problem. 66 | 67 | :param completion_id: an optional completion ID so we can match 68 | the results later even if execution finishes asynchronously. 69 | """ 70 | 71 | manager = multiprocessing.Manager() 72 | result = manager.list() 73 | 74 | p = multiprocessing.Process( 75 | target=unsafe_execute, 76 | args=( 77 | result, 78 | solution, 79 | time_out 80 | ) 81 | ) 82 | p.start() 83 | p.join(timeout=time_out + 1) 84 | if p.is_alive(): 85 | p.kill() 86 | 87 | if not result: 88 | result.append("timed out") 89 | 90 | return dict( 91 | task_id = task_id, 92 | completion_id = completion_id, 93 | passed = result[0] == "passed", 94 | result = result[0], 95 | solution = solution 96 | ) 97 | 98 | 99 | @contextlib.contextmanager 100 | def time_limit(seconds: float): 101 | def signal_handler(signum, frame): 102 | raise TimeoutException("Timed out!") 103 | signal.setitimer(signal.ITIMER_REAL, seconds) 104 | signal.signal(signal.SIGALRM, signal_handler) 105 | try: 106 | yield 107 | finally: 108 | signal.setitimer(signal.ITIMER_REAL, 0) 109 | 110 | 111 | @contextlib.contextmanager 112 | def swallow_io(): 113 | stream = WriteOnlyStringIO() 114 | with contextlib.redirect_stdout(stream): 115 | with contextlib.redirect_stderr(stream): 116 | with redirect_stdin(stream): 117 | yield 118 | 119 | 120 | @contextlib.contextmanager 121 | def create_tempdir(): 122 | with tempfile.TemporaryDirectory() as dirname: 123 | with chdir(dirname): 124 | yield dirname 125 | 126 | 127 | class TimeoutException(Exception): 128 | pass 129 | 130 | 131 | class WriteOnlyStringIO(io.StringIO): 132 | """ StringIO that throws an exception when it's read from """ 133 | 134 | def read(self, *args, **kwargs): 135 | raise IOError 136 | 137 | def readline(self, *args, **kwargs): 138 | raise IOError 139 | 140 | def readlines(self, *args, **kwargs): 141 | raise IOError 142 | 143 | def readable(self, *args, **kwargs): 144 | """ Returns True if the IO object can be read. """ 145 | return False 146 | 147 | 148 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 149 | _stream = 'stdin' 150 | 151 | 152 | @contextlib.contextmanager 153 | def chdir(root): 154 | if root == ".": 155 | yield 156 | return 157 | cwd = os.getcwd() 158 | os.chdir(root) 159 | try: 160 | yield 161 | except BaseException as exc: 162 | raise exc 163 | finally: 164 | os.chdir(cwd) 165 | 166 | 167 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 168 | """ 169 | This disables various destructive functions and prevents the generated code 170 | from interfering with the test (e.g. fork bomb, killing other processes, 171 | removing filesystem files, etc.) 172 | 173 | WARNING 174 | This function is NOT a security sandbox. Untrusted code, including, model- 175 | generated code, should not be blindly executed outside of one. See the 176 | Codex paper for more information about OpenAI's code sandbox, and proceed 177 | with caution. 178 | """ 179 | 180 | if maximum_memory_bytes is not None: 181 | import resource 182 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 183 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 184 | if not platform.uname().system == 'Darwin': 185 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 186 | 187 | faulthandler.disable() 188 | 189 | import builtins 190 | builtins.exit = None 191 | builtins.quit = None 192 | 193 | import os 194 | os.environ['OMP_NUM_THREADS'] = '1' 195 | 196 | os.kill = None 197 | os.system = None 198 | os.putenv = None 199 | os.remove = None 200 | os.removedirs = None 201 | os.rmdir = None 202 | os.fchdir = None 203 | os.setuid = None 204 | os.fork = None 205 | os.forkpty = None 206 | os.killpg = None 207 | os.rename = None 208 | os.renames = None 209 | os.truncate = None 210 | os.replace = None 211 | os.unlink = None 212 | os.fchmod = None 213 | os.fchown = None 214 | os.chmod = None 215 | os.chown = None 216 | os.chroot = None 217 | os.fchdir = None 218 | os.lchflags = None 219 | os.lchmod = None 220 | os.lchown = None 221 | os.getcwd = None 222 | os.chdir = None 223 | 224 | import shutil 225 | shutil.rmtree = None 226 | shutil.move = None 227 | shutil.chown = None 228 | 229 | import subprocess 230 | subprocess.Popen = None # type: ignore 231 | 232 | __builtins__['help'] = None 233 | 234 | import sys 235 | sys.modules['ipdb'] = None 236 | sys.modules['joblib'] = None 237 | sys.modules['resource'] = None 238 | sys.modules['psutil'] = None 239 | sys.modules['tkinter'] = None -------------------------------------------------------------------------------- /evaluate/HumanEval/execution.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | import signal 5 | import tempfile 6 | import platform 7 | import contextlib 8 | import faulthandler 9 | import multiprocessing 10 | 11 | from typing import Optional, Callable, Dict 12 | 13 | 14 | # WARNING 15 | # This program exists to execute untrusted model-generated code. Although 16 | # it is highly unlikely that model-generated code will do something overtly 17 | # malicious in response to this test suite, model-generated code may act 18 | # destructively due to a lack of model capability or alignment. 19 | # Users are strongly encouraged to sandbox this evaluation suite so that it 20 | # does not perform destructive actions on their host or network. For more 21 | # information on how OpenAI sandboxes its code, see the accompanying paper. 22 | # Once you have read this disclaimer and taken appropriate precautions, 23 | # uncomment the 58 line and proceed at your own risk 24 | 25 | def unsafe_execute(result: list, solution: str, time_out: float): 26 | 27 | with create_tempdir(): 28 | 29 | # These system calls are needed when cleaning up tempdir. 30 | import os 31 | import shutil 32 | rmtree = shutil.rmtree 33 | rmdir = os.rmdir 34 | chdir = os.chdir 35 | 36 | # Disable functionalities that can make destructive changes to the test. 37 | reliability_guard() 38 | 39 | # Construct the check program and run it. 40 | check_program = solution 41 | 42 | try: 43 | exec_globals = {} 44 | with swallow_io(): 45 | with time_limit(time_out): 46 | exec(check_program, exec_globals) 47 | result.append("passed") 48 | except TimeoutException: 49 | result.append("timed out") 50 | except BaseException as e: 51 | result.append(f"failed: {e}") 52 | 53 | # Needed for cleaning up. 54 | shutil.rmtree = rmtree 55 | os.rmdir = rmdir 56 | os.chdir = chdir 57 | 58 | def check_correctness(task_id: int, 59 | completion_id: int, 60 | solution: str, 61 | time_out: float, 62 | ) -> Dict: 63 | """ 64 | Evaluates the functional correctness of a completion by running the test 65 | suite provided in the problem. 66 | 67 | :param completion_id: an optional completion ID so we can match 68 | the results later even if execution finishes asynchronously. 69 | """ 70 | 71 | manager = multiprocessing.Manager() 72 | result = manager.list() 73 | 74 | p = multiprocessing.Process( 75 | target=unsafe_execute, 76 | args=( 77 | result, 78 | solution, 79 | time_out 80 | ) 81 | ) 82 | p.start() 83 | p.join(timeout=time_out + 1) 84 | if p.is_alive(): 85 | p.kill() 86 | 87 | if not result: 88 | result.append("timed out") 89 | 90 | return dict( 91 | task_id = task_id, 92 | completion_id = completion_id, 93 | passed = result[0] == "passed", 94 | result = result[0], 95 | solution = solution 96 | ) 97 | 98 | 99 | @contextlib.contextmanager 100 | def time_limit(seconds: float): 101 | def signal_handler(signum, frame): 102 | raise TimeoutException("Timed out!") 103 | signal.setitimer(signal.ITIMER_REAL, seconds) 104 | signal.signal(signal.SIGALRM, signal_handler) 105 | try: 106 | yield 107 | finally: 108 | signal.setitimer(signal.ITIMER_REAL, 0) 109 | 110 | 111 | @contextlib.contextmanager 112 | def swallow_io(): 113 | stream = WriteOnlyStringIO() 114 | with contextlib.redirect_stdout(stream): 115 | with contextlib.redirect_stderr(stream): 116 | with redirect_stdin(stream): 117 | yield 118 | 119 | 120 | @contextlib.contextmanager 121 | def create_tempdir(): 122 | with tempfile.TemporaryDirectory() as dirname: 123 | with chdir(dirname): 124 | yield dirname 125 | 126 | 127 | class TimeoutException(Exception): 128 | pass 129 | 130 | 131 | class WriteOnlyStringIO(io.StringIO): 132 | """ StringIO that throws an exception when it's read from """ 133 | 134 | def read(self, *args, **kwargs): 135 | raise IOError 136 | 137 | def readline(self, *args, **kwargs): 138 | raise IOError 139 | 140 | def readlines(self, *args, **kwargs): 141 | raise IOError 142 | 143 | def readable(self, *args, **kwargs): 144 | """ Returns True if the IO object can be read. """ 145 | return False 146 | 147 | 148 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 149 | _stream = 'stdin' 150 | 151 | 152 | @contextlib.contextmanager 153 | def chdir(root): 154 | if root == ".": 155 | yield 156 | return 157 | cwd = os.getcwd() 158 | os.chdir(root) 159 | try: 160 | yield 161 | except BaseException as exc: 162 | raise exc 163 | finally: 164 | os.chdir(cwd) 165 | 166 | 167 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 168 | """ 169 | This disables various destructive functions and prevents the generated code 170 | from interfering with the test (e.g. fork bomb, killing other processes, 171 | removing filesystem files, etc.) 172 | 173 | WARNING 174 | This function is NOT a security sandbox. Untrusted code, including, model- 175 | generated code, should not be blindly executed outside of one. See the 176 | Codex paper for more information about OpenAI's code sandbox, and proceed 177 | with caution. 178 | """ 179 | 180 | if maximum_memory_bytes is not None: 181 | import resource 182 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 183 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 184 | if not platform.uname().system == 'Darwin': 185 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 186 | 187 | faulthandler.disable() 188 | 189 | import builtins 190 | builtins.exit = None 191 | builtins.quit = None 192 | 193 | import os 194 | os.environ['OMP_NUM_THREADS'] = '1' 195 | 196 | os.kill = None 197 | os.system = None 198 | os.putenv = None 199 | os.remove = None 200 | os.removedirs = None 201 | os.rmdir = None 202 | os.fchdir = None 203 | os.setuid = None 204 | os.fork = None 205 | os.forkpty = None 206 | os.killpg = None 207 | os.rename = None 208 | os.renames = None 209 | os.truncate = None 210 | os.replace = None 211 | os.unlink = None 212 | os.fchmod = None 213 | os.fchown = None 214 | os.chmod = None 215 | os.chown = None 216 | os.chroot = None 217 | os.fchdir = None 218 | os.lchflags = None 219 | os.lchmod = None 220 | os.lchown = None 221 | os.getcwd = None 222 | os.chdir = None 223 | 224 | import shutil 225 | shutil.rmtree = None 226 | shutil.move = None 227 | shutil.chown = None 228 | 229 | import subprocess 230 | subprocess.Popen = None # type: ignore 231 | 232 | __builtins__['help'] = None 233 | 234 | import sys 235 | sys.modules['ipdb'] = None 236 | sys.modules['joblib'] = None 237 | sys.modules['resource'] = None 238 | sys.modules['psutil'] = None 239 | sys.modules['tkinter'] = None -------------------------------------------------------------------------------- /evaluate/StudentEval/execution.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | import signal 5 | import tempfile 6 | import platform 7 | import contextlib 8 | import faulthandler 9 | import multiprocessing 10 | 11 | from typing import Optional, Callable, Dict 12 | 13 | 14 | # WARNING 15 | # This program exists to execute untrusted model-generated code. Although 16 | # it is highly unlikely that model-generated code will do something overtly 17 | # malicious in response to this test suite, model-generated code may act 18 | # destructively due to a lack of model capability or alignment. 19 | # Users are strongly encouraged to sandbox this evaluation suite so that it 20 | # does not perform destructive actions on their host or network. For more 21 | # information on how OpenAI sandboxes its code, see the accompanying paper. 22 | # Once you have read this disclaimer and taken appropriate precautions, 23 | # uncomment the 58 line and proceed at your own risk 24 | 25 | def unsafe_execute(result: list, solution: str, time_out: float): 26 | 27 | with create_tempdir(): 28 | 29 | # These system calls are needed when cleaning up tempdir. 30 | import os 31 | import shutil 32 | rmtree = shutil.rmtree 33 | rmdir = os.rmdir 34 | chdir = os.chdir 35 | 36 | # Disable functionalities that can make destructive changes to the test. 37 | reliability_guard() 38 | 39 | # Construct the check program and run it. 40 | check_program = solution 41 | 42 | try: 43 | exec_globals = {} 44 | with swallow_io(): 45 | with time_limit(time_out): 46 | exec(check_program, exec_globals) 47 | result.append("passed") 48 | except TimeoutException: 49 | result.append("timed out") 50 | except BaseException as e: 51 | result.append(f"failed: {e}") 52 | 53 | # Needed for cleaning up. 54 | shutil.rmtree = rmtree 55 | os.rmdir = rmdir 56 | os.chdir = chdir 57 | 58 | def check_correctness(task_id: int, 59 | completion_id: int, 60 | solution: str, 61 | time_out: float, 62 | ) -> Dict: 63 | """ 64 | Evaluates the functional correctness of a completion by running the test 65 | suite provided in the problem. 66 | 67 | :param completion_id: an optional completion ID so we can match 68 | the results later even if execution finishes asynchronously. 69 | """ 70 | 71 | manager = multiprocessing.Manager() 72 | result = manager.list() 73 | 74 | p = multiprocessing.Process( 75 | target=unsafe_execute, 76 | args=( 77 | result, 78 | solution, 79 | time_out 80 | ) 81 | ) 82 | p.start() 83 | p.join(timeout=time_out + 1) 84 | if p.is_alive(): 85 | p.kill() 86 | 87 | if not result: 88 | result.append("timed out") 89 | 90 | return dict( 91 | task_id = task_id, 92 | completion_id = completion_id, 93 | passed = result[0] == "passed", 94 | result = result[0], 95 | solution = solution 96 | ) 97 | 98 | 99 | @contextlib.contextmanager 100 | def time_limit(seconds: float): 101 | def signal_handler(signum, frame): 102 | raise TimeoutException("Timed out!") 103 | signal.setitimer(signal.ITIMER_REAL, seconds) 104 | signal.signal(signal.SIGALRM, signal_handler) 105 | try: 106 | yield 107 | finally: 108 | signal.setitimer(signal.ITIMER_REAL, 0) 109 | 110 | 111 | @contextlib.contextmanager 112 | def swallow_io(): 113 | stream = WriteOnlyStringIO() 114 | with contextlib.redirect_stdout(stream): 115 | with contextlib.redirect_stderr(stream): 116 | with redirect_stdin(stream): 117 | yield 118 | 119 | 120 | @contextlib.contextmanager 121 | def create_tempdir(): 122 | with tempfile.TemporaryDirectory() as dirname: 123 | with chdir(dirname): 124 | yield dirname 125 | 126 | 127 | class TimeoutException(Exception): 128 | pass 129 | 130 | 131 | class WriteOnlyStringIO(io.StringIO): 132 | """ StringIO that throws an exception when it's read from """ 133 | 134 | def read(self, *args, **kwargs): 135 | raise IOError 136 | 137 | def readline(self, *args, **kwargs): 138 | raise IOError 139 | 140 | def readlines(self, *args, **kwargs): 141 | raise IOError 142 | 143 | def readable(self, *args, **kwargs): 144 | """ Returns True if the IO object can be read. """ 145 | return False 146 | 147 | 148 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 149 | _stream = 'stdin' 150 | 151 | 152 | @contextlib.contextmanager 153 | def chdir(root): 154 | if root == ".": 155 | yield 156 | return 157 | cwd = os.getcwd() 158 | os.chdir(root) 159 | try: 160 | yield 161 | except BaseException as exc: 162 | raise exc 163 | finally: 164 | os.chdir(cwd) 165 | 166 | 167 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 168 | """ 169 | This disables various destructive functions and prevents the generated code 170 | from interfering with the test (e.g. fork bomb, killing other processes, 171 | removing filesystem files, etc.) 172 | 173 | WARNING 174 | This function is NOT a security sandbox. Untrusted code, including, model- 175 | generated code, should not be blindly executed outside of one. See the 176 | Codex paper for more information about OpenAI's code sandbox, and proceed 177 | with caution. 178 | """ 179 | 180 | if maximum_memory_bytes is not None: 181 | import resource 182 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 183 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 184 | if not platform.uname().system == 'Darwin': 185 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 186 | 187 | faulthandler.disable() 188 | 189 | import builtins 190 | builtins.exit = None 191 | builtins.quit = None 192 | 193 | import os 194 | os.environ['OMP_NUM_THREADS'] = '1' 195 | 196 | os.kill = None 197 | os.system = None 198 | os.putenv = None 199 | os.remove = None 200 | os.removedirs = None 201 | os.rmdir = None 202 | os.fchdir = None 203 | os.setuid = None 204 | os.fork = None 205 | os.forkpty = None 206 | os.killpg = None 207 | os.rename = None 208 | os.renames = None 209 | os.truncate = None 210 | os.replace = None 211 | os.unlink = None 212 | os.fchmod = None 213 | os.fchown = None 214 | os.chmod = None 215 | os.chown = None 216 | os.chroot = None 217 | os.fchdir = None 218 | os.lchflags = None 219 | os.lchmod = None 220 | os.lchown = None 221 | os.getcwd = None 222 | os.chdir = None 223 | 224 | import shutil 225 | shutil.rmtree = None 226 | shutil.move = None 227 | shutil.chown = None 228 | 229 | import subprocess 230 | subprocess.Popen = None # type: ignore 231 | 232 | __builtins__['help'] = None 233 | 234 | import sys 235 | sys.modules['ipdb'] = None 236 | sys.modules['joblib'] = None 237 | sys.modules['resource'] = None 238 | sys.modules['psutil'] = None 239 | sys.modules['tkinter'] = None -------------------------------------------------------------------------------- /evaluate/GPQA/gpqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | 16 | disable_progress_bars() 17 | 18 | DATA_DIR = "data/GPQA" 19 | 20 | PROMPT = """Answer the following question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of {candidates}. 21 | 22 | {question} 23 | 24 | {options} 25 | 26 | Let's think step by step.""" 27 | 28 | ANSWER_PATTERN = r"(?i)Answer\s*:\s*([A-D])[.\s\n]?" 29 | 30 | class GPQAEvaluator(BaseEvaluator): 31 | def __init__(self, max_workers: int = 8, mode: str="test"): 32 | super().__init__(max_workers=max_workers, mode=mode) 33 | self.task = "GPQA" 34 | self.seed = 42 35 | 36 | def load_data(self, split: str): 37 | data = self.load_jsonl(os.path.join(DATA_DIR, f"gpqa_main.json")) 38 | 39 | data = Dataset.from_list(data) 40 | data = data.map(lambda x: self.format_prompt(x)) 41 | if self.mode == "test": 42 | # split data into train and test 43 | logger.warning(f"Split data into train and test for {self.task}") 44 | split_data = data.train_test_split(test_size=0.3) 45 | train_data = split_data["train"] 46 | data = split_data["test"] 47 | logger.info(f"Calibration data: {len(train_data)}") 48 | logger.info(f"Test data: {len(data)}") 49 | 50 | return data 51 | 52 | def format_prompt(self, item: Dict) -> Dict: 53 | prompt = PROMPT.format( 54 | subject = item["High-level domain"], 55 | question = item["question"], 56 | candidates = "".join(list(item["options"].keys())), 57 | options = "\n".join([f"{key}. {value}" for key, value in item["options"].items()]) 58 | ) 59 | return {"prompt": prompt} 60 | 61 | def extract_raw_answer(self, raw_datas: list[str]) -> list[str]: 62 | return [ 63 | self.extract_normal_answer(text=data, answer_pattern=ANSWER_PATTERN) 64 | for data in raw_datas 65 | ] 66 | 67 | def process_output(self, output: GeneratorOutput): 68 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output) 69 | prediction = Counter(full_prediction).most_common(1)[0][0] 70 | prediction_stats = self.count_prediction_frequency(predictions=full_prediction) 71 | 72 | return prediction, full_prediction, prediction_stats 73 | 74 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 75 | answer = data['answer'] 76 | 77 | # step 1. get router, get generator 78 | router_result = router.route(question=data['prompt']) 79 | generator = GeneratorFactory.create_generator( 80 | experts=router_result, generator_config=generator_config 81 | ) # type: ignore 82 | 83 | # step 2. generate & update token usage 84 | if generator_config['type'] == 'model_switch': 85 | output: tuple[GeneratorOutput, GeneratorOutput] = generator.generate(question=data['prompt']) 86 | first_output, final_output = output 87 | prediction, full_prediction, prediction_stats = self.process_output(output=first_output) 88 | consistency_rate = prediction_stats[prediction]['frequency'] 89 | if consistency_rate < generator.consistency_rate_threshold: 90 | prediction, full_prediction, prediction_stats = self.process_output(output=final_output) 91 | output = final_output 92 | else: 93 | output = first_output 94 | elif generator_config['type'] == 'fast_slow': 95 | output: GeneratorOutput = generator.generate(question=data['prompt']) 96 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 97 | consistency_rate = prediction_stats[prediction]['frequency'] 98 | if consistency_rate < generator.consistency_rate_threshold: 99 | slow_output = generator.slow_generate(question=data['prompt']) 100 | prediction, full_prediction, prediction_stats = self.process_output(output=slow_output) 101 | if prediction == "": 102 | logger.warning(f"slow_output is empty, use fast_output to replace.") 103 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 104 | else: 105 | output = slow_output 106 | else: 107 | output: GeneratorOutput = generator.generate(question=data['prompt']) 108 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 109 | 110 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 111 | 112 | is_correct = answer == prediction 113 | 114 | return dict( 115 | index=index, 116 | query=data['prompt'], 117 | origin_query=data['question'], 118 | prediction=prediction, 119 | full_prediction=full_prediction, 120 | prediction_stats=prediction_stats, 121 | raw_output=output.raw_output, 122 | answer=answer, 123 | is_correct=is_correct, 124 | model_name=generator.model 125 | ) 126 | 127 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 128 | start_time = time.time() 129 | data = self.load_data(split="test") 130 | 131 | counter = 0 132 | results = [] 133 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 134 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 135 | futures = [ 136 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 137 | for idx, d in enumerate(data) 138 | ] 139 | for future in as_completed(futures): 140 | result = future.result() 141 | results.append(result) 142 | if result['is_correct']: 143 | counter += 1 144 | pbar.update(1) 145 | pbar.close() 146 | 147 | model_counts = self.calculate_model_counts(results=results) 148 | 149 | acc = counter / len(data) 150 | end_time = time.time() 151 | logger.info(f"Task: {self.task}") 152 | logger.info(f"Accuracy: {acc}") 153 | logger.info(f"Time taken: {end_time - start_time} seconds") 154 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 155 | logger.info(f"Completion tokens: {self.completion_tokens}") 156 | 157 | return { 158 | "performance": acc, 159 | "time_taken": end_time - start_time, 160 | "prompt_tokens": self.prompt_tokens, 161 | "completion_tokens": self.completion_tokens, 162 | "model_counts": model_counts, 163 | "records": results, 164 | } -------------------------------------------------------------------------------- /diversity/embedding_cache.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pathlib import Path 3 | import sqlite3 4 | import hashlib 5 | import json 6 | import os 7 | import time 8 | import threading 9 | from loguru import logger 10 | from openai import OpenAI, RateLimitError, APIError 11 | 12 | __all__ = ["EmbeddingCache"] 13 | 14 | 15 | class EmbeddingCache: 16 | """Thin wrapper around the OpenAI embedding endpoint with SQLite caching.""" 17 | def __init__( 18 | self, 19 | base_url: str = None, 20 | api_key: str = None, 21 | model_name: str = "jina-embeddings-v3", 22 | cache_dir: str | os.PathLike = None, 23 | max_retries: int = 5, 24 | initial_delay: float = 1.0, 25 | ) -> None: 26 | self.model_name = model_name 27 | self.max_retries = max_retries 28 | self.initial_delay = initial_delay 29 | 30 | # Get configuration from environment variables 31 | self.base_url = base_url or os.getenv("EMBEDDING_API_BASE_URL") 32 | self.api_key = api_key or os.getenv("EMBEDDING_API_KEY") 33 | 34 | if not self.base_url or not self.api_key: 35 | raise ValueError( 36 | "API configuration not found. Please set EMBEDDING_API_BASE_URL and EMBEDDING_API_KEY " 37 | "environment variables or provide them as arguments." 38 | ) 39 | 40 | self._client = OpenAI( 41 | base_url=self.base_url, 42 | api_key=self.api_key, 43 | ) 44 | 45 | # Set up cache directory 46 | if cache_dir is None: 47 | cache_dir = os.getenv("EMBEDDING_CACHE_DIR", ".cache") 48 | cache_path = Path(cache_dir) 49 | cache_path.mkdir(parents=True, exist_ok=True) 50 | self.db_path = cache_path / "embeddings.db" 51 | 52 | # —— 单一持久连接 —— 53 | self._conn = self._open_conn() 54 | self._init_db() 55 | 56 | # 写锁,保证一次只写一条 57 | self._w_lock = threading.Lock() 58 | 59 | self._init_db() 60 | 61 | # --------------------------------------------------------------------- 62 | # public helpers 63 | # --------------------------------------------------------------------- 64 | 65 | def get(self, text: str) -> List[float]: 66 | """Return the embedding for *text*, fetching from cache or remote.""" 67 | text_hash = hashlib.md5(text.encode()).hexdigest() 68 | 69 | # 1. try cache ------------------------------------------------------ 70 | row = self._select(text_hash) 71 | if row is not None: 72 | return row 73 | 74 | # 2. call OpenAI ---------------------------------------------------- 75 | delay = self.initial_delay 76 | for attempt in range(self.max_retries): 77 | try: 78 | rsp = self._client.embeddings.create(input=text, model=self.model_name) 79 | emb: List[float] = rsp.data[0].embedding # type: ignore[index] 80 | self._insert(text_hash, text, emb) 81 | return emb 82 | 83 | except RateLimitError as e: 84 | logger.warning( 85 | f"Rate limited (attempt {attempt+1}/{self.max_retries}). Retry in {delay:.1f}s" 86 | ) 87 | except APIError as e: 88 | logger.warning( 89 | f"OpenAI API error (attempt {attempt+1}/{self.max_retries}): {e}. Retry in {delay:.1f}s" 90 | ) 91 | except Exception as e: 92 | logger.error(f"Unexpected error — abort: {e}") 93 | raise 94 | 95 | time.sleep(delay) 96 | delay *= 2 97 | 98 | raise RuntimeError("Failed to get embedding after multiple retries.") 99 | 100 | def batch(self, texts: List[str], max_batch_size: int = 100) -> List[List[float]]: 101 | """Return embeddings for a list of texts (keeps order). 102 | 103 | Args: 104 | texts: List of texts to get embeddings for 105 | max_batch_size: Maximum number of texts to process in a single API call 106 | """ 107 | all_embeddings: List[List[float]] = [] 108 | 109 | # Process texts in chunks 110 | for i in range(0, len(texts), max_batch_size): 111 | chunk = texts[i:i + max_batch_size] 112 | 113 | # 原有的缓存查询和API调用逻辑 114 | hits: List[List[float]] = [] 115 | misses: List[str] = [] 116 | mapping: dict[str, int] = {} 117 | 118 | for idx, t in enumerate(chunk): 119 | h = hashlib.md5(t.encode()).hexdigest() 120 | row = self._select(h) 121 | if row is not None: 122 | hits.append(row) 123 | else: 124 | mapping[t] = idx 125 | misses.append(t) 126 | 127 | if misses: 128 | rsp = self._client.embeddings.create(input=misses, model=self.model_name) 129 | for text, record in zip(misses, rsp.data): 130 | emb: List[float] = record.embedding 131 | self._insert(hashlib.md5(text.encode()).hexdigest(), text, emb) 132 | hits.insert(mapping[text], emb) 133 | 134 | all_embeddings.extend(hits) 135 | 136 | return all_embeddings 137 | 138 | # ------------------------------------------------------------------ 139 | # private db helpers 140 | # ------------------------------------------------------------------ 141 | def _open_conn(self) -> sqlite3.Connection: 142 | conn = sqlite3.connect( 143 | self.db_path, 144 | timeout=30, # 等锁 30s 145 | check_same_thread=False, # 允许跨线程 146 | isolation_level=None # autocommit 147 | ) 148 | conn.execute("PRAGMA journal_mode=WAL;") 149 | conn.execute("PRAGMA synchronous=NORMAL;") 150 | return conn 151 | 152 | def _init_db(self) -> None: 153 | with sqlite3.connect(self.db_path) as conn: 154 | conn.execute( 155 | """ 156 | CREATE TABLE IF NOT EXISTS embeddings ( 157 | text_hash TEXT, 158 | model TEXT, 159 | embedding TEXT NOT NULL, 160 | text TEXT, 161 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 162 | PRIMARY KEY(text_hash, model) 163 | ) 164 | """ 165 | ) 166 | conn.execute("CREATE INDEX IF NOT EXISTS idx_model ON embeddings(model);") 167 | 168 | def _select(self, text_hash: str) -> List[float] | None: 169 | with sqlite3.connect(self.db_path) as conn: 170 | row = conn.execute( 171 | "SELECT embedding FROM embeddings WHERE text_hash=? AND model=?", 172 | (text_hash, self.model_name), 173 | ).fetchone() 174 | if row: 175 | return json.loads(row[0]) # stored as JSON string for readability 176 | return None 177 | 178 | def _insert(self, text_hash: str, text: str, embedding: List[float]) -> None: 179 | with sqlite3.connect(self.db_path) as conn: 180 | conn.execute( 181 | "INSERT OR REPLACE INTO embeddings (text_hash, model, embedding, text) VALUES (?,?,?,?)", 182 | (text_hash, self.model_name, json.dumps(embedding), text), 183 | ) -------------------------------------------------------------------------------- /core/ablation/embedding_cache.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pathlib import Path 3 | import sqlite3 4 | import hashlib 5 | import json 6 | import os 7 | import time 8 | import threading 9 | from loguru import logger 10 | from openai import OpenAI, RateLimitError, APIError 11 | 12 | __all__ = ["EmbeddingCache"] 13 | 14 | 15 | class EmbeddingCache: 16 | """Thin wrapper around the OpenAI embedding endpoint with SQLite caching.""" 17 | def __init__( 18 | self, 19 | base_url: str = None, 20 | api_key: str = None, 21 | model_name: str = "jina-embeddings-v3", 22 | cache_dir: str | os.PathLike = None, 23 | max_retries: int = 5, 24 | initial_delay: float = 1.0, 25 | ) -> None: 26 | self.model_name = model_name 27 | self.max_retries = max_retries 28 | self.initial_delay = initial_delay 29 | 30 | # Get configuration from environment variables 31 | self.base_url = base_url or os.getenv("EMBEDDING_API_BASE_URL") 32 | self.api_key = api_key or os.getenv("EMBEDDING_API_KEY") 33 | 34 | if not self.base_url or not self.api_key: 35 | raise ValueError( 36 | "API configuration not found. Please set EMBEDDING_API_BASE_URL and EMBEDDING_API_KEY " 37 | "environment variables or provide them as arguments." 38 | ) 39 | 40 | self._client = OpenAI( 41 | base_url=self.base_url, 42 | api_key=self.api_key, 43 | ) 44 | 45 | # Set up cache directory 46 | if cache_dir is None: 47 | cache_dir = os.getenv("EMBEDDING_CACHE_DIR", ".cache") 48 | cache_path = Path(cache_dir) 49 | cache_path.mkdir(parents=True, exist_ok=True) 50 | self.db_path = cache_path / "embeddings.db" 51 | 52 | # —— 单一持久连接 —— 53 | self._conn = self._open_conn() 54 | self._init_db() 55 | 56 | # 写锁,保证一次只写一条 57 | self._w_lock = threading.Lock() 58 | 59 | self._init_db() 60 | 61 | # --------------------------------------------------------------------- 62 | # public helpers 63 | # --------------------------------------------------------------------- 64 | 65 | def get(self, text: str) -> List[float]: 66 | """Return the embedding for *text*, fetching from cache or remote.""" 67 | text_hash = hashlib.md5(text.encode()).hexdigest() 68 | 69 | # 1. try cache ------------------------------------------------------ 70 | row = self._select(text_hash) 71 | if row is not None: 72 | return row 73 | 74 | # 2. call OpenAI ---------------------------------------------------- 75 | delay = self.initial_delay 76 | for attempt in range(self.max_retries): 77 | try: 78 | rsp = self._client.embeddings.create(input=text, model=self.model_name) 79 | emb: List[float] = rsp.data[0].embedding # type: ignore[index] 80 | self._insert(text_hash, text, emb) 81 | return emb 82 | 83 | except RateLimitError as e: 84 | logger.warning( 85 | f"Rate limited (attempt {attempt+1}/{self.max_retries}). Retry in {delay:.1f}s" 86 | ) 87 | except APIError as e: 88 | logger.warning( 89 | f"OpenAI API error (attempt {attempt+1}/{self.max_retries}): {e}. Retry in {delay:.1f}s" 90 | ) 91 | except Exception as e: 92 | logger.error(f"Unexpected error — abort: {e}") 93 | raise 94 | 95 | time.sleep(delay) 96 | delay *= 2 97 | 98 | raise RuntimeError("Failed to get embedding after multiple retries.") 99 | 100 | def batch(self, texts: List[str], max_batch_size: int = 100) -> List[List[float]]: 101 | """Return embeddings for a list of texts (keeps order). 102 | 103 | Args: 104 | texts: List of texts to get embeddings for 105 | max_batch_size: Maximum number of texts to process in a single API call 106 | """ 107 | all_embeddings: List[List[float]] = [] 108 | 109 | # Process texts in chunks 110 | for i in range(0, len(texts), max_batch_size): 111 | chunk = texts[i:i + max_batch_size] 112 | 113 | # 原有的缓存查询和API调用逻辑 114 | hits: List[List[float]] = [] 115 | misses: List[str] = [] 116 | mapping: dict[str, int] = {} 117 | 118 | for idx, t in enumerate(chunk): 119 | h = hashlib.md5(t.encode()).hexdigest() 120 | row = self._select(h) 121 | if row is not None: 122 | hits.append(row) 123 | else: 124 | mapping[t] = idx 125 | misses.append(t) 126 | 127 | if misses: 128 | rsp = self._client.embeddings.create(input=misses, model=self.model_name) 129 | for text, record in zip(misses, rsp.data): 130 | emb: List[float] = record.embedding 131 | self._insert(hashlib.md5(text.encode()).hexdigest(), text, emb) 132 | hits.insert(mapping[text], emb) 133 | 134 | all_embeddings.extend(hits) 135 | 136 | return all_embeddings 137 | 138 | # ------------------------------------------------------------------ 139 | # private db helpers 140 | # ------------------------------------------------------------------ 141 | def _open_conn(self) -> sqlite3.Connection: 142 | conn = sqlite3.connect( 143 | self.db_path, 144 | timeout=30, # 等锁 30s 145 | check_same_thread=False, # 允许跨线程 146 | isolation_level=None # autocommit 147 | ) 148 | conn.execute("PRAGMA journal_mode=WAL;") 149 | conn.execute("PRAGMA synchronous=NORMAL;") 150 | return conn 151 | 152 | def _init_db(self) -> None: 153 | with sqlite3.connect(self.db_path) as conn: 154 | conn.execute( 155 | """ 156 | CREATE TABLE IF NOT EXISTS embeddings ( 157 | text_hash TEXT, 158 | model TEXT, 159 | embedding TEXT NOT NULL, 160 | text TEXT, 161 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 162 | PRIMARY KEY(text_hash, model) 163 | ) 164 | """ 165 | ) 166 | conn.execute("CREATE INDEX IF NOT EXISTS idx_model ON embeddings(model);") 167 | 168 | def _select(self, text_hash: str) -> List[float] | None: 169 | with sqlite3.connect(self.db_path) as conn: 170 | row = conn.execute( 171 | "SELECT embedding FROM embeddings WHERE text_hash=? AND model=?", 172 | (text_hash, self.model_name), 173 | ).fetchone() 174 | if row: 175 | return json.loads(row[0]) # stored as JSON string for readability 176 | return None 177 | 178 | def _insert(self, text_hash: str, text: str, embedding: List[float]) -> None: 179 | with sqlite3.connect(self.db_path) as conn: 180 | conn.execute( 181 | "INSERT OR REPLACE INTO embeddings (text_hash, model, embedding, text) VALUES (?,?,?,?)", 182 | (text_hash, self.model_name, json.dumps(embedding), text), 183 | ) -------------------------------------------------------------------------------- /evaluate/MedQA/medqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | 16 | disable_progress_bars() 17 | 18 | DATA_DIR = "data/MedQA" 19 | 20 | PROMPT = """Answer the following question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of {candidates}. 21 | 22 | {question} 23 | 24 | {options} 25 | 26 | Let's think step by step.""" 27 | 28 | ANSWER_PATTERN = r"(?i)Answer\s*:\s*([A-D])[.\s\n]?" 29 | 30 | class MedQAEvaluator(BaseEvaluator): 31 | def __init__(self, max_workers: int=8, mode: str="test"): 32 | super().__init__(max_workers=max_workers, mode=mode) 33 | self.task = "MedQA" 34 | self.seed = 42 35 | 36 | def load_data(self, split: str): 37 | data = self.load_jsonl(os.path.join(DATA_DIR, f"{split}.json")) 38 | 39 | data = Dataset.from_list(data) 40 | data = data.map(lambda x: self.format_prompt(x)) 41 | 42 | if self.mode == "test": 43 | # split data into train and test 44 | logger.warning(f"Split data into train and test for {self.task}") 45 | split_data = data.train_test_split(test_size=0.3) 46 | train_data = split_data["train"] 47 | data = split_data["test"] 48 | logger.info(f"Calibration data: {len(train_data)}") 49 | logger.info(f"Test data: {len(data)}") 50 | 51 | return data 52 | 53 | def format_prompt(self, item: Dict) -> Dict: 54 | prompt = PROMPT.format( 55 | question = item["question"], 56 | candidates = "".join(list(item["options"].keys())), 57 | options = "\n".join([f"{key}. {value}" for key, value in item["options"].items()]) 58 | ) 59 | return {"prompt": prompt} 60 | 61 | def extract_raw_answer(self, raw_datas: list[str]) -> list[str]: 62 | return [ 63 | self.extract_normal_answer(text=data, answer_pattern=ANSWER_PATTERN) 64 | for data in raw_datas 65 | ] 66 | 67 | def process_output(self, output: GeneratorOutput): 68 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output) 69 | prediction = Counter(full_prediction).most_common(1)[0][0] 70 | prediction_stats = self.count_prediction_frequency(predictions=full_prediction) 71 | 72 | return prediction, full_prediction, prediction_stats 73 | 74 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 75 | answer = data['answer'] 76 | 77 | # step 1. get router, get generator 78 | router_result = router.route(question=data['prompt']) 79 | generator = GeneratorFactory.create_generator( 80 | experts=router_result, generator_config=generator_config 81 | ) # type: ignore 82 | 83 | # step 2. generate & update token usage 84 | if generator_config['type'] == 'model_switch': 85 | output: tuple[GeneratorOutput, GeneratorOutput] = generator.generate(question=data['prompt']) 86 | first_output, final_output = output 87 | prediction, full_prediction, prediction_stats = self.process_output(output=first_output) 88 | consistency_rate = prediction_stats[prediction]['frequency'] 89 | if consistency_rate < generator.consistency_rate_threshold: 90 | prediction, full_prediction, prediction_stats = self.process_output(output=final_output) 91 | output = final_output 92 | else: 93 | output = first_output 94 | elif generator_config['type'] == 'fast_slow': 95 | output: GeneratorOutput = generator.generate(question=data['prompt']) 96 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 97 | consistency_rate = prediction_stats[prediction]['frequency'] 98 | if consistency_rate < generator.consistency_rate_threshold: 99 | slow_output = generator.slow_generate(question=data['prompt']) 100 | prediction, full_prediction, prediction_stats = self.process_output(output=slow_output) 101 | if prediction == "": 102 | logger.warning(f"slow_output is empty, use fast_output to replace.") 103 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 104 | else: 105 | output = slow_output 106 | else: 107 | output: GeneratorOutput = generator.generate(question=data['prompt']) 108 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 109 | 110 | # step 3. update token usage 111 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 112 | 113 | # step 4. calculate is_correct 114 | is_correct = answer == prediction 115 | 116 | return dict( 117 | index=index, 118 | query=data['prompt'], 119 | origin_query=data['question'], 120 | prediction=prediction, 121 | full_prediction=full_prediction, 122 | prediction_stats=prediction_stats, 123 | raw_output=output.raw_output, 124 | answer=answer, 125 | is_correct=is_correct, 126 | model_name=generator.model 127 | ) 128 | 129 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 130 | start_time = time.time() 131 | data = self.load_data(split="test") 132 | 133 | counter = 0 134 | results = [] 135 | pbar = tqdm(total=len(data), desc="Evaluating MedQA ...") 136 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 137 | futures = [ 138 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 139 | for idx, d in enumerate(data) 140 | ] 141 | for future in as_completed(futures): 142 | result = future.result() 143 | results.append(result) 144 | if result['is_correct']: 145 | counter += 1 146 | pbar.update(1) 147 | pbar.close() 148 | 149 | model_counts = self.calculate_model_counts(results=results) 150 | 151 | acc = counter / len(data) 152 | end_time = time.time() 153 | logger.info(f"Task: {self.task}") 154 | logger.info(f"Accuracy: {acc}") 155 | logger.info(f"Time taken: {end_time - start_time} seconds") 156 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 157 | logger.info(f"Completion tokens: {self.completion_tokens}") 158 | 159 | return { 160 | "performance": acc, 161 | "time_taken": end_time - start_time, 162 | "prompt_tokens": self.prompt_tokens, 163 | "completion_tokens": self.completion_tokens, 164 | "model_counts": model_counts, 165 | "records": results 166 | } -------------------------------------------------------------------------------- /evaluate/TruthfulQA/truthfulqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | 16 | disable_progress_bars() 17 | 18 | DATA_DIR = "data/TruthfulQA" 19 | 20 | PROMPT = """Answer the following question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of {candidates}. 21 | 22 | {question} 23 | 24 | {options} 25 | 26 | Let's think step by step.""" 27 | 28 | ANSWER_PATTERN = r"(?i)Answer\s*:\s*([A-D])[.\s\n]?" 29 | 30 | class TruthfulQAEvaluator(BaseEvaluator): 31 | def __init__(self, max_workers: int=8, mode: str="test"): 32 | super().__init__(max_workers=max_workers, mode=mode) 33 | self.task = "TruthfulQA" 34 | self.seed = 42 35 | 36 | def load_data(self, split: str): 37 | data = self.load_jsonl(os.path.join(DATA_DIR, f"{split}.json")) 38 | 39 | data = Dataset.from_list(data) 40 | data = data.map(lambda x: self.format_prompt(x)) 41 | 42 | if self.mode == "test": 43 | # split data into train and test 44 | logger.warning(f"Split data into train and test for {self.task}") 45 | split_data = data.train_test_split(test_size=0.3) 46 | train_data = split_data["train"] 47 | data = split_data["test"] 48 | logger.info(f"Calibration data: {len(train_data)}") 49 | logger.info(f"Test data: {len(data)}") 50 | 51 | return data 52 | 53 | def format_prompt(self, item: Dict) -> Dict: 54 | question = item["question"] 55 | options = item["options"] 56 | options = "\n".join([f"{chr(65+i)}) {option}" for i, option in enumerate(options)]) 57 | prompt = PROMPT.format( 58 | question = question, 59 | candidates = "A, B, C, D", 60 | options = options 61 | ) 62 | return {"prompt": prompt} 63 | 64 | def extract_raw_answer(self, raw_datas: list[str]) -> list[str]: 65 | return [ 66 | self.extract_normal_answer(text=data, answer_pattern=ANSWER_PATTERN) 67 | for data in raw_datas 68 | ] 69 | 70 | def process_output(self, output: GeneratorOutput): 71 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output) 72 | prediction = Counter(full_prediction).most_common(1)[0][0] 73 | prediction_stats = self.count_prediction_frequency(predictions=full_prediction) 74 | 75 | return prediction, full_prediction, prediction_stats 76 | 77 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 78 | answer = data['answer'] 79 | 80 | # step 1. get router, get generator 81 | router_result = router.route(question=data['prompt']) 82 | generator = GeneratorFactory.create_generator( 83 | experts=router_result, generator_config=generator_config 84 | ) # type: ignore 85 | 86 | # step 2. generate & update token usage 87 | if generator_config['type'] == 'model_switch': 88 | output: tuple[GeneratorOutput, GeneratorOutput] = generator.generate(question=data['prompt']) 89 | first_output, final_output = output 90 | prediction, full_prediction, prediction_stats = self.process_output(output=first_output) 91 | consistency_rate = prediction_stats[prediction]['frequency'] 92 | if consistency_rate < generator.consistency_rate_threshold: 93 | prediction, full_prediction, prediction_stats = self.process_output(output=final_output) 94 | output = final_output 95 | else: 96 | output = first_output 97 | elif generator_config['type'] == 'fast_slow': 98 | output: GeneratorOutput = generator.generate(question=data['prompt']) 99 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 100 | consistency_rate = prediction_stats[prediction]['frequency'] 101 | if consistency_rate < generator.consistency_rate_threshold: 102 | slow_output = generator.slow_generate(question=data['prompt']) 103 | prediction, full_prediction, prediction_stats = self.process_output(output=slow_output) 104 | if prediction == "": 105 | logger.warning(f"slow_output is empty, use fast_output to replace.") 106 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 107 | else: 108 | output = slow_output 109 | else: 110 | output: GeneratorOutput = generator.generate(question=data['prompt']) 111 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 112 | 113 | # step 3. update token usage 114 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 115 | 116 | # step 4. calculate is_correct 117 | is_correct = answer == prediction 118 | 119 | return dict( 120 | index=index, 121 | query=data['prompt'], 122 | origin_query=data['question'], 123 | prediction=prediction, 124 | full_prediction=full_prediction, 125 | prediction_stats=prediction_stats, 126 | raw_output=output.raw_output, 127 | answer=answer, 128 | is_correct=is_correct, 129 | model_name=generator.model 130 | ) 131 | 132 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 133 | start_time = time.time() 134 | data = self.load_data(split="test") 135 | 136 | counter = 0 137 | results = [] 138 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 139 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 140 | futures = [ 141 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 142 | for idx, d in enumerate(data) 143 | ] 144 | for future in as_completed(futures): 145 | result = future.result() 146 | results.append(result) 147 | if result['is_correct']: 148 | counter += 1 149 | pbar.update(1) 150 | pbar.close() 151 | 152 | model_counts = self.calculate_model_counts(results=results) 153 | 154 | acc = counter / len(data) 155 | end_time = time.time() 156 | logger.info(f"Task: {self.task}") 157 | logger.info(f"Accuracy: {acc}") 158 | logger.info(f"Time taken: {end_time - start_time} seconds") 159 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 160 | logger.info(f"Completion tokens: {self.completion_tokens}") 161 | 162 | return { 163 | "performance": acc, 164 | "time_taken": end_time - start_time, 165 | "prompt_tokens": self.prompt_tokens, 166 | "completion_tokens": self.completion_tokens, 167 | "model_counts": model_counts, 168 | "records": results 169 | } -------------------------------------------------------------------------------- /evaluate/BrainTeaser/brainteaser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | 16 | disable_progress_bars() 17 | 18 | DATA_DIR = "data/BrainTeaser" 19 | 20 | PROMPT = """Answer the following question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of {candidates}. 21 | 22 | {question} 23 | 24 | {options} 25 | 26 | Let's think step by step.""" 27 | 28 | ANSWER_PATTERN = r"(?i)Answer\s*:\s*([A-E])[.\s\n]?" 29 | 30 | class BrainTeaserEvaluator(BaseEvaluator): 31 | def __init__(self, max_workers: int=8, mode: str="test"): 32 | super().__init__(max_workers=max_workers, mode=mode) 33 | self.task = "BrainTeaser" 34 | self.seed = 42 35 | 36 | def load_data(self, split: str): 37 | data = self.load_jsonl(os.path.join(DATA_DIR, f"rs_dev.jsonl")) 38 | 39 | data = Dataset.from_list(data) 40 | data = data.map(lambda x: self.format_prompt(x)) 41 | 42 | if self.mode == "test": 43 | # split data into train and test 44 | logger.warning(f"Split data into train and test for {self.task}") 45 | split_data = data.train_test_split(test_size=0.3) 46 | train_data = split_data["train"] 47 | data = split_data["test"] 48 | logger.info(f"Calibration data: {len(train_data)}") 49 | logger.info(f"Test data: {len(data)}") 50 | 51 | return data 52 | 53 | def format_prompt(self, item: Dict) -> Dict: 54 | question = item["question"] 55 | 56 | stem = question["stem"] 57 | choices = question["choices"] 58 | candidate = ''.join([c['label'] for c in choices]) 59 | options = "\n".join([f"{c['label']}) {c['text']}" for c in choices]) 60 | prompt = PROMPT.format( 61 | question = stem, 62 | candidates = candidate, 63 | options = options 64 | ) 65 | return {"prompt": prompt, "answer": item["answerKey"]} 66 | 67 | def extract_raw_answer(self, raw_datas: list[str]) -> list[str]: 68 | return [ 69 | self.extract_normal_answer(text=data, answer_pattern=ANSWER_PATTERN) 70 | for data in raw_datas 71 | ] 72 | 73 | def process_output(self, output: GeneratorOutput): 74 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output) 75 | prediction = Counter(full_prediction).most_common(1)[0][0] 76 | prediction_stats = self.count_prediction_frequency(predictions=full_prediction) 77 | 78 | return prediction, full_prediction, prediction_stats 79 | 80 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 81 | answer = data['answer'] 82 | 83 | # step 1. get router, get generator 84 | router_result = router.route(question=data['prompt']) 85 | generator = GeneratorFactory.create_generator( 86 | experts=router_result, generator_config=generator_config 87 | ) # type: ignore 88 | 89 | # step 2. generate & update token usage 90 | if generator_config['type'] == 'model_switch': 91 | output: tuple[GeneratorOutput, GeneratorOutput] = generator.generate(question=data['prompt']) 92 | first_output, final_output = output 93 | prediction, full_prediction, prediction_stats = self.process_output(output=first_output) 94 | consistency_rate = prediction_stats[prediction]['frequency'] 95 | if consistency_rate < generator.consistency_rate_threshold: 96 | prediction, full_prediction, prediction_stats = self.process_output(output=final_output) 97 | output = final_output 98 | else: 99 | output = first_output 100 | elif generator_config['type'] == 'fast_slow': 101 | output: GeneratorOutput = generator.generate(question=data['prompt']) 102 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 103 | consistency_rate = prediction_stats[prediction]['frequency'] 104 | if consistency_rate < generator.consistency_rate_threshold: 105 | slow_output = generator.slow_generate(question=data['prompt']) 106 | prediction, full_prediction, prediction_stats = self.process_output(output=slow_output) 107 | if prediction == "": 108 | logger.warning(f"slow_output is empty, use fast_output to replace.") 109 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 110 | else: 111 | output = slow_output 112 | else: 113 | output: GeneratorOutput = generator.generate(question=data['prompt']) 114 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 115 | 116 | # step 3. update token usage 117 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 118 | 119 | # step 4. calculate is_correct 120 | is_correct = answer == prediction 121 | 122 | return dict( 123 | index=index, 124 | query=data['prompt'], 125 | origin_query=data['question'], 126 | prediction=prediction, 127 | full_prediction=full_prediction, 128 | prediction_stats=prediction_stats, 129 | raw_output=output.raw_output, 130 | answer=answer, 131 | is_correct=is_correct, 132 | model_name=generator.model 133 | ) 134 | 135 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 136 | start_time = time.time() 137 | data = self.load_data(split="test") 138 | 139 | counter = 0 140 | results = [] 141 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 142 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 143 | futures = [ 144 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 145 | for idx, d in enumerate(data) 146 | ] 147 | for future in as_completed(futures): 148 | result = future.result() 149 | results.append(result) 150 | if result['is_correct']: 151 | counter += 1 152 | pbar.update(1) 153 | pbar.close() 154 | 155 | model_counts = self.calculate_model_counts(results=results) 156 | 157 | acc = counter / len(data) 158 | end_time = time.time() 159 | logger.info(f"Task: {self.task}") 160 | logger.info(f"Accuracy: {acc}") 161 | logger.info(f"Time taken: {end_time - start_time} seconds") 162 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 163 | logger.info(f"Completion tokens: {self.completion_tokens}") 164 | 165 | return { 166 | "performance": acc, 167 | "time_taken": end_time - start_time, 168 | "prompt_tokens": self.prompt_tokens, 169 | "completion_tokens": self.completion_tokens, 170 | "model_counts": model_counts, 171 | "records": results 172 | } -------------------------------------------------------------------------------- /evaluate/EmoryNLP/emorynlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | 16 | disable_progress_bars() 17 | 18 | DATA_DIR = "data/EmoryNLP" 19 | 20 | PROMPT = """Given a conversation history and a current utterance, follow these steps to identify the emotion of the current utterance from the given options. The emotion should be determined based on both the conversation context and the current utterance. 21 | The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCDEFG. Let's think step by step. 22 | 23 | History: 24 | {history} 25 | 26 | Utterance: 27 | {utterance} 28 | 29 | Options: 30 | {options}""" 31 | 32 | ANSWER_PATTERN = r"(?i)Answer\s*:\s*([A-G])[.\s\n]?" 33 | 34 | class EmoryNLPEvaluator(BaseEvaluator): 35 | def __init__(self, max_workers: int = 8, mode: str="test"): 36 | super().__init__(max_workers=max_workers, mode=mode) 37 | self.task = "EmoryNLP" 38 | self.seed = 42 39 | 40 | def load_data(self, split: str): 41 | data = self.load_jsonl(os.path.join(DATA_DIR, f"{split}.json")) 42 | 43 | data = Dataset.from_list(data) 44 | data = data.map(lambda x: self.format_prompt(x)) 45 | 46 | if self.mode == "test": 47 | # split data into train and test 48 | logger.warning(f"Split data into train and test for {self.task}") 49 | split_data = data.train_test_split(test_size=0.3) 50 | train_data = split_data["train"] 51 | data = split_data["test"] 52 | logger.info(f"Calibration data: {len(train_data)}") 53 | logger.info(f"Test data: {len(data)}") 54 | 55 | return data 56 | 57 | def format_prompt(self, item: Dict) -> Dict: 58 | prompt = PROMPT.format( 59 | history = "- "+"\n- ".join(item["history"]), 60 | utterance = item["utterance"], 61 | options = "\n".join([f"{key}. {value}" for key, value in item["candidate"].items()]) 62 | ) 63 | return {"prompt": prompt} 64 | 65 | def extract_raw_answer(self, raw_datas: list[str]) -> list[str]: 66 | return [ 67 | self.extract_normal_answer(text=data, answer_pattern=ANSWER_PATTERN) 68 | for data in raw_datas 69 | ] 70 | 71 | def process_output(self, output: GeneratorOutput): 72 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output) 73 | prediction = Counter(full_prediction).most_common(1)[0][0] 74 | prediction_stats = self.count_prediction_frequency(predictions=full_prediction) 75 | 76 | return prediction, full_prediction, prediction_stats 77 | 78 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 79 | answer = data['answer'] 80 | 81 | # step 1. get router, get generator 82 | router_result = router.route(question=data['prompt']) 83 | generator = GeneratorFactory.create_generator( 84 | experts=router_result, generator_config=generator_config 85 | ) # type: ignore 86 | 87 | # step 2. generate & update token usage 88 | if generator_config['type'] == 'model_switch': 89 | output: tuple[GeneratorOutput, GeneratorOutput] = generator.generate(question=data['prompt']) 90 | first_output, final_output = output 91 | prediction, full_prediction, prediction_stats = self.process_output(output=first_output) 92 | consistency_rate = prediction_stats[prediction]['frequency'] 93 | if consistency_rate < generator.consistency_rate_threshold: 94 | prediction, full_prediction, prediction_stats = self.process_output(output=final_output) 95 | output = final_output 96 | else: 97 | output = first_output 98 | elif generator_config['type'] == 'fast_slow': 99 | output: GeneratorOutput = generator.generate(question=data['prompt']) 100 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 101 | consistency_rate = prediction_stats[prediction]['frequency'] 102 | if consistency_rate < generator.consistency_rate_threshold: 103 | slow_output = generator.slow_generate(question=data['prompt']) 104 | prediction, full_prediction, prediction_stats = self.process_output(output=slow_output) 105 | if prediction == "": 106 | logger.warning(f"slow_output is empty, use fast_output to replace.") 107 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 108 | else: 109 | output = slow_output 110 | else: 111 | output: GeneratorOutput = generator.generate(question=data['prompt']) 112 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 113 | 114 | # step 3. update token usage 115 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 116 | 117 | # step 4. calculate is_correct 118 | is_correct = answer == prediction 119 | return dict( 120 | index=index, 121 | query=data['prompt'], 122 | origin_query=None, 123 | prediction=prediction, 124 | full_prediction=full_prediction, 125 | prediction_stats=prediction_stats, 126 | raw_output=output.raw_output, 127 | answer=answer, 128 | is_correct=is_correct, 129 | model_name=generator.model 130 | ) 131 | 132 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 133 | start_time = time.time() 134 | data = self.load_data(split="test") 135 | counter = 0 136 | results = [] 137 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 138 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 139 | futures = [ 140 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 141 | for idx, d in enumerate(data) 142 | ] 143 | for future in as_completed(futures): 144 | result = future.result() 145 | results.append(result) 146 | if result['is_correct']: 147 | counter += 1 148 | pbar.update(1) 149 | pbar.close() 150 | 151 | model_counts = self.calculate_model_counts(results=results) 152 | 153 | acc = counter / len(data) 154 | end_time = time.time() 155 | logger.info(f"Task: {self.task}") 156 | logger.info(f"Accuracy: {acc}") 157 | logger.info(f"Time taken: {end_time - start_time} seconds") 158 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 159 | logger.info(f"Completion tokens: {self.completion_tokens}") 160 | 161 | return { 162 | "performance": acc, 163 | "time_taken": end_time - start_time, 164 | "prompt_tokens": self.prompt_tokens, 165 | "completion_tokens": self.completion_tokens, 166 | "model_counts": model_counts, 167 | "records": results 168 | } -------------------------------------------------------------------------------- /evaluate/MELD/meld.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from collections import Counter 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Dict 7 | 8 | from datasets import Dataset, disable_progress_bars 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from core.inference import GeneratorFactory, GeneratorOutput 13 | from core.routing import BaseRouter 14 | from evaluate.base_evaluator import BaseEvaluator 15 | 16 | disable_progress_bars() 17 | 18 | DATA_DIR = "data/MELD" 19 | 20 | PROMPT = """Given a conversation history and a current utterance, follow these steps to identify the emotion of the current utterance from the given options. The emotion should be determined based on both the conversation context and the current utterance. 21 | The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCDEFG. Let's think step by step. 22 | 23 | History: 24 | {history} 25 | 26 | Utterance: 27 | {utterance} 28 | 29 | Options: 30 | {options}""" 31 | 32 | ANSWER_PATTERN = r"(?i)Answer\s*:\s*([A-G])[.\s\n]?" 33 | 34 | class MELDEvaluator(BaseEvaluator): 35 | def __init__(self, max_workers: int = 8, mode: str="test"): 36 | super().__init__(max_workers=max_workers, mode=mode) 37 | self.task = "MELD" 38 | self.seed = 42 39 | 40 | def load_data(self, split: str): 41 | data = self.load_jsonl(os.path.join(DATA_DIR, f"{split}.json")) 42 | 43 | data = Dataset.from_list(data) 44 | data = data.map(lambda x: self.format_prompt(x)) 45 | 46 | if self.mode == "test": 47 | # split data into train and test 48 | logger.warning(f"Split data into train and test for {self.task}") 49 | split_data = data.train_test_split(test_size=0.3) 50 | train_data = split_data["train"] 51 | data = split_data["test"] 52 | logger.info(f"Calibration data: {len(train_data)}") 53 | logger.info(f"Test data: {len(data)}") 54 | 55 | return data 56 | def format_prompt(self, item: Dict) -> Dict: 57 | prompt = PROMPT.format( 58 | history = "- "+"\n- ".join(item["history"]), 59 | utterance = item["utterance"], 60 | options = "\n".join([f"{key}. {value}" for key, value in item["candidate"].items()]) 61 | ) 62 | return {"prompt": prompt} 63 | 64 | def extract_raw_answer(self, raw_datas: list[str]) -> list[str]: 65 | return [ 66 | self.extract_normal_answer(text=data, answer_pattern=ANSWER_PATTERN) 67 | for data in raw_datas 68 | ] 69 | 70 | def process_output(self, output: GeneratorOutput): 71 | full_prediction = self.extract_raw_answer(raw_datas=output.raw_output) 72 | prediction = Counter(full_prediction).most_common(1)[0][0] 73 | prediction_stats = self.count_prediction_frequency(predictions=full_prediction) 74 | 75 | return prediction, full_prediction, prediction_stats 76 | 77 | def evaluate(self, index: int, data: dict, router: BaseRouter, generator_config: dict): 78 | answer = data['answer'] 79 | 80 | # step 1. get router, get generator 81 | router_result = router.route(question=data['prompt']) 82 | generator = GeneratorFactory.create_generator( 83 | experts=router_result, generator_config=generator_config 84 | ) # type: ignore 85 | 86 | # step 2. generate & update token usage 87 | if generator_config['type'] == 'model_switch': 88 | output: tuple[GeneratorOutput, GeneratorOutput] = generator.generate(question=data['prompt']) 89 | first_output, final_output = output 90 | prediction, full_prediction, prediction_stats = self.process_output(output=first_output) 91 | consistency_rate = prediction_stats[prediction]['frequency'] 92 | if consistency_rate < generator.consistency_rate_threshold: 93 | prediction, full_prediction, prediction_stats = self.process_output(output=final_output) 94 | output = final_output 95 | else: 96 | output = first_output 97 | elif generator_config['type'] == 'fast_slow': 98 | output: GeneratorOutput = generator.generate(question=data['prompt']) 99 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 100 | consistency_rate = prediction_stats[prediction]['frequency'] 101 | if consistency_rate < generator.consistency_rate_threshold: 102 | slow_output = generator.slow_generate(question=data['prompt']) 103 | prediction, full_prediction, prediction_stats = self.process_output(output=slow_output) 104 | if prediction == "": 105 | logger.warning(f"slow_output is empty, use fast_output to replace.") 106 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 107 | else: 108 | output = slow_output 109 | else: 110 | output: GeneratorOutput = generator.generate(question=data['prompt']) 111 | prediction, full_prediction, prediction_stats = self.process_output(output=output) 112 | 113 | # step 3. update token usage 114 | self.update_tokens(prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens) 115 | 116 | # step 4. calculate is_correct 117 | is_correct = answer == prediction 118 | 119 | return dict( 120 | index=index, 121 | query=data['prompt'], 122 | origin_query=None, 123 | prediction=prediction, 124 | full_prediction=full_prediction, 125 | prediction_stats=prediction_stats, 126 | raw_output=output.raw_output, 127 | answer=answer, 128 | is_correct=is_correct, 129 | model_name=generator.model 130 | ) 131 | 132 | def evaluate_loop(self, router: BaseRouter, generator_config: dict): 133 | start_time = time.time() 134 | data = self.load_data(split="test") 135 | 136 | counter = 0 137 | results = [] 138 | pbar = tqdm(total=len(data), desc=f"Evaluating {self.task} ...") 139 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 140 | futures = [ 141 | executor.submit(self.evaluate, index=idx, data=d, router=router, generator_config=generator_config) 142 | for idx, d in enumerate(data) 143 | ] 144 | for future in as_completed(futures): 145 | result = future.result() 146 | results.append(result) 147 | if result['is_correct']: 148 | counter += 1 149 | pbar.update(1) 150 | pbar.close() 151 | 152 | model_counts = self.calculate_model_counts(results=results) 153 | 154 | acc = counter / len(data) 155 | end_time = time.time() 156 | logger.info(f"Task: {self.task}") 157 | logger.info(f"Accuracy: {acc}") 158 | logger.info(f"Time taken: {end_time - start_time} seconds") 159 | logger.info(f"Prompt tokens: {self.prompt_tokens}") 160 | logger.info(f"Completion tokens: {self.completion_tokens}") 161 | 162 | return { 163 | "performance": acc, 164 | "time_taken": end_time - start_time, 165 | "prompt_tokens": self.prompt_tokens, 166 | "completion_tokens": self.completion_tokens, 167 | "model_counts": model_counts, 168 | "records": results 169 | } --------------------------------------------------------------------------------