├── Herald ├── pipeline │ ├── __init__.py │ └── run.py ├── service │ ├── __init__.py │ ├── handler │ │ ├── __init__.py │ │ ├── prover_handlerr.py │ │ ├── back_http_handler.py │ │ ├── back_handler.py │ │ └── tran_handler.py │ ├── parallel_http_service.py │ ├── pipeline_service.py │ └── parallel_service.py ├── conf │ ├── __init__.py │ └── config.py ├── data │ ├── example_result │ │ ├── back_trans.json │ │ └── translate.json │ └── example │ │ └── simp_10.jsonl ├── util │ ├── __init__.py │ ├── profiler.py │ ├── http_util.py │ ├── string_util.py │ └── common_util.py └── README.md ├── Realprover ├── experiment │ ├── __init__.py │ ├── score.py │ ├── examples │ │ └── default.toml │ ├── resume.py │ └── run.py ├── herald │ ├── __init__.py │ ├── run.py │ └── pipeline_prover.py ├── manager │ ├── __init__.py │ ├── struct │ │ ├── __init__.py │ │ └── structs.py │ ├── thirdparty │ │ ├── tests │ │ │ └── test_interactive.py │ │ ├── __init__.py │ │ ├── critic.py │ │ ├── lean_search.py │ │ ├── claude.py │ │ ├── verifier.py │ │ ├── generator.py │ │ └── interactive.py │ ├── search │ │ ├── __init__.py │ │ ├── exception.py │ │ ├── beam_search.py │ │ ├── best_first.py │ │ └── mcts_search.py │ ├── manage │ │ ├── __init__.py │ │ ├── model_manage.py │ │ ├── prompt_manage.py │ │ └── proof_parse_manage.py │ └── service │ │ ├── __init__.py │ │ ├── base_service.py │ │ ├── pipeline_main_service.py │ │ └── batch_main_service.py ├── util │ ├── debug_util.py │ ├── __init__.py │ ├── profiler.py │ ├── http_util.py │ ├── log_util.py │ ├── string_util.py │ └── common_util.py ├── conf │ ├── __init__.py │ └── config.py ├── server.py └── README.md ├── LeanSearch-PS-inference ├── tests │ ├── __init__.py │ ├── test_local_run.py │ └── test_request.py ├── conf │ ├── __init__.py │ └── config.py ├── util │ ├── __init__.py │ └── request_util.py ├── worker │ ├── __init__.py │ └── premise_selector.py ├── server.py └── README.md ├── .gitignore ├── LeanSearch-PS ├── examples │ └── train.sh ├── README.md └── build_training_data.py ├── LICENSE └── README.md /Herald/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Realprover/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Realprover/herald/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Realprover/manager/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LeanSearch-PS-inference/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Realprover/util/debug_util.py: -------------------------------------------------------------------------------- 1 | from manager.struct import Node 2 | -------------------------------------------------------------------------------- /Realprover/manager/struct/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .structs import * 4 | -------------------------------------------------------------------------------- /Realprover/manager/thirdparty/tests/test_interactive.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | -------------------------------------------------------------------------------- /LeanSearch-PS-inference/conf/__init__.py: -------------------------------------------------------------------------------- 1 | """ This module provides configurations """ -------------------------------------------------------------------------------- /LeanSearch-PS-inference/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .request_util import RequestUtil 2 | -------------------------------------------------------------------------------- /LeanSearch-PS-inference/worker/__init__.py: -------------------------------------------------------------------------------- 1 | from .premise_selector import PremiseSelector 2 | -------------------------------------------------------------------------------- /Realprover/manager/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch 2 | from .best_first import BestFirstSearch 3 | from .mcts_search import MCTSSearch 4 | -------------------------------------------------------------------------------- /Realprover/manager/manage/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .prompt_manage import PromptManage 4 | from .model_manage import ModelManage 5 | from .proof_parse_manage import ProofParseManage 6 | -------------------------------------------------------------------------------- /Herald/service/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .pipeline_service import PipelineService 3 | from .parallel_service import ParallelService 4 | from .parallel_http_service import ParallelHttpService 5 | -------------------------------------------------------------------------------- /Realprover/manager/service/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_service import BaseService 2 | from .batch_main_service import BatchMainService 3 | from .pipeline_main_service import PipelineMainService 4 | -------------------------------------------------------------------------------- /Herald/service/handler/__init__.py: -------------------------------------------------------------------------------- 1 | from .tran_handler import TranHandler 2 | from .back_handler import BackHandler 3 | from .prover_handlerr import ProverHandler 4 | from .back_http_handler import BackHttpHandler 5 | 6 | 7 | -------------------------------------------------------------------------------- /Realprover/manager/thirdparty/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .interactive import Interactive 3 | from .lean_search import LeanSearch 4 | from .claude import Claude 5 | from .generator import TacticGenerator 6 | from .critic import Critic 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | __pycache__ 3 | .vscode 4 | .idea 5 | generated/** 6 | experiment/logs/** 7 | outputs/** 8 | figs/** 9 | config.test.json 10 | kto_data/** 11 | **.tar.gz 12 | oooops/** 13 | md/* 14 | temp/* 15 | statistic/* 16 | print_tree.py 17 | alg-test-temp.jsonl -------------------------------------------------------------------------------- /Herald/conf/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | # 4 | # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved 5 | # 6 | ################################################################################ 7 | """ 8 | This module provide ... 9 | 10 | """ -------------------------------------------------------------------------------- /Realprover/experiment/score.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from manager.manage import ProofParseManage 3 | 4 | result_dir = sys.argv[1] 5 | ProofParseManage.get_stats(result_dir) 6 | ProofParseManage.visualize_all_proof_trees(result_dir, keep_false=False) 7 | ProofParseManage.get_all_correct_proofs(result_dir) 8 | # ProofParseManage.get_demo_data(result_dir) -------------------------------------------------------------------------------- /Realprover/conf/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | # 4 | # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved 5 | # 6 | ################################################################################ 7 | """ 8 | This module provide ... 9 | 10 | """ -------------------------------------------------------------------------------- /Herald/data/example_result/back_trans.json: -------------------------------------------------------------------------------- 1 | { 2 | "informal_statement": "Prove that the following is a group, \n\n Given group $(G,*)$. We define $(G^\\mathrm{op},\\circ)$ as follows: \n\n \n\n $\\bullet$ Elements are the set $G$. \n\n\n\n $\\bullet$ Multiplication $\\circ$ given by $g\\circ h:=h*g$. \n\n\n\n $\\bullet$ Identity and inverse same as in $G$.", 3 | "translate_list": [], 4 | "back_trans_list": [] 5 | } -------------------------------------------------------------------------------- /Herald/data/example_result/translate.json: -------------------------------------------------------------------------------- 1 | { 2 | "informal_statement": "Prove that the following is a group, \n\n Given group $(G,*)$. We define $(G^\\mathrm{op},\\circ)$ as follows: \n\n \n\n $\\bullet$ Elements are the set $G$. \n\n\n\n $\\bullet$ Multiplication $\\circ$ given by $g\\circ h:=h*g$. \n\n\n\n $\\bullet$ Identity and inverse same as in $G$.", 3 | "translate_list": [], 4 | "back_trans_list": [] 5 | } -------------------------------------------------------------------------------- /LeanSearch-PS-inference/util/request_util.py: -------------------------------------------------------------------------------- 1 | class RequestUtil(object): 2 | @staticmethod 3 | def gen_success_data(data): 4 | return { 5 | 'error': False, 6 | 'msg': '', 7 | 'data': data 8 | } 9 | @staticmethod 10 | def gen_fail_data(message): 11 | return { 12 | 'error': True, 13 | 'msg': message, 14 | 'data': {} 15 | } 16 | -------------------------------------------------------------------------------- /Realprover/manager/manage/model_manage.py: -------------------------------------------------------------------------------- 1 | import conf.config 2 | 3 | 4 | class ModelManage(object): 5 | 6 | @staticmethod 7 | def contain_local(model_list): 8 | """ 9 | 10 | """ 11 | return conf.config.MODEL_TYPE_LOCAL in model_list 12 | 13 | @staticmethod 14 | def contain_gemini(model_list): 15 | """ 16 | 17 | """ 18 | return conf.config.MODEL_TYPE_GEMINI in model_list 19 | -------------------------------------------------------------------------------- /Herald/util/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | # 4 | # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved 5 | # 6 | ################################################################################ 7 | """ 8 | This module provide ... 9 | 10 | """ 11 | 12 | from .common_util import CommonUtil 13 | from .http_util import HttpUtil 14 | from .string_util import StringUtil 15 | from .profiler import profiler 16 | -------------------------------------------------------------------------------- /LeanSearch-PS-inference/tests/test_local_run.py: -------------------------------------------------------------------------------- 1 | import conf.config 2 | from worker import PremiseSelector 3 | 4 | def local_run(): 5 | premise_selector = PremiseSelector() 6 | related_theorems = premise_selector.retrieve(conf.config.TEST_QUERY, num=5) 7 | for theorems in related_theorems: 8 | for theorem in theorems: 9 | print(theorem) 10 | print("-" * 40) 11 | 12 | if __name__ == '__main__': 13 | local_run() 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /LeanSearch-PS-inference/conf/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provide configurations. 3 | """ 4 | INDEX_PATH = "path/to/faiss/index" 5 | TOKENIZER_PATH = "intfloat/e5-mistral-7b-instruct" 6 | MODEL_PATH = "path/to/embedding/model" 7 | ANSWER_PATH = "path/to/answer/data" 8 | 9 | 10 | EXPIRE_TIME = 5 * 60 11 | 12 | TEST_QUERY = '''G : Type u_1 13 | inst✝ : Group G 14 | a b c : G 15 | ⊢ (fun x => a * x * b = c) (a⁻¹ * c * b⁻¹) ∧ ∀ (y : G), (fun x => a * x * b = c) y → y = a⁻¹ * c * b⁻¹''' 16 | 17 | 18 | -------------------------------------------------------------------------------- /Realprover/util/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | # 4 | # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved 5 | # 6 | ################################################################################ 7 | """ 8 | This module provide ... 9 | 10 | """ 11 | 12 | from .common_util import CommonUtil 13 | from .profiler import profiler 14 | from .log_util import LogUtil 15 | from .string_util import StringUtil 16 | from .http_util import HttpUtil 17 | -------------------------------------------------------------------------------- /Realprover/manager/thirdparty/critic.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from manager.thirdparty import Claude 3 | 4 | class Critic: 5 | def __init__(self, model='claude'): 6 | if model == 'claude': 7 | self.client = Claude() 8 | else: 9 | raise NotImplementedError 10 | 11 | def get_critics(self, sass: list[tuple[str, str, str]]) -> list[bool]: 12 | results = asyncio.run(self.client.get_claude_critics(sass)) 13 | return [False if r=="FALSE" else True for r in results] 14 | 15 | 16 | def get_critic(self, s1: str, t: str, s2: str) -> bool: 17 | return self.get_critics([(s1, t, s2)])[-1] -------------------------------------------------------------------------------- /LeanSearch-PS-inference/tests/test_request.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import conf.config 3 | 4 | REQUEST_URL = 'http://localhost:8080/retrieve_premises' 5 | 6 | def get_param(): 7 | params = { 8 | 'query': conf.config.TEST_QUERY, 9 | 'num': 20 10 | } 11 | return params 12 | 13 | def get_from_http(): 14 | data = get_param() 15 | response = requests.post(REQUEST_URL, json=data) 16 | if response.status_code == 200: 17 | print('Response JSON:', response.json()) 18 | result = response.json()['data'] 19 | print(f"size = {len(result[0])}") 20 | else: 21 | print('Error:', response.json()) 22 | 23 | 24 | if __name__ == '__main__': 25 | get_from_http() 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /Herald/pipeline/run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from util import profiler, CommonUtil 3 | from service import ParallelHttpService 4 | import argparse 5 | def run_http_parallel(args): 6 | """ 7 | 多卡并行运行: 翻译使用http接口 8 | """ 9 | 10 | source_file = args.source_file 11 | result_dir = args.result_dir 12 | 13 | profiler.start(f"pipeline_http_parallel") 14 | parallel_service = ParallelHttpService(source_file=source_file, result_dir=result_dir, trans_gpus=[0]) 15 | parallel_service.run() 16 | profiler.stop(f"pipeline_http_parallel") 17 | 18 | if __name__=='__main__': 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--source_file', default=None,required=True) 21 | parser.add_argument('--result_dir', default=None,required=True) 22 | args = parser.parse_args() 23 | run_http_parallel(args) -------------------------------------------------------------------------------- /Realprover/experiment/examples/default.toml: -------------------------------------------------------------------------------- 1 | info = "bf_minif2f_test_pass64" 2 | 3 | [model] 4 | prover_model_id = "realprover" 5 | prover_model_path = "/path/to/your/model" 6 | prover_model_params = { temperature = 1.5, max_tokens = 256, top_p = 0.9, logprobs = 1 } 7 | is_incontext = true 8 | template = 'qwen' 9 | use_retrieval = true 10 | 11 | [data] 12 | data_id = "minif2f_test" 13 | data_path = "data/minif2f_test.jsonl" 14 | 15 | [search] 16 | num_samples = 64 17 | max_depth = 128 18 | max_calls = 1024 19 | max_retries = 64 20 | max_nodes = 1024 21 | abandon_if_contain = ["sorry", "admit", "apply?"] 22 | 23 | [beam_search_params] 24 | use_beam_search = false 25 | beam_width = 3 26 | 27 | [mcts_params] 28 | use_mcts_search = false 29 | sim_depth = 0 30 | c_puct = 10.0 31 | c_score = 1.0 32 | c_expansion_fail_penalty = 30.0 33 | max_root_expansion = 5 34 | -------------------------------------------------------------------------------- /LeanSearch-PS-inference/server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | from worker import PremiseSelector 3 | from util import RequestUtil 4 | 5 | app = Flask(__name__) 6 | 7 | premise_selector = PremiseSelector() 8 | premise_selector.init_model() 9 | 10 | @app.route('/retrieve_premises', methods=['POST']) 11 | def to_formal_statement(): 12 | data = request.get_json() 13 | if not data or 'query' not in data: 14 | return jsonify(RequestUtil.gen_fail_data('Missing param \'query\'')), 400 15 | num = data.get('num', 5) 16 | related_theorems = premise_selector.retrieve(data['query'], num=num) 17 | return jsonify(RequestUtil.gen_success_data(related_theorems)), 200 18 | 19 | # Start http server 20 | if __name__ == '__main__': 21 | print('start server on 8080') 22 | app.run(host='0.0.0.0', port=8080, debug=False, use_reloader=False) 23 | -------------------------------------------------------------------------------- /Realprover/herald/run.py: -------------------------------------------------------------------------------- 1 | """" 2 | 使用Herald-pipeline 跑出的数据继续stepprover 3 | """ 4 | import sys 5 | import torch 6 | from manager.service import PipelineMainService 7 | import argparse 8 | 9 | def run_herald_stepprover(args): 10 | source_file = args.source_file 11 | result_dir = args.result_file 12 | gpus = torch.cuda.device_count() 13 | assert gpus >= 1 14 | # 使用全部可用的卡,第一张卡默认用作翻译 15 | gpu_list = list(range(1, gpus)) 16 | print(gpu_list) 17 | pipeline_service = PipelineMainService(source_file=source_file, result_dir=result_dir, gpus_list=gpu_list) 18 | # gpu_list = list(range(torch.cuda.device_count())) 19 | 20 | pipeline_service.run_pipeline_prover() 21 | 22 | if __name__=='__main__': 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--source_file', default=None,required=True) 25 | parser.add_argument('--result_dir', default=None,required=True) 26 | args = parser.parse_args() 27 | run_herald_stepprover(args) -------------------------------------------------------------------------------- /Realprover/herald/pipeline_prover.py: -------------------------------------------------------------------------------- 1 | """" 2 | 使用Herald-pipeline 跑出的数据继续stepprover 3 | """ 4 | import torch 5 | 6 | from manager.service import PipelineMainService 7 | from util import CommonUtil, profiler 8 | 9 | 10 | FILE_DIR = "data" 11 | # source_file = f"{FILE_DIR}/example/simp.jsonl" 12 | # result_dir = f"{FILE_DIR}/example_result/simp" 13 | 14 | source_file = f"{FILE_DIR}/example/simp_10.jsonl" 15 | result_dir = f"{FILE_DIR}/example_result/simp_10" 16 | 17 | 18 | 19 | def run_herald_stepprover(): 20 | gpus = torch.cuda.device_count() 21 | assert gpus >= 1 22 | # 默认使用全部可用的gpu卡,可以自己配置 23 | gpu_list = list(range(gpus)) 24 | print(gpu_list) 25 | 26 | pipeline_service = PipelineMainService(source_file=source_file, result_dir=result_dir, gpus_list=gpu_list) 27 | # gpu_list = list(range(torch.cuda.device_count())) 28 | 29 | pipeline_service.run_pipeline_prover() 30 | 31 | 32 | if __name__ == '__main__': 33 | profiler.start("run_prover") 34 | run_herald_stepprover() 35 | profiler.stop("run_prover") 36 | -------------------------------------------------------------------------------- /LeanSearch-PS/examples/train.sh: -------------------------------------------------------------------------------- 1 | deepspeed --include localhost:0,1,2,3 --master_port 60000 --module tevatron.retriever.driver.train \ 2 | --deepspeed ./deepspeed/ds_zero3_config.json \ 3 | --output_dir ./checkpoints/output_dir \ 4 | --model_name_or_path ./models/E5-Mistral-7B-Instruct \ 5 | --lora \ 6 | --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ 7 | --save_steps 100 \ 8 | --dataset_name ./datasets/dataset_name \ 9 | --query_prefix 'Given a lean4 context with two states, retrieve Lean4 theorems useful to transfer the state from "state before" and "state after": ' \ 10 | --passage_prefix "Lean4 theorem: " \ 11 | --bf16 \ 12 | --pooling eos \ 13 | --append_eos_token \ 14 | --normalize \ 15 | --temperature 0.01 \ 16 | --per_device_train_batch_size 8 \ 17 | --gradient_checkpointing \ 18 | --train_group_size 16 \ 19 | --learning_rate 2e-5 \ 20 | --query_max_len 128 \ 21 | --passage_max_len 256 \ 22 | --num_train_epochs 1 \ 23 | --logging_steps 10 \ 24 | --overwrite_output_dir \ 25 | --gradient_accumulation_steps 4 26 | -------------------------------------------------------------------------------- /Realprover/manager/thirdparty/lean_search.py: -------------------------------------------------------------------------------- 1 | import time 2 | import conf.config 3 | from util import StringUtil, HttpUtil 4 | 5 | REQUEST_URL = conf.config.API_CONFIG['lean_search'] 6 | 7 | 8 | class LeanSearch: 9 | 10 | @staticmethod 11 | def get_related_theorem(query: str): 12 | data = LeanSearch.get_param(query) 13 | result = HttpUtil.post(url=REQUEST_URL, json=data) 14 | return result['data'][0] # type: ignore 15 | 16 | @staticmethod 17 | def get_param(query: str, num: int=conf.config.NUM_QUERYS): 18 | params = { 19 | 'timestamp': time.time(), 20 | 'query': query, 21 | 'num': num 22 | } 23 | params['sign'] = StringUtil.gen_sign(params) 24 | return params 25 | 26 | @staticmethod 27 | def get_related_theorem_batch(queries: list[str]): 28 | data = [LeanSearch.get_param(q) for q in queries] 29 | result = HttpUtil.post(url=REQUEST_URL, json=data) 30 | return result['data'] # type: ignore 31 | 32 | if __name__ == '__main__': 33 | # TODO: write tests here! 34 | pass 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 FrenzyMath 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Herald/conf/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | 4 | ################################################################################ 5 | """ 6 | This module provide configure. 7 | """ 8 | 9 | AIFORMATH_PATH = "/AI4M" # 挂载磁盘的路径 10 | 11 | MODEL_CONFIG = { 12 | 'trans': "FrenzyMath/Herald_translator", # informal -> formal 13 | 'back_trans': "deepseek-ai/DeepSeek-V3", # formal -> informal 14 | 'compare': "deepseek-ai/DeepSeek-V3" 15 | } 16 | 17 | LEAN_TEST_PATH = "/path/to/lean_test" 18 | DEFAULT_LAKE_PATH = '/path/to/lake' 19 | 20 | 21 | NIM_CONFIG = { 22 | 'url': "https://api.deepseek.com/v1", 23 | 'key': "your api key" 24 | } 25 | 26 | # API_CONFIG = { 27 | # "step_prover": 'Your realprover address' #not use now 28 | # } 29 | 30 | 31 | 32 | # 线程数量控制 33 | THREAD_CONFIG = { 34 | "lean_build": 10, 35 | "same_check": 10, 36 | "proof": 5 37 | } 38 | 39 | TRAN_CONFIG = { 40 | 'sampling_params': dict( 41 | n=8, 42 | max_tokens=1024, 43 | temperature=0.99, 44 | top_p=0.99, 45 | ) 46 | } 47 | 48 | SALT = "ai-for-math" 49 | EXPIRE_TIME = 5 * 60 50 | -------------------------------------------------------------------------------- /LeanSearch-PS-inference/README.md: -------------------------------------------------------------------------------- 1 | # LeanSearch-PS Server 2 | 3 | ## Quickstart 4 | 5 | 1. Install dependencies 6 | 7 | ```bash 8 | pip install transformers==4.52.4 faiss-gpu==1.12 flask 9 | ``` 10 | 11 | 2. Download Models and FAISS index 12 | 13 | - LeanSearch-PS model: 14 | [https://huggingface.co/FrenzyMath/LeanSearch-PS](https://huggingface.co/FrenzyMath/LeanSearch-PS) 15 | 16 | - FAISS index: 17 | [https://huggingface.co/FrenzyMath/LeanSearch-PS-faiss](https://huggingface.co/FrenzyMath/LeanSearch-PS-faiss) 18 | 19 | 3. Configure the Server 20 | 21 | Update the configuration file at [`conf/config.py`](conf/config.py): 22 | ```python 23 | INDEX_PATH = "path/to/faiss/index" 24 | TOKENIZER_PATH = "intfloat/e5-mistral-7b-instruct" 25 | MODEL_PATH = "FrenzyMath/LeanSearch-PS" 26 | ANSWER_PATH = "path/to/answer/data" # answer.json in repo https://huggingface.co/FrenzyMath/LeanSearch-PS-faiss 27 | ``` 28 | 29 | 4. Start LeanSearch-PS server 30 | 31 | ```bash 32 | python server.py 33 | ``` 34 | 35 | 5. Test server avaibility 36 | 37 | ```bash 38 | python -m test.test_request 39 | ``` 40 | 41 | ## Run locally 42 | 43 | ```bash 44 | python -m test.test_local_run 45 | ``` 46 | -------------------------------------------------------------------------------- /Realprover/server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | from flask_cors import CORS 3 | 4 | from manager.service import BaseService 5 | 6 | from util import StringUtil, LogUtil 7 | 8 | app = Flask(__name__) 9 | log_util = LogUtil() 10 | 11 | base_service = BaseService() 12 | 13 | 14 | @app.route('/step-prover', methods=['POST']) 15 | def prover(): 16 | data = request.get_json() 17 | # check_result, message = StringUtil.check_param_valid(params=data) 18 | # log_util.info(f"check_result = {check_result}, message = {message}") 19 | # if not check_result: 20 | # return jsonify(StringUtil.gen_fail_data(message)), 400 21 | 22 | # 判断是否传入了正确的字段 23 | if not data or 'formal_statement' not in data: 24 | log_util.info(f"Missing param formal_statement") 25 | return jsonify(StringUtil.gen_fail_data('Missing param \'formal_statement\'')), 400 26 | 27 | ret = base_service.single_run(formal_statement=data['formal_statement']) 28 | 29 | return jsonify(StringUtil.gen_success_data(data=ret)), 200 30 | 31 | 32 | @app.route('/test', methods=['GET']) 33 | def test(): 34 | return StringUtil.gen_success_data({'data': 'stepprover is ok'}), 200 35 | 36 | 37 | CORS(app) 38 | # 启动Web服务 39 | if __name__ == '__main__': 40 | print('start server on 8080') 41 | app.run(host='0.0.0.0', port=8080, debug=False, use_reloader=False) 42 | -------------------------------------------------------------------------------- /LeanSearch-PS/README.md: -------------------------------------------------------------------------------- 1 | # LeanSearch-PS 2 | 3 | We leverage [Tevatron V2.0](https://github.com/texttron/tevatron) for training. 4 | 5 | ## Two-stage Training Pipeline 6 | 7 | ### 1. Building pairwise dataset 8 | 9 | This process aims to construct a dataset of $(s, t_{\text{pos}})$ pairs. These pairs are extracted from Mathlib using the Jixia tool. In this context, $s$ refers to a Lean proof state, and $t_{\text{pos}}$ refers to its corresponding theorem. 10 | 11 | Notice that in Tevatron V2.0 training pipeline, the negative samples should be set to random 64 theorems in the datasets. 12 | 13 | ### 2. Initial training 14 | 15 | ``` 16 | sh examples/train.sh 17 | ``` 18 | 19 | ### 3. Building triplets with hard negative data 20 | 21 | This process produces triplets of the form $(s, t_{\text{pos}}, t_{\text{hard-neg}})$, where $t_{\text{hard-neg}}$ refers to hard negative premise. For the hard negative examples, we first embed all statements and theorems with the initial trained embedding model, and then randomly select one passage from the top 30 to top 100 most similar ones for each query as its hard negative premise. Specifically, 22 | - (1) build query data and corpus data; 23 | - (2) embed query data and corpus data; 24 | - (3) search corpus embedding within query embedding; 25 | - (4) build training data. 26 | 27 | ``` 28 | python build_training_data.py 29 | ``` 30 | 31 | ### 4. Hard Negative Enhanced Training 32 | 33 | ``` 34 | sh examples/train.sh 35 | ``` 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # REAL-Prover 2 | 3 | ## Introduction 4 | 5 | This repository contains the codebase for **REAL-Prover**, a retrieval-augmented stepwise theorem prover built on Lean 4. 6 | 7 | ## Content 8 | 9 | ### LeanSearch-PS 10 | See the detailed documentation in [LeanSearch-PS/README.md](LeanSearch-PS/README.md). 11 | 12 | 13 | 14 | ### Jixia-interactive 15 | **Jixia**: 16 | A foundational library used in REAL-Prover. 17 | Repository: [https://github.com/frenzymath/jixia](https://github.com/frenzymath/jixia) 18 | 19 | ```bash 20 | git clone https://github.com/frenzymath/jixia 21 | cd jixia 22 | lake build 23 | ``` 24 | 25 | **Interactive**: 26 | Provides interactive tactics and proof state. 27 | Repository: [https://github.com/frenzymath/interactive](https://github.com/frenzymath/interactive) 28 | 29 | ```bash 30 | git clone https://github.com/frenzymath/interactive 31 | cd interactive 32 | lake build 33 | ``` 34 | 35 | 36 | ### REAL-Prover 37 | See [Realprover/README.md](Realprover/README.md) for usage and implementation details. 38 | The trained model is available on Hugging Face: [REAL-Prover](https://huggingface.co/FrenzyMath/REAL-Prover) 39 | 40 | ### FATE-M Dataset 41 | The FATE-M dataset is located at [`Realprover/data/fate_m.jsonl`](Realprover/data/fate_m.jsonl). 42 | 43 | 44 | ### Data 45 | We collected around 50k state–tactic pairs, available in this dataset: 46 | [https://huggingface.co/datasets/FrenzyMath/state_tactic_pairs](https://huggingface.co/datasets/FrenzyMath/state_tactic_pairs) 47 | -------------------------------------------------------------------------------- /Herald/README.md: -------------------------------------------------------------------------------- 1 | # HERALD-AF 2 | ## Configuration 3 | To get started, modify the configuration file at [`conf/config.py`](conf/config.py): 4 | ```python 5 | LEAN_TEST_PATH = "/path/to/lean_test" # Path to the Lean repo: https://github.com/frenzymath/lean_test_v4160 6 | DEFAULT_LAKE_PATH = '/path/to/lake' # Path to Lean's lake directory 7 | 8 | 9 | NIM_CONFIG = { 10 | 'url': "https://api.deepseek.com/v1", 11 | 'key': "your api key" # Your DeepSeek API key 12 | } 13 | ``` 14 | 15 | ## Data format 16 | Each input data should be a JSONL line with two keys: `id` and `informal_statement`. 17 | You can refer to `data/example/simp_10.jsonl` as an example: 18 | ```json 19 | {"id": 5, "informal_statement": "Prove that product of two groups is a group, \n\n Given groups $G,H$. We define $G\\times H$ as follows: \n\n \n\n $\\bullet$ Elements are the set $G\\times H$. \n\n\n\n $\\bullet$ Multiplication given by $(g_1,h_1)(g_2h_2)=(g_1g_2,h_1h_2)$. \n\n\n\n $\\bullet$ Identity given by $(e_G,e_H)$, where $e_G,e_H$ are the identities of $G,H$. \n\n\n\n $\\bullet$ Inverse given by $(g,h)^{-1}=(g^{-1},h^{-1})$."} 20 | {"id": 6, "informal_statement": "Prove that the following is a group, \n\n Given group $(G,*)$. We define $(G^\\mathrm{op},\\circ)$ as follows: \n\n \n\n $\\bullet$ Elements are the set $G$. \n\n\n\n $\\bullet$ Multiplication $\\circ$ given by $g\\circ h:=h*g$. \n\n\n\n $\\bullet$ Identity and inverse same as in $G$."} 21 | ``` 22 | 23 | ## Running HERALD 24 | HERALD uses the DeepSeek API to back-translate formal statements, so only one GPU is required. 25 | To run: 26 | ```bash 27 | cd Herald 28 | python -m herald.run --source_file /path/to/your/data --result_dir /path/to/save/result 29 | ``` -------------------------------------------------------------------------------- /Realprover/manager/thirdparty/claude.py: -------------------------------------------------------------------------------- 1 | from anthropic import AsyncAnthropic 2 | import asyncio 3 | 4 | import conf.config 5 | from manager.manage import PromptManage 6 | 7 | claude_config = conf.config.CLAUDE_CONFIG 8 | 9 | 10 | class Claude: 11 | """ 12 | 调用claude接口执行generate 13 | """ 14 | def __init__(self): 15 | self.async_client = AsyncAnthropic( 16 | base_url=claude_config['base_url'], 17 | api_key=claude_config['api_key'], 18 | ) 19 | 20 | @staticmethod 21 | async def get_single_response(async_client, claude_prompt): 22 | response = await async_client.messages.create( 23 | model="claude-3-5-sonnet-latest", 24 | max_tokens=100, 25 | temperature=0.9, 26 | messages=[{"role": "user", "content": claude_prompt}] 27 | ) 28 | return response.content[0].text if isinstance(response.content, list) else response.content # type: ignore 29 | 30 | 31 | async def get_claude_tactics(self, state, related_theorems, num_samples=16): 32 | claude_prompt = PromptManage.build_claude_prompt_str(state, related_theorems) 33 | tasks = [Claude.get_single_response(self.async_client, claude_prompt) 34 | for _ in range(num_samples)] 35 | responses = await asyncio.gather(*tasks) 36 | return responses 37 | 38 | 39 | async def get_claude_critics(self, state_tactic_states: list[tuple[str, str, str]]): 40 | tasks = [Claude.get_single_response(self.async_client, PromptManage.build_claude_critic_str(s, t, s)) 41 | for (s, t, s) in state_tactic_states] 42 | responses = await asyncio.gather(*tasks) 43 | return responses 44 | -------------------------------------------------------------------------------- /Herald/util/profiler.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import threading 4 | from collections import defaultdict 5 | 6 | 7 | class Profile: 8 | def __init__(self): 9 | self.execution_times = defaultdict(list) 10 | self.local = threading.local() 11 | 12 | def start(self, name): 13 | """Start timing a code block.""" 14 | if not hasattr(self.local, 'stack'): 15 | self.local.stack = [] 16 | 17 | start_time = time.time() 18 | self.local.stack.append((name, start_time)) 19 | 20 | def stop(self, name): 21 | """Stop timing a code block.""" 22 | end_time = time.time() 23 | 24 | if not hasattr(self.local, 'stack') or not self.local.stack: 25 | raise ValueError(f"No active timing block named '{name}' to stop.") 26 | 27 | last_name, start_time = self.local.stack.pop() 28 | if last_name != name: 29 | raise ValueError(f"Mismatched timing blocks: expected '{last_name}' but got '{name}'.") 30 | 31 | elapsed_time = end_time - start_time 32 | self.execution_times[name].append(elapsed_time) 33 | print(f'stop {name} with {elapsed_time}', flush=True) 34 | 35 | def get_execution_times(self): 36 | """Get the total execution times for each named block.""" 37 | times = {name: sum(records) for name, records in self.execution_times.items()} 38 | return times 39 | 40 | def print_execution_times(self): 41 | """Print the execution times for all recorded blocks.""" 42 | for name, records in self.execution_times.items(): 43 | total_time = sum(records) 44 | print(f"Block '{name}' executed in total: {total_time:.6f} seconds over {len(records)} runs.") 45 | 46 | 47 | profiler = Profile() 48 | -------------------------------------------------------------------------------- /Realprover/util/profiler.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import time 4 | import threading 5 | from collections import defaultdict 6 | 7 | 8 | class Profile: 9 | def __init__(self): 10 | self.execution_times = defaultdict(list) 11 | self.local = threading.local() 12 | 13 | def start(self, name): 14 | """Start timing a code block.""" 15 | if not hasattr(self.local, 'stack'): 16 | self.local.stack = [] 17 | 18 | start_time = time.time() 19 | self.local.stack.append((name, start_time)) 20 | 21 | def stop(self, name): 22 | """Stop timing a code block.""" 23 | end_time = time.time() 24 | 25 | if not hasattr(self.local, 'stack') or not self.local.stack: 26 | raise ValueError(f"No active timing block named '{name}' to stop.") 27 | 28 | last_name, start_time = self.local.stack.pop() 29 | if last_name != name: 30 | raise ValueError(f"Mismatched timing blocks: expected '{last_name}' but got '{name}'.") 31 | 32 | elapsed_time = end_time - start_time 33 | self.execution_times[name].append(elapsed_time) 34 | print(f'stop {name} with {elapsed_time}', flush=True) 35 | 36 | def get_execution_times(self): 37 | """Get the total execution times for each named block.""" 38 | times = {name: sum(records) for name, records in self.execution_times.items()} 39 | return times 40 | 41 | def print_execution_times(self): 42 | """Print the execution times for all recorded blocks.""" 43 | for name, records in self.execution_times.items(): 44 | total_time = sum(records) 45 | print(f"Block '{name}' executed in total: {total_time:.6f} seconds over {len(records)} runs.") 46 | 47 | 48 | profiler = Profile() 49 | -------------------------------------------------------------------------------- /Herald/util/http_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | 4 | ################################################################################ 5 | import requests 6 | 7 | 8 | class HttpUtil(object): 9 | 10 | @staticmethod 11 | def get(url, params=None, headers=None): 12 | """ 13 | 发送 GET 请求 14 | 15 | :param url: 请求的 URL 16 | :param params: 查询参数,字典类型 17 | :param headers: 请求头,字典类型 18 | :return: 响应的 JSON 数据(如果响应内容为 JSON),否则返回响应对象 19 | """ 20 | try: 21 | response = requests.get(url, params=params, headers=headers) 22 | response.raise_for_status() # 检查请求是否成功 23 | return response.json() if response.headers.get('Content-Type') == 'application/json' else response 24 | except requests.exceptions.HTTPError as err: 25 | print(f"HTTP error occurred: {err}") 26 | except Exception as err: 27 | print(f"An error occurred: {err}") 28 | 29 | @staticmethod 30 | def post(url, data=None, json=None, headers=None): 31 | """ 32 | 发送 POST 请求 33 | 34 | :param url: 请求的 URL 35 | :param data: 表单数据,字典类型 36 | :param json: JSON 数据,字典类型 37 | :param headers: 请求头,字典类型 38 | :return: 响应的 JSON 数据(如果响应内容为 JSON),否则返回响应对象 39 | """ 40 | try: 41 | response = requests.post(url, data=data, json=json, headers=headers, timeout=30) 42 | response.raise_for_status() # 检查请求是否成功 43 | return response.json() if response.headers.get('Content-Type') == 'application/json' else response 44 | except requests.exceptions.HTTPError as err: 45 | print(f"HTTP error occurred: {err}") 46 | except Exception as err: 47 | print(f"An error occurred: {err}") 48 | 49 | 50 | -------------------------------------------------------------------------------- /Realprover/util/http_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | 4 | ################################################################################ 5 | import requests 6 | 7 | 8 | class HttpUtil(object): 9 | 10 | @staticmethod 11 | def get(url, params=None, headers=None): 12 | """ 13 | 发送 GET 请求 14 | 15 | :param url: 请求的 URL 16 | :param params: 查询参数,字典类型 17 | :param headers: 请求头,字典类型 18 | :return: 响应的 JSON 数据(如果响应内容为 JSON),否则返回响应对象 19 | """ 20 | try: 21 | response = requests.get(url, params=params, headers=headers) 22 | response.raise_for_status() # 检查请求是否成功 23 | return response.json() if response.headers.get('Content-Type') == 'application/json' else response 24 | except requests.exceptions.HTTPError as err: 25 | print(f"HTTP error occurred: {err}") 26 | except Exception as err: 27 | print(f"An error occurred: {err}") 28 | 29 | @staticmethod 30 | def post(url, data=None, json=None, headers=None): 31 | """ 32 | 发送 POST 请求 33 | 34 | :param url: 请求的 URL 35 | :param data: 表单数据,字典类型 36 | :param json: JSON 数据,字典类型 37 | :param headers: 请求头,字典类型 38 | :return: 响应的 JSON 数据(如果响应内容为 JSON),否则返回响应对象 39 | """ 40 | try: 41 | response = requests.post(url, data=data, json=json, headers=headers, timeout=30) 42 | response.raise_for_status() # 检查请求是否成功 43 | return response.json() if response.headers.get('Content-Type') == 'application/json' else response 44 | except requests.exceptions.HTTPError as err: 45 | print(f"HTTP error occurred: {err}") 46 | except Exception as err: 47 | print(f"An error occurred: {err}") 48 | 49 | 50 | -------------------------------------------------------------------------------- /Realprover/util/log_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import logging 5 | import os 6 | 7 | 8 | class LogUtil: 9 | _instance = None # 类变量,用于保存唯一实例 10 | 11 | def __new__(cls, *args, **kwargs): 12 | """保证只创建一个实例""" 13 | if not cls._instance: 14 | # 创建唯一的实例 15 | cls._instance = super(LogUtil, cls).__new__(cls) 16 | return cls._instance 17 | 18 | def __init__(self, log_file='logs/app.log', log_level=logging.DEBUG, log_format=None): 19 | """初始化日志工具类""" 20 | if hasattr(self, '_initialized'): # 避免重复初始化 21 | return 22 | 23 | # 标记初始化已经完成 24 | self._initialized = True 25 | 26 | # 默认日志格式 27 | if log_format is None: 28 | log_format = '%(asctime)s - %(levelname)s - %(message)s' 29 | 30 | self.logger = logging.getLogger(__name__) 31 | self.logger.setLevel(log_level) 32 | 33 | # 创建日志输出格式 34 | formatter = logging.Formatter(log_format) 35 | 36 | # 创建文件处理器 37 | if not os.path.exists(os.path.dirname(log_file)): 38 | os.makedirs(os.path.dirname(log_file)) 39 | 40 | file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8') 41 | file_handler.setFormatter(formatter) 42 | 43 | # 创建控制台输出处理器 44 | console_handler = logging.StreamHandler() 45 | console_handler.setFormatter(formatter) 46 | 47 | # 添加处理器到日志器 48 | self.logger.addHandler(file_handler) 49 | self.logger.addHandler(console_handler) 50 | 51 | def debug(self, message): 52 | """记录调试级别的日志""" 53 | self.logger.debug(message) 54 | 55 | def info(self, message): 56 | """记录信息级别的日志""" 57 | self.logger.info(message) 58 | 59 | def warning(self, message): 60 | """记录警告级别的日志""" 61 | self.logger.warning(message) 62 | 63 | def error(self, message): 64 | """记录错误级别的日志""" 65 | self.logger.error(message) 66 | 67 | def critical(self, message): 68 | """记录严重错误级别的日志""" 69 | self.logger.critical(message) 70 | 71 | def exception(self, message): 72 | """记录异常信息""" 73 | self.logger.exception(message) 74 | 75 | 76 | -------------------------------------------------------------------------------- /Realprover/conf/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | 4 | ################################################################################ 5 | """ 6 | This module provide configure. 7 | """ 8 | 9 | 10 | # 支持的几个模型 11 | MODEL_TYPE_LOCAL = 'local' 12 | MODEL_TYPE_GEMINI = 'gemini' 13 | MODEL_TYPE_CLAUDE = 'claude' 14 | 15 | # 默认使用的模型List 16 | DEFAULT_MODEL_LIST = [MODEL_TYPE_LOCAL] 17 | # DEFAULT_MODEL_LIST = [MODEL_TYPE_GEMINI] 18 | 19 | AIFORMATH_PATH = "/volume/ai4math" 20 | 21 | PROVER_MODEL_ID = 'realprover' 22 | PROVER_MODEL_PATH = '/path/to/your/model' 23 | LEAN_TEST_PATH = '/path/to/lean/environment' #Path to repo https://github.com/frenzymath/lean_test_v4160 24 | LEAN_ENV_PATH = "/path/to/lean/dir" #path to bin e.g. ~/.elan/bin 25 | 26 | # NUM_SAMPLES = 64 27 | # MAX_DEPTH = 128 28 | # MAX_NODES = 1024 29 | # MAX_CALLS = 1024 30 | 31 | NUM_SAMPLES = 16 32 | MAX_DEPTH = 32 33 | MAX_NODES = 64 34 | MAX_CALLS = 512 35 | 36 | MAX_RETRIES = 1 37 | USE_BEAM_SEARCH = False 38 | USE_MCTS_SEARCH = False 39 | BEAM_WIDTH = 3 40 | IS_INCONTEXT = False 41 | USE_RETRIEVAL = True 42 | 43 | SIM_DEPTH = 5 44 | C_PUCT = 1.0 45 | C_SCORE = 1.0 46 | C_EXPANSION_FAIL_PENALTY = 30.0 47 | MAX_ROOT_EXPANSION = 5 48 | 49 | # params for model other than n 50 | PROVER_MODEL_PARAMS = { 51 | "temperature": 1.5, 52 | "max_tokens": 256, 53 | "top_p": 0.9, 54 | "logprobs": 1 55 | } 56 | 57 | API_CONFIG = { 58 | "lean_search": 'address to leansearch-ps' 59 | } 60 | NUM_QUERYS = 10 61 | 62 | 63 | CLAUDE_CONFIG = { 64 | 'base_url': 'url path', 65 | 'api_key': 'your api key', 66 | } 67 | 68 | OTHER_MODELS = ["gemini", "claude"] 69 | 70 | # 此为需要过滤的不合法tactic列表 71 | ABANDON_IF_CONTAIN = ["sorry", "admit", "apply?"] 72 | 73 | # interactive 用到的一些配置 74 | from pathlib import Path 75 | base_path = Path(__file__).parent.parent 76 | analyzer_path = Path(base_path, "../jixia/.lake/build/bin/jixia") 77 | interactive_path = Path(base_path, "../interactive") 78 | build_path = Path(base_path, "../interactive/.lake/build") 79 | 80 | 81 | 82 | # 接口验权参数 83 | SALT = "ai-for-math" 84 | EXPIRE_TIME = 5 * 60 85 | 86 | # DATA_ID = 'alg-test-v1-20' 87 | # DATA_PATH = 'data/example/alg-test-v1.jsonl' 88 | -------------------------------------------------------------------------------- /Realprover/README.md: -------------------------------------------------------------------------------- 1 | # REAL-Prover 2 | 3 | ## Dependencies and Usage 4 | 5 | Across all our projects, we use `leanprover/lean:v4.16.0`. To run this project, you need to have `interactive` and `jixia` installed and built. You also need a Lean 4 project as a working environment (referred to as a `space`). You can obtain these 2 dependencies as follow: 6 | 7 | - jixia: [jixia](https://github.com/frenzymath/jixia). Clone this repo in the parental directory and run `lake build`. 8 | - interactive: [interactive](https://github.com/frenzymath/interactive). Clone this repo in the parental directory and run `lake build`. 9 | - space: [lean_test_v4160](https://github.com/frenzymath/lean_test_v4160). Clone this repo in the parental directory and run `lake build`. 10 | 11 | The recommended directory organization is 12 | ``` 13 | - root/ 14 | - jixia/ 15 | - ... 16 | - interactive/ 17 | - .lake/build/ 18 | - ... 19 | - lean_test_v4160/ 20 | - .lake/build/ 21 | - ... 22 | - Realprover/ 23 | ``` 24 | 25 | ## Configuration 26 | All system configuration now lies in `conf/config.py`. 27 | 28 | ## Running REAL-Prover 29 | 30 | ### 1. Running the HERALD Pipeline 31 | To process data generated by HERALD, run: 32 | 33 | ```bash 34 | cd Realprover 35 | python -m herald.run --source_file /path/to/herald/informal/data --result_dir /path/to/herald/result/dir 36 | ``` 37 | 38 | - `--source_file`: Path to the input JSONL file (informal statements). 39 | 40 | - `--result_dir`: Directory to save the processed output. 41 | 42 | 43 | > **Note:** Ensure these paths are consistent with those used in `Herald/pipeline/run.py`. 44 | 45 | --- 46 | 47 | ### 2. Running Experiments 48 | 49 | #### Input Data Format 50 | Each input data should be a JSONL line with two keys: `id` and `formal_statement`. 51 | See `data/minif2f_test.jsonl` for examples. 52 | 53 | #### Start Leansearch-PS-inference server 54 | See the detailed documentation in [LeanSearch-PS-inference/README.md](../LeanSearch-PS-inference/README.md). 55 | 56 | 57 | #### Running 58 | 59 | Create a runtime configuration file (in TOML format) in `./experiment/examples/`, then run: 60 | 61 | ```bash 62 | cd Realprover 63 | python -m experiment.run /path/to/config.toml 64 | ``` 65 | 66 | --- 67 | -------------------------------------------------------------------------------- /Herald/service/handler/prover_handlerr.py: -------------------------------------------------------------------------------- 1 | import time 2 | import aiohttp 3 | import asyncio 4 | import conf.config 5 | from util import HttpUtil, StringUtil 6 | 7 | 8 | class ProverHandler(object): 9 | """ 10 | stepprover: 通过API接口调用 11 | """ 12 | 13 | def __init__(self): 14 | """ 15 | 16 | """ 17 | self.request_url = conf.config.API_CONFIG['step_prover'] 18 | 19 | def batch_gen_proof(self, formal_statement_list): 20 | return asyncio.run(self._batch_run_request(formal_statement_list)) 21 | 22 | async def _batch_run_request(self, formal_statement_list): 23 | """ 24 | 25 | """ 26 | ret = {} 27 | # 创建一个 aiohttp 客户端会话 28 | 29 | max_workers = min(len(formal_statement_list), conf.config.THREAD_CONFIG['proof']) 30 | max_workers = 1 31 | connector = aiohttp.TCPConnector(limit=max_workers) # 设置最大并发连接数 32 | async with aiohttp.ClientSession(connector=connector) as session: 33 | # 创建任务列表 34 | tasks = [self._send_post_request(session, data) for data in formal_statement_list] 35 | 36 | # 并发运行所有任务并收集结果 37 | results = await asyncio.gather(*tasks) 38 | 39 | # 输出结果 40 | for i, result in enumerate(results): 41 | print(f"Request {i + 1} result: {result}") 42 | ret[i] = result['data'] if isinstance(result, dict) else result 43 | return ret 44 | 45 | async def _send_post_request(self, session, data): 46 | try: 47 | json_data = self._gen_request_param(data) 48 | async with session.post(self.request_url, json=json_data) as response: 49 | return await response.json() 50 | except Exception as e: 51 | return f"Error: {e}" 52 | 53 | def get_one_prover_result(self, formal_statement): 54 | """ 55 | 56 | """ 57 | data = self._gen_request_param(formal_statement) 58 | res_json = HttpUtil.post(url=self.request_url, json=data) 59 | 60 | return res_json 61 | 62 | def _gen_request_param(self, formal_statement): 63 | params = { 64 | 'timestamp': time.time(), 65 | 'formal_statement': formal_statement, 66 | } 67 | params['sign'] = StringUtil.gen_sign(params) 68 | return params 69 | -------------------------------------------------------------------------------- /Realprover/manager/search/exception.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import os 4 | from manager.thirdparty import Interactive, TacticGenerator 5 | from util import CommonUtil, profiler 6 | import conf.config 7 | 8 | 9 | class SearchError(Exception): 10 | def __init__(self, message="An error occurred", error_data=None, error_type = None): 11 | super().__init__(message) 12 | self.error_data = error_data 13 | self.error_type = error_type 14 | 15 | def error_logging(logging_path: Path, id: str, statement:str, tactic_sid_record: list[dict]): 16 | record = {"statement": statement, 17 | "tactic_sid_record":tactic_sid_record} 18 | CommonUtil.write_to_json_file(logging_path, record) 19 | 20 | 21 | #从experiment/run.py 中运行 22 | def single_error_test(record_path:Path): 23 | lean_env = Path("/root/.elan/bin") 24 | os.environ["PATH"] += ":" + str(lean_env) 25 | root = Path("../lean_test_v4130") 26 | interactive = Interactive(root, Path("Header.lean"), conf.config.ABANDON_IF_CONTAIN) 27 | test_file = root / "TestOne.lean" 28 | record = CommonUtil.load_json(record_path) 29 | statement = record["statement"] 30 | tactic_list = record["tactic_sid_record"] 31 | with test_file.open("w") as fp: 32 | fp.write(statement) 33 | interactive.open_file(test_file, [None]) 34 | decl = interactive.get_next_problem() 35 | sid = 0 36 | for tactic_sid in tactic_list: 37 | try: 38 | sid = interactive.run_tactic(tactic_sid["sid"], tactic_sid["tactic"]) 39 | except RuntimeError: 40 | print("PARENT_SID:", tactic_sid["sid"], "NEW_SID:", sid, "TACTIC:", tactic_sid["tactic"], "fail") 41 | pass 42 | except Exception as e: 43 | print("ERROR_SID:", tactic_sid["sid"], "ERROR_TACTIC:", tactic_sid["tactic"]) 44 | raise e 45 | else: 46 | print("PARENT_SID:", tactic_sid["sid"], "NEW_SID:", sid, "TACTIC:", tactic_sid["tactic"], "success") 47 | tactic_sid["new_sid"] = sid 48 | state = interactive.get_state(sid) 49 | if not state: 50 | print("success") 51 | break 52 | for goal in state: 53 | print(goal.pretty) 54 | print("----------------------------------------------------") 55 | 56 | sid = interactive.give_up(0) 57 | interactive.commit(sid) 58 | decl = interactive.get_next_problem() -------------------------------------------------------------------------------- /Herald/data/example/simp_10.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 5, "informal_statement": "Prove that product of two groups is a group, \n\n Given groups $G,H$. We define $G\\times H$ as follows: \n\n \n\n $\\bullet$ Elements are the set $G\\times H$. \n\n\n\n $\\bullet$ Multiplication given by $(g_1,h_1)(g_2h_2)=(g_1g_2,h_1h_2)$. \n\n\n\n $\\bullet$ Identity given by $(e_G,e_H)$, where $e_G,e_H$ are the identities of $G,H$. \n\n\n\n $\\bullet$ Inverse given by $(g,h)^{-1}=(g^{-1},h^{-1})$."} 2 | {"id": 6, "informal_statement": "Prove that the following is a group, \n\n Given group $(G,*)$. We define $(G^\\mathrm{op},\\circ)$ as follows: \n\n \n\n $\\bullet$ Elements are the set $G$. \n\n\n\n $\\bullet$ Multiplication $\\circ$ given by $g\\circ h:=h*g$. \n\n\n\n $\\bullet$ Identity and inverse same as in $G$."} 3 | {"id": 38, "informal_statement": "Use induction to show that in the group $\\mathbb Z$ under addition, the $n$-th power of an element $k\\in\\mathbb Z$ is the usual multiplication $nk$."} 4 | {"id": 45, "informal_statement": "Basic property of order of a group element. \n\n\n\nSuppose $g\\in G$ with an positive integer $k_0\\in\\mathbb Z_+$, such that $g^{k_0}=e$ and $k_0$ is {\\bf the smallest positive integer satisfying this equation}. Prove that if positive integer $k\\in \\mathbb Z_+$ such that $g^k=e$. Then $k$ is a multiple of $k_0$. \n\n\n\nWe denote this smallest number $k_0:=\\operatorname{ord}(g)$."} 5 | {"id": 47, "informal_statement": "Prove that in the multiplicative group $\\mathbb Q^\\times$, for $g\\in\\mathbb Q^\\times\\backslash\\{-1,1\\}$, $g$ has infinite order, in other words, positive integer $k$, we have $g^k\\ne 0$."} 6 | {"id": 56, "informal_statement": "Let $a,b$ be elements of a group $G$. If $a$ has finite order, show that $\\operatorname{ord}(ab)=\\operatorname{ord}(ba)$. [HINT: if\n\n\t\\[\n\n\t(b a)^{n}=\\underbrace{b a b a \\cdots b a}_{x} a=e\n\n\t\\]\n\n\tthen $a$ is the inverse of $x$. Thus, $a x=e$.\n\n\n\n Another prove: use $\\operatorname{ord}(ab)=\\operatorname{ord}(b\\cdot ab\\cdot b^{-1})$. \n\n ]"} 7 | {"id": 65, "informal_statement": "Given integer $n\\in \\mathbb{Z}_{+}$. Prove that a group $G$ of order $n$ is a cyclic group if and only if $G$ contains an element of order $n$."} 8 | {"id": 69, "informal_statement": "Let $G$ and $H$ be groups. Prove that if $G \\times H$ is a cyclic group, then $G$ and $H$ are both cyclic."} 9 | {"id": 70, "informal_statement": "Let $G$ and $H$ be groups. Prove that if $G \\times H$ is a cyclic group, then the order of $G$ and $H$ are coprime."} 10 | {"id": 88, "informal_statement": "Let $G$ be a group, $H$ be a subset of $G$. If for any $a, b \\in H$, $ab^{-1} \\in H$, then $H$ is a subgroup of $G$."} -------------------------------------------------------------------------------- /Realprover/manager/thirdparty/verifier.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import tempfile 4 | import subprocess 5 | import random 6 | import datetime 7 | import os 8 | import pathlib 9 | import shutil 10 | 11 | DEFAULT_LAKE_PATH = "~/.elan/bin/lake" 12 | DEFAULT_LEAN_WORKSPACE = "../lean_test_4160" 13 | 14 | def verify_proof(full_code: str, 15 | lake_path: str = DEFAULT_LAKE_PATH, 16 | lean_workspace: str = DEFAULT_LEAN_WORKSPACE, 17 | timeout: int = 300) -> bool: 18 | """Verify if the proof code compiles successfully in Lean. 19 | 20 | Args: 21 | code: The proof code to verify 22 | formal_statement: The theorem statement prefix 23 | lake_path: Path to lake executable 24 | lean_workspace: Path to Lean workspace 25 | timeout: Maximum verification time in seconds 26 | 27 | Returns: 28 | True if verification succeeds, False otherwise 29 | """ 30 | 31 | # full_code = formal_statement.strip() + code 32 | 33 | command = {"cmd": full_code, "allTactics": False, "ast": False, 34 | "tactics": False, "premises": False} 35 | message_str = json.dumps(command, ensure_ascii=False) 36 | 37 | try: 38 | with tempfile.TemporaryFile(mode="w+", encoding="utf-8") as temp_file: 39 | temp_file.write(message_str + "\r\n\r\n") 40 | temp_file.seek(0) 41 | outputs = subprocess.run( 42 | [lake_path, "exe", "repl"], 43 | stdin=temp_file, 44 | capture_output=True, 45 | text=True, 46 | cwd=lean_workspace, 47 | timeout=timeout, 48 | ) 49 | result = json.loads(outputs.stdout) 50 | result = { 51 | "sorries": result.get("sorries", []), 52 | "errors": [m for m in result.get("messages", []) 53 | if m["severity"] == "error" or "sorry" in m["data"]] 54 | } 55 | if not result["errors"] and not result["sorries"]: 56 | str = "Pass" 57 | else: 58 | str = "False" 59 | return not result["errors"] and not result["sorries"] 60 | except Exception as e: 61 | logging.error(f"Verification failed: {str(e)}") 62 | return False 63 | 64 | if __name__ == "__main__": 65 | json_path = "/volume/ai4math/users/szj/realprover/stepprover-v2/experiment/logs/2025-05-07/bf2_alge_annot_qwen_nrag_sft_T_1_5_pass64-alg_test_v2-64-1024-0853/generated/1110/1110_43.json" 66 | with open(json_path,'r') as f: 67 | data = json.load(f) 68 | proof_data = data["formal_proof"] 69 | print(proof_data) 70 | res = verify_proof(proof_data) 71 | print(res) -------------------------------------------------------------------------------- /Herald/util/string_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | """ 4 | This module provide string process service. 5 | 6 | """ 7 | import hashlib 8 | import random 9 | import time 10 | 11 | import conf.config 12 | 13 | 14 | class StringUtil(object): 15 | """ 16 | 字符串处理有关 17 | """ 18 | 19 | @staticmethod 20 | def check_param_valid(params): 21 | """ 22 | 参数校验 23 | :param params: 24 | :return: 25 | """ 26 | if not params or 'timestamp' not in params: 27 | return False, "Missing param 'timestamp'" 28 | 29 | if not StringUtil.check_timestamp_valid(params['timestamp']): 30 | return False, "timestamp timeout" 31 | 32 | if not StringUtil.check_api_token(params=params): 33 | return False, "sign check fail" 34 | 35 | return True, '' 36 | 37 | 38 | 39 | @staticmethod 40 | def check_api_token(params): 41 | """ 42 | 参数校验 43 | :param params: 44 | :return: 45 | """ 46 | check_str = f"{conf.config.SALT}-{params['timestamp']}" 47 | return StringUtil.get_str_md5(check_str) == params['sign'] 48 | 49 | @staticmethod 50 | def check_timestamp_valid(timestamp): 51 | """ 52 | 校验timestamp是否超时 53 | :param timestamp: 54 | :return: 55 | """ 56 | return (time.time() - int(timestamp)) < conf.config.EXPIRE_TIME 57 | 58 | @staticmethod 59 | def gen_sign(params): 60 | """ 61 | 参数校验 62 | :param params: 63 | :return: 64 | """ 65 | check_str = f"{conf.config.SALT}-{params['timestamp']}" 66 | return StringUtil.get_str_md5(check_str) 67 | 68 | @staticmethod 69 | def get_str_md5(ustr): 70 | """ 71 | 获取字符串md5 72 | :param ustr: 73 | :return: 74 | """ 75 | m2 = hashlib.md5() 76 | m2.update(ustr.encode('utf-8')) 77 | return m2.hexdigest() 78 | 79 | @staticmethod 80 | def generate_shortcut(): 81 | """ 82 | 83 | :return: 84 | """ 85 | temp_str = StringUtil.gen_random_str(16) 86 | temp_md5 = StringUtil.get_str_md5(temp_str) 87 | return temp_md5[0:8] 88 | 89 | @staticmethod 90 | def gen_random_str(str_len=6): 91 | """ 92 | 生成随机字符串,最长不超过62位 93 | :return: 94 | """ 95 | if str_len > 62: 96 | str_len = 62 97 | return ''.join(random.sample('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', str_len)) 98 | 99 | @staticmethod 100 | def gen_success_data(data): 101 | return { 102 | 'error': False, 103 | 'msg': '', 104 | 'data': data 105 | } 106 | 107 | @staticmethod 108 | def gen_fail_data(message): 109 | return { 110 | 'error': True, 111 | 'msg': message, 112 | 'data': {} 113 | } 114 | -------------------------------------------------------------------------------- /Realprover/util/string_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | # 4 | # Copyright (c) 2018 Baidu.com, Inc. All Rights Reserved 5 | # 6 | ################################################################################ 7 | """ 8 | This module provide string process manager. 9 | 10 | """ 11 | import os 12 | import hashlib 13 | import random 14 | import time 15 | 16 | import conf.config 17 | 18 | 19 | class StringUtil(object): 20 | """ 21 | 字符串处理有关 22 | """ 23 | 24 | @staticmethod 25 | def check_param_valid(params): 26 | """ 27 | 参数校验 28 | :param params: 29 | :return: 30 | """ 31 | if not params or 'timestamp' not in params: 32 | return False, "Missing param 'timestamp'" 33 | 34 | if not StringUtil.check_timestamp_valid(params['timestamp']): 35 | return False, "timestamp timeout" 36 | 37 | if not StringUtil.check_api_token(params=params): 38 | return False, "sign check fail" 39 | 40 | return True, '' 41 | 42 | 43 | 44 | @staticmethod 45 | def check_api_token(params): 46 | """ 47 | 参数校验 48 | :param params: 49 | :return: 50 | """ 51 | check_str = f"{conf.config.SALT}-{params['timestamp']}" 52 | return StringUtil.get_str_md5(check_str) == params['sign'] 53 | 54 | @staticmethod 55 | def check_timestamp_valid(timestamp): 56 | """ 57 | 校验timestamp是否超时 58 | :param timestamp: 59 | :return: 60 | """ 61 | return (time.time() - int(timestamp)) < conf.config.EXPIRE_TIME 62 | 63 | @staticmethod 64 | def gen_sign(params): 65 | """ 66 | 参数校验 67 | :param params: 68 | :return: 69 | """ 70 | check_str = f"{conf.config.SALT}-{params['timestamp']}" 71 | return StringUtil.get_str_md5(check_str) 72 | 73 | @staticmethod 74 | def get_str_md5(ustr): 75 | """ 76 | 获取字符串md5 77 | :param ustr: 78 | :return: 79 | """ 80 | m2 = hashlib.md5() 81 | m2.update(ustr.encode('utf-8')) 82 | return m2.hexdigest() 83 | 84 | @staticmethod 85 | def generate_shortcut(): 86 | """ 87 | 88 | :return: 89 | """ 90 | temp_str = StringUtil.gen_random_str(16) 91 | temp_md5 = StringUtil.get_str_md5(temp_str) 92 | return temp_md5[0:8] 93 | 94 | @staticmethod 95 | def gen_random_str(str_len=6): 96 | """ 97 | 生成随机字符串,最长不超过62位 98 | :return: 99 | """ 100 | if str_len > 62: 101 | str_len = 62 102 | return ''.join(random.sample('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', str_len)) 103 | 104 | @staticmethod 105 | def gen_success_data(data): 106 | return { 107 | 'error': False, 108 | 'msg': '', 109 | 'data': data 110 | } 111 | 112 | @staticmethod 113 | def gen_fail_data(message): 114 | return { 115 | 'error': True, 116 | 'msg': message, 117 | 'data': {} 118 | } 119 | -------------------------------------------------------------------------------- /Herald/service/parallel_http_service.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import multiprocessing as mp 3 | 4 | from service.handler import TranHandler, BackHttpHandler 5 | from service import ParallelService 6 | 7 | 8 | class ParallelHttpService(ParallelService): 9 | """ 10 | 翻译使用http接口请求 11 | """ 12 | 13 | def __init__(self, source_file, result_dir, re_run=False, 14 | trans_gpus=None, back_process_count=8): 15 | super().__init__(source_file=source_file, result_dir=result_dir, re_run=re_run, trans_gpus=trans_gpus, 16 | back_gpus=[]) 17 | 18 | self.back_process_count = back_process_count 19 | self.back_handler = {} 20 | 21 | def _init_handler(self): 22 | for i in self.trans_gpus: 23 | self.handler[i] = TranHandler() 24 | for j in range(self.back_process_count): 25 | self.back_handler[j] = BackHttpHandler() 26 | 27 | def run(self): 28 | self._init_handler() 29 | self._init_trans_queue() 30 | print(f"data_list_size = {len(self.source_list)}") 31 | 32 | trans_process_list = [] 33 | back_process_list = [] 34 | for i in self.trans_gpus: 35 | # 创建trans进程 36 | process1 = mp.Process(target=self._run_translate, args=(i,)) 37 | process1.start() 38 | trans_process_list.append(process1) 39 | 40 | # 创建back进程 41 | for j in range(self.back_process_count): 42 | process2 = mp.Process(target=self._run_back_trans, args=(j,)) 43 | process2.start() 44 | back_process_list.append(process2) 45 | 46 | for i in range(len(trans_process_list)): 47 | trans_process_list[i].join() 48 | 49 | # 发送结束信号到第二步, 因为步骤2有多个进程,每个进程都需要单独收到一个 None 才能正确退出. 50 | for _ in range(self.back_process_count): 51 | self.back_queue.put(None) 52 | 53 | for j in range(len(back_process_list)): 54 | back_process_list[j].join() 55 | 56 | print("All data processed.") 57 | 58 | def _run_back_trans(self, process_index=0): 59 | """ 60 | 反翻译 & 比对 61 | """ 62 | this_handler = self.back_handler[process_index] 63 | while True: 64 | item = self.back_queue.get() 65 | if item is None: # 检测结束信号 66 | break 67 | data_list = [{ 68 | 'informal_statement': item['informal_statement'], 69 | 'formal_statement': i 70 | } for i in item['translate_list']] 71 | print("start_back_compare_filter") 72 | loop = asyncio.new_event_loop() 73 | asyncio.set_event_loop(loop) 74 | 75 | valid_list = [] 76 | try: 77 | valid_list = loop.run_until_complete(this_handler.back_compare_filter(data_list)) 78 | except RuntimeError as e: 79 | if "Event loop is closed" not in str(e): 80 | raise # 只忽略特定的 RuntimeError,其他的仍然抛出 81 | finally: 82 | loop.close() 83 | 84 | valid_formal_list = [i['formal_statement'] for i in valid_list] 85 | print(f"valid_formal_list.size = {len(valid_formal_list)}") 86 | 87 | self._set_dict_data(item['unique_key'], 'back_trans_list', valid_formal_list) 88 | self._save_back_trans_data(item) 89 | -------------------------------------------------------------------------------- /Realprover/experiment/resume.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import datetime 3 | try: import tomllib 4 | except ModuleNotFoundError: import pip._vendor.tomli as tomllib 5 | import torch 6 | from manager.service import BatchMainService 7 | from manager.manage import ProofParseManage 8 | from util import profiler 9 | from manager.search.exception import single_error_test 10 | import shutil 11 | import os 12 | 13 | 14 | def main(result_dir: str): 15 | # with open(f'experiment/examples/{config_path}', 'rb') as fp: 16 | # config = tomllib.load(fp) 17 | config_path = os.path.join(result_dir,'config.toml') 18 | with open(config_path, 'rb') as fp: 19 | config = tomllib.load(fp) 20 | profiler.start("run_batch") 21 | # 默认使用可见的全部gpu, 也可以自己配置 22 | #gpus = 1 23 | gpus = torch.cuda.device_count() 24 | assert gpus >= 1 25 | print(f"gpus = {gpus}") 26 | gpus_list = list(range(gpus)) 27 | 28 | main_service = BatchMainService(source_file=config['data']['data_path'], 29 | result_dir=result_dir, 30 | gpus_list=gpus_list, 31 | max_nodes=config['search']['max_nodes'], 32 | max_depth=config['search']['max_depth'], 33 | num_samples=config['search']['num_samples'], 34 | use_beam_search=config['beam_search_params']['use_beam_search'], 35 | use_mcts_search=config['mcts_params']['use_mcts_search'], 36 | beam_width=config['beam_search_params']['beam_width'], 37 | local_model_path=config['model']['prover_model_path'], 38 | sampling_params=config['model']['prover_model_params'], 39 | max_retries=config['search']['max_retries'], 40 | simulation_depth=config['mcts_params']['sim_depth'], 41 | c_puct=config['mcts_params']['c_puct'], 42 | c_score=config['mcts_params']['c_score'], 43 | c_expansion_fail_penalty=config['mcts_params']['c_expansion_fail_penalty'], 44 | max_root_expansion=config['mcts_params']['max_root_expansion'], 45 | max_calls=config['search']['max_calls'], 46 | abandon_if_contain=config['search']['abandon_if_contain'], 47 | is_incontext=config['model'].get('is_incontext', False), 48 | template=config['model'].get('template', 'deepseek') 49 | ) 50 | main_service.batch_run() 51 | print(main_service.info) 52 | profiler.stop("run_batch") 53 | ProofParseManage.get_stats(result_dir, main_service.info) 54 | # shutil.copy(f'experiment/examples/{config_path}', f'{result_dir}/{config_path}') 55 | # ProofParseManage.get_stats("/AI4M/users/ytwang/auto-proof/stepprover-v2/experiment/logs/2025-01-21/alg-test-v2-32-256-2156") 56 | # single_error_test("experiment/logs/2025-02-16/test_new-example_20_index-8-64-2037/error/8.json") 57 | # ProofParseManage.get_stats("experiment/logs/2025-02-14/mcts_test-example_20_index-6-128-0209") 58 | 59 | 60 | if __name__ == '__main__': 61 | main(sys.argv[1]) 62 | -------------------------------------------------------------------------------- /Realprover/experiment/run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import datetime 3 | try: import tomllib 4 | except ModuleNotFoundError: import pip._vendor.tomli as tomllib 5 | import torch 6 | from manager.service import BatchMainService 7 | from manager.manage import ProofParseManage 8 | from util import profiler 9 | from manager.search.exception import single_error_test 10 | import shutil 11 | import os 12 | 13 | 14 | def main(config_path: str): 15 | # with open(f'experiment/examples/{config_path}', 'rb') as fp: 16 | # config = tomllib.load(fp) 17 | with open(config_path, 'rb') as fp: 18 | config = tomllib.load(fp) 19 | result_dir = 'experiment/logs/{timestamp:%Y-%m-%d}/{info}-{source_id}-{num_samples}-{max_calls}-{timestamp:%H%M}'.format( 20 | timestamp=datetime.datetime.now(), 21 | info = config["info"], 22 | source_id=config['data']['data_id'], 23 | num_samples=config['search']['num_samples'], 24 | max_calls=config['search']['max_calls']) # 结果文件目录 25 | if not os.path.exists(result_dir): 26 | os.makedirs(result_dir) 27 | shutil.copy(config_path, os.path.join(result_dir,'config.toml')) 28 | profiler.start("run_batch") 29 | # 默认使用可见的全部gpu, 也可以自己配置 30 | #gpus = 1 31 | gpus = torch.cuda.device_count() 32 | assert gpus >= 1 33 | print(f"gpus = {gpus}") 34 | gpus_list = list(range(gpus)) 35 | 36 | main_service = BatchMainService(source_file=config['data']['data_path'], 37 | result_dir=result_dir, 38 | gpus_list=gpus_list, 39 | max_nodes=config['search']['max_nodes'], 40 | max_depth=config['search']['max_depth'], 41 | num_samples=config['search']['num_samples'], 42 | use_beam_search=config['beam_search_params']['use_beam_search'], 43 | use_mcts_search=config['mcts_params']['use_mcts_search'], 44 | beam_width=config['beam_search_params']['beam_width'], 45 | local_model_path=config['model']['prover_model_path'], 46 | sampling_params=config['model']['prover_model_params'], 47 | max_retries=config['search']['max_retries'], 48 | simulation_depth=config['mcts_params']['sim_depth'], 49 | c_puct=config['mcts_params']['c_puct'], 50 | c_score=config['mcts_params']['c_score'], 51 | c_expansion_fail_penalty=config['mcts_params']['c_expansion_fail_penalty'], 52 | max_root_expansion=config['mcts_params']['max_root_expansion'], 53 | max_calls=config['search']['max_calls'], 54 | abandon_if_contain=config['search']['abandon_if_contain'], 55 | is_incontext=config['model'].get('is_incontext', False), 56 | template=config['model'].get('template', 'deepseek'), 57 | use_retrieval = config['model'].get('use_retrieval', True) 58 | ) 59 | main_service.batch_run() 60 | print(main_service.info) 61 | profiler.stop("run_batch") 62 | ProofParseManage.get_stats(result_dir, main_service.info) 63 | # shutil.copy(f'experiment/examples/{config_path}', f'{result_dir}/{config_path}') 64 | # ProofParseManage.get_stats("/AI4M/users/ytwang/auto-proof/stepprover-v2/experiment/logs/2025-01-21/alg-test-v2-32-256-2156") 65 | # single_error_test("experiment/logs/2025-02-16/test_new-example_20_index-8-64-2037/error/8.json") 66 | # ProofParseManage.get_stats("experiment/logs/2025-02-14/mcts_test-example_20_index-6-128-0209") 67 | 68 | 69 | if __name__ == '__main__': 70 | main(sys.argv[1]) 71 | -------------------------------------------------------------------------------- /Herald/service/handler/back_http_handler.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from openai import AsyncOpenAI 4 | import conf.config 5 | 6 | 7 | class BackHttpHandler(object): 8 | """ 9 | lean 语言 => 自然语言 && 比对informal_statement 10 | """ 11 | 12 | def __init__(self): 13 | """ 14 | 15 | """ 16 | # self.model_name = 'deepseek-ai/DeepSeek-V2.5' 17 | self.model_name = conf.config.MODEL_CONFIG['compare'] # 翻译和反翻译都使用此模型 18 | 19 | async def back_compare_filter(self, data_list): 20 | """ 21 | 22 | """ 23 | if len(data_list) == 0: 24 | return [] 25 | 26 | for index, row_data in enumerate(data_list): 27 | print(f"back_compare_filter_index = {index}") 28 | row_data['back_translate'] = await self.get_back_tran(row_data) 29 | same_str = await self.compare(row_data) 30 | row_data['same_result'] = same_str 31 | 32 | result_list = [row for row in data_list if row['same_result'] == 'same'] 33 | return result_list 34 | 35 | async def get_back_tran(self, data): 36 | sys_prompt = 'Convert the formal statement into natural language: ' 37 | prompt = data['formal_statement'] 38 | message_str = json.dumps([ 39 | {"role": "system", "content": sys_prompt}, 40 | {"role": "user", "content": prompt} 41 | ]) 42 | messages = json.loads(message_str) 43 | # back_tran_str = self.execute_request(messages=messages) 44 | back_tran_str = await self.request_model(messages=messages) 45 | # TODO 生成的数据是否需要处理呢 46 | data['back_translate'] = back_tran_str 47 | return back_tran_str 48 | 49 | def get_query_nil_apichat(self, problem): 50 | sys_prompt = 'Please check the following two math problems are the same or different? Please consider each statement in the two problems; they are different if any statement is different. Please point out any differences you found. Please reply ||same|| or ||different|| in the final sentence with "||" format.' 51 | problem_origin = problem['informal_statement'] 52 | problem_back = problem['back_translate'] 53 | prompt = 'Problem 1:\n' + problem_origin + '\nProblem 2:\n' + problem_back 54 | problem["prompt"] = json.dumps([ 55 | {"role": "system", "content": sys_prompt}, 56 | {"role": "user", "content": prompt} 57 | ]) 58 | return problem 59 | 60 | async def compare(self, data): 61 | """ 62 | 返回值为 same or different 63 | """ 64 | data = self.get_query_nil_apichat(data) 65 | messages = json.loads(data['prompt']) 66 | generate_data = await self.request_model(messages=messages) 67 | ret = self.extract_bold_text(generate_data) 68 | print("***check-same-ret***: %s" % ret) 69 | data['same_result'] = ret 70 | return ret 71 | 72 | async def request_model(self, messages): 73 | """ 74 | 请求大模型 75 | """ 76 | nil_client = AsyncOpenAI( 77 | base_url=conf.config.NIM_CONFIG['url'], 78 | api_key=conf.config.NIM_CONFIG['key'], 79 | timeout=600 80 | ) 81 | response = await nil_client.chat.completions.create( 82 | model=self.model_name, 83 | messages=messages, 84 | max_tokens=1024, 85 | temperature=0.01, 86 | top_p=0.7, 87 | extra_body={'repetition_penalty': 1}, 88 | stream=False 89 | ) 90 | if not response or not response.choices: 91 | return 'null' 92 | result = response.choices[0].message.content 93 | print("*** request_model_response result ***") 94 | print(result) 95 | return result 96 | 97 | def extract_bold_text(self, output): 98 | # 使用正则表达式提取**之间的内容 99 | match = re.search(r'\|\|(.*?)\|\|', output) 100 | if match: 101 | return match.group(1) 102 | return 'null' 103 | -------------------------------------------------------------------------------- /LeanSearch-PS-inference/worker/premise_selector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor 4 | from transformers import AutoTokenizer, AutoModel 5 | import faiss 6 | import numpy as np 7 | import json 8 | import conf.config 9 | 10 | np.random.seed(42) # make reproducible 11 | 12 | class PremiseSelector(object): 13 | def __init__(self): 14 | self.device = 'cuda' 15 | 16 | self.index = faiss.read_index(conf.config.INDEX_PATH) 17 | self.answer_path = conf.config.ANSWER_PATH 18 | self.tokenizer = None 19 | self.model = None 20 | self.formals = [] 21 | 22 | self.init_formals() 23 | self.init_data() 24 | 25 | def init_formals(self): 26 | formals = [] 27 | with open(self.answer_path, "r") as fd: 28 | data = json.load(fd) 29 | for i in data: 30 | formals.append(data[i]['Formal name'] + '\n' + data[i]['Formal statement']) 31 | self.formals = formals 32 | 33 | def init_data(self): 34 | datas = [] 35 | with open(self.answer_path, "r") as fd: 36 | data = json.load(fd) 37 | for i in data: 38 | datas.append(data[i]) 39 | self.datas = datas 40 | 41 | def init_model(self): 42 | if self.model is None: 43 | self.tokenizer = AutoTokenizer.from_pretrained(conf.config.TOKENIZER_PATH) 44 | self.model = AutoModel.from_pretrained(conf.config.MODEL_PATH).half() 45 | self.model.to(self.device) 46 | 47 | def release_model(self): 48 | self.tokenizer = None 49 | self.model = None 50 | 51 | def retrieve(self, queries, num=10): 52 | self.init_model() 53 | task = "Given a lean4 context, retrieve Lean4 theorems useful to solve it: " 54 | if isinstance(queries, str): 55 | queries = [queries] 56 | 57 | query_vecs = self.get_embeddings(task, queries).cpu().numpy().astype(np.float32) 58 | 59 | # faiss for searching 60 | D, I = self.index.search(query_vecs, num) 61 | 62 | all_related_theorems = [] 63 | for query_results in I: 64 | related_theorems = [] 65 | for docid in query_results: 66 | related_theorems.append(self.datas[docid]) 67 | all_related_theorems.append(related_theorems[:num]) 68 | 69 | return all_related_theorems 70 | 71 | def get_embeddings(self, task, input_texts): 72 | with torch.no_grad(): 73 | max_length = 4096 74 | detailed_input_texts = [self.get_detailed_instruct(task, input_text) for input_text in input_texts] 75 | # Tokenize the input texts 76 | batch_dict = self.tokenizer(detailed_input_texts, max_length=max_length - 1, return_attention_mask=False, 77 | padding=False, truncation=True) 78 | batch_dict['input_ids'] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in 79 | batch_dict['input_ids']] 80 | batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, 81 | return_tensors='pt').to( 82 | self.device) 83 | 84 | outputs = self.model(**batch_dict) 85 | embeddings = self.last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) 86 | 87 | embeddings = F.normalize(embeddings, p=2, dim=1) 88 | 89 | # Release cuda memory 90 | del batch_dict 91 | del outputs 92 | 93 | # Release cuda memory 94 | torch.cuda.empty_cache() 95 | 96 | return embeddings 97 | 98 | def last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: 99 | left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) 100 | if left_padding: 101 | return last_hidden_states[:, -1] 102 | else: 103 | sequence_lengths = attention_mask.sum(dim=1) - 1 104 | batch_size = last_hidden_states.shape[0] 105 | return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] 106 | 107 | def get_detailed_instruct(self, task_description: str, query: str) -> str: 108 | return f'{task_description}\n{query}' 109 | -------------------------------------------------------------------------------- /Herald/util/common_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | 4 | ################################################################################ 5 | """ 6 | This module provide string process service. 7 | 8 | """ 9 | import json 10 | import os 11 | from datetime import datetime 12 | 13 | 14 | class CommonUtil(object): 15 | """ 16 | 通用处理类 17 | """ 18 | @staticmethod 19 | def load_json(config_file): 20 | """load json""" 21 | params = None 22 | if os.path.exists(config_file): 23 | f = open(config_file) 24 | params = json.load(f) 25 | else: 26 | raise Exception('Config Error') 27 | return params 28 | 29 | @staticmethod 30 | def write_to_json_file(file_path, json_data, ensure_ascii=False): 31 | """ 32 | 33 | """ 34 | with open(file_path, 'w', encoding='utf-8') as dump_f: 35 | json.dump(json_data, dump_f, ensure_ascii=ensure_ascii) 36 | 37 | @staticmethod 38 | def write_json_list_to_file(file_path, data_list): 39 | """ 40 | 41 | :param file_path: 42 | :param data_list: 43 | :return: 44 | """ 45 | with open(file_path, mode='w', encoding='utf-8') as f: 46 | for index, data in enumerate(data_list): 47 | json.dump(data, f, ensure_ascii=False) 48 | if index < len(data_list) - 1: 49 | f.write('\n') 50 | 51 | @staticmethod 52 | def read_json_list(file_path): 53 | """ 54 | 55 | :param file_path: 56 | :return: 57 | """ 58 | data_list = [] 59 | with open(file_path, mode='r') as f: 60 | for line in f: 61 | if line: 62 | data_list.append(json.loads(line)) 63 | return data_list 64 | 65 | @staticmethod 66 | def read_list(file_path, skip_first=True): 67 | """ 68 | 69 | :param skip_first: 70 | :param file_path: 71 | :return: 72 | """ 73 | data_list = [] 74 | with open(file_path, mode='r') as fr: 75 | if skip_first: 76 | next(fr) 77 | for line in fr: 78 | if line: 79 | data_list.append(line) 80 | return data_list 81 | 82 | @staticmethod 83 | def write_list(file_path, lines_list): 84 | """ 85 | 86 | :param file_path: 87 | :param lines_list: 88 | :return: 89 | """ 90 | with open(file_path, 'w', encoding='utf-8') as f: 91 | f.writelines(lines_list) 92 | 93 | @staticmethod 94 | def build_key_to_data(data_list, key_name, value_key=''): 95 | """ 96 | 97 | :param data_list: 98 | :param key_name: 99 | :param value_key: 100 | :return: 101 | """ 102 | ret = {} 103 | for row_data in data_list: 104 | row_value = row_data[key_name] 105 | if row_value not in ret: 106 | ret[row_value] = row_data[value_key] if value_key else row_data 107 | return ret 108 | 109 | @staticmethod 110 | def build_key_to_list(data_list, key_name): 111 | """ 112 | 113 | :param data_list: 114 | :param key_name: 115 | :return: 116 | """ 117 | ret = {} 118 | for row_data in data_list: 119 | row_value = row_data[key_name] 120 | if row_value not in ret: 121 | ret[row_value] = [] 122 | ret[row_value].append(row_data) 123 | return ret 124 | 125 | @staticmethod 126 | def get_date_time(): 127 | """ 128 | 129 | :return: 130 | """ 131 | return datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 132 | 133 | @staticmethod 134 | def file_exist(file_path): 135 | """ 136 | 137 | :param file_path: 138 | :return: 139 | """ 140 | return os.path.isfile(file_path) 141 | 142 | @staticmethod 143 | def split_list(lst, chunk_size): 144 | """ 145 | 将列表拆分成多个小列表,每个小列表包含chunk_size个元素。 146 | 147 | :param lst: 待拆分的列表 148 | :param chunk_size: 每个小列表包含的元素数量 149 | :return: 包含多个小列表的列表 150 | """ 151 | return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] 152 | 153 | @staticmethod 154 | def print(string): 155 | print(f"{CommonUtil.get_date_time()}: {string}") 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /Realprover/util/common_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ################################################################################ 3 | 4 | ################################################################################ 5 | """ 6 | This module provide string process manager. 7 | 8 | """ 9 | import json 10 | import os 11 | from datetime import datetime 12 | 13 | 14 | class CommonUtil(object): 15 | """ 16 | 通用处理类 17 | """ 18 | @staticmethod 19 | def load_json(config_file): 20 | """load json""" 21 | params = None 22 | if os.path.exists(config_file): 23 | f = open(config_file) 24 | params = json.load(f) 25 | else: 26 | raise Exception('Config Error') 27 | return params 28 | 29 | @staticmethod 30 | def write_to_json_file(file_path, json_data, ensure_ascii=False): 31 | """ 32 | 33 | """ 34 | with open(file_path, 'w', encoding='utf-8') as dump_f: 35 | json.dump(json_data, dump_f, ensure_ascii=ensure_ascii) 36 | 37 | @staticmethod 38 | def write_json_list_to_file(file_path, data_list): 39 | """ 40 | 41 | :param file_path: 42 | :param data_list: 43 | :return: 44 | """ 45 | with open(file_path, mode='w', encoding='utf-8') as f: 46 | for index, data in enumerate(data_list): 47 | json.dump(data, f, ensure_ascii=False) 48 | if index < len(data_list) - 1: 49 | f.write('\n') 50 | 51 | @staticmethod 52 | def read_json_list(file_path): 53 | """ 54 | 55 | :param file_path: 56 | :return: 57 | """ 58 | data_list = [] 59 | with open(file_path, mode='r') as f: 60 | for line in f: 61 | if line: 62 | data_list.append(json.loads(line)) 63 | return data_list 64 | 65 | @staticmethod 66 | def read_list(file_path, skip_first=True): 67 | """ 68 | 69 | :param skip_first: 70 | :param file_path: 71 | :return: 72 | """ 73 | data_list = [] 74 | with open(file_path, mode='r') as fr: 75 | if skip_first: 76 | next(fr) 77 | for line in fr: 78 | if line: 79 | data_list.append(line) 80 | return data_list 81 | 82 | @staticmethod 83 | def write_list(file_path, lines_list): 84 | """ 85 | 86 | :param file_path: 87 | :param lines_list: 88 | :return: 89 | """ 90 | with open(file_path, 'w', encoding='utf-8') as f: 91 | f.writelines(lines_list) 92 | 93 | @staticmethod 94 | def build_key_to_data(data_list, key_name, value_key=''): 95 | """ 96 | 97 | :param data_list: 98 | :param key_name: 99 | :param value_key: 100 | :return: 101 | """ 102 | ret = {} 103 | for row_data in data_list: 104 | row_value = row_data[key_name] 105 | if row_value not in ret: 106 | ret[row_value] = row_data[value_key] if value_key else row_data 107 | return ret 108 | 109 | @staticmethod 110 | def build_key_to_list(data_list, key_name): 111 | """ 112 | 113 | :param data_list: 114 | :param key_name: 115 | :return: 116 | """ 117 | ret = {} 118 | for row_data in data_list: 119 | row_value = row_data[key_name] 120 | if row_value not in ret: 121 | ret[row_value] = [] 122 | ret[row_value].append(row_data) 123 | return ret 124 | 125 | @staticmethod 126 | def get_date_time(): 127 | """ 128 | 129 | :return: 130 | """ 131 | return datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 132 | 133 | @staticmethod 134 | def file_exist(file_path): 135 | """ 136 | 137 | :param file_path: 138 | :return: 139 | """ 140 | return os.path.isfile(file_path) 141 | 142 | @staticmethod 143 | def split_list(lst, chunk_size): 144 | """ 145 | 将列表拆分成多个小列表,每个小列表包含chunk_size个元素。 146 | 147 | :param lst: 待拆分的列表 148 | :param chunk_size: 每个小列表包含的元素数量 149 | :return: 包含多个小列表的列表 150 | """ 151 | return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] 152 | 153 | @staticmethod 154 | def print(string): 155 | print(f"{CommonUtil.get_date_time()}: {string}") 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /Realprover/manager/search/beam_search.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from heapdict import heapdict 3 | from manager.struct import Node, state_repr_dedup 4 | from manager.thirdparty import Interactive, TacticGenerator 5 | import conf.config 6 | from manager.search.exception import SearchError 7 | # import logging 8 | 9 | 10 | class BeamSearch: 11 | def __init__(self, 12 | beam_width: int = conf.config.BEAM_WIDTH, 13 | num_samples: int = conf.config.NUM_SAMPLES, 14 | max_nodes: int = conf.config.MAX_NODES, 15 | max_depth: int = conf.config.MAX_DEPTH, 16 | abandon_if_contain: list[str] = conf.config.ABANDON_IF_CONTAIN 17 | ): 18 | self.found = False 19 | self.nodes = {} 20 | self.score = heapdict() 21 | self.depth = 0 22 | self.beam_width = beam_width 23 | self.num_samples = num_samples 24 | self.max_depth = max_depth 25 | self.max_nodes = max_nodes 26 | self.tactic_sid_record = [] 27 | self.abandon_if_contain = abandon_if_contain 28 | 29 | def insert(self, node: Node): 30 | if not node.state: 31 | self.found = True 32 | 33 | deduped_state_str = state_repr_dedup(node.state) 34 | if deduped_state_str not in self.nodes: 35 | self.nodes[deduped_state_str] = node 36 | self.score[deduped_state_str] = -node.score 37 | self.depth = max(self.depth, node.depth) 38 | # print(f'{len(self.nodes)} Explored.') 39 | else: 40 | try: 41 | self.score[deduped_state_str] -= node.score 42 | except KeyError: 43 | pass 44 | 45 | def get(self) -> list[Node]: 46 | beam = [] 47 | for _ in range(min(len(self.score), self.beam_width)): 48 | k, _ = self.score.popitem() 49 | beam.append(self.nodes[k]) 50 | # self.process_record.append(self.nodes[k]) 51 | return beam 52 | 53 | def going(self) -> bool: 54 | return not self.found and len(self.nodes) < self.max_nodes and self.depth < self.max_depth 55 | 56 | def tactic_filter(self, tactic: str) -> bool: 57 | for forbidden_tactic in self.abandon_if_contain: 58 | if forbidden_tactic in tactic: 59 | return False 60 | return True 61 | 62 | def search_proof(self, 63 | generator: TacticGenerator, 64 | interactive: Interactive): 65 | assert len(self.nodes) == 1 66 | assert generator.calls == [] 67 | while self.going() and generator.has_quota(): 68 | beam = self.get() 69 | if not beam: 70 | break 71 | for node in beam: 72 | tactics, scores = generator.from_state(node.state, self.num_samples) 73 | for (tactic, num_reps), _ in zip(Counter(tactics).items(), scores): 74 | if not self.tactic_filter(tactic): 75 | continue 76 | self.tactic_sid_record.append({"tactic":tactic, "sid":node.sid}) 77 | try: 78 | sid = interactive.run_tactic(node.sid, tactic) 79 | except RuntimeError: 80 | pass 81 | except Exception as e: 82 | #目前仅在run_tactic加入记录error-logging功能, 因为根据以往经验在get_state/giveup 加入try block可能会导致broken pipe error 83 | #如果确定问题所在可以手动添加 84 | raise SearchError("An error occurred at run_tactic", 85 | error_data = self.tactic_sid_record, 86 | error_type = e) 87 | else: 88 | state = interactive.get_state(sid) 89 | new_node = Node(sid, node.sid, tactic, state, node.depth + 1, num_reps, node) 90 | self.insert(new_node) 91 | if not state: 92 | interactive.commit(sid) 93 | self.found = True 94 | return 95 | if not self.found: 96 | sid = interactive.give_up(0) 97 | interactive.commit(sid) 98 | 99 | @property 100 | def info(self): 101 | return dict( 102 | use_beam_search=True, 103 | beam_width=self.beam_width, 104 | num_samples=self.num_samples, 105 | max_nodes=self.max_nodes, 106 | max_depth=self.max_depth) 107 | -------------------------------------------------------------------------------- /Herald/service/pipeline_service.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import os 3 | import gc 4 | import torch 5 | 6 | from service.handler import TranHandler, BackHandler, ProverHandler 7 | from util import CommonUtil 8 | 9 | 10 | class PipelineService(object): 11 | """ 12 | 服务 13 | """ 14 | 15 | def __init__(self, result_dir): 16 | """ 17 | 18 | """ 19 | self.result_dir = result_dir 20 | self.tran_handler = None 21 | self.back_handler = None 22 | self.proof_handler = None 23 | if result_dir is None: 24 | self.result_dir = f"./data_result/{CommonUtil.get_date_time(), uuid.uuid4()}" 25 | self._init_dir() 26 | 27 | self.informal_statement = "" 28 | self.formal_list_after_validate = [] # 生成并编译后的 29 | self.formal_list_after_compare = [] # 反翻译比对通过的 30 | self.prove_detail_dict = {} # proof详情信息 31 | self.result_list = [] # statement和proof的成对 32 | 33 | def _init_dir(self): 34 | self.tran_file_path = f"{self.result_dir}/translate.json" 35 | self.prove_file_path = f"{self.result_dir}/prove_info.json" 36 | 37 | if not os.path.exists(self.result_dir): 38 | os.makedirs(self.result_dir) 39 | 40 | def _init_handler(self): 41 | if self.tran_handler is None: 42 | self.tran_handler = TranHandler() 43 | if self.back_handler is None: 44 | self.back_handler = BackHandler() 45 | if self.proof_handler is None: 46 | self.proof_handler = ProverHandler() 47 | 48 | def run(self, informal_statement, with_proof=True): 49 | self.informal_statement = informal_statement 50 | self._init_handler() 51 | 52 | if CommonUtil.file_exist(self.tran_file_path): 53 | saved_data = CommonUtil.load_json(self.tran_file_path) 54 | self.formal_list_after_validate = saved_data['formal_list_after_validate'] 55 | self.formal_list_after_compare = saved_data['formal_list_after_compare'] 56 | else: 57 | 58 | self._run_translate() 59 | self._run_back_trans() 60 | self._save_tran_data() 61 | 62 | if with_proof: 63 | self._run_proof() 64 | self._save_proof_data() 65 | 66 | def _run_translate(self): 67 | """ 68 | 翻译、编译 69 | """ 70 | generate_list = self.tran_handler.generate_and_check(self.informal_statement) 71 | self.tran_handler.release_model() 72 | self._gc_collect() 73 | 74 | print('generate finished: data_list_size = %s' % len(generate_list)) 75 | self.formal_list_after_validate = generate_list 76 | print(f"formal_list_after_validate_size = {len(generate_list)}") 77 | 78 | def _run_back_trans(self): 79 | """ 80 | 反翻译 & 比对 81 | """ 82 | data_list = [{ 83 | 'informal_statement': self.informal_statement, 84 | 'formal_statement': i 85 | } for i in self.formal_list_after_validate] 86 | valid_list = self.back_handler.back_compare_filter(data_list) 87 | self.back_handler.release_model() 88 | self._gc_collect() 89 | self.formal_list_after_compare = [i['formal_statement'] for i in valid_list] 90 | print(f"formal_list_after_compare_size = {len(valid_list)}") 91 | 92 | def _run_proof(self): 93 | self.prove_detail_dict = self.proof_handler.batch_gen_proof(self.formal_list_after_compare) 94 | for index, prove_detail in self.prove_detail_dict.items(): 95 | if "formal_proof" in prove_detail: 96 | self.result_list.append({ 97 | 'formal_statement': self.formal_list_after_compare[index], 98 | 'formal_proof': prove_detail['formal_proof'] 99 | }) 100 | 101 | 102 | def _gc_collect(self): 103 | gc.collect() # 调用垃圾回收 104 | torch.cuda.empty_cache() # 清理 GPU 缓存 105 | 106 | def _save_tran_data(self): 107 | CommonUtil.write_to_json_file(self.tran_file_path, self._build_save_data()) 108 | 109 | def _save_proof_data(self): 110 | CommonUtil.write_to_json_file(self.prove_file_path, self._build_save_data(True)) 111 | 112 | def _build_save_data(self, with_proof=False): 113 | """ 114 | 115 | """ 116 | save_data = { 117 | 'informal_statement': self.informal_statement, 118 | 'formal_list_after_validate': self.formal_list_after_validate, 119 | 'formal_list_after_compare': self.formal_list_after_compare 120 | } 121 | if with_proof: 122 | save_data['prove_detail_dict'] = self.prove_detail_dict 123 | save_data['result_list'] = self.result_list 124 | return save_data 125 | -------------------------------------------------------------------------------- /Herald/service/handler/back_handler.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | from transformers import AutoTokenizer 3 | import re 4 | import json 5 | from openai import AsyncOpenAI 6 | import asyncio 7 | import concurrent.futures 8 | import conf.config 9 | 10 | 11 | class BackHandler(object): 12 | """ 13 | lean 语言 => 自然语言 && 比对informal_statement 14 | """ 15 | def __init__(self): 16 | """ 17 | 18 | """ 19 | self.model = None 20 | self.bt_tokenizer = None 21 | self.sampling_params = None 22 | self.tp_size = 1 23 | 24 | # self._init_model() 25 | 26 | 27 | def _init_model(self): 28 | if self.model is None: 29 | self.bt_tokenizer = AutoTokenizer.from_pretrained(conf.config.MODEL_CONFIG['back_trans'], use_fast=False, 30 | trust_remote_code=True) 31 | self.model = LLM( 32 | model=conf.config.MODEL_CONFIG['back_trans'], tensor_parallel_size=self.tp_size, trust_remote_code=True, 33 | dtype="bfloat16") 34 | 35 | def release_model(self): 36 | self.model = None 37 | self.bt_tokenizer = None 38 | 39 | def _init_sampling_params(self): 40 | if self.sampling_params is None: 41 | self.sampling_params = SamplingParams( 42 | temperature=0.1, 43 | max_tokens=1024, 44 | stop=['[UNUSED_TOKEN_146]', '[UNUSED_TOKEN_145]', '<|im_end|>']) 45 | 46 | def get_query_backtrans_intern(self, problem): 47 | output = problem['formal_statement'] 48 | output = '[UNUSED_TOKEN_146]user\nConvert the formal statement into natural language:\n```lean\n' + output + '\n```[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' 49 | problem['prompt_backtranslate'] = output 50 | return problem 51 | 52 | def back_compare_filter_old(self, data_list): 53 | """ 54 | 反翻译并比对,失败的过滤掉 55 | """ 56 | result_list = [] 57 | for index, row_data in enumerate(data_list): 58 | print('back_compare: index = %s' % index) 59 | row_data['back_translate'] = self.get_back_tran(row_data) 60 | same_str = asyncio.run(self.compare(row_data)) 61 | row_data['same_result'] = same_str 62 | if same_str == "same": 63 | result_list.append(row_data) 64 | return result_list 65 | 66 | def back_compare_filter(self, data_list): 67 | """ 68 | 反翻译并比对,失败的过滤掉 69 | """ 70 | if len(data_list) == 0: 71 | return [] 72 | self._init_model() 73 | for row_data in data_list: 74 | row_data['back_translate'] = self.get_back_tran(row_data) 75 | 76 | def process_row(row_data): 77 | print(f"back_compare: index = {data_list.index(row_data)}") 78 | # row_data['back_translate'] = self.get_back_tran(row_data) 79 | same_str = asyncio.run(self.compare(row_data)) 80 | row_data['same_result'] = same_str 81 | return row_data if same_str == "same" else None 82 | 83 | max_workers = min(len(data_list), conf.config.THREAD_CONFIG['same_check']) 84 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 85 | results = list(executor.map(process_row, data_list)) 86 | 87 | # 过滤掉 None 值,仅保留匹配的行 88 | result_list = [row for row in results if row is not None] 89 | return result_list 90 | 91 | def get_back_tran(self, data): 92 | self._init_sampling_params() 93 | data = self.get_query_backtrans_intern(data) 94 | outputs = self.model.generate(data['prompt_backtranslate'], sampling_params=self.sampling_params) 95 | print('***outputs***') 96 | print(outputs) 97 | result = outputs[0].outputs[0].text 98 | return result 99 | 100 | 101 | 102 | def get_query_nil_apichat(self, problem): 103 | sys_prompt = 'Please check the following two math problems are the same or different? Please consider each statement in the two problems; they are different if any statement is different. Please point out any differences you found. Please reply ||same|| or ||different|| in the final sentence with "||" format.' 104 | problem_origin = problem['informal_statement'] 105 | problem_back = problem['back_translate'] 106 | prompt = 'Problem 1:\n' + problem_origin + '\nProblem 2:\n' + problem_back 107 | problem["prompt"] = json.dumps([ 108 | {"role": "system", "content": sys_prompt}, 109 | {"role": "user", "content": prompt} 110 | ]) 111 | return problem 112 | 113 | async def compare(self, data): 114 | """ 115 | 返回值为 same or different 116 | """ 117 | data = self.get_query_nil_apichat(data) 118 | messages = json.loads(data['prompt']) 119 | nil_client = AsyncOpenAI( 120 | base_url=conf.config.NIM_CONFIG['url'], 121 | api_key=conf.config.NIM_CONFIG['key'], 122 | timeout=600 123 | ) 124 | response = await nil_client.chat.completions.create( 125 | model=conf.config.MODEL_CONFIG['compare'], 126 | messages=messages, 127 | max_tokens=1024, 128 | temperature=0.01, 129 | top_p=0.7, 130 | extra_body={'repetition_penalty': 1}, 131 | stream=False 132 | ) 133 | if not response or not response.choices: 134 | return 'null' 135 | result = response.choices[0].message.content 136 | print("***check-same-response result***") 137 | print(result) 138 | 139 | ret = self.extract_bold_text(result) 140 | print("***check-same-ret***: %s" % ret) 141 | return ret 142 | 143 | 144 | def extract_bold_text(self, output): 145 | # 使用正则表达式提取**之间的内容 146 | match = re.search(r'\|\|(.*?)\|\|', output) 147 | if match: 148 | return match.group(1) 149 | return 'null' 150 | -------------------------------------------------------------------------------- /Realprover/manager/search/best_first.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from heapdict import heapdict 3 | from manager.struct import Node, state_repr_dedup 4 | from manager.thirdparty import Interactive, TacticGenerator 5 | from manager.thirdparty.verifier import verify_proof 6 | import conf.config 7 | from manager.search.exception import SearchError 8 | import traceback 9 | 10 | def hard_stop_criterion(node: Node, window_size=5) -> bool: 11 | return all('have' in n.tactic for n in node.current_path[:-window_size-1:-1]) 12 | 13 | class BestFirstSearch: 14 | def __init__( 15 | self, 16 | num_samples: int = conf.config.NUM_SAMPLES, 17 | max_nodes: int = conf.config.MAX_NODES, 18 | max_depth: int = conf.config.MAX_DEPTH, 19 | abandon_if_contain: list[str] = conf.config.ABANDON_IF_CONTAIN, 20 | is_incontext: bool = conf.config.IS_INCONTEXT, 21 | template: str = 'deepseek', 22 | use_retrieval: bool = conf.config.USE_RETRIEVAL, 23 | alpha: float = 0.5 24 | ): 25 | self.found = False 26 | self.nodes = {} 27 | self.score = heapdict() 28 | self.depth = 0 29 | self.max_nodes = max_nodes 30 | self.max_depth = max_depth 31 | self.num_samples = num_samples 32 | self.tactic_sid_record = [] 33 | self.abandon_if_contain = abandon_if_contain 34 | self.is_incontext = is_incontext 35 | self.template = template 36 | self.use_retrieval = use_retrieval 37 | self.alpha = alpha 38 | 39 | def insert(self, node: Node, formal_statement: str=None): 40 | if not node.state: 41 | try: 42 | print(f"lake repl check!!!",flush=True) 43 | full_proof = self.get_incontext(node, formal_statement) 44 | repl_res = verify_proof(full_proof) 45 | except Exception as e: 46 | print(traceback.format_exc()) 47 | self.found = False 48 | else: 49 | self.found = repl_res 50 | if not self.found: 51 | return 52 | # self.found = True 53 | deduped_state_str = state_repr_dedup(node.state) 54 | if deduped_state_str not in self.nodes: 55 | self.nodes[deduped_state_str] = node 56 | if node.depth >0: 57 | self.score[deduped_state_str] = -node.score/((node.depth)**self.alpha) 58 | else: 59 | self.score[deduped_state_str] = 0.0 60 | self.depth = max(self.depth, node.depth) 61 | 62 | def get(self) -> Node: 63 | k, _ = self.score.popitem() 64 | return self.nodes[k] 65 | 66 | def going(self) -> bool: 67 | return not self.found and len(self.nodes) < self.max_nodes 68 | 69 | def tactic_filter(self, tactic: str) -> bool: 70 | for forbidden_tactic in self.abandon_if_contain: 71 | if forbidden_tactic in tactic: 72 | return False 73 | return True 74 | 75 | def get_incontext(self, node: Node, formal_statement: str): 76 | path = [] 77 | current = node 78 | while current: 79 | path.append(current.tactic) 80 | current = current.parent 81 | path.reverse() 82 | res = formal_statement.replace("sorry", "").replace("by", "").strip() + " by\n" 83 | res += "\n".join(path) 84 | return res 85 | 86 | def search_proof(self, generator: TacticGenerator, interactive: Interactive): 87 | while self.going() and generator.has_quota(): 88 | try: 89 | node = self.get() 90 | except IndexError: 91 | break 92 | if self.is_incontext: 93 | incontext = self.get_incontext(node, generator.formal_statement) 94 | else: 95 | incontext = None 96 | tactics, logprobs = generator.from_state(node.state, self.num_samples, incontext, self.template, self.use_retrieval) 97 | tactics_logprob = [] 98 | for tactic, logprob in zip(tactics,logprobs): 99 | tactics_logprob.append((tactic,logprob)) 100 | for tactic_logprob, num_reps in Counter(tactics_logprob).items(): 101 | tactic,logprob = tactic_logprob 102 | if not self.tactic_filter(tactic): 103 | continue 104 | self.tactic_sid_record.append({"tactic":tactic, "sid":node.sid}) 105 | try: 106 | sid = interactive.run_tactic(node.sid, tactic) 107 | except RuntimeError as e: 108 | # error_info = traceback.format_exc() 109 | # print(error_info,flush=True) 110 | pass 111 | except Exception as e: 112 | #目前仅在run_tactic加入记录error-logging功能, 因为根据以往经验在get_state/giveup 加入try block可能会导致broken pipe error 113 | #如果确定问题所在可以手动添加 114 | raise SearchError("An error occurred at run_tactic", 115 | error_data = self.tactic_sid_record, 116 | error_type = e) 117 | else: 118 | state = interactive.get_state(sid) 119 | # print(f"{tactic},logprob:{logprob}",flush=True) 120 | # print(state,flush=True) 121 | new_node = Node(sid, node.sid, tactic, state, node.depth + 1, logprob+node.score, node) 122 | if hard_stop_criterion(new_node): 123 | continue 124 | self.insert(new_node,generator.formal_statement) 125 | if not state and self.found: 126 | interactive.commit(sid) 127 | break 128 | if not self.found: 129 | sid = interactive.give_up(0) 130 | interactive.commit(sid) 131 | 132 | @property 133 | def info(self): 134 | return dict( 135 | use_beam_search=False, 136 | beam_width=None, 137 | num_samples=self.num_samples, 138 | max_nodes=self.max_nodes, 139 | max_depth=self.max_depth) 140 | -------------------------------------------------------------------------------- /Realprover/manager/thirdparty/generator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Any 4 | from vllm import LLM, SamplingParams 5 | from vllm.lora.request import LoRARequest 6 | 7 | from manager.manage import ModelManage, PromptManage 8 | from manager.struct import Goal, state_repr 9 | from manager.thirdparty import LeanSearch, Claude 10 | import conf.config 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class TacticGenerator: 16 | def __init__(self, 17 | model_list: list[str], 18 | gpu_id: int, 19 | local_model_path: str=conf.config.PROVER_MODEL_PATH, 20 | sampling_params: dict['str', Any]=conf.config.PROVER_MODEL_PARAMS, # n excluded 21 | max_calls: int=conf.config.MAX_CALLS): 22 | """ 23 | model_list: 支持多个model 24 | """ 25 | self.gpu_id = gpu_id 26 | self.model_list = model_list 27 | self.calls = [] 28 | self.llm = None 29 | self.model_path = local_model_path 30 | self.sampling_params = sampling_params 31 | self.max_calls = max_calls 32 | 33 | def from_state_str(self, state_str: str, num_samples: int, incontext: str=None, template: str = 'deepseek', use_retrieval: bool=True) -> tuple[list[str], list[float]]: 34 | try: 35 | tactics, logprobs, prompt = self.get_lean_tactics(state_str, num_samples=num_samples, incontext=incontext, template=template, use_retrieval=use_retrieval) 36 | self.calls.append((state_str, tactics, logprobs, prompt)) 37 | except Exception: 38 | logger.exception("message") 39 | 40 | tactics, logprobs = [], [] 41 | return tactics, logprobs 42 | 43 | def from_state(self, state: list[Goal], num_samples: int, incontext: str=None, template: str = 'deepseek', use_retrieval: bool=True) -> tuple[list[str], list[float]]: 44 | return self.from_state_str(state_repr(state), num_samples, incontext, template, use_retrieval) 45 | 46 | def from_state_batch(self, states: list[list[Goal]], num_samples: int, incontext: list[str]=None) -> list[tuple[list[str], list[float]]]: 47 | state_strs = [state_repr(s) for s in states] 48 | related_theorems = LeanSearch.get_related_theorem_batch(state_strs) 49 | if incontext is None: 50 | prompts = [PromptManage.build_local_prompt_str(s, r) 51 | for s, r in zip(state_strs, related_theorems)] 52 | else: 53 | prompts = [PromptManage.build_local_prompt_str(s, r, c) 54 | for s, r, c in zip(state_strs, related_theorems, incontext)] 55 | self._init_model() # 懒加载,用到时才加载模型 56 | assert self.llm is not None 57 | sampling_params = SamplingParams(n=num_samples, **self.sampling_params) 58 | outputs = self.llm.generate(prompts, sampling_params) 59 | results = [] 60 | 61 | for state, prompt, output in zip(states, prompts, outputs): 62 | response = [ot.text.strip() for ot in output.outputs] 63 | # 为完整记录,在此处不过滤不合法tactic 64 | # response = [i for i in response if not "sorry" in i] 65 | logprob = [output.cumulative_logprob / max(len(output.token_ids), 1) # type: ignore 66 | for output in output.outputs] 67 | results.append((response, logprob)) 68 | self.calls.append((state, prompt, (response, logprob))) 69 | return results 70 | 71 | 72 | def get_lean_tactics(self, state: str, num_samples: int, incontext: str=None, template: str = 'deepseek', use_retrieval: bool=True) -> tuple[list[str], list[float], str]: 73 | # 获取相关定理 74 | if use_retrieval: 75 | related_theorems = LeanSearch.get_related_theorem(state) 76 | else: 77 | related_theorems = None 78 | if incontext is None: 79 | prompt = PromptManage.build_local_prompt_str(state, related_theorems) 80 | else: 81 | prompt = PromptManage.build_local_incontext_prompt_str(incontext, state, related_theorems, template) 82 | responses = [] 83 | logprobs = [] 84 | sampling_params = SamplingParams(n=num_samples, **self.sampling_params) 85 | 86 | # Get tactics from Claude if requested 87 | if ModelManage.contain_gemini(self.model_list): 88 | claude_responses = asyncio.run(Claude().get_claude_tactics(state, related_theorems, num_samples)) 89 | responses.extend(claude_responses) 90 | 91 | # Get tactics from local model if requested 92 | if ModelManage.contain_local(self.model_list): 93 | self._init_model() # 懒加载,用到时才加载模型 94 | assert self.llm is not None 95 | sampling_params = SamplingParams(n=num_samples, **self.sampling_params) 96 | outputs = self.llm.generate([prompt], sampling_params, use_tqdm=False) 97 | # lora = LoRARequest("new_data", self.gpu_id, "/AI4M/users/nhhuang/LLaMA-Factory/ds_stepprover_algebra_together") 98 | # outputs = self.llm.generate([prompt], sampling_params, use_tqdm=False, lora_request=lora) 99 | local_responses = [output.text.strip() for output in outputs[0].outputs] 100 | local_responses = [i for i in local_responses if not "sorry" in i] 101 | local_logprobs = [output.cumulative_logprob / max(len(output.token_ids), 1) # type: ignore 102 | for output in outputs[0].outputs] 103 | responses.extend(local_responses) 104 | logprobs.extend(local_logprobs) 105 | 106 | return responses, logprobs, prompt 107 | 108 | def has_quota(self) -> bool: 109 | return len(self.calls) < self.max_calls 110 | 111 | def reset_calls(self, formal_statement:str = None): 112 | """ 113 | 单进程只实例化一次generator, 每次执行时需重置calls 114 | """ 115 | self.calls = [] 116 | self.formal_statement = formal_statement 117 | 118 | def _init_model(self): 119 | if self.llm is None: 120 | self.llm = LLM(model=self.model_path) # Lora 121 | 122 | @property 123 | def info(self): 124 | return dict( 125 | model_path=self.model_path, 126 | max_calls=self.max_calls, 127 | sampling_params=self.sampling_params) -------------------------------------------------------------------------------- /Realprover/manager/thirdparty/interactive.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import logging 4 | import subprocess 5 | from io import TextIOWrapper 6 | from pathlib import Path 7 | 8 | import conf.config as config 9 | from manager.struct.structs import from_json, Goal, ProofGoal, ProofVariable 10 | import os 11 | logging.basicConfig(level=logging.INFO, format='%(levelname)s %(asctime)s [%(name)s] %(message)s') 12 | logger = logging.getLogger(os.path.basename(__file__)) 13 | IS_BUILD = False 14 | 15 | def get_project_toolchain(root: Path) -> str: 16 | with (root / "lean-toolchain").open() as fp: 17 | return fp.read() 18 | 19 | 20 | def build_with_toolchain(root: Path, toolchain: str) -> Path: 21 | logger.info("Building interactive for %s", toolchain) 22 | if IS_BUILD: 23 | with (root / "lean-toolchain").open("w") as fp: 24 | fp.write(toolchain) 25 | # logger.info("Building interactive for %s", toolchain) 26 | if subprocess.call( 27 | ["lake", "build"], 28 | stdout=subprocess.DEVNULL, 29 | cwd=root) != 0: 30 | raise RuntimeError(f"Failed to build: {root} with {toolchain}") 31 | return root / ".lake" / "build" / "bin" 32 | 33 | 34 | def build_interactive(toolchain: str) -> Path: 35 | bin_path = build_with_toolchain(config.interactive_path, toolchain) 36 | return bin_path / "interactive" 37 | 38 | 39 | class Interactive: 40 | def __init__(self, root: Path, path: Path): 41 | toolchain = get_project_toolchain(root) 42 | interactive_bin = build_interactive(toolchain) 43 | args = ["lake", "env", interactive_bin, "-i", path] 44 | self.proc = subprocess.Popen( 45 | args, 46 | stdin=subprocess.PIPE, 47 | stdout=subprocess.PIPE, 48 | cwd=root, 49 | ) 50 | self.read_from = self.proc.stdout 51 | self.write_to = TextIOWrapper(self.proc.stdin, line_buffering=True) 52 | self.id = 0 53 | self.tactic_mode = False 54 | 55 | def read(self) -> dict: 56 | line = self.read_from.readline().strip() 57 | logger.debug("<-" + repr(line)) 58 | return json.loads(line) 59 | 60 | def write(self, data: dict): 61 | data = json.dumps(data, separators=(',', ':'), ensure_ascii=False) 62 | logger.debug("->" + data) 63 | print(data, file=self.write_to) 64 | 65 | def open_file(self, path: Path, selectors: list[str | int | None]): 66 | assert not self.tactic_mode 67 | self.write({"filename": str(path), "selectors": selectors}) 68 | 69 | def get_next_problem(self) -> str | None: 70 | assert not self.tactic_mode 71 | logger.debug("get_next_problem") 72 | response = self.read() 73 | decl_name = response.get("declName") 74 | self.tactic_mode = decl_name is not None 75 | return decl_name 76 | 77 | def request(self, method: str, params: dict): 78 | assert self.tactic_mode 79 | self.id += 1 80 | self.write({'id': self.id, 'method': method, 'params': params}) 81 | response = self.read() 82 | try: 83 | return response['result'] 84 | except KeyError: 85 | raise RuntimeError(response['error']) 86 | 87 | def run_tactic(self, sid: int, tactic: str, heartbeats: int = 200000000) -> int: 88 | return self.request('runTactic', {'sid': sid, 'tactic': tactic, 'heartbeats': heartbeats}) 89 | 90 | def get_state(self, sid: int) -> list[Goal]: 91 | res = self.request('getState', {'sid': sid}) 92 | return from_json(list[Goal], res) 93 | 94 | def get_messages(self, sid: int) -> list[str]: 95 | res = self.request('getMessages', {'sid': sid}) 96 | return from_json(list[str], res) 97 | 98 | def resolve_name(self, sid: int, name: str) -> list[tuple[str, list[str]]]: 99 | return self.request('resolveName', {'sid': sid, 'name': name}) 100 | 101 | def unify(self, sid: int, s1: str, s2: str) -> list[tuple[str, str | None]] | None: 102 | return self.request('unify', {'sid': sid, 's1': s1, 's2': s2}) 103 | 104 | def new_state(self, state: list[ProofGoal]) -> int: 105 | return self.request('newState', {'state': [dataclasses.asdict(g) for g in state]}) 106 | 107 | def get_position(self) -> dict: 108 | return self.request('getPosition', {}) 109 | 110 | def give_up(self, sid: int) -> int: 111 | return self.request('giveUp', {'sid': sid}) 112 | 113 | def commit(self, sid: int): 114 | self.request('commit', {'sid': sid}) 115 | self.tactic_mode = False 116 | 117 | 118 | if __name__ == '__main__': 119 | logging.basicConfig(level=logging.DEBUG) 120 | interactive = Interactive(Path(__file__).parent, Path("Example.lean")) 121 | sid = 0 122 | interactive.open_file(Path("Example.lean"), ["neg_is_some_none", "eq_trans_sym"]) 123 | 124 | while True: 125 | problem = interactive.get_next_problem() 126 | if problem is None: 127 | break 128 | 129 | print("Solving problem", problem) 130 | sid = 0 131 | 132 | while True: 133 | state = interactive.get_state(sid) 134 | if not state: 135 | # proof complete 136 | interactive.commit(sid) 137 | break 138 | print("current state id:", sid) 139 | print(f"{len(state)} goals:") 140 | for goal in state: 141 | print(goal.pretty) 142 | 143 | command = input() 144 | try: 145 | if command.startswith(":s"): 146 | sid = int(command.split()[1]) 147 | elif command.startswith(":u"): 148 | params = command.split()[1:] 149 | result = interactive.unify(sid, params[0], params[1]) 150 | print(result) 151 | elif command.startswith(":n"): 152 | sid = interactive.new_state([ 153 | ProofGoal(context=[ProofVariable("a", "Nat")], type="Nat"), 154 | ProofGoal(context=[ProofVariable("b", "Nat")], type="Nat") 155 | ]) 156 | else: 157 | sid = interactive.run_tactic(sid, command) 158 | messages = interactive.get_messages(sid) 159 | if messages: 160 | print("Messages:", messages) 161 | except RuntimeError as e: 162 | err = e.args[0] 163 | print(f"Error {err['code']}: {err['message']}") 164 | print(err['data']) -------------------------------------------------------------------------------- /Realprover/manager/manage/prompt_manage.py: -------------------------------------------------------------------------------- 1 | class PromptManage(object): 2 | """ 3 | prompt有关 4 | """ 5 | @staticmethod 6 | def chat_template_to_prompt(messages, template): 7 | result = "" 8 | if template=='qwen': 9 | result += "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 10 | # elif template=='deepseek': 11 | # result += "<|begin▁of▁sentence|>" 12 | total_step = len(messages) 13 | for i, message in enumerate(messages): 14 | if template == 'internlm': 15 | result += ('<|im_start|>' + message['role'] + 16 | '\n' + message['content']) 17 | if i+1 != total_step: 18 | result += '<|im_end|>\n' 19 | elif message['role'] == 'user': 20 | result += '<|im_end|>\n<|im_start|>assistant\n' 21 | elif template=='deepseek': 22 | if message['role']=='user': 23 | result += 'User: ' + message['content'] + '\n\n' 24 | elif message['role']=='assistant': 25 | result += 'Assistant:' + message['content'] + '<|end▁of▁sentence|>' 26 | elif message['role'] == 'system': 27 | result += message['content'] + '\n\n' 28 | if i+1 == total_step and message['role'] == 'user': 29 | result += 'Assistant:' 30 | elif template=='qwen': 31 | if message['role'] == 'user': 32 | result += f"<|im_start|>user\n{message['content']}<|im_end|>\n" 33 | elif message['role'] == 'assistant': 34 | result += f"<|im_start|>assistant\n{message['content']}<|im_end|>\n" 35 | if i+1 == total_step and message['role'] == 'user': 36 | result += "<|im_start|>assistant\n" 37 | elif template=='deepseek3': 38 | if message['role'] == 'user': 39 | result += f"<|User|>{message['content']}" 40 | elif message['role'] == 'assistant': 41 | result += f"<|Assistant|>{message['content']}<|end▁of▁sentence|>" 42 | if i+1 == total_step and message['role'] == 'user': 43 | result += "<|Assistant|>" 44 | else: 45 | raise NotImplementedError 46 | return result 47 | 48 | @staticmethod 49 | def build_local_prompt_str(state: str, related_theorems: list[dict['str', 'str']]) -> str: 50 | theorems_str = PromptManage.build_theorems_str(related_theorems) 51 | prompt = f"Please generate a tactic in lean4 to solve the state.\nHere're some theorems that may be helpful:\n{theorems_str}\nSTATE:\n{state}\nTACTIC:\n" 52 | prompt = PromptManage.chat_template_to_prompt( 53 | [{'role': 'user', 'content': prompt}], 'deepseek' 54 | ) 55 | return prompt 56 | 57 | @staticmethod 58 | def build_local_incontext_prompt_str(incontext: str, state: str, related_theorems: list[dict['str', 'str']],template: str = 'deepseek'): 59 | if related_theorems is not None: 60 | theorems_str = PromptManage.build_theorems_str(related_theorems) 61 | prompt = """In Lean, a formal proof is a fully constructed proof term that is type-checked and verified by the kernel. It represents a complete and correct derivation of a proposition. 62 | 63 | The state after tactics refers to the intermediate proof state during tactic-based proof construction. It includes the list of remaining goals and the local context at that point. 64 | 65 | Relationship: 66 | 67 | - Tactics are procedural tools used to incrementally construct a formal proof. 68 | - Each tactic transforms the current proof state by solving or reducing goals. 69 | - The state after a tactic reflects the goals that still need to be proven after that tactic has been applied. 70 | - Once all goals are solved, Lean assembles the underlying proof terms generated by the tactics into a complete formal proof. 71 | - This final term is then type-checked by the kernel to ensure correctness. 72 | 73 | In essence, the state after tactics shows where you are in the process of building a formal proof — it's a snapshot of what's left to do before the proof is complete. 74 | 75 | Here is the FORMAL PROOF before the current state: 76 | """ 77 | prompt += incontext 78 | prompt += "\nHere is the current STATE:\n" 79 | prompt += state 80 | prompt += "\n\n**Please generate a TACTIC in lean4 to solve the state.**" 81 | if related_theorems is not None: 82 | prompt += "\n\nAnd here're some theorems that may be helpful:\n" 83 | prompt += theorems_str 84 | prompt = PromptManage.chat_template_to_prompt( 85 | [{'role': 'user', 'content': prompt}], template 86 | ) 87 | return prompt 88 | 89 | @staticmethod 90 | def build_claude_prompt_str(state, related_theorems): 91 | theorems_str = PromptManage.build_theorems_str(related_theorems) 92 | claude_prompt = f"""You are a Lean4 theorem prover assistant. Given the state and some potentially helpful theorems, 93 | generate ONE SINGLE tactic step that could help prove the goal. Output ONLY the tactic, nothing else. 94 | 95 | Relevant theorems: 96 | {theorems_str} 97 | 98 | Current state: 99 | {state} 100 | 101 | Output the tactic:""" 102 | return claude_prompt 103 | 104 | @staticmethod 105 | def build_theorems_str(related_theorems): 106 | return "\n\n".join([ 107 | f"ID:{index}\nFormal name: {i['Formal name']}\nInformal name: {i['Informal name']}\nFormal statement: {i['Formal statement']}" 108 | for index, i in enumerate(related_theorems[:6]) 109 | ]) 110 | 111 | @staticmethod 112 | def build_claude_critic_str(state1, tactic, state2): 113 | return f"""You are a Lean4 theorem prover assitant. Given the state before tactic and current tactic. Tell whether the current tactic is trying to make repetitive or useless propositions. If so, output FALSE to deny it, otherwise output True. Please output ONLY TRUE or FALSE, nothing else. 114 | 115 | State before current tactic: 116 | {state1} 117 | 118 | Current tactic: 119 | {tactic} 120 | 121 | Your output:""" 122 | 123 | if __name__=='__main__': 124 | state = "case:\nn : Nat" 125 | theorems = [{ 126 | 'Formal name': 'Test.test_thm_aux_1', 127 | 'Informal name': 'Test\'s First Theorem', 128 | 'Formal statement': 'thoerem_test_1'}, { 129 | 'Formal name': 'Test.test_thm_aux_2', 130 | 'Informal name': 'Test\'s Second Theorem', 131 | 'Formal statement': 'thoerem_test_2'}] 132 | print(PromptManage.build_local_prompt_str(state, theorems)) -------------------------------------------------------------------------------- /Herald/service/parallel_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import multiprocessing as mp 4 | 5 | from service.handler import TranHandler, BackHandler 6 | from util import CommonUtil 7 | 8 | 9 | class ParallelService(object): 10 | """ 11 | 多卡并行运行服务:翻译使用本地模型 12 | """ 13 | 14 | def __init__(self, source_file, result_dir, re_run=False, 15 | trans_gpus=None, back_gpus=None): 16 | """ 17 | source_file: 输入文件jsonl路径 18 | result_dir: 输出文件目录 19 | re_run: 已经存在结果文件时, True: 重新run, False: 跳过 20 | trans_gpus: trans用到的gpu卡 21 | back_gpus: 反翻译用到的gpu卡 22 | """ 23 | if trans_gpus is None: 24 | trans_gpus = [0] 25 | if back_gpus is None: 26 | back_gpus = [1] 27 | 28 | self.source_file = source_file 29 | self.result_dir = result_dir 30 | self.re_run = re_run 31 | self.trans_gpus = trans_gpus 32 | self.back_gpus = back_gpus 33 | 34 | # 初始化原始数据 35 | self._init_source_list() 36 | 37 | self.manager = mp.Manager() 38 | self.shared_dict = self.manager.dict() # 创建一个共享字典 39 | self._init_one_share_data() 40 | 41 | self.trans_queue = mp.Queue() 42 | self.back_queue = mp.Queue() 43 | self.lock = mp.Lock() 44 | 45 | # 存放handler, Key是gpu_id 46 | self.handler = self.manager.dict() 47 | 48 | def _init_handler(self): 49 | for i in self.trans_gpus: 50 | self.handler[i] = TranHandler() 51 | for j in self.back_gpus: 52 | self.handler[j] = BackHandler() 53 | 54 | def _init_source_list(self): 55 | source_list = CommonUtil.read_json_list(self.source_file) 56 | # unique_key 用来作为唯一索引和每条结果文件的目录 57 | for index, item in enumerate(source_list): 58 | item['unique_key'] = f"index_{index}" 59 | self.source_list = source_list 60 | 61 | def _init_one_share_data(self): 62 | """ 63 | 初始化结构体数据 64 | """ 65 | for item in self.source_list: 66 | temp_dict = { 67 | 'unique_key': item['unique_key'], 68 | 'informal_statement': item['informal_statement'], 69 | 'translate_list': [], 70 | 'back_trans_list': [], 71 | } 72 | temp_dict.update(item) 73 | self.shared_dict[item['unique_key']] = self.manager.dict(temp_dict) 74 | 75 | def _set_dict_data(self, key_name, type_name, data_list): 76 | try: 77 | with self.lock: # 使用锁来确保对 shared_dict 的操作是线程安全的 78 | self.shared_dict[key_name][type_name] = data_list 79 | except BrokenPipeError as e: 80 | print(f"Error: {e}") 81 | 82 | def _init_trans_queue(self): 83 | for item in self.source_list: 84 | if not self.re_run: 85 | # 不重新跑时, 如果已经存在结果文件则跳过 86 | back_file_path = f"{self.result_dir}/{item['unique_key']}/back_trans.json" 87 | if CommonUtil.file_exist(back_file_path): 88 | continue 89 | self.trans_queue.put(item) 90 | 91 | for _ in self.trans_gpus: 92 | self.trans_queue.put(None) 93 | 94 | def run(self): 95 | gpus = torch.cuda.device_count() 96 | assert gpus >= 2 97 | self._init_handler() 98 | self._init_trans_queue() 99 | print(f"data_list_size = {len(self.source_list)}") 100 | 101 | trans_process_list = [] 102 | back_process_list = [] 103 | for i in self.trans_gpus: 104 | # 创建trans进程 105 | process1 = mp.Process(target=self._run_translate, args=(i,)) 106 | process1.start() 107 | trans_process_list.append(process1) 108 | for j in self.back_gpus: 109 | # 创建back进程 110 | process2 = mp.Process(target=self._run_back_trans, args=(j,)) 111 | process2.start() 112 | back_process_list.append(process2) 113 | 114 | for i in range(len(trans_process_list)): 115 | trans_process_list[i].join() 116 | 117 | # 发送结束信号到第二步, 因为步骤2有多个进程,每个进程都需要单独收到一个 None 才能正确退出. 118 | for _ in self.back_gpus: 119 | self.back_queue.put(None) 120 | 121 | for j in range(len(back_process_list)): 122 | back_process_list[j].join() 123 | 124 | print("All data processed.") 125 | 126 | def _run_translate(self, gpu_id): 127 | """ 128 | 翻译、编译 129 | """ 130 | # 设置GPU环境变量 131 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 132 | this_handler = self.handler[gpu_id] 133 | 134 | while True: 135 | item = self.trans_queue.get() 136 | if item is None: # 检测结束信号 137 | break 138 | self._init_dir(item['unique_key']) 139 | generate_list = this_handler.generate_and_check(item['informal_statement']) 140 | item['translate_list'] = generate_list 141 | print(f"translate_list.size = {len(item['translate_list'])}") 142 | 143 | self._set_dict_data(item['unique_key'], 'translate_list', generate_list) 144 | self._save_trans_data(item) 145 | self.back_queue.put(item) # 将结果放入队列 146 | 147 | # self.back_queue.put(None) # 发送结束信号 148 | 149 | def _run_back_trans(self, gpu_id): 150 | """ 151 | 反翻译 & 比对 152 | """ 153 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 154 | this_handler = self.handler[gpu_id] 155 | while True: 156 | item = self.back_queue.get() 157 | if item is None: # 检测结束信号 158 | break 159 | data_list = [{ 160 | 'informal_statement': item['informal_statement'], 161 | 'formal_statement': i 162 | } for i in item['translate_list']] 163 | valid_list = this_handler.back_compare_filter(data_list) 164 | valid_formal_list = [i['formal_statement'] for i in valid_list] 165 | print(f"valid_formal_list.size = {len(valid_formal_list)}") 166 | 167 | self._set_dict_data(item['unique_key'], 'back_trans_list', valid_formal_list) 168 | self._save_back_trans_data(item) 169 | 170 | def _init_dir(self, file_name): 171 | temp_dir = f"{self.result_dir}/{file_name}" 172 | if not os.path.exists(temp_dir): 173 | os.makedirs(temp_dir) 174 | 175 | def _save_trans_data(self, item): 176 | CommonUtil.write_to_json_file(self._gen_file_path(item), dict(self.shared_dict[item['unique_key']])) 177 | 178 | def _save_back_trans_data(self, item): 179 | CommonUtil.write_to_json_file(self._gen_file_path(item, 'back'), dict(self.shared_dict[item['unique_key']])) 180 | 181 | def _gen_file_path(self, item, type_name='trans'): 182 | return f"{self.result_dir}/{item['unique_key']}/translate.json" if type_name == "trans" \ 183 | else f"{self.result_dir}/{item['unique_key']}/back_trans.json" 184 | -------------------------------------------------------------------------------- /Realprover/manager/struct/structs.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import types 3 | import typing 4 | from functools import cached_property 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | from operator import attrgetter 8 | 9 | Name = list[str | int] 10 | ImportInfoRaw = list[Name] 11 | 12 | @dataclass 13 | class StringRange: 14 | """Byte position within a file""" 15 | start: int 16 | stop: int 17 | 18 | def as_slice(self) -> slice: 19 | return slice(self.start, self.stop) 20 | 21 | @dataclass 22 | class Param: 23 | bi: str 24 | type: Optional[StringRange] 25 | 26 | 27 | @dataclass 28 | class Modifiers: 29 | visibility: str 30 | is_noncomputable: bool 31 | rec_kind: str 32 | is_unsafe: bool 33 | doc_string: Optional[str] = field(default=None) 34 | 35 | 36 | @dataclass 37 | class Syntax: 38 | original: bool 39 | range: Optional[StringRange] 40 | 41 | 42 | @dataclass 43 | class DeclarationInfoRaw: 44 | kind: str 45 | id: Optional[Syntax] 46 | name: Name 47 | fullname: Name 48 | modifiers: Modifiers 49 | params: Optional[list[Param]] 50 | type: Optional[Syntax] 51 | value: Optional[Syntax] 52 | tactics: list[Syntax] 53 | ref: Optional[Syntax] 54 | 55 | 56 | @dataclass 57 | class SymbolInfoRaw: 58 | name: Name 59 | type: Optional[str] 60 | kind: str 61 | typeReferences: Optional[list[Name]] 62 | valueReferences: Optional[list[Name]] 63 | isProp: bool 64 | 65 | 66 | @dataclass 67 | class Module: 68 | pass 69 | 70 | 71 | @dataclass 72 | class Variable: 73 | name: Name 74 | type: str 75 | is_prop: bool 76 | binder_info: str = field(default="default") 77 | value: Optional[str] = field(default=None) 78 | 79 | @cached_property 80 | def pretty(self) -> str: 81 | s = f"{pretty_name(self.name)} : {self.type}" 82 | if self.value is not None: 83 | s += " := " + self.value 84 | return s 85 | 86 | @cached_property 87 | def as_param(self) -> str: 88 | if self.value is not None: 89 | raise ValueError("Let-bindings should not be used as parameters") 90 | if self.binder_info == "instImplicit": 91 | return f"[{self.type}]" 92 | if self.binder_info == "default": 93 | l, r = "(", ")" 94 | elif self.binder_info == "implicit": 95 | l, r = "{", "}" 96 | elif self.binder_info == "strictImplicit": 97 | l, r = "{{", "}}" 98 | else: 99 | raise RuntimeError("Unexpected binder_info") 100 | return f"{l}{self.pretty}{r}" 101 | 102 | 103 | @dataclass 104 | class Goal: 105 | context: list[Variable] 106 | type: str 107 | is_prop: bool 108 | 109 | @cached_property 110 | def pretty(self) -> str: 111 | return "\n".join(v.pretty for v in self.context) + "\n⊢ " + self.type 112 | 113 | @cached_property 114 | def as_signature(self) -> str: 115 | return " ".join(v.as_param for v in self.context if v.value is None) + " : " + self.type 116 | 117 | def state_repr(state: Goal | list[Goal]) -> str: 118 | if not state: 119 | return "no goals" 120 | if isinstance(state, Goal): 121 | return state.pretty 122 | if isinstance(state, list): 123 | if len(state) == 1: 124 | return state[0].pretty 125 | state_str = f"case:\n{state[0].pretty}" 126 | for i in state[1:]: 127 | state_str += f"\n\ncase:\n{i.pretty}" 128 | return state_str 129 | raise TypeError 130 | 131 | 132 | def state_repr_dedup(state: list[Goal]) -> str: 133 | k = [] 134 | for goal in state: 135 | prop = [v for v in goal.context if v.is_prop] 136 | prop.sort(key=attrgetter("type")) 137 | dedup_prop = [] 138 | for p in prop: 139 | if not dedup_prop or dedup_prop[-1].type != p.type: 140 | dedup_prop.append(p) 141 | non_prop = [v for v in goal.context if not v.is_prop] 142 | k.append("\n".join(v.pretty for v in dedup_prop + non_prop) + "\n⊢ " + goal.type) 143 | k = "\n\n".join(k) 144 | # print(k) 145 | # import ipdb;ipdb.set_trace() 146 | return k 147 | 148 | @dataclass 149 | class ProofVariable: 150 | name: str 151 | type: str 152 | 153 | 154 | @dataclass 155 | class ProofGoal: 156 | context: list[ProofVariable] 157 | type: str 158 | 159 | 160 | @dataclass 161 | class TacticInfo: 162 | kind: Name 163 | original: bool 164 | range: Optional[StringRange] = field(default=None) 165 | 166 | 167 | @dataclass 168 | class TacticElabInfo: 169 | tactic: TacticInfo 170 | references: Optional[list[Name]] 171 | before: list[Goal] 172 | after: list[Goal] 173 | 174 | 175 | # TODO: add «» quoting when necessary 176 | def pretty_name(name: Name): 177 | return ".".join(str(c) for c in name) 178 | 179 | 180 | # TODO: use pydantic 181 | def snake_to_camel(s: str) -> str: 182 | words = s.split("_") 183 | return words[0] + "".join(w.capitalize() for w in words[1:]) 184 | 185 | 186 | def extract_field(data: dict, f: dataclasses.Field): 187 | k = snake_to_camel(f.name) 188 | if f.default is not dataclasses.MISSING: 189 | return data.get(k, f.default) 190 | elif f.default_factory is not dataclasses.MISSING: 191 | return data.get(k, f.default_factory()) 192 | else: 193 | return data[k] 194 | 195 | @dataclass 196 | class Node: 197 | sid: int 198 | parent_sid: int 199 | tactic: str 200 | state: list[Goal] 201 | depth: int = 0 202 | score: float = 0 203 | parent: Optional["Node"] = None 204 | 205 | @property 206 | def current_path(self): 207 | path = [] 208 | current = self 209 | while current: 210 | path.append(current) 211 | current = current.parent 212 | path.reverse() # Reverse the path to get it from root to current node 213 | return path 214 | 215 | def from_json(tp: type, data): 216 | origin = typing.get_origin(tp) 217 | if origin is typing.Union or origin is types.UnionType: 218 | args = typing.get_args(tp) 219 | for arg in args: 220 | try: 221 | return from_json(arg, data) 222 | except: 223 | pass 224 | raise TypeError(f"union {tp}") 225 | elif origin is list: 226 | (arg,) = typing.get_args(tp) 227 | # print(data) 228 | if not isinstance(data, list): 229 | raise TypeError(f"list {tp}") 230 | return [from_json(arg, x) for x in data] 231 | elif origin is dict: 232 | (kt, vt) = typing.get_args(tp) 233 | if not isinstance(data, dict): 234 | raise TypeError(f"dict {tp}") 235 | assert kt is str 236 | return {k: from_json(vt, v) for k, v in data.items()} 237 | elif tp is StringRange: 238 | if not isinstance(data, list): 239 | raise TypeError(f"list {tp}") 240 | return StringRange(start=data[0], stop=data[1]) 241 | elif dataclasses.is_dataclass(tp): 242 | tp: dataclass 243 | fields = dataclasses.fields(tp) 244 | return tp(**{f.name: from_json(f.type, extract_field(data, f)) for f in fields}) 245 | else: 246 | if not isinstance(data, tp): 247 | # print(data, tp) 248 | raise TypeError(tp) 249 | return data -------------------------------------------------------------------------------- /Realprover/manager/service/base_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import platform 4 | from pathlib import Path 5 | 6 | import conf.config 7 | from manager.thirdparty import Interactive, TacticGenerator 8 | from manager.struct import Node 9 | from manager.search import BestFirstSearch, BeamSearch, MCTSSearch 10 | from manager.manage import ProofParseManage 11 | 12 | 13 | class BaseService(object): 14 | """ 15 | 16 | """ 17 | 18 | def __init__(self, 19 | num_samples: int = conf.config.NUM_SAMPLES, 20 | max_nodes: int = conf.config.MAX_NODES, 21 | max_depth: int = conf.config.MAX_DEPTH, 22 | use_beam_search: bool = conf.config.USE_BEAM_SEARCH, 23 | use_mcts_search: bool = conf.config.USE_MCTS_SEARCH, 24 | simulation_depth: int = conf.config.SIM_DEPTH, 25 | c_puct: float = conf.config.C_PUCT, 26 | beam_width: int = conf.config.BEAM_WIDTH, 27 | root: Path = Path(conf.config.LEAN_TEST_PATH), 28 | lean_env: str = conf.config.LEAN_ENV_PATH, 29 | model_list: list = conf.config.DEFAULT_MODEL_LIST, 30 | local_model_path: str = conf.config.PROVER_MODEL_PATH, 31 | sampling_params: dict = conf.config.PROVER_MODEL_PARAMS, 32 | max_calls: int = conf.config.MAX_CALLS, 33 | max_root_expansion: int = conf.config.MAX_ROOT_EXPANSION, 34 | c_score: float = conf.config.C_SCORE, 35 | c_expansion_fail_penalty: float = conf.config.C_EXPANSION_FAIL_PENALTY, 36 | abandon_if_contain: list[str] = conf.config.ABANDON_IF_CONTAIN, 37 | is_incontext: bool = conf.config.IS_INCONTEXT, 38 | template: str = 'deepseek', 39 | use_retrieval: bool = conf.config.USE_RETRIEVAL 40 | ): 41 | """ 42 | 43 | """ 44 | self.num_samples = num_samples 45 | self.max_nodes = max_nodes 46 | self.max_depth = max_depth 47 | self.use_beam_search = use_beam_search 48 | self.use_mcts_search = use_mcts_search 49 | self.simulation_depth = simulation_depth 50 | self.c_puct = c_puct 51 | self.c_score = c_score 52 | self.c_expansion_fail_penalty = c_expansion_fail_penalty 53 | self.max_root_expansion = max_root_expansion 54 | self.beam_width = beam_width 55 | self.lean_env = lean_env 56 | self.root = root 57 | self.model_list = model_list 58 | self.local_model_path = local_model_path 59 | self.sampling_params = sampling_params 60 | self.max_calls = max_calls 61 | self.single_generator = None # 此处初始化一个generator变量的原因是:防止部署的server重复加载模型 62 | self.abandon_if_contain = abandon_if_contain 63 | self.info: dict = {} 64 | self.is_incontext = is_incontext 65 | self.template = template 66 | self.use_retrieval = use_retrieval 67 | 68 | def process_one(self, 69 | source: str, 70 | generator: TacticGenerator) -> list[tuple[str, BestFirstSearch, list]]: 71 | os.environ["PATH"] += ":" + str(self.lean_env) 72 | interactive = Interactive(self.root, Path("Header.lean")) 73 | 74 | 75 | test_file = self.root / f"TestOne_{platform.node()}_{os.getpid()}.lean" # 添加节点id和进程Id,防止多进程操作同一个文件引起的error 76 | with test_file.open("w") as fp: 77 | fp.write(source) 78 | interactive.open_file(test_file, [None]) 79 | 80 | results = [] 81 | while True: 82 | generator.reset_calls(source) 83 | decl = interactive.get_next_problem() 84 | if decl is None: 85 | break 86 | if self.use_beam_search: 87 | search = BeamSearch(max_nodes=self.max_nodes, 88 | max_depth=self.max_depth, 89 | beam_width=self.beam_width, 90 | num_samples=self.num_samples, 91 | abandon_if_contain = self.abandon_if_contain) 92 | elif self.use_mcts_search: 93 | search = MCTSSearch(max_nodes=self.max_nodes, 94 | max_depth=self.max_depth, 95 | num_samples=self.num_samples, 96 | simulation_depth=self.simulation_depth, 97 | c_puct=self.c_puct, 98 | c_score=self.c_score, 99 | c_expansion_fail_penalty=self.c_expansion_fail_penalty, 100 | max_root_expansion=self.max_root_expansion, 101 | max_calls=self.max_calls, 102 | abandon_if_contain = self.abandon_if_contain) 103 | else: 104 | search = BestFirstSearch(max_nodes=self.max_nodes, 105 | max_depth=self.max_depth, 106 | num_samples=self.num_samples, 107 | abandon_if_contain = self.abandon_if_contain, 108 | is_incontext=self.is_incontext, 109 | template=self.template, 110 | use_retrieval=self.use_retrieval) 111 | if self.info == {}: 112 | self.info.update(generator.info) 113 | self.info.update(search.info) 114 | 115 | state = interactive.get_state(0) 116 | search.insert(Node(0, 0, "", state)) # type: ignore 117 | search.search_proof(generator, interactive) 118 | results.append((decl, search, copy.copy(generator))) 119 | 120 | test_file.unlink() 121 | return results 122 | 123 | @staticmethod 124 | def collect_info(decl: str, search, generator) -> dict: 125 | """ This function collects info for final results 126 | """ 127 | nodes = [{ 128 | "id": node.sid, 129 | "parent": node.parent_sid, 130 | "depth": node.depth, 131 | "tactic": node.tactic, 132 | "state": [goal.pretty for goal in node.state], 133 | } for node in search.nodes.values()] 134 | return { 135 | "declaration": decl, 136 | "success": search.found, 137 | "calls": generator.calls, 138 | "nodes": nodes, 139 | "stop_cause": { 140 | "nodes": len(search.nodes) >= search.max_nodes, 141 | "depth": search.depth >= search.max_depth, 142 | "calls": not generator.has_quota() 143 | } 144 | } 145 | 146 | def parse_result(self, formal_statement, results): 147 | """ 148 | 用途:整理总体的输出结果信息 149 | results: 为 process_one() 生成的结果 150 | """ 151 | collect_results = [BaseService.collect_info(decl, search, generator) for decl, search, generator in results] 152 | ret = { 153 | 'formal_statement': formal_statement, 154 | 'collect_results': collect_results, 155 | } 156 | if collect_results and collect_results[0]['success']: 157 | # 成功时才输出formal_proof 158 | ret['formal_proof'] = ProofParseManage.get_correct_proof(ret) 159 | return ret 160 | 161 | def single_run(self, formal_statement, gpu_id=0): 162 | """ 163 | 单条运行测试, 并解析出返回值 164 | """ 165 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 166 | # this_generator = TacticGenerator(self.model_list, gpu_id, self.local_model_path) 167 | self._init_single_generator(gpu_id) 168 | results = self.process_one(formal_statement, self.single_generator) # type: ignore 169 | ret = self.parse_result(formal_statement, results) 170 | return ret 171 | 172 | def _init_single_generator(self, gpu_id=0): 173 | if self.single_generator is None: 174 | self.single_generator = TacticGenerator( 175 | model_list=self.model_list, 176 | gpu_id=gpu_id, 177 | local_model_path=self.local_model_path, 178 | sampling_params=self.sampling_params, 179 | max_calls=self.max_calls) -------------------------------------------------------------------------------- /Realprover/manager/service/pipeline_main_service.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import torch.multiprocessing as mp 3 | import os 4 | import time 5 | import sys 6 | from pathlib import Path 7 | import traceback 8 | from manager.service import BaseService 9 | from manager.thirdparty import TacticGenerator 10 | from util import CommonUtil #, profiler 11 | import conf.config 12 | from manager.thirdparty.verifier import verify_proof 13 | 14 | 15 | def data_producer(queue, counter, total_count, result_dir, file_prefix, re_run, num_workers=2): 16 | """ 17 | 轮训检测Herald生成的back_trans的数据,并添加到队列中 18 | """ 19 | exist_keys = set() 20 | 21 | while True: 22 | print("run_data_producer") 23 | for index in range(total_count): 24 | unique_key = f"{file_prefix}{index}" 25 | tran_path = f"{result_dir}/{unique_key}/back_trans.json" 26 | prove_path = f"{result_dir}/{unique_key}/prove_info.json" 27 | if CommonUtil.file_exist(tran_path): 28 | if CommonUtil.file_exist(prove_path) and not re_run: 29 | # 结果文件已经存在 & 不需要重新跑时 30 | exist_keys.add(unique_key) 31 | else: 32 | if unique_key not in exist_keys: 33 | json_data = CommonUtil.load_json(tran_path) 34 | json_data['prove_path'] = prove_path 35 | 36 | exist_keys.add(unique_key) 37 | queue.put(json_data) 38 | counter.value += 1 39 | 40 | print(f"exist_keys_size = {len(exist_keys)}") 41 | print(f"queue_size: {counter.value}") 42 | if len(exist_keys) >= total_count: 43 | print(f"producer_finished_count = {len(exist_keys)}") 44 | for _i in range(num_workers): 45 | # 完成时发送空消息给消费者 46 | queue.put(None) 47 | break 48 | 49 | sys.stdout.flush() # 保证该线程内的日志正常输出 50 | # m每20s循环检测一次 51 | time.sleep(20) 52 | 53 | 54 | class PipelineMainService(BaseService): 55 | """ 56 | 处理Herald_pipeline 生成的formal_statement 57 | 单独处理原因: 特定的结构体, informal_statement, translate_list, back_trans_list 58 | 59 | 整体思路: 60 | 定时器轮训检测Herald生成的每个文件的 "back_trans.json" 文件, 如何存在则添加到待队列中,全部添加完成后停止掉定时器 61 | 多进程从队列中获取待执行任务, 每条数据执行完成后保存结果文件到 "prove_info.json" 中 62 | """ 63 | 64 | def __init__(self, 65 | source_file: str, 66 | result_dir: str, 67 | gpus_list=None, 68 | num_samples: int = conf.config.NUM_SAMPLES, 69 | max_calls: int = conf.config.MAX_NODES, 70 | root: Path = Path(conf.config.LEAN_TEST_PATH), 71 | lean_env: str = conf.config.LEAN_ENV_PATH, 72 | model_list: list = conf.config.DEFAULT_MODEL_LIST, 73 | local_model_path: str = conf.config.PROVER_MODEL_PATH, 74 | re_run=False 75 | ): 76 | """ 77 | 78 | """ 79 | super().__init__(num_samples=num_samples, max_nodes=max_calls, root=root, lean_env=lean_env, 80 | model_list=model_list, local_model_path=local_model_path) 81 | self.source_file = source_file 82 | self.result_dir = result_dir 83 | if gpus_list is None: 84 | gpus_list = [0] 85 | self.gpus_list = gpus_list 86 | 87 | self.source_list = CommonUtil.read_json_list(self.source_file) 88 | self.length = len(self.source_list) 89 | self.back_data_list = [] 90 | self.file_prefix = f"index_" 91 | 92 | self.manager = mp.Manager() 93 | 94 | self.queue = mp.Queue() 95 | self.counter = mp.Value('i', 0) # 初始化计数器,用来存储队列中的数据量 96 | self.generator = self.manager.dict() 97 | self.re_run = re_run 98 | 99 | def _check_add_task_to_queue(self): 100 | """ 101 | """ 102 | producer_thread = threading.Thread(target=data_producer, args=(self.queue, self.counter, len(self.source_list), 103 | self.result_dir, self.file_prefix, 104 | self.re_run, len(self.gpus_list)), 105 | daemon=True) 106 | producer_thread.start() 107 | 108 | def _add_exist_to_queue(self): 109 | """ 110 | 添加已经back_tran完成的数据到queue 111 | """ 112 | print("_add_exist_to_queue") 113 | add_index_list = [] 114 | for index in range(len(self.source_list)): 115 | unique_key = f"{self.file_prefix}{index}" 116 | tran_path = f"{self.result_dir}/{unique_key}/back_trans.json" 117 | prove_path = f"{self.result_dir}/{unique_key}/prove_info.json" 118 | if not CommonUtil.file_exist(tran_path): 119 | continue 120 | if CommonUtil.file_exist(prove_path) and not self.re_run: 121 | continue 122 | json_data = CommonUtil.load_json(tran_path) 123 | json_data['prove_path'] = prove_path 124 | 125 | add_index_list.append(index) 126 | self.queue.put(json_data) 127 | self.counter.value += 1 128 | for _ in self.gpus_list: 129 | self.queue.put(None) 130 | print( 131 | f"add_count = {len(add_index_list)}, min_index = {min(add_index_list)}, max_index = {max(add_index_list)}") 132 | 133 | def _int_generator(self): 134 | for i in self.gpus_list: 135 | self.generator[i] = TacticGenerator(self.model_list, i, self.local_model_path) 136 | 137 | def _process_run(self, gpu_id: int): 138 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 139 | this_generator = self.generator[gpu_id] 140 | while True: 141 | item = self.queue.get() 142 | if item is None: # 检测结束信号 143 | break 144 | self.counter.value -= 1 145 | prove_detail_dict = {} 146 | result_list = [] # 收集成对的formal_statement 和formal_proof 147 | 148 | # profiler.start(f"run_time_{item['unique_key']}") 149 | print(f"run_time_start_{item['unique_key']}: back_size = {len(item['back_trans_list'])}: gpu_id = {str(gpu_id)}",flush=True) 150 | for inner_index, formal_statement in enumerate(item['back_trans_list']): 151 | try: 152 | temp_results = self.process_one(source=formal_statement, generator=this_generator) 153 | temp_result = self.parse_result(formal_statement, temp_results) 154 | # check result right 155 | if 'formal_proof' in temp_result: 156 | try: 157 | repl_res = verify_proof(temp_result['formal_proof'],os.path.join(conf.config.LEAN_ENV_PATH,'lake'),conf.config.LEAN_TEST_PATH) 158 | except Exception as e: 159 | print(traceback.format_exc()) 160 | repl_res = False 161 | else: 162 | repl_res = False 163 | prove_detail_dict[inner_index] = temp_result 164 | 165 | if repl_res: # 如果proof成功 166 | result_list.append({ 167 | 'formal_statement': formal_statement, 168 | 'formal_proof': temp_result['formal_proof'] 169 | }) 170 | except Exception as e: 171 | print(f"Error occurred: {e}") 172 | print(f"run_error_unique_key = {item['unique_key']}") 173 | 174 | item['prove_detail_dict'] = prove_detail_dict 175 | item['result_list'] = result_list 176 | CommonUtil.write_to_json_file(item['prove_path'], item) 177 | 178 | # profiler.stop(f"run_time_{item['unique_key']}") 179 | print(f"run_time_finished_{item['unique_key']}") 180 | 181 | def run_pipeline_prover(self): 182 | """ 183 | 184 | """ 185 | process_list = [] 186 | self._check_add_task_to_queue() # herald实时产生数据时使用此处 187 | # self._add_exist_to_queue() # Herald数据已产生完成时使用此处 188 | 189 | self._int_generator() 190 | 191 | for i in self.gpus_list: 192 | # 创建trans进程 193 | process = mp.Process(target=self._process_run, args=(i,)) 194 | process.start() 195 | process_list.append(process) 196 | 197 | for j in range(len(process_list)): 198 | process_list[j].join() 199 | print("All data processed.") 200 | -------------------------------------------------------------------------------- /Herald/service/handler/tran_handler.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | import re 3 | import json 4 | import tempfile 5 | import traceback 6 | import subprocess 7 | from typing import Optional 8 | import concurrent.futures 9 | 10 | import conf.config 11 | from util import profiler, CommonUtil 12 | 13 | 14 | class TranHandler(object): 15 | """ 16 | 自然语言 => lean 语言 17 | lean代码编译检测 18 | """ 19 | 20 | def __init__(self, gpus=1): 21 | """ 22 | 23 | """ 24 | self.model = None 25 | self.name = 'textbook_exercise' 26 | self.model_id = 'Herald' 27 | self.lean_path = conf.config.LEAN_TEST_PATH 28 | 29 | self.gpus = gpus 30 | 31 | # self._init_model() 32 | 33 | def _init_model(self): 34 | if self.model is None: 35 | # self.model = LLM(model=conf.config.TRAN_MODEL_PATH, max_num_batched_tokens=8192, seed=1, 36 | # trust_remote_code=True) 37 | self.model = LLM( 38 | model=conf.config.MODEL_CONFIG['trans'], 39 | tensor_parallel_size=self.gpus, 40 | trust_remote_code=True, 41 | dtype='bfloat16', 42 | # gpu_memory_utilization=0.9, 43 | ) 44 | 45 | def release_model(self): 46 | self.model = None 47 | 48 | 49 | def generate_and_check(self, informal_statement): 50 | """ 51 | 多线程处理 52 | """ 53 | statement_list = self.generate(informal_statement) 54 | # return statement_list 55 | return self.batch_validate_item(statement_list) 56 | 57 | def generate(self, informal_statement): 58 | if self.model is None: 59 | self._init_model() 60 | 61 | prompt = self.get_query(informal_statement, self.name, self.model_id) 62 | output = self.model.generate(prompt, sampling_params=self._build_sampling_param( 63 | conf.config.TRAN_CONFIG['sampling_params'])) 64 | outputs = output[0].outputs 65 | generated = [output.text for output in outputs] 66 | generated = self.process(generated, self.model_id) 67 | print('*** generated *** size = %s' % len(generated)) 68 | # print(json.dumps({'generated': generated}, indent=4)) 69 | return generated 70 | 71 | 72 | def batch_generate(self, data_list, sampling_params=None): 73 | """ 74 | 75 | """ 76 | if self.model is None: 77 | profiler.start('init_model') 78 | self._init_model() 79 | profiler.stop('init_model') 80 | prompt_list = [self.get_query(i['informal_statement'], self.name, self.model_id) for i in data_list] 81 | if not sampling_params: 82 | sampling_params = conf.config.TRAN_CONFIG['sampling_params'] 83 | 84 | profiler.start(f"generate_{len(prompt_list)}") 85 | gen_result_list = self.model.generate(prompt_list, sampling_params=self._build_sampling_param(sampling_params)) 86 | profiler.stop(f"generate_{len(prompt_list)}") 87 | CommonUtil.print(f"gen_result_list.size = {len(gen_result_list)}") 88 | for index, data in enumerate(data_list): 89 | data['generate_informal_statement_list'] = [self.process(output.text) for output in gen_result_list[index].outputs] 90 | return data_list 91 | 92 | def _build_sampling_param(self, sampling_params): 93 | """ 94 | 95 | """ 96 | return SamplingParams( 97 | n=sampling_params['n'], 98 | max_tokens=sampling_params['max_tokens'], 99 | temperature=sampling_params['temperature'], 100 | top_p=sampling_params['top_p'], 101 | ) 102 | 103 | def get_query(self, informal_statement: str, name: str, model_id='Herald') -> str: 104 | template = """Please translate the natural language statement to Lean4 code with the header. Do not generate any notations. 105 | **Name** 106 | {name} 107 | **Informal statement** 108 | {informal_statement} 109 | """ 110 | msgs = [ 111 | {'role': 'system', 'content': 'You are an expert at Lean 4 and Mathematics.'}, 112 | {'role': 'user', 'content': template.format( 113 | name=name, 114 | informal_statement=informal_statement)} 115 | ] 116 | if model_id == 'Herald': 117 | return self.chat_template_to_prompt(msgs, 'deepseek') 118 | elif model_id == 'InternLM': 119 | return self.chat_template_to_prompt(msgs, 'internlm') 120 | elif model_id == 'TheoremLlama': 121 | return self.chat_template_to_prompt(msgs, 'thmllm') 122 | else: 123 | raise NotImplementedError 124 | 125 | def process(self, generated: str, model_id='Herald') -> str: 126 | if model_id == 'Herald': 127 | return generated 128 | elif model_id in ['InternLM', 'TheoremLlama']: 129 | new_output = re.sub(r'^\s*-- .*$', '', generated, flags=re.MULTILINE) 130 | lean_code_pattern = r'```lean\n(.*?)(?:\n```|$)' # Match to the end if not finished 131 | matches = re.findall(lean_code_pattern, new_output, re.DOTALL) 132 | new_output = '\n'.join(matches) 133 | new_output = re.sub(r'\n+', '', new_output).strip() 134 | new_output = re.sub(r'-+', '', new_output).strip() 135 | new_output = re.sub(r':=.*', ':= sorry', new_output) 136 | return new_output 137 | else: 138 | raise NotImplementedError 139 | 140 | def chat_template_to_prompt(self, prompt_list, model='default'): 141 | result = "" 142 | total_step = len(prompt_list) 143 | for i, message in enumerate(prompt_list): 144 | if model == 'internlm': 145 | result += ('<|im_start|>' + message['role'] + 146 | '\n' + message['content']) 147 | if i + 1 != total_step: 148 | result += '<|im_end|>\n' 149 | elif message['role'] == 'user': 150 | result += '<|im_end|>\n<|im_start|>assistant\n' 151 | 152 | elif model == 'deepseek': 153 | if message['role'] == 'user': 154 | result += 'User:' + message['content'] + '\n\n' 155 | elif message['role'] == 'assistant': 156 | result += 'Assistant' + message['content'] + '<|end▁of▁sentence|>' 157 | elif message['role'] == 'system': 158 | result += message['content'] + '\n\n' 159 | if i + 1 == total_step and message['role'] == 'user': 160 | result += 'Assistant:' 161 | 162 | elif model == 'thmllm': 163 | result += ('<|start_header_id|>' + message['role'] + '<|end_header_id|>' + 164 | message['content'] + '<|eot_id|>') 165 | if i + 1 == total_step and message['role'] == 'user': 166 | result += '<|start_header_id|>assistant<|end_header_id|>' 167 | else: 168 | raise NotImplementedError 169 | return result 170 | 171 | 172 | def validate(self, code_string, header='', timeout=120) -> tuple[Optional[bool], str]: 173 | validation = True 174 | try: 175 | result = self.validate_one_lean_codestring(code_string, header, timeout) 176 | result_json = json.loads(result) 177 | if result_json.get("messages"): 178 | for msg in result_json.get("messages"): 179 | if msg.get("severity") == "error": 180 | validation = False 181 | except Exception as e: 182 | print(e) 183 | validation, result = False, str(e) 184 | return validation, result 185 | 186 | def validate_one_lean_codestring(self, code_string, header, timeout=300): 187 | command = dict(cmd=header + '\n' + code_string) 188 | print('command = %s' % command) 189 | message_str = json.dumps(command, ensure_ascii=False) 190 | lean_path = self.lean_path 191 | try: 192 | with tempfile.TemporaryFile(mode='w+', encoding='utf-8') as temp_file: 193 | temp_file.write(message_str + "\r\n\r\n") 194 | temp_file.seek(0) 195 | outputs = subprocess.run( 196 | [conf.config.DEFAULT_LAKE_PATH, "exe", 'repl'], 197 | stdin=temp_file, 198 | capture_output=True, 199 | text=True, 200 | cwd=lean_path, 201 | timeout=timeout, 202 | encoding='utf-8' 203 | ) 204 | except Exception as _: 205 | CommonUtil.print("validate_one_lean_code error.....") 206 | error_info = traceback.format_exc() 207 | print(error_info) 208 | else: 209 | return outputs.stdout 210 | 211 | def batch_validate_item(self, statement_list): 212 | def validate_item(item): 213 | validation, result = self.validate(item) 214 | CommonUtil.print('validation = %s' % validation) 215 | # print('result = %s' % result) 216 | return item if validation else None 217 | 218 | max_workers = min(len(statement_list), conf.config.THREAD_CONFIG['lean_build']) 219 | CommonUtil.print('validate_max_workers = %s' % max_workers) 220 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 221 | results = list(executor.map(validate_item, statement_list)) 222 | 223 | # 过滤出有效的结果 224 | return[item for item in results if item is not None] 225 | 226 | -------------------------------------------------------------------------------- /Realprover/manager/service/batch_main_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import multiprocessing as mp 4 | import traceback 5 | from manager.thirdparty import TacticGenerator 6 | from manager.service import BaseService 7 | from util import CommonUtil, profiler 8 | import conf.config 9 | from manager.search.exception import SearchError, error_logging 10 | import logging 11 | from manager.thirdparty.verifier import verify_proof 12 | import json 13 | 14 | def search_check(path, max_retries): 15 | cnt = 0 16 | for root, dir, files in os.walk(path): 17 | for file in files: 18 | connect_error = False 19 | if '.json' in file and os.path.samefile(path,root): 20 | file_path = os.path.join(root, file) 21 | with open(file_path,'r') as f: 22 | try: 23 | file_json = json.load(f) 24 | for res in file_json["collect_results"]: 25 | if res['success']: 26 | return True 27 | if len(res['calls'])==0: 28 | connect_error=True 29 | except: 30 | return False 31 | if not connect_error: 32 | cnt += 1 33 | if cnt == max_retries: 34 | return True 35 | else: 36 | return False 37 | 38 | class BatchMainService(BaseService): 39 | """ 40 | 多进程批量执行prover, 41 | """ 42 | 43 | def __init__(self, 44 | source_file: str, 45 | result_dir: str, 46 | gpus_list=None, 47 | num_samples: int = conf.config.NUM_SAMPLES, 48 | max_nodes: int = conf.config.MAX_NODES, 49 | max_depth: int = conf.config.MAX_DEPTH, 50 | use_beam_search: bool = conf.config.USE_BEAM_SEARCH, 51 | use_mcts_search: bool = conf.config.USE_MCTS_SEARCH, 52 | simulation_depth: int = conf.config.SIM_DEPTH, 53 | c_puct: float = conf.config.C_PUCT, 54 | beam_width: int = conf.config.BEAM_WIDTH, 55 | root: Path = Path(conf.config.LEAN_TEST_PATH), 56 | lean_env: str = conf.config.LEAN_ENV_PATH, 57 | model_list: list = conf.config.DEFAULT_MODEL_LIST, 58 | local_model_path: str = conf.config.PROVER_MODEL_PATH, 59 | sampling_params: dict = conf.config.PROVER_MODEL_PARAMS, 60 | max_retries: int = conf.config.MAX_RETRIES, 61 | c_score: float = conf.config.C_SCORE, 62 | c_expansion_fail_penalty: float = conf.config.C_EXPANSION_FAIL_PENALTY, 63 | max_root_expansion: int = conf.config.MAX_ROOT_EXPANSION, 64 | max_calls: int = conf.config.MAX_CALLS, 65 | abandon_if_contain: list[str] = conf.config.ABANDON_IF_CONTAIN, 66 | is_incontext: bool = conf.config.IS_INCONTEXT, 67 | template: str = 'deepseek', 68 | use_retrieval: bool = conf.config.USE_RETRIEVAL 69 | ): 70 | """ 71 | 72 | """ 73 | # 支持此种方式透传的原因是:可能存在执行方参数配置的输入 74 | super().__init__(num_samples=num_samples, 75 | max_nodes=max_nodes, 76 | max_depth=max_depth, 77 | use_beam_search=use_beam_search, 78 | use_mcts_search=use_mcts_search, 79 | simulation_depth=simulation_depth, 80 | c_puct=c_puct, 81 | beam_width=beam_width, 82 | root=root, 83 | lean_env=lean_env, 84 | model_list=model_list, 85 | local_model_path=local_model_path, 86 | sampling_params=sampling_params, 87 | max_calls=max_calls, 88 | max_root_expansion=max_root_expansion, 89 | c_score=c_score, 90 | c_expansion_fail_penalty=c_expansion_fail_penalty, 91 | abandon_if_contain = abandon_if_contain, 92 | is_incontext = is_incontext, 93 | template=template, 94 | use_retrieval=use_retrieval) 95 | self.source_file = source_file 96 | self.result_dir = result_dir 97 | self.max_retries = max_retries 98 | if not os.path.exists(self.result_dir+'/generated'): 99 | os.makedirs(self.result_dir+'/generated') 100 | if not os.path.exists(self.result_dir+'/error'): 101 | os.makedirs(self.result_dir+'/error') 102 | logging.basicConfig( 103 | filename=f"{self.result_dir}/error.log", # File where logs will be saved 104 | level=logging.ERROR, # Log level 105 | format='%(asctime)s - %(levelname)s - %(message)s', 106 | datefmt='%Y-%m-%d %H:%M:%S' 107 | ) 108 | if gpus_list is None: 109 | gpus_list = [0] 110 | self.gpus_list = gpus_list 111 | 112 | self.manager = mp.Manager() 113 | self.generator = self.manager.dict() 114 | self.queue = mp.Queue() 115 | 116 | self.source_list = [] 117 | 118 | 119 | 120 | def _init_source_list(self): 121 | data_list = CommonUtil.read_json_list(self.source_file) 122 | print(f"data_list = {len(data_list)}") 123 | for source_index, item in enumerate(data_list): 124 | # 添加在原始文件的索引 125 | item['source_index'] = source_index 126 | self.source_list = data_list 127 | 128 | def _int_generator(self): 129 | for i in self.gpus_list: 130 | self.generator[i] = TacticGenerator( 131 | model_list=self.model_list, 132 | gpu_id=i, 133 | local_model_path=self.local_model_path, 134 | sampling_params=self.sampling_params, 135 | max_calls=self.max_calls) 136 | 137 | def _int_queue(self): 138 | for item in self.source_list: 139 | self.queue.put(item) 140 | for _ in self.gpus_list: 141 | self.queue.put(None) 142 | 143 | def _process_run(self, gpu_id: int): 144 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 145 | this_generator = self.generator[gpu_id] 146 | while True: 147 | item = self.queue.get() 148 | if item is None: # 检测结束信号 149 | break 150 | item_path = os.path.join(self.result_dir,'generated',f"{item['id']}") 151 | if not os.path.exists(item_path): 152 | os.makedirs(item_path) 153 | if search_check(item_path, self.max_retries): 154 | continue 155 | for idx in range(self.max_retries): 156 | try: 157 | print(f"processing {item['id']}") 158 | profiler.start(f"run_index_{item['id']}") 159 | results = self.process_one(source=item["formal_statement"], generator=this_generator) 160 | profiler.stop(f"run_index_{item['id']}") 161 | except SearchError as e: 162 | error_log_path = f"{self.result_dir+'/error'}/{item['id']}.json" 163 | error_logging(error_log_path, item["id"], item["formal_statement"], e.error_data) 164 | # CommonUtil.write_to_json_file(error_log_path, e.error_data) 165 | logging.error([e.error_type,item['id'], "Searching Error"] ,exc_info=True) 166 | except Exception as e: 167 | print(traceback.format_exc()) 168 | print(f"Error occurred: {e}") 169 | print(f"run_error_index = {item['id']}") 170 | logging.error([e,item['id'], "Non-Searching Error"] ,exc_info=True) 171 | else: 172 | save_data = self.parse_result(item["formal_statement"], results) 173 | if 'formal_proof' in save_data: 174 | try: 175 | print(f"lake repl check:{item['id']}",flush=True) 176 | repl_res = verify_proof(save_data['formal_proof'],os.path.join(conf.config.LEAN_ENV_PATH,'lake'),conf.config.LEAN_TEST_PATH) 177 | except Exception as e: 178 | print(traceback.format_exc()) 179 | save_data['collect_results'][0]['success'] = False 180 | else: 181 | save_data['collect_results'][0]['success'] = repl_res 182 | result_file = os.path.join(item_path,f"{item['id']}_{idx}.json") 183 | CommonUtil.write_to_json_file(result_file, save_data) 184 | print(f"finish_index_{item['id']}",flush=True) 185 | if save_data['collect_results'] and save_data['collect_results'][0]['success']: 186 | break 187 | else: 188 | continue 189 | self.info.update(self.get_info()) 190 | print(self.info) 191 | 192 | def batch_run(self): 193 | """ 194 | 195 | """ 196 | process_list = [] 197 | self._init_source_list() 198 | self._int_generator() 199 | self._int_queue() 200 | for i in self.gpus_list: 201 | # 创建trans进程 202 | process = mp.Process(target=self._process_run, args=(i,)) 203 | process.start() 204 | process_list.append(process) 205 | 206 | for j in range(len(process_list)): 207 | process_list[j].join() 208 | print("All data processed.") 209 | 210 | def get_info(self): 211 | return dict( 212 | max_retries=self.max_retries, 213 | source_file=self.source_file, 214 | result_dir=self.result_dir 215 | ) -------------------------------------------------------------------------------- /LeanSearch-PS/build_training_data.py: -------------------------------------------------------------------------------- 1 | ### build training data for embedding 2 | # The format of the training data is a json file as follow: 3 | # { 4 | # "query_id": "", 5 | # "query": "", 6 | # "positive_passages": [ 7 | # {"docid": "", "title": "", "text": ""}, 8 | # ... 9 | # ], 10 | # "negative_passages": [ 11 | # {"docid": "", "title": "", "text": ""}, 12 | # ... 13 | # ] 14 | # } 15 | # 16 | # Process for building: 17 | # 1. build query data and corpus data 18 | # 2. embed query data and corpus data 19 | # 3. search corpus embedding within query embedding 20 | # 4. build training data 21 | 22 | 23 | import argparse 24 | import json 25 | import os 26 | import random 27 | import glob 28 | import pickle 29 | from tqdm import tqdm 30 | import subprocess 31 | import numpy as np 32 | 33 | def build_query_and_corpus_data(args): 34 | print('Building query and corpus data') 35 | # Step 1: build query data and corpus data 36 | if os.path.exists(os.path.join(args.save_dir, 'query_data/')) and os.path.exists(os.path.join(args.save_dir, 'corpus_data/')): 37 | print("query and corpus data already exist") 38 | return 39 | os.makedirs(os.path.join(args.save_dir, 'query_data/')) 40 | os.makedirs(os.path.join(args.save_dir, 'corpus_data/')) 41 | with open('./answers.json', 'r') as f: 42 | answers = json.load(f) 43 | for key, value in tqdm(answers.items(), ): 44 | corpus_data, query_data = {}, {} 45 | query_data['query_id'] = key 46 | query_data['query'] = value['Informal statement'] 47 | corpus_data['docid'] = key 48 | corpus_data['title'] = "" 49 | corpus_data['text'] = f"Formal name: {value['Formal name']} Formal statement: {value['Formal statement']}" 50 | # Save the query data to a file 51 | with open(os.path.join(args.save_dir, 'query_data/query_data.json'), 'a') as f: 52 | json.dump(query_data, f, ensure_ascii=False) 53 | f.write('\n') 54 | # Save the corpus data to a file 55 | with open(os.path.join(args.save_dir, 'corpus_data/corpus_data.json'), 'a') as f: 56 | json.dump(corpus_data, f, ensure_ascii=False) 57 | f.write('\n') 58 | 59 | def embed_query_and_corpus_data(args): 60 | print('Embedding query and corpus data') 61 | # Step 2: embed query data and corpus data 62 | if not os.path.exists(os.path.join(args.save_dir, 'query_embedding/')): 63 | os.makedirs(os.path.join(args.save_dir, 'query_embedding/')) 64 | if not os.path.exists(os.path.join(args.save_dir, 'corpus_embedding/')): 65 | os.makedirs(os.path.join(args.save_dir, 'corpus_embedding/')) 66 | # query embedding 67 | print("----embed query data") 68 | processes = [] 69 | for s in args.gpu_ids: 70 | env = os.environ.copy() 71 | env['CUDA_VISIBLE_DEVICES'] = str(s) 72 | command = [ 73 | 'python', '-m', 'tevatron.retriever.driver.encode', 74 | '--output_dir=temp', 75 | '--model_name_or_path', args.model_name_or_path, 76 | '--lora', '--lora_name_or_path', args.lora_name_or_path, 77 | '--query_prefix', '"Given a math statement, retrieve Lean4 code mathematically equivalent to it: "', 78 | '--passage_prefix', '"Lean4 code: "', 79 | '--bf16', 80 | '--pooling', 'eos', 81 | '--append_eos_token', 82 | '--normalize', 83 | '--encode_is_query', 84 | '--per_device_eval_batch_size', '128', 85 | '--query_max_len', '128', 86 | '--passage_max_len', '256', 87 | '--dataset_name', os.path.join(args.save_dir, 'query_data/'), 88 | '--dataset_number_of_shards', str(len(args.gpu_ids)), 89 | '--dataset_shard_index', str(s), 90 | '--encode_output_path', os.path.join(args.save_dir, 'query_embedding/query.{}.pkl'.format(s)), 91 | ] 92 | # if args.lora: 93 | # command = command[:6] + ['--lora', '--lora_name_or_path', args.lora_name_or_path] + command[6:] 94 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env) 95 | processes.append(process) 96 | for p in processes: 97 | p.wait() 98 | # corpus embedding 99 | print("----embed corpus data") 100 | processes = [] 101 | for s in args.gpu_ids: 102 | env = os.environ.copy() 103 | env['CUDA_VISIBLE_DEVICES'] = str(s) 104 | command = [ 105 | 'python', '-m', 'tevatron.retriever.driver.encode', 106 | '--output_dir=temp', 107 | '--model_name_or_path', args.model_name_or_path, 108 | '--lora', '--lora_name_or_path', args.lora_name_or_path, 109 | '--query_prefix', 'Given a math statement, retrieve Lean4 code mathematically equivant to it: ', 110 | '--passage_prefix', 'Lean4 code: ', 111 | '--bf16', 112 | '--pooling', 'eos', 113 | '--append_eos_token', 114 | '--normalize', 115 | '--per_device_eval_batch_size', '128', 116 | '--query_max_len','128', 117 | '--passage_max_len', '256', 118 | '--dataset_name', os.path.join(args.save_dir, 'corpus_data/'), 119 | '--dataset_number_of_shards', str(len(args.gpu_ids)), 120 | '--dataset_shard_index', str(s), 121 | '--encode_output_path', os.path.join(args.save_dir, 'corpus_embedding/corpus.{}.pkl'.format(s)), 122 | ] 123 | # if args.lora: 124 | # command = command[:6] + ['--lora', '--lora_name_or_path', args.lora_name_or_path] + command[6:] 125 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env) 126 | processes.append(process) 127 | for p in processes: 128 | p.wait() 129 | 130 | def search(args): 131 | print('Searching corpus embedding within query embedding') 132 | # Step 3: search corpus embedding within query embedding 133 | index_files = glob.glob(os.path.join(args.save_dir, 'query_embedding/query.*.pkl')) 134 | reps_list, lookup_list = [], [] 135 | for path in index_files: 136 | with open(path, 'rb') as f: 137 | reps, lookup = pickle.load(f) 138 | print(reps.shape, type(lookup)) 139 | reps_list.append(reps) 140 | lookup_list += lookup 141 | reps_array = np.vstack(reps_list) 142 | print(reps_array.shape) 143 | with open(os.path.join(args.save_dir, 'query_embedding/query.pkl'), 'wb') as f: 144 | pickle.dump((reps_array, lookup_list), f) 145 | command = [ 146 | 'set', '-f', '&&', 'python', '-m', 'tevatron.retriever.driver.search', 147 | '--query_reps', os.path.join(args.save_dir, 'query_embedding/query.pkl'), 148 | '--passage_reps', os.path.join(args.save_dir, 'corpus_embedding/corpus.*.pkl'), 149 | '--depth', str(args.top_k), 150 | '--batch_size', '64', 151 | '--save_text', 152 | '--save_ranking_to', os.path.join(args.save_dir, 'retrieval.txt'), 153 | ] 154 | subprocess.run(command[3:]) 155 | 156 | def build_training_data(args): 157 | print('Building training data') 158 | # Step 4: build training data 159 | with open('./answers.json', 'r') as f: 160 | answers = json.load(f) 161 | training_dataset = {} 162 | with open(os.path.join(args.save_dir, 'retrieval.txt'), 'r', encoding='utf-8') as f: 163 | content = f.read().strip() 164 | for line in tqdm(content.split(sep='\n')): 165 | parts = line.split(sep='\t') 166 | query_id, docid = int(parts[0]), parts[1] 167 | try: 168 | training_dataset[query_id].append(docid) 169 | except: 170 | training_dataset[query_id] = [docid] 171 | # 按照query_id对training_data这个dir排序 172 | training_dataset = sorted(training_dataset.items(), key=lambda x: x[0]) 173 | with open(os.path.join(args.save_dir, 'training_data.json'), 'w', encoding='utf-8') as f: 174 | random.seed(args.seed) 175 | for query_id, data in tqdm(training_dataset): 176 | if args.is_choice: 177 | docids = [data[random.randrange(args.bottom_k, args.top_k)]] 178 | else: 179 | docids = data[args.bottom_k:args.top_k] 180 | train_data = { 181 | "query_id": query_id, 182 | "query": answers[str(query_id)]['Informal statement'], 183 | "positive_passages": [{ 184 | "doc_id": query_id, 185 | "title": answers[str(query_id)]['Formal name'], 186 | "text": answers[str(query_id)]['Formal statement'] 187 | }], 188 | "negative_passages": [{ 189 | "doc_id": int(docid), 190 | "title": answers[docid]['Formal name'], 191 | "text": answers[docid]['Formal statement'] 192 | } for docid in docids], 193 | } 194 | json.dump(train_data, f, ensure_ascii=False) 195 | f.write('\n') 196 | 197 | if __name__ == "__main__": 198 | # build parser 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument('--model_name_or_path', type=str, default='./models/model_name_or_path') 201 | parser.add_argument('--lora', action='store_true') 202 | parser.add_argument('--lora_name_or_path', type=str, default='./checkpoints/lora_name_or_path') 203 | parser.add_argument('--top_k', type=int, default=100) 204 | parser.add_argument('--bottom_k', type=int, default=30) 205 | parser.add_argument('--save_dir', type=str, default='./datasets/training_data') 206 | parser.add_argument('--gpu_ids', type=list, default=[0,1,2,3]) 207 | parser.add_argument('--is_choice', action='store_true') 208 | parser.add_argument('--seed', type=int, default=42) 209 | args = parser.parse_args() 210 | 211 | if not os.path.exists(args.save_dir): 212 | os.makedirs(args.save_dir) 213 | # Step 1: build query data and corpus data 214 | build_query_and_corpus_data(args) 215 | # Step 2: embed query data and corpus data 216 | embed_query_and_corpus_data(args) 217 | # Step 3: search corpus embedding within query embedding 218 | search(args) 219 | # Step 4: build training data 220 | build_training_data(args) 221 | -------------------------------------------------------------------------------- /Realprover/manager/manage/proof_parse_manage.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | from pprint import pp 4 | from matplotlib import pyplot as plt 5 | import networkx as nx 6 | from conf import config 7 | 8 | from manager.struct import Goal 9 | import os 10 | 11 | class ProofParseManage(object): 12 | """ 13 | 此处代码从源代码 stats.py 文件中移植 14 | """ 15 | 16 | @staticmethod 17 | def get_correct_proof(data: dict) -> str: 18 | proof = data['formal_statement'] 19 | for p in data['collect_results']: 20 | G, path = ProofParseManage.get_proof_tree(p) 21 | edge_labels = nx.get_edge_attributes(G, "label") 22 | proof = proof.replace("sorry", "").replace("by", "").strip() + " by" 23 | for i in range(len(path) - 1): 24 | tac = edge_labels[(path[i], path[i + 1])] 25 | proof += "\n " + tac 26 | return proof 27 | 28 | @staticmethod 29 | def get_proof_tree(data: dict) -> tuple[nx.DiGraph, list[int]]: 30 | # Create a directed graph 31 | nodes = data["nodes"] 32 | G = nx.DiGraph() 33 | correct_path = [] 34 | 35 | # Add nodes and edges 36 | for node in nodes: 37 | G.add_node( 38 | node["id"], 39 | parent=node["parent"], 40 | id=node["id"], 41 | tactic=node["tactic"], 42 | state=node["state"] 43 | ) 44 | if node["id"] != node["parent"]: 45 | G.add_edge(node["parent"], node["id"], label=node["tactic"]) 46 | # Check if this step ends with an empty goal 47 | if not node["state"]: 48 | correct_path.append(node["id"]) 49 | 50 | # Backtrack the correct path 51 | path = [] 52 | if correct_path: 53 | # Assume the first empty goal is the correct one 54 | node = correct_path[0] 55 | while node in G: 56 | path.append(node) 57 | # Stop if we reach the root 58 | if node == G.nodes[node]["parent"]: 59 | break 60 | node = G.nodes[node]["parent"] 61 | path = path[::-1] # Reverse the path to start from the root 62 | return G, path 63 | 64 | @staticmethod 65 | def pp_state(state: list[Goal]) -> str: 66 | return "\n\n".join([goal.pretty for goal in state]) 67 | 68 | @staticmethod 69 | def collect_error(nid: int, decl: str, state: str, tactic: str, error: str, result_dir: Path) -> None: 70 | with open(result_dir, 'a') as fp: 71 | fp.write( 72 | f'# Error on node {nid}\n\n## Informal statement:\n{decl}\n\n## State:\n{state}\n\n## Tactic:\n{tactic}\n\n## Detail:\n\n{error}\n\n' 73 | ) 74 | 75 | @staticmethod 76 | def concat_proof(d: dict) -> str: 77 | nodes = d['nodes'][1:] # ignore the first node 78 | concatenated = "\n".join([" " + n['tactic'] for n in nodes]) 79 | return concatenated 80 | 81 | @staticmethod 82 | def pretty(data: list[dict], result_dir: Path) -> None: 83 | output_dir = Path("outputs") / (result_dir.name) 84 | output_dir.mkdir(exist_ok=True) 85 | for d in data[:]: 86 | formal_statement = d['formal_proof'] 87 | informal_statement = d['informal_stmt'] 88 | result_file = result_dir / f"{d['answer_id']}.json" 89 | with open(result_file) as fp: 90 | result_dict = json.load(fp) 91 | if not result_dict: 92 | continue 93 | if not result_dict[0]['success']: 94 | continue 95 | formal_proof = ProofParseManage.concat_proof(result_dict[0]) 96 | 97 | with open(output_dir / f"{d['answer_id']}.lean", "w") as fp: 98 | fp.write(f"/- {informal_statement} -/\n") 99 | fp.write(formal_statement + formal_proof) 100 | # print(formal_statement + formal_proof) 101 | 102 | @staticmethod 103 | def visualize_proof_tree(G: nx.DiGraph, path: list[int], output: Path = Path("figs/test_plot.jpg")) -> None: 104 | # Draw the graph 105 | pos = nx.drawing.nx_agraph.graphviz_layout( 106 | G, prog="dot") # Layout for visualization 107 | colors = ["red" if node in path else "blue" for node in G.nodes] 108 | labels = {node: f"{G.nodes[node]['id']}" for node in G.nodes} 109 | _, ax = plt.subplots(figsize=(10, 8)) 110 | nx.draw( 111 | G, 112 | pos, 113 | ax, 114 | with_labels=True, 115 | labels=labels, 116 | node_color=colors, 117 | node_size=500, 118 | font_size=10, 119 | font_color="white", 120 | edge_color="gray", 121 | ) 122 | edge_labels = nx.get_edge_attributes(G, "label") 123 | nx.draw_networkx_edge_labels( 124 | G, 125 | pos, 126 | edge_labels=edge_labels, 127 | font_size=8, 128 | rotate=False 129 | ) 130 | plt.title("Proof Search Tree") 131 | plt.show() 132 | plt.savefig(output) 133 | 134 | @staticmethod 135 | def get_prompt(data, node): 136 | for i, n in enumerate(data[-1]["nodes"]): 137 | if n["id"] == node: 138 | pp(data[-1]['calls'][i][2]) 139 | return data[-1]["calls"][i][2] 140 | 141 | @staticmethod 142 | def get_one_kto_data(data: dict): 143 | G, path = ProofParseManage.get_proof_tree(data) 144 | correct_edges = [(path[i], path[i + 1]) for i in range(len(path) - 1)] 145 | results = [] 146 | for edge in G.edges: 147 | node = G.nodes[edge[1]] 148 | try: 149 | prompt = ProofParseManage.get_prompt(data, node["id"]) 150 | except IndexError: 151 | prompt = "" 152 | d = { 153 | "id": node["id"], 154 | "parent": node["parent"], 155 | "tactic": node["tactic"], 156 | "state": node["state"], 157 | "prompt": prompt 158 | } 159 | if edge in correct_edges: 160 | d.update({"label": True}) 161 | else: 162 | d.update({"label": False}) 163 | results.append(d) 164 | return results 165 | 166 | @staticmethod 167 | def get_kto_data(result_dir: Path): 168 | kto_data = [] 169 | for file in result_dir.iterdir(): 170 | with file.open() as fp: 171 | data = json.load(fp) 172 | try: 173 | one_kto_data = ProofParseManage.get_one_kto_data(data) 174 | except Exception as e: 175 | one_kto_data = [] 176 | print(e) 177 | kto_data.extend(one_kto_data) 178 | print(len(kto_data)) 179 | with open(Path("kto_data") / (result_dir.name + ".json"), 'w') as fp: 180 | json.dump(kto_data, fp, ensure_ascii=False, indent=2) 181 | 182 | @staticmethod 183 | def get_stats(result_dir: str, info: dict = {}): 184 | results = dict(total=0, success=0, accuracy=0.0) 185 | generated_dir = os.path.join(result_dir, 'generated') 186 | for dir_path in os.listdir(generated_dir): 187 | success_flag = False 188 | full_path = os.path.join(generated_dir,dir_path) 189 | for root, dir, files in os.walk(full_path): 190 | for file in files: 191 | if '.json' in file and os.path.samefile(full_path,root): 192 | file_path = os.path.join(root, file) 193 | with open(file_path,'r') as f: 194 | data = json.load(f) 195 | for problem in data['collect_results']: 196 | if problem["success"]: 197 | results['success'] += 1 198 | success_flag = True 199 | break 200 | if success_flag: 201 | break 202 | results['total'] += 1 203 | results['accuracy'] = results["success"] / results['total'] 204 | results.update(info) 205 | with open(Path(result_dir, 'result.txt'), 'w', encoding='utf-8') as fp: 206 | for k, v in results.items(): 207 | fp.write(f'{k}: {v}\n') 208 | pp(results) 209 | 210 | @staticmethod 211 | def visualize_all_proof_trees(result_dir: str, keep_false: bool = True): 212 | output_dir = Path(result_dir, 'figs') 213 | output_dir.mkdir(exist_ok=True) 214 | for file in Path(result_dir, 'generated').iterdir(): 215 | with file.open() as fp: 216 | data = json.load(fp) 217 | if not data['collect_results']: 218 | continue 219 | if (not keep_false) and (not data['collect_results'][-1]['success']): 220 | continue 221 | for p in data['collect_results']: 222 | G, path = ProofParseManage.get_proof_tree(p) 223 | output_path = output_dir / file.name.replace('.json', '.png') 224 | ProofParseManage.visualize_proof_tree(G, path, output_path) 225 | 226 | @staticmethod 227 | def get_all_correct_proofs(result_dir: str) -> None: 228 | output_dir = Path(result_dir, 'proofs') 229 | output_dir.mkdir(exist_ok=True) 230 | for file in Path(result_dir, 'generated').iterdir(): 231 | with file.open() as fp: 232 | data = json.load(fp) 233 | if not data['collect_results'] or not data['collect_results'][-1]['success']: 234 | continue 235 | proof = ProofParseManage.get_correct_proof(data) 236 | with open(Path(output_dir, file.name.replace('.json', '.lean')), 'w') as fp: 237 | fp.write(proof) 238 | 239 | @staticmethod 240 | def get_demo_data(result_dir: str) -> None: 241 | output_dir = Path(result_dir, 'demos') 242 | output_dir.mkdir(exist_ok=True) 243 | for file in Path(result_dir, 'generated').iterdir(): 244 | with file.open() as fp: 245 | data = json.load(fp) 246 | if 'formal_proof' not in data.keys(): 247 | continue 248 | data['_formal_proof'] = data.pop('formal_proof') 249 | data['informal_statement'] = '' 250 | if not data['collect_results']: 251 | continue 252 | if not data['collect_results'][-1]['success']: 253 | continue 254 | data['nodes'] = data['collect_results'][-1]['nodes'] 255 | _, path = ProofParseManage.get_proof_tree(data) 256 | data.pop('collect_results') 257 | for n in data['nodes']: 258 | n['informal_tactic'] = '' 259 | n['short_informal_tactic'] = '' 260 | n['informal_state'] = [''] * len(n['state']) 261 | n['in_right_path'] = n['id'] in path 262 | with open(output_dir / file.name, 'w') as fp: 263 | json.dump(data, fp, ensure_ascii=False, indent=4) 264 | 265 | 266 | @staticmethod 267 | def get_length_distribution(result_dir: str): 268 | from collections import Counter 269 | len_list = [] 270 | generated_dir = Path(result_dir, "generated") 271 | for dir_path in os.listdir(generated_dir): 272 | success_flag = False 273 | full_path = os.path.join(generated_dir,dir_path) 274 | for root, dir, files in os.walk(full_path): 275 | for file in files: 276 | if '.json' in file and os.path.samefile(full_path,root): 277 | file_path = os.path.join(root, file) 278 | with open(file_path,'r') as f: 279 | data = json.load(f) 280 | if data['collect_results'][-1]['success']: 281 | G, path = ProofParseManage.get_proof_tree(data['collect_results'][-1]) 282 | len_list.append(len(path)-1) 283 | counter = Counter(len_list) 284 | 285 | numbers = list(counter.keys()) 286 | frequencies = list(counter.values()) 287 | 288 | output = Path(result_dir, "len_distribution.png") 289 | 290 | plt.bar(numbers, frequencies) 291 | plt.xlabel('Length') 292 | plt.ylabel('Frequency') 293 | plt.title('Proof Length Distribution') 294 | plt.savefig(output, dpi=300, bbox_inches='tight') 295 | 296 | -------------------------------------------------------------------------------- /Realprover/manager/search/mcts_search.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import math 3 | from typing import Optional, Dict, Tuple 4 | 5 | from manager.struct import Node, Goal, state_repr_dedup 6 | from manager.thirdparty import Interactive, TacticGenerator 7 | import conf.config 8 | from manager.search.exception import SearchError 9 | EPSILON = 1e-3 10 | 11 | class MCTSNode(Node): 12 | def __init__(self, 13 | sid: int, 14 | parent_sid: int, 15 | tactic: str, 16 | state: list[Goal], 17 | depth: int = 0, 18 | score: float = 0, #现版本MCTS代码并未用到node的score 19 | parent: Optional["Node"] = None): 20 | super().__init__(sid, parent_sid, tactic, state, depth, score, parent) 21 | self.visits = 0 # 访问次数 22 | self.value = 0.0 # 价值估计 23 | self.children: Dict[str, "MCTSNode"] = {} # str是字符串化的tacitc 24 | 25 | class MCTSSearch: 26 | def __init__( 27 | self, 28 | num_samples: int = conf.config.NUM_SAMPLES, 29 | max_nodes: int = conf.config.MAX_NODES, 30 | max_depth: int = conf.config.MAX_DEPTH, 31 | max_calls: int = conf.config.MAX_CALLS, 32 | simulation_depth: int = conf.config.SIM_DEPTH, # MCTS模拟的最大深度 33 | max_root_expansion: int = conf.config.MAX_ROOT_EXPANSION, # 根节点最多扩展的次数 34 | c_puct: float = conf.config.C_PUCT, # PUCT公式的探索参数 35 | c_score: float = conf.config.C_SCORE, # UCB中score的权重 36 | c_expansion_fail_penalty: float = conf.config.C_EXPANSION_FAIL_PENALTY, # 扩展失败时反向传播的value,是正数 37 | abandon_if_contain: list[str] = conf.config.ABANDON_IF_CONTAIN 38 | ): 39 | self.found = False 40 | self.nodes: Dict[str, MCTSNode] = {} # 由于是继承式字段扩增,所以外面当作Node类的调用不会出问题 41 | self.root: MCTSNode = None 42 | self.score = dict() #(deduped_state_str, score), 用于UCB的计算,当作基础值。注意并非heapdict,所以都是正值,跟beamsearch不同 43 | self.depth = 0 44 | self.max_nodes = max_nodes 45 | self.max_depth = max_depth 46 | self.max_calls = max_calls 47 | self.num_samples = num_samples 48 | self.simulation_depth = simulation_depth 49 | self.c_puct = c_puct 50 | self.max_root_expansion = max_root_expansion 51 | self.c_score = c_score 52 | self.c_expansion_fail_penalty = c_expansion_fail_penalty 53 | self.call_cnt = 0 54 | self.tactic_sid_record = [] 55 | self.abandon_if_contain = abandon_if_contain 56 | 57 | def tactic_filter(self, tactic: str) -> bool: 58 | for forbidden_tactic in self.abandon_if_contain: 59 | if forbidden_tactic in tactic: 60 | return False 61 | return True 62 | 63 | def insert(self, node): 64 | if not isinstance(node, MCTSNode): #处理外界当作node的使用 65 | node = MCTSNode( 66 | sid=node.sid, 67 | parent_sid=node.parent_sid, 68 | tactic=node.tactic, 69 | state=node.state, 70 | depth=node.depth, 71 | score=node.score, 72 | parent=node.parent 73 | ) 74 | if not node.state: 75 | # self.interactive.commit(node.sid) 76 | self.found = True 77 | self.success_sid = node.sid 78 | deduped_state_str = state_repr_dedup(node.state) 79 | if deduped_state_str not in self.nodes: 80 | self.nodes[deduped_state_str] = node 81 | self.depth = max(self.depth, node.depth) 82 | self.score[deduped_state_str] = 0 83 | 84 | def _delete_node(self, mcts_node: MCTSNode): 85 | if mcts_node.sid == 0: # 根节点不删除 86 | return 87 | 88 | # 递归删除所有子节点 89 | for child in list(mcts_node.children.values()): # 使用list创建副本避免在迭代时修改 90 | self._delete_node(child) 91 | 92 | # 删除当前节点 93 | parent = mcts_node.parent 94 | del parent.children[mcts_node.tactic] 95 | del self.nodes[state_repr_dedup(mcts_node.state)] 96 | del self.score[state_repr_dedup(mcts_node.state)] 97 | 98 | # 清理引用关系 99 | mcts_node.parent = None 100 | mcts_node.children.clear() 101 | 102 | def _get_root(self) -> MCTSNode: 103 | assert self.nodes 104 | #选取self.nodes里面sid为0的节点。如果没有或者超出一个则报错 105 | root_nodes = [node for node in self.nodes.values() if node.sid == 0] 106 | if len(root_nodes) == 0: 107 | raise ValueError("Root node not found") 108 | elif len(root_nodes) > 1: 109 | raise ValueError("Multiple sid = 0 root nodes found") 110 | return root_nodes[0] 111 | 112 | def _build_children(self, mcts_node: MCTSNode, generator: TacticGenerator, interactive: Interactive) -> bool: 113 | if mcts_node.children: 114 | return True 115 | tactics, scores = generator.from_state(mcts_node.state, self.num_samples) 116 | self.call_cnt += 1 117 | has_valid_tactics = False 118 | for (tactic, num_reps), _ in zip(Counter(tactics).items(), scores): 119 | if not self.tactic_filter(tactic): 120 | continue 121 | self.tactic_sid_record.append({"tactic":tactic, "sid":mcts_node.sid}) 122 | try: 123 | sid = interactive.run_tactic(mcts_node.sid, tactic) 124 | except RuntimeError: 125 | pass 126 | except Exception as e: 127 | #根据以往经验在get_state/giveup 加入try block可能会导致broken pipe error 128 | #如果确定问题所在可以手动添加 129 | raise SearchError("An error occurred at run_tactic", 130 | error_data = self.tactic_sid_record, 131 | error_type = e) 132 | else: 133 | has_valid_tactics = True 134 | try: 135 | #这里据国雄消息可能还有bug, debug后去掉try block 136 | state = interactive.get_state(sid) 137 | except Exception as e: 138 | raise SearchError("An error occurred at get_state", 139 | error_data = self.tactic_sid_record, 140 | error_type = e) 141 | if state_repr_dedup(state) not in self.nodes: 142 | new_node = MCTSNode(sid, mcts_node.sid, tactic, state, 143 | mcts_node.depth + 1, 1, mcts_node) 144 | mcts_node.children[tactic] = new_node 145 | self.insert(new_node) 146 | if not state: 147 | interactive.commit(sid) 148 | self.found = True 149 | return True 150 | self.score[state_repr_dedup(state)] += num_reps 151 | if not has_valid_tactics: 152 | self._delete_node(mcts_node) 153 | return False 154 | return True 155 | 156 | def _select(self, mcts_node: MCTSNode) -> MCTSNode: 157 | """选择阶段:使用PUCT公式选择最优子节点 158 | Args: 159 | mcts_node: 当前MCTS节点 160 | Returns: 161 | 选择的子节点 162 | """ 163 | if not mcts_node.children: 164 | return mcts_node 165 | 166 | # 使用PUCT公式选择子节点 167 | best_value = float('-inf') 168 | best_child = None 169 | total_visits = sum(child.visits for child in mcts_node.children.values()) 170 | 171 | temp_scores = [] 172 | 173 | for child in mcts_node.children.values(): 174 | # UCB公式: Q + c_puct * sqrt(N) / (1 + n) 175 | exploit = self.c_score * self.score[state_repr_dedup(child.state)] + child.value / (child.visits + EPSILON) # 防止除0 176 | explore = self.c_puct * math.sqrt(total_visits) / (1 + child.visits) 177 | value = exploit + explore 178 | temp_scores.append((self.score[state_repr_dedup(child.state)],child.value / (child.visits + EPSILON),explore,child.tactic)) 179 | if value > best_value: 180 | best_value = value 181 | best_child = child 182 | 183 | return best_child # type: ignore 184 | 185 | def _expand(self, mcts_node: MCTSNode, generator: TacticGenerator, interactive: Interactive) -> Optional[MCTSNode]: 186 | """扩展阶段:展开一个children已经被build好的节点 187 | Returns: 188 | 新扩展的子节点, 如果无法扩展则返回None 189 | """ 190 | flag = self._build_children(mcts_node, generator, interactive) 191 | if not flag: 192 | return None 193 | else: 194 | return mcts_node #注意这里实际上标准MCTS的expand部分被纳入这里写的select了。这里只是标准MCTS的expand部分的收尾处理。 195 | 196 | def _simulate(self, mcts_node: MCTSNode, generator: TacticGenerator, interactive: Interactive) -> float: 197 | """模拟阶段:向下模拟几步并评估终态。只roll-out一次,是因为相信value model的准确度。 198 | value model理应返回"成功概率",为0到1之间的float。 199 | Returns: 200 | 模拟得到的价值估计 201 | """ 202 | current_node = mcts_node 203 | depth = 0 204 | while depth < self.simulation_depth: 205 | if not current_node.state: 206 | return 1.0 # 找到证明 207 | tactics, _ = generator.from_state(current_node.state, 1) # 只采样一个动作,没加call_cnt是避免两种情况混淆 208 | # TODO: 按sample budget采样,之后取最高分tactic继续 209 | if not tactics: 210 | break 211 | if not self.tactic_filter(tactics[0]): 212 | continue 213 | self.tactic_sid_record.append({"tactic":tactics[0], "sid":current_node.sid}) 214 | try: 215 | sid = interactive.run_tactic(current_node.sid, tactics[0]) 216 | except RuntimeError: 217 | print(f"------- \n tactic: {tactics[0]} \n node_num: {len(self.nodes)}") 218 | break 219 | except Exception as e: 220 | #根据以往经验在get_state/giveup 加入try block可能会导致broken pipe error 221 | #如果确定问题所在可以手动添加 222 | raise SearchError("An error occurred at run_tactic", 223 | error_data = self.tactic_sid_record, 224 | error_type = e) 225 | else: 226 | #这里据国雄消息可能还有bug, debug后去掉try block 227 | try: 228 | state = interactive.get_state(sid) 229 | except Exception as e: 230 | raise SearchError("An error occurred at get_state", 231 | error_data = self.tactic_sid_record, 232 | error_type = e) 233 | current_node = Node(sid, current_node.sid, tactics[0], state, 234 | current_node.depth + 1, 1, current_node) 235 | depth += 1 236 | 237 | # TODO: 使用价值网络评估终态 238 | if not current_node.state: 239 | return 1.0 240 | return self.score[state_repr_dedup(mcts_node.state)] 241 | 242 | def _backpropagate(self, mcts_node: MCTSNode, value: float): 243 | """反向传播阶段:更新节点统计信息 244 | Args: 245 | mcts_node: 当前MCTS节点 246 | value: 模拟得到的价值 247 | """ 248 | current = mcts_node 249 | while current: 250 | current.visits += 1 251 | current.value += value 252 | # 获取父节点 253 | parent_state = state_repr_dedup(current.parent.state) if current.parent else None #TODO: 把parent在继承类中复写为MCTSNode 254 | current = self.nodes.get(parent_state) if parent_state else None 255 | 256 | def going(self) -> bool: 257 | return not self.found and len(self.nodes) < self.max_nodes and self.depth < self.max_depth and self.call_cnt < self.max_calls 258 | 259 | def search_proof(self, generator: TacticGenerator, interactive: Interactive): 260 | """执行MCTS搜索 261 | Args: 262 | generator: tactic生成器 263 | interactive: 交互式证明环境 264 | """ 265 | self.root = self._get_root() 266 | cnt = 0 267 | while (cnt < self.max_root_expansion and not self.root.children): 268 | self._build_children(self.root, generator, interactive) 269 | cnt += 1 270 | if not self.root.children: 271 | print(f"------- \n No valid tactics found ON ROOT: \ndepth {self.root.depth} \n state: {self.root.state}\n node_num: {len(self.nodes)}\n") 272 | sid = interactive.give_up(0) 273 | interactive.commit(sid) 274 | return 275 | while self.going() and generator.has_quota(): 276 | current = self.root 277 | # 选择 278 | while current.children: 279 | current = self._select(current) 280 | parent = current.parent if current.parent else None 281 | # 扩展 282 | new_node = self._expand(current, generator, interactive) 283 | if self.found: 284 | break 285 | if new_node: 286 | current = new_node 287 | #模拟 288 | value = self._simulate(current, generator, interactive) 289 | #反向传播 290 | self._backpropagate(current, value) 291 | else: 292 | value = -self.c_expansion_fail_penalty 293 | if parent: 294 | self._backpropagate(parent, value) #认为模型没有有效tactic输出是差情况,反向传播一个差的value 295 | # 检查是否找到证明 296 | # if self.found: 297 | # interactive.commit(self.success_sid) 298 | # break 299 | 300 | if not self.found: 301 | sid = interactive.give_up(0) 302 | interactive.commit(sid) 303 | 304 | @property 305 | def info(self): 306 | """返回搜索器的配置信息""" 307 | return dict( 308 | use_beam_search=False, 309 | use_mcts=True, 310 | beam_width=None, 311 | num_samples=self.num_samples, 312 | max_nodes=self.max_nodes, 313 | max_depth=self.max_depth, 314 | simulation_depth=self.simulation_depth, 315 | c_puct=self.c_puct 316 | ) 317 | --------------------------------------------------------------------------------